import argparse
import os
import ipdb
import random
import stanza
from tqdm import tqdm
import time

random.seed(0)
if __name__ == '__main__':
    parser = argparse.ArgumentParser('clean input file')
    parser.add_argument('--fp1', type=str)
    parser.add_argument('--fp2', type=str)
    parser.add_argument('--model_type', type=str, required=True, help='model type')
    parser.add_argument('--data_type', type=str, required=True, help='model type')
    parser.add_argument('--lang', type=str, required=True, help='language')
    parser.add_argument('--lemma', action='store_true')
    parser.add_argument('--no_verb_removal', action='store_true')
    parser.add_argument('--no_pos_tags', action='store_true')
    parser.add_argument('--out', type=str)
    args = parser.parse_args()
    if args.lang == 'en':
        args.fp1 = "../data/en/train.sentences"
        args.fp2 = "../data/en/train.target"
        args.out = f"../models/en/{args.model_type}/{args.data_type}-data" 
    else:
        if 'saoke' in args.data_type:
            args.fp1 = f"../data/{args.lang}/{args.data_type}/train.sentences"
        elif 'ctranslate' in args.data_type:
            args.fp1 = f"../data/{args.lang}/mbart/consistent/train.sentences"
            # args.fp1 = f"../data/{args.lang}/mbart/train.sentences"
        else:
            args.fp1 = f"../data/{args.lang}/mbart/train.sentences"
        args.fp2 = f"../data/{args.lang}/{args.data_type}/train.target"
        args.out = f"../models/{args.lang}/{args.model_type}/{args.data_type}-data" 
    if args.lang == 'hi':
        args.lemma = True
    if 'ud_' in args.lang:
        proc_lang = args.lang.split('_')[1]
    else:
        proc_lang = args.lang

    if not os.path.exists(args.out):
        os.makedirs(args.out)
    def helper(x, y1, y2):
        x = x.split()
        res = ''
        while y1 in x and y2 in x:
            start_index = x.index(y1)
            end_index = x.index(y2)
            x[start_index] = ''
            x[end_index] = ''
            for ind in range(start_index+1, end_index):
                if x[ind] in ['<r>', '</r>', '<a2>', '</a2>', '<l>', '</l>', '<t>', '</t>','<a1>', '</a1>']:
                    ipdb.set_trace()
                res += (x[ind].strip() + ' ')
        return res.strip()
    if args.model_type == 'rerank':
        with open(args.fp1, 'r') as f:
            sentences = f.readlines()
        with open(args.fp2, 'r') as f:
            extractions = f.readlines()

        assert len(sentences) == len(extractions), ipdb.set_trace()

        data = []
        for i in range(len(sentences)):
            for ext in extractions[i].strip().split('<e>')[:-1]:
                data.append((sentences[i].strip(), ext.strip()))

        random.shuffle(data)
        valid_data_len = int(len(data)*0.1)
        valid_data = data[0:valid_data_len] 
        train_data = data[valid_data_len:] 
        train_input = open(args.out+'/train.input', 'w')
        train_target = open(args.out+'/train.target', 'w')
        valid_input = open(args.out+'/valid.input', 'w')
        valid_target = open(args.out+'/valid.target', 'w')

        for i in range(len(train_data)):
            train_input.write(train_data[i][0].strip() + '\n')
            train_target.write(train_data[i][1].strip() + '\n')

        for i in range(len(valid_data)):
            valid_input.write(valid_data[i][0].strip() + '\n')
            valid_target.write(valid_data[i][1].strip() + '\n')

        train_input.close()
        train_target.close()
        valid_input.close()
        valid_target.close()
    if args.model_type == 'genoie':
        with open(args.fp1, 'r') as f:
            sentences = f.readlines()
        with open(args.fp2, 'r') as f:
            extractions = f.readlines()

        assert len(sentences) == len(extractions), ipdb.set_trace()

        data = []
        for i in range(len(sentences)):
            data.append((sentences[i].strip(), extractions[i].strip()))

        if 'ud' in args.lang:
            valid_data = data[0:298] 
            train_data = data[298:] 
        else:
            random.shuffle(data)
            valid_data = data[0:1000] 
            train_data = data[1000:] 
        train_input = open(args.out+'/train.input', 'w')
        train_target = open(args.out+'/train.target', 'w')
        valid_input = open(args.out+'/valid.input', 'w')
        valid_target = open(args.out+'/valid.target', 'w')

        for i in range(len(train_data)):
            train_input.write(train_data[i][0].strip() + '\n')
            train_target.write(train_data[i][1].strip() + '\n')

        for i in range(len(valid_data)):
            valid_input.write(valid_data[i][0].strip() + '\n')
            valid_target.write(valid_data[i][1].strip() + '\n')

        train_input.close()
        train_target.close()
        valid_input.close()
        valid_target.close()
    if args.model_type == 'gen2oie_s1':
        if args.lemma:
            nlp = stanza.Pipeline(proc_lang, tokenize_pretokenized= True, tokenize_no_ssplit=True, processors='tokenize,pos,lemma',tokenize_batch_size=4096)
            random.seed(0)
            with open(args.fp1, 'r') as f:
                sentences = f.readlines()
            with open(args.fp2, 'r') as f:
                extractions = f.readlines()

            assert len(sentences) == len(extractions)

            all_sentences = []
            for sentence in sentences:
                all_sentences.append(sentence.strip().split())
                
            print('started stanza processing')
            start_time = time.time()
            tgt = []
            for sentence in nlp(all_sentences).sentences:
                tgt.append(sentence.to_dict())
            end_time = time.time()
            print('time taken for stanza is ', end_time-start_time)

            relations_data = []
            all_relations_data = []
            for i in range(len(sentences)):
                sentence = sentences[i].strip()
                extraction_list = extractions[i].split('<e>')
                relations = []
                for j in range(len(extraction_list)):
                    ext = extraction_list[j].strip()
                    rel = helper(ext, '<r>', '</r>')
                    if rel != "":
                        relations.append(rel)
                if len(relations) == 0:
                    relations_data.append('<r>')
                    all_relations_data.append(['<r>'])
                else:
                    relations_data.append(" <r> ".join(relations).strip() + ' <r>')
                    all_relations_data.append(relations_data[i].split())
            print('started stanza processing')
            start_time = time.time()
            rel_tags = []
            for sentence in nlp(all_relations_data).sentences:
                rel_tags.append(sentence.to_dict())
            end_time = time.time()
            print('time taken for stanza is ', end_time-start_time)

            def make_input(sent_tgt,base_relation_data, exact_relation_data, tgt_tag):
                input = []
                tmp = False
                new_sent = []
                for ind in range(len(sent_tgt)):
                    if tgt_tag[ind][0].strip() in ['AUX', 'VERB']:
                        if tgt_tag[ind][1].strip() not in base_relation_data:
                            continue
                    new_sent.append(sent_tgt[ind].strip())
                    input.append('# ' + tgt_tag[ind][0].strip() + ' ## ' + sent_tgt[ind].strip())
                input = " ".join(input).strip()
                # if tmp:
                #     print(" ".join(sent_tgt))
                #     print(" ".join(new_sent))
                #     print(base_relation_data)
                #     print(exact_relation_data)
                #     ipdb.set_trace()
                return input
            assert len(sentences) == len(relations_data) 
            input_output = []
            tgt_tag = []
            for ind in range(len(all_sentences)):
                sentence = all_sentences[ind]
                tgt_tag_sent = []
                for ind1 in range(len(sentence)):
                    assert tgt[ind][ind1]['text'] == sentence[ind1], ipdb.set_trace()
                    tgt_tag_sent.append((tgt[ind][ind1]['upos'],tgt[ind][ind1]['lemma']))
                tgt_tag.append(tgt_tag_sent)
            
            tgt_rel_tag = []
            for ind in range(len(all_relations_data)):
                relation_phrase = all_relations_data[ind]
                tgt_rel_tag_sent = []
                for ind1 in range(len(relation_phrase)):
                    assert rel_tags[ind][ind1]['text'] == relation_phrase[ind1], ipdb.set_trace()
                    if relation_phrase[ind1] == '<r>':
                        tgt_rel_tag_sent.append('<r>')
                    else:
                        tgt_rel_tag_sent.append(rel_tags[ind][ind1]['lemma'])
                tgt_rel_tag.append(" ".join(tgt_rel_tag_sent).strip())


            for i in tqdm(range(len(all_sentences))):
                input_output.append((make_input(all_sentences[i],tgt_rel_tag[i] ,relations_data[i],tgt_tag[i]), relations_data[i].strip()))

            if 'ud' in args.lang:
                valid_data = input_output[0:298]
                train_data = input_output[298:]
            else:
                random.shuffle(input_output)
                valid_data = input_output[0:1000]
                train_data = input_output[1000:]

            train_input = open(args.out+'/train.input', 'w')
            train_target = open(args.out+'/train.target', 'w')
            valid_input = open(args.out+'/valid.input', 'w')
            valid_target = open(args.out+'/valid.target', 'w')


            for i in range(len(train_data)):
                train_input.write(train_data[i][0].strip() + '\n')
                train_target.write(train_data[i][1].strip() + '\n')
            for i in range(len(valid_data)):
                valid_input.write(valid_data[i][0].strip() + '\n')
                valid_target.write(valid_data[i][1].strip() + '\n')


            train_input.close()
            train_target.close()
            valid_input.close()
            valid_target.close()
        else:
            nlp = stanza.Pipeline(proc_lang, tokenize_pretokenized= True, tokenize_no_ssplit=True, processors='tokenize,pos',tokenize_batch_size=4096)
            random.seed(0)
            with open(args.fp1, 'r') as f:
                sentences = f.readlines()
            with open(args.fp2, 'r') as f:
                extractions = f.readlines()

            assert len(sentences) == len(extractions)

            relations_data = []
            for i in range(len(sentences)):
                sentence = sentences[i].strip()
                extraction_list = extractions[i].split('<e>')
                relations = []
                for j in range(len(extraction_list)):
                    ext = extraction_list[j].strip()
                    rel = helper(ext, '<r>', '</r>')
                    if rel != "":
                        relations.append(rel)
                if len(relations) == 0:
                    relations_data.append('<r>')
                else:
                    relations_data.append(" <r> ".join(relations).strip() + ' <r>')

            def make_input(sent_tgt, relation_data,tgt_tag):
                input = []
                tmp = False
                new_sent = []
                for ind in range(len(sent_tgt)):
                    if not args.no_verb_removal:
                        if tgt_tag[ind].strip() in ['AUX', 'VERB']:
                            if sent_tgt[ind].strip() not in relation_data:
                                # tmp = True
                                continue
                    new_sent.append(sent_tgt[ind].strip())
                    input.append('# ' + tgt_tag[ind].strip() + ' ## ' + sent_tgt[ind].strip())
                if args.no_pos_tags:
                    input = " ".join(new_sent).strip()
                else:
                    input = " ".join(input).strip()
                # if tmp:
                #     print(" ".join(sent_tgt))
                #     print(" ".join(new_sent))
                #     print(relation_data)
                #     ipdb.set_trace()
                return input
            assert len(sentences) == len(relations_data) 
            input_output = []
            all_sentences = []
            for sentence in sentences:
                all_sentences.append(sentence.strip().split())
            print('started stanza processing')
            start_time = time.time()
            tgt = []
            for sentence in nlp(all_sentences).sentences:
                tgt.append(sentence.to_dict())
            end_time = time.time()
            print('time taken for stanza is ', end_time-start_time)
            tgt_tag = []
            for ind in range(len(all_sentences)):
                sentence = all_sentences[ind]
                tgt_tag_sent = []
                for ind1 in range(len(sentence)):
                    assert tgt[ind][ind1]['text'] == sentence[ind1], ipdb.set_trace()
                    tgt_tag_sent.append(tgt[ind][ind1]['upos'])
                tgt_tag.append(tgt_tag_sent)

            for i in tqdm(range(len(all_sentences))):
                input_output.append((make_input(all_sentences[i],relations_data[i].strip(), tgt_tag[i]), relations_data[i].strip()))

            if 'ud' in args.lang:
                valid_data = input_output[0:298]
                train_data = input_output[298:]
            else:
                random.shuffle(input_output)
                valid_data = input_output[0:1000]
                train_data = input_output[1000:]

            train_input = open(args.out+'/train.input', 'w')
            train_target = open(args.out+'/train.target', 'w')
            valid_input = open(args.out+'/valid.input', 'w')
            valid_target = open(args.out+'/valid.target', 'w')


            for i in range(len(train_data)):
                train_input.write(train_data[i][0].strip() + '\n')
                train_target.write(train_data[i][1].strip() + '\n')
            for i in range(len(valid_data)):
                valid_input.write(valid_data[i][0].strip() + '\n')
                valid_target.write(valid_data[i][1].strip() + '\n')


            train_input.close()
            train_target.close()
            valid_input.close()
            valid_target.close()

    if args.model_type == 'gen2oie_s2':
        if args.lemma:
            nlp = stanza.Pipeline(proc_lang, tokenize_pretokenized= True, tokenize_no_ssplit=True, processors='tokenize,lemma',tokenize_batch_size=4096)
            random.seed(0)
            with open(args.fp1, 'r') as f:
                sentences = f.readlines()
            with open(args.fp2, 'r') as f:
                extractions = f.readlines()

            assert len(sentences) == len(extractions)

            def corrupt(x1,x2,y):
                relation_words = x1.strip().split()
                base_relation_words = x2.strip().split()
                assert len(relation_words) == len(base_relation_words)
                sentence_words = y.strip().split()
                add_words = list(set(sentence_words) - set(relation_words)-set(base_relation_words))
                res = []
                new_rel = []
                for index,exact_word in enumerate(relation_words):
                    exact_word = exact_word.strip()
                    base_word = base_relation_words[index].strip()
                    p1 = random.random()
                    if p1<=0.5:
                        word = exact_word
                    else:
                        word = base_word
                    new_rel.append(word)
                    p = random.random()
                    if p<=0.8:
                        res.append(word)
                    if p>=0.9:
                        res.append(word)
                        if len(add_words) > 0:
                            res.append(random.choice(add_words))
                if len(res) == 0:
                    return " ".join(new_rel)
                return " ".join(res).strip()
            
            relations_data = []
            extractions_data = []
            rel_sentence_data = []
            sentences_data = []
            for i in range(len(sentences)):
                sentence = sentences[i].strip()
                extraction_list = extractions[i].split('<e>')
                relations = []
                exts = []
                for j in range(len(extraction_list)):
                    ext = extraction_list[j].strip()
                    rel = helper(ext, '<r>', '</r>')
                    if rel != "":
                        if rel in relations:
                            index = relations.index(rel)
                            exts[index].append(ext)
                        else:
                            relations.append(rel)
                            exts.append([ext])
                assert len(relations) == len(exts)

                for ind in range(len(relations)):
                    relation = relations[ind].strip()
                    rel_sentence_data.append(relation + ' <r> ' + sentence)
                    res = ""
                    for extraction in exts[ind]:
                        res+= (extraction + ' <e> ')
                    extractions_data.append(res.strip())
            all_relations = []
            for i in range(len(rel_sentence_data)):
                rel_sent = rel_sentence_data[i].strip().split('<r>')
                assert len(rel_sent) == 2
                relation = rel_sent[0].strip()
                all_relations.append(relation.split())
            print('started stanza processing')
            start_time = time.time()
            all_relations_tags = []
            for sentence in nlp(all_relations).sentences:
                all_relations_tags.append(sentence.to_dict())
            end_time = time.time()
            print('time taken for stanza is ', end_time-start_time)
            base_relations_tag = []
            for ind in range(len(all_relations)):
                relation = all_relations[ind]
                base_relations_tag_rel = []
                for ind1 in range(len(relation)):
                    assert all_relations_tags[ind][ind1]['text'] == relation[ind1], ipdb.set_trace()
                    base_relations_tag_rel.append(all_relations_tags[ind][ind1]['lemma'])

                base_relations_tag.append(" ".join(base_relations_tag_rel).strip())

            corrupted_rel_sentence_data = []
            for i in range(len(rel_sentence_data)):
                rel_sent = rel_sentence_data[i].strip().split('<r>')
                assert len(rel_sent) == 2
                relation = rel_sent[0].strip()
                base_relation = base_relations_tag[i]
                sentence = rel_sent[1].strip()
                corrupted_relation = corrupt(relation,base_relation, sentence)
                corrupted_rel_sentence_data.append(corrupted_relation + ' <r> ' + sentence)

            corrupt_input_ext_data = []
            assert len(corrupted_rel_sentence_data) == len(extractions_data)
            for i in range(len(corrupted_rel_sentence_data)):
                corrupt_input = corrupted_rel_sentence_data[i].strip()
                corrupt_input_ext_data.append((corrupt_input.strip(), extractions_data[i].strip()))

            if 'ud' in args.lang:
                valid_data = corrupt_input_ext_data[0:298]
                train_data = corrupt_input_ext_data[298:]
            else:
                random.shuffle(corrupt_input_ext_data)
                valid_data = corrupt_input_ext_data[0:2000]
                train_data = corrupt_input_ext_data[2000:]

            train_input = open(args.out+'/train.input', 'w')
            train_target = open(args.out+'/train.target', 'w')
            valid_input = open(args.out+'/valid.input', 'w')
            valid_target = open(args.out+'/valid.target', 'w')


            for i in range(len(train_data)):
                train_input.write(train_data[i][0].strip() + '\n')
                train_target.write(train_data[i][1].strip() + '\n')
            for i in range(len(valid_data)):
                valid_input.write(valid_data[i][0].strip() + '\n')
                valid_target.write(valid_data[i][1].strip() + '\n')


            train_input.close()
            train_target.close()
            valid_input.close()
            valid_target.close()
        else:
            random.seed(0)
            with open(args.fp1, 'r') as f:
                sentences = f.readlines()
            with open(args.fp2, 'r') as f:
                extractions = f.readlines()

            assert len(sentences) == len(extractions)

            def corrupt(x,y):
                res = []
                relation_words = x.strip().split()
                sentence_words = y.strip().split()
                add_words = list(set(sentence_words) - set(relation_words))
                if len(add_words) == 0:
                    return " ".join(relation_words)
                if len(relation_words) > 1:
                    for word in relation_words:
                        word = word.strip()
                        p = random.random()
                        if p<=0.8:
                            res.append(word)
                        if p>=0.9:
                            res.append(word)
                            res.append(random.choice(add_words))
                else:
                    return x.strip()

                return " ".join(res).strip()
            relations_data = []
            extractions_data = []
            rel_sentence_data = []
            sentences_data = []
            for i in range(len(sentences)):
                sentence = sentences[i].strip()
                extraction_list = extractions[i].split('<e>')
                relations = []
                exts = []
                for j in range(len(extraction_list)):
                    ext = extraction_list[j].strip()
                    rel = helper(ext, '<r>', '</r>')
                    if rel != "":
                        if rel in relations:
                            index = relations.index(rel)
                            exts[index].append(ext)
                        else:
                            relations.append(rel)
                            exts.append([ext])
                if len(relations) == 0:
                    relations_data.append('<r>')
                    rel_sentence_data.append('<r>' + ' ' + sentence)
                    extractions_data.append('<e>')
                else:
                    relations_data.append(" <r> ".join(relations).strip() + ' <r>')
                assert len(relations) == len(exts)
                for ind in range(len(relations)):
                    relation = relations[ind].strip()
                    rel_sentence_data.append(relation + ' <r> ' + sentence)
                    res = ""
                    for extraction in exts[ind]:
                        res+= (extraction + ' <e> ')
                    extractions_data.append(res.strip())

            corrupted_rel_sentence_data = []
            for i in range(len(rel_sentence_data)):
                rel_sent = rel_sentence_data[i].strip().split('<r>')
                assert len(rel_sent) == 2
                relation = rel_sent[0].strip()
                sentence = rel_sent[1].strip()
                corrupted_relation = corrupt(relation, sentence)
                if corrupted_relation == "":
                    corrupted_rel_sentence_data.append('<r> ' + sentence)
                else:
                    corrupted_rel_sentence_data.append(corrupted_relation + ' <r> ' + sentence)

            corrupt_input_ext_data = []
            assert len(corrupted_rel_sentence_data) == len(extractions_data)
            for i in range(len(corrupted_rel_sentence_data)):
                corrupt_input = corrupted_rel_sentence_data[i].strip()
                corrupt_input_ext_data.append((corrupt_input.strip(), extractions_data[i].strip()))
            if 'ud' in args.lang:
                valid_data = corrupt_input_ext_data[0:298]
                train_data = corrupt_input_ext_data[298:]
            else:                
                random.shuffle(corrupt_input_ext_data)
                valid_data = corrupt_input_ext_data[0:2000]
                train_data = corrupt_input_ext_data[2000:]

            train_input = open(args.out+'/train.input', 'w')
            train_target = open(args.out+'/train.target', 'w')
            valid_input = open(args.out+'/valid.input', 'w')
            valid_target = open(args.out+'/valid.target', 'w')


            for i in range(len(train_data)):
                train_input.write(train_data[i][0].strip() + '\n')
                train_target.write(train_data[i][1].strip() + '\n')
            for i in range(len(valid_data)):
                valid_input.write(valid_data[i][0].strip() + '\n')
                valid_target.write(valid_data[i][1].strip() + '\n')


            train_input.close()
            train_target.close()
            valid_input.close()
            valid_target.close()


    