• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

divviup / divviup-api / 8666760397

12 Apr 2024 07:02PM UTC coverage: 56.289% (+0.2%) from 56.083%
8666760397

Pull #968

github

web-flow
Merge 8c0857084 into a6cdbab81
Pull Request #968: Support for time bucketed fixed size

58 of 86 new or added lines in 8 files covered. (67.44%)

6 existing lines in 5 files now uncovered.

3692 of 6559 relevant lines covered (56.29%)

102.51 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

82.26
/src/entity/task/new_task.rs
1
use super::*;
2
use crate::{
3
    clients::aggregator_client::api_types::{AggregatorVdaf, QueryType},
4
    entity::{
5
        aggregator::{Feature, Role},
6
        Account, CollectorCredential, Protocol,
7
    },
8
    handler::Error,
9
};
10
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
11
use rand::Rng;
12
use sea_orm::{ColumnTrait, QueryFilter};
13
use sha2::{Digest, Sha256};
14
use validator::{ValidationErrors, ValidationErrorsKind};
15

16
#[derive(Deserialize, Validate, Debug, Clone, Default)]
220✔
17
pub struct NewTask {
18
    #[validate(required, length(min = 1))]
19
    pub name: Option<String>,
20

21
    #[validate(required)]
22
    pub leader_aggregator_id: Option<String>,
23

24
    #[validate(required)]
25
    pub helper_aggregator_id: Option<String>,
26

27
    #[validate(required, nested)]
28
    pub vdaf: Option<Vdaf>,
29

30
    #[validate(required, range(min = 100))]
31
    pub min_batch_size: Option<u64>,
32

33
    #[validate(range(min = 0))]
34
    pub max_batch_size: Option<u64>,
35

36
    #[validate(range(min = 0))]
37
    pub batch_time_window_size_seconds: Option<u64>,
38

39
    #[validate(
40
        required,
41
        range(
42
            min = 60,
43
            max = 2592000,
44
            message = "must be between 1 minute and 4 weeks"
45
        )
46
    )]
47
    pub time_precision_seconds: Option<u64>,
48

49
    #[validate(required)]
50
    pub collector_credential_id: Option<String>,
51
}
52

53
async fn load_aggregator(
40✔
54
    account: &Account,
40✔
55
    id: Option<&str>,
40✔
56
    db: &impl ConnectionTrait,
40✔
57
) -> Result<Option<Aggregator>, Error> {
40✔
58
    let Some(id) = id.map(Uuid::parse_str).transpose()? else {
40✔
59
        return Ok(None);
6✔
60
    };
61

62
    let aggregator = Aggregators::find_by_id(id)
34✔
63
        .filter(AggregatorColumn::DeletedAt.is_null())
34✔
64
        .one(db)
34✔
65
        .await?;
72✔
66

67
    let Some(aggregator) = aggregator else {
34✔
68
        return Ok(None);
2✔
69
    };
70

71
    if aggregator.account_id.is_none() || aggregator.account_id == Some(account.id) {
32✔
72
        Ok(Some(aggregator))
32✔
73
    } else {
74
        Ok(None)
×
75
    }
76
}
40✔
77

78
const VDAF_BYTES: usize = 16;
79
fn generate_vdaf_verify_key_and_expected_task_id() -> (String, String) {
7✔
80
    let mut verify_key = [0; VDAF_BYTES];
7✔
81
    rand::thread_rng().fill(&mut verify_key);
7✔
82
    (
7✔
83
        URL_SAFE_NO_PAD.encode(verify_key),
7✔
84
        URL_SAFE_NO_PAD.encode(Sha256::digest(verify_key)),
7✔
85
    )
7✔
86
}
7✔
87

88
impl NewTask {
89
    fn validate_min_lte_max(&self, errors: &mut ValidationErrors) {
20✔
90
        let min = self.min_batch_size;
20✔
91
        let max = self.max_batch_size;
20✔
92
        if matches!((min, max), (Some(min), Some(max)) if min > max) {
20✔
93
            let error = ValidationError::new("min_greater_than_max");
2✔
94
            errors.add("min_batch_size", error.clone());
2✔
95
            errors.add("max_batch_size", error);
2✔
96
        }
18✔
97
    }
20✔
98

99
    fn validate_batch_time_window_size(&self, errors: &mut ValidationErrors) {
20✔
100
        let window = self.batch_time_window_size_seconds;
20✔
101
        if let Some(window) = window {
20✔
102
            if self.max_batch_size.is_none() {
5✔
103
                errors.add(
1✔
104
                    "batch_time_window_size_seconds",
1✔
105
                    ValidationError::new("missing-max-batch-size"),
1✔
106
                );
1✔
107
            }
4✔
108
            if let Some(precision) = self.time_precision_seconds {
5✔
109
                if window % precision != 0 {
5✔
110
                    errors.add(
2✔
111
                        "batch_time_window_size_seconds",
2✔
112
                        ValidationError::new("not-multiple-of-time-precision"),
2✔
113
                    );
2✔
114
                }
3✔
NEW
115
            }
×
116
        }
15✔
117
    }
20✔
118

119
    async fn load_collector_credential(
20✔
120
        &self,
20✔
121
        account: &Account,
20✔
122
        db: &impl ConnectionTrait,
20✔
123
    ) -> Option<CollectorCredential> {
20✔
124
        let id = Uuid::parse_str(self.collector_credential_id.as_deref()?).ok()?;
20✔
125
        CollectorCredentials::find_by_id(id)
9✔
126
            .filter(CollectorCredentialColumn::AccountId.eq(account.id))
9✔
127
            .one(db)
9✔
128
            .await
16✔
129
            .ok()
9✔
130
            .flatten()
9✔
131
    }
20✔
132

133
    async fn validate_collector_credential(
20✔
134
        &self,
20✔
135
        account: &Account,
20✔
136
        leader: Option<&Aggregator>,
20✔
137
        db: &impl ConnectionTrait,
20✔
138
        errors: &mut ValidationErrors,
20✔
139
    ) -> Option<CollectorCredential> {
20✔
140
        match self.load_collector_credential(account, db).await {
20✔
141
            None => {
142
                errors.add("collector_credential_id", ValidationError::new("required"));
11✔
143
                None
11✔
144
            }
145

146
            Some(collector_credential) => {
9✔
147
                let leader_needs_token_hash =
9✔
148
                    leader.map_or(false, |leader| leader.features.token_hash_enabled());
9✔
149

9✔
150
                if leader_needs_token_hash && collector_credential.token_hash.is_none() {
9✔
151
                    errors.add(
×
152
                        "collector_credential_id",
×
153
                        ValidationError::new("missing-token-hash"),
×
154
                    );
×
155
                    None
×
156
                } else {
157
                    Some(collector_credential)
9✔
158
                }
159
            }
160
        }
161
    }
20✔
162

163
    async fn validate_aggregators(
20✔
164
        &self,
20✔
165
        account: &Account,
20✔
166
        db: &impl ConnectionTrait,
20✔
167
        errors: &mut ValidationErrors,
20✔
168
    ) -> Option<(Aggregator, Aggregator, Protocol)> {
20✔
169
        let leader = load_aggregator(account, self.leader_aggregator_id.as_deref(), db)
20✔
170
            .await
36✔
171
            .ok()
20✔
172
            .flatten();
20✔
173
        if leader.is_none() {
20✔
174
            errors.add("leader_aggregator_id", ValidationError::new("required"));
4✔
175
        }
16✔
176

177
        let helper = load_aggregator(account, self.helper_aggregator_id.as_deref(), db)
20✔
178
            .await
36✔
179
            .ok()
20✔
180
            .flatten();
20✔
181
        if helper.is_none() {
20✔
182
            errors.add("helper_aggregator_id", ValidationError::new("required"));
4✔
183
        }
16✔
184

185
        let (Some(leader), Some(helper)) = (leader, helper) else {
20✔
186
            return None;
5✔
187
        };
188

189
        if leader == helper {
15✔
190
            errors.add("leader_aggregator_id", ValidationError::new("same"));
×
191
            errors.add("helper_aggregator_id", ValidationError::new("same"));
×
192
        }
15✔
193

194
        if !leader.is_first_party && !helper.is_first_party {
15✔
195
            errors.add(
×
196
                "leader_aggregator_id",
×
197
                ValidationError::new("no-first-party"),
×
198
            );
×
199
            errors.add(
×
200
                "helper_aggregator_id",
×
201
                ValidationError::new("no-first-party"),
×
202
            );
×
203
        }
15✔
204

205
        let resolved_protocol = if leader.protocol == helper.protocol {
15✔
206
            leader.protocol
15✔
207
        } else {
208
            errors.add("leader_aggregator_id", ValidationError::new("protocol"));
×
209
            errors.add("helper_aggregator_id", ValidationError::new("protocol"));
×
210
            return None;
×
211
        };
212

213
        if leader.role == Role::Helper {
15✔
214
            errors.add("leader_aggregator_id", ValidationError::new("role"))
1✔
215
        }
14✔
216

217
        if helper.role == Role::Leader {
15✔
218
            errors.add("helper_aggregator_id", ValidationError::new("role"))
1✔
219
        }
14✔
220

221
        if self.batch_time_window_size_seconds.is_some()
15✔
222
            && !leader.features.contains(&Feature::TimeBucketedFixedSize)
5✔
223
        {
224
            errors.add(
1✔
225
                "leader_aggregator_id",
1✔
226
                ValidationError::new("time-bucketed-fixed-size-unsupported"),
1✔
227
            )
1✔
228
        }
14✔
229

230
        if errors.is_empty() {
15✔
231
            Some((leader, helper, resolved_protocol))
7✔
232
        } else {
233
            None
8✔
234
        }
235
    }
20✔
236

237
    fn validate_vdaf_is_supported(
7✔
238
        &self,
7✔
239
        leader: &Aggregator,
7✔
240
        helper: &Aggregator,
7✔
241
        protocol: &Protocol,
7✔
242
        errors: &mut ValidationErrors,
7✔
243
    ) -> Option<AggregatorVdaf> {
7✔
244
        let Some(vdaf) = self.vdaf.as_ref() else {
7✔
245
            return None;
×
246
        };
247

248
        let name = vdaf.name();
7✔
249
        let aggregator_vdaf = match vdaf.representation_for_protocol(protocol) {
7✔
250
            Ok(vdaf) => vdaf,
7✔
251
            Err(e) => {
×
252
                let errors = errors.errors_mut().entry("vdaf").or_insert_with(|| {
×
253
                    ValidationErrorsKind::Struct(Box::new(ValidationErrors::new()))
×
254
                });
×
255
                match errors {
×
256
                    ValidationErrorsKind::Struct(errors) => {
×
257
                        errors.errors_mut().extend(e.into_errors())
×
258
                    }
259
                    other => *other = ValidationErrorsKind::Struct(Box::new(e)),
×
260
                };
261
                return None;
×
262
            }
263
        };
264

265
        if !leader.vdafs.contains(&name) || !helper.vdafs.contains(&name) {
7✔
266
            let errors = errors
×
267
                .errors_mut()
×
268
                .entry("vdaf")
×
269
                .or_insert_with(|| ValidationErrorsKind::Struct(Box::new(ValidationErrors::new())));
×
270
            match errors {
×
271
                ValidationErrorsKind::Struct(errors) => {
×
272
                    errors.add("type", ValidationError::new("not-supported"));
×
273
                }
×
274
                other => {
×
275
                    let mut e = ValidationErrors::new();
×
276
                    e.add("type", ValidationError::new("not-supported"));
×
277
                    *other = ValidationErrorsKind::Struct(Box::new(e));
×
278
                }
×
279
            };
280
        }
7✔
281

282
        Some(aggregator_vdaf)
7✔
283
    }
7✔
284

285
    fn populate_chunk_length(&mut self, protocol: &Protocol) {
7✔
286
        if let Some(vdaf) = &mut self.vdaf {
7✔
287
            vdaf.populate_chunk_length(protocol);
7✔
288
        }
7✔
289
    }
7✔
290

291
    fn validate_query_type_is_supported(
7✔
292
        &self,
7✔
293
        leader: &Aggregator,
7✔
294
        helper: &Aggregator,
7✔
295
        errors: &mut ValidationErrors,
7✔
296
    ) {
7✔
297
        let name = QueryType::from(self.clone()).name();
7✔
298
        if !leader.query_types.contains(&name) || !helper.query_types.contains(&name) {
7✔
299
            errors.add("max_batch_size", ValidationError::new("not-supported"));
×
300
        }
7✔
301
    }
7✔
302

303
    pub async fn normalize_and_validate(
20✔
304
        &mut self,
20✔
305
        account: Account,
20✔
306
        db: &impl ConnectionTrait,
20✔
307
    ) -> Result<ProvisionableTask, ValidationErrors> {
20✔
308
        let mut errors = Validate::validate(self).err().unwrap_or_default();
20✔
309
        self.validate_min_lte_max(&mut errors);
20✔
310
        self.validate_batch_time_window_size(&mut errors);
20✔
311
        let aggregators = self.validate_aggregators(&account, db, &mut errors).await;
72✔
312
        let collector_credential = self
20✔
313
            .validate_collector_credential(
20✔
314
                &account,
20✔
315
                aggregators.as_ref().map(|(leader, ..)| leader),
20✔
316
                db,
20✔
317
                &mut errors,
20✔
318
            )
20✔
319
            .await;
16✔
320

321
        let aggregator_vdaf = if let Some((leader, helper, protocol)) = aggregators.as_ref() {
20✔
322
            self.validate_query_type_is_supported(leader, helper, &mut errors);
7✔
323
            self.populate_chunk_length(protocol);
7✔
324
            self.validate_vdaf_is_supported(leader, helper, protocol, &mut errors)
7✔
325
        } else {
326
            None
13✔
327
        };
328

329
        if errors.is_empty() {
20✔
330
            // Unwrap safety: All of these unwraps below have previously
331
            // been checked by the above validations. The fact that we
332
            // have to check them twice is a consequence of the
333
            // disharmonious combination of Validate and the fact that we
334
            // need to use options for all fields so serde doesn't bail on
335
            // the first error.
336
            let (leader_aggregator, helper_aggregator, protocol) = aggregators.unwrap();
7✔
337

7✔
338
            let (vdaf_verify_key, id) = generate_vdaf_verify_key_and_expected_task_id();
7✔
339

7✔
340
            Ok(ProvisionableTask {
7✔
341
                account,
7✔
342
                id,
7✔
343
                vdaf_verify_key,
7✔
344
                name: self.name.clone().unwrap(),
7✔
345
                leader_aggregator,
7✔
346
                helper_aggregator,
7✔
347
                vdaf: self.vdaf.clone().unwrap(),
7✔
348
                aggregator_vdaf: aggregator_vdaf.unwrap(),
7✔
349
                min_batch_size: self.min_batch_size.unwrap(),
7✔
350
                max_batch_size: self.max_batch_size,
7✔
351
                batch_time_window_size_seconds: self.batch_time_window_size_seconds,
7✔
352
                expiration: Some(OffsetDateTime::now_utc() + DEFAULT_EXPIRATION_DURATION),
7✔
353
                time_precision_seconds: self.time_precision_seconds.unwrap(),
7✔
354
                collector_credential: collector_credential.unwrap(),
7✔
355
                aggregator_auth_token: None,
7✔
356
                protocol,
7✔
357
            })
7✔
358
        } else {
359
            Err(errors)
13✔
360
        }
361
    }
20✔
362
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc