import os
import json
import sys

dataset_dir = sys.argv[1]

def check_data(data):
    for sample in data:
        intent = sample['intent']
        text = sample["userInput"]["text"]
        assert len(intent) > 0
        assert intent[0] != ' '
        assert intent[-1] != ' '

        if 'fb_tod_sf' not in sys.argv[1] and 'restaurant8' not in sys.argv[1]:   # fuck fb_tod_sf and restaurant8 dataset. it contains numerious annotation errors
            assert text[0] != ' '
            assert text[-1] != ' '
            orig_len = len(text)
            text = ' '.join(text.split())
            assert len(text) == orig_len
        assert 'labels' in sample
        if len(sample['labels']) != 0:
            for label in sample['labels']:
                slot = label['slot']
                assert slot[0] != ' '
                assert slot[-1] != ' '
                assert ' ' not in slot
                assert text[label['valueSpan']['startIndex']:label['valueSpan']['endIndex']] == label['valueSpan']['text']
                slot_text = label['valueSpan']['text']
                if 'fb_tod_sf' not in sys.argv[1] and 'restaurant8' not in sys.argv[1]:
                    assert slot_text[0] != ' '
                    assert slot_text[-1] != ' '
                if slot_text[-1] == '.' or slot_text[-1] == ',':
                    pass
                    # print(slot_text)
    return len(data)

data_files = [
    os.path.join(dataset_dir, 'ripe_data', 'train.json'),
    os.path.join(dataset_dir, 'ripe_data', 'valid.json'),
    os.path.join(dataset_dir, 'ripe_data', 'test.json'),
    ]

for file in data_files:
    print('checking', file)
    with open(file, 'r') as f:
        res = [json.loads(i) for i in f.readlines()]
    check_data(res)
    print('passed')

print('pass')