from .db import get_client_db as get_db
import traceback
import json
import os
import time
import hashlib
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
import openai

# ─────────────────────────────────────────────────────────────────────────────
# HELPERS
# ─────────────────────────────────────────────────────────────────────────────

def _run(db_name: str, sql: str, *params):
    try:
        with get_db(db_name) as conn:
            cur = conn.cursor()
            cur.execute(sql, *params) if params else cur.execute(sql)
            cols = [c[0] for c in cur.description]
            return [dict(zip(cols, row)) for row in cur.fetchall()]
    except Exception:
        traceback.print_exc()
        return []

def _scalar(db_name: str, sql: str, *params):
    try:
        with get_db(db_name) as conn:
            cur = conn.cursor()
            cur.execute(sql, *params) if params else cur.execute(sql)
            row = cur.fetchone()
            return row[0] if row else 0
    except Exception:
        traceback.print_exc()
        return 0


# ─────────────────────────────────────────────────────────────────────────────
# IN-MEMORY CACHE  (per client_db, TTL-based)
# Two separate caches:
#   _data_cache    → DB query results  (TTL: 10 min)
#   _insight_cache → LLM insights      (TTL: 30 min)
#
# Key = client_db name  (one entry per client, safe for multi-tenant)
# Each entry: { "data": <payload>, "ts": <epoch float> }
# ─────────────────────────────────────────────────────────────────────────────

DATA_TTL    = 10 * 60   # 10 minutes
INSIGHT_TTL = 30 * 60   # 30 minutes

_data_cache    = {}
_insight_cache = {}
_cache_lock    = threading.Lock()


def _cache_get(store: dict, key: str, ttl: int):
    """Return cached value if fresh, else None."""
    entry = store.get(key)
    if entry and (time.time() - entry["ts"]) < ttl:
        return entry["data"]
    return None


def _cache_set(store: dict, key: str, data):
    with _cache_lock:
        store[key] = {"data": data, "ts": time.time()}


def invalidate_client_cache(client_db: str):
    """Call this if you ever want to force-refresh a specific client."""
    with _cache_lock:
        _data_cache.pop(client_db, None)
        _insight_cache.pop(client_db, None)


# ─────────────────────────────────────────────────────────────────────────────
# KPI CARDS
# ─────────────────────────────────────────────────────────────────────────────

def get_kpi_cards(db_name: str) -> dict:
    # Single query — project stats + supplier count + monthly total
    # (UserSurvey completes aata hai get_survey_stats_combined se, neeche dekho)
    sql = """
        SELECT
            COUNT(*)                                                            AS total_projects,
            SUM(CASE WHEN FlagId = 1 THEN 1 ELSE 0 END)                        AS active_projects,
            ISNULL(AVG(CASE WHEN FlagId=1 AND IR>0
                            THEN CAST(IR AS FLOAT) END), 0)                    AS avg_ir,
            ISNULL(AVG(CASE WHEN FlagId=1 AND AverageLOI>0
                            THEN CAST(AverageLOI AS FLOAT) END), 0)            AS avg_loi,
            ISNULL(AVG(CASE WHEN FlagId=1 AND ProjectCPI>0
                            THEN CAST(ProjectCPI AS FLOAT) END), 0)            AS avg_cpi,
            (SELECT COUNT(DISTINCT SupplierId)
             FROM ProjectSupplierMapping WHERE IsRemoved = 0)                  AS total_suppliers,
            (SELECT ISNULL(SUM(CompleteCount), 0)
             FROM MonthlyStatsTrend
             WHERE Month = MONTH(GETDATE()) AND Year = YEAR(GETDATE()))        AS this_month
        FROM Project
        WHERE ProjectType <> 'G'
    """
    row = (_run(db_name, sql) or [{}])[0]
    return {
        "total_projects":       row.get("total_projects", 0),
        "active_projects":      row.get("active_projects", 0),
        "total_completes":      0,   # get_survey_stats_combined se fill hoga
        "avg_ir":               round(float(row.get("avg_ir", 0)), 1),
        "avg_loi":              round(float(row.get("avg_loi", 0)), 1),
        "avg_cpi":              round(float(row.get("avg_cpi", 0)), 2),
        "total_suppliers":      row.get("total_suppliers", 0),
        "this_month_completes": row.get("this_month", 0),
    }


# ─────────────────────────────────────────────────────────────────────────────
# COMBINED UserSurvey QUERY
# Pehle: 4 alag queries → 4 full table scans
# Ab:    1 query → 1 scan, saare results ek saath
# ─────────────────────────────────────────────────────────────────────────────

def get_survey_stats_combined(db_name: str) -> dict:
    """
    UserSurvey table ek baar scan karke ye sab nikalta hai:
      - total_completes       (KPI card ke liye)
      - quality_funnel        (Status breakdown)
      - device_split          (Device chart)
      - top_countries         (Country chart — via CTE)
      - top_suppliers         (Supplier chart — via CTE)
    """

    # ── Part 1: Aggregates + device + funnel (single scan) ──────────────────
    agg_sql = """
        SELECT
            -- KPI
            SUM(CASE WHEN Status='C' AND IsTestLink=0 THEN 1 ELSE 0 END) AS total_completes,

            -- Quality funnel
            SUM(CASE WHEN Status='C' AND IsTestLink=0 THEN 1 ELSE 0 END) AS funnel_C,
            SUM(CASE WHEN Status='T' AND IsTestLink=0 THEN 1 ELSE 0 END) AS funnel_T,
            SUM(CASE WHEN Status='D' AND IsTestLink=0 THEN 1 ELSE 0 END) AS funnel_D,
            SUM(CASE WHEN Status='Q' AND IsTestLink=0 THEN 1 ELSE 0 END) AS funnel_Q,

            -- Device split (completes only)
            SUM(CASE WHEN Status='C' AND IsTestLink=0
                      AND ISNULL(DeviceType,'Unknown')='Desktop'    THEN 1 ELSE 0 END) AS dev_Desktop,
            SUM(CASE WHEN Status='C' AND IsTestLink=0
                      AND ISNULL(DeviceType,'Unknown')='Mobile'     THEN 1 ELSE 0 END) AS dev_Mobile,
            SUM(CASE WHEN Status='C' AND IsTestLink=0
                      AND ISNULL(DeviceType,'Unknown')='Tablet'     THEN 1 ELSE 0 END) AS dev_Tablet,
            SUM(CASE WHEN Status='C' AND IsTestLink=0
                      AND ISNULL(DeviceType,'Unknown') NOT IN
                          ('Desktop','Mobile','Tablet')             THEN 1 ELSE 0 END) AS dev_Unknown
        FROM UserSurvey
        WHERE IsTestLink = 0
    """

    # ── Part 2: Top 10 countries (indexed on ProjectId + Status) ────────────
    country_sql = """
        SELECT TOP 10
               c.CountryName AS country,
               COUNT(*)      AS completes
        FROM   UserSurvey us  WITH (NOLOCK)
        JOIN   Project    p   WITH (NOLOCK) ON p.ProjectId  = us.ProjectId
        JOIN   Country    c   WITH (NOLOCK) ON c.CountryId  = p.CountryId
        WHERE  us.Status = 'C' AND us.IsTestLink = 0
        GROUP  BY c.CountryName
        ORDER  BY completes DESC
    """

    # ── Part 3: Top 8 suppliers ──────────────────────────────────────────────
    supplier_sql = """
        SELECT TOP 8
               s.SupplierName                             AS supplier,
               COUNT(*)                                   AS completes,
               AVG(CAST(us.SupplierCPI AS FLOAT))         AS avg_cpi
        FROM   UserSurvey us  WITH (NOLOCK)
        JOIN   Supplier   s   WITH (NOLOCK) ON s.SupplierId = us.SupplierId
        WHERE  us.Status = 'C' AND us.IsTestLink = 0
        GROUP  BY s.SupplierName
        ORDER  BY completes DESC
    """

    with ThreadPoolExecutor(max_workers=3) as ex:
        f_agg  = ex.submit(_run, db_name, agg_sql)
        f_ctry = ex.submit(_run, db_name, country_sql)
        f_sup  = ex.submit(_run, db_name, supplier_sql)
        agg_rows  = f_agg.result()  or [{}]
        ctry_rows = f_ctry.result() or []
        sup_rows  = f_sup.result()  or []

    agg = agg_rows[0]

    total_completes = agg.get("total_completes", 0) or 0

    quality_funnel = [
        {"status_label": "C", "cnt": agg.get("funnel_C", 0) or 0},
        {"status_label": "T", "cnt": agg.get("funnel_T", 0) or 0},
        {"status_label": "D", "cnt": agg.get("funnel_D", 0) or 0},
        {"status_label": "Q", "cnt": agg.get("funnel_Q", 0) or 0},
    ]
    quality_funnel = [r for r in quality_funnel if r["cnt"] > 0]

    device_split = []
    for dev in ("Desktop", "Mobile", "Tablet", "Unknown"):
        cnt = agg.get(f"dev_{dev}", 0) or 0
        if cnt > 0:
            device_split.append({"device": dev, "cnt": cnt})
    device_split.sort(key=lambda x: x["cnt"], reverse=True)

    return {
        "total_completes": total_completes,
        "quality_funnel":  quality_funnel,
        "device_split":    device_split,
        "countries":       ctry_rows,
        "top_suppliers":   sup_rows,
    }


# ─────────────────────────────────────────────────────────────────────────────
# CHART QUERIES
# ─────────────────────────────────────────────────────────────────────────────

def get_monthly_trend(db_name: str) -> list:
    sql = """
        SELECT TOP 12
               MonthName, Month, Year,
               SUM(ClickCount)    AS ClickCount,
               SUM(CompleteCount) AS CompleteCount
        FROM   MonthlyStatsTrend
        GROUP  BY MonthName, Month, Year
        ORDER  BY Year DESC, Month DESC
    """
    return list(reversed(_run(db_name, sql)))

def get_project_status_breakdown(db_name: str) -> list:
    sql = """
        SELECT pf.FlagStatus AS [status], COUNT(*) AS cnt
        FROM   Project p
        INNER JOIN ProjectFlag pf ON p.FlagId = pf.FlagId
        WHERE  p.ProjectType <> 'G'
        GROUP  BY pf.FlagStatus
        ORDER  BY cnt DESC
    """
    return _run(db_name, sql)

def get_cpi_ir_scatter(db_name: str) -> list:
    sql = """
        SELECT TOP 50
               ProjectName               AS name,
               CAST(IR AS FLOAT)         AS ir,
               CAST(ProjectCPI AS FLOAT) AS cpi,
               ProjectType               AS ptype
        FROM   Project
        WHERE  FlagId = 1 AND IR > 0 AND ProjectCPI > 0 AND ProjectType <> 'G'
        ORDER  BY CreateDate DESC
    """
    return _run(db_name, sql)

def get_project_type_mix(db_name: str) -> list:
    sql = """
        SELECT ISNULL(ProjectType,'Unknown') AS ptype, COUNT(*) AS cnt
        FROM   Project
        WHERE  FlagId = 1
        GROUP  BY ProjectType
        ORDER  BY cnt DESC
    """
    return _run(db_name, sql)

def get_loi_distribution(db_name: str) -> list:
    sql = """
        SELECT
          CASE
            WHEN AverageLOI <= 5  THEN '0-5 min'
            WHEN AverageLOI <= 10 THEN '6-10 min'
            WHEN AverageLOI <= 15 THEN '11-15 min'
            WHEN AverageLOI <= 20 THEN '16-20 min'
            ELSE '20+ min'
          END AS bucket,
          COUNT(*) AS cnt
        FROM Project
        WHERE FlagId = 1 AND AverageLOI > 0
        GROUP BY
          CASE
            WHEN AverageLOI <= 5  THEN '0-5 min'
            WHEN AverageLOI <= 10 THEN '6-10 min'
            WHEN AverageLOI <= 15 THEN '11-15 min'
            WHEN AverageLOI <= 20 THEN '16-20 min'
            ELSE '20+ min'
          END
        ORDER BY MIN(AverageLOI)
    """
    return _run(db_name, sql)

def get_recent_projects(db_name: str) -> list:
    sql = """
        SELECT TOP 10
               ProjectCode, ProjectName,
               ProjectType, IR, AverageLOI, ProjectCPI,
               SampleSize, CreateDate
        FROM   Project
        WHERE  FlagId = 1 AND ProjectType <> 'G'
        ORDER  BY CreateDate DESC
    """
    rows = _run(db_name, sql)
    for r in rows:
        if r.get("CreateDate"):
            r["CreateDate"] = str(r["CreateDate"])[:10]
    return rows


# ─────────────────────────────────────────────────────────────────────────────
# LLM INSIGHT GENERATOR  (openai v0.x compatible — e.g. v0.8.0)
# ─────────────────────────────────────────────────────────────────────────────

LLM_MODEL = "gpt-4o-mini"


import urllib.request  # already in stdlib, no install needed

def generate_chart_insight(chart_name: str, data, kpis: dict = None) -> str:
    try:
        payload = {
            "chart": chart_name,
            "data":  data[:20] if isinstance(data, list) else data,
        }
        if kpis:
            payload["kpis"] = kpis

        system = (
            "You are a senior market research analyst. "
            "Given the chart name and data, write 2-3 concise, business-focused sentences "
            "highlighting the most important trend, anomaly, or opportunity. "
            "Use plain language. No bullet points. No markdown. Keep under 60 words."
        )

        api_key = os.getenv("OPENAI_API_KEY", "")
        request_body = json.dumps({
            "model": "gpt-3.5-turbo",
            "messages": [
                {"role": "system", "content": system},
                {"role": "user",   "content": json.dumps(payload)},
            ],
            "max_tokens": 120,
            "temperature": 0.4,
        }).encode("utf-8")

        req = urllib.request.Request(
            "https://api.openai.com/v1/chat/completions",
            data=request_body,
            headers={
                "Content-Type":  "application/json",
                "Authorization": f"Bearer {api_key}",
            },
            method="POST"
        )
        with urllib.request.urlopen(req, timeout=15) as resp:
            result = json.loads(resp.read().decode("utf-8"))
            return result["choices"][0]["message"]["content"].strip()

    except Exception:
        traceback.print_exc()
        return ""


# ─────────────────────────────────────────────────────────────────────────────
# PUBLIC: get_dashboard_data  — with per-client caching (TTL: 10 min)
# ─────────────────────────────────────────────────────────────────────────────

def get_dashboard_data(db_name: str, force: bool = False) -> dict:
    """
    Returns chart data for the given client DB.
    Result is cached in-memory for DATA_TTL seconds (default 10 min).
    Pass force=True to bypass cache (e.g. on manual refresh).

    Optimization: UserSurvey table sirf 3 targeted queries mein hit hoti hai
    (combined function) — pehle 4 alag full-scans the.
    """
    if not force:
        cached = _cache_get(_data_cache, db_name, DATA_TTL)
        if cached is not None:
            return cached

    # 7 parallel DB calls (pehle 11 the, aur UserSurvey 4 baar scan hoti thi)
    with ThreadPoolExecutor(max_workers=7) as ex:
        f_kpis    = ex.submit(get_kpi_cards,               db_name)
        f_survey  = ex.submit(get_survey_stats_combined,   db_name)  # ← combined
        f_monthly = ex.submit(get_monthly_trend,           db_name)
        f_status  = ex.submit(get_project_status_breakdown,db_name)
        f_cpi_ir  = ex.submit(get_cpi_ir_scatter,          db_name)
        f_ptype   = ex.submit(get_project_type_mix,        db_name)
        f_loi     = ex.submit(get_loi_distribution,        db_name)
        f_recent  = ex.submit(get_recent_projects,         db_name)

        kpis        = f_kpis.result()
        survey      = f_survey.result()
        monthly     = f_monthly.result()
        status_pie  = f_status.result()
        cpi_ir      = f_cpi_ir.result()
        ptype_mix   = f_ptype.result()
        loi_dist    = f_loi.result()
        recent_proj = f_recent.result()

    # combined survey results se KPI card ka total_completes fill karo
    kpis["total_completes"] = survey.get("total_completes", 0)

    result = {
        "kpis":            kpis,
        "monthly":         monthly,
        "status_pie":      status_pie,
        "countries":       survey.get("countries",      []),
        "quality_funnel":  survey.get("quality_funnel", []),
        "device_split":    survey.get("device_split",   []),
        "top_suppliers":   survey.get("top_suppliers",  []),
        "cpi_ir":          cpi_ir,
        "ptype_mix":       ptype_mix,
        "loi_dist":        loi_dist,
        "recent_projects": recent_proj,
    }

    _cache_set(_data_cache, db_name, result)
    return result


# ─────────────────────────────────────────────────────────────────────────────
# PUBLIC: get_dashboard_insights  — with per-client caching (TTL: 30 min)
# ─────────────────────────────────────────────────────────────────────────────

def get_dashboard_insights(chart_data: dict, client_db: str = "", force: bool = False) -> dict:
    """
    Generates LLM insights for all charts.
    Cached per client_db for INSIGHT_TTL seconds (default 30 min).
    Uses a fingerprint of the data so cache auto-invalidates when data changes.
    """
    fingerprint_src = json.dumps(
        {
            "monthly":   chart_data.get("monthly",    [])[:3],
            "kpis":      chart_data.get("kpis",       {}),
            "status":    chart_data.get("status_pie", [])[:3],
        },
        default=str
    )
    fingerprint = hashlib.md5(fingerprint_src.encode()).hexdigest()[:12]
    cache_key   = f"{client_db}:{fingerprint}"

    if not force:
        cached = _cache_get(_insight_cache, cache_key, INSIGHT_TTL)
        if cached is not None:
            return cached

    kpis = chart_data.get("kpis", {})
    insight_tasks = {
        "monthly":        ("Monthly Clicks & Completes Trend", chart_data.get("monthly",        []), kpis),
        "status_pie":     ("Project Status Breakdown",         chart_data.get("status_pie",     []), None),
        "countries":      ("Top Countries by Completes",       chart_data.get("countries",      []), None),
        "quality_funnel": ("Survey Quality & Fraud Funnel",    chart_data.get("quality_funnel", []), None),
        "device_split":   ("Device Type Split",                chart_data.get("device_split",   []), None),
        "top_suppliers":  ("Top Suppliers by Completes",       chart_data.get("top_suppliers",  []), None),
        "cpi_ir":         ("CPI vs IR Scatter",                chart_data.get("cpi_ir",         []), None),
        "ptype_mix":      ("Project Type Mix",                 chart_data.get("ptype_mix",      []), None),
        "loi_dist":       ("LOI Distribution",                 chart_data.get("loi_dist",       []), None),
    }

    insights = {}
    with ThreadPoolExecutor(max_workers=9) as ex:
        futures = {
            ex.submit(generate_chart_insight, v[0], v[1], v[2]): k
            for k, v in insight_tasks.items()
        }
        for future in as_completed(futures):
            key = futures[future]
            try:
                insights[key] = future.result()
            except Exception:
                traceback.print_exc()
                insights[key] = ""

    _cache_set(_insight_cache, cache_key, insights)
    return insights


# ─────────────────────────────────────────────────────────────────────────────
# PUBLIC: stream_dashboard_insights  — SSE generator
# Yields each insight as soon as LLM responds (no waiting for all 9).
# Use this in a Flask Response(stream_with_context(...)) route.
# ─────────────────────────────────────────────────────────────────────────────

def stream_dashboard_insights(chart_data: dict, client_db: str = "", force: bool = False):
    """
    Generator that yields SSE-formatted strings.
    Each event: data: {"key": "<chart_key>", "text": "<insight>"}

    Checks insight cache first — if fresh, streams from cache instantly
    (no LLM calls needed). Otherwise calls LLM in parallel and yields
    each insight as it arrives.

    Usage in Flask route:
        return Response(stream_with_context(
            stream_dashboard_insights(chart_data, client_db=client_db)
        ), mimetype="text/event-stream")
    """
    fingerprint_src = json.dumps(
        {
            "monthly": chart_data.get("monthly",    [])[:3],
            "kpis":    chart_data.get("kpis",       {}),
            "status":  chart_data.get("status_pie", [])[:3],
        },
        default=str
    )
    fingerprint = hashlib.md5(fingerprint_src.encode()).hexdigest()[:12]
    cache_key   = f"{client_db}:{fingerprint}"

    # ── Cache hit: stream instantly from stored insights ──────────────────
    if not force:
        cached = _cache_get(_insight_cache, cache_key, INSIGHT_TTL)
        if cached is not None:
            for key, text in cached.items():
                yield f"data: {json.dumps({'key': key, 'text': text, 'cached': True})}\n\n"
            yield "data: {\"done\": true}\n\n"
            return

    # ── Cache miss: call LLM in parallel, yield each as it arrives ────────
    kpis = chart_data.get("kpis", {})
    insight_tasks = {
        "monthly":        ("Monthly Clicks & Completes Trend", chart_data.get("monthly",        []), kpis),
        "status_pie":     ("Project Status Breakdown",         chart_data.get("status_pie",     []), None),
        "countries":      ("Top Countries by Completes",       chart_data.get("countries",      []), None),
        "quality_funnel": ("Survey Quality & Fraud Funnel",    chart_data.get("quality_funnel", []), None),
        "device_split":   ("Device Type Split",                chart_data.get("device_split",   []), None),
        "top_suppliers":  ("Top Suppliers by Completes",       chart_data.get("top_suppliers",  []), None),
        "cpi_ir":         ("CPI vs IR Scatter",                chart_data.get("cpi_ir",         []), None),
        "ptype_mix":      ("Project Type Mix",                 chart_data.get("ptype_mix",      []), None),
        "loi_dist":       ("LOI Distribution",                 chart_data.get("loi_dist",       []), None),
    }

    collected = {}

    with ThreadPoolExecutor(max_workers=9) as ex:
        futures = {
            ex.submit(generate_chart_insight, v[0], v[1], v[2]): k
            for k, v in insight_tasks.items()
        }
        for future in as_completed(futures):
            key = futures[future]
            try:
                text = future.result() or ""
            except Exception:
                traceback.print_exc()
                text = ""
            collected[key] = text
            yield f"data: {json.dumps({'key': key, 'text': text, 'cached': False})}\n\n"

    # Store in cache for next request
    _cache_set(_insight_cache, cache_key, collected)
    yield "data: {\"done\": true}\n\n"