import torch
import torch.nn as nn
import torch.nn.functional as K

import transformers

from transformers import AdamW, get_linear_schedule_with_warmup
from transformers import AutoTokenizer, AutoProcessor, AutoFeatureExtractor, AutoModel, AutoModelForCausalLM
from transformers import TrainingArguments, Trainer, logging

import os
import numpy as np
import pandas as pd
import json
from copy import deepcopy

import time
import math
import datasets
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
from collections import OrderedDict
from PIL import Image
from sklearn.metrics import accuracy_score, f1_score

os.environ["TOKENIZERS_PARALLELISM"] = "false" 

# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"]="2,3"
# os.environ["CUDA_VISIBLE_DEVICES"]="0,3"

import argparse
parser = argparse.ArgumentParser()

parser.add_argument('--output_dir',default=None,type=str)

parser.add_argument('--pp_trn_path',default=None,type=str)
parser.add_argument('--pp_dev_path',default=None,type=str)
parser.add_argument('--pp_tst_path',default=None,type=str)

parser.add_argument('--max_seq_len',default=40,type=int)
parser.add_argument('--q_min_len',default=10,type=int)
parser.add_argument('--batch_size',default=10,type=int)
parser.add_argument('--epochs',default=5,type=int)
parser.add_argument('--lr',default=3e-5,type=float)
parser.add_argument('--random_seed',default=42,type=int)
parser.add_argument('--scheduler_name',default='linear',type=str)
parser.add_argument('--fp16',default=False,type=bool)
parser.add_argument('--warmup_ratio',default=0.0,type=float)
parser.add_argument('--ga_steps',default=1,type=int)

args = parser.parse_args()

print(args)

device = torch.device("cuda")
print('There are %d GPU(s) available.' % torch.cuda.device_count())
print('We will use the GPU:', torch.cuda.get_device_name(0))

dataset_train = datasets.load_from_disk(args.pp_trn_path)
dataset_val = datasets.load_from_disk(args.pp_dev_path)

dataset_test = datasets.load_from_disk(args.pp_tst_path)

@dataclass
class FlickrCapCollator:
    tokenizer: AutoTokenizer
    preprocessor: AutoFeatureExtractor
 
    def process_texts(self, texts: List[str]):
        input_ids_raw = self.tokenizer(text=texts,
                                     padding='do_not_pad',
                                     max_length=args.max_seq_len-2,
                                     truncation=True,).input_ids
        input_ids = []
        label_ids = []
        attention_masks = []
        
        for i_b in range(len(input_ids_raw)):

            input_ids.append([self.tokenizer.cls_token_id]+input_ids_raw[i_b]+[self.tokenizer.sep_token_id]+[self.tokenizer.pad_token_id]*(args.max_seq_len - 2-len(input_ids_raw[i_b])))
            label_ids.append([self.tokenizer.cls_token_id]+input_ids_raw[i_b]+[self.tokenizer.sep_token_id]+[-100]*(args.max_seq_len - 2-len(input_ids_raw[i_b])))
            attention_masks.append([1]*(len(input_ids_raw[i_b])+2)+[0]*(args.max_seq_len - 2-len(input_ids_raw[i_b])))
        
        input_ids = torch.tensor(input_ids)
        label_ids = torch.tensor(label_ids)
        attention_masks = torch.tensor(attention_masks)
        
        return {
            "input_ids": input_ids.squeeze(),
            "attention_mask": attention_masks.squeeze(),
            "labels": label_ids.squeeze(),
        }
        

    def preprocess_images(self, images: List[str]):
        processed_images = self.preprocessor(
            images,
            return_tensors="pt",
        )
        return {
            "pixel_values": processed_images['pixel_values'].squeeze(),
        }
            
    def __call__(self, raw_batch_dict):
        return {
            **self.process_texts(
                raw_batch_dict['caption']
                if isinstance(raw_batch_dict, dict) else
                [i['caption'] for i in raw_batch_dict],
            ),
            **self.preprocess_images(
                raw_batch_dict['image'].convert('RGB')
                if isinstance(raw_batch_dict, dict) else
                [i['image'].convert('RGB') for i in raw_batch_dict]
            ),

        }
    
class GITForTrainFlickrCap(nn.Module):
    def __init__(self, git_model):
        super(GITForTrainFlickrCap, self).__init__()
        
        self.git_model = git_model
        
    def forward(self,
                input_ids: torch.LongTensor,
                pixel_values: torch.FloatTensor,
                attention_mask: Optional[torch.LongTensor] = None,
                labels: Optional[torch.LongTensor] = None):
        
        outputs = self.git_model(input_ids = input_ids,
                                 attention_mask = attention_mask,
                                 pixel_values = pixel_values,
                                 labels = labels)
        
        out = {
            "logits": outputs.logits,
        }
        if labels is not None:
            out["loss"] = outputs.loss
        
        return out


processor = AutoProcessor.from_pretrained("microsoft/git-base")
# git_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base")
git_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base")

model = GITForTrainFlickrCap(git_model=git_model)
collator = FlickrCapCollator(tokenizer=processor.tokenizer, preprocessor=processor.image_processor)

multi_args = TrainingArguments(
    output_dir=args.output_dir,
    seed=args.random_seed, 
    learning_rate=args.lr,
    #evaluation_strategy="steps",
    #eval_steps=100,
    #save_strategy="steps",
    #save_steps=100,
    evaluation_strategy="no",
    save_strategy="epoch",
    logging_strategy="steps",
    logging_steps=500,
    save_total_limit=3,
    metric_for_best_model='acc',
    per_device_train_batch_size=args.batch_size,
    per_device_eval_batch_size=args.batch_size,
    remove_unused_columns=False,
    num_train_epochs=args.epochs,
    dataloader_num_workers=8,
    load_best_model_at_end=False,
    lr_scheduler_type=args.scheduler_name,
    eval_accumulation_steps=500,
    gradient_accumulation_steps=args.ga_steps,
    warmup_ratio=args.warmup_ratio,
    save_safetensors=False

)
trainer = Trainer(
    model,
    multi_args,
    train_dataset=dataset_train,
    eval_dataset=None,
    data_collator=collator
)

train_metrics = trainer.train()

all_answers = []

i = 0
start_time = time.time()
while i < len(dataset_test):
    
    if i % (args.batch_size*25) == 0:
        print(i,"/",len(dataset_test)," // time elapsed:",time.time()-start_time)
    
    sample = dataset_test[i:i+args.batch_size]
    
    sample['image'] = [item.convert('RGB') for item in sample['image']]
    
    processed_images = processor.image_processor(images=sample['image'],return_tensors="pt",)
    pixel_values = processed_images['pixel_values'].cuda()
      
    inputs = processor.tokenizer([processor.tokenizer.cls_token]*args.batch_size, add_special_tokens=False,
                       return_tensors='pt')
        
    generated_ids = model.git_model.generate(pixel_values=pixel_values,
                                       input_ids=inputs.input_ids.cuda(),
                                       attention_mask=inputs.attention_mask.cuda(), 
                                       max_length=50)

    all_answers += processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        
    i += args.batch_size

with open(os.path.join(args.output_dir,"results.json"),"w") as f:
    json.dump(all_answers,f)