import pandas as pd
from sqlalchemy import create_engine
from .sentiment_analysis import SentimentAnalyzer
from tqdm import tqdm
from scipy.spatial import distance
from pathlib import Path
import json
from sqlalchemy import create_engine, MetaData, Table, Index
from sqlalchemy.types import String, Integer
import argparse


def filter_df_by_keywords(df, keywords):
    filter = None
    comments = df.c_body.str.lower()
    for k in keywords:
        r = comments.str.contains(k.lower())
        filter = r if filter is None else (filter | r)
    return filter


def form_sqlite_str(path):
    return f'sqlite:///{path}'


def gen_topic_df(dbs, sample_per_sub, filter_words):
    sample_per_class = sample_per_sub // 2
    res_all = None
    for table in ['test', 'val']:
        res = None
        for db in dbs:
            print(db)
            n = sample_per_class
            df = pd.read_sql_table(table, create_engine(form_sqlite_str(db)))
            topic_df = df[filter_df_by_keywords(df, filter_words)].reset_index(drop=True)
            pos = topic_df[topic_df.c_polarity == 1]
            neg = topic_df[topic_df.c_polarity == 0]
            if pos.shape[0] <= n or neg.shape[0] <= n:
                n = min(pos.shape[0], neg.shape[0])
            topic_df = pd.concat((pos.sample(n=n), neg.sample(n=n)), axis=0)
            print(topic_df.shape)
            res = topic_df if res is None else pd.concat((res, topic_df), axis=0)
        res_all = res if res_all is None else pd.concat((res_all, res), axis=0)
    res_all = res_all.reset_index(drop=True)
    res_all['id'] = res_all.index
    return res_all


def get_sentiment_score(ser, st):
    res_sent = ser.apply(st.sentiment_pipeline)
    res_sent = pd.DataFrame({'sent_label': [e[0]['label'] for e in res_sent], 'sent_score':[e[0]['score'] for e in res_sent] })
    res_sent.columns = ['sent_label', 'sent_score']
    return res_sent


def form_sqlite_str(path):
    return f'sqlite:///{path}'


def create_id_index(table_name, engine):
    metadata = MetaData(bind=engine)
    table = Table(table_name, metadata, autoload=True)
    id_idx = Index(f'{table_name}_idx', table.c['id'])
    id_idx.create(bind=engine)


dtypes = {'p_title':String(), 'p_created_utc':Integer(), 'p_subreddit':String(), 'p_is_self':String(), 'p_score':Integer(), 
         'p_author':String(), 'p_author_flair_text':String(), 'p_selftext':String(), 'link_id':String(), 
     'c_created_utc':Integer(), 'c_score':Integer(), 'c_author':String(), 'c_body':String(), 'c_polarity':Integer(),
    'year':Integer(), 'month':Integer(), 'id':Integer()}

parser = argparse.ArgumentParser()
parser.add_argument("--out_dir", type=str, required=True)
parser.add_argument("--prefix", type=str, default="")
parser.add_argument("--filter_words_json", type=str, required=True)
parser.add_argument("--sentiment_dbs_json", type=str, required=True)
args = parser.parse_args()



out_dir = Path(args.out_dir)
prefix = args.prefix
sentiment_dbs_json = Path(args.sentiment_dbs_json)
filter_words_json = Path(args.filter_words_json)

with open(sentiment_dbs_json) as f:
    dbs = [Path(e) for e in json.load(f)]
with open(filter_words_json) as f:
    filter_words = json.load(f)
df = gen_topic_df(dbs, sample_per_sub=3000, filter_words=filter_words)
st = SentimentAnalyzer()
df = pd.concat((df, get_sentiment_score(df.c_body, st)), axis=1)
df.groupby('p_subreddit')['sent_label'].value_counts(normalize=True).sort_index().sort_values()
df.to_json(out_dir / f'{prefix}sentiment.json')
engine = create_engine(form_sqlite_str(out_dir / f'{prefix}sentiment.db'))
df.to_sql('test', engine, dtype=dtypes)
create_id_index('test', engine)