import torch
import transformers
import itertools
from transformers import AutoModel, AutoTokenizer
import ipdb
import argparse
import os

import numpy as np
import torch
from tqdm import trange
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset, SequentialSampler

from  phrase_extract import phrase_extraction
import stanza
import string
from tqdm import tqdm
import random
from awesome_align.configuration_bert import BertConfig
from awesome_align.modeling import BertForMaskedLM
from nltk.translate.bleu_score import sentence_bleu
import multiprocessing
import pickle

def set_seed(args):
    if args.seed >= 0:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)

def helper(x):
    x = x.split()
    res = ['NONE' for _ in range(len(x))]
    for tag in [('<r>', '</r>','REL'), ('<a2>', '</a2>','ARG2'), ('<l>', '</l>','LOC'), ('<t>', '</t>','TIME'),('<a1>', '</a1>','ARG1')]:
        y1 = tag[0]
        y2 = tag[1]
        label = tag[2]
        while y1 in x and y2 in x:
            start_index = x.index(y1)
            end_index = x.index(y2)
            x[start_index] = ''
            x[end_index] = ''
            for ind in range(start_index+1, end_index):
                if x[ind] in ['<r>', '</r>', '<a2>', '</a2>', '<l>', '</l>', '<t>', '</t>','<a1>', '</a1>']:
                    ipdb.set_trace()
                res[ind] = label
    return res

def get_spans(tgt):
    exts = tgt.split('<e>')[:-1]
    new_tgt = []
    new_tgt_label = []
    all_exts = []
    for ext in exts:
        ext = ext.strip()
        all_exts.append(ext)
        new_ext = []
        new_ext_label = []
        label = helper(ext)
        for ind,word in enumerate(ext.split()):
            if word not in ['<r>', '</r>', '<a2>', '</a2>', '<l>', '</l>', '<t>', '</t>','<a1>', '</a1>']:
                new_ext.append(word)
                if label[ind] == 'NONE':
                    ipdb.set_trace()
                new_ext_label.append(label[ind])
        new_tgt.append(new_ext)
        new_tgt_label.append(new_ext_label)
    all_spans = []
    for ind in range(len(new_tgt)):
        label = new_tgt_label[ind]
        ind = 0
        spans = []
        while ind < len(label):
            if label[ind] != 'NONE':
                start_index = ind
                end_index = ind
                prev_label = label[ind]
                ind+=1
                while ind < len(label) and (label[ind] == prev_label):
                    end_index = ind
                    ind+=1
                spans.append(((start_index, end_index+1), prev_label))
            else:
                ind+=1
        new_spans = {'start_rel':(10000,10000), 'end_rel':(-1,-1)}
        for sp in spans:
            if sp[1] == 'REL':
                if new_spans['start_rel'][0] > sp[0][0]:
                    new_spans['start_rel'] = sp[0]
                if new_spans['end_rel'][0] < sp[0][1]:
                    new_spans['end_rel'] = sp[0]
        all_spans.append(new_spans)
    return all_exts, new_tgt, all_spans

class LineByLineTextDataset(Dataset):
    def __init__(self, tokenizer, args):
        assert os.path.isfile(args.inp1)
        assert os.path.isfile(args.inp2)
        print('Loading the dataset...')
        self.examples = []
        with open(args.inp1) as f1,\
            open(args.inp2) as f2:
            data1 = f1.readlines()
            data2 = f2.readlines()
            if args.debug:
                data1 = data1[:200]
                data2 = data2[:200]
                args.output_file = args.output_file + '.debug'
            assert len(data1) == len(data2)
            for idx, line1 in tqdm(enumerate(data1)):
                if len(line1) == 0 or line1.isspace():
                    raise ValueError(f'Line {idx+1} is not in the correct format!')
                line2 = data2[idx]
                if len(line2) == 0 or line2.isspace():
                    raise ValueError(f'Line {idx+1} is not in the correct format!')

                src = line1.strip()
                tgt = line2.strip()
                if src.rstrip() == '' or tgt.rstrip() == '':
                    raise ValueError(f'Line {idx+1} is not in the correct format!')
                sent_src = src.strip().split()
                complete_sent_tgt,sent_tgt,labels_dict = get_spans(tgt)
                assert len(sent_tgt) > 0, ipdb.set_trace()
                for ind in range(len(sent_tgt)):
                    token_src, token_tgt = [tokenizer.tokenize(word) for word in sent_src], [tokenizer.tokenize(word) for word in sent_tgt[ind]]
                    wid_src, wid_tgt = [tokenizer.convert_tokens_to_ids(x) for x in token_src], [tokenizer.convert_tokens_to_ids(x) for x in token_tgt]
                    ids_src, ids_tgt = tokenizer.prepare_for_model(list(itertools.chain(*wid_src)), return_tensors='pt', model_max_length=tokenizer.model_max_length, truncation=True)['input_ids'], tokenizer.prepare_for_model(list(itertools.chain(*wid_tgt)), return_tensors='pt', truncation=True, model_max_length=tokenizer.model_max_length)['input_ids']
                    bpe2word_map_src = []
                    for i, word_list in enumerate(token_src):
                        bpe2word_map_src += [i for x in word_list]
                    bpe2word_map_tgt = []
                    for i, word_list in enumerate(token_tgt):
                        bpe2word_map_tgt += [i for x in word_list]
                    self.examples.append( (ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, complete_sent_tgt[ind],labels_dict[ind], sent_src, sent_tgt[ind], idx) )

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i):
        return self.examples[i]

def word_align(args, model, tokenizer, shuffle):

    def collate(examples):
        ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, complete_sent_tgt, label_dict, sent_src, sent_tgt, idx = zip(*examples)
        ids_src = pad_sequence(ids_src, batch_first=True, padding_value=tokenizer.pad_token_id)        
        ids_tgt = pad_sequence(ids_tgt, batch_first=True, padding_value=tokenizer.pad_token_id)
        return ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, complete_sent_tgt, label_dict, sent_src, sent_tgt, idx

    dataset = LineByLineTextDataset(tokenizer, args)
    sampler = SequentialSampler(dataset)
    dataloader = DataLoader(
        dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate
    )

    model.to(args.device)
    model.eval()
    tqdm_iterator = trange(dataset.__len__(), desc="Extracting")

    def clean_phrases(phrases):
        new_phrases_dict = {}
        for phr in phrases:
            try:
                tmp = new_phrases_dict[phr[0]]
                if tmp[0][1] - tmp[0][0] > phr[1][1] - phr[1][0]:
                    new_phrases_dict[phr[0]] = (phr[1],phr[2],phr[3])
            except:
                new_phrases_dict[phr[0]] = (phr[1],phr[2],phr[3])
        new_phrases = []
        for phr in new_phrases_dict:
            new_phrases.append((phr, new_phrases_dict[phr][0],new_phrases_dict[phr][1] ,new_phrases_dict[phr][2]))
        return new_phrases
    def clean_aligns(word_aligns):
        new_word_aligns = {}
        for word_pair in word_aligns:
            try:
                tmp = new_word_aligns[word_pair[0]]
                new_word_aligns[word_pair[0]] = -1
            except:
                new_word_aligns[word_pair[0]] = word_pair[1]
        new_aligns = []
        for word in new_word_aligns:
            if new_word_aligns[word] != -1:
                new_aligns.append((word, new_word_aligns[word]))
        if len(new_aligns) == 0:
            return word_aligns
        return new_aligns

    def projection(phrases, label, sent_tgt):
        max_score, span = -1, -1
        l = (label, " ".join(sent_tgt[label[0]:label[1]]))
        for phr in phrases:
            if len(set([i for i in range(phr[0][0],phr[0][1])]).intersection(set([i for i in range(label[0],label[1])]))) > 0:
                phrase_len = min(len(l[1].split()), len(phr[2].split()))
                score = 0
                if phrase_len >3:
                    weight = (0.25, 0.25, 0.25, 0.25)
                    score = round(sentence_bleu([l[1].split()], phr[2].split(), weights = weight),3)
                if score == 0 or phrase_len == 3:
                    weight = (1/3, 1/3, 1/3)
                    score = round(sentence_bleu([l[1].split()], phr[2].split(), weights = weight),3)
                if score==0 or phrase_len == 2:
                    weight = (0.5, 0.5)
                    score = round(sentence_bleu([l[1].split()], phr[2].split(), weights = weight),3)
                if score==0 or phrase_len == 1:
                    weight = (1,)
                    score = round(sentence_bleu([l[1].split()], phr[2].split(), weights = weight),3)

                if score > max_score:
                    max_score = score
                    span = phr
                elif score == max_score:
                    if (phr[1][1] - phr[1][0]) < (span[1][1] - span[1][0]):
                        span = phr 
                    elif (phr[1][1] - phr[1][0]) == (span[1][1] - span[1][0]):
                        if abs(len(phr[2])-len(l[1])) <= abs(len(span[2])-len(l[1])):
                            span = phr
        if max_score == 0:
            return 10000
        elif max_score == -1:
            return 10000
        else:
            return span[1][1]-1
    all_sent_src, all_sent_tgt, all_word_aligns, all_label_dict, all_complete_sent_tgt, all_idx = [],[],[],[],[],[]
    with open(args.output_file, 'w') as writer:
        for batch in dataloader:
            with torch.no_grad():
                ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, complete_sent_tgt, label_dict, sent_src, sent_tgt,idx = batch
                word_aligns_list = model.get_aligned_word(ids_tgt, ids_src, bpe2word_map_tgt, bpe2word_map_src, args.device, None, None, test=True)
                for ind, word_aligns in enumerate(word_aligns_list):
                    word_aligns = clean_aligns(word_aligns)
                    all_word_aligns.append(word_aligns)
                    all_sent_src.append(sent_src[ind])
                    all_sent_tgt.append(sent_tgt[ind])
                    all_label_dict.append(label_dict[ind])
                    all_complete_sent_tgt.append(complete_sent_tgt[ind])
                    all_idx.append(idx[ind])
                tqdm_iterator.update(len(ids_src))
        def parallel_sorter(sent_src_list, sent_tgt_list, label_dict_list, word_aligns_list, worker_id, shuffle):
            sorting_indices_list = []
            for ind in range(len(sent_src_list)):
                label = label_dict_list[ind]
                sent_tgt = sent_tgt_list[ind]
                sent_src = sent_src_list[ind]
                word_aligns = word_aligns_list[ind]
                if label['start_rel'] == 10000:
                    sorting_indices_list.append((10000, 10000))
                    continue
                if not shuffle:
                    phrases = list(phrase_extraction(" ".join(sent_tgt), " ".join(sent_src),word_aligns))
                    phrases = clean_phrases(phrases)
                    start_rel_index = projection(phrases, label['start_rel'], sent_tgt)
                    end_rel_index = projection(phrases, label['end_rel'], sent_tgt)
                    sorting_indices_list.append((start_rel_index, end_rel_index))
                else:
                    sorting_indices_list.append((10000,10000))

            with open(args.output_file +'_' + str(worker_id), 'wb') as f:
                pickle.dump(sorting_indices_list, f)
        workers = 10
        each_load = int(len(all_sent_src)/workers)+1
        processes = []
        for i in range(workers):
            processes.append(multiprocessing.Process(target=parallel_sorter, args=(all_sent_src[i*each_load:(i+1)*each_load],all_sent_tgt[i*each_load:(i+1)*each_load],\
                all_label_dict[i*each_load:(i+1)*each_load], all_word_aligns[i*each_load:(i+1)*each_load], i, shuffle)))
        for i in range(workers):
            processes[i].start()
        for i in range(workers):
            processes[i].join()
        all_sorting_indices = []
        for i in range(workers):
            with open(args.output_file + '_' + str(i), 'rb') as f:
                worker_data = pickle.load(f)
            all_sorting_indices.extend(worker_data)
            os.remove(args.output_file +'_'+str(i))
        extractions_tuples = {}
        max_idx = -1
        for ind in range(len(all_sent_src)):
            index = all_idx[ind]
            if max_idx < index :
                max_idx = index
            if index in extractions_tuples.keys():
                extractions_tuples[index].append((all_sorting_indices[ind],all_complete_sent_tgt[ind]))
            else:
                extractions_tuples[index] = [(all_sorting_indices[ind], all_complete_sent_tgt[ind])]

        for ind in range(max_idx+1):
            extractions = extractions_tuples[ind]
            if shuffle:
                random.shuffle(extractions) 
                sorted_extractions = extractions
            else:
                sorted_extractions = sorted(extractions, key=lambda x:(x[0][1], x[0][0]))
            new_extractions = []
            for ext in sorted_extractions:
                new_extractions.append(ext[1])
            complete_extraction = " <e> ".join(new_extractions) + ' <e> '
            if complete_extraction.strip() == "<a1> empty </a1> <e>":
                complete_extraction = "<e>"
            writer.write(complete_extraction.strip() + '\n')
        

def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--inp1", default=None, type=str, required=True, help="The input data file (a text file)."
    )
    parser.add_argument(
        "--inp2", default=None, type=str, required=True, help="The input data file (a text file)."
    )
    parser.add_argument(
        "--output_file",
        type=str,
        required=True,
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--softmax_threshold", type=float, default=0.001
    )
    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
    parser.add_argument("--batch_size", default=32, type=int)
    parser.add_argument("--align_layer", type=int, default=8, help="layer for alignment extraction")
    parser.add_argument(
        "--extraction", default='softmax', type=str, help='softmax or entmax15'
    )
    parser.add_argument(
        "--lang", type=str, help='es or zh or hi'
    )
    parser.add_argument("--debug", action = 'store_true')
    parser.add_argument("--shuffle", action = 'store_true')
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        help="The model checkpoint for weights initialization. Leave None if you want to train a model from scratch.",
    )

    args = parser.parse_args()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.device = device
    set_seed(args)
    config_class, model_class = BertConfig, BertForMaskedLM
    config = config_class.from_pretrained(args.model_name_or_path)
    model = model_class.from_pretrained(
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config
    )
    tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
    word_align(args, model, tokenizer, args.shuffle)
if __name__ == "__main__":
    main()
