# from datasets import load_dataset
# from collections import defaultdict
# DATASET = "preference-agents-experiments/enron-standardized-jeff-dasovich-20-split-llmeval"
# COLUMN = "llmwinrate_70b_raw_cleaned"

# mapper = {
#     "1" : "large_baseline",
#     "2" : "large_w_70b_rules",
#     "3" : "large_w_8b_rules",
#     "4" : "large_w_nob_rules",
#     "5" : "naive_ft_baseline"
# }
# n = 0
# counter = defaultdict(int)

# def get_win_percentage(data):
#     for row in data:
#         try:
#             winner = int(row)
#             counter[winner] += 1
#             n += 1
#         except Exception as e:
#             print("Hit a Exception:", e)
#     # convert to %
#     for key in counter:
#         counter[key] = counter[key] / n * 100

# dataset = load_dataset(DATASET)
# get_win_percentage(dataset["train"][COLUMN])
# print(counter)

from datasets import load_dataset
import matplotlib.pyplot as plt

# Load the dataset
dataset = load_dataset("preference-agents-experiments/enron-standardized-jeff-dasovich-20-split-llmeval")

# Define the mapper
mapper = {
    "1": "large_baseline",
    "2": "large_w_70b_rules",
    "3": "large_w_8b_rules",
    "4": "large_w_nob_rules",
    "5": "naive_ft_baseline"
}

# Aggregate the data from both splits
column_data = dataset['test']['llmwinrate_70b_raw_cleaned'] + dataset['train']['llmwinrate_70b_raw_cleaned']

# Count the occurrences of each number, ignoring null values
counts = {}
total_count = 0

for number in column_data:
    if number is not None:
        total_count += 1
        if str(number) in mapper:
            label = mapper[str(number)]
            counts[label] = counts.get(label, 0) + 1

# Calculate the percentages
percentages = {label: (count / total_count) * 100 for label, count in counts.items()}

# Sort the percentages in descending order
sorted_percentages = dict(sorted(percentages.items(), key=lambda item: item[1], reverse=False))
print(sorted_percentages)
# Create a horizontal bar chart
labels = list(sorted_percentages.keys())
values = list(sorted_percentages.values())

plt.figure(figsize=(10, 6))
plt.barh(labels, values)
plt.xlabel('Percentage')
plt.ylabel('Label')
plt.title('Percentage of Occurrences in llmwinrate_70b_raw_cleaned (Aggregated, Null Values Ignored)')

# Add percentage labels to the bars
for i, v in enumerate(values):
    plt.text(v + 0.1, i, f'{v:.2f}%', va='center')

plt.tight_layout()
plt.savefig('llmwinrate_70b_raw_cleaned.png')
plt.show()