import torch
import torch.nn as nn
import torch.nn.functional as K

import transformers

from transformers import ViTConfig, ViTModel
from transformers import T5Config, T5EncoderModel, T5Tokenizer
from transformers import FlavaConfig, FlavaProcessor, FlavaModel
from transformers import AutoTokenizer, AutoProcessor, AutoFeatureExtractor
from transformers import TrainingArguments, Trainer, logging

import os
import numpy as np
import pandas as pd
import json
from copy import deepcopy

import datasets
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from PIL import Image
from sklearn.metrics import accuracy_score, f1_score

import argparse
parser = argparse.ArgumentParser()


parser.add_argument('--data_path',default=None,type=str)

parser.add_argument('--flava_model',default="facebook/flava-full",type=str)

parser.add_argument('--output_dir',default=None,type=str)

parser.add_argument('--max_seq_len',default=24,type=int)
parser.add_argument('--batch_size',default=10,type=int)
parser.add_argument('--epochs',default=5,type=int)
parser.add_argument('--lr',default=3e-5,type=float)
parser.add_argument('--random_seed',default=42,type=int)
parser.add_argument('--dropout',default=0.1,type=float)
parser.add_argument('--scheduler_name',default='linear',type=str)
parser.add_argument('--intermediate_dim',default=1536,type=int)
parser.add_argument('--warmup_ratio',default=0.0,type=float)
parser.add_argument('--confidence_penalty_weight',default=0.0,type=float)
parser.add_argument('--ga_steps',default=1,type=int)

parser.add_argument('--do_train',default=False,type=bool)
parser.add_argument('--do_eval',default=False,type=bool)
parser.add_argument('--do_predict',default=False,type=bool)

args = parser.parse_args()

os.environ["TOKENIZERS_PARALLELISM"] = "false" 

device = torch.device("cuda")
print('There are %d GPU(s) available.' % torch.cuda.device_count())
print('We will use the GPU:', torch.cuda.get_device_name(0))

print("Load dataset from disk ... ")
dataset_train = datasets.load_from_disk(os.path.join(args.data_path,'train'))
dataset_dev = datasets.load_from_disk(os.path.join(args.data_path,'dev'))
dataset_test = datasets.load_from_disk(os.path.join(args.data_path,'test'))
print("Done ! ")


@dataclass
class FlavaSNLICollator:
    tokenizer: AutoTokenizer
    preprocessor: AutoFeatureExtractor

    def tokenize_text(self, texts: List[str]):
        encoded_text = self.tokenizer(
            text=texts,
            padding='max_length',
            max_length=args.max_seq_len,
            truncation=True,
            return_tensors='pt',
            return_attention_mask=True,
        )
        return {
            "input_ids": encoded_text['input_ids'].squeeze(),
            "attention_mask": encoded_text['attention_mask'].squeeze(),
        }

    def preprocess_images(self, images: List[str]):
        processed_images = self.preprocessor(
            images,
            return_tensors="pt",
        )
        return {
            "pixel_values": processed_images['pixel_values'].squeeze(),
        }
            
    def __call__(self, raw_batch_dict):
        return {
            **self.tokenize_text(
                raw_batch_dict['hypothesis']
                if isinstance(raw_batch_dict, dict) else
                [i['hypothesis'] for i in raw_batch_dict]
            ),
            **self.preprocess_images(
                raw_batch_dict['image'].convert('RGB')
                if isinstance(raw_batch_dict, dict) else
                [i['image'].convert('RGB') for i in raw_batch_dict]
            ),
            'labels': torch.tensor(
                raw_batch_dict['label']
                if isinstance(raw_batch_dict, dict) else
                [i['label'] for i in raw_batch_dict],
                dtype=torch.int64
            ),
        }

def compute_metrics(eval_tuple: Tuple[np.ndarray, np.ndarray]) -> Dict[str, float]:
    logits, labels = eval_tuple
    preds = logits[3].argmax(axis=-1)
    return {
        "acc": accuracy_score(labels, preds),
        "f1": f1_score(labels, preds, average='micro')
    }

class FlavaForSNLI(nn.Module):
    def __init__(self,  pretrained_flava_name, num_labels=3, dropout=0.5):
        super(FlavaForSNLI, self).__init__()
        self.num_labels = num_labels
        
        self.pretrained_flava_name = pretrained_flava_name
        
        self.flava_model = FlavaModel.from_pretrained(self.pretrained_flava_name)
        
        # Fully-connected classifier
        self.classifier = nn.Linear(self.flava_model.config.multimodal_config.hidden_size, self.num_labels)
        
        self.classifier = nn.Sequential(
            nn.Linear(self.flava_model.config.multimodal_config.hidden_size, args.intermediate_dim),
            nn.LayerNorm(args.intermediate_dim),
            nn.GELU(),
            nn.Linear(args.intermediate_dim, self.num_labels),
        )

        self.criterion = nn.CrossEntropyLoss()
        self.sm = torch.nn.Softmax(dim=-1)
        self.lsm = torch.nn.LogSoftmax(dim=-1)

        
    
    def forward(
            self,
            input_ids: torch.LongTensor,
            pixel_values: torch.FloatTensor,
            attention_mask: Optional[torch.LongTensor] = None,
            labels: Optional[torch.LongTensor] = None):
        
        model_output = self.flava_model(input_ids=input_ids,
                                        attention_mask=attention_mask,
                                        pixel_values=pixel_values,
                                        return_dict=True,
                                        )
        
        logits = self.classifier(model_output['multimodal_output'].last_hidden_state[:,0,:])

        out = {
            "text_cls": model_output['text_output'].last_hidden_state[:,0,:],
            "image_cls": model_output['image_output'].last_hidden_state[:,0,:],
            "multimodal_cls": model_output['multimodal_output'].last_hidden_state[:,0,:],
            "logits": logits
        }
        if labels is not None:
            loss_ce = self.criterion(logits, labels)
            
            probabilities = self.sm(logits)
            log_probs = self.lsm(logits)
            negative_entropy = torch.mean(- torch.sum(probabilities * log_probs, dim=-1))
            
            loss = loss_ce - args.confidence_penalty_weight * negative_entropy
            
            out["loss"] = loss
        
        return out

def createFlavaSNLICollatorAndModel(pretrained_flava_name="facebook/flava-full"):
    processor = FlavaProcessor.from_pretrained(pretrained_flava_name)
    multimodal_collator = FlavaSNLICollator(tokenizer=processor.tokenizer, 
                                           preprocessor=processor.feature_extractor)
  
    multimodal_model = FlavaForSNLI(pretrained_flava_name=pretrained_flava_name).to(device)
    
    return multimodal_collator, multimodal_model

multi_args = TrainingArguments(
    output_dir=args.output_dir,
    seed=args.random_seed, 
    learning_rate=args.lr,
    # evaluation_strategy="steps",
    # eval_steps=100,
    # save_strategy="steps",
    # save_steps=100,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=5,      
    logging_strategy="epoch",
    metric_for_best_model='acc',
    per_device_train_batch_size=args.batch_size,
    per_device_eval_batch_size=args.batch_size,
    remove_unused_columns=False,
    num_train_epochs=args.epochs,
    # max_steps=50,
    fp16=False,
    dataloader_num_workers=8,
    load_best_model_at_end=True,
    weight_decay=1e-2,
    warmup_ratio=args.warmup_ratio,
    lr_scheduler_type=args.scheduler_name,
    eval_accumulation_steps=500,
    gradient_accumulation_steps=args.ga_steps,
)

collator, model = createFlavaSNLICollatorAndModel(args.flava_model)

multi_trainer = Trainer(
    model,
    multi_args,
    train_dataset=dataset_train,
    eval_dataset=dataset_dev,
    data_collator=collator,
    compute_metrics=compute_metrics
)

train_multi_metrics = multi_trainer.train()

prd_results = multi_trainer.predict(dataset_test)

prd_label_ids = np.argmax(prd_results[0][3],axis=-1)

test_acc = accuracy_score(dataset_test['label'], prd_label_ids)
print("test accuracy = "+str(test_acc))
with open(os.path.join(args.output_dir,'test_eval_result.txt'),'w') as f:
    f.write("test accuracy = "+str(test_acc))