import logging
import argparse
import torch
import numpy as np

from collections import defaultdict
from torch.nn.modules.loss import BCEWithLogitsLoss
from torch.utils.data import TensorDataset, DataLoader
from torch.optim import Adam

from models import LinearClassifier
from data.predata import load_data
from tools.metrics import precision_at_k

logger = logging.getLogger()
formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(name)s | %(message)s')
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)  # 也可以直接给formatter赋值

logger.addHandler(console_handler)


class Trainer:
    def __init__(self, args, model):
        self.args = args
        self.loss_fn = BCEWithLogitsLoss()
        if torch.cuda.is_available():
            self.loss_fn.to(0)
        self.optimizer = Adam(
            model.parameters(),
            lr=args.lr
        )
        self.model = model
    
    def train_epoch(self, dataloader):
        self.model.train()
        loss_sum = 0
        for items in dataloader:
            x, y = items[:2]
            if torch.cuda.is_available():
                x = x.to(0)
                y = y.to(0)
            logits = self.model(x)
            loss = self.loss_fn(logits, y)
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            loss_sum += loss.detach()
        return self.model, loss_sum

    def eval_model(self, dataloader):
        self.model.eval()
        results = defaultdict(list)
        for items in dataloader:
            xs, ys, doc_ids = items
            if torch.cuda.is_available():
                xs = xs.to(0)
                ys = ys.to(0)
                
            logits = self.model(xs)
            for (logit, y, di) in zip(logits, ys, doc_ids):
                di = int(di)
                results[di].append([logit, y])
        
        # logger.warning("number of documents: {}".format(len(results.keys())))
        p3 = [precision_at_k(results[di]) for di in results]
        return np.average(p3)

        """
        # calculate accuracy 
        all_preds = torch.cat(all_preds, dim=-1).squeeze(-1)
        all_labels = torch.cat(all_labels, dim=-1).squeeze(-1)

        right_num = 0
        for (p, l) in zip(all_preds, all_labels):
            if p == l:
                right_num += 1
        return right_num / all_preds.size(0)
        """

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("-i", help="input directory", type=str)
    parser.add_argument("--emb", help="embedding file", type=str, default="sent_embedding.jsonl")
    parser.add_argument("--doc", help="document file", type=str, default="document.jsonl")
    parser.add_argument("--log", help="log file", type=str, default="debug.log")
    parser.add_argument("--valid-size", help="the size of valid dataset", default=1000)
    
    parser.add_argument("-id", help="the dimension of input embedding", default=1024)

    # training arguments
    parser.add_argument("--batch-size", default=32)
    parser.add_argument("--epoch", default=100)
    parser.add_argument("--centering", action="store_true")
    parser.add_argument(
        "--early-stop", default=0, 
        help="If 0 (by default), there is no early stopping "
    )
    parser.add_argument("-lr", type=float, default=1e-3)
    args = parser.parse_args()

    file_handler = logging.FileHandler(args.log)
    file_handler.setFormatter(formatter)  # 也可以直接给formatter赋值
    logger.addHandler(file_handler)

    # load dataset
    logger.info("Loading datasets...")
    train_xs, train_ys, train_doc_ids, dev_xs, dev_ys, dev_doc_ids = load_data(
        valid_size=args.valid_size,
        input_dir=args.i,
        embfile=args.emb,
        datafile=args.doc,
        centering=args.centering
    )
    train_dataset = TensorDataset(train_xs, train_ys, train_doc_ids)
    dev_dataset = TensorDataset(dev_xs, dev_ys, dev_doc_ids)

    train_dataloader = DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True
    )
    dev_dataloader = DataLoader(
        dev_dataset, batch_size=args.batch_size
    )
    logger.info("Dataloader is created !")

    # define model
    model = LinearClassifier(args.id)
    if torch.cuda.is_available():
        model = model.to(0)

    # define optimizer
    trainer = Trainer(args, model)
    # training
    best_acc = 0.0
    for epoch_i in range(args.epoch):
        # train_epoch
        _, loss = trainer.train_epoch(train_dataloader)
        dev_acc = trainer.eval_model(dev_dataloader)
        best_acc = max(dev_acc, best_acc)
        logging_str = "Epoch {} / {}; loss: {:.3f} P@3: {:.2f}; Best P@3: {:.2f}".format(
            epoch_i, 
            args.epoch,
            loss,
            dev_acc * 100,
            best_acc * 100
        )
        logger.warning(logging_str)
