import argparse
import json
import logging
import os
import glob
import re
from seqeval import metrics as seqeval_metrics
from sklearn import metrics as sklearn_metrics
import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from fastprogress.fastprogress import master_bar, progress_bar
from attrdict import AttrDict
from transformers import AutoConfig, AutoTokenizer,BertConfig
from transformers.trainer_utils import is_main_process
import logging
from transformers import (
    AdamW,
    get_linear_schedule_with_warmup
)
from processor import seq_cls_load_and_cache_examples as load_and_cache_examples
from processor import MultiProcessor
from model.bert import BertForRelationAwareClassification
import random

def init_logger():
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )


logger = logging.getLogger(__name__)

def simple_accuracy(labels, preds):
    return (labels == preds).mean()
def acc_score(labels, preds):
    return {
        "acc": simple_accuracy(labels, preds),
    }
def f1_pre_rec(labels, preds):
    return {"f1": sklearn_metrics.f1_score(labels, preds, average="macro"),}

def train(args,
          model,tokenizer,
          train_dataset,
          dev_dataset=None,
          test_dataset=None):
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
         'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(t_total * args.warmup_proportion), num_training_steps=t_total)

    if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
            os.path.join(args.model_name_or_path, "scheduler.pt")
    ):
        # Load optimizer and scheduler states
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Total train batch size = %d", args.train_batch_size)
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)
    logger.info("  Logging steps = %d", args.logging_steps)
    logger.info("  Save steps = %d", args.save_steps)

    global_step = 0
    tr_loss = 0.0

    model.zero_grad()
    mb = master_bar(range(int(args.num_train_epochs)))
    for epoch in mb:
        epoch_iterator = progress_bar(train_dataloader, parent=mb)
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "class_label": batch[3],
                "span_label":batch[5],
                "p_mask":batch[4],
            }
            if args.model_type not in ["distilkobert", "xlm-roberta"]:
                inputs["token_type_ids"] = batch[2]  # Distilkobert, XLM-Roberta don't use segment_ids
            outputs = model(**inputs)
            loss, loss_cls,loss_span = outputs.loss, outputs.loss_cls, outputs.loss_span

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            loss.backward()
            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0 or (
                    len(train_dataloader) <= args.gradient_accumulation_steps
                    and (step + 1) == len(train_dataloader)
            ):
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

                optimizer.step()
                scheduler.step()
                model.zero_grad()
                global_step += 1
            if args.max_steps > 0 and global_step > args.max_steps:
                break

        if args.evaluate_test_during_training:
            evaluate(args, model, test_dataset, "test", epoch + 1)
            evaluate(args, model, dev_dataset, "dev", epoch + 1)

        output_dir = os.path.join(args.output_dir, "epoch-{}".format(epoch+1))
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        model_to_save = (
            model.module if hasattr(model, "module") else model
        )
        model_to_save.save_pretrained(output_dir)
        tokenizer.save_pretrained(output_dir)
        torch.save(args, os.path.join(output_dir, "training_args.bin"))
        logger.info("Saving model checkpoint to {}".format(output_dir))
        if args.save_optimizer:
            torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
            torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
            logger.info("Saving optimizer and scheduler states to {}".format(output_dir))

        mb.write("Epoch {} done".format(epoch + 1))

        if args.max_steps > 0 and global_step > args.max_steps:
            break

    return global_step, tr_loss / global_step



def evaluate(args, model, eval_dataset, mode, epoch=None):
    results = {}
    f1_results = {}
    token_f1_results={}
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

    # Eval!
    if epoch != None:
        logger.info("***** Running evaluation on {} dataset ({} epoch) *****".format(mode, epoch))
    else:
        logger.info("***** Running evaluation on {} dataset *****".format(mode))
    logger.info("  Num examples = {}".format(len(eval_dataset)))
    logger.info("  Eval Batch size = {}".format(args.eval_batch_size))
    eval_loss = 0.0
    nb_eval_steps = 0
    preds_class = None
    preds_spans=None
    out_label_ids = None
    out_label_span = None

    check_span_preds = []
    check_out_label_span = []
    for batch in progress_bar(eval_dataloader):
        model.eval()
        batch = tuple(t.to(args.device) for t in batch)

        with torch.no_grad():
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "class_label": batch[3],
                "span_label": batch[5],
                "p_mask": batch[4],
            }
            if args.model_type not in ["distilkobert", "xlm-roberta"]:
                inputs["token_type_ids"] = batch[2]  # Distilkobert, XLM-Roberta don't use segment_ids
            outputs = model(**inputs)
            tmp_eval_loss, class_logits,span_logits = outputs.loss,outputs.class_logits,outputs.span_logits

            eval_loss += tmp_eval_loss.mean().item()
        nb_eval_steps += 1
        if preds_class is None:
            preds_class = class_logits.detach().cpu().numpy()
            out_label_ids = inputs["class_label"].detach().cpu().numpy()


        else:
            preds_class = np.append(preds_class, class_logits.detach().cpu().numpy(), axis=0)
            out_label_ids = np.append(out_label_ids, inputs["class_label"].detach().cpu().numpy(), axis=0)

        active_logits = span_logits.view(-1, 3).tolist()
        p_masks=inputs["p_mask"].view(-1).tolist()
        span_label_list=inputs["span_label"].view(-1).tolist()

        for idx,real in enumerate(p_masks):
            if real ==0:
                check_out_label_span.append(span_label_list[idx])
                check_span_preds.append(np.argmax(active_logits[idx]))


    eval_loss = eval_loss / nb_eval_steps
    preds = np.argmax(preds_class, axis=1)

    result = acc_score(out_label_ids, preds)
    f1_result=f1_pre_rec(out_label_ids, preds)
    token_f1s=f1_pre_rec(check_out_label_span, check_span_preds)
    results.update(result)
    f1_results.update(f1_result)
    token_f1_results.update(token_f1s)

    output_dir = os.path.join(args.output_dir, mode)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    output_eval_file = os.path.join(output_dir, "{}-{}.txt".format(mode, epoch) if epoch else "{}.txt".format(mode))
    with open(output_eval_file, "w") as f_w:
        logger.info("***** Eval results on {} dataset *****".format(mode))
        for key in sorted(results.keys()):
            logger.info("  {} = {}".format(key, str(results[key])))
            f_w.write("  {} = {}\n".format(key, str(results[key])))
        for key in sorted(f1_results.keys()):
            logger.info("  {} = {}".format(key, str(f1_results[key])))
            f_w.write("  {} = {}\n".format(key, str(f1_results[key])))
        for key in sorted(token_f1_results.keys()):
            logger.info("  {} = {}".format(key, str(token_f1_results[key])))
            f_w.write("token_{} = {}\n".format(key, str(token_f1_results[key])))

    return results,f1_results,token_f1_results

def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if not args.no_cuda and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

def main(cli_args):
    # Read from config file and make args
    with open(os.path.join(cli_args.config_dir, cli_args.config_file)) as f:
        args = AttrDict(json.load(f))
    logger.info("Training/evaluation parameters {}".format(args))

    args.output_dir = os.path.join(args.ckpt_dir, args.output_dir)

    init_logger()
    set_seed(args)

    processor = MultiProcessor(args)
    labels = processor.get_labels()
    span_labels=processor.get_labels_span()
    config = BertConfig.from_pretrained(
            args.model_name_or_path,
            id2label={str(i): label for i, label in enumerate(labels)},
            label2id={label: i for i, label in enumerate(labels)},
            span_labels={label: i for i, label in enumerate(span_labels)}
        )
    model = BertForRelationAwareClassification.from_pretrained(
        args.model_name_or_path,
        config=config
    )
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name_or_path,
        do_lower_case=args.do_lower_case
    )
    n_added_token = tokenizer.add_special_tokens({'additional_special_tokens': tokenizer.additional_special_tokens + ['[SPEAKER]']})
    span_token_id = tokenizer.additional_special_tokens_ids[tokenizer.additional_special_tokens.index('[SPEAKER]')]
    logger.warning(
        f'SPAN_TOKEN "[SPEAKER]" was added as "{span_token_id}". You can safely ignore'
        ' this warning if you are training a model from pretrained LMs.')
    model.resize_token_embeddings(len(tokenizer))
    # GPU or CPU
    args.device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"

    model.to(args.device)

    # Load dataset
    train_dataset = load_and_cache_examples(args=args, tokenizer=tokenizer, mode="train",span_token_id=span_token_id) if args.train_file else None
    dev_dataset = load_and_cache_examples(args, tokenizer, mode="dev",span_token_id=span_token_id) if args.dev_file else None
    test_dataset = load_and_cache_examples(args, tokenizer, mode="test",span_token_id=span_token_id) if args.test_file else None

    if dev_dataset == None:
        args.evaluate_test_during_training = True  # If there is no dev dataset, only use testset

    if args.do_train:
        global_step, tr_loss = train(args, model, tokenizer,train_dataset, dev_dataset, test_dataset)
        logger.info(" global_step = {}, average loss = {}".format(global_step, tr_loss))

    results = {}
    results2 = {}
    f1results = {}
    f1results2 = {}
    token_f1results = {}
    token_f1results2 = {}

    if args.do_eval:
        checkpoints = list(os.path.dirname(c) for c in
                           sorted(glob.glob(args.output_dir + "/**/" + "pytorch_model.bin", recursive=True)))
        if not args.eval_all_checkpoints:
            checkpoints = checkpoints[-1:]
        else:
            logging.getLogger("transformers.configuration_utils").setLevel(logging.WARN)  # Reduce logging
            logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # Reduce logging
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            epoch = checkpoint.split("-")[-1]
            model = BertForRelationAwareClassification.from_pretrained(checkpoint)
            model.to(args.device)
            accresult,f1result,tokenresult = evaluate(args, model, test_dataset, mode="test", epoch=epoch)
            accresult2,f1result2,tokenresult2 = evaluate(args, model, dev_dataset, mode="dev", epoch=epoch)
            accresult = dict((k + "_{}".format(epoch), v) for k, v in accresult.items())
            accresult2 = dict((k + "_{}".format(epoch), v) for k, v in accresult2.items())
            results.update(accresult)
            results2.update(accresult2)
            f1result = dict((k + "_{}".format(epoch), v) for k, v in f1result.items())
            f1result2 = dict((k + "_{}".format(epoch), v) for k, v in f1result2.items())
            f1results.update(f1result)
            f1results2.update(f1result2)
            tokenresult = dict((k + "_{}".format(epoch), v) for k, v in tokenresult.items())
            tokenresult2 = dict((k + "_{}".format(epoch), v) for k, v in tokenresult2.items())
            token_f1results.update(tokenresult)
            token_f1results2.update(tokenresult2)


        output_eval_file = os.path.join(args.output_dir, "eval_results_with_tokenf1.txt")
        with open(output_eval_file, "w") as f_w:
            for key in sorted(results.keys()):
                f_w.write("{} = {}\n".format(key, str(results[key])))
            for key in sorted(f1results.keys()):
                f_w.write("{} = {}\n".format(key, str(f1results[key])))
            for key in sorted(token_f1results.keys()):
                f_w.write("token_{} = {}\n".format(key, str(token_f1results[key])))

        output_dev_file = os.path.join(args.output_dir, "eval_dev_results_with_tokenf1.txt")
        with open(output_dev_file, "w") as f_w:
            for key in sorted(results2.keys()):
                f_w.write("{} = {}\n".format(key, str(results2[key])))
            for key in sorted(f1results2.keys()):
                f_w.write("{} = {}\n".format(key, str(f1results2[key])))
            for key in sorted(token_f1results2.keys()):
                f_w.write("token_{} = {}\n".format(key, str(token_f1results2[key])))

if __name__ == '__main__':

    cli_parser = argparse.ArgumentParser()

    cli_parser.add_argument("--config_dir", type=str, default="config")
    cli_parser.add_argument("--config_file", type=str, required=True)

    cli_args = cli_parser.parse_args()

    main(cli_args)