import os.path
import pickle
import numpy as np
import pycountry

banned = set(('Haida',
              'Urdu',
              'Quechua',
              'Navajo',
              'Arabic',
              'Georgian',
              'Bengali',
              'Hebrew',
              'Macedonian',
              'Latvian',
              'Armenian',
              'Persian',
              'Khaling',
              'Slovak',
              'Lower Sorbian',
              'Serbo-Croatian'
              ))
class Task1File:
    def __init__(self, filename):
        with open(filename, 'r', encoding='utf-8') as f:
            base = os.path.basename(filename)
            suffixes = ('-dev', '-train-low', '-train-medium', '-train-high',
                        '-test', '-covered-test')
            for suffix in suffixes:
                if base.endswith(suffix):
                    self.part = suffix.split('-')[1]
                    self.language = base[:-len(suffix)]
                    break

            assert self.part and self.language

            data = [line.rstrip('\n').split('\t', -1) for line in f]

        assert all(len(row) for row in data)
        try:
            if 'train' in filename:
                lname = ' '.join([i.capitalize() for i in filename.split('/')[-1].split('-')[:-2]])
            else:
                lname = ' '.join([i.capitalize() for i in filename.split('/')[-1].split('-')[:-1]])

            if lname == 'Norwegian Bokmal':
                lname = 'Norwegian Bokmål'
            elif lname == 'Serbo Croatian':
                lname = 'Serbo-Croatian'
            elif lname == 'Slovene':
                lname = 'Slovenian'

            if lname in banned:
                self.data = []
                return
            lang = pycountry.languages.get(name=lname).alpha_3
            print(lname,lang)
        except:
            print(lname, 'error')
            self.data = []
            return
        self.data = [(src, None if trg in ('--', '') else trg, tag.split(';'), lang)
                     for src, trg, tag in data]

    def get_alphabet(self):
        return sorted({c for src,_,_,_ in self.data for c in src} |
                      {c for _,trg,_,_ in self.data for c in trg if trg})

    def get_features(self):
        return sorted({f for _,_,feats,_ in self.data for f in feats})

    def merge(self, other_file):
        self.data = self.data + other_file.data


class G2pFile:
    def __init__(self, filename, languages):
        languages = set((lang.split('-')[0] for lang in languages))
        with open(filename, 'r', encoding='utf-8') as f:
            base = os.path.basename(filename)
            suffixes = ('_train.sig', '_test.sig')
            for suffix in suffixes:
                if base.endswith(suffix):
                    self.part = suffix.split('_')[1]
                    #self.language = base[:-len(suffix)]
                    break

            assert self.part #and self.language

            data = [line.rstrip('\n').split('\t', -1) for line in f]

        assert all(len(row) for row in data)

        self.data = [(src, trg, tag, lang)
                     for src, trg, tag, lang in data if lang in languages]

    def get_alphabet(self):
        #TODO: See about making src/trg unique
        return sorted({c for src,_,_,_ in self.data for c in src} |
                      {c for _,trg,_,_ in self.data for c in trg.split()})

    def get_features(self):
        return sorted({f for _,_,feats,_ in self.data for f in feats})

    def merge(self, other_file):
        self.data = self.data + other_file.data

class ASJPFile:
    def __init__(self, filename, languages):
        languages = set((lang.split('-')[0] for lang in languages))
        with open(filename, 'r', encoding='utf-8') as f:
            base = os.path.basename(filename)
            suffixes = ('_train.asjp', '_dev.asjp')
            for suffix in suffixes:
                if base.endswith(suffix):
                    self.part = suffix.split('_')[1]
                    #self.language = base[:-len(suffix)]
                    break

            assert self.part #and self.language

            data = [line.rstrip('\n').split('\t', -1) for line in f]

        assert all(len(row) for row in data)

        self.data = [(src, trg, tag, lang)
                     for src, trg, tag, lang in data if lang in languages]

    def get_alphabet(self):
        #TODO: See about making src/trg unique
        return sorted({c for src,_,_,_ in self.data for c in src} |
                      {c for _,trg,_,_ in self.data for c in trg.split()})

    def get_features(self):
        return sorted({f for _,_,feats,_ in self.data for f in feats})

    def merge(self, other_file):
        self.data = self.data + other_file.data

def load_language_embeddings():
    """Loads a word embedding file."""

    with open('./lvec_ids.pkl', 'rb') as in_f:
        lang2id = pickle.load(in_f)

    p1 = np.load('./lvec_p1.npy')
    p2 = np.load('./lvec_p2.npy')
    p3 = np.load('./lvec_p3.npy')
    stacked =  np.hstack([p1,p2,p3])

    # lang2vec = {}
    # for idx, lang in enumerate(lang2id):
    #     lang2vec[lang] = stacked[idx]

    print('Loaded pre-trained language embeddings')
    return lang2id, p1#stacked

if __name__ == '__main__':
    import sys
    for filename in sys.argv[1:]:
        print(filename)
        tf = Task1File(filename)
        for src, trg, tag in tf.data:
            print(src, trg)
        #print(len(tf.get_features()), len(tf.get_alphabet()))
