import torch
import torch.nn as nn
import random
import numpy as np
import itertools


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)


# MINE : https://github.com/sungyubkim/MINE-Mutual-Information-Neural-Estimation-/blob/master/GAN_MINE.ipynb
# Paper https://arxiv.org/pdf/1806.04498.pdf
def update_target(ma_net, net, update_rate=1e-1):
    # update moving average network parameters using network
    for ma_net_param, net_param in zip(ma_net.parameters(), net.parameters()):
        ma_net_param.data.copy_((1.0 - update_rate) \
                                * ma_net_param.data + update_rate * net_param.data)


# control which parameters are frozen / free for optimization
def free_params(module: nn.Module):
    for p in module.parameters():
        p.requires_grad = True


def comput_gradient_norm(model):
    total_norm = 0
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** (1. / 2)
    return torch.tensor(total_norm)


def frozen_params(module: nn.Module):
    for p in module.parameters():
        p.requires_grad = False


def corrupt_input(model, input_tensor):
    pad_indices = (input_tensor == model.args.tokenizer.pad_token_id) + (
            input_tensor == model.args.tokenizer.sep_token_id) + (
                          input_tensor == model.args.tokenizer.cls_token_id)
    pad_indices.to(model.args.device)
    ### MASK TOKEN ###
    """
    special_tokens_mask = self.args.tokenizer.mask_token_id
    input_tensor_to_corrupt = input_tensor.clone()
    probability_matrix = torch.full(input_tensor_to_corrupt.shape, self.args.noise_p).to(self.args.device)
    masked_indices = torch.bernoulli(probability_matrix).byte() & ~pad_indices
    input_tensor_to_corrupt[masked_indices] = torch.tensor(special_tokens_mask).to(self.args.device)"""

    ### RANDOMLY CHANGE SOME TOKENS ###
    input_tensor_to_corrupt = input_tensor.clone()
    probability_matrix = torch.full(input_tensor_to_corrupt.shape, model.args.noise_p).to(model.args.device)
    random_indices = torch.bernoulli(probability_matrix).byte() & (~pad_indices).byte()
    random_words = torch.randint(len(model.args.tokenizer), input_tensor.shape, dtype=torch.long).to(
        model.args.device)
    input_tensor_to_corrupt[random_indices] = random_words[random_indices]

    ### RANDOMLY SWAP ORDER OF TOKENS ###
    # BUILD THE CORRECT MATRIX
    corrupt_batch = True if random.random() < model.args.noise_p else False
    input_tensor_corrupted_ = []
    for b in range(model.args.batch_size):
        lengths = torch.sum(input_tensor[b, :] != model.args.tokenizer.pad_token_id)
        perms = list(itertools.permutations(list(range(1, lengths - 1)), 2))  # don't permut beo or eos
        random_perm = random.choice(perms)
        index_ = list(range(model.args.max_length))
        if corrupt_batch:
            try:
                index_[random_perm[0]], index_[random_perm[1]] = index_[random_perm[1]], index_[random_perm[0]]
            except:
                print("Corruption Error Empty")
        input_tensor_corrupted_.append(
            torch.index_select(input_tensor_to_corrupt[b, :].unsqueeze(0), 1,
                               torch.LongTensor(index_).to(model.args.device)))

    input_tensor_corrupted = torch.cat(input_tensor_corrupted_, dim=0)

    return input_tensor_corrupted
