• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mendersoftware / deviceauth / 826658230

pending completion
826658230

Pull #638

gitlab-ci

Peter Grzybowski
chore: moving to single db
Pull Request #638: chore: moving to single db

334 of 405 new or added lines in 5 files covered. (82.47%)

38 existing lines in 3 files now uncovered.

4669 of 5588 relevant lines covered (83.55%)

75.19 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.0
/cache/cache.go
1
// Copyright 2021 Northern.tech AS
2
//
3
//    Licensed under the Apache License, Version 2.0 (the "License");
4
//    you may not use this file except in compliance with the License.
5
//    You may obtain a copy of the License at
6
//
7
//        http://www.apache.org/licenses/LICENSE-2.0
8
//
9
//    Unless required by applicable law or agreed to in writing, software
10
//    distributed under the License is distributed on an "AS IS" BASIS,
11
//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
//    See the License for the specific language governing permissions and
13
//    limitations under the License.
14

15
// Package cache introduces API throttling based
16
// on redis, and functions for auth token management.
17
//
18
// Throttling mechanisms
19
//
20
// 1. Quota enforcement
21
//
22
// Based on https://redislabs.com/redis-best-practices/basic-rate-limiting/, but with a flexible
23
// interval (ratelimits.ApiQuota.IntervalSec).
24
// Current usage for a device lives under key:
25
//
26
// `tenant:<tid>:device:<did>:quota:<interval_num>: <num_reqs>`
27
//
28
// expiring in the defined time window.
29
//
30
// 2. Burst control
31
//
32
// Implemented with a simple single key:
33
//
34
// `tenant:<tid>:device:<did>:burst:<action>:<url>: <last_req_ts>`
35
//
36
// expiring in ratelimits.ApiBurst.MinIntervalSec.
37
// The value is not really important, just the existence of the key
38
// means the burst was exceeded.
39
//
40
// Token Management
41
//
42
// Tokens are expected at:
43
// `tenant:<tid>:device:<did>:tok: <token>`
44
//
45

46
package cache
47

48
import (
49
        "context"
50
        "encoding/json"
51
        "fmt"
52
        "strconv"
53
        "time"
54

55
        "github.com/pkg/errors"
56

57
        "github.com/go-redis/redis/v8"
58
        "github.com/mendersoftware/go-lib-micro/ratelimits"
59

60
        "github.com/mendersoftware/deviceauth/utils"
61
)
62

63
const (
64
        IdTypeDevice = "device"
65
        IdTypeUser   = "user"
66
)
67

68
var (
69
        ErrTooManyRequests = errors.New("too many requests")
70
)
71

72
//go:generate ../utils/mockgen.sh
73
type Cache interface {
74
        // Throttle applies desired api limits and retrieves a cached token.
75
        // These ops are bundled because the implementation will pipeline them for a single network
76
        // roundtrip for max performance.
77
        // Returns:
78
        // - the token (if any)
79
        // - potentially ErrTooManyRequests (other errors: internal)
80
        Throttle(
81
                ctx context.Context,
82
                rawToken string,
83
                l ratelimits.ApiLimits,
84
                tid,
85
                id,
86
                idtype,
87
                url,
88
                action string,
89
        ) (string, error)
90

91
        // CacheToken caches the token under designated key, with expiration
92
        CacheToken(ctx context.Context, tid, id, idtype, token string, expireSec time.Duration) error
93

94
        // DeleteToken deletes the token for 'id'
95
        DeleteToken(ctx context.Context, tid, id, idtype string) error
96

97
        // GetLimits fetches limits for 'id'
98
        GetLimits(ctx context.Context, tid, id, idtype string) (*ratelimits.ApiLimits, error)
99

100
        // CacheLimits saves limits for 'id'
101
        CacheLimits(ctx context.Context, l ratelimits.ApiLimits, tid, id, idtype string) error
102

103
        // FlushDB clears the whole db asynchronously (FLUSHDB ASYNC)
104
        // TODO: replace with more fine grained key removal (per tenant)
105
        FlushDB(ctx context.Context) error
106
}
107

108
type RedisCache struct {
109
        c               *redis.Client
110
        LimitsExpireSec int
111
        clock           utils.Clock
112
}
113

114
func NewRedisCache(
115
        addr,
116
        user,
117
        pass string,
118
        db int,
119
        timeoutSec,
120
        limitsExpireSec int,
121
) (*RedisCache, error) {
14✔
122
        c := redis.NewClient(&redis.Options{
14✔
123
                Addr:     addr,
14✔
124
                Username: user,
14✔
125
                Password: pass,
14✔
126
                DB:       db,
14✔
127
        })
14✔
128

14✔
129
        c = c.WithTimeout(time.Duration(timeoutSec) * time.Second)
14✔
130

14✔
131
        _, err := c.Ping(context.TODO()).Result()
14✔
132
        return &RedisCache{
14✔
133
                c:               c,
14✔
134
                LimitsExpireSec: limitsExpireSec,
14✔
135
                clock:           utils.NewClock(),
14✔
136
        }, err
14✔
137
}
14✔
138

139
func (rl *RedisCache) WithClock(c utils.Clock) *RedisCache {
4✔
140
        rl.clock = c
4✔
141
        return rl
4✔
142
}
4✔
143

144
func (rl *RedisCache) Throttle(
145
        ctx context.Context,
146
        rawToken string,
147
        l ratelimits.ApiLimits,
148
        tid,
149
        id,
150
        idtype,
151
        url,
152
        action string,
153
) (string, error) {
250✔
154
        now := rl.clock.Now().Unix()
250✔
155

250✔
156
        var tokenGet *redis.StringCmd
250✔
157
        var quotaInc *redis.IntCmd
250✔
158
        var quotaExp *redis.BoolCmd
250✔
159
        var burstGet *redis.StringCmd
250✔
160
        var burstSet *redis.StatusCmd
250✔
161

250✔
162
        pipe := rl.c.TxPipeline()
250✔
163

250✔
164
        // queue quota/burst control and token fetching
250✔
165
        // for piped execution
250✔
166
        quotaInc, quotaExp = rl.pipeQuota(ctx, pipe, l, tid, id, idtype, now)
250✔
167
        tokenGet = rl.pipeToken(ctx, pipe, tid, id, idtype)
250✔
168

250✔
169
        burstGet, burstSet = rl.pipeBurst(ctx,
250✔
170
                pipe,
250✔
171
                l,
250✔
172
                tid, id, idtype,
250✔
173
                url, action,
250✔
174
                now)
250✔
175

250✔
176
        _, err := pipe.Exec(ctx)
250✔
177
        if err != nil && !isErrRedisNil(err) {
250✔
UNCOV
178
                return "", err
×
UNCOV
179
        }
×
180

181
        // collect quota/burst control and token fetch results
182
        tok, err := rl.checkToken(tokenGet, rawToken)
250✔
183
        if err != nil {
250✔
UNCOV
184
                return "", err
×
UNCOV
185
        }
×
186

187
        err = rl.checkQuota(l, quotaInc, quotaExp)
250✔
188
        if err != nil {
290✔
189
                return "", err
40✔
190
        }
40✔
191

192
        err = rl.checkBurst(burstGet, burstSet)
210✔
193
        if err != nil {
276✔
194
                return "", err
66✔
195
        }
66✔
196

197
        return tok, nil
144✔
198
}
199

200
func (rl *RedisCache) pipeToken(
201
        ctx context.Context,
202
        pipe redis.Pipeliner,
203
        tid,
204
        id,
205
        idtype string,
206
) *redis.StringCmd {
250✔
207
        key := KeyToken(tid, id, idtype)
250✔
208
        return pipe.Get(ctx, key)
250✔
209
}
250✔
210

211
func (rl *RedisCache) checkToken(cmd *redis.StringCmd, raw string) (string, error) {
250✔
212
        err := cmd.Err()
250✔
213

250✔
214
        if err != nil {
490✔
215
                if isErrRedisNil(err) {
480✔
216
                        return "", nil
240✔
217
                }
240✔
UNCOV
218
                return "", err
×
219
        }
220

221
        token := cmd.Val()
10✔
222
        if token == raw {
18✔
223
                return token, nil
8✔
224
        } else {
10✔
225
                // must be a stale token - we don't want to use it
2✔
226
                // let it expire in the background
2✔
227
                return "", nil
2✔
228
        }
2✔
229
}
230

231
func (rl *RedisCache) pipeQuota(
232
        ctx context.Context,
233
        pipe redis.Pipeliner,
234
        l ratelimits.ApiLimits,
235
        tid,
236
        id,
237
        idtype string,
238
        now int64,
239
) (*redis.IntCmd, *redis.BoolCmd) {
250✔
240
        var incr *redis.IntCmd
250✔
241
        var expire *redis.BoolCmd
250✔
242

250✔
243
        // not a default/empty quota
250✔
244
        if l.ApiQuota.MaxCalls != 0 {
398✔
245
                intvl := int64(now / int64(l.ApiQuota.IntervalSec))
148✔
246
                keyQuota := KeyQuota(tid, id, idtype, strconv.FormatInt(intvl, 10))
148✔
247
                incr = pipe.Incr(ctx, keyQuota)
148✔
248
                expire = pipe.Expire(ctx, keyQuota, time.Duration(l.ApiQuota.IntervalSec)*time.Second)
148✔
249
        }
148✔
250

251
        return incr, expire
250✔
252
}
253

254
func (rl *RedisCache) checkQuota(
255
        l ratelimits.ApiLimits,
256
        incr *redis.IntCmd,
257
        expire *redis.BoolCmd,
258
) error {
250✔
259
        if incr == nil && expire == nil {
352✔
260
                return nil
102✔
261
        }
102✔
262

263
        err := incr.Err()
148✔
264
        if err != nil && !isErrRedisNil(err) {
148✔
UNCOV
265
                return err
×
UNCOV
266
        }
×
267

268
        err = expire.Err()
148✔
269
        if err != nil {
148✔
UNCOV
270
                return err
×
UNCOV
271
        }
×
272

273
        quota := incr.Val()
148✔
274
        if quota > int64(l.ApiQuota.MaxCalls) {
188✔
275
                return ErrTooManyRequests
40✔
276
        }
40✔
277

278
        return nil
108✔
279
}
280

281
func (rl *RedisCache) pipeBurst(ctx context.Context,
282
        pipe redis.Pipeliner,
283
        l ratelimits.ApiLimits,
284
        tid, id, idtype, url, action string,
285
        now int64) (*redis.StringCmd, *redis.StatusCmd) {
250✔
286
        var get *redis.StringCmd
250✔
287
        var set *redis.StatusCmd
250✔
288

250✔
289
        for _, b := range l.ApiBursts {
418✔
290
                if b.Action == action &&
168✔
291
                        b.Uri == url &&
168✔
292
                        b.MinIntervalSec != 0 {
306✔
293

138✔
294
                        intvl := int64(now / int64(b.MinIntervalSec))
138✔
295
                        keyBurst := KeyBurst(tid, id, idtype, url, action, strconv.FormatInt(intvl, 10))
138✔
296

138✔
297
                        get = pipe.Get(ctx, keyBurst)
138✔
298
                        set = pipe.Set(ctx, keyBurst, now, time.Duration(b.MinIntervalSec)*time.Second)
138✔
299
                }
138✔
300
        }
301

302
        return get, set
250✔
303
}
304

305
func (rl *RedisCache) checkBurst(get *redis.StringCmd, set *redis.StatusCmd) error {
210✔
306
        if get != nil && set != nil {
338✔
307
                err := get.Err()
128✔
308

128✔
309
                // no error means burst was found/hit
128✔
310
                if err == nil {
194✔
311
                        return ErrTooManyRequests
66✔
312
                }
66✔
313

314
                if isErrRedisNil(err) {
124✔
315
                        return nil
62✔
316
                }
62✔
317

UNCOV
318
                return err
×
319
        }
320

321
        return nil
82✔
322
}
323

324
func (rl *RedisCache) CacheToken(
325
        ctx context.Context,
326
        tid,
327
        id,
328
        idtype,
329
        token string,
330
        expire time.Duration,
331
) error {
10✔
332
        res := rl.c.Set(ctx, KeyToken(tid, id, idtype),
10✔
333
                token,
10✔
334
                expire)
10✔
335
        return res.Err()
10✔
336
}
10✔
337

338
func (rl *RedisCache) DeleteToken(ctx context.Context, tid, id, idtype string) error {
6✔
339
        res := rl.c.Del(ctx, KeyToken(tid, id, idtype))
6✔
340
        return res.Err()
6✔
341
}
6✔
342

343
func (rl *RedisCache) GetLimits(
344
        ctx context.Context,
345
        tid,
346
        id,
347
        idtype string,
348
) (*ratelimits.ApiLimits, error) {
6✔
349
        res := rl.c.Get(ctx, KeyLimits(tid, id, idtype))
6✔
350

6✔
351
        if res.Err() != nil {
10✔
352
                if isErrRedisNil(res.Err()) {
8✔
353
                        return nil, nil
4✔
354
                }
4✔
UNCOV
355
                return nil, res.Err()
×
356
        }
357

358
        var limits ratelimits.ApiLimits
2✔
359

2✔
360
        err := json.Unmarshal([]byte(res.Val()), &limits)
2✔
361
        if err != nil {
2✔
UNCOV
362
                return nil, err
×
UNCOV
363
        }
×
364

365
        return &limits, nil
2✔
366
}
367

368
func (rl *RedisCache) CacheLimits(
369
        ctx context.Context,
370
        l ratelimits.ApiLimits,
371
        tid,
372
        id,
373
        idtype string,
374
) error {
2✔
375
        enc, err := json.Marshal(l)
2✔
376
        if err != nil {
2✔
UNCOV
377
                return err
×
UNCOV
378
        }
×
379

380
        res := rl.c.Set(
2✔
381
                ctx,
2✔
382
                KeyLimits(tid, id, idtype),
2✔
383
                enc,
2✔
384
                time.Duration(rl.LimitsExpireSec)*time.Second,
2✔
385
        )
2✔
386

2✔
387
        return res.Err()
2✔
388
}
389

390
func (rl *RedisCache) FlushDB(ctx context.Context) error {
2✔
391
        return rl.c.FlushDBAsync(ctx).Err()
2✔
392
}
2✔
393

394
func KeyQuota(tid, id, idtype, intvlNum string) string {
148✔
395
        return fmt.Sprintf("tenant:%s:%s:%s:quota:%s", tid, idtype, id, intvlNum)
148✔
396
}
148✔
397

398
func KeyBurst(tid, id, idtype, url, action, intvlNum string) string {
138✔
399
        return fmt.Sprintf("tenant:%s:%s:%s:burst:%s:%s:%s", tid, idtype, id, url, action, intvlNum)
138✔
400
}
138✔
401

402
func KeyToken(tid, id, idtype string) string {
274✔
403
        return fmt.Sprintf("tenant:%s:%s:%s:tok", tid, idtype, id)
274✔
404
}
274✔
405

406
func KeyLimits(tid, id, idtype string) string {
8✔
407
        return fmt.Sprintf("tenant:%s:%s:%s:limits", tid, idtype, id)
8✔
408
}
8✔
409

410
// isErrRedisNil checks for a very common non-error, "redis: nil",
411
// which just means the key was not found, and is normal
412
// it's routinely returned e.g. from GET, or pipelines containing it
413
func isErrRedisNil(e error) bool {
546✔
414
        return e.Error() == "redis: nil"
546✔
415
}
546✔
416

417
// TODO: move to go-lib-micro/ratelimits
UNCOV
418
func LimitsEmpty(l *ratelimits.ApiLimits) bool {
×
UNCOV
419
        return l.ApiQuota.MaxCalls == 0 &&
×
UNCOV
420
                l.ApiQuota.IntervalSec == 0 &&
×
UNCOV
421
                len(l.ApiBursts) == 0
×
UNCOV
422
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc