"""
Disclaimer: this process is not optimized whatsoever, but it only had to run once. 
Final output: the Dictionaries required to load a dataset (serialized with pickle to save time in future runs)
---- 
Script that transforms the CoNLL chain as presented in AMALGUM to the one required for this task.
Two step process
    1: transformation to our annotation
    2: resolution of nested entities
        approach 1: only use outer entities (as previous publication)
        approach 2: create document copies that make use of that specific
            entity, based on the entity mentions and occurrences.
"""

import os
import numpy as np

ROOT_DIR_NAME = "sources/amalgum"
balanced = False 
strategy = 'meancols'
strategy_depth = 3
DOC_KEYS = {'reddit':1,'academic':2,'bio':3,'fiction':4,'interview':5,'news':6,'whow':7,'voyage':8}

#split-data-params
percent=0.10
seed=1337
# source data class
class ProtoGum():
    def __init__(self, sent, chars, tok, ent, ent_prop, c_type, c_link):
        self.sent = sent
        self.chars = chars
        self.tok = tok
        self.ent = ent  # entities with ids
        self.ent_prop = ent_prop  # entity properties (new/giv/acc)
        self.c_type = c_type
        self.c_link = c_link

    def __str__(self):
        return "{} {} {} {} {} {} {}".format(self.sent, self.chars, self.tok, self.ent, self.ent_prop, self.c_type,
                                             self.c_link)

# target data class
class ProtoET():
    def __init__(self, tok, ents):
        self.tok = tok
        self.ents = ents

    def __str__(self):
        return "{} {}".format(self.tok, self.ents)

def entity_generalization(directory,second_dir):
    path = os.path.join(directory,second_dir)
    global_entity_dict = {'0':0}
    global_entity_counter = 1 


    for file in os.listdir(path):
        with open(os.path.join(path,file), 'r') as f:
            lines = f.readlines()
        for line in lines:
            entity = line.strip('\n').split('\t')[-1]
            try:
                entity_id = global_entity_dict[entity]
            except KeyError:
                global_entity_dict[entity] = global_entity_counter
                global_entity_counter += 1

    print('\t {} total entities identified in {}'.format(global_entity_counter,path))


    print('Total entities: {}'.format(len(global_entity_dict)))

    for file in os.listdir(path):
        with open(os.path.join(path,file)) as f:
            lines = f.readlines()
        with open(os.path.join(path,file), 'w') as f:
            for line in lines:
                try:
                    tok, ent = line.strip('\n').split('\t')
                    if tok != '':
                        ent = global_entity_dict[ent]
                        f.write('{}\t{}\n'.format(tok,str(ent)))
                except ValueError:
                    pass
    return None

def anox_to_conllu(input_path, output_path):
    print('Working on directory: {}'.format(output_path))
    if not os.path.exists(output_path):
        print('\t Directory missing. Creating...')
        os.mkdir(output_path)

    input_path = os.path.join(input_path,'tsv')
    for file in os.listdir(input_path):
        print('Processing file: {}'.format(file))
        outname = file.split('.')[0]+'_protoet.conllu'

        with open(os.path.join(input_path,file)) as f:
            data_raw = f.readlines()

        data = []
        for line in data_raw:
            m = line.strip('\n').split('\t')
            if len(m) > 1:
                point = ProtoGum(m[0], m[1], m[2], m[3], m[4], m[5], m[6])
            else:
                point = ProtoGum('', '', '', '', '', '', '')
            data.append(point)

        new_data = []  
        links = []  

        for current in data:
            if current.tok != '':
                new_entities = []  

                entity_types = current.ent_prop.split("|")
                entity_content = current.ent.split("|")
                assert len(entity_content) == len(entity_types)  # they have to describe the same entities.

                if not entity_types:
                    new_data.append(ProtoET(current.tok, '-'))
                    # print('Not part of any entity')

                # extract info from the link if any
                coref_targets, coref_sources = [], []
                if current.c_link != '_':
                    coref_link = current.c_link.split("|")
                    for link in coref_link:
                        if '[' in link:
                            link = link.split('[')[-1].strip(']')
                            target, source = link.split('_')
                            coref_targets.append(target)
                            coref_sources.append(source)

                for e in range(0, len(entity_types)):
                    # transform entity
                    entity = entity_content[e]
                    if '[' not in entity:
                        # handling amalgum error that entities 0 have no id.
                        # example: organization[0] is identifies as just organization
                        # example from: amalgum_news_doc39.tsv
                        e_type = entity
                        e_id = '0'
                        entity = e_type+ '-' + e_id
                    else:
                        e_type, e_id = entity.split('[')[0], entity.split('[')[-1].strip(']')
                        entity = e_type + '-' + e_id

                    # if new  (singleton - or token in an entity sequence in a previously added mapping )
                    if "new" in entity_types[e]:
                        new_entities.append(e_type + '-' + e_id)
                        # print(' Entity is new - adding to entities')

                    if "giv" in entity_types[e]:  # if giv
                        # try to find the mapping in the mappings (links)
                        resolved = False
                        for i in range(0, len(links)):
                            # check the mappins from the list
                            if entity in links[i][-1]:
                                # instead of reconstruct take the original from mappings.
                                new_entities.append(links[i][0])
                                # change e_id to the current for the next step
                                or_e_id = links[i][0].split('-')[-1]
                                # print(' Entity was given - resolved to: ' + e_type + '-' + or_e_id)
                                resolved = True
                        if not resolved:  # first time link introduction
                            # print(' Given entity failed to be resolved - adding:' + e_type + '-' + e_id)
                            links.append([e_type + '-' + e_id, [e_type + '-' + e_id]])  #
                            new_entities.append(e_type + '-' + e_id)
                            or_e_id = e_id #since it failed, original is the current
                            resolved = True

                    if "acc" in entity_types[e]:
                        new_entities.append(e_type + '-' + e_id)

                    if e_id in coref_sources:
                        if "new" in entity_types[e] or "acc" in entity_types[e]:  # firts time mapping :
                            next_mention = e_type + '-' + coref_targets[coref_sources.index(e_id)]
                            links.append([entity, [entity, next_mention]])
                        if "giv" in entity_types[e]:
                            next_mention = e_type + '-' + coref_targets[coref_sources.index(e_id)]
                            original_mention = e_type + '-' + or_e_id
                            idx = [keys[0] for keys in links].index(original_mention)
                            links[idx][-1].append(next_mention)

                # print('Initializing ProtoET with: ')
                new_data.append(ProtoET(current.tok, new_entities))


            # ProtoET(current.tok,entity)
            else:
                new_data.append(ProtoET('', ''))

        #new_data is the thing
        with open(os.path.join(output_path,outname), 'w') as outf:
            for i, dp in enumerate(new_data):
                outf.write('{}\t{}\t{}\n'.format(str(i),dp.tok,'|'.join(dp.ents)))
    return None

def conllu_to_et(input_path,output_path,strategy):
    if not os.path.exists(output_path):
        print('\t Root output directory missing. Creating...')
        os.mkdir(output_path)

    implemented_strategies = ['maxcol','meancols']
    if strategy not in implemented_strategies:
        print('Selected strategy not yet implemented.')
        return None
    else:
        output_path = os.path.join(output_path, strategy)
        print('Working on directory: {}'.format(output_path))
        if not os.path.exists(output_path):
            print('\t Directory missing. Creating...')
            os.mkdir(output_path)

    #read file data to ProtoET
    for file in os.listdir(input_path):
        data = []
        #unique document entity id
        _, doc_type, doc_id, _  = file.split('_')
        document_id = str(DOC_KEYS[doc_type]) + doc_id.strip('doc')
        with open(os.path.join(input_path, file)) as f:
            data_raw = f.readlines()
            for line in data_raw:
                line = line.split('\t')
                if len(line) > 1:
                    data.append(ProtoET(line[1], line[2].split('|')))

        #clean data from un-nessary empty entries.
        reccurring = False
        empty_et = ProtoET('', '')
        for i, dp in enumerate(data):
            if dp.tok == empty_et.tok and reccurring:
                data.pop(i)
            elif dp.tok == empty_et.tok and not reccurring:
                reccurring = True
            elif dp.tok != empty_et.tok and reccurring:
                reccurring = False

        #let's remove the 3 first lines if they are empty because i forgot to handle the Anox identifiers before
        for i in range(0,3):
            if data[0].tok == empty_et.tok :
                data.pop(0)

        #we can now go on with the strategy
        if strategy == 'maxcol':
                print('Processing file: {}'.format(file))
                outname = file.split('.')[0] + strategy

                tokens = []
                columns = []

                # create max columns
                max_cols = max([len(dp.ents) for dp in data])
                for i in range(0, max_cols):
                    columns.append([])

                assert len(columns) == max_cols

                # map entities to dictionary to be made with ids.
                unraveled_column = []
                for dp in data:
                    unraveled_column.extend(dp.ents)
                unraveled_column = set(unraveled_column)

                document_vocab = {}
                for value, key in enumerate(unraveled_column):
                    document_vocab[key.strip('\n')] = value

                for i, dp in enumerate(data):
                    tokens.append(dp.tok)
                    for i in range(0, max_cols):
                        try:
                            ent_id = dp.ents[i].strip('\n')
                            if ent_id == '':
                                ent_id = '0'
                            else:
                                ent_id = document_id + str(document_vocab[ent_id])
                        except IndexError:
                            ent_id = '0'
                        columns[i].append(ent_id)

                for idx, col in enumerate(columns):
                    with open(os.path.join(output_path,outname+'_'+str(idx)+'.et') ,'w') as outf:
                        for i, (tok, ent) in enumerate(zip(tokens, col)):
                            outf.write('{}\t{}\n'.format(tok,str(ent)))

                print('\t {} document instances created.'.format(idx))
        elif strategy == 'coldepth':
            return None
        elif strategy == 'meancols':
            print('Processing file: {}'.format(file))
            outname = file.split('.')[0] + strategy

            tokens = []
            columns = []

            lengths = [len(dp.ents) for dp in data]
            cols = int(np.round(np.mean(lengths) + np.std(lengths)))

            for i in range(0, cols):
                columns.append([])

            unraveled_column = []
            for dp in data:
                unraveled_column.extend(dp.ents)
            unraveled_column = set(unraveled_column)

            document_vocab = {}
            for value, key in enumerate(unraveled_column):
                document_vocab[key.strip('\n')] = value

            for i, dp in enumerate(data):
                tokens.append(dp.tok)
                for i in range(0, cols):
                    try:

                        ent_id = dp.ents[i].strip('\n')
                        if ent_id == '':
                            ent_id = '0'
                        else:
                            ent_id = document_id+ str(document_vocab[ent_id])
                    except IndexError:
                        ent_id = '0'
                    columns[i].append(ent_id)

            for idx, col in enumerate(columns):
                with open(os.path.join(output_path, outname + '_' + str(idx) + '.et'), 'w') as outf:
                    for i, (tok, ent) in enumerate(zip(tokens, col)):
                        outf.write('{}\t{}\n'.format(tok, str(ent)))

    return None

def random_split(path,percentage=0.1,seed=5):
    rs = np.random.RandomState(seed=seed)
    all_data = os.listdir(path)
    length = len(all_data)
    dev_files_num = round(length * percentage)
    idx = rs.randint(0,length-dev_files_num)

    train_dir = os.path.join(path,'train')
    if not os.path.exists(train_dir):
        os.mkdir(train_dir)
        print('Created train directory: {}'.format(train_dir))
    dev_dir = os.path.join(path,'dev')
    if not os.path.exists(dev_dir):
        os.mkdir(dev_dir)
        print('Created train directory: {}'.format(dev_dir))

    for file in all_data[idx:idx+dev_files_num]:
        os.replace(os.path.join(path,file),os.path.join(dev_dir,file))

    for file in all_data[:idx]:
        os.replace(os.path.join(path,file),os.path.join(train_dir,file))

    for file in all_data[idx+dev_files_num:]:
        os.replace(os.path.join(path,file),os.path.join(train_dir,file))

def make_dict(path, files,tokenizer):
    print(f'Processing dir: {path}')
    data_dict = {'attention_mask':[], 'input_ids':[], 'entities':[]}

    for file in files:
        print(f'Processing file: {file}')
        filepath = os.path.join(path,file)
        with open(filepath, 'r') as f:
            lines = f.readlines()

        new_lines = []
        for line in lines:
            new_lines.append(line.strip('\n').split('\t'))

        ids = []
        ents = []
        for line in new_lines:
            input_ids = tokenizer.encode(line[0])
            ids.extend(input_ids)
            ents.extend(len(input_ids) * [line[1]])

        data_dict['input_ids'].append(ids)
        data_dict['attention_mask'].append([1] * len(ids))
        data_dict['entities'].append(ents)

    return data_dict

def create_dicts(path):
    from transformers import GPT2Tokenizer
    import pickle as pkl

    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

    train_path = os.path.join(path,'train')
    dev_path = os.path.join(path, 'dev')

    train_docs = os.listdir(train_path)
    dev_docs = os.listdir(dev_path)

    train_dict = make_dict(train_path,train_docs,tokenizer)
    train_out = os.path.join(path, 'train.pkl')
    with open(train_out, 'wb') as pf:
        pkl.dump(train_dict, pf, protocol=pkl.HIGHEST_PROTOCOL)
        print(f'Saving at :{train_out}')

    dev_dict = make_dict(dev_path, dev_docs,tokenizer)
    val_out = os.path.join(path, 'validation.pkl')
    with open(val_out, 'wb') as pf:
        pkl.dump(dev_dict, pf, protocol=pkl.HIGHEST_PROTOCOL)
        print(f'Saving at: {val_out}')

    return True


def create_csv_files(path,output):
    #converts et files to csv files.
    import csv

    files = os.listdir(path)

    for file in files:
        with open(os.path.join(path,file), 'r') as f:
            lines = f.readlines()

        new_lines = []
        for line in lines:
           new_lines.append(line.strip('\n').split('\t'))

        with open(os.path.join(output,file), 'w', newline='') as f:
           wr = csv.writer(f, delimiter='\t', quotechar='|', quoting=csv.QUOTE_MINIMAL)
           for line in new_lines:
               wr.writerow(line)


def main():
    #assumes dataset source files are on this directory
    cwd = os.getcwd()
    output_directory_root = os.path.join(cwd,'amalgum')
    output_directory_conllu = os.path.join(output_directory_root,'conllu')
    output_directory_et = os.path.join(output_directory_root, 'et')
    dataset_path = os.path.join(cwd,ROOT_DIR_NAME)

    if not os.path.exists(output_directory_root):
        print('Creating root data directory at: {}'.format(output_directory_root))
        os.mkdir(output_directory_root)

    if balanced == True:
        postfix = '_balanced'
    else:
        postfix = ''

    print('Starting Anox to CONNLU conversion...')
    anox_to_conllu(dataset_path+postfix,output_directory_conllu+postfix)
    anox_to_conllu(dataset_path+'_reddit'+postfix,output_directory_conllu+'_reddit'+postfix)
    print('Starting CoNLLU to ET conversion...')
    print('\t Un-nesting strategy followed is: {}'.format(strategy))
    conllu_to_et(output_directory_conllu+postfix,output_directory_et,strategy)
    conllu_to_et(output_directory_conllu+'_reddit'+postfix,output_directory_et,strategy)
    print('Generalizing the vocabulary to one.')
    entity_generalization(output_directory_et,strategy)
    print('Creating persistent train/dev split with {}% on dev (seed={}'.format(percent,seed))
    random_split(os.path.join(output_directory_et,strategy),percentage=percent,seed=seed)
    print('Making of dicts to be used for LM task...')
    create_dicts(os.path.join(output_directory_et,strategy))

    print('Process ended successfully.')
    return True

if __name__=='__main__':
    main()