import numpy as np
from copy import deepcopy
from tqdm import tqdm
import torch
from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification
import torch.nn.functional as F


class SentenceBertRetriever:
    def __init__(self, samples=None, dataset='MultiWOZ'):
        self.documents = samples
        self.dataset = dataset
        self.vectors = []
        self.device = 'cpu'
        if torch.cuda.device_count() > 0:
            self.device = 'cuda'

        self.model = AutoModel.from_pretrained('BAAI/bge-large-en-v1.5').to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-large-en-v1.5')
        self.model.eval()

        if dataset == 'MultiWOZ':
            print('Using last utterance for exemplar retrieval.')
            contexts = [x['context'][-1] for x in samples]
        elif dataset == 'SMD':
            print('Using context for exemplar retrieval.')
            contexts = [' '.join(x['context']) for x in samples]

        vectors = np.zeros((len(contexts), 1024))
        for st in tqdm(range(0, len(contexts), 128)):
            en = min(len(contexts), st + 128)
            tout = self.tokenizer(contexts[st:en], return_tensors='pt', padding=True, truncation=True)
            tout = dict([(k, v.to(self.device)) for k, v in tout.items()])
            with torch.no_grad():
                ret = self.model(**tout)
            embs = ret[0][:, 0]
            embs = F.normalize(embs, p=2, dim=1)
            vectors[st:en, :] = embs.to("cpu").numpy()

        self.vectors = vectors
        print(f'Total documents', len(vectors), self.vectors.shape)

    def compute_scores(self, text):
        tout = self.tokenizer([text], return_tensors='pt', padding=True, truncation=True)
        tout = dict([(k, v.to(self.device)) for k, v in tout.items()])
        with torch.no_grad():
            ret = self.model(**tout)

        embs = ret[0][:, 0]
        embs = F.normalize(embs, p=2, dim=1)
        qvec = embs.to("cpu").numpy()
        scores = np.matmul(self.vectors, qvec.T)[:, 0]

        return scores

    def search_top_k(self, text, k=1, uuid=None, etype=None):
        scores = self.compute_scores(text)

        rets = []
        idxs = np.argsort(scores)[::-1]
        dids_sofar = set()
        for ii in idxs:
            if uuid is not None and etype is not None:
                if uuid == self.documents[ii]['uuid']:
                    continue
            if self.documents[ii]['uuid'] in dids_sofar:
                continue
            dids_sofar.add(self.documents[ii]['uuid'])
            rets.append(deepcopy(self.documents[ii]))
            if len(rets) == k:
                break

        return rets
