#!/usr/bin/env python3
import argparse
import copy
import hashlib
import json
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Tuple


def _now_iso() -> str:
    return datetime.now(timezone.utc).replace(microsecond=0).isoformat()


def _norm_str(v: Any) -> Optional[str]:
    if v is None:
        return None
    s = str(v).strip()
    return s if s else None


def _norm_ts(v: Any) -> Optional[str]:
    s = _norm_str(v)
    if not s:
        return None
    try:
        s2 = s.replace("Z", "+00:00")
        dt = datetime.fromisoformat(s2)
        if dt.tzinfo is None:
            dt = dt.replace(tzinfo=timezone.utc)
        return dt.astimezone(timezone.utc).replace(microsecond=0).isoformat()
    except Exception:
        # Keep original when parsing fails, but normalized whitespace
        return s


def _stable_hash(parts: List[Any]) -> str:
    blob = "|".join(json.dumps(p, sort_keys=True, ensure_ascii=False) for p in parts)
    return hashlib.sha1(blob.encode("utf-8")).hexdigest()[:12]


def _obj(*candidates, default=None):
    for c in candidates:
        if c is not None:
            return c
    return default


def _flatten(raw: Dict[str, Any]) -> Dict[str, List[Dict[str, Any]]]:
    sessions = list(raw.get("sessions", []))
    topics = list(raw.get("topics", []))
    messages = list(raw.get("messages", []))
    events = list(raw.get("events", []))

    # Support nested topics/messages/events in sessions
    for s in raw.get("sessions", []):
        sid = _obj(s.get("id"), s.get("session_id"), s.get("external_id"))
        for t in s.get("topics", []) or []:
            tc = copy.deepcopy(t)
            tc.setdefault("session_id", sid)
            topics.append(tc)

    # Support nested messages/events in topics
    for t in topics:
        tid = _obj(t.get("id"), t.get("topic_id"), t.get("external_id"))
        tsid = _obj(t.get("session_id"), t.get("parent_session_id"))
        for m in t.get("messages", []) or []:
            mc = copy.deepcopy(m)
            mc.setdefault("topic_id", tid)
            if tsid is not None:
                mc.setdefault("session_id", tsid)
            messages.append(mc)

    # Support nested events in messages
    for m in messages:
        mid = _obj(m.get("id"), m.get("message_id"), m.get("external_id"))
        msid = m.get("session_id")
        mtid = m.get("topic_id")
        for e in m.get("events", []) or []:
            ec = copy.deepcopy(e)
            ec.setdefault("message_id", mid)
            if msid is not None:
                ec.setdefault("session_id", msid)
            if mtid is not None:
                ec.setdefault("topic_id", mtid)
            events.append(ec)

    return {
        "sessions": sessions,
        "topics": topics,
        "messages": messages,
        "events": events,
    }


def normalize(raw: Dict[str, Any]) -> Dict[str, Any]:
    flat = _flatten(raw)

    out_sessions: List[Dict[str, Any]] = []
    out_topics: List[Dict[str, Any]] = []
    out_messages: List[Dict[str, Any]] = []
    out_events: List[Dict[str, Any]] = []

    # Maps to resolve references by original id/external_id -> canonical id
    session_ref_map: Dict[str, str] = {}
    topic_ref_map: Dict[str, str] = {}
    message_ref_map: Dict[str, str] = {}

    # Dedup indexes
    seen_session: Dict[str, str] = {}
    seen_topic: Dict[str, str] = {}
    seen_message: Dict[str, str] = {}
    seen_event: Dict[str, str] = {}

    raw_counts = {k: len(v) for k, v in flat.items()}
    dedup_dropped = {"sessions": 0, "topics": 0, "messages": 0, "events": 0}

    # Sessions
    for s in flat["sessions"]:
        ext = _norm_str(s.get("external_id"))
        title = _norm_str(_obj(s.get("title"), s.get("name")))
        start_ts = _norm_ts(_obj(s.get("start_ts"), s.get("start_time"), s.get("started_at"), s.get("created_at")))
        end_ts = _norm_ts(_obj(s.get("end_ts"), s.get("end_time"), s.get("ended_at")))

        hash_key = _stable_hash([title, start_ts])
        keys = [f"ext:{ext}"] if ext else []
        keys.append(f"hash:{hash_key}")

        existing_id = next((seen_session[k] for k in keys if k in seen_session), None)
        if existing_id:
            dedup_dropped["sessions"] += 1
            canon_id = existing_id
        else:
            canon_id = f"ses_{len(out_sessions)+1:04d}"
            out_sessions.append({
                "id": canon_id,
                "external_id": ext,
                "title": title,
                "start_ts": start_ts,
                "end_ts": end_ts,
                "attrs": s.get("attrs", {}),
            })

        for k in keys:
            seen_session[k] = canon_id

        for ref in [s.get("id"), s.get("session_id"), s.get("external_id")]:
            r = _norm_str(ref)
            if r:
                session_ref_map[r] = canon_id

    # Topics
    for t in flat["topics"]:
        ext = _norm_str(t.get("external_id"))
        raw_sref = _norm_str(_obj(t.get("session_id"), t.get("parent_session_id")))
        session_id = session_ref_map.get(raw_sref) if raw_sref else None
        title = _norm_str(_obj(t.get("title"), t.get("name"), t.get("subject")))
        created_ts = _norm_ts(_obj(t.get("created_at"), t.get("created_ts"), t.get("ts")))

        hash_key = _stable_hash([session_id, title])
        keys = [f"ext:{ext}"] if ext else []
        keys.append(f"hash:{hash_key}")

        existing_id = next((seen_topic[k] for k in keys if k in seen_topic), None)
        if existing_id:
            dedup_dropped["topics"] += 1
            canon_id = existing_id
        else:
            canon_id = f"top_{len(out_topics)+1:04d}"
            out_topics.append({
                "id": canon_id,
                "external_id": ext,
                "session_id": session_id,
                "title": title,
                "created_ts": created_ts,
                "attrs": t.get("attrs", {}),
            })

        for k in keys:
            seen_topic[k] = canon_id

        for ref in [t.get("id"), t.get("topic_id"), t.get("external_id")]:
            r = _norm_str(ref)
            if r:
                topic_ref_map[r] = canon_id

    # Messages
    for m in flat["messages"]:
        ext = _norm_str(m.get("external_id"))
        raw_sref = _norm_str(m.get("session_id"))
        raw_tref = _norm_str(m.get("topic_id"))
        session_id = session_ref_map.get(raw_sref) if raw_sref else None
        topic_id = topic_ref_map.get(raw_tref) if raw_tref else None

        author_id = _norm_str(_obj(m.get("author_id"), m.get("user_id"), m.get("sender_id")))
        text = _norm_str(_obj(m.get("text"), m.get("body"), m.get("content")))
        ts = _norm_ts(_obj(m.get("ts"), m.get("created_at"), m.get("timestamp")))

        hash_key = _stable_hash([session_id, topic_id, author_id, ts, text])
        keys = [f"ext:{ext}"] if ext else []
        keys.append(f"hash:{hash_key}")

        existing_id = next((seen_message[k] for k in keys if k in seen_message), None)
        if existing_id:
            dedup_dropped["messages"] += 1
            canon_id = existing_id
        else:
            canon_id = f"msg_{len(out_messages)+1:06d}"
            out_messages.append({
                "id": canon_id,
                "external_id": ext,
                "session_id": session_id,
                "topic_id": topic_id,
                "author_id": author_id,
                "ts": ts,
                "text": text,
                "attrs": m.get("attrs", {}),
            })

        for k in keys:
            seen_message[k] = canon_id

        for ref in [m.get("id"), m.get("message_id"), m.get("external_id")]:
            r = _norm_str(ref)
            if r:
                message_ref_map[r] = canon_id

    # Events
    for e in flat["events"]:
        ext = _norm_str(e.get("external_id"))
        raw_sref = _norm_str(e.get("session_id"))
        raw_tref = _norm_str(e.get("topic_id"))
        raw_mref = _norm_str(_obj(e.get("message_id"), e.get("message_ref")))

        session_id = session_ref_map.get(raw_sref) if raw_sref else None
        topic_id = topic_ref_map.get(raw_tref) if raw_tref else None
        message_id = message_ref_map.get(raw_mref) if raw_mref else None

        event_type = _norm_str(_obj(e.get("event_type"), e.get("type"), e.get("name")))
        ts = _norm_ts(_obj(e.get("ts"), e.get("created_at"), e.get("timestamp")))
        payload = e.get("payload") if isinstance(e.get("payload"), dict) else {
            k: v for k, v in e.items() if k not in {
                "id", "event_id", "external_id", "session_id", "topic_id", "message_id", "message_ref",
                "event_type", "type", "name", "ts", "created_at", "timestamp", "payload"
            }
        }

        hash_key = _stable_hash([event_type, message_id, ts, payload])
        keys = [f"ext:{ext}"] if ext else []
        keys.append(f"hash:{hash_key}")
        if any(k in seen_event for k in keys):
            dedup_dropped["events"] += 1
            continue

        canon_id = f"evt_{len(out_events)+1:06d}"
        for k in keys:
            seen_event[k] = canon_id
        out_events.append({
            "id": canon_id,
            "external_id": ext,
            "session_id": session_id,
            "topic_id": topic_id,
            "message_id": message_id,
            "event_type": event_type,
            "ts": ts,
            "payload": payload,
        })

    canonical = {
        "meta": {
            "source": _norm_str(raw.get("source")) or "unknown",
            "generated_at": _now_iso(),
            "counts_raw": raw_counts,
            "counts_canonical": {
                "sessions": len(out_sessions),
                "topics": len(out_topics),
                "messages": len(out_messages),
                "events": len(out_events),
            },
            "dedup_dropped": dedup_dropped,
        },
        "sessions": sorted(out_sessions, key=lambda x: x["id"]),
        "topics": sorted(out_topics, key=lambda x: x["id"]),
        "messages": sorted(out_messages, key=lambda x: x["id"]),
        "events": sorted(out_events, key=lambda x: x["id"]),
    }
    return canonical


def main() -> None:
    parser = argparse.ArgumentParser(description="Normalize raw JSON into canonical entities with basic dedup")
    parser.add_argument("--input", "-i", required=True, help="Path to raw input JSON")
    parser.add_argument("--output", "-o", required=True, help="Path to canonical output JSON")
    parser.add_argument("--pretty", action="store_true", help="Pretty print JSON output")
    args = parser.parse_args()

    with open(args.input, "r", encoding="utf-8") as f:
        raw = json.load(f)

    canonical = normalize(raw)

    with open(args.output, "w", encoding="utf-8") as f:
        if args.pretty:
            json.dump(canonical, f, ensure_ascii=False, indent=2)
            f.write("\n")
        else:
            json.dump(canonical, f, ensure_ascii=False, separators=(",", ":"))


if __name__ == "__main__":
    main()
