import sys

from qa.tapas.utils import *
from qa.table.utils import *


class TableDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, table_path, phase, supervise=False):
        self.data = data
        self.tokenizer = tokenizer
        self.table_path = table_path
        self.phase = phase
        self.supervise = supervise

    def __getitem__(self, idx):
        try:
            item = self.data.iloc[idx]
            # ic(item.id, item.table_file, item.answer_text)
            table = pd.read_csv(self.table_path + item.table_file).astype(str)  # be sure to make your table data text only
            if self.phase == 'train':
                item.answer_coordinates = str2tuple(item.answer_coordinates)
                encoding = self.tokenizer(table=table,
                                          queries=item.question,
                                          answer_coordinates=item.answer_coordinates,
                                          answer_text=item.answer_text,   # string text, to match with string in table
                                          truncation=True,
                                          padding="max_length",
                                          return_tensors="pt"
                                          )
                # remove the batch dimension which the tokenizer adds by default
                encoding = {key: val.squeeze(0) for key, val in encoding.items()}

                if not self.supervise:
                    # add the float_answer which is also required (weak supervision for aggregation case)
                    float_answer = item.float_answer
                    encoding["float_answer"] = torch.tensor(float_answer) if isinstance(float_answer, float) else np.nan
                else:
                    aggregation_labels = item.aggregation_labels
                    encoding["aggregation_labels"] = torch.tensor(aggregation_labels)
            elif self.phase == 'test':
                encoding = self.tokenizer(
                    table=table,
                    queries=item.question,
                    truncation=True,
                    padding='max_length',
                    return_tensors="pt",
                )
                encoding = {key: val.squeeze(0) for key, val in encoding.items()}
            else:
                raise TypeError("Only ['train','test'] phase is available.")

            return encoding, item.id
        except Exception as e:
            print(f"Error occurs on table {item.table_file}: {e} ")
            return None

    def __len__(self):
        return len(self.data)
