import glob
import json
import pathlib
import traceback
import uuid
import os

import click
import numpy
import pandas
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import RichProgressBar, ModelCheckpoint
from torch import nn

import wandb
from datasets import Dataset
from pytorch_lightning.loggers import WandbLogger
from sklearn.metrics import accuracy_score, f1_score
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, \
    DebertaForSequenceClassification, AutoConfig, BartForSequenceClassification
import torch
from dotenv import load_dotenv

from hyperparams import HYPERPARAMS
from models.CrossEncoderAdjusted import CrossEncoderAdjusted
from models.classification import FineTuningClassification, MeanFineTuningClassification, CLSFineTuningClassification
from utils.composition import compose_samples
from utils.dataset import FineTuningDataset
from utils.training import check_run_done

def compute_metrics(eval_preds):
    logits, labels = eval_preds

    if type(logits) == tuple:
        ## for BART model
        logits = logits[0]

    predictions = numpy.argmax(logits, axis=-1)

    metrics = {
        "accuracy": accuracy_score(y_true=labels, y_pred=predictions),
        "f1-macro": f1_score(y_true=labels, y_pred=predictions, average="macro"),
    }

    return metrics

def truncate_sentence(sentence, truncation_length, tokenizer):
    if sentence == None:
        sentence = "None"
    tokens = tokenizer.encode(text=sentence, max_length=truncation_length, truncation=True, add_special_tokens=False)
    return tokenizer.decode(tokens)

@click.command()
@click.option('--task', type=str, default="x-stance-fr")
@click.option('--model_name', type=str, default="deepset/gbert-base")
@click.option('--fold', type=int, default=0)
@click.option('--setup', type=str, default="it")
@click.option('--pooling', type=str, default="cls")
@click.option('--seed', type=int, default=20)
@click.option('--batch_size', type=int, default=-1)
@click.option('--epochs', type=int, default=5)
@click.option('--dropout_rate', type=float, default=0.1)
@click.option('--learning_rate', type=float, default=0.00002)
@click.option('--dump', type=bool, default=False)
@click.option('--dump_path', type=str, default="/root/cache-7/cifs/gender/")
#@click.option('--gradient_accumulation_steps', type=int, default=1)
def main(task, model_name, fold, setup, pooling, seed, batch_size, epochs, dropout_rate, learning_rate, dump, dump_path):

    if task in HYPERPARAMS and model_name in HYPERPARAMS[task]:
        learning_rate = HYPERPARAMS[task][model_name]["learning_rate"]
        batch_size = HYPERPARAMS[task][model_name]["batch_size"]

    if batch_size == -1:
        batch_size = 16

    load_dotenv()

    task_id = task + "-" + setup + "-fold-" + str(fold)

    if "few-shot" in task:
        base_task = task.split("@")[-1]
    else:
        base_task = task

    mode = os.getenv('MODE')
    gpu = "cuda" if os.getenv('MODE') == "prod" else "cpu"

    training = "FINE_TUNING"

    try:
        train_samples = pandas.read_json("../tasks/" + task_id + "/train.jsonl").sort_index()
        dev_samples = pandas.read_json("../tasks/" + task_id + "/dev.jsonl").sort_index()
        test_samples = pandas.read_json("../tasks/" + task_id + "/test.jsonl").sort_index()
        other_test_samples = {
            file.replace("../tasks/" + task_id + "/test_", "").replace(".jsonl", ""): pandas.read_json(file).sort_index()
            for file in glob.glob("../tasks/" + task_id + "/test_*.jsonl")
        }

    except:
        train_samples = pandas.read_json("../tasks/" + task_id + "/train.jsonl", lines=True).sort_index()
        dev_samples = pandas.read_json("../tasks/" + task_id + "/dev.jsonl", lines=True).sort_index()
        test_samples = pandas.read_json("../tasks/" + task_id + "/test.jsonl", lines=True).sort_index()
        other_test_samples = {
            file.replace("../tasks/" + task_id + "/test_", "").replace(".jsonl", ""): pandas.read_json(file, lines=True).sort_index()
            for file in glob.glob("../tasks/" + task_id + "/test_*.jsonl")
        }

    if "text" in dev_samples.columns:
        dev_samples = dev_samples.sort_values('text',key=lambda x:x.str.len())
        test_samples = test_samples.sort_values('text',key=lambda x:x.str.len())
        other_test_samples = {
            test_set:samples.sort_values('text',key=lambda x:x.str.len())
            for test_set, samples in other_test_samples.items()
        }

    num_classes = len(train_samples["label"].unique())

    hyperparameter = {
        "mode": mode,
        "model_name": model_name,
        "pooling": pooling,
        "fold": fold,
        "setup": setup,
        "training": training,
        "batch_size": batch_size,
        "dropout_rate": dropout_rate,
        "learning_rate": learning_rate,
        "seed": seed,
    }

    hyperparameter["warmup_steps"] = int(train_samples.shape[0] * epochs / batch_size * 0.1)
    hyperparameter["training_steps"] = int(train_samples.shape[0] * epochs / batch_size)

    is_run_done = check_run_done(task, hyperparameter)

    if not is_run_done:

        if pooling == "mean":
            model = MeanFineTuningClassification(hyperparameter=hyperparameter, num_classes=num_classes)
        elif pooling == "cls":
            model = CLSFineTuningClassification(hyperparameter=hyperparameter, num_classes=num_classes)

        tokenizer = AutoTokenizer.from_pretrained(model_name)

        def shorten_text(samples):
            samples["text"] = samples["text"].apply(lambda sentence: truncate_sentence(sentence, 300, tokenizer))
            return samples

        if "review" in task or "sentiment" in task or "germeval" in task:
            train_samples = shorten_text(train_samples)
            dev_samples = shorten_text(dev_samples)
            test_samples = shorten_text(test_samples)
            other_test_samples = {
                test_set:shorten_text(samples)
                for test_set, samples in other_test_samples.items()
            }

        def tokenize_function(samples):
            composed_samples = compose_samples(samples, task=base_task, sep_token=tokenizer.sep_token)
            samples["input_ids"] = [
                tokenizer.encode(composed_sample, truncation=True)
                for composed_sample in composed_samples
            ]
            return samples

        train_samples = tokenize_function(train_samples)
        dev_samples = tokenize_function(dev_samples)
        test_samples = tokenize_function(test_samples)
        other_test_samples = {
            test_set:tokenize_function(samples)
            for test_set, samples in other_test_samples.items()
        }

        train_dataset = FineTuningDataset(train_samples)
        dev_dataset = FineTuningDataset(dev_samples)
        test_dataset = FineTuningDataset(test_samples)

        other_test_datasets = {
            test_set:FineTuningDataset(samples)
            for test_set, samples in other_test_samples.items()
        }

        run_id = str(uuid.uuid4())

        wandb_logger = WandbLogger(project="new-" + task, id=run_id)


        if "xnli" in task or "x-sentiment":
            batch_size = int(batch_size/4)
            accumulate_grad_batches = 4
        elif "germeval" in task:
            batch_size = int(batch_size/4)
            accumulate_grad_batches = 4
        else:
            accumulate_grad_batches = 1

        trainer = Trainer(
            max_epochs=epochs, gradient_clip_val=1.0, logger=wandb_logger, accelerator=gpu, num_sanity_val_steps=0, accumulate_grad_batches=accumulate_grad_batches,
            callbacks=[RichProgressBar(), ModelCheckpoint(monitor="eval/f1-macro",  mode="max", dirpath="./" + run_id + "-checkpoints")]
        )

        data_collator = DataCollatorWithPadding(tokenizer=tokenizer, max_length=512, padding="longest")

        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=data_collator)
        dev_dataloader = DataLoader(dev_dataset, batch_size=batch_size, shuffle=False, collate_fn=data_collator)
        test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=data_collator)

        other_test_dataloader = {
            test_set:DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=data_collator)
            for test_set, dataset in other_test_datasets.items()
        }

        try:
            trainer.fit(model=model, train_dataloaders=[train_dataloader], val_dataloaders=[dev_dataloader])
            trainer.test(ckpt_path="best", dataloaders=[test_dataloader])
            test_prediction = trainer.predict(ckpt_path="best", dataloaders=[test_dataloader], return_predictions=True)

            other_test_predictions = {
                test_set: trainer.predict(ckpt_path="best", dataloaders=[dataloader], return_predictions=True)
                for test_set, dataloader in other_test_dataloader.items()
            }

        except Exception as e:
            print(e)
            traceback.print_exc()
            wandb.join()
            os.system("rm -rf ./" + run_id + "-checkpoints")
            return

        test_prediction = torch.concat(test_prediction).numpy()

        other_test_predictions = {
            test_set: torch.concat(predictions).numpy()
            for test_set, predictions in other_test_predictions.items()
        }

        def add_predictions(samples, predictions):
            samples["pred"] = predictions

            if "input_ids" in samples.columns:
                del samples["input_ids"]

            return samples

        test_samples = add_predictions(test_samples, test_prediction)

        other_test_samples = {
            test_set: add_predictions(other_test_samples[test_set], predictions)
            for test_set, predictions in other_test_predictions.items()
        }

        test_samples_table = wandb.Table(dataframe=test_samples)

        wandb.log({
            "test_predictions": test_samples_table
        })


        wandb.config["strategy"] = "full"
        wandb.config["status"] = "done"
        wandb.config.update(
            {
                k: str(v)
                for k, v in hyperparameter.items()
            },
            allow_val_change=True
        )
        wandb.join()

        for test_set, samples in other_test_samples.items():
            run_id = str(uuid.uuid4())
            wandb.init(id=run_id, project="new-" + task)


            samples_table = wandb.Table(dataframe=samples)

            wandb.log({
                test_set + "_test_predictions": samples_table
            })

            if num_classes > 2:
                wandb.log({
                    "test/f1-macro": f1_score(samples["label"], samples["pred"], average="macro"),
                    "test/accuracy": accuracy_score(samples["label"], samples["pred"]),
                })
            else:
                wandb.log({
                    "test/f1-macro": f1_score(samples["label"], samples["pred"]),
                    "test/accuracy": accuracy_score(samples["label"], samples["pred"]),
                })

            wandb.config["strategy"] = test_set
            wandb.config["status"] = "done"
            wandb.config.update(
                {
                    k: str(v)
                    for k, v in hyperparameter.items()
                },
                allow_val_change=True
            )
            wandb.join()
        #os.system("mv " + trainer.state.best_model_checkpoint + "/* " + model_store + "/" + run_id)

        run_dump_path = f"{dump_path}/{run_id}/"

        print(f"--- ---dump to {run_dump_path}", trainer.checkpoint_callback.best_model_path)
        os.system(f"mkdir {run_dump_path}")
        os.system(f"cp {trainer.checkpoint_callback.best_model_path} {run_dump_path}/dump.ckpt")


        os.system("rm -rf ./" + run_id + "-checkpoints")

    else:
        print("Run already done")



if __name__ == "__main__":
    main()