from typing import Dict, List, Tuple, Any
import copy
import sys
import json
import numpy as np
from pathlib import Path

from qa.table_bert.config import TableBertConfig, BERT_CONFIGS
from qa.table_bert.table_bert import TableBertModel
from qa.table_bert.vertical.config import VerticalAttentionTableBertConfig
from qa.table_bert.table import Column, Table
from qa.table_bert.vertical.vertical_attention_table_bert import VerticalAttentionTableBert
from qa.table_bert.vanilla_table_bert import VanillaTableBert

from qa.nsm.parser_module.sequence_util import StringMatchUtil


def get_table_bert_model(config: Dict, use_proxy=False, master=None):
    model_name_or_path = config.get('table_bert_model_or_config')
    if model_name_or_path in {None, ''}:
        model_name_or_path = config.get('table_bert_config_file')
    if model_name_or_path in {None, ''}:
        model_name_or_path = config.get('table_bert_model')

    table_bert_extra_config = config.get('table_bert_extra_config', dict())

    # print(f'Loading table BERT model {model_name_or_path}', file=sys.stderr)
    model = TableBertModel.from_pretrained(
        model_name_or_path,
        **table_bert_extra_config
    )

    if type(model) == VanillaTableBert:
        model.config.column_representation = config.get('column_representation', 'mean_pool_column_name')

    if use_proxy:
        from qa.nsm.parser_module.table_bert_proxy import TableBertProxy
        tb_config = copy.deepcopy(model.config)
        del model
        model = TableBertProxy(actor_id=master, table_bert_config=tb_config)

    # print('Table Bert Config', file=sys.stderr)  # TODO: restore
    """
    Table Bert Config
        {
          "base_model_name": "bert-base-uncased",
          "column_delimiter": "[SEP]",
          "context_first": true,
          "column_representation": "mean_pool_column_name",
          "max_cell_len": 5,
          "max_sequence_len": 512,
          "max_context_len": 256,
          "do_lower_case": true,
          "cell_input_template": [
            "column",
            "|",
            "type",
            "|",
            "value"
          ],
          "masked_context_prob": 0.15,
          "masked_column_prob": 0.2,
          "max_predictions_per_seq": 100,
          "context_sample_strategy": "nearest",
          "table_mask_strategy": "column",
          "vocab_size": 30522,
          "hidden_size": 768,
          "num_hidden_layers": 12,
          "num_attention_heads": 12,
          "hidden_act": "gelu",
          "intermediate_size": 3072,
          "hidden_dropout_prob": 0.1,
          "attention_probs_dropout_prob": 0.1,
          "max_position_embeddings": 512,
          "type_vocab_size": 2,
          "initializer_range": 0.02
        }
    """
    # print(json.dumps(vars(model.config), indent=2), file=sys.stderr)

    return model


def get_table_bert_model_deprecated(config: Dict, use_proxy=False, master=None):
    tb_path = config.get('table_bert_model_or_config')
    if tb_path is None or tb_path == '':
        tb_path = config.get('table_bert_config_file')
    if tb_path is None or tb_path == '':
        tb_path = config.get('table_bert_model')

    tb_path = Path(tb_path)
    assert tb_path.exists()

    if tb_path.suffix == '.json':
        tb_config_file = tb_path
        tb_path = None
    else:
        print(f'Loading table BERT model {tb_path}', file=sys.stderr)
        tb_config_file = tb_path.parent / 'tb_config.json'

    if use_proxy:
        from qa.nsm.parser_module.table_bert_proxy import TableBertProxy
        tb_config = TableBertConfig.from_file(tb_config_file)
        table_bert_model = TableBertProxy(actor_id=master, table_bert_config=tb_config)
    else:
        table_bert_extra_config = config.get('table_bert_extra_config', dict())
        # if it is a not pre-trained model, we use the default parameters
        if tb_path is None:
            table_bert_cls = TableBertConfig.infer_model_class_from_config_file(tb_config_file)
            print(f'Creating a default {table_bert_cls.__name__} without pre-trained parameters!', file=sys.stderr)

            table_bert_model = table_bert_cls(
                config=table_bert_cls.CONFIG_CLASS.from_file(
                    tb_config_file, **table_bert_extra_config
                )
            )
        else:
            table_bert_model = TableBertModel.from_pretrained(
                tb_path,
                **table_bert_extra_config
            )

        if type(table_bert_model) == VanillaTableBert:
            table_bert_model.config.column_representation = config.get('column_representation', 'mean_pool_column_name')

        print('Table Bert Config', file=sys.stderr)
        print(json.dumps(vars(table_bert_model.config), indent=2), file=sys.stderr)

    return table_bert_model


def model_use_vertical_attention(bert_model):
    return isinstance(bert_model.config, VerticalAttentionTableBertConfig)


def get_question_biased_sampled_rows(context, table, num_rows=3):
        candidate_row_match_score = {}
        for row_id, row in enumerate(table.data):
            row_data = list(row.values() if isinstance(row, dict) else row)
            for cell in row_data:
                if len(cell) > 0 and StringMatchUtil.contains(context, cell) and not StringMatchUtil.all_stop_words(cell):
                    candidate_row_match_score[row_id] = max(
                        candidate_row_match_score.get(row_id, 0),
                        len(cell)
                    )

        candidate_row_ids = [idx for idx, score in candidate_row_match_score.items() if score > 0]
        if len(candidate_row_ids) < num_rows:
            # find partial match
            max_ngram_num = 3
            for row_id, row in enumerate(table.data):
                if row_id in candidate_row_ids:
                    continue

                row_data = list(row.values() if isinstance(row, dict) else row)

                for cell in row_data:
                    found = False
                    if len(cell) > 0:
                        for ngram_num in reversed(range(1, max_ngram_num + 1)):
                            for start_idx in range(0, len(cell) - ngram_num + 1):
                                end_idx = start_idx + ngram_num
                                ngram = cell[start_idx: end_idx]
                                if not StringMatchUtil.all_stop_words(ngram) and StringMatchUtil.contains(context, ngram):
                                    candidate_row_match_score[row_id] = max(
                                        ngram_num,
                                        candidate_row_match_score.get(row_id, 0)
                                    )
                                    found = True

                                if found: break
                            if found: break

            candidate_row_ids = [idx for idx, score in candidate_row_match_score.items() if score > 0]
            if len(candidate_row_ids) < num_rows:
                not_included_row_ids = [idx for idx in range(len(table)) if idx not in candidate_row_ids]
                left = num_rows - len(candidate_row_ids)
                for idx in not_included_row_ids[:left]: candidate_row_match_score[idx] = 0
                candidate_row_ids = candidate_row_ids + not_included_row_ids[:left]

        top_k_row_ids_by_match_score = sorted(
            candidate_row_ids,
            key=lambda row_id: -candidate_row_match_score[row_id])[:num_rows]
        top_k_row_ids_by_match_score = sorted(top_k_row_ids_by_match_score)
        candidate_rows = [table.data[idx] for idx in top_k_row_ids_by_match_score]

        return candidate_rows


def get_question_biased_sampled_cells(context, table):
        candidate_cells = [[] for column in table.header]

        for col_idx, column in enumerate(table.header):
            cell_match_scores = []

            for row in table.data:
                cell = row.get(table.header[col_idx].name, []) if isinstance(row, dict) else row[col_idx]
                if len(cell) > 0 and StringMatchUtil.contains(context, cell) and not StringMatchUtil.all_stop_words(cell):
                    cell_match_scores.append((cell, len(cell)))

            if len(cell_match_scores) == 0:
                # use partial match
                max_ngram_num = 3

                for row_id, row in enumerate(table.data):
                    cell = row.get(table.header[col_idx].name, []) if isinstance(row, dict) else row[col_idx]
                    found = False
                    if len(cell) > 0:
                        for ngram_num in reversed(range(1, max_ngram_num + 1)):
                            for start_idx in range(0, len(cell) - ngram_num + 1):
                                end_idx = start_idx + ngram_num
                                ngram = cell[start_idx: end_idx]
                                if not StringMatchUtil.all_stop_words(ngram) and StringMatchUtil.contains(context, ngram):
                                    cell_match_scores.append((cell, ngram_num))
                                    found = True

                                if found: break

                            if found: break

            best_matched_cell = sorted(cell_match_scores, key=lambda x: -x[1])
            if best_matched_cell:
                best_matched_cell = best_matched_cell[0][0]
            else:
                best_matched_cell = column.sample_value_tokens

            candidate_cells[col_idx] = best_matched_cell

        return candidate_cells


def get_table_bert_input_from_context(
    env_context: List[Dict],
    bert_model: TableBertModel,
    is_training: bool,
    **kwargs
) -> Tuple[List[Any], List[Table]]:
    contexts = []
    tables = []

    content_snapshot_strategy = kwargs.get('content_snapshot_strategy', None)
    if content_snapshot_strategy:
        assert content_snapshot_strategy in ('sampled_rows', 'synthetic_row')

    for e in env_context:
        contexts.append(e['question_tokens'])
        tables.append(e['table'])

    return contexts, tables
