import os
import json
import torch
from tqdm import tqdm

ENT_START_TAG = "[unused0]"
ENT_END_TAG = "[unused1]"
ENT_TITLE_TAG = "[unused2]"

train_domains = ['american_football', 'doctor_who', 'fallout', 'final_fantasy', 'military', 'pro_wrestling', 'starwars', 'world_of_warcraft']
val_domains = ['coronation_street', 'muppets', 'ice_hockey', 'elder_scrolls']
test_domains = ['forgotten_realms', 'lego', 'star_trek', 'yugioh']

def load_mentions(data_dir, logger):
    mention_path = os.path.join(data_dir, 'blink_format')

    def read_mentions(part):
        mentions = dict()
        with open(os.path.join(mention_path, "%s.jsonl" % part)) as f:
            for line in f:
                field = json.loads(line)
                if field['corpus'] not in mentions.keys():
                    mentions[field['corpus']] = []
                mentions[field['corpus']].append(field)
        for key in mentions.keys():
            logger.info("Read {} mentions from {}".format(len(mentions[key]), key))
        
        return mentions

    logger.info("Loading train mentions")
    train_mentions = read_mentions('train')
    logger.info("Loading val mentions")
    val_mentions = read_mentions('val')
    logger.info("Loading test mentions")
    test_mentions = read_mentions('test')

    assert len(train_mentions.keys()) == len(train_domains)
    assert len(val_mentions.keys()) == len(val_domains)
    assert len(test_mentions.keys()) == len(test_domains)

    return train_mentions, val_mentions, test_mentions

def load_heldout_mentions(data_dir, logger):
    mention_path = os.path.join(data_dir, 'blink_format')

    def read_mentions(part):
        mentions = dict()
        with open(os.path.join(mention_path, "%s.jsonl" % part)) as f:
            for line in f:
                field = json.loads(line)
                if field['corpus'] not in mentions.keys():
                    mentions[field['corpus']] = []
                mentions[field['corpus']].append(field)
        for key in mentions.keys():
            logger.info("Read {} mentions from {}".format(len(mentions[key]), key))
        
        return mentions
    
    logger.info("Loading train seen mentions")
    train_seen_mentions = read_mentions('heldout_train_seen')
    logger.info("Loading train unseen mentions")
    train_unseen_mentions = read_mentions('heldout_train_unseen')

    assert len(train_seen_mentions.keys()) == len(train_domains)
    assert len(train_unseen_mentions.keys()) == len(train_domains)

    return train_seen_mentions, train_unseen_mentions


def load_entities(data_dir, logger):
    doc_path = os.path.join(data_dir, 'documents')
    
    def read_entities(domains):
        doc = dict()
        for domain in domains:
            doc_list = []
            with open(os.path.join(doc_path, domain + '.json')) as f:
                for line in f:
                    field = json.loads(line)
                    doc_list.append(field)
            docid_map = {d['document_id']:i for i,d in enumerate(doc_list)}

            doc[domain] = (doc_list, docid_map)
            logger.info("Read {} entities from {}".format(len(doc_list), domain))

        return doc 

    logger.info("Loading train entities")
    train_e = read_entities(train_domains)
    logger.info("Loading val entities")
    val_e = read_entities(val_domains)
    logger.info("Loading test entities")
    test_e = read_entities(test_domains)

    assert len(train_e.keys()) == len(train_domains)
    assert len(val_e.keys()) == len(val_domains)
    assert len(test_e.keys()) == len(test_domains)

    return train_e, val_e, test_e

def select_field(data, key1, key2=None):
    if key2 is None:
        return [example[key1] for example in data]
    else:
        return [example[key1][key2] for example in data]

def get_context_representation(sample, tokenizer, max_seq_length):
    mention_tokens = []
    if sample["mention"] and len(sample["mention"]) > 0:
        mention_tokens = tokenizer.tokenize(sample["mention"])
        mention_tokens = [ENT_START_TAG] + mention_tokens + [ENT_END_TAG]

    context_left = sample["context_left"]
    context_right = sample["context_right"]
    context_left = tokenizer.tokenize(context_left)
    context_right = tokenizer.tokenize(context_right)

    left_quota = (max_seq_length - len(mention_tokens)) // 2 - 1
    right_quota = max_seq_length - len(mention_tokens) - left_quota - 2
    left_add = len(context_left)
    right_add = len(context_right)
    if left_add <= left_quota:
        if right_add > right_quota:
            right_quota += left_quota - left_add
    else:
        if right_add <= right_quota:
            left_quota += right_quota - right_add

    context_tokens = (
            context_left[-left_quota:] + mention_tokens + context_right[:right_quota]
    )

    context_tokens = ["[CLS]"] + context_tokens + ["[SEP]"]
    input_ids = tokenizer.convert_tokens_to_ids(context_tokens)
    padding = [0] * (max_seq_length - len(input_ids))
    input_ids += padding
    assert len(input_ids) == max_seq_length

    return {"tokens": context_tokens, "ids": input_ids}

def get_entity_representation(sample, tokenizer, max_seq_length):
    cand_tokens = tokenizer.tokenize(sample['text'])
    title_tokens = tokenizer.tokenize(sample['title'])
    cand_tokens = title_tokens + [ENT_TITLE_TAG] + cand_tokens

    cand_tokens = cand_tokens[: max_seq_length - 2]
    cand_tokens = ["[CLS]"] + cand_tokens + ["[SEP]"]

    input_ids = tokenizer.convert_tokens_to_ids(cand_tokens)
    padding = [0] * (max_seq_length - len(input_ids))
    input_ids += padding
    assert len(input_ids) == max_seq_length

    return {"tokens": cand_tokens, "ids": input_ids}

def process_m_data(ms, es, tokenizer, max_len):
    # domain
    processed_m = dict()
    total_num = 0

    print("Convert mention to tensor")
    for domain in ms.keys():
        processed_m[domain] = []
        mentions = ms[domain]
        docid_map = es[domain][1]

        iter_m = tqdm(mentions, desc=domain)
        for idx, m in enumerate(iter_m):
            context_tokens = get_context_representation(m, tokenizer, max_len)
            g_doc = docid_map[m['label_document_id']]
            record = {"context": context_tokens, "g_doc": g_doc}
            processed_m[domain].append(record)
            total_num += 1

        context_vecs = torch.tensor(select_field(processed_m[domain], "context", "ids"), dtype=torch.long)
        g_doc_vecs = torch.tensor(select_field(processed_m[domain], "g_doc"), dtype=torch.long)
        processed_m[domain] = {"tokens": context_vecs, "g_doc": g_doc_vecs}

    return processed_m, total_num

def process_e_data(es, tokenizer, max_len):
    # domain
    processed_e = dict()
    total_num = 0

    print("Convert entity to tensor")
    for domain in es.keys():
        entities = es[domain][0] # 0 all entities
        processed_e[domain] = []

        iter_e = tqdm(entities, desc=domain)
        for idx, e in enumerate(iter_e):
            context_tokens = get_entity_representation(e, tokenizer, max_len)
            record = {"context": context_tokens}
            processed_e[domain].append(record)
            total_num += 1

        context_vecs = torch.tensor(select_field(processed_e[domain], "context", "ids"), dtype=torch.long)
        processed_e[domain] = {"tokens": context_vecs}

    return processed_e, total_num


if __name__ == "__main__":
    train_m, val_m, test_m = load_mentions('debug_dataset')
    train_e, val_e, test_e = load_entities('debug_dataset')

    from pytorch_transformers.tokenization_bert import BertTokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
    
    train_m_processed = process_m_data(train_m, train_e, tokenizer, 128)
    print(train_m_processed['final_fantasy']['tokens'])
    print(train_m_processed['final_fantasy']['g_doc'])

    train_e_processed = process_e_data(train_e, tokenizer, 128)

    '''
    # generate debug dataset
    train_m_debug, val_m_debug, test_m_debug = dict(), dict(), dict()
    train_e_debug, val_e_debug, test_e_debug = dict(), dict(), dict()
    
    # randomly sample 200 cases from each domain
    sample_num = 200
    import random
    for domain in train_m.keys():
        sampled_m = random.sample(train_m[domain], sample_num)
        doc, docid_map = train_e[domain][0], train_e[domain][1]
        train_m_debug[domain], train_e_debug[domain] = [], []
        for m in sampled_m:
            # context_document_id, label_document_id
            train_m_debug[domain].append(m)
            train_e_debug[domain].append(doc[docid_map[m['context_document_id']]])
            train_e_debug[domain].append(doc[docid_map[m['label_document_id']]])

    for domain in val_m.keys():
        sampled_m = random.sample(val_m[domain], sample_num)
        doc, docid_map = val_e[domain][0], val_e[domain][1]
        val_m_debug[domain], val_e_debug[domain] = [], []
        for m in sampled_m:
            # context_document_id, label_document_id
            val_m_debug[domain].append(m)
            val_e_debug[domain].append(doc[docid_map[m['context_document_id']]])
            val_e_debug[domain].append(doc[docid_map[m['label_document_id']]])

    for domain in test_m.keys():
        sampled_m = random.sample(test_m[domain], sample_num)
        doc, docid_map = test_e[domain][0], test_e[domain][1]
        test_m_debug[domain], test_e_debug[domain] = [], []
        for m in sampled_m:
            # context_document_id, label_document_id
            test_m_debug[domain].append(m)
            test_e_debug[domain].append(doc[docid_map[m['context_document_id']]])
            test_e_debug[domain].append(doc[docid_map[m['label_document_id']]])

    # save debug mentions and entities
    def save_mentions(mentions, save_dir):
        with open(save_dir, 'w') as f:
            for domain in mentions.keys():
                for m in mentions[domain]:
                    f.write(json.dumps(m)+'\n')

    def save_entities(es, save_dir):
        for domain in es.keys():
            with open(os.path.join(save_dir, domain+'.json'), 'w') as f:
                for e in es[domain]:
                    f.write(json.dumps(e)+'\n')

    print("Saving debug train mentions...")
    save_mentions(train_m_debug, './debug_dataset/mentions/train.json')
    print("Saving debug val mentions...")
    save_mentions(val_m_debug, './debug_dataset/mentions/val.json')
    print("Saving debug test mentions...")
    save_mentions(test_m_debug, './debug_dataset/mentions/test.json')

    print("Saving debug train entities...")
    save_entities(train_e_debug, './debug_dataset/documents')
    print("Saving debug val entities...")
    save_entities(val_e_debug, './debug_dataset/documents')
    print("Saving debug test entities...")
    save_entities(test_e_debug, './debug_dataset/documents')
    '''



