import torch
from src.model.universal_model import UniversalModel

if __name__ == '__main__':
    import random
    import numpy as np

    random.seed(42)
    torch.manual_seed(42)
    np.random.seed(42)

    model = UniversalModel.from_pretrained('model_files/ours_fp16_best', num_labels=6, constant_num=13,
                                           diff_param_for_height=False,
                                           height=10,
                                           add_replacement=True,
                                           consider_multiple_m0=True)
    model.eval()
    from transformers import BertTokenizer

    tokenizer = BertTokenizer.from_pretrained('hfl/chinese-roberta-wwm-ext')
    uni_labels = [
        '+', '-', '-_rev', '*', '/', '/_rev'
    ]
    # # text1 = "一本笔记本 <quant> 元钱, 王小明共带了 <quant> 元, 他一共能买多少本这样的笔记本?"  ## x= temp_b / temp_a
    # text2 = "爸爸买来 <quant> 个桃子, 吃了 <quant> 个, 妈妈又买来 <quant> 个桃子, 现在有多少个桃子?"  ##x= temp_a - temp_b + temp_c"
    # # res = tokenizer.batch_encode_plus([text1, text2], return_tensors='pt', padding=True)
    # res = tokenizer.batch_encode_plus([text2], return_tensors='pt', padding=True)
    # input_ids = res["input_ids"]
    # attention_mask = res["attention_mask"]
    # token_type_ids = res["token_type_ids"]
    # # variable_indexs_start = torch.tensor([[6, 20, 0], [5, 16, 28]])
    # # variable_indexs_end = torch.tensor([[10, 24, 0], [9, 20, 32]])
    # # num_variables = torch.tensor([2, 3])
    # # variable_index_mask = torch.tensor([[1, 1, 0], [1, 1, 1]])
    #
    # variable_indexs_start = torch.tensor([[5, 16, 28]])
    # variable_indexs_end = torch.tensor([[9, 20, 32]])
    # num_variables = torch.tensor([3])
    # variable_index_mask = torch.tensor([[1, 1, 1]])
    #
    # res, scores = model.beam_search(input_ids=input_ids,
    #                         attention_mask=attention_mask,
    #                         token_type_ids=token_type_ids,
    #                         variable_indexs_start=variable_indexs_start,
    #                         variable_indexs_end=variable_indexs_end,
    #                         num_variables=num_variables,
    #                         variable_index_mask=variable_index_mask,
    #                         num_beams=3)
    # print(res[0])
    # print(scores)
    # # print(res[1])
    #
    # from universal_main import get_batched_prediction_consider_multiple_m0
    # from src.data.universal_dataset import UniFeature
    #
    # res = model(input_ids=input_ids,
    #             attention_mask=attention_mask,
    #             token_type_ids=token_type_ids,
    #             variable_indexs_start=variable_indexs_start,
    #             variable_indexs_end=variable_indexs_end,
    #             num_variables=num_variables,
    #             variable_index_mask=variable_index_mask)
    # feature = UniFeature(variable_indexs_start=variable_indexs_start, input_ids=input_ids,
    #                      attention_mask=attention_mask)
    # batched_prediction = get_batched_prediction_consider_multiple_m0(feature=feature, all_logits=res.all_logits,
    #                                                                  constant_num=13,
    #                                                                  add_replacement=True)
    # ## post process remve extra
    # for b, inst_predictions in enumerate(batched_prediction):
    #     for p, prediction_step in enumerate(inst_predictions):
    #         left, right, op_id, stop_id = prediction_step
    #         if stop_id == 1:
    #             batched_prediction[b] = batched_prediction[b][:(p + 1)]
    #             break
    # print(batched_prediction)

    from src.data.universal_dataset import UniversalDataset
    from torch.utils.data import DataLoader

    constants = ['5.0', '10.0', '2.0', '8.0', '30.0', '1.0', '6.0', '7.0', '12.0', '4.0', '31.0', '3.14', '3.0']
    constant2id = {c: idx for idx, c in enumerate(constants)}
    constant_values = [float(c) for c in constants]
    eval_dataset = UniversalDataset(file="data/large_math/large_math_test_nodup.json", tokenizer=tokenizer, number=-1,
                                    filtered_steps=None,
                                    constant2id=constant2id, constant_values=constant_values,
                                    add_replacement=True,
                                    use_incremental_labeling=True,
                                    add_new_token=False)
    valid_dataloader = DataLoader(eval_dataset, batch_size=1, shuffle=False, num_workers=0,
                                  collate_fn=eval_dataset.collate_function)
    dev = torch.device('cpu')
    from universal_main import get_batched_prediction_consider_multiple_m0
    for feature in valid_dataloader:
        batched_prediction = model.beam_search(input_ids=feature.input_ids.to(dev),
                                                attention_mask=feature.attention_mask.to(dev),
                                                token_type_ids=feature.token_type_ids.to(dev),
                                                variable_indexs_start=feature.variable_indexs_start.to(dev),
                                                variable_indexs_end=feature.variable_indexs_end.to(dev),
                                                num_variables=feature.num_variables.to(dev),
                                                variable_index_mask=feature.variable_index_mask.to(dev),
                                                labels=feature.labels.to(dev),
                                                label_height_mask=feature.label_height_mask.to(dev),
                                            return_dict=True, is_eval=True, num_beams=3)
        print(batched_prediction)

        all_logits = model(input_ids=feature.input_ids.to(dev), attention_mask=feature.attention_mask.to(dev),
                            token_type_ids=feature.token_type_ids.to(dev),
                            variable_indexs_start=feature.variable_indexs_start.to(dev),
                            variable_indexs_end=feature.variable_indexs_end.to(dev),
                            num_variables=feature.num_variables.to(dev),
                            variable_index_mask=feature.variable_index_mask.to(dev),
                            labels=feature.labels.to(dev), label_height_mask=feature.label_height_mask.to(dev),
                            return_dict=True, is_eval=True).all_logits
        batched_prediction = get_batched_prediction_consider_multiple_m0(feature=feature,
                                                                                         all_logits=all_logits,
                                                                                         constant_num=len(constant_values),
                                                                                         add_replacement=True)
        ## post process remve extra
        for b, inst_predictions in enumerate(batched_prediction):
            for p, prediction_step in enumerate(inst_predictions):
                left, right, op_id, stop_id = prediction_step
                if stop_id == 1:
                    batched_prediction[b] = batched_prediction[b][:(p + 1)]
                    break
        print(batched_prediction)

