import argparse
import shutil
import os
import random
import datasets
import numpy as np
import torch
from datasets import load_dataset, concatenate_datasets
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
from transformers import DataCollatorForSeq2Seq
from transformers import get_scheduler

import tensorflow as tf
from spell_check_evaluation.evaluate_v2.evaluate import evaluate


def validation(
    model, 
    args, 
    dev_dataset,
#     module,
    penalize_ngramms_num = None,
):
    model.eval()
    device = model.device
    
    answers = []
    sources = []
    corrections = []
    for row in tqdm(dev_dataset):
        with torch.no_grad():
            encodings = args.tokenizer(row['source'], return_tensors='pt')
            for k, v in encodings.items():
                encodings[k] = v.to(device)
            generated_tokens = model.generate(
                **encodings, 
                forced_bos_token_id=args.tokenizer.get_lang_id("ru"), 
                no_repeat_ngram_size=penalize_ngramms_num).cpu()
            ans = args.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
#             module.add(source = row['source'], corrections = row['correction'], answers = ans[0])
#     metrics = module.compute()
        answers.append(ans[0])
        sources.append(row['source'])
        corrections.append(row['correction'])

    metrics = evaluate(sources, corrections, answers)
    return metrics

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--path_to_model",
        default=None,
        type=str,
        help="The model checkpoint for weights initialization. Leave None if you want to train a model from scratch.",
    )
    parser.add_argument(
        "--path_to_tokenizer",
        default=os.path.join(os.getcwd(), 'tokenizer'),
        type=str,
        help="The model checkpoint for weights initialization. Leave None if you want to train a model from scratch.",
    )
    parser.add_argument(
        "--path_to_data",
        default=os.path.join(os.getcwd(), "spellcheck_benchmark", "russian_spellcheck_benchmark.py"),
        type=str,
        help="The path to training data."
    )
    parser.add_argument(
        "--name_dataset",
        default="RUSpellRU",
        type=str,
        help="The name train dataset."
    )
    parser.add_argument(
        "--device_name",
        default="cuda:0",
        type=str,
        help="The device name."
    )
    parser.add_argument(
        "--aug_suffix",
        default="",
        type=str,
        help="The path to training augmentation data."
    ) 
    parser.add_argument(
        "--path_to_checkpoints",
        default=os.path.join(os.getcwd(), 'finetuned_checkpoints', 'm2m418m'),
        type=str,
        help="Path to save checkpoints."
    )
    parser.add_argument(
        "--path_to_evaluation_module",
        default=os.path.join(os.getcwd(), 'spell_check_evaluation', 'spellcheck_metric.py'),
        type=str,
        help="Path to evaluation module."
    )
    parser.add_argument("--optimizer_type", type=str, default="adamw")
    parser.add_argument("--batch_size", type=int, default=None)
    parser.add_argument("--train_shuffle", action="store_true")
    parser.add_argument("--num_training_epochs", type=int, default=None)
    parser.add_argument("--learning_rate", type=float, default=None)
    parser.add_argument("--weight_decay", type=float, default=None)
    parser.add_argument("--scheduler_type", type=str, default="linear")
    parser.add_argument("--num_warmup_steps", type=int, default=None)
    parser.add_argument("--track_every_num_steps", type=int, default=50)
    parser.add_argument("--dev_batch_size", type=int, default=8)
    parser.add_argument("--task_prefix", type=str, default="")
    parser.add_argument("--seed", type=int, default=134)
    args = parser.parse_args()
    return args


def get_optimizer(optimizer_grouped_parameters, optimizer_type, learning_rate):
    if optimizer_type == "adamw":
        optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=learning_rate)
    elif optimizer_type == "adam":
        optimizer = torch.optim.Adam(optimizer_grouped_parameters, lr=learning_rate)
    elif optimizer_type == "sgd":
        optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=learning_rate)

    return optimizer

def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

def _preprocess(samples, tokenizer, task_prefix, **kwargs):
    texts = [task_prefix + sample.strip('\ufeff') for sample in samples["source"]]
    corrections = [sample.strip('\ufeff') for sample in samples["correction"]]
    
    text_encodings = tokenizer(texts, **kwargs, return_tensors="pt")
    corr_encodings = tokenizer(text_target=corrections, **kwargs, return_tensors="pt")
    
    # replace 0 with -100
    labels = corr_encodings['input_ids']
    corr_encodings['input_ids'] = torch.where(
        labels == tokenizer.pad_token_id, torch.full(labels.shape, -100), labels)
    return {
        'input_ids': text_encodings['input_ids'],
        'attention_mask': text_encodings['attention_mask'],
        'labels': corr_encodings['input_ids'],
        # "decoder_attention_mask": corr_encodings["attention_mask"] 
    }

def setup_dataset(args):
    if args.name_dataset == "RUSpellRU":
        raw_datasets = load_dataset(
            "json", 
            data_files={"train": f"datasets/RUSpellRU_train{args.aug_suffix}.json", "dev": "datasets/RUSpellRU_dev.json"}
        )
        train_dataset = raw_datasets['train']
        dev_dataset = raw_datasets['dev']        
    
    elif args.name_dataset == "MultidomainGold":
        raw_datasets = load_dataset(
            "json", 
            data_files={"train": f"datasets/MultidomainGold_train{args.aug_suffix}.json", "dev": "datasets/MultidomainGold_dev.json"}
        )
        train_dataset = raw_datasets['train']
        dev_dataset = raw_datasets['dev']
        
    elif args.name_dataset == "ALL":
        ruspell_datasets = load_dataset(
            "json", 
            data_files={"train": f"datasets/RUSpellRU_train{args.aug_suffix}.json", "dev": "datasets/RUSpellRU_dev.json"}
        )
        gold_datasets = load_dataset(
            "json", 
            data_files={"train": f"datasets/MultidomainGold_train{args.aug_suffix}.json", "dev": "datasets/MultidomainGold_dev.json"}
        )
        train_dataset = concatenate_datasets([ruspell_datasets['train'], gold_datasets['train']])
        dev_dataset = concatenate_datasets([ruspell_datasets['dev'], gold_datasets['dev']])

    return train_dataset, dev_dataset
    
def preprocess_dataset(args, model, train_dataset, dev_dataset):
    
    preprocess_config = {
        'max_length': None,
        'padding': 'longest',
        'truncation': False,
    }
    lambda_helper = lambda x: _preprocess(x, args.tokenizer, args.task_prefix, **preprocess_config)

    train = train_dataset.map(
        lambda_helper,
        batched=True,
        batch_size=args.batch_size,
        remove_columns=['source', 'correction', "domain"],
        drop_last_batch=False,
    )

    dev = dev_dataset.map(
        lambda_helper,
        batched=True,
        batch_size=args.dev_batch_size,
        remove_columns=['source', 'correction', "domain"],
        drop_last_batch=False,
    )
    
    collator = DataCollatorForSeq2Seq(
        tokenizer=args.tokenizer,
        model=model,
        padding=True,
        label_pad_token_id=-100,
        return_tensors='pt',
    )
    
    train_dataloader = DataLoader(
        train,
        batch_size=args.batch_size,
        shuffle=args.train_shuffle,
        collate_fn=collator, 
        drop_last=False,
    )

    dev_dataloader = DataLoader(
        dev, batch_size=args.dev_batch_size,
        shuffle=False,
        collate_fn=collator, 
        drop_last=False,
    )

    return train_dataloader, dev_dataloader

def setup(args):
    device = torch.device(args.device_name)
    args.n_gpu = torch.cuda.device_count()
    args.device = device
    set_seed(args)

def setup_optimizer_scheduler(model, args):
    no_decay = ['bias', "layer_norm.weight"]
    optimizer_grouped_parameters = [
        {
            'params': [p for n, p in model.named_parameters() if any([f in n for f in no_decay])],
            'weight_decay': 0.
        },
        {
            'params': [p for n, p in model.named_parameters() if not any([f in n for f in no_decay])],
            'weight_decay': args.weight_decay
        }
    ]
    optimizer = get_optimizer(optimizer_grouped_parameters, args.optimizer_type, args.learning_rate)
    optimizer.zero_grad()
    scheduler = get_scheduler(
        args.scheduler_type, 
        optimizer = optimizer, 
        num_warmup_steps = args.num_warmup_steps,
        num_training_steps = args.training_steps,
    )
    
    return optimizer, scheduler

def train_epoch(args, model, train_dataloader, optimizer, scheduler, epoch):
    running_loss = 0.
    train_loss = 0.
    for step, batch in enumerate(train_dataloader):
        for k, v in batch.items():
            batch[k] = v.to(args.device)
        loss = model(**batch).loss.mean()
        running_loss += loss.item()
        train_loss += loss.item()
        loss.backward()

        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        if (step + 1) % args.track_every_num_steps == 0:
            print("Iteration {}/{} of epoch {} complete. Loss : {} "
                  .format(step + 1, len(train_dataloader), epoch + 1, running_loss / args.track_every_num_steps))
            running_loss = 0.0
    
    return train_loss / (step + 1)

def eval_model(args, model, dev_dataloader, epoch):
    eval_loss = 0.
    for i, batch in tqdm(enumerate(dev_dataloader), total=len(dev_dataloader)):
        with torch.no_grad():
            for k, v in batch.items():
                batch[k] = v.to(args.device)

            loss = model(**batch).loss.mean().detach().item()
            eval_loss += loss
    score = eval_loss / (i + 1)
    print("Epoch {} complete. Eval_loss: {}".format(epoch + 1, score))
    return score

def main(args):
    setup(args)        
    model = M2M100ForConditionalGeneration.from_pretrained(args.path_to_model)
    model = model.to(args.device)
    args.tokenizer = M2M100Tokenizer.from_pretrained(args.path_to_tokenizer, src_lang="ru", tgt_lang="ru")

    train_dataset, dev_dataset = setup_dataset(args)
    train_dataloader, dev_dataloader = preprocess_dataset(args, model, train_dataset, dev_dataset)
    args.training_steps = len(train_dataloader) * args.num_training_epochs

    optimizer, scheduler = setup_optimizer_scheduler(model, args)


    best_score = 0.
    for epoch in range(args.num_training_epochs):

        model.train()
        train_loss = train_epoch(args, model, train_dataloader, optimizer, scheduler, epoch)

        model.eval()
        dev_metrics = validation(model, args, dev_dataset)
        print('\n'.join(['metrics {} = {}'.format(k, v) for k, v in dev_metrics.items()]))
        
        if dev_metrics["F1"] >= best_score:
            shutil.rmtree(args.path_to_checkpoints, ignore_errors=True)
            current_folder = 'model-epoch-{}-devPr-{}-devR-{}-devF1-{}'.format(
                epoch+1, dev_metrics["Precision"], dev_metrics["Recall"], dev_metrics["F1"]
            )
            model.save_pretrained(os.path.join(args.path_to_checkpoints, current_folder))
            best_score = dev_metrics["F1"]
