# Interface for accessing the SNLI-VE dataset.

import json
import os

class SNLIVE(object):
	def __init__(self, root_path=None):
		if root_path is None:
			print("root_path must be defined")
			raise ValueError
		self.flickr_image_path = os.path.join(root_path,'Flickr30K','flickr30k_images')
		self.caption_path = os.path.join(root_path,'results_20130124.token')
		self.train_path = os.path.join(root_path,'snli_ve_train.jsonl')
		self.dev_path = os.path.join(root_path, 'snli_ve_dev.jsonl')
		self.test_path = os.path.join(root_path, 'snli_ve_test.jsonl')
		self.nli_path = os.path.join(root_path, 'multinli_1.0_train.jsonl')

		self.caption = self.get_caption_data()
		self.train_set = self.get_train_data()
		self.valid_set = self.get_valid_data()
		self.test_set = self.get_test_data()
		self.nli_train_set = self.get_other_nli_data()

	def get_caption_data(self):
		print('Loading captions...')
		caption = {}
		with open(self.caption_path) as cap:
			for line in cap:
				line = line.strip().split('\t')
				cid, sentence = line[0].split('.jpg')[0], line[1]
				if cid in caption:
					caption[cid].append(sentence)
				else:
					caption[cid] = [sentence]
		print(f"caption has {len(caption)} images.")
		return caption

	def get_train_data(self):
		print('Loading training set...')
		train_data = []
		with open(self.train_path) as training:
			for line in training:
				tmp_dict = {}
				tmp = json.loads(line)
				tmp_dict['Image_path'] = os.path.join(self.flickr_image_path, f"{tmp['Flickr30K_ID']}.jpg")
				tmp_dict['sentence1'] = tmp['sentence1']
				tmp_dict['sentence2'] = tmp['sentence2']
				tmp_dict['captionID'] = tmp['captionID'].split('.jpg')[0]
				tmp_dict['gold_label'] = tmp['gold_label']
				if tmp['gold_label'] == 'neutral':
					tmp_dict['label_id'] = 0
				elif tmp['gold_label'] == 'entailment':
					tmp_dict['label_id'] = 1
				elif tmp['gold_label'] == 'contradiction':
					tmp_dict['label_id'] = 2
				train_data.append(tmp_dict)
		print(f"training set has {len(train_data)} examples.")
		return train_data

	def get_valid_data(self):
		print('Loading validation set...')
		valid_data = []
		with open(self.dev_path) as valid:
			for line in valid:
				tmp = json.loads(line)
				tmp_dict = {}
				tmp_dict['Image_path'] = os.path.join(self.flickr_image_path, f"{tmp['Flickr30K_ID']}.jpg")
				tmp_dict['sentence1'] = tmp['sentence1']
				tmp_dict['sentence2'] = tmp['sentence2']
				tmp_dict['captionID'] = tmp['captionID'].split('.jpg')[0]
				tmp_dict['gold_label'] = tmp['gold_label']
				if tmp['gold_label'] == 'neutral':
					tmp_dict['label_id'] = 0
				elif tmp['gold_label'] == 'entailment':
					tmp_dict['label_id'] = 1
				elif tmp['gold_label'] == 'contradiction':
					tmp_dict['label_id'] = 2
				valid_data.append(tmp_dict)
		print(f"validation set has {len(valid_data)} examples.")
		return valid_data


	def get_test_data(self):
		print('Loading testing set...')
		test_data = []
		with open(self.test_path) as test:
			for line in test:
				tmp = json.loads(line)
				tmp_dict = {}
				tmp_dict['Image_path'] = os.path.join(self.flickr_image_path, f"{tmp['Flickr30K_ID']}.jpg")
				tmp_dict['sentence1'] = tmp['sentence1']
				tmp_dict['sentence2'] = tmp['sentence2']
				tmp_dict['captionID'] = tmp['captionID'].split('.jpg')[0]
				tmp_dict['gold_label'] = tmp['gold_label']
				if tmp['gold_label'] == 'neutral':
					tmp_dict['label_id'] = 0
				elif tmp['gold_label'] == 'entailment':
					tmp_dict['label_id'] = 1
				elif tmp['gold_label'] == 'contradiction':
					tmp_dict['label_id'] = 2
				test_data.append(tmp_dict)
		print(f"testing set has {len(test_data)} examples.")
		return test_data


	def get_other_nli_data(self):
		print('Loading NLI training set...')
		nli_data = []
		with open(self.nli_path) as training:
			for line in training:
				tmp_dict = {}
				tmp = json.loads(line)
				if tmp['gold_label'] == '-':
					continue
				tmp_dict['Image_path'] = os.path.join(self.flickr_image_path, f"NULL.jpg")
				tmp_dict['sentence1'] = tmp['sentence1']
				tmp_dict['sentence2'] = tmp['sentence2']
				tmp_dict['captionID'] = '00000000'
				tmp_dict['gold_label'] = tmp['gold_label']
				if tmp['gold_label'] == 'neutral':
					tmp_dict['label_id'] = 0
				elif tmp['gold_label'] == 'entailment':
					tmp_dict['label_id'] = 1
				elif tmp['gold_label'] == 'contradiction':
					tmp_dict['label_id'] = 2
				else:
					print(tmp['gold_label'])
					raise ValueError
				nli_data.append(tmp_dict)
		print(f"NLI training set has {len(nli_data)} examples.")
		return nli_data
