import logging
import argparse
import torch
import numpy as np
import os

from collections import defaultdict
from torch.nn.modules.loss import CrossEntropyLoss
from torch.utils.data import TensorDataset, DataLoader
from torch.optim import Adam
# from tools.ioFn import readJsonl
from models import LinearClassifier
from tools.metrics import precision_at_k

logger = logging.getLogger()
formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(name)s | %(message)s')
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)  # 也可以直接给formatter赋值

logger.addHandler(console_handler)

def load_x_data(input_dir, system, embfile, split, lang):
    data_dir = os.path.join(
        input_dir,
        "{}_{}_{}".format(system, split, lang)
    )
    data_file = os.path.join(data_dir, embfile)
    embeds = np.load(open(data_file, 'rb'))
    tensor_embeds = torch.tensor(embeds, dtype=torch.float32)
    return tensor_embeds


class Trainer:
    def __init__(self, args, model):
        self.args = args
        self.loss_fn = CrossEntropyLoss()
        if torch.cuda.is_available():
            self.loss_fn.to(0)
        self.optimizer = Adam(
            model.parameters(),
            lr=args.lr
        )
        self.model = model
    
    def train_epoch(self, dataloader):
        self.model.train()
        loss_sum = 0
        i = 0
        for items in dataloader:
            i += 1
            x, y = items[:2]
            if torch.cuda.is_available():
                x = x.to(0)
                y = y.to(0)
            logits = self.model(x)
            loss = self.loss_fn(logits, y)
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            loss_sum += loss.detach()
        return self.model, loss_sum

    def eval_model(self, dataloader):
        self.model.eval()
        all_labels = []
        all_preds = []
        for items in dataloader:
            xs, ys = items
            if torch.cuda.is_available():
                xs = xs.to(0)
                ys = ys.to(0)
                
            logits = self.model(xs)
            all_preds.append(torch.argmax(logits, dim=-1))
            all_labels.append(ys)
        
        # calculate accuracy 
        all_preds = torch.cat(all_preds, dim=-1).squeeze(-1)
        all_labels = torch.cat(all_labels, dim=-1).squeeze(-1)

        right_num = 0
        for (p, l) in zip(all_preds, all_labels):
            if p == l:
                right_num += 1
        return right_num / all_preds.size(0)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("-i", help="input directory", type=str)
    parser.add_argument("--emb", help="embedding file", type=str, default="document_embedding.npy")
    parser.add_argument("--log", help="log file", type=str, default="proj_debug.log")
    parser.add_argument("--system", help="the prefix of system", type=str, default="proj")
    parser.add_argument("-id", help="the dimension of input embedding", default=1024)
    parser.add_argument("--langs", nargs="+", type=str, default=["en", "zh", "fr"])
    # training arguments
    parser.add_argument("--batch-size", default=32)
    parser.add_argument("--epoch", default=100)
    parser.add_argument(
        "--early-stop", default=0, 
        help="If 0 (by default), there is no early stopping "
    )
    parser.add_argument("-lr", type=float, default=1e-3)
    args = parser.parse_args()

    file_handler = logging.FileHandler(args.log)
    file_handler.setFormatter(formatter)  # 也可以直接给formatter赋值
    logger.addHandler(file_handler)

    # load dataset
    logger.info("Loading datasets...")
    all_train_xs, all_train_ys = [], []
    all_dev_xs, all_dev_ys = [], []
    for (i, lang) in enumerate(args.langs):
        train_xs = load_x_data(
            input_dir=args.i,
            system=args.system,
            embfile=args.emb,
            # split="train",
            split="dev",
            lang=lang
        )

        train_ys = torch.tensor(
            [i for _ in range(train_xs.size(0))], dtype=torch.long
        )


        dev_xs = train_xs[450:]
        dev_ys = train_ys[450:]
        train_xs = train_xs[:450]
        train_ys = train_ys[:450]

        # dev_xs = load_x_data(
        #     input_dir=args.i,
        #     system=args.system,
        #     embfile=args.emb,
        #     split="dev",
        #     lang=lang
        # )

        # dev_ys = torch.tensor(
        #     [i for _ in range(dev_xs.size(0))], dtype=torch.long
        # )

        all_train_xs.append(train_xs)
        all_train_ys.append(train_ys)
        all_dev_xs.append(dev_xs)
        all_dev_ys.append(dev_ys)

    all_train_xs = torch.cat(all_train_xs, dim=0)
    all_train_ys = torch.cat(all_train_ys, dim=0)
    all_dev_xs = torch.cat(all_dev_xs, dim=0)
    all_dev_ys = torch.cat(all_dev_ys, dim=0)

    train_dataset = TensorDataset(all_train_xs, all_train_ys)
    # dev_dataset = TensorDataset(all_dev_xs, all_dev_ys)
    dev_dataset = TensorDataset(all_dev_xs, all_dev_ys)

    train_dataloader = DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True
    )
    dev_dataloader = DataLoader(
        dev_dataset, batch_size=args.batch_size
    )
    logger.info("Dataloader is created !")

    # define model
    model = LinearClassifier(
        input_dim=args.id,
        output_dim=len(args.langs)
    )
    if torch.cuda.is_available():
        model = model.to(0)

    # define optimizer
    trainer = Trainer(args, model)
    # training
    best_acc = 0.0
    for epoch_i in range(args.epoch):
        # train_epoch
        _, loss = trainer.train_epoch(train_dataloader)
        dev_acc = trainer.eval_model(dev_dataloader)
        best_acc = max(dev_acc, best_acc)
        logging_str = "Epoch {} / {}; loss: {:.3f} Acc: {:.2f}; Best Acc: {:.2f}".format(
            epoch_i, 
            args.epoch,
            loss,
            dev_acc * 100,
            best_acc * 100
        )
        logger.warning(logging_str)


# python3 main.py -i /home/tiger/xgiga_dumpEmb_proj/embedding --log proj_debug.log --system proj_ln_3 --langs en zh