from scipy import special

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import os
import re

# sns.set_theme(style="darkgrid")
# tips = sns.load_dataset("tips")
# sns.relplot(x="total_bill", y="tip", data=tips).savefig("test.png")

csv_path = "pos_dist"

csvfiles = os.listdir(csv_path)
csvfiles = [fname for fname in csvfiles if "_pos" in fname]

categories = [re.search('(.+?)_pos', fname).group(1) for fname in csvfiles]

#print(csvfiles)
#print(categories)

# indices from 0 to 12

rank_lens = list()

fig = plt.figure()
fig.set_size_inches(14, 8)
#fig.tight_layout()

# X11 colors for plots
# x11_colors = ['red', 'black', 'sienna', 'purple', 'darkorange', 
#  'gold', 'greenyellow', 'darkgreen', 'lightseagreen', 'grey',
#  'blue', 'cyan', 'hotpink', 'silver', 'gray']

ax = fig.subplots()

for i in range(len(csvfiles)):
    test_file = os.path.join(csv_path, csvfiles[i])

    df = pd.read_csv(test_file)
    df = df.rename(columns={df.columns[0]: "pos", df.columns[1]: "ratio"})

    # for row in df.itertuples():
    #     print(row.mask_id)

    pos = [row.pos for row in df.itertuples()]
    ratio = [row.ratio for row in df.itertuples()]

    ax.plot(pos, ratio, '-o', linewidth=0.6, markersize=3, label=categories[i])
    ax.legend(ncol=4)
    #plt.xscale('linear')
    plt.yscale('linear')

plt.title('POS Distribution of each Category in DUTA-10K')
plt.xticks(rotation=75, fontsize=7)
plt.xlabel('POS')
plt.ylabel('POS Ratio')
plt.tight_layout()

#plt.show()
plt.savefig('pos_dist.png')