import sys
import os

import argparse
import json
import re
import uuid

wikipred_re = re.compile("\(P[0-9]+\)")

def getParser():
    parser = argparse.ArgumentParser(description="parser for arguments", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--triples", type=str, help="file containing the triples (h,r,t)", required=True)
    parser.add_argument("--enames", type=str, help="file containing entity names", default="ent2id.txt")
    parser.add_argument("--rnames", type=str, help="file containing relation names", default="rel2id.txt")
    parser.add_argument("--allow-multiword", action="store_true", default=False, help="use this flag to keep first word for multi-word tails")
    return parser

def readnames(file):
    delim1 = "\t"
    id2names = {}
    with open(file, 'r') as fin:
        first_line = True
        for line in fin:
            if first_line:
                first_line = False
                continue
            line = line.strip()
            if line:
                x = line.split(delim1)
                id2names[int(x[1])] = x[0].strip()
    return id2names

def expand_triples(params):
    delim1 = "\t"
    delim2 = "."
    enames = readnames(params.enames)
    rnames = readnames(params.rnames)
    uniq_rels = set()
    triples = {}
    test_triples = []
    with open(params.triples) as fin:
        first_line = True
        for line in fin:
            line = line.strip()
            if first_line:
                first_line = False
                continue
            if line:
                x = line.split(delim1)
                head_id = int(x[0].strip())
                rel_id = int(x[1].strip())
                tail_id = int(x[2].strip())
                head = enames.get(head_id, "NA")
                tail = enames.get(tail_id, "NA")
                rel = rnames.get(rel_id, "NA")
                if "NA" in [head, tail, rel]:
                    print([head, rel, tail])
                    import pdb; pdb.set_trace()
                    continue
                # if len(tail.split(' ')) > 1:
                #     if params.allow_multiword:
                #         tail = tail.split(' ')[0]
                #     else:
                #         continue
                masked_sentences = [" ".join([head, rel, "[MASK]"])]
                outdict = {"sub": head, "pred": rel, "obj": tail, "masked_sentences":masked_sentences, "obj_label": tail, "sub_label": head, "uuid": uuid.uuid4().hex, "sub_id": head_id, "pred_id": rel_id, "obj_id": tail_id}
                print("\t".join([head, rel, tail]))
                # print(json.dumps(outdict))
                # print(line)
                # print(tail)

def main():
    parser = getParser()
    try:
        params = parser.parse_args()
    except:
        # parser.print_help()
        sys.exit(1)
    expand_triples(params)

if __name__ == "__main__":
    main()

