import plotly.graph_objects as go
from datasets import load_dataset
import numpy as np
from scipy.stats import gaussian_kde

# Load the dataset
dataset = load_dataset("preference-agents-experiments/kay-mann-20-split")
dataset = dataset["test"]

columns = [
    "baseline_generation_llmeval_8b", 
    "gold_rule_email_llmeval_8b", 
    "naiveft_baseline_llmeval_8b", 
    "rulegen_rules_generated_email_llmeval_8b",
    "baseline_generation_llmeval_70b",
    "gold_rule_email_llmeval_70b",
    "naiveft_baseline_llmeval_70b",
    "rulegen_rules_generated_email_llmeval_70b"
]

colors = ["green", "purple", "yellow", "blue", "red", "orange", "brown", "pink"]

fig = go.Figure()

# Generate the KDE and plot the distributions
for i, column in enumerate(columns):
    data = dataset[column]
    data = np.array([0 if v is None else v for v in data])  # Replace None values with 0
    print(f"Column: {column}, Length: {len(data)}")
    print(data[:10])
    
    # Gaussian KDE
    kde = gaussian_kde(data)
    x = np.linspace(0, 5, 500)  # Generate X values from 0 to 5
    y = kde(x)
    
    fig.add_trace(
        go.Scatter(
            x=x,
            y=y,
            mode='lines',
            line=dict(color=colors[i]),
            name=column
        )
    )

fig.update_xaxes(title_text="Score", range=[0, 5], tickvals=[0, 1, 2, 3, 4, 5])
fig.update_yaxes(title_text="Density")

fig.update_layout(
    title_text="Distribution Curves of Scores",
    height=400,
    width=600
)

fig.show()

# Save the figure
fig.write_image("distribution_curves.png")
