import csv
import numpy as np
import re
import torch
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from sklearn.utils import shuffle
import pandas as pd
import pickle
import json


def init_embeddings(embeddings):
    bias = np.sqrt(3.0 / embeddings.size(1))
    return torch.nn.init.uniform_(embeddings, -bias, bias)


def load_embeddings(emb_file, word_map):
    
    vocab = set(word_map.keys())
    
    print("Loading embedding...")
    cnt = 0 # 记录读入的词数
        
    with open(emb_file, 'r', encoding='utf-8') as f:
        emb_dim = len(f.readline().split(' ')) - 1 

    embeddings = torch.FloatTensor(len(vocab) + 1, emb_dim)
    embeddings = init_embeddings(embeddings)
    
    for line in open(emb_file, 'r', encoding='utf-8'):
        line = line.split(' ')
        emb_word = line[0]

        embedding = list(map(lambda t: float(t), filter(lambda n: n and not n.isspace(), line[1:])))
        if emb_word not in vocab:
            continue
        else:
            cnt+=1
        embeddings[word_map[emb_word]] = torch.FloatTensor(embedding)

    print("Number of words read: ", cnt)
    print("Number of OOV: ", len(vocab) + 1 -cnt)

    return embeddings, emb_dim
    

def pretrain_embed(vocab, database_path):
        #pretrain_embed, embed_dim = self.load_embeddings('GoogleNews-vectors-negative300.bin','word2vec',vocab)
        pretrain_embed, embed_dim = load_embeddings('/mnt/glove.840B.300d.txt',vocab)
        embed = dict()
        embed['pretrain'] = pretrain_embed
        embed['dim'] = embed_dim
        torch.save(embed, '{}/glove300d_pretrain_embed.pth'.format(database_path))



def clean_str(string, TREC=False):
    """
    Tokenization/string cleaning for all datasets except for SST.
    Every dataset is lower cased except for TREC
    """
    string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)     
    string = re.sub(r"\'s", " \'s", string) 
    string = re.sub(r"\'ve", " \'ve", string) 
    string = re.sub(r"n\'t", " n\'t", string) 
    string = re.sub(r"\'re", " \'re", string) 
    string = re.sub(r"\'d", " \'d", string) 
    string = re.sub(r"\'ll", " \'ll", string) 
    string = re.sub(r",", " , ", string) 
    string = re.sub(r"!", " ! ", string) 
    string = re.sub(r"\(", " \( ", string) 
    string = re.sub(r"\)", " \) ", string) 
    string = re.sub(r"\?", " \? ", string) 
    string = re.sub(r"\s{2,}", " ", string)    
    return string.strip() if TREC else string.strip().lower()

def clean_str_sst(string):
    """
    Tokenization/string cleaning for the SST dataset
    """
    string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)   
    string = re.sub(r"\s{2,}", " ", string)    
    return string.strip().lower()

def read_MR():
    data = []
    vocab = {}
    vocab_em = 1
    sequence_max_length = 0
    with open("data_source/MR/rt-polarity.pos", "r", encoding="utf-8") as f:
        for line in f:
            x = []
            if line[-1] == "\n":
                line = line[:-1]
            line = clean_str(line)
            sequence = line.split()
            if len(sequence) != 0:
                for i in range(len(sequence)):
                    if sequence[i] not in vocab:
                        vocab.update({sequence[i]:vocab_em})
                        vocab_em += 1
                if sequence_max_length < len(sequence):
                    sequence_max_length = len(sequence)
                x.append(2)
                x.append(line)           
                data.append(x)

    with open("data_source/MR/rt-polarity.neg", "r", encoding="utf-8") as f:
        for line in f:
            x = []
            if line[-1] == "\n":
                line = line[:-1]
            line = clean_str(line)
            sequence = line.split()
            if len(sequence) != 0:
                for i in range(len(sequence)):
                    if sequence[i] not in vocab:
                        vocab.update({sequence[i]:vocab_em})
                        vocab_em += 1
                if sequence_max_length < len(sequence):
                    sequence_max_length = len(sequence)
                x.append(1)
                x.append(line)           
                data.append(x)

    data = shuffle(data)
    test_idx = len(data) // 10 * 9

    data_train = pd.DataFrame(data=data[:test_idx])
    data_test = pd.DataFrame(data=data[test_idx:])
    data_train.to_csv("data/MR/train.csv", index=None,header=None)
    data_test.to_csv("data/MR/test.csv", index=None,header=None)
    train_len = test_idx
    text_len = len(data) - train_len
    vocab_len = len(vocab)
    with open('data/MR/wordmap.json', 'w') as j:
        json.dump(vocab, j)
    return vocab_len + 1, sequence_max_length, train_len, text_len, vocab


def read_SUBJ():
    data = []
    vocab = {}
    vocab_em = 1
    sequence_max_length = 0
    with open("data_source/SUBJ/all.txt", "r", encoding="ISO8859-1") as f:
        for line in f:
            x = []
            if line[-1] == "\n":
                line = line[:-1]
            linet = line.split()
            label = linet[0]
            sequence = clean_str(" ".join(linet[1:])).split()
            if len(sequence) != 0:
                for i in range(len(sequence)):
                    if sequence[i] not in vocab:
                        vocab.update({sequence[i]:vocab_em})
                        vocab_em += 1
                if sequence_max_length < len(sequence):
                    sequence_max_length = len(sequence)
                x.append(int(label)+1)
                x.append(" ".join(sequence))           
                data.append(x)
    data = shuffle(data)
    test_idx = len(data) // 10 * 9

    data_train = pd.DataFrame(data=data[:test_idx])
    data_test = pd.DataFrame(data=data[test_idx:])
    data_train.to_csv("data/SUBJ/train.csv", index=None,header=None)
    data_test.to_csv("data/SUBJ/test.csv", index=None,header=None)
    train_len = test_idx
    text_len = len(data) - train_len
    vocab_len = len(vocab)
    with open('data/SUBJ/wordmap.json', 'w') as j:
        json.dump(vocab, j)
    return vocab_len + 1, sequence_max_length, train_len, text_len, vocab

def read_CR():
    data = []
    vocab = {}
    vocab_em = 1
    sequence_max_length = 0
    with open("data_source/SUBJ/all.txt", "r", encoding="ISO8859-1") as f:
        for line in f:
            x = []
            if line[-1] == "\n":
                line = line[:-1]
            linet = line.split()
            label = linet[0]
            sequence = clean_str(" ".join(linet[1:])).split()
            if len(sequence) != 0:
                for i in range(len(sequence)):
                    if sequence[i] not in vocab:
                        vocab.update({sequence[i]:vocab_em})
                        vocab_em += 1
                if sequence_max_length < len(sequence):
                    sequence_max_length = len(sequence)
                x.append(int(label)+1)
                x.append(" ".join(sequence))           
                data.append(x)

    data = pd.DataFrame(data=data)
    data.to_csv("data/SUBJ/main.csv", index=None,header=None)


    pretrain_embed(vocab, 'data/SUBJ')


    with open('data/SUBJ/wordmap.json', 'w') as j:
        json.dump(vocab, j)


    vocab_len = len(vocab)
    return vocab_len + 1, sequence_max_length, vocab


def read_TREC():
    classes = {'DESC':1,'ENTY':2,'ABBR':3,'HUM':4,'LOC':5,'NUM':6}
    data_train, data_test = [],[]
    vocab = {}
    vocab_em = 1
    sequence_max_length = 0
    with open("data_source/TREC/train.txt", "r", encoding="utf-8") as f:
        for line in f:
            x = [] 
            if line[-1] == "\n":
                line = line[:-1]
            linet = line.split()
            label = linet[0].split(":")[0]
            sequence = clean_str(" ".join(linet[1:]),TREC=True).split()
            if len(sequence) != 0:
                for i in range(len(sequence)):
                    if sequence[i] not in vocab:
                        vocab.update({sequence[i]:vocab_em})
                        vocab_em += 1
                if sequence_max_length < len(sequence):
                    sequence_max_length = len(sequence)
                x.append(classes[label])
                x.append(" ".join(sequence))           
                data_train.append(x)                
    with open("data_source/TREC/test.txt", "r", encoding="utf-8") as f:
        for line in f:
            x = [] 
            if line[-1] == "\n":
                line = line[:-1]
            linet = line.split()
            label = linet[0].split(":")[0]#确定一下叭
            #linet.append(linet[0].split(":")[1])
            sequence = clean_str(" ".join(linet[1:]),TREC=True).split()
            if len(sequence) != 0:
                for i in range(len(sequence)):
                    if sequence[i] not in vocab:
                        vocab.update({sequence[i]:vocab_em})
                        vocab_em += 1
                if sequence_max_length < len(sequence):
                    sequence_max_length = len(sequence)
                x.append(classes[label])
                x.append(" ".join(sequence))           
                data_test.append(x)
    # data_train = shuffle(data_train)
    # data_test = shuffle(data_test)
#一会儿不刷新试试
    data_train = pd.DataFrame(data=data_train)
    data_test = pd.DataFrame(data=data_test)
    data_train.to_csv("data/TREC/train.csv", index=None,header=None)
    data_test.to_csv("data/TREC/test.csv", index=None,header=None)
    train_len = len(data_train)
    text_len = len(data_test)
    vocab_len = len(vocab)

    with open('data/TREC/wordmap.json', 'w') as j:
        json.dump(vocab, j)
    return vocab_len + 1, sequence_max_length, train_len, text_len, vocab





if __name__ == '__main__':
	vocab_len, sequence_max_length, vocab = read_CR()
	print(vocab_len + 1, sequence_max_length)
    # vocab_len, max_len, vocab = read_DBP()
    # print(vocab_len, max_len)
    # print('wordmap')
    # exit()
    # vocab_len, sequence_max_length, train_len, text_len, vocab = read_SST1()
    # print(vocab_len, sequence_max_length, train_len, text_len)#17091 53 8544 2210
    # database_path = 'data/DBPedia/'
    # data_helper = DataHelper(vocab, sequence_max_length=max_len)
    # data_helper.pretrain_embed(vocab,database_path)

