import pandas as pd
import numpy as np
from ekphrasis.classes.spellcorrect import SpellCorrector
from ekphrasis.classes.segmenter import Segmenter

class DataProcessing:
    def __init__(self, modelFile, dataFile, emojiFile, swearFile, numDimensions, eab=0, sab=False):
        print('Loading Embeddings...')
        self.numDimensions = numDimensions
        self.emoji_numDimensions = numDimensions
        self.swear_numDimensions = 200
        self.vocab, embd = self.getVecFromLocal(modelFile, emojiFile)
        self.vocab_size = len(self.vocab)
        self.embedding = np.asarray(embd)
        self.eab = eab
        self.emoji_vocab, emoji_embd = self.getVecFromEmoji(emojiFile)
        if eab != 0:
            self.emoji_vocab_size = len(self.emoji_vocab)
            self.emoji_embedding = np.asarray(emoji_embd)
        if sab == True:
            self.swear_vocab, swear_embd = self.getVecFromSwear(swearFile)
            self.swear_vocab_size = len(self.swear_vocab)
            self.swear_embedding = np.asarray(swear_embd)
        print('Embeddings loaded.')

        print('Loading Train Data...')
        self.seg = Segmenter(corpus="twitter")
        self.spc = SpellCorrector(corpus="twitter")
        if eab == 2:
            if sab == False:
                self.train_not, self.train_off, self.train_not_emoji, self.train_off_emoji = self.getSentences_eab(dataFile)
            else:
                self.train_not, self.train_off, self.train_not_emoji, self.train_off_emoji, self.train_not_swear, self.train_off_swear = self.getSentences_esab(dataFile)
        else:
            if sab == False:
                self.train_not, self.train_off = self.getSentences(dataFile)
            else:
                self.train_not, self.train_off, self.train_not_swear, self.train_off_swear = self.getSentences_sab(dataFile)
        self.train_not_leng = len(self.train_not)
        self.train_off_leng = len(self.train_off)
        print('Train Data loaded.')

    def getVecFromLocal(self, modelFile, emojiFile):
        vocab = {}
        embd = np.zeros((1195098, self.numDimensions))
        vocab['__nok__'] = 0
        vocab['__unk__'] = 1
        vocab['__num__'] = 2
        vocab['__date__'] = 3
        embd[1] = 1 - 2 * np.random.random(self.numDimensions)
        embd[2] = 1 - 2 * np.random.random(self.numDimensions)
        embd[3] = 1 - 2 * np.random.random(self.numDimensions)
        embd[4] = 1 - 2 * np.random.random(self.numDimensions)
        with open(modelFile, "rb") as f:
            i = 4
            for line in f.readlines():
                line = line.decode(encoding="utf-8").split(' ')
                vocab[line[0]] = i
                embd[i] = line[1:]
                i += 1
            f.close()
        with open(emojiFile, "rb") as f:
            for line in f.readlines():
                line = line.decode(encoding="utf-8").replace(" \r\n", "")
                line = line.split(' ')
                try:
                    vocab[line[0]]
                except KeyError:
                    vocab[line[0]] = i
                    embd[i] = line[1:]
                    i += 1
            f.close()
        return vocab, embd

    def getVecFromEmoji(self, modelFile):
        vocab = {}
        embd = np.zeros((1661 + 1, self.emoji_numDimensions))
        vocab['__nok__'] = 0
        with open(modelFile, "rb") as f:
            i = 1
            for line in f.readlines():
                line = line.decode(encoding="utf-8").replace(" \r\n", "")
                line = line.split(' ')
                vocab[line[0]] = i
                embd[i] = line[1:]
                i += 1
                if i == 1661 + 1:
                    break
            f.close()
        return vocab, embd

    def getVecFromSwear(self, modelFile):
        vocab = {}
        embd = np.zeros((673 + 1, self.swear_numDimensions))
        vocab['__nok__'] = 0
        with open(modelFile, "rb") as f:
            i = 1
            for line in f.readlines():
                line = line.decode(encoding="utf-8").replace(" \r\n", "")
                line = line.split(' ')
                vocab[line[0]] = i
                embd[i] = line[1:]
                i += 1
                if i == 673 + 1:
                    break
            f.close()
        return vocab, embd

    def getVecFromSent(self, sentence):
        ids = []
        words = sentence.split(' ')
        for i in range(len(words)):
            try:
                if words[i] == ' ' or words[i] == '':
                    continue
                ids.append(self.vocab[words[i]])
            except KeyError:
                ids = self.getVecFromSent_unkWordPro1(words[i], ids)
        return ids

    def getVecFromSent_unkWordPro1(self, word, ids):
        word = word.replace("'", "")
        word = word.replace("‘", "")
        word = word.replace("’", "")
        word = self.seg.segment(word)
        words = word.split(' ')
        for i in range(len(words)):
            try:
                if words[i] == ' ' or words[i] == '':
                    continue
                ids.append(self.vocab[words[i]])
            except KeyError:
                ids = self.getVecFromSent_unkWordPro2(words[i], ids)
        return ids

    def getVecFromSent_unkWordPro2(self, word, ids):
        word = self.spc.correct(word)
        try:
            ids.append(self.vocab[word])
        except KeyError:
            if word.isdigit() == True:
                if int(word) >= 1900 and int(word) <= 2020:
                    ids.append(4)
                else:
                    ids.append(3)
            else:
                ids.append(1)
        return ids

    def prePro_1(self, sentence):
        sentence = sentence.lower()
        sentence = sentence.replace("@user", " <user> ")
        sentence = sentence.replace("'m", " 'm")
        sentence = sentence.replace("’m", " ’m")
        sentence = sentence.replace("'re", " 're")
        sentence = sentence.replace("’re", " ’re")
        sentence = sentence.replace("'ll", " 'll")
        sentence = sentence.replace("’ll", " ’ll")
        sentence = sentence.replace("'ve", " 've")
        sentence = sentence.replace("’ve", " ’ve")
        sentence = sentence.replace("'d", " 'd")
        sentence = sentence.replace("’d", " ’d")
        sentence = sentence.replace("n't", " n't")
        sentence = sentence.replace("n’t", " n’t")
        sentence = sentence.replace("'s", " 's")
        sentence = sentence.replace("’s", " ’s")
        sentence = sentence.replace("s'", " s'")
        sentence = sentence.replace("s’", " s’")
        sentence = sentence.replace(".", " . ")
        sentence = sentence.replace(",", " , ")
        sentence = sentence.replace("!", " ! ")
        sentence = sentence.replace("！", " ！ ")
        sentence = sentence.replace("?", " ? ")
        sentence = sentence.replace("？", " ？ ")
        sentence = sentence.replace("#", " # ")
        sentence = sentence.replace("$", " $ ")
        sentence = sentence.replace("\"", " \" ")
        sentence = sentence.replace("”", " ” ")
        sentence = sentence.replace("“", " “ ")
        sentence = sentence.replace(";", " ; ")
        sentence = sentence.replace(":", " : ")
        sentence = sentence.replace("(", " ( ")
        sentence = sentence.replace(")", " ) ")
        sentence = sentence.replace("[", " [ ")
        sentence = sentence.replace("]", " ] ")
        sentence = sentence.replace("{", " { ")
        sentence = sentence.replace("}", " } ")
        sentence = sentence.replace("+", " + ")
        sentence = sentence.replace("-", " - ")
        sentence = sentence.replace("*", " * ")
        sentence = sentence.replace("/", " / ")
        sentence = sentence.replace("=", " = ")
        sentence = sentence.replace("—", " — ")
        sentence = sentence.replace("…", " … ")
        sentence = sentence.replace("•", " • ")
        sentence = sentence.replace(" ", " ")
        sentence = sentence.replace("🏻", " ")
        sentence = sentence.replace("🏼", " ")
        sentence = sentence.replace("🏽", " ")
        sentence = sentence.replace("🏾", " ")
        sentence = sentence.replace("♂", " ")
        sentence = sentence.replace("♀", " ")
        return sentence

    def prePro_2_delEmoji(self, sentence):
        for key in self.emoji_vocab:
            sentence = sentence.replace(key, " ")
        return sentence

    def prePro_2_attEmoji(self, sentence, getEmoji=True):
        emoji_encode_list = []
        for key in self.emoji_vocab:
            if sentence.find(key) >= 0:
                emoji_encode_list.append(self.emoji_vocab[key])
                # sentence = sentence.replace(key, " ")
                sentence = sentence.replace(key, " " + key + " ")
        if getEmoji == False:
            return sentence
        return sentence, emoji_encode_list

    def prePro_3_attSwear(self, sentence):
        swear_encode_list = []
        words = sentence.split(' ')
        for i in range(len(words)):
            if words[i] == ' ' or words[i] == '':
                continue
            for key in self.swear_vocab:
                if words[i] == key:
                    swear_encode_list.append(self.swear_vocab[words[i]])
        return swear_encode_list

    def getSentences(self, file):
        train = pd.read_csv(file, sep='\t', header=0)
        train_not = []
        train_off = []
        for i in range(len(train)):
            sentence = str(train['tweet'][i])
            sentence = self.prePro_1(sentence)
            if self.eab == 1:
                sentence = self.prePro_2_delEmoji(sentence)
            else:
                sentence = self.prePro_2_attEmoji(sentence, False)
            if str(train['subtask_a'][i]) == 'NOT':
                train_not.append(sentence)
            else:
                train_off.append(sentence)
        return train_not, train_off

    def getSentences_eab(self, file):
        train = pd.read_csv(file, sep='\t', header=0)
        train_not = []
        train_off = []
        train_not_emoji = []
        train_off_emoji = []
        for i in range(len(train)):
            sentence = str(train['tweet'][i])
            sentence = self.prePro_1(sentence)
            sentence, sent_emoji_encode_list = self.prePro_2_attEmoji(sentence)
            if str(train['subtask_a'][i]) == 'NOT':
                train_not.append(sentence)
                train_not_emoji.append(sent_emoji_encode_list)
            else:
                train_off.append(sentence)
                train_off_emoji.append(sent_emoji_encode_list)
        return train_not, train_off, train_not_emoji, train_off_emoji

    def getSentences_sab(self, file):
        train = pd.read_csv(file, sep='\t', header=0)
        train_not = []
        train_off = []
        train_not_swear = []
        train_off_swear = []
        for i in range(len(train)):
            sentence = str(train['tweet'][i])
            sentence = self.prePro_1(sentence)
            if self.eab == 1:
                sentence = self.prePro_2_delEmoji(sentence)
            else:
                sentence = self.prePro_2_attEmoji(sentence, False)
            sent_swear_encode_list = self.prePro_3_attSwear(sentence)
            if str(train['subtask_a'][i]) == 'NOT':
                train_not.append(sentence)
                train_not_swear.append(sent_swear_encode_list)
            else:
                train_off.append(sentence)
                train_off_swear.append(sent_swear_encode_list)
        return train_not, train_off, train_not_swear, train_off_swear

    def getSentences_esab(self, file):
        train = pd.read_csv(file, sep='\t', header=0)
        train_not = []
        train_off = []
        train_not_emoji = []
        train_off_emoji = []
        train_not_swear = []
        train_off_swear = []
        for i in range(len(train)):
            sentence = str(train['tweet'][i])
            sentence = self.prePro_1(sentence)
            sentence, sent_emoji_encode_list = self.prePro_2_attEmoji(sentence)
            sent_swear_encode_list = self.prePro_3_attSwear(sentence)
            if str(train['subtask_a'][i]) == 'NOT':
                train_not.append(sentence)
                train_not_emoji.append(sent_emoji_encode_list)
                train_not_swear.append(sent_swear_encode_list)
            else:
                train_off.append(sentence)
                train_off_emoji.append(sent_emoji_encode_list)
                train_off_swear.append(sent_swear_encode_list)
        return train_not, train_off, train_not_emoji, train_off_emoji, train_not_swear, train_off_swear

    def verPro(self, x):
        maxWordNumOfBatch = 0
        for sentence in x:
            if maxWordNumOfBatch < len(sentence):
                maxWordNumOfBatch = len(sentence)
        for sentenceIndex in range(len(x)):
            if len(x[sentenceIndex]) < maxWordNumOfBatch:
                for i in range(maxWordNumOfBatch - len(x[sentenceIndex])):
                    x[sentenceIndex].append(0)
        x = np.array(x)
        return x

    # 5 fold cross validation
    def getTrainBatch(self, verify, batch_index, cross_deviation, cross_multiple=5, not_batch_size=221, off_batch_size=110):
        inputs = []
        labels = []
        single_not = int(self.train_not_leng / cross_multiple)
        single_off = int(self.train_off_leng / cross_multiple)
        ver_not_start = single_not * cross_deviation
        ver_off_start = single_off * cross_deviation
        ver_not_end = single_not * (cross_deviation + 1)
        ver_off_end = single_off * (cross_deviation + 1)
        if verify == False:
            t_not = self.train_not[0:ver_not_start] + self.train_not[ver_not_end:self.train_not_leng]
            t_off = self.train_off[0:ver_off_start] + self.train_off[ver_off_end:self.train_off_leng]
            for i in range(not_batch_size):
                inputs.append(self.getVecFromSent(t_not[batch_index * not_batch_size + i]))
                labels.append([0, 1])
            for i in range(off_batch_size):
                inputs.append(self.getVecFromSent(t_off[batch_index * off_batch_size + i]))
                labels.append([1, 0])
        else:
            for i in range(not_batch_size):
                inputs.append(self.getVecFromSent(self.train_not[ver_not_start + batch_index * not_batch_size + i]))
                labels.append([0, 1])
            for i in range(off_batch_size):
                inputs.append(self.getVecFromSent(self.train_off[ver_off_start + batch_index * off_batch_size + i]))
                labels.append([1, 0])
        inputs = self.verPro(inputs)
        return inputs, labels

    def getTrainBatch_eab(self, verify, batch_index, cross_deviation, cross_multiple=5, not_batch_size=221, off_batch_size=110):
        inputs = []
        emojis = []
        labels = []
        single_not = int(self.train_not_leng / cross_multiple)
        single_off = int(self.train_off_leng / cross_multiple)
        ver_not_start = single_not * cross_deviation
        ver_off_start = single_off * cross_deviation
        ver_not_end = single_not * (cross_deviation + 1)
        ver_off_end = single_off * (cross_deviation + 1)
        if verify == False:
            t_not = self.train_not[0:ver_not_start] + self.train_not[ver_not_end:self.train_not_leng]
            t_off = self.train_off[0:ver_off_start] + self.train_off[ver_off_end:self.train_off_leng]
            t_emoji_not = self.train_not_emoji[0:ver_not_start] + self.train_not_emoji[ver_not_end:self.train_not_leng]
            t_emoji_off = self.train_off_emoji[0:ver_off_start] + self.train_off_emoji[ver_off_end:self.train_off_leng]
            for i in range(not_batch_size):
                inputs.append(self.getVecFromSent(t_not[batch_index * not_batch_size + i]))
                emojis.append(t_emoji_not[batch_index * not_batch_size + i])
                labels.append([0, 1])
            for i in range(off_batch_size):
                inputs.append(self.getVecFromSent(t_off[batch_index * off_batch_size + i]))
                emojis.append(t_emoji_off[batch_index * off_batch_size + i])
                labels.append([1, 0])
        else:
            for i in range(not_batch_size):
                inputs.append(self.getVecFromSent(self.train_not[ver_not_start + batch_index * not_batch_size + i]))
                emojis.append(self.train_not_emoji[ver_not_start + batch_index * not_batch_size + i])
                labels.append([0, 1])
            for i in range(off_batch_size):
                inputs.append(self.getVecFromSent(self.train_off[ver_off_start + batch_index * off_batch_size + i]))
                emojis.append(self.train_off_emoji[ver_off_start + batch_index * off_batch_size + i])
                labels.append([1, 0])
        inputs = self.verPro(inputs)
        emojis = self.verPro(emojis)
        return inputs, emojis, labels

    def getTrainBatch_sab(self, verify, batch_index, cross_deviation, cross_multiple=5, not_batch_size=221, off_batch_size=110):
        inputs = []
        swears = []
        labels = []
        single_not = int(self.train_not_leng / cross_multiple)
        single_off = int(self.train_off_leng / cross_multiple)
        ver_not_start = single_not * cross_deviation
        ver_off_start = single_off * cross_deviation
        ver_not_end = single_not * (cross_deviation + 1)
        ver_off_end = single_off * (cross_deviation + 1)
        if verify == False:
            t_not = self.train_not[0:ver_not_start] + self.train_not[ver_not_end:self.train_not_leng]
            t_off = self.train_off[0:ver_off_start] + self.train_off[ver_off_end:self.train_off_leng]
            t_swear_not = self.train_not_swear[0:ver_not_start] + self.train_not_swear[ver_not_end:self.train_not_leng]
            t_swear_off = self.train_off_swear[0:ver_off_start] + self.train_off_swear[ver_off_end:self.train_off_leng]
            for i in range(not_batch_size):
                inputs.append(self.getVecFromSent(t_not[batch_index * not_batch_size + i]))
                swears.append(t_swear_not[batch_index * not_batch_size + i])
                labels.append([0, 1])
            for i in range(off_batch_size):
                inputs.append(self.getVecFromSent(t_off[batch_index * off_batch_size + i]))
                swears.append(t_swear_off[batch_index * off_batch_size + i])
                labels.append([1, 0])
        else:
            for i in range(not_batch_size):
                inputs.append(self.getVecFromSent(self.train_not[ver_not_start + batch_index * not_batch_size + i]))
                swears.append(self.train_not_swear[ver_not_start + batch_index * not_batch_size + i])
                labels.append([0, 1])
            for i in range(off_batch_size):
                inputs.append(self.getVecFromSent(self.train_off[ver_off_start + batch_index * off_batch_size + i]))
                swears.append(self.train_off_swear[ver_off_start + batch_index * off_batch_size + i])
                labels.append([1, 0])
        inputs = self.verPro(inputs)
        swears = self.verPro(swears)
        return inputs, swears, labels

    def getTrainBatch_esab(self, verify, batch_index, cross_deviation, cross_multiple=5, not_batch_size=221, off_batch_size=110):
        inputs = []
        emojis = []
        swears = []
        labels = []
        single_not = int(self.train_not_leng / cross_multiple)
        single_off = int(self.train_off_leng / cross_multiple)
        ver_not_start = single_not * cross_deviation
        ver_off_start = single_off * cross_deviation
        ver_not_end = single_not * (cross_deviation + 1)
        ver_off_end = single_off * (cross_deviation + 1)
        if verify == False:
            t_not = self.train_not[0:ver_not_start] + self.train_not[ver_not_end:self.train_not_leng]
            t_off = self.train_off[0:ver_off_start] + self.train_off[ver_off_end:self.train_off_leng]
            t_emoji_not = self.train_not_emoji[0:ver_not_start] + self.train_not_emoji[ver_not_end:self.train_not_leng]
            t_emoji_off = self.train_off_emoji[0:ver_off_start] + self.train_off_emoji[ver_off_end:self.train_off_leng]
            t_swear_not = self.train_not_swear[0:ver_not_start] + self.train_not_swear[ver_not_end:self.train_not_leng]
            t_swear_off = self.train_off_swear[0:ver_off_start] + self.train_off_swear[ver_off_end:self.train_off_leng]
            for i in range(not_batch_size):
                inputs.append(self.getVecFromSent(t_not[batch_index * not_batch_size + i]))
                emojis.append(t_emoji_not[batch_index * not_batch_size + i])
                swears.append(t_swear_not[batch_index * not_batch_size + i])
                labels.append([0, 1])
            for i in range(off_batch_size):
                inputs.append(self.getVecFromSent(t_off[batch_index * off_batch_size + i]))
                emojis.append(t_emoji_off[batch_index * off_batch_size + i])
                swears.append(t_swear_off[batch_index * off_batch_size + i])
                labels.append([1, 0])
        else:
            for i in range(not_batch_size):
                inputs.append(self.getVecFromSent(self.train_not[ver_not_start + batch_index * not_batch_size + i]))
                emojis.append(self.train_not_emoji[ver_not_start + batch_index * not_batch_size + i])
                swears.append(self.train_not_swear[ver_not_start + batch_index * not_batch_size + i])
                labels.append([0, 1])
            for i in range(off_batch_size):
                inputs.append(self.getVecFromSent(self.train_off[ver_off_start + batch_index * off_batch_size + i]))
                emojis.append(self.train_off_emoji[ver_off_start + batch_index * off_batch_size + i])
                swears.append(self.train_off_swear[ver_off_start + batch_index * off_batch_size + i])
                labels.append([1, 0])
        inputs = self.verPro(inputs)
        emojis = self.verPro(emojis)
        swears = self.verPro(swears)
        return inputs, emojis, swears, labels

# wordModel = DataProcessing('D:/SemEval2019-Task6/workspace/wordModel/glove.twitter.27B.200d_1.txt',
#                            'D:/SemEval2019-Task6/workspace/data/train/offenseval-training-v1.tsv',
#                            'D:/SemEval2019-Task6/workspace/wordModel/emoji2vec200.txt',
#                            'D:/SemEval2019-Task6/workspace/wordModel/SwearWord2vec.txt',
#                            200, eab=2, sab=False)
# flag = False
# for i in range(len(wordModel.train_off)):
#     flag = False
#     words = wordModel.train_off[i].split(' ')
#     for j in range(len(words)):
#         if words[j] == ' ' or words[j] == '':
#             continue
#         try:
#             x = wordModel.swear_vocab[words[j]]
#             flag = True
#             with open("xxx.txt", "a", encoding="utf-8") as f:
#                 f.write(words[j]+" | ")
#                 f.close()
#         except KeyError:
#             continue
#     if flag == False:
#         with open("xxx.txt", "a", encoding="utf-8") as f:
#             f.write(wordModel.train_off[i]+"\n")
#             f.close()

# for i in range(len(wordModel.train_not)):
#     wordModel.getVecFromSent(wordModel.train_not[i])
# for i in range(len(wordModel.train_off)):
#     wordModel.getVecFromSent(wordModel.train_off[i])
