import os
from collections import defaultdict, Counter
from itertools import chain
import pickle
import numpy as np


if __name__ == "__main__":
    base_tsv = "data/noun/noun_closure.tsv"

    v2i = {}
    with open(base_tsv + ".vocab") as f:
        for l in f:
            lspl = l.strip().split()
            v2i[lspl[1]] = int(lspl[0])

    train_percentages = [0, 10, 25, 50]
    for tp in train_percentages:
        training_data = base_tsv + ".train_{}percent".format(tp)
        print("Loading ", training_data)

        with open(training_data) as f:
            pairs = [list(map(int, l.split())) for l in f]

        word2context = defaultdict(list)
        for p in pairs:
            word2context[p[0]] += [p[1]]
            word2context[p[1]] += [p[0]]

        freq = np.zeros(len(v2i))
        cnt = Counter(chain.from_iterable(pairs))
        for k, v in cnt.items():
            freq[k] = v

        pickle.dump(list(v2i.keys()), open(training_data + '_vocab.pkl', mode='wb'))
        pickle.dump(freq, open(training_data + '_freq.pkl', mode='wb'))
        pickle.dump(word2context, open(training_data + '_context.pkl', mode='wb'))
