import json
from pathlib import Path

import numpy as np
import torch
from pytorch_metric_learning import miners, losses
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix

from utils.metrics import build_main_report


def run_epoch(model,
              dataloader,
              optimizer=None,
              criterion=None,
              scheduler=None,
              is_trainable: bool = True,
              is_test: bool = False,
              new_intents=None,
              parameters=None,
              epoch_number: int = 0,
              best_metric_valid: float = None,
              return_metric: bool = True,
              logging_steps: int = 25,
              logger=None,
              save_path: str = 'checkpoint.pt',
              device=None,
              predict_action: bool = False,
              predict_concept: bool = False,
              metric_learning: bool = False
              ):
    if is_trainable:
        model.train()
    else:
        model.eval()
    print("=" * 50)
    print(f"Epoch {epoch_number} Trainable {is_trainable} STARTS")

    if device is None:
        if torch.cuda.is_available():
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')
    all_prediction = torch.Tensor()
    all_target = dataloader.dataset['labels']
    all_loss = torch.Tensor([0])

    if metric_learning:
        miner = miners.MultiSimilarityMiner()
        loss_func = losses.TripletMarginLoss()

    for batch_id, batch in enumerate(dataloader):
        input_ids = {i: batch[i].to(device)
                     for i in ['input_ids', 'attention_mask']}
        intent_description = {
            i: j.to(device) for i, j in parameters['descriptions_names'].items()}
        if predict_action:
            action_description = {i: j.to(device)
                                  for i, j in parameters['actions'].items()}
        if predict_concept:
            concept_description = {i: j.to(device)
                                   for i, j in parameters['concepts'].items()}

        output = model(batch={i: batch[i].to(device) for i in ['input_ids', 'attention_mask']},
                       intent_description=intent_description,
                       concept_description=concept_description if predict_concept else None,
                       action_description=action_description if predict_action else None,
                       )

        loss = criterion(output['intent_prediction'],
                         batch['labels'].to(device))
        if predict_action:
            action_loss = criterion(
                output['action_prediction'], batch['labels_actions'].to(device))
            loss += action_loss
        if predict_concept:
            concept_loss = criterion(
                output['concept_prediction'], batch['labels_concepts'].to(device))
            loss += concept_loss
        if metric_learning:
            metric_learning_loss = loss_func(
                output['utt_embeddings'], batch['labels'])
            loss += metric_learning_loss

        predictions = output['intent_prediction'].max(axis=1).indices.cpu()
        all_prediction = torch.cat((all_prediction, predictions))
        all_loss += loss.item()

        if is_trainable:
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            if scheduler is not None:
                scheduler.step()
        if batch_id % logging_steps == 0 and batch_id > 0:
            f1_score_ = f1_score(
                all_target[: all_prediction.shape[0]], all_prediction, average='macro')
            accuracy_ = accuracy_score(
                all_target[: all_prediction.shape[0]], all_prediction)
            print(
                f"Epoch {epoch_number} Trainable {is_trainable} Step {batch_id} out of {len(dataloader)} Loss {all_loss.item() / batch_id}\tf1-score {np.round(f1_score_, 3)}\taccuracy {np.round(accuracy_, 3)}")

    f1_score_ = f1_score(all_target, all_prediction, average='macro')
    accuracy_ = accuracy_score(all_target, all_prediction)
    print(
        f"Epoch {epoch_number}, Loss {all_loss.item() / len(dataloader)}\tf1-score {np.round(f1_score_, 3)}\taccuracy {np.round(accuracy_, 3)}")

    metrics = dict()
    if is_test:
        if new_intents is not None:
            new_intent_target, new_intent_pred = list(), list()
            for tar, pred in zip(all_target.tolist(), all_prediction.tolist()):
                if tar in new_intents:
                    new_intent_target.append(tar)
                    new_intent_pred.append(pred)
            metrics.update(build_main_report(new_intent_target, new_intent_pred, prefix='new_'))
            metrics.update({'new_confusion_matrix': confusion_matrix(new_intent_target, new_intent_pred).tolist()})
    #     print(metrics['new_f1_score'])
    #     print(len(new_intent_target), len(new_intent_pred))
    #     print(abc)

    if not is_trainable and not is_test:
        if f1_score_ > best_metric_valid['f1_score'] or epoch_number == 0:
            print(
                f"Epoch {epoch_number} Trainable {is_trainable},\t Loss {all_loss.item() / len(dataloader)}\tf1-score {f1_score_}\taccuracy {accuracy_}")
            torch.save({
                'epoch': epoch_number,
                'model_state_dict': model.state_dict(),
                'f1_score': f1_score_,
                'accuracy': accuracy_,
            }, save_path)

    if return_metric:
        metrics.update({'f1_score': f1_score_,
                        'accuracy': accuracy_,
                        'confusion_matrix': confusion_matrix(all_target, all_prediction).tolist()}
                       )
        return metrics


def get_f1_score(statistics):
    precisions = {i: j['tp'] / (j['tp'] + j['fp'])
                  for i, j in statistics.items()}
    recalls = {i: j['tp'] / (j['tp'] + j['fn']) for i, j in statistics.items()}
    f1_score = {i: 2 *
                precisions[i] *
                recalls[i] /
                (precisions[i] +
                 recalls[i]) for i in precisions.keys()}
    f1_score = np.mean(list(f1_score.values()))
    return f1_score


def run_epoch_binary(model,
                     dataloader,
                     optimizer=None,
                     criterion=None,
                     scheduler=None,
                     is_trainable: bool = True,
                     is_test: bool = False,
                     parameters=None,
                     epoch_number: int = 0,
                     best_metric_valid: float = None,
                     return_metric: bool = True,
                     logging_steps: int = 25,
                     logger=None,
                     save_path: str = 'checkpoint.pt',
                     device=None,
                     ):
    if is_trainable:
        model.train()
    else:
        model.eval()
    print("=" * 50)
    print(f"Epoch {epoch_number} Trainable {is_trainable} STARTS")

    if device is None:
        if torch.cuda.is_available():
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')

    all_loss = torch.Tensor([0])

    statistics = dict()
    for batch_id, batch in enumerate(dataloader):
        output = model(batch={i: batch[i].to(device)
                              for i in ['input_ids', 'attention_mask']},)

        loss = criterion(
            output['logits'],
            batch['binary_label'].to(device))
        predictions = output['logits'].max(axis=1).indices.cpu()
        all_prediction = torch.cat((all_prediction, predictions))
        all_loss += loss.item()

        if is_trainable:
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            if scheduler is not None:
                scheduler.step()
        if batch_id % logging_steps == 0 and batch_id > 0:
            f1_score_ = get_f1_score(statistics)
            print(f"Epoch {epoch_number} Trainable {is_trainable} Step {batch_id} out of {len(dataloader)} Loss {loss.item()}\tf1-score {np.round(f1_score_, 3)}")

    f1_score_ = get_f1_score(statistics)
    accuracy_ = accuracy_score(all_target, all_prediction)

    print(
        f"Epoch {epoch_number}, Loss {all_loss.item()/len(dataloader)}\tf1-score {np.round(f1_score_, 3)}")

    if not is_trainable and not is_test:
        if f1_score_ > best_metric_valid['f1_score'] or epoch_number == 0:
            print(f"Epoch {epoch_number} Trainable {is_trainable},\t Loss {all_loss.item()/len(dataloader)}\tf1-score {f1_score_}\taccuracy {accuracy_}")
            torch.save({
                'epoch': epoch_number,
                'model_state_dict': model.state_dict(),
                'f1_score': f1_score_,
                'accuracy': accuracy_,
            }, save_path)
    if return_metric:
        return {'f1_score': f1_score_, 'accuracy': accuracy_}


def run_training(model, optimizer, parameters,
                 criterion=torch.nn.CrossEntropyLoss(), train_dataloader=None, valid_dataloader=None,
                 test_dataloader=None,
                 scheduler=None, number_of_epochs: int = 10, logging_file: str = 'log.txt',
                 save_path: str = 'checkpoint.pt', binary_classification: bool = False,
                 new_intents=None,
                 predict_action=False,
                 predict_concept=False,
                 metric_learning=False,
                 ):
    if binary_classification:
        epoch_function = run_epoch_binary
        criterion = torch.nn.BCEWithLogitsLoss()
    else:
        epoch_function = run_epoch
    best_metric = {'f1_score': 0}
    print('Start Training')
    for epoch_id, epoch in enumerate(range(number_of_epochs)):
        epoch_function(model=model, dataloader=train_dataloader,
                       optimizer=optimizer, criterion=criterion,
                       scheduler=scheduler, is_trainable=True,
                       parameters=parameters['train'], epoch_number=epoch_id,
                       best_metric_valid=best_metric, return_metric=False,
                       save_path=save_path,
                       predict_action=predict_action,
                       predict_concept=predict_concept,
                       metric_learning=metric_learning,
                       )
        with torch.no_grad():
            current_metric = epoch_function(model=model, dataloader=valid_dataloader, criterion=criterion,
                                            is_trainable=False,
                                            parameters=parameters['valid'], epoch_number=epoch_id,
                                            best_metric_valid=best_metric, return_metric=True,
                                            predict_action=predict_action,
                                            predict_concept=predict_concept,
                                            metric_learning=metric_learning,
                                            )
        if epoch_id == 0:
            best_metric = current_metric
        else:
            if current_metric['f1_score'] > best_metric['f1_score']:
                best_metric = current_metric

        if test_dataloader is not None:
            if epoch_id == 0:
                test_metric = dict()
            test_metric_current = epoch_function(model=model, dataloader=test_dataloader, criterion=criterion,
                                                 is_trainable=False, is_test=True,
                                                 parameters=parameters['infer'], epoch_number=epoch_id,
                                                 best_metric_valid=best_metric, return_metric=True,
                                                 new_intents=new_intents
                                                 )
            test_metric[epoch_id] = test_metric_current
            json.dump(test_metric, Path('test_history.json').open('w'))


def run_test(model, parameters,
             criterion=torch.nn.CrossEntropyLoss(), test_dataloader=None,
             logging_file: str = 'log.txt',
             save_test_metrics_file: str = None,
             binary_classification: bool = False,
             new_intents=None,
             predict_action=False,
             predict_concept=False,
             metric_learning=False,
             ):
    print('Start Testing')
    if binary_classification:
        epoch_function = run_epoch_binary
        criterion = torch.nn.BCEWithLogitsLoss()
    else:
        epoch_function = run_epoch

    with torch.no_grad():
        metric = epoch_function(model=model, dataloader=test_dataloader,
                                criterion=criterion,
                                is_trainable=False, is_test=True,
                                parameters=parameters['infer'], epoch_number=0,
                                return_metric=True, logging_steps=50,
                                new_intents=new_intents,
                                predict_action=predict_action,
                                predict_concept=predict_concept,
                                metric_learning=metric_learning,
                                )
    if save_test_metrics_file is not None:
        json.dump(metric, Path(save_test_metrics_file).open('w'))
