import os
import sys
import json
import torch

import numpy as np
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, f1_score

current_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(current_path, ".."))

from data_utils.net import ConvNet
from data_utils.sensor_data_loader import SensorDataset


class AuxTrainer:
    """Trainer class for training models on sensor data"""

    def __init__(self, data_dir, model="ConvNet"):
        self.data_dir = data_dir
        metadata_path = os.path.join(data_dir, "meta_data.json")
        self.metadata = json.load(open(metadata_path, "r", encoding="utf-8"))
        self.mapping = {l: i for i, l in enumerate(self.metadata["classes"])}
        self.mapping_reversed = {i: l for i, l in enumerate(self.metadata["classes"])}

        num_channels = self.metadata["num_channels"]
        num_classes = len(self.metadata["classes"])
        if model == "ConvNet":
            self.model = ConvNet(num_channels, num_classes).cuda()
        else:
            raise ValueError("Model not supported.")

    def train(
        self, lr: float = 0.001, batch_size: int = 32, num_epochs: int = 100
    ) -> None:
        """train the model on the training set"""
        self.model.train()
        train_dataset = SensorDataset(os.path.join(self.data_dir, "HF", "train"))
        train_dataloader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True
        )

        criterion = nn.CrossEntropyLoss().cuda()
        optimizer = optim.Adam(self.model.parameters(), lr=lr)

        for epoch in range(num_epochs):
            for i, data in enumerate(train_dataloader):
                inputs, labels = data
                inputs = inputs.float().cuda()
                labels = torch.tensor(
                    [self.mapping[label] for label in labels], dtype=torch.long
                ).cuda()

                optimizer.zero_grad()
                outputs = self.model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                acc = accuracy_score(
                    labels.cpu().numpy(), outputs.argmax(1).cpu().numpy()
                )
                if i % 50 == 0:
                    print(
                        f"[Epoch {epoch+1} ({i+1}/{len(train_dataloader)})] loss: {loss.item():.3f}\tacc: {acc:.3f}"
                    )
                    self.test()

    def test(self) -> None:
        """print out f1 score and accuracy of the model on the test set"""
        self.model.eval()
        test_dataset = SensorDataset(os.path.join(self.data_dir, "HF", "test"))
        test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

        pred = []
        gt = []
        with torch.no_grad():
            for data in test_dataloader:
                inputs, labels = data
                inputs = inputs.float().cuda()
                labels = torch.tensor(
                    [self.mapping[label] for label in labels], dtype=torch.long
                ).cuda()

                outputs = self.model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                pred.append(predicted.cpu().numpy())
                gt.append(labels.cpu().numpy())

        pred = np.concatenate(pred)
        gt = np.concatenate(gt)
        print(f"F1 score: {f1_score(gt, pred, average='macro')}")
        print(f"Accuracy: {accuracy_score(gt, pred)}")

    def store_model(self, out_dir: str) -> None:
        """store trained model in out_dir"""
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)

        torch.save(self.model.state_dict(), os.path.join(out_dir, "model.pth"))


if __name__ == "__main__":
    datasets = ["HHAR"]
    for dataset in datasets:
        at = AuxTrainer(f"path_to_data/{dataset}")
        at.train(num_epochs=2, lr=0.001, batch_size=64)
        at.test()
        at.store_model(f"path_to_data/{dataset}")
