• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mozilla / fx-private-relay / c261d5a6-4482-49d5-8f64-a949e5295e9d

18 Apr 2024 02:57PM UTC coverage: 75.479% (-0.1%) from 75.611%
c261d5a6-4482-49d5-8f64-a949e5295e9d

Pull #4612

circleci

rafeerahman
Linter and more test fixes
Pull Request #4612: MPP3779: E2E test fixes and additions

2443 of 3406 branches covered (71.73%)

Branch coverage included in aggregate %.

6767 of 8796 relevant lines covered (76.93%)

20.09 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

76.99
/emails/utils.py
1
from __future__ import annotations
1✔
2
import base64
1✔
3
import contextlib
1✔
4
from email.errors import InvalidHeaderDefect
1✔
5
from email.headerregistry import Address, AddressHeader
1✔
6
from email.message import EmailMessage
1✔
7
from email.utils import formataddr, parseaddr
1✔
8
from functools import cache
1✔
9
from typing import cast, Any, Callable, TypeVar
1✔
10
import json
1✔
11
import pathlib
1✔
12
import re
1✔
13
from django.template.loader import render_to_string
1✔
14
from django.utils.text import Truncator
1✔
15
import requests
1✔
16

17
from botocore.exceptions import ClientError
1✔
18
from cryptography.hazmat.primitives import hashes
1✔
19
from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand
1✔
20
from mypy_boto3_ses.type_defs import ContentTypeDef, SendRawEmailResponseTypeDef
1✔
21
import jwcrypto.jwe
1✔
22
import jwcrypto.jwk
1✔
23
import markus
1✔
24
import logging
1✔
25
from urllib.parse import quote_plus, urlparse
1✔
26

27
from django.conf import settings
1✔
28
from django.contrib.auth.models import Group, User
1✔
29
from django.template.defaultfilters import linebreaksbr, urlize
1✔
30

31
from allauth.socialaccount.models import SocialAccount
1✔
32

33
from privaterelay.plans import get_bundle_country_language_mapping
1✔
34
from privaterelay.utils import get_countries_info_from_lang_and_mapping
1✔
35

36
from .apps import s3_client, ses_client
1✔
37

38
logger = logging.getLogger("events")
1✔
39
info_logger = logging.getLogger("eventsinfo")
1✔
40
study_logger = logging.getLogger("studymetrics")
1✔
41
metrics = markus.get_metrics("fx-private-relay")
1✔
42

43
shavar_prod_lists_url = (
1✔
44
    "https://raw.githubusercontent.com/mozilla-services/shavar-prod-lists/"
45
    "master/disconnect-blacklist.json"
46
)
47
EMAILS_FOLDER_PATH = pathlib.Path(__file__).parent
1✔
48
TRACKER_FOLDER_PATH = EMAILS_FOLDER_PATH / "tracker_lists"
1✔
49

50

51
def ses_message_props(data: str) -> ContentTypeDef:
1✔
52
    return {"Charset": "UTF-8", "Data": data}
1✔
53

54

55
def get_domains_from_settings():
1✔
56
    # HACK: detect if code is running in django tests
57
    if "testserver" in settings.ALLOWED_HOSTS:
1!
58
        return {"RELAY_FIREFOX_DOMAIN": "default.com", "MOZMAIL_DOMAIN": "test.com"}
1✔
59
    return {
×
60
        "RELAY_FIREFOX_DOMAIN": settings.RELAY_FIREFOX_DOMAIN,
61
        "MOZMAIL_DOMAIN": settings.MOZMAIL_DOMAIN,
62
    }
63

64

65
def get_trackers(level):
1✔
66
    category = "Email"
×
67
    tracker_list_name = "level-one-trackers"
×
68
    if level == 2:
×
69
        category = "EmailAggressive"
×
70
        tracker_list_name = "level-two-trackers"
×
71

72
    trackers = []
×
73
    file_name = f"{tracker_list_name}.json"
×
74
    try:
×
75
        with open(TRACKER_FOLDER_PATH / file_name, "r") as f:
×
76
            trackers = json.load(f)
×
77
    except FileNotFoundError:
×
78
        trackers = download_trackers(shavar_prod_lists_url, category)
×
79
        store_trackers(trackers, TRACKER_FOLDER_PATH, file_name)
×
80
    return trackers
×
81

82

83
def download_trackers(repo_url, category="Email"):
1✔
84
    # email tracker lists from shavar-prod-list as per agreed use under license:
85
    resp = requests.get(repo_url)
×
86
    json_resp = resp.json()
×
87
    formatted_trackers = json_resp["categories"][category]
×
88
    trackers = []
×
89
    for entity in formatted_trackers:
×
90
        for _, resources in entity.items():
×
91
            for _, domains in resources.items():
×
92
                trackers.extend(domains)
×
93
    return trackers
×
94

95

96
def store_trackers(trackers, path, file_name):
1✔
97
    with open(path / file_name, "w+") as f:
×
98
        json.dump(trackers, f, indent=4)
×
99

100

101
@cache
1✔
102
def general_trackers():
1✔
103
    return get_trackers(level=1)
×
104

105

106
@cache
1✔
107
def strict_trackers():
1✔
108
    return get_trackers(level=2)
×
109

110

111
_TimedFunction = TypeVar("_TimedFunction", bound=Callable[..., Any])
1✔
112

113

114
def time_if_enabled(name: str) -> Callable[[_TimedFunction], _TimedFunction]:
1✔
115
    def timing_decorator(func: _TimedFunction) -> _TimedFunction:
1✔
116
        def func_wrapper(*args, **kwargs):
1✔
117
            ctx_manager = (
1✔
118
                metrics.timer(name)
119
                if settings.STATSD_ENABLED
120
                else contextlib.nullcontext()
121
            )
122
            with ctx_manager:
1✔
123
                return func(*args, **kwargs)
1✔
124

125
        return cast(_TimedFunction, func_wrapper)
1✔
126

127
    return timing_decorator
1✔
128

129

130
def incr_if_enabled(name, value=1, tags=None):
1✔
131
    if settings.STATSD_ENABLED:
1✔
132
        metrics.incr(name, value, tags)
1✔
133

134

135
def histogram_if_enabled(name, value, tags=None):
1✔
136
    if settings.STATSD_ENABLED:
1✔
137
        metrics.histogram(name, value=value, tags=tags)
1✔
138

139

140
def gauge_if_enabled(name, value, tags=None):
1✔
141
    if settings.STATSD_ENABLED:
1✔
142
        metrics.gauge(name, value, tags)
1✔
143

144

145
def get_email_domain_from_settings() -> str:
1✔
146
    email_network_locality = str(urlparse(settings.SITE_ORIGIN).netloc)
1✔
147
    # on dev server we need to add "mail" prefix
148
    # because we can’t publish MX records on Heroku
149
    if settings.RELAY_CHANNEL == "dev":
1✔
150
        email_network_locality = f"mail.{email_network_locality}"
1✔
151
    return email_network_locality
1✔
152

153

154
def parse_email_header(header_value: str) -> list[tuple[str, str]]:
1✔
155
    """
156
    Extract the (display name, email address) pairs from a header value.
157

158
    This is useful when working with header values provided by a
159
    AWS SES delivery notification.
160

161
    email.utils.parseaddr() works with well-formed emails, but fails in
162
    cases with badly formed emails where an email address could still
163
    be extracted.
164
    """
165
    address_list = AddressHeader.value_parser(header_value)
1✔
166
    pairs: list[tuple[str, str]] = []
1✔
167
    for address in address_list.addresses:
1✔
168
        for mailbox in address.all_mailboxes:
1✔
169
            addr_spec = mailbox.addr_spec
1✔
170
            if addr_spec and addr_spec.count("@") == 1:
1✔
171
                pairs.append((mailbox.display_name or "", addr_spec))
1✔
172
    return pairs
1✔
173

174

175
def _get_hero_img_src(lang_code):
1✔
176
    img_locale = "en"
1✔
177
    avail_l10n_image_codes = [
1✔
178
        "cs",
179
        "de",
180
        "en",
181
        "es",
182
        "fi",
183
        "fr",
184
        "hu",
185
        "id",
186
        "it",
187
        "ja",
188
        "nl",
189
        "pt",
190
        "ru",
191
        "sv",
192
        "zh",
193
    ]
194
    major_lang = lang_code.split("-")[0]
1✔
195
    if major_lang in avail_l10n_image_codes:
1!
196
        img_locale = major_lang
1✔
197

198
    assert settings.SITE_ORIGIN
1✔
199
    return (
1✔
200
        settings.SITE_ORIGIN
201
        + f"/static/images/email-images/first-time-user/hero-image-{img_locale}.png"
202
    )
203

204

205
def get_welcome_email(user: User, format: str) -> str:
1✔
206
    sa = SocialAccount.objects.get(user=user)
1✔
207
    bundle_plans = get_countries_info_from_lang_and_mapping(
1✔
208
        sa.extra_data.get("locale", "en"), get_bundle_country_language_mapping()
209
    )
210
    lang_code = user.profile.language
1✔
211
    hero_img_src = _get_hero_img_src(lang_code)
1✔
212
    return render_to_string(
1✔
213
        f"emails/first_time_user.{format}",
214
        {
215
            "in_bundle_country": bundle_plans["available_in_country"],
216
            "SITE_ORIGIN": settings.SITE_ORIGIN,
217
            "hero_img_src": hero_img_src,
218
            "language": lang_code,
219
        },
220
    )
221

222

223
@time_if_enabled("ses_send_raw_email")
1✔
224
def ses_send_raw_email(
1✔
225
    source_address: str,
226
    destination_address: str,
227
    message: EmailMessage,
228
) -> SendRawEmailResponseTypeDef:
229
    assert (client := ses_client()) is not None
1✔
230
    assert settings.AWS_SES_CONFIGSET
1✔
231

232
    data = message.as_string()
1✔
233
    try:
1✔
234
        ses_response = client.send_raw_email(
1✔
235
            Source=source_address,
236
            Destinations=[destination_address],
237
            RawMessage={"Data": data},
238
            ConfigurationSetName=settings.AWS_SES_CONFIGSET,
239
        )
240
        incr_if_enabled("ses_send_raw_email", 1)
1✔
241
        return ses_response
1✔
242
    except ClientError as e:
1✔
243
        logger.error("ses_client_error_raw_email", extra=e.response["Error"])
1✔
244
        raise
1✔
245

246

247
def urlize_and_linebreaks(text, autoescape=True):
1✔
248
    return linebreaksbr(urlize(text, autoescape=autoescape), autoescape=autoescape)
1✔
249

250

251
def get_reply_to_address(premium: bool = True) -> str:
1✔
252
    """Return the address that relays replies."""
253
    if premium:
1!
254
        _, reply_to_address = parseaddr(
1✔
255
            "replies@%s" % get_domains_from_settings().get("RELAY_FIREFOX_DOMAIN")
256
        )
257
    else:
258
        _, reply_to_address = parseaddr(settings.RELAY_FROM_ADDRESS)
×
259
    return reply_to_address
1✔
260

261

262
def truncate(max_length: int, value: str) -> str:
1✔
263
    """
264
    Truncate a string to a maximum length.
265

266
    If the value is all ASCII, the truncation suffix will be ...
267
    If the value is non-ASCII, the truncation suffix will be … (Unicode ellipsis)
268
    """
269
    if len(value) <= max_length:
1✔
270
        return value
1✔
271
    ellipsis = "..."  # ASCII Ellipsis
1✔
272
    try:
1✔
273
        value.encode("ascii")
1✔
274
    except UnicodeEncodeError:
1✔
275
        ellipsis = "…"
1✔
276
    return Truncator(value).chars(max_length, truncate=ellipsis)
1✔
277

278

279
class InvalidFromHeader(Exception):
1✔
280
    pass
1✔
281

282

283
def generate_from_header(original_from_address: str, relay_mask: str) -> str:
1✔
284
    """
285
    Return a From: header str using the original sender and a display name that
286
    refers to Relay.
287

288
    This format was introduced in June 2023 with MPP-2117.
289
    """
290
    oneline_from_address = (
1✔
291
        original_from_address.replace("\u2028", "").replace("\r", "").replace("\n", "")
292
    )
293
    display_name, original_address = parseaddr(oneline_from_address)
1✔
294
    try:
1✔
295
        parsed_address = Address(addr_spec=original_address)
1✔
296
    except (InvalidHeaderDefect, IndexError) as e:
1✔
297
        # TODO: MPP-3407, MPP-3417 - Determine how to handle these
298
        raise InvalidFromHeader from e
1✔
299

300
    # Truncate the display name to 71 characters, so the sender portion fits on the
301
    # first line of a multi-line "From:" header, if it is ASCII. A utf-8 encoded header
302
    # will be 226 chars, still below the 998 limit of RFC 5322 2.1.1.
303
    max_length = 71
1✔
304

305
    if display_name:
1✔
306
        short_name = truncate(max_length, display_name)
1✔
307
        short_address = truncate(max_length, parsed_address.addr_spec)
1✔
308
        sender = f"{short_name} <{short_address}>"
1✔
309
    else:
310
        # Use the email address if the display name was not originally set
311
        display_name = parsed_address.addr_spec
1✔
312
        sender = truncate(max_length, display_name)
1✔
313
    return formataddr((f"{sender} [via Relay]", relay_mask))
1✔
314

315

316
def get_message_id_bytes(message_id_str: str) -> bytes:
1✔
317
    message_id = message_id_str.split("@", 1)[0].rsplit("<", 1)[-1].strip()
1✔
318
    return message_id.encode()
1✔
319

320

321
def b64_lookup_key(lookup_key: bytes) -> str:
1✔
322
    return base64.urlsafe_b64encode(lookup_key).decode("ascii")
1✔
323

324

325
def derive_reply_keys(message_id: bytes) -> tuple[bytes, bytes]:
1✔
326
    """Derive the lookup key and encrytion key from an aliased message id."""
327
    algorithm = hashes.SHA256()
1✔
328
    hkdf = HKDFExpand(algorithm=algorithm, length=16, info=b"replay replies lookup key")
1✔
329
    lookup_key = hkdf.derive(message_id)
1✔
330
    hkdf = HKDFExpand(
1✔
331
        algorithm=algorithm, length=32, info=b"replay replies encryption key"
332
    )
333
    encryption_key = hkdf.derive(message_id)
1✔
334
    return (lookup_key, encryption_key)
1✔
335

336

337
def encrypt_reply_metadata(key: bytes, payload: dict[str, str]) -> str:
1✔
338
    """Encrypt the given payload into a JWE, using the given key."""
339
    # This is a bit dumb, we have to base64-encode the key in order to load it :-/
340
    k = jwcrypto.jwk.JWK(
1✔
341
        kty="oct", k=base64.urlsafe_b64encode(key).rstrip(b"=").decode("ascii")
342
    )
343
    e = jwcrypto.jwe.JWE(
1✔
344
        json.dumps(payload), json.dumps({"alg": "dir", "enc": "A256GCM"}), recipient=k
345
    )
346
    return cast(str, e.serialize(compact=True))
1✔
347

348

349
def decrypt_reply_metadata(key, jwe):
1✔
350
    """Decrypt the given JWE into a json payload, using the given key."""
351
    # This is a bit dumb, we have to base64-encode the key in order to load it :-/
352
    k = jwcrypto.jwk.JWK(
1✔
353
        kty="oct", k=base64.urlsafe_b64encode(key).rstrip(b"=").decode("ascii")
354
    )
355
    e = jwcrypto.jwe.JWE()
1✔
356
    e.deserialize(jwe)
1✔
357
    e.decrypt(k)
1✔
358
    return e.plaintext
1✔
359

360

361
def _get_bucket_and_key_from_s3_json(message_json):
1✔
362
    # Only Received notifications have S3-stored data
363
    notification_type = message_json.get("notificationType")
1✔
364
    if notification_type != "Received":
1✔
365
        return None, None
1✔
366

367
    if "receipt" in message_json and "action" in message_json["receipt"]:
1!
368
        message_json_receipt = message_json["receipt"]
1✔
369
    else:
370
        logger.error(
×
371
            "sns_inbound_message_without_receipt",
372
            extra={"message_json_keys": message_json.keys()},
373
        )
374
        return None, None
×
375

376
    bucket = None
1✔
377
    object_key = None
1✔
378
    try:
1✔
379
        if "S3" in message_json_receipt["action"]["type"]:
1✔
380
            bucket = message_json_receipt["action"]["bucketName"]
1✔
381
            object_key = message_json_receipt["action"]["objectKey"]
1✔
382
    except (KeyError, TypeError):
×
383
        logger.error(
×
384
            "sns_inbound_message_receipt_malformed",
385
            extra={"receipt_action": message_json_receipt["action"]},
386
        )
387
    return bucket, object_key
1✔
388

389

390
@time_if_enabled("s3_get_message_content")
1✔
391
def get_message_content_from_s3(bucket, object_key):
1✔
392
    if bucket and object_key:
×
393
        assert (client := s3_client()) is not None
×
394
        streamed_s3_object = client.get_object(Bucket=bucket, Key=object_key).get(
×
395
            "Body"
396
        )
397
        return streamed_s3_object.read()
×
398

399

400
@time_if_enabled("s3_remove_message_from")
1✔
401
def remove_message_from_s3(bucket, object_key):
1✔
402
    if bucket is None or object_key is None:
1!
403
        return False
1✔
404
    try:
×
405
        assert (client := s3_client()) is not None
×
406
        response = client.delete_object(Bucket=bucket, Key=object_key)
×
407
        return response.get("DeleteMarker")
×
408
    except ClientError as e:
×
409
        if e.response["Error"].get("Code", "") == "NoSuchKey":
×
410
            logger.error("s3_delete_object_does_not_exist", extra=e.response["Error"])
×
411
        else:
412
            logger.error("s3_client_error_delete_email", extra=e.response["Error"])
×
413
        incr_if_enabled("message_not_removed_from_s3", 1)
×
414
    return False
×
415

416

417
def set_user_group(user):
1✔
418
    if "@" not in user.email:
1✔
419
        return None
1✔
420
    email_domain = user.email.split("@")[1]
1✔
421
    group_attribute = {
1✔
422
        "mozilla.com": "mozilla_corporation",
423
        "mozillafoundation.org": "mozilla_foundation",
424
        "getpocket.com": "pocket",
425
    }
426
    group_name = group_attribute.get(email_domain)
1✔
427
    if not group_name:
1!
428
        return None
1✔
429
    internal_group_qs = Group.objects.filter(name=group_name)
×
430
    internal_group = internal_group_qs.first()
×
431
    if internal_group is None:
×
432
        return None
×
433
    internal_group.user_set.add(user)
×
434

435

436
def convert_domains_to_regex_patterns(domain_pattern):
1✔
437
    return r"""(["'])(\S*://(\S*\.)*""" + re.escape(domain_pattern) + r"\S*)\1"
1✔
438

439

440
def count_tracker(html_content, trackers):
1✔
441
    tracker_total = 0
1✔
442
    details = {}
1✔
443
    # html_content needs to be str for count()
444
    for tracker in trackers:
1✔
445
        pattern = convert_domains_to_regex_patterns(tracker)
1✔
446
        html_content, count = re.subn(pattern, "", html_content)
1✔
447
        if count:
1✔
448
            tracker_total += count
1✔
449
            details[tracker] = count
1✔
450
    return {"count": tracker_total, "trackers": details}
1✔
451

452

453
def count_all_trackers(html_content):
1✔
454
    general_detail = count_tracker(html_content, general_trackers())
×
455
    strict_detail = count_tracker(html_content, strict_trackers())
×
456

457
    incr_if_enabled("tracker.general_count", general_detail["count"])
×
458
    incr_if_enabled("tracker.strict_count", strict_detail["count"])
×
459
    study_logger.info(
×
460
        "email_tracker_summary",
461
        extra={"level_one": general_detail, "level_two": strict_detail},
462
    )
463

464

465
def remove_trackers(html_content, from_address, datetime_now, level="general"):
1✔
466
    trackers = general_trackers() if level == "general" else strict_trackers()
1✔
467
    tracker_removed = 0
1✔
468
    changed_content = html_content
1✔
469

470
    for tracker in trackers:
1✔
471
        pattern = convert_domains_to_regex_patterns(tracker)
1✔
472

473
        def convert_to_tracker_warning_link(matchobj):
1✔
474
            quote, original_link, _ = matchobj.groups()
1✔
475
            tracker_link_details = {
1✔
476
                "sender": from_address,
477
                "received_at": datetime_now,
478
                "original_link": original_link,
479
            }
480
            anchor = quote_plus(json.dumps(tracker_link_details, separators=(",", ":")))
1✔
481
            url = f"{settings.SITE_ORIGIN}/contains-tracker-warning/#{anchor}"
1✔
482
            return f"{quote}{url}{quote}"
1✔
483

484
        changed_content, matched = re.subn(
1✔
485
            pattern, convert_to_tracker_warning_link, changed_content
486
        )
487
        tracker_removed += matched
1✔
488

489
    level_one_detail = count_tracker(html_content, general_trackers())
1✔
490
    level_two_detail = count_tracker(html_content, strict_trackers())
1✔
491

492
    tracker_details = {
1✔
493
        "tracker_removed": tracker_removed,
494
        "level_one": level_one_detail,
495
    }
496
    logger_details = {"level": level, "level_two": level_two_detail}
1✔
497
    logger_details.update(tracker_details)
1✔
498
    info_logger.info(
1✔
499
        "email_tracker_summary",
500
        extra=logger_details,
501
    )
502
    return changed_content, tracker_details
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc