import os, pdb, sys
import numpy as np
import json

from tqdm import tqdm as progress_bar

# NOTE: Download MultiWOZ 2.3 at https://github.com/lexmen318/MultiWOZ-coref


domains = ['attraction', 'hotel', 'restaurant', 'taxi', 'train']
splits = ['train', 'dev', 'test']

data = json.load(open('data.json', 'r'))
ont = json.load(open('ontology.json', 'r'))

with open('valListFile.json', 'r') as valfile:
	val_list = [line.rstrip('\n') for line in valfile]
with open('testListFile.json', 'r') as testfile:
	test_list = [line.rstrip('\n') for line in testfile]
valid_ont = {domain: {} for domain in domains}

size = len(data)
print('data size', len(data))
# pdb.set_trace()

for domain_slot, values in ont.items():
	domain, slot = domain_slot.split('-')
	if domain in domains and len(values) > 2 and len(values) < 100:
		valid_ont[domain][slot] = values

final = {split: [] for split in splits}
for guid, conversation in progress_bar(data.items(), total=size):
	speaker = 'customer'

	if guid in val_list:
		split = 'dev'
	elif guid in test_list:
		split = 'test'
	else:
		split = 'train'

	topics = []
	for domain, slot_vals in conversation['goal'].items():
		if domain in domains and len(slot_vals) > 0:
			topics.append(domain)

	new_convo = {'convo_id': guid, "original": [], 'topics': topics}
	for turn in conversation['log']:
		if turn['turn_id'] % 2 == 0:
			speaker = 'customer'
		else:
			speaker = 'agent'
		new_turn = [speaker, turn['text']]
		new_convo['original'].append(new_turn)

	final[split].append(new_convo)

for split, processed in final.items():
	print(split, len(processed))

pdb.set_trace()
json.dump(final, open(f'mwoz_processed.json', 'w'))
print(f"Finished pre processing {split}")


