import json
import os
import sys, getopt
from tqdm import tqdm
from time import time
import random
import torch
import torch.utils.data as data
from torch.utils.data import TensorDataset, SequentialSampler, DataLoader, RandomSampler
from utils import *
from pprint import pprint
from time import time
from transformers import BartTokenizerFast

# load data
def load_data():
    train_data = read_json('preprocess/train_data.json')
    eval_data = read_json('preprocess/eval_data.json')

    return train_data, eval_data

def infilling_process(i, max_length=None, event_max_length=None, other_sent_end_id=2, tkr=None, negative_number=10):
    cls_token, mask_token, eos_token = tkr.cls_token, tkr.mask_token, tkr.eos_token
    end_para_id = None
    if i['end_number'] > 0:
        if int(i['end_list'][-1][0]) == len(i['para'])-1:
            end_para_id = int(len(i['para'])-1)
        else:
            for k in i['end_list']:
                if int(k[0]) >= other_sent_end_id:
                    end_para_id = int(k[0])
    if end_para_id != None:
        end_head, end_tail = i['event'][str(end_para_id)][-1][0]
        mask_piece = i['token'][str(end_para_id)][end_head:]
        end_string_list = i['token'][str(end_para_id)][:end_head] + [mask_token]
        string = list2string(i['para'][:end_para_id] + [list2string(end_string_list)])
        
        positive_piece = list2string(mask_piece)
        
        negative_event_candidate = i['event'][str(end_para_id)][-1][1][0]
        random.shuffle(negative_event_candidate)
        mask_piece_list = negative_event_candidate[:5]
        other_piece_number = negative_number - len(mask_piece_list)
        negative_event_candidate = i['event'][str(end_para_id)][-1][1][1] + i['event'][str(end_para_id)][-1][1][2]
        random.shuffle(negative_event_candidate)
        mask_piece_list += negative_event_candidate[:other_piece_number]
        negative_piece = [k for k in mask_piece_list]
    else:
        para_id_list = [k for k in i['event'].keys()]
        random.shuffle(para_id_list)

        event_interval_id_list = [k for k in range(len(i['event'][para_id_list[0]]))]
        random.shuffle(event_interval_id_list)
        event_interval_list = [i['event'][para_id_list[0]][event_interval_id_list[0]][0]]

        temp_list_1 = i['token'][para_id_list[0]]
        mask_piece = temp_list_1[event_interval_list[0][0]:event_interval_list[0][1]]
        temp_list_1 = temp_list_1[:event_interval_list[0][0]] + [mask_token] + temp_list_1[event_interval_list[0][1]:]
        string = list2string(i['para'][:int(para_id_list[0])] + [list2string(temp_list_1)] + \
                                i['para'][int(int(para_id_list[0])+1):])

        positive_piece = list2string(mask_piece)
        
        negative_event_candidate = i['event'][para_id_list[0]][event_interval_id_list[0]][1][0]
        random.shuffle(negative_event_candidate)
        mask_piece_list = negative_event_candidate[:5]
        other_piece_number = negative_number - len(mask_piece_list)
        negative_event_candidate = i['event'][para_id_list[0]][event_interval_id_list[0]][1][1] + i['event'][para_id_list[0]][event_interval_id_list[0]][1][2]
        random.shuffle(negative_event_candidate)
        mask_piece_list += negative_event_candidate[:other_piece_number]
        negative_piece = [k for k in mask_piece_list]

    batch_dict = tkr(string, padding='max_length', truncation=True, max_length=max_length)
    mask_loc = [0]*len(batch_dict['input_ids'])
    mask_loc[batch_dict['input_ids'].index(tkr.convert_tokens_to_ids(mask_token))] = 1
    batch_dict['mask_loc'] = mask_loc
    labels = tkr(positive_piece, max_length=event_max_length)['input_ids']
    batch_dict['labels'] = labels + [-100]*(event_max_length-len(labels))

    positive_dict = tkr(positive_piece, padding='max_length', truncation=True, max_length=event_max_length)

    negative_dict = tkr(negative_piece, padding='max_length', truncation=True, max_length=event_max_length)

    return batch_dict, positive_dict, negative_dict

def class_tagging_process(i, max_length=None, event_max_length=None, tkr=None, negative_cls_number=5, event_span_number=5):
    cls_token, mask_token, eos_token = tkr.cls_token, tkr.mask_token, tkr.eos_token
    cls_id = tkr.convert_tokens_to_ids(cls_token)
    eos_id = tkr.convert_tokens_to_ids(eos_token)
    mask_id = tkr.convert_tokens_to_ids(mask_token)
    
    # sampling
    pos2_event_para_id_list = []
    pos1_event_para_id_list = []
    for k,v in i['event'].items():
        for event_frag_id, (event_loc, ng_event) in enumerate(v):
            if len(ng_event[0]) >= 2:
                pos2_event_para_id_list.append([k, event_frag_id, 2])
            elif len(ng_event[0]) >= 1:
                pos1_event_para_id_list.append([k, event_frag_id, 1])

    random.shuffle(pos1_event_para_id_list)
    random.shuffle(pos2_event_para_id_list)

    all_event_para_id_list = pos2_event_para_id_list[:event_span_number]
    if len(all_event_para_id_list) < event_span_number:
        all_event_para_id_list += pos1_event_para_id_list[:event_span_number-len(all_event_para_id_list)]
    
    target_event_loc = all_event_para_id_list[0]
    all_event_para_id_list.sort(key=lambda x:(int(x[0]), x[1]))
    target_event_list = [0]*event_span_number
    target_event_list[all_event_para_id_list.index(target_event_loc)] = 1

    sent_piece_list = []
    sent_piece_sign = []
    sent_neg_sign = []
    sent_neg_list = []

    for k, v in enumerate(i['para']):
        tmp_list = []
        tmp_sign = []
        for mk in range(len(all_event_para_id_list)):
            if all_event_para_id_list[mk][0] == str(k):
                tmp_list.append(all_event_para_id_list[mk][1])
                if target_event_list[mk] == 1:
                    tmp_sign.append(1)
                else:
                    tmp_sign.append(0)

        if len(tmp_list) == 0:
            sent_piece_list.append(v)
            sent_piece_sign.append(0)
            sent_neg_sign.append(0)
        else:
            sent_start_id = 0
            for mk in range(len(tmp_list)):
                event_loc = i['event'][str(k)][tmp_list[mk]][0]
                tmp = list2string(i['token'][str(k)][sent_start_id:event_loc[0]])
                if len(tmp) > 0:
                    sent_piece_list.append(tmp)
                    sent_piece_sign.append(0)
                    sent_neg_sign.append(0)
                tmp = list2string(i['token'][str(k)][event_loc[0]:event_loc[1]])
                if len(tmp) > 0:
                    sent_piece_list.append(tmp)
                    sent_piece_sign.append(1)
                    if tmp_sign[mk] == 1:
                        sent_neg_sign.append(1)
                        ng_1 = i['event'][str(k)][tmp_list[mk]][1][0]
                        ng_2 = i['event'][str(k)][tmp_list[mk]][1][1] + i['event'][str(k)][tmp_list[mk]][1][2]
                        random.shuffle(ng_1)
                        random.shuffle(ng_2)
                        sent_neg_list = ng_1[:negative_cls_number-2]
                        need_neg_num = negative_cls_number-len(sent_neg_list)
                        sent_neg_list += ng_2[:need_neg_num]
                        sent_neg_list = [tkr(mtmp, max_length=event_max_length-2, add_special_tokens=False)['input_ids'] for mtmp in sent_neg_list]
                    else:
                        sent_neg_sign.append(0)
                sent_start_id = event_loc[1]

    ### classification
    cls_list = [cls_id]
    options = []
    for mj_id, mj in enumerate(sent_piece_list):
        if sent_neg_sign[mj_id] == 0:
            cls_list += tkr(mj, add_special_tokens=False)['input_ids']
        else:
            cls_list += [mask_id]
            options.append(tkr(mj, max_length=event_max_length-2, add_special_tokens=False)['input_ids'])
            cls_label = tkr(mj, max_length=event_max_length)['input_ids']

    cls_list += tkr("Options:", add_special_tokens=False)['input_ids']
    remain_length = max_length - len(cls_list) - len(options[0]) - 3
    
    for mj in sent_neg_list:
        if len(mj) + 2 < remain_length:
            options.append(mj)
            remain_length -= (len(mj) + 2)
    
    random.shuffle(options)
    for mj_id, mj in enumerate(options):
        options_id = '%s:' % str(int(mj_id+1))
        cls_list = cls_list + tkr(options_id, add_special_tokens=False)['input_ids'] + mj
    cls_list += [eos_id]

    cls_batch_dict = {}
    attention_mask_len = len(cls_list)
    cls_batch_dict['input_ids'] = cls_list + [tkr.pad_token_id]*(max_length-len(cls_list))
    cls_batch_dict['attention_mask'] = [1]*attention_mask_len + [0]*(max_length-attention_mask_len)
    cls_batch_dict['labels'] = cls_label + [-100]*(event_max_length-len(cls_label))

    ### position
    pos_list = [cls_id]
    for mj_id, mj in enumerate(sent_piece_list):
        if sent_piece_sign[mj_id] == 0:
            pos_list += tkr(mj, add_special_tokens=False)['input_ids']
        else:
            if sent_neg_sign[mj_id] == 0:
                pos_list += tkr(mj, add_special_tokens=False)['input_ids']
            else:
                random.shuffle(sent_neg_list)
                pos_list += sent_neg_list[0]
                pos_label = [cls_id] + sent_neg_list[0] + [eos_id]

    pos_list += tkr("Event:", add_special_tokens=False)['input_ids']
    pos_list += [mask_id]
    pos_list += tkr("is wrong", add_special_tokens=False)['input_ids']
    pos_list += [eos_id]

    pos_batch_dict = {}
    attention_mask_len = len(pos_list)
    pos_batch_dict['input_ids'] = pos_list + [tkr.pad_token_id]*(max_length-len(pos_list))
    pos_batch_dict['attention_mask'] = [1]*attention_mask_len + [0]*(max_length-attention_mask_len)
    pos_batch_dict['labels'] = pos_label + [-100]*(event_max_length-len(pos_label))
    
    return cls_batch_dict, pos_batch_dict

def conj_class_tagging_process(i, max_length=None, event_max_length=None, tkr=None, negative_cls_number=5, event_span_number=5):
    cls_token, mask_token, eos_token = tkr.cls_token, tkr.mask_token, tkr.eos_token
    cls_id = tkr.convert_tokens_to_ids(cls_token)
    eos_id = tkr.convert_tokens_to_ids(eos_token)
    mask_id = tkr.convert_tokens_to_ids(mask_token)

    sent_piece_list = []
    sent_neg_list = []
    sent_neg_sign = []

    conj_can_retrieval_conj = generate_conj_can_retrieval_conj()

    conj_para_loc_list = i['conj']
    random.shuffle(conj_para_loc_list)

    for conj_para_loc in conj_para_loc_list:
        if conj_para_loc[2] in conj_can_retrieval_conj:
            break

    conj_cand_list = conj_can_retrieval_conj[conj_para_loc[2]]
    random.shuffle(conj_cand_list)
    conj_cand_list = conj_cand_list[:negative_cls_number]
    sent_neg_list = [tkr(mtmp, max_length=event_max_length-2, add_special_tokens=False)['input_ids'] for mtmp in conj_cand_list]

    for k, v in enumerate(i['para']):
        if k != int(conj_para_loc[0]):
            sent_piece_list.append(v)
            sent_neg_sign.append(0)
        else:
            if len(i['token'][str(k)][:conj_para_loc[1][0]]) > 0:
                sent_piece_list.append(list2string(i['token'][str(k)][:conj_para_loc[1][0]]))
                sent_neg_sign.append(0)
            sent_piece_list.append(list2string(i['token'][str(k)][conj_para_loc[1][0]:conj_para_loc[1][1]]))
            sent_neg_sign.append(1)
            if len(i['token'][str(k)][conj_para_loc[1][1]:]) > 0:
                sent_piece_list.append(list2string(i['token'][str(k)][conj_para_loc[1][1]:]))
                sent_neg_sign.append(0)

    ### classification
    cls_list = [cls_id]
    options = []
    for mj_id, mj in enumerate(sent_piece_list):
        if sent_neg_sign[mj_id] == 0:
            cls_list += tkr(mj, add_special_tokens=False)['input_ids']
        else:
            cls_list += [mask_id]
            options.append(tkr(mj, max_length=event_max_length-2, add_special_tokens=False)['input_ids'])
            cls_label = tkr(mj, max_length=event_max_length)['input_ids']

    cls_list += tkr("Options:", add_special_tokens=False)['input_ids']
    remain_length = max_length - len(cls_list) - len(options[0]) - 3
    
    for mj in sent_neg_list:
        if len(mj) + 2 < remain_length:
            options.append(mj)
            remain_length -= (len(mj) + 2)
    
    random.shuffle(options)
    for mj_id, mj in enumerate(options):
        options_id = '%s:' % str(int(mj_id+1))
        cls_list = cls_list + tkr(options_id, add_special_tokens=False)['input_ids'] + mj
    cls_list += [eos_id]

    cls_batch_dict = {}
    attention_mask_len = len(cls_list)
    cls_batch_dict['input_ids'] = cls_list + [tkr.pad_token_id]*(max_length-len(cls_list))
    cls_batch_dict['attention_mask'] = [1]*attention_mask_len + [0]*(max_length-attention_mask_len)
    cls_batch_dict['labels'] = cls_label + [-100]*(event_max_length-len(cls_label))

    ### position
    pos_list = [cls_id]
    for mj_id, mj in enumerate(sent_piece_list):
        if sent_neg_sign[mj_id] == 0:
            pos_list += tkr(mj, add_special_tokens=False)['input_ids']
        else:
            random.shuffle(sent_neg_list)
            pos_list += sent_neg_list[0]
            pos_label = [cls_id] + sent_neg_list[0] + [eos_id]

    pos_list += tkr("Event:", add_special_tokens=False)['input_ids']
    pos_list += [mask_id]
    pos_list += tkr("is wrong", add_special_tokens=False)['input_ids']
    pos_list += [eos_id]

    pos_batch_dict = {}
    attention_mask_len = len(pos_list)
    pos_batch_dict['input_ids'] = pos_list + [tkr.pad_token_id]*(max_length-len(pos_list))
    pos_batch_dict['attention_mask'] = [1]*attention_mask_len + [0]*(max_length-attention_mask_len)
    pos_batch_dict['labels'] = pos_label + [-100]*(event_max_length-len(pos_label))
    
    return cls_batch_dict, pos_batch_dict

def collate_fn(data):
    new_data = []
    for i in data:
        if i[0] == None:
            continue
        new_data.append(i)
    new_data = zip(*new_data)
    new_data = [torch.stack(i, 0) for i in new_data]
    
    return new_data[0], new_data[1], new_data[2], new_data[3], new_data[4], \
           new_data[5], new_data[6], new_data[7], new_data[8], new_data[9], \
           new_data[10], new_data[11], new_data[12], new_data[13], new_data[14], \
           new_data[15], new_data[16], new_data[17], new_data[18], new_data[19]

class EventDataset(data.Dataset):
    def __init__(self, data, tkr, inf_max_length=None, cls_max_length=None, event_max_length=None, conj_max_length=None, other_sent_end_id=2):
        self.data=data
        self.tkr=tkr
        self.inf_max_length=inf_max_length
        self.cls_max_length=cls_max_length
        self.event_max_length=event_max_length
        self.conj_max_length=conj_max_length
        self.other_sent_end_id=other_sent_end_id

    def __getitem__(self, index):
        i = self.data[index]

        try:
            ##### Event Infilling
            infilling_input_dict, infilling_positive_dict, infilling_negative_dict = infilling_process(i, self.inf_max_length, self.event_max_length, self.other_sent_end_id, self.tkr)

            infilling_input_ids = torch.LongTensor(infilling_input_dict['input_ids'])
            infilling_attention_mask = torch.LongTensor(infilling_input_dict['attention_mask'])
            infilling_labels = torch.LongTensor(infilling_input_dict['labels'])
            infilling_mask_loc = torch.FloatTensor(infilling_input_dict['mask_loc'])

            positive_piece_input_ids = torch.LongTensor(infilling_positive_dict['input_ids'])
            positive_piece_attention_mask = torch.LongTensor(infilling_positive_dict['attention_mask'])
            
            negative_piece_input_ids = torch.LongTensor(infilling_negative_dict['input_ids'])
            negative_piece_attention_mask = torch.LongTensor(infilling_negative_dict['attention_mask'])

            ##### Classify and Tagging
            cls_batch_dict, pos_batch_dict = class_tagging_process(i, self.cls_max_length, self.event_max_length, self.tkr)

            cls_input_ids = torch.LongTensor(cls_batch_dict['input_ids'])
            cls_attention_mask = torch.LongTensor(cls_batch_dict['attention_mask'])
            cls_labels = torch.LongTensor(cls_batch_dict['labels'])

            pos_input_ids = torch.LongTensor(pos_batch_dict['input_ids'])
            pos_attention_mask = torch.LongTensor(pos_batch_dict['attention_mask'])
            pos_labels = torch.LongTensor(pos_batch_dict['labels'])

            cls_batch_dict, pos_batch_dict = conj_class_tagging_process(i, self.inf_max_length+2*self.conj_max_length, self.conj_max_length, self.tkr)

            conj_cls_input_ids = torch.LongTensor(cls_batch_dict['input_ids'])
            conj_cls_attention_mask = torch.LongTensor(cls_batch_dict['attention_mask'])
            conj_cls_labels = torch.LongTensor(cls_batch_dict['labels'])

            conj_pos_input_ids = torch.LongTensor(pos_batch_dict['input_ids'])
            conj_pos_attention_mask = torch.LongTensor(pos_batch_dict['attention_mask'])
            conj_pos_labels = torch.LongTensor(pos_batch_dict['labels'])

            return infilling_input_ids, infilling_attention_mask, infilling_labels, infilling_mask_loc, \
                    positive_piece_input_ids, positive_piece_attention_mask, \
                    negative_piece_input_ids, negative_piece_attention_mask, \
                    cls_input_ids, cls_attention_mask, cls_labels, \
                    pos_input_ids, pos_attention_mask, pos_labels, \
                    conj_cls_input_ids, conj_cls_attention_mask, conj_cls_labels, \
                    conj_pos_input_ids, conj_pos_attention_mask, conj_pos_labels
        except:
            return None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None

    def __len__(self):
        return len(self.data)

def get_loader(batch_size, tkr, num_workers, inf_max_length=168, cls_max_length=224, event_max_length=30, conj_max_length=10, test=False):
    train_data, eval_data = load_data()
    train_data = EventDataset(tkr=tkr, data=train_data, inf_max_length=inf_max_length, cls_max_length=cls_max_length, \
                                event_max_length=event_max_length, conj_max_length=conj_max_length)
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
    train_data_loader = torch.utils.data.DataLoader(dataset=train_data, 
                                                    batch_size=batch_size,
                                                    num_workers=num_workers,
                                                    sampler=train_sampler,
                                                    pin_memory=True, 
                                                    drop_last=True, 
                                                    collate_fn=collate_fn)
    
    eval_data = EventDataset(tkr=tkr, data=eval_data, inf_max_length=inf_max_length, cls_max_length=cls_max_length, \
                                event_max_length=event_max_length, conj_max_length=conj_max_length)
    eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_data)
    eval_data_loader = torch.utils.data.DataLoader(dataset=eval_data, 
                                                    batch_size=batch_size*4,
                                                    num_workers=num_workers,
                                                    sampler=eval_sampler,
                                                    pin_memory=True, 
                                                    drop_last=True, 
                                                    collate_fn=collate_fn)

    return train_data_loader, eval_data_loader


if __name__ == '__main__':
    tokenizer = BartTokenizerFast.from_pretrained('facebook/bart-large')

