• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mendersoftware / mender-server / 2029630506

09 Sep 2025 08:06AM UTC coverage: 65.283% (-0.07%) from 65.357%
2029630506

Pull #941

gitlab-ci

web-flow
Merge pull request #904 from bahaa-ghazal/MEN-8649

MEN-8649: Clean OS services from tenantadm related code 🧹
Pull Request #941: Merge MEN-8649 into main

18 of 26 new or added lines in 4 files covered. (69.23%)

11 existing lines in 3 files now uncovered.

31634 of 48457 relevant lines covered (65.28%)

1.4 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

54.15
/backend/services/deviceauth/cache/cache.go
1
// Copyright 2023 Northern.tech AS
2
//
3
//    Licensed under the Apache License, Version 2.0 (the "License");
4
//    you may not use this file except in compliance with the License.
5
//    You may obtain a copy of the License at
6
//
7
//        http://www.apache.org/licenses/LICENSE-2.0
8
//
9
//    Unless required by applicable law or agreed to in writing, software
10
//    distributed under the License is distributed on an "AS IS" BASIS,
11
//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
//    See the License for the specific language governing permissions and
13
//    limitations under the License.
14

15
// Token Management
16
//
17
// Tokens are expected at:
18
// `tenant:<tid>:version<tenant_key_version>:device:<did>:tok: <token>`
19
//
20
// Cache invalidation.
21
// We achive cache invalidation by incrementing tenant key version.
22
// Each tenant related key in the cache has to contain tenant key version.
23
// This way, by incrementing tenant key version, we invalidate all tenant
24
// related keys.
25

26
package cache
27

28
import (
29
        "context"
30
        "encoding/json"
31
        "fmt"
32
        "time"
33

34
        "github.com/pkg/errors"
35
        "github.com/redis/go-redis/v9"
36

37
        "github.com/mendersoftware/mender-server/pkg/identity"
38
        "github.com/mendersoftware/mender-server/pkg/log"
39

40
        "github.com/mendersoftware/mender-server/services/deviceauth/model"
41
)
42

43
const (
44
        IdTypeDevice = "device"
45
        IdTypeUser   = "user"
46
        // expiration of the device check in time - one week
47
        CheckInTimeExpiration = time.Duration(time.Hour * 24 * 7)
48
)
49

50
var (
51
        ErrNoPositiveInteger = errors.New("must be a positive integer")
52
        ErrNegativeInteger   = errors.New("cannot be a negative integer")
53
)
54

55
//go:generate ../../../utils/mockgen.sh
56
type Cache interface {
57
        // Throttle retrieves a cached token.
58
        // These ops are bundled because the implementation will pipeline them for a single network
59
        // roundtrip for max performance.
60
        // Returns:
61
        // - the token (if any)
62
        Throttle(
63
                ctx context.Context,
64
                rawToken string,
65
                tid,
66
                id,
67
                idtype,
68
                url,
69
                action string,
70
        ) (string, error)
71

72
        // CacheToken caches the token under designated key, with expiration
73
        CacheToken(ctx context.Context, tid, id, idtype, token string, expireSec time.Duration) error
74

75
        // DeleteToken deletes the token for 'id'
76
        DeleteToken(ctx context.Context, tid, id, idtype string) error
77

78
        // GetLimit gets a limit from cache (see store.Datastore.GetLimit)
79
        GetLimit(ctx context.Context, name string) (*model.Limit, error)
80
        // SetLimit writes a limit to cache (see store.Datastore.SetLimit)
81
        SetLimit(ctx context.Context, limit *model.Limit) error
82
        // DeleteLimit evicts the limit with the given name from cache
83
        DeleteLimit(ctx context.Context, name string) error
84

85
        // CacheCheckInTime caches the last device check in time
86
        CacheCheckInTime(ctx context.Context, t *time.Time, tid, id string) error
87

88
        // GetCheckInTime gets the last device check in time from cache
89
        GetCheckInTime(ctx context.Context, tid, id string) (*time.Time, error)
90

91
        // GetCheckInTimes gets the last device check in time from cache
92
        // for each device with id from the list of ids
93
        GetCheckInTimes(ctx context.Context, tid string, ids []string) ([]*time.Time, error)
94

95
        // SuspendTenant increment tenant key version
96
        // tenant key is used in all cache keys, this way, when we increment the key version,
97
        // all the keys are no longer accessible - in other words, be incrementing tenant key version
98
        // we invalidate all tenant keys
99
        SuspendTenant(ctx context.Context, tid string) error
100
}
101

102
type RedisCache struct {
103
        c               redis.Cmdable
104
        prefix          string
105
        LimitsExpireSec int
106
        DefaultExpire   time.Duration
107
}
108

109
func NewRedisCache(
110
        redisClient redis.Cmdable,
111
        prefix string,
112
        limitsExpireSec int,
113
) *RedisCache {
1✔
114
        return &RedisCache{
1✔
115
                c:               redisClient,
1✔
116
                LimitsExpireSec: limitsExpireSec,
1✔
117
                prefix:          prefix,
1✔
118
                DefaultExpire:   time.Hour * 3,
1✔
119
        }
1✔
120
}
1✔
121

UNCOV
122
func (rl *RedisCache) keyLimit(tenantID, name string) string {
×
UNCOV
123
        if tenantID == "" {
×
124
                tenantID = "default"
×
125
        }
×
126
        return fmt.Sprintf("%s:tenant:%s:limit:%s", rl.prefix, tenantID, name)
×
127
}
128

129
func (rl *RedisCache) GetLimit(ctx context.Context, name string) (*model.Limit, error) {
×
130
        var tenantID string
×
131
        id := identity.FromContext(ctx)
×
132
        if id != nil {
×
133
                tenantID = id.Tenant
×
134
        }
×
135
        value, err := rl.c.Get(ctx, rl.keyLimit(tenantID, name)).Uint64()
×
136
        if err != nil {
×
137
                if errors.Is(err, redis.Nil) {
×
138
                        return nil, nil
×
139
                }
×
140
                return nil, err
×
141
        }
142
        return &model.Limit{
×
143
                TenantID: tenantID,
×
144
                Value:    value,
×
145
                Name:     name,
×
146
        }, nil
×
147
}
148

149
func (rl *RedisCache) SetLimit(ctx context.Context, limit *model.Limit) error {
×
150
        if limit == nil {
×
151
                return nil
×
152
        }
×
153
        var tenantID string
×
154
        id := identity.FromContext(ctx)
×
155
        if id != nil {
×
156
                tenantID = id.Tenant
×
157
        }
×
158
        key := rl.keyLimit(tenantID, limit.Name)
×
159
        return rl.c.SetEx(ctx, key, limit.Value, rl.DefaultExpire).Err()
×
160
}
161

162
func (rl *RedisCache) DeleteLimit(ctx context.Context, name string) error {
×
163
        var tenantID string
×
164
        id := identity.FromContext(ctx)
×
165
        if id != nil {
×
166
                tenantID = id.Tenant
×
167
        }
×
168
        key := rl.keyLimit(tenantID, name)
×
169
        return rl.c.Del(ctx, key).Err()
×
170
}
171

172
func (rl *RedisCache) Throttle(
173
        ctx context.Context,
174
        rawToken string,
175
        tid,
176
        id,
177
        idtype,
178
        url,
179
        action string,
180
) (string, error) {
1✔
181
        var tokenGet *redis.StringCmd
1✔
182

1✔
183
        pipe := rl.c.Pipeline()
1✔
184

1✔
185
        version, err := rl.getTenantKeyVersion(ctx, tid)
1✔
186
        if err != nil {
1✔
187
                return "", err
×
188
        }
×
189

190
        // queue token fetching
191
        // for piped execution
192
        tokenGet = rl.pipeToken(ctx, pipe, tid, id, idtype, version)
1✔
193

1✔
194
        _, err = pipe.Exec(ctx)
1✔
195
        if err != nil && !isErrRedisNil(err) {
1✔
196
                return "", err
×
197
        }
×
198

199
        // collect quota/burst control and token fetch results
200
        tok, err := rl.checkToken(tokenGet, rawToken)
1✔
201
        if err != nil {
1✔
202
                return "", err
×
203
        }
×
204

205
        return tok, nil
1✔
206
}
207

208
func (rl *RedisCache) pipeToken(
209
        ctx context.Context,
210
        pipe redis.Pipeliner,
211
        tid,
212
        id,
213
        idtype string,
214
        version int,
215
) *redis.StringCmd {
1✔
216
        key := rl.KeyToken(tid, id, idtype, version)
1✔
217
        return pipe.Get(ctx, key)
1✔
218
}
1✔
219

220
func (rl *RedisCache) checkToken(cmd *redis.StringCmd, raw string) (string, error) {
1✔
221
        err := cmd.Err()
1✔
222
        if err != nil {
2✔
223
                if isErrRedisNil(err) {
2✔
224
                        return "", nil
1✔
225
                }
1✔
226
                return "", err
×
227
        }
228

229
        token := cmd.Val()
1✔
230
        if token == raw {
2✔
231
                return token, nil
1✔
232
        } else {
2✔
233
                // must be a stale token - we don't want to use it
1✔
234
                // let it expire in the background
1✔
235
                return "", nil
1✔
236
        }
1✔
237
}
238

239
func (rl *RedisCache) CacheToken(
240
        ctx context.Context,
241
        tid,
242
        id,
243
        idtype,
244
        token string,
245
        expire time.Duration,
246
) error {
1✔
247
        version, err := rl.getTenantKeyVersion(ctx, tid)
1✔
248
        if err != nil {
1✔
249
                return err
×
250
        }
×
251
        res := rl.c.Set(ctx, rl.KeyToken(tid, id, idtype, version),
1✔
252
                token,
1✔
253
                expire)
1✔
254
        return res.Err()
1✔
255
}
256

257
func (rl *RedisCache) DeleteToken(ctx context.Context, tid, id, idtype string) error {
1✔
258
        version, err := rl.getTenantKeyVersion(ctx, tid)
1✔
259
        if err != nil {
1✔
260
                return err
×
261
        }
×
262
        res := rl.c.Del(ctx, rl.KeyToken(tid, id, idtype, version))
1✔
263
        return res.Err()
1✔
264
}
265

266
func (rl *RedisCache) KeyToken(tid, id, idtype string, version int) string {
1✔
267
        return fmt.Sprintf(
1✔
268
                "%s:tenant:%s:version:%d:%s:%s:tok",
1✔
269
                rl.prefix, tid, version, idtype, id)
1✔
270
}
1✔
271

272
func (rl *RedisCache) KeyCheckInTime(tid, id, idtype string, version int) string {
1✔
273
        return fmt.Sprintf(
1✔
274
                "%s:tenant:%s:version:%d:%s:%s:checkInTime",
1✔
275
                rl.prefix, tid, version, idtype, id)
1✔
276
}
1✔
277

278
func (rl *RedisCache) KeyTenantVersion(tid string) string {
1✔
279
        return fmt.Sprintf("%s:tenant:%s:version", rl.prefix, tid)
1✔
280
}
1✔
281

282
// isErrRedisNil checks for a very common non-error, "redis: nil",
283
// which just means the key was not found, and is normal
284
// it's routinely returned e.g. from GET, or pipelines containing it
285
func isErrRedisNil(e error) bool {
1✔
286
        return errors.Is(e, redis.Nil)
1✔
287
}
1✔
288

289
func (rl *RedisCache) CacheCheckInTime(
290
        ctx context.Context,
291
        t *time.Time,
292
        tid,
293
        id string,
294
) error {
1✔
295
        tj, err := json.Marshal(t)
1✔
296
        if err != nil {
1✔
297
                return err
×
298
        }
×
299

300
        version, err := rl.getTenantKeyVersion(ctx, tid)
1✔
301
        if err != nil {
1✔
302
                return err
×
303
        }
×
304

305
        res := rl.c.Set(
1✔
306
                ctx,
1✔
307
                rl.KeyCheckInTime(tid, id, IdTypeDevice, version),
1✔
308
                tj,
1✔
309
                CheckInTimeExpiration,
1✔
310
        )
1✔
311

1✔
312
        return res.Err()
1✔
313
}
314

315
func (rl *RedisCache) GetCheckInTime(
316
        ctx context.Context,
317
        tid,
318
        id string,
319
) (*time.Time, error) {
1✔
320
        version, err := rl.getTenantKeyVersion(ctx, tid)
1✔
321
        if err != nil {
1✔
322
                return nil, err
×
323
        }
×
324

325
        res := rl.c.Get(ctx, rl.KeyCheckInTime(tid, id, IdTypeDevice, version))
1✔
326

1✔
327
        if res.Err() != nil {
2✔
328
                if isErrRedisNil(res.Err()) {
2✔
329
                        return nil, nil
1✔
330
                }
1✔
331
                return nil, res.Err()
×
332
        }
333

334
        var checkInTime time.Time
1✔
335

1✔
336
        err = json.Unmarshal([]byte(res.Val()), &checkInTime)
1✔
337
        if err != nil {
1✔
338
                return nil, err
×
339
        }
×
340

341
        return &checkInTime, nil
1✔
342
}
343

344
func (rl *RedisCache) GetCheckInTimes(
345
        ctx context.Context,
346
        tid string,
347
        ids []string,
348
) ([]*time.Time, error) {
1✔
349
        l := log.FromContext(ctx)
1✔
350

1✔
351
        version, err := rl.getTenantKeyVersion(ctx, tid)
1✔
352
        if err != nil {
1✔
353
                return nil, err
×
354
        }
×
355
        checkInTimes := make([]*time.Time, len(ids))
1✔
356
        if _, ok := rl.c.(*redis.ClusterClient); ok {
1✔
357
                pipe := rl.c.Pipeline()
×
358
                for _, id := range ids {
×
359
                        pipe.Get(ctx, rl.KeyCheckInTime(tid, id, IdTypeDevice, version))
×
360
                }
×
361
                results, err := pipe.Exec(ctx)
×
362
                if err != nil && !errors.Is(err, redis.Nil) {
×
363
                        return nil, fmt.Errorf("failed to fetch check in times: %w", err)
×
364
                }
×
365
                for i, result := range results {
×
366
                        cmd, ok := result.(*redis.StringCmd)
×
367
                        if !ok {
×
368
                                continue // should never happen
×
369
                        }
370
                        b, err := cmd.Bytes()
×
371
                        if err != nil {
×
372
                                if errors.Is(err, redis.Nil) {
×
373
                                        continue
×
374
                                } else {
×
375
                                        l.Errorf("failed to get device: %s", err.Error())
×
376
                                }
×
377
                        } else {
×
378
                                checkInTime := new(time.Time)
×
379
                                err = json.Unmarshal(b, checkInTime)
×
380
                                if err != nil {
×
381
                                        l.Errorf("failed to deserialize check in time: %s", err.Error())
×
382
                                } else {
×
383
                                        checkInTimes[i] = checkInTime
×
384
                                }
×
385

386
                        }
387
                }
388
        } else {
1✔
389
                keys := make([]string, len(ids))
1✔
390
                for i, id := range ids {
2✔
391
                        keys[i] = rl.KeyCheckInTime(tid, id, IdTypeDevice, version)
1✔
392
                }
1✔
393
                res := rl.c.MGet(ctx, keys...)
1✔
394

1✔
395
                for i, v := range res.Val() {
2✔
396
                        if v != nil {
2✔
397
                                b, ok := v.(string)
1✔
398
                                if !ok {
1✔
399
                                        continue
×
400
                                }
401
                                var checkInTime time.Time
1✔
402
                                err := json.Unmarshal([]byte(b), &checkInTime)
1✔
403
                                if err != nil {
1✔
404
                                        l.Errorf("failed to unmarshal check-in time: %s", err.Error())
×
405
                                        continue
×
406
                                }
407
                                checkInTimes[i] = &checkInTime
1✔
408
                        }
409
                }
410
        }
411

412
        return checkInTimes, nil
1✔
413
}
414

415
func (rl *RedisCache) SuspendTenant(
416
        ctx context.Context,
417
        tid string,
418
) error {
×
419
        res := rl.c.Incr(ctx, rl.KeyTenantVersion(tid))
×
420
        return res.Err()
×
421
}
×
422

423
func (rl *RedisCache) getTenantKeyVersion(ctx context.Context, tid string) (int, error) {
1✔
424
        res := rl.c.Get(ctx, rl.KeyTenantVersion(tid))
1✔
425
        if res.Err() != nil {
2✔
426
                if isErrRedisNil(res.Err()) {
2✔
427
                        return 0, nil
1✔
428
                }
1✔
429
                return 0, res.Err()
×
430
        }
431

432
        var version int
×
433

×
434
        err := json.Unmarshal([]byte(res.Val()), &version)
×
435
        if err != nil {
×
436
                return 0, err
×
437
        }
×
438

439
        return version, nil
×
440
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc