from transformers import TapasTokenizer, TapasConfig, TapasForQuestionAnswering, AdamW, TapasModel
from torch.utils.data import DataLoader
import torch.nn as nn
from argparse import ArgumentParser

from qa.tapas.dataloader import TableDataset
from qa.tapas.utils import *

# INVALID_TABLE_IDS = ['80', '1401', '1243', '456',  # FIXME
#                      '345', '1245', '1130', '1495',
#                      '1224', '501', '358', '1014',
#                      '1092']


def main():
    # load file
    tokenizer = TapasTokenizer.from_pretrained(args.model_name)
    # tokenizer.update_answer_coordinates = True
    table_dir = os.path.join(args.root_dir, args.data_dir, args.table_dir)
    train_file = os.path.join(args.root_dir, args.data_dir, args.train_file)
    train_data = pd.read_csv(train_file, sep='\t')
    train_dataset = TableDataset(train_data, tokenizer, table_dir, phase='train', supervise=args.supervise)
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=collate_fn_skip_none)
    dev_file = os.path.join(args.root_dir, args.data_dir, args.dev_file)
    dev_data = pd.read_csv(dev_file, sep='\t')
    dev_dataset = TableDataset(dev_data, tokenizer, table_dir, phase='test')
    dev_dataloader = DataLoader(dev_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn_skip_none)
    log_file = open(os.path.join(args.root_dir, args.data_dir, args.ckpt_dir, args.log_file), 'w')

    # config
    config = TapasConfig(
        num_aggregation_labels=args.num_aggregation_labels,
        use_answer_as_supervision=args.use_answer_as_supervision,
        answer_loss_cutoff=args.answer_loss_cutoff,
        cell_selection_preference=args.cell_selection_preference,
        huber_loss_delta=args.huber_loss_delta,
        init_cell_selection_weights_to_zero=args.init_cell_selection_weights_to_zero,
        select_one_column=args.select_one_column,
        allow_empty_column_selection=args.allow_empty_column_selection,
        temperature=args.temperature,
    )
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = TapasForQuestionAnswering.from_pretrained(args.model_name, config=config).to(device)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    if args.ckpt_file:
        model = load_checkpoint(model, os.path.join(args.root_dir, args.data_dir, args.ckpt_dir, args.ckpt_file))
    optimizer = AdamW(model.parameters(), lr=args.lr)

    # train
    best_val_acc = 0
    for epoch in range(1, args.max_epochs + 1):
        model.train()
        for idx, (batch, _) in enumerate(tqdm(train_dataloader)):
            optimizer.zero_grad()
            outputs = train_step(model, batch, device)
            loss = outputs.loss
            if loss.size(0) != 1:
                loss = torch.mean(loss)
            loss.backward()
            optimizer.step()
            print(f"epoch#{epoch}, batch#{idx}, loss:{loss}")

        if epoch % 10 == 0:
            print(f"save model at epoch#{epoch}")
            ckpt_path = os.path.join(args.root_dir, args.data_dir, args.ckpt_dir, f'model.{epoch}.pt')
            save_checkpoint(model, ckpt_path)

        model.eval()
        answers_info = {}
        for idx, (batch, id) in enumerate(tqdm(dev_dataloader)):
            id = int(id.cpu().detach().item())
            outputs = test_step(model, batch, device)
            predicted_answer_coordinates, predicted_aggregation_indices = tokenizer.convert_logits_to_predictions(
                batch,
                outputs.logits.cpu().detach(),
                outputs.logits_aggregation.cpu().detach()
            )

            answers_info[id] = [predicted_answer_coordinates[0], predicted_aggregation_indices[0]]

        acc, info = evaluate(answers_info, dev_data, table_dir)
        if acc > best_val_acc:
            best_val_acc = acc
            ckpt_path = os.path.join(args.root_dir, args.data_dir, args.ckpt_dir, f"model.best.pt")
            save_checkpoint(model, ckpt_path)
            print(f"new best model at epoch#{epoch}, acc={acc}", file=log_file)
            print(f"new best model at epoch#{epoch}, acc={acc}")
        print(json.dumps(info['aggr_numbers'], indent=2))
        print(f"Accuracy: {acc}, best accuracy: {best_val_acc}")


if __name__ == "__main__":
    parser = ArgumentParser()
    # path
    parser.add_argument('--root_dir', type=str, default='/data/home/hdd3000/USER/HMT/')
    parser.add_argument('--data_dir', type=str, default='qa/data/')
    parser.add_argument('--ckpt_dir', type=str, default='raw_input/tapas_data/checkpoints/')
    parser.add_argument('--ckpt_file', type=str, default=None)
    parser.add_argument('--table_dir', type=str, default='raw_input/tapas_data/tables/')
    parser.add_argument('--train_file', type=str, default='raw_input/tapas_data/train_samples.tsv')
    parser.add_argument('--dev_file', type=str, default='raw_input/tapas_data/dev_samples.tsv')
    parser.add_argument('--log_file', type=str, default='log.txt')
    # config
    parser.add_argument('--model_name', type=str, default='google/tapas-base')
    parser.add_argument('--supervise', action='store_true')
    parser.add_argument('--num_aggregation_labels', type=int, default=4)
    parser.add_argument('--use_answer_as_supervision', type=bool, default=True)
    parser.add_argument('--answer_loss_cutoff', type=float, default=0.664694)
    parser.add_argument('--cell_selection_preference', type=float, default=0.207951)
    parser.add_argument('--huber_loss_delta', type=float, default=0.121194)
    parser.add_argument('--init_cell_selection_weights_to_zero', type=bool, default=True)
    parser.add_argument('--select_one_column', type=bool, default=True)
    parser.add_argument('--allow_empty_column_selection', type=bool, default=False)
    parser.add_argument('--temperature', type=float, default=0.0352513)
    # learning
    parser.add_argument('--max_epochs', type=int, default=200)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--lr', type=float, default=5e-5)
    args = parser.parse_args()
    if args.supervise:
        args.use_answer_as_supervision = False
        args.train_file = 'raw_input/tapas_data/train_samples_sup.tsv'

    main()
