import collections
from typing import Type, List, Dict

from overrides import overrides

from allennlp.common import JsonDict
from allennlp.data import Instance
from allennlp.models.model import Model
from allennlp.predictors.predictor import Predictor
from allennlp.common.util import sanitize
from allennlp.data.fields import MetadataField
from allennlp.common.checks import ConfigurationError
from allennlp.data.dataset_readers import MultiTaskDatasetReader


@Predictor.register("seq2seq_multihead")
class Seq2SeqMultiHeadPredictor(Predictor):
    """
    Predictor for multitask models.

    This predictor is tightly coupled to `MultiTaskDatasetReader` and `MultiTaskModel`, and will not work if
    used with other readers or models.
    """

    _WRONG_READER_ERROR = (
        "MultitaskPredictor is designed to work with MultiTaskDatasetReader. "
        + "If you have a different DatasetReader, you have to write your own "
        + "Predictor, but you can use MultiTaskPredictor as a starting point."
    )

    _WRONG_FIELD_ERROR = (
        "MultiTaskPredictor expects instances that have a MetadataField "
        + "with the name 'task', containing the name of the task the instance is for."
    )

    def __init__(self, model: Model, dataset_reader: MultiTaskDatasetReader) -> None:
        if not isinstance(dataset_reader, MultiTaskDatasetReader):
            raise ConfigurationError(self._WRONG_READER_ERROR)

        super().__init__(model, dataset_reader)

    @overrides
    def predict_instance(self, instance: Instance) -> JsonDict:
        task_field = instance["task"]
        if not isinstance(task_field, MetadataField):
            raise ValueError(self._WRONG_FIELD_ERROR)
        task: str = task_field.metadata
        if not isinstance(self._dataset_reader, MultiTaskDatasetReader):
            raise ConfigurationError(self._WRONG_READER_ERROR)
        self._dataset_reader.readers[task].apply_token_indexers(instance)
        outputs = self._model.forward_on_instance(instance)
        return sanitize(outputs)

    @overrides
    def _json_to_instance(self, json_dict: JsonDict) -> Instance:
        task = "sql"
        source = json_dict["source"]
        source_lang = json_dict['source_lang']
        instance = self._dataset_reader.readers[task].text_to_instance(source_string=source, source_lang=source_lang)
        instance.add_field("task", MetadataField(task))
        return instance
