import torch
import torch.nn as nn
import torch.nn.functional as K

import transformers

from transformers import AdamW, get_linear_schedule_with_warmup
from transformers import ViTConfig, ViTModel
from transformers import T5Config, T5EncoderModel, T5Tokenizer
from transformers import FlavaConfig, FlavaProcessor, FlavaModel
from transformers import AutoTokenizer, AutoProcessor, AutoFeatureExtractor, AutoModel
from transformers import TrainingArguments, Trainer, logging

import os
import numpy as np
import pandas as pd
import json
from copy import deepcopy

import datasets
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from collections import OrderedDict
from PIL import Image
from sklearn.metrics import accuracy_score, f1_score

from tqdm.notebook import tqdm

import argparse
parser = argparse.ArgumentParser()

parser.add_argument('--label_dict_path',default=None,type=str)
parser.add_argument('--data_from_file',default=False,type=bool)
parser.add_argument('--data_path',default=None,type=str)
parser.add_argument('--data_preprocessed',default=False,type=bool)
parser.add_argument('--pp_trn_path',default=None,type=str)
parser.add_argument('--pp_dev_path',default=None,type=str)
parser.add_argument('--pp_tst_path',default=None,type=str)

parser.add_argument('--text_model',default='t5-base',type=str)
parser.add_argument('--vision_model',default="google/vit-base-patch16-224-in21k",type=str)

parser.add_argument('--flava_model',default="facebook/flava-full",type=str)

parser.add_argument('--finetuned_flava_model',default=None,type=str)
parser.add_argument('--finetuned_student_model',default=None,type=str)

parser.add_argument('--output_dir',default=None,type=str)

parser.add_argument('--max_seq_len',default=24,type=int)
parser.add_argument('--batch_size',default=10,type=int)
parser.add_argument('--epochs',default=5,type=int)
parser.add_argument('--lr',default=3e-5,type=float)
parser.add_argument('--random_seed',default=42,type=int)
parser.add_argument('--dropout',default=0.5,type=float)
parser.add_argument('--scheduler_name',default='linear',type=str)
parser.add_argument('--intermediate_dim',default=1536,type=int)
parser.add_argument('--fp16',default=False,type=bool)
parser.add_argument('--warmup_ratio',default=0.0,type=float)
parser.add_argument('--mm_depth',default=1,type=int)
parser.add_argument('--atloss_ratio',default=0.5,type=float)
parser.add_argument('--atloss_mode',default='all',type=str)
parser.add_argument('--ga_steps',default=1,type=int)

parser.add_argument('--do_train',default=False,type=bool)
parser.add_argument('--do_eval',default=False,type=bool)
parser.add_argument('--do_predict',default=False,type=bool)

args = parser.parse_args()

os.environ["TOKENIZERS_PARALLELISM"] = "false" 

assert args.atloss_mode in ['all','mm_only','no_mm']


device = torch.device("cuda")
print('There are %d GPU(s) available.' % torch.cuda.device_count())
print('We will use the GPU:', torch.cuda.get_device_name(0))

def get_score(count: int) -> float:
    return min(1.0, count / 3)

def generate_target_labels(dataset_answers,answer_label_map_dict,data_type):
    all_labels = []
    all_scores = []
    all_targets = []
    if data_type == 'train':
        for answers in tqdm(dataset_answers):
            answer_count = {}
            for answer in answers:
                answer_ = answer["answer"]
                answer_count[answer_] = answer_count.get(answer_, 0) + 1
            labels = []
            scores = []
            for answer in answer_count:
                if answer not in list(answer_label_map_dict.keys()):
                    continue
                labels.append(int(answer_label_map_dict[answer]))
                score = get_score(answer_count[answer])
                scores.append(score)
            all_labels.append(labels)
            all_scores.append(scores)
        
        for labels,scores in tqdm(zip(all_labels,all_scores)):
            targets = np.zeros(len(answer_label_map_dict))

            for label, score in zip(labels, scores):
                  targets[label] = score
            all_targets.append(targets)
    else:
        for mc_answer in tqdm(dataset_answers):
            label = int(answer_label_map_dict[mc_answer]) if mc_answer in answer_label_map_dict.keys() else 125
            targets = np.zeros(len(answer_label_map_dict))
            targets[label] = 1.0
            all_targets.append(targets)
        
    return all_targets

with open(args.label_dict_path,'r') as f:
    label_answer_map_dict = json.load(f)
    answer_label_map_dict = {v: k for k, v in label_answer_map_dict.items()}

if args.data_preprocessed == False:
    if args.data_from_file:
        print("Load dataset from disk ... ")
        dataset_vqav2 = datasets.load_from_disk(args.data_path)
        print("Done ! ")
    else:
        print("Load dataset from hub and cache ... ")
        dataset_vqav2 = datasets.load_dataset("HuggingFaceM4/VQAv2")
        print("Done ! ")

    dataset_train = deepcopy(dataset_vqav2['train'])
    dataset_val = deepcopy(dataset_vqav2['validation'])

    all_targets_trn = generate_target_labels(dataset_train['answers'],answer_label_map_dict,'train')
    all_targets_val = generate_target_labels(dataset_val['multiple_choice_answer'],answer_label_map_dict,'val')

    dataset_train = dataset_train.add_column("labels",all_targets_trn)
    dataset_val = dataset_val.add_column("labels",all_targets_val)
else:
    assert args.pp_trn_path != None and args.pp_dev_path != None
    print("Load preprocessed train and dev data ...")
    dataset_train = datasets.load_from_disk(args.pp_trn_path)
    dataset_val = datasets.load_from_disk(args.pp_dev_path)

    dataset_train = datasets.concatenate_datasets([dataset_train,dataset_val])
    dataset_val = None # Memory flush

    dataset_test = datasets.load_from_disk(args.pp_tst_path)

@dataclass
class VQACollatorForTransfer:
    tokenizer_T: AutoTokenizer
    preprocessor_T: AutoFeatureExtractor
    tokenizer_S: AutoTokenizer
    preprocessor_S: AutoFeatureExtractor    

    def tokenize_text_T(self, texts: List[str]):
        encoded_text = self.tokenizer_T(
            text=texts,
            padding='max_length',
            max_length=args.max_seq_len,
            truncation=True,
            return_tensors='pt',
            return_attention_mask=True,
        )
        return {
            "input_ids_T": encoded_text['input_ids'].squeeze(),
            "attention_mask_T": encoded_text['attention_mask'].squeeze(),
        }

    def preprocess_images_T(self, images: List[str]):
        processed_images = self.preprocessor_T(
            images,
            return_tensors="pt",
        )
        return {
            "pixel_values_T": processed_images['pixel_values'].squeeze(),
        }
    
    def tokenize_text_S(self, texts: List[str]):
        encoded_text = self.tokenizer_S(
            text=texts,
            padding='max_length',
            max_length=args.max_seq_len,
            truncation=True,
            return_tensors='pt',
            return_attention_mask=True,
        )
        return {
            "input_ids_S": encoded_text['input_ids'].squeeze(),
            "attention_mask_S": encoded_text['attention_mask'].squeeze(),
        }

    def preprocess_images_S(self, images: List[str]):
        processed_images = self.preprocessor_S(
            images,
            return_tensors="pt",
        )
        return {
            "pixel_values_S": processed_images['pixel_values'].squeeze(),
        }
            
    def __call__(self, raw_batch_dict):
        return {
            **self.tokenize_text_T(
                raw_batch_dict['question']
                if isinstance(raw_batch_dict, dict) else
                [i['question'] for i in raw_batch_dict]
            ),
            **self.preprocess_images_T(
                raw_batch_dict['image'].convert('RGB')
                if isinstance(raw_batch_dict, dict) else
                [i['image'].convert('RGB') for i in raw_batch_dict]
            ),
            **self.tokenize_text_S(
                raw_batch_dict['question']
                if isinstance(raw_batch_dict, dict) else
                [i['question'] for i in raw_batch_dict]
            ),
            **self.preprocess_images_S(
                raw_batch_dict['image'].convert('RGB')
                if isinstance(raw_batch_dict, dict) else
                [i['image'].convert('RGB') for i in raw_batch_dict]
            ),
            'labels': torch.tensor(
                raw_batch_dict['labels']
                if isinstance(raw_batch_dict, dict) else
                [i['labels'] for i in raw_batch_dict],
                dtype=torch.float64
            ),
        }

@dataclass
class VQAInferenceCollator:
    tokenizer: AutoTokenizer
    preprocessor: AutoFeatureExtractor

    def tokenize_text(self, texts: List[str]):
        encoded_text = self.tokenizer(
            text=texts,
            padding='max_length',
            max_length=args.max_seq_len,
            truncation=True,
            return_tensors='pt',
            return_attention_mask=True,
        )
        return {
            "input_ids": encoded_text['input_ids'].squeeze(),
            "attention_mask": encoded_text['attention_mask'].squeeze(),
        }

    def preprocess_images(self, images: List[str]):
        processed_images = self.preprocessor(
            images,
            return_tensors="pt",
        )
        return {
            "pixel_values": processed_images['pixel_values'].squeeze(),
        }
            
    def __call__(self, raw_batch_dict):
        return {
            **self.tokenize_text(
                raw_batch_dict['question']
                if isinstance(raw_batch_dict, dict) else
                [i['question'] for i in raw_batch_dict]
            ),
            **self.preprocess_images(
                raw_batch_dict['image'].convert('RGB')
                if isinstance(raw_batch_dict, dict) else
                [i['image'].convert('RGB') for i in raw_batch_dict]
            ),
            
        }

processor_T = FlavaProcessor.from_pretrained(args.flava_model)
tokenizer_S = AutoTokenizer.from_pretrained(args.text_model)
preprocessor_S = AutoFeatureExtractor.from_pretrained(args.vision_model)
multimodal_collator = VQACollatorForTransfer(tokenizer_T=processor_T.tokenizer,
                                             preprocessor_T=processor_T.feature_extractor,
                                             tokenizer_S=tokenizer_S,
                                             preprocessor_S=preprocessor_S)

collator_inf = VQAInferenceCollator(tokenizer=tokenizer_S, preprocessor=preprocessor_S)

def intuitive_loss(gs,prd):
    gs = np.array(gs)
    prd = np.array(prd)

    assert gs.shape == prd.shape
    
    def ro(x):
        res = []
        for item in x: 
            res.append(1 if item>0 else 0)
        return np.array(res)
    
    total = 0
    for item1, item2 in zip(gs,prd):
        diff = ro(item1) - ro(item2)
        total += np.linalg.norm(diff,1)
    return total/len(gs)

def criterion_alternative_L2(source, target, margin=1.0, scaler=1000.0):
    loss = ((source + margin)**2 * ((source > -margin) & (target <= 0)).float() +
            (source - margin)**2 * ((source <= margin) & (target > 0)).float())
    return torch.abs(loss).sum() / scaler

class FlavaForVQA(nn.Module):
    def __init__(self,  pretrained_flava_name, num_labels=len(answer_label_map_dict), dropout=0.5, classifier_hidden_size=1536):
        super(FlavaForVQA, self).__init__()
        self.num_labels = num_labels
        self.classifier_hidden_size = classifier_hidden_size
        
        self.pretrained_flava_name = pretrained_flava_name
        
        self.flava_model = FlavaModel.from_pretrained(self.pretrained_flava_name)
        
        # Fully-connected classifier
        
        self.intermediate_linear = nn.Linear(self.flava_model.config.multimodal_config.hidden_size, self.classifier_hidden_size)
    
        self.classifier = nn.Sequential(
            nn.LayerNorm(self.classifier_hidden_size),
            nn.GELU(),
            nn.Linear(self.classifier_hidden_size, self.num_labels),
        )
        
        self.criterion = nn.BCEWithLogitsLoss()
    
    def forward(
            self,
            input_ids: torch.LongTensor,
            pixel_values: torch.FloatTensor,
            attention_mask: Optional[torch.LongTensor] = None,
            labels: Optional[torch.LongTensor] = None):
        
        model_output = self.flava_model(input_ids=input_ids,
                                        attention_mask=attention_mask,
                                        pixel_values=pixel_values,
                                        return_dict=True,
                                        )
        intermediate_out = self.intermediate_linear(model_output['multimodal_output'].last_hidden_state[:,0,:])
        
        logits = self.classifier(intermediate_out)
        out = {
            "text_cls": model_output['text_output'].last_hidden_state[:,0,:],
            "image_cls": model_output['image_output'].last_hidden_state[:,0,:],
            "multimodal_cls": model_output['multimodal_output'].last_hidden_state[:,0,:],
            "intermediate_out":intermediate_out,
            "logits": logits
        }
        if labels is not None:
            loss = self.criterion(logits, labels)
            out["loss"] = loss
        
        return out
    
class BasicStudentModelForVQA(nn.Module):
    def __init__(self,  pretrained_text_name, pretrained_image_name, flava_mm_encoder,
                 num_labels=len(answer_label_map_dict), intermediate_dim=1536, dropout=0.5):
        super(BasicStudentModelForVQA, self).__init__()
        
        self.num_labels = num_labels
        self.pretrained_text_name = pretrained_text_name
        self.pretrained_image_name = pretrained_image_name
        
        self.intermediate_dim = intermediate_dim
        
        # Pretrained transformers for text & image featurization 
        self.text_encoder = T5EncoderModel.from_pretrained(self.pretrained_text_name)
        self.image_encoder = AutoModel.from_pretrained(self.pretrained_image_name)
        
        self.flava_mm_encoder = flava_mm_encoder

        self.text_projection = nn.Linear(self.text_encoder.config.hidden_size, self.flava_mm_encoder.config.hidden_size)
        self.image_projection = nn.Linear(self.image_encoder.config.hidden_size, self.flava_mm_encoder.config.hidden_size)
        
        
        self.fusion = nn.Linear(self.flava_mm_encoder.config.hidden_size, self.intermediate_dim)
        self.fusion_activations = nn.Sequential(nn.LayerNorm(self.intermediate_dim),nn.GELU())
        
        self.classifier = nn.Linear(self.intermediate_dim, self.num_labels)
        
        self.criterion = nn.BCEWithLogitsLoss()
    
    def forward(
            self,
            input_ids: torch.LongTensor,
            pixel_values: torch.FloatTensor,
            attention_mask: Optional[torch.LongTensor] = None,
            labels: Optional[torch.LongTensor] = None):
        
        encoded_text = self.text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True,
        )
        encoded_image = self.image_encoder(
            pixel_values=pixel_values,
            return_dict=True,
        )
        
        projected_text = self.text_projection(encoded_text['last_hidden_state'])
        projected_image = self.image_projection(encoded_image['last_hidden_state'])

        multimodal_output = self.flava_mm_encoder(torch.cat([projected_text,projected_image], dim=1))['last_hidden_state']
        
        logits = self.classifier(self.fusion_activations(self.fusion(multimodal_output[:,0,:])))
        
        out = {
            "text_cls": projected_text[:,0,:],
            "image_cls": projected_image[:,0,:],
            "multimodal_cls":multimodal_output[:,0,:],
            "logits": logits,
        }
        if labels is not None:
            loss = self.criterion(logits, labels)
            out["loss"] = loss
        
        return out


model_flava_ft = FlavaForVQA(pretrained_flava_name=args.flava_model,dropout=args.dropout)

sd = torch.load(args.finetuned_flava_model)

model_flava_ft.load_state_dict(sd)

model_flava = FlavaModel.from_pretrained(args.flava_model)
flava_mm_encoder = deepcopy(model_flava.multimodal_model)
model_flava=None # Mem flush

model_student_ft = BasicStudentModelForVQA(pretrained_text_name=args.text_model, 
                                           pretrained_image_name=args.vision_model,
                                           flava_mm_encoder=flava_mm_encoder,
                                           intermediate_dim=args.intermediate_dim).to(device)

if args.finetuned_student_model is not None:
    print("Restoring student model from fine-tuned checkpoint ... ")

    model_student_ft.load_state_dict(torch.load(args.finetuned_student_model))
else:
    print("Initialize the student model ... ")

class TransferModelForVQA(nn.Module):
    def __init__(self,  teacher_model, student_model, num_labels=len(answer_label_map_dict)):
        super(TransferModelForVQA, self).__init__()
        
        self.num_labels = num_labels
        self.teacher_model = teacher_model
        self.student_model = student_model
        
        self.criterion1 = nn.BCEWithLogitsLoss()
        self.criterion2 = nn.MSELoss()
    
    def forward(
            self,
            input_ids_T: torch.LongTensor,
            pixel_values_T: torch.FloatTensor,
            input_ids_S: torch.LongTensor,
            pixel_values_S: torch.FloatTensor,
            attention_mask_T: Optional[torch.LongTensor] = None,
            attention_mask_S: Optional[torch.LongTensor] = None,
            labels: Optional[torch.LongTensor] = None):
        
        with torch.no_grad():
            output_teacher = self.teacher_model(input_ids=input_ids_T,
                                                attention_mask=attention_mask_T,
                                                pixel_values=pixel_values_T)
        
        output_student = self.student_model(input_ids=input_ids_S,
                                            attention_mask=attention_mask_S,
                                            pixel_values=pixel_values_S,
                                            labels=labels)
        
        
        out = {
            "logits": output_student["logits"]
        }
        if labels is not None:
            loss = self.criterion1(output_student["logits"], labels)
            loss_AT_text = self.criterion2(output_student['text_cls'],output_teacher['text_cls'])
            loss_AT_image = self.criterion2(output_student['image_cls'],output_teacher['image_cls'])
            loss_AT_mm = self.criterion2(output_student['multimodal_cls'],output_teacher['multimodal_cls'])
            
            if args.atloss_mode == 'all':
                out["loss"] = (1 - args.atloss_ratio) * loss + args.atloss_ratio * (loss_AT_text + loss_AT_image + loss_AT_mm)
            elif args.atloss_mode == 'no_mm':
                out["loss"] = (1 - args.atloss_ratio) * loss + args.atloss_ratio * (loss_AT_text + loss_AT_image)
            else:
                out["loss"] = (1 - args.atloss_ratio) * loss + args.atloss_ratio * loss_AT_mm
        
        return out

for p in model_flava_ft.parameters():
    p.requires_grad = False
model_transfer = TransferModelForVQA(teacher_model=model_flava_ft,student_model=model_student_ft)

model_transfer.to(device)

def compute_metrics(eval_tuple: Tuple[np.ndarray, np.ndarray]) -> Dict[str, float]:
    logits, labels = eval_tuple
    labels = labels.argmax(axis=-1)
    preds = logits.argmax(axis=-1)
    return {
        "acc": accuracy_score(labels, preds),
        "f1": f1_score(labels, preds, average='macro')
    }

multi_args = TrainingArguments(
    output_dir=args.output_dir,
    seed=args.random_seed, 
    learning_rate=args.lr,
    #evaluation_strategy="steps",
    #eval_steps=100,
    #save_strategy="steps",
    #save_steps=100,
    evaluation_strategy="no",
    save_strategy="epoch",
    logging_strategy="steps",
    logging_steps=500,
    save_total_limit=5,
    metric_for_best_model='acc',
    per_device_train_batch_size=args.batch_size,
    per_device_eval_batch_size=args.batch_size,
    remove_unused_columns=False,
    num_train_epochs=args.epochs,
    dataloader_num_workers=8,
    load_best_model_at_end=False,
    lr_scheduler_type=args.scheduler_name,
    eval_accumulation_steps=2000,
    warmup_ratio=args.warmup_ratio,
    gradient_accumulation_steps=args.ga_steps,
)

multi_trainer = Trainer(
    model_transfer,
    multi_args,
    train_dataset=dataset_train,
    eval_dataset=None,
    data_collator=multimodal_collator,
    compute_metrics=compute_metrics
)

train_multi_metrics = multi_trainer.train()

st_dict_old = model_transfer.state_dict()
model_transfer = model_transfer.cpu()
del model_transfer
torch.cuda.empty_cache()

model_flava = FlavaModel.from_pretrained(args.flava_model)
flava_mm_encoder = deepcopy(model_flava.multimodal_model)
model_flava=None # Mem flush
model = BasicStudentModelForVQA(pretrained_text_name=args.text_model, 
                                 pretrained_image_name=args.vision_model,
                                 flava_mm_encoder=flava_mm_encoder,
                                 intermediate_dim=args.intermediate_dim).to(device)

sd_new = OrderedDict()
for k,v in st_dict_old.items():
    if 'student_model' in k:
        sd_new[k.replace('student_model.','')] = v
model.load_state_dict(sd_new)

multi_trainer_inf = Trainer(
    model,
    multi_args,
    train_dataset=dataset_train,
    eval_dataset=None,
    data_collator=collator_inf,
    compute_metrics=compute_metrics
)

prd_results = multi_trainer_inf.predict(dataset_test)

prd_label_ids = np.argmax(prd_results[0][3],axis=-1)

final_result_output = []
for label_id, qid in zip(prd_label_ids,dataset_test['question_id']):
    answer_qid_pair = {}
    answer_qid_pair['answer'] = label_answer_map_dict[str(label_id)]
    answer_qid_pair['question_id'] = qid
    final_result_output.append(answer_qid_pair)

with open(os.path.join(args.output_dir,"results.json"),"w") as f:
    json.dump(final_result_output,f)