import json
import glob
import os
import numpy as np
import pandas as pd

import torch
import pickle
from sklearn.preprocessing import LabelEncoder


class FinetuneDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        split_file,
        root_dir,
        dataset="include",
        seq_len=10,
        num_seq=6,
        splits=["train"],
        modality="pose",
        class_mappings_file_path=None,
        transforms=None,
    ):
        self.data = []
        self.glosses = []
        self.root_dir = root_dir
        self.class_mappings_file_path = class_mappings_file_path
        if dataset == "include":
            self.read_index_file_include(split_file, splits, modality)
        elif dataset == "wlasl":
            self.read_index_file_wlasl(split_file, splits, modality)
        elif dataset == "autsl":
            self.read_index_file_autsl(split_file, splits, modality)
        elif dataset == "csl":
            self.read_index_file_csl(split_file, splits, modality)

        self.seq_len = seq_len
        self.num_seq = num_seq
        self.transforms = transforms

    def read_index_file_include(self, index_file_path, splits, modality="rgb"):
        # `splits` is not used here as we pass the split-specific CSV directly
        df = pd.read_csv(index_file_path)

        self.glosses = sorted({df["Word"][i].strip() for i in range(len(df))})
        label_encoder = LabelEncoder()
        label_encoder.fit(self.glosses)

        for i in range(len(df)):
            gloss_cat = label_encoder.transform([df["Word"][i]])[0]
            instance_entry = df["FilePath"][i], gloss_cat

            video_path = os.path.join(self.root_dir, df["FilePath"][i])
            if "rgb" in modality and not os.path.isfile(video_path):
                print(f"Video not found: {video_path}")
                continue
            if "/Second (Number)/" in video_path:
                print(f"WARNING: Skipping {video_path} assuming no present")
                continue

            self.data.append(instance_entry)
        if not self.data:
            exit("No data found")

    def read_index_file_autsl(self, index_file_path, splits, modality):
        class_mappings_df = pd.read_csv(self.class_mappings_file_path)
        self.id_to_glosses = dict(
            zip(class_mappings_df["ClassId"], class_mappings_df["TR"])
        )
        self.glosses = sorted(self.id_to_glosses.values())

        df = pd.read_csv(index_file_path, header=None)

        if modality == "rgb":
            file_suffix = "color.mp4"
        elif modality == "pose":
            file_suffix = "color.pkl"

        for i in range(len(df)):
            instance_entry = df[0][i] + "_" + file_suffix, df[1][i]
            self.data.append(instance_entry)

    def read_index_file_wlasl(self, index_file_path, splits, modality="rgb"):
        with open(index_file_path, "r") as f:
            content = json.load(f)

        self.glosses = sorted([gloss_entry["gloss"] for gloss_entry in content])
        label_encoder = LabelEncoder()
        label_encoder.fit(self.glosses)

        for gloss_entry in content:
            gloss, instances = gloss_entry["gloss"], gloss_entry["instances"]
            gloss_cat = label_encoder.transform([gloss])[0]

            for instance in instances:
                if instance["split"] not in splits:
                    continue

                video_id = instance["video_id"]
                instance_entry = video_id, gloss_cat
                self.data.append(instance_entry)

        if not self.data:
            exit(f"ERROR: No {splits} data found")

    def read_index_file_csl(self, index_file_path, splits, modality="rgb"):
        """
        Check the file "DEVISIGN Technical Report.pdf" inside `Documents\` folder
        for dataset format (page 12) and splits (page 15)
        """
        self.glosses = list(range(2000))

        if "rgb" in modality:
            common_filename = "color.avi"
        elif "pose" in modality:
            common_filename = "pose.pkl"
        else:
            raise NotImplementedError

        video_files_path = os.path.join(self.root_dir, "**", common_filename)
        video_files = glob.glob(video_files_path, recursive=True)
        if not video_files:
            exit(f"No videos files found for: {video_files_path}")

        signs = set()
        for video_file in video_files:
            naming_parts = video_file.replace("\\", "/").split("/")[-2].split("_")
            gloss_id = int(naming_parts[1])
            signs.add(gloss_id)
            signer_id = int(naming_parts[0].replace("P", ""))

            if (signer_id <= 4 and "train" in splits) or (
                signer_id > 4 and "test" in splits
            ):
                instance_entry = video_file, gloss_id
                self.data.append(instance_entry)

    def load_pose_from_path(self, path):
        """
        Load dumped pose keypoints.
        Should contain: {
            "keypoints" of shape (T, V, C),
            "confidences" of shape (T, V),
            "vid_shape" of shape (W, H)
        }
        """
        pose_data = pickle.load(open(path, "rb"))
        return pose_data

    def read_pose_data(self, index):
        video_name, label = self.data[index]
        video_path = os.path.join(self.root_dir, video_name)
        pose_path = (
            video_path if os.path.isdir(video_path) else os.path.splitext(video_path)[0]
        )
        pose_path = pose_path + ".pkl"
        pose_data = self.load_pose_from_path(pose_path)
        pose_data["label"] = torch.tensor(label, dtype=torch.long)
        return pose_data, pose_path

    def __getitem_pose(self, index):
        """
        Returns
        C - num channels
        T - num frames
        V - num vertices
        M - num persons
        """
        data, path = self.read_pose_data(index)
        label = data["label"]
        # imgs shape: (T, V, C)
        kps = data["keypoints"]
        scores = data["confidences"]

        kps = kps[:, :, :2]

        # Expand to 4 dim for person dim
        if kps.ndim == 3:
            kps = np.expand_dims(kps, axis=-1)

        kps = np.asarray(kps, dtype=np.float32)
        data = {
            "frames": torch.tensor(kps).permute(2, 0, 1, 3),  # (C, T, V, M )
        }

        if self.transforms is not None:
            data = self.transforms(data)

        kps = data["frames"].squeeze(-1).permute(1, 2, 0).numpy()
        if kps.shape[0] < self.seq_len * self.num_seq:
            pad_kps = np.zeros(
                ((self.seq_len * self.num_seq) - kps.shape[0], *kps.shape[1:])
            )
            kps = np.concatenate([pad_kps, kps])

        elif kps.shape[0] > self.seq_len * self.num_seq:
            kps = kps[: self.seq_len * self.num_seq, ...]

        SL = kps.shape[0]
        clips = []
        i = 0
        while i + self.seq_len <= SL:
            clips.append(torch.tensor(kps[i : i + self.seq_len, :]))
            i += self.seq_len

        T, V, C = clips[0].shape
        t_seq = torch.stack(clips, 0)
        t_seq = t_seq.view(self.num_seq, T, V, C)

        return t_seq, label

    def __getitem__(self, index):
        return self.__getitem_pose(index)

    @property
    def num_class(self):
        return len(self.glosses)

    def __len__(self):
        return len(self.data)
