import math

import transformers
from transformers import T5ForConditionalGeneration, T5Tokenizer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers import BertForMaskedLM, BertTokenizer, AutoConfig
import argparse
import torch
import os
from tqdm import tqdm
import json
import argparse
import pandas as pd
from os.path import join
from glob import glob
import torch.nn.functional as F


def load_vocab(vocab_filename):
    with open(vocab_filename, "r") as f:
        lines = f.readlines()
    vocab = [x.strip() for x in lines]
    return vocab

vocab_filename = "../common_vocabs/common_vocab_cased_be_ro_al.txt"

def get_text(template, sub, obj, tokenizer):
    def enc(text):
        return tokenizer.encode(text, add_special_tokens=False)

    special_token_mapping = {'cls': tokenizer.cls_token_id, 'mask': tokenizer.mask_token_id,
                             'sep': tokenizer.eos_token_id, 'sep+': tokenizer.eos_token_id}
    for i in range(10):
        special_token_mapping["<extra_id_%d>" % (i)] = tokenizer.convert_tokens_to_ids("<extra_id_%d>" % (i))
    template_list = template.split('*')
    input_ids = []
    for part in template_list:
        new_tokens = []
        if part in special_token_mapping:
            if part == 'cls' and 'T5' in type(tokenizer).__name__:
                # T5 does not have cls token
                continue
            new_tokens.append(special_token_mapping[part])
        elif part[:5] == 'label':
            new_tokens += enc(' ' + obj)
        elif part[:5] == 'sent_':
            # sent_id = int(part.split('_')[1])
            new_tokens += enc(" "+ sub)
        else:
            part = part.replace('_', ' ') # there cannot be space in command, so use '_' to replace space
            # handle special case when t5 tokenizer might add an extra space
            if len(part) == 1:
                new_tokens.append(tokenizer.convert_tokens_to_ids(part))
            else:
                new_tokens += enc(part)
        input_ids += new_tokens
    return input_ids


def generate(dataset, template, model, tokenizer, target_number, beam, label=None, length_limit=None):
    """
    Generate templates based on given inputs
    """
    input_tensors = []
    max_length = 0
    for item in dataset:
        sub = item['sub']
        obj = item['obj']
        input_text = get_text(template, sub, obj, tokenizer)
        input_ids = torch.tensor(input_text).long()
        max_length = max(max_length, input_ids.size(-1))
        input_tensors.append(input_ids)

    # Concatenate inputs as a batch
    input_ids = torch.ones((len(input_tensors), max_length)).long()* tokenizer.pad_token_id
    attention_mask = torch.zeros((len(input_tensors), max_length)).long()
    for i in range(len(input_tensors)):
        input_ids[i, :input_tensors[i].size(-1)] = input_tensors[i]
        attention_mask[i, :input_tensors[i].size(-1)] = 1

    input_ids = input_ids.cuda()
    attention_mask = attention_mask.cuda()

    # Maximum generate content length
    max_length = 20

    start_mask = tokenizer.convert_tokens_to_ids('<extra_id_0>')
    ori_decoder_input_ids = torch.zeros((input_ids.size(0), max_length)).long()
    ori_decoder_input_ids[..., 0] = model.config.decoder_start_token_id

    current_output = [{'decoder_input_ids': ori_decoder_input_ids, 'll': 0, 'output_id': 1, 'output': [], 'last_length': -1}]
    for i in tqdm(range(max_length - 2)):
        new_current_output = []
        for item in current_output:
            if item['output_id'] > target_number:
                # Enough contents
                new_current_output.append(item)
                continue

            # Forward
            batch_size = 32
            turn = input_ids.size(0) // batch_size
            if input_ids.size(0) % batch_size != 0:
                turn += 1
            aggr_output = []
            decoder_input_ids = item['decoder_input_ids']
            for t in range(turn):
                start = t * batch_size
                end = min((t + 1) * batch_size, input_ids.size(0))
                with torch.no_grad():
                    aggr_output.append(model(input_ids[start:end], attention_mask=attention_mask[start:end],
                                             decoder_input_ids=decoder_input_ids.cuda()[start:end])[0])
            aggr_output = torch.cat(aggr_output, 0)

            # Gather results across all input sentences, and sort generated tokens by log likelihood
            aggr_output = aggr_output.mean(0)
            ids = list(range(model.config.vocab_size))

            log_denominator = torch.logsumexp(aggr_output[i], -1).item()
            ids.sort(key=lambda x: aggr_output[i][x].item(), reverse=True)
            ids = ids[:beam + 3]

            output_id = item['output_id']
            for word_id in ids:
                if word_id == start_mask - output_id or word_id == tokenizer.eos_token_id:
                    # Finish one part
                    if length_limit is not None and item['last_length'] < length_limit[output_id - 1]:
                        check = False
                    else:
                        check = True
                    output_id += 1
                    last_length = 0
                else:
                    last_length = item['last_length'] + 1
                    check = True

                output_text = item['output'] + [word_id]
                ll = item['ll'] + aggr_output[i][word_id] - log_denominator
                new_decoder_input_ids = decoder_input_ids.new_zeros(decoder_input_ids.size())
                new_decoder_input_ids[:] = decoder_input_ids
                new_decoder_input_ids[..., i + 1] = word_id

                # Forbid single space token, "....", and ".........."
                if word_id in [3, 19794, 22354]:
                    check = False

                # Forbid continuous "."
                if len(output_text) > 1 and output_text[-2] == 5 and output_text[-1] == 5:
                    check = False

                if check:
                    # Add new results to beam search pool
                    new_item = {'decoder_input_ids': new_decoder_input_ids, 'll': ll, 'output_id': output_id,
                                'output': output_text, 'last_length': last_length}
                    new_current_output.append(new_item)

        if len(new_current_output) == 0:
            break

        new_current_output.sort(key=lambda x: x['ll'], reverse=True)
        new_current_output = new_current_output[:beam]
        current_output = new_current_output

    result = []
    print("####### generated results #######")
    for item in current_output:
        generate_text = ''
        for token in item['output']:
            generate_text += tokenizer.convert_ids_to_tokens(token)
        print('--------------')
        print('score:', item['ll'])
        print('generated ids', item['output'])
        print('generated text', generate_text)
        result.append(generate_text)
    print("####### generated results #######\n")
    return result

def load_data(file):
    data = open(file, "r").readlines()
    samples = []
    for line in data:
        item = json.loads(line.strip())
        feat = {}
        feat["sub"] = item["sub_label"]
        feat["obj"] = item["obj_label"]
        samples.append(feat)
    return samples

def search_template(model, tokenizer, task_name, seed, beam, output_dir, data_dir):
    dataset = load_data(join(data_dir, task_name, "train.jsonl"))
    template = "*cls**sent_0**<extra_id_0>**label**<extra_id_1>**sep+*"
    generate_text = generate(dataset, template, model, tokenizer, target_number=2, beam=beam, label=None)
    print("####### generated templates #######")

    os.makedirs(output_dir, exist_ok=True)
    # os.makedirs(os.path.join(output_dir, task_name), exist_ok=True)
    f = open(os.path.join(output_dir, task_name+".txt"), 'w')

    print("%s" % task_name)
    for text in generate_text:
        # Transform T5 outputs to our template format
        text = text.replace('<extra_id_0>', '*cls**sent_0*')
        text = text.replace('<extra_id_1>', '*mask*')
        text = text.replace('<extra_id_2>', '*sep+*')
        text = text.replace('</s>', '*sep+*')
        text = text.replace('▁', '_')
        print(text)
        f.write(text + '\n')
    print("####### generated templates #######\n")


def main_train():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default='t5-3b', help='pre-trained model')
    parser.add_argument('--seed', type=int, default=13, help="seeds")
    parser.add_argument('--output_dir', type=str, default='./baseline_Output')
    parser.add_argument('--data_dir', type=str, default="../LAMA_data/autoprompt_data", help="Data directory")
    parser.add_argument('--beam', type=int, default=10, help="Beam search width")

    args = parser.parse_args()

    # model = T5ForConditionalGeneration.from_pretrained(args.model)
    # tokenizer = T5Tokenizer.from_pretrained(args.t5_model)
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    model = AutoModelForSeq2SeqLM.from_pretrained(args.model)
    # tokenizer.sep_token = '</s>'

    model = model.cuda()
    model.eval()

    relation_directories = glob(args.data_dir + "/*/")
    for dir in relation_directories:
        task_name = dir.split("/")[-2]
        search_template(model=model, tokenizer=tokenizer, task_name=task_name, seed=args.seed, beam=args.beam, output_dir=args.output_dir, data_dir=args.data_dir)



def init_indices_for_filter_logprobs(vocab_subset, tokenizer, logger=None):
    index_list = []
    new_vocab_subset = []
    for word in vocab_subset:
        tokens = tokenizer.tokenize(' ' + word)
        if (len(tokens) == 1) and (tokens[0] != tokenizer.unk_token):
            index_list.append(tokenizer.convert_tokens_to_ids(tokens)[0])
            new_vocab_subset.append(word)
        else:
            msg = "word {} from vocab_subset not in model vocabulary!".format(word)
            if logger is not None:
                logger.warning(msg)
            else:
                logger.info("WARNING: {}".format(msg))
    indices = torch.as_tensor(index_list)
    return indices, index_list

def main_test():
    templateDir = "./baseline_Output"
    dataDir = "../LAMA_data/autoprompt_data"

    def enc(text):
        return tokenizer.encode(text, add_special_tokens=False)

    batch_size = 32
    language_model_name = "bert-base-cased"
    language_vocab_name = language_model_name
    mlm_config = AutoConfig.from_pretrained(language_model_name)
    model_type = 'bert'
    tokenizer = BertTokenizer.from_pretrained(language_vocab_name)
    model = BertForMaskedLM.from_pretrained(language_model_name, config=mlm_config)
    model.eval()
    model = model.cuda()

    relation_directories = glob(dataDir + "/*/")
    special_token_mapping = {'cls': tokenizer.cls_token_id, 'mask': tokenizer.mask_token_id,
                             'sep': tokenizer.sep_token_id, 'sep+': tokenizer.sep_token_id}
    vocab = list(tokenizer.get_vocab().keys())
    vocab_subset = load_vocab(vocab_filename)
    filter_indices, index_list = init_indices_for_filter_logprobs(vocab_subset, tokenizer)

    k = 5
    for dir in relation_directories:
        task_name = dir.split("/")[-2]
        print(task_name)
        data_file = join(dataDir, task_name, "test.jsonl")

        all_samples = []
        distinct_facts = set()
        raw_samples = load_data(data_file)
        for data_sample in raw_samples:
            # follow the LAMA setting, only keep distinct (sub, obj) pairs
            if (data_sample['sub'], data_sample['obj']) in distinct_facts:
                continue
            if (data_sample['obj'] not in vocab_subset):
                continue
            distinct_facts.add((data_sample['sub'], data_sample['obj']))
            all_samples.append(data_sample)

        template_file = join(templateDir, task_name+".txt")
        template = open(template_file, "r").readlines()[0].strip()

        template_list = template.split('*')

        result = {}
        cor_all = 0
        tot_all = 0
        list_of_predictions = {}
        steps = math.ceil(len(all_samples)*1.0 / batch_size)
        for step in range(steps):
            start = step* batch_size
            end = min((step+1)* batch_size, len(all_samples))
            samples = []
            for i in range(start, end):
                samples.append(all_samples[i])
            labels = []
            max_len = 0
            input_ids_tensor = []
            masked_indices = []
            for idx, item in enumerate(samples):
                input_ids = []
                for part in template_list:
                    new_tokens = []
                    if part in special_token_mapping:
                        new_tokens.append(special_token_mapping[part])
                    elif part[:5] == 'sent_':
                        new_tokens += enc(" " + item["sub"])
                    else:
                        part = part.replace('_', ' ')  # there cannot be space in command, so use '_' to replace space
                        # handle special case when t5 tokenizer might add an extra space
                        if len(part) == 1:
                            new_tokens.append(tokenizer.convert_tokens_to_ids(part))
                        else:
                            new_tokens += enc(part)
                    input_ids += new_tokens
                try:
                    masked_indice = input_ids.index(tokenizer.mask_token_id)
                except:
                    print(template_list)
                    return
                masked_indices.append(masked_indice)
                max_len = max(max_len, len(input_ids))
                input_ids_tensor.append(input_ids)
                labels.append(tokenizer.convert_tokens_to_ids(item["obj"]))

            labels = torch.tensor(labels).long().cuda()
            masked_indices = torch.tensor(masked_indices).long().cuda()
            attention_mask = torch.zeros([len(labels), max_len]).long()
            mlm_labels = torch.ones([len(labels), max_len]).long() * (-100)
            for i in range(len(labels)):
                seq_len = len(input_ids_tensor[i])
                attention_mask[i, :seq_len] = 1
                pad_len = max_len- seq_len
                if pad_len > 0:
                    input_ids_tensor[i] += [0]* pad_len
                mlm_labels[i, masked_indices[i]] = labels[i]
            input_ids_tensor = torch.tensor(input_ids_tensor).long().cuda()
            attention_mask = attention_mask.cuda()
            mlm_labels = mlm_labels.cuda()

            with torch.no_grad():
                loss, logits = model(input_ids=input_ids_tensor,
                                     attention_mask=attention_mask,
                                     labels=mlm_labels,
                                     return_dict=False)
            log_probs = F.log_softmax(logits, dim=-1).cpu()

            tot = log_probs.shape[0]
            cor = 0
            preds = []
            topk = []
            common_vocab_loss = []

            vocab_to_common_vocab = None
            if index_list is not None:
                vocab_to_common_vocab = {}
                for cid, idx in enumerate(index_list):
                    vocab_to_common_vocab[idx] = cid

            # During testing, return accuracy and top-k predictions
            for i in range(log_probs.shape[0]):
                masked_index = masked_indices[i]
                log_prob = log_probs[i, masked_index]
                mlm_label = labels[i].item()

                log_prob = log_prob.index_select(dim=0, index=filter_indices)
                pred_common_vocab = torch.argmax(log_prob)
                pred = index_list[pred_common_vocab]
                # get top-k predictions
                topk_preds = []
                topk_log_prob, topk_ids = torch.topk(log_prob, k)
                for log_prob_i, idx in zip(topk_log_prob, topk_ids):
                    ori_idx = index_list[idx]
                    token = vocab[ori_idx]
                    topk_preds.append({'token': token, 'log_prob': log_prob_i.item()})
                topk.append(topk_preds)

                # compute entropy on common vocab
                common_logits = logits[i][masked_index].cpu().index_select(dim=0, index=filter_indices)
                common_log_prob = -F.log_softmax(common_logits, dim=-1)
                common_label_id = vocab_to_common_vocab[mlm_label]
                common_vocab_loss.append(common_log_prob[common_label_id].item())

                if pred == mlm_labels[i, masked_index]:
                    cor += 1
                    preds.append(1)
                else:
                    preds.append(0)

            cor_all += cor
            tot_all += tot
            for pred, sample, topk, vocab_loss in zip(preds, samples, topk, common_vocab_loss):
                if task_name not in result:
                    result[task_name] = (0, 0, 0, 0.0)
                    list_of_predictions[task_name] = []
                cor, tot, _, rel_tot_loss = result[task_name]
                tot += 1
                cor += pred
                rel_tot_loss += vocab_loss
                result[task_name] = (cor, tot, cor / tot_all if tot > 0 else 0.0, rel_tot_loss)
                list_of_predictions[task_name].append({
                    'sub_label': sample['sub'],
                    'obj_label': sample['obj'],
                    'topk': topk})

        print("%s %f %d %d" % (task_name, result[task_name][2], result[task_name][0], result[task_name][1]))
        output_topk = "baseline_topK"
        if output_topk is not None:
            print('Output top-k prediction to %s..' % output_topk)
            if not os.path.exists(output_topk):
                os.makedirs(output_topk)
            for rel in list_of_predictions:
                with open(os.path.join(output_topk, '%s_predictions.jsonl' % rel), 'w') as f:
                    f.write('\n'.join([json.dumps(x) for x in list_of_predictions[rel]]))
        # cor, tot, _, rel_tot_loss = result[task_name]

if __name__ == '__main__':
    # main_train()
    main_test()
