#!/usr/bin/env python3
"""
TEST 1: ALL BIGRAM SIMILARITY
Computes cosine similarity between Neville's 1600 Confession and all plays.
Uses ALL bigrams (not filtered for rarity).

Output: Ranking of plays by similarity to the Confession.
"""

import csv
import re
import sqlite3
from collections import Counter, defaultdict
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

# Paths - adjust if needed
DB_PLAYS = "early_modern_plays.db"
CONFESSION_XML = "Neville_Confession_1600.xml"
OUTPUT_FILE = "results_all_bigrams.csv"


def has_alpha(token):
    return re.search(r"[A-Za-z]", token) is not None


def extract_xml_lemmas(path):
    with open(path, "r", encoding="utf-8", errors="ignore") as f:
        xml = f.read()
    lemma_matches = re.findall(r'lemma=["\']([^"\']+)["\']', xml, flags=re.IGNORECASE)
    return [l.lower().strip() for l in lemma_matches if l and has_alpha(l)]


def build_bigrams(tokens):
    counts = Counter()
    for i in range(len(tokens) - 1):
        bigram = (tokens[i], tokens[i+1])
        counts[bigram] += 1
    return counts


def main():
    print("=" * 70)
    print("TEST 1: ALL BIGRAM SIMILARITY")
    print("Neville's 1600 Confession vs All Plays (1590-1615)")
    print("=" * 70)
    
    conn = sqlite3.connect(DB_PLAYS)
    cursor = conn.cursor()
    
    # Extract Confession
    print("\n1. Extracting Confession bigrams...")
    confession_lemmas = extract_xml_lemmas(CONFESSION_XML)
    print(f"   Tokens: {len(confession_lemmas)}")
    
    confession_bigrams = build_bigrams(confession_lemmas)
    print(f"   Unique bigrams: {len(confession_bigrams)}")
    
    # Get all plays
    cursor.execute("""
        SELECT PLAY_ID, TITLE, CREATION_YEAR 
        FROM plays 
        WHERE CREATION_YEAR BETWEEN 1590 AND 1615
        ORDER BY CREATION_YEAR
    """)
    plays = cursor.fetchall()
    print(f"\n2. Loading {len(plays)} plays...")
    
    # Load all play tokens
    play_tokens = defaultdict(list)
    cursor.execute("SELECT PLAY_ID, A0 FROM words WHERE A0 IS NOT NULL ORDER BY PLAY_ID, WORD_ID")
    for pid, word in cursor:
        if word and has_alpha(word):
            play_tokens[pid].append(word.lower())
    
    # Build bigrams for each play and compute similarity
    print("3. Computing bigram similarity to each play...")
    
    # Build vocabulary from confession + all plays
    all_bigrams = set(confession_bigrams.keys())
    play_bigram_counts = {}
    
    for pid, title, year in plays:
        tokens = play_tokens.get(pid, [])
        if len(tokens) >= 1000:
            bigrams = build_bigrams(tokens)
            play_bigram_counts[pid] = bigrams
            all_bigrams.update(bigrams.keys())
    
    # Convert to sorted list for consistent indexing
    vocab = sorted(all_bigrams)
    vocab_index = {bg: i for i, bg in enumerate(vocab)}
    
    print(f"   Total vocabulary: {len(vocab)} bigrams")
    
    # Build confession vector
    confession_vec = np.zeros(len(vocab))
    for bg, count in confession_bigrams.items():
        confession_vec[vocab_index[bg]] = count
    
    # Normalize
    confession_norm = confession_vec / (np.linalg.norm(confession_vec) + 1e-10)
    
    # Compute similarity to each play
    results = []
    for pid, title, year in plays:
        if pid not in play_bigram_counts:
            continue
        
        # Build play vector
        play_vec = np.zeros(len(vocab))
        for bg, count in play_bigram_counts[pid].items():
            play_vec[vocab_index[bg]] = count
        
        # Normalize and compute similarity
        play_norm = play_vec / (np.linalg.norm(play_vec) + 1e-10)
        sim = float(np.dot(confession_norm, play_norm))
        
        results.append({
            'PLAY_ID': pid,
            'Title': title,
            'Year': year,
            'Tokens': len(play_tokens[pid]),
            'Bigrams': len(play_bigram_counts[pid]),
            'Similarity': sim,
        })
    
    # Sort by similarity
    results.sort(key=lambda r: r['Similarity'], reverse=True)
    
    # Add rank
    for i, r in enumerate(results, 1):
        r['Rank'] = i
    
    # Print top 40
    print("\n" + "=" * 70)
    print("TOP 40 PLAYS MOST SIMILAR TO NEVILLE'S 1600 CONFESSION")
    print("=" * 70)
    print(f"{'Rank':<5} {'Year':<6} {'Similarity':<12} {'Title'}")
    print("-" * 70)
    
    for r in results[:40]:
        print(f"{r['Rank']:<5} {r['Year']:<6} {r['Similarity']:<12.4f} {r['Title'][:45]}")
    
    # Write full results
    print(f"\n4. Writing results to {OUTPUT_FILE}...")
    fieldnames = ['Rank', 'PLAY_ID', 'Title', 'Year', 'Tokens', 'Bigrams', 'Similarity']
    with open(OUTPUT_FILE, 'w', newline='', encoding='utf-8') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for r in results:
            writer.writerow(r)
    
    # Summary
    print("\n" + "=" * 70)
    print("SUMMARY")
    print("=" * 70)
    print(f"Text analyzed: Neville's 1600 Confession")
    print(f"Confession tokens: {len(confession_lemmas)}")
    print(f"Confession bigrams: {len(confession_bigrams)}")
    print(f"Plays analyzed: {len(results)}")
    print(f"Top play: {results[0]['Title']} ({results[0]['Similarity']:.4f})")
    
    conn.close()
    print("\n✓ Analysis complete!")


if __name__ == "__main__":
    main()
