"""Heuristic risk scoring. This is intentionally simple — a weighted feature scorer. Each feature returns a 0-100 sub-score; the overall score is a weighted sum. We keep memory of seen fingerprints per guest so subsequent accesses can be compared against the baseline established by the first one. """ from __future__ import annotations import hashlib from dataclasses import dataclass, field from typing import Any from uuid import UUID from app.schemas import AccessAttempted LOW = "low" MEDIUM = "medium" HIGH = "high" BLOCK = "block" def risk_band(score: int) -> str: if score <= 30: return LOW if score <= 60: return MEDIUM if score <= 85: return HIGH return BLOCK @dataclass class GuestBaseline: fingerprint_digest: str | None = None ip_prefix: str | None = None accesses: int = 0 @dataclass class ScoringResult: score: int reasons: list[str] @dataclass class HeuristicScorer: weights: dict[str, float] = field( default_factory=lambda: { "fingerprint_mismatch": 0.40, "ip_change": 0.25, "missing_signals": 0.10, "repeated_access": 0.10, "no_user_agent": 0.15, } ) baselines: dict[UUID, GuestBaseline] = field(default_factory=dict) def score(self, evt: AccessAttempted) -> ScoringResult: reasons: list[str] = [] sub: dict[str, int] = {} baseline = self.baselines.get(evt.guest_id, GuestBaseline()) current_digest = _fingerprint_digest(evt.fingerprint) current_prefix = _ip_prefix(evt.ip_address) if baseline.fingerprint_digest is None: sub["fingerprint_mismatch"] = 0 elif baseline.fingerprint_digest == current_digest: sub["fingerprint_mismatch"] = 0 else: sub["fingerprint_mismatch"] = 100 reasons.append("fingerprint differs from baseline") if baseline.ip_prefix is None: sub["ip_change"] = 0 elif baseline.ip_prefix == current_prefix: sub["ip_change"] = 0 else: sub["ip_change"] = 80 reasons.append("ip address changed since first access") if not evt.fingerprint: sub["missing_signals"] = 70 reasons.append("no device fingerprint provided") else: sub["missing_signals"] = 0 sub["repeated_access"] = min(baseline.accesses * 10, 60) if baseline.accesses >= 5: reasons.append(f"token accessed {baseline.accesses + 1} times") if not evt.user_agent: sub["no_user_agent"] = 80 reasons.append("missing user agent") else: sub["no_user_agent"] = 0 weighted = sum(sub[k] * self.weights[k] for k in self.weights) final = int(round(min(max(weighted, 0), 100))) # Tier 2 Block G — tighten the consecutive-fingerprint false # positive. Pre-Block-G, a guest opening their invitation a second # time with even a slightly-shifted device fingerprint (browser # update, different network) would score ~60 (HIGH band): the # fingerprint_mismatch sub-score of 100 × 0.40 weight = 40, plus a # tiny baseline of repeated_access, easily tipped them over. # # The rule: a single signal can't push the score into HIGH (>= # configured high threshold). It takes at least *two* sub-scores # of >= 70 to escalate. The API re-bands using per-event # thresholds, but we still cap at 55 here so a single signal # caps at MEDIUM regardless of how strict the host has set their # band boundaries. strong_signals = sum(1 for v in sub.values() if v >= 70) if strong_signals < 2 and final > 55: final = 55 reasons.append("single-signal cap applied (need ≥2 signals for HIGH)") # Update baseline AFTER scoring so the first access sets it without # being penalised against itself. if baseline.fingerprint_digest is None: baseline.fingerprint_digest = current_digest if baseline.ip_prefix is None: baseline.ip_prefix = current_prefix baseline.accesses += 1 self.baselines[evt.guest_id] = baseline return ScoringResult(score=final, reasons=reasons) def _fingerprint_digest(fp: dict[str, Any] | None) -> str | None: if not fp: return None items = sorted((str(k), str(v)) for k, v in fp.items()) h = hashlib.sha256() for k, v in items: h.update(k.encode()) h.update(b"=") h.update(v.encode()) h.update(b";") return h.hexdigest() def _ip_prefix(ip: str | None) -> str | None: if not ip: return None if ":" in ip: # IPv6 — keep first 4 hextets parts = ip.split(":")[:4] return ":".join(parts) parts = ip.split(".") if len(parts) == 4: return ".".join(parts[:3]) return ip