import torch
from typing import Optional

from dataclasses import dataclass
from torch import Tensor

from bilm.load_vocab import BiLMVocabLoader
from coli.basic_tools.common_utils import add_slots
from coli.data_utils.dataset import SentenceFeaturesBase
from coli.torch_extra.dataset import lookup_list, lookup_characters
from coli.span.const_tree import ConstTree, ConstTreeStatistics


def load_span_ids():
    max_sent_len = 512
    # reserve position 0 for padding
    current_idx = 1
    span_count_on_length_k = torch.zeros((max_sent_len,), dtype=torch.long)
    span_to_idx = torch.zeros((max_sent_len, max_sent_len), dtype=torch.long)
    span_starts = torch.zeros((max_sent_len * max_sent_len,), dtype=torch.long)
    span_ends = torch.zeros((max_sent_len * max_sent_len,), dtype=torch.long)
    span_starts[0] = 1
    span_ends[0] = 1
    for end in range(1, max_sent_len):
        for start in range(0, end):
            span_to_idx[start, end] = current_idx
            span_starts[current_idx] = start
            span_ends[current_idx] = end
            current_idx += 1
        span_count_on_length_k[end] = current_idx
    return span_count_on_length_k, span_to_idx, span_starts, span_ends


span_count_on_length_k, span_ids, span_starts, span_ends = load_span_ids()


@add_slots
@dataclass(eq=True, unsafe_hash=True)
class SentenceFeatures(SentenceFeaturesBase[ConstTree]):
    original_idx: int = -1
    original_obj: ConstTree = None
    words: Tensor = None  # [int] (pad_length, )
    characters: Tensor = None  # [int] (pad_length, word_pad_length)
    postags: Tensor = None  # [int] (pad_length, )
    sent_length: int = None  # (include start and stop)
    char_lengths: Tensor = None  # [int] (pad_length, )
    span_starts: Tensor = None  # [int] (span_count, )
    span_ends: Tensor = None  # [int] (span_count, )
    span_labels: Tensor = None  # [int] (span_count, )
    span_ids: Tensor = None

    @classmethod
    def from_sentence_obj(cls, original_idx, sent: ConstTree,
                          statistics: ConstTreeStatistics,
                          padded_length=None, lower=True,
                          include_postags=False,
                          plugins=None,
                          strip_top="TOP"
                          ):
        ret = cls(original_idx=original_idx, original_obj=sent)

        sent = sent.condensed_unary_chain(include_postag=include_postags,
                                          strip_top=strip_top)
        words = sent.words
        ret.sent_length = len(words) + 2
        lower_func = lambda x: x.lower() if lower else lambda x: x

        if padded_length is None:
            padded_length = ret.sent_length

        ret.words = lookup_list(
            (lower_func(i) for i in words), statistics.words.word_to_int,
            padded_length=padded_length, default=1,
            start_and_stop=True
        )

        if hasattr(statistics, "postags"):
            ret.postags = lookup_list(sent.postags, statistics.postags.word_to_int,
                                      padded_length=padded_length, default=1, start_and_stop=True)

        if hasattr(statistics, "characters"):
            ret.char_lengths, ret.characters = lookup_characters(
                words, statistics.characters.word_to_int,
                padded_length + 2, 1, return_lengths=True,
                sentence_start_and_stop=True,
            )

        all_spans = list(sent.generate_scoreable_spans())
        ret.span_starts, ret.span_ends, labels_1d = zip(*all_spans)

        # noinspection PyCallingNonCallable
        ret.span_labels = torch.tensor([statistics.labels.word_to_int.get(i, -1) for i in labels_1d])

        # noinspection PyCallingNonCallable
        ret.span_ids = torch.tensor([span_ids[start, end] for start, end, label in all_spans])

        if plugins:
            for plugin in plugins.values():
                plugin.process_sentence_feature(sent, ret, padded_length)

        return ret

    @classmethod
    def get_feed_dict(cls, pls, batch_sentences, plugins=None):
        # noinspection PyCallingNonCallable
        ret = {pls.words: torch.stack([i.words for i in batch_sentences]),
               pls.chars: torch.stack([i.characters for i in batch_sentences]),
               pls.sent_lengths: torch.tensor([i.sent_length for i in batch_sentences]),
               pls.word_lengths: torch.stack([i.char_lengths for i in batch_sentences]),
               pls.span_ids_1d: torch.cat([i.span_ids for i in batch_sentences]),
               pls.span_labels_1d: torch.cat([i.span_labels for i in batch_sentences]),
               pls.span_batch_indices_1d: torch.cat(
                   [torch.tensor(idx, dtype=torch.long).expand(len(i.span_starts))
                    for idx, i in enumerate(batch_sentences)])
               }

        if batch_sentences[0].postags is not None:
            ret[pls.postags] = torch.stack([i.postags for i in batch_sentences])

        if plugins:
            for plugin in plugins.values():
                plugin.process_batch(pls, ret, batch_sentences)

        return ret
