import traceback
import decimal
import datetime
from typing import Any

from .db import get_client_db as get_db

# ─────────────────────────────────────────────────────────────────────────────
# HELPERS
# ─────────────────────────────────────────────────────────────────────────────

def _run(db_name: str, sql: str, params: tuple = ()) -> list:
    try:
        with get_db(db_name) as conn:
            cur = conn.cursor()
            if params:
                cur.execute(sql, params)
            else:
                cur.execute(sql)
            if cur.description is None:
                return []
            cols = [c[0] for c in cur.description]
            return [dict(zip(cols, row)) for row in cur.fetchall()]
    except Exception:
        traceback.print_exc()
        return []


def _serialize_row(row: dict) -> dict:
    """Convert Decimal/date types to JSON-safe primitives."""
    out = {}
    for k, v in row.items():
        if isinstance(v, decimal.Decimal):
            out[k] = float(v)
        elif isinstance(v, (datetime.date, datetime.datetime)):
            out[k] = str(v)[:10]
        else:
            out[k] = v
    return out


def _call_proc(
    db_name: str,
    project_id = None,
    parent_project_id = None,
):
    """
    Execute dbo.GetProjectDetail_Dashboard and return all result sets
    merged as a flat list of dicts, with all types serialised.
    """
    try:
        with get_db(db_name) as conn:
            cur = conn.cursor()
            cur.execute(
                "EXEC dbo.GetProjectDetail_Dashboard @ProjectId=?, @ParentProjectId=?",
                project_id,
                parent_project_id,
            )
            # Advance past any result sets that have no description
            while cur.description is None:
                if not cur.nextset():
                    return []
            cols = [c[0] for c in cur.description]
            rows = [_serialize_row(dict(zip(cols, row))) for row in cur.fetchall()]
            return rows
    except Exception:
        traceback.print_exc()
        return []


# ─────────────────────────────────────────────────────────────────────────────
# PROJECT SEARCH  (kept — reads Project table, not the dropped detail table)
# ─────────────────────────────────────────────────────────────────────────────

def search_projects_for_analytics(db_name: str, q: str, limit: int = 20) -> list:
    like = f"%{q.strip()}%"
    sql = f"""
        SELECT TOP ({int(limit)})
            ProjectId   AS id,
            ProjectCode AS code,
            ProjectName AS name,
            ProjectType AS project_type
        FROM dbo.Project
        WHERE (CAST(ProjectCode AS NVARCHAR) LIKE ? OR ProjectName LIKE ?)
          AND ProjectType IN ('G','C','S')
        ORDER BY
            CASE WHEN CAST(ProjectCode AS NVARCHAR) = ? THEN 0 ELSE 1 END,
            ProjectName
    """
    return _run(db_name, sql, (like, like, q.strip()))


# ─────────────────────────────────────────────────────────────────────────────
# PROC-BASED ANALYTICS  — all aggregation done from stored-proc rows
# ─────────────────────────────────────────────────────────────────────────────

def _safe_div(num: float, den: float, pct: bool = False, decimals: int = 2):
    if not den:
        return None
    result = num / den
    if pct:
        result *= 100
    return round(result, decimals)


def get_analytics_from_proc(
    db_name: str,
    project_id = None,
    parent_project_id = None,
) -> dict:
    """
    Single entry-point for the analytics tab.
    Calls the stored procedure once, then computes every KPI / chart dataset
    that the frontend needs — no table queries at all.
    """
    rows = _call_proc(db_name, project_id, parent_project_id)
    if not rows:
        return {"error": "no_data", "rows": []}

    total       = len(rows)
    completes   = [r for r in rows if r.get("is_complete") == 1]
    terminates  = [r for r in rows if r.get("is_terminate") == 1]
    drops       = [r for r in rows if r.get("is_drop") == 1]
    frauds      = [r for r in rows if r.get("is_fraud") == 1]

    n_comp  = len(completes)
    n_term  = len(terminates)
    n_drop  = len(drops)
    n_fraud = len(frauds)
    denom   = n_comp + n_term + n_drop

    # ── KPIs ─────────────────────────────────────────────────────────────────
    project_codes = {str(r.get("project_code", "")).strip() for r in rows if r.get("project_code")}

    kpis = {
        "total_projects":   len(project_codes),
        "total_responses":  total,
        "completes":        n_comp,
        "terminates":       n_term,
        "drops":            n_drop,
        "fraud_events":     n_fraud,
        "conversion_rate":  _safe_div(n_comp, denom, pct=True),
        "fraud_rate":       _safe_div(n_fraud, total, pct=True),
        "avg_loi_min":      None,
    }

    # ── Fraud breakdown ───────────────────────────────────────────────────────
    FRAUD_ORDER = [
        "Duplicate IP", "Duplicate Supplier User", "Proxy Validation",
        "RDText Failure", "GEO IP Mismatch", "RD Failure",
        "Fraud User", "Rejected Identifier",
    ]
    fraud_status_map = {}
    for r in frauds:
        s = r.get("Status") or r.get("status") or "Unknown"
        fraud_status_map[s] = fraud_status_map.get(s, 0) + 1

    fraud_breakdown = {
        "total_flagged": n_fraud,
        "fraud_pct":     _safe_div(n_fraud, total, pct=True),
        "breakdown":     [
            {"status": s, "count": fraud_status_map.get(s, 0)}
            for s in FRAUD_ORDER
        ],
    }

    # ── Financial ────────────────────────────────────────────────────────────
    def _fsum(row_list, field):
        return sum(float(r.get(field) or 0) for r in row_list)

    proj_spend  = _fsum(completes, "project_cpi")
    sup_cost    = _fsum(completes, "supplier_cpi")
    fraud_waste = _fsum(frauds,    "supplier_cpi")
    margin      = proj_spend - sup_cost
    margin_pct  = _safe_div(margin, proj_spend, decimals=4)
    avg_proj_cpi = _safe_div(proj_spend, n_comp)
    avg_sup_cpi  = _safe_div(sup_cost,   n_comp)
    eff_cost     = _safe_div(sup_cost + fraud_waste, n_comp, decimals=4)

    financial = {
        "total_project_spend":               round(proj_spend,  2),
        "total_supplier_cost":               round(sup_cost,    2),
        "gross_margin":                      round(margin,      2),
        "margin_pct":                        margin_pct,
        "avg_project_cpi":                   avg_proj_cpi,
        "avg_supplier_cpi":                  avg_sup_cpi,
        "cost_wasted_on_fraud":              round(fraud_waste, 2),
        "effective_cost_per_clean_complete": eff_cost,
    }

    # ── LOI ──────────────────────────────────────────────────────────────────
    loi_vals  = [float(r["loi_min"]) for r in completes if (r.get("loi_min") or 0) > 0]
    term_lois = [float(r["loi_min"]) for r in terminates if (r.get("loi_min") or 0) > 0]
    drop_lois = [float(r["loi_min"]) for r in drops      if (r.get("loi_min") or 0) >= 0]

    avg_loi = _safe_div(sum(loi_vals), len(loi_vals), decimals=1) if loi_vals else None
    kpis["avg_loi_min"] = avg_loi

    sorted_loi = sorted(loi_vals)
    n_loi = len(sorted_loi)
    if n_loi:
        mid = n_loi // 2
        median_loi = (sorted_loi[mid] if n_loi % 2 else
                      round((sorted_loi[mid - 1] + sorted_loi[mid]) / 2, 1))
    else:
        median_loi = None

    speeders = [v for v in loi_vals if v < 3]
    loi_kpi = {
        "avg_loi_completes":    avg_loi,
        "median_loi_completes": median_loi,
        "max_loi_complete":     round(max(loi_vals), 1) if loi_vals else None,
        "min_loi_complete":     round(min(loi_vals), 1) if loi_vals else None,
        "speeders":             len(speeders),
        "speeder_rate_pct":     _safe_div(len(speeders), len(loi_vals), pct=True),
        "avg_loi_terminates":   _safe_div(sum(term_lois), len(term_lois), decimals=1) if term_lois else None,
        "avg_loi_drops":        _safe_div(sum(drop_lois), len(drop_lois), decimals=1) if drop_lois else None,
    }

    # ── Quality ───────────────────────────────────────────────────────────────
    def _scores(row_list, field):
        return [float(r[field]) for r in row_list if r.get(field) is not None]

    comp_scores  = _scores(completes, "composite_score")
    fraud_scores = _scores(frauds,    "composite_score")
    thr_scores   = [float(r["threat_score"]) for r in rows if r.get("threat_score") is not None]
    thr_p_scores = [float(r["threat_potential_score"]) for r in frauds if r.get("threat_potential_score") is not None]
    rd_fail      = [r for r in rows if (r.get("Status") or r.get("status")) == "RD Failure"]
    rd_thr       = _scores(rd_fail, "threat_score")

    quality = {
        "avg_composite_completes":       _safe_div(sum(comp_scores),  len(comp_scores)),
        "avg_composite_fraud":           _safe_div(sum(fraud_scores), len(fraud_scores)),
        "avg_threat_score_rd_failures":  _safe_div(sum(rd_thr),       len(rd_thr)),
        "avg_threat_potential_flagged":  _safe_div(sum(thr_p_scores), len(thr_p_scores)),
        "high_threat_count":             sum(1 for v in thr_scores   if v >= 50),
        "high_threat_potential_count":   sum(1 for v in thr_p_scores if v >= 75),
        "quality_pass_rate":             _safe_div(n_comp, n_comp + n_fraud, pct=True),
    }

    # ── Funnel ────────────────────────────────────────────────────────────────
    after_fraud     = total - n_fraud
    reached_survey  = n_comp + n_term
    funnel = {
        "gross_traffic":     total,
        "after_fraud":       after_fraud,
        "reached_survey":    reached_survey,
        "passed_prescreen":  reached_survey,
        "completed_survey":  n_comp,
        "incidence_rate":    _safe_div(n_comp, n_comp + n_term, pct=True, decimals=1),
        "effective_yield":   _safe_div(n_comp, total, pct=True, decimals=1),
    }

    # ── Status distribution ───────────────────────────────────────────────────
    status_map = {}
    for r in rows:
        s = r.get("Status") or r.get("status") or "Unknown"
        status_map[s] = status_map.get(s, 0) + 1
    status_distribution = sorted(
        [{"status": s, "count": c, "pct": _safe_div(c, total, pct=True)}
         for s, c in status_map.items()],
        key=lambda x: x["count"], reverse=True,
    )

    # ── Device distribution ───────────────────────────────────────────────────
    dev_map = {}
    for r in rows:
        d = (r.get("device_type") or "Unknown").strip() or "Unknown"
        if d not in dev_map:
            dev_map[d] = {"device_type": d, "total": 0, "completes": 0}
        dev_map[d]["total"] += 1
        if r.get("is_complete") == 1:
            dev_map[d]["completes"] += 1
    device_distribution = sorted(dev_map.values(), key=lambda x: x["total"], reverse=True)
    for d in device_distribution:
        d["complete_rate"] = _safe_div(d["completes"], d["total"], pct=True)

    # ── Country distribution ──────────────────────────────────────────────────
    ctry_map = {}
    for r in rows:
        c = (r.get("Country") or r.get("project_country") or "Unknown").strip() or "Unknown"
        if c not in ctry_map:
            ctry_map[c] = {"country": c, "total": 0, "completes": 0, "terminates": 0, "drops": 0}
        ctry_map[c]["total"] += 1
        if r.get("is_complete")  == 1: ctry_map[c]["completes"]  += 1
        if r.get("is_terminate") == 1: ctry_map[c]["terminates"] += 1
        if r.get("is_drop")      == 1: ctry_map[c]["drops"]      += 1
    country_list = sorted(ctry_map.values(), key=lambda x: x["completes"], reverse=True)[:10]
    for c in country_list:
        cd = c["completes"] + c["terminates"] + c["drops"]
        c["conversion_rate"] = _safe_div(c["completes"], cd, pct=True)

    # ── Supplier performance ──────────────────────────────────────────────────
    sup_map = {}
    for r in rows:
        s = (r.get("supplier_name") or "Unknown").strip() or "Unknown"
        if s not in sup_map:
            sup_map[s] = {"supplier_name": s, "total": 0, "completes": 0,
                          "terminates": 0, "drops": 0, "fraud_events": 0}
        sup_map[s]["total"] += 1
        if r.get("is_complete")  == 1: sup_map[s]["completes"]   += 1
        if r.get("is_terminate") == 1: sup_map[s]["terminates"]  += 1
        if r.get("is_drop")      == 1: sup_map[s]["drops"]       += 1
        if r.get("is_fraud")     == 1: sup_map[s]["fraud_events"] += 1
    supplier_list = sorted(sup_map.values(), key=lambda x: x["total"], reverse=True)
    for s in supplier_list:
        sd = s["completes"] + s["terminates"] + s["drops"]
        s["conv_pct"]  = _safe_div(s["completes"],   sd,        pct=True)
        s["fraud_pct"] = _safe_div(s["fraud_events"], s["total"], pct=True)

    # ── Top projects by completes ─────────────────────────────────────────────
    proj_map = {}
    for r in rows:
        pc = str(r.get("project_code") or "Unknown").strip() or "Unknown"
        co = (r.get("Country") or r.get("project_country") or "Unknown").strip() or "Unknown"
        key = pc
        if key not in proj_map:
            proj_map[key] = {"project_code": pc, "country": co,
                             "completes": 0, "terminates": 0, "drops": 0, "total": 0}
        proj_map[key]["total"] += 1
        if r.get("is_complete")  == 1: proj_map[key]["completes"]  += 1
        if r.get("is_terminate") == 1: proj_map[key]["terminates"] += 1
        if r.get("is_drop")      == 1: proj_map[key]["drops"]      += 1
    top_projects = sorted(proj_map.values(), key=lambda x: x["completes"], reverse=True)[:10]
    for p in top_projects:
        pd = p["completes"] + p["terminates"] + p["drops"]
        p["conv_pct"] = _safe_div(p["completes"], pd, pct=True, decimals=1)

    # ── Time trends ───────────────────────────────────────────────────────────
    date_set = set()
    daily_map = {}
    daily_comp = {}
    for r in rows:
        rd = r.get("response_date")
        if rd:
            d = str(rd)[:10]
            date_set.add(d)
            daily_map[d] = daily_map.get(d, 0) + 1
            if r.get("is_complete") == 1:
                daily_comp[d] = daily_comp.get(d, 0) + 1

    field_start = min(date_set) if date_set else "—"
    field_end   = max(date_set) if date_set else "—"
    n_days      = len(date_set)
    peak_day    = max(daily_map, key=daily_map.get) if daily_map else "—"
    peak_count  = daily_map.get(peak_day, 0)

    time_trends = {
        "field_start":         field_start,
        "field_end":           field_end,
        "total_field_days":    n_days,
        "avg_daily_traffic":   round(total / n_days, 1) if n_days else 0,
        "avg_daily_completes": round(n_comp / n_days, 1) if n_days else 0,
        "peak_traffic_day":    peak_day,
        "peak_traffic_count":  peak_count,
    }

    # ── Daily time series (for charts) ────────────────────────────────────────
    time_series = sorted(
        [{"date": d, "total": daily_map[d], "completes": daily_comp.get(d, 0)}
         for d in date_set],
        key=lambda x: x["date"],
    )

    # ── Country × Device ──────────────────────────────────────────────────────
    cd_map = {}
    for r in rows:
        c  = (r.get("Country") or r.get("project_country") or "Unknown").strip() or "Unknown"
        dt = (r.get("device_type") or "Unknown").lower()
        if c not in cd_map:
            cd_map[c] = {"country": c, "desktop_traffic": 0, "mobile_traffic": 0,
                         "tablet_traffic": 0, "desktop_completes": 0,
                         "mobile_completes": 0, "tablet_completes": 0, "total_traffic": 0}
        cd_map[c]["total_traffic"] += 1
        is_c = r.get("is_complete") == 1
        if "desktop" in dt:
            cd_map[c]["desktop_traffic"] += 1
            if is_c: cd_map[c]["desktop_completes"] += 1
        elif "mobile" in dt:
            cd_map[c]["mobile_traffic"] += 1
            if is_c: cd_map[c]["mobile_completes"] += 1
        elif "tablet" in dt:
            cd_map[c]["tablet_traffic"] += 1
            if is_c: cd_map[c]["tablet_completes"] += 1
    country_device = sorted(cd_map.values(), key=lambda x: x["total_traffic"], reverse=True)
    for cd in country_device:
        devs = {"Desktop": cd["desktop_traffic"],
                "Mobile":  cd["mobile_traffic"],
                "Tablet":  cd["tablet_traffic"]}
        cd["best_device"] = max(devs, key=devs.get)

    # ── Supplier performance modal (with TOTAL row) ───────────────────────────
    sup_total = {
        "supplier_name": "TOTAL",
        "total":         sum(s["total"]        for s in supplier_list),
        "completes":     sum(s["completes"]     for s in supplier_list),
        "terminates":    sum(s["terminates"]    for s in supplier_list),
        "drops":         sum(s["drops"]         for s in supplier_list),
        "fraud_events":  sum(s["fraud_events"]  for s in supplier_list),
    }
    td = sup_total["completes"] + sup_total["terminates"] + sup_total["drops"]
    sup_total["conv_pct"]  = _safe_div(sup_total["completes"],    td,                  pct=True)
    sup_total["fraud_pct"] = _safe_div(sup_total["fraud_events"], sup_total["total"],  pct=True)

    # ── Failures (non-complete statuses only) ─────────────────────────────────
    failures_raw = [s for s in status_distribution
                    if s["status"].lower() not in ("complete", "completes")][:10]
    failures = [{"status": f["status"], "count": f["count"],
                 "pct_of_traffic": _safe_div(f["count"], total, pct=True, decimals=1)}
                for f in failures_raw]

    return {
        "kpis":                kpis,
        "fraud_breakdown":     fraud_breakdown,
        "financial":           financial,
        "loi":                 loi_kpi,
        "quality":             quality,
        "funnel":              funnel,
        "status_distribution": status_distribution,
        "device_distribution": device_distribution,
        "country_distribution":country_list,
        "supplier_performance":supplier_list,
        "supplier_modal":      {"suppliers": supplier_list, "total": sup_total},
        "top_projects":        top_projects,
        "time_trends":         time_trends,
        "time_series":         time_series,
        "country_device":      country_device,
        "failures":            failures,
        "total_rows":          total,
    }

# ─────────────────────────────────────────────────────────────────────────────
# get_project_detail_proc  
# ─────────────────────────────────────────────────────────────────────────────

def get_project_detail_proc(
    db_name: str,
    project_id = None,
    parent_project_id = None,
) -> list:
    return _call_proc(db_name, project_id, parent_project_id)