import re
import sys
import torch
import clip
import json
import os
import numpy as np
import json

from PIL import Image
from argparse import ArgumentParser
from transformers import T5Tokenizer, T5Config, T5ForConditionalGeneration
from tqdm import tqdm
from torch.nn import CrossEntropyLoss
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration



parser = ArgumentParser("VQAv2 Prompt")
parser.add_argument("--statement", type=str, default="vqa_statement.txt")
parser.add_argument("--result", type=str, default="GT_LM_PROMPT.json")
parser.add_argument("--mount", type=str, default="/mnt")
parser.add_argument("--partition", type=int, default="64")
parser.add_argument("--portion", type=int, default="0")
parser.add_argument("--mask_token", type=str, default="<extra_id_0>") # BERT: [MASK] T5: <extra_id_0> BART: <mask>
parser.add_argument("--T5", action="store_true")
parser.add_argument("--BERT", action="store_true")
args = parser.parse_args()


def process_template(templt, mask, question):
	templt = templt.lower()
	templt = templt.replace(' does ',' ').replace(' do ',' ')
	templt = templt.replace('ising', 'is doing').replace('areing', 'are doing')
	if 'looking' in question:
		if 'looking at' in question:
			templt = templt.replace('looking in ', 'looking at ')
		elif 'looking into ' in question:
			templt = templt.replace('looking in ', 'looking into ')
		elif 'looking?' in question:
			templt = templt.replace('looking in ', 'looking ')
	templt = templt.replace('is there ', 'there is ')
	templt = templt.replace('are there ', 'there are ')
	templt = templt.replace('<mask>',mask)

	if templt[-1] == '.' and templt[-2:] != ' .':
		templt = templt[:-1] + ' .'
	if templt.index(mask) == 0:
		templt = 'The ' + templt
	prompt_lst = [templt]
	return prompt_lst

def filter_labels(labs):
	new_labels = []
	for l in labs:
		if l not in ['.', '?', ',', '!', 'it', 'i', 'the', '', '\\', '-', '|', '...', 'a', 'an', '', 'this', 'that',
					 ':', ' :', 'not','in'] and len(l) > 1 and not l.isdigit() and l not in new_labels:
			new_labels.append(l.lower())
	return new_labels

class CandidateDataset(Dataset):
    def __init__(self,
                 sentence,
                 candidates,
                 tokenizer,
                 template="<extra_id_0> {} <extra_id_1>"):
        self.sentence = sentence
        self.candidates = candidates
        self.tokenizer = tokenizer
        self.template = template
        self.build_inputs()

    def build_inputs(self):
        """For T5"""
        self.input_ids = self.tokenizer(
            self.sentence, return_tensors='pt').input_ids.squeeze()

        self.all_candidate_ids = self.tokenizer(
            [self.template.format(cand) for cand in candidates],
            padding=True).input_ids

        self.all_candidate_labels = torch.tensor([[
            (l if l != self.tokenizer.pad_token_id else -100) for l in label
        ] for label in self.all_candidate_ids])

    def update_sentence(self, sentence):
        self.sentence = sentence
        self.input_ids = self.tokenizer(
            self.sentence, return_tensors='pt').input_ids.squeeze()

    def update_candidate(self):
        """Just for BART"""
        template = self.sentence.replace("<mask>", "{}")
        self.all_candidate_ids = self.tokenizer(
            [template.format(cand) for cand in candidates],
            padding=True).input_ids

        self.all_candidate_labels = torch.tensor([[
            (l if l != self.tokenizer.pad_token_id else -100) for l in label
        ] for label in self.all_candidate_ids])

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids,
            "labels": self.all_candidate_labels[idx]
        }

    def __len__(self):
        return len(self.candidates)


def read_candidates(filename):
    with open(filename, "r", encoding="utf-8") as f:
        candidates = list(json.load(f).values())
    return candidates


def compute_logits(cand_loader, model, device="cuda"):
    logits = []
    loss_fct = CrossEntropyLoss(ignore_index=-100, reduction="none")
    with torch.no_grad():
        for cand in cand_loader:
            for key, value in cand.items():
                cand[key] = value.to(device)
            lm_logits = model(**cand).logits
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)),
                            cand["labels"].view(-1))
            # fix!
            batch_size, seq_length = cand["labels"].shape
            loss = loss.view(batch_size, seq_length).mean(-1)

            logits.extend(loss.detach().cpu().tolist())

    return logits

def LM_infilling(sentence, candidates, tokenizer, device, LM_model, topK=20):
	top_candidates, top_loss = [],[]
	cand_dataset = CandidateDataset(sentence, candidates, tokenizer)
	cand_dataset.update_sentence(sentence)
	cand_loader = DataLoader(cand_dataset, batch_size=128, shuffle=False)
	logits = compute_logits(cand_loader, LM_model, device=device)
	cand_logits = sorted([[cand, logit]
						  for cand, logit in zip(candidates, logits)],
						 key=lambda item: item[1],
						 reverse=False)
	for cand, logit in cand_logits:
		top_candidates.append(cand)
		top_loss.append(logit)
	return top_candidates[:topK], top_loss[:topK]


print('Reading statement...')
data_lst = []
statement_path = args.mount + '/DATA/' + args.statement
with open(statement_path) as src:
	for line in src:
		line = line.strip().split('\t')
		if len(line) >= 6:
			qid, question, statement, flag, answer, image = line[0], line[1], line[2], line[3], line[4], line[5]
		else:
			qid, question, statement, flag, answer, image = line[0], line[1], line[2], line[3], '', line[5]
		data_lst.append({'qid':qid, 'question':question, 'statement':statement, 'flag':flag, 'answer':answer, 'image':image})

print('Reading candidates...')
answer_vocab = []
candidate_path = args.mount + '/DATA/' + 'answer_vocab.txt'
answer_vocab = read_candidates(candidate_path)


# Load CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load('ViT-B/16', device)  # ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16']


top_K = 200

if args.T5:
	# {'qid':qid, 'question':question, 'prompt':prompt, 'labels':labels}
	print('Loading T5 model ...')
	T5_PATH = 't5-large' # "t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"
	DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
	t5_tokenizer = T5Tokenizer.from_pretrained(T5_PATH)
	t5_config = T5Config.from_pretrained(T5_PATH)
	t5_mlm = T5ForConditionalGeneration.from_pretrained(T5_PATH, config=t5_config).to(DEVICE)

	tgt_path = args.mount+'/tmp_test/'+str(args.portion)+'_'+args.result

	color_lst = get_colors()
	season_lst, sports_lst, animal_lst, fruit_lst, brand_lst, shape_lst, direction_lst, room_lst = get_others()

	with open(tgt_path,'w') as tgt:
		print(f"Processing portion {args.portion} with partition {args.partition}.")
		idx = len(data_lst) // args.partition
		if args.portion == args.partition - 1:
			anns = data_lst[idx * args.portion:]
		else:
			anns = data_lst[idx * args.portion: idx * (args.portion + 1)]
		for data in tqdm(anns):
			qid, question, template = data['qid'], data['question'], data['statement'].replace('<mask>',args.mask_token)
			answer = data['answer']
			image = data['image']

			# knowledge from Vision
			candidates = []
			image_input = preprocess(Image.open(
				f"{args.mount}/datasets/VQA_v2/InputImages/val2014/{image}")).unsqueeze(0).to(device)
			classes = [f'a photo containing {c} .' for c in answer_vocab]
			text_inputs = torch.cat([clip.tokenize(f"{c}") for c in classes]).to(device)
			with torch.no_grad():
				image_features = clip_model.encode_image(image_input)
				text_features = clip_model.encode_text(text_inputs)
			image_features /= image_features.norm(dim=-1, keepdim=True)
			text_features /= text_features.norm(dim=-1, keepdim=True)
			similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
			values, indices = similarity[0].topk(200) if len(classes) >= 5 else similarity[0].topk(len(classes))

			for index in indices:
				candidates.append(answer_vocab[index])

			templates, label_vision = [], []
			if data['flag'] == '0' or 'how ' in question.lower():
				candidate_lst, logit_lst = LM_infilling(question+f' It is {args.mask_token} .', candidates, t5_tokenizer, DEVICE, t5_mlm, topK=top_K)
				q_lst = question.split(' ')
				if 'or' in q_lst:
					or_idx = q_lst.index('or')
					label_vision = []
				else:
					templates = [f'It is {args.mask_token} .']
					label_vision = []

			elif data['flag'] == '1':
				template = process_template(template, args.mask_token, question)[0]
				label_vision, logits = LM_infilling(f"{question} {template}", candidates, t5_tokenizer, DEVICE, t5_mlm, topK=top_K)
				label_vision = filter_labels(label_vision)
				templates = [template]

			else:
				raise ValueError

			# Knowledge from Language
			candidates = answer_vocab
			templates, label_language = [], []
			if data['flag'] == '0' or 'how ' in question.lower():
				candidate_lst, logit_lst = LM_infilling(question + f' It is {args.mask_token} .', candidates,
														t5_tokenizer, DEVICE, t5_mlm, topK=top_K)
				q_lst = question.split(' ')
				if 'or' in q_lst:
					or_idx = q_lst.index('or')
					label_language = candidate_lst[:3]
					logits = logit_lst[:3]
					templates = [args.mask_token]
				else:
					templates = [f'It is {args.mask_token} .']
					label_language = filter_labels(candidate_lst)
					logits = logit_lst
			elif data['flag'] == '1':
				template = process_template(template, args.mask_token, question)[0]
				label_language, logits = LM_infilling(f"{question} {template}", candidates, t5_tokenizer, DEVICE, t5_mlm,
											  topK=top_K)
				label_language = filter_labels(label_language)
				templates = [template]
			else:
				raise ValueError

			label_language = set(label_language)
			label_vision = set(label_vision)
			labels = list(label_language | label_vision)

			tmp = {'qid': qid, 'question': question, 'prompts': templates, 'labels': labels, 'answer': answer,
				   'image': image}

			print(f"{question} {template}")
			print(f"{tmp}  {answer in labels}")
			tgt.write(json.dumps(tmp) + '\n')




	print(f"Data saved at {tgt_path}")

