import pandas as pd
import numpy as np
from ekphrasis.classes.spellcorrect import SpellCorrector
from ekphrasis.classes.segmenter import Segmenter

class DataProcessing:
    def __init__(self, modelFile, dataFile, emojiFile, swearFile, numDimensions, eab=0):
        print('Loading Embeddings...')
        self.vocab, embd = self.getVecFromLocal(modelFile, emojiFile)
        self.vocab_size = len(self.vocab)
        self.embedding = np.asarray(embd)
        self.eab = eab
        self.emoji_vocab, emoji_embd = self.getVecFromEmoji(emojiFile)
        if eab != 0:
            self.emoji_vocab_size = len(self.emoji_vocab)
            self.emoji_embedding = np.asarray(emoji_embd)
        print('Embeddings loaded.')

        print('Loading Train Data...')
        self.seg = Segmenter(corpus="twitter")
        self.spc = SpellCorrector(corpus="twitter")
        if eab == 2:
            self.train_off_unt, self.train_off_tin, self.train_off_unt_emoji, self.train_off_tin_emoji = self.getSentences_eab(dataFile)
        else:
            self.train_off_unt, self.train_off_tin = self.getSentences(dataFile)
        self.train_off_unt_leng = len(self.train_off_unt)
        self.train_off_tin_leng = len(self.train_off_tin)
        self.numDimensions = numDimensions
        self.emoji_numDimensions = 200
        print('Train Data loaded.')

    def getVecFromLocal(self, modelFile, emojiFile):
        vocab = {}
        embd = np.zeros((1195098, 200))
        vocab['__nok__'] = 0
        vocab['__unk__'] = 1
        vocab['__num__'] = 2
        vocab['__date__'] = 3
        embd[1] = 1 - 2 * np.random.random(200)
        embd[2] = 1 - 2 * np.random.random(200)
        embd[3] = 1 - 2 * np.random.random(200)
        with open(modelFile, "rb") as f:
            i = 4
            for line in f.readlines():
                line = line.decode(encoding="utf-8").split(' ')
                vocab[line[0]] = i
                embd[i] = line[1:]
                i += 1
            f.close()
        with open(emojiFile, "rb") as f:
            for line in f.readlines():
                line = line.decode(encoding="utf-8").replace(" \r\n", "")
                line = line.split(' ')
                try:
                    vocab[line[0]]
                except KeyError:
                    vocab[line[0]] = i
                    embd[i] = line[1:]
                    i += 1
            f.close()
        return vocab, embd

    def getVecFromEmoji(self, modelFile):
        vocab = {}
        embd = np.zeros((1661 + 1, 200))
        vocab['__nok__'] = 0
        with open(modelFile, "rb") as f:
            i = 1
            for line in f.readlines():
                line = line.decode(encoding="utf-8").replace(" \r\n", "")
                line = line.split(' ')
                vocab[line[0]] = i
                embd[i] = line[1:]
                i += 1
                if i == 1661 + 1:
                    break
            f.close()
        return vocab, embd

    def getVecFromSent(self, sentence):
        ids = []
        words = sentence.split(' ')
        for i in range(len(words)):
            try:
                if words[i] == ' ' or words[i] == '':
                    continue
                ids.append(self.vocab[words[i]])
            except KeyError:
                ids = self.getVecFromSent_unkWordPro1(words[i], ids)
        return ids

    def getVecFromSent_unkWordPro1(self, word, ids):
        word = word.replace("'", "")
        word = word.replace("‘", "")
        word = word.replace("’", "")
        word = self.seg.segment(word)
        words = word.split(' ')
        for i in range(len(words)):
            try:
                if words[i] == ' ' or words[i] == '':
                    continue
                ids.append(self.vocab[words[i]])
            except KeyError:
                ids = self.getVecFromSent_unkWordPro2(words[i], ids)
        return ids

    def getVecFromSent_unkWordPro2(self, word, ids):
        word = self.spc.correct(word)
        try:
            ids.append(self.vocab[word])
        except KeyError:
            if word.isdigit() == True:
                if int(word) >= 1900 and int(word) <= 2020:
                    ids.append(4)
                else:
                    ids.append(3)
            else:
                ids.append(1)
        return ids

    def prePro_1(self, sentence):
        sentence = sentence.lower()
        sentence = sentence.replace("@user", " <user> ")
        sentence = sentence.replace("'m", " 'm")
        sentence = sentence.replace("’m", " ’m")
        sentence = sentence.replace("'re", " 're")
        sentence = sentence.replace("’re", " ’re")
        sentence = sentence.replace("'ll", " 'll")
        sentence = sentence.replace("’ll", " ’ll")
        sentence = sentence.replace("'ve", " 've")
        sentence = sentence.replace("’ve", " ’ve")
        sentence = sentence.replace("'d", " 'd")
        sentence = sentence.replace("’d", " ’d")
        sentence = sentence.replace("n't", " n't")
        sentence = sentence.replace("n’t", " n’t")
        sentence = sentence.replace("'s", " 's")
        sentence = sentence.replace("’s", " ’s")
        sentence = sentence.replace("s'", " s'")
        sentence = sentence.replace("s’", " s’")
        sentence = sentence.replace(".", " . ")
        sentence = sentence.replace(",", " , ")
        sentence = sentence.replace("!", " ! ")
        sentence = sentence.replace("！", " ！ ")
        sentence = sentence.replace("?", " ? ")
        sentence = sentence.replace("？", " ？ ")
        sentence = sentence.replace("#", " # ")
        sentence = sentence.replace("$", " $ ")
        sentence = sentence.replace("\"", " \" ")
        sentence = sentence.replace("”", " ” ")
        sentence = sentence.replace("“", " “ ")
        sentence = sentence.replace(";", " ; ")
        sentence = sentence.replace(":", " : ")
        sentence = sentence.replace("(", " ( ")
        sentence = sentence.replace(")", " ) ")
        sentence = sentence.replace("[", " [ ")
        sentence = sentence.replace("]", " ] ")
        sentence = sentence.replace("{", " { ")
        sentence = sentence.replace("}", " } ")
        sentence = sentence.replace("+", " + ")
        sentence = sentence.replace("-", " - ")
        sentence = sentence.replace("*", " * ")
        sentence = sentence.replace("/", " / ")
        sentence = sentence.replace("=", " = ")
        sentence = sentence.replace("—", " — ")
        sentence = sentence.replace("…", " … ")
        sentence = sentence.replace("•", " • ")
        sentence = sentence.replace(" ", " ")
        sentence = sentence.replace("🏻", " ")
        sentence = sentence.replace("🏼", " ")
        sentence = sentence.replace("🏽", " ")
        sentence = sentence.replace("🏾", " ")
        sentence = sentence.replace("♂", " ")
        sentence = sentence.replace("♀", " ")
        return sentence

    def prePro_2_delEmoji(self, sentence):
        for key in self.emoji_vocab:
            sentence = sentence.replace(key, " ")
        return sentence

    def prePro_2_attEmoji(self, sentence, getEmoji=True):
        emoji_encode_list = []
        for key in self.emoji_vocab:
            if sentence.find(key) >= 0:
                emoji_encode_list.append(self.emoji_vocab[key])
                # sentence = sentence.replace(key, " ")
                sentence = sentence.replace(key, " " + key + " ")
        if getEmoji == False:
            return sentence
        return sentence, emoji_encode_list

    def getSentences(self, file):
        train = pd.read_csv(file, sep='\t', header=0)
        train_off_unt = []
        train_off_tin = []
        for i in range(len(train)):
            if str(train['subtask_a'][i]) == 'OFF':
                sentence = str(train['tweet'][i])
                sentence = self.prePro_1(sentence)
                if self.eab == 1:
                    sentence = self.prePro_2_delEmoji(sentence)
                else:
                    sentence = self.prePro_2_attEmoji(sentence, False)
                if str(train['subtask_b'][i]) == 'UNT':
                    train_off_unt.append(sentence)
                elif str(train['subtask_b'][i]) == 'TIN':
                    train_off_tin.append(sentence)
        return train_off_unt, train_off_tin

    def getSentences_eab(self, file):
        train = pd.read_csv(file, sep='\t', header=0)
        train_off_unt = []
        train_off_tin = []
        train_off_unt_emoji = []
        train_off_tin_emoji = []
        for i in range(len(train)):
            sentence = str(train['tweet'][i])
            sentence = self.prePro_1(sentence)
            sentence, sent_emoji_encode_list = self.prePro_2_attEmoji(sentence)
            if str(train['subtask_a'][i]) == 'OFF' and str(train['subtask_b'][i]) == 'UNT':
                train_off_unt.append(sentence)
                train_off_unt_emoji.append(sent_emoji_encode_list)
            elif str(train['subtask_a'][i]) == 'OFF' and str(train['subtask_b'][i]) == 'TIN':
                train_off_tin.append(sentence)
                train_off_tin_emoji.append(sent_emoji_encode_list)
        return train_off_unt, train_off_tin, train_off_unt_emoji, train_off_tin_emoji

    def verPro(self, x):
        maxWordNumOfBatch = 0
        for sentence in x:
            if maxWordNumOfBatch < len(sentence):
                maxWordNumOfBatch = len(sentence)
        for sentenceIndex in range(len(x)):
            if len(x[sentenceIndex]) < maxWordNumOfBatch:
                for i in range(maxWordNumOfBatch - len(x[sentenceIndex])):
                    x[sentenceIndex].append(0)
        x = np.array(x)
        return x

    # 5 fold cross validation
    # def getTrainBatch1(self, verify, cross_deviation, cross_multiple=5, unt_batch_size=21, tin_batch_size=155):
    #     inputs = []
    #     labels = []
    #     if cross_deviation == 0:
    #         ver_nut_start = 105 * cross_deviation
    #         ver_nut_end = 105 * (cross_deviation + 1)
    #         ver_tin_start = 0
    #         ver_tin_end = 776
    #     elif cross_deviation == 4:
    #         ver_nut_start = 420
    #         ver_nut_end = 524
    #         ver_tin_start = 775 * cross_deviation
    #         ver_tin_end = 775 * (cross_deviation + 1)
    #     else:
    #         ver_nut_start = 105 * cross_deviation
    #         ver_nut_end = 105 * (cross_deviation + 1)
    #         ver_tin_start = 775 * cross_deviation
    #         ver_tin_end = 775 * (cross_deviation + 1)
    #     if verify == False:
    #         t_unt = self.train_off_unt[0:ver_nut_start] + self.train_off_unt[ver_nut_end:self.train_off_unt_leng]
    #         t_tin = self.train_off_tin[0:ver_tin_start] + self.train_off_tin[ver_tin_end:self.train_off_tin_leng]
    #         for i in range(unt_batch_size):
    #             unt_index = np.random.randint(0, len(t_unt))
    #             inputs.append(self.getVecFromSent(t_unt[unt_index]))
    #             labels.append([0, 1])
    #         for i in range(tin_batch_size):
    #             tin_index = np.random.randint(0, len(t_tin))
    #             inputs.append(self.getVecFromSent(t_tin[tin_index]))
    #             labels.append([1, 0])
    #     else:
    #         for i in range(ver_nut_end-ver_nut_start):
    #             inputs.append(self.getVecFromSent(self.train_off_unt[ver_nut_start+i]))
    #             labels.append([0, 1])
    #         for i in range(ver_tin_end-ver_tin_start):
    #             inputs.append(self.getVecFromSent(self.train_off_tin[ver_tin_start+i]))
    #             labels.append([1, 0])
    #     inputs = self.verPro(inputs)
    #     return inputs, labels

    def getTrainBatch2(self, verify, cross_deviation, cross_multiple=5, half_batch_size=88):
        inputs = []
        labels = []
        if cross_deviation == 0:
            ver_nut_start = 105 * cross_deviation
            ver_nut_end = 105 * (cross_deviation + 1)
            ver_tin_start = 0
            ver_tin_end = 776
        elif cross_deviation == 4:
            ver_nut_start = 420
            ver_nut_end = 524
            ver_tin_start = 775 * cross_deviation
            ver_tin_end = 775 * (cross_deviation + 1)
        else:
            ver_nut_start = 105 * cross_deviation
            ver_nut_end = 105 * (cross_deviation + 1)
            ver_tin_start = 775 * cross_deviation
            ver_tin_end = 775 * (cross_deviation + 1)
        if verify == False:
            t_unt = self.train_off_unt[0:ver_nut_start] + self.train_off_unt[ver_nut_end:self.train_off_unt_leng]
            t_tin = self.train_off_tin[0:ver_tin_start] + self.train_off_tin[ver_tin_end:self.train_off_tin_leng]
            for i in range(half_batch_size):
                unt_index = np.random.randint(0, len(t_unt))
                inputs.append(self.getVecFromSent(t_unt[unt_index]))
                labels.append([0, 1])
            for i in range(half_batch_size):
                tin_index = np.random.randint(0, len(t_tin))
                inputs.append(self.getVecFromSent(t_tin[tin_index]))
                labels.append([1, 0])
        else:
            for i in range(ver_nut_end-ver_nut_start):
                inputs.append(self.getVecFromSent(self.train_off_unt[ver_nut_start+i]))
                labels.append([0, 1])
            for i in range(ver_tin_end-ver_tin_start):
                inputs.append(self.getVecFromSent(self.train_off_tin[ver_tin_start+i]))
                labels.append([1, 0])
        inputs = self.verPro(inputs)
        return inputs, labels

    # def getTrainBatch3(self, verify, cross_deviation, model_deviation, cross_multiple=5, half_batch_size=88):
    #     inputs = []
    #     labels = []
    #     if cross_deviation == 0:
    #         ver_nut_start = 105 * cross_deviation
    #         ver_nut_end = 105 * (cross_deviation + 1)
    #         ver_tin_start = 0
    #         ver_tin_end = 776
    #     elif cross_deviation == 4:
    #         ver_nut_start = 420
    #         ver_nut_end = 524
    #         ver_tin_start = 775 * cross_deviation
    #         ver_tin_end = 775 * (cross_deviation + 1)
    #     else:
    #         ver_nut_start = 105 * cross_deviation
    #         ver_nut_end = 105 * (cross_deviation + 1)
    #         ver_tin_start = 775 * cross_deviation
    #         ver_tin_end = 775 * (cross_deviation + 1)
    #     if verify == False:
    #         t_unt = self.train_off_unt[0:ver_nut_start] + self.train_off_unt[ver_nut_end:self.train_off_unt_leng]
    #         t_tin = self.train_off_tin[0:ver_tin_start] + self.train_off_tin[ver_tin_end:self.train_off_tin_leng]
    #         for i in range(half_batch_size):
    #             unt_index = np.random.randint(0, len(t_unt))
    #             inputs.append(self.getVecFromSent(t_unt[unt_index]))
    #             labels.append([0, 1])
    #             if (model_deviation+1)*443 <= len(t_tin):
    #                 tin_index = np.random.randint(model_deviation * 443, (model_deviation + 1) * 443)
    #             else:
    #                 tin_index = np.random.randint(model_deviation * 443, len(t_tin))
    #             inputs.append(self.getVecFromSent(t_tin[tin_index]))
    #             labels.append([1, 0])
    #     else:
    #         for i in range(ver_nut_end-ver_nut_start):
    #             inputs.append(self.getVecFromSent(self.train_off_unt[ver_nut_start+i]))
    #             labels.append([0, 1])
    #         for i in range(ver_tin_end-ver_tin_start):
    #             inputs.append(self.getVecFromSent(self.train_off_tin[ver_tin_start+i]))
    #             labels.append([1, 0])
    #     inputs = self.verPro(inputs)
    #     return inputs, labels

    def getTrainBatch_eab(self, verify, cross_deviation, cross_multiple=5, half_batch_size=88):
        inputs = []
        emojis = []
        labels = []
        if cross_deviation == 0:
            ver_nut_start = 0
            ver_nut_end = 104
            ver_tin_start = 0
            ver_tin_end = 776
        else:
            ver_nut_start = 104 + 105 * (cross_deviation - 1)
            ver_nut_end = 104 + 105 * cross_deviation
            ver_tin_start = 776 + 775 * (cross_deviation - 1)
            ver_tin_end = 776 + 775 * cross_deviation
        if verify == False:
            t_unt = self.train_off_unt[0:ver_nut_start] + self.train_off_unt[ver_nut_end:self.train_off_unt_leng]
            t_tin = self.train_off_tin[0:ver_tin_start] + self.train_off_tin[ver_tin_end:self.train_off_tin_leng]
            t_emoji_unt = self.train_off_unt_emoji[0:ver_nut_start] + self.train_off_unt_emoji[ver_nut_end:self.train_off_unt_leng]
            t_emoji_tin = self.train_off_tin_emoji[0:ver_tin_start] + self.train_off_tin_emoji[ver_tin_end:self.train_off_tin_leng]
            for i in range(half_batch_size):
                unt_index = np.random.randint(0, len(t_unt))
                inputs.append(self.getVecFromSent(t_unt[unt_index]))
                emojis.append(t_emoji_unt[unt_index])
                labels.append([0, 1])
            for i in range(half_batch_size):
                tin_index = np.random.randint(0, len(t_tin))
                inputs.append(self.getVecFromSent(t_tin[tin_index]))
                emojis.append(t_emoji_tin[tin_index])
                labels.append([1, 0])
        else:
            for i in range(ver_nut_end-ver_nut_start):
                inputs.append(self.getVecFromSent(self.train_off_unt[ver_nut_start+i]))
                emojis.append(self.train_off_unt_emoji[ver_nut_start + i])
                labels.append([0, 1])
            for i in range(ver_tin_end-ver_tin_start):
                inputs.append(self.getVecFromSent(self.train_off_tin[ver_tin_start+i]))
                emojis.append(self.train_off_tin_emoji[ver_tin_start + i])
                labels.append([1, 0])
        inputs = self.verPro(inputs)
        emojis = self.verPro(emojis)
        return inputs, emojis, labels

    def getTrainBatch_eab_10(self, verify, cross_deviation, cross_multiple=10, half_batch_size=88):
        inputs = []
        emojis = []
        labels = []
        if cross_deviation < 4:
            ver_nut_start = 53 * cross_deviation
            ver_nut_end = 53 * (cross_deviation + 1)
            ver_tin_start = 387 * cross_deviation
            ver_tin_end = 387 * (cross_deviation + 1)
        else:
            ver_nut_start = 212 + 52 * (cross_deviation - 4)
            ver_nut_end = 212 + 52 * (cross_deviation - 3)
            ver_tin_start = 1548 + 388 * (cross_deviation - 4)
            ver_tin_end = 1548 + 388 * (cross_deviation - 3)
        if verify == False:
            t_unt = self.train_off_unt[0:ver_nut_start] + self.train_off_unt[ver_nut_end:self.train_off_unt_leng]
            t_tin = self.train_off_tin[0:ver_tin_start] + self.train_off_tin[ver_tin_end:self.train_off_tin_leng]
            t_emoji_unt = self.train_off_unt_emoji[0:ver_nut_start] + self.train_off_unt_emoji[ver_nut_end:self.train_off_unt_leng]
            t_emoji_tin = self.train_off_tin_emoji[0:ver_tin_start] + self.train_off_tin_emoji[ver_tin_end:self.train_off_tin_leng]
            for i in range(half_batch_size):
                unt_index = np.random.randint(0, len(t_unt))
                inputs.append(self.getVecFromSent(t_unt[unt_index]))
                emojis.append(t_emoji_unt[unt_index])
                labels.append([0, 1])
            for i in range(half_batch_size):
                tin_index = np.random.randint(0, len(t_tin))
                inputs.append(self.getVecFromSent(t_tin[tin_index]))
                emojis.append(t_emoji_tin[tin_index])
                labels.append([1, 0])
        else:
            for i in range(ver_nut_end-ver_nut_start):
                inputs.append(self.getVecFromSent(self.train_off_unt[ver_nut_start+i]))
                emojis.append(self.train_off_unt_emoji[ver_nut_start + i])
                labels.append([0, 1])
            for i in range(ver_tin_end-ver_tin_start):
                inputs.append(self.getVecFromSent(self.train_off_tin[ver_tin_start+i]))
                emojis.append(self.train_off_tin_emoji[ver_tin_start + i])
                labels.append([1, 0])
        inputs = self.verPro(inputs)
        emojis = self.verPro(emojis)
        return inputs, emojis, labels

# wordModel = DataProcessing('D:/SemEval2019-Task6/workspace/wordModel/glove.twitter.27B.200d_1.txt',
#                            'D:/SemEval2019-Task6/workspace/data/train/offenseval-training-v1.tsv',
#                            'D:/SemEval2019-Task6/workspace/wordModel/emoji2vec200.txt',
#                            'D:/SemEval2019-Task6/workspace/wordModel/SwearWord2vec.txt',
#                            200, eab=2)
# print(wordModel.train_off_unt_leng)
# print(wordModel.train_off_tin_leng)
# print(wordModel.train_off_unt)
