import os
import random
import argparse
from collections import Counter
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score
import numpy as np
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer
import transformers
import torch
from transformers import AutoModelForSeq2SeqLM
from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType
import torch
from datasets import load_dataset
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from transformers import default_data_collator, get_linear_schedule_with_warmup
from tqdm import tqdm
from datasets import load_dataset

# from huggingface_hub import login
# access_token_read = "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"
# access_token_write = "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"
# login(token = access_token_read)


import sys
class Unbuffered(object):
   def __init__(self, stream):
       self.stream = stream
   def write(self, data):
       self.stream.write(data)
       self.stream.flush()
   def writelines(self, datas):
       self.stream.writelines(datas)
       self.stream.flush()
   def __getattr__(self, attr):
       return getattr(self.stream, attr)
sys.stdout = Unbuffered(sys.stdout)

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='t5_small', help='model name')
parser.add_argument('--load_model_from_disk', type=bool, default=False)
parser.add_argument('--dataset', type=str, default='scifact_oracle', help='dataset name')
parser.add_argument('--dataset_dir', type=str, help='directory')
parser.add_argument('--direction', type=str, default='ce', help='direction of the model')
parser.add_argument('--output_dir', type=str, default=".")
parser.add_argument('--num_epochs', type=int, default=20)



args = parser.parse_args()

def read_data(dataset):
    paths = {
        "train": f"data/{dataset}/train.csv",
        "test": f"data/{dataset}/test.csv"
    }

    # load the split
    train = pd.read_csv(paths['train'], index_col='index')
    test = pd.read_csv(paths['test'], index_col='index')

    return train['claim'].tolist(), train['evidence'].tolist(), train['label'].tolist(), test['claim'].tolist(), test['evidence'].tolist(), test['label'].tolist()

def sample_t(labels_train, t=10, seed = 123):
    random.seed(seed)
    s = [i for i, label in enumerate(labels_train) if label =='SUPPORTS']
    n = [i for i, label in enumerate(labels_train) if label =='NOT_ENOUGH_INFO' or label == 'NOT ENOUGH INFO']
    c = [i for i, label in enumerate(labels_train) if label =='REFUTES']
    all_indexes = []
    for l in [s, n, c]:
        indexes = random.sample(l, t)
        all_indexes.extend(indexes)
    # for index in all_indexes:
    #     print(index, labels_train[index])
    return all_indexes

def generate_response(prompt_input):
    sequences = pipeline(
    prompt_input,
    do_sample=True,
    top_k=10,
    num_return_sequences=1,
    eos_token_id=tokenizer.eos_token_id,
    return_full_text=False,
    max_new_tokens=10,
    )
    return sequences[0]['generated_text'].strip("label: ").replace('\n', '')

    #todo: add batch size for inference


if __name__ == '__main__':

    print(torch.cuda.is_available())
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    
    model = args.model.replace("_", "-")
    model_name_or_path = model
    tokenizer_name_or_path = model

    if args.direction == 'ec':
        text_column = "evidence"
        label_column = "claim"
    else:
        text_column = "claim"
        label_column = "evidence"

    max_length = 512
    lr = 1e-4
    num_epochs = args.num_epochs
    train_batch_size = 16
    eval_batch_size = 128

    # creating model
    peft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)

    model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()

    # loading dataset
    data_files = {"train": "train.csv", "test": "test.csv"}
    dataset = load_dataset(f"{args.dataset_dir}/{args.dataset}", data_files=data_files)
    print(dataset)
    # classes = dataset["train"].features["label"].names
    # dataset = dataset.map(
    #     lambda x: {"text_label": [classes[label] for label in x["label"]]},
    #     batched=True,
    #     num_proc=1,
    # )

    print(dataset["train"][0])


    # data preprocessing
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)


    def preprocess_function(examples):
        inputs = examples[text_column]
        targets = examples[label_column]
        model_inputs = tokenizer(inputs, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt")
        labels = tokenizer(targets, max_length=512, padding="max_length", truncation=True, return_tensors="pt")
        labels = labels["input_ids"]
        labels[labels == tokenizer.pad_token_id] = -100
        model_inputs["labels"] = labels
        return model_inputs


    processed_datasets = dataset.map(
        preprocess_function,
        batched=True,
        num_proc=1,
        remove_columns=dataset["train"].column_names,
        load_from_cache_file=False,
        desc="Running tokenizer on dataset",
    )

    train_dataset = processed_datasets["train"]
    eval_dataset = processed_datasets["test"]

    train_dataloader = DataLoader(
        train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=train_batch_size, pin_memory=True
    )
    eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=eval_batch_size, pin_memory=True)


    # optimizer and lr scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=(len(train_dataloader) * num_epochs),
    )


    # training and evaluation
    model = model.to(device)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        train_preds = []
        for step, batch in enumerate(tqdm(train_dataloader)):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            train_preds.extend(
                tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)
            )
            loss = outputs.loss
            total_loss += loss.detach().float()
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        model.eval()
        eval_loss = 0
        eval_preds = []
        for step, batch in enumerate(tqdm(eval_dataloader)):
            batch = {k: v.to(device) for k, v in batch.items()}
            with torch.no_grad():
                outputs = model(**batch)
            loss = outputs.loss
            eval_loss += loss.detach().float()
            eval_preds.extend(
                tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)
            )

        eval_epoch_loss = eval_loss / len(eval_dataloader)
        eval_ppl = torch.exp(eval_epoch_loss)
        train_epoch_loss = total_loss / len(train_dataloader)
        train_ppl = torch.exp(train_epoch_loss)
        print(f"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}")

        #save generated text as a jsonl file, each line is a json object with key: generated_text

        train_preds = pd.DataFrame(train_preds, columns=['generated_text'])
        print(train_preds)
        eval_preds = pd.DataFrame(eval_preds, columns=['generated_text'])
        print(eval_preds)

        # save geenrated text to json file
        # file_output_dir = f"{args.output_dir}/{args.direction}/lora_{args.model}_{args.dataset}"
        file_output_dir = "{}/{}/lora_{}_{}".format(args.output_dir, args.direction, args.model, args.dataset)

        if not os.path.exists(file_output_dir):
            os.makedirs(file_output_dir, exist_ok=True)
        
        train_preds.to_json(f"{file_output_dir}/epoch_{epoch}_val_split_predictions.json", orient='records')
        eval_preds.to_json(f"{file_output_dir}/epoch_{epoch}_test_split_predictions.json", orient='records')


        # with torch.no_grad():
        #     outputs = model.generate(input_ids=inputs["input_ids"], 
        #     do_sample = True, top_k = 0, temperature = 0.7, min_length = 50, max_new_tokens = 100)
        # eval_dataset["generated_text"] = tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)
