from scipy import special

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
#import seaborn as sns

import os
import re

# sns.set_theme(style="darkgrid")
# tips = sns.load_dataset("tips")
# sns.relplot(x="total_bill", y="tip", data=tips).savefig("test.png")

csv_path = "mask_dist"

csvfiles = os.listdir(csv_path)
csvfiles = [fname for fname in csvfiles if "_mask" in fname]

categories = [re.search('(.+?)_mask', fname).group(1) for fname in csvfiles]

#print(csvfiles)
#print(categories)

# indices from 0 to 12

rank_lens = list()

fig = plt.figure()
fig.set_size_inches(8, 5)
#fig.tight_layout()

colorpalette = [
    '#ff0029',
    '#377eb8',
    '#66a61e',
    '#984ea3',
    '#00d2d5',
    '#ff7f00',
    '#af8d00',
    '#7f80cd',
    '#b3e900',
    '#c42e60',
    '#a65628',
    '#f781bf',
    '#8dd3c7',
    '#bebada',
    '#fb8072',
    '#80b1d3',
    '#fdb462'
]

ax = fig.subplots()

# Get list of df by CATEGORY 
df_list = dict()

for i in range(len(csvfiles)):
    test_file = os.path.join(csv_path, csvfiles[i])

    df = pd.read_csv(test_file)
    df = df.rename(columns={df.columns[0]: "mask_id", df.columns[1]: "ratio"})
    
    df_list[categories[i]] = df

    # for row in df.itertuples():
    #     print(row.mask_id)

mask_ids = df_list['Arms']['mask_id'].tolist()

mask_ratios = list()

for i in range(len(mask_ids)):

    mask_ratio = list()

    for j in range(len(categories)):
        mask_ratio.append(df_list[categories[j]]['ratio'][i])
    
    mask_ratios.append(mask_ratio)

#for i in range(len(mask_ids)):
for i in range(len(mask_ids)):

    current_mask = mask_ids[i]

    mask_array = np.array(mask_ratios)
    mask_ratio_sum = np.zeros(len(categories))

    for j in range(i):
        mask_ratio_sum += mask_array[j]
    
    mask_bottom = mask_ratio_sum.tolist()
    #print(mask_bottom)

    ax.bar(categories, mask_ratios[i], color=colorpalette[i], bottom=mask_bottom, label=current_mask)

# mask_id = [row.mask_id for row in df.itertuples()]
# ratio = [row.ratio for row in df.itertuples()]

# ax.plot(mask_id, ratio, '-o', c=x11_colors[i], linewidth=0.6, markersize=3, label=categories[i], alpha=0.95)
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[::-1], labels[::-1], ncol=1, fontsize='small', bbox_to_anchor=(1, 1), loc='upper left')
#plt.xscale('linear')
plt.yscale('linear')

#plt.title('Mask ID Distribution of each Category in CoDA')
plt.xticks(rotation=75, fontsize=8)
plt.xlabel('Categories')
plt.ylabel('Mask Frequency Ratio')
plt.tight_layout()

#plt.show()
plt.savefig('mask_dist_bar.pdf')