import numpy as np
import json
from pattern.en import tenses, wordnet, lexeme, pluralize, singularize, comparative, superlative
import csv
def transform(tok, cand_tok):
    tag = tok.tag_
    if tag in ["NN", "NNP"]:
        # cand = singularize(cand)
        cand = cand_tok.lemma_
    elif tag in ["NNS", "NNPS"]:
        cand = pluralize(cand_tok.text)
    elif tag == "JJ":
        cand = cand_tok.lemma_
    elif tag == "JJR":
        cand = comparative(cand_tok.text)
    elif tag == "JJS":
        cand = superlative(cand_tok.text)
    elif tag.startswith("VB"):
        cand_lex_list = lexeme(cand_tok.text)
        try:
            tense = collections.Counter([x[0] for x in pattern.en.tenses(tok.text)]).most_common(1)[0][0]
            p = pattern.en.tenses(tok.text)
            params = [tense, 3]
            if p:
                params = list(p[0])
                params[0] = tense
            cand = pattern.en.conjugate(cand_tok.text, *params)
            if cand is None:
                cand = cand_tok.lemma_
        except:
            cand = cand_tok.lemma_
    else:
        print(tok.text, tok.pos_, tok.tag_, cand_tok.text)
        cand = cand_tok.text
    return cand

class coherence():
    def __init__(self):
        import spacy
        from spacy.symbols import ORTH
        nlp = spacy.load('en_core_web_sm')
        nlp.tokenizer.add_special_case("[MALE]", [{ORTH:"[MALE]"}])
        nlp.tokenizer.add_special_case("[FEMALE]", [{ORTH:"[FEMALE]"}])
        nlp.tokenizer.add_special_case("[NEUTRAL]", [{ORTH:"[NEUTRAL]"}])
        self.nlp = nlp
        self.id_dict = {}
    def construct(self, origin_data):
        from nltk import ngrams
        for k, d in enumerate(origin_data):
            if k % 100 == 0:
                print("processing %d lines"%k)
            toks = self.nlp(d["truth"])
            rep_toks = self.nlp(np.random.choice(origin_data)["truth"])
            text_word_list = []
            pos_map = {}
            for rt in rep_toks:
                if rt.tag_ in ["NNS", "NNP", "NNPS", "VB", "VBD", "VBG", "VBN", "VBP", "VBZ"]:
                    if rt.tag_ in pos_map:
                        pos_map[rt.tag_].append(rt)
                    else:
                        pos_map[rt.tag_] = [rt]
            sub = []
            for t in toks:
                if t.tag_ in pos_map and np.random.random() < 0.25:
                    rand_cand_tok = np.random.choice(pos_map[t.tag_])
                    cand = transform(t, rand_cand_tok)
                    text_word_list.append(cand)
                    sub.append([t.text, cand])
                else:
                    text_word_list.append(t.text)

            if (len(sub) >= 3) and np.random.random() < 0.5:
                text = " ".join(text_word_list)
                self.id_dict[d["id"]] = {"text":text, "type":"type_substitute_token:%s"%",".join([" ".join(s) for s in sub])}
            else:
                tmp_sen_list = [sen.strip() for sen in (d["truth"]).split(".")]
                sen_list = []
                for sen in tmp_sen_list:
                    if len(sen) > 1:
                        sen_list.append(sen)
                if len(sen_list) == 0:
                    continue
                while True:
                    cand_st = [sen.strip() for sen in np.random.choice(origin_data)["truth"].split(".")]
                    if len(cand_st):
                        cand_sen = np.random.choice(cand_st)
                        if cand_sen != "":
                            break
                idx = np.random.choice(range(len(sen_list)))
                sen_list[idx] = cand_sen
                text = " . ".join(sen_list) + " ."
                if text == " . ".join(tmp_sen_list):
                    continue
                self.id_dict[d["id"]] = {"text":text, "type":"type_substitute_sen:%d"%(idx)}
        self.id_list = sorted(self.id_dict.keys())


for name in ["val", "test", "train"]:
    with open("../ini_data/%s.txt"%name) as fin:
        origin_data, all_data, tmp_data = [], {"data": []}, []
        for i, line in enumerate(fin):
            if (i+1) % 6 == 0:
                ipt = tmp_data[0]
                sen_list = tmp_data[1:]
                origin_data.append({"truth": " ".join(tmp_data[1:]), "ipt": tmp_data[0], "id": len(origin_data)})
                all_data["data"].append({"text": " ".join(tmp_data), "label": 1})
                tmp_data = []
            else:
                tmp_data.append(line.strip())
        cohe = coherence()
        cohe.construct(origin_data)
        for id_ in cohe.id_list:
            all_data["data"].append({"text": origin_data[id_]["ipt"].strip() + " " + cohe.id_dict[id_]["text"].strip(), "label": 0, "type":cohe.id_dict[id_]["type"]})

        all_data["data"] = np.random.permutation(all_data["data"]).tolist()
        with open("./%s.json"%name, "w") as fout:
            json.dump(all_data, fout, indent=4, ensure_ascii=False)
