#!/usr/bin/env python3
"""
Function-word bigram analysis with length-matched bootstrapping.
Baseline: Neville letters XML. Compare to plays 1590-1615 (with H8 sections).
"""

import csv
import re
import sqlite3
from collections import Counter, defaultdict

import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

DB_PLAYS = "../early_modern_plays.db"
NEVILLE_XML = "../Neville_Letters_Corpus_v3.xml"

OUTPUT_FILE = "Neville_Bigram_Functionwords_Bootstrap_1590_1615.csv"
OUTPUT_SORTED = "Neville_Bigram_Functionwords_Bootstrap_1590_1615_Sorted.csv"

YEAR_START = 1590
YEAR_END = 1615
MFW_COUNT = 200
MIN_FUNC_TOKENS = 5000
WINDOW_SIZE = 10000
NEVILLE_WINDOWS = 50
PLAY_WINDOWS = 50
RANDOM_SEED = 123

HENRY_VIII_PLAY_ID = 502
SHAKESPEARE_DIV = 54
FLETCHER_DIV = 109


def has_alpha(token):
    return re.search(r"[A-Za-z]", token) is not None


def extract_neville_lemmas(path):
    with open(path, "r", encoding="utf-8", errors="ignore") as f:
        xml = f.read()
    lemma_matches = re.findall(r'lemma=["\']([^"\']+)["\']', xml, flags=re.IGNORECASE)
    return [l.lower().strip() for l in lemma_matches if l and has_alpha(l)]


def get_corpus_mfw(cursor, top_n):
    cursor.execute(
        """
        SELECT LOWER(w.A0) AS token, COUNT(*) as c
        FROM words w
        JOIN plays p ON w.PLAY_ID = p.PLAY_ID
        WHERE p.CREATION_YEAR BETWEEN 1580 AND 1620
        AND w.A0 IS NOT NULL
        GROUP BY token
        ORDER BY c DESC
        LIMIT ?
        """,
        (top_n * 5,),
    )
    rows = cursor.fetchall()
    mfw = []
    for token, _count in rows:
        if token and has_alpha(token):
            mfw.append(token)
        if len(mfw) >= top_n:
            break
    return set(mfw)


def get_play_metadata(cursor):
    cursor.execute(
        """
        SELECT PLAY_ID, TITLE, CREATION_YEAR
        FROM plays
        WHERE CREATION_YEAR BETWEEN ? AND ?
        """,
        (YEAR_START, YEAR_END),
    )
    return cursor.fetchall()


def get_play_tokens(cursor, play_ids, vocab):
    if not play_ids:
        return {}
    ids_str = ",".join(map(str, play_ids))
    cursor.execute(
        f"SELECT PLAY_ID, A0 FROM words WHERE PLAY_ID IN ({ids_str}) AND A0 IS NOT NULL ORDER BY PLAY_ID, TWN"
    )
    tokens = defaultdict(list)
    for pid, word in cursor:
        if word and has_alpha(word):
            w = word.lower()
            if w in vocab:
                tokens[pid].append(w)
    return tokens


def get_division_tokens(cursor, play_id, division_id, vocab):
    cursor.execute(
        """
        SELECT A0 FROM words
        WHERE PLAY_ID = ? AND DIVISION_ID = ? AND A0 IS NOT NULL
        ORDER BY TWN
        """,
        (play_id, division_id),
    )
    tokens = []
    for (w,) in cursor:
        if w and has_alpha(w):
            w = w.lower()
            if w in vocab:
                tokens.append(w)
    return tokens


def build_bigram_counts(tokens):
    counts = Counter()
    for i in range(len(tokens) - 1):
        counts[(tokens[i], tokens[i + 1])] += 1
    return counts


def vectorize(counts, bigram_index):
    vec = np.zeros(len(bigram_index), dtype=float)
    total = sum(counts.values()) or 1
    for bg, c in counts.items():
        idx = bigram_index.get(bg)
        if idx is not None:
            vec[idx] = c / total
    return vec


def sample_windows(tokens, window_size, n_windows, rng):
    if len(tokens) <= window_size:
        return [tokens]
    max_start = len(tokens) - window_size
    starts = rng.integers(0, max_start + 1, size=n_windows)
    return [tokens[s : s + window_size] for s in starts]


def main():
    rng = np.random.default_rng(RANDOM_SEED)

    conn = sqlite3.connect(DB_PLAYS)
    cursor = conn.cursor()

    vocab = get_corpus_mfw(cursor, MFW_COUNT)

    plays_meta = get_play_metadata(cursor)
    play_ids = [p[0] for p in plays_meta]
    titles = {p[0]: p[1] for p in plays_meta}
    years = {p[0]: p[2] for p in plays_meta}

    neville_tokens_all = [t for t in extract_neville_lemmas(NEVILLE_XML) if t in vocab]
    if len(neville_tokens_all) < MIN_FUNC_TOKENS:
        print("Neville has too few function-word tokens.")
        return

    play_tokens = get_play_tokens(cursor, play_ids, vocab)

    # Henry VIII sections
    h8_shakespeare_tokens = get_division_tokens(cursor, HENRY_VIII_PLAY_ID, SHAKESPEARE_DIV, vocab)
    h8_fletcher_tokens = get_division_tokens(cursor, HENRY_VIII_PLAY_ID, FLETCHER_DIV, vocab)

    # Build bigram vocabulary from all plays + Neville + H8 sections
    bigram_vocab = Counter()
    bigram_vocab.update(build_bigram_counts(neville_tokens_all))
    for toks in play_tokens.values():
        if len(toks) >= MIN_FUNC_TOKENS:
            bigram_vocab.update(build_bigram_counts(toks))
    if len(h8_shakespeare_tokens) >= MIN_FUNC_TOKENS:
        bigram_vocab.update(build_bigram_counts(h8_shakespeare_tokens))
    if len(h8_fletcher_tokens) >= MIN_FUNC_TOKENS:
        bigram_vocab.update(build_bigram_counts(h8_fletcher_tokens))

    bigram_index = {bg: i for i, (bg, _c) in enumerate(bigram_vocab.most_common())}

    # Neville windows and centroid
    neville_windows = sample_windows(neville_tokens_all, WINDOW_SIZE, NEVILLE_WINDOWS, rng)
    neville_vecs = []
    for win in neville_windows:
        counts = build_bigram_counts(win)
        neville_vecs.append(vectorize(counts, bigram_index))
    neville_centroid = np.mean(neville_vecs, axis=0)

    results = []

    def score_tokens(name_id, title, year, tokens):
        if len(tokens) < MIN_FUNC_TOKENS:
            return
        windows = sample_windows(tokens, WINDOW_SIZE, PLAY_WINDOWS, rng)
        sims = []
        for win in windows:
            counts = build_bigram_counts(win)
            vec = vectorize(counts, bigram_index)
            sim = float(cosine_similarity(neville_centroid.reshape(1, -1), vec.reshape(1, -1))[0][0])
            sims.append(sim)
        results.append(
            {
                "PLAY_ID": name_id,
                "Title": title,
                "Year": year,
                "Func_Tokens": len(tokens),
                "Window_Size": min(WINDOW_SIZE, len(tokens)),
                "Windows": len(windows),
                "Mean_Sim": float(np.mean(sims)),
                "Std_Sim": float(np.std(sims)),
            }
        )

    for pid in play_ids:
        tokens = play_tokens.get(pid, [])
        score_tokens(pid, titles.get(pid, ""), years.get(pid), tokens)

    score_tokens("502-S", "Henry VIII [Shakespeare Section]", 1613, h8_shakespeare_tokens)
    score_tokens("502-F", "Henry VIII [Fletcher Section]", 1613, h8_fletcher_tokens)

    conn.close()

    results.sort(key=lambda r: r["Mean_Sim"], reverse=True)

    with open(OUTPUT_FILE, "w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(
            f,
            fieldnames=[
                "Rank",
                "PLAY_ID",
                "Title",
                "Year",
                "Func_Tokens",
                "Window_Size",
                "Windows",
                "Mean_Sim",
                "Std_Sim",
            ],
        )
        writer.writeheader()
        for i, r in enumerate(results, 1):
            r["Rank"] = i
            writer.writerow(r)

    with open(OUTPUT_SORTED, "w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(
            f,
            fieldnames=[
                "Rank",
                "PLAY_ID",
                "Title",
                "Year",
                "Func_Tokens",
                "Window_Size",
                "Windows",
                "Mean_Sim",
                "Std_Sim",
            ],
        )
        writer.writeheader()
        for i, r in enumerate(results, 1):
            r["Rank"] = i
            writer.writerow(r)

    print(f"✓ Wrote {OUTPUT_FILE}")
    print(f"✓ Wrote {OUTPUT_SORTED}")
    print("Top 20 (mean similarity):")
    for i, r in enumerate(results[:20], 1):
        print(f"{i}. {r['Title']} ({r['PLAY_ID']}) {r['Mean_Sim']:.4f} ± {r['Std_Sim']:.4f}")


if __name__ == "__main__":
    main()
