from dataclasses import dataclass, field
import json
import math
import logging
import os
from typing import Dict, Optional, List
import torch
from torch.utils.data import Dataset
import transformers
from transformers import Trainer
import cv2
import numpy as np
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import random

import torch.distributed as dist
import torch.nn as nn
import re
import torch.nn.functional as F
import pandas as pd
from decord import VideoReader, gpu

import videoclipxl

def get_latest_checkpoint(folder_path):
    regex_pattern = r'^checkpoint-(\d+)$'
    checkpoint_dirs = []
    
    if not os.path.exists(folder_path):
        return None
    
    for item in os.listdir(folder_path):
        full_path = os.path.join(folder_path, item)
        
        if os.path.isdir(full_path):
            match = re.match(regex_pattern, item)
            if match:
                num_part = int(match.group(1))
                checkpoint_dirs.append((item, num_part))
                
    if checkpoint_dirs:
        latest_checkpoint = max(checkpoint_dirs, key=lambda x: x[1])[0]
        return os.path.join(folder_path, latest_checkpoint)
    else:
        return None
    
    
@dataclass
class TrainingArguments(transformers.TrainingArguments):
    ### training args

    
def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    # if use distributed training
    if not is_dist_avail_and_initialized():
        return tensor

    tensors_gather = [
        torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())
    ]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output

local_rank = None

def rank0_print(*args):
    if local_rank == 0:
        print(*args)


class train_dataset(Dataset):
    def __init__(self, training_args,):
        self.training_args = training_args  
        self.json_data = []
        for k,v in train_dataset_list.items():
            json_name = v[0]
            with open(json_name, 'r',encoding='utf8') as fp:
                data = fp.read().splitlines()
            data = [(json.loads(q), k) for q in data]
            self.json_data += data

        
    def __len__(self):
        return len(self.json_data)

    def pre_text(self, text,):
        text = text.replace('\n', '')
        text = re.sub(r"\s{2,}", ' ', text)
        text = text.strip(' ')
        return text

    def __getitem__(self, index):
        data = self.json_data[index]
        data, name = data
        video_path = data["video_path"]
        # import pdb;pdb.set_trace()

        caption = self.pre_text(random.choice(data["captions"])) # long caption
        caption_short = self.pre_text(random.choice(data["short_captions"])) # short caption

        caption = videoclipxl.tokenize(caption, truncate=True)[0] # t
        caption_short = videoclipxl.tokenize(caption_short, truncate=True)[0] # t
        
        # HDR
        hdr_all_texts = videoclipxl.tokenize(data["hdr_captions"], truncate=True) # m, t
        # DDR
        ddr_all_texts = videoclipxl.tokenize(data["ddr_captions"], truncate=True) # m, t
             
        try:
            vr = VideoReader(video_path, num_threads=2,)
            frame_idx = frame_sample_func(vr, self.training_args.frame_num)
            
            vid_tube = [] 
            vr = vr.get_batch(frame_idx).asnumpy()
            for now in vr:
                now = cv2.resize(now, (224, 224))
                now = augmentation_func(now)
                now = np.expand_dims(normalize(now), axis=(0, 1))
                vid_tube.append(now) 
            vid_tube = np.concatenate(vid_tube, axis=1)
            vid_tube = np.transpose(vid_tube, (0, 1, 4, 2, 3))
            vid_tube = torch.from_numpy(vid_tube)[0]

        except Exception as error:
            print(f"Load Video Error: {error}. Select a new one.")
            new_index = random.randint(0, len(self) - 1)
            return self.__getitem__(new_index)

        del vr
        
        output = {
            "images": vid_tube, 
            "caption": caption, 
            "short_caption": caption_short,
            "hdr_all_texts": hdr_all_texts,
            "ddr_all_texts": ddr_all_texts,
        }

        return output


def make_supervised_data_module(training_args,) -> Dict:
    """Make dataset and collator for supervised fine-tuning.""" 
    rank0_print("Loading data...")

    train_ = train_dataset(training_args,)
    return dict(train_dataset=train_,)


class VideoCLIPXL(nn.Module):
    def __init__(self, training_args):
        super(VideoCLIPXL, self).__init__()
        self.training_args = training_args
        self.model = load_model_func()
        self.model.logit_scale = torch.nn.Parameter(torch.ones([]) * training_args.log_scale)


    def forward(
        self, 
        images, caption, short_caption=None,
        hdr_all_texts=None, ddr_all_texts=None,
    ):
    
        loss, image_features, text_features = self.inference(images, caption)
        
        try:
            loss_short = self.inference_short(image_features, short_caption, text_features)
            loss = loss + self.training_args.ratio_short_caption * loss_short

        except Exception as e:
            if self.training:
                print(f"Skip short caption loss with the Error: {e}")

        try:
            loss_hdr = self.training_args.hdr_loss_ratio * self.inference_rank(image_features, hdr_all_texts,)
            loss = loss + loss_hdr
        except Exception as e:
            if self.training:
                print(f"{e}")
            pass

        try:
            loss_ddr = self.training_args.ddr_loss_ratio * self.inference_rank(image_features, ddr_all_texts,)
            loss = loss + loss_ddr
        except Exception as e:
            if self.training:
                print(f"{e}")
            pass
            
        return dict(
            loss=loss,
            text_features=text_features,
            image_features=image_features,
        )
    
    
    def PCA(self, input_tensor, avg_sim_txt=None):
        mean = torch.mean(input_tensor, dim=0)
        X_centered = input_tensor - mean.unsqueeze(0)
        X_centered = X_centered.float()
        cov_matrix = torch.mm(X_centered.T, X_centered).float()
        eigenvalues, eigenvectors = torch.linalg.eig(cov_matrix)
        eigenvalues = eigenvalues.float()
        eigenvectors = eigenvectors.float()    
        sorted_indices = torch.argsort(eigenvalues, descending=True)
        eigenvectors = eigenvectors[:, sorted_indices]
        
        PCA_dim = len(eigenvalues)
        
        thresh = avg_sim_txt
        eigenvalues = eigenvalues[sorted_indices]
        eigenvalues[eigenvalues < 0] = 0
        for ei in range(1, len(eigenvalues) + 1):
            if eigenvalues[:ei].sum() / eigenvalues.sum() >= thresh:
                PCA_dim = ei
                break
        
        principal_components = eigenvectors[:, :PCA_dim]
        
        X_transformed = torch.mm(X_centered, principal_components)
        X_reversed = torch.mm(X_transformed, principal_components.T)
        X_reversed += mean
        
        return X_reversed

    def inference(self, images, texts): 
        image_features = self.model.encode_video(images).float()
        sub_image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        
        text_features = self.model.encode_text(texts).float()
        sub_text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        

        image_feat_all = concat_all_gather(sub_image_features)
        text_feat_all = concat_all_gather(sub_text_features)
        
        sim_i2t = torch.matmul(sub_image_features, text_feat_all.T)
        sim_t2i = torch.matmul(image_feat_all, sub_text_features.T)
        sim_t2i = sim_t2i.T
        
        sim_i2t = self.model.logit_scale.exp() * sim_i2t
        sim_t2i = self.model.logit_scale.exp() * sim_t2i
        
        if is_dist_avail_and_initialized():
            rank = dist.get_rank()
        else:
            rank = 0
        bs = sub_image_features.shape[0]
        targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(
            images.device
        )
        
        loss_itc = (
                F.cross_entropy(sim_i2t.float(), targets, label_smoothing=0.1)
                + F.cross_entropy(sim_t2i.float(), targets, label_smoothing=0.1)
            ) / 2
        return loss_itc, sub_image_features, sub_text_features


    def inference_short(self, image_features_origin, texts, text_features_origin): 
        text_features = self.model.encode_text(texts).float()
        sub_text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        with torch.no_grad():
            avg_sim_txt = torch.bmm(sub_text_features.unsqueeze(1), text_features_origin.unsqueeze(-1)).squeeze().mean()

        sub_image_features = self.PCA(image_features_origin, avg_sim_txt=avg_sim_txt).float()

        image_feat_all = concat_all_gather(sub_image_features)
        text_feat_all = concat_all_gather(sub_text_features)

        sim_i2t = torch.matmul(sub_image_features, text_feat_all.T)
        sim_t2i = torch.matmul(image_feat_all, sub_text_features.T)
        sim_t2i = sim_t2i.T
        
        sim_i2t = self.model.logit_scale.exp() * sim_i2t
        sim_t2i = self.model.logit_scale.exp() * sim_t2i
        
        if is_dist_avail_and_initialized():
            rank = dist.get_rank()
        else:
            rank = 0
        bs = sub_text_features.shape[0]
        targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(
            texts.device
        )
        
        loss_itc = (
                F.cross_entropy(sim_i2t.float(), targets, label_smoothing=0.1)
                + F.cross_entropy(sim_t2i.float(), targets, label_smoothing=0.1)
            ) / 2
        return loss_itc


    def inference_rank(self, image_features, rank_text,): 
        bs, ts, ls = rank_text.shape
        rank_text = rank_text.reshape(-1, rank_text.shape[-1])

        text_features = self.model.encode_text(rank_text).float()
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        text_features = text_features.reshape(bs, ts, text_features.shape[-1]) # b,n,d

        sim_score_each = torch.bmm(text_features, image_features.unsqueeze(-1)).squeeze(-1) # b,n

        score_diff = []
        for start_n in range(sim_score_each.shape[-1]-1):
            for each_n in range(start_n+1, sim_score_each.shape[-1]):
                now_diff = sim_score_each[:, start_n] - sim_score_each[:, each_n] # b,
                score_diff.append(now_diff.unsqueeze(-1))
                
        score_diff = torch.cat(score_diff, -1) # b,Np
        loss_rank = torch.relu(-score_diff).mean()
            
        return loss_rank
    
    
def train():
    global local_rank

    parser = transformers.HfArgumentParser(
        (TrainingArguments,)
    )
    (
        training_args,
    ) = parser.parse_args_into_dataclasses()
 
    local_rank = training_args.local_rank

    world_size = int(os.environ.get("WORLD_SIZE", 1))
    ddp = world_size != 1
     
    # Load data
    data_module = make_supervised_data_module(training_args,)
 
    # laod model
    model = VideoCLIPXL(training_args,) 

    trainer = Trainer(
        model=model, 
        args=training_args,
        **data_module
    )

    latest_checkpoint_path = get_latest_checkpoint(training_args.output_dir,)
    if latest_checkpoint_path:
        rank0_print('#'*80)
        rank0_print(f"Resuming from {latest_checkpoint_path}")
        rank0_print('#'*80)
        trainer.train(resume_from_checkpoint=latest_checkpoint_path)
    else:
        rank0_print('#'*80)
        rank0_print("No checkpoint found, train from scratch.")
        rank0_print('#'*80)
        trainer.train()
    
    trainer.save_model()

if __name__ == "__main__":
    train()
