import torch
import torch.nn as nn
import torch.nn.functional as K

import transformers

from transformers import AdamW, get_linear_schedule_with_warmup
from transformers import AutoTokenizer, AutoProcessor, AutoFeatureExtractor, AutoModel, AutoModelForCausalLM
from transformers import TrainingArguments, Trainer, logging

import os
import numpy as np
import pandas as pd
import json
from copy import deepcopy

import time
import math
import datasets
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
from collections import OrderedDict
from PIL import Image
from sklearn.metrics import accuracy_score, f1_score

os.environ["TOKENIZERS_PARALLELISM"] = "false" 

import argparse
parser = argparse.ArgumentParser()

parser.add_argument('--output_dir',default=None,type=str)
parser.add_argument('--teacher_ckpt',default=None,type=str)
parser.add_argument('--student_ckpt',default=None,type=str)

parser.add_argument('--pp_trn_path',default=None,type=str)
parser.add_argument('--pp_dev_path',default=None,type=str)
parser.add_argument('--pp_tst_path',default=None,type=str)

parser.add_argument('--max_seq_len',default=40,type=int)
parser.add_argument('--q_min_len',default=10,type=int)
parser.add_argument('--batch_size',default=10,type=int)
parser.add_argument('--epochs',default=5,type=int)
parser.add_argument('--lr',default=3e-5,type=float)
parser.add_argument('--random_seed',default=42,type=int)
parser.add_argument('--scheduler_name',default='linear',type=str)
parser.add_argument('--fp16',default=False,type=bool)
parser.add_argument('--warmup_ratio',default=0.0,type=float)
parser.add_argument('--ga_steps',default=1,type=int)
parser.add_argument('--kdloss_ratio',default=0.5,type=float)

args = parser.parse_args()

print(args)


device = torch.device("cuda")
print('There are %d GPU(s) available.' % torch.cuda.device_count())
print('We will use the GPU:', torch.cuda.get_device_name(0))

dataset_train = datasets.load_from_disk(args.pp_trn_path)
dataset_val = datasets.load_from_disk(args.pp_dev_path)

dataset_test = datasets.load_from_disk(args.pp_tst_path)

@dataclass
class FlickrCapTransferCollator:
    tokenizer_S: AutoTokenizer
    preprocessor_S: AutoFeatureExtractor
        
    tokenizer_T: AutoTokenizer
    preprocessor_T: AutoFeatureExtractor
 
    def process_texts(self, texts: List[str]):
        input_ids_raw_S = self.tokenizer_S(text=texts,
                                     padding='do_not_pad',
                                     max_length=args.max_seq_len-2,
                                     truncation=True,).input_ids
        
        input_ids_raw_T = self.tokenizer_T(text=texts,
                                     padding='do_not_pad',
                                     max_length=args.max_seq_len-2,
                                     truncation=True,).input_ids
        input_ids_S = []
        label_ids_S = []
        attention_masks_S = []
        
        for i_b in range(len(input_ids_raw_S)):

            input_ids_S.append([self.tokenizer_S.bos_token_id]+input_ids_raw_S[i_b]+[self.tokenizer_S.eos_token_id]+[self.tokenizer_S.pad_token_id]*(args.max_seq_len - 2-len(input_ids_raw_S[i_b])))
            label_ids_S.append([self.tokenizer_S.bos_token_id]+input_ids_raw_S[i_b]+[self.tokenizer_S.eos_token_id]+[-100]*(args.max_seq_len - 2-len(input_ids_raw_S[i_b])))
            attention_masks_S.append([1]*(len(input_ids_raw_S[i_b])+2)+[0]*(args.max_seq_len - 2-len(input_ids_raw_S[i_b])))
        
        input_ids_S = torch.tensor(input_ids_S)
        label_ids_S = torch.tensor(label_ids_S)
        attention_masks_S = torch.tensor(attention_masks_S)
        
        input_ids_T = []
        label_ids_T = []
        attention_masks_T = []
        
        for i_b in range(len(input_ids_raw_T)):

            input_ids_T.append([self.tokenizer_T.cls_token_id]+input_ids_raw_T[i_b]+[self.tokenizer_T.sep_token_id]+[self.tokenizer_T.pad_token_id]*(args.max_seq_len - 2-len(input_ids_raw_T[i_b])))
            label_ids_T.append([self.tokenizer_T.cls_token_id]+input_ids_raw_T[i_b]+[self.tokenizer_T.sep_token_id]+[-100]*(args.max_seq_len - 2-len(input_ids_raw_T[i_b])))
            attention_masks_T.append([1]*(len(input_ids_raw_T[i_b])+2)+[0]*(args.max_seq_len - 2-len(input_ids_raw_T[i_b])))
        
        input_ids_T = torch.tensor(input_ids_T)
        label_ids_T = torch.tensor(label_ids_T)
        attention_masks_T = torch.tensor(attention_masks_T)
        
        return {
            "input_ids_S": input_ids_S.squeeze(),
            "attention_mask_S": attention_masks_S.squeeze(),
            "labels_S": label_ids_S.squeeze(),
            
            "input_ids_T": input_ids_T.squeeze(),
            "attention_mask_T": attention_masks_T.squeeze(),
            "labels_T": label_ids_T.squeeze(),
        }
        

    def preprocess_images(self, images: List[str]):
        processed_images_S = self.preprocessor_S(
            images,
            return_tensors="pt",
        )
        processed_images_T = self.preprocessor_T(
            images,
            return_tensors="pt",
        )
        return {
            "pixel_values_S": processed_images_S['pixel_values'].squeeze(),
            "pixel_values_T": processed_images_T['pixel_values'].squeeze(),
        }
            
    def __call__(self, raw_batch_dict):
        return {
            **self.process_texts(
                raw_batch_dict['caption']
                if isinstance(raw_batch_dict, dict) else
                [i['caption'] for i in raw_batch_dict],
            ),
            **self.preprocess_images(
                raw_batch_dict['image'].convert('RGB')
                if isinstance(raw_batch_dict, dict) else
                [i['image'].convert('RGB') for i in raw_batch_dict]
            ),

        }

class GITForTrainFlickrCap(nn.Module):
    def __init__(self, git_model):
        super(GITForTrainFlickrCap, self).__init__()
        
        self.git_model = git_model
        
    def forward(self,
                input_ids: torch.LongTensor,
                pixel_values: torch.FloatTensor,
                attention_mask: Optional[torch.LongTensor] = None,
                labels: Optional[torch.LongTensor] = None):
        
        outputs = self.git_model(input_ids = input_ids,
                                 attention_mask = attention_mask,
                                 pixel_values = pixel_values,
                                 labels = labels)
        
        out = {
            "logits": outputs.logits,
        }
        if labels is not None:
            out["loss"] = outputs.loss
        
        return out

class CEDMForTrainFlickrCap(nn.Module):
    def __init__(self, text_model, image_model):
        super(CEDMForTrainFlickrCap, self).__init__()
        
        self.text_model = text_model
        self.image_model = image_model
        
        self.image_projection = nn.Sequential(nn.Linear(image_model.config.hidden_size, text_model.config.n_embd),
                                              nn.LayerNorm(text_model.config.n_embd))
        
    def forward(self,
                input_ids: torch.LongTensor,
                pixel_values: torch.FloatTensor,
                attention_mask: Optional[torch.LongTensor] = None,
                labels: Optional[torch.LongTensor] = None):
        
        encoded_image = self.image_model(pixel_values=pixel_values)
        encoded_image = self.image_projection(encoded_image.last_hidden_state)
        
        embedded_text = self.text_model.transformer.wte(input_ids)

        embedded_seq_vqa = torch.cat([encoded_image,embedded_text],dim=1)
        
        attention_mask = torch.cat([torch.ones(encoded_image.shape[:-1],dtype=torch.long).cuda(),attention_mask],dim=1)
        labels = torch.cat([(torch.ones(encoded_image.shape[:-1],dtype=torch.long)*-100).cuda(),labels],dim=1)
        
        outputs = self.text_model(inputs_embeds=embedded_seq_vqa, labels=labels)
        
        out = {
            "logits": outputs.logits,
        }
        if labels is not None:
            out["loss"] = outputs.loss
        
        return out

class TransferModelForTrainFlickrCap(nn.Module):
    def __init__(self, git_wrapper, student_model):
        super(TransferModelForTrainFlickrCap, self).__init__()
        
        self.git_wrapper = git_wrapper
        
        self.student_model = student_model

        self.projection_T2S = nn.Linear(git_wrapper.git_model.config.hidden_size, student_model.text_model.config.n_embd)

        self.criterion = nn.MSELoss()
        
    def forward(self,
                input_ids_S: torch.LongTensor,
                input_ids_T: torch.LongTensor,
                pixel_values_S: torch.FloatTensor,
                pixel_values_T: torch.FloatTensor,
                attention_mask_S: Optional[torch.LongTensor] = None,
                attention_mask_T: Optional[torch.LongTensor] = None,
                labels_S: Optional[torch.LongTensor] = None,               
                labels_T: Optional[torch.LongTensor] = None,):
        
        # Teacher forward path
        with torch.no_grad():
            outputs_T = self.git_wrapper.git_model(input_ids = input_ids_T,
                                 attention_mask = attention_mask_T,
                                 pixel_values = pixel_values_T,
                                 labels = labels_T,
                                 output_hidden_states=True)
            last_hidden_state_T = outputs_T.hidden_states[-1]
            encoded_image_T = outputs_T.hidden_states[0][:, :-args.max_seq_len, :]
        
        last_hidden_state_T = self.projection_T2S(last_hidden_state_T)
        encoded_image_T = self.projection_T2S(encoded_image_T)

        # Student forward path
        encoded_image = self.student_model.image_model(pixel_values=pixel_values_S)
        encoded_image = self.student_model.image_projection(encoded_image.last_hidden_state)
        
        embedded_text = self.student_model.text_model.transformer.wte(input_ids_S)

        embedded_seq_vqa = torch.cat([encoded_image,embedded_text],dim=1)
        
        attention_mask_input_S = torch.cat([torch.ones(encoded_image.shape[:-1],dtype=torch.long).cuda(),attention_mask_S],dim=1)
        
        labels_S = torch.cat([(torch.ones(encoded_image.shape[:-1],dtype=torch.long)*-100).cuda(),labels_S],dim=1)
        
        outputs_S = self.student_model.text_model(inputs_embeds=embedded_seq_vqa,
                                                  attention_mask=attention_mask_input_S,
                                                  labels=labels_S,
                                                  output_hidden_states=True)
        
        last_hidden_state_S = outputs_S.hidden_states[-1]
        
        
        # Distillation Loss calculation
        loss_mse_image_pre = self.criterion(torch.mean(encoded_image_T, dim=1),
                                            torch.mean(encoded_image, dim=1))
        
        loss_mse_image_post = self.criterion(torch.mean(last_hidden_state_S[:, :-args.max_seq_len, :], dim=1),
                                             torch.mean(last_hidden_state_T[:, :-args.max_seq_len, :], dim=1))
        
        
        mask_expanded_T = attention_mask_T.unsqueeze(-1)
        mask_expanded_S = attention_mask_S.unsqueeze(-1)
        
        masked_tensor_T = torch.where(mask_expanded_T == 1, last_hidden_state_T[:, -args.max_seq_len:, :], 
                                         torch.zeros_like(last_hidden_state_T[:, -args.max_seq_len:, :]))
        masked_tensor_S = torch.where(mask_expanded_S == 1, last_hidden_state_S[:, -args.max_seq_len:, :], 
                                         torch.zeros_like(last_hidden_state_S[:, -args.max_seq_len:, :]))

        valid_lengths_T = attention_mask_T.sum(dim=1).unsqueeze(-1).clamp(min=1)
        valid_lengths_S = attention_mask_S.sum(dim=1).unsqueeze(-1).clamp(min=1)

        pooled_text_encoding_T = masked_tensor_T.sum(dim=1) / valid_lengths_T
        pooled_text_encoding_S = masked_tensor_S.sum(dim=1) / valid_lengths_S
        
        loss_mse_text_post = self.criterion(pooled_text_encoding_S, pooled_text_encoding_T)
        
        out = {
            "logits": outputs_S.logits,
            "loss": (1 - args.kdloss_ratio) * outputs_S.loss + args.kdloss_ratio * (loss_mse_image_pre + loss_mse_image_post + loss_mse_text_post)
        }
        
        return out


git_processor = AutoProcessor.from_pretrained("microsoft/git-base")
git_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base")

teacher_model = GITForTrainFlickrCap(git_model=git_model)
teacher_model.load_state_dict(torch.load(args.teacher_ckpt))


text_model = transformers.GPT2LMHeadModel.from_pretrained("gpt2-medium")

tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2-medium",pad_token="<pad>")

text_model.resize_token_embeddings(len(tokenizer))

image_model = transformers.SwinModel.from_pretrained("microsoft/swin-base-patch4-window12-384-in22k")
processor = transformers.AutoFeatureExtractor.from_pretrained("microsoft/swin-base-patch4-window12-384-in22k")

student_model = CEDMForTrainFlickrCap(text_model=text_model,image_model=image_model)
student_model.load_state_dict(torch.load(args.student_ckpt))

collator = FlickrCapTransferCollator(tokenizer_T=git_processor.tokenizer, preprocessor_T=git_processor.image_processor,
                                     tokenizer_S=tokenizer, preprocessor_S=processor)

for p in teacher_model.parameters():
    p.requires_grad = False

model = TransferModelForTrainFlickrCap(git_wrapper=teacher_model,student_model=student_model)

multi_args = TrainingArguments(
    output_dir=args.output_dir,
    seed=args.random_seed, 
    learning_rate=args.lr,
    #evaluation_strategy="steps",
    #eval_steps=100,
    #save_strategy="steps",
    #save_steps=100,
    evaluation_strategy="no",
    save_strategy="epoch",
    logging_strategy="steps",
    logging_steps=500,
    save_total_limit=3,
    metric_for_best_model='acc',
    per_device_train_batch_size=args.batch_size,
    per_device_eval_batch_size=args.batch_size,
    remove_unused_columns=False,
    num_train_epochs=args.epochs,
    dataloader_num_workers=8,
    load_best_model_at_end=False,
    lr_scheduler_type=args.scheduler_name,
    eval_accumulation_steps=500,
    gradient_accumulation_steps=args.ga_steps,
    warmup_ratio=args.warmup_ratio,
    save_safetensors=False

)
trainer = Trainer(
    model,
    multi_args,
    train_dataset=dataset_train,
    eval_dataset=None,
    data_collator=collator
)

train_metrics = trainer.train()

model.eval()

all_answers = []

i = 0
start_time = time.time()
while i < len(dataset_test):
    
    if i % (args.batch_size*50) == 0:
        print(i,"/",len(dataset_test)," // time elapsed:",time.time()-start_time)
    
    sample = dataset_test[i:i+args.batch_size]
    
    sample['image'] = [item.convert('RGB') for item in sample['image']]
    
    processed_images_swin = processor(images=sample['image'],return_tensors="pt",)
    pixel_values_swin = processed_images_swin['pixel_values'].cuda()
      
    inputs = tokenizer([tokenizer.bos_token]*args.batch_size, add_special_tokens=False,
                       return_tensors='pt')
    
    encoded_image = model.student_model.image_model(pixel_values=pixel_values_swin)
    encoded_image = model.student_model.image_projection(encoded_image.last_hidden_state)
    
    inputs.attention_mask = torch.cat((torch.ones(encoded_image.shape[:2]).cuda(),inputs.attention_mask.cuda()),dim=1)

    embedded_text = model.student_model.text_model.transformer.wte(inputs.input_ids.cuda())

    embedded_seq_vqa = torch.cat([encoded_image,embedded_text],dim=1)

    generated_ids = model.student_model.text_model.generate(inputs_embeds=embedded_seq_vqa,attention_mask=inputs.attention_mask, max_length=50)

    all_answers += tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    
    i += args.batch_size

with open(os.path.join(args.output_dir,"results.json"),"w") as f:
    json.dump(all_answers,f)