• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mendersoftware / deviceauth / 834337996

pending completion
834337996

push

gitlab-ci

GitHub
Merge pull request #636 from alfrunes/MEN-6399

11 of 11 new or added lines in 1 file covered. (100.0%)

20 existing lines in 1 file now uncovered.

4531 of 5374 relevant lines covered (84.31%)

47.21 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.24
/cache/cache.go
1
// Copyright 2023 Northern.tech AS
2
//
3
//    Licensed under the Apache License, Version 2.0 (the "License");
4
//    you may not use this file except in compliance with the License.
5
//    You may obtain a copy of the License at
6
//
7
//        http://www.apache.org/licenses/LICENSE-2.0
8
//
9
//    Unless required by applicable law or agreed to in writing, software
10
//    distributed under the License is distributed on an "AS IS" BASIS,
11
//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
//    See the License for the specific language governing permissions and
13
//    limitations under the License.
14

15
// Package cache introduces API throttling based
16
// on redis, and functions for auth token management.
17
//
18
// Throttling mechanisms
19
//
20
// 1. Quota enforcement
21
//
22
// Based on https://redislabs.com/redis-best-practices/basic-rate-limiting/, but with a flexible
23
// interval (ratelimits.ApiQuota.IntervalSec).
24
// Current usage for a device lives under key:
25
//
26
// `tenant:<tid>:device:<did>:quota:<interval_num>: <num_reqs>`
27
//
28
// expiring in the defined time window.
29
//
30
// 2. Burst control
31
//
32
// Implemented with a simple single key:
33
//
34
// `tenant:<tid>:device:<did>:burst:<action>:<url>: <last_req_ts>`
35
//
36
// expiring in ratelimits.ApiBurst.MinIntervalSec.
37
// The value is not really important, just the existence of the key
38
// means the burst was exceeded.
39
//
40
// Token Management
41
//
42
// Tokens are expected at:
43
// `tenant:<tid>:device:<did>:tok: <token>`
44
//
45

46
package cache
47

48
import (
49
        "context"
50
        "encoding/json"
51
        "fmt"
52
        "net"
53
        "strconv"
54
        "time"
55

56
        "github.com/pkg/errors"
57

58
        "github.com/go-redis/redis/v8"
59
        "github.com/mendersoftware/go-lib-micro/ratelimits"
60

61
        "github.com/mendersoftware/deviceauth/utils"
62
)
63

64
const (
65
        IdTypeDevice = "device"
66
        IdTypeUser   = "user"
67
)
68

69
var (
70
        ErrNoPositiveInteger = errors.New("must be a positive integer")
71
        ErrNegativeInteger   = errors.New("cannot be a negative integer")
72

73
        ErrTooManyRequests = errors.New("too many requests")
74
)
75

76
//go:generate ../utils/mockgen.sh
77
type Cache interface {
78
        // Throttle applies desired api limits and retrieves a cached token.
79
        // These ops are bundled because the implementation will pipeline them for a single network
80
        // roundtrip for max performance.
81
        // Returns:
82
        // - the token (if any)
83
        // - potentially ErrTooManyRequests (other errors: internal)
84
        Throttle(
85
                ctx context.Context,
86
                rawToken string,
87
                l ratelimits.ApiLimits,
88
                tid,
89
                id,
90
                idtype,
91
                url,
92
                action string,
93
        ) (string, error)
94

95
        // CacheToken caches the token under designated key, with expiration
96
        CacheToken(ctx context.Context, tid, id, idtype, token string, expireSec time.Duration) error
97

98
        // DeleteToken deletes the token for 'id'
99
        DeleteToken(ctx context.Context, tid, id, idtype string) error
100

101
        // GetLimits fetches limits for 'id'
102
        GetLimits(ctx context.Context, tid, id, idtype string) (*ratelimits.ApiLimits, error)
103

104
        // CacheLimits saves limits for 'id'
105
        CacheLimits(ctx context.Context, l ratelimits.ApiLimits, tid, id, idtype string) error
106

107
        // FlushDB clears the whole db asynchronously (FLUSHDB ASYNC)
108
        // TODO: replace with more fine grained key removal (per tenant)
109
        FlushDB(ctx context.Context) error
110
}
111

112
type RedisCache struct {
113
        c               *redis.Client
114
        LimitsExpireSec int
115
        clock           utils.Clock
116
}
117

118
func NewRedisCache(
119
        addr,
120
        user,
121
        pass string,
122
        db int,
123
        timeoutSec,
124
        limitsExpireSec int,
125
) (*RedisCache, error) {
24✔
126
        _, _, err := net.SplitHostPort(addr)
24✔
127
        if err != nil {
26✔
128
                return nil, errors.WithMessage(err, "redis: invalid address")
2✔
129
        } else if db < 0 {
26✔
130
                return nil, errors.WithMessage(ErrNegativeInteger, "redis: database")
2✔
131
        } else if timeoutSec <= 0 {
24✔
132
                return nil, errors.WithMessage(ErrNoPositiveInteger, "redis: timeout seconds")
2✔
133
        } else if limitsExpireSec <= 0 {
22✔
134
                return nil, errors.WithMessage(ErrNoPositiveInteger, "redis: limit expire seconds")
2✔
135
        }
2✔
136
        c := redis.NewClient(&redis.Options{
16✔
137
                Addr:     addr,
16✔
138
                Username: user,
16✔
139
                Password: pass,
16✔
140
                DB:       db,
16✔
141
        }).WithTimeout(time.Duration(timeoutSec) * time.Second)
16✔
142
        return &RedisCache{
16✔
143
                c:               c,
16✔
144
                LimitsExpireSec: limitsExpireSec,
16✔
145
                clock:           utils.NewClock(),
16✔
146
        }, err
16✔
147
}
148

149
func (rl *RedisCache) WithClock(c utils.Clock) *RedisCache {
4✔
150
        rl.clock = c
4✔
151
        return rl
4✔
152
}
4✔
153

154
func (rl *RedisCache) Throttle(
155
        ctx context.Context,
156
        rawToken string,
157
        l ratelimits.ApiLimits,
158
        tid,
159
        id,
160
        idtype,
161
        url,
162
        action string,
163
) (string, error) {
250✔
164
        now := rl.clock.Now().Unix()
250✔
165

250✔
166
        var tokenGet *redis.StringCmd
250✔
167
        var quotaInc *redis.IntCmd
250✔
168
        var quotaExp *redis.BoolCmd
250✔
169
        var burstGet *redis.StringCmd
250✔
170
        var burstSet *redis.StatusCmd
250✔
171

250✔
172
        pipe := rl.c.TxPipeline()
250✔
173

250✔
174
        // queue quota/burst control and token fetching
250✔
175
        // for piped execution
250✔
176
        quotaInc, quotaExp = rl.pipeQuota(ctx, pipe, l, tid, id, idtype, now)
250✔
177
        tokenGet = rl.pipeToken(ctx, pipe, tid, id, idtype)
250✔
178

250✔
179
        burstGet, burstSet = rl.pipeBurst(ctx,
250✔
180
                pipe,
250✔
181
                l,
250✔
182
                tid, id, idtype,
250✔
183
                url, action,
250✔
184
                now)
250✔
185

250✔
186
        _, err := pipe.Exec(ctx)
250✔
187
        if err != nil && !isErrRedisNil(err) {
250✔
UNCOV
188
                return "", err
×
UNCOV
189
        }
×
190

191
        // collect quota/burst control and token fetch results
192
        tok, err := rl.checkToken(tokenGet, rawToken)
250✔
193
        if err != nil {
250✔
UNCOV
194
                return "", err
×
UNCOV
195
        }
×
196

197
        err = rl.checkQuota(l, quotaInc, quotaExp)
250✔
198
        if err != nil {
290✔
199
                return "", err
40✔
200
        }
40✔
201

202
        err = rl.checkBurst(burstGet, burstSet)
210✔
203
        if err != nil {
276✔
204
                return "", err
66✔
205
        }
66✔
206

207
        return tok, nil
144✔
208
}
209

210
func (rl *RedisCache) pipeToken(
211
        ctx context.Context,
212
        pipe redis.Pipeliner,
213
        tid,
214
        id,
215
        idtype string,
216
) *redis.StringCmd {
250✔
217
        key := KeyToken(tid, id, idtype)
250✔
218
        return pipe.Get(ctx, key)
250✔
219
}
250✔
220

221
func (rl *RedisCache) checkToken(cmd *redis.StringCmd, raw string) (string, error) {
250✔
222
        err := cmd.Err()
250✔
223

250✔
224
        if err != nil {
490✔
225
                if isErrRedisNil(err) {
480✔
226
                        return "", nil
240✔
227
                }
240✔
UNCOV
228
                return "", err
×
229
        }
230

231
        token := cmd.Val()
10✔
232
        if token == raw {
18✔
233
                return token, nil
8✔
234
        } else {
10✔
235
                // must be a stale token - we don't want to use it
2✔
236
                // let it expire in the background
2✔
237
                return "", nil
2✔
238
        }
2✔
239
}
240

241
func (rl *RedisCache) pipeQuota(
242
        ctx context.Context,
243
        pipe redis.Pipeliner,
244
        l ratelimits.ApiLimits,
245
        tid,
246
        id,
247
        idtype string,
248
        now int64,
249
) (*redis.IntCmd, *redis.BoolCmd) {
250✔
250
        var incr *redis.IntCmd
250✔
251
        var expire *redis.BoolCmd
250✔
252

250✔
253
        // not a default/empty quota
250✔
254
        if l.ApiQuota.MaxCalls != 0 {
398✔
255
                intvl := int64(now / int64(l.ApiQuota.IntervalSec))
148✔
256
                keyQuota := KeyQuota(tid, id, idtype, strconv.FormatInt(intvl, 10))
148✔
257
                incr = pipe.Incr(ctx, keyQuota)
148✔
258
                expire = pipe.Expire(ctx, keyQuota, time.Duration(l.ApiQuota.IntervalSec)*time.Second)
148✔
259
        }
148✔
260

261
        return incr, expire
250✔
262
}
263

264
func (rl *RedisCache) checkQuota(
265
        l ratelimits.ApiLimits,
266
        incr *redis.IntCmd,
267
        expire *redis.BoolCmd,
268
) error {
250✔
269
        if incr == nil && expire == nil {
352✔
270
                return nil
102✔
271
        }
102✔
272

273
        err := incr.Err()
148✔
274
        if err != nil && !isErrRedisNil(err) {
148✔
UNCOV
275
                return err
×
UNCOV
276
        }
×
277

278
        err = expire.Err()
148✔
279
        if err != nil {
148✔
UNCOV
280
                return err
×
UNCOV
281
        }
×
282

283
        quota := incr.Val()
148✔
284
        if quota > int64(l.ApiQuota.MaxCalls) {
188✔
285
                return ErrTooManyRequests
40✔
286
        }
40✔
287

288
        return nil
108✔
289
}
290

291
func (rl *RedisCache) pipeBurst(ctx context.Context,
292
        pipe redis.Pipeliner,
293
        l ratelimits.ApiLimits,
294
        tid, id, idtype, url, action string,
295
        now int64) (*redis.StringCmd, *redis.StatusCmd) {
250✔
296
        var get *redis.StringCmd
250✔
297
        var set *redis.StatusCmd
250✔
298

250✔
299
        for _, b := range l.ApiBursts {
418✔
300
                if b.Action == action &&
168✔
301
                        b.Uri == url &&
168✔
302
                        b.MinIntervalSec != 0 {
306✔
303

138✔
304
                        intvl := int64(now / int64(b.MinIntervalSec))
138✔
305
                        keyBurst := KeyBurst(tid, id, idtype, url, action, strconv.FormatInt(intvl, 10))
138✔
306

138✔
307
                        get = pipe.Get(ctx, keyBurst)
138✔
308
                        set = pipe.Set(ctx, keyBurst, now, time.Duration(b.MinIntervalSec)*time.Second)
138✔
309
                }
138✔
310
        }
311

312
        return get, set
250✔
313
}
314

315
func (rl *RedisCache) checkBurst(get *redis.StringCmd, set *redis.StatusCmd) error {
210✔
316
        if get != nil && set != nil {
338✔
317
                err := get.Err()
128✔
318

128✔
319
                // no error means burst was found/hit
128✔
320
                if err == nil {
194✔
321
                        return ErrTooManyRequests
66✔
322
                }
66✔
323

324
                if isErrRedisNil(err) {
124✔
325
                        return nil
62✔
326
                }
62✔
327

UNCOV
328
                return err
×
329
        }
330

331
        return nil
82✔
332
}
333

334
func (rl *RedisCache) CacheToken(
335
        ctx context.Context,
336
        tid,
337
        id,
338
        idtype,
339
        token string,
340
        expire time.Duration,
341
) error {
10✔
342
        res := rl.c.Set(ctx, KeyToken(tid, id, idtype),
10✔
343
                token,
10✔
344
                expire)
10✔
345
        return res.Err()
10✔
346
}
10✔
347

348
func (rl *RedisCache) DeleteToken(ctx context.Context, tid, id, idtype string) error {
6✔
349
        res := rl.c.Del(ctx, KeyToken(tid, id, idtype))
6✔
350
        return res.Err()
6✔
351
}
6✔
352

353
func (rl *RedisCache) GetLimits(
354
        ctx context.Context,
355
        tid,
356
        id,
357
        idtype string,
358
) (*ratelimits.ApiLimits, error) {
6✔
359
        res := rl.c.Get(ctx, KeyLimits(tid, id, idtype))
6✔
360

6✔
361
        if res.Err() != nil {
10✔
362
                if isErrRedisNil(res.Err()) {
8✔
363
                        return nil, nil
4✔
364
                }
4✔
UNCOV
365
                return nil, res.Err()
×
366
        }
367

368
        var limits ratelimits.ApiLimits
2✔
369

2✔
370
        err := json.Unmarshal([]byte(res.Val()), &limits)
2✔
371
        if err != nil {
2✔
UNCOV
372
                return nil, err
×
UNCOV
373
        }
×
374

375
        return &limits, nil
2✔
376
}
377

378
func (rl *RedisCache) CacheLimits(
379
        ctx context.Context,
380
        l ratelimits.ApiLimits,
381
        tid,
382
        id,
383
        idtype string,
384
) error {
2✔
385
        enc, err := json.Marshal(l)
2✔
386
        if err != nil {
2✔
UNCOV
387
                return err
×
UNCOV
388
        }
×
389

390
        res := rl.c.Set(
2✔
391
                ctx,
2✔
392
                KeyLimits(tid, id, idtype),
2✔
393
                enc,
2✔
394
                time.Duration(rl.LimitsExpireSec)*time.Second,
2✔
395
        )
2✔
396

2✔
397
        return res.Err()
2✔
398
}
399

400
func (rl *RedisCache) FlushDB(ctx context.Context) error {
2✔
401
        return rl.c.FlushDBAsync(ctx).Err()
2✔
402
}
2✔
403

404
func KeyQuota(tid, id, idtype, intvlNum string) string {
148✔
405
        return fmt.Sprintf("tenant:%s:%s:%s:quota:%s", tid, idtype, id, intvlNum)
148✔
406
}
148✔
407

408
func KeyBurst(tid, id, idtype, url, action, intvlNum string) string {
138✔
409
        return fmt.Sprintf("tenant:%s:%s:%s:burst:%s:%s:%s", tid, idtype, id, url, action, intvlNum)
138✔
410
}
138✔
411

412
func KeyToken(tid, id, idtype string) string {
274✔
413
        return fmt.Sprintf("tenant:%s:%s:%s:tok", tid, idtype, id)
274✔
414
}
274✔
415

416
func KeyLimits(tid, id, idtype string) string {
8✔
417
        return fmt.Sprintf("tenant:%s:%s:%s:limits", tid, idtype, id)
8✔
418
}
8✔
419

420
// isErrRedisNil checks for a very common non-error, "redis: nil",
421
// which just means the key was not found, and is normal
422
// it's routinely returned e.g. from GET, or pipelines containing it
423
func isErrRedisNil(e error) bool {
546✔
424
        return e.Error() == "redis: nil"
546✔
425
}
546✔
426

427
// TODO: move to go-lib-micro/ratelimits
UNCOV
428
func LimitsEmpty(l *ratelimits.ApiLimits) bool {
×
UNCOV
429
        return l.ApiQuota.MaxCalls == 0 &&
×
UNCOV
430
                l.ApiQuota.IntervalSec == 0 &&
×
UNCOV
431
                len(l.ApiBursts) == 0
×
UNCOV
432
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc