from scipy import special

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
#import seaborn as sns

import os
import re

# sns.set_theme(style="darkgrid")
# tips = sns.load_dataset("tips")
# sns.relplot(x="total_bill", y="tip", data=tips).savefig("test.png")

csv_path = "mask_dist"

csvfiles = os.listdir(csv_path)
csvfiles = [fname for fname in csvfiles if "_mask" in fname]

categories = [re.search('(.+?)_mask', fname).group(1) for fname in csvfiles]

#print(csvfiles)
#print(categories)

# indices from 0 to 12

rank_lens = list()

fig = plt.figure()
fig.set_size_inches(7, 6)
#fig.tight_layout()

# X11 colors for plots
x11_colors = ['red', 'black', 'blue', 'purple', 'darkorange', 
 'gold', 'greenyellow', 'darkgreen', 'lightseagreen', 'grey',
 'fuchsia', 'cyan', 'springgreen', 'silver', 'slateblue', 'darkseagreen',
 'sandybrown']

ax = fig.subplots()

for i in range(len(csvfiles)):
    test_file = os.path.join(csv_path, csvfiles[i])

    df = pd.read_csv(test_file)
    df = df.rename(columns={df.columns[0]: "mask_id", df.columns[1]: "ratio"})

    # for row in df.itertuples():
    #     print(row.mask_id)

    mask_id = [row.mask_id for row in df.itertuples()]
    ratio = [row.ratio for row in df.itertuples()]

    ax.plot(mask_id, ratio, '-o', c=x11_colors[i], linewidth=0.6, markersize=3, label=categories[i], alpha=0.95)
    ax.legend(ncol=2, fontsize='small')
    #plt.xscale('linear')
    plt.yscale('linear')

#plt.title('Mask ID Distribution of each Category in CoDA')
plt.xticks(rotation=90, fontsize=6.5)
plt.xlabel('Mask ID')
plt.ylabel('Mask Frequency Ratio')
plt.tight_layout()

#plt.show()
plt.savefig('mask_dist.pdf')