import torch
import torch.nn as nn
import torch.nn.functional as F

import pytorch_lightning as pl

import torchmetrics
from pathlib import Path

from finetune_datasets import FinetuneDataset
from dpc_rnn import DPC_RNN_Finetune
from transforms import *


def neq_load_customized(model, pretrained_model_path):
    ckpt = torch.load(pretrained_model_path)
    ckpt_dict = ckpt["state_dict"].items()
    pretrained_dict = {k.replace("model.", ""): v for k, v in ckpt_dict}

    model_dict = model.state_dict()
    tmp = {}
    print("\n=======Check Weights Loading======")
    print("Weights not used from pretrained file:")
    for k, v in pretrained_dict.items():
        if k in model_dict:
            tmp[k] = v
        else:
            print(k)
    print("---------------------------")
    print("Weights not loaded into new model:")
    for k, v in model_dict.items():
        if k not in pretrained_dict:
            print(k)
    print("===================================\n")
    del pretrained_dict
    model_dict.update(tmp)
    del tmp
    model.load_state_dict(model_dict)
    return model


class PosePretrainingModel(pl.LightningModule):
    def __init__(self, params):
        super(PosePretrainingModel, self).__init__()
        self.params = params

        self.train_transforms = Compose(
            [
                PoseSelect(preset="mediapipe_holistic_minimal_27"),
                CenterAndScaleNormalize(
                    reference_points_preset="shoulder_mediapipe_holistic_minimal_27"
                ),
                ShearTransform(shear_std=0.1),
                RotatationTransform(rotation_std=0.1),
            ]
        )

        self.valid_transforms = Compose(
            [
                PoseSelect(preset="mediapipe_holistic_minimal_27"),
                CenterAndScaleNormalize(
                    reference_points_preset="shoulder_mediapipe_holistic_minimal_27"
                ),
            ]
        )

        self.train_dataset = FinetuneDataset(
            "/home/username/data-disk/datasets/AUTSL/train_labels.csv",
            "/home/username/data-disk/datasets/AUTSL/poses_pickle/train_poses/new_train_poses",
            dataset="include",
            splits=["train"],
            transforms=self.train_transforms,
        )
        self.test_dataset = FinetuneDataset(
            "/home/username/data-disk/datasets/AUTSL/validation_labels.csv",
            "/home/username/data-disk/datasets/AUTSL/poses_pickle/val_poses/new_val_poses/",
            dataset="include",
            splits=["val"],
            transforms=self.valid_transforms,
        )
        self.learning_rate = params.get("lr", 1e-4)
        self.max_epochs = params.get("max_epochs", 1)
        self.num_workers = params.get("num_workers", 0)
        self.batch_size = params.get("batch_size", 2)

        self.output_path = Path.cwd() / params.get("output_path", "model-outputs")
        self.output_path.mkdir(exist_ok=True)

        self.loss = nn.CrossEntropyLoss()
        self.accuracy_metric = torchmetrics.functional.accuracy

        self.model = DPC_RNN_Finetune(num_class=self.train_dataset.num_class)
        # use pretrained
        self.model = neq_load_customized(
            self.model, "DPC-outputs-augs-new-pred-4/epoch=114-step=70954.ckpt"
        )

        self.checkpoint_callback = pl.callbacks.ModelCheckpoint(
            dirpath=self.output_path,
            verbose=True,
        )

    def training_step(self, batch, batch_idx):
        input_seq, target = batch
        B = input_seq.size(0)
        y_hat = self.model(input_seq.float())
        loss = self.loss(y_hat, target)
        acc = self.accuracy_metric(F.softmax(y_hat, dim=-1), target)
        self.log("train_loss", loss)
        self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True)

        return {"loss": loss, "train_acc": acc}

    def validation_step(self, batch, batch_idx):
        input_seq, target = batch
        B = input_seq.size(0)
        y_hat = self.model(input_seq.float())

        loss = self.loss(y_hat, target)
        preds = F.softmax(y_hat, dim=-1)
        acc_top1 = self.accuracy_metric(preds, target)
        acc_top3 = self.accuracy_metric(preds, target, top_k=3)
        acc_top5 = self.accuracy_metric(preds, target, top_k=5)
        self.log("val_loss", loss)
        self.log("val_acc", acc_top1, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_acc_top3", acc_top3, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_acc_top5", acc_top5, on_step=False, on_epoch=True, prog_bar=True)
        return {"valid_loss": loss, "valid_acc": acc_top1}

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True,
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate)
        lr_scheduler = {
            "scheduler": torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=10, verbose=True
            )
        }
        return [optimizer], [lr_scheduler]

    def fit(self):
        self.trainer = pl.Trainer(
            gpus=1,
            precision=16,
            max_epochs=self.max_epochs,
            default_root_dir=self.output_path,
            logger=pl.loggers.WandbLogger(),
            gradient_clip_val=self.hparams.get("gradient_clip_val", 1),
            callbacks=[self.checkpoint_callback],
        )
        self.trainer.fit(self)


hparams = {
    "lr": 1e-3,
    "max_epochs": 500,
    "batch_size": 16,
    "num_workers": 6,
    "output_path": "DPC-Finetune-outputs",
}

trainer = PosePretrainingModel(hparams)
trainer.fit()
