#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
14/11/2024
new version: 11/4/2025
Author: Katarina
"""

# script to preprocess the raw entities; join split up miscellaneous entities
# and if there is a sole adjective with ent_group is NOT person, lemmatize and translate

import time
import re
import os
import pickle as p
import random as r
import pandas as pd

from nltk import sent_tokenize, word_tokenize
from deep_translator import GoogleTranslator
from simplemma import lemmatize

import flair
from flair.data import Sentence
from flair.models import SequenceTagger

pos_tagger = SequenceTagger.load("flair/upos-multi")

def get_channel_langs():
    # open a .csv with the languages of the channels
    channel_langs = pd.read_csv('channel_langs.csv', sep = ';')
    short_forms = {'italian' : 'it', 'dutch' : 'nl', 'english' : 'en'}
    languages = [x.lower() for x in channel_langs['language']]
    channel_langs = {x : short_forms[languages[i]] for i, x in enumerate(channel_langs['channel']) if languages[i] in short_forms}
    return(channel_langs)

def get_all_entities(raw_entities):
    all_entities = []
    for paragraph in raw_entities:
        for entity in paragraph:
            if (entity['score'] > 0.99) and (len(entity['word']) > 3) and (entity['entity_group'] in ['PER', 'LOC', 'ORG']):
                word = entity['word'].lower()
                all_entities.append(word)
    return(list(set(all_entities)))

def get_groups(l):
    # in: a list of integers
    # out: a list of consecutive integers
    l.sort()
    groups = []
    group = []
    for x in l:
        if len(group) == 0:
            group = [x]
        else:
            if group[-1] == (x-1):
                group.append(x)
            else:
                groups.append(group)
                group = [x]
    groups.append(group)
    return(groups)

def get_misc_entities(raw, lang, paragraph):
    positions = []
    for entity in raw:
        positions += range(entity['start'], entity['end'])
    groups = get_groups(positions)
    misc_entities = {x : [] for x in range(len(groups))}
    for entity in raw:
        for i, group in enumerate(groups):
            if (entity['start'] in group) and (entity['end']-1 in group):
                misc_entities[i].append(entity)
    misc_entities = list(misc_entities.values())
    misc_entities = [sorted(x, key = lambda k : k['start']) for x in misc_entities]
    new_entities = []
    remove_entities = []
    for potential_entity in misc_entities:
        if len(potential_entity) <= 1:
            continue
        string = ''
        old_entities = []
        n_entities = 0
        score = 0
        for i, entity in enumerate(potential_entity):
            old_entities.append((entity['word'], entity['start']))
            if i == 0:
                string += entity['word']
                n_entities += 1
                score += entity['score']
            else:
                if (entity['entity_group'] == 'MISC') or (re.search('^##', entity['word']) is not None):
                    word = re.sub('##', '', entity['word'])
                    string += word
                    n_entities += 1
                    score += entity['score']
        if (n_entities > 1) and (len(string) > 2) and (len(string) < 14):
            string = re.sub('[^ \w0-9\-]', '', string)
            string = re.sub('[( $)|(^ )]', '', string)
            string = re.sub(' - ', '-', string)
            score = score / n_entities
            new_entity = {}
            if score > 0.7:
                #tokens = word_tokenize(string)
                #lemmas = [lemmatize(x, lang = ('en', 'it', 'nl', 'es')) for x in tokens]
                #lemmas = ' '.join(lemmas)
                #translation = GoogleTranslator(source = lang, target = 'en').translate(lemmas)
                new_entity = {'entity_group' : 'MERGE',
                                'score' : score,
                                'word' : string,
                                'start' : potential_entity[0]['start']}
                new_entities.append(new_entity)
                remove_entities += old_entities
    # now check which of the new entities are adjectives that need a translation
    new_entities_translated = []
    for entity in new_entities:
        word = entity['word'].lower()
        if is_adjective(word, paragraph):
            lemma = lemmatize(word, lang = ('en', 'it', 'nl', 'es'))
            translation = GoogleTranslator(source = lang, target = 'en').translate(lemma)
            entity['original'] = entity['word']
            entity['word'] = translation
        new_entities_translated.append(entity)
    # returns the new entities + a list of tuples (word, start) of entities to remove
    return(new_entities_translated, remove_entities)

def is_adjective(word, paragraph):
    context = ''
    for sentence in sent_tokenize(paragraph):
        sentence = sentence.lower()
        if word in sentence:
            context = sentence
            break
    if context == '':
        return(False)
    sentence = Sentence(context)
    pos_tagger.predict(sentence)
    for token in sentence:
        if token.text == word:
            if token.tag == 'ADJ':
                return(True)
    return(False)

def in_entity_tokens(name, entity_tokens):
    # name = a string
    # entity_tokens = a list of strings
    name = name.lower()
    for token in entity_tokens:
        if re.search(name, token.lower()) is not None:
            return(True)
    return(False)

def main():
    in_directory = 'data/3_raw_entities'
    out_directory = 'data/4_preprocessed_entities'

    n = 0
    joined_entities = []
    sample = pd.read_csv('final_sample.csv')
    sample['language'] = sample['username'].apply(lambda x : 'it' if x == 'Italian' else 'en')
    channel_languages = {x : list(sample['language'])[i] for i, x in enumerate(list(sample['username']))}

    for file in os.listdir(in_directory):
        if re.search('.p$', file) == None or file in os.listdir(out_directory):
            continue
        with open(f'{in_directory}/{file}', 'rb') as f:
            channel_data = p.load(f)
        n += 1
        start_time = time.time()
        print(f'Processing {n} out of {len(channel_languages)}')
        certain_channel_entities = get_all_entities(channel_data['raw_entities'])
        channel_preprocessed_entities = []
        lang = channel_languages[re.sub('\.p$', '', file)]

        n_adjectives_translated = 0
        n_posthoc = 0
        n_joined = 0
        n_raw_entities = 0
        n_preprocessed_entities = 0

        for i, raw_entities in enumerate(channel_data['raw_entities']):
            if i % 500 == 0 and not i == 0:
                print(f"\t\t{i} out of {len(channel_data['raw_entities'])} done!")
                time.sleep(2)

            n_raw_entities += len(raw_entities)
            paragraph = channel_data['post'][i]
            par_preprocessed_entities = []
            # first we check if there are more entities to be found in the paragraph
            entity_tokens = [x['word'].lower() for x in raw_entities]
            for token in word_tokenize(channel_data['post'][i]):
                token = token.lower()
                if (token in certain_channel_entities) and (not in_entity_tokens(token, entity_tokens)):
                    n_posthoc += 1
                    new_entity = {'entity_group' : 'POSTHOC',
                                    'word' : token}
                    par_preprocessed_entities.append(new_entity)

            merged_entities, remove_entities = get_misc_entities(raw_entities, lang, channel_data['post'][i])
            n_joined += len(merged_entities)
            remove_entities = list(zip(*remove_entities))
            if len(remove_entities) > 0:
                remove_words = [x.lower() for x in list(remove_entities[0])]
                remove_starts = remove_entities[1]
                joined_entities.append((merged_entities, remove_words))
            else:
                remove_words, remove_starts = [], []

            par_preprocessed_entities += merged_entities

            # go over the other raw entities and preprocess them
            for raw_entity in raw_entities:
                word = raw_entity['word'].lower()
                if raw_entity['word'].lower() in remove_words and raw_entity['start'] in remove_starts:
                    continue

                # if the entity is miscellaneous and consists of 1 word with at least 5 letters:
                elif (not raw_entity['entity_group'] in ['PER', 'LOC', 'ORG']) and (len(word_tokenize(word)) == 1) and (len(word) > 4) and (raw_entity['score'] > 0.75):
                    paragraph = channel_data['post'][i]
                    if is_adjective(word, paragraph):
                        lemma = lemmatize(word, lang = ('en', 'it', 'nl', 'es'))
                        translation = GoogleTranslator(source = lang, target = 'en').translate(lemma)
                        n_adjectives_translated += 1
                        new_entity = {'entity_group' : 'ADJ',
                                        'score' : raw_entity['score'],
                                        'word' : translation,
                                        'start' : raw_entity['start'],
                                        'original' : raw_entity['word']}
                        par_preprocessed_entities.append(new_entity)
                    else:
                        par_preprocessed_entities.append(raw_entity)
                else:
                    par_preprocessed_entities.append(raw_entity)
            n_preprocessed_entities += len(par_preprocessed_entities)
            channel_preprocessed_entities.append(par_preprocessed_entities)

        # dump the channel data
        channel_data['preprocessed_entities'] = channel_preprocessed_entities
        with open(f'{out_directory}/{file}', 'wb') as f:
            p.dump(channel_data, f)

        # print the final channel stats
        print(f'Took {time.time() - start_time:.2f} s to process {n_raw_entities}')
        print(f'Added {n_adjectives_translated} translated adjectives, {n_posthoc} posthoc entities, {n_joined} joined entities')
        print(f'Found a final of {n_preprocessed_entities} preprocessed entities\n\n')

    #merge_sample = r.sample(joined_entities, 70)
    #merge_sample = list(zip(*merge_sample))
    #df = {'merged' : merge_sample[0], 'removed' : merge_sample[1]}
    #df = pd.DataFrame(df)
    #df.to_csv('evaluate_merge.csv')

if __name__ == '__main__':
    main()
