from torch import nn
import torch
from random import shuffle
from torch.nn.modules.loss import BCELoss, BCEWithLogitsLoss, MSELoss
import numpy as np
import os
import random
import sys
from numpy import indices

sys.path.append('../')

from fast_transformers.builders import TransformerEncoderBuilder
from fast_transformers.masking import LengthMask
from autoencoders.torch_utils import PositionalEncoder
from emb2emb.utils import Namespace
from torch_utils import add_vectors, remove_vectors
from emb2emb.utils import write_to_csv


class BinaryClassifier(nn.Module):
    """
    """

    def __init__(self, input_size, hidden_size, dropout=0., gaussian_noise_std=0.):
        super(BinaryClassifier, self).__init__()

        self.classifier = nn.Sequential(nn.Dropout(dropout),
                                        nn.Linear(input_size, hidden_size),
                                        nn.ReLU(),
                                        nn.Dropout(dropout),
                                        nn.Linear(hidden_size, 1))

        self.gaussian_noise_std = gaussian_noise_std

    def forward(self, inputs):

        if self.training and self.gaussian_noise_std > 0.:
            inputs = inputs + \
                torch.randn_like(inputs) * self.gaussian_noise_std

        return self.classifier(inputs)


def freeze(m):
    for p in m.parameters():
        p.requires_grad = False


def _get_bovclassifier(params, dropout, gaussian):
    config = Namespace()
    config.n_layers = params.n_layers_binary
    config.heads = params.n_heads_binary
    config.hidden_size = params.binary_dense_layer_size
    config.input_dim = params.hidden_size_binary
    config.embedding_dim = params.embedding_dim
    config.learned_positional_embeddings = params.learned_positional_embeddings
    config.dropout = dropout
    config.gaussian_noise = gaussian
    config.vector_distortion_rate = params.vector_distortion_rate
    config.vector_distortion_probability = params.vector_distortion_probability
    binary_classifier = BoVBinaryClassifier(config)
    return binary_classifier


def train_binary_classifier(true_inputs, false_inputs, encoder, params, num_val_samples=1000, regress=False):

    outputmodelname = params.outputmodelname + "_binary_clf"
    if params.load_binary_clf:

        if params.emb2emb == "bovtobov" or params.emb2emb == "bovidentity" or params.emb2emb == "bovoracle" or params.emb2emb == "simplebov":
            binary_classifier = _get_bovclassifier(
                params, 0., 0.).to(encoder.device)
        else:
            binary_classifier = BinaryClassifier(
                params.embedding_dim, 512, 0., 0.).to(encoder.device)

        checkpoint = torch.load(os.path.join(params.outputdir, outputmodelname),
                                map_location=params.device)
        binary_classifier.load_state_dict(checkpoint["model_state_dict"])
        return binary_classifier

    inputs = true_inputs + false_inputs

    if regress:
        t = [len(s) for s in inputs]
    else:
        t = ([1] * len(true_inputs)) + ([0] * len(false_inputs))

    # get validation set
    indices = list(range(len(inputs)))
    inputs, t = np.array(inputs), np.array(t)

    # normalize targets
    if regress:
        t = (t - t.mean()) / t.std()

    shuffle(indices)
    val_inputs = inputs[indices[-num_val_samples:]]
    val_targets = t[indices[-num_val_samples:]]
    inputs = inputs[indices[:-num_val_samples]]
    t = t[indices[:-num_val_samples]]
    indices = list(range(len(inputs)))

    if not (params.emb2emb == "bovtobov" or params.emb2emb == "bovidentity" or params.emb2emb == "bovoracle"):
        binary_classifier = BinaryClassifier(params.embedding_dim,
                                             512,
                                             params.dropout_binary,
                                             params.gaussian_noise_binary).to(encoder.device)
    else:
        binary_classifier = _get_bovclassifier(
            params, params.dropout_binary, params.gaussian_noise_binary).to(encoder.device)

    opt = torch.optim.Adam(binary_classifier.parameters(), lr=params.lr_bclf)
    freeze(encoder)
    encoder.eval()
    if not regress:
        loss_f = BCEWithLogitsLoss()
    else:
        loss_f = MSELoss()

    def save_clf():
        checkpoint = {"model_state_dict": binary_classifier.state_dict()}
        torch.save(checkpoint, os.path.join(params.outputdir, outputmodelname))

    best_acc = evaluate(val_inputs, val_targets, encoder,
                        binary_classifier, params, regress)
    bsize = params.batch_size
    for e in range(params.n_epochs_binary):

        # shuffle data in each epoch
        shuffle(indices)
        inputs = inputs[indices]
        t = t[indices]

        binary_classifier.train()
        losses = []
        total_len = 0
        correct = 0.
        for idx in range(0, len(inputs), bsize):
            ib = inputs[idx: idx + bsize]
            tb = t[idx: idx + bsize]

            tb = torch.tensor(tb, device=encoder.device).view(-1, 1).float()
            with torch.no_grad():
                embeddings = encoder(ib)

            _, lens = embeddings
            b_len = lens.sum().item()
            total_len = total_len + b_len
            preds = binary_clf_predict(binary_classifier, embeddings, params)

            if not regress:
                acc = ((preds > 0.5) == tb).sum()
                loss = loss_f(preds, tb)
                correct += acc
            else:
                loss = loss_f(preds, tb)
                acc = 0.

            opt.zero_grad()
            loss.backward()
            opt.step()
            losses.append(loss.item())

            if (idx / bsize) % params.log_freq == 0:
                avg_loss = np.array(losses[-params.log_freq:]).mean()
                print_acc = (correct / float(idx + bsize)
                             ).item() if not regress else correct
                print("Binary classification step {}<->{}: loss {} ; t-acc: {}, v-acc: {}".format(e,
                                                                                                  idx,
                                                                                                  avg_loss,
                                                                                                  print_acc,
                                                                                                  best_acc))
        #print("Average length of the samples:", total_len)

        val_acc = evaluate(val_inputs, val_targets, encoder,
                           binary_classifier, params, regress)
        train_acc = (correct / float(idx + bsize)
                     ).item() if not regress else np.array(losses[-params.log_freq:]).mean()
        scores = {"binary_clf_epoch": e,
                  "training_acc": train_acc,
                  "validation_acc": val_acc}
        write_to_csv(scores,
                     params, output_file_path=params.output_file + "_binaryclf")

        if (not regress and val_acc > best_acc) or (regress and val_acc < best_acc):
            best_acc = val_acc
            save_clf()
        print("Loss in epoch {}: {}".format(e, np.array(losses).mean()))
        print("Print total sequence length:", total_len)

    return binary_classifier


class BoVBinaryClassifier(nn.Module):

    def __init__(self, config):
        super(BoVBinaryClassifier, self).__init__()

        self.input_projection = nn.Linear(
            config.embedding_dim, config.input_dim)
        self.layers = config.n_layers
        self.reduction = "mean"
        self.att_type = "full"
        self.heads = config.heads
        self.hidden_size = config.hidden_size
        self.input_size = config.input_dim
        self.ff_dimension = config.input_dim
        self.positional_embeddings = True
        self.classifier = nn.Sequential(nn.Linear(
            self.input_size, self.hidden_size), nn.ReLU(), nn.Linear(self.hidden_size, 2))
        self.vector_distortion_rate = config.vector_distortion_rate
        self.vector_distortion_probability = config.vector_distortion_probability

        # Create the builder for our transformers
        builder = TransformerEncoderBuilder.from_kwargs(
            attention_type=self.att_type,
            n_layers=self.layers,
            n_heads=self.heads,
            query_dimensions=int(self.input_size / self.heads),
            value_dimensions=int(self.input_size / self.heads),
            feed_forward_dimensions=self.ff_dimension,
            dropout=config.dropout
        )
        self.gaussian_noise = config.gaussian_noise

        self.transformer = builder.get()

        if self.positional_embeddings:
            self.pe = PositionalEncoder(
                self.input_size, mul_by_sqrt=False, learned_embeddings=config.learned_positional_embeddings)

    def forward(self, X, X_lens):

        if self.training and self.vector_distortion_probability > 0:
            if random.random() < self.vector_distortion_probability:
                if random.random() < 0.5:
                    vector_distortion_function = add_vectors
                else:
                    vector_distortion_function = remove_vectors

                if self.vector_distortion_rate == 0.:
                    X_orig = X
                    X_lens_orig = X_lens
                (X, X_lens) = vector_distortion_function(
                    self.vector_distortion_rate, (X, X_lens))

                # just to validate that the add vector and remove vectors
                # functions dont do anything if the probability is zero
                if self.vector_distortion_rate == 0.:
                    assert torch.equal(X_orig, X)
                    assert torch.equal(X_lens_orig, X_lens)

        if self.training and self.gaussian_noise > 0.:
            X = X + torch.randn_like(X) * self.gaussian_noise

        X = self.input_projection(X)

        if self.positional_embeddings:
            X = self.pe(X)

        mask = LengthMask(X_lens, max_len=X.size(1),
                          device=X.device)
        outs = self.transformer(X, length_mask=mask)
        X_lens = X_lens.to(outs.device)

        # zero out the invalid items
        mask = self._make_mask(
            outs.size(0), outs.size(1), X_lens).unsqueeze(-1)
        outs = outs * mask

        # mean over all hidden layers
        output = outs.sum(dim=1) / X_lens.unsqueeze(1)

        output = self.classifier(output)
        return output.unsqueeze(1)

    def _make_mask(self, bsize, max_lens, lens):
        mask = torch.arange(max_lens, device=lens.device)
        mask = mask.unsqueeze(0).expand(bsize, -1)
        mask = mask < lens.unsqueeze(1)
        return mask


def evaluate(val_inputs, val_targets, encoder, binary_classifier, params, regress):
    inputs = val_inputs
    t = val_targets
    bsize = params.batch_size

    correct = 0.
    binary_classifier.eval()

    for idx in range(0, len(inputs), bsize):
        ib = inputs[idx: idx + bsize]
        tb = t[idx: idx + bsize]

        tb = torch.tensor(tb, device=encoder.device).view(-1, 1).float()
        with torch.no_grad():
            embeddings = encoder(ib)
        preds = binary_clf_predict(binary_classifier, embeddings, params)

        if not regress:
            acc = ((preds > 0.5) == tb).sum()
            correct += acc
        else:
            loss_f = MSELoss(reduction='sum')
            correct += loss_f(preds, tb)

    return float(correct) / len(inputs)


def binary_clf_predict(binary_classifier, embeddings, params=None):

    if isinstance(binary_classifier, BoVBinaryClassifier):
        embs = embeddings[0]
        if params and params.gates:
            embs = torch.cat([embs[..., :-2].contiguous(),
                              embs[..., -1].contiguous().unsqueeze(-1)], dim=-1)
        preds = binary_classifier(embs, embeddings[1])[..., 1]
    else:
        preds = binary_classifier(embeddings)
    return preds
