import json
import argparse
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
from nltk import sent_tokenize
import re
import numpy as np
import string
import torch
from searcher import SearcherWithinDocs

import pandas as pd
from transformers import AutoTokenizer
from utils import *
import inseq
from inseq.commands.attribute_context.attribute_context import AttributeContextArgs, attribute_context, attribute_context_with_model


def remove_citations(sent):
    return re.sub(r"\[\d+", "", re.sub(r" \[\d+", "", sent)).replace(" |", "").replace("]", "")

def pecora_cite(res_pecora, cti_threshold, start_pos_sent, end_pos_sent, topk_CCI, doc_seps):
    res = []

    sum_weight = 0
    sum_value = np.zeros(len(res_pecora['input_context_tokens']))
    
    for i in res_pecora['cci_scores']:
        # CTI Filtering
        if not (i["cti_idx"] >= start_pos_sent and i["cti_idx"] < end_pos_sent): continue
        if i['cti_score'] >= cti_threshold:
            # CCI Focus
            CCI_value = np.array(i['input_context_scores'])
            if topk_CCI == 0:
                cci_threshold = np.mean(CCI_value)
            elif topk_CCI < 0:
                cci_threshold = (1+topk_CCI/100) * np.max(CCI_value) - topk_CCI/100 * np.min(CCI_value)
            else:
                cci_threshold = np.sort(CCI_value)[-topk_CCI]
            zero_idx = CCI_value < cci_threshold
            CCI_value[zero_idx] = 0

            sum_value += CCI_value

        if i['cti_score'] < cti_threshold: break

    sum_tmp = 0
    for i, v in enumerate(sum_value):
        sum_tmp += v
        if doc_seps[i] or (i == len(sum_value)-1): # meet '\n'
            res.append(sum_tmp)
            sum_tmp = 0
    return res


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--f", type=str, help="Output data file")
    parser.add_argument("--only_cite", action="store_true", help="Only re-generate citations with new CTI and CCI thresholds")
    
    # CTI and CCI strategies
    topk_CTI = -1 # -1 means over average+1SD
    #topk_CTI = 0 # 0 means over average
    #topk_CTI = 3

    #topk_CCI = -5 # -5 means range top5%
    #topk_CCI = 0
    topk_CCI = 3 # 3 means top 3
    
    dir_pecora = './pecora_results_compare/'

    cite_idx_acs = False
    
    args = parser.parse_args()
    
    data = json.load(open(args.f))
    new_data = []

    prefix = dir_pecora + data["args"]["model"].lower().replace('/','_')+'-shot'+str(data["args"]["shot"])+'-seed'+str(data["args"]["seed"]) 

    # answer attribution with PECoRE
    if not args.only_cite:
        # Load prompt
        prompt_data = json.load(open(data["args"]["prompt_file"]))

        # Load model
        model, tokenizer = load_model(data["args"]["model"])
        model_pecora = inseq.load_model(
                model,
                "saliency",
                model_kwargs={"device_map": 'cuda:0', "torch_dtype": torch.float16},
                tokenizer_kwargs={"use_fast": False},
        )

        stop = []
        stop = list(set(stop + ["\n", "Ċ", "ĊĊ", "<0x0A>"])) # In Llama \n is <0x0A>; In OPT \n is Ċ
        stop_token_ids = list(set([tokenizer._convert_token_to_id(stop_token) for stop_token in stop] + [model.config.eos_token_id]))
        if "llama" in data["args"]["model"].lower() or "zephyr" in data["args"]["model"].lower() or "mistral" in data["args"]["model"].lower():
            stop_token_ids.remove(tokenizer.unk_token_id)

        special_tokens_to_keep = []

        if "zephyr" in data["args"]["model"].lower():
            decoder_input_output_separator = '\n '
            special_tokens_to_keep = ["</s>"]
        elif "llama" in data["args"]["model"].lower():
            decoder_input_output_separator = ' '
        elif "mistral" in data["args"]["model"].lower():
            decoder_input_output_separator = ' '
        else:
            print("model not supported yet")
            print(abc)

        num_empty = 0
        for idx, item in enumerate(tqdm(data['data'])):
            if item["output"] == "": 
                num_empty += 1
                continue
            
            item["output"] = item["output"].strip()
            for i in range(10):
                r_tmp = "\n" * (10-i) 
                item["output"] = item["output"].replace(r_tmp, " ")

            doc_list = item['docs']

            input_context_text = "".join([make_doc_prompt(doc, doc_id, prompt_data["doc_prompt"], use_shorter=None) for doc_id, doc in enumerate(doc_list)])

            input_current_text = item['question']
            
            input_template = prompt_data["demo_prompt"].replace("{INST}", prompt_data["instruction"]).replace("{Q}", "{current}").replace("{A}</s>", "").replace("{A}", "").replace("{D}", "{context}").rstrip()
            contextless_input_current_text = input_template.replace("{context}", "")

            output_current_text = item["output"]

            ##### Test #####
            print("***********")
            print("input_context_text")
            print(input_context_text)
            print("***********")
            print("input_current_text")
            print(input_current_text)
            print("***********")
            print("input_template")
            print(input_template)
            print("***********")
            print("contextless_input_current_text")
            print(contextless_input_current_text)
            print("***********")
            print("output_current_text")
            print(output_current_text)
            print("***********")
            print("decoder_input_output_separator")
            print(decoder_input_output_separator)
                
            #print(stop_token_ids)

            save_path = prefix + '-'+str(idx)+'.json'
            
            lm_rag_prompting_example = AttributeContextArgs(
                    model_name_or_path=data["args"]["model"],
                    input_context_text=input_context_text,
                    input_current_text=input_current_text,
                    output_template="{current}",
                    input_template=input_template,
                    contextless_input_current_text=contextless_input_current_text,
                    show_intermediate_outputs=False,
                    attributed_fn="contrast_prob_diff",
                    context_sensitivity_std_threshold=0,
                    output_current_text=output_current_text,
                    attribution_method="saliency",
                    attribution_kwargs={"logprob": True},
                    save_path=save_path,
                    tokenizer_kwargs={"use_fast": False},
                    model_kwargs={
                        "device_map": 'auto',
                        "torch_dtype": torch.float16,
                        "max_memory": get_max_memory(),
                        "load_in_8bit": False,
                        "cache_dir": "/projects/0/prjs0888/plms/"
                        },
                    generation_kwargs={
                        "do_sample": True,
                        "temperature": data["args"]["temperature"],
                        "top_p": data["args"]["top_p"],
                        "max_new_tokens": data["args"]["max_new_tokens"],
                        "num_return_sequences": 1,
                        "eos_token_id": stop_token_ids
                        },
                    decoder_input_output_separator=decoder_input_output_separator,
                    special_tokens_to_keep=special_tokens_to_keep,
                    show_viz=False,
                    )

            gen = attribute_context_with_model(lm_rag_prompting_example, model_pecora)
            
            #print(gen)

        print("*********")
        print("num_empty:")
        print(num_empty)
        print("*********")
        print()
    
    
    # Load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(data["args"]["model"], use_fast=False)

    # Fix OPT bos token problem in HF
    if "opt" in data["args"]["model"]:
        tokenizer.bos_token = "<s>"
    tokenizer.padding_side = "left"
    
    num_empty = 0
    for idx, item in enumerate(tqdm(data['data'])):
        if item["output"] == "": 
            new_data.append(item)
            num_empty += 1
            continue
    
        item["output"] = item["output"].strip()
        for i in range(10):
            r_tmp = "\n" * (10-i)
            item["output"] = item["output"].replace(r_tmp, " ")

        output = remove_citations(item["output"])

        # read pecora json results
        read_path = prefix + '-'+str(idx)+'.json'
        with open(read_path) as r:
            res_pecora = json.load(r)

        if topk_CTI == 0:
            cti_threshold = np.mean(res_pecora["cti_scores"])
        elif topk_CTI < 0:
            cti_threshold = np.mean(res_pecora["cti_scores"]) - topk_CTI * np.std(res_pecora["cti_scores"])
        else:
            cti_threshold = sorted(res_pecora["cti_scores"], reverse=True)[min(topk_CTI-1, len(res_pecora["cti_scores"])-1)]
        if "qampari" in args.f:
            sents = [item['question'] + ' ' + x.strip() for x in item['output'].rstrip(".").split(",")]
        else:
            sents = sent_tokenize(output)
        # check num and index of '\n' (i.e. <0x0A> in Llama, zephyr, mistral)
        # num should constantly be 5
        doc_seps = np.array(res_pecora["input_context_tokens"])
        doc_seps = doc_seps == '<0x0A>'
        #num_doc = pd.value_counts(res_pecora["input_context_tokens"])["<0x0A>"]
        
        #if num_doc != 5:
            #print(num_doc)
            #print(idx)
            #print(output)
            #print()

        new_output = ""
        start_pos_sent = 0
        end_pos_sent = 0
        print("\n\n")
        print("="*5)
        print(item['prompt'])
        print(item['output'])
        for sent in sents:
            # e.g. [1,3,4]
            original_ref = [int(r[1:])-1 for r in re.findall(r"\[\d+", sent)] 
            #if len(original_ref) > 1:
                #print(idx)
                #print(sent)
                #print(original_ref)
            end_pos_sent = start_pos_sent + len(tokenizer.tokenize(sent))
            
            # e.g. [0, 0, 20, 3, 0]; always length == 5
            cite_result_pecora = pecora_cite(res_pecora, cti_threshold, start_pos_sent, end_pos_sent, topk_CCI, doc_seps)
            #print(cite_result_pecora)
            #print()
            start_pos_sent = end_pos_sent

            if len(cite_result_pecora) >= 0:
                print("\n-----")
                print("Original sentence:", sent)
                print("Original ref:", original_ref)
                sent = remove_citations(sent)
               
                #best_doc_id = [i for i, v in enumerate(cite_result_pecora) if v]
                best_doc_id_tmp = {i: v for i, v in enumerate(cite_result_pecora) if v}
                best_doc_id = list(dict(sorted(best_doc_id_tmp.items(), key=lambda item: item[1], reverse=True)).keys())
                #best_doc_id = best_doc_id[: min(3, len(best_doc_id))]

                if cite_idx_acs:
                    best_doc_id = sorted(best_doc_id)

                print("New ref:", best_doc_id)
                best_doc_id_str = ""
                for i in best_doc_id:
                    best_doc_id_str += "[" + str(i+1) + "]"
                sent = best_doc_id_str + " " + sent
                print("New sentence:", sent)
            
            if "qampari" in args.f:
                new_output += sent.replace(item['question'], '').strip() + ", "
            else:
                new_output += sent + " "

            #for i in original_ref:
                #if "["+str(i+1)+"]" not in sent:
                    #print(sent)
                    #print(original_ref)
                    #print()
                #if not cite_result_pecora[i]:
                    #print(sent)
                    #print(original_ref)
                    #print(cite_result_pecora)
                    #print()

        item['output'] = new_output.rstrip().rstrip(",")
        print("\n-----")
        print("Final output: " + item['output'])
        new_data.append(item)

    print("num_empty:")
    print(num_empty)
    print()
    data['data'] = new_data 
    
    tag = f".pecora_compare"     
    tag += "_topk_CTI_" + str(topk_CTI)
    tag += "_topk_CCI_" + str(topk_CCI)

    if cite_idx_acs:
        tag += '_acs'

    json.dump(data, open(args.f + f".post_hoc_cite{tag}", 'w'), indent=4)

if __name__ == "__main__":
    main()
