import pdb

from transformers import GPT2Config, AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from lsp_model.modeling_gpt2 import GPT2LMHeadPrefixModel
import random


import os
import re
import math
import time
from tqdm import tqdm
import numpy as np
import argparse, csv, json

def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)


parser = argparse.ArgumentParser()
parser.add_argument("-d","--data_file", nargs='+', type=str, required=True, help="Paths to the data file for which we want to make DGPT predictions")
parser.add_argument("-o", "--output_file", help="Path to the pickle file where we will save DGPT outputs", type=str, required=True)
parser.add_argument("-m", "--model_dir", help="Path to the directory containing DGPT-prefix model", type=str, required=True)
parser.add_argument("-bm", "--base_model_dir", help="Path to the directory containing pretrained DGPT model", type=str, required=True)
parser.add_argument("-n", "--num_samples", help="Number of samples for each input", type=int, default=5)
parser.add_argument("-bs", "--batch_size", help="Specifies the number of sentences that should be predicted at once", type=int, default=32)
parser.add_argument("-s", "--seed", help="fixed random seed", type=int, default=42)
parser.add_argument("-msl", "--max_seq_length", help="max sequence length of an utterance", type=int, default=1000)
parser.add_argument("-mgl", "--max_generation_length", help="max length of a generation", type=int, default=1000)

parser.add_argument("-pl", "--prefix_length", help="the length of each prefix", type=int, default=10)
parser.add_argument("-ph", "--prefix_hidden_size", help="the size of the prefix hidden layer", type=int, default=800)
parser.add_argument("-pn", "--prefix_num", help="the number of prefixes", type=int, default=2)

parser.add_argument("-ci",'--control_indexes', nargs='+', help="the indexes of the desired control code", type=int, default=None)



args = parser.parse_args()
set_seed(args)

import logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")

if torch.cuda.is_available():
    device = torch.device("cuda")
    logging.info(f"Using GPU{torch.cuda.get_device_name(0)} to make predictions")
else:
    device = torch.device("cpu")
    logging.info(f"Using CPU to make predictions")


def load_from_file(file_path_list, tokenizer):
    data=[]
    if args.max_seq_length is None:
        max_length=tokenizer.max_len
    else:
        max_length=args.max_seq_length
    for filepath in file_path_list:
        with open(filepath,'r') as f:
            reader=csv.DictReader(f)
            for row in reader:
                utt_list=[u.strip()for u in json.loads(row['utt_list'])]
                data.append([tokenizer.eos_token.join(utt_list)+tokenizer.eos_token, row['id']])
    batch_encoding=tokenizer(
        [example[0] for example in data],
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_token_type_ids=True,
    )
    all_input_ids = torch.tensor([batch_encoding['input_ids'][i] for i in range(len(data))], dtype=torch.long)
    all_attention_mask = torch.tensor([batch_encoding['attention_mask'][i] for i in range(len(data))], dtype=torch.long)
    all_token_type_ids = torch.tensor([batch_encoding['token_type_ids'][i] for i in range(len(data))], dtype=torch.long)
    dataset=TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids)
    return dataset

def top_p_filtering(logits,  top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k >0: keep only top k tokens with highest probability (top-k filtering).
            top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
    """

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        logits=logits.masked_fill(indices_to_remove, filter_value)
    return logits


def get_nucleus_sampling_generations_from_model(file_list, code_indexes,model, tokenizer, device):
    # We will implement custom batch nucleus sampling decoding while using the past variable.
    # We will start generating with the smallest sequence and finish updating when all the sequences generate EOS tokens.
    if code_indexes is None:
        code_indexes=[i for i in range(args.prefix_num)]
    code_indexes=torch.LongTensor([code_indexes]).to(device)
    all_input_generations = list()
    dataset = load_from_file(file_list, tokenizer)
    # Create tqdm progressbar
    generation_sampler=SequentialSampler(dataset)
    generation_dataloader=DataLoader(dataset, sampler=generation_sampler,batch_size=args.batch_size)
    # Setting model to eval for predictions
    # NOTE: assuming that model is already in the given device
    model.eval()

    with torch.no_grad():
        for idx, batch in enumerate(generation_dataloader):
            # Make predictions and save them
            current_batch_saved_generations = [[] for _ in range(batch[0].size(0))]
            batch = tuple(t.to(device) for t in batch)
            for _ in range(args.num_samples):
                # Tokenize the inputs in the batch and create input_ids and attention_mask for the model
                # Ref: https://github.com/huggingface/transformers/issues/3021
                input_lengths = torch.sum(batch[1], dim=-1).tolist()
                max_seq_len = max(input_lengths)
                min_seq_len = min(input_lengths)
                input_lengths = torch.tensor(input_lengths).long().to(device)

                input_ids = batch[0]
                # print(input_ids)
                attn_mask = batch[1]
                token_type_ids=batch[2]
                pad_token_id = tokenizer.eos_token_id
                eos_token_id = tokenizer.eos_token_id
                eos_not_in_sents = torch.ones(input_ids.shape[0]).long().to(device)

                past = None

                # Decode until all EOS found
                step = min_seq_len
                current_input_ids = input_ids[:, :min_seq_len]
                generation_ids = current_input_ids.clone()
                while eos_not_in_sents.float().sum().item() != 0.0 and step-min_seq_len < args.max_generation_length:
                    outputs = model(current_input_ids, past_key_values=past, prefix_indexes=code_indexes)
                    # print(tokenizer.decode(current_input_ids[0]))
                    next_token_logits = outputs[0][:, -1, :]
                    past = outputs[1]

                    # Intead of simple greedy decoding we will use nucleus sampling
                    # next_tokens = torch.argmax(next_token_logits, dim=-1)
                    # next_tokens = list()
                    top_p_next_token_logits = top_p_filtering(next_token_logits, top_p=0.9)
                    probabilities = F.softmax(top_p_next_token_logits, dim=-1)

                    next_tokens = torch.multinomial(probabilities, 1).squeeze(1)


                    # Compute flags to indicate whether to decode or copy from input_ids
                    copy_or_decode_flag = (input_lengths > step).long()
                    if step < max_seq_len:
                        next_input_tokens = input_ids[:, step]
                    else:
                        next_input_tokens = pad_token_id

                    # this updates which sentences have not seen an <EOS> token so far
                    # if one <EOS> token was seen the sentence is finished
                    # Only update if decoding
                    eos_not_in_sents.mul_(
                        next_tokens.ne(eos_token_id).long() * (1 - copy_or_decode_flag) + copy_or_decode_flag)

                    # either pick the next token from input_ids or decode
                    # if decoding, append a padding token here if <EOS> has been seen or append next token
                    tokens_to_add = next_input_tokens * (copy_or_decode_flag) + (1 - copy_or_decode_flag) * (
                                next_tokens * (eos_not_in_sents) + pad_token_id * (1 - eos_not_in_sents))

                    # Update next inputs and all generations
                    generation_ids = torch.cat([generation_ids, tokens_to_add.unsqueeze(-1)], dim=-1).to(device)
                    current_input_ids = tokens_to_add.unsqueeze(-1).to(device)
                    step += 1

                # flag = False
                # if eos_not_in_sents.float().sum().item() != 0.0:
                #     logging.warning(
                #         f"Some of the posts in current batch didn't finish properly. eos_not_in_sents = {eos_not_in_sents}")
                #     flag = True
                full_generations = [tokenizer.decode(output, skip_special_tokens=False).replace("\n", " ") for
                                    output in generation_ids]
                full_generations = [[e for e in s.split("<|endoftext|>") if e.strip()] for s in full_generations]
                try:
                    generations = [e[-1] if len(e) > 0 else "" for e in full_generations]
                    # if flag:
                    #     # TEMP: manually checking the unfinished generations
                    #     unfinished_gens = [(i, gen) for i, (gen, eos_flag) in
                    #                        enumerate(zip(generations, eos_not_in_sents.tolist())) if eos_flag]

                except IndexError:
                    # NOTE: There was an empty string in SRC which was causing this Index error
                    logging.error("Some generation has not completed properly")
                    for i, e in enumerate(full_generations):
                        logging.info(e)
                    logging.info("")
                    pdb.set_trace()
                # Update current_batch saved generations with new samples
                for i, generation in enumerate(generations):
                    current_batch_saved_generations[i].append(generation)
            # Save current batch_generation in final list
            print(current_batch_saved_generations)
            all_input_generations.extend(current_batch_saved_generations)

    return all_input_generations

def main():
    logging.info(f"Loading DialoGPT-prefix model from {args.base_model_dir} ...")
    tokenizer = AutoTokenizer.from_pretrained(args.base_model_dir)
    config=GPT2Config.from_pretrained(args.base_model_dir)
    config.prefix_length = args.prefix_length
    config.prefix_hidden_size = args.prefix_hidden_size
    config.prefix_num = args.prefix_num
    model = GPT2LMHeadPrefixModel.from_pretrained(args.base_model_dir, config=config)
    model.to(device)
    logging.info(f"Model loaded to device:{device}")
    tokenizer.pad_token = tokenizer.eos_token
    test_generations = get_nucleus_sampling_generations_from_model(args.data_file, args.control_indexes, model, tokenizer, device)
    with open(args.output_file,'w') as f:
        writer=csv.writer(f)
        writer.writerow(['generation'])
        for row in test_generations:
            writer.writerow(json.dumps(row))



if __name__ == "__main__":
    main()