import torch
import json

from time import time
from tqdm import tqdm
from random import shuffle, randint
from transformers import BertTokenizer, BertModel

from source.model import noname
from source.parser import parse
from source.sampler import sampler
from source.count import counter
from source.zero_shot import zero_shot_calculator
from source.function import *

def test_zero_shot(config, datapath, model, tokenizer, data, types):
    print('test:')
    if config.zh:
        test_data = open(datapath + config.dataset + '/test_zh.json').readlines()
    else:
        test_data = open(datapath + config.dataset + '/test.json').readlines()

    spj = zero_shot_calculator(config, model, tokenizer, data, types)
    my_sampler = sampler(test_data, types, batch_size = 100)

    model.eval()
    cnter = counter()

    for x in my_sampler.n_way_k_shot():
        model_input, mask_tensor, ans, pos = get_input(x, tokenizer, types, without_ent = (config.load==None))
        model_input = model_input.to(config.dev)
        mask_tensor = mask_tensor.to(config.dev)
        ans = ans.to(config.dev)

        model_output, embeddings = model(model_input, pos, mask_tensor = mask_tensor, without_ent = (config.load==None))
        model_output = spj.get_output(embeddings)

        cnter.count(model_output.view(-1).tolist(),ans.view(-1).tolist())

    print("on all types:")
    return cnter.output()

def test_contrastive(model, data, types, tokenizer):
    model.eval()

    avg_loss = 0
    cnt = 0

    my_sampler = sampler(data, types, batch_size = 64)

    print('begin test')
    for samp in my_sampler.get_batches():
        model_input, mask_tensor, ans, pos = get_input(samp, tokenizer, types)
        model_input = model_input.to('cuda:1')
        mask_tensor = mask_tensor.to('cuda:1')
        sim = get_sim(samp).to('cuda:1')

        model_output,embeddings = model(model_input, pos, mask_tensor = mask_tensor)
        loss = sim_loss_function(embeddings, sim)
        avg_loss += loss.item()
        cnt += 1

    print('avg_loss = ', avg_loss/cnt)

def test(model, datapath, tokenizer, types, config):
    if config.zh:
        test_data = open(datapath + config.dataset + '/test_zh.json').readlines()
    else:
        test_data = open(datapath + config.dataset + '/test.json').readlines()

    shuffle(test_data)
    if config.showall and config.dataset=='fewnerd':
        if len(test_data)>2000:
            test_data=test_data[:2000]

    print('testing : ')

    my_sampler = sampler(test_data, types, batch_size = 500)

    model.eval()
    cnter9 = counter()
    cnter = counter()

    loss = 0.

    for x in my_sampler.n_way_k_shot():
        model_input, mask_tensor, ans, pos = get_input(x, tokenizer, types)
        model_input = model_input.to(config.dev)
        mask_tensor = mask_tensor.to(config.dev)
        ans = ans.to(config.dev)

        model_output, embeddings = model(model_input, pos, mask_tensor = mask_tensor)
        loss += loss_function(model_output, ans).item()
        cnter.count(model_output.view(-1).tolist(),ans.view(-1).tolist())

    print("avg_loss on test : " , loss/len(test_data))
    print("on all types:")
    return cnter.output()
    print("on 9 types:")
    return cnter9.output()

def main():
    config = parse()
    path = '.'
    datapath = ''

    tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
    tokenizer.add_special_tokens({'additional_special_tokens':["<ent>", "<blank>"]})
    types = loads_from_file(path + '/types.json')

    model = torch.load('./model/cross_large_3.pth', map_location=lambda storage, loc: storage).to('cuda:1')
    model.activate_bert_fine_tuning(False)

    test_data = open(datapath + 'open_entity/test.json').readlines()
    test_contrastive(model, test_data, types, tokenizer)

if __name__ == '__main__':
    main()
