• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mozilla / fx-private-relay / d3128616-238d-446e-82c5-ab66cd38ceaf

09 May 2024 06:22PM CUT coverage: 84.07% (-0.6%) from 84.64%
d3128616-238d-446e-82c5-ab66cd38ceaf

push

circleci

web-flow
Merge pull request #4684 from mozilla/enable-flak8-bandit-checks-mpp-3802

fix MPP-3802: stop ignoring bandit security checks

3601 of 4734 branches covered (76.07%)

Branch coverage included in aggregate %.

74 of 158 new or added lines in 24 files covered. (46.84%)

5 existing lines in 5 files now uncovered.

14686 of 17018 relevant lines covered (86.3%)

10.86 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

74.39
/emails/utils.py
1
from __future__ import annotations
1✔
2

3
import base64
1✔
4
import contextlib
1✔
5
import json
1✔
6
import logging
1✔
7
import pathlib
1✔
8
import re
1✔
9
from collections.abc import Callable
1✔
10
from email.errors import InvalidHeaderDefect
1✔
11
from email.headerregistry import Address, AddressHeader
1✔
12
from email.message import EmailMessage
1✔
13
from email.utils import formataddr, parseaddr
1✔
14
from functools import cache
1✔
15
from typing import Any, Literal, TypeVar, cast
1✔
16
from urllib.parse import quote_plus, urlparse
1✔
17

18
from django.conf import settings
1✔
19
from django.contrib.auth.models import Group, User
1✔
20
from django.template.defaultfilters import linebreaksbr, urlize
1✔
21
from django.template.loader import render_to_string
1✔
22
from django.utils.text import Truncator
1✔
23

24
import jwcrypto.jwe
1✔
25
import jwcrypto.jwk
1✔
26
import markus
1✔
27
import requests
1✔
28
from allauth.socialaccount.models import SocialAccount
1✔
29
from botocore.exceptions import ClientError
1✔
30
from cryptography.hazmat.primitives import hashes
1✔
31
from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand
1✔
32
from mypy_boto3_ses.type_defs import ContentTypeDef, SendRawEmailResponseTypeDef
1✔
33

34
from privaterelay.plans import get_bundle_country_language_mapping
1✔
35
from privaterelay.utils import get_countries_info_from_lang_and_mapping
1✔
36

37
from .apps import s3_client, ses_client
1✔
38

39
logger = logging.getLogger("events")
1✔
40
info_logger = logging.getLogger("eventsinfo")
1✔
41
study_logger = logging.getLogger("studymetrics")
1✔
42
metrics = markus.get_metrics("fx-private-relay")
1✔
43

44
shavar_prod_lists_url = (
1✔
45
    "https://raw.githubusercontent.com/mozilla-services/shavar-prod-lists/"
46
    "master/disconnect-blacklist.json"
47
)
48
EMAILS_FOLDER_PATH = pathlib.Path(__file__).parent
1✔
49
TRACKER_FOLDER_PATH = EMAILS_FOLDER_PATH / "tracker_lists"
1✔
50

51

52
def ses_message_props(data: str) -> ContentTypeDef:
1✔
53
    return {"Charset": "UTF-8", "Data": data}
1✔
54

55

56
def get_domains_from_settings() -> (
1✔
57
    dict[Literal["RELAY_FIREFOX_DOMAIN", "MOZMAIL_DOMAIN"], str]
58
):
59
    # HACK: detect if code is running in django tests
60
    if "testserver" in settings.ALLOWED_HOSTS:
1!
61
        return {"RELAY_FIREFOX_DOMAIN": "default.com", "MOZMAIL_DOMAIN": "test.com"}
1✔
62
    return {
×
63
        "RELAY_FIREFOX_DOMAIN": settings.RELAY_FIREFOX_DOMAIN,
64
        "MOZMAIL_DOMAIN": settings.MOZMAIL_DOMAIN,
65
    }
66

67

68
def get_trackers(level):
1✔
69
    category = "Email"
×
70
    tracker_list_name = "level-one-trackers"
×
71
    if level == 2:
×
72
        category = "EmailAggressive"
×
73
        tracker_list_name = "level-two-trackers"
×
74

75
    trackers = []
×
76
    file_name = f"{tracker_list_name}.json"
×
77
    try:
×
78
        with open(TRACKER_FOLDER_PATH / file_name) as f:
×
79
            trackers = json.load(f)
×
80
    except FileNotFoundError:
×
81
        trackers = download_trackers(shavar_prod_lists_url, category)
×
82
        store_trackers(trackers, TRACKER_FOLDER_PATH, file_name)
×
83
    return trackers
×
84

85

86
def download_trackers(repo_url, category="Email"):
1✔
87
    # email tracker lists from shavar-prod-list as per agreed use under license:
NEW
88
    resp = requests.get(repo_url, timeout=10)
×
89
    json_resp = resp.json()
×
90
    formatted_trackers = json_resp["categories"][category]
×
91
    trackers = []
×
92
    for entity in formatted_trackers:
×
93
        for _, resources in entity.items():
×
94
            for _, domains in resources.items():
×
95
                trackers.extend(domains)
×
96
    return trackers
×
97

98

99
def store_trackers(trackers, path, file_name):
1✔
100
    with open(path / file_name, "w+") as f:
×
101
        json.dump(trackers, f, indent=4)
×
102

103

104
@cache
1✔
105
def general_trackers():
1✔
106
    return get_trackers(level=1)
×
107

108

109
@cache
1✔
110
def strict_trackers():
1✔
111
    return get_trackers(level=2)
×
112

113

114
_TimedFunction = TypeVar("_TimedFunction", bound=Callable[..., Any])
1✔
115

116

117
def time_if_enabled(name: str) -> Callable[[_TimedFunction], _TimedFunction]:
1✔
118
    def timing_decorator(func: _TimedFunction) -> _TimedFunction:
1✔
119
        def func_wrapper(*args, **kwargs):
1✔
120
            ctx_manager = (
1✔
121
                metrics.timer(name)
122
                if settings.STATSD_ENABLED
123
                else contextlib.nullcontext()
124
            )
125
            with ctx_manager:
1✔
126
                return func(*args, **kwargs)
1✔
127

128
        return cast(_TimedFunction, func_wrapper)
1✔
129

130
    return timing_decorator
1✔
131

132

133
def incr_if_enabled(name, value=1, tags=None):
1✔
134
    if settings.STATSD_ENABLED:
1✔
135
        metrics.incr(name, value, tags)
1✔
136

137

138
def histogram_if_enabled(name, value, tags=None):
1✔
139
    if settings.STATSD_ENABLED:
1✔
140
        metrics.histogram(name, value=value, tags=tags)
1✔
141

142

143
def gauge_if_enabled(name, value, tags=None):
1✔
144
    if settings.STATSD_ENABLED:
1✔
145
        metrics.gauge(name, value, tags)
1✔
146

147

148
def get_email_domain_from_settings() -> str:
1✔
149
    email_network_locality = str(urlparse(settings.SITE_ORIGIN).netloc)
1✔
150
    # on dev server we need to add "mail" prefix
151
    # because we can’t publish MX records on Heroku
152
    if settings.RELAY_CHANNEL == "dev":
1✔
153
        email_network_locality = f"mail.{email_network_locality}"
1✔
154
    return email_network_locality
1✔
155

156

157
def parse_email_header(header_value: str) -> list[tuple[str, str]]:
1✔
158
    """
159
    Extract the (display name, email address) pairs from a header value.
160

161
    This is useful when working with header values provided by a
162
    AWS SES delivery notification.
163

164
    email.utils.parseaddr() works with well-formed emails, but fails in
165
    cases with badly formed emails where an email address could still
166
    be extracted.
167
    """
168
    address_list = AddressHeader.value_parser(header_value)
1✔
169
    pairs: list[tuple[str, str]] = []
1✔
170
    for address in address_list.addresses:
1✔
171
        for mailbox in address.all_mailboxes:
1✔
172
            addr_spec = mailbox.addr_spec
1✔
173
            if addr_spec and addr_spec.count("@") == 1:
1✔
174
                pairs.append((mailbox.display_name or "", addr_spec))
1✔
175
    return pairs
1✔
176

177

178
def _get_hero_img_src(lang_code):
1✔
179
    img_locale = "en"
1✔
180
    avail_l10n_image_codes = [
1✔
181
        "cs",
182
        "de",
183
        "en",
184
        "es",
185
        "fi",
186
        "fr",
187
        "hu",
188
        "id",
189
        "it",
190
        "ja",
191
        "nl",
192
        "pt",
193
        "ru",
194
        "sv",
195
        "zh",
196
    ]
197
    major_lang = lang_code.split("-")[0]
1✔
198
    if major_lang in avail_l10n_image_codes:
1!
199
        img_locale = major_lang
1✔
200

201
    if not settings.SITE_ORIGIN:
1!
NEW
202
        raise ValueError("settings.SITE_ORIGIN must have a value")
×
203
    return (
1✔
204
        settings.SITE_ORIGIN
205
        + f"/static/images/email-images/first-time-user/hero-image-{img_locale}.png"
206
    )
207

208

209
def get_welcome_email(user: User, format: str) -> str:
1✔
210
    sa = SocialAccount.objects.get(user=user)
1✔
211
    bundle_plans = get_countries_info_from_lang_and_mapping(
1✔
212
        sa.extra_data.get("locale", "en"), get_bundle_country_language_mapping()
213
    )
214
    lang_code = user.profile.language
1✔
215
    hero_img_src = _get_hero_img_src(lang_code)
1✔
216
    return render_to_string(
1✔
217
        f"emails/first_time_user.{format}",
218
        {
219
            "in_bundle_country": bundle_plans["available_in_country"],
220
            "SITE_ORIGIN": settings.SITE_ORIGIN,
221
            "hero_img_src": hero_img_src,
222
            "language": lang_code,
223
        },
224
    )
225

226

227
@time_if_enabled("ses_send_raw_email")
1✔
228
def ses_send_raw_email(
1✔
229
    source_address: str,
230
    destination_address: str,
231
    message: EmailMessage,
232
) -> SendRawEmailResponseTypeDef:
233
    client = ses_client()
1✔
234
    if client is None:
1!
NEW
235
        raise ValueError("client must have a value")
×
236
    if not settings.AWS_SES_CONFIGSET:
1!
NEW
237
        raise ValueError("settings.AWS_SES_CONFIGSET must have a value")
×
238

239
    data = message.as_string()
1✔
240
    try:
1✔
241
        ses_response = client.send_raw_email(
1✔
242
            Source=source_address,
243
            Destinations=[destination_address],
244
            RawMessage={"Data": data},
245
            ConfigurationSetName=settings.AWS_SES_CONFIGSET,
246
        )
247
        incr_if_enabled("ses_send_raw_email", 1)
1✔
248
        return ses_response
1✔
249
    except ClientError as e:
1✔
250
        logger.error("ses_client_error_raw_email", extra=e.response["Error"])
1✔
251
        raise
1✔
252

253

254
def urlize_and_linebreaks(text, autoescape=True):
1✔
255
    return linebreaksbr(urlize(text, autoescape=autoescape), autoescape=autoescape)
1✔
256

257

258
def get_reply_to_address(premium: bool = True) -> str:
1✔
259
    """Return the address that relays replies."""
260
    if premium:
1!
261
        _, reply_to_address = parseaddr(
1✔
262
            "replies@{}".format(get_domains_from_settings().get("RELAY_FIREFOX_DOMAIN"))
263
        )
264
    else:
265
        _, reply_to_address = parseaddr(settings.RELAY_FROM_ADDRESS)
×
266
    return reply_to_address
1✔
267

268

269
def truncate(max_length: int, value: str) -> str:
1✔
270
    """
271
    Truncate a string to a maximum length.
272

273
    If the value is all ASCII, the truncation suffix will be ...
274
    If the value is non-ASCII, the truncation suffix will be … (Unicode ellipsis)
275
    """
276
    if len(value) <= max_length:
1✔
277
        return value
1✔
278
    ellipsis = "..."  # ASCII Ellipsis
1✔
279
    try:
1✔
280
        value.encode("ascii")
1✔
281
    except UnicodeEncodeError:
1✔
282
        ellipsis = "…"
1✔
283
    return Truncator(value).chars(max_length, truncate=ellipsis)
1✔
284

285

286
class InvalidFromHeader(Exception):
1✔
287
    pass
1✔
288

289

290
def generate_from_header(original_from_address: str, relay_mask: str) -> str:
1✔
291
    """
292
    Return a From: header str using the original sender and a display name that
293
    refers to Relay.
294

295
    This format was introduced in June 2023 with MPP-2117.
296
    """
297
    oneline_from_address = (
1✔
298
        original_from_address.replace("\u2028", "").replace("\r", "").replace("\n", "")
299
    )
300
    display_name, original_address = parseaddr(oneline_from_address)
1✔
301
    try:
1✔
302
        parsed_address = Address(addr_spec=original_address)
1✔
303
    except (InvalidHeaderDefect, IndexError) as e:
1✔
304
        # TODO: MPP-3407, MPP-3417 - Determine how to handle these
305
        raise InvalidFromHeader from e
1✔
306

307
    # Truncate the display name to 71 characters, so the sender portion fits on the
308
    # first line of a multi-line "From:" header, if it is ASCII. A utf-8 encoded header
309
    # will be 226 chars, still below the 998 limit of RFC 5322 2.1.1.
310
    max_length = 71
1✔
311

312
    if display_name:
1✔
313
        short_name = truncate(max_length, display_name)
1✔
314
        short_address = truncate(max_length, parsed_address.addr_spec)
1✔
315
        sender = f"{short_name} <{short_address}>"
1✔
316
    else:
317
        # Use the email address if the display name was not originally set
318
        display_name = parsed_address.addr_spec
1✔
319
        sender = truncate(max_length, display_name)
1✔
320
    return formataddr((f"{sender} [via Relay]", relay_mask))
1✔
321

322

323
def get_message_id_bytes(message_id_str: str) -> bytes:
1✔
324
    message_id = message_id_str.split("@", 1)[0].rsplit("<", 1)[-1].strip()
1✔
325
    return message_id.encode()
1✔
326

327

328
def b64_lookup_key(lookup_key: bytes) -> str:
1✔
329
    return base64.urlsafe_b64encode(lookup_key).decode("ascii")
1✔
330

331

332
def derive_reply_keys(message_id: bytes) -> tuple[bytes, bytes]:
1✔
333
    """Derive the lookup key and encrytion key from an aliased message id."""
334
    algorithm = hashes.SHA256()
1✔
335
    hkdf = HKDFExpand(algorithm=algorithm, length=16, info=b"replay replies lookup key")
1✔
336
    lookup_key = hkdf.derive(message_id)
1✔
337
    hkdf = HKDFExpand(
1✔
338
        algorithm=algorithm, length=32, info=b"replay replies encryption key"
339
    )
340
    encryption_key = hkdf.derive(message_id)
1✔
341
    return (lookup_key, encryption_key)
1✔
342

343

344
def encrypt_reply_metadata(key: bytes, payload: dict[str, str]) -> str:
1✔
345
    """Encrypt the given payload into a JWE, using the given key."""
346
    # This is a bit dumb, we have to base64-encode the key in order to load it :-/
347
    k = jwcrypto.jwk.JWK(
1✔
348
        kty="oct", k=base64.urlsafe_b64encode(key).rstrip(b"=").decode("ascii")
349
    )
350
    e = jwcrypto.jwe.JWE(
1✔
351
        json.dumps(payload), json.dumps({"alg": "dir", "enc": "A256GCM"}), recipient=k
352
    )
353
    return cast(str, e.serialize(compact=True))
1✔
354

355

356
def decrypt_reply_metadata(key, jwe):
1✔
357
    """Decrypt the given JWE into a json payload, using the given key."""
358
    # This is a bit dumb, we have to base64-encode the key in order to load it :-/
359
    k = jwcrypto.jwk.JWK(
1✔
360
        kty="oct", k=base64.urlsafe_b64encode(key).rstrip(b"=").decode("ascii")
361
    )
362
    e = jwcrypto.jwe.JWE()
1✔
363
    e.deserialize(jwe)
1✔
364
    e.decrypt(k)
1✔
365
    return e.plaintext
1✔
366

367

368
def _get_bucket_and_key_from_s3_json(message_json):
1✔
369
    # Only Received notifications have S3-stored data
370
    notification_type = message_json.get("notificationType")
1✔
371
    if notification_type != "Received":
1✔
372
        return None, None
1✔
373

374
    if "receipt" in message_json and "action" in message_json["receipt"]:
1!
375
        message_json_receipt = message_json["receipt"]
1✔
376
    else:
377
        logger.error(
×
378
            "sns_inbound_message_without_receipt",
379
            extra={"message_json_keys": message_json.keys()},
380
        )
381
        return None, None
×
382

383
    bucket = None
1✔
384
    object_key = None
1✔
385
    try:
1✔
386
        if "S3" in message_json_receipt["action"]["type"]:
1✔
387
            bucket = message_json_receipt["action"]["bucketName"]
1✔
388
            object_key = message_json_receipt["action"]["objectKey"]
1✔
389
    except (KeyError, TypeError):
×
390
        logger.error(
×
391
            "sns_inbound_message_receipt_malformed",
392
            extra={"receipt_action": message_json_receipt["action"]},
393
        )
394
    return bucket, object_key
1✔
395

396

397
@time_if_enabled("s3_get_message_content")
1✔
398
def get_message_content_from_s3(bucket, object_key):
1✔
399
    if bucket and object_key:
×
NEW
400
        client = s3_client()
×
NEW
401
        if client is None:
×
NEW
402
            raise ValueError("client must not be None")
×
UNCOV
403
        streamed_s3_object = client.get_object(Bucket=bucket, Key=object_key).get(
×
404
            "Body"
405
        )
406
        return streamed_s3_object.read()
×
407

408

409
@time_if_enabled("s3_remove_message_from")
1✔
410
def remove_message_from_s3(bucket, object_key):
1✔
411
    if bucket is None or object_key is None:
1!
412
        return False
1✔
413
    try:
×
NEW
414
        client = s3_client()
×
NEW
415
        if client is None:
×
NEW
416
            raise ValueError("client must not be None")
×
417
        response = client.delete_object(Bucket=bucket, Key=object_key)
×
418
        return response.get("DeleteMarker")
×
419
    except ClientError as e:
×
420
        if e.response["Error"].get("Code", "") == "NoSuchKey":
×
421
            logger.error("s3_delete_object_does_not_exist", extra=e.response["Error"])
×
422
        else:
423
            logger.error("s3_client_error_delete_email", extra=e.response["Error"])
×
424
        incr_if_enabled("message_not_removed_from_s3", 1)
×
425
    return False
×
426

427

428
def set_user_group(user):
1✔
429
    if "@" not in user.email:
1✔
430
        return None
1✔
431
    email_domain = user.email.split("@")[1]
1✔
432
    group_attribute = {
1✔
433
        "mozilla.com": "mozilla_corporation",
434
        "mozillafoundation.org": "mozilla_foundation",
435
        "getpocket.com": "pocket",
436
    }
437
    group_name = group_attribute.get(email_domain)
1✔
438
    if not group_name:
1!
439
        return None
1✔
440
    internal_group_qs = Group.objects.filter(name=group_name)
×
441
    internal_group = internal_group_qs.first()
×
442
    if internal_group is None:
×
443
        return None
×
444
    internal_group.user_set.add(user)
×
445

446

447
def convert_domains_to_regex_patterns(domain_pattern):
1✔
448
    return r"""(["'])(\S*://(\S*\.)*""" + re.escape(domain_pattern) + r"\S*)\1"
1✔
449

450

451
def count_tracker(html_content, trackers):
1✔
452
    tracker_total = 0
1✔
453
    details = {}
1✔
454
    # html_content needs to be str for count()
455
    for tracker in trackers:
1✔
456
        pattern = convert_domains_to_regex_patterns(tracker)
1✔
457
        html_content, count = re.subn(pattern, "", html_content)
1✔
458
        if count:
1✔
459
            tracker_total += count
1✔
460
            details[tracker] = count
1✔
461
    return {"count": tracker_total, "trackers": details}
1✔
462

463

464
def count_all_trackers(html_content):
1✔
465
    general_detail = count_tracker(html_content, general_trackers())
×
466
    strict_detail = count_tracker(html_content, strict_trackers())
×
467

468
    incr_if_enabled("tracker.general_count", general_detail["count"])
×
469
    incr_if_enabled("tracker.strict_count", strict_detail["count"])
×
470
    study_logger.info(
×
471
        "email_tracker_summary",
472
        extra={"level_one": general_detail, "level_two": strict_detail},
473
    )
474

475

476
def remove_trackers(html_content, from_address, datetime_now, level="general"):
1✔
477
    trackers = general_trackers() if level == "general" else strict_trackers()
1✔
478
    tracker_removed = 0
1✔
479
    changed_content = html_content
1✔
480

481
    for tracker in trackers:
1✔
482
        pattern = convert_domains_to_regex_patterns(tracker)
1✔
483

484
        def convert_to_tracker_warning_link(matchobj):
1✔
485
            quote, original_link, _ = matchobj.groups()
1✔
486
            tracker_link_details = {
1✔
487
                "sender": from_address,
488
                "received_at": datetime_now,
489
                "original_link": original_link,
490
            }
491
            anchor = quote_plus(json.dumps(tracker_link_details, separators=(",", ":")))
1✔
492
            url = f"{settings.SITE_ORIGIN}/contains-tracker-warning/#{anchor}"
1✔
493
            return f"{quote}{url}{quote}"
1✔
494

495
        changed_content, matched = re.subn(
1✔
496
            pattern, convert_to_tracker_warning_link, changed_content
497
        )
498
        tracker_removed += matched
1✔
499

500
    level_one_detail = count_tracker(html_content, general_trackers())
1✔
501
    level_two_detail = count_tracker(html_content, strict_trackers())
1✔
502

503
    tracker_details = {
1✔
504
        "tracker_removed": tracker_removed,
505
        "level_one": level_one_detail,
506
    }
507
    logger_details = {"level": level, "level_two": level_two_detail}
1✔
508
    logger_details.update(tracker_details)
1✔
509
    info_logger.info(
1✔
510
        "email_tracker_summary",
511
        extra=logger_details,
512
    )
513
    return changed_content, tracker_details
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc