from collections import deque
import torch
import torch.utils.data as data
import numpy as np
from torch import Tensor
import random
from tqdm import tqdm
from sentence_transformers import InputExample

topks = [3, 5, 10, 20, 50, 100, 500]

def cosine_similarity(vector1, vector2):
    dot_product = torch.dot(vector1, vector2)
    magnitude1 = torch.linalg.norm(vector1)
    magnitude2 = torch.linalg.norm(vector2)
    
    similarity = dot_product / (magnitude1 * magnitude2)
    return similarity

def convert_to_rank_score(org_results, raw=0.1):
    rank_importance = np.asarray([(1 / rank) ** (raw) for rank in range(1, 3001)])
    new_top_result = {}

    tmp = np.asarray(list(org_results.items()))
    top_items, scores = tmp[:,0], tmp[:,1].astype(float)
    rank_tmp = np.argsort(-scores)

    rank_score = np.zeros_like(rank_tmp).astype(float)
    rank_score[rank_tmp] = rank_importance[:rank_tmp.shape[0]]

    for idx in range(top_items.shape[0]):
        new_top_result[top_items[idx]] = rank_score[idx]
            
    return new_top_result

def distinveness(term, c, hd_c, k=20):
    if term not in hd_c:
        return c[term]
    else:
        return c[term] * np.log(k/hd_c[term])

def lexical_sim(doc2terms, term_set):
    score = 0.
    co_terms = doc2terms.keys() & term_set
    for co_term in co_terms:
        score += doc2terms[co_term]
    return score

def single_level_assignment(docnumber, parent_topic, cid_list, class_dict, merged_class_dict, topic2topicid_dict, doc2phrase_dic, corpus_emb, topic_embs, mode=0):
    docid = cid_list[docnumber]

    if parent_topic in class_dict: 
        child_topics = class_dict[parent_topic]
    else:
        child_topics = []
        return False, False

    lexical_score_dict = {} 
    semantic_score_dict = {}

    for child_topic in child_topics:
        if child_topic not in topic2topicid_dict: continue
        
        term_set = merged_class_dict[child_topic]
        class_number = topic2topicid_dict[child_topic]

        lexical_score = lexical_sim(doc2phrase_dic[docid], term_set)
        semantic_score = cosine_similarity(corpus_emb[docnumber], topic_embs[class_number])

        lexical_score_dict[child_topic] = lexical_score
        semantic_score_dict[child_topic] = float(semantic_score)
        
    if len(lexical_score_dict) == 0: return False, False

    lexical_rank_score = convert_to_rank_score(lexical_score_dict)
    semantic_rank_score = convert_to_rank_score(semantic_score_dict)

    ensemble_result = []
    for child_topic in child_topics:
        if child_topic not in topic2topicid_dict: continue
        ensemble_result.append([child_topic, ((semantic_rank_score[child_topic] * 0.5 + lexical_rank_score[child_topic] * 0.5 * mode) ** 10)])
    sorted_result = sorted(ensemble_result, key=lambda x: -x[1])
    
    return True, sorted_result[:2]


def omit_substrings(terms):
    sorted_terms = sorted(terms, key=len)
    result = []

    for i, term in enumerate(sorted_terms):
        is_prefix = False
        for other_term in sorted_terms[i+1:]:
            if term in other_term:
                is_prefix = True
                break
        if not is_prefix:
            result.append(term)

    return result

def to_np(x):
    return x.data.cpu().numpy()

def sigmoid(x):
    return 1 / (1 + torch.exp(-x))

def normalize(x):
    x /= torch.sqrt(torch.sum(x ** 2, axis=1, keepdims=True))
    return x

def softmax(x):
    return torch.exp(x) / np.sum(torch.exp(x), axis=1, keepdims=True)

def score_mat_2_rank_mat(score_mat):
    rank_tmp = torch.argsort(-score_mat, axis=-1).to('cpu')
    rank_mat = torch.zeros_like(rank_tmp)
    for i in range(rank_mat.shape[0]):
        row = rank_tmp[i]
        rank_mat[i][row] = torch.LongTensor(np.arange(len(row)))
    return rank_mat

def eval_full_score_mat(score_mat, qid_list, cid_list):
    results_dict = {}
    for row in tqdm(range(score_mat.shape[0])):
        q_id = qid_list[row]
        if q_id not in results_dict: results_dict[q_id] = {}
        for col in range(score_mat.shape[1]):
            c_id = cid_list[col]
            results_dict[q_id][c_id] = float(score_mat[row][col])
    return results_dict

def print_metrics(metrics):
    metric_list = ['NDCG@', 'MAP@', 'Recall@']
    for idx, metric in enumerate(metric_list):
        tmp = []
        for k in topks:
            tmp.append(metric + str(k) + ':')
            tmp.append(str(round(metrics[idx][metric + str(k)], 4)))
            tmp.append(', ')
        print(''.join(tmp))

class CLF_dataset(data.Dataset):
    def __init__(self, X, Y, Y2):
        super(CLF_dataset, self).__init__()
        self.X = X
        self.Y = Y
        self.Y2 = Y2

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return idx, self.X[idx]

    def get_labels(self, batch_indices):
        return self.Y[batch_indices].todense(), self.Y2[batch_indices].todense()
    
class TripletDataset(data.Dataset):
    def __init__(self, train_triplets, doc2class_dict, doc2phrase_dict, q_emb, q_list, c_emb, c_list, num_class, num_phrase):
        self.train_triplets = train_triplets
        self.queries_ids = list(train_triplets.keys())
        self.doc2class_dict = doc2class_dict
        self.doc2phrase_dict = doc2phrase_dict

        for qid in self.queries_ids:
            self.train_triplets[qid]['pos'] = list(self.train_triplets[qid]['pos'])
            self.train_triplets[qid]['hard_neg'] = list(self.train_triplets[qid]['hard_neg'])
            random.shuffle(self.train_triplets[qid]['hard_neg'])
            
        self.q_emb, self.q_list = q_emb, q_list
        self.qid2idx = {}
        for idx, qid in enumerate(q_list):
            self.qid2idx[qid] = idx
            
        self.c_emb, self.c_list = c_emb, c_list
        self.cid2idx = {}
        for idx, cid in enumerate(c_list):
            self.cid2idx[cid] = idx
            
        self.num_class = num_class
        self.num_phrase = num_phrase

    def __getitem__(self, item):
        qid = self.queries_ids[item]
        
        pos_id = self.train_triplets[qid]['pos'].pop(0) 
        self.train_triplets[qid]['pos'].append(pos_id)

        neg_id = self.train_triplets[qid]['hard_neg'].pop(0)
        self.train_triplets[qid]['hard_neg'].append(neg_id)
        
        class_labels = self.doc2class_dict[pos_id]
        phrase_labels = self.doc2phrase_dict[pos_id]
        
        class_label_vec = torch.zeros([self.num_class])
        class_label_vec[class_labels] = 1
        
        phrase_label_vec = torch.zeros([self.num_phrase])
        phrase_label_vec[phrase_labels] = 1        

        return qid, pos_id, neg_id, class_label_vec, phrase_label_vec
    
    def qid2emb(self, qids):
        idx_tensor = [self.qid2idx[qid] for qid in qids]
        return self.q_emb[idx_tensor]  
    
    def cid2emb(self, cids):
        idx_tensor = [self.cid2idx[cid] for cid in cids]
        return self.c_emb[idx_tensor]

    def __len__(self):
        return len(self.queries_ids)


class Standard_TripletDataset(data.Dataset):
    def __init__(self, queries, corpus):
        self.queries = queries
        self.queries_ids = list(queries.keys())
        self.corpus = corpus

        for qid in self.queries:
            self.queries[qid]['pos'] = list(self.queries[qid]['pos'])
            self.queries[qid]['hard_neg'] = list(self.queries[qid]['hard_neg'])
            random.shuffle(self.queries[qid]['hard_neg'])

    def __getitem__(self, item):
        query = self.queries[self.queries_ids[item]]
        query_text = query['query']

        pos_id = query['pos'].pop(0)    
        pos_text = self.corpus[pos_id]["text"]
        query['pos'].append(pos_id)

        neg_id = query['hard_neg'].pop(0)    
        neg_text = self.corpus[neg_id]["text"]
        query['hard_neg'].append(neg_id)

        return InputExample(texts=[query_text, pos_text, neg_text])

    def __len__(self):
        return len(self.queries)


def get_degree_dict(label_graph):

    class_degree_dict = {}

    for parent in label_graph:
        for child in label_graph[parent]:
            
            if parent not in class_degree_dict:
                class_degree_dict[parent] = 1
            else:
                class_degree_dict[parent] += 1

            if child not in class_degree_dict:
                class_degree_dict[child] = 1
            else:
                class_degree_dict[child] += 1
                
    return class_degree_dict
    
def get_adj_mat(num_class, child2parent_dict):
    
    A_indices, A_values = [[], []], []
    for child in child2parent_dict:
        for parent in child2parent_dict[child]:
            A_indices[0].append(child)
            A_indices[1].append(parent)
            A_values.append(1 / len(child2parent_dict[child]))

    A_indices = torch.LongTensor(A_indices)
    A_values = torch.FloatTensor(A_values)
    A = torch.sparse.FloatTensor(A_indices, A_values, torch.Size([num_class, num_class]))

    return A

def return_topK_result(result_dict, topk=1000):
    new_results = {}
    for qid in result_dict:
        new_results[qid] = {}
        for pid, score in sorted(result_dict[qid].items(), key=lambda item: item[1], reverse=True)[:topk]:
            new_results[qid][pid] = score
    return new_results

def dot_score(a: Tensor, b: Tensor) -> Tensor:
    if not isinstance(a, torch.Tensor):
        a = torch.tensor(a)

    if not isinstance(b, torch.Tensor):
        b = torch.tensor(b)

    if len(a.shape) == 1:
        a = a.unsqueeze(0)

    if len(b.shape) == 1:
        b = b.unsqueeze(0)

    return torch.mm(a, b.transpose(0, 1))