import threading
from multiprocessing.pool import ThreadPool

from typing import List

import torch
from torch.nn import Module, LayerNorm

from coli.torch_extra.layers import ContextualUnits
from coli.torch_extra.sentence import SentenceEmbeddings
from coli.torch_extra.utils import pad_and_stack_1d, broadcast_gather
from coli.torch_span.data_loader import SentenceFeatures, span_starts, span_ends
from coli.basic_tools.common_utils import NoPickle, AttrDict
from coli.torch_span.data_loader import span_count_on_length_k, span_ids
from coli.span.const_tree import ConstTreeStatistics
from coli.torch_span.config import SpanParserOptions

import pyximport

pyximport.install(build_dir="/tmp/cython_cache/")
from coli.span.cdecoder import decoder_types


class AttentiveSpanNetwork(Module):
    def __init__(self,
                 args,
                 hparams: SpanParserOptions,
                 statistics: ConstTreeStatistics,
                 plugins
                 ):
        super().__init__()
        self.args = args
        self.hparams = hparams
        self.statistics = statistics
        self.decode_type = decoder_types[self.hparams.decoder]
        self.embedding = SentenceEmbeddings(hparams.sentence_embedding,
                                            statistics, plugins)

        self.contextual_unit = ContextualUnits.get(hparams.contextual,
                                                   input_size=self.embedding.output_dim)

        self.d_model = self.contextual_unit.output_dim

        self.f_span = torch.nn.Sequential(
            torch.nn.Linear(self.d_model, hparams.d_span_hidden),
            LayerNorm(hparams.d_span_hidden),
            torch.nn.LeakyReLU(0.1),
            torch.nn.Linear(hparams.d_span_hidden, 1, bias=False)
        )

        self.f_label = torch.nn.Sequential(
            torch.nn.Linear(self.d_model, hparams.d_label_hidden),
            LayerNorm(hparams.d_label_hidden),
            torch.nn.LeakyReLU(0.1),
            torch.nn.Linear(hparams.d_label_hidden, len(statistics.labels)),
        )

        self.create_pool()

    def create_pool(self):
        self.thread_local = NoPickle(threading.local())
        self.pool = NoPickle(ThreadPool(self.args.concurrent_count, initializer=self.create_decoder))

    def create_decoder(self):
        # pre allocate memory for decoders
        self.thread_local.decoder = self.decode_type(
            self.statistics.max_sentence_length + 10,
            len(self.statistics.labels),
            use_rules=self.args.use_rules)

    def decode_func(self, span_scores, label_scores, leaftag_scores,
                    use_rules, return_item):
        decoder = self.thread_local.decoder
        result = decoder(
            self.statistics.rules,
            span_scores, label_scores,
            leaftag_scores,
            self.statistics.leaftag_to_label,
            self.statistics.internal_labels,
            self.statistics.root_rules if self.args.restrict_root_rule else None,
            use_rules,
            return_item
        )
        return result

    def get_span_features(self, contextual_outputs,
                          start_indices, end_indices, *, repeat_indices=True):
        """
        :type start_indices: original span start. Need to +1 to skip START symbol
        """

        if repeat_indices:
            def select_2nd(x, idx):
                return x[:, idx]

            select_func = select_2nd
            # start_indices = start_indices.unsqueeze(0).repeat(batch_size, 1)
            # end_indices = end_indices.unsqueeze(0).repeat(batch_size, 1)
        else:
            select_func = lambda x, idx: broadcast_gather(x, 1, idx)

        if self.hparams.span_feature_type == "split_concat":
            # the start index is wrong, just use it for backward compatibility
            start_features, end_features = contextual_outputs.split(
                self.d_model // 2, dim=-1)
            span_features = torch.cat(
                [select_func(start_features, start_indices),
                 select_func(end_features, end_indices)], -1)
        elif self.hparams.span_feature_type == "split_concat_start_stop":
            start_features, end_features = contextual_outputs.split(
                self.d_model // 2, dim=-1)
            span_features = torch.cat(
                [select_func(start_features, start_indices + 1),
                 select_func(end_features, end_indices)], -1)
        elif self.hparams.span_feature_type == "minus":
            span_features = select_func(contextual_outputs, end_indices) \
                            - select_func(contextual_outputs, start_indices)
        else:
            assert self.hparams.span_feature_type == "lstm-minus"
            lstm_outputs = contextual_outputs
            forward_outputs, backward_outputs = lstm_outputs.split(
                self.d_model // 2, dim=-1)
            forward_features = select_func(forward_outputs, end_indices) \
                               - select_func(forward_outputs, start_indices)
            backward_features = select_func(backward_outputs, start_indices + 1) \
                                - select_func(backward_outputs, end_indices + 1)
            span_features = torch.cat([forward_features, backward_features], -1)
        return span_features

    def forward(self, sentences: List[SentenceFeatures], inputs,
                *,
                leaftag_scores=None,
                return_span_features=False,
                return_chart=False
                ):
        inputs.words = inputs.words
        embeded = self.embedding(inputs)

        if self.hparams.contextual.type == "transformer":
            sentence_mask = inputs.words.gt(0)
            contextual_outputs = self.contextual_unit(embeded, sentence_mask, use_mask=True)
            if self.hparams.contextual.transformer_options.d_positional is not None:
                # Rearrange the annotations to ensure that the transition to
                # fenceposts captures an even split between position and content.
                # TODO(nikita): try alternatives, such as omitting position entirely
                contextual_outputs = torch.cat([
                    contextual_outputs[:, :, 0::2],
                    contextual_outputs[:, :, 1::2],
                ], -1)
        else:
            contextual_outputs = self.contextual_unit(embeded, inputs.sent_lengths,
                                                      return_all=True)

        if return_span_features:
            # noinspection PyCallingNonCallable
            all_starts = pad_and_stack_1d([torch.tensor(i.span_starts) for i in sentences])
            # noinspection PyCallingNonCallable
            all_ends = pad_and_stack_1d([torch.tensor(i.span_ends) for i in sentences])
            if self.args.gpu:
                all_starts = all_starts.cuda()
                all_ends = all_ends.cuda()

            return self.get_span_features(
                contextual_outputs,
                all_starts, all_ends, repeat_indices=False)

        batch_size, max_sent_length = inputs.words.shape
        raw_sent_length = max_sent_length - 2
        span_count = span_count_on_length_k[raw_sent_length]
        start_indices = span_starts[:span_count]
        end_indices = span_ends[:span_count]
        this_span_ids = span_ids[
                        :raw_sent_length, :(raw_sent_length + 1)].to(contextual_outputs.device)

        # (batch_size, span_count, feature_count)
        span_features_3d = self.get_span_features(
            contextual_outputs, start_indices, end_indices)
        # (batch_size, span_count, label_count)
        label_scores_3d = self.f_label(span_features_3d)
        # (batch_size, span_count)
        span_scores_2d = self.f_span(span_features_3d).view(batch_size, -1)

        if self.training:
            gold_span_selector = inputs.span_batch_indices_1d, inputs.span_ids_1d
            gold_label_selector = inputs.span_batch_indices_1d, inputs.span_ids_1d, inputs.span_labels_1d

            span_scores_2d += self.hparams.hamming_factor
            label_scores_3d += self.hparams.hamming_factor
            span_scores_2d[gold_span_selector] -= self.hparams.loss_margin
            label_scores_3d[gold_label_selector] -= self.hparams.loss_margin
            use_rules = self.args.use_rules and self.hparams.train_with_rules
        else:
            use_rules = self.args.use_rules

        span_scores_broadcast = span_scores_2d.unsqueeze(-1).expand(-1, -1, raw_sent_length + 1)
        span_selector = this_span_ids.expand(batch_size, -1, -1)
        span_scores_padded = span_scores_broadcast.gather(1, span_selector)

        label_scores_broadcast = label_scores_3d.unsqueeze(-2).expand(-1, -1, raw_sent_length + 1, -1)
        labels_selector = span_selector.unsqueeze(-1).expand(-1, -1, -1, label_scores_3d.size(-1))
        label_scores_padded = label_scores_broadcast.gather(1, labels_selector)

        if return_chart:
            chart_list = []
            for sent_idx, sentence in enumerate(sentences):
                sent_length = sentence.sent_length - 2
                chart_list.append((span_scores_padded[sent_idx, :sent_length, :(sent_length + 1)],
                                   label_scores_padded[sent_idx, :sent_length, :(sent_length + 1)]))
            return chart_list

        span_scores_padded_detach = span_scores_padded.detach().cpu()
        label_scores_padded_detach = label_scores_padded.detach().cpu()

        pending_results = []
        for sent_idx, sentence in enumerate(sentences):
            sent_length = sentence.sent_length - 2
            pending_results.append(
                self.pool.apply_async(self.decode_func,
                                      (span_scores_padded_detach[sent_idx, :sent_length, :(sent_length + 1)].numpy(),
                                       label_scores_padded_detach[sent_idx, :sent_length, :(sent_length + 1)].numpy(),
                                       leaftag_scores[sent_idx, :sent_length] if leaftag_scores is not None else None,
                                       use_rules, not self.training)))

        decoded_list = []
        p_batch_indices = []
        p_starts = []
        p_ends = []
        p_labels = []
        for sent_idx, (sentence, pending_result) in enumerate(
                zip(sentences, pending_results)):
            decoded = pending_result.get()
            if not self.training:
                decoded_list.append(decoded.to_const_tree(
                    self.statistics.labels.int_to_word,
                    list(sentence.original_obj.generate_words())
                ).expanded_unary_chain())
            else:
                for start, end, label in decoded:
                    p_batch_indices.append(sent_idx)
                    p_starts.append(start)
                    p_ends.append(end)
                    p_labels.append(label)

        if self.training:
            p_span_ids = this_span_ids[p_starts, p_ends]
            predict_span_selector = p_batch_indices, p_span_ids
            predict_label_selector = p_batch_indices, p_span_ids, p_labels
            pred_score = span_scores_2d[predict_span_selector].sum() + \
                         label_scores_3d[predict_label_selector].sum()
            if pred_score != pred_score:
                # nan
                pred_score = 0.0
            # noinspection PyUnboundLocalVariable
            gold_score = span_scores_2d[gold_span_selector].sum() + \
                         label_scores_3d[gold_label_selector].sum()
            loss = (pred_score - gold_score) / batch_size
        else:
            loss = None

        return AttrDict({"loss": loss,
                         "sent_count": inputs.words.shape[0],
                         "decoded_list": decoded_list})
