
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
#import seaborn as sns

import os
import re

# sns.set_theme(style="darkgrid")
# tips = sns.load_dataset("tips")
# sns.relplot(x="total_bill", y="tip", data=tips).savefig("test.png")

csv_path = "pos_dist_nonproc"
csv_surf_path = "../lexstats_surface/pos_dist"

csvfiles = os.listdir(csv_path)
csvfiles = [fname for fname in csvfiles if not "_all" in fname]
categories = [re.search('(.+?)_pos', fname).group(1) for fname in csvfiles]

csvsurfs = os.listdir(csv_surf_path)
csvsurfs = [fname for fname in csvsurfs if not "_all" in fname]
categories_surf = [re.search('(.+?)_pos', fname).group(1) for fname in csvsurfs]

#print(csvfiles)
#print(categories)

# indices from 0 to 12

rank_lens = list()

fig = plt.figure()
fig.set_size_inches(8, 5)
#fig.tight_layout()

colorpalette = [
    '#332288',
    '#88ccee',
    '#44aa99',
    '#117733',
    '#999933',
    '#ddcc77',
    '#500c00',
    '#cc6677',
    '#ff0029',
    '#aa4499',
    '#fbb4ae',
    '#b3cde3',
    '#ccebc5'
]

ax = fig.subplots()

for i in range(len(csvfiles)):
    test_file = os.path.join(csv_path, csvfiles[i])

    df = pd.read_csv(test_file)
    df = df.rename(columns={df.columns[0]: "pos", df.columns[1]: "ratio"})

    # for row in df.itertuples():
    #     print(row.mask_id)

    pos = [row.pos for row in df.itertuples()]
    ratio = [row.ratio for row in df.itertuples()]

    ax.plot(pos, ratio, '-o', c=colorpalette[i], linewidth=0.6, markersize=3, label=categories[i])
    ax.legend()
    #plt.xscale('linear')
    plt.yscale('linear')

for i in range(len(csvsurfs)):
    test_file = os.path.join(csv_surf_path, csvsurfs[i])

    df = pd.read_csv(test_file)
    df = df.rename(columns={df.columns[0]: "pos", df.columns[1]: "ratio"})

    #print(df['term'][0])

    pos = [row.pos for row in df.itertuples()]
    ratio = [row.ratio for row in df.itertuples()]

    #ax.scatter(rank, counts, s=2, linewidths=1, label=categories[i])
    ax.plot(pos, ratio, '--+', c=colorpalette[i + len(csvfiles)], linewidth=0.7, markersize=5, label=categories_surf[i])
    #ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
    ax.legend(ncol=1, fontsize='small')
    #plt.xscale('log')
    plt.yscale('linear')

#plt.title('Mean POS Distribution of each Category in CoDA and the Surface Web')
plt.xticks(rotation=75, fontsize=8)
plt.xlabel('Part of Speech (UPOS Tag in spaCy)')
plt.ylabel('Mean PoS Ratio')
plt.tight_layout()

#plt.show()
plt.savefig('pos_dist_w_surf.pdf')