from transformers import TapasTokenizer, TapasConfig, TapasForQuestionAnswering, AdamW
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn as nn
from torch.utils.data import DataLoader
from argparse import ArgumentParser

from qa.tapas.dataloader import TableDataset
from qa.tapas.utils import *


def main():
    # load data
    tokenizer = TapasTokenizer.from_pretrained(args.model_name)
    test_file = os.path.join(args.root_dir, args.data_dir, args.test_file)
    table_dir = os.path.join(args.root_dir, args.data_dir, args.table_dir)
    test_data = pd.read_csv(test_file, sep='\t')
    test_dataset = TableDataset(test_data, tokenizer, table_dir, phase='test')
    test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    # config
    config = TapasConfig(
        num_aggregation_labels=args.num_aggregation_labels,
        use_answer_as_supervision=args.use_answer_as_supervision,
        answer_loss_cutoff=args.answer_loss_cutoff,
        cell_selection_preference=args.cell_selection_preference,
        huber_loss_delta=args.huber_loss_delta,
        init_cell_selection_weights_to_zero=args.init_cell_selection_weights_to_zero,
        select_one_column=args.select_one_column,
        allow_empty_column_selection=args.allow_empty_column_selection,
        temperature=args.temperature,
    )
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = TapasForQuestionAnswering.from_pretrained(args.model_name, config=config).to(device)
    if torch.cuda.device_count() > 1:
        # model = DDP(model)
        model = nn.DataParallel(model)
    if args.ckpt_file:
        model = load_checkpoint(model, os.path.join(args.root_dir, args.data_dir, args.ckpt_dir, args.ckpt_file))

    # inference
    model.eval()
    answers_info = {}
    for idx, (batch, id) in enumerate(tqdm(test_dataloader)):
        id = int(id.cpu().detach().item())
        outputs = test_step(model, batch, device)
        predicted_answer_coordinates, predicted_aggregation_indices = tokenizer.convert_logits_to_predictions(
            batch,
            outputs.logits.cpu().detach(),
            outputs.logits_aggregation.cpu().detach()
        )

        answers_info[id] = [predicted_answer_coordinates[0], predicted_aggregation_indices[0]]

    # evaluate
    acc, info = evaluate(answers_info, test_data, table_dir)
    print(json.dumps(info['aggr_numbers'], indent=2))
    print(f"Accuracy: {acc}")


if __name__ == '__main__':
    parser = ArgumentParser()
    # path
    parser.add_argument('--root_dir', type=str, default='/data/home/hdd3000/USER/HMT/')
    parser.add_argument('--data_dir', type=str, default='qa/data/')
    parser.add_argument('--ckpt_dir', type=str, default='raw_input/tapas_data/checkpoints/')
    parser.add_argument('--ckpt_file', type=str, default=None)
    parser.add_argument('--table_dir', type=str, default='raw_input/tapas_data/tables/')
    parser.add_argument('--test_file', type=str, default='raw_input/tapas_data/test_samples.tsv')
    # config
    parser.add_argument('--model_name', type=str, default='google/tapas-base')
    parser.add_argument('--num_aggregation_labels', type=int, default=4)
    parser.add_argument('--use_answer_as_supervision', type=bool, default=True)
    parser.add_argument('--answer_loss_cutoff', type=float, default=0.664694)
    parser.add_argument('--cell_selection_preference', type=float, default=0.207951)
    parser.add_argument('--huber_loss_delta', type=float, default=0.121194)
    parser.add_argument('--init_cell_selection_weights_to_zero', type=bool, default=True)
    parser.add_argument('--select_one_column', type=bool, default=True)
    parser.add_argument('--allow_empty_column_selection', type=bool, default=False)
    parser.add_argument('--temperature', type=float, default=0.0352513)
    args = parser.parse_args()

    main()
