import os
import _jsonnet
import json

mlms = ['bert-base-multilingual-cased', 'xlm-roberta-large']

udVersions = ['2.2', '2.5', '2.10']


hasNoSplits = ['UD_Afrikaans-AfriBooms', 'UD_Ancient_Greek-PROIEL', 'UD_Ancient_Greek-Perseus', 'UD_Arabic-NYUAD', 'UD_Arabic-PADT', 'UD_Armenian-ArmTDP', 'UD_Basque-BDT', 'UD_Bulgarian-BTB', 'UD_Buryat-BDT', 'UD_Catalan-AnCora', 'UD_Chinese-GSD', 'UD_Coptic-Scriptorium', 'UD_Croatian-SET', 'UD_Danish-DDT', 'UD_Dutch-Alpino', 'UD_Dutch-LassySmall', 'UD_Estonian-EDT', 'UD_Finnish-FTB', 'UD_French-Sequoia', 'UD_French-Spoken', 'UD_Gothic-PROIEL', 'UD_Greek-GDT', 'UD_Hebrew-HTB', 'UD_Hindi-HDTB', 'UD_Hungarian-Szeged', 'UD_Indonesian-GSD', 'UD_Irish-IDT', 'UD_Italian-ISDT', 'UD_Japanese-BCCWJ', 'UD_Kazakh-KTB', 'UD_Korean-Kaist', 'UD_Kurmanji-MG', 'UD_Latin-ITTB', 'UD_Latin-PROIEL', 'UD_Latin-Perseus', 'UD_Lithuanian-HSE', 'UD_Marathi-UFAL', 'UD_North_Sami-Giella', 'UD_Norwegian-NynorskLIA', 'UD_Old_Church_Slavonic-PROIEL', 'UD_Old_French-SRCMF', 'UD_Persian-Seraji', 'UD_Polish-SZ', 'UD_Portuguese-GSD', 'UD_Romanian-RRT', 'UD_Russian-GSD', 'UD_Serbian-SET', 'UD_Slovak-SNK', 'UD_Slovenian-SSJ', 'UD_Slovenian-SST', 'UD_Spanish-AnCora', 'UD_Swedish-LinES', 'UD_Swedish-Talbanken', 'UD_Swedish_Sign_Language-SSLC', 'UD_Tamil-TTB', 'UD_Telugu-MTG', 'UD_Upper_Sorbian-UFAL', 'UD_Urdu-UDTB', 'UD_Uyghur-UDT', 'UD_Vietnamese-VTB']

class ScriptFinder():
    def __init__(self):
        ranges = []
        if not os.path.isfile('scripts/Scripts.txt'):
            os.system('mkdir -p scripts')
            os.system('wget https://www.unicode.org/Public/15.0.0/ucd/Scripts.txt --no-check-certificate -O scripts/Scripts.txt')
        for line in open('scripts/Scripts.txt'):
            tok = line.split(';')
            if line[0]!='#' and len(tok) == 2:
                char_range_hex = tok[0].strip().split('..')
                char_range_int = [int(x, 16) for x in char_range_hex]
                if len(char_range_int) == 1:
                    char_range_int.append(char_range_int[0])
                # Note that we include the first and the last character of the range 
                # in the indices, so the first range for Latin is 65-90 for example, 
                # character 65 (A) and 90 (Z) are both included in the Latin
                # set.  This means that for single character scripts (caught in
                # the "if" above) the same number is repeated twice
                ranges.append(char_range_int + [tok[1].strip().split()[0]])

        self.ranges = sorted(ranges)

    def find_char(self, char):
        if len(char) > 1:
            char = char[0]
        char_idx = ord(char)
        for rangeIdx, char_range in enumerate(self.ranges):
            if char_idx >= char_range[0]:
                if char_idx <= char_range[1]:
                    return char_range[2]
            if char_range[1] > char_idx: # we can give up, because the list is sorted
                break
        return None

    def guess_script(self, text):
        classes = {}
        for char in text:
            cat = self.find_char(char)
            if cat == None or cat == 'Common':
                continue
            if cat not in classes:
                classes[cat] = 0
            classes[cat] += 1
        if len(classes) == 0:
            return None
        main_class = sorted(classes.items(), key=lambda x: x[1], reverse=True)[0][0]
        return main_class


def getTrainDevTest(path):
    train = ''
    dev = ''
    test = ''
    for conlFile in os.listdir(path):
        if conlFile.endswith('conllu'):
            if 'train' in conlFile:
                train = path + '/' + conlFile
            if 'dev' in conlFile:
                dev = path + '/' + conlFile
            if 'test' in conlFile:
                test = path + '/' + conlFile
    return train, dev, test

def hasColumn(path, idx, threshold=.1):
    total = 0
    noWord = 0
    for line in open(path).readlines()[:5000]:
        if line[0] == '#' or len(line) < 2:
            continue
        tok = line.strip().split('\t')
        if tok[idx] == '_':
            noWord += 1
        total += 1
    return noWord/total < threshold

def getModel(name):
    modelDir = 'machamp/logs/'
    nameDir = modelDir + name + '/'
    if os.path.isdir(nameDir):
        for modelDir in reversed(os.listdir(nameDir)):
            modelPath = nameDir + modelDir + '/model.pt'
            if os.path.isfile(modelPath):
                return modelPath
    return ''

def load_json(path: str):
    """
    Loads a jsonnet file through the json package and returns a dict.
    
    Parameters
    ----------
    path: str
        the path to the json(net) file to load
    """
    return json.loads(_jsonnet.evaluate_snippet("", '\n'.join(open(path).readlines())))

def makeParams(defaultPath, mlm):
    config = load_json(defaultPath)
    config['transformer_model'] = mlm
    tgt_path = 'configs/params.' + mlm.replace('/', '_') + '.json'
    json.dump(config, open(tgt_path, 'w'), indent=4)
    return tgt_path


