import numpy as np
import torch
import json
import logging
from os.path import join
import random
import sys
from glob import glob

MAX_NUM_1 = 2
MAX_NUM_2 = 10
MAX_NUM_3 = 6
MAX_NUM_VECTORS = MAX_NUM_1+ MAX_NUM_2+ MAX_NUM_3

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger(__name__)


def load_file(filename):
    data = []
    with open(filename, "r") as f:
        for line in f.readlines():
            data.append(json.loads(line))
    return data


def get_relation_meta(args):
    relations = load_file(args.relation_profile)
    for relation in relations:
        if relation['relation'] == args.relation:
            return relation
    raise ValueError('Relation info %s not found in file %s'%(args.relation, args.relation_profile))


def get_labelWords(tokenizer, dataDir="../LAMA_data/autoprompt_data"):
    dirs = glob(dataDir+"/*/")
    predict_labelWords = {}
    for dir in dirs:
        data1 = open(join(dir, "train.jsonl"), "r").readlines()
        data2 = open(join(dir, "dev.jsonl"), "r").readlines()
        data3 = open(join(dir, "test.jsonl"), "r").readlines()
        data = data1 + data2+ data3
        for line in data:
            item = json.loads(line.strip())
            label = item["obj_label"]
            label_id = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(label))[0]
            predicate_id = item["predicate_id"]
            if not (predicate_id) in predict_labelWords:
                predict_labelWords[predicate_id] = set([])
            predict_labelWords[predicate_id].add(label_id)
    return predict_labelWords


# def get_new_token(vid):
#     assert(vid > 0 and vid <= MAX_NUM_VECTORS)
#     return '[V%d]'%(vid)
#
#
# def add_new_token(model):
#     new_tokens = [get_new_token(i + 1) for i in range(MAX_NUM_VECTORS)]
#     model.lm_tokenizer.add_tokens(new_tokens)
#     ebd = model.mlm_model.resize_token_embeddings(len(model.tokenizer))
#     logger.info('# vocab after adding new tokens: %d' % len(model.tokenizer))


def init_template(args, tokenizer):
    # if args.init_manual_template:
    relation = get_relation_meta(args)
    manual_template = relation["template"]
    # new_token_id = 0
    template = []
    for word in manual_template.split():
        if word in ['[X]', '[Y]']:
            template.append(word)
        else:
            tokens = tokenizer.tokenize(' ' + word)
            for token in tokens:
                # new_token_id += 1
                template.append(token)
    # else:
    #     template = '[X] ' + ' '.join(['<unk>' for i in range(args.num_vectors)]) + ' [Y] '+ ' '.join(['<unk>' for i in range(args.num_vectors)])
    return template


def load_vocab(vocab_filename):
    with open(vocab_filename, "r") as f:
        lines = f.readlines()
    vocab = [x.strip() for x in lines]
    return vocab


def parse_template(template, subject_label, object_label):
    SUBJ_SYMBOL = "[X]"
    OBJ_SYMBOL = "[Y]"
    template = template.replace(SUBJ_SYMBOL, subject_label)
    template = template.replace(OBJ_SYMBOL, object_label)
    return [template]


def gen_feature_sample(data_sample, template, mask_token):
    feature_sample = {}
    feature_sample['predicate_id'] = data_sample['predicate_id']
    feature_sample['sub_label'] = data_sample['sub_label']
    feature_sample['obj_label'] = data_sample['obj_label']
    feature_sample['uuid'] = data_sample['uuid'] if 'uuid' in data_sample else ''
    templates = parse_template(template.strip(), feature_sample['sub_label'].strip(), mask_token)
    feature_sample['template'] = [templates[0]]

    evidences = data_sample["evidences"]
    masked_sentences = []
    for evidence in evidences:
        masked_sentence = evidence["masked_sentence"]
        masked_sentences.append(masked_sentence)
    feature_sample["masked_sentences"] = masked_sentences
    return feature_sample


def load_data(data_path, template, vocab_subset=None, mask_token='[MASK]'):
    all_samples = []
    distinct_facts = set()
    raw_samples = load_file(data_path)
    for data_sample in raw_samples:
        # follow the LAMA setting, only keep distinct (sub, obj) pairs
        if (data_sample['sub_label'], data_sample['obj_label']) in distinct_facts:
            continue
        if (data_sample['obj_label'] not in vocab_subset):
            continue
        distinct_facts.add((data_sample['sub_label'], data_sample['obj_label']))
        feature_sample = gen_feature_sample(data_sample, template, mask_token)
        all_samples.append(feature_sample)
    return all_samples


def load_all_data(data_dirs, relations_meta, vocab_subset=None,
                  mask_token='[MASK]', mode="train", shuffle=True):
    # mode: "train", "dev", "test"
    all_samples = []
    distinct_facts = set()
    template_dict = {}
    for meta in relations_meta:
        name = meta["relation"]
        template = meta["template"]
        template_dict[name] = template

    for dir in data_dirs:
        name = dir.split("/")[-2]
        data_path = join(dir, mode+".jsonl")
        raw_samples = load_file(data_path)
        for data_sample in raw_samples:
            # follow the LAMA setting, only keep distinct (sub, obj) pairs
            if (data_sample['sub_label'], data_sample['obj_label']) in distinct_facts:
                continue
            if (data_sample['obj_label'] not in vocab_subset):
                continue
            distinct_facts.add((data_sample['sub_label'], data_sample['obj_label']))
            template = template_dict[name]
            feature_sample = gen_feature_sample(data_sample, template, mask_token)
            feature_sample["masked_sentences"] = feature_sample["template"]
            all_samples.append(feature_sample)
    if shuffle:
        random.shuffle(all_samples)
    return all_samples

def batchify(data, batch_size):
    list_samples_batches = []
    # list_sentences_batches = []
    current_samples_batch = []
    # current_sentences_batches = []
    c = 0
    for sample in data:
        # input_sentences = sample['input_sentences']
        current_samples_batch.append(sample)
        # current_sentences_batches.append(input_sentences)
        c += 1
        if c >= batch_size:
            list_samples_batches.append(current_samples_batch)
            # list_sentences_batches.append(current_sentences_batches)
            current_samples_batch = []
            # current_sentences_batches = []
            c = 0
    if current_samples_batch and len(current_samples_batch) > 0:
        list_samples_batches.append(current_samples_batch)
        # list_sentences_batches.append(current_sentences_batches)
    return list_samples_batches


def output_result(result, eval_loss):
    logger.info('* Evaluation result *')
    cor = 0
    tot = 0
    macro = 0.0
    loss = 0.0
    for rel in result:
        cor_, tot_, avg_, loss_ = result[rel]
        cor += cor_
        tot += tot_
        macro += avg_
        loss_ /= tot_
        loss += loss_
        logger.info('%s\t%.5f\t%d\t%d\t%.5f' % (rel, avg_, cor_, tot_, loss_))
    macro = macro / len(result) if len(result) > 0 else 0.0
    micro = cor / tot if tot > 0 else 0.0
    logger.info('Macro avg: %.5f' % macro)
    logger.info('Micro avg: %.5f, Eval_loss: %.5f, Eval_loss (common vocab): %.5f' %(micro, eval_loss / tot, loss / len(result) if len(result) > 0 else 0.0))
    sys.stdout.flush()
    return micro, macro
