#!/usr/bin/python3
#-*- coding:utf-8 -*-
############################
#File Name: wbuild_graph.py
#Created Time:
############################
import os
import math
import random
import numpy as np
import pickle as pkl
# import networkx as nx
import scipy.sparse as sp
from math import log
from sklearn import svm
from nltk.corpus import wordnet as wn
from sklearn.feature_extraction.text import TfidfVectorizer
from scipy.spatial.distance import cosine
from collections import defaultdict, Counter
#from utils.utils import loadWord2Vec, clean_str
import sys
sys.path.append('../')

if len(sys.argv) != 2:
	sys.exit("Use: python build_graph.py <dataset>")

#datasets = ['20ng', 'R8', 'R52', 'ohsumed', 'mr', 'WebKB','ag_news','dbpedia','yahoo']
datasets = ['patent_1sub','patent_f','patent_1sub_abs','patent_2gro','aapd_h','wos','eurlex',"patent_fm_10",'patent_fm_20','patent_fm_50']
# build corpus
dataset = sys.argv[1]

if dataset not in datasets:
	sys.exit("wrong dataset name")

word_embeddings_dim = 300

#--------------step1--------------------------------------
# shulffing
# read training information set, split test and train docs
doc_name_list = []
doc_train_list = []
doc_test_list = []

with open('../data/' + dataset + '.txt', 'r') as f:
    lines = f.readlines()
    for line in lines:
        doc_name_list.append(line.strip())
        temp = line.split("\t")
        if temp[1].find('test') != -1:
            doc_test_list.append(line.strip())
        elif temp[1].find('train') != -1:
            doc_train_list.append(line.strip())

# read all docs text
doc_content_list = []
with open('../data/corpus/' + dataset + '.clean.txt', 'r') as f:
    lines = f.readlines()
    for line in lines:
        doc_content_list.append(line.strip())

# read train docs ids
train_ids = []
for train_name in doc_train_list:
    train_id = doc_name_list.index(train_name)
    train_ids.append(train_id)
random.shuffle(train_ids)


train_ids_str = '\n'.join(str(index) for index in train_ids)
with open('../data/' + dataset + '.train.index', 'w') as f:
    f.write(train_ids_str)

# read test docs ids
test_ids = []
for test_name in doc_test_list:
    test_id = doc_name_list.index(test_name)
    test_ids.append(test_id)
random.shuffle(test_ids)

test_ids_str = '\n'.join(str(index) for index in test_ids)
with open('../data/' + dataset + '.test.index', 'w') as f:
    f.write(test_ids_str)

# total text ids
# train_ids first and test ids second
ids = train_ids + test_ids

shuffle_doc_name_list = []
shuffle_doc_words_list = []
for id in ids:
    shuffle_doc_name_list.append(doc_name_list[int(id)])
    shuffle_doc_words_list.append(doc_content_list[int(id)])
shuffle_doc_name_str = '\n'.join(shuffle_doc_name_list)
shuffle_doc_words_str = '\n'.join(shuffle_doc_words_list)

with open('../data/' + dataset + '_shuffle.txt', 'w') as f:
    f.write(shuffle_doc_name_str)

with open('../data/corpus/' + dataset + '_shuffle.txt', 'w') as f:
    f.write(shuffle_doc_words_str)

# build vocab and count words frequency

word_freq = {}
word_set = set()
for doc_words in shuffle_doc_words_list:
    words = doc_words.split()
    for word in words:
        word_set.add(word)
        if word in word_freq:
            word_freq[word] += 1
        else:
            word_freq[word] = 1

# vocab tables
vocab = list(word_set)
vocab_size = len(vocab)

word_doc_list = {}
label_word_list = defaultdict(list)

for i in range(len(shuffle_doc_words_list)):
    doc_words = shuffle_doc_words_list[i]
    doc_label = shuffle_doc_name_list[i].split('\t')[2].split(" ")
    words = doc_words.split()
    for lab in doc_label:
        label_word_list[lab].extend(words)
    appeared = set()
    for word in words:
        if word in appeared:
            continue
        if word in word_doc_list:
            doc_list = word_doc_list[word]
            doc_list.append(i)
            word_doc_list[word] = doc_list
        else:
            word_doc_list[word] = [i]
        appeared.add(word)

word_doc_freq = {}
for word, doc_list in word_doc_list.items():
    word_doc_freq[word] = len(doc_list)

word_id_map = {}
for i in range(vocab_size):
    word_id_map[vocab[i]] = i

vocab_str = '\n'.join(vocab)

with open('../data/corpus/' + dataset + '_vocab.txt', 'w') as f:
    f.write(vocab_str)

# label list
label_set = set()
for doc_meta in shuffle_doc_name_list:
    temp = doc_meta.split('\t')
    label_set.update(temp[2].split(" "))

#label_list = list(label_set)
#label_list = list(map(str,sorted(map(int,label_list))))
label_list = list(map(str,range(9162)))
#label_list = list(map(str,range(798)))
#label_list = list(map(str,range(150)))
#label_list = list(map(str,range(441)))


label_list_str = '\n'.join(label_list)
with open('../data/corpus/' + dataset + '_labels.txt', 'w') as f:
    f.write(label_list_str)

# slect 90% training set
train_size = len(train_ids)
test_size = len(test_ids)
val_size = int(0.1 * train_size)
real_train_size = train_size - val_size  # - int(0.5 * train_size)
# different training rates

real_train_doc_names = shuffle_doc_name_list[:real_train_size]
real_train_doc_names_str = '\n'.join(real_train_doc_names)

with open('../data/' + dataset + '.real_train.name', 'w') as f:
    f.write(real_train_doc_names_str)



'''
Word Graph
'''


# word co-occurence with context windows
window_size = 6
windows = []

for doc_words in shuffle_doc_words_list:
    words = doc_words.split()
    length = len(words)
    if length <= window_size:
        windows.append(words)
    else:
        for j in range(length - window_size + 1):
            window = words[j: j + window_size]
            windows.append(window)

word_window_freq = {}
for window in windows:
    appeared = set()
    for i in range(len(window)):
        if window[i] in appeared:
            continue
        if window[i] in word_window_freq:
            word_window_freq[window[i]] += 1
        else:
            word_window_freq[window[i]] = 1
        appeared.add(window[i])

word_pair_count = {}
for window in windows:
    for i in range(1, len(window)):
        for j in range(0, i):
            word_i = window[i]
            word_i_id = word_id_map[word_i]
            word_j = window[j]
            word_j_id = word_id_map[word_j]
            if word_i_id == word_j_id:
                continue
            word_pair_str = str(word_i_id) + ',' + str(word_j_id)
            if word_pair_str in word_pair_count:
                word_pair_count[word_pair_str] += 1
            else:
                word_pair_count[word_pair_str] = 1
            # two orders
            word_pair_str = str(word_j_id) + ',' + str(word_i_id)
            if word_pair_str in word_pair_count:
                word_pair_count[word_pair_str] += 1
            else:
                word_pair_count[word_pair_str] = 1


##--------------step2------------------
# Build word graph and tfidf matrix


# train_labels
# there are more than one labels
train_y = []
label_pairs = []
for i in range(train_size):
    doc_meta = shuffle_doc_name_list[i]
    temp = doc_meta.split('\t')
    label = temp[2].split(' ')
    label_pairs.append(label)
    one_hot = [0 for l in range(len(label_list))]
    for lab in label:
        label_index = label_list.index(lab)
        one_hot[label_index] = 1
    train_y.append(one_hot)

test_y = []
for i in range(test_size):
    doc_meta = shuffle_doc_name_list[i + train_size]
    temp = doc_meta.split('\t')
    label = temp[2].split(" ")
    one_hot = [0 for l in range(len(label_list))]
    for lab in label:
        label_index = label_list.index(lab)
        one_hot[label_index] = 1
    test_y.append(one_hot)
test_y = np.array(test_y)

label_pair_freq = defaultdict(dict)
label_nums = defaultdict(int)
for label_pair in label_pairs:
    for v1 in label_pair:
        label_nums[v1] += 1
        for v2 in label_pair:
            if v1==v2:
                continue
            label_pair_freq[v1].setdefault(v2,0)
            label_pair_freq[v1][v2] += 1


row_wd = []
col_wd = []
weight_wd = []

# pmi as weights

num_window = len(windows)

for key in word_pair_count:
    temp = key.split(',')
    i = int(temp[0])
    j = int(temp[1])
    count = word_pair_count[key]
    word_freq_i = word_window_freq[vocab[i]]
    word_freq_j = word_window_freq[vocab[j]]
    pmi = log((1.0 * count / num_window) /
              (1.0 * word_freq_i * word_freq_j/(num_window * num_window)))
    if pmi <= 0:
        continue
    row_wd.append(i)
    col_wd.append(j)
    weight_wd.append(pmi)


'''
stats word freq according to labels
'''
total_words = sum([v for k,v in word_freq.items()])

for i, lab in enumerate(label_list):
    words = label_word_list[lab]
    word_c = Counter(words)
    tot = sum([v for k,v in word_c.items()])
    for k,v in word_c.items():
        if k in word_id_map:
            j = word_id_map[k]
        else:
            continue
        row_wd.append(i+vocab_size)
        col_wd.append(j)
        freq_tot = word_freq[k]/total_words
        freq_lab = v/tot
        #val = freq_lab/(1+math.log(1+freq_tot))
        #val = 1000.0*freq_lab*log(1.0/freq_tot)
        val = 500.0*freq_lab*log(1.0/freq_tot)
        #val = freq_lab
        print(val)
        weight_wd.append(val)


for i, lab in enumerate(label_list):
    for k,v in label_pair_freq[lab].items():
        j = label_list.index(k)
        if i == j:
            continue
        row_wd.append(i+vocab_size)
        col_wd.append(j+vocab_size)
        #weight_wd.append(10.0*v/label_nums[lab])
        weight_wd.append(v/label_nums[lab])
#
label_size = len(label_list)
word_vectors = defaultdict()
with open("/home/chixiao/dataset/glove.840B.300d.txt") as f:
    for line in f:
        line = line.strip().split()
        #if len(line)!= 301:
        #    print("error")
        #    print(line)

        voc = line[0]
        vec = line[-300:]
        word_vectors[voc] = np.array(vec)

word_dim = 300
cnt = 0
with open("../data/ind.{}.glove".format(dataset),'w') as f:
    for i in range(vocab_size):
        word = vocab[i]
        vec = word_vectors.get(word,np.random.uniform(-1,1, word_dim))
        f.write(" ".join(map(str,vec))+'\n')
        if word in word_vectors.keys():
            cnt += 1

    for i in range(label_size):
        vec = np.random.uniform(-1,1,word_dim)
        f.write(" ".join(map(str,vec))+'\n')
print(cnt/vocab_size)
#row_allx = []
#col_allx = []
#data_allx = []

# doc word frequency
doc_word_freq = {}

for doc_id in range(len(shuffle_doc_words_list)):
    doc_words = shuffle_doc_words_list[doc_id]
    words = doc_words.split()
    for word in words:
        word_id = word_id_map[word]
        doc_word_str = str(doc_id) + ',' + str(word_id)
        if doc_word_str in doc_word_freq:
            doc_word_freq[doc_word_str] += 1
        else:
            doc_word_freq[doc_word_str] = 1

label_size = len(label_list)

# First create tf-idf matrix
row_tf = []
col_tf = []
weight_tf = []
for i in range(len(shuffle_doc_words_list)):
    doc_words = shuffle_doc_words_list[i]
    words = doc_words.split()
    doc_word_set = set()
    for word in words:
        if word in doc_word_set:
            continue
        j = word_id_map[word]
        key = str(i) + ',' + str(j)
        freq = doc_word_freq[key]
        row_tf.append(i)
        col_tf.append(j)
        idf = log(1.0 * len(shuffle_doc_words_list) /
                  word_doc_freq[vocab[j]])
        weight_tf.append(freq * idf)
        doc_word_set.add(word)

net_size = vocab_size+label_size


wordadj = sp.csr_matrix(
    (weight_wd, (row_wd, col_wd)), shape=(net_size, net_size))

tfmatrix = sp.csr_matrix(
    (weight_tf, (row_tf, col_tf)), shape=(len(shuffle_doc_words_list),vocab_size))

# word net (word+label) matrix
with open("../data/ind.{}.adj".format(dataset), 'wb') as f:
    pkl.dump(wordadj, f)

# tfidf matrix
with open("../data/ind.{}.tfidf".format(dataset), 'wb') as f:
    pkl.dump(tfmatrix, f)

# one-hot train labels
with open("../data/ind.{}.train_labels".format(dataset), 'wb') as f:
    pkl.dump(train_y, f)

# one-hot test labels
with open("../data/ind.{}.test_labels".format(dataset), 'wb') as f:
    pkl.dump(test_y, f)
