import math
import spacy
import re
import sys
import os, json
import csv
import random
from tqdm import tqdm
from copy import copy
from collections import defaultdict, Counter
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer
from nltk.tokenize import word_tokenize
from data_utils import read_tool_timeline, read_transcript, get_intent_candidates, add_url_to_sample


stemmer = PorterStemmer()
stopwords = [w.strip() for w in open('../rkgraph/data/behance-stopwords.txt', 'r').readlines()]
for w in ['create', 'creates', 'feel', 'feels', 'make', 'makes', 'convey', 'conveys', 'vibe', 'invoke']:
    if w in stopwords:
        stopwords.remove(w)
stopwords.extend([stemmer.stem(w) for w in stopwords])
stopwords = list(set(stopwords))

# Load spAcy model for tokenization, stemming and dependency graph parsing
nlp = spacy.load('en_core_web_sm')
data_dir = './data/streams-2021-02-03/streams-2021-02-03'
files = [f for f in os.listdir(data_dir) if f.endswith('.tools.tsv')]
tag_to_id = {"O": 0, "B-INTENT": 1, "I-INTENT": 2}
root_dir = os.path.dirname(os.path.dirname(data_dir))

# regex for creative intents
re_convey = "([Ii]nspire(s|d)?|[Ee]voke(s|d)?|[Ii]nstill(s|ed)?|[Ss]uggest(s|ed)?|[Gg](ives?|ave)( it| us| you)?|[Aa]dd(s|ed)?|[Cc]reate(s|ed)?|[Cc]onvey(s|ed)?|[Pp]roduce(s|ed)?)( an?| the| some| more| extra)? (ambience|tension|drama|emotion|feeling|mood|sense|vibe|impression|sensation|hint)( [0-9A-Za-z,']+){1,9}"
re_convey_2 = "([Ii]nspire|[Ee]voke|[Ii]nstill|[Ss]uggest|[Gg]ive( it| us| you)?|[Aa]dd|[Cc]reate|[Cc]onvey|[Pp]roduce)( an?| the| some| more| extra)? (ambience|tension|drama|emotion|feeling|mood|sense|vibe|impression|sensation|hint|illusion|effect)( [0-9A-Za-z,']+){1,9}"
re_vibe = "(more of|kind of|sort of|like) ([a-z]+ )+vibes?"
re_vibe_2 = "([Gg]ives?( it| you| me)?|[Mm]ake|[Aa]dd|[Cc]reate|[Cc]onvey|[Pp]roduce)s?( [a-z]+)+ vibes?"
re_create = "([Cc]reate|[Mm]ake) (a|the)( [0-9A-Za-z']+){1,6}"
re_make_it_into = "(make it into)( [0-9A-Za-z']+){1,6}"


def merge_samples(tool_sample, creative_sample):
    tool_tags = tool_sample["tags"]
    creative_tags = creative_sample["tags"]

    new_tags = []
    for i, (t, c) in enumerate(zip(tool_tags, creative_tags)):
        if c != 0:
            new_tags.append(c)
        else:
            if t != 0:
                if t == 3 and len(new_tags)>0 and new_tags[-1] not in [1, 3]:
                    new_tags.append(0)
                else:
                        new_tags.append(t)
            else:
                new_tags.append(0)

    print('-------------------------------')
    print("Words: ", list(zip(range(len(tool_sample["str_words"])), tool_sample["str_words"])))
    print('Tool tags: ', tool_sample["tags"])
    print('Creative tags: ', creative_sample["tags"])
    print('Merged tags: ', new_tags)

    new_sample = copy(creative_sample)
    if any([t in [1,3] for t in new_tags]):
        if "label" not in new_sample:
            new_sample["label"] = []
        new_sample["label"].append("tool")
        if "spans" not in new_sample:
            new_sample["spans"] = []
        new_sample["spans"].append([new_tags.index(1), new_tags.index(3)])
    new_sample["tags"] = new_tags
    return new_sample

def convert_tags(tags, labels, spans):
    for label, span in zip(labels, spans):
        tags[span[0]] = 2 if label == 'creative' else 1
        for idx in range(span[0]+1, span[1]):
            tags[idx] = 4 if label == 'creative' else 3
    return tags


def regex_for_creative_intents(line):
    results = re.search(re_make_it_into, line)
    if results:
        start = line.index(results.groups()[0])
        end = line.index(results.groups()[-1]) + len(results.groups()[-1])
        return (start, end)
    else:
        return None


def subtrees(node):
    """Given a SpACy tree structure's node, return all children of the node"""
    if not node.children:
        return []
    result = [(node.dep, list(node.children), [c.i for c in node.children])]
    for n_child in node.children:
        result.extend(subtrees(n_child))
    return result

def get_prep_branch(node):
    """Given a SpACy tree structure's node, return all children of the node"""
    if not node.children:
        return []
    result = []
    for child in node.children:
        if child.dep in ['prep', 'pobj']:
            result.append([child.dep, None, [child.i]])
            result.extend(get_prep_branch(child))
    return result


def complete_labels(tags, action_start_idx, obj_idxs):
    """Given a set of tags, enforce consistency by replacing 'O' tags between B-ACTION and B/I-OBJ with apt OBJ tags"""
    # if tags contain B-ACTION / O / B-OBJ, convert to B-ACTION / B-OBJ / I-OBJ
    object_start_idx = obj_idxs[0]
    if object_start_idx <= action_start_idx:
        return False
    else:
        for i in range(action_start_idx+1, object_start_idx):
            tags[i] = tag_to_id["I-INTENT"]
        return tags


def set_labels(tags, action_idx, object_idxs, words):
    """Create BIO tags for given action object pair in a sentence"""
    object_idxs.sort()
    new_tags = copy(tags)
    if all([new_tags[i] == tag_to_id["O"] for i in [action_idx] + object_idxs]):
        new_tags[action_idx] = tag_to_id["B-INTENT"]
        for o_idx in object_idxs:
            tags[o_idx] = tag_to_id["I-INTENT"]
        new_tags = complete_labels(new_tags, action_idx, object_idxs)
    else:
        with open('log.txt', 'a+') as f:
            f.write("Skipped the sentence: " + ' '.join(words) + '\n')
            f.write("New tags: " + ' '.join([str(n) for n in [action_idx] + object_idxs]) + '\n')
            f.write("Current tags: " + ' '.join([str(n) for n in tags]) + '\n\n')
    return new_tags


def get_ranked_intents(data_dir, action_obj_file, tfidf_by_doc=False):

    corpus_count_file = action_obj_file.replace('.json', '_count.json')
    doc_count_file = action_obj_file.replace('.json', '_doc_count.json')
    stemmed_file = action_obj_file.replace('.json', '_stemmed.json')

    if not (os.path.exists(os.path.join(root_dir, corpus_count_file)) and
            os.path.exists(os.path.join(root_dir, doc_count_file))):
        # Use action object pairs
        print("Reading corpus")
        # Prepare stemmed file for better counts
        if not os.path.exists(os.path.join(root_dir, stemmed_file)):
            corpus = json.load(open(os.path.join(root_dir, action_obj_file)))
            stemmed_corpus = {}
            all_action_object_pairs = []
            for key, val in tqdm(corpus.items(), desc="Stemming action-object phrases"):
                stemmed_corpus[key] = {}
                for t, v in val.items():
                    tokenized_phrases = [word_tokenize(p[0].lower()+' '+p[1].lower()) for p in v['action-obj-pairs']]
                    stemmed_phrases = [' '.join([stemmer.stem(w) for w in phrase]) for phrase in tokenized_phrases]
                    all_action_object_pairs.extend(stemmed_phrases)
                    stemmed_corpus[key][t] = stemmed_phrases

            with open(os.path.join(root_dir, stemmed_file), 'w') as fout:
                json.dump(stemmed_corpus, fout)
        else:
            stemmed_corpus = json.load(open(os.path.join(root_dir, stemmed_file)))
            all_action_object_pairs = []
            for key, val in tqdm(stemmed_corpus.items()):
                for t, v in val.items():
                    all_action_object_pairs.extend(v)

        print("Found %s action-object phrases" % (len(all_action_object_pairs)))
        phrase_counts = dict(Counter(all_action_object_pairs))

        all_action_object_pairs = list(set(all_action_object_pairs))
        print("Found %s unique action-object phrases" % len(all_action_object_pairs))
        pair2idx = {p: i for i, p in enumerate(all_action_object_pairs)}

        total_transcripts = 0
        phrase_doc_count = defaultdict(lambda: 0)
        for i, (key, val) in tqdm(enumerate(stemmed_corpus.items())):
            tool_file = key.replace('.trans.tsv', '.tools.tsv')
            tools, tooltimes = read_tool_timeline(tool_file)
            if not tools:
                continue

            total_transcripts += 1
            all_phrases_in_doc = []

            for t, phrases in stemmed_corpus[key].items():
                for p in phrases:
                    all_phrases_in_doc.append(p)

            all_phrases_in_doc = set(all_phrases_in_doc)
            for phrase in all_phrases_in_doc:
                phrase_doc_count[phrase] += 1

        print(total_transcripts)

        with open(os.path.join(root_dir, corpus_count_file), 'w') as f:
            json.dump({k: v for k, v in sorted(phrase_counts.items(), key=lambda item: item[1], reverse=True)}, f, indent=2)
        with open(os.path.join(root_dir, doc_count_file), 'w') as f:
            json.dump({k: v for k, v in sorted(phrase_doc_count.items(), key=lambda item: item[1], reverse=True)}, f, indent=2)

    else:
        phrase_counts = json.load(open(os.path.join(root_dir, corpus_count_file)))
        phrase_doc_count = json.load(open(os.path.join(root_dir, doc_count_file)))

    # filter phrases by stopwords
    filtered_phrase_counts = {}
    for k, v in phrase_counts.items():
        _k = [w for w in k.split() if w not in stopwords]
        if len(_k)>0:
            filtered_phrase_counts[k] = v

    phrases = list(filtered_phrase_counts.keys())
    total_phrase_count = sum(list(filtered_phrase_counts.values()))
    total_n_transcripts = 3159
    print("Found %s phrases in corpus" % len(phrase_counts))
    print("Retained %s unique phrases in %s docs after filtering most frequent unigrams"
          % (len(filtered_phrase_counts), total_n_transcripts))
    phrase_tf_idf = {}
    sorted_phrase_tf_idf_by_doc = {}
    for p in phrases:
        tf = float(phrase_counts[p]/total_phrase_count)
        idf = math.log(float(total_n_transcripts)/phrase_doc_count[p]+1)
        phrase_tf_idf[p] = {'tf': tf, 'idf': idf, 'tf-idf': tf*idf}

    tfidf_file = action_obj_file.replace('.json', '_ranked_by_tf_idf.json')
    sorted_k_v = {k: v for k, v in sorted(phrase_tf_idf.items(), key=lambda item: item[1]['tf-idf'], reverse=True)}
    with open(os.path.join(root_dir, tfidf_file), 'w') as f:
        json.dump(sorted_k_v, f, indent=2)

    if tfidf_by_doc:
        stemmed_corpus = json.load(open(os.path.join(root_dir, stemmed_file)))
        for i, (key, val) in tqdm(enumerate(stemmed_corpus.items())):
            tool_file = key.replace('.trans.tsv', '.tools.tsv')
            tools, tooltimes = read_tool_timeline(tool_file)
            if not tools:
                continue

            phrase_doc_count_this = defaultdict(lambda: 0)
            for t, phrases in stemmed_corpus[key].items():
                for p in phrases:
                    if p in filtered_phrase_counts:
                        phrase_doc_count_this[p] += 1

            phrase_tf_idf = {}
            total_terms = sum(list(phrase_doc_count_this.values()))
            for p, count in phrase_doc_count_this.items():
                tf = float(count) / total_terms
                idf = math.log(float(total_n_transcripts) / phrase_doc_count[p])
                phrase_tf_idf[p] = {'tf': tf, 'idf': idf, 'tf-idf': tf * idf}

            sorted_k_v = {k: v for k, v in
                          sorted(phrase_tf_idf.items(), key=lambda item: item[1]['tf-idf'], reverse=True)}
            sorted_phrase_tf_idf_by_doc[key] = sorted_k_v
        tfidf_by_doc_file = action_obj_file.replace('.json', '_ranked_by_tf_idf_by_doc.json')
        with open(os.path.join(root_dir, tfidf_by_doc_file), 'w') as f:
            json.dump(sorted_phrase_tf_idf_by_doc, f, indent=2)


def get_helpx_intent_candidates(intent_file):

    intents = [c.strip() for c in open(intent_file).readlines()]
    tokenized_intents = [[w for w in word_tokenize(intent) if w not in stopwords] for intent in intents]
    candidates = []
    for intent_tokens in tokenized_intents:
        stemmed_tokens = [stemmer.stem(w) for w in intent_tokens]
        candidates.extend(stemmed_tokens)
    return candidates


def rank_tool_phrase_co_occurence(data_dir, action_obj_file, window=5, use_bag_of_tools=False):

    root_dir = os.path.dirname(os.path.dirname(data_dir))
    stemmed_file = action_obj_file.replace('.json', '_stemmed.json')
    stemmed_corpus = json.load(open(os.path.join(root_dir, stemmed_file)))
    all_action_object_pairs = []
    for key, val in tqdm(stemmed_corpus.items()):
        for t, v in val.items():
            all_action_object_pairs.extend(v)

    tfidf_by_doc_file = action_obj_file.replace('.json', '_ranked_by_tf_idf_by_doc.json')
    phrase_tf_idf_by_doc = json.load(open(os.path.join(root_dir, tfidf_by_doc_file)))
    tool_phrase_count = defaultdict(lambda : 0)
    for i, (key, val) in tqdm(enumerate(stemmed_corpus.items()), desc='Computing co-occurence stats'):
        tool_file = key.replace('.trans.tsv', '.tools.tsv')
        tools, tooltimes = read_tool_timeline(tool_file)
        if not tools:
            continue

        phrase_tf_idf = list(phrase_tf_idf_by_doc[key].keys())
        # phrase_tf_idf = [p for p in phrase_tf_idf[:int(0.2*len(phrase_tf_idf))] if not any([p.startswith(s) for s in filter_out_words])]

        if use_bag_of_tools:
            for t, phrases in stemmed_corpus[key].items():
                bag_of_tools = [tool for tool, tool_time in zip(tools, tooltimes) if
                                abs(tool_time - float(t)) <= window and tool not in
                                ['color', 'undo', 'redo', 'hide', 'show']]
                if bag_of_tools == []:
                    continue
                bag_of_tools = list(set(bag_of_tools))
                bag_of_tools.sort()
                for p in phrases:
                    if p in phrase_tf_idf:
                        tool_phrase_count['/'.join(bag_of_tools) + '||' + p] += 1
        else:
            for tool, tool_time in zip(tools, tooltimes):
                if tool in ['color', 'undo', 'redo', 'hide', 'show']:
                    continue
                for t, phrases in stemmed_corpus[key].items():
                    if abs(tool_time - float(t)) <= window:
                        for p in phrases:
                            if p in phrase_tf_idf:
                                tool_phrase_count[tool + '||' + p] += 1

    if use_bag_of_tools:
        out_file = action_obj_file.replace('.json', '_bot_count_win_%s.json' % (window*2))
    else:
        out_file = action_obj_file.replace('.json', '_tool_count_win_%s.json' % (window*2))
    with open(os.path.join(root_dir, out_file), 'w') as f:
        json.dump({k: v for k, v in sorted(tool_phrase_count.items(), key=lambda item: item[1], reverse=True)}, f, indent=2)


def get_action_object_tagged_dataset():

    count = 0
    parse_corpus = {}
    f_idx = 0
    for f in tqdm(files):
        f_idx += 1
        if f_idx % 1000 == 0:
            with open(os.path.join(root_dir, 'intent_v2_dataset.json'), 'w') as fout:
                json.dump(parse_corpus, fout)

        # 1. Read tool timeline and skip to next file if the file has zero content
        tools, tool_times = read_tool_timeline(os.path.join(data_dir, f))

        # 2. Read the transcript and skip to next file if the file has zero content
        transcript_file = os.path.join(data_dir, f.replace('.tools.tsv', '.trans.tsv'))
        trans, trans_times = read_transcript(transcript_file)

        if not trans or not tools:
            continue

        for t, text in zip(trans_times, trans):
            doc = nlp(text)

            action_obj_pairs = []
            spans = []
            words = [t.text for t in doc]
            tags = [tag_to_id["O"]] * len(words)
            for idx, token in enumerate(doc):
                # if token.head.pos_ == 'VERB':
                if token.pos_ == 'VERB':
                    object_idxs = []
                    obj = None
                    for child in token.children:
                        if child.dep_ == 'dobj' and child.i > token.i:
                            # if text == "I'll either duplicate the layer to make a backup copy, which I have.":
                            #     print(child.text)
                            object_idxs.append(child.i)
                            subtree = subtrees(child)
                            for dep, children, idxs in subtree:
                                object_idxs.extend(idxs)
                            # prep_phrase = get_prep_branch(token)
                            # for dep, children, idxs in prep_phrase:
                            #     object_idxs.extend(idxs)
                            object_idxs = list(set(object_idxs))
                            object_idxs.sort()
                            obj = [doc[i].text for i in object_idxs]
                            # if prep_phrase:
                            #     print(obj)
                            break

                    if obj:
                        # Filter 1: Retain action words convey, vibe, create
                        if token.text in ['create', 'convey', 'vibe', 'make']:
                            pass
                        # Filter 1: Remove stop words
                        else:
                            phrase = [token.text] + obj
                            filt_phrase = [t for t in phrase if t.lower() not in stopwords]
                            if len(filt_phrase) == 0:
                                # if text == "I'll either duplicate the layer to make a backup copy, which I have.":
                                #     print("Skipped: ", phrase, filt_phrase)
                                continue
                        new_tags = set_labels(tags, idx, object_idxs, words)
                        if not new_tags:
                            # if text == "I'll either duplicate the layer to make a backup copy, which I have.":
                            #     print("did not find tags??")
                            #     print(words)
                            #     print(idx, object_idxs)
                            #     print(tags)
                            #     print(new_tags)
                            continue
                        action_obj_pairs.append((token.text, ' '.join(obj)))
                        spans.append([idx, object_idxs[-1]+1])
                        # print(text, (token.text, ' '.join(obj)))
                        tags = new_tags

            # TODO: Write another function for VERB --> prep --> pobj and the rest

            if action_obj_pairs:
                # print(text, action_obj_pairs)
                count += 1
                if transcript_file not in parse_corpus:
                    parse_corpus[transcript_file] = {}
                parse_corpus[transcript_file][t] = {"text": text,
                                                    "str_words": words,
                                                    "tags": tags,
                                                    "action-obj-pairs": action_obj_pairs,
                                                    "spans": spans}

        if f_idx%100 == 0:
            print(f_idx, count)

    with open(os.path.join(root_dir, 'intent_v1_dataset.json'), 'w') as fout:
        json.dump(parse_corpus, fout)


def create_creative_dataset(dataset_file, action_obj_file, window=15):

    with open(os.path.join(root_dir, dataset_file), 'r') as f:
        dataset = json.load(f)

    ner_dataset = []
    tool_intents = 0
    creative_intents = 0
    counter = 0
    _, (intents, intent_pool)  = get_intent_candidates(action_obj_file, window)
    print(intent_pool)

    negative_samples = []
    for key, val in tqdm(dataset.items(), desc="Filtering dataset"):

        tools, tooltimes = read_tool_timeline(key.replace('.trans', '.tools'))
        for t, v in val.items():

            bag_of_tools = [tool for tool, tool_time in zip(tools, tooltimes) if
                            abs(tool_time - float(t)) <= 30 and tool not in
                            ['undo', 'redo', 'hide', 'show']]
            if not bag_of_tools:
                continue

            labels, action_obj_pairs, spans = [], [], []
            found = False
            for c in intents:
                if c.lower() in v['text'].lower():
                    start = v['text'].lower().index(c.lower())
                    end = start + len(c)
                    doc = nlp(v['text'])
                    labels.append("creative")
                    v["str_words"] = [token.text for token in doc]
                    word_idxs = [token.i for token in doc if (token.idx >= start and token.idx < end)]
                    spans.append([word_idxs[0], word_idxs[-1]+1])
                    tags = [0]*len(v["str_words"])
                    action_obj_pairs.append(c)
                    found = True
                    break
            if not found:
                span = regex_for_creative_intents(v["text"].lower())
                if span:
                    start, end = span
                    doc = nlp(v['text'])
                    labels.append("creative")
                    v["str_words"] = [token.text for token in doc]
                    word_idxs = [token.i for token in doc if (token.idx >= start and token.idx < end)]
                    try:
                        spans.append([word_idxs[0], word_idxs[-1]+1])
                    except IndexError:
                        print(span, v["text"][span[0]:span[1]])
                        print(v["text"])
                        continue
                    tags = [0]*len(v["str_words"])
                    action_obj_pairs.append(v['text'][start:end])
                    found = True
            if not found:
                for aop, span in zip(v["action-obj-pairs"], v["spans"]):
                    phrase = ' '.join(v['str_words'][span[0]:span[1]])
                    if phrase == 'make sense' or phrase == 'makes sense':
                        continue
                    tags = [0]*len(v["str_words"])
                    for intent in intent_pool:
                        if intent in phrase or phrase in intent:
                            labels.append("creative")
                            action_obj_pairs.append(aop)
                            spans.append(span)
                            # print('creative: ', phrase, '| from |', intent)
                            # print('full text: ', v["text"])
                            break

            if labels:
                ner_dataset.append({"str_words": v["str_words"],
                                    "tags": convert_tags(tags, labels, spans),
                                    "label": labels,
                                    "video_id": key.replace(root_dir, "").replace('\\', ''),
                                    "timestamp": t,
                                    "action-obj-pairs": action_obj_pairs,
                                    "spans": spans})
                creative_intents += sum([label == 'creative' for label in labels])
            else:
                negative_samples.append({"str_words": v["str_words"],
                                         "tags": [0]*len(v["str_words"]),
                                         "video_id": key.replace(root_dir, "").replace('\\', ''),
                                         "timestamp": t,
                                         })

        counter += 1
        if counter % 50 == 0:
            print("%s tool and %s creative intents" % (tool_intents, creative_intents))
        if counter % 200 == 0:
            with open(os.path.join(root_dir, 'temp.json'), 'w') as f:
                json.dump(ner_dataset, f)

    ner_dataset = [s for s in ner_dataset if not any([w.isspace() for w in s["str_words"]])]
    negative_samples = [s for s in negative_samples if not any([w.isspace() for w in s["str_words"]])]
    print("Found %s samples in entire corpus with atleast one intent" % len(ner_dataset))
    print("Found %s samples in entire corpus with no intents" % len(negative_samples))

    with open(os.path.join(root_dir, dataset_file.replace('.json', '_v4_creative_negatives.json')), 'w') as f:
        json.dump(negative_samples, f)
    with open(os.path.join(root_dir, dataset_file.replace('.json', '_v4_creative_positives.json')), 'w') as f:
        json.dump(ner_dataset, f)


def create_tool_dataset(dataset_file, action_obj_file, window=15):

    stemmed_dataset_file = dataset_file.replace('.json', '._stemmed.json')
    if not os.path.exists(os.path.join(root_dir, stemmed_dataset_file)):
        with open(os.path.join(root_dir, dataset_file), 'r') as f:
            dataset = json.load(f)
        for key, val in tqdm(dataset.items(), desc="Stemming the dataset"):
            for t, v in val.items():
                tokenized_phrases = [word_tokenize(p[0].lower() + ' ' + p[1].lower()) for p in v["action-obj-pairs"]]
                stemmed_phrases = [' '.join([stemmer.stem(w) for w in phrase]) for phrase in tokenized_phrases]
                dataset[key][t]["stemmed-phrases"] = stemmed_phrases

        with open(os.path.join(root_dir, stemmed_dataset_file), 'w') as f:
            json.dump(dataset, f)
    else:
        with open(os.path.join(root_dir, stemmed_dataset_file), 'r') as f:
            dataset = json.load(f)

    ner_dataset = []
    tool_intents = 0
    counter = 0
    tool_candidates, _ = get_intent_candidates(action_obj_file, window)
    print("Found %s tool candidates" % (len(tool_candidates)))
    # tool_candidates = [set(c.split()) for c in tool_candidates]
    # creative_candidates = [set([stemmer.stem(w) for w in c.split()]) for c in creative_candidates]
    negative_samples = []
    for key, val in tqdm(dataset.items(), desc="Filtering dataset"):

        tools, tooltimes = read_tool_timeline(key.replace('.trans', '.tools'))

        for t, v in val.items():

            bag_of_tools = [tool for tool, tool_time in zip(tools, tooltimes) if
                            abs(tool_time - float(t)) <= 2.5 and tool not in
                            ['undo', 'redo', 'hide', 'show']]
            labels, stemmed_phrases, action_obj_pairs, spans, key_phrases = [], [], [], [], []
            if bag_of_tools:
                tags = v["tags"]
                for p, aop, span in zip(v["stemmed-phrases"], v["action-obj-pairs"], v["spans"]):
                    # p_tokens = set(p.split())
                    for c, _ in tool_candidates.items():
                        if c in p:
                            labels.append("tool")
                            stemmed_phrases.append(p)
                            action_obj_pairs.append(aop)
                            spans.append(span)
                            key_phrases.append(c)
                            # print(p, '| from |', c)
                            break
                    else:
                        # change labels to other (supports multi-intent sentences and is also used downstream)
                        for i in range(span[0], span[1]):
                            tags[i] = 0

            if labels:
                ner_dataset.append({"str_words": v["str_words"],
                                    "tags": convert_tags(tags, labels, spans),
                                    "label": labels,
                                    "video_id": key.replace(root_dir, "").replace('\\', ''),
                                    "timestamp": t,
                                    "stemmed-phrases": stemmed_phrases,
                                    "action-obj-pairs": action_obj_pairs,
                                    "spans": spans,
                                    "tfidf_tools": [tool_candidates[c] for c in key_phrases],
                                    "timeline_tools": bag_of_tools
                                    })
                tool_intents += sum([label == 'tool' for label in labels])
            else:
                negative_samples.append({"str_words": v["str_words"],
                                         "tags": [0]*len(v["str_words"]),
                                         "video_id": key.replace(root_dir, "").replace('\\', ''),
                                         "timestamp": t,
                                         "timeline_tools": bag_of_tools
                                         })

        counter += 1
        if counter % 50 == 0:
            print("Found %s tool intents in entire corpus" % tool_intents)

        if counter % 200 == 0:
            with open(os.path.join(root_dir, 'temp.json'), 'w') as f:
                json.dump(ner_dataset, f)

    ner_dataset = [s for s in ner_dataset if not any([w.isspace() for w in s["str_words"]])]
    negative_samples = [s for s in negative_samples if not any([w.isspace() for w in s["str_words"]])]
    print("Found %s samples in entire corpus with atleast one tool intent" % len(ner_dataset))
    print("Found %s samples in entire corpus with no tool intents" % len(negative_samples))

    with open(os.path.join(root_dir, dataset_file.replace('.json', '_v2_tool_negatives.json')), 'w') as f:
        json.dump(negative_samples, f)
    with open(os.path.join(root_dir, dataset_file.replace('.json', '_v2_tool_positives.json')), 'w') as f:
        json.dump(ner_dataset, f)


def stratified_split(mode, positives_file, negatives_file, positives_count=(64000, 4000), negatives_count=(16000, 4000)):

    with open(os.path.join(root_dir, positives_file), 'r') as f:
        positive_samples = [add_url_to_sample(s) for s in json.load(f)]


    # with open(os.path.join(root_dir, negatives_file), 'r') as f:
    #     negative_samples = json.load(f)
    #     random.shuffle(negative_samples)

    if mode == 'tool':
        all_tools = []
        for s in positive_samples:
            for t in s["timeline_tools"]:
                all_tools.append(t)
        all_tools = Counter(all_tools)
        eval_count = 0
        eval_i = 0
        for i, (k,v) in enumerate(all_tools.most_common()[::-1]):
            eval_count += v
            if eval_count >= int(len(positive_samples)*0.45):
                eval_i = i+1
                break
        print("%s tools and %s positive samples included in evaluation set" % (eval_i, eval_count))
        # dev_count = -3000
        # test_count = -1000
        # print("Iterating for equivalent tool evaluation splits")
        # while abs(dev_count-test_count)>300:
        #     dev_key_idxs = random.sample(list(range(eval_i)), k=int(eval_i/2))
        #     test_key_idxs = [i for i in range(eval_i) if i not in dev_key_idxs]
        #     dev_count = sum([all_tools.most_common()[::-1][k][1] for k in dev_key_idxs])
        #     test_count = sum([all_tools.most_common()[::-1][k][1] for k in test_key_idxs])
        # print("Finished finding the most equivalent split: %s and %s samples in dev and test" % (dev_count, test_count))

        train_keys = [t for t, count in all_tools.most_common()[::-1][eval_i:]]
        eval_keys = [t for t, count in all_tools.most_common()[::-1][:eval_i]]

        # naive splitting
        # all_tools = list(set(all_tools))
        # random.shuffle(all_tools)
        # train_keys = all_tools[:int(0.8*len(all_tools))]
        # dev_keys = all_tools[int(0.8*len(all_tools)):int(0.9*len(all_tools))]
        # test_keys = all_tools[int(0.9*len(all_tools)):]

        train_set_idxs, eval_set_idxs, dev_set_idxs, test_set_idxs = [], [], [], []
        no_category = 0
        for i, s in tqdm(enumerate(positive_samples)):
            s_tools = []
            for t in s["timeline_tools"]:
                s_tools.append(t)
            # if all([t in train_keys for t in s_tools]):
            if any([t in train_keys for t in s["timeline_tools"]]):
                train_set_idxs.append(i)
            elif any([t in eval_keys for t in s["timeline_tools"]]):
                eval_set_idxs.append(i)
            else:
                no_category += 1
        print("Could not categorize %s samples" % no_category)

        random.shuffle(eval_set_idxs)
        dev_set_idxs = eval_set_idxs[:int(len(eval_set_idxs)/2)]
        test_set_idxs = eval_set_idxs[int(len(eval_set_idxs)/2):]

    elif mode == 'creative':
        all_session_ids = [s["session_id"] for s in positive_samples]
        all_session_ids = list(set(all_session_ids))
        random.shuffle(all_session_ids)
        all_keys = [s["session_id"] + '|' + s["timestamp"] for s in positive_samples]
        train_keys = all_session_ids[:int(0.7 * len(all_session_ids))]
        dev_keys = all_session_ids[int(0.7 * len(all_session_ids)):int(0.85 * len(all_session_ids))]
        test_keys = all_session_ids[int(0.85 * len(all_session_ids)):]

        train_set_idxs, dev_set_idxs, test_set_idxs = [], [], []
        no_category = 0
        for i, s in tqdm(enumerate(positive_samples)):
            if s["session_id"] in train_keys:
                train_set_idxs.append(i)
            elif s["session_id"] in dev_keys:
                dev_set_idxs.append(i)
            elif s["session_id"] in test_keys:
                test_set_idxs.append(i)
            else:
                no_category += 1
        print("Could not categorize %s samples" % no_category)

    else:
        intent2idx = defaultdict(lambda: [])
        for i, s in tqdm(enumerate(positive_samples)):
            for p in s["action-obj-pairs"]:
                intent2idx[' '.join(p)].append(i)

        all_intents = list(intent2idx.keys())
        print("Found %s unique intents in corpus" % len(all_intents))
        random.shuffle(all_intents)
        train_keys = all_intents[:int(0.8*len(all_intents))]
        dev_keys = all_intents[int(0.8*len(all_intents)):int(0.9*len(all_intents))]
        test_keys = all_intents[int(0.9*len(all_intents)):]

    # with open('temp.txt', 'w') as f:
    #     f.write('----------------------Train Keys-----------------------' + '\n')
    #     for i in train_keys:
    #         attr = open(os.path.join(data_dir, i + '.attrs.tsv'), errors='ignore').read()
    #         f.write(attr + '\n')
    #     f.write('----------------------Dev Keys-----------------------' + '\n')
    #     for i in dev_keys:
    #         attr = open(os.path.join(data_dir, i + '.attrs.tsv'), errors='ignore').read()
    #         f.write(attr + '\n')
    #     f.write('----------------------Test Keys--------------------------' + '\n')
    #     for i in test_keys:
    #         attr = open(os.path.join(data_dir, i + '.attrs.tsv'), errors='ignore').read()
    #         f.write(attr + '\n')


    train_set_idxs = set(train_set_idxs)
    dev_set_idxs = set(dev_set_idxs)
    test_set_idxs = set(test_set_idxs)

    # Remove overlapping samples between the three splits
    for idx in train_set_idxs.intersection(dev_set_idxs):
        dev_set_idxs.remove(idx)
    for idx in train_set_idxs.intersection(test_set_idxs):
        test_set_idxs.remove(idx)
    for idx in dev_set_idxs.intersection(test_set_idxs):
        test_set_idxs.remove(idx)

    # sample["video_id"] = sample["video_id"].replace("/streams-2021-02-03/streams-2021-02-03", "")

    print("Found %s / %s / %s positives samples in the dataset" % (len(set(train_set_idxs)), len(set(dev_set_idxs)), len(set(test_set_idxs))))

    # with open(os.path.join(root_dir, positives_file.replace('_positives.json', '_train.json')), 'w') as f:
    #     negatives_subset = [add_url_to_sample(s) for s in random.sample(
    #         negative_samples[:int(0.8*len(negative_samples))], k=negatives_count[0])]

    with open(os.path.join(root_dir, 'intent_dataset_v4_%s_train.json' % mode), 'w') as f:
        negatives_subset = [s for s in
                            json.load(open(os.path.join(root_dir, 'intent_dataset_v3_%s_train.json' % mode), 'r')) if
                            not any(s["tags"]) if s["session_id"] + '|' + s["timestamp"] not in all_keys]
        print(len(negatives_subset))
        positives_subset = [positive_samples[idx] for idx in random.sample(train_set_idxs, k=min(len(train_set_idxs),positives_count[0]))]
        json.dump(positives_subset + negatives_subset, f)

    # with open(os.path.join(root_dir, positives_file.replace('_positives.json', '_dev.json')), 'w') as f:
    #     negatives_subset = [add_url_to_sample(s) for s in random.sample(
    #         negative_samples[int(0.8*len(negative_samples)):int(0.9*len(negative_samples))], k=negatives_count[1])]

    with open(os.path.join(root_dir, 'intent_dataset_v4_%s_dev.json' % mode), 'w') as f:
        negatives_subset = [s for s in
                            json.load(open(os.path.join(root_dir, 'intent_dataset_v3_%s_dev.json' % mode), 'r')) if
                            not any(s["tags"]) if s["session_id"] + '|' + s["timestamp"] not in all_keys]
        print(len(negatives_subset))
        positives_subset = [positive_samples[idx] for idx in random.sample(dev_set_idxs, k=len(dev_set_idxs))]
        json.dump(positives_subset + negatives_subset, f)

    # with open(os.path.join(root_dir, positives_file.replace('_positives.json', '_test.json')), 'w') as f:
    #     negatives_subset = [add_url_to_sample(s) for s in random.sample(
    #         negative_samples[int(0.9 * len(negative_samples)):], k=negatives_count[1])]

    with open(os.path.join(root_dir, 'intent_dataset_v4_%s_test.json' % mode), 'w') as f:
        negatives_subset = [s for s in
                            json.load(open(os.path.join(root_dir, 'intent_dataset_v3_%s_test.json' % mode), 'r')) if
                            not any(s["tags"]) if s["session_id"] + '|' + s["timestamp"] not in all_keys]
        print(len(negatives_subset))
        positives_subset = [positive_samples[idx] for idx in random.sample(test_set_idxs, k=len(test_set_idxs))]
        json.dump(positives_subset + negatives_subset, f)


def combined_split(tool_file, creative_file, out_file):

    all_samples = {}
    for mode in ['train', 'dev', 'test']:
        with open(os.path.join(root_dir, tool_file.replace('.json', '_%s.json' % mode)), 'r') as f:
            samples = json.load(f)
            print('%s samples in ' % len(samples), os.path.join(root_dir, tool_file.replace('.json', '_%s.json' % mode)))
            all_samples[mode] = {s['session_id'] + '|' + s['timestamp']: s for s in samples}

    for mode in ['train', 'dev', 'test']:
        with open(os.path.join(root_dir, creative_file.replace('.json', '_%s.json' % mode)), 'r') as f:
            samples = json.load(f)
            print('%s samples in ' % len(samples), os.path.join(root_dir, creative_file.replace('.json', '_%s.json' % mode)))
            for s in samples:
                key = s['session_id'] + '|' + s['timestamp']
                if key in all_samples[mode]:
                    merged_s = merge_samples(all_samples[mode][key], s)
                    all_samples[mode][key] = merged_s
                else:
                    all_samples[mode][key] = s

    print("Found %s / %s / %s samples in the dataset" % (len(all_samples['train']), len(all_samples['dev']), len(all_samples['test'])))

    # with open(os.path.join(root_dir, negatives_file), 'r') as f:
    #     negative_samples = [s for s in json.load(f) if all([s['session_id'] + '|' + s['timestamp'] not in all_samples[mode] for mode in ['train', 'dev', 'test']])]
    #     random.shuffle(negative_samples)

    with open(os.path.join(root_dir, out_file.replace('.json', '_train.json')), 'w') as f:
        # negatives_subset = [add_url_to_sample(s) for s in random.sample(
        #     negative_samples[:int(0.8*len(negative_samples))], k=6000)]
        # json.dump(list(all_samples['train'].values()) + negatives_subset, f)
        json.dump(list(all_samples['train'].values()), f)
    with open(os.path.join(root_dir, out_file.replace('.json', '_dev.json')), 'w') as f:
        # negatives_subset = [add_url_to_sample(s) for s in random.sample(
        #     negative_samples[int(0.8*len(negative_samples)):int(0.9*len(negative_samples))], k=1000)]
        # json.dump(list(all_samples['dev'].values()) + negatives_subset, f)
        json.dump(list(all_samples['dev'].values()), f)
    with open(os.path.join(root_dir, out_file.replace('.json', '_test.json')), 'w') as f:
        # negatives_subset = [add_url_to_sample(s) for s in random.sample(
        #     negative_samples[int(0.9 * len(negative_samples)):], k=1000)]
        # json.dump(list(all_samples['dev'].values()) + negatives_subset, f)
        json.dump(list(all_samples['test'].values()), f)


def prepare_corpus_dataset(dataset_file):
    with open(os.path.join(root_dir, dataset_file), 'r') as f:
        dataset = json.load(f)

    ner_dataset = []
    skipped = 0
    for key, val in tqdm(dataset.items(), desc="Preparing dataset"):
        for t, v in val.items():
            if len(v["str_words"]) != len(v["tags"]) or any([w.isspace() for w in v["str_words"]]):
                skipped += 1
                continue
            ner_dataset.append({"str_words": v["str_words"],
                                "tags": v["tags"],
                                "video_id": key.replace("./data/streams-2021-02-03/streams-2021-02-03", "").replace('\\', ''),
                                "timestamp": t,
                                "action-obj-pairs": v["action-obj-pairs"]})
    print(skipped)

    with open(os.path.join(root_dir, dataset_file.replace('.json', '_corpus.json')), 'w') as f:
        json.dump(ner_dataset, f)


def prepare_use_dataset(dataset_file, mode="sentence"):

    tool_words = []
    for k, v in json.load(open('./data/tooltip_counts.json', 'r')).items():
        tool_words.extend([w.lower() for w in k.split()])
    tool_words = list(set(tool_words))
    # tool_word_counts = Counter(tool_words)
    # for w, count in tool_word_counts.most_common(50):
    #     print(w, count)
    for w in ["select", "tool", "brush", "selection", "group", "color", "mask", "move", "add",
              "paste", "type", "panel", "history", "hide", "show", "erase", "copy", "change", "make", "fill"]:
        tool_words.remove(w)

    with open(os.path.join(root_dir, dataset_file), 'r') as f:
        dataset = json.load(f)

    creative_phrases = []
    all_other_phrases = []
    for key, val in tqdm(dataset.items(), desc="Filtering dataset"):
        for t, v in val.items():
            if regex_for_creative_intents(' '.join(v["str_words"])):
                if mode == "phrase":
                    for p in v["action-obj-pairs"]:
                        creative_phrases.append({"phrase": ' '.join(p),
                                                 "sentence": ' '.join(v["str_words"]),
                                                 "video_id": key.replace("./data/streams-2021-02-03/streams-2021-02-03",
                                                                         "").replace('\\', ''),
                                                 "timestamp": t,
                                                 "str_words": v["str_words"]})
                else:
                    creative_phrases.append({"phrase": [' '.join(p) for p in v["action-obj-pairs"]],
                                             "sentence": ' '.join(v["str_words"]),
                                             "video_id": key.replace("./data/streams-2021-02-03/streams-2021-02-03", "").replace('\\', ''),
                                             "timestamp": t,
                                             "str_words": v["str_words"]})
            else:
                phrases = [' '.join(p) for p in v["action-obj-pairs"]]
                if mode == "phrase":
                    for p in phrases:
                        if any([w in p for w in tool_words]):
                            continue
                        all_other_phrases.append({"phrase": p,
                                                  "sentence": ' '.join(v["str_words"]),
                                                  "video_id": key.replace(
                                                      "./data/streams-2021-02-03/streams-2021-02-03", "").replace('\\',
                                                                                                                  ''),
                                                  "timestamp": t,
                                                  "str_words": v["str_words"]})
                else:
                    if any([any([w in p for w in tool_words]) for p in phrases]):
                        continue
                    all_other_phrases.append({"phrase": phrases,
                                              "sentence": ' '.join(v["str_words"]),
                                              "video_id": key.replace("./data/streams-2021-02-03/streams-2021-02-03", "").replace('\\', ''),
                                             "timestamp": t,
                                             "str_words": v["str_words"]})

    with open(os.path.join(root_dir, dataset_file.replace('.json', '_creative_%s.json' % mode)), 'w') as f:
        json.dump(creative_phrases, f)

    with open(os.path.join(root_dir, dataset_file.replace('.json', '_other_%s.json' % mode)), 'w') as f:
        json.dump(all_other_phrases, f)

    print("Found %s creative phrases and %s other phrases" % (len(creative_phrases), len(all_other_phrases)))


def dump_intents_to_csv():


    with open('../rkgraph/data/2021-07-07/behance_intent_data.json.txt', 'r') as f:
        samples = [json.loads(line.strip()) for line in f.readlines()]
        print(len(samples))
        seed_intents = list(set([s['c-intent'] for s in samples]))

    with open('./out/regex_creative_intents_annotation_%s.csv', 'w', newline='') as f:
        csvwriter = csv.writer(f)
        csvwriter.writerow(['intent', 'label'])
        for intent in seed_intents:
            csvwriter.writerow([intent, ''])


    # with open('./out/top_5_similar_phrase.json', 'r') as f:
    #     samples = json.load(f)
    #
    # seed_intents = list(samples.keys())
    # for k, v in samples.items():
    #     seed_intents.extend([s['phrase'] for s in v])
    #
    # seed_intents = list(set(seed_intents))
    # print("%s intents found" % len(seed_intents))
    # random.shuffle(seed_intents)
    #
    # for i in range(3):
    #     with open('./out/creative_intents_annotation_%s.csv' % i, 'w', newline='') as f:
    #         csvwriter = csv.writer(f)
    #         csvwriter.writerow(['intent', 'label'])
    #         for intent in seed_intents[int(i/3*len(seed_intents)):int((i+1)/3*len(seed_intents))]:
    #             csvwriter.writerow([intent, ''])


def main():

    # Step 1: Get all action objects in the corpus for computing stats.
    # Here, the action-objects are only verb-->dobj
    # get_action_object_pairs('./data/streams-2021-02-03/streams-2021-02-03')

    # Step 2: Rank by tf-idf
    # get_ranked_intents('./data/streams-2021-02-03/streams-2021-02-03', 'action_obj_pairs_dobj_only.json', True)

    # Step 3: Rank tool-intent co-occurrence
    # rank_tool_phrase_co_occurence('./data/streams-2021-02-03/streams-2021-02-03',
    #                               'action_obj_pairs_dobj_only.json', window=15, use_bag_of_tools=False)
    # rank_tool_phrase_co_occurence('./data/streams-2021-02-03/streams-2021-02-03', 'action_obj_pairs_dobj_only.json',
    #                               window=15, use_bag_of_tools=True)

    # Step 4: Get tagged dataset (can be combined with Step 1 in further iterations
    # Here, the action objects are the expanded version.
    # get_action_object_tagged_dataset()

    #TODO: Also filter dataset by tool words

    # Step 5: Postprocess dataset
    # create_creative_dataset('intent_dataset.json', 'action_obj_pairs_dobj_only.json', window=5)
    # create_tool_dataset('intent_dataset.json', 'action_obj_pairs_dobj_only.json', window=5)
    # stratified_split('tool', 'intent_dataset_v2_tool_positives.json', 'intent_dataset_v2_tool_negatives.json', positives_count=(21000, 2500), negatives_count=(5000,1000))
    # stratified_split('creative', 'intent_dataset_v4_creative_positives.json', 'intent_dataset_v4_creative_negatives.json', positives_count=(8000, 4000), negatives_count=(2000,400))
    combined_split('intent_dataset_v4_tool.json', 'intent_dataset_v4_creative.json', 'intent_dataset_both_v6.json')

    # Step 7: Create corpus wide dataset
    # prepare_corpus_dataset('intent_dataset.json')

    # Step 8: Create dataset for bootstrapping
    # prepare_use_dataset('intent_dataset.json')
    # prepare_use_dataset('intent_dataset.json', mode="phrase")

    # Step 9: Dump intents to csv for annotation or cleaning
    # dump_intents_to_csv()


if __name__ == '__main__':
    main()
