import torch
import clip
import stanza
import json
import inflect
import random as rd
import numpy as np
import logging
from vqaTools.vqa import VQA
from PIL import Image
from torch import nn, optim
from tqdm import tqdm
from argparse import ArgumentParser
from transformers import pipeline

parser = ArgumentParser("VQAv2 Prompt")
parser.add_argument("--yesno", action="store_true")
parser.add_argument("--number", action="store_true")
parser.add_argument("--other", action="store_true")
parser.add_argument("--binor", action="store_true")
parser.add_argument("--best", action="store_true")
parser.add_argument("--clip", type=str, default="ViT-B/16")  # ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16']
parser.add_argument("--result", type=str, default="zero_shot")
parser.add_argument("--partition", type=int, default="64")
parser.add_argument("--portion", type=int, default="0")
parser.add_argument("--topk", type=int, default="64")
parser.add_argument("--n_examples", type=int, default="16")
parser.add_argument("--mount", type=str, default="/mnt")
parser.add_argument("--prompt_type", type=str, default="LM_PROMPT")  # or ANSWER_PROMPT
parser.add_argument("--mask_token", type=str, default="<extra_id_0>") # BERT: [MASK] T5: <extra_id_0> BART: <mask>
parser.add_argument("--answer", type=str, default="/DATA/answer_vocab.txt")
parser.add_argument("--save_to", type=str, default="/DATA/MODELS/")
parser.add_argument("--way_type", type=str, default="question") # question or answer
parser.add_argument("--lr", type=float, default=2e-5)
parser.add_argument("--batch_size", type=int, default=8)

args = parser.parse_args()

EPOCH = 30


logging.getLogger("transformers").setLevel(logging.CRITICAL)
logging.getLogger("stanza").setLevel(logging.CRITICAL)

stanza.download('en')
nlp_pipeline = stanza.Pipeline('en')
pnoun = inflect.engine()


def top_k_candidate(txt, img, mod, ans, ans_vocab, log_scale, template=None, candidate=None, scorer=None):
	score = []
	if template is not None and candidate is not None and scorer is not None:
		if '[MASK]' in template:
			score_list = scorer(template, targets=candidate)
			score_dict = {}
			min_score = 999
			# {'sequence': '<s>today isred</s>', 'score': 5.708644494006876e-06, 'token': 2050, 'token_str': 'red'}
			for i in score_list:
				score_dict[i['token_str']]= i['score'] * 1e6
				if i['score'] * 1e6 < min_score:
					min_score = i['score'] * 1e6
			for c in candidate:
				if c not in score_dict:
					score.append(min_score)
				else:
					score.append(score_dict[c])
	with torch.no_grad():
		i_features = mod.encode_image(img)
		t_features = mod.encode_text(txt)
		i_features /= i_features.norm(dim=-1, keepdim=True)
		t_features /= t_features.norm(dim=-1, keepdim=True)
		similarity = (log_scale * i_features @ t_features.T).softmax(dim=-1)
		_, indices = similarity[0].topk(args.topk)
		ans_flag = False
		candidata_lst = []
		candidate2idx = {}
		indices = indices.detach().cpu().data.numpy()
		rd.shuffle(indices)
		for i, idx in enumerate(indices):
			tmp = ans_vocab[idx]
			candidata_lst.append(tmp)
			candidate2idx[tmp] = i
			if ans == tmp:
				ans_flag = True
				tmp_id = i
		ans_id = rd.randint(0, args.topk-1)
		candidate2idx[ans] = ans_id
		if not ans_flag:
			candidata_lst[ans_id] = ans
		else:
			tmp = candidata_lst[ans_id]
			candidata_lst[ans_id] = ans
			candidata_lst[tmp_id] = tmp
		return candidata_lst, candidate2idx, ans_id, score


dataDir		= f'{args.mount}/datasets/VQA_v2'
versionType = 'v2_' # this should be '' when using VQA v2.0 dataset
taskType	= 'OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
dataType	= 'mscoco'  # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
dataSubType = 'train2014'
annFile	 = '%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType)
quesFile	= '%s/InputQuestions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType)
imgDir 		= '%s/InputImages/%s/' %(dataDir, dataSubType)

# initialize VQA api for QA annotations
vqa=VQA(annFile, quesFile)
annIds = vqa.getQuesIds()
print(len(annIds))
anns = vqa.loadQA(annIds)

rd.shuffle(anns)

vqa_prompt = {}
data_counter = 0
with open(f'{args.mount}/DATA/{args.prompt_type}.json') as other:
	for line in other:
		data = json.loads(line.strip())
		vqa_prompt[data['qid']] = data
		data_counter += 1

answer_vocab = []
answer2id = {}
idx = 0
with open(args.mount+args.answer) as ans:
	ans_json = json.load(ans)
	for key, value in ans_json.items():
		answer_vocab.append(value)
		answer2id[value] = idx
		idx += 1
answer_vocab = answer_vocab

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load(args.clip, device, jit=False)
model = model.float()

logit_scale = nn.Parameter(torch.ones([], dtype=torch.half) * np.log(1 / 0.07)).to(device)
scale = logit_scale.exp()
clip_loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)

print("Initialize BERT scorer...")
if torch.cuda.is_available():
	bert_scorer = pipeline('fill-mask', model='bert-large-uncased', top_k=200, device=0)
else:
	bert_scorer = pipeline('fill-mask', model='bert-large-uncased', top_k=200)

torch.autograd.set_detect_anomaly(True)

print(f"saving original model weights...")
torch.save({
	'epoch': 0,
	'model_state_dict': model.state_dict(),
	'loss': 999,
	'optimizer_state_dict': optimizer.state_dict(),}, f"{args.mount+args.save_to}/model_epoch_0.pt")

if args.binor:
	print('Enable BiNor Training')
	for n, p in model.named_parameters():
		if not ('bias' in n or 'ln_' in n or '.bn' in n):
			p.requires_grad = False
else:
	print('Finetuning all parameters')


question_type = ["how many", "is the", "what", "what color is the", "what is the", "is this", "is this a", "what is",
				 "are the", "what kind of", "is there a", "what type of", "is it", "what are the", "where is the", 
				 "is there", "does the", "what color are the", "are these", "are there", "which", "is", "what is the man", 
				 "is the man", "are", "how", "does this", "what is on the", "what does the", "how many people are", 
				 "what is in the", "what is this", "do", "what are", "are they", "what time", "what sport is", 
				 "are there any", "is he", "what color is", "why", "where are the", "what color", "who is", 
				 "what animal is", "is the woman", "is this an", "do you", "how many people are in", "what room is", 
				 "has", "is this person", "what is the woman", "can you", "why is the", "is the person", 
				 "what is the color of the", "what is the person", "could", "was", "is that a", "what number is", 
				 "what is the name", "what brand", "none of the above"]

# prepare K shots for each category

support_set = []

if args.way_type == 'answer':

	a_cover = {}
	for aw in answer_vocab:
		a_cover[aw] = 0

	for ann in anns:
		answer = ann['multiple_choice_answer']
		if answer in a_cover:
			if a_cover[answer] < args.n_examples:
				support_set.append(ann)
				a_cover[answer] += 1

	print(f"way type is answer, and len(support_set)={len(support_set)}")



elif args.way_type == 'question':
	qa_cover = {}
	for qt in question_type:
		qa_cover[qt + 'yesno'] = 0
		qa_cover[qt + 'number'] = 0
		qa_cover[qt + 'other'] = 0
	qa_answer = []
	number_dict = [str(i) for i in range(30)]

	for ann in anns[1000:]:
		answer = ann['multiple_choice_answer']
		if ann['answer_type'] == 'yes/no' and qa_cover[ann['question_type']+'yesno'] < args.n_examples and args.yesno:
			qa_cover[ann['question_type'] + 'yesno'] += 1
			support_set.append(ann)
		elif ann['answer_type'] == 'number' and qa_cover[ann['question_type']+'number'] < args.n_examples and answer in number_dict  and args.number:
			qa_cover[ann['question_type'] + 'number'] += 1
			support_set.append(ann)
		elif ann['answer_type'] == 'other' and qa_cover[ann['question_type']+'other'] < args.n_examples and args.other:
			qa_cover[ann['question_type'] + 'other'] += 1
			support_set.append(ann)

print(f"Support set has {len(support_set)} examples.")

best_loss = 99999
text_lst = []
image_lst = []
answ_lst = []

for epoch in range(EPOCH):
	step, epoch_step, n_batch = 0, 0, 0
	total_loss, epoch_loss = 0, 0
	rd.shuffle(support_set)
	for ann in tqdm(support_set):
		optimizer.zero_grad()
		quesId = ann['question_id']
		imgId = ann['image_id']
		answer = ann['multiple_choice_answer']
		imgFilename = 'COCO_' + dataSubType + '_' + str(imgId).zfill(12) + '.jpg'
		question = vqa.qqa[quesId]['question'].lower()
		image_input = preprocess(Image.open(imgDir + imgFilename)).unsqueeze(0).to(device)
		image_features = model.encode_image(image_input)
		image = image_features / image_features.norm(dim=-1, keepdim=True)

		quesVerify = str(quesId)
		question = vqa.qqa[quesId]['question'].lower().replace('?', '').replace('.', '')
		question_type = ann['question_type'].lower().strip(' ') + ' '
		if quesVerify in vqa_prompt:
			template = vqa_prompt[quesVerify]['prompts'][0]
			answer_classes = []
			for c in vqa_prompt[quesVerify]['labels']:
				if c not in ['[unk]', 'UNK', '[UNK]']:
					answer_classes.append(c)
		else:
			template = args.mask_token
			answer_classes = ['<unk>']
		if answer not in answer_classes:
			answer_classes = answer_classes + [answer]
		if len(answer_classes) < args.topk:
			answer_classes = answer_classes + ['<unk>' for i in range(args.topk-len(answer_classes)+1)]

		txt = torch.cat(
			[clip.tokenize(f"{template.replace(args.mask_token, str(l))}", truncate=True) for l in
			 answer_classes]).to(device)

		cand_lst, cand2idx, ans_idx, lm_score = top_k_candidate(txt, image_input, model, answer, answer_classes,
																logit_scale,
																template=template.replace(args.mask_token, '[MASK]'),
																candidate=answer_classes, scorer=bert_scorer)

		text_inputs = torch.cat(
			[clip.tokenize(f"{template.replace(args.mask_token, str(l))}", truncate=True) for l in
			 cand_lst]).to(device)
		text_features = model.encode_text(text_inputs)
		text = text_features / text_features.norm(dim=-1, keepdim=True)
		lm_score = torch.tensor(lm_score, dtype=torch.half)

		if (not args.yesno and ann['answer_type'] == 'yes/no') or (
				not args.number and ann['answer_type'] == 'number') or (
				not args.other and ann['answer_type'] == 'other'):
			continue
		else:
			# optimizer.zero_grad()
			step += 1
			epoch_step += 1
			text_lst.append(text_features)
			image_lst.append(image)
			answ_lst.append(ans_idx)
			if step % args.batch_size == 0 or epoch_step == len(support_set):

				ground_truth = torch.tensor(answ_lst, dtype=torch.long).to(device)
				image_batch = torch.stack(image_lst).to(device)
				text_batch = torch.stack(text_lst).to(device)
				logits = (logit_scale * image_batch @ torch.transpose(text_batch,1,2)).squeeze(1)

				loss = clip_loss(logits, ground_truth)
				if args.yesno and ann['answer_type'] == 'other':
					loss = loss
				elif args.number and ann['answer_type'] == 'yes/no':
					loss = loss
				elif args.number and ann['answer_type'] == 'number':
					loss = loss

				n_batch += 1
				step = 0
				text_lst, image_lst, answ_lst = [], [], []

				epoch_loss += loss.detach().cpu().data.numpy()

				loss.backward()
				optimizer.step()

	if epoch_loss/n_batch < best_loss or not args.best:
		best_loss = epoch_loss/n_batch
		print(f"saving model at epoch {epoch+1} with loss {best_loss:.4f}...")
		torch.save({
			'epoch': epoch,
			'model_state_dict': model.state_dict(),
			'loss': best_loss,
			'optimizer_state_dict': optimizer.state_dict(),}, f"{args.mount+args.save_to}/model_epoch_{epoch+1}.pt")
	else:
		print(f"skip checkpoint with loss {(epoch_loss/n_batch):.4f}, previous best is {best_loss:.4f}")
