import torch

from tqdm.autonotebook import trange
from transformers import AutoTokenizer, AutoModel
import pickle
import numpy as np
from typing import Union, List, Dict, Tuple
import argparse

class CLS_pooling:
    def __init__(self, model_query, model_passage, tokenizer, **kwargs):
        self.model_query = model_query.to('cuda')
        self.model_passage = model_query
        self.tokenizer = tokenizer
    
    def encode_queries(self, queries: List[str], batch_size: int, **kwargs) -> np.ndarray:
        query_embeddings = []
        with torch.no_grad():
            for start_idx in trange(0, len(queries), batch_size):
                inputs = self.tokenizer(queries[start_idx:start_idx+batch_size], truncation=True, padding=True, \
                                           return_tensors='pt', max_length=512).to('cuda')
                model_out = self.model_query(**inputs)# bs x len x dim
                embeddings = model_out.last_hidden_state[:, 0, :]
                query_embeddings.append(embeddings)
                
        return torch.cat(query_embeddings, 0)
    
    def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs) -> np.ndarray:
        corpus_embeddings = []
        with torch.no_grad():
            for start_idx in trange(0, len(corpus), batch_size):
                texts = [row['text']  for row in corpus[start_idx:start_idx+batch_size]]
                inputs = self.tokenizer(texts, padding=True, return_tensors='pt', \
                                         truncation=True, max_length=512).to('cuda')
                attention_mask = inputs.attention_mask
                model_out = self.model_passage(**inputs)
                embeddings = model_out.last_hidden_state[:, 0, :]
                corpus_embeddings.append(embeddings)
        
        return torch.cat(corpus_embeddings, 0)

class MEAN_pooling:
    def __init__(self, model_query, model_passage, tokenizer, **kwargs):
        self.model_query = model_query.to('cuda')
        self.model_passage = model_query
        self.tokenizer = tokenizer
        
    def mean_pooling(self, token_embeddings, mask):
        token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
        sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
        return sentence_embeddings
    
    def encode_queries(self, queries: List[str], batch_size: int, **kwargs) -> np.ndarray:
        query_embeddings = []
        with torch.no_grad():
            for start_idx in trange(0, len(queries), batch_size):
                inputs = self.tokenizer(queries[start_idx:start_idx+batch_size], truncation=True, padding=True, \
                                           return_tensors='pt', max_length=512).to('cuda')
                model_out = self.model_query(**inputs)# bs x len x dim
                embeddings = self.mean_pooling(model_out[0], inputs['attention_mask'])
                query_embeddings.append(embeddings)
                
        return torch.cat(query_embeddings, 0)
    
    def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs) -> np.ndarray:
        corpus_embeddings = []
        with torch.no_grad():
            for start_idx in trange(0, len(corpus), batch_size):
                texts = [row['text']  for row in corpus[start_idx:start_idx+batch_size]]
                inputs = self.tokenizer(texts, padding=True, return_tensors='pt', \
                                         truncation=True, max_length=512).to('cuda')
                attention_mask = inputs.attention_mask
                model_out = self.model_passage(**inputs)
                embeddings = self.mean_pooling(model_out[0], inputs['attention_mask'])
                corpus_embeddings.append(embeddings)
        
        return torch.cat(corpus_embeddings, 0)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--backbone', type=str) # allenai/specter2_base, facebook/contriever-msmarco
    parser.add_argument('--text_path', type=str)
    parser.add_argument('--output_path', type=str)

    args = parser.parse_args()

    with open(args.text_path, 'rb') as f:
        text_dict = pickle.load(f)

    # load model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.backbone)
    model = AutoModel.from_pretrained(args.backbone)

    if args.backbone == 'allenai/specter2_base':
        model_query = model
        model_passage = model_query

        extractor = CLS_pooling(model_query, model_passage, tokenizer)
        text_emb = extractor.encode_corpus([text_dict[cid] for cid in text_dict], batch_size=16)
    elif args.backbone == 'facebook/contriever-msmarco':
        model_query = model
        model_passage = model_query
        extractor = MEAN_pooling(model_query, model_passage, tokenizer)
        text_emb = extractor.encode_corpus([text_dict[cid] for cid in text_dict], batch_size=16)

    torch.save(text_emb, args.output_path)

