import plotly.graph_objects as go
from datasets import load_dataset

# Load the dataset
dataset = load_dataset("preference-agents-experiments/kay-mann-20-split")
dataset = dataset["test"]

columns = ["baseline_generation_llmeval_8b", 
           "gold_rule_email_llmeval_8b", 
           "naiveft_baseline_llmeval_8b", 
           "rulegen_rules_generated_email_llmeval_8b"]

colors = ["green", "purple", "yellow", "blue"]

fig = go.Figure()

for i, column in enumerate(columns):
    # Create a histogram trace
    trace = go.Histogram(
        x=dataset[column],
        xbins=dict(start=0, end=5, size=1),  # Set the bin size to 1 and scale to 0-5
        marker=dict(color=colors[i]),  
        name=column,
        opacity=0.5  
    )
    
    fig.add_trace(trace)
    
fig.update_xaxes(title_text="Score", range=[0, 5], tickvals=[0, 1, 2, 3, 4, 5])
fig.update_yaxes(title_text="Frequency")


fig.update_layout(
    title_text="Histograms of Scores",
    height=400,
    width=600,
    bargap=0.2,  # gap between bars of adjacent location coordinates
    bargroupgap=0.1  # gap between bars of the same location coordinates
)

fig.show()
# save the figure
fig.write_image("histogram.png")