import json
import os
import random
import logging
import sys

import numpy as np
import torch
from torch.nn import CosineSimilarity
from torch import nn

from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR


PREFIX_DIR = ""
UTTR_TOKEN = "[UTTR]"
LMDICT = {
    "bert": "bert-base-uncased",
    "roberta": "roberta-base",
    "electra": "google/electra-base-discriminator",
}


def set_logger(name: str) -> logging.Logger:
    """
    set and return the logger object

    :param name: logger name
    :return: logger object
    """
    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)

    stream_handler = logging.StreamHandler(sys.stdout)
    stream_handler.setFormatter(logging.Formatter("[%(asctime)s] %(message)s"))
    logger.addHandler(stream_handler)

    return logger


def set_random_seed(seed):
    """
    set random seed with seed value
    
    :param seed: seed value
    """
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all( seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = False


def dump_config(args):
    """
    save model parameters to config file
    
    :param args: model parameters (arguments)
    """
    with open(os.path.join(args.exp_path, "config.json"), "w") as f:
        json.dump(vars(args), f)


def make_soft_labels(device, batch_size, num_of_hard_negs, smoothing_value) -> torch.Tensor:
    """
    make soft labels with smoothing value
    
    :param device: device
    :param batch_size: current batch size
    :param num_of_hard_negs: number of hard negaitves in the batch size
    :param smoothing_value: smoothing value for label smoothing
    :return: labels with smoothing value
    """
    return (torch.tensor(
            [
                [1 - smoothing_value]
                + [smoothing_value / 5] * num_of_hard_negs
                + [0] * num_of_hard_negs
            ]
        )
        .repeat(batch_size, 1)
        .to(device)
    )


def get_marginal_loss(output, margin):
    pos, hard, rand = output[:, 0], output[:, 1:6], output[:, 6:]
    hard_mean = hard.mean(dim=1)
    rand_mean = rand.mean(dim=1)
    criteria = nn.TripletMarginLoss(margin=margin, p=2)
    return criteria(pos, hard_mean, rand_mean)


def get_ntxent_loss(output, temp: float = 0.1, device=None):
    pos, hard, rand = output[:, 0].unsqueeze(1), output[:, 1:6], output[:, 6:]
    cos_sim_func = CosineSimilarity(dim=2)
    pos_hard_sim = cos_sim_func(pos, hard) / temp
    pos_rand_sim = cos_sim_func(pos, rand) / temp
    total_sim = torch.cat([pos_hard_sim, pos_rand_sim], dim=1)
    logsoftmax_func = torch.nn.LogSoftmax(dim=1)
    scores = logsoftmax_func(total_sim)
    true_dist = (
        torch.tensor([[0.2] * 5 + [0.0] * 5])
        .repeat(output.shape[0], 1)
        .to(device)
    )
    loss = torch.mean(torch.sum(-true_dist * scores, dim=-1))
    return loss


def get_linear_schedule_with_warmup(
    optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, last_epoch: int = -1
) -> LambdaLR:
    """Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
    a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
    
    :param optimizer: The optimizer for which to schedule the learning rate.
    :param num_warmup_steps: The number of steps for the warmup phase.
    :param num_training_steps: The total number of training steps.
    :param last_epoch: The index of the last epoch when resuming training., defaults to -1
    :return: LambdaLR with the appropriate schedule.
    """

    def lr_lambda(current_step: int) -> float:
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))

    return LambdaLR(optimizer, lr_lambda, last_epoch)


def write2tensorboard(writer, score_dict, setname, global_step):
    """
    write scores in a specific global step to tensorboard
    
    :param writer: writer
    :param score_dict: score dictionary
    :param setname: setname (train, valid, test)
    :param global_step: current global step
    """
    for k, v in score_dict.items():
        writer.add_scalars(k, {setname: v}, global_step)
    writer.flush()


def save_model(model, epoch, model_path, is_distributed):
    """
    save model weight to given model path
    (considering DistributedDataParallel (DDP) setting)
    
    :param model: model for saving
    :param epoch: current epoch
    :param model_path: path for saving
    :param is_distributed: if true, distributed training setup
    """
    if is_distributed:
        torch.save(
            model.module.state_dict(),
            os.path.join(model_path, f"epoch-{epoch}.pth"),
        )
    else:
        torch.save(
            model.state_dict(),
            os.path.join(model_path, f"epoch-{epoch}.pth"),
        )


def load_model(model, model_path, epoch, len_tokenizer):
    """
    save model weight to given model path
    (considering DistributedDataParallel (DDP) setting)
    
    :param model: model for saving
    :param epoch: current epoch
    :param model_path: path for saving
    """
    if "select" in model_path:
        model.bert.resize_token_embeddings(len_tokenizer)
    model.load_state_dict(torch.load(model_path + f"/epoch-{epoch}.pth"))
    
    return model


def recall_x_at_k(score_list, x, k, answer_index):
    assert len(score_list) == x
    sorted_score_index = np.array(score_list).argsort()[::-1]
    assert answer_index in sorted_score_index
    return int(answer_index in sorted_score_index[:k])


def mrr(score_list, answer_index=0):
    sorted_score_index = np.array(score_list).argsort()[::-1]
    assert answer_index in sorted_score_index
    rank = list(sorted_score_index).index(answer_index) + 1
    return 1 / rank