import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import os
import sys
import json
from scipy.stats import ttest_rel


src_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(src_root)

from collaborative_storm.modules.encoder import get_text_embeddings

def load_embeddings_from_file(article_name, output_dir="data/cache"):
    file_path = os.path.join(output_dir, f"{article_name}.npy")
    if os.path.exists(file_path):
        return np.load(file_path)
    else:
        return None
    
def save_embeddings_to_file(article_name, embeddings, output_dir="data/cache"):
    # Create the output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Save the embeddings to a file
    file_path = os.path.join(output_dir, f"{article_name}.npy")
    np.save(file_path, embeddings)


import numpy as np
from sklearn.neighbors import KernelDensity
from scipy.stats import entropy
def calculate_entropy(embeddings):
    # Calculate pairwise cosine similarities
    similarities = cosine_similarity(embeddings)
    # Convert similarities to probabilities
    probabilities = similarities / similarities.sum(axis=1, keepdims=True)
    # Calculate entropy for each row
    entropies = np.apply_along_axis(entropy, 1, probabilities)
    return np.mean(entropies)

def calculate_diversity_scores():
    diversity_score_output_path = "diversity_score_output.json"
    api_key = ... # FIXME
    evaluation_data_path = "snippets_to_grade.json"

    with open(evaluation_data_path) as f:
        all_snippets = json.load(f)

    diversity_scores = {}

    for article_name, data in tqdm(list(all_snippets.items())):
        snippets = data["snippets"]
        method_to_index_mapping = data["method_to_index_mapping"]

    for article_name, data in tqdm(list(all_snippets.items())):
        snippets = data["snippets"]
        method_to_index_mapping = data["method_to_index_mapping"]

        # Calculate embeddings for the snippets
        embeddings = load_embeddings_from_file(article_name=article_name)
        if embeddings is None:
            embeddings, _ = get_text_embeddings(texts=snippets, max_workers=8, api_key=api_key)
            save_embeddings_to_file(article_name=article_name, embeddings=embeddings)
        # Calculate cosine similarity matrix
        
        article_scores = {}
        print(f"article: {article_name}")
        for method, indices in method_to_index_mapping.items():
            if len(indices) < 2:
                # If there are fewer than 2 snippets, diversity score is not applicable
                continue
            method_embeddings = embeddings[indices]
            similarities = cosine_similarity(method_embeddings)
            avg_similarity = np.mean(similarities[np.triu_indices_from(similarities, k=1)])
            article_scores[method] = 1 - avg_similarity
        diversity_scores[article_name] = article_scores

    with open(diversity_score_output_path, "w") as f:
        json.dump(diversity_scores, f, indent=2)
    
    with open(diversity_score_output_path) as f:
        diversity_scores = json.load(f)
    # Aggregate diversity scores by method
    method_aggregate_scores = {}
    method_counts = {}

    for article_scores in diversity_scores.values():
        for method, score in article_scores.items():
            if method not in method_aggregate_scores:
                method_aggregate_scores[method] = 0
                method_counts[method] = 0
            method_aggregate_scores[method] += score
            method_counts[method] += 1

    for method in method_aggregate_scores:
        method_aggregate_scores[method] /= method_counts[method]

    # Print the results
    for method, avg_score in sorted(method_aggregate_scores.items()):
        print(f"{method}: {avg_score * 100:.2f}")
    
    # Extract scores for the paired t-test
    new_method_scores = []
    method_scores = {}

    for article_scores in diversity_scores.values():
        if "new_method" in article_scores:
            new_method_scores.append(article_scores["new_method"])
            for method, score in article_scores.items():
                if method != "new_method":
                    if method not in method_scores:
                        method_scores[method] = []
                    method_scores[method].append(score)

    for method, scores in sorted(method_scores.items()):
        if len(scores) == len(new_method_scores):  # Ensure equal lengths for paired test
            t_stat, p_value = ttest_rel(new_method_scores, scores)
            print(f"\nPaired t-test results for new_method vs {method}:")
            print(f"t-statistic: {t_stat:.3f}")
            print(f"p-value: {p_value:.3f}")
        else:
            print(f"\nCannot perform t-test for new_method vs {method} due to unequal sample sizes.")

calculate_diversity_scores()