• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mozilla / fx-private-relay / 23182427-2247-476e-8849-1b3b4acceb2c

15 Oct 2024 04:08PM CUT coverage: 84.491% (+0.09%) from 84.402%
23182427-2247-476e-8849-1b3b4acceb2c

push

circleci

web-flow
Merge pull request #5090 from mozilla/add-developer-mode-flag-mpp-3932

MPP-3932: Add flag 'developer_mode', use to simulate complaint and log notifications

2372 of 3515 branches covered (67.48%)

Branch coverage included in aggregate %.

143 of 145 new or added lines in 4 files covered. (98.62%)

1 existing line in 1 file now uncovered.

16456 of 18769 relevant lines covered (87.68%)

10.16 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

74.8
/emails/utils.py
1
from __future__ import annotations
1✔
2

3
import base64
1✔
4
import contextlib
1✔
5
import json
1✔
6
import logging
1✔
7
import pathlib
1✔
8
import re
1✔
9
import zlib
1✔
10
from collections.abc import Callable
1✔
11
from email.errors import InvalidHeaderDefect
1✔
12
from email.headerregistry import Address, AddressHeader
1✔
13
from email.message import EmailMessage
1✔
14
from email.utils import formataddr, parseaddr
1✔
15
from functools import cache
1✔
16
from typing import Any, Literal, TypeVar, cast
1✔
17
from urllib.parse import quote_plus, urlparse
1✔
18

19
from django.conf import settings
1✔
20
from django.contrib.auth.models import Group, User
1✔
21
from django.template.defaultfilters import linebreaksbr, urlize
1✔
22
from django.template.loader import render_to_string
1✔
23
from django.utils.text import Truncator
1✔
24

25
import jwcrypto.jwe
1✔
26
import jwcrypto.jwk
1✔
27
import markus
1✔
28
import requests
1✔
29
from allauth.socialaccount.models import SocialAccount
1✔
30
from botocore.exceptions import ClientError
1✔
31
from cryptography.hazmat.primitives import hashes
1✔
32
from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand
1✔
33
from mypy_boto3_ses.type_defs import ContentTypeDef, SendRawEmailResponseTypeDef
1✔
34

35
from privaterelay.plans import get_bundle_country_language_mapping
1✔
36
from privaterelay.utils import get_countries_info_from_lang_and_mapping
1✔
37

38
from .apps import s3_client, ses_client
1✔
39

40
logger = logging.getLogger("events")
1✔
41
info_logger = logging.getLogger("eventsinfo")
1✔
42
study_logger = logging.getLogger("studymetrics")
1✔
43
metrics = markus.get_metrics("fx-private-relay")
1✔
44

45
shavar_prod_lists_url = (
1✔
46
    "https://raw.githubusercontent.com/mozilla-services/shavar-prod-lists/"
47
    "master/disconnect-blacklist.json"
48
)
49
EMAILS_FOLDER_PATH = pathlib.Path(__file__).parent
1✔
50
TRACKER_FOLDER_PATH = EMAILS_FOLDER_PATH / "tracker_lists"
1✔
51

52

53
def ses_message_props(data: str) -> ContentTypeDef:
1✔
54
    return {"Charset": "UTF-8", "Data": data}
1✔
55

56

57
def get_domains_from_settings() -> (
1✔
58
    dict[Literal["RELAY_FIREFOX_DOMAIN", "MOZMAIL_DOMAIN"], str]
59
):
60
    # HACK: detect if code is running in django tests
61
    if "testserver" in settings.ALLOWED_HOSTS:
1!
62
        return {"RELAY_FIREFOX_DOMAIN": "default.com", "MOZMAIL_DOMAIN": "test.com"}
1✔
63
    return {
×
64
        "RELAY_FIREFOX_DOMAIN": settings.RELAY_FIREFOX_DOMAIN,
65
        "MOZMAIL_DOMAIN": settings.MOZMAIL_DOMAIN,
66
    }
67

68

69
def get_trackers(level):
1✔
70
    category = "Email"
×
71
    tracker_list_name = "level-one-trackers"
×
72
    if level == 2:
×
73
        category = "EmailAggressive"
×
74
        tracker_list_name = "level-two-trackers"
×
75

76
    trackers = []
×
77
    file_name = f"{tracker_list_name}.json"
×
78
    try:
×
79
        with open(TRACKER_FOLDER_PATH / file_name) as f:
×
80
            trackers = json.load(f)
×
81
    except FileNotFoundError:
×
82
        trackers = download_trackers(shavar_prod_lists_url, category)
×
83
        store_trackers(trackers, TRACKER_FOLDER_PATH, file_name)
×
84
    return trackers
×
85

86

87
def download_trackers(repo_url, category="Email"):
1✔
88
    # email tracker lists from shavar-prod-list as per agreed use under license:
89
    resp = requests.get(repo_url, timeout=10)
×
90
    json_resp = resp.json()
×
91
    formatted_trackers = json_resp["categories"][category]
×
92
    trackers = []
×
93
    for entity in formatted_trackers:
×
94
        for _, resources in entity.items():
×
95
            for _, domains in resources.items():
×
96
                trackers.extend(domains)
×
97
    return trackers
×
98

99

100
def store_trackers(trackers, path, file_name):
1✔
101
    with open(path / file_name, "w+") as f:
×
102
        json.dump(trackers, f, indent=4)
×
103

104

105
@cache
1✔
106
def general_trackers():
1✔
107
    return get_trackers(level=1)
×
108

109

110
@cache
1✔
111
def strict_trackers():
1✔
112
    return get_trackers(level=2)
×
113

114

115
_TimedFunction = TypeVar("_TimedFunction", bound=Callable[..., Any])
1✔
116

117

118
def time_if_enabled(name: str) -> Callable[[_TimedFunction], _TimedFunction]:
1✔
119
    def timing_decorator(func: _TimedFunction) -> _TimedFunction:
1✔
120
        def func_wrapper(*args, **kwargs):
1✔
121
            ctx_manager = (
1✔
122
                metrics.timer(name)
123
                if settings.STATSD_ENABLED
124
                else contextlib.nullcontext()
125
            )
126
            with ctx_manager:
1✔
127
                return func(*args, **kwargs)
1✔
128

129
        return cast(_TimedFunction, func_wrapper)
1✔
130

131
    return timing_decorator
1✔
132

133

134
def incr_if_enabled(name, value=1, tags=None):
1✔
135
    if settings.STATSD_ENABLED:
1✔
136
        metrics.incr(name, value, tags)
1✔
137

138

139
def histogram_if_enabled(name, value, tags=None):
1✔
140
    if settings.STATSD_ENABLED:
1✔
141
        metrics.histogram(name, value=value, tags=tags)
1✔
142

143

144
def gauge_if_enabled(name, value, tags=None):
1✔
145
    if settings.STATSD_ENABLED:
1✔
146
        metrics.gauge(name, value, tags)
1✔
147

148

149
def get_email_domain_from_settings() -> str:
1✔
150
    email_network_locality = str(urlparse(settings.SITE_ORIGIN).netloc)
1✔
151
    # on dev server we need to add "mail" prefix
152
    # because we can’t publish MX records on Heroku
153
    if settings.RELAY_CHANNEL == "dev":
1✔
154
        email_network_locality = f"mail.{email_network_locality}"
1✔
155
    return email_network_locality
1✔
156

157

158
def parse_email_header(header_value: str) -> list[tuple[str, str]]:
1✔
159
    """
160
    Extract the (display name, email address) pairs from a header value.
161

162
    This is useful when working with header values provided by a
163
    AWS SES delivery notification.
164

165
    email.utils.parseaddr() works with well-formed emails, but fails in
166
    cases with badly formed emails where an email address could still
167
    be extracted.
168
    """
169
    address_list = AddressHeader.value_parser(header_value)
1✔
170
    pairs: list[tuple[str, str]] = []
1✔
171
    for address in address_list.addresses:
1✔
172
        for mailbox in address.all_mailboxes:
1✔
173
            addr_spec = mailbox.addr_spec
1✔
174
            if addr_spec and addr_spec.count("@") == 1:
1✔
175
                pairs.append((mailbox.display_name or "", addr_spec))
1✔
176
    return pairs
1✔
177

178

179
def _get_hero_img_src(lang_code):
1✔
180
    img_locale = "en"
1✔
181
    avail_l10n_image_codes = [
1✔
182
        "cs",
183
        "de",
184
        "en",
185
        "es",
186
        "fi",
187
        "fr",
188
        "hu",
189
        "id",
190
        "it",
191
        "ja",
192
        "nl",
193
        "pt",
194
        "ru",
195
        "sv",
196
        "zh",
197
    ]
198
    major_lang = lang_code.split("-")[0]
1✔
199
    if major_lang in avail_l10n_image_codes:
1!
200
        img_locale = major_lang
1✔
201

202
    if not settings.SITE_ORIGIN:
1!
203
        raise ValueError("settings.SITE_ORIGIN must have a value")
×
204
    return (
1✔
205
        settings.SITE_ORIGIN
206
        + f"/static/images/email-images/first-time-user/hero-image-{img_locale}.png"
207
    )
208

209

210
def get_welcome_email(user: User, format: str) -> str:
1✔
211
    sa = SocialAccount.objects.get(user=user)
1✔
212
    bundle_plans = get_countries_info_from_lang_and_mapping(
1✔
213
        sa.extra_data.get("locale", "en"), get_bundle_country_language_mapping()
214
    )
215
    lang_code = user.profile.language
1✔
216
    hero_img_src = _get_hero_img_src(lang_code)
1✔
217
    return render_to_string(
1✔
218
        f"emails/first_time_user.{format}",
219
        {
220
            "in_bundle_country": bundle_plans["available_in_country"],
221
            "SITE_ORIGIN": settings.SITE_ORIGIN,
222
            "hero_img_src": hero_img_src,
223
            "language": lang_code,
224
        },
225
    )
226

227

228
@time_if_enabled("ses_send_raw_email")
1✔
229
def ses_send_raw_email(
1✔
230
    source_address: str,
231
    destination_address: str,
232
    message: EmailMessage,
233
) -> SendRawEmailResponseTypeDef:
234
    client = ses_client()
1✔
235
    if client is None:
1!
236
        raise ValueError("client must have a value")
×
237
    if not settings.AWS_SES_CONFIGSET:
1!
238
        raise ValueError("settings.AWS_SES_CONFIGSET must have a value")
×
239

240
    data = message.as_string()
1✔
241
    try:
1✔
242
        ses_response = client.send_raw_email(
1✔
243
            Source=source_address,
244
            Destinations=[destination_address],
245
            RawMessage={"Data": data},
246
            ConfigurationSetName=settings.AWS_SES_CONFIGSET,
247
        )
248
        incr_if_enabled("ses_send_raw_email", 1)
1✔
249
        return ses_response
1✔
250
    except ClientError as e:
1✔
251
        logger.error("ses_client_error_raw_email", extra=e.response["Error"])
1✔
252
        raise
1✔
253

254

255
def urlize_and_linebreaks(text, autoescape=True):
1✔
256
    return linebreaksbr(urlize(text, autoescape=autoescape), autoescape=autoescape)
1✔
257

258

259
def get_reply_to_address(premium: bool = True) -> str:
1✔
260
    """Return the address that relays replies."""
261
    if premium:
1!
262
        _, reply_to_address = parseaddr(
1✔
263
            "replies@{}".format(get_domains_from_settings().get("RELAY_FIREFOX_DOMAIN"))
264
        )
265
    else:
266
        _, reply_to_address = parseaddr(settings.RELAY_FROM_ADDRESS)
×
267
    return reply_to_address
1✔
268

269

270
def truncate(max_length: int, value: str) -> str:
1✔
271
    """
272
    Truncate a string to a maximum length.
273

274
    If the value is all ASCII, the truncation suffix will be ...
275
    If the value is non-ASCII, the truncation suffix will be … (Unicode ellipsis)
276
    """
277
    if len(value) <= max_length:
1✔
278
        return value
1✔
279
    ellipsis = "..."  # ASCII Ellipsis
1✔
280
    try:
1✔
281
        value.encode("ascii")
1✔
282
    except UnicodeEncodeError:
1✔
283
        ellipsis = "…"
1✔
284
    return Truncator(value).chars(max_length, truncate=ellipsis)
1✔
285

286

287
class InvalidFromHeader(Exception):
1✔
288
    pass
1✔
289

290

291
def generate_from_header(original_from_address: str, relay_mask: str) -> str:
1✔
292
    """
293
    Return a From: header str using the original sender and a display name that
294
    refers to Relay.
295

296
    This format was introduced in June 2023 with MPP-2117.
297
    """
298
    oneline_from_address = (
1✔
299
        original_from_address.replace("\u2028", "").replace("\r", "").replace("\n", "")
300
    )
301
    display_name, original_address = parseaddr(oneline_from_address)
1✔
302
    try:
1✔
303
        parsed_address = Address(addr_spec=original_address)
1✔
304
    except (InvalidHeaderDefect, IndexError) as e:
1✔
305
        # TODO: MPP-3407, MPP-3417 - Determine how to handle these
306
        raise InvalidFromHeader from e
1✔
307

308
    # Truncate the display name to 71 characters, so the sender portion fits on the
309
    # first line of a multi-line "From:" header, if it is ASCII. A utf-8 encoded header
310
    # will be 226 chars, still below the 998 limit of RFC 5322 2.1.1.
311
    max_length = 71
1✔
312

313
    if display_name:
1✔
314
        short_name = truncate(max_length, display_name)
1✔
315
        short_address = truncate(max_length, parsed_address.addr_spec)
1✔
316
        sender = f"{short_name} <{short_address}>"
1✔
317
    else:
318
        # Use the email address if the display name was not originally set
319
        display_name = parsed_address.addr_spec
1✔
320
        sender = truncate(max_length, display_name)
1✔
321
    return formataddr((f"{sender} [via Relay]", relay_mask))
1✔
322

323

324
def get_message_id_bytes(message_id_str: str) -> bytes:
1✔
325
    message_id = message_id_str.split("@", 1)[0].rsplit("<", 1)[-1].strip()
1✔
326
    return message_id.encode()
1✔
327

328

329
def b64_lookup_key(lookup_key: bytes) -> str:
1✔
330
    return base64.urlsafe_b64encode(lookup_key).decode("ascii")
1✔
331

332

333
def derive_reply_keys(message_id: bytes) -> tuple[bytes, bytes]:
1✔
334
    """Derive the lookup key and encryption key from an aliased message id."""
335
    algorithm = hashes.SHA256()
1✔
336
    hkdf = HKDFExpand(algorithm=algorithm, length=16, info=b"replay replies lookup key")
1✔
337
    lookup_key = hkdf.derive(message_id)
1✔
338
    hkdf = HKDFExpand(
1✔
339
        algorithm=algorithm, length=32, info=b"replay replies encryption key"
340
    )
341
    encryption_key = hkdf.derive(message_id)
1✔
342
    return (lookup_key, encryption_key)
1✔
343

344

345
def encrypt_reply_metadata(key: bytes, payload: dict[str, str]) -> str:
1✔
346
    """Encrypt the given payload into a JWE, using the given key."""
347
    # This is a bit dumb, we have to base64-encode the key in order to load it :-/
348
    k = jwcrypto.jwk.JWK(
1✔
349
        kty="oct", k=base64.urlsafe_b64encode(key).rstrip(b"=").decode("ascii")
350
    )
351
    e = jwcrypto.jwe.JWE(
1✔
352
        json.dumps(payload), json.dumps({"alg": "dir", "enc": "A256GCM"}), recipient=k
353
    )
354
    return cast(str, e.serialize(compact=True))
1✔
355

356

357
def decrypt_reply_metadata(key, jwe):
1✔
358
    """Decrypt the given JWE into a json payload, using the given key."""
359
    # This is a bit dumb, we have to base64-encode the key in order to load it :-/
360
    k = jwcrypto.jwk.JWK(
1✔
361
        kty="oct", k=base64.urlsafe_b64encode(key).rstrip(b"=").decode("ascii")
362
    )
363
    e = jwcrypto.jwe.JWE()
1✔
364
    e.deserialize(jwe)
1✔
365
    e.decrypt(k)
1✔
366
    return e.plaintext
1✔
367

368

369
def _get_bucket_and_key_from_s3_json(message_json):
1✔
370
    # Only Received notifications have S3-stored data
371
    notification_type = message_json.get("notificationType")
1✔
372
    if notification_type != "Received":
1✔
373
        return None, None
1✔
374

375
    if "receipt" in message_json and "action" in message_json["receipt"]:
1!
376
        message_json_receipt = message_json["receipt"]
1✔
377
    else:
378
        logger.error(
×
379
            "sns_inbound_message_without_receipt",
380
            extra={"message_json_keys": message_json.keys()},
381
        )
382
        return None, None
×
383

384
    bucket = None
1✔
385
    object_key = None
1✔
386
    try:
1✔
387
        if "S3" in message_json_receipt["action"]["type"]:
1✔
388
            bucket = message_json_receipt["action"]["bucketName"]
1✔
389
            object_key = message_json_receipt["action"]["objectKey"]
1✔
390
    except (KeyError, TypeError):
×
391
        logger.error(
×
392
            "sns_inbound_message_receipt_malformed",
393
            extra={"receipt_action": message_json_receipt["action"]},
394
        )
395
    return bucket, object_key
1✔
396

397

398
@time_if_enabled("s3_get_message_content")
1✔
399
def get_message_content_from_s3(bucket, object_key):
1✔
400
    if bucket and object_key:
×
401
        client = s3_client()
×
402
        if client is None:
×
403
            raise ValueError("client must not be None")
×
404
        streamed_s3_object = client.get_object(Bucket=bucket, Key=object_key).get(
×
405
            "Body"
406
        )
407
        return streamed_s3_object.read()
×
408

409

410
@time_if_enabled("s3_remove_message_from")
1✔
411
def remove_message_from_s3(bucket, object_key):
1✔
412
    if bucket is None or object_key is None:
1!
413
        return False
1✔
414
    try:
×
415
        client = s3_client()
×
416
        if client is None:
×
417
            raise ValueError("client must not be None")
×
418
        response = client.delete_object(Bucket=bucket, Key=object_key)
×
419
        return response.get("DeleteMarker")
×
420
    except ClientError as e:
×
421
        if e.response["Error"].get("Code", "") == "NoSuchKey":
×
422
            logger.error("s3_delete_object_does_not_exist", extra=e.response["Error"])
×
423
        else:
424
            logger.error("s3_client_error_delete_email", extra=e.response["Error"])
×
425
        incr_if_enabled("message_not_removed_from_s3", 1)
×
426
    return False
×
427

428

429
def set_user_group(user):
1✔
430
    if "@" not in user.email:
1✔
431
        return None
1✔
432
    email_domain = user.email.split("@")[1]
1✔
433
    group_attribute = {
1✔
434
        "mozilla.com": "mozilla_corporation",
435
        "mozillafoundation.org": "mozilla_foundation",
436
        "getpocket.com": "pocket",
437
    }
438
    group_name = group_attribute.get(email_domain)
1✔
439
    if not group_name:
1!
440
        return None
1✔
441
    internal_group_qs = Group.objects.filter(name=group_name)
×
442
    internal_group = internal_group_qs.first()
×
443
    if internal_group is None:
×
444
        return None
×
445
    internal_group.user_set.add(user)
×
446

447

448
def convert_domains_to_regex_patterns(domain_pattern):
1✔
449
    return r"""(["'])(\S*://(\S*\.)*""" + re.escape(domain_pattern) + r"\S*)\1"
1✔
450

451

452
def count_tracker(html_content, trackers):
1✔
453
    tracker_total = 0
1✔
454
    details = {}
1✔
455
    # html_content needs to be str for count()
456
    for tracker in trackers:
1✔
457
        pattern = convert_domains_to_regex_patterns(tracker)
1✔
458
        html_content, count = re.subn(pattern, "", html_content)
1✔
459
        if count:
1✔
460
            tracker_total += count
1✔
461
            details[tracker] = count
1✔
462
    return {"count": tracker_total, "trackers": details}
1✔
463

464

465
def count_all_trackers(html_content):
1✔
466
    general_detail = count_tracker(html_content, general_trackers())
×
467
    strict_detail = count_tracker(html_content, strict_trackers())
×
468

469
    incr_if_enabled("tracker.general_count", general_detail["count"])
×
470
    incr_if_enabled("tracker.strict_count", strict_detail["count"])
×
471
    study_logger.info(
×
472
        "email_tracker_summary",
473
        extra={"level_one": general_detail, "level_two": strict_detail},
474
    )
475

476

477
def remove_trackers(html_content, from_address, datetime_now, level="general"):
1✔
478
    trackers = general_trackers() if level == "general" else strict_trackers()
1✔
479
    tracker_removed = 0
1✔
480
    changed_content = html_content
1✔
481

482
    for tracker in trackers:
1✔
483
        pattern = convert_domains_to_regex_patterns(tracker)
1✔
484

485
        def convert_to_tracker_warning_link(matchobj):
1✔
486
            quote, original_link, _ = matchobj.groups()
1✔
487
            tracker_link_details = {
1✔
488
                "sender": from_address,
489
                "received_at": datetime_now,
490
                "original_link": original_link,
491
            }
492
            anchor = quote_plus(json.dumps(tracker_link_details, separators=(",", ":")))
1✔
493
            url = f"{settings.SITE_ORIGIN}/contains-tracker-warning/#{anchor}"
1✔
494
            return f"{quote}{url}{quote}"
1✔
495

496
        changed_content, matched = re.subn(
1✔
497
            pattern, convert_to_tracker_warning_link, changed_content
498
        )
499
        tracker_removed += matched
1✔
500

501
    level_one_detail = count_tracker(html_content, general_trackers())
1✔
502
    level_two_detail = count_tracker(html_content, strict_trackers())
1✔
503

504
    tracker_details = {
1✔
505
        "tracker_removed": tracker_removed,
506
        "level_one": level_one_detail,
507
    }
508
    logger_details = {"level": level, "level_two": level_two_detail}
1✔
509
    logger_details.update(tracker_details)
1✔
510
    info_logger.info(
1✔
511
        "email_tracker_summary",
512
        extra=logger_details,
513
    )
514
    return changed_content, tracker_details
1✔
515

516

517
def encode_dict_gza85(data: dict[str, Any]) -> str:
1✔
518
    """
519
    Encode a dict to the compressed Ascii85 format
520

521
    The dict will be JSON-encoded will be compressed, Ascii85-encoded with padding, and
522
    split by newlines into 1024-bytes chunks. This can be used to ensure it fits into
523
    a GCP log entry, which has a 64KB limit per label value.
524
    """
525
    return base64.a85encode(
1✔
526
        zlib.compress(json.dumps(data).encode()), wrapcol=1024, pad=True
527
    ).decode("ascii")
528

529

530
def decode_dict_gza85(encoded_data: str) -> dict[str, Any]:
1✔
531
    """Decode a dict encoded with _encode_dict_gza85."""
532
    data = json.loads(
1✔
533
        zlib.decompress(base64.a85decode(encoded_data.encode("ascii"))).decode()
534
    )
535
    if not isinstance(data, dict):
1✔
536
        raise ValueError("Encoded data is not a dict")
1✔
537
    if any(not isinstance(key, str) for key in data):
1!
NEW
538
        raise ValueError("Encoded data has non-str key")
×
539
    return data
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc