import pdb

from transformers import GPT2Config, AutoModelForCausalLM, AutoTokenizer, AdamW, get_linear_schedule_with_warmup
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from lsp_model.modeling_gpt2 import GPT2LMHeadPrefixModel
import random
from tqdm import tqdm, trange
from typing import Tuple

import os
import re
import math
import time
from tqdm import tqdm
import numpy as np
import argparse, csv, json

from transformers.generation_logits_process import (
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    TemperatureLogitsWarper,
    TopKLogitsWarper,
    TopPLogitsWarper,
)

def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)


parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default='toxicity', help="task name, toxicity or stance")
parser.add_argument("-td", "--train_data_file", nargs='+', type=str, required=True, help="Paths to the training data file")
parser.add_argument("-ed", "--eval_data_file", nargs='+', type=str, help="Paths to the evaluation data file")
parser.add_argument("-to", "--train_output_dir", help="Path to where we will save checkpoints", type=str, required=True)
parser.add_argument("-eo", "--eval_output_dir", help="Path to where we will save eval outputs", type=str)
parser.add_argument("--balanced", action="store_true", help="use balanced dataset for training")
parser.add_argument("--one_utt", action="store_true", help="use only one utt in the history.")
parser.add_argument("--withstance", action="store_true", help="use both stance and offensiveness for label.")
parser.add_argument("-m", "--model_dir", help="Path to the directory containing DGPT-prefix model", type=str)
parser.add_argument("-bm", "--base_model_dir", help="Path to the directory containing pretrained DGPT model", type=str, required=True)
parser.add_argument("-n", "--num_samples", help="Number of samples for each input", type=int, default=5)
parser.add_argument("-bs", "--batch_size", help="Specifies the number of sentences that should be predicted at once", type=int, default=32)
parser.add_argument("-s", "--seed", help="fixed random seed", type=int, default=42)
parser.add_argument("-msl", "--max_seq_length", help="max sequence length of an utterance", type=int, default=1000)
parser.add_argument("-mgl", "--max_generation_length", help="max length of a generation", type=int, default=1000)
parser.add_argument("--device", type=int, default='0')

parser.add_argument("-pl", "--prefix_length", help="the length of each prefix", type=int, default=10)
parser.add_argument("-ph", "--prefix_hidden_size", help="the size of the prefix hidden layer", type=int, default=800)
parser.add_argument("-pn", "--prefix_num", help="the number of prefixes", type=int, default=2)

parser.add_argument("-ci",'--control_indexes', nargs='+', help="the indexes of the desired control code", type=int, default=None)

parser.add_argument("--gen_weight", help="weight of the generative loss", type=float, default=0.0)
parser.add_argument("--disc_weight", help="weight of the discriminative loss", type=float, default=0.0)

parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
parser.add_argument("--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step.")

parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")

parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.")
parser.add_argument("--max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.",)
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
parser.add_argument("--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory")

parser.add_argument("--top_k", type=int, default=50)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--temperature", type=float, default=None)
parser.add_argument("--min_length", type=int, default=10)


args = parser.parse_args()
set_seed(args)
if (
        os.path.exists(args.train_output_dir)
        and os.listdir(args.train_output_dir)
        and args.do_train
        and not args.overwrite_output_dir
):
    raise ValueError(
        "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
            args.train_output_dir
        )
    )
if (
        os.path.exists(args.eval_output_dir)
        and os.listdir(args.eval_output_dir)
        and (args.do_eval or args.evaluate_during_training)
        and not args.overwrite_output_dir
):
    raise ValueError(
        "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
            args.eval_output_dir
        )
    )

import logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")

if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.cuda.set_device(args.device)
    logging.info(f"Using GPU{torch.cuda.get_device_name(0)} to make predictions")
else:
    device = torch.device("cpu")
    logging.info(f"Using CPU to make predictions")


def train(train_data, model, tokenizer, eval_data=None):
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader)) + 1
    else:
        t_total = len(train_dataloader) * args.num_train_epochs
    for param in model.transformer.parameters():
        param.requires_grad = False
    for param in model.lm_head.parameters():
        param.requires_grad = False
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
         "weight_decay": 0.0},
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )

    logging.info("***** Running training *****")
    logging.info("  Num examples = %d", len(train_data))
    logging.info("  Num Epochs = %d", args.num_train_epochs)
    logging.info("  Train batch size = %d", args.batch_size)
    logging.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    tr_loss, logging_loss = 0.0, 0.0
    tr_disc_loss, logging_disc_loss= 0.0, 0.0
    tr_gen_loss, logging_gen_loss =0.0, 0.0
    model.zero_grad()
    train_iterator = trange(
        epochs_trained, int(args.num_train_epochs), desc="Epoch",
    )

    for epoch_ in train_iterator:

        epoch_iterator = train_dataloader
        for step, batch in enumerate(epoch_iterator):

            model.train()
            batch = tuple(t.to(device) for t in batch)

            input_ids = batch[0]
            lm_labels=batch[1]
            off_labels=batch[2]
            attention_mask=batch[3]
            token_type_ids=batch[4]

            bsz = input_ids.shape[0]

            #want to compute LM loss here so feeding inputs as labels
            inputs_a = {"input_ids": input_ids,
                          "lm_labels": lm_labels,
                          'prefix_indexes': off_labels.unsqueeze(1)}
            inputs_b = {"input_ids": input_ids,
                          "lm_labels": lm_labels,
                          'prefix_indexes': 1-off_labels.unsqueeze(1)}

            loss_agreeable = model(**inputs_a) #modeling_gpt2.py modified to have none reduction
            loss_notagreeable = model(**inputs_b)

            loss_mask = attention_mask[:,1:].to(torch.float32).to(device)
            loss_lengths = torch.sum(loss_mask,1,keepdim=True)

            loss_agreeable*=loss_mask
            loss_notagreeable*=loss_mask

            gen_loss = loss_agreeable / loss_lengths
            gen_loss = torch.sum(gen_loss)/bsz

            loss_agreeable = (loss_agreeable / loss_lengths).sum(dim=1)
            loss_notagreeable = (loss_notagreeable / loss_lengths).sum(dim=1)


            class_logits = torch.stack((-loss_notagreeable, -loss_agreeable), dim=1) #(bsz, 2) dimensional

            loss_fn = torch.nn.CrossEntropyLoss()
            ce_labels = input_ids.new_ones(bsz)
            disc_loss = loss_fn(class_logits, ce_labels)
            loss=args.gen_weight*gen_loss+args.disc_weight*disc_loss

            if np.isnan(loss.detach().cpu().numpy()):
                import pdb; pdb.set_trace()

            loss.backward()

            tr_loss += loss.item()
            tr_disc_loss += disc_loss.item()
            tr_gen_loss += gen_loss.item()

            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

            optimizer.step()
            scheduler.step()  # Update learning rate schedule
            model.zero_grad()
            global_step += 1

            if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                logs = {}
                if args.evaluate_during_training:
                    # results = get_nucleus_sampling_generations_from_model(
                    results=generate(
                args.eval_data_file, args.control_indexes, model, tokenizer, device, one_utt=args.one_utt)
                    if not os.path.exists(args.eval_output_dir):
                        os.makedirs(args.eval_output_dir)
                    with open(os.path.join(args.eval_output_dir, str(global_step) + '.csv'), 'w') as f:
                        writer = csv.writer(f)
                        writer.writerow(['generation'])
                        for row in results:
                            writer.writerow([json.dumps(row)])

                loss_scalar = (tr_loss - logging_loss) / args.logging_steps
                loss_disc_scalar = (tr_disc_loss - logging_disc_loss) / args.logging_steps
                loss_gen_scalar = (tr_gen_loss - logging_gen_loss) / args.logging_steps
                learning_rate_scalar = scheduler.get_lr()[0]
                logs["learning_rate"] = learning_rate_scalar
                logs["loss"] = loss_scalar
                logs['disc_loss'] = loss_disc_scalar
                logs['gen_loss'] = loss_gen_scalar
                logging_loss = tr_loss
                logging_disc_loss = tr_disc_loss
                logging_gen_loss = tr_gen_loss

                print(json.dumps({**logs, **{"step": global_step}}))

            if args.save_steps > 0 and global_step % args.save_steps == 0:
                # Save model checkpoint
                output_dir = os.path.join(args.train_output_dir, "checkpoint-{}".format(global_step))
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                model_to_save = (
                    model.module if hasattr(model, "module") else model
                )  # Take care of distributed/parallel training
                model_to_save.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)

                torch.save(args, os.path.join(output_dir, "training_args.bin"))
                logging.info("Saving model checkpoint to %s", output_dir)

                # torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                # torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                # logging.info("Saving optimizer and scheduler states to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    return global_step, tr_loss / global_step


def load_stance_from_file(file_path_list, tokenizer, evaluate=False, one_utt=False):
    data=[]
    if args.max_seq_length is None:
        max_length=tokenizer.max_len
    else:
        max_length=args.max_seq_length
    for filepath in file_path_list:
        with open(filepath,'r') as f:
            reader=csv.DictReader(f)
            for row in reader:
                utt_list=[u.strip()for u in json.loads(row['utt_list'])]
                if evaluate:
                    if one_utt:
                        data.append([utt_list[-1] + tokenizer.eos_token, row['id']])
                    else:
                        data.append([tokenizer.eos_token.join(utt_list) + tokenizer.eos_token, row['id']])
                else:
                    if len(row['response'].strip())>0:

                        if one_utt:
                            context=utt_list[-1]
                            label = 1 if int(row['res_stance'])==1 else 0
                            data.append([context + tokenizer.eos_token+'[GPT]: '+row['response'].strip()+tokenizer.eos_token,
                                         context + tokenizer.eos_token,
                                         label, row['id']])
                        else:
                            context=tokenizer.eos_token.join(utt_list)
                            label = 1 if int(row['res_stance'])==1 else 0
                            data.append([context + tokenizer.eos_token+'[GPT]: '+row['response'].strip()+tokenizer.eos_token,
                                         context + tokenizer.eos_token,
                                         label, row['id']])
    batch_encoding=tokenizer(
        [example[0] for example in data],
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_token_type_ids=True,
    )
    if evaluate:
        all_input_ids = torch.tensor([batch_encoding['input_ids'][i][:-1]+[tokenizer.eos_token_id] for i in range(len(data))], dtype=torch.long)
        all_attention_mask = torch.tensor([batch_encoding['attention_mask'][i] for i in range(len(data))], dtype=torch.long)
        all_token_type_ids = torch.tensor([batch_encoding['token_type_ids'][i] for i in range(len(data))], dtype=torch.long)
        dataset=TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids)
        return dataset
    else:
        #TODO: encoding truncation is problematic. Correct this
        all_input_ids = torch.tensor([batch_encoding['input_ids'][i] for i in range(len(data))], dtype=torch.long)
        all_attention_mask = torch.tensor([batch_encoding['attention_mask'][i] for i in range(len(data))],
                                          dtype=torch.long)
        all_token_type_ids = torch.tensor([batch_encoding['token_type_ids'][i] for i in range(len(data))],
                                          dtype=torch.long)
        batch_context_encoding = tokenizer(
            [example[1] for example in data],
            max_length=max_length,
            padding='max_length',
            truncation=True,
            return_token_type_ids=True,
        )
        context_attention_mask=batch_context_encoding['attention_mask']
        label_mask=torch.tensor(context_attention_mask, dtype=torch.long)
        label_mask=torch.cat([label_mask,
                              torch.LongTensor(label_mask.size(0), all_input_ids.size(1)-label_mask.size(1)).fill_(0)],
                             dim=-1)
        all_label_ids=label_mask*(-1)+(1-label_mask)*all_input_ids
        all_label_ids=torch.cat([all_label_ids[:,1:], torch.LongTensor(all_input_ids.size(0),1).fill_(-1)], dim=-1)
        all_offense_labels=torch.tensor([example[2] for example in data], dtype=torch.long)

        dataset = TensorDataset(all_input_ids, all_label_ids, all_offense_labels, all_attention_mask, all_token_type_ids)
        return dataset


def load_from_file(file_path_list, tokenizer, evaluate=False, one_utt=False, withstance=False):
    data=[]
    if args.max_seq_length is None:
        max_length=tokenizer.max_len
    else:
        max_length=args.max_seq_length
    for filepath in file_path_list:
        with open(filepath,'r') as f:
            reader=csv.DictReader(f)
            for row in reader:
                utt_list=[u.strip()for u in json.loads(row['utt_list'])]
                if evaluate:
                    if one_utt:
                        data.append([utt_list[-1] + tokenizer.eos_token, row['id']])
                    else:
                        data.append([tokenizer.eos_token.join(utt_list) + tokenizer.eos_token, row['id']])
                else:
                    if len(row['response'].strip())>0:

                        if one_utt:
                            context=utt_list[-1]
                            last_offense = json.loads(row['utt_label'])[-1]
                            if withstance:
                                label = 0
                                if int(row['res_label']) == 1 or (int(row['res_stance']) == 1 and last_offense == 1):
                                    label = 1
                            else:
                                label = int(row['res_label'])
                            data.append([context + tokenizer.eos_token+'[GPT]: '+row['response'].strip()+tokenizer.eos_token,
                                         context + tokenizer.eos_token,
                                         label, row['id']])
                        else:
                            context=tokenizer.eos_token.join(utt_list)
                            last_offense = json.loads(row['utt_label'])[-1]
                            if withstance:
                                label = 0
                                if int(row['res_label']) == 1 or (int(row['res_stance']) == 1 and last_offense == 1):
                                    label = 1
                            else:
                                label = int(row['res_label'])
                            data.append([context + tokenizer.eos_token+'[GPT]: '+row['response'].strip()+tokenizer.eos_token,
                                         context + tokenizer.eos_token,
                                         label, row['id']])
    batch_encoding=tokenizer(
        [example[0] for example in data],
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_token_type_ids=True,
    )
    if evaluate:
        all_input_ids = torch.tensor([batch_encoding['input_ids'][i][:-1]+[tokenizer.eos_token_id] for i in range(len(data))], dtype=torch.long)
        all_attention_mask = torch.tensor([batch_encoding['attention_mask'][i] for i in range(len(data))], dtype=torch.long)
        all_token_type_ids = torch.tensor([batch_encoding['token_type_ids'][i] for i in range(len(data))], dtype=torch.long)
        dataset=TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids)
        return dataset
    else:
        #TODO: encoding truncation is problematic. Correct this
        all_input_ids = torch.tensor([batch_encoding['input_ids'][i] for i in range(len(data))], dtype=torch.long)
        all_attention_mask = torch.tensor([batch_encoding['attention_mask'][i] for i in range(len(data))],
                                          dtype=torch.long)
        all_token_type_ids = torch.tensor([batch_encoding['token_type_ids'][i] for i in range(len(data))],
                                          dtype=torch.long)
        batch_context_encoding = tokenizer(
            [example[1] for example in data],
            max_length=max_length,
            padding='max_length',
            truncation=True,
            return_token_type_ids=True,
        )
        context_attention_mask=batch_context_encoding['attention_mask']
        label_mask=torch.tensor(context_attention_mask, dtype=torch.long)
        label_mask=torch.cat([label_mask,
                              torch.LongTensor(label_mask.size(0), all_input_ids.size(1)-label_mask.size(1)).fill_(0)],
                             dim=-1)
        all_label_ids=label_mask*(-1)+(1-label_mask)*all_input_ids
        all_label_ids=torch.cat([all_label_ids[:,1:], torch.LongTensor(all_input_ids.size(0),1).fill_(-1)], dim=-1)
        all_offense_labels=torch.tensor([example[2] for example in data], dtype=torch.long)

        dataset = TensorDataset(all_input_ids, all_label_ids, all_offense_labels, all_attention_mask, all_token_type_ids)
        return dataset


def init_sequence_length_for_generation(
            input_ids: torch.LongTensor, max_length: int
    ) -> Tuple[torch.Tensor, torch.Tensor, int]:
    unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
    sequence_lengths = input_ids.new(input_ids.shape[0]).fill_(max_length)

    cur_len = input_ids.shape[-1]
    return sequence_lengths, unfinished_sequences, cur_len

def update_seq_length_for_generation(
        sequence_lengths: torch.LongTensor,
        unfinished_sequences: torch.LongTensor,
        cur_len: int,
        is_eos_in_next_token: torch.BoolTensor,
) -> Tuple[torch.LongTensor, torch.LongTensor]:
    # check if sentence is not finished yet
    is_sent_unfinished = unfinished_sequences.mul(is_eos_in_next_token.long()).bool()

    # update sentence length
    sequence_lengths = sequence_lengths.masked_fill(is_sent_unfinished, cur_len)
    unfinished_sequences = unfinished_sequences.mul((~is_eos_in_next_token).long())
    return sequence_lengths, unfinished_sequences

def generate( file_list,code_indexes, model, tokenizer, device, one_utt=False):
    all_input_generations = list()
    code_num=0
    if code_indexes is not None:
        code_num=len(code_indexes)
        code_indexes=torch.LongTensor([code_indexes]).to(device)
    testing_data = []
    for infilepath in file_list:
        with open(infilepath, 'r') as f:
            reader = csv.DictReader(f)
            for row in reader:
                utt_list = [u.strip() for u in json.loads(row['utt_list'])]
                if not one_utt:
                    testing_data.append(tokenizer.eos_token.join(utt_list) + tokenizer.eos_token)
                else:
                    testing_data.append(utt_list[-1]+tokenizer.eos_token)
    model.eval()

    top_k=50 if args.top_k is None else args.top_k
    top_p = 0.9 if args.top_p is None else args.top_p
    temperature=1.0 if args.temperature is None else args.temperature
    min_length=10 if args.min_length is None else args.min_length
    max_length=args.max_generation_length

    logits_warper = LogitsProcessorList()
    if top_k is not None and top_k != 0:
        logits_warper.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1))
    if top_p is not None and top_p < 1.0:
        logits_warper.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1))
    if temperature is not None and temperature != 1.0:
        logits_warper.append(TemperatureLogitsWarper(temperature))

    logits_processor = LogitsProcessorList()
    if min_length is not None and min_length > -1:
        logits_processor.append(MinLengthLogitsProcessor(min_length, tokenizer.eos_token_id))
    with torch.no_grad():
        for data_idx in range(len(testing_data)):
            input_ids = tokenizer.encode(
                testing_data[data_idx])
            if len(input_ids)>args.max_seq_length:
                input_ids=input_ids[:(args.max_seq_length-1)]+[tokenizer.eos_token_id]
            added_input_ids = torch.LongTensor(tokenizer.encode('[GPT]:'))
            gen_len=added_input_ids.size(-1)
            input_ids=torch.cat([torch.tensor(input_ids), added_input_ids], dim=-1)
            input_ids=input_ids.unsqueeze(0).to(device)
            # print(input_ids.data)
            attention_mask_input = input_ids.new_ones(input_ids.size(0), input_ids.size(1)+args.prefix_length*code_num)

            input_length = input_ids.size(-1)

            expanded_return_idx = (
                torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, args.num_samples).view(-1).to(
                    input_ids.device)
            )

            input_ids = input_ids.index_select(0, expanded_return_idx)
            attention_mask_input = attention_mask_input.index_select(0, expanded_return_idx)

            # init sequence length for generation
            sequence_lengths, unfinished_sequences, cur_len = init_sequence_length_for_generation(
                input_ids, max_length
            )

            for _ in range(max_length):
                # print(input_ids.size())
                outputs = model(input_ids=input_ids,
                                attention_mask=attention_mask_input,
                                prefix_indexes=code_indexes)
                next_token_logits = outputs[0][:, -1, :].detach().clone()

                scores = logits_processor(input_ids, next_token_logits)
                # if gen_len < min_length:
                #     scores[:, tokenizer.eos_token_id] = -float("inf")
                scores = logits_warper(input_ids, scores)

                probs = F.softmax(scores, dim=-1)

                next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
                next_tokens = next_tokens * unfinished_sequences + (tokenizer.pad_token_id) * (
                            1 - unfinished_sequences)

                input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
                cur_len = cur_len + 1
                gen_len = gen_len +1

                sequence_lengths, unfinished_sequences = update_seq_length_for_generation(
                    sequence_lengths, unfinished_sequences, cur_len, next_tokens == tokenizer.eos_token_id
                )
                if unfinished_sequences.max() == 0:
                    break
                # past_input=outputs.past_key_values
                # past_toxic=toxic_outputs.past_key_values
                attention_mask_input = torch.cat(
                    [attention_mask_input, attention_mask_input.new_ones((attention_mask_input.shape[0], 1))],
                    dim=-1
                )

            output_list = list(input_ids[:,input_length:].detach().cpu().numpy())
            generated_list = []
            for i in range(input_ids.size(0)):
                generated_list.append(tokenizer.decode(output_list[i],
                    clean_up_tokenization_spaces=True,
                    skip_special_tokens=True))
            all_input_generations.append(generated_list)
    # print(all_input_generations)
    return all_input_generations


def main():
    logging.info(f"Loading DialoGPT-prefix model from {args.base_model_dir} ...")
    tokenizer = AutoTokenizer.from_pretrained(args.base_model_dir)
    config=GPT2Config.from_pretrained(args.base_model_dir)
    config.prefix_length = args.prefix_length
    config.prefix_hidden_size = args.prefix_hidden_size
    config.prefix_num = args.prefix_num
    model = GPT2LMHeadPrefixModel.from_pretrained(args.base_model_dir, config=config)
    model.to(device)
    logging.info(f"Model loaded to device:{device}")
    logging.info(f"Training/evaluation parameters:\n%s", args)
    tokenizer.pad_token = tokenizer.eos_token
    if args.do_train:
        if args.task=='toxicity':
            train_data=load_from_file(args.train_data_file, tokenizer, evaluate=False, one_utt=args.one_utt, withstance=args.withstance)
        else:
            assert args.task=='stance'
            train_data = load_stance_from_file(args.train_data_file, tokenizer, evaluate=False, one_utt=args.one_utt)
        if args.evaluate_during_training:
            if args.task == 'toxicity':
                eval_data=load_from_file(args.eval_data_file, tokenizer, evaluate=True, one_utt=args.one_utt, withstance=args.withstance)
            else:
                assert args.task == 'stance'
                eval_data = load_stance_from_file(args.eval_data_file, tokenizer, evaluate=True, one_utt=args.one_utt)
        else:
            eval_data=None
        train(train_data, model, tokenizer, eval_data)

        if not os.path.exists(args.train_output_dir):
            os.makedirs(args.train_output_dir)

        logging.info("Saving model checkpoint to %s", args.train_output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        model_to_save = (
            model.module if hasattr(model, "module") else model
        )  # Take care of distributed/parallel training
        model_to_save.save_pretrained(args.train_output_dir)
        tokenizer.save_pretrained(args.train_output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(args, os.path.join(args.train_output_dir, "training_args.bin"))

    if args.do_eval:
        checkpoints = [args.train_output_dir]
        logging.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split("-")[-1]
            if global_step.isdigit()==False:
                if global_step[-1]=='/':
                    global_step=checkpoint[:-1].split('/')[-1]+'_end'
                else:
                    global_step=checkpoint.split('/')[-1]+'_end'
            # prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""

            model = GPT2LMHeadPrefixModel.from_pretrained(checkpoint)
            model.to(device)

            # test_generations = get_nucleus_sampling_generations_from_model(
            test_generations=generate(
                args.eval_data_file, args.control_indexes, model, tokenizer, device,one_utt=args.one_utt)
            if not os.path.exists(args.eval_output_dir):
                os.makedirs(args.eval_output_dir)
            with open(os.path.join(args.eval_output_dir, global_step+'.csv'),'w') as f:
                writer=csv.writer(f)
                writer.writerow(['generation'])
                for row in test_generations:
                    writer.writerow([json.dumps(row)])



if __name__ == "__main__":
    main()