• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mozilla / fx-private-relay / b25a085d-c11f-433e-91c3-49db81c1ae49

13 May 2024 03:49PM UTC coverage: 84.377% (+0.3%) from 84.07%
b25a085d-c11f-433e-91c3-49db81c1ae49

push

circleci

web-flow
Merge pull request #4702 from mozilla/handle-broken-email-processing-mpp-3815

MPP-3815: Handle broken email processing

3611 of 4741 branches covered (76.17%)

Branch coverage included in aggregate %.

187 of 187 new or added lines in 4 files covered. (100.0%)

3 existing lines in 1 file now uncovered.

14779 of 17054 relevant lines covered (86.66%)

10.84 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.62
/emails/management/commands/process_emails_from_sqs.py
1
"""
2
Process the SQS email queue.
3

4
The SQS queue is processed using the long poll method, which waits until at
5
least one message is available, or wait_seconds is reached.
6

7
See:
8
https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-short-and-long-polling.html#sqs-long-polling
9
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sqs.html#SQS.Queue.receive_messages
10
"""
11

12
import gc
1✔
13
import json
1✔
14
import logging
1✔
15
import shlex
1✔
16
import time
1✔
17
from datetime import UTC, datetime
1✔
18
from multiprocessing import Pool
1✔
19
from typing import Any
1✔
20
from urllib.parse import urlsplit
1✔
21

22
from django import setup
1✔
23
from django.core.management.base import CommandError
1✔
24
from django.http import HttpResponse
1✔
25

26
import boto3
1✔
27
import OpenSSL
1✔
28
from botocore.exceptions import ClientError
1✔
29
from codetiming import Timer
1✔
30
from markus.utils import generate_tag
1✔
31
from mypy_boto3_sqs.service_resource import Message as SQSMessage
1✔
32
from mypy_boto3_sqs.service_resource import Queue as SQSQueue
1✔
33
from sentry_sdk import capture_exception
1✔
34

35
from emails.management.command_from_django_settings import (
1✔
36
    CommandFromDjangoSettings,
37
    SettingToLocal,
38
)
39
from emails.sns import verify_from_sns
1✔
40
from emails.utils import gauge_if_enabled, incr_if_enabled
1✔
41
from emails.views import _sns_inbound_logic, validate_sns_arn_and_type
1✔
42

43
logger = logging.getLogger("eventsinfo.process_emails_from_sqs")
1✔
44

45

46
class Command(CommandFromDjangoSettings):
1✔
47
    help = "Fetch email tasks from SQS and process them."
1✔
48

49
    settings_to_locals = [
1✔
50
        SettingToLocal(
51
            "PROCESS_EMAIL_BATCH_SIZE",
52
            "batch_size",
53
            "Number of SQS messages to fetch at a time.",
54
            lambda batch_size: 0 < batch_size <= 10,
55
        ),
56
        SettingToLocal(
57
            "PROCESS_EMAIL_WAIT_SECONDS",
58
            "wait_seconds",
59
            "Time to wait for messages with long polling.",
60
            lambda wait_seconds: wait_seconds > 0,
61
        ),
62
        SettingToLocal(
63
            "PROCESS_EMAIL_VISIBILITY_SECONDS",
64
            "visibility_seconds",
65
            "Time to mark a message as reserved for this process.",
66
            lambda visibility_seconds: visibility_seconds > 0,
67
        ),
68
        SettingToLocal(
69
            "PROCESS_EMAIL_HEALTHCHECK_PATH",
70
            "healthcheck_path",
71
            "Path to file to write healthcheck data.",
72
            lambda healthcheck_path: healthcheck_path is not None,
73
        ),
74
        SettingToLocal(
75
            "PROCESS_EMAIL_DELETE_FAILED_MESSAGES",
76
            "delete_failed_messages",
77
            (
78
                "If a message fails to process, delete it from the queue,"
79
                " instead of letting SQS resend or move to a dead-letter queue."
80
            ),
81
            lambda delete_failed_messages: delete_failed_messages in (True, False),
82
        ),
83
        SettingToLocal(
84
            "PROCESS_EMAIL_MAX_SECONDS",
85
            "max_seconds",
86
            "Maximum time to process before exiting, or None to run forever.",
87
            lambda max_seconds: max_seconds is None or max_seconds > 0.0,
88
        ),
89
        SettingToLocal(
90
            "PROCESS_EMAIL_MAX_SECONDS_PER_MESSAGE",
91
            "max_seconds_per_message",
92
            "Maximum time to process a message before cancelling.",
93
            lambda max_seconds: max_seconds > 0.0,
94
        ),
95
        SettingToLocal(
96
            "AWS_REGION",
97
            "aws_region",
98
            "AWS region of the SQS queue",
99
            lambda aws_region: bool(aws_region),
100
        ),
101
        SettingToLocal(
102
            "AWS_SQS_EMAIL_QUEUE_URL",
103
            "sqs_url",
104
            "URL of the SQL queue",
105
            lambda sqs_url: bool(sqs_url),
106
        ),
107
        SettingToLocal(
108
            "PROCESS_EMAIL_VERBOSITY",
109
            "verbosity",
110
            "Default verbosity of the process logs",
111
            lambda verbosity: verbosity in range(5),
112
        ),
113
    ]
114

115
    # Added by CommandFromDjangoSettings.init_from_settings
116
    batch_size: int
1✔
117
    wait_seconds: int
1✔
118
    visibility_seconds: int
1✔
119
    healthcheck_path: str
1✔
120
    delete_failed_messages: bool
1✔
121
    max_seconds: float | None
1✔
122
    max_seconds_per_message: float
1✔
123
    aws_region: str
1✔
124
    sqs_url: str
1✔
125
    verbosity: int
1✔
126

127
    def handle(self, verbosity: int, *args: Any, **kwargs: Any) -> None:
1✔
128
        """Handle call from command line (called by BaseCommand)"""
129
        self.init_from_settings(verbosity)
1✔
130
        self.init_locals()
1✔
131
        logger.info(
1✔
132
            "Starting process_emails_from_sqs",
133
            extra={
134
                "batch_size": self.batch_size,
135
                "wait_seconds": self.wait_seconds,
136
                "visibility_seconds": self.visibility_seconds,
137
                "healthcheck_path": self.healthcheck_path,
138
                "delete_failed_messages": self.delete_failed_messages,
139
                "max_seconds": self.max_seconds,
140
                "max_seconds_per_message": self.max_seconds_per_message,
141
                "aws_region": self.aws_region,
142
                "sqs_url": self.sqs_url,
143
                "verbosity": self.verbosity,
144
            },
145
        )
146

147
        try:
1✔
148
            self.queue = self.create_client()
1✔
149
        except ClientError as e:
1✔
150
            raise CommandError("Unable to connect to SQS") from e
1✔
151

152
        process_data = self.process_queue()
1✔
153
        logger.info("Exiting process_emails_from_sqs", extra=process_data)
1✔
154

155
    def init_locals(self) -> None:
1✔
156
        """Initialize command attributes that don't come from settings."""
157
        self.queue_name = urlsplit(self.sqs_url).path.split("/")[-1]
1✔
158
        self.halt_requested = False
1✔
159
        self.start_time: float = 0.0
1✔
160
        self.cycles: int = 0
1✔
161
        self.total_messages: int = 0
1✔
162
        self.failed_messages: int = 0
1✔
163
        self.pause_count: int = 0
1✔
164
        self.queue_count: int = 0
1✔
165
        self.queue_count_delayed: int = 0
1✔
166
        self.queue_count_not_visible: int = 0
1✔
167

168
    def create_client(self) -> SQSQueue:
1✔
169
        """Create the SQS client."""
170
        if not self.aws_region:
1!
171
            raise ValueError("self.aws_region must be truthy value.")
×
172
        if not self.sqs_url:
1!
173
            raise ValueError("self.sqs_url must be truthy value.")
×
174
        sqs_client = boto3.resource("sqs", region_name=self.aws_region)
1✔
175
        return sqs_client.Queue(self.sqs_url)
1✔
176

177
    def process_queue(self) -> dict[str, Any]:
1✔
178
        """
179
        Process the SQS email queue until an exit condition is reached.
180

181
        Return is a dict suitable for logging context, with these keys:
182
        * exit_on: Why processing exited - "interrupt", "max_seconds", "unknown"
183
        * cycles: How many polling cycles completed
184
        * total_s: The total execution time, in seconds with millisecond precision
185
        * total_messages: The number of messages processed, with and without errors
186
        * failed_messages: The number of messages that failed with errors,
187
          omitted if none
188
        * pause_count: The number of 1-second pauses due to temporary errors
189
        """
190
        exit_on = "unknown"
1✔
191
        self.cycles = 0
1✔
192
        self.total_messages = 0
1✔
193
        self.failed_messages = 0
1✔
194
        self.pause_count = 0
1✔
195
        self.start_time = time.monotonic()
1✔
196

197
        while not self.halt_requested:
1✔
198
            try:
1✔
199
                cycle_data: dict[str, Any] = {
1✔
200
                    "cycle_num": self.cycles,
201
                    "cycle_s": 0.0,
202
                }
203
                cycle_data.update(self.refresh_and_emit_queue_count_metrics())
1✔
204
                self.write_healthcheck()
1✔
205

206
                # Check if we should exit due to time limit
207
                if self.max_seconds is not None:
1✔
208
                    elapsed = time.monotonic() - self.start_time
1✔
209
                    if elapsed >= self.max_seconds:
1✔
210
                        exit_on = "max_seconds"
1✔
211
                        break
1✔
212

213
                # Request and process a chunk of messages
214
                with Timer(logger=None) as cycle_timer:
1✔
215
                    message_batch, queue_data = self.poll_queue_for_messages()
1✔
216
                    cycle_data.update(queue_data)
1✔
217
                    cycle_data.update(self.process_message_batch(message_batch))
1✔
218

219
                # Collect data and log progress
220
                self.total_messages += len(message_batch)
1✔
221
                self.failed_messages += int(cycle_data.get("failed_count", 0))
1✔
222
                self.pause_count += int(cycle_data.get("pause_count", 0))
1✔
223
                cycle_data["message_total"] = self.total_messages
1✔
224
                cycle_data["cycle_s"] = round(cycle_timer.last, 3)
1✔
225
                logger.log(
1✔
226
                    (
227
                        logging.INFO
228
                        if (message_batch or self.verbosity > 1)
229
                        else logging.DEBUG
230
                    ),
231
                    (
232
                        f"Cycle {self.cycles}: processed"
233
                        f" {self.pluralize(len(message_batch), 'message')}"
234
                    ),
235
                    extra=cycle_data,
236
                )
237

238
                self.cycles += 1
1✔
239
                gc.collect()  # Force garbage collection of boto3 SQS client resources
1✔
240

241
            except KeyboardInterrupt:
1✔
242
                self.halt_requested = True
1✔
243
                exit_on = "interrupt"
1✔
244

245
        process_data = {
1✔
246
            "exit_on": exit_on,
247
            "cycles": self.cycles,
248
            "total_s": round(time.monotonic() - self.start_time, 3),
249
            "total_messages": self.total_messages,
250
        }
251
        if self.failed_messages:
1✔
252
            process_data["failed_messages"] = self.failed_messages
1✔
253
        if self.pause_count:
1!
UNCOV
254
            process_data["pause_count"] = self.pause_count
×
255
        return process_data
1✔
256

257
    def refresh_and_emit_queue_count_metrics(self) -> dict[str, float | int]:
1✔
258
        """
259
        Query SQS queue attributes, store backlog metrics, and emit them as gauge stats
260

261
        Return is a dict suitable for logging context, with these keys:
262
        * queue_load_s: How long, in seconds (millisecond precision) it took to
263
          load attributes
264
        * queue_count: Approximate number of messages in queue
265
        * queue_count_delayed: Approx. messages not yet ready for receiving
266
        * queue_count_not_visible: Approx. messages reserved by other receiver
267

268
        """
269
        # Load attributes from SQS
270
        with Timer(logger=None) as attribute_timer:
1✔
271
            self.queue.load()
1✔
272

273
        # Save approximate queue counts
274
        self.queue_count = int(self.queue.attributes["ApproximateNumberOfMessages"])
1✔
275
        self.queue_count_delayed = int(
1✔
276
            self.queue.attributes["ApproximateNumberOfMessagesDelayed"]
277
        )
278
        self.queue_count_not_visible = int(
1✔
279
            self.queue.attributes["ApproximateNumberOfMessagesNotVisible"]
280
        )
281

282
        # Emit gauges for approximate queue counts
283
        queue_tag = generate_tag("queue", self.queue_name)
1✔
284
        gauge_if_enabled("email_queue_count", self.queue_count, tags=[queue_tag])
1✔
285
        gauge_if_enabled(
1✔
286
            "email_queue_count_delayed", self.queue_count_delayed, tags=[queue_tag]
287
        )
288
        gauge_if_enabled(
1✔
289
            "email_queue_count_not_visible",
290
            self.queue_count_not_visible,
291
            tags=[queue_tag],
292
        )
293

294
        return {
1✔
295
            "queue_load_s": round(attribute_timer.last, 3),
296
            "queue_count": self.queue_count,
297
            "queue_count_delayed": self.queue_count_delayed,
298
            "queue_count_not_visible": self.queue_count_not_visible,
299
        }
300

301
    def poll_queue_for_messages(
1✔
302
        self,
303
    ) -> tuple[list[SQSMessage], dict[str, float | int]]:
304
        """Request a batch of messages, using the long-poll method.
305

306
        Return is a tuple:
307
        * message_batch: a list of messages, which may be empty
308
        * data: A dict suitable for logging context, with these keys:
309
            - message_count: the number of messages
310
            - sqs_poll_s: The poll time, in seconds with millisecond precision
311
        """
312
        with Timer(logger=None) as poll_timer:
1✔
313
            message_batch = self.queue.receive_messages(
1✔
314
                MaxNumberOfMessages=self.batch_size,
315
                VisibilityTimeout=self.visibility_seconds,
316
                WaitTimeSeconds=self.wait_seconds,
317
            )
318
        return (
1✔
319
            message_batch,
320
            {
321
                "message_count": len(message_batch),
322
                "sqs_poll_s": round(poll_timer.last, 3),
323
            },
324
        )
325

326
    def process_message_batch(self, message_batch: list[SQSMessage]) -> dict[str, Any]:
1✔
327
        """
328
        Process a batch of messages.
329

330
        Arguments:
331
        * messages - a list of SQS messages, possibly empty
332

333
        Return is a dict suitable for logging context, with these keys:
334
        * process_s: How long processing took, omitted if no messages
335
        * pause_count: How many pauses were taken for temporary errors, omitted if 0
336
        * pause_s: How long pauses took, omitted if no pauses
337
        * failed_count: How many messages failed to process, omitted if 0
338

339
        Times are in seconds, with millisecond precision
340
        """
341
        if not message_batch:
1✔
342
            return {}
1✔
343
        failed_count = 0
1✔
344
        pause_time = 0.0
1✔
345
        pause_count = 0
1✔
346
        process_time = 0.0
1✔
347
        for message in message_batch:
1✔
348
            self.write_healthcheck()
1✔
349
            with Timer(logger=None) as message_timer:
1✔
350
                message_data = self.process_message(message)
1✔
351
                if not message_data["success"]:
1✔
352
                    failed_count += 1
1✔
353
                if message_data["success"] or self.delete_failed_messages:
1✔
354
                    message.delete()
1✔
355
                pause_time += message_data.get("pause_s", 0.0)
1✔
356
                pause_count += message_data.get("pause_count", 0)
1✔
357

358
            message_data["message_process_time_s"] = round(message_timer.last, 3)
1✔
359
            process_time += message_timer.last
1✔
360
            logger.log(logging.INFO, "Message processed", extra=message_data)
1✔
361

362
        batch_data = {"process_s": round((process_time - pause_time), 3)}
1✔
363
        if pause_count:
1!
UNCOV
364
            batch_data["pause_count"] = pause_count
×
UNCOV
365
            batch_data["pause_s"] = round(pause_time, 3)
×
366
        if failed_count:
1✔
367
            batch_data["failed_count"] = failed_count
1✔
368
        return batch_data
1✔
369

370
    def process_message(self, message: SQSMessage) -> dict[str, Any]:
1✔
371
        """
372
        Process an SQS message, which may include sending an email.
373

374
        Return is a dict suitable for logging context, with these keys:
375
        * success: True if message was processed successfully
376
        * error: The processing error, omitted on success
377
        * message_body_quoted: Set if the message was non-JSON, omitted for valid JSON
378
        * pause_count: Set to 1 if paused due to temporary error, or omitted
379
          with no error
380
        * pause_s: The pause in seconds (ms precision) for temp error, or omitted
381
        * pause_error: The temporary error, or omitted if no temp error
382
        * client_error_code: The error code for non-temp or retry error,
383
          omitted on success
384
        """
385
        incr_if_enabled("process_message_from_sqs", 1)
1✔
386
        results = {"success": True, "sqs_message_id": message.message_id}
1✔
387
        raw_body = message.body
1✔
388
        try:
1✔
389
            json_body = json.loads(raw_body)
1✔
390
        except ValueError as e:
1✔
391
            results["success"] = False
1✔
392
            results["error"] = f"Failed to load message.body: {e}"
1✔
393
            results["message_body_quoted"] = shlex.quote(raw_body)
1✔
394
            return results
1✔
395
        try:
1✔
396
            verified_json_body = verify_from_sns(json_body)
1✔
397
        except (KeyError, OpenSSL.crypto.Error) as e:
1✔
398
            logger.error("Failed SNS verification", extra={"error": str(e)})
1✔
399
            results["success"] = False
1✔
400
            results["error"] = f"Failed SNS verification: {e}"
1✔
401
            return results
1✔
402

403
        topic_arn = verified_json_body["TopicArn"]
1✔
404
        message_type = verified_json_body["Type"]
1✔
405
        error_details = validate_sns_arn_and_type(topic_arn, message_type)
1✔
406
        if error_details:
1✔
407
            results["success"] = False
1✔
408
            results.update(error_details)
1✔
409
            return results
1✔
410

411
        def success_callback(result: HttpResponse) -> None:
1✔
412
            """Handle return from successful call to _sns_inbound_logic"""
413
            # TODO: extract data from _sns_inbound_logic return
414
            pass
1✔
415

416
        def error_callback(exc_info: BaseException) -> None:
1✔
417
            """Handle exception raised by _sns_inbound_logic"""
418
            capture_exception(exc_info)
1✔
419
            results["success"] = False
1✔
420
            if isinstance(exc_info, ClientError):
1✔
421
                incr_if_enabled("message_from_sqs_error")
1✔
422
                err = exc_info.response["Error"]
1✔
423
                logger.error("sqs_client_error", extra=err)
1✔
424
                results["error"] = err
1✔
425
                results["client_error_code"] = err["Code"].lower()
1✔
426
            else:
427
                incr_if_enabled("email_processing_failure")
1✔
428
                results["error"] = str(exc_info)
1✔
429
                results["error_type"] = type(exc_info).__name__
1✔
430

431
        # Run in a multiprocessing Pool
432
        # This will start a subprocess, which needs to run django.setup
433
        # The benefit is that the subprocess can be terminated
434
        # The penalty is that is is slower to start
435
        pool_start_time = time.monotonic()
1✔
436
        with Pool(1, initializer=setup) as pool:
1✔
437
            future = pool.apply_async(
1✔
438
                _sns_inbound_logic,
439
                [topic_arn, message_type, verified_json_body],
440
                callback=success_callback,
441
                error_callback=error_callback,
442
            )
443
            setup_time = time.monotonic() - pool_start_time
1✔
444
            results["subprocess_setup_time_s"] = round(setup_time, 3)
1✔
445

446
            message_start_time = time.monotonic()
1✔
447
            message_duration = 0.0
1✔
448
            while message_duration < self.max_seconds_per_message:
1✔
449
                self.write_healthcheck()
1✔
450
                future.wait(1.0)
1✔
451
                message_duration = time.monotonic() - message_start_time
1✔
452
                if future.ready():
1✔
453
                    break
1✔
454

455
            results["message_process_time_s"] = round(message_duration, 3)
1✔
456
            if not future.ready():
1✔
457
                error = f"Timed out after {self.max_seconds_per_message:0.1f} seconds."
1✔
458
                results["success"] = False
1✔
459
                results["error"] = error
1✔
460
        return results
1✔
461

462
    def write_healthcheck(self) -> None:
1✔
463
        """Update the healthcheck file with operations data, if path is set."""
464
        data: dict[str, str | int] = {
1✔
465
            "timestamp": datetime.now(tz=UTC).isoformat(),
466
            "cycles": self.cycles,
467
            "total_messages": self.total_messages,
468
            "failed_messages": self.failed_messages,
469
            "pause_count": self.pause_count,
470
            "queue_count": int(self.queue.attributes["ApproximateNumberOfMessages"]),
471
            "queue_count_delayed": int(
472
                self.queue.attributes["ApproximateNumberOfMessagesDelayed"]
473
            ),
474
            "queue_count_not_visible": int(
475
                self.queue.attributes["ApproximateNumberOfMessagesNotVisible"]
476
            ),
477
        }
478
        with open(self.healthcheck_path, "w", encoding="utf-8") as healthcheck_file:
1✔
479
            json.dump(data, healthcheck_file)
1✔
480

481
    def pluralize(self, value: int, singular: str, plural: str | None = None) -> str:
1✔
482
        """Returns 's' suffix to make plural, like 's' in tasks"""
483
        if value == 1:
1✔
484
            return f"{value} {singular}"
1✔
485
        else:
486
            return f"{value} {plural or (singular + 's')}"
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc