import torch
import clip
import random as rd
from sklearn.metrics import accuracy_score
from PIL import Image
from torch import nn, optim
from argparse import ArgumentParser
from lib.snliveloader import SNLIVE


parser = ArgumentParser("SNLI-VE")
parser.add_argument("--bitfit", action="store_true")
parser.add_argument("--clip", type=str, default="RN50x16")  # ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16']
parser.add_argument("--clip_size", type=int, default=768)
parser.add_argument("--n_examples", type=int, default="16")
parser.add_argument("--mount", type=str, default="/mnt/")
parser.add_argument("--save_to", type=str, default="/DATA/SNLI_MODELS")
parser.add_argument("--identifier", type=str, default="1024_128_3")
parser.add_argument("--lr", type=float, default=3e-6)
parser.add_argument("--dropout", type=float, default=0)
parser.add_argument("--batch_size", type=int, default=128)

parser.add_argument("--size1", type=int, default=1024)
parser.add_argument("--size2", type=int, default=128)

args = parser.parse_args()


class MLP(nn.Module):
	def __init__(self, input_size, num_class):
		super(MLP, self).__init__()
		self.linear = nn.Sequential(
			nn.Linear(input_size, args.size1),
			nn.BatchNorm1d(args.size1),
			nn.ReLU(inplace=True),
			nn.Dropout(args.dropout),
			nn.Linear(args.size1, args.size2),
			nn.BatchNorm1d(args.size2),
			nn.ReLU(inplace=True),
			nn.Dropout(args.dropout),
			nn.Linear(args.size2, num_class),
		)
	def forward(self, x):
		return self.linear(x)


def fusion(vector1, vector2):
	vsum = vector1 + vector2
	vdif = vector1 - vector2
	vcat = torch.cat((vector1,vector2),1)
	vmul = vector1 * vector2
	# return torch.cat((vcat,vmul),1)
	return torch.cat((vsum,vdif,vcat,vmul),1)



snlive_dataset = SNLIVE(root_path='/home/data/datasets/SNLI-VE/data')
EPOCH = 20

device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load(args.clip, device, jit=False)
clip_model = clip_model.float()

for p in clip_model.parameters():
	p.require_grad = False

OutputLayer = MLP(args.clip_size*5, 3).to(device)

training_loss = nn.CrossEntropyLoss()
optimizer = optim.Adam([
                		{'params': OutputLayer.parameters()},
						], lr=args.lr)

# 0: neutral, 1: entailment, 2: contradiction
snlive_train = snlive_dataset.train_set
snlive_valid = snlive_dataset.valid_set
snlive_test = snlive_dataset.test_set
snlive_caption = snlive_dataset.caption
nli_train = snlive_dataset.nli_train_set

snlive_train = snlive_train
snlive_test = snlive_test

text_acc_lst = []
image_acc_lst = []


for epoch in range(EPOCH):
	batch_idx, step = 0, 0
	train_ground_truth = []
	train_predict = []
	OutputLayer.train()
	rd.shuffle(snlive_train)
	while batch_idx < len(snlive_train):
		train_batch = snlive_train[batch_idx:batch_idx+args.batch_size] if batch_idx+args.batch_size < len(snlive_train) else snlive_train[batch_idx:len(snlive_train)-1]
		image_batch, sent1_batch, sent2_batch, caption_batch, label_batch = [], [], [], [], []
		for item in train_batch:
			image_batch.append(Image.open(item['Image_path']))
			sent1_batch.append(item['sentence1'])
			sent2_batch.append(item['sentence2'])
			label_batch.append(item['label_id'])
		batch_idx += args.batch_size
		train_ground_truth += label_batch

		sent1_ids = clip.tokenize(sent1_batch, truncate=True).to(device)
		sent2_ids = clip.tokenize(sent2_batch, truncate=True).to(device)
		image_ids = torch.stack([preprocess(img) for img in image_batch], dim=0).to(device)
		tmp_ground_truth = torch.tensor(label_batch, dtype=torch.long).to(device)

		with torch.no_grad():
			sent1_features = clip_model.encode_text(sent1_ids).to(device)
			sent2_features = clip_model.encode_text(sent2_ids).to(device)

			sent1_embedding = sent1_features / sent1_features.norm(dim=-1, keepdim=True)
			sent2_embedding = sent2_features / sent2_features.norm(dim=-1, keepdim=True)

			input_embedding = fusion(sent1_embedding,sent2_embedding)

		# optimizer.zero_grad()
		logits = OutputLayer(input_embedding)
		train_predict += torch.argmax(logits,dim=-1).detach().cpu().data.numpy().tolist()
		loss = training_loss(logits, tmp_ground_truth)
		loss.backward()
		optimizer.step()

		if step % 1000 == 0:
			print(f"Training Epoch {epoch+1} Step {step}")
		step += 1

	print(f"saving model at epoch {epoch+1}...")
	torch.save({
		'epoch': epoch+1,
		'model_state_dict': OutputLayer.state_dict(),
		'clip_state_dict':clip_model.state_dict(),
		'optimizer_state_dict': optimizer.state_dict(), },
		f"{args.mount}{args.save_to}/snlive_model_{args.identifier}_epoch_{epoch + 1}.pt")



	print('Evaluating...')
	OutputLayer.eval()
	ground_truth = []
	prediction, prediction_2 = [], []
	batch_idx = 0
	eval_batch_size = 64
	while batch_idx < len(snlive_test):
		train_batch = snlive_test[batch_idx:batch_idx + eval_batch_size] if batch_idx + eval_batch_size < len(
			snlive_test) else snlive_test[batch_idx:len(snlive_test) - 1]
		image_batch, sent1_batch, sent2_batch, label_batch = [], [], [], []

		with torch.no_grad():
			for item in train_batch:
				image_batch.append(Image.open(item['Image_path']))
				sent1_batch.append(item['sentence1'])
				sent2_batch.append(item['sentence2'])
				label_batch.append(item['label_id'])
			batch_idx += eval_batch_size
			ground_truth += label_batch

			sent1_ids = clip.tokenize(sent1_batch, truncate=True).to(device)
			sent2_ids = clip.tokenize(sent2_batch, truncate=True).to(device)
			image_ids = torch.stack([preprocess(img) for img in image_batch], dim=0).to(device)

			sent1_features = clip_model.encode_text(sent1_ids).to(device)
			sent2_features = clip_model.encode_text(sent2_ids).to(device)
			image_features = clip_model.encode_image(image_ids).to(device)

			sent1_embedding = sent1_features / sent1_features.norm(dim=-1, keepdim=True)
			sent2_embedding = sent2_features / sent2_features.norm(dim=-1, keepdim=True)
			image_embedding = image_features / image_features.norm(dim=-1, keepdim=True)

			text_embedding = fusion(sent1_embedding,sent2_embedding)
			logits = OutputLayer(text_embedding)

			image_text_embedding = fusion(image_embedding,sent2_embedding)
			image_logits = OutputLayer(image_text_embedding)

			prediction += torch.argmax(logits, dim=-1).detach().cpu().data.numpy().tolist()
			prediction_2 += torch.argmax(image_logits, dim=-1).detach().cpu().data.numpy().tolist()

	text_acc = round(accuracy_score(ground_truth, prediction),4)
	image_acc = round(accuracy_score(ground_truth, prediction_2),4)
	text_acc_lst.append(text_acc)
	image_acc_lst.append(image_acc)
	print(f"  Epoch {epoch+1}, "
		  f" Train Acc: {round(accuracy_score(train_ground_truth, train_predict),4)},"
		  f" Text Test Acc: {text_acc},"
		  f" Image Test Acc: {image_acc}\n")

print(f"text acc:{text_acc_lst}")
print(f"image acc:{image_acc_lst}")


