import sys
import os

import argparse
import json
import pickle
import random

random.seed(42)

def getParser():
    parser = argparse.ArgumentParser(description="parser for arguments", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--bertdata", type=str, help="pickle containing bert predicitons", default='../BERT/LAMA/output/results/bert_large/freebase/uncased_result.pkl')
    parser.add_argument("--types", type=str, help="file containing the types", default="eval/human/types.txt")
    parser.add_argument("--num-samples", type=int, help="number of samples to extract", default=150)
    parser.add_argument("--num-splits", type=int, help="number of files into which samples should be split", default=3)
    parser.add_argument("--num-copies", type=int, help="number of copies for each sample", default=3)
    parser.add_argument("--outdir", type=str, help="output dir", required=True)
    return parser

def generate(params):
    with open(params.bertdata, 'rb') as fin:
        bert = pickle.load(fin)
    types = []
    with open(params.types, 'r') as fin:
        # types = json.load(fin)
        for line in fin:
            line = line.strip()
            if line:
                types.append(line)
    triples = bert['list_of_results']
    num_triples = len(triples)
    sample_indices = list(range(num_triples))
    random.shuffle(sample_indices)
    # sample_indices = random.sample(list(range(num_triples)), params.num_samples)
    samples_per_split = params.num_samples//params.num_splits
    output = []
    delim = ","
    with open(os.path.join(params.outdir, 'sample_indices.json'), 'w') as fout:
        json.dump(sample_indices, fout)

    count = 0
    for idx in sample_indices:
        sub = triples[idx]['sample']['sub']
        pred = triples[idx]['sample']['pred']
        if len(pred.split('.')) >1 :
            continue
        obj = triples[idx]['sample']['obj']
        output.append(delim.join([str(idx), sub, pred, obj]))
        count += 1
        if count > params.num_samples:
            break

    split_idxs = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']
    for copy_idx in range(params.num_copies):
        for split_idx in range(params.num_splits):
            filename = os.path.join(params.outdir, f"{copy_idx}{split_idxs[split_idx]}.csv")
            with open(filename, 'w') as fout:
                for idx in range(split_idx*samples_per_split, (split_idx+1)*samples_per_split, 1):
                    print(output[idx], file=fout)
def main():
    parser = getParser()
    try:
        params = parser.parse_args()
    except:
        # parser.print_help()
        sys.exit(1)
    generate(params)

if __name__ == "__main__":
    main()

