#!/usr/bin/env python3
"""
Generator pliku CDA XML (e-skierowanie MP, profil v0.3) z minimalnego JSON-a.

  python generate_skierowanie_mp.py --input examples/minimal_skierowanie.json --output out.xml

OID-y i domyślne root identyfikatorów: test_profile_defaults.json (scalane z wejściem).
"""

from __future__ import annotations

import argparse
import json
import re
import sys
from datetime import date, datetime
from pathlib import Path
from typing import Any
from xml.sax.saxutils import escape

try:
    import jsonschema
except ImportError:
    jsonschema = None  # type: ignore[misc, assignment]

from jinja2 import Environment, FileSystemLoader, select_autoescape

MONTHS_PL = (
    "",
    "stycznia",
    "lutego",
    "marca",
    "kwietnia",
    "maja",
    "czerwca",
    "lipca",
    "sierpnia",
    "września",
    "października",
    "listopada",
    "grudnia",
)

TABLE_UNIT_LABEL = {
    "Cel": "°C",
    "dB": "dB",
    "mg/m3": "mg/m³",
    "h/d": "h/dobę",
}


def xesc(s: str) -> str:
    return escape(s, {'"': "&quot;", "'": "&apos;"})


def format_number_pl(n: float) -> str:
    s = f"{n:.10f}".rstrip("0").rstrip(".")
    if "." in s:
        intp, frac = s.split(".", 1)
        return f"{intp},{frac}"
    return s


def format_number_xml(n: float) -> str:
    return f"{n:.10f}".rstrip("0").rstrip(".")


def parse_date(iso: str) -> date:
    return date.fromisoformat(iso.strip())


def format_date_pl(iso: str) -> str:
    d = parse_date(iso)
    return f"{d.day} {MONTHS_PL[d.month]} {d.year} r."


def date_to_hl7_d8(iso: str) -> str:
    d = parse_date(iso)
    return d.strftime("%Y%m%d")


def iso_to_hl7_effective(iso: str) -> str:
    s = iso.strip()
    if s.endswith("Z"):
        s = s[:-1] + "+00:00"
    dt = datetime.fromisoformat(s)
    if dt.tzinfo is None:
        raise ValueError("effective_time musi zawierać strefę czasową, np. +01:00 lub Z")
    off = dt.utcoffset()
    assert off is not None
    total = int(off.total_seconds())
    sign = "+" if total >= 0 else "-"
    total = abs(total)
    hh, rem = divmod(total, 3600)
    mm = rem // 60
    tz = f"{sign}{hh:02d}{mm:02d}"
    return dt.strftime(f"%Y%m%d%H%M%S{tz}")


def _validate_required_keys(d: dict[str, Any], keys: list[str], ctx: str) -> None:
    for k in keys:
        if k not in d:
            raise ValueError(f"Brak pola {ctx}.{k}")


def validate_input_basic(data: dict[str, Any]) -> None:
    _validate_required_keys(data, ["document", "patient", "author", "employer", "purpose", "position", "exposure_groups", "additional_exams"], "root")
    _validate_required_keys(data["document"], ["id_extension", "effective_time"], "document")
    if not data["exposure_groups"]:
        raise ValueError("exposure_groups nie może być puste")
    if not data["additional_exams"]:
        raise ValueError("additional_exams nie może być puste")
    if data["patient"]["gender"] not in ("M", "F"):
        raise ValueError("patient.gender musi być M lub F")


def validate_with_jsonschema(data: dict[str, Any], schema_path: Path) -> None:
    if jsonschema is None:
        return
    with schema_path.open(encoding="utf-8") as f:
        schema = json.load(f)
    jsonschema.validate(instance=data, schema=schema)


def unit_label_table(unit: str) -> str:
    return TABLE_UNIT_LABEL.get(unit, unit)


def build_measured_narrative_lines(
    measured_display: str | None,
    measured_value: dict[str, Any] | None,
    exposure_time: dict[str, Any] | None,
) -> list[str]:
    lines: list[str] = []
    if measured_display:
        lines.append(measured_display)
    elif measured_value and measured_value.get("value") is not None:
        v = float(measured_value["value"])
        u = str(measured_value["unit"])
        lines.append(f"{format_number_pl(v)} {unit_label_table(u)}")
    if exposure_time:
        ev = float(exposure_time["value"])
        eu = str(exposure_time["unit"])
        if eu == "h/d":
            lines.append(f"({format_number_pl(ev)} h/dobę)")
        else:
            lines.append(f"({format_number_pl(ev)} {unit_label_table(eu)})")
    return lines


def build_exposure_bundle(data: dict[str, Any]) -> dict[str, Any]:
    """Buduje struktury pod narrację tabeli, uwagi i entry organizer."""
    idx = 0
    footnote_counter = 0
    table_rows: list[dict[str, Any]] = []
    footnotes: list[dict[str, Any]] = []
    groups_entry: list[dict[str, Any]] = []

    for grp in data["exposure_groups"]:
        g_factors: list[dict[str, Any]] = []
        for fac in grp["factors"]:
            idx += 1
            content_id = f"OBS_CZN_{idx}"
            desc_id = f"OBS_CZN_{idx}_DESC"
            extra = (fac.get("extra_description") or "").strip()
            has_extra = bool(extra)
            footnote_num: int | None = None
            footnote_el = ""
            if has_extra:
                footnote_counter += 1
                footnote_num = footnote_counter
                fn_id = f"UWAGA_{footnote_num}"
                footnote_el = f' <footnote ID="{fn_id}">{footnote_num}</footnote>'
                footnotes.append(
                    {
                        "footnote_id": fn_id,
                        "footnote_num": footnote_num,
                        "desc_id": desc_id,
                        "text": xesc(extra),
                    }
                )

            measured_lines = build_measured_narrative_lines(
                fac.get("measured_display"),
                fac.get("measured_value"),
                fac.get("exposure_time"),
            )
            if not measured_lines:
                measured_cell = "—"
            else:
                measured_cell = "<br/>".join(xesc(x) for x in measured_lines)

            norm = fac.get("norm")
            if has_extra and footnote_num is not None:
                norm_cell = xesc(f"patrz uwaga {footnote_num}")
            elif norm and norm.get("value") is not None:
                nv = float(norm["value"])
                nu = str(norm["unit"])
                norm_cell = xesc(f"{format_number_pl(nv)} {unit_label_table(nu)}")
            else:
                norm_cell = "—"

            mdate = fac.get("measurement_date")
            if mdate:
                date_cell = xesc(format_date_pl(mdate))
            else:
                date_cell = "—"

            table_rows.append(
                {
                    "group_display": xesc(grp["group_display"]),
                    "content_id": content_id,
                    "factor_name": xesc(fac["display_name"]),
                    "footnote_html": footnote_el,
                    "measured_cell": measured_cell,
                    "norm_cell": norm_cell,
                    "date_cell": date_cell,
                }
            )

            measured_val = fac.get("measured_value")
            norm_val = fac.get("norm")
            et = fac.get("exposure_time")

            eff = ""
            if mdate:
                eff = date_to_hl7_d8(mdate)

            pq_val = ""
            pq_unit = ""
            has_value = False
            if measured_val and measured_val.get("value") is not None:
                has_value = True
                pq_val = format_number_xml(float(measured_val["value"]))
                pq_unit = str(measured_val["unit"])

            has_norm = bool(norm_val and norm_val.get("value") is not None)
            norm_xml_val = format_number_xml(float(norm_val["value"])) if has_norm else ""
            norm_xml_unit = str(norm_val["unit"]) if has_norm else ""

            has_et = bool(et and et.get("value") is not None)
            et_xml_val = format_number_xml(float(et["value"])) if has_et else ""
            et_xml_unit = str(et["unit"]) if has_et else ""

            g_factors.append(
                {
                    "content_id": content_id,
                    "desc_id": desc_id,
                    "factor_code": xesc(fac["factor_code"]),
                    "display_name": xesc(fac["display_name"]),
                    "effective_time": eff,
                    "has_effective": bool(eff),
                    "has_value": has_value,
                    "pq_value": pq_val,
                    "pq_unit": pq_unit,
                    "has_norm": has_norm,
                    "norm_value": norm_xml_val,
                    "norm_unit": norm_xml_unit,
                    "has_exposure_time": has_et,
                    "et_value": et_xml_val,
                    "et_unit": et_xml_unit,
                    "has_extra": has_extra,
                    "extra_text": xesc(extra),
                }
            )

        groups_entry.append(
            {
                "group_code": xesc(grp["group_code"]),
                "group_display": xesc(grp["group_display"]),
                "factors": g_factors,
            }
        )

    return {
        "exposure_table_rows": table_rows,
        "exposure_footnotes": footnotes,
        "exposure_groups_entry": groups_entry,
        "exposure_factor_count": idx,
    }


def build_context(data: dict[str, Any]) -> dict[str, Any]:
    ids = data["identifiers"]
    doc = data["document"]
    patient = data["patient"]
    author = data["author"]
    employer = data["employer"]
    purpose = data["purpose"]
    position = data["position"]

    effective_hl7 = iso_to_hl7_effective(doc["effective_time"])
    id_ext = doc["id_extension"]
    set_ext = doc.get("set_id_extension") or f"{id_ext}-SET"
    version = int(doc.get("version_number") or 1)

    last_iso = purpose.get("last_exam_date")
    has_last = bool(last_iso)
    last_pl = format_date_pl(last_iso) if has_last and last_iso else ""
    last_hl7 = date_to_hl7_d8(last_iso) if has_last and last_iso else ""

    valid_iso = purpose.get("current_exam_valid_until")
    has_valid = bool(valid_iso)
    valid_pl = format_date_pl(valid_iso) if has_valid and valid_iso else ""
    valid_hl7 = date_to_hl7_d8(valid_iso) if has_valid and valid_iso else ""

    addr_p = patient["address"]
    addr_e = employer["address"]

    exams = []
    for i, ex in enumerate(data["additional_exams"], start=1):
        exams.append(
            {
                "idx": i,
                "content_id": f"OBS_DOD_{i}",
                "icd9": xesc(ex["icd9_code"]),
                "display_name": xesc(ex["display_name"]),
                "narrative_suffix": xesc(ex["narrative_line"]),
            }
        )

    bundle = build_exposure_bundle(data)

    add_info = data.get("additional_info")
    if add_info is None:
        add_info_html = ""
        has_additional_info = False
    else:
        add_info_html = xesc(str(add_info))
        has_additional_info = bool(str(add_info).strip())

    pos_sys = ids.get("position_code_system") or data.get("position_code_system")

    ctx = {
        "document_id_extension": xesc(id_ext),
        "document_id_root": xesc(ids["document_id_root"]),
        "set_id_extension": xesc(set_ext),
        "effective_hl7": effective_hl7,
        "version_number": version,
        "patient_hr_extension": xesc(patient.get("hr_id_extension") or ""),
        "patient_hr_root": xesc(ids["patient_hr_id_root"]),
        "has_patient_hr": bool((patient.get("hr_id_extension") or "").strip()),
        "patient_pesel": xesc(patient["pesel"]),
        "pesel_root": xesc(ids["pesel_root"]),
        "addr_p_city": xesc(addr_p["city"]),
        "addr_p_postal": xesc(addr_p["postal_code"]),
        "addr_p_street": xesc(addr_p["street_name"]),
        "addr_p_house": xesc(addr_p["house_number"]),
        "has_unit_id": bool((addr_p.get("unit_id") or "").strip()),
        "addr_p_unit": xesc(addr_p.get("unit_id") or ""),
        "patient_given": xesc(patient["given"]),
        "patient_family": xesc(patient["family"]),
        "patient_gender": patient["gender"],
        "birth_hl7": date_to_hl7_d8(patient["birth_date"]),
        "author_pesel": xesc(author["pesel"]),
        "has_author_prefix": bool((author.get("prefix") or "").strip()),
        "author_prefix": xesc(author.get("prefix") or ""),
        "author_given": xesc(author["given"]),
        "author_family": xesc(author["family"]),
        "employer_regon": xesc(employer["regon"]),
        "regon_root": xesc(ids["regon_root"]),
        "employer_internal": xesc(employer["internal_id"]),
        "employer_internal_root": xesc(ids["employer_internal_id_root"]),
        "employer_name": xesc(employer["name"]),
        "employer_telecom": xesc(employer["telecom"]),
        "addr_e_country": xesc(addr_e["country"]),
        "addr_e_city": xesc(addr_e["city"]),
        "addr_e_postal": xesc(addr_e["postal_code"]),
        "addr_e_street": xesc(addr_e["street_name"]),
        "addr_e_house": xesc(addr_e["house_number"]),
        "purpose_exam_code": xesc(purpose["exam_type_code"]),
        "purpose_exam_display": xesc(purpose["exam_type_display"]),
        "has_last_exam": has_last,
        "last_exam_pl": last_pl,
        "last_exam_hl7": last_hl7,
        "has_valid_exam": has_valid,
        "valid_exam_pl": valid_pl,
        "valid_exam_hl7": valid_hl7,
        "position_code": xesc(position["position_code"]),
        "position_display": xesc(position["position_display"]),
        "position_description": xesc(position["description"]),
        "position_code_system": xesc(pos_sys or ""),
        "additional_exams": exams,
        "has_additional_info": has_additional_info,
        "additional_info_html": add_info_html,
        **bundle,
    }
    return ctx


def load_merged_input(path: Path) -> dict[str, Any]:
    tools_dir = Path(__file__).resolve().parent
    profile_path = tools_dir / "test_profile_defaults.json"
    with profile_path.open(encoding="utf-8") as f:
        profile = json.load(f)
    with path.open(encoding="utf-8") as f:
        user = json.load(f)
    merged = json.loads(json.dumps(user))
    id_prof = profile.get("identifiers") or {}
    merged.setdefault("identifiers", {})
    for k, v in id_prof.items():
        merged["identifiers"].setdefault(k, v)
    return merged


def check_reference_consistency(xml_text: str) -> None:
    content_ids = set(re.findall(r'<content ID="([^"]+)"', xml_text))
    refs = set(re.findall(r'<reference value="#([^"]+)"', xml_text))
    missing = refs - content_ids
    if missing:
        raise RuntimeError(f"Referencje bez pasującego <content ID>: {sorted(missing)}")
def main() -> int:
    ap = argparse.ArgumentParser(description="Generuj skierowanie MP v0.3 (CDA XML) z pliku JSON.")
    ap.add_argument("--input", "-i", type=Path, required=True, help="Plik JSON z danymi")
    ap.add_argument("--output", "-o", type=Path, required=True, help="Ścieżka wyjściowego XML")
    ap.add_argument(
        "--no-validate-schema",
        action="store_true",
        help="Pomiń walidację jsonschema (jeśli biblioteka jest dostępna)",
    )
    args = ap.parse_args()

    data = load_merged_input(args.input.resolve())
    validate_input_basic(data)

    tools_dir = Path(__file__).resolve().parent
    schema_path = tools_dir / "skierowanie_mp_input.schema.json"
    if not args.no_validate_schema and schema_path.is_file():
        validate_with_jsonschema(data, schema_path)

    ctx = build_context(data)

    env = Environment(
        loader=FileSystemLoader(tools_dir / "templates"),
        autoescape=select_autoescape(enabled_extensions=()),
        trim_blocks=True,
        lstrip_blocks=True,
    )
    tpl = env.get_template("skierowanie_mp_v0.3.xml.j2")
    xml_out = tpl.render(**ctx)

    check_reference_consistency(xml_out)

    args.output.parent.mkdir(parents=True, exist_ok=True)
    args.output.write_text(xml_out, encoding="utf-8")
    return 0


if __name__ == "__main__":
    sys.exit(main())
