"""This script converts from original arxiv format(nested jsons) into textual input.
We add some tags like <beginsection> etc. in order to flatten the json structure. """

import os
import json

path = '.'
splits = ['val', 'test', 'train']
gonito_mapping_dir = {'val': 'dev-0', 'test': 'test-A', 'train': 'train'}
os.makedirs(path + '/dev-0', exist_ok=True)
os.makedirs(path + '/test-A', exist_ok=True)
os.makedirs(path + '/train', exist_ok=True)

def flattenize_json(json_line):
    """convert json nested structured text into one liners for summary and article text"""
    section_out = []
    #import pdb; pdb.set_trace()
    assert len(json_line['section_names']) == len(json_line['sections'])
    if len(json_line['abstract_text'])>len(json_line['article_text']):
        print(f"dupa  {len(json_line['abstract_text'])} , {len(json_line['article_text'])}")
    if len(json_line['section_names']) <2 and 'ment' in json_line['section_names'][0]:
        import pdb; pdb.set_trace()
    for sn, sec in zip(json_line['section_names'], json_line['sections']):
        section_out.append(
            '<section> ' + sn + ' <beginsection> ' + ' '.join(sec) + ' <endsection> ' + sn)
    art_text = ' '.join(section_out)

    abstract = [el.lstrip(' <S>').rstrip('</S> ') for el in json_line['abstract_text']]
    abstract_text = ' '.join(abstract)
    return art_text, abstract_text


if __name__ == '__main__':
    for s in splits:
        gs = gonito_mapping_dir[s]
        split_path = os.path.join(path, s + '.txt')
        path_in_line = os.path.join(path, gs , 'in.tsv')
        path_out_line = os.path.join(path, gs , 'expected.tsv')
        with open(split_path) as split, \
                open(path_in_line, 'w') as in_line, \
                open(path_out_line, 'w') as out_line, \
                open(split_path) as split:
            for i, line in enumerate(split):
                json_line = json.loads(line)
                art_text, abstract_text = flattenize_json(json_line)
                in_line.write(art_text + '\n')
                out_line.write(abstract_text + '\n')
                if i % 1000 == 0:
                    print('.')
                if i % 10000 == 0:
                    print(i)
        print(f'finished {s} split')
    print('Done!')

