import os
import json
from tqdm import tqdm
from sklearn.feature_extraction.text import TfidfVectorizer
from collections import defaultdict
from nltk.tokenize import word_tokenize
import numpy as np
from collections import Counter
from nltk.corpus import stopwords
import math
import os
from data_utils import get_stemmed_corpus
from nltk.stem import PorterStemmer


def get_features(data_dir):

    corpus = get_stemmed_corpus(data_dir)
    docs = []
    for key, val in corpus.items():
        docs.append(' '.join(val['stemmed_transcript']))
    vectorizer = TfidfVectorizer(stop_words='english', ngram_range=(2,3), min_df=10)
    X = vectorizer.fit_transform(docs)
    print(X.shape)
    with open(os.path.join(data_dir, 'features.jsonl'), 'w') as f:
        json.dump(list(vectorizer.get_feature_names()), f)
    print(len(vectorizer.get_feature_names()))

    return vectorizer


def prep_co_occurence_matrix_v0(data_dir, window = 10):

    # Build word feature matrix - vector-by-vector
    features = json.load(open(os.path.join(data_dir, 'features.jsonl')))
    matrix = {}

    corpus = get_corpus(data_dir)
    for key, val in tqdm(corpus.items()):
        tool_file = key.replace('.trans.tsv', '.tools.tsv')
        text = [line.strip() for line in open(os.path.join(data_dir, tool_file), 'r').readlines()]

        tools = []
        tooltimes = []
        for line in text:
            tokens = line.split('\t')
            tools.append(tokens[2].lower())
            tooltimes.append(float(tokens[1]))

        if tools == []:
            continue
        for tool, tool_time in zip(tools, tooltimes):
            for t, t_time in zip(val['stemmed_transcript'], val['transcript_times']):
                if abs(tool_time - t_time) <= window:
                    feat_vector = np.zeros(len(features))
                    for i, f in enumerate(features):
                        if f in t:
                            feat_vector[i] += 1
            if tool in matrix.keys():
                matrix[tool] += feat_vector
            else:
                matrix[tool] = feat_vector

    matrix = {k:v.tolist() for k, v in matrix.items()}
    with open(os.path.join(data_dir, 'co_occurence_matrix.jsonl'), 'w') as f:
        json.dump(matrix, f)


def prep_co_occurence_matrix_v1(data_dir, window = 5):

    # Concatenate sentences appearing together and use vectorizer
    vectorizer = get_features(data_dir)

    if os.path.exists(os.path.join(data_dir, 'co_occurence_doc.jsonl')):
        matrix = json.load(open(os.path.join(data_dir, 'co_occurence_doc.jsonl')))
    else:
        matrix = defaultdict(lambda: '')
        corpus = get_corpus(data_dir)
        for key, val in tqdm(corpus.items()):
            tool_file = key.replace('.trans.tsv', '.tools.tsv')
            text = [line.strip() for line in open(os.path.join(data_dir, tool_file), 'r').readlines()]

            tools = []
            tooltimes = []
            for line in text:
                tokens = line.split('\t')
                tools.append(tokens[2].lower())
                tooltimes.append(float(tokens[1]))

            if tools == []:
                continue
            for tool, tool_time in zip(tools, tooltimes):
                for t, t_time in zip(val['stemmed_transcript'], val['transcript_times']):
                    if abs(tool_time - t_time) <= window:
                        matrix[tool] += t + ' '

        with open(os.path.join(data_dir, 'co_occurence_doc.jsonl'), 'w') as f:
            json.dump(matrix, f)

    tool_docs = list(matrix.values())
    tools = list(matrix.keys())
    X = vectorizer.transform(tool_docs)
    print(X.shape)

    feature_array = np.array(vectorizer.get_feature_names())
    tfidf_sorting = np.argsort(X.toarray()).flatten()[::-1]

    n = 100
    top_n = feature_array[tfidf_sorting][:n]
    for n in top_n:
        print(n)

# prep_co_occurence_matrix_v1('./data/streams-2021-02-03')

def get_ranked_phrase(data_dir, window = 5):

    root_dir = os.path.dirname(data_dir)

    if not (os.path.exists(os.path.join(root_dir, 'action_obj_count.jsonl')) and os.path.exists(os.path.join(root_dir, 'action_obj_doc_count.jsonl'))):
        # Use action object pairs
        print("Reading corpus")
        if not os.path.exists(os.path.join(root_dir, 'action_obj_stemmed_pairs.json')):
            corpus = json.load(open(os.path.join(root_dir, 'action_obj_pairs.json')))
            stemmed_corpus = {}
            all_action_object_pairs = []
            for key, val in tqdm(corpus.items()):
                stemmed_corpus[key] = {}
                for t, v in val.items():
                    tokenized_phrases = [word_tokenize(p[0].lower()+' '+p[1].lower()) for p in v[1]]
                    stemmed_phrases = [' '.join([stemmer.stem(w) for w in phrase]) for phrase in tokenized_phrases]
                    all_action_object_pairs.extend(stemmed_phrases)
                    stemmed_corpus[key][t] = stemmed_phrases

            with open(os.path.join(root_dir, 'action_obj_stemmed_pairs.json'), 'w') as fout:
                json.dump(stemmed_corpus, fout)
        else:
            stemmed_corpus = json.load(open(os.path.join(root_dir, 'action_obj_stemmed_pairs.json')))
            all_action_object_pairs = []
            for key, val in tqdm(stemmed_corpus.items()):
                for t, v in val.items():
                    all_action_object_pairs.extend(v)

        print("Found %s action-object phrases" % (len(all_action_object_pairs)))
        phrase_counts = dict(Counter(all_action_object_pairs))

        all_action_object_pairs = list(set(all_action_object_pairs))
        print("Found %s unique action-object phrases" % len(all_action_object_pairs))
        pair2idx = {p: i for i, p in enumerate(all_action_object_pairs)}
        idx2pair = {v: k for k, v in pair2idx.items()}

        total_transcripts = 0
        phrase_doc_count = defaultdict(lambda: 0)
        for i, (key, val) in tqdm(enumerate(stemmed_corpus.items())):
            tool_file = key.replace('.trans.tsv', '.tools.tsv')
            tool_text = [line.strip() for line in open(os.path.join(data_dir, tool_file), 'r').readlines()]

            tools = []
            tooltimes = []
            for line in tool_text:
                tokens = line.split('\t')
                tools.append(tokens[2].lower())
                tooltimes.append(float(tokens[1]))

            if not tools:
                continue

            total_transcripts += 1
            all_phrases_in_doc = []

            for t, phrases in stemmed_corpus[key].items():
                for p in phrases:
                    all_phrases_in_doc.append(p)

            all_phrases_in_doc = set(all_phrases_in_doc)
            for phrase in all_phrases_in_doc:
                phrase_doc_count[phrase] += 1

        print(total_transcripts)

        with open(os.path.join(root_dir, 'action_obj_count.jsonl'), 'w') as f:
            json.dump({k: v for k, v in sorted(phrase_counts.items(), key=lambda item: item[1], reverse=True)}, f, indent=2)
        with open(os.path.join(root_dir, 'action_obj_doc_count.jsonl'), 'w') as f:
            json.dump({k: v for k, v in sorted(phrase_doc_count.items(), key=lambda item: item[1], reverse=True)}, f, indent=2)

    else:
        phrase_counts = json.load(open(os.path.join(root_dir, 'action_obj_count.jsonl')))
        phrase_doc_count = json.load(open(os.path.join(root_dir, 'action_obj_count.jsonl')))

    phrases = list(phrase_counts.keys())
    total_phrases = sum(list(phrase_counts.values()))
    phrase_tf_idf = {}
    for p in phrases:
        tf = float(phrase_counts[p]/total_phrases)
        idf = math.log(float(3159)/phrase_doc_count[p]+1)
        phrase_tf_idf[p] = {'tf': tf, 'idf': idf, 'tf-idf': tf*idf}

    sorted_k_v = {k: v for k, v in sorted(phrase_tf_idf.items(), key=lambda item: item[1]['tf-idf'], reverse=True)}
    with open(os.path.join(root_dir, 'action_obj_ranked_tf_idf.jsonl'), 'w') as f:
        json.dump(sorted_k_v, f, indent=2)

# get_ranked_phrase('./data/streams-2021-02-03')


def tf_idf_by_doc(data_dir):

    root_dir = os.path.dirname(data_dir)
    stemmed_corpus = json.load(open(os.path.join(root_dir, 'action_obj_stemmed_pairs.json')))
    all_action_object_pairs = []
    for key, val in tqdm(stemmed_corpus.items()):
        for t, v in val.items():
            all_action_object_pairs.extend(v)

    print("Found %s action-object phrases" % (len(all_action_object_pairs)))
    all_action_object_pairs = list(set(all_action_object_pairs))
    print("Found %s unique action-object phrases" % len(all_action_object_pairs))
    phrase_doc_count = json.load(open(os.path.join(root_dir, 'action_obj_doc_count.jsonl')))

    total_transcripts = 3159
    sorted_phrase_tf_idf_by_doc = {}
    for i, (key, val) in tqdm(enumerate(stemmed_corpus.items())):
        tool_file = key.replace('.trans.tsv', '.tools.tsv')
        tool_text = [line.strip() for line in open(os.path.join(data_dir, tool_file), 'r').readlines()]

        tools = []
        tooltimes = []
        for line in tool_text:
            tokens = line.split('\t')
            tools.append(tokens[2].lower())
            tooltimes.append(float(tokens[1]))

        if not tools:
            continue

        phrase_doc_count_this = defaultdict(lambda :0)
        all_phrases_in_doc = []
        for t, phrases in stemmed_corpus[key].items():
            for p in phrases:
                phrase_doc_count_this[p] += 1

        phrase_tf_idf = {}
        total_terms = sum(list(phrase_doc_count_this.values()))
        for p, count in phrase_doc_count_this.items():
            tf = float(count)/total_terms
            idf = math.log(float(total_transcripts)/phrase_doc_count[p])
            phrase_tf_idf[p] = {'tf': tf, 'idf': idf, 'tf-idf': tf * idf}

        sorted_k_v = {k: v for k, v in sorted(phrase_tf_idf.items(), key=lambda item: item[1]['tf-idf'], reverse=True)}
        sorted_phrase_tf_idf_by_doc[key] = sorted_k_v

    with open(os.path.join(root_dir, 'action_obj_ranked_tf_idf_by_doc.jsonl'), 'w') as f:
        json.dump(sorted_phrase_tf_idf_by_doc, f, indent=2)

# tf_idf_by_doc('./data/streams-2021-02-03')


def prep_co_occurence_matrix_v2(data_dir, window=5):

    root_dir = os.path.dirname(data_dir)
    # Use action object pairs + tools
    print("Reading corpus")
    if not os.path.exists(os.path.join(root_dir, 'action_obj_stemmed_pairs.json')):
        corpus = json.load(open(os.path.join(root_dir, 'action_obj_pairs.json')))
        stemmed_corpus = {}
        all_action_object_pairs = []
        for key, val in tqdm(corpus.items()):
            stemmed_corpus[key] = {}
            for t, v in val.items():
                tokenized_phrases = [word_tokenize(p[0].lower()+' '+p[1].lower()) for p in v[1]]
                stemmed_phrases = [' '.join([stemmer.stem(w) for w in phrase]) for phrase in tokenized_phrases]
                all_action_object_pairs.extend(stemmed_phrases)
                stemmed_corpus[key][t] = stemmed_phrases

        with open(os.path.join(root_dir, 'action_obj_stemmed_pairs.json'), 'w') as fout:
            json.dump(stemmed_corpus, fout)
    else:
        stemmed_corpus = json.load(open(os.path.join(root_dir, 'action_obj_stemmed_pairs.json')))
        all_action_object_pairs = []
        for key, val in tqdm(stemmed_corpus.items()):
            for t, v in val.items():
                all_action_object_pairs.extend(v)

    print("Found %s action-object phrases" % (len(all_action_object_pairs)))
    x = Counter(all_action_object_pairs)
    for tup in x.most_common(500):
        print(tup)

    all_action_object_pairs = list(set(all_action_object_pairs))
    print("Found %s unique action-object phrases" % len(all_action_object_pairs))
    pair2idx = {p: i for i, p in enumerate(all_action_object_pairs)}
    idx2pair = {v: k for k, v in pair2idx.items()}

    tool_phrase_count = defaultdict(lambda: 0)
    total_transcripts = 0
    tool_phrase_doc_count = defaultdict(lambda: 0)
    for key, val in tqdm(stemmed_corpus.items()):
        tool_file = key.replace('.trans.tsv', '.tools.tsv')
        tool_text = [line.strip() for line in open(os.path.join(root_dir, tool_file), 'r').readlines()]

        tools = []
        tooltimes = []
        for line in tool_text:
            tokens = line.split('\t')
            tools.append(tokens[2].lower())
            tooltimes.append(float(tokens[1]))

        if not tools:
            continue

        total_transcripts += 1
        all_tool_phrases_in_doc = []

        for tool, tool_time in zip(tools, tooltimes):
            for t, phrases in stemmed_corpus[key].items():
                if abs(tool_time - float(t)) <= window:
                    for p in phrases:
                        tool_phrase_count[tool + '-' + p] += 1
                        all_tool_phrases_in_doc.append(tool + '-' + p)

        all_tool_phrases_in_doc = set(list(all_tool_phrases_in_doc))
        for tool_phrase in all_tool_phrases_in_doc:
            tool_phrase_doc_count[tool_phrase] += 1

    print(total_transcripts)

    with open(os.path.join(root_dir, 'tool_phrase_count.jsonl'), 'w') as f:
        json.dump(tool_phrase_count, f)
    with open(os.path.join(root_dir, 'tool_phrase_doc_count.jsonl'), 'w') as f:
        json.dump(tool_phrase_doc_count, f)

    y = Counter(tool_phrase_count)
    print(y.most_common(100))

def tool_phrase_tf_idf_by_doc(data_dir, window=5):

    root_dir = os.path.dirname(data_dir)
    stemmed_corpus = json.load(open(os.path.join(root_dir, 'action_obj_stemmed_pairs.json')))
    all_action_object_pairs = []
    for key, val in tqdm(stemmed_corpus.items()):
        for t, v in val.items():
            all_action_object_pairs.extend(v)

    print("Found %s action-object phrases" % (len(all_action_object_pairs)))
    all_action_object_pairs = list(set(all_action_object_pairs))
    print("Found %s unique action-object phrases" % len(all_action_object_pairs))
    tool_phrase_doc_count = json.load(open(os.path.join(root_dir, 'tool_phrase_doc_count.jsonl')))

    total_transcripts = 3159
    sorted_tf_idf_by_doc = {}
    for i, (key, val) in tqdm(enumerate(stemmed_corpus.items())):
        tool_file = key.replace('.trans.tsv', '.tools.tsv')
        tool_text = [line.strip() for line in open(os.path.join(data_dir, tool_file), 'r').readlines()]

        tools = []
        tooltimes = []
        for line in tool_text:
            tokens = line.split('\t')
            tools.append(tokens[2].lower())
            tooltimes.append(float(tokens[1]))

        if not tools:
            continue

        tool_phrase_doc_count_this = defaultdict(lambda :0)
        all_phrases_in_doc = []

        for tool, tool_time in zip(tools, tooltimes):
            for t, phrases in stemmed_corpus[key].items():
                if abs(tool_time - float(t)) <= window:
                    for p in phrases:
                        tool_phrase_doc_count_this[tool + '-' + p] += 1

        phrase_tf_idf = {}
        total_terms = sum(list(tool_phrase_doc_count_this.values()))
        for p, count in tool_phrase_doc_count_this.items():
            tf = float(count)/total_terms
            idf = math.log(float(total_transcripts)/tool_phrase_doc_count[p])
            phrase_tf_idf[p] = {'tf': tf, 'idf': idf, 'tf-idf': tf * idf}

        sorted_k_v = {k: v for k, v in sorted(phrase_tf_idf.items(), key=lambda item: item[1]['tf-idf'], reverse=True)}
        sorted_tf_idf_by_doc[key] = sorted_k_v

    with open(os.path.join(root_dir, 'tool_phrase_ranked_tf_idf_by_doc.jsonl'), 'w') as f:
        json.dump(sorted_tf_idf_by_doc, f, indent=2)

tool_phrase_tf_idf_by_doc('./data/streams-2021-02-03')

def get_most_common_phrases(data_dir, idf=False):
    root_dir = os.path.dirname(data_dir)
    tool_phrase_counts = json.load(open(os.path.join(root_dir, 'tool_phrase_count.jsonl')))
    sorted_k_v = {k: v for k, v in sorted(tool_phrase_counts.items(), key=lambda item: item[1], reverse=True)}
    filtered_k_v = {}
    for k, v in sorted_k_v.items():
        tool = k.split('-')[0]
        if tool in ['color', 'select brush', 'hide', 'show', 'select']:
            continue
        else:
            filtered_k_v[k] = v
    with open(os.path.join(root_dir, 'tool_phrase_count_ranked_filtered.jsonl'), 'w') as f:
        json.dump(filtered_k_v, f, indent=2)

    tool_phrase_tf_idf = {}
    if idf:
        tool_phrase_doc_counts = json.load(open(os.path.join(root_dir, 'tool_phrase_doc_count.jsonl')))
        total_terms = sum(list(tool_phrase_counts.values()))
        total_docs = 3159
        for k, v in filtered_k_v.items():
            doc_count = tool_phrase_doc_counts[k] + 1
            tool_phrase_tf_idf[k] = (float(v)/total_terms) * math.log(float(total_docs)/doc_count)

        sorted_k_v = {k: v for k, v in sorted(tool_phrase_tf_idf.items(), key=lambda item: item[1], reverse=True)}
        with open(os.path.join(root_dir, 'tool_phrase_tf_idf_ranked.jsonl'), 'w') as f:
            json.dump(sorted_k_v, f, indent=2)

# prep_co_occurence_matrix_v2('./data/streams-2021-02-03')
# get_most_common_phrases('./data/streams-2021-02-03', idf=True)

