import torch
import json
import requests
import argparse
import os,sys
from tqdm import tqdm
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from utils import query_search


tokenizer = AutoTokenizer.from_pretrained('./bge-reranker-base')
model = AutoModelForSequenceClassification.from_pretrained('./bge-reranker-base', 
                                                            device_map="auto",
                                                            torch_dtype=torch.float16,)
model.eval()

def search_for_rewrite(rewrite):

    search_result = []
    for idx, item in enumerate(rewrite):
        search_result.append(query_search(item))

    return search_result

def rank_pairs(pairs):
    if len(pairs) == 0:
        return []
    with torch.no_grad():
        inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512).to("cuda")
        scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
    return scores.tolist()

def convert_ctx(ctx):
    if ctx is None: return []
    if "items" not in ctx.keys(): return []

    result = []
    for mention in ctx["items"]:
        snippet = ""
        if "description" in mention.keys():
            snippet = mention["description"]
        if "extensions" in mention.keys():
            for ext in mention["extensions"]:
                if "text" not in ext: continue
                snippet += " " + ext["text"]
        result.append({
            "title": mention["title"],
            "snippet": snippet,
            "rank": mention["rank"],
            "global_rank": mention["global_rank"]
        })
    return result

def clean_text(text):
    text = text.replace(" ", "")
    return text

def dedupe(raw_list, key="snippet"):
    # raw_list: list(dict)
    # key: the target dedupe key: 
    dedupe_set = set()
    deduped_list = []
    
    for mention in raw_list:
        cleaned_text = clean_text(mention[key])
        if cleaned_text not in dedupe_set:
            dedupe_set.add(cleaned_text)
            deduped_list.append({
                "title": mention["title"],
                "snippet": mention["snippet"],
            })
    assert len(dedupe_set) == len(deduped_list)
    return deduped_list


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--input_file',
                        type=str,)
    parser.add_argument('--output_file',
                        type=str,)
    parser.add_argument('--ckpt_dir',
                        type=str,)
    parser.add_argument('--debug',
                        action="store_true",)
    parser.add_argument('--ranking_ctxs',
                        action="store_true",)
    args = parser.parse_args()


    file = args.input_file
    data = json.load(open(file, "r"))
    print(file)

    if args.ranking_ctxs:

        if isinstance(data, list):
            if args.debug:
                data = data[:2]
            for idx, mention in tqdm(enumerate(data), total=len(data)):

                search_list = mention["ctxs"]
                if "rewrite_search" in mention.keys():
                    for rewrite_search in mention["rewrite_search"]:
                        if rewrite_search is not None:
                            search_list.extend(rewrite_search)

                deduped_list = dedupe(search_list)

                all_pairs = [[mention["query"], f"{document['title']} {document['snippet']}"] for document in deduped_list]
                scores = rank_pairs(all_pairs)
                sorted_indices = [index for index, value in sorted(enumerate(scores), key=lambda x: x[1], reverse=True)]
                
                mention["rewrite_search"] = [deduped_list[index] for index in sorted_indices]

        elif isinstance(data, dict):
            for k, v in data.items():
                if args.debug:
                    v = v[:1]
                for idx, mention in tqdm(enumerate(v), total=len(v), desc=k):
                    search_list = mention["ctxs"]

                    if "rewrite_search" in mention.keys():
                        for rewrite_search in mention["rewrite_search"]:
                            if rewrite_search is not None:
                                search_list.extend(rewrite_search)

                    deduped_list = dedupe(search_list)

                    all_pairs = [[mention["query"], f"{document['title']} {document['snippet']}"] for document in deduped_list]
                    scores = rank_pairs(all_pairs)
                    sorted_indices = [index for index, value in sorted(enumerate(scores), key=lambda x: x[1], reverse=True)]
                        
                    mention["rewrite_search"] = [deduped_list[index] for index in sorted_indices]
        
        json.dump(data, open(args.output_file, "w"), ensure_ascii=False, indent=4)



    else:
        new_data = []
        for idx, mention in tqdm(enumerate(data), total=len(data)):
            ranking_score = []
            query = mention["query"]
            mention["ranking_score"] = []
            all_pairs = []
            all_index = []
            for rewrite_search in mention["rewrite_search"]:
                all_index.append((len(all_pairs), len(all_pairs)+len(rewrite_search)))
                all_pairs.extend([[query, f"{document['title']} {document['snippet']}"] for document in rewrite_search])
            
            scores = rank_pairs(all_pairs)
            mention["ranking_score"] = [scores[i:j] for i, j in all_index]
            mention["idx"] = str(idx)
            new_data.append(mention)


        json.dump(new_data, open(args.output_file, "w"), ensure_ascii=False, indent=4)