• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mendersoftware / deviceauth / 1507843008

13 Sep 2024 11:01AM UTC coverage: 81.326%. Remained the same
1507843008

push

gitlab-ci

web-flow
Merge pull request #727 from mzedel/chore/deprecate

Chore/deprecate

4834 of 5944 relevant lines covered (81.33%)

42.77 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

74.63
/cache/cache.go
1
// Copyright 2023 Northern.tech AS
2
//
3
//    Licensed under the Apache License, Version 2.0 (the "License");
4
//    you may not use this file except in compliance with the License.
5
//    You may obtain a copy of the License at
6
//
7
//        http://www.apache.org/licenses/LICENSE-2.0
8
//
9
//    Unless required by applicable law or agreed to in writing, software
10
//    distributed under the License is distributed on an "AS IS" BASIS,
11
//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
//    See the License for the specific language governing permissions and
13
//    limitations under the License.
14

15
// Package cache introduces API throttling based
16
// on redis, and functions for auth token management.
17
//
18
// Throttling mechanisms
19
//
20
// 1. Quota enforcement
21
//
22
// Based on https://redislabs.com/redis-best-practices/basic-rate-limiting/, but with a flexible
23
// interval (ratelimits.ApiQuota.IntervalSec).
24
// Current usage for a device lives under key:
25
//
26
// `tenant:<tid>:version<tenant_key_version>:device:<did>:quota:<interval_num>: <num_reqs>`
27
//
28
// expiring in the defined time window.
29
//
30
// 2. Burst control
31
//
32
// Implemented with a simple single key:
33
//
34
// `tenant:<tid>:version<tenant_key_version>:device:<did>:burst:<action>:<url>: <last_req_ts>`
35
//
36
// expiring in ratelimits.ApiBurst.MinIntervalSec.
37
// The value is not really important, just the existence of the key
38
// means the burst was exceeded.
39
//
40
// Token Management
41
//
42
// Tokens are expected at:
43
// `tenant:<tid>:version<tenant_key_version>:device:<did>:tok: <token>`
44
//
45
// Cache invalidation.
46
// We achive cache invalidation by incrementing tenant key version.
47
// Each tenant related key in the cache has to contain tenant key version.
48
// This way, by incrementing tenant key version, we invalidate all tenant
49
// related keys.
50

51
package cache
52

53
import (
54
        "context"
55
        "encoding/json"
56
        "fmt"
57
        "strconv"
58
        "time"
59

60
        "github.com/pkg/errors"
61

62
        "github.com/mendersoftware/go-lib-micro/log"
63
        "github.com/mendersoftware/go-lib-micro/ratelimits"
64
        mredis "github.com/mendersoftware/go-lib-micro/redis"
65
        "github.com/redis/go-redis/v9"
66

67
        "github.com/mendersoftware/deviceauth/utils"
68
)
69

70
const (
71
        IdTypeDevice = "device"
72
        IdTypeUser   = "user"
73
        // expiration of the device check in time - one week
74
        CheckInTimeExpiration = time.Duration(time.Hour * 24 * 7)
75
)
76

77
var (
78
        ErrNoPositiveInteger = errors.New("must be a positive integer")
79
        ErrNegativeInteger   = errors.New("cannot be a negative integer")
80

81
        ErrTooManyRequests = errors.New("too many requests")
82
)
83

84
//go:generate ../utils/mockgen.sh
85
type Cache interface {
86
        // Throttle applies desired api limits and retrieves a cached token.
87
        // These ops are bundled because the implementation will pipeline them for a single network
88
        // roundtrip for max performance.
89
        // Returns:
90
        // - the token (if any)
91
        // - potentially ErrTooManyRequests (other errors: internal)
92
        Throttle(
93
                ctx context.Context,
94
                rawToken string,
95
                l ratelimits.ApiLimits,
96
                tid,
97
                id,
98
                idtype,
99
                url,
100
                action string,
101
        ) (string, error)
102

103
        // CacheToken caches the token under designated key, with expiration
104
        CacheToken(ctx context.Context, tid, id, idtype, token string, expireSec time.Duration) error
105

106
        // DeleteToken deletes the token for 'id'
107
        DeleteToken(ctx context.Context, tid, id, idtype string) error
108

109
        // GetLimits fetches limits for 'id'
110
        GetLimits(ctx context.Context, tid, id, idtype string) (*ratelimits.ApiLimits, error)
111

112
        // CacheLimits saves limits for 'id'
113
        CacheLimits(ctx context.Context, l ratelimits.ApiLimits, tid, id, idtype string) error
114

115
        // CacheCheckInTime caches the last device check in time
116
        CacheCheckInTime(ctx context.Context, t *time.Time, tid, id string) error
117

118
        // GetCheckInTime gets the last device check in time from cache
119
        GetCheckInTime(ctx context.Context, tid, id string) (*time.Time, error)
120

121
        // GetCheckInTimes gets the last device check in time from cache
122
        // for each device with id from the list of ids
123
        GetCheckInTimes(ctx context.Context, tid string, ids []string) ([]*time.Time, error)
124

125
        // SuspendTenant increment tenant key version
126
        // tenant key is used in all cache keys, this way, when we increment the key version,
127
        // all the keys are no longer accessible - in other words, be incrementing tenant key version
128
        // we invalidate all tenant keys
129
        SuspendTenant(ctx context.Context, tid string) error
130
}
131

132
type RedisCache struct {
133
        c               redis.Cmdable
134
        prefix          string
135
        LimitsExpireSec int
136
        clock           utils.Clock
137
}
138

139
func NewRedisCache(
140
        ctx context.Context,
141
        connectionString string,
142
        prefix string,
143
        limitsExpireSec int,
144
) (*RedisCache, error) {
7✔
145
        c, err := mredis.ClientFromConnectionString(ctx, connectionString)
7✔
146
        if err != nil {
7✔
147
                return nil, err
×
148
        }
×
149

150
        return &RedisCache{
7✔
151
                c:               c,
7✔
152
                LimitsExpireSec: limitsExpireSec,
7✔
153
                prefix:          prefix,
7✔
154
                clock:           utils.NewClock(),
7✔
155
        }, err
7✔
156
}
157

158
func (rl *RedisCache) WithClock(c utils.Clock) *RedisCache {
2✔
159
        rl.clock = c
2✔
160
        return rl
2✔
161
}
2✔
162

163
func (rl *RedisCache) Throttle(
164
        ctx context.Context,
165
        rawToken string,
166
        l ratelimits.ApiLimits,
167
        tid,
168
        id,
169
        idtype,
170
        url,
171
        action string,
172
) (string, error) {
123✔
173
        now := rl.clock.Now().Unix()
123✔
174

123✔
175
        var tokenGet *redis.StringCmd
123✔
176
        var quotaInc *redis.IntCmd
123✔
177
        var quotaExp *redis.BoolCmd
123✔
178
        var burstGet *redis.StringCmd
123✔
179
        var burstSet *redis.StatusCmd
123✔
180

123✔
181
        pipe := rl.c.Pipeline()
123✔
182

123✔
183
        version, err := rl.getTenantKeyVersion(ctx, tid)
123✔
184
        if err != nil {
123✔
185
                return "", err
×
186
        }
×
187

188
        // queue quota/burst control and token fetching
189
        // for piped execution
190
        quotaInc, quotaExp = rl.pipeQuota(ctx, pipe, l, tid, id, idtype, now, version)
123✔
191
        tokenGet = rl.pipeToken(ctx, pipe, tid, id, idtype, version)
123✔
192

123✔
193
        burstGet, burstSet = rl.pipeBurst(ctx,
123✔
194
                pipe,
123✔
195
                l,
123✔
196
                tid, id, idtype,
123✔
197
                url, action,
123✔
198
                now, version)
123✔
199

123✔
200
        _, err = pipe.Exec(ctx)
123✔
201
        if err != nil && !isErrRedisNil(err) {
123✔
202
                return "", err
×
203
        }
×
204

205
        // collect quota/burst control and token fetch results
206
        tok, err := rl.checkToken(tokenGet, rawToken)
123✔
207
        if err != nil {
123✔
208
                return "", err
×
209
        }
×
210

211
        err = rl.checkQuota(l, quotaInc, quotaExp)
123✔
212
        if err != nil {
143✔
213
                return "", err
20✔
214
        }
20✔
215

216
        err = rl.checkBurst(burstGet, burstSet)
103✔
217
        if err != nil {
136✔
218
                return "", err
33✔
219
        }
33✔
220

221
        return tok, nil
70✔
222
}
223

224
func (rl *RedisCache) pipeToken(
225
        ctx context.Context,
226
        pipe redis.Pipeliner,
227
        tid,
228
        id,
229
        idtype string,
230
        version int,
231
) *redis.StringCmd {
123✔
232
        key := rl.KeyToken(tid, id, idtype, version)
123✔
233
        return pipe.Get(ctx, key)
123✔
234
}
123✔
235

236
func (rl *RedisCache) checkToken(cmd *redis.StringCmd, raw string) (string, error) {
123✔
237
        err := cmd.Err()
123✔
238

123✔
239
        if err != nil {
241✔
240
                if isErrRedisNil(err) {
236✔
241
                        return "", nil
118✔
242
                }
118✔
243
                return "", err
×
244
        }
245

246
        token := cmd.Val()
5✔
247
        if token == raw {
9✔
248
                return token, nil
4✔
249
        } else {
5✔
250
                // must be a stale token - we don't want to use it
1✔
251
                // let it expire in the background
1✔
252
                return "", nil
1✔
253
        }
1✔
254
}
255

256
func (rl *RedisCache) pipeQuota(
257
        ctx context.Context,
258
        pipe redis.Pipeliner,
259
        l ratelimits.ApiLimits,
260
        tid,
261
        id,
262
        idtype string,
263
        now int64,
264
        version int,
265
) (*redis.IntCmd, *redis.BoolCmd) {
123✔
266
        var incr *redis.IntCmd
123✔
267
        var expire *redis.BoolCmd
123✔
268

123✔
269
        // not a default/empty quota
123✔
270
        if l.ApiQuota.MaxCalls != 0 {
197✔
271
                intvl := int64(now / int64(l.ApiQuota.IntervalSec))
74✔
272
                keyQuota := rl.KeyQuota(tid, id, idtype, strconv.FormatInt(intvl, 10), version)
74✔
273
                incr = pipe.Incr(ctx, keyQuota)
74✔
274
                expire = pipe.Expire(ctx, keyQuota, time.Duration(l.ApiQuota.IntervalSec)*time.Second)
74✔
275
        }
74✔
276

277
        return incr, expire
123✔
278
}
279

280
func (rl *RedisCache) checkQuota(
281
        l ratelimits.ApiLimits,
282
        incr *redis.IntCmd,
283
        expire *redis.BoolCmd,
284
) error {
123✔
285
        if incr == nil && expire == nil {
172✔
286
                return nil
49✔
287
        }
49✔
288

289
        err := incr.Err()
74✔
290
        if err != nil && !isErrRedisNil(err) {
74✔
291
                return err
×
292
        }
×
293

294
        err = expire.Err()
74✔
295
        if err != nil {
74✔
296
                return err
×
297
        }
×
298

299
        quota := incr.Val()
74✔
300
        if quota > int64(l.ApiQuota.MaxCalls) {
94✔
301
                return ErrTooManyRequests
20✔
302
        }
20✔
303

304
        return nil
54✔
305
}
306

307
func (rl *RedisCache) pipeBurst(ctx context.Context,
308
        pipe redis.Pipeliner,
309
        l ratelimits.ApiLimits,
310
        tid, id, idtype, url, action string,
311
        now int64, version int) (*redis.StringCmd, *redis.StatusCmd) {
123✔
312
        var get *redis.StringCmd
123✔
313
        var set *redis.StatusCmd
123✔
314

123✔
315
        for _, b := range l.ApiBursts {
207✔
316
                if b.Action == action &&
84✔
317
                        b.Uri == url &&
84✔
318
                        b.MinIntervalSec != 0 {
153✔
319

69✔
320
                        intvl := int64(now / int64(b.MinIntervalSec))
69✔
321
                        keyBurst := rl.KeyBurst(
69✔
322
                                tid, id, idtype, url, action, strconv.FormatInt(intvl, 10), version)
69✔
323

69✔
324
                        get = pipe.Get(ctx, keyBurst)
69✔
325
                        set = pipe.Set(ctx, keyBurst, now, time.Duration(b.MinIntervalSec)*time.Second)
69✔
326
                }
69✔
327
        }
328

329
        return get, set
123✔
330
}
331

332
func (rl *RedisCache) checkBurst(get *redis.StringCmd, set *redis.StatusCmd) error {
103✔
333
        if get != nil && set != nil {
167✔
334
                err := get.Err()
64✔
335

64✔
336
                // no error means burst was found/hit
64✔
337
                if err == nil {
97✔
338
                        return ErrTooManyRequests
33✔
339
                }
33✔
340

341
                if isErrRedisNil(err) {
62✔
342
                        return nil
31✔
343
                }
31✔
344

345
                return err
×
346
        }
347

348
        return nil
39✔
349
}
350

351
func (rl *RedisCache) CacheToken(
352
        ctx context.Context,
353
        tid,
354
        id,
355
        idtype,
356
        token string,
357
        expire time.Duration,
358
) error {
3✔
359
        version, err := rl.getTenantKeyVersion(ctx, tid)
3✔
360
        if err != nil {
3✔
361
                return err
×
362
        }
×
363
        res := rl.c.Set(ctx, rl.KeyToken(tid, id, idtype, version),
3✔
364
                token,
3✔
365
                expire)
3✔
366
        return res.Err()
3✔
367
}
368

369
func (rl *RedisCache) DeleteToken(ctx context.Context, tid, id, idtype string) error {
2✔
370
        version, err := rl.getTenantKeyVersion(ctx, tid)
2✔
371
        if err != nil {
2✔
372
                return err
×
373
        }
×
374
        res := rl.c.Del(ctx, rl.KeyToken(tid, id, idtype, version))
2✔
375
        return res.Err()
2✔
376
}
377

378
func (rl *RedisCache) GetLimits(
379
        ctx context.Context,
380
        tid,
381
        id,
382
        idtype string,
383
) (*ratelimits.ApiLimits, error) {
3✔
384

3✔
385
        version, err := rl.getTenantKeyVersion(ctx, tid)
3✔
386
        if err != nil {
3✔
387
                return nil, err
×
388
        }
×
389

390
        res := rl.c.Get(ctx, rl.KeyLimits(tid, id, idtype, version))
3✔
391

3✔
392
        if res.Err() != nil {
5✔
393
                if isErrRedisNil(res.Err()) {
4✔
394
                        return nil, nil
2✔
395
                }
2✔
396
                return nil, res.Err()
×
397
        }
398

399
        var limits ratelimits.ApiLimits
1✔
400

1✔
401
        err = json.Unmarshal([]byte(res.Val()), &limits)
1✔
402
        if err != nil {
1✔
403
                return nil, err
×
404
        }
×
405

406
        return &limits, nil
1✔
407
}
408

409
func (rl *RedisCache) CacheLimits(
410
        ctx context.Context,
411
        l ratelimits.ApiLimits,
412
        tid,
413
        id,
414
        idtype string,
415
) error {
1✔
416
        enc, err := json.Marshal(l)
1✔
417
        if err != nil {
1✔
418
                return err
×
419
        }
×
420

421
        version, err := rl.getTenantKeyVersion(ctx, tid)
1✔
422
        if err != nil {
1✔
423
                return err
×
424
        }
×
425

426
        res := rl.c.Set(
1✔
427
                ctx,
1✔
428
                rl.KeyLimits(tid, id, idtype, version),
1✔
429
                enc,
1✔
430
                time.Duration(rl.LimitsExpireSec)*time.Second,
1✔
431
        )
1✔
432

1✔
433
        return res.Err()
1✔
434
}
435

436
func (rl *RedisCache) KeyQuota(tid, id, idtype, intvlNum string, version int) string {
74✔
437
        return fmt.Sprintf(
74✔
438
                "%s:tenant:%s:version:%d:%s:%s:quota:%s",
74✔
439
                rl.prefix, tid, version, idtype, id, intvlNum)
74✔
440
}
74✔
441

442
func (rl *RedisCache) KeyBurst(
443
        tid, id, idtype, url, action, intvlNum string, version int) string {
69✔
444
        return fmt.Sprintf(
69✔
445
                "%s:tenant:%s:version:%d:%s:%s:burst:%s:%s:%s",
69✔
446
                rl.prefix, tid, version, idtype, id, url, action, intvlNum)
69✔
447
}
69✔
448

449
func (rl *RedisCache) KeyToken(tid, id, idtype string, version int) string {
132✔
450
        return fmt.Sprintf(
132✔
451
                "%s:tenant:%s:version:%d:%s:%s:tok",
132✔
452
                rl.prefix, tid, version, idtype, id)
132✔
453
}
132✔
454

455
func (rl *RedisCache) KeyLimits(tid, id, idtype string, version int) string {
4✔
456
        return fmt.Sprintf(
4✔
457
                "%s:tenant:%s:version:%d:%s:%s:limits",
4✔
458
                rl.prefix, tid, version, idtype, id)
4✔
459
}
4✔
460

461
func (rl *RedisCache) KeyCheckInTime(tid, id, idtype string, version int) string {
4✔
462
        return fmt.Sprintf(
4✔
463
                "%s:tenant:%s:version:%d:%s:%s:checkInTime",
4✔
464
                rl.prefix, tid, version, idtype, id)
4✔
465
}
4✔
466

467
func (rl *RedisCache) KeyTenantVersion(tid string) string {
136✔
468
        return fmt.Sprintf("%s:tenant:%s:version", rl.prefix, tid)
136✔
469
}
136✔
470

471
// isErrRedisNil checks for a very common non-error, "redis: nil",
472
// which just means the key was not found, and is normal
473
// it's routinely returned e.g. from GET, or pipelines containing it
474
func isErrRedisNil(e error) bool {
406✔
475
        return errors.Is(e, redis.Nil)
406✔
476
}
406✔
477

478
// TODO: move to go-lib-micro/ratelimits
479
func LimitsEmpty(l *ratelimits.ApiLimits) bool {
×
480
        return l.ApiQuota.MaxCalls == 0 &&
×
481
                l.ApiQuota.IntervalSec == 0 &&
×
482
                len(l.ApiBursts) == 0
×
483
}
×
484

485
func (rl *RedisCache) CacheCheckInTime(
486
        ctx context.Context,
487
        t *time.Time,
488
        tid,
489
        id string,
490
) error {
1✔
491
        tj, err := json.Marshal(t)
1✔
492
        if err != nil {
1✔
493
                return err
×
494
        }
×
495

496
        version, err := rl.getTenantKeyVersion(ctx, tid)
1✔
497
        if err != nil {
1✔
498
                return err
×
499
        }
×
500

501
        res := rl.c.Set(
1✔
502
                ctx,
1✔
503
                rl.KeyCheckInTime(tid, id, IdTypeDevice, version),
1✔
504
                tj,
1✔
505
                CheckInTimeExpiration,
1✔
506
        )
1✔
507

1✔
508
        return res.Err()
1✔
509
}
510

511
func (rl *RedisCache) GetCheckInTime(
512
        ctx context.Context,
513
        tid,
514
        id string,
515
) (*time.Time, error) {
2✔
516

2✔
517
        version, err := rl.getTenantKeyVersion(ctx, tid)
2✔
518
        if err != nil {
2✔
519
                return nil, err
×
520
        }
×
521

522
        res := rl.c.Get(ctx, rl.KeyCheckInTime(tid, id, IdTypeDevice, version))
2✔
523

2✔
524
        if res.Err() != nil {
3✔
525
                if isErrRedisNil(res.Err()) {
2✔
526
                        return nil, nil
1✔
527
                }
1✔
528
                return nil, res.Err()
×
529
        }
530

531
        var checkInTime time.Time
1✔
532

1✔
533
        err = json.Unmarshal([]byte(res.Val()), &checkInTime)
1✔
534
        if err != nil {
1✔
535
                return nil, err
×
536
        }
×
537

538
        return &checkInTime, nil
1✔
539
}
540

541
func (rl *RedisCache) GetCheckInTimes(
542
        ctx context.Context,
543
        tid string,
544
        ids []string,
545
) ([]*time.Time, error) {
1✔
546
        l := log.FromContext(ctx)
1✔
547

1✔
548
        version, err := rl.getTenantKeyVersion(ctx, tid)
1✔
549
        if err != nil {
1✔
550
                return nil, err
×
551
        }
×
552
        checkInTimes := make([]*time.Time, len(ids))
1✔
553
        if _, ok := rl.c.(*redis.ClusterClient); ok {
1✔
554
                pipe := rl.c.Pipeline()
×
555
                for _, id := range ids {
×
556
                        pipe.Get(ctx, rl.KeyCheckInTime(tid, id, IdTypeDevice, version))
×
557
                }
×
558
                results, err := pipe.Exec(ctx)
×
559
                if err != nil && !errors.Is(err, redis.Nil) {
×
560
                        return nil, fmt.Errorf("failed to fetch check in times: %w", err)
×
561
                }
×
562
                for i, result := range results {
×
563
                        cmd, ok := result.(*redis.StringCmd)
×
564
                        if !ok {
×
565
                                continue // should never happen
×
566
                        }
567
                        b, err := cmd.Bytes()
×
568
                        if err != nil {
×
569
                                if errors.Is(err, redis.Nil) {
×
570
                                        continue
×
571
                                } else {
×
572
                                        l.Errorf("failed to get device: %s", err.Error())
×
573
                                }
×
574
                        } else {
×
575
                                checkInTime := new(time.Time)
×
576
                                err = json.Unmarshal(b, checkInTime)
×
577
                                if err != nil {
×
578
                                        l.Errorf("failed to deserialize check in time: %s", err.Error())
×
579
                                } else {
×
580
                                        checkInTimes[i] = checkInTime
×
581
                                }
×
582

583
                        }
584
                }
585
        } else {
1✔
586
                keys := make([]string, len(ids))
1✔
587
                for i, id := range ids {
2✔
588
                        keys[i] = rl.KeyCheckInTime(tid, id, IdTypeDevice, version)
1✔
589
                }
1✔
590
                res := rl.c.MGet(ctx, keys...)
1✔
591

1✔
592
                for i, v := range res.Val() {
2✔
593
                        if v != nil {
2✔
594
                                b, ok := v.(string)
1✔
595
                                if !ok {
1✔
596
                                        continue
×
597
                                }
598
                                var checkInTime time.Time
1✔
599
                                err := json.Unmarshal([]byte(b), &checkInTime)
1✔
600
                                if err != nil {
1✔
601
                                        l.Errorf("failed to unmarshal check-in time: %s", err.Error())
×
602
                                        continue
×
603
                                }
604
                                checkInTimes[i] = &checkInTime
1✔
605
                        }
606
                }
607
        }
608

609
        return checkInTimes, nil
1✔
610
}
611

612
func (rl *RedisCache) SuspendTenant(
613
        ctx context.Context,
614
        tid string,
615
) error {
×
616
        res := rl.c.Incr(ctx, rl.KeyTenantVersion(tid))
×
617
        return res.Err()
×
618
}
×
619

620
func (rl *RedisCache) getTenantKeyVersion(ctx context.Context, tid string) (int, error) {
136✔
621
        res := rl.c.Get(ctx, rl.KeyTenantVersion(tid))
136✔
622
        if res.Err() != nil {
272✔
623
                if isErrRedisNil(res.Err()) {
272✔
624
                        return 0, nil
136✔
625
                }
136✔
626
                return 0, res.Err()
×
627
        }
628

629
        var version int
×
630

×
631
        err := json.Unmarshal([]byte(res.Val()), &version)
×
632
        if err != nil {
×
633
                return 0, err
×
634
        }
×
635

636
        return version, nil
×
637
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc