import logging
import os
import time
import numpy as np
import torch
from sklearn.metrics import accuracy_score


def train(model, criterion, optimizer, scheduler, n_labels,
          train_loader, val_loader, device, writer, checkpoints_dir,
          accum_steps=1, print_every=1000, n_epoch=10,
          save_from_epoch=1, start_epoch=1, best_acc=0., model_type=None):
    step = 0
    logging.info("Start training...")
    for epoch in range(start_epoch, n_epoch + 1):
        model.train()
        epoch_loss_train = 0
        start = time.time()
        for batch_id, batch in enumerate(train_loader):
            optimizer.zero_grad()
            loss, logits = run_batch(model, criterion, batch, device, model_type)
            loss.backward()
            if batch_id % accum_steps == 0:
                optimizer.step()
                scheduler.step()
                writer.add_scalar('lr', np.array(scheduler.get_last_lr()), step)

            epoch_loss_train += loss.item()

            preds = np.argmax(logits.detach().cpu().numpy(), axis=1)
            labels = batch['label'].numpy()
            current_acc = accuracy_score(labels, preds)

            if batch_id % print_every == 0 and batch_id > 0:
                logging.info('EPOCH {} BATCH {} of {}\t TRAIN_LOSS {:.3f}'.format(
                    epoch, batch_id, len(train_loader),epoch_loss_train / batch_id)
                )
                logging.info(
                    f'EPOCH TIME: {(time.time() - start) // 60} min {round((time.time() - start) % 60, 1)} sec')
                writer.add_scalar('train/loss', loss.item(), step)
                writer.add_scalar('train/acc', current_acc, step)
            step += 1
        logging.info(
            'EPOCH {} TRAIN_LOSS {:.3f}'.format(epoch, epoch_loss_train / len(train_loader)))
        logging.info(f'EPOCH TIME: {(time.time() - start) // 60} min {round((time.time() - start) % 60, 1)} sec')

        y_val, pred_val, _, loss_val = validate(model, criterion, val_loader, n_labels, device, model_type)
        epoch_accuracy = accuracy_score(y_val, pred_val)
        logging.info('-' * 100)
        logging.info('EVAL EPOCH {}\t EVAL_LOSS {:.3f}\tACCURACY {:.3f}'.format(epoch, loss_val / len(val_loader),
                                                                                epoch_accuracy))
        logging.info(f'EVAL EPOCH TIME: {(time.time() - start) // 60} min {round((time.time() - start) % 60, 1)} sec')

        if epoch >= save_from_epoch:
            torch.save({'model': model, 'loss_val': loss_val, 'acc_val': epoch_accuracy, 'epoch': epoch},
                       os.path.join(checkpoints_dir, f'checkpoint_e{epoch}.pt'))
        if epoch_accuracy >= best_acc:
            torch.save({'model': model, 'loss_val': loss_val, 'acc_val': epoch_accuracy, 'epoch': epoch},
                       os.path.join(checkpoints_dir, 'checkpoint_best_acc.pt'))
        logging.info(f'model saved')


def validate(model, criterion, loader, n_labels, device, model_type=None):
    model.eval()

    loss_val = 0
    pred_val = np.zeros(len(loader.dataset))
    y_val = np.zeros(len(loader.dataset))
    logits_val = np.zeros((len(loader.dataset), n_labels))
    with torch.no_grad():
        batch_size = loader.batch_size
        for i, batch in enumerate(loader):
            y_batch = batch['label']
            loss, logits = run_batch(model, criterion, batch, device, model_type)
            loss_val += loss.item()
            upper_bound = min((i + 1) * batch_size, len(loader.dataset))
            logits_val[i * batch_size:upper_bound, :] = logits.detach().cpu().numpy()
            pred_intent = np.argmax(logits.detach().cpu().numpy(), axis=1)
            pred_val[i * batch_size:upper_bound] = pred_intent
            y_val[i * batch_size:upper_bound] = y_batch
        return y_val, pred_val, logits_val, loss_val


def encode(model, loader, n_labels, device, features_dim):
    model.eval()

    y_val = np.zeros(len(loader.dataset))
    features_val = np.zeros((len(loader.dataset), n_labels, features_dim))
    with torch.no_grad():
        batch_size = loader.batch_size
        for i, batch in enumerate(loader):
            y_batch = batch['label']
            feats_batch = model.encode(batch['pair_ids'].to(device), batch['pair_attention_mask'].to(device))
            upper_bound = min((i + 1) * batch_size, len(loader.dataset))
            features_val[i * batch_size:upper_bound, :] = feats_batch.detach().cpu().numpy()
            y_val[i * batch_size:upper_bound] = y_batch
        return y_val, features_val


# def run_batch(model, batch, device, model_type):
#     if model_type == 'nli_ca':
#         loss, logits = model(batch['uttr'].to(device), batch['label_concept_enc'].to(device),
#                              batch['label_action_enc'].to(device))
#     elif model_type == 'nli_strict':
#         loss, logits = model(batch['uttr'].to(device), batch['label_enc'].to(device))
#     elif model_type == 'nli_contrastive':
#         loss, logits = model(batch['pairs'].to(device), batch['labels'].to(device))
#     else:
#         raise ValueError('Unknown model type')
#     return loss, logits

def run_batch(model, criterion, batch, device, model_type):
    if model_type == 'nli_ca':
        logits = model(batch['pair_ids'].to(device), batch['pair_attention_mask'].to(device))
        loss = criterion(logits, batch['label_action_enc'].to(device), batch['label_concept_enc'].to(device))
    elif model_type == 'nli_strict':
        logits = model(batch['pair_ids'].to(device), batch['pair_attention_mask'].to(device))
        loss = criterion(logits, batch['label_enc'].to(device))
    elif model_type == 'nli_contrastive':
        logits = model(batch['pair_ids'].to(device))
        loss = criterion(batch['labels'].to(device))
    else:
        raise ValueError('Unknown model type')
    return loss, logits
