import pickle
import numpy as np
import tqdm
import random
from collections import defaultdict, Counter
from itertools import chain


MAMMAL_TSV = "./data/mammal/mammal_closure.csv"
OUT_TSV = "./data/mammal/mammal_closure.tsv"

# read the closure file of mammal
closures = []
with open(MAMMAL_TSV) as f:
    for idx, ic in enumerate(f):
        if idx == 0:  # header
            continue
        ic_spl = ic.strip().split(",")
        closures.append([ic_spl[1], ic_spl[0]])

# create the vocabulary
vocab = {}
for c in closures:
    for w in c:
        if w not in vocab:
            vocab[w] = len(vocab)

# convert the closures to ids
id_closures = []
for c in closures:
    id_closures.append((vocab[c[0]], vocab[c[1]]))

# save new closure file
with open(OUT_TSV, "w") as f:
    for c in id_closures:
        f.write("{}\t{}\n".format(*c))

# create negative closures
negative_closures = set()
nodes = list(vocab.keys())
while len(negative_closures) != 1000:
    n1 = vocab[random.choice(nodes)]
    n2 = vocab[random.choice(nodes)]
    if n1 == n2:
        continue

    if (n1, n2) not in negative_closures:
        negative_closures.add((n1, n2))

with open(OUT_TSV + "_neg", "w") as f:
    for c in negative_closures:
        f.write("{}\t{}\n".format(*c))

# save vocab
with open(OUT_TSV + ".vocab", "w") as f:
    for k, v in vocab.items():
        f.write("{}\t{}\n".format(k, v))

with open(OUT_TSV) as f:
    pairs = [list(map(int, l.split())) for l in f]

word2context = defaultdict(list)
for p in pairs:
    word2context[p[0]] += [p[1]]
    word2context[p[1]] += [p[0]]

freq = np.zeros(len(vocab))
cnt = Counter(chain.from_iterable(pairs))
for k, v in cnt.items():
    freq[k] = v


pickle.dump(list(vocab.keys()), open(OUT_TSV + '_vocab.pkl', mode='wb'))
pickle.dump(freq, open(OUT_TSV + '_freq.pkl', mode='wb'))
pickle.dump(word2context, open(OUT_TSV + '_context.pkl', mode='wb'))