from scipy import special

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
#import seaborn as sns

import os
import re

# sns.set_theme(style="darkgrid")
# tips = sns.load_dataset("tips")
# sns.relplot(x="total_bill", y="tip", data=tips).savefig("test.png")

csv_path = "common_words"
csv_surf_path = "../lexstats_surface/common_words"

csvfiles = os.listdir(csv_path)
csvfiles = [fname for fname in csvfiles if "_nostops" in fname]
categories = [re.search('(.+?)_nostops', fname).group(1) for fname in csvfiles]

csvsurfs = os.listdir(csv_surf_path)
csvsurfs = [fname for fname in csvsurfs if "_nostops" in fname]
categories_surf = [re.search('(.+?)_nostops', fname).group(1) for fname in csvsurfs]
#print(csvfiles)
#print(categories)

# indices from 0 to 12

rank_lens = list()

fig = plt.figure()
fig.set_size_inches(8, 5)

colorpalette = [
    '#332288',
    '#88ccee',
    '#44aa99',
    '#117733',
    '#999933',
    '#ddcc77',
    '#500c00',
    '#cc6677',
    '#ff0029',
    '#aa4499',
    '#fbb4ae',
    '#b3cde3',
    '#ccebc5'
]


ax = fig.subplots()
#ax = fig.add_axes([0.1, 0.1, 0.75, 0.8])

for i in range(len(csvfiles)):
    test_file = os.path.join(csv_path, csvfiles[i])

    df = pd.read_csv(test_file)
    df = df.rename(columns={df.columns[0]: "term"})

    #print(df['term'][0])

    rank = [row.Index for row in df.itertuples() if row.count > 0]
    rank_lens.append(len(rank))
    total_wordcnt = sum([row.count for row in df.itertuples()])
    counts = [row.count / total_wordcnt for row in df.itertuples() if row.count > 0]

    #ax.scatter(rank, counts, s=2, linewidths=1, label=categories[i])
    ax.plot(rank, counts, c=colorpalette[i], linewidth=1.25, label=categories[i], alpha=0.7)
    #ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
    ax.legend()
    plt.xscale('log')
    plt.yscale('log')

for i in range(len(csvsurfs)):
    test_file = os.path.join(csv_surf_path, csvsurfs[i])

    df = pd.read_csv(test_file)
    df = df.rename(columns={df.columns[0]: "term"})

    #print(df['term'][0])

    rank = [row.Index for row in df.itertuples() if row.count > 0]
    rank_lens.append(len(rank))
    total_wordcnt = sum([row.count for row in df.itertuples()])
    counts = [row.count / total_wordcnt for row in df.itertuples() if row.count > 0]

    #ax.scatter(rank, counts, s=2, linewidths=1, label=categories[i])
    ax.plot(rank, counts, '--', c=colorpalette[i + len(csvfiles)], linewidth=1.25, label=categories_surf[i], alpha=0.7)
    #ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
    ax.legend(ncol=2, fontsize='small')
    plt.xscale('log')
    plt.yscale('log')

# zipf distribution
# a = 1.32
# x = np.arange(1, max(rank_lens))
# y = x**(-a) / special.zetac(a)

# plt.plot(x, y/max(y), linewidth=2, color='r')
# plt.title('Rank vs. Frequency Ratio of words in each CoDA Category and Surface Web')
plt.xlabel('Word rank (log)')
plt.ylabel('Word frequency ratio (log)')
plt.tight_layout()

#plt.show()
plt.savefig('coda_dist_w_surf.pdf')