import torch
import pandas as pd
import torch.nn as nn
from statistics import mean, stdev, variance, median
from icecream import ic
from tqdm import tqdm
import os
import numpy as np
import collections.abc
import json

from qa.table.utils import *

EMPTY_TK = '<EMPTY>'
# CONSTANTS copied from pytorch used in self-defined collate_fn
default_collate_err_msg_format = (
    "default_collate: batch must contain tensors, numpy arrays, numbers, "
    "dicts or lists; found {}")
inf = math.inf
nan = math.nan
string_classes = (str, bytes)
int_classes = int
container_abcs = collections.abc
np_str_obj_array_pattern = re.compile(r'[SaUO]')


def str2tuple(string):
    """ Convert string format coordinated into list of tuples."""
    tuple_list = eval(string)
    for i in range(len(tuple_list)):
        tuple_list[i] = eval(tuple_list[i])
    return tuple_list


def load_checkpoint(model, checkpoint_path):
    """ Load checkpoint. """
    model_ckpt = torch.load(checkpoint_path)
    if isinstance(model, nn.DataParallel):
        model.module.load_state_dict(model_ckpt['state_dict'])
    else:
        model.load_state_dict(model_ckpt['state_dict'])
    return model


def save_checkpoint(model, checkpoint_path):
    """ Save checkpoint. """
    model_params = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
    torch.save(dict(
        state_dict=model_params
    ), checkpoint_path)


def select_by_coord(table, coord):
    """ Select cell value given coord. """
    if coord[0] == -1:  # column line
        return table.columns[coord[1]]
    else:
        return table.iat[coord]


def drop_invalid_samples(data: pd.DataFrame,
                         invalid_files,
                         single_answer=False):
    """ Drop invalid samples in given data.
    Args:
        data: sample data in pd.Dataframe format,
        invalid_files: invalid files
        single_answer: whether to only allow single cell answer
    Return:
        processed sample data
    """
    for file in invalid_files:
        data = data.drop(data[data['table_file'] == file+'.csv'].index)
    if single_answer:
        multi_answer_idx = []
        for index, row in data.iterrows():
            if len(eval(row.answer_coordinates)) > 1:
                multi_answer_idx.append(index)
        data = data.drop(multi_answer_idx)
    return data


def train_step(model, batch, device):
    """ A tapas train forward step given a batch. Return model outputs."""
    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    token_type_ids = batch["token_type_ids"].to(device)
    labels = batch["labels"].to(device)
    numeric_values = batch["numeric_values"].to(device).float()
    numeric_values_scale = batch["numeric_values_scale"].to(device).float()
    if "float_answer" in batch:
        float_answer = batch["float_answer"].to(device)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
                        labels=labels, numeric_values=numeric_values, numeric_values_scale=numeric_values_scale,
                        float_answer=float_answer)
    elif "aggregation_labels" in batch:
        aggregation_labels = batch['aggregation_labels'].to(device)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
                        labels=labels, aggregation_labels=aggregation_labels)
    else:
        raise ValueError("Neither weak supervision nor supervision is available in train_step.")
    return outputs


def test_step(model, batch, device):
    """ A tapas test forward step given a batch. Return model outputs."""
    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    token_type_ids = batch["token_type_ids"].to(device)
    outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
    return outputs


def evaluate(answers_info, data, table_dir):
    """ Evaluate tapas answers. """
    n_hit = 0
    id2aggregation = {0: "NONE", 1: "SUM", 2: "AVERAGE", 3: "COUNT"}
    aggr_numbers = {v: 0 for v in id2aggregation.values()}
    # f = open('/data/home/hdd3000/USER/HMT-QA/datadump/data/webdata/raw_input/tapas_data
    # /checkpoints_wtq_multianswer/success.txt', 'w')
    for id, (coords, aggr_idx) in answers_info.items():
        sample = data[data['id'] == id]
        table = pd.read_csv(table_dir + sample['table_file'].values[0]).astype(str)
        # target answer
        if np.isnan(sample['float_answer'].values[0]):
            target_answer = eval(sample['answer_text'].values[0])
        else:
            target_answer = sample['float_answer'].values[0]
        # predicted answer
        pred_answer = []
        for coord in coords:
            pred_answer.append(select_by_coord(table, coord))
        pred_answer = [naive_str_to_float(s) for s in pred_answer]
        if not pred_answer:
            pred_answer.append(EMPTY_TK)
        aggr = id2aggregation[aggr_idx]
        try:
            if aggr == 'NONE':
                pass
            elif aggr == 'SUM':
                pred_answer = sum(pred_answer)
            elif aggr == 'AVERAGE':
                pred_answer = mean(pred_answer)
            elif aggr == 'COUNT':
                pred_answer = len(pred_answer)
        except Exception as e:
            print(f"Error in answer aggregation: {e}")
        # score
        score = hmt_score(pred_answer, target_answer)
        if score == 1:
            n_hit += 1
            aggr_numbers[aggr] += 1
            # print(f"id: {id}, tid: {sample['table_file'].values[0]}", file=f)
            # print(f"target answer: {target_answer}", file=f)
            # print(f"predicted answer: {pred_answer}", file=f)
            # print(f"aggr: {aggr}\n", file=f)
    info = {
        'aggr_numbers': aggr_numbers
    }
    return n_hit / len(answers_info), info



def collate_fn_skip_none(batch):
    """ Add "skip none" on default collate_fn.
    Puts each data field into a tensor with outer dimension batch size"""

    batch = [item for item in batch if item is not None]  # added

    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(default_collate_err_msg_format.format(elem.dtype))

            return collate_fn_skip_none([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int_classes):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, container_abcs.Mapping):
        return {key: collate_fn_skip_none([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(collate_fn_skip_none(samples) for samples in zip(*batch)))
    elif isinstance(elem, container_abcs.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError('each element in list of batch should be of equal size')
        transposed = zip(*batch)
        return [collate_fn_skip_none(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))


def print_if_valid_answers_in_table(samples, table_dir):
    """ Print if answer_text corresponds to text in answer_coordinate, for DEBUG use."""
    for idx, row in samples.iterrows():
        table = pd.read_csv(table_dir + row['table_file'])
        ic(row.answer_coordinates)
        ic(row.answer_text)
        actual_answer = []
        for coord in str2tuple(row['answer_coordinates']):
            actual_answer.append(select_by_coord(table, coord))
        ic(actual_answer)