import logging
from dataclasses import dataclass
from typing import List, Optional

from stanza.models.common.doc import Document

from contrastive_conditioning.parser import Constituent, ParseTree
from contrastive_conditioning.utils import load_stanza_parser
from translation_models import ScoringModel

try:
    from dict_to_dataclass import DataclassFromDict, field_from_dict
except:  # Python 3.7 compatibility
    DataclassFromDict = object
    def field_from_dict(default=None):
        return default


@dataclass
class CoverageError(DataclassFromDict):
    constituent: Constituent = field_from_dict()

    def __str__(self):
        return self.constituent.removed


@dataclass
class AdditionError(CoverageError):
    ...


@dataclass
class OmissionError(CoverageError):
    ...


@dataclass
class CoverageResult(DataclassFromDict):
    addition_errors: Optional[List[AdditionError]] = field_from_dict()
    omission_errors: Optional[List[OmissionError]] = field_from_dict()
    src_language: Optional[str] = field_from_dict()
    tgt_language: Optional[str] = field_from_dict()
    is_multi_sentence_input: Optional[bool] = field_from_dict(default=None)

    @property
    def contains_addition_error(self) -> bool:
        return len(self.addition_errors) >= 1

    @property
    def contains_omission_error(self) -> bool:
        return len(self.omission_errors) >= 1

    def __str__(self):
        return "".join([
            f"Addition errors: {' | '.join(map(str, self.addition_errors))}" if self.addition_errors else "",
            "\n" if self.addition_errors and self.omission_errors else "",
            f"Omission errors: {' | '.join(map(str, self.omission_errors))}" if self.omission_errors else "",
        ])


class CoverageDetector:

    def __init__(self,
                 src_language: str = None,
                 tgt_language: str = None,
                 forward_evaluator: ScoringModel = None,
                 backward_evaluator: ScoringModel = None,
                 batch_size: int = 16,
                 ):
        self.src_language = src_language
        self.tgt_language = tgt_language
        self.src_parser = load_stanza_parser(src_language) if src_language is not None else None
        self.tgt_parser = load_stanza_parser(tgt_language) if tgt_language is not None else None
        self.forward_evaluator = forward_evaluator
        self.backward_evaluator = backward_evaluator
        self.batch_size = batch_size

    def _get_error_constituents(self, src_sequence: str, tgt_sequence: str, evaluator, parser=None, src_doc: Document = None) -> Optional[List[Constituent]]:
        if src_doc is None:
            src_doc = parser(src_sequence)
        if len(src_doc.sentences) > 1:
            logging.warning("Coverage detector does not handle multi-sentence inputs yet; skipping ...")
            return None
        tree = ParseTree(src_doc.sentences[0])
        constituents = list(tree.iter_constituents())
        if not constituents:
            return []

        scores = evaluator.score(
            source_sentences=[src_sequence] + [constituent.remainder for constituent in constituents],
            hypothesis_sentences=(1 + len(constituents)) * [tgt_sequence],
            batch_size=self.batch_size,
        )
        base_score = scores[0]
        returned_constituents = []
        for score, constituent in zip(scores[1:], constituents):
            if score > base_score:
                constituent.constituent_score = score
                constituent.base_score = base_score
                returned_constituents.append(constituent)
        return returned_constituents

    def detect_errors(self, src: str, translation: str, src_doc: Document = None, translation_doc: Document = None) -> CoverageResult:
        is_multi_sentence_input = False
        addition_errors = None
        if self.backward_evaluator is not None:
            tgt_constituents = self._get_error_constituents(translation, src, self.backward_evaluator, self.tgt_parser, translation_doc)
            if tgt_constituents is None:
                is_multi_sentence_input = True
            else:
                addition_errors = [AdditionError(constituent=constituent) for constituent in tgt_constituents]
        omission_errors = None
        if self.forward_evaluator is not None:
            src_constituents = self._get_error_constituents(src, translation, self.forward_evaluator, self.src_parser, src_doc)
            if src_constituents is None:
                is_multi_sentence_input = True
            else:
                omission_errors = [OmissionError(constituent=constituent) for constituent in src_constituents]
        return CoverageResult(
            src_language=self.src_language,
            tgt_language=self.tgt_language,
            addition_errors=addition_errors,
            omission_errors=omission_errors,
            is_multi_sentence_input=is_multi_sentence_input,
        )
