from collections import OrderedDict
from copy import deepcopy
import codecs
import json
from sacremoses import MosesDetokenizer
from conllu import parse
from lib.rule import Question, AnswerSpan
import pattern
import stanza
from argparse import ArgumentParser
from pattern.en import conjugate
from tqdm import tqdm

try:
	print(conjugate('say', tense='present', person=3, number='singular'))
except:
	pass

def qa2d(idx):
			q = Question(deepcopy(examples[idx].tokens))
			if not q.isvalid:
				print("Question {} is not valid.".format(idx))
				return ''
			a = AnswerSpan(deepcopy(examples[str(idx)+'_answer'].tokens))
			if not a.isvalid:
				print("Answer span {} is not valid.".format(idx))
				return ''
			q.insert_answer_default(a)
			return detokenizer.detokenize(q.format_declr(), return_str=True)

def print_sentence(idx):
	return detokenizer.detokenize([examples[idx].tokens[i]['form'] for i in range(len(examples[idx].tokens))], return_str=True)

parser = ArgumentParser("Stanford Parser")
parser.add_argument("--file", type=str, default="./data/vqa.txt")
parser.add_argument("--result", type=str, default="./data/vqa_q2d.txt")

args = parser.parse_args()
nlp = stanza.Pipeline('en')
detokenizer = MosesDetokenizer()

counter = 0
conllu_batch = []
id_batch = []
answer_batch = []
image_batch = []
with open(args.file) as scr, open(args.result, 'w') as tgt:
	for line in tqdm(scr):
		line = line.strip().split('\t')
		typ, qid, question, ans, image = line[0], line[1], line[2], line[3].replace('[','').replace(']','').replace('"','').replace("'",'').split(', '), line[4]
		if typ == 'other':
			counter += 1
			answer = '<mask>'
			doc = nlp(question)
			ques_conllu = []
			# 1	What	_	PRON	WP	_	0	root	_	_
			for word in doc.sentences[0].words:
				ques_conllu.append(f"{word.id}\t{word.text}\t_\t{word.upos}\t{word.xpos}\t_\t{word.head}\t{word.deprel}\t_\t_")
			parsed_question = '\n'.join(ques_conllu)

			doc = nlp(answer)
			answ_conllu = []
			# 1	What	_	PRON	WP	_	0	root	_	_
			for word in doc.sentences[0].words:
				answ_conllu.append(f"{word.id}\t<mask>\t_\t{word.upos}\t{word.xpos}\t_\t{word.head}\t{word.deprel}\t_\t_")
			parsed_answer = '\n'.join(answ_conllu)

			conllu = str(parsed_question + '\n\n' + parsed_answer)
			conllu_batch.append(conllu+'\n\n')
			id_batch.append(qid)
			answer_batch.append(max(ans, key=ans.count))
			image_batch.append(image)

			if counter % 128 == 0:
				conllu_file = parse(''.join(conllu_batch))
				# Creating dict
				ids = [i for i in range(int(len(conllu_file)//2))]
				examples = OrderedDict()
				count = 0
				for i, s in enumerate(conllu_file):
					if i % 2 == 0:
						examples[ids[count]] = s
					else:
						examples[str(ids[count])+'_answer'] = s
						count +=1

				total = int(len(examples.keys())/2)
				for i in range(total):
					out = qa2d(i)
					if out != '':
						# print(f"{print_sentence(i)}\n{out}\n-------------------")
						tgt.write(f"{id_batch[i]}\t{print_sentence(i)}\t{out}\t{'1'}\t{answer_batch[i]}\t{image_batch[i]}\n")
					else:
						tgt.write(f"{id_batch[i]}\t{print_sentence(i)}\t{print_sentence(i)}\t{'0'}\t{answer_batch[i]}\t{image_batch[i]}\n")
				conllu_batch = []
				id_batch = []
				answer_batch = []
				image_batch = []


