#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
14/11/2024
Author: Katarina
"""

# script to get the best search result from the wikidata

import time
import re
import os
import pickle as p
import random as r
import pandas as pd

from nltk import sent_tokenize, word_tokenize
from deep_translator import GoogleTranslator
from simplemma import lemmatize

import flair
from flair.data import Sentence
from flair.models import SequenceTagger

pos_tagger = SequenceTagger.load("flair/upos-multi")

manual_mistakes = {'gaza' : 'gaza strip', 'po' : 'po', 'rassemblement national' : 'rassemblement national', 'pd' : 'democratic party', 'striscia' : 'gaza strip', 'cvid' : 'covid', 'njp' : 'njp', 'trs' : 'trs', 'manovra' : 'maneuver', 'aifa' : 'italian medicines agency', 'cdc' : 'cdc', 'miu' : 'nicolae miu', 'lega' : 'lega nord', 'lazio' : 'lazio', 'spiriti' : 'spiriti', 'pc' : 'pc', 'riese' : 'riese', 'conte' : 'conte', 'mauro sandri' : 'mauro sandri', 'ipl' : 'indian premier league', 'stura' : 'stura', 'belmarsh' : 'belmarsh', 'draghi' : 'draghi', 'espionage act' : 'espionage act', 'galeazzo' : 'galeazzo', 'westminster' : 'westminster', 'stockport' : 'stockport', 'american' : 'american', 'scottish' : 'scotland', 'virgin' : 'virgin', 'unite': 'unite', 'eu' : 'european union', 'sky' : 'sky', 'xi' : 'xi jinping', 'medvedev' : 'medvedev', 'schoof' : 'schoof', 'nsc' : 'new social contract', 'bbb' : 'farmer-citizen movement', 'fd' : 'financieel dagblad', 'belgian' : 'belgium', 'nederlandsche bank' : 'nederlandsche bank', 'vn' : 'united nations', 'un' : 'united nations', 'congress' : 'congress', 'snp' : 'scottish national party', 'ukraine war' : 'ukraine war', 'kremlin' : 'kremlin', 'iranian' : 'iran', 'lando norris' : 'lando norris', 'anp' : 'algemeen nederlands persbureau', 'democratic' : 'democratic', 'iq' : 'intelligence quotient', 'wilders' : 'wilders', 'elena' : 'elena', 'prigozhin' : 'yevgeny prigozhin', 'falcone' : 'falcone', 'stalin' : 'stalin', 'zelensky' : 'volodymyr zelenskyy', 'islam radical' : 'islam', 'haz' : 'haz', 'eastern europe' : 'eastern europe', 'ashkenazi' : 'ashkenazi', 'non-white' : 'people of color', 'civil rights act' : 'civil rights act', 'civil rights' : 'civil rights', 'racial' : 'race', 'torrente aspio' : 'aspio', 'viru $' : 'virus'}

to_remove = ['an', 'al', 'bi', 'br', 'fa', 'hd', 'ho', 'am', 'bp', 'paar', 'nu', 'jo', 'mo', 'bl', 'et', 'id', 'je', 'po', 'beni', 'irl', 'non', 'dr', 'di', 'ad', 'cp', 'orbene']


def get_raw(name, entities):
    for x in entities:
        if x['word'].lower() == name:
            return(x)
    return(None)

def link_entity(entity):
    try:
        if entity['entity_group'] in ['PER', 'ORG']:
            return(entity['word'].lower())
        if entity['entity_group'] in ['MERGE', 'ADJ'] and entity['score'] > 0.8:
            return(entity['word'].lower())
        if entity['score'] < 0.8:
            return(None)
        if entity['entity_group'] == 'MISC':
            tokens = re.split('\W', entity['word'])
            if len(tokens)  == 1:
                lemma = lemmatize(entity['word'], lang = ('en', 'it', 'nl', 'es'))
                translation = GoogleTranslator(target = 'en').translate(lemma)
                return(translation.lower())
            elif len(tokens) < 4:
                translation = GoogleTranslator(target = 'en').translate(entity['word'])
                return(translation.lower())
            else:
                return(None)
        if entity['entity_group'] == 'LOC':
            translation = GoogleTranslator(target = 'en').translate(entity['word'])
            return(translation.lower())
        return(None)
    except:
        print(f'Linking went wrong!')
        print(entity)
        return(None)

def main():
    in_directory = 'data/5_raw_wikidata'
    out_directory = 'data/6_linked_wikidata'
    for file in os.listdir(in_directory):
        if re.search('\.p$', file) is None or file in os.listdir(out_directory):
            continue
        with open(f'{in_directory}/{file}', 'rb') as f:
            channel_data = p.load(f)
        print(f'Processing {file}')
        cleaned_wikidata = []
        for i, wikidata in enumerate(channel_data['raw_wikidata']):
            if i % 1000 == 0 and i > 0:
                print(f"\t{i} out of {len(channel_data['raw_wikidata'])} done!")
            paragraph = channel_data['post'][i]
            paragraph_cleaned_wikidata = {}
            for entity in wikidata:
                if (len(entity) == 1) or (entity in to_remove):
                    continue
                try:
                    search_results = wikidata[entity]['search']
                except:
                    search_results = []
                titles = []
                for index, search_result in enumerate(search_results[:5]):
                    titles.append(search_result['display']['label']['value'].lower())
                if entity in manual_mistakes:
                    linked_entity = manual_mistakes[entity]
                    if linked_entity in titles:
                        index = titles.index(linked_entity)
                        entity_data = search_results[index]
                        entity_id = entity_data['id']
                    else:
                        entity_data = None
                        entity_id = None
                elif len(search_results) > 0:
                    search_result = search_results[0]
                    linked_entity = search_result['display']['label']['value'].lower()
                    entity_data = search_result
                    entity_id = search_result['id']
                else:
                    raw_entity = get_raw(entity, channel_data['preprocessed_entities'][i])
                    if raw_entity == None:
                        print(f'could not find entity for {entity}!!')
                        continue
                    linked_entity = link_entity(raw_entity)
                    if linked_entity == None:
                        continue
                    entity_id = '0000'
                    entity_data = None
                paragraph_cleaned_wikidata[entity] = (linked_entity, entity_id)
            cleaned_wikidata.append(paragraph_cleaned_wikidata)
        channel_data['cleaned_wikidata'] = cleaned_wikidata
        with open(f'{out_directory}/{file}', 'wb') as f:
            p.dump(channel_data, f)

if __name__ == '__main__':
    main()
