• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mendersoftware / deviceauth / 857655394

pending completion
857655394

Pull #644

gitlab-ci

Krzysztof Jaskiewicz
chore: reduce cyclomatic complexity of VerifyToken method
Pull Request #644: feat: handle device check-in time

106 of 160 new or added lines in 3 files covered. (66.25%)

99 existing lines in 3 files now uncovered.

4627 of 5519 relevant lines covered (83.84%)

46.04 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.54
/cache/cache.go
1
// Copyright 2023 Northern.tech AS
2
//
3
//    Licensed under the Apache License, Version 2.0 (the "License");
4
//    you may not use this file except in compliance with the License.
5
//    You may obtain a copy of the License at
6
//
7
//        http://www.apache.org/licenses/LICENSE-2.0
8
//
9
//    Unless required by applicable law or agreed to in writing, software
10
//    distributed under the License is distributed on an "AS IS" BASIS,
11
//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
//    See the License for the specific language governing permissions and
13
//    limitations under the License.
14

15
// Package cache introduces API throttling based
16
// on redis, and functions for auth token management.
17
//
18
// Throttling mechanisms
19
//
20
// 1. Quota enforcement
21
//
22
// Based on https://redislabs.com/redis-best-practices/basic-rate-limiting/, but with a flexible
23
// interval (ratelimits.ApiQuota.IntervalSec).
24
// Current usage for a device lives under key:
25
//
26
// `tenant:<tid>:device:<did>:quota:<interval_num>: <num_reqs>`
27
//
28
// expiring in the defined time window.
29
//
30
// 2. Burst control
31
//
32
// Implemented with a simple single key:
33
//
34
// `tenant:<tid>:device:<did>:burst:<action>:<url>: <last_req_ts>`
35
//
36
// expiring in ratelimits.ApiBurst.MinIntervalSec.
37
// The value is not really important, just the existence of the key
38
// means the burst was exceeded.
39
//
40
// Token Management
41
//
42
// Tokens are expected at:
43
// `tenant:<tid>:device:<did>:tok: <token>`
44
//
45

46
package cache
47

48
import (
49
        "context"
50
        "encoding/json"
51
        "fmt"
52
        "strconv"
53
        "time"
54

55
        "github.com/pkg/errors"
56

57
        "github.com/go-redis/redis/v8"
58
        "github.com/mendersoftware/go-lib-micro/log"
59
        "github.com/mendersoftware/go-lib-micro/ratelimits"
60

61
        "github.com/mendersoftware/deviceauth/utils"
62
)
63

64
const (
65
        IdTypeDevice = "device"
66
        IdTypeUser   = "user"
67
        // expiration of the device check in time - one week
68
        CheckInTimeExpiration = time.Duration(time.Hour * 24 * 7)
69
)
70

71
var (
72
        ErrTooManyRequests = errors.New("too many requests")
73
)
74

75
//go:generate ../utils/mockgen.sh
76
type Cache interface {
77
        // Throttle applies desired api limits and retrieves a cached token.
78
        // These ops are bundled because the implementation will pipeline them for a single network
79
        // roundtrip for max performance.
80
        // Returns:
81
        // - the token (if any)
82
        // - potentially ErrTooManyRequests (other errors: internal)
83
        Throttle(
84
                ctx context.Context,
85
                rawToken string,
86
                l ratelimits.ApiLimits,
87
                tid,
88
                id,
89
                idtype,
90
                url,
91
                action string,
92
        ) (string, error)
93

94
        // CacheToken caches the token under designated key, with expiration
95
        CacheToken(ctx context.Context, tid, id, idtype, token string, expireSec time.Duration) error
96

97
        // DeleteToken deletes the token for 'id'
98
        DeleteToken(ctx context.Context, tid, id, idtype string) error
99

100
        // GetLimits fetches limits for 'id'
101
        GetLimits(ctx context.Context, tid, id, idtype string) (*ratelimits.ApiLimits, error)
102

103
        // CacheLimits saves limits for 'id'
104
        CacheLimits(ctx context.Context, l ratelimits.ApiLimits, tid, id, idtype string) error
105

106
        // FlushDB clears the whole db asynchronously (FLUSHDB ASYNC)
107
        // TODO: replace with more fine grained key removal (per tenant)
108
        FlushDB(ctx context.Context) error
109

110
        // CacheCheckInTime caches the last device check in time
111
        CacheCheckInTime(ctx context.Context, t *time.Time, tid, id string) error
112

113
        // GetCheckInTime gets the last device check in time from cache
114
        GetCheckInTime(ctx context.Context, tid, id string) (*time.Time, error)
115

116
        // GetCheckInTimes gets the last device check in time from cache
117
        // for each device with id from the list of ids
118
        GetCheckInTimes(ctx context.Context, tid string, ids []string) ([]*time.Time, error)
119
}
120

121
type RedisCache struct {
122
        c               *redis.Client
123
        LimitsExpireSec int
124
        clock           utils.Clock
125
}
126

127
func NewRedisCache(
128
        addr,
129
        user,
130
        pass string,
131
        db int,
132
        timeoutSec,
133
        limitsExpireSec int,
134
) (*RedisCache, error) {
16✔
135
        c := redis.NewClient(&redis.Options{
16✔
136
                Addr:     addr,
16✔
137
                Username: user,
16✔
138
                Password: pass,
16✔
139
                DB:       db,
16✔
140
        })
16✔
141

16✔
142
        c = c.WithTimeout(time.Duration(timeoutSec) * time.Second)
16✔
143

16✔
144
        _, err := c.Ping(context.TODO()).Result()
16✔
145
        return &RedisCache{
16✔
146
                c:               c,
16✔
147
                LimitsExpireSec: limitsExpireSec,
16✔
148
                clock:           utils.NewClock(),
16✔
149
        }, err
16✔
150
}
16✔
151

152
func (rl *RedisCache) WithClock(c utils.Clock) *RedisCache {
4✔
153
        rl.clock = c
4✔
154
        return rl
4✔
155
}
4✔
156

157
func (rl *RedisCache) Throttle(
158
        ctx context.Context,
159
        rawToken string,
160
        l ratelimits.ApiLimits,
161
        tid,
162
        id,
163
        idtype,
164
        url,
165
        action string,
166
) (string, error) {
250✔
167
        now := rl.clock.Now().Unix()
250✔
168

250✔
169
        var tokenGet *redis.StringCmd
250✔
170
        var quotaInc *redis.IntCmd
250✔
171
        var quotaExp *redis.BoolCmd
250✔
172
        var burstGet *redis.StringCmd
250✔
173
        var burstSet *redis.StatusCmd
250✔
174

250✔
175
        pipe := rl.c.TxPipeline()
250✔
176

250✔
177
        // queue quota/burst control and token fetching
250✔
178
        // for piped execution
250✔
179
        quotaInc, quotaExp = rl.pipeQuota(ctx, pipe, l, tid, id, idtype, now)
250✔
180
        tokenGet = rl.pipeToken(ctx, pipe, tid, id, idtype)
250✔
181

250✔
182
        burstGet, burstSet = rl.pipeBurst(ctx,
250✔
183
                pipe,
250✔
184
                l,
250✔
185
                tid, id, idtype,
250✔
186
                url, action,
250✔
187
                now)
250✔
188

250✔
189
        _, err := pipe.Exec(ctx)
250✔
190
        if err != nil && !isErrRedisNil(err) {
250✔
UNCOV
191
                return "", err
×
UNCOV
192
        }
×
193

194
        // collect quota/burst control and token fetch results
195
        tok, err := rl.checkToken(tokenGet, rawToken)
250✔
196
        if err != nil {
250✔
UNCOV
197
                return "", err
×
UNCOV
198
        }
×
199

200
        err = rl.checkQuota(l, quotaInc, quotaExp)
250✔
201
        if err != nil {
290✔
202
                return "", err
40✔
203
        }
40✔
204

205
        err = rl.checkBurst(burstGet, burstSet)
210✔
206
        if err != nil {
276✔
207
                return "", err
66✔
208
        }
66✔
209

210
        return tok, nil
144✔
211
}
212

213
func (rl *RedisCache) pipeToken(
214
        ctx context.Context,
215
        pipe redis.Pipeliner,
216
        tid,
217
        id,
218
        idtype string,
219
) *redis.StringCmd {
250✔
220
        key := KeyToken(tid, id, idtype)
250✔
221
        return pipe.Get(ctx, key)
250✔
222
}
250✔
223

224
func (rl *RedisCache) checkToken(cmd *redis.StringCmd, raw string) (string, error) {
250✔
225
        err := cmd.Err()
250✔
226

250✔
227
        if err != nil {
490✔
228
                if isErrRedisNil(err) {
480✔
229
                        return "", nil
240✔
230
                }
240✔
UNCOV
231
                return "", err
×
232
        }
233

234
        token := cmd.Val()
10✔
235
        if token == raw {
18✔
236
                return token, nil
8✔
237
        } else {
10✔
238
                // must be a stale token - we don't want to use it
2✔
239
                // let it expire in the background
2✔
240
                return "", nil
2✔
241
        }
2✔
242
}
243

244
func (rl *RedisCache) pipeQuota(
245
        ctx context.Context,
246
        pipe redis.Pipeliner,
247
        l ratelimits.ApiLimits,
248
        tid,
249
        id,
250
        idtype string,
251
        now int64,
252
) (*redis.IntCmd, *redis.BoolCmd) {
250✔
253
        var incr *redis.IntCmd
250✔
254
        var expire *redis.BoolCmd
250✔
255

250✔
256
        // not a default/empty quota
250✔
257
        if l.ApiQuota.MaxCalls != 0 {
398✔
258
                intvl := int64(now / int64(l.ApiQuota.IntervalSec))
148✔
259
                keyQuota := KeyQuota(tid, id, idtype, strconv.FormatInt(intvl, 10))
148✔
260
                incr = pipe.Incr(ctx, keyQuota)
148✔
261
                expire = pipe.Expire(ctx, keyQuota, time.Duration(l.ApiQuota.IntervalSec)*time.Second)
148✔
262
        }
148✔
263

264
        return incr, expire
250✔
265
}
266

267
func (rl *RedisCache) checkQuota(
268
        l ratelimits.ApiLimits,
269
        incr *redis.IntCmd,
270
        expire *redis.BoolCmd,
271
) error {
250✔
272
        if incr == nil && expire == nil {
352✔
273
                return nil
102✔
274
        }
102✔
275

276
        err := incr.Err()
148✔
277
        if err != nil && !isErrRedisNil(err) {
148✔
UNCOV
278
                return err
×
UNCOV
279
        }
×
280

281
        err = expire.Err()
148✔
282
        if err != nil {
148✔
UNCOV
283
                return err
×
UNCOV
284
        }
×
285

286
        quota := incr.Val()
148✔
287
        if quota > int64(l.ApiQuota.MaxCalls) {
188✔
288
                return ErrTooManyRequests
40✔
289
        }
40✔
290

291
        return nil
108✔
292
}
293

294
func (rl *RedisCache) pipeBurst(ctx context.Context,
295
        pipe redis.Pipeliner,
296
        l ratelimits.ApiLimits,
297
        tid, id, idtype, url, action string,
298
        now int64) (*redis.StringCmd, *redis.StatusCmd) {
250✔
299
        var get *redis.StringCmd
250✔
300
        var set *redis.StatusCmd
250✔
301

250✔
302
        for _, b := range l.ApiBursts {
418✔
303
                if b.Action == action &&
168✔
304
                        b.Uri == url &&
168✔
305
                        b.MinIntervalSec != 0 {
306✔
306

138✔
307
                        intvl := int64(now / int64(b.MinIntervalSec))
138✔
308
                        keyBurst := KeyBurst(tid, id, idtype, url, action, strconv.FormatInt(intvl, 10))
138✔
309

138✔
310
                        get = pipe.Get(ctx, keyBurst)
138✔
311
                        set = pipe.Set(ctx, keyBurst, now, time.Duration(b.MinIntervalSec)*time.Second)
138✔
312
                }
138✔
313
        }
314

315
        return get, set
250✔
316
}
317

318
func (rl *RedisCache) checkBurst(get *redis.StringCmd, set *redis.StatusCmd) error {
210✔
319
        if get != nil && set != nil {
338✔
320
                err := get.Err()
128✔
321

128✔
322
                // no error means burst was found/hit
128✔
323
                if err == nil {
194✔
324
                        return ErrTooManyRequests
66✔
325
                }
66✔
326

327
                if isErrRedisNil(err) {
124✔
328
                        return nil
62✔
329
                }
62✔
330

UNCOV
331
                return err
×
332
        }
333

334
        return nil
82✔
335
}
336

337
func (rl *RedisCache) CacheToken(
338
        ctx context.Context,
339
        tid,
340
        id,
341
        idtype,
342
        token string,
343
        expire time.Duration,
344
) error {
10✔
345
        res := rl.c.Set(ctx, KeyToken(tid, id, idtype),
10✔
346
                token,
10✔
347
                expire)
10✔
348
        return res.Err()
10✔
349
}
10✔
350

351
func (rl *RedisCache) DeleteToken(ctx context.Context, tid, id, idtype string) error {
6✔
352
        res := rl.c.Del(ctx, KeyToken(tid, id, idtype))
6✔
353
        return res.Err()
6✔
354
}
6✔
355

356
func (rl *RedisCache) GetLimits(
357
        ctx context.Context,
358
        tid,
359
        id,
360
        idtype string,
361
) (*ratelimits.ApiLimits, error) {
6✔
362
        res := rl.c.Get(ctx, KeyLimits(tid, id, idtype))
6✔
363

6✔
364
        if res.Err() != nil {
10✔
365
                if isErrRedisNil(res.Err()) {
8✔
366
                        return nil, nil
4✔
367
                }
4✔
UNCOV
368
                return nil, res.Err()
×
369
        }
370

371
        var limits ratelimits.ApiLimits
2✔
372

2✔
373
        err := json.Unmarshal([]byte(res.Val()), &limits)
2✔
374
        if err != nil {
2✔
UNCOV
375
                return nil, err
×
UNCOV
376
        }
×
377

378
        return &limits, nil
2✔
379
}
380

381
func (rl *RedisCache) CacheLimits(
382
        ctx context.Context,
383
        l ratelimits.ApiLimits,
384
        tid,
385
        id,
386
        idtype string,
387
) error {
2✔
388
        enc, err := json.Marshal(l)
2✔
389
        if err != nil {
2✔
UNCOV
390
                return err
×
UNCOV
391
        }
×
392

393
        res := rl.c.Set(
2✔
394
                ctx,
2✔
395
                KeyLimits(tid, id, idtype),
2✔
396
                enc,
2✔
397
                time.Duration(rl.LimitsExpireSec)*time.Second,
2✔
398
        )
2✔
399

2✔
400
        return res.Err()
2✔
401
}
402

403
func (rl *RedisCache) FlushDB(ctx context.Context) error {
2✔
404
        return rl.c.FlushDBAsync(ctx).Err()
2✔
405
}
2✔
406

407
func KeyQuota(tid, id, idtype, intvlNum string) string {
148✔
408
        return fmt.Sprintf("tenant:%s:%s:%s:quota:%s", tid, idtype, id, intvlNum)
148✔
409
}
148✔
410

411
func KeyBurst(tid, id, idtype, url, action, intvlNum string) string {
138✔
412
        return fmt.Sprintf("tenant:%s:%s:%s:burst:%s:%s:%s", tid, idtype, id, url, action, intvlNum)
138✔
413
}
138✔
414

415
func KeyToken(tid, id, idtype string) string {
274✔
416
        return fmt.Sprintf("tenant:%s:%s:%s:tok", tid, idtype, id)
274✔
417
}
274✔
418

419
func KeyLimits(tid, id, idtype string) string {
8✔
420
        return fmt.Sprintf("tenant:%s:%s:%s:limits", tid, idtype, id)
8✔
421
}
8✔
422

423
func KeyCheckInTime(tid, id, idtype string) string {
8✔
424
        return fmt.Sprintf("tenant:%s:%s:%s:checkInTime", tid, idtype, id)
8✔
425
}
8✔
426

427
// isErrRedisNil checks for a very common non-error, "redis: nil",
428
// which just means the key was not found, and is normal
429
// it's routinely returned e.g. from GET, or pipelines containing it
430
func isErrRedisNil(e error) bool {
548✔
431
        return e.Error() == "redis: nil"
548✔
432
}
548✔
433

434
// TODO: move to go-lib-micro/ratelimits
UNCOV
435
func LimitsEmpty(l *ratelimits.ApiLimits) bool {
×
UNCOV
436
        return l.ApiQuota.MaxCalls == 0 &&
×
UNCOV
437
                l.ApiQuota.IntervalSec == 0 &&
×
UNCOV
438
                len(l.ApiBursts) == 0
×
UNCOV
439
}
×
440

441
func (rl *RedisCache) CacheCheckInTime(
442
        ctx context.Context,
443
        t *time.Time,
444
        tid,
445
        id string,
446
) error {
2✔
447
        tj, err := json.Marshal(t)
2✔
448
        if err != nil {
2✔
NEW
449
                return err
×
NEW
450
        }
×
451

452
        res := rl.c.Set(
2✔
453
                ctx,
2✔
454
                KeyCheckInTime(tid, id, IdTypeDevice),
2✔
455
                tj,
2✔
456
                time.Duration(CheckInTimeExpiration),
2✔
457
        )
2✔
458

2✔
459
        return res.Err()
2✔
460
}
461

462
func (rl *RedisCache) GetCheckInTime(
463
        ctx context.Context,
464
        tid,
465
        id string,
466
) (*time.Time, error) {
4✔
467
        res := rl.c.Get(ctx, KeyCheckInTime(tid, id, IdTypeDevice))
4✔
468

4✔
469
        if res.Err() != nil {
6✔
470
                if isErrRedisNil(res.Err()) {
4✔
471
                        return nil, nil
2✔
472
                }
2✔
NEW
473
                return nil, res.Err()
×
474
        }
475

476
        var checkInTime time.Time
2✔
477

2✔
478
        err := json.Unmarshal([]byte(res.Val()), &checkInTime)
2✔
479
        if err != nil {
2✔
NEW
480
                return nil, err
×
NEW
481
        }
×
482

483
        return &checkInTime, nil
2✔
484
}
485

486
func (rl *RedisCache) GetCheckInTimes(
487
        ctx context.Context,
488
        tid string,
489
        ids []string,
490
) ([]*time.Time, error) {
2✔
491
        keys := make([]string, len(ids))
2✔
492
        for i, id := range ids {
4✔
493
                keys[i] = KeyCheckInTime(tid, id, IdTypeDevice)
2✔
494
        }
2✔
495

496
        res := rl.c.MGet(ctx, keys...)
2✔
497

2✔
498
        checkInTimes := make([]*time.Time, len(ids))
2✔
499

2✔
500
        for i, v := range res.Val() {
4✔
501
                if v != nil {
4✔
502
                        b, ok := v.(string)
2✔
503
                        if !ok {
2✔
NEW
504
                                continue
×
505
                        }
506
                        var checkInTime time.Time
2✔
507
                        err := json.Unmarshal([]byte(b), &checkInTime)
2✔
508
                        checkInTimes[i] = &checkInTime
2✔
509
                        if err != nil {
2✔
NEW
510
                                l := log.FromContext(ctx)
×
NEW
511
                                l.Errorf("failed to unmarshal check-in time: %s", err.Error())
×
NEW
512
                                continue
×
513
                        }
514
                }
515
        }
516

517
        return checkInTimes, nil
2✔
518
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc