import argparse
import os
import random
import datasets
import numpy as np
import torch
from typing import List, Dict
from datasets import load_dataset
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from transformers import T5ForConditionalGeneration, GPT2Tokenizer
from transformers import get_scheduler


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--path_to_model",
        default="ai-forever/FRED-T5-large",
        type=str,
        help="The model checkpoint for weights initialization. Leave None if you want to train a model from scratch.",
    )
    parser.add_argument(
        "--path_to_train",
        default=os.path.join(os.getcwd(), "data", "train_udata_wiki.csv"),
#         default=os.path.join(os.getcwd(), "data", "synthetic_dev_v2.csv"),
        type=str,
        help="The input training data file (a text file)."
    )
    parser.add_argument(
        "--path_to_news_dev",
        default=os.path.join(os.getcwd(), "data", "synthetic_dev_v2.csv"),
        type=str,
        help="The input synthetic dev file (a text file)."
    )
    parser.add_argument(
        "--path_to_wiki_dev",
        default=os.path.join(os.getcwd(), "data", "synthetic_wiki_dev.csv"),
        type=str,
        help="The input wiki dev data file (a text file)."
    )
    parser.add_argument(
        "--path_to_assembled_train_short",
        default=os.path.join(os.getcwd(), "data", "train_dataset_fredt5_v2"),
#         default=os.path.join(os.getcwd(), "data", "check_dataset_fredt5_v2"),
        type=str,
        help="The input udata wiki train ds data file (a text file)."
    )
    parser.add_argument(
        "--path_to_checkpoints",
        default=os.path.join(os.getcwd(), 'pretrained_checkpoints_fredt5_v2'),
        type=str,
        help="Path to save checkpoints."
    )
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--train_shuffle", action="store_true")
    parser.add_argument("--no_cuda", action="store_true")
    parser.add_argument("--num_training_epochs", type=int, default=5)
    # parser.add_argument("--num_training_epochs", type=int, default=2)
    parser.add_argument("--learning_rate", type=float, default=3e-04)
    parser.add_argument("--weight_decay", type=float, default=0.001)
    parser.add_argument("--scheduler_type", type=str, default="linear")
    parser.add_argument("--num_warmup_steps", type=int, default=0)
    parser.add_argument("--track_every_num_steps", type=int, default=5_000)
#     parser.add_argument("--track_every_num_steps", type=int, default=10)
    parser.add_argument("--num_iter_to_accumulate", type=int, default=2)
    parser.add_argument("--dev_batch_size", type=int, default=16)
    parser.add_argument("--evaluation_num_steps", type=int, default=35_000)
    parser.add_argument("--task_prefix", type=str, default="Исправь: ")

    parser.add_argument("--seed", type=int, default=1234)
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
    parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
    parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
    args = parser.parse_args()
    return args


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)


def _merge(src, crr):
    res = {'text': [], 'correction': []}
    for s, c in zip(src, crr):
        res['text'].append(s['text'])
        res['correction'].append(c['text'])
    return datasets.Dataset.from_dict(res)

def _preprocess(samples, tokenizer, task_prefix, **kwargs):
    texts = [task_prefix + sample.strip('\ufeff') for sample in samples["text"]]
    corrections = [sample.strip('\ufeff') for sample in samples["correction"]]
    
    text_encodings = tokenizer(texts, **kwargs, return_tensors="pt")
    corr_encodings = tokenizer(text_target=corrections, **kwargs, return_tensors="pt")
    
    # replace 0 with -100
    labels = corr_encodings['input_ids']
    corr_encodings['input_ids'] = torch.where(
        labels == tokenizer.pad_token_id, torch.full(labels.shape, -100), labels)
    return {
        'input_ids': text_encodings['input_ids'],
        'attention_mask': text_encodings['attention_mask'],
        'labels': corr_encodings['input_ids'],
        # "decoder_attention_mask": corr_encodings["attention_mask"] 
    }


class T2TDataCollator():
    def __call__(self, batch: List) -> Dict[str, torch.Tensor]:
        """
        Take a list of samples from a Dataset and collate them into a batch.
        Returns: A dictionary of tensors
        """
        max_length = max([len(sample["input_ids"]) for sample in batch])
        max_length_labels = max([len(sample["labels"]) for sample in batch])
        ids_dtype = batch[0]["input_ids"].dtype
        att_dtype = batch[0]["attention_mask"].dtype
        lab_dtype = batch[0]["labels"].dtype
        
        input_ids = torch.stack([torch.cat((sample["input_ids"], torch.full(size=(max_length-len(sample["input_ids"]),), fill_value=0, dtype=ids_dtype))) for sample in batch])
        attention_mask = torch.stack([torch.cat((sample["attention_mask"], torch.full(size=(max_length-len(sample["attention_mask"]),), fill_value=0, dtype=att_dtype))) for sample in batch])
        labels = torch.stack([torch.cat((sample["labels"], torch.full(size=(max_length_labels-len(sample["labels"]),), fill_value=-100, dtype=lab_dtype))) for sample in batch])

        return {
            'input_ids': input_ids, 
            'attention_mask': attention_mask,
            'labels': labels, 
        }



def setup_dataset(args):
    raw_datasets = load_dataset(
        "csv",
        data_files={
            "train": [args.path_to_train],
            "news_dev": args.path_to_news_dev,
            "wiki_dev": args.path_to_wiki_dev
        },
        split=(
            datasets.ReadInstruction("train"),
            datasets.ReadInstruction("news_dev"),
            datasets.ReadInstruction("wiki_dev"),
        )
    )
    preprocess_config = {
        'max_length': 512,
        'padding': True,
        'truncation': True,
    }
    lambda_helper = lambda x: _preprocess(x, args.tokenizer, args.task_prefix, **preprocess_config)
    if os.path.exists(args.path_to_assembled_train_short):
        train = datasets.load_from_disk(args.path_to_assembled_train_short)
    else:
        train = raw_datasets[0].map(
            lambda_helper,
            batched=True,
            batch_size=args.batch_size,
            remove_columns=['text', 'correction'],
            drop_last_batch=True,
        )
        if args.local_rank in [-1, 0]:
            train.save_to_disk(args.path_to_assembled_train_short)

    news_dev = raw_datasets[1].map(
        lambda_helper,
        batched=True,
        batch_size=args.dev_batch_size,
        remove_columns=['text', 'correction'],
        drop_last_batch=False,
    )

    wiki_dev = raw_datasets[2].map(
        lambda_helper,
        batched=True,
        batch_size=args.dev_batch_size,
        remove_columns=['text', 'correction'],
        drop_last_batch=False,
    )
    train.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
    news_dev.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
    wiki_dev.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
    train_sampler = RandomSampler(train) if args.local_rank == -1 else DistributedSampler(
        train, shuffle=True)
    
    train_dataloader = DataLoader(
        train,
        batch_size=args.batch_size,
        # shuffle=args.train_shuffle,
        sampler=train_sampler,
        collate_fn=T2TDataCollator(),
        drop_last=True,
    )

    news_dev_dataloader = DataLoader(
        news_dev, batch_size=args.dev_batch_size,
        shuffle=False,
        drop_last=False,
    )

    wiki_dev_dataloader = DataLoader(
        wiki_dev, batch_size=args.dev_batch_size,
        shuffle=False,
        drop_last=False,
    )

    return train_dataloader, news_dev_dataloader, wiki_dev_dataloader


def setup():
    args = parse_args()
    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd

        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        args.n_gpu = 1
    args.device = device
    set_seed(args)

    model = T5ForConditionalGeneration.from_pretrained(args.path_to_model)
    model = model.to(args.device)
    args.tokenizer = GPT2Tokenizer.from_pretrained(args.path_to_model, eos_token='</s>')

    train_dataloader, news_dev_dataloader, wiki_dev_dataloader = setup_dataset(args)
    
    args.training_steps = len(train_dataloader) // args.num_iter_to_accumulate * args.num_training_epochs

    no_decay = ['bias', "layer_norm.weight"]
    # training_steps = (len(train_dataloader) * args.batch_size // args.effective_batch_size) * args.num_training_epochs

    optimizer_grouped_parameters = [
        {
            'params': [p for n, p in model.named_parameters() if any([f in n for f in no_decay])],
            'weight_decay': 0.
        },
        {
            'params': [p for n, p in model.named_parameters() if not any([f in n for f in no_decay])],
            'weight_decay': args.weight_decay
        }
    ]
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
    optimizer.zero_grad()
    scheduler = get_scheduler(
        args.scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=args.num_warmup_steps,
        num_training_steps=args.training_steps,
    )

    scaler = torch.cuda.amp.GradScaler()

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )
    return (
        args,
        train_dataloader,
        news_dev_dataloader,
        wiki_dev_dataloader,
        model,
        args.tokenizer,
        optimizer,
        scheduler,
        scaler
    )


def main():
    (args, train_dataloader,
     news_dev_dataloader,
     wiki_dev_dataloader,
     model,
     tokenizer,
     optimizer,
     scheduler,
     scaler
     ) = setup()
    progress_bar = tqdm(range(args.training_steps))
    for epoch in range(args.num_training_epochs):
        model.train()
        running_loss = 0.
        for step, batch in enumerate(train_dataloader):
            # print(args.device, step, "------------------")
            for k, v in batch.items():
                batch[k] = v.to(args.device)
            # print(args.device, step, "before forward ------------------")
            with torch.cuda.amp.autocast(dtype=torch.bfloat16):
                outputs = model(**batch)
                loss = outputs.loss.mean()
                loss = loss / args.num_iter_to_accumulate
            scaler.scale(loss).backward()
            # print(args.device, step, "after backward ------------------")

            running_loss += loss.item()

            if (step + 1) % args.num_iter_to_accumulate == 0 or step == len(train_dataloader) - 1:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad()
                progress_bar.update(1)

            if (step + 1) % args.track_every_num_steps == 0:
                if args.local_rank == 0:
                    print()
                    print("Iteration {}/{} of epoch {} complete. Loss : {} "
                        .format(step + 1, len(train_dataloader), epoch + 1, running_loss / args.track_every_num_steps))

                running_loss = 0.0

            if (step + 1) % args.evaluation_num_steps == 0 or step == len(train_dataloader) - 1:
                if args.local_rank == 0:
                    current_folder = 'model-epoch_{}-step_{}'.format(epoch, step)
                    model.module.save_pretrained(os.path.join(args.path_to_checkpoints, current_folder))

                    torch.save(
                        optimizer.state_dict(),
                        os.path.join(args.path_to_checkpoints, current_folder, "optimizer.pth.tar"))
                    torch.save(
                        scheduler.state_dict(),
                        os.path.join(args.path_to_checkpoints, current_folder, "scheduler.pth.tar"))
                    model.train()


if __name__ == "__main__":
    main()
