"""Heuristic risk scoring. This is intentionally simple — a weighted feature scorer. Each feature returns a 0-100 sub-score; the overall score is a weighted sum. We keep memory of seen fingerprints per guest so subsequent accesses can be compared against the baseline established by the first one. """ from __future__ import annotations import hashlib import math from dataclasses import dataclass, field from datetime import datetime from typing import Any from uuid import UUID from app.geo import GeoLocation from app.schemas import AccessAttempted LOW = "low" MEDIUM = "medium" HIGH = "high" BLOCK = "block" def risk_band(score: int) -> str: if score <= 30: return LOW if score <= 60: return MEDIUM if score <= 85: return HIGH return BLOCK @dataclass class GuestBaseline: fingerprint_digest: str | None = None ip_prefix: str | None = None accesses: int = 0 # Tier 2 Block G — geo_jump. Stash the most recent coordinates + the # timestamp of the access they came from so the next access can be # compared against them. last_lat: float | None = None last_lon: float | None = None last_geo_at: datetime | None = None last_country: str | None = None @dataclass class ScoringResult: score: int reasons: list[str] geo: GeoLocation | None = None @dataclass class HeuristicScorer: weights: dict[str, float] = field( default_factory=lambda: { "fingerprint_mismatch": 0.40, "ip_change": 0.25, "missing_signals": 0.10, "repeated_access": 0.10, "no_user_agent": 0.15, # Tier 2 Block G — geo_jump. Implausibly fast travel between # two accesses (>500 km in <1h) carries the heaviest weight # alongside fingerprint mismatch. Note the weights are not # required to sum to 1; the final score is clamped to # [0, 100] so the relative magnitudes are what matters. "geo_jump": 0.40, } ) baselines: dict[UUID, GuestBaseline] = field(default_factory=dict) def score(self, evt: AccessAttempted, geo: GeoLocation | None = None) -> ScoringResult: reasons: list[str] = [] sub: dict[str, int] = {} baseline = self.baselines.get(evt.guest_id, GuestBaseline()) current_digest = _fingerprint_digest(evt.fingerprint) current_prefix = _ip_prefix(evt.ip_address) if baseline.fingerprint_digest is None: sub["fingerprint_mismatch"] = 0 elif baseline.fingerprint_digest == current_digest: sub["fingerprint_mismatch"] = 0 else: sub["fingerprint_mismatch"] = 100 reasons.append("fingerprint differs from baseline") if baseline.ip_prefix is None: sub["ip_change"] = 0 elif baseline.ip_prefix == current_prefix: sub["ip_change"] = 0 else: sub["ip_change"] = 80 reasons.append("ip address changed since first access") if not evt.fingerprint: sub["missing_signals"] = 70 reasons.append("no device fingerprint provided") else: sub["missing_signals"] = 0 sub["repeated_access"] = min(baseline.accesses * 10, 60) if baseline.accesses >= 5: reasons.append(f"token accessed {baseline.accesses + 1} times") if not evt.user_agent: sub["no_user_agent"] = 80 reasons.append("missing user agent") else: sub["no_user_agent"] = 0 # geo_jump — implausibly fast travel between this access and the # previous one. 500 km in under an hour means either a private # jet or, far more likely, a stolen invitation being opened by # someone in a different country. Spec threshold from # docs/TIER2_PLAN.md Block G. sub["geo_jump"] = 0 if ( geo is not None and geo.lat is not None and geo.lon is not None and baseline.last_lat is not None and baseline.last_lon is not None and baseline.last_geo_at is not None ): km = _haversine_km( baseline.last_lat, baseline.last_lon, geo.lat, geo.lon ) dt = (evt.occurred_at - baseline.last_geo_at).total_seconds() if km > 500 and 0 < dt < 3600: sub["geo_jump"] = 100 where_now = geo.city or geo.country or "elsewhere" where_before = baseline.last_country or "another location" mins = max(int(dt / 60), 1) reasons.append( f"accessed from {where_before} and {where_now} within {mins} minutes" ) elif km > 500 and dt < 21600: # within 6h is still suspicious-but-possible sub["geo_jump"] = 50 reasons.append(f"large geographic jump ({int(km)} km)") weighted = sum(sub[k] * self.weights.get(k, 0) for k in sub) final = int(round(min(max(weighted, 0), 100))) # Tier 2 Block G — tighten the consecutive-fingerprint false # positive. Pre-Block-G, a guest opening their invitation a second # time with even a slightly-shifted device fingerprint (browser # update, different network) would score ~60 (HIGH band): the # fingerprint_mismatch sub-score of 100 × 0.40 weight = 40, plus a # tiny baseline of repeated_access, easily tipped them over. # # The rule: a single signal can't push the score into HIGH (>= # configured high threshold). It takes at least *two* sub-scores # of >= 70 to escalate. The API re-bands using per-event # thresholds, but we still cap at 55 here so a single signal # caps at MEDIUM regardless of how strict the host has set their # band boundaries. strong_signals = sum(1 for v in sub.values() if v >= 70) if strong_signals < 2 and final > 55: final = 55 reasons.append("single-signal cap applied (need ≥2 signals for HIGH)") # Update baseline AFTER scoring so the first access sets it without # being penalised against itself. if baseline.fingerprint_digest is None: baseline.fingerprint_digest = current_digest if baseline.ip_prefix is None: baseline.ip_prefix = current_prefix if geo is not None and geo.lat is not None and geo.lon is not None: baseline.last_lat = geo.lat baseline.last_lon = geo.lon baseline.last_geo_at = evt.occurred_at baseline.last_country = geo.city or geo.country baseline.accesses += 1 self.baselines[evt.guest_id] = baseline return ScoringResult(score=final, reasons=reasons, geo=geo) def _haversine_km(lat1: float, lon1: float, lat2: float, lon2: float) -> float: """Great-circle distance in kilometres. Earth radius 6371 km.""" rlat1, rlat2 = math.radians(lat1), math.radians(lat2) dlat = math.radians(lat2 - lat1) dlon = math.radians(lon2 - lon1) a = math.sin(dlat / 2) ** 2 + math.cos(rlat1) * math.cos(rlat2) * math.sin(dlon / 2) ** 2 return 2 * 6371.0 * math.asin(math.sqrt(a)) def _fingerprint_digest(fp: dict[str, Any] | None) -> str | None: if not fp: return None items = sorted((str(k), str(v)) for k, v in fp.items()) h = hashlib.sha256() for k, v in items: h.update(k.encode()) h.update(b"=") h.update(v.encode()) h.update(b";") return h.hexdigest() def _ip_prefix(ip: str | None) -> str | None: if not ip: return None if ":" in ip: # IPv6 — keep first 4 hextets parts = ip.split(":")[:4] return ":".join(parts) parts = ip.split(".") if len(parts) == 4: return ".".join(parts[:3]) return ip