# -*- coding: UTF-8 -*-
import torch
import os
from transformers import BertModel,BertTokenizer,BertConfig
from metric import cal_hr
# from dataset_self import get_data
from dataset import get_data
from setup_seed import setup_seed
from SGTA import SGTA, IntentionClassifier
from path import ROOT_DIR

import datetime
from tqdm import tqdm


def train(model, optimizer, criterion, data_loader, device, verbose=False):
    model.train()
    loss_sum = 0

    # get data
    data_iter = enumerate(data_loader)
    if verbose:
        data_iter = tqdm(enumerate(data_loader), total=len(data_loader))

    # train
    for i, (seq, labels) in data_iter:
        seq, labels = seq.to(device), labels.to(device)
        data = {
            'seq': seq,
        }
        # print(data['seq'].shape[2])
        # break
        optimizer.zero_grad()

        # forward & backward
        outputs, loss = model(data)

        loss = loss + criterion(outputs.to(device), labels.to(device))
        loss.backward()
        optimizer.step()

        loss_sum += loss.item() * seq.shape[0]
    if verbose:
        print(f"loss: {loss_sum}")


def eval(model, data_loader, device, metrics=None, verbose=False, valid=True):
    if metrics is None:
        metrics = [1, 3, 5, 10]

    model.eval()
    with torch.no_grad():
        hrs = [0 for _ in range(len(metrics))]
        for _, (seq, labels) in enumerate(data_loader):
            seq, labels = seq.to(device), labels.to(device)
            data = {
                'seq': seq,
            }
            
            # forward & backward
            outputs, _ = model(data)

            # metric
            result = torch.topk(outputs, k=metrics[-1], dim=1)[1]
            for i, k in enumerate(metrics):
                hrs[i] += cal_hr(result[:, :k].cpu().numpy(), labels.cpu().numpy())

        for i, k in enumerate(metrics):
            hrs[i] = hrs[i] / len(data_loader.dataset)
            if verbose:
                if valid:
                    print(f'valid, HR@{k}: {hrs[i]:.4f}')
                else:
                    print(f'test, HR@{k}: {hrs[i]:.4f}')
    return hrs


def trainer(config):
    setup_seed(2021)
    device = config["device"]
    verbose = config["verbose"]
    metrics = config["metrics"]

    # get dataset
    train_dataset, valid_dataset, test_dataset = get_data(seq_length=config["seq_length"],
                                                          front_padding=config["front_padding"])

    # get dataloader
    train_loader = train_dataset.get_data_loader(device=device, batch_size=config["batch_size"], shuffle=True)
    valid_loader = valid_dataset.get_data_loader(device=device, batch_size=config["batch_size"], shuffle=False)
    test_loader = test_dataset.get_data_loader(device=device, batch_size=config["batch_size"], shuffle=False)

    # model
    model = SGTA(config).to(device)
    # model = IntentionClassifier(config).to(device)

    # optimizer & loss function
    optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"], weight_decay=config["reg"], amsgrad=True)
    criterion = torch.nn.CrossEntropyLoss()

    # train & evaluate
    best_hr = 0
    best_test_hrs = []
    for epoch in range(config["epoch"]):
        if config["verbose"]:
            print(f"epoch: {epoch}/{config['epoch']}")
            print("Start Training:" + datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))

        train(model, optimizer, criterion, train_loader, device, verbose=verbose)

        if config["verbose"]:
            print("Start Validating:" + datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
        valid_hrs = eval(model, valid_loader, device, metrics=metrics, verbose=verbose, valid=True)

        if config["verbose"]:
            print("Start Testing:" + datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
        test_hrs = eval(model, test_loader, device, metrics=metrics, verbose=verbose, valid=False)

        if valid_hrs[1] > best_hr:  # hr@3
            best_hr = max(best_hr, valid_hrs[1])
            torch.save(model, f"{ROOT_DIR}/model.pth")
            best_test_hrs = test_hrs

    if config["verbose"]:
        print(f'\nresults:')
        for i, k in enumerate(metrics):
            print(f'test, HR@{k}: {best_test_hrs[i]:.4f} ')


if __name__ == '__main__':
    # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

    config = {
        "item_num": 2570 + 1,
        "seq_length": 20,
        "epoch": 50,
        "batch_size": 128,
        "metrics": [1, 3, 5, 10],
        "front_padding": True,
        "hidden_size": 768,
        "lr": 1e-5,
        "reg": 0,
        "n_layers": 2,
        "n_heads": 1,
        "dropout_prob": 0.1,
        "device": torch.device("cuda:0"),
        "verbose": True,
        "report": False,
        "bert_path": "pretrain_model/wwm_ext/",
        "bert_config": "pretrain_model/wwm_ext/config.json"
    }
    print(config)
    trainer(config)
