import pandas as pd
import numpy as np
from ekphrasis.classes.spellcorrect import SpellCorrector
from ekphrasis.classes.segmenter import Segmenter

class DataProcessing:
    def __init__(self, modelFile, dataFile, emojiFile, swearFile, numDimensions, eab=0):
        print('Loading Embeddings...')
        self.vocab, embd = self.getVecFromLocal(modelFile, emojiFile)
        self.vocab_size = len(self.vocab)
        self.embedding = np.asarray(embd)
        self.eab = eab
        self.emoji_vocab, emoji_embd = self.getVecFromEmoji(emojiFile)
        if eab != 0:
            self.emoji_vocab_size = len(self.emoji_vocab)
            self.emoji_embedding = np.asarray(emoji_embd)
        print('Embeddings loaded.')

        print('Loading Train Data...')
        self.seg = Segmenter(corpus="twitter")
        self.spc = SpellCorrector(corpus="twitter")
        if eab == 2:
            self.train_ind, self.train_grp, self.train_oth, self.train_ind_emoji, self.train_grp_emoji, self.train_oth_emoji = self.getSentences_eab(dataFile)
        else:
            self.train_ind, self.train_grp, self.train_oth = self.getSentences(dataFile)
        self.numDimensions = numDimensions
        self.emoji_numDimensions = 200
        self.train_ind_leng = len(self.train_ind)
        self.train_grp_leng = len(self.train_grp)
        self.train_oth_leng = len(self.train_oth)
        print('Train Data loaded.')

    def getVecFromLocal(self, modelFile, emojiFile):
        vocab = {}
        embd = np.zeros((1195098, 200))
        vocab['__nok__'] = 0
        vocab['__unk__'] = 1
        vocab['__num__'] = 2
        vocab['__date__'] = 3
        embd[1] = 1 - 2 * np.random.random(200)
        embd[2] = 1 - 2 * np.random.random(200)
        embd[3] = 1 - 2 * np.random.random(200)
        with open(modelFile, "rb") as f:
            i = 4
            for line in f.readlines():
                line = line.decode(encoding="utf-8").split(' ')
                vocab[line[0]] = i
                embd[i] = line[1:]
                i += 1
            f.close()
        with open(emojiFile, "rb") as f:
            for line in f.readlines():
                line = line.decode(encoding="utf-8").replace(" \r\n", "")
                line = line.split(' ')
                try:
                    vocab[line[0]]
                except KeyError:
                    vocab[line[0]] = i
                    embd[i] = line[1:]
                    i += 1
            f.close()
        return vocab, embd

    def getVecFromEmoji(self, modelFile):
        vocab = {}
        embd = np.zeros((1661 + 1, 200))
        vocab['__nok__'] = 0
        with open(modelFile, "rb") as f:
            i = 1
            for line in f.readlines():
                line = line.decode(encoding="utf-8").replace(" \r\n", "")
                line = line.split(' ')
                vocab[line[0]] = i
                embd[i] = line[1:]
                i += 1
                if i == 1661 + 1:
                    break
            f.close()
        return vocab, embd

    def getVecFromSent(self, sentence):
        ids = []
        words = sentence.split(' ')
        for i in range(len(words)):
            try:
                if words[i] == ' ' or words[i] == '':
                    continue
                ids.append(self.vocab[words[i]])
            except KeyError:
                ids = self.getVecFromSent_unkWordPro1(words[i], ids)
        return ids

    def getVecFromSent_unkWordPro1(self, word, ids):
        word = word.replace("'", "")
        word = word.replace("‘", "")
        word = word.replace("’", "")
        word = self.seg.segment(word)
        words = word.split(' ')
        for i in range(len(words)):
            try:
                if words[i] == ' ' or words[i] == '':
                    continue
                ids.append(self.vocab[words[i]])
            except KeyError:
                ids = self.getVecFromSent_unkWordPro2(words[i], ids)
        return ids

    def getVecFromSent_unkWordPro2(self, word, ids):
        word = self.spc.correct(word)
        try:
            ids.append(self.vocab[word])
        except KeyError:
            if word.isdigit() == True:
                if int(word) >= 1900 and int(word) <= 2020:
                    ids.append(4)
                else:
                    ids.append(3)
            else:
                ids.append(1)
        return ids

    def prePro_1(self, sentence):
        sentence = sentence.lower()
        sentence = sentence.replace("@user", " <user> ")
        sentence = sentence.replace("'m", " 'm")
        sentence = sentence.replace("’m", " ’m")
        sentence = sentence.replace("'re", " 're")
        sentence = sentence.replace("’re", " ’re")
        sentence = sentence.replace("'ll", " 'll")
        sentence = sentence.replace("’ll", " ’ll")
        sentence = sentence.replace("'ve", " 've")
        sentence = sentence.replace("’ve", " ’ve")
        sentence = sentence.replace("'d", " 'd")
        sentence = sentence.replace("’d", " ’d")
        sentence = sentence.replace("n't", " n't")
        sentence = sentence.replace("n’t", " n’t")
        sentence = sentence.replace("'s", " 's")
        sentence = sentence.replace("’s", " ’s")
        sentence = sentence.replace("s'", " s'")
        sentence = sentence.replace("s’", " s’")
        sentence = sentence.replace(".", " . ")
        sentence = sentence.replace(",", " , ")
        sentence = sentence.replace("!", " ! ")
        sentence = sentence.replace("！", " ！ ")
        sentence = sentence.replace("?", " ? ")
        sentence = sentence.replace("？", " ？ ")
        sentence = sentence.replace("#", " # ")
        sentence = sentence.replace("$", " $ ")
        sentence = sentence.replace("\"", " \" ")
        sentence = sentence.replace("”", " ” ")
        sentence = sentence.replace("“", " “ ")
        sentence = sentence.replace(";", " ; ")
        sentence = sentence.replace(":", " : ")
        sentence = sentence.replace("(", " ( ")
        sentence = sentence.replace(")", " ) ")
        sentence = sentence.replace("[", " [ ")
        sentence = sentence.replace("]", " ] ")
        sentence = sentence.replace("{", " { ")
        sentence = sentence.replace("}", " } ")
        sentence = sentence.replace("+", " + ")
        sentence = sentence.replace("-", " - ")
        sentence = sentence.replace("*", " * ")
        sentence = sentence.replace("/", " / ")
        sentence = sentence.replace("=", " = ")
        sentence = sentence.replace("—", " — ")
        sentence = sentence.replace("…", " … ")
        sentence = sentence.replace("•", " • ")
        sentence = sentence.replace(" ", " ")
        sentence = sentence.replace("🏻", " ")
        sentence = sentence.replace("🏼", " ")
        sentence = sentence.replace("🏽", " ")
        sentence = sentence.replace("🏾", " ")
        sentence = sentence.replace("♂", " ")
        sentence = sentence.replace("♀", " ")
        return sentence

    def prePro_2_delEmoji(self, sentence):
        for key in self.emoji_vocab:
            sentence = sentence.replace(key, " ")
        return sentence

    def prePro_2_attEmoji(self, sentence, getEmoji=True):
        emoji_encode_list = []
        for key in self.emoji_vocab:
            if sentence.find(key) >= 0:
                emoji_encode_list.append(self.emoji_vocab[key])
                # sentence = sentence.replace(key, " ")
                sentence = sentence.replace(key, " " + key + " ")
        if getEmoji == False:
            return sentence
        return sentence, emoji_encode_list

    def getSentences(self, file):
        train = pd.read_csv(file, sep='\t', header=0)
        train_ind = []
        train_grp = []
        train_oth = []
        for i in range(len(train)):
            if str(train['subtask_a'][i]) == 'OFF' and str(train['subtask_b'][i]) == 'TIN':
                sentence = str(train['tweet'][i])
                sentence = self.prePro_1(sentence)
                if self.eab == 1:
                    sentence = self.prePro_2_delEmoji(sentence)
                else:
                    sentence = self.prePro_2_attEmoji(sentence, False)
                if str(train['subtask_c'][i]) == 'IND':
                    train_ind.append(sentence)
                elif str(train['subtask_c'][i]) == 'GRP':
                    train_grp.append(sentence)
                elif str(train['subtask_c'][i]) == 'OTH':
                    train_oth.append(sentence)
        return train_ind, train_grp, train_oth

    def getSentences_eab(self, file):
        train = pd.read_csv(file, sep='\t', header=0)
        train_ind = []
        train_grp = []
        train_oth = []
        train_ind_emoji = []
        train_grp_emoji = []
        train_oth_emoji = []
        for i in range(len(train)):
            if str(train['subtask_a'][i]) == 'OFF' and str(train['subtask_b'][i]) == 'TIN':
                sentence = str(train['tweet'][i])
                sentence = self.prePro_1(sentence)
                sentence, sent_emoji_encode_list = self.prePro_2_attEmoji(sentence)
                if str(train['subtask_c'][i]) == 'IND':
                    train_ind.append(sentence)
                    train_ind_emoji.append(sent_emoji_encode_list)
                elif str(train['subtask_c'][i]) == 'GRP':
                    train_grp.append(sentence)
                    train_grp_emoji.append(sent_emoji_encode_list)
                elif str(train['subtask_c'][i]) == 'OTH':
                    train_oth.append(sentence)
                    train_oth_emoji.append(sent_emoji_encode_list)
        return train_ind, train_grp, train_oth, train_ind_emoji, train_grp_emoji, train_oth_emoji

    def verPro(self, x):
        maxWordNumOfBatch = 0
        for sentence in x:
            if maxWordNumOfBatch < len(sentence):
                maxWordNumOfBatch = len(sentence)
        for sentenceIndex in range(len(x)):
            if len(x[sentenceIndex]) < maxWordNumOfBatch:
                for i in range(maxWordNumOfBatch - len(x[sentenceIndex])):
                    x[sentenceIndex].append(0)
        x = np.array(x)
        return x

    def getTrainBatch(self, verify, cross_deviation, cross_multiple=5, one_third_batch_size=50):
        inputs = []
        labels = []
        if cross_deviation == 0:
            ver_ind_start = 0
            ver_ind_end = 482
            ver_grp_start = 0
            ver_grp_end = 215
            ver_oth_start = 0
            ver_oth_end = 79
        elif cross_deviation == 1:
            ver_ind_start = 482
            ver_ind_end = 964
            ver_grp_start = 215
            ver_grp_end = 215 + 214 * cross_deviation
            ver_oth_start = 79 * cross_deviation
            ver_oth_end = 79 * (cross_deviation + 1)
        else:
            ver_ind_start = 964 + 481 * (cross_deviation - 2)
            ver_ind_end = 964 + 481 * (cross_deviation - 1)
            ver_grp_start = 215 + 214 * (cross_deviation - 1)
            ver_grp_end = 215 + 214 * cross_deviation
            ver_oth_start = 79 * cross_deviation
            ver_oth_end = 79 * (cross_deviation + 1)
        if verify == False:
            t_ind = self.train_ind[0:ver_ind_start] + self.train_ind[ver_ind_end:self.train_ind_leng]
            t_grp = self.train_grp[0:ver_grp_start] + self.train_grp[ver_grp_end:self.train_grp_leng]
            t_oth = self.train_oth[0:ver_oth_start] + self.train_oth[ver_oth_end:self.train_oth_leng]
            for i in range(one_third_batch_size):
                index = np.random.randint(0, len(t_ind))
                inputs.append(self.getVecFromSent(t_ind[index]))
                labels.append([0, 0, 1])
            for i in range(one_third_batch_size):
                index = np.random.randint(0, len(t_grp))
                inputs.append(self.getVecFromSent(t_grp[index]))
                labels.append([0, 1, 0])
            for i in range(one_third_batch_size):
                index = np.random.randint(0, len(t_oth))
                inputs.append(self.getVecFromSent(t_oth[index]))
                labels.append([1, 0, 0])
        else:
            for i in range(ver_ind_end - ver_ind_start):
                inputs.append(self.getVecFromSent(self.train_ind[ver_ind_start+i]))
                labels.append([0, 0, 1])
            for i in range(ver_grp_end - ver_grp_start):
                inputs.append(self.getVecFromSent(self.train_grp[ver_grp_start+i]))
                labels.append([0, 1, 0])
            for i in range(ver_oth_end - ver_oth_start):
                inputs.append(self.getVecFromSent(self.train_oth[ver_oth_start+i]))
                labels.append([1, 0, 0])
        inputs = self.verPro(inputs)
        return inputs, labels

    def getTrainBatch_eab(self, verify, cross_deviation, cross_multiple=5, one_third_batch_size=50):
        inputs = []
        emojis = []
        labels = []
        if cross_deviation == 0:
            ver_ind_start = 0
            ver_ind_end = 482
            ver_grp_start = 0
            ver_grp_end = 215
            ver_oth_start = 0
            ver_oth_end = 79
        elif cross_deviation == 1:
            ver_ind_start = 482
            ver_ind_end = 964
            ver_grp_start = 215
            ver_grp_end = 215 + 214 * cross_deviation
            ver_oth_start = 79 * cross_deviation
            ver_oth_end = 79 * (cross_deviation + 1)
        else:
            ver_ind_start = 964 + 481 * (cross_deviation - 2)
            ver_ind_end = 964 + 481 * (cross_deviation - 1)
            ver_grp_start = 215 + 214 * (cross_deviation - 1)
            ver_grp_end = 215 + 214 * cross_deviation
            ver_oth_start = 79 * cross_deviation
            ver_oth_end = 79 * (cross_deviation + 1)
        if verify == False:
            t_ind = self.train_ind[0:ver_ind_start] + self.train_ind[ver_ind_end:self.train_ind_leng]
            t_grp = self.train_grp[0:ver_grp_start] + self.train_grp[ver_grp_end:self.train_grp_leng]
            t_oth = self.train_oth[0:ver_oth_start] + self.train_oth[ver_oth_end:self.train_oth_leng]
            t_ind_emoji = self.train_ind_emoji[0:ver_ind_start] + self.train_ind_emoji[ver_ind_end:self.train_ind_leng]
            t_grp_emoji = self.train_grp_emoji[0:ver_grp_start] + self.train_grp_emoji[ver_grp_end:self.train_grp_leng]
            t_oth_emoji = self.train_oth_emoji[0:ver_oth_start] + self.train_oth_emoji[ver_oth_end:self.train_oth_leng]
            for i in range(one_third_batch_size):
                index = np.random.randint(0, len(t_ind))
                inputs.append(self.getVecFromSent(t_ind[index]))
                emojis.append(t_ind_emoji[index])
                labels.append([0, 0, 1])
            for i in range(one_third_batch_size):
                index = np.random.randint(0, len(t_grp))
                inputs.append(self.getVecFromSent(t_grp[index]))
                emojis.append(t_grp_emoji[index])
                labels.append([0, 1, 0])
            for i in range(one_third_batch_size):
                index = np.random.randint(0, len(t_oth))
                inputs.append(self.getVecFromSent(t_oth[index]))
                emojis.append(t_oth_emoji[index])
                labels.append([1, 0, 0])
        else:
            for i in range(ver_ind_end - ver_ind_start):
                inputs.append(self.getVecFromSent(self.train_ind[ver_ind_start+i]))
                emojis.append(self.train_ind_emoji[ver_ind_start+i])
                labels.append([0, 0, 1])
            for i in range(ver_grp_end - ver_grp_start):
                inputs.append(self.getVecFromSent(self.train_grp[ver_grp_start+i]))
                emojis.append(self.train_grp_emoji[ver_grp_start+i])
                labels.append([0, 1, 0])
            for i in range(ver_oth_end - ver_oth_start):
                inputs.append(self.getVecFromSent(self.train_oth[ver_oth_start+i]))
                emojis.append(self.train_oth_emoji[ver_oth_start+i])
                labels.append([1, 0, 0])
        inputs = self.verPro(inputs)
        emojis = self.verPro(emojis)
        return inputs, emojis, labels

wordModel = DataProcessing(
    'D:/SemEval2019-Task6/workspace/wordModel/glove.twitter.27B.200d_1.txt',
    'D:/SemEval2019-Task6/workspace/data/train/offenseval-training-v1.tsv',
    'D:/SemEval2019-Task6/workspace/wordModel/emoji2vec200.txt',
    'D:/SemEval2019-Task6/workspace/wordModel/SwearWord2vec.txt',
    200, eab=2)
print(len(wordModel.train_ind))
print(len(wordModel.train_grp))
print(len(wordModel.train_oth))
