from typing import Optional, Any

from dataclasses import dataclass, field

from coli.basic_tools.dataclass_argparse import argfield
from coli.torch_extra.layers import ExternalContextualEmbedding, AdvancedLearningOptions, \
    ContextualUnits
from coli.torch_extra.parser_base import SimpleParser
from coli.torch_extra.sentence import SentenceEmbeddings


@dataclass
class SpanParserOptions(SimpleParser.HParams):
    contextual: ContextualUnits.Options = field(
        default_factory=ContextualUnits.Options)

    train_iters: Any = 100000
    strip_top: Optional[str] = argfield(None, type=str)
    predict_postags: bool = False
    span_feature_type: str = argfield(
        default="split_concat_start_stop",
        choices=["split_concat_start_stop", "minus", "lstm-minus"])
    decoder: str = "rulefree3"
    train_with_rules: bool = False

    hamming_factor: float = 0.4
    loss_margin: float = 1.0

    d_label_hidden: int = 256
    d_span_hidden: int = 256

    # char_lstm_input_dropout: float = 0.2

    word_threshold = 0

    sentence_embedding: SentenceEmbeddings.Options = field(
        default_factory=SentenceEmbeddings.Options)
    pretrained_contextual: ExternalContextualEmbedding.Options = field(
        default_factory=ExternalContextualEmbedding.Options)
    learning: AdvancedLearningOptions = field(
        default_factory=AdvancedLearningOptions)

    @classmethod
    def get_default(cls):
        ret = cls()
        ret.pretrained_contextual.type = "elmo"
        ret.pretrained_contextual.elmo_options.keep_sentence_boundaries = True
        ret.pretrained_contextual.elmo_options.dropout = 0.5
        ret.pretrained_contextual.elmo_options.feature_dropout = 0.2
        ret.pretrained_contextual.elmo_options.project_to = 512
        ret.sentence_embedding.dim_word = 512
        ret.sentence_embedding.dim_postag = 0
        ret.sentence_embedding.mode = "add"
        ret.contextual.type = "transformer"
        ret.contextual.transformer_options.num_layers = 8
        ret.contextual.transformer_options.num_heads = 8
        ret.contextual.transformer_options.d_kv = 64
        ret.contextual.transformer_options.d_ff = 2048
        ret.contextual.transformer_options.d_positional = 512
        ret.contextual.transformer_options.max_sent_len = 300
        ret.contextual.transformer_options.attention_dropout = 0.2
        ret.contextual.transformer_options.timing_dropout = 0.0
        ret.contextual.transformer_options.timing_layer_norm = True
        ret.contextual.transformer_options.relu_dropout = 0.1
        ret.contextual.transformer_options.residual_dropout = 0.2

        ret.learning.learning_rate = 0.0008
        return ret

    @classmethod
    def get_lstm_options(cls):
        ret = cls()
        ret.contextual.type = "lstm"
        ret.contextual.lstm_options.hidden_size = 256
        ret.contextual.lstm_options.num_layers = 2
        ret.span_feature_type = "lstm-minus"
        return ret

