import re
import torch
import pickle
import numpy as np
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel


def splitCamelCase(string):
    string = re.sub('(.)([A-Z][a-z]+)', r'\1 \2', string)
    string = re.sub('(.)([0-9]+)', r'\1 \2', string)
    return re.sub('([a-z0-9])([A-Z])', r'\1 \2', string).lower()


class PathRanker(nn.Module):
    def __init__(self, config):
        super(PathRanker, self).__init__()
        self.config = config

        # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
        # BERT

        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
        self.tokenizer.add_special_tokens({'additional_special_tokens': ['[R]', '[Q]']})
        self.bert = AutoModel.from_pretrained("bert-base-cased")
        self.bert.resize_token_embeddings(len(self.tokenizer))
        for p in self.bert.parameters():
            p.requires_grad = False

        # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
        # KGE

        model_name = 'RotatE'
        is_use_property = True

        global_entity_path = '/DATA_PATH/ImRL/KGE/ckpts/' + model_name + '/entity.npy'
        global_relation_path = '/DATA_PATH/ImRL/KGE/ckpts/' + model_name + '/relation.npy'

        pre_train_embed = np.load(global_relation_path, allow_pickle=True)
        if is_use_property:
            property_embed = np.load('/DATA_PATH/ImRL/data/kge/property_embedding.npy', allow_pickle=True)
            if model_name == 'TransE_l1':
                pre_train_embed = np.vstack((pre_train_embed[:9822], property_embed * (1e-2)))
            elif model_name == 'RotatE':
                pre_train_embed = np.vstack((pre_train_embed[:9822], property_embed * (1e-2)))
            else:
                pre_train_embed = np.vstack((pre_train_embed[:9822], property_embed))
        unseen_embed = np.load('/DATA_PATH/ImRL/data/kge/unseen_relation.npy', allow_pickle=True)
        pre_train_embed = np.vstack((pre_train_embed, unseen_embed * (1e-2)))
        self.kge_rel_embedding = nn.Embedding(pre_train_embed.shape[0], pre_train_embed.shape[1])
        self.kge_rel_embedding.weight.data.copy_(torch.from_numpy(pre_train_embed))
        self.kge_rel_embedding = self.kge_rel_embedding.to(self.config.device)
        self.kge_rel_embedding.requires_grad = False

        pre_train_embed = np.load(global_entity_path, allow_pickle=True)
        self.kge_ent_embedding = nn.Embedding(pre_train_embed.shape[0], self.config.kge_size)
        self.kge_ent_embedding.weight.data.copy_(torch.from_numpy(pre_train_embed[:, :self.config.kge_size]))
        self.kge_ent_embedding = self.kge_ent_embedding.to(self.config.device)
        self.kge_ent_embedding.requires_grad = False

        # bert-level
        self.relation_lm_encoder = nn.Linear(config.bert_size, config.hidden_size)
        self.phrase_lm_encoder = nn.Linear(config.bert_size, config.hidden_size)

        # rotate-level
        self.phrase_kg_encoder1 = nn.Linear(config.bert_size, config.hidden_size)
        self.phrase_kg_encoder2 = nn.Linear(config.hidden_size, config.hidden_size)
        self.phrase_kg_encoder3 = nn.Linear(config.hidden_size, config.kge_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.1)

        # DICT
        self.use_dict = True
        self.Kmatrix = torch.load('/DATA_PATH/ImRL/KGE/dict_keys.pt')
        self.V_BERT_matrix = torch.load('/DATA_PATH/ImRL/KGE/dict_values_bert.pt')
        self.V_ROTATE_matrix = torch.load('/DATA_PATH/ImRL/KGE/dict_values_rotate.pt')

        self.cat_lm = nn.Linear(config.bert_size * 2, config.bert_size)
        self.cat_kg = nn.Linear(config.kge_size * 2, config.kge_size)

        self.gate_lm = nn.Linear(config.bert_size * 2, 2)
        self.gate_kg = nn.Linear(config.kge_size * 2, 2)

        self.softmax = nn.Softmax(dim=1)


    def encode_path_kg(self, input):
        pi = 3.14159266368979323846
        r1, r2 = torch.chunk(input, 2, dim=1)  # batch * 1 * kge_size
        # convert from radians to points on complex unix ball
        r1_re, r1_im = torch.cos(r1), torch.sin(r1)
        r2_re, r2_im = torch.cos(r2), torch.sin(r2)
        r_re = r1_re * r2_re
        r_im = r1_re * r2_re
        a = torch.acos(r_re)
        b = torch.asin(r_im)
        # equals to directly element add using r1 & r2
        return b

    def forward(self, relation_phrases, questions, candidate_relation_path_names, candidate_relation_path_ids, candi_relation_path_masks, positive_index):

        # question
        q_encoded_inputs = self.tokenizer(questions[0], padding=True, return_tensors="pt")
        q_input_ids = q_encoded_inputs.input_ids
        q_attention_masks = q_encoded_inputs.attention_mask
        q_input_ids = q_input_ids.to(self.config.device)
        q_attention_masks = q_attention_masks.to(self.config.device)
        q_output = self.bert(input_ids=q_input_ids, attention_mask=q_attention_masks).last_hidden_state
        Q_embedding = q_output[:, 0, :].unsqueeze(1)  # 1 * 1 * hidden

        P_lm = self.phrase_lm_encoder(Q_embedding).squeeze(1) # 1 * hidden
        P_kg = self.phrase_kg_encoder3(self.dropout(
                                            self.relu(self.phrase_kg_encoder2(
                                                self.relu(self.phrase_kg_encoder1(
                                                    Q_embedding
                                                )))))).squeeze(1) # 1* hidden
        # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

        # relation
        encoded_inputs = self.tokenizer(candidate_relation_path_names, padding=True, return_tensors="pt")
        input_ids = encoded_inputs.input_ids
        attention_masks = encoded_inputs.attention_mask
        input_ids = input_ids.to(self.config.device)
        attention_masks = attention_masks.to(self.config.device)
        output = self.bert(input_ids=input_ids, attention_mask=attention_masks).last_hidden_state
        R_embedding = output[:, 0, :].unsqueeze(1)  # batch * 1 * hidden

        R_lm = self.relation_lm_encoder(R_embedding).squeeze(1)

        relation_kge_ids = torch.tensor(candidate_relation_path_ids).to(self.config.device)
        relation_kge_masks = torch.tensor(candi_relation_path_masks).unsqueeze(2).to(self.config.device)
        res = self.kge_rel_embedding(relation_kge_ids)
        res = torch.mul(res, relation_kge_masks)
        # R_kg = self.encode_path_kg(res)
        R_kg = torch.sum(res, dim=1).squeeze(1)  # batch * hidden

        if positive_index == -1:
            R_kg_gold = None
        else:
            R_kg_gold = R_kg[positive_index, :]

        # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

        # dict
        if self.use_dict:
            dot_res = torch.matmul(Q_embedding.squeeze(1), self.Kmatrix.transpose(0, 1))
            _, indexs = dot_res.topk(k=10, dim=1)
            masks = torch.ones(dot_res.shape, dtype=torch.bool).to(self.config.device)
            masks[torch.arange(dot_res.shape[0]).unsqueeze(-1), indexs] = False
            dot_res = torch.where(masks, torch.tensor(0.0).to(self.config.device), dot_res)
            weight = self.softmax(dot_res)
            know_lm = torch.matmul(weight, self.V_BERT_matrix) # 1 * bert_size
            know_kg = torch.matmul(weight, self.V_ROTATE_matrix) # 1 * kge_size

            if self.config.fusion_method == 'mean': # mean , cat , gate
                P_lm_fused = (P_lm + know_lm) / 2
                P_kg_fused = (P_kg + know_kg) / 2
            elif self.config.fusion_method == 'cat':
                P_lm_fused = self.cat_lm(torch.cat((P_lm, know_lm), dim=1))
                P_kg_fused = self.cat_kg(torch.cat((P_kg, know_kg), dim=1))
            elif self.config.fusion_method == 'gate':
                gate_lm = self.softmax(self.gate_lm(torch.cat((P_lm, know_lm), dim=1)))[:, 1]
                P_lm_fused = gate_lm * P_lm + (1 - gate_lm) * know_lm  # 1 * bert_size
                gate_kg = self.softmax(self.gate_kg(torch.cat((P_kg, know_kg), dim=1)))[:, 1]
                P_kg_fused = gate_kg * P_kg + (1 - gate_kg) * know_kg  # 1 * kge_size
            else:
                P_lm_fused = P_lm
                P_kg_fused = P_kg

        final_Q = torch.cat((P_lm_fused, P_kg_fused), dim=1)
        final_R = torch.cat((R_lm, R_kg), dim=1)
        # print(final_Q, final_R)

        score = final_Q.mm(final_R.transpose(0, 1))

        return score.reshape(1, -1), P_kg.squeeze(0), R_kg_gold