from typing import List, Dict, Union, Tuple, Set
from itertools import chain
from argparse import ArgumentParser
import os
import torch
import json
import copy
import time
import numpy as np
import subprocess
from collections import OrderedDict
from tqdm import tqdm

from beir.datasets.data_loader import GenericDataLoader

from utils.logger import logger
from utils.utils import load_pseudo_queries, load_json_queries, save_input_arguments, get_flat_document_repr, save_data_to_reproduce_experiments, load_json_corpus, save_elapsed_time, load_tsv_ranking

from sklearn.feature_extraction.text import TfidfVectorizer
from itertools import chain

import numpy as np
from sklearn.cluster import KMeans
from analysis.colbert_prf_oracle_doc_term.get_exp_star_from_prf import IdfFinder

from .more_utils import get_word_subword_mapping, map_subword_to_word, combine_mapping
import random
import ipdb
import openai
import copy
import time
from pathlib import Path

def prompt_construction(prompt_data, n_queries=2, n_prf_docs=3):
    init_prompt = ""
    for ii in range(n_queries):
        data = prompt_data['queries'][ii]

        query = data['query']
        prompt = f"Query: {query}\n"

        for idx, prf in enumerate(data['prf_docs'][:n_prf_docs]):
            prf_title = prf['title']
            prf_text = prf['text']
            prompt_text = f"[Relevant Document {idx}]\nTitle: {prf_title}\nDocument: {prf_text}\n"   
            prompt += prompt_text

        exp_tokens = data['key_search_terms']
        refined_exp_tokens = data['key_search_terms_top10']        

        prompt += f"List of key search terms: [ {' , '.join(exp_tokens)} ]\nRefine the list to the top 10 key search terms: [ {' , '.join(refined_exp_tokens)} ]"

        init_prompt += prompt
        init_prompt += '\n'
    return init_prompt

if __name__=='__main__':
    parser = ArgumentParser("Obtain expansion term vectors and weights from PRF where \
        max-pooled term vectors by each oracle query term vector from RF are selected as expansion term vectors. \
        For term weights, ")

    parser.add_argument('--output_dir_specified', type=str)
    parser.add_argument('--output_dir', type=str)
    parser.add_argument("--overwrite", action="store_true")

    parser.add_argument('--data_path', type=str, required=True)
    parser.add_argument('--dataset', required=True)
    parser.add_argument('--dataset_split', type=str, default='test')

    parser.add_argument('--checkpoint', type=str, default="experiments/checkpoints/colbertv2.0")
    parser.add_argument('--doc_maxlen', type=int, default=300)
    parser.add_argument('--query_maxlen', type=int, default=32)
    
    parser.add_argument('--n_expansions', type=int, default=10)
    parser.add_argument('--n_clusters', type=int, default=24)
    
    parser.add_argument('--temperature', type=float, default=0.2)
    
    parser.add_argument('--ranking_md_score_path', type=str, required=True,)
    parser.add_argument('--mult_relevance', action='store_true')
    parser.add_argument('--md_normalization', type=str, default='none', choices=['none', 'softmax'])

    parser.add_argument('--ranking_path', required=True)
    parser.add_argument('--depth', type=int, default=3)
    
    parser.add_argument('--nprobe', dest='nprobe', default=10, type=int)
    parser.add_argument('--index_root', dest='index_root', required=True)
    parser.add_argument('--index_name', dest='index_name', required=True)
    parser.add_argument('--partitions', dest='partitions', default=None, type=int)

    args = parser.parse_args()
    assert args.output_dir or args.output_dir_specified
    start_time = time.time()

    ## load llm prompt
    prompt_data = json.load(open('prompt.json','r'))  

    init_prompt = prompt_construction(prompt_data)  

    openai.api_key = "openai.api_key"
    llm_model = "gpt-3.5-turbo-0301"

    if args.output_dir_specified: 
        output_dir = args.output_dir_specified
    else:
        output_dir = args.output_dir
        output_dir = os.path.join(output_dir, os.path.basename(__file__)) 
        output_dir = os.path.join(output_dir, args.dataset) 
    print()
    outfile_path:str = os.path.join(output_dir, "expansion.pt")
    if os.path.exists(outfile_path):
        logger.info(f"#> Results already exist at {outfile_path}")
        if args.overwrite:
            logger.info(f'We will overwrite results.')
        else:
            print()
            exit()
    logger.info(f"#> Results will be saved at {outfile_path}")
    print()
    os.makedirs(output_dir, exist_ok=True)

    save_data_to_reproduce_experiments(output_directory=output_dir, path_to_python_script=__file__, input_arguments=args, prefix=os.path.basename(__file__))

    output_dir_exp = Path(output_dir) / 'expansions'
    output_dir_exp.mkdir(parents=True,exist_ok=True)

    print()
    corpus, queries, qrels = GenericDataLoader(data_folder=args.data_path).load(split=args.dataset_split)

    ranking = load_tsv_ranking(ranking_path=args.ranking_path, depth=args.depth)

    import json
    from collections import OrderedDict
    queries_prf = OrderedDict()
    with open(args.ranking_md_score_path) as fIn:
        for line in fIn:
            qid, doc_id, rank, rel_score, md_scores = line.strip().split("\t")
            if qid not in queries:
                continue
            md_scores = json.loads(md_scores)
            if int(rank) == 1:
                queries_prf[qid] = {"_id": qid, "text": queries[qid], "PRF": []}
            if int(rank) <= args.depth:
                doc = corpus[doc_id]
                queries_prf[qid]["PRF"].append({"_id": doc_id, **doc, "md": md_scores, "rel_score":float(rel_score)})
    logger.info(f"The number of queries with PRF = {len(queries_prf)}")

    from transformers import BertTokenizerFast, BertTokenizer
    from run_colbert.colbert.end_to_end_ranking.faiss_index import FaissIndex
    from run_colbert.colbert.indexing.faiss import get_faiss_index_name
    from argparse import Namespace
    tokenizer=BertTokenizer.from_pretrained('bert-base-uncased')
    index_path = os.path.join(args.index_root, args.index_name)
    faiss_index_path = os.path.join(index_path, get_faiss_index_name(args))
    inference = Namespace()
    inference.query_tokenizer = Namespace()
    inference.query_tokenizer.tok = tokenizer
    inference = Namespace(query_tokenizer=Namespace(tok=tokenizer))
    ann_faiss_index = FaissIndex(index_path=index_path, faiss_index_path=faiss_index_path, nprobe=args.nprobe, part_range=None, inference=inference)
    idf_finder = IdfFinder(faiss_index=ann_faiss_index, tokenizer=tokenizer, device=torch.device("cuda"))
    
    from run_colbert.colbert.evaluation.load_model import load_model
    print()
    logger.info(f"Load ColBERT")
    colbert_args = copy.deepcopy(args)
    colbert_args.similarity = 'cosine'
    colbert_args.dim = 128
    colbert_args.mask_punctuation = True
    colbert_args.amp = False
    colbert, _ = load_model(args=colbert_args, do_print=True)
    import torch
    device = torch.device('cuda')
    colbert.to(device)
    colbert.eval()

    print()
    logger.info(f"Load DocTokenizer")
    from run_colbert.colbert.modeling.tokenization.doc_tokenization import DocTokenizer
    doc_tokenizer = DocTokenizer(doc_maxlen=args.doc_maxlen)
    doc_tokenizer.tok = tokenizer

    from analysis.colbert_lexical_bias.utils import doc_string_to_tokens_batch
    expansion_vecs_dict = OrderedDict()
    expansion_weights_dict = OrderedDict()
    expansion_metadata = OrderedDict()
    iii = 0
    with torch.no_grad():
        for qid, query_prf in tqdm(queries_prf.items(), desc="Tokenize/Encode/Obtain expansion vectors", total=len(queries_prf)):
            #ipdb.set_trace()
            iii += 1
            savefile_path = output_dir_exp / f"qid{qid}.pt"
            if savefile_path.exists():
                print(f"Query {qid} is already done. Skipping.")
                continue

            if iii % 50 == 0:
                time.sleep(31)

            prf = query_prf["PRF"]

            docs_rel:List[float] = [d["rel_score"] for d in prf]
            docs_rel_softmax:torch.Tensor = torch.nn.functional.softmax(torch.tensor(docs_rel), dim=0)
 
            docs = [get_flat_document_repr(d) for d in prf]
            
            list_of_doc_tokens, doc_input_ids, doc_mask = doc_string_to_tokens_batch(doc_tokenizer=doc_tokenizer, list_of_doc_text=docs)

            docs_imp = [d["md"] for d in prf]
            docs_imp = [d[0:1] + [("[D]", 0.0)] + d[1:] for d in docs_imp]
            docs_imp = [d[:len(d2)] for d, d2 in zip(docs_imp, list_of_doc_tokens)]
            
            D, D_mask = colbert.doc(input_ids=doc_input_ids, attention_mask=doc_mask, keep_dims='return_mask')
            D = D.cpu().data
            D_mask = D_mask.cpu().data

            tokens_flatten = []
            vecs_flatten = []
            md_flatten = []
            for rank, (d, d_mask) in enumerate(zip(D, D_mask)):
                d_tokens:List[str] = list_of_doc_tokens[rank][:len(d)]
                d_imps:List[str] = docs_imp[rank][:len(d_tokens)]
                d_mask = d_mask.squeeze(-1)

                d = d[d_mask, :]
                d_tokens = [x for x, b in zip(d_tokens, d_mask.cpu().data.tolist()) if b]
                d_imps = [x for x, b in zip(d_imps, d_mask.cpu().data.tolist()) if b]

                d = d[2:]
                d_tokens = d_tokens[2:]
                d_imps = d_imps[2:]

                d = d[:-1]
                d_tokens = d_tokens[:-1]
                d_imps = d_imps[:-1]

                _tokens2 = " ".join([_[0] for _ in d_imps])
                try:
                    assert " ".join(d_tokens)==_tokens2, f"colbert={' '.join(d_tokens)}, md={_tokens2}"
                    tokens_flatten.extend(d_tokens)
                    vecs_flatten.append(d.float())
                    md_flatten_add = [_[1] for _ in d_imps]
                    if args.mult_relevance:
                        md_flatten_add = torch.tensor(md_flatten_add) * docs_rel_softmax[rank]
                        md_flatten.extend(md_flatten_add.cpu().numpy().tolist())
                    else:
                        md_flatten.extend(md_flatten_add)
                except:
                    tokens_flatten.extend(d_tokens)
                    vecs_flatten.append(d.float())
                    md_flatten_add = [1.0]*len(d_tokens)
                    if args.mult_relevance:
                        md_flatten_add = torch.tensor(md_flatten_add) * docs_rel_softmax[rank]
                        md_flatten.extend(md_flatten_add.cpu().numpy().tolist())
                    else:
                        md_flatten.extend(md_flatten_add)

            vecs_flatten = torch.cat(vecs_flatten, dim=0)

            kmn = KMeans(args.n_clusters)
            if args.md_normalization=='softmax':
                md_flatten = torch.tensor(md_flatten).softmax(dim=-1).cpu().numpy()
            kmn.fit(X=vecs_flatten.cpu().numpy(), sample_weight=md_flatten)
            centroids = np.float32(kmn.cluster_centers_)

            toks2freqs = idf_finder.get_nearest_tokens_for_embs(centroids)

            triples = [] 
            for cluster_idx, tok2freq in zip(range(args.n_clusters), toks2freqs):
                if len(tok2freq) == 0:
                    continue
                most_likely_tok = max(tok2freq, key=tok2freq.get)
                tid = inference.query_tokenizer.tok.convert_tokens_to_ids(most_likely_tok)
                
                exp_emb = centroids[cluster_idx]
                exp_wt = idf_finder.idfdict[tid]
                exp_tok = most_likely_tok
                triples.append((exp_emb, exp_wt, exp_tok))
            
            all_triples = sorted(triples, key=lambda tup: -tup[1])[:]
            exp_embs, exp_weights, exp_tokens = zip(*all_triples)

            exp_embs = torch.tensor(np.array(exp_embs))
            exp_weights = torch.tensor(np.array(exp_weights))
            exp_tokens = list(exp_tokens)

            expansion_vecs_dict[qid] = exp_embs
            expansion_weights_dict[qid] = exp_weights

            subword_mapping1 = get_word_subword_mapping(query_prf['PRF'][0]['md'])
            subword_mapping2 = get_word_subword_mapping(query_prf['PRF'][1]['md'])
            subword_mapping3 = get_word_subword_mapping(query_prf['PRF'][2]['md'])
            
            map1 = map_subword_to_word(exp_tokens, subword_mapping1)
            map2 = map_subword_to_word(exp_tokens, subword_mapping2)
            map3 = map_subword_to_word(exp_tokens, subword_mapping3)
            mapping = combine_mapping([map1, map2, map3])

            orig_exp_tokens = copy.deepcopy(exp_tokens)
            _exp_tokens = []
            for tok in mapping.keys():
                if tok.startswith('##'):
                    _exp_tokens.append(mapping[tok]['word'])
                else:
                    _exp_tokens.append(tok)
            exp_tokens = _exp_tokens
            refined_exp_tokens = copy.deepcopy(exp_tokens)[:10]

            random.shuffle(exp_tokens)
            test_prompt = ""
            _prompt = f"Query: {query_prf['text']}\n"        
            for idx, prf in enumerate(query_prf['PRF'][:3]):
                prf_title = prf['title']
                prf_text = prf['text']
                _prompt += f"[Relevant Document {idx}]\nTitle: {prf_title}\nDocument: {prf_text}\n"                                
            _prompt += f"List of key search terms: [ {' , '.join(exp_tokens)} ]\nRefine the list to the top 10 key search terms: [ "         
            test_prompt += _prompt    

            n_queries = 2
            n_prf_docs = args.depth
            
            while True:
                try:
                    response = openai.ChatCompletion.create(
                        model=llm_model,
                        temperature=args.temperature,
                        max_tokens=200,
                        n=1,
                        frequency_penalty=0.0,
                        presence_penalty=0.0,
                        messages=[{"role": "system", "content": ""},
                                    {"role": "user", "content": init_prompt + test_prompt}] 
                    )
                except openai.error.InvalidRequestError as err:
                    n_prf_docs -= 1
                    if n_prf_docs == 0:
                        ipdb.set_trace()

                    init_prompt = prompt_construction(prompt_data, n_queries=n_queries, n_prf_docs=n_prf_docs)
                else:
                    print(f"Passed with n_prf_docs={n_prf_docs}")
                    break

            
            llm_exp_tokens = [item.strip() for item in response["choices"][0]['message']['content'].split('\n')[0].strip('[').strip(']').strip().split(',')]
          
            llm_exp_tokens = [item.lower() for item in llm_exp_tokens if not item == '']
            print(llm_exp_tokens)

            matched = []
            unmatched = []
            for tok in llm_exp_tokens:
                for k,v in mapping.items():
                    if v['word'] == tok:
                        matched.append([tok, k])
                        break
                
                if not tok in [item[0] for item in matched]:
                    unmatched.append([tok, None])
           
            for w, _ in unmatched:
                if w in orig_exp_tokens:
                    matched.append([w, w])
          
            final_exp_tokens = []
            final_exp_vecs = []
            final_exp_weights = []
            for w, sw in matched:
                if sw in orig_exp_tokens:
                    loc = orig_exp_tokens.index(sw)
                    final_exp_tokens.append(sw)
                    final_exp_vecs.append(exp_embs[loc])
                    final_exp_weights.append(exp_weights[loc])
            
            assert len(final_exp_tokens) == len(final_exp_vecs) == len(final_exp_weights)
            if len(final_exp_tokens) == 0:
                final_exp_tokens = ["[D]"]
                final_exp_vecs = [torch.zeros(128)]
                final_exp_weights = [torch.zeros(1)]
            
            exp_tokens = final_exp_tokens
            exp_embs = torch.stack(final_exp_vecs)
            exp_weights = torch.stack(final_exp_weights)
         
            expansion_metadata[qid] = {
                "query": queries[qid],
                "query_prf": query_prf,
                "exp_tokens":exp_tokens, 
                "exp_vecs": exp_embs, 
                "exp_weights": exp_weights, 
            }
            print(f"N matched: {len(exp_tokens)}")
            
            torch.save({"expansion_vecs": expansion_vecs_dict[qid], 
                        "expansion_weights":expansion_weights_dict[qid], 
                        "metadata": expansion_metadata[qid]},
                        savefile_path)
