import torch
import torch.nn as nn
import torch.nn.functional as K

import transformers

from transformers import ViTConfig, ViTModel
from transformers import T5Config, T5EncoderModel, T5Tokenizer
from transformers import FlavaConfig, FlavaProcessor, FlavaModel
from transformers import AutoTokenizer, AutoProcessor, AutoFeatureExtractor, AutoModel
from transformers import TrainingArguments, Trainer, logging
from transformers import BertLayer

import os
import numpy as np
import pandas as pd
import json
from copy import deepcopy

import datasets
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from PIL import Image
from sklearn.metrics import accuracy_score, f1_score

from tqdm.notebook import tqdm

import argparse
parser = argparse.ArgumentParser()

parser.add_argument('--label_dict_path',default=None,type=str)
parser.add_argument('--data_from_file',default=False,type=bool)
parser.add_argument('--data_path',default=None,type=str)
parser.add_argument('--data_preprocessed',default=False,type=bool)
parser.add_argument('--pp_trn_path',default=None,type=str)
parser.add_argument('--pp_dev_path',default=None,type=str)
parser.add_argument('--pp_tst_path',default=None,type=str)

parser.add_argument('--text_model',default='t5-base',type=str)
parser.add_argument('--vision_model',default="google/vit-base-patch16-224-in21k",type=str)
parser.add_argument('--flava_model',default='facebook/flava-full',type=str)

parser.add_argument('--output_dir',default=None,type=str)

parser.add_argument('--mm_init',default='pt',type=str)
parser.add_argument('--max_seq_len',default=24,type=int)
parser.add_argument('--batch_size',default=10,type=int)
parser.add_argument('--epochs',default=5,type=int)
parser.add_argument('--lr',default=3e-5,type=float)
parser.add_argument('--random_seed',default=42,type=int)
parser.add_argument('--dropout',default=0.5,type=float)
parser.add_argument('--scheduler_name',default='linear',type=str)
parser.add_argument('--intermediate_dim',default=1536,type=int)
parser.add_argument('--fp16',default=False,type=bool)
parser.add_argument('--warmup_ratio',default=0.0,type=float)
parser.add_argument('--ga_steps',default=1,type=int)

args = parser.parse_args()

print(args)

assert args.mm_init in ['rnd','pt']

os.environ["TOKENIZERS_PARALLELISM"] = "false" 


device = torch.device("cuda")
print('There are %d GPU(s) available.' % torch.cuda.device_count())
print('We will use the GPU:', torch.cuda.get_device_name(0))

def get_score(count: int) -> float:
    return min(1.0, count / 3)

def generate_target_labels(dataset_answers,answer_label_map_dict,data_type):
    all_labels = []
    all_scores = []
    all_targets = []
    if data_type == 'train':
        for answers in tqdm(dataset_answers):
            answer_count = {}
            for answer in answers:
                answer_ = answer["answer"]
                answer_count[answer_] = answer_count.get(answer_, 0) + 1
            labels = []
            scores = []
            for answer in answer_count:
                if answer not in list(answer_label_map_dict.keys()):
                    continue
                labels.append(int(answer_label_map_dict[answer]))
                score = get_score(answer_count[answer])
                scores.append(score)
            all_labels.append(labels)
            all_scores.append(scores)
        
        for labels,scores in tqdm(zip(all_labels,all_scores)):
            targets = np.zeros(len(answer_label_map_dict))

            for label, score in zip(labels, scores):
                  targets[label] = score
            all_targets.append(targets)
    else:
        for mc_answer in tqdm(dataset_answers):
            label = int(answer_label_map_dict[mc_answer]) if mc_answer in answer_label_map_dict.keys() else 125
            targets = np.zeros(len(answer_label_map_dict))
            targets[label] = 1.0
            all_targets.append(targets)
        
    return all_targets

with open(args.label_dict_path,'r') as f:
    label_answer_map_dict = json.load(f)
    answer_label_map_dict = {v: k for k, v in label_answer_map_dict.items()}

if args.data_preprocessed == False:
    if args.data_from_file:
        print("Load dataset from disk ... ")
        dataset_vqav2 = datasets.load_from_disk(args.data_path)
        print("Done ! ")
    else:
        print("Load dataset from hub and cache ... ")
        dataset_vqav2 = datasets.load_dataset("HuggingFaceM4/VQAv2")
        print("Done ! ")

    dataset_train = deepcopy(dataset_vqav2['train'])
    dataset_val = deepcopy(dataset_vqav2['validation'])

    all_targets_trn = generate_target_labels(dataset_train['answers'],answer_label_map_dict,'train')
    all_targets_val = generate_target_labels(dataset_val['multiple_choice_answer'],answer_label_map_dict,'val')

    dataset_train = dataset_train.add_column("labels",all_targets_trn)
    dataset_val = dataset_val.add_column("labels",all_targets_val)
else:
    assert args.pp_trn_path != None and args.pp_dev_path != None
    print("Load preprocessed train and dev data ...")
    dataset_train = datasets.load_from_disk(args.pp_trn_path)
    dataset_val = datasets.load_from_disk(args.pp_dev_path)

    dataset_train = datasets.concatenate_datasets([dataset_train,dataset_val])
    dataset_val = None # Memory flush

    dataset_test = datasets.load_from_disk(args.pp_tst_path)


@dataclass
class MultimodalCollator:
    tokenizer: AutoTokenizer
    preprocessor: AutoFeatureExtractor

    def tokenize_text(self, texts: List[str]):
        encoded_text = self.tokenizer(
            text=texts,
            padding='max_length',
            max_length=args.max_seq_len,
            truncation=True,
            return_tensors='pt',
            return_attention_mask=True,
        )
        return {
            "input_ids": encoded_text['input_ids'].squeeze(),
            "attention_mask": encoded_text['attention_mask'].squeeze(),
        }

    def preprocess_images(self, images: List[str]):
        processed_images = self.preprocessor(
            images,
            return_tensors="pt",
        )
        return {
            "pixel_values": processed_images['pixel_values'].squeeze(),
        }
            
    def __call__(self, raw_batch_dict):
        return {
            **self.tokenize_text(
                raw_batch_dict['question']
                if isinstance(raw_batch_dict, dict) else
                [i['question'] for i in raw_batch_dict]
            ),
            **self.preprocess_images(
                raw_batch_dict['image'].convert('RGB')
                if isinstance(raw_batch_dict, dict) else
                [i['image'].convert('RGB') for i in raw_batch_dict]
            ),
            'labels': torch.tensor(
                raw_batch_dict['labels']
                if isinstance(raw_batch_dict, dict) else
                [i['labels'] for i in raw_batch_dict],
                dtype=torch.float64
            ),
        }


@dataclass
class VQAInferenceCollator:
    tokenizer: AutoTokenizer
    preprocessor: AutoFeatureExtractor

    def tokenize_text(self, texts: List[str]):
        encoded_text = self.tokenizer(
            text=texts,
            padding='max_length',
            max_length=args.max_seq_len,
            truncation=True,
            return_tensors='pt',
            return_attention_mask=True,
        )
        return {
            "input_ids": encoded_text['input_ids'].squeeze(),
            "attention_mask": encoded_text['attention_mask'].squeeze(),
        }

    def preprocess_images(self, images: List[str]):
        processed_images = self.preprocessor(
            images,
            return_tensors="pt",
        )
        return {
            "pixel_values": processed_images['pixel_values'].squeeze(),
        }
            
    def __call__(self, raw_batch_dict):
        return {
            **self.tokenize_text(
                raw_batch_dict['question']
                if isinstance(raw_batch_dict, dict) else
                [i['question'] for i in raw_batch_dict]
            ),
            **self.preprocess_images(
                raw_batch_dict['image'].convert('RGB')
                if isinstance(raw_batch_dict, dict) else
                [i['image'].convert('RGB') for i in raw_batch_dict]
            ),
            
        }


class BasicStudentModelForVQA(nn.Module):
    def __init__(self,  pretrained_text_name, pretrained_image_name, flava_mm_encoder,
                 num_labels=len(answer_label_map_dict), intermediate_dim=1536, dropout=0.5):
        super(BasicStudentModelForVQA, self).__init__()
        
        self.num_labels = num_labels
        self.pretrained_text_name = pretrained_text_name
        self.pretrained_image_name = pretrained_image_name
        
        self.intermediate_dim = intermediate_dim
        
        # Pretrained transformers for text & image featurization 
        self.text_encoder = T5EncoderModel.from_pretrained(self.pretrained_text_name)
        self.image_encoder = AutoModel.from_pretrained(self.pretrained_image_name)
        
        self.flava_mm_encoder = flava_mm_encoder

        self.text_projection = nn.Linear(self.text_encoder.config.hidden_size, self.flava_mm_encoder.config.hidden_size)
        self.image_projection = nn.Linear(self.image_encoder.config.hidden_size, self.flava_mm_encoder.config.hidden_size)
        
        
        self.fusion = nn.Linear(self.flava_mm_encoder.config.hidden_size, self.intermediate_dim)
        self.fusion_activations = nn.Sequential(nn.LayerNorm(self.intermediate_dim),nn.GELU())
        
        self.classifier = nn.Linear(self.intermediate_dim, self.num_labels)
        
        self.criterion = nn.BCEWithLogitsLoss()
    
    def forward(
            self,
            input_ids: torch.LongTensor,
            pixel_values: torch.FloatTensor,
            attention_mask: Optional[torch.LongTensor] = None,
            labels: Optional[torch.LongTensor] = None):
        
        encoded_text = self.text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True,
        )
        encoded_image = self.image_encoder(
            pixel_values=pixel_values,
            return_dict=True,
        )
        
        multimodal_output = self.flava_mm_encoder(torch.cat([self.text_projection(encoded_text['last_hidden_state']),
                                                             self.image_projection(encoded_image['last_hidden_state'])], dim=1))['last_hidden_state']
        
        logits = self.classifier(self.fusion_activations(self.fusion(multimodal_output[:,0,:])))
        
        out = {
            "text_cls": encoded_text['last_hidden_state'][:,0,:],
            "image_cls": encoded_image['last_hidden_state'][:,0,:],
            "multimodal_cls":multimodal_output[:,0,:],
            "logits": logits,
        }
        if labels is not None:
            loss = self.criterion(logits, labels)
            out["loss"] = loss
        
        return out


def compute_metrics(eval_tuple: Tuple[np.ndarray, np.ndarray]) -> Dict[str, float]:
    logits, labels = eval_tuple
    labels = labels.argmax(axis=-1)
    preds = logits[3].argmax(axis=-1)
    return {
        "acc": accuracy_score(labels, preds),
        "f1": f1_score(labels, preds, average='macro')
    }


multi_args = TrainingArguments(
    output_dir=args.output_dir,
    seed=args.random_seed, 
    learning_rate=args.lr,
    #evaluation_strategy="steps",
    #eval_steps=100,
    #save_strategy="steps",
    #save_steps=100,
    evaluation_strategy="no",
    save_strategy="epoch",
    logging_strategy="steps",
    logging_steps=500,
    save_total_limit=5,             # Since models are large, save only the last 3 checkpoints at any given time while training 
    metric_for_best_model='acc',
    per_device_train_batch_size=args.batch_size,
    per_device_eval_batch_size=args.batch_size,
    remove_unused_columns=False,
    num_train_epochs=args.epochs,
    dataloader_num_workers=8,
    load_best_model_at_end=False,
    lr_scheduler_type=args.scheduler_name,
    eval_accumulation_steps=500,
    gradient_accumulation_steps=args.ga_steps,
    warmup_ratio=args.warmup_ratio,

)

if args.mm_init == 'pt':
    model_flava = FlavaModel.from_pretrained(args.flava_model)
    flava_mm_encoder = deepcopy(model_flava.multimodal_model)
    model_flava=None # Mem flush

elif args.mm_init == 'rnd':
    # Same architecture with FLAVA, but random init
    config_flava = FlavaConfig.from_pretrained(args.flava_model)
    model_flava = FlavaModel(config=config_flava)

    flava_mm_encoder = deepcopy(model_flava.multimodal_model)
    model_flava=None # Mem flush

tokenizer = AutoTokenizer.from_pretrained(args.text_model)
preprocessor = AutoFeatureExtractor.from_pretrained(args.vision_model)
collator = MultimodalCollator(tokenizer=tokenizer, preprocessor=preprocessor)

collator_inf = VQAInferenceCollator(tokenizer=tokenizer, preprocessor=preprocessor)

model = BasicStudentModelForVQA(pretrained_text_name=args.text_model, 
                                 pretrained_image_name=args.vision_model,
                                 flava_mm_encoder=flava_mm_encoder,
                                 intermediate_dim=args.intermediate_dim).to(device)


print("#"*20, "Show Model Architecture","#"*20)
print(model)

multi_trainer = Trainer(
    model,
    multi_args,
    train_dataset=dataset_train,
    eval_dataset=None,
    data_collator=collator,
    compute_metrics=compute_metrics
)

train_multi_metrics = multi_trainer.train()

multi_trainer = Trainer(
    model,
    multi_args,
    train_dataset=dataset_train,
    eval_dataset=None,
    data_collator=collator_inf,
    compute_metrics=compute_metrics
)

prd_results = multi_trainer.predict(dataset_test)

prd_label_ids = np.argmax(prd_results[0][3],axis=-1)

final_result_output = []
for label_id, qid in zip(prd_label_ids,dataset_test['question_id']):
    answer_qid_pair = {}
    answer_qid_pair['answer'] = label_answer_map_dict[str(label_id)]
    answer_qid_pair['question_id'] = qid
    final_result_output.append(answer_qid_pair)

with open(os.path.join(args.output_dir,"results.json"),"w") as f:
    json.dump(final_result_output,f)
