import random
import argparse
import json
import torch
import os
from fastchat.model import load_model, get_conversation_template, add_model_args
import glob

sub_tensors = {"vicuna-7b-v1.5-16k":torch.tensor([2, 11889, 29901]),
              'longchat-7b-v1.5-32k':torch.tensor([2, 11889, 29901]),
              'Meta-Llama-3-8B-Instruct':torch.tensor([14711, 11344,    25]),
              }
sub_tensors2 = {"vicuna-7b-v1.5-16k":torch.tensor([9047, 13566, 29901]),
              'longchat-7b-v1.5-32k':torch.tensor([9047, 13566, 29901]),
              'Meta-Llama-3-8B-Instruct':torch.tensor([14711, 22103,    25]),
              }

def prepare_CS_data(demos_num,data,demos,seed):
    random.seed(seed)
    all_labels = []
    choose_demos = random.sample(demos,demos_num)
    query=[]
    #query.append([0,'You are a helpful assistant. Think the given question and then make a choice with the format of Answer:[choice].\n\n'])
    for i in range(demos_num):
        query.append([0,choose_demos[i][1]['input']+"\n"])
        query.append([1, choose_demos[i][1]['answerKey']+"\n"])
        all_labels.append(choose_demos[i][1]['answerKey'])
    query.append([0,data[1]['input']+"\n"])
    all_labels.append(data[1]['answerKey'])
    return query,all_labels

def prepare_data(demos_num,data,demos,seed):
    all_labels = []
    random.seed(seed)
    cans = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K"]
    choose_demos = random.sample(demos,demos_num)
    query=[]
    #query.append([0,'You are a helpful assistant. Think the given question and then make a choice with the format of Answer:[choice].\n\n'])
    for i in range(demos_num):
        can=""
        for j in range(len(choose_demos[i][1][1])):
            can+=" ("+cans[j]+") "+choose_demos[i][1][1][j]
        query.append([0,choose_demos[i][1][0]+can+"\n"])
        query.append([1, cans[choose_demos[i][1][2]]+"\n"])
        all_labels.append(str(cans[choose_demos[i][1][2]]))
    can = ""
    for j in range(len(data[1][1])):
        can += " (" + cans[j] + ") " + data[1][1][j]
    query.append([0,data[1][0]+can+"\n"])
    all_labels.append(str(cans[data[1][2]]))
    return query,all_labels

def prepare_syn_data(demos_num,data,demos,seed):
    all_labels=[]
    random.seed(seed)
    cans = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K"]
    choose_demos = random.sample(demos,demos_num)
    query=[]
    #query.append([0,'You are a helpful assistant. Think the given question and then make a choice with the format of Answer:[choice].\n\n'])
    for i in range(demos_num):
        query.append([0,choose_demos[i][1][0]+"\n"])
        query.append([1, str(choose_demos[i][1][1])+"\n"])
        all_labels.append(str(choose_demos[i][1][1]))
    query.append([0,data[1][0]+"\n"])
    all_labels.append(str(data[1][1]))
    return query,all_labels

def prepare_reasoning_data(demos_num,data,demos,seed):
    all_labels = []
    random.seed(seed)
    choose_demos = random.sample(demos,demos_num)
    query=[]
    for i in range(demos_num):
        query.append([0,choose_demos[i][1]['question']+"\n"])
        query.append([1, choose_demos[i][1]['answer']+"\n"])
        all_labels.append(choose_demos[i][1]['answer'].split('#### ')[1])
    query.append([0,data[1]['question']+"\n"])
    all_labels.append(data[1]['answer'].split('#### ')[1])
    return query,all_labels

def process_prompt(x,sub_tensor,batch_num):
    matches=[]
    for i in range(len(x) - len(sub_tensor) + 1):
        if torch.all(x[i:i + len(sub_tensor)] == sub_tensor):
            if i <2:continue
            matches.append(i)
    new_matches=[]
    for i in range(len(matches)):
        if (i+1)%batch_num==0:
            new_matches.append(matches[i])
    # matches_ = [matches[i]-3 for i in range(len(matches)) if i%32==0]+[len(x)-1]
    all_matches = [[matches[i]-3,matches[i]-2] for i in range(len(matches))]
  #  print("Matches")
  #  print(matches)
  #  print(len(x))
    golden = [x[matches[i]-2:matches[i]-1] for i in range(len(matches))]
    return new_matches,all_matches,golden
def process_prompt_reasoning(x,sub_tensor,sub_tensor2,batch_num):
    matches=[]
    for i in range(len(x) - len(sub_tensor) + 1):
        if torch.all(x[i:i + len(sub_tensor)] == sub_tensor):
            if i<2:continue
            matches.append(i)
    matches2=[]
    for i in range(len(x) - len(sub_tensor2) + 1):
        if torch.all(x[i:i + len(sub_tensor2)] == sub_tensor2):
            matches2.append(i)
    new_matches=[]
    for i in range(len(matches)):
        if (i+1)%batch_num==0:
            new_matches.append(matches[i])
    # matches_ = [matches[i]-3 for i in range(len(matches)) if i%32==0]+[len(x)-1]
    all_matches = [[matches2[i]+2,matches[i]-2] for i in range(len(matches))]
    golden = [x[matches2[i]+3:matches[i]-1] for i in range(len(matches))]
    return new_matches,all_matches,golden

@torch.inference_mode()
def main(args,model,tokenizer):
    if args.batch_num > args.demos_num: return
    if torch.cuda.is_available():
        for i in range(torch.cuda.device_count()):
            print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    else:
            print("No GPU available.")
    demos_seed = args.demos_seed
    filter_factor = args.filter_factor
    temperature = args.temperature
    task=args.task
    with open("Dataset/{}/test/data.json".format(task), 'r') as f:
        datas = json.load(f)
        f.close()
    random.shuffle(datas)
    factor_num = args.demos_num/args.batch_num
    random.seed(demos_seed)
    model_name = args.model_path
    demos_num = args.demos_num
    n = args.n_num
    dataset = task
    further_dir = "final_para_filter_model-{}_demo-{}_batch-{}_filter-{}_temp-{}_n-{}_seed-{}".format(model_name,demos_num,args.batch_num,filter_factor,temperature,n,demos_seed)
    if args.load_8bit:
        further_dir+="-8bit"
    dir_name = "API_Completion_exp/" + dataset + "/" + further_dir
    os.makedirs(dir_name, exist_ok=True)
    os.makedirs(dir_name + "/raw_data", exist_ok=True)
    os.makedirs(dir_name + "/hidden", exist_ok=True)
    os.makedirs(dir_name + "/has_generate", exist_ok=True)
    with open("Dataset/{}/train/data.json".format(task), 'r') as f:
        demos = json.load(f)
        f.close()
    random.shuffle(demos)

    print(args)
    for tem in datas:
        idx = tem[0]
        has_generate = [int(t.split("/")[-1]) for t in glob.glob(dir_name + "/has_generate/*")]
        if idx in has_generate:continue
        os.makedirs(dir_name + "/has_generate/{}".format(idx), exist_ok=True)
        print(len(has_generate))
        if task == 'CommonsenseQA':
            query, all_labels = prepare_CS_data(demos_num, tem, demos,
                                             demos_num * demos_seed + demos_num * idx + idx * demos_seed + idx)
        elif task in ['GSM8K']:
            query, all_labels = prepare_reasoning_data(demos_num, tem, demos,
                                                demos_num * demos_seed + demos_num * idx + idx * demos_seed + idx)
        elif 'syn' in task:
            query, all_labels = prepare_syn_data(demos_num, tem, demos,
                                                demos_num * demos_seed + demos_num * idx + idx * demos_seed + idx)
        else:
            query,all_labels = prepare_data(demos_num, tem, demos,demos_num*demos_seed+demos_num*idx+idx*demos_seed+idx)
        try:
       # if True:
            post_fix = ' ' if 'syn' in task else ''
            conv = get_conversation_template(args.model_path)
            for tt in query:
                conv.append_message(conv.roles[tt[0]], tt[1])
            conv.append_message(conv.roles[1], None)
            prompt = conv.get_prompt()+post_fix
            try: prompt = prompt.split("chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. ")[1]
            except:pass
            # Run inference
            if 'Have a great celebration!\n' in prompt:
                prompt=prompt.split('Have a great celebration!\n')[-1]
            inputs = tokenizer([prompt], return_tensors="pt")
            if task in ['GSM8K']:
                parts,o_parts,golden = process_prompt_reasoning(inputs['input_ids'][0], sub_tensors[model_name], sub_tensors2[model_name],args.batch_num)
            else:
                parts, o_parts,golden = process_prompt(inputs['input_ids'][0], sub_tensors[model_name], args.batch_num)
            #print("**!!**")
            #print(golden)
            inputs = inputs.to(args.device)
            cur = {"response": [], "raw": tem}
            for j in range(n):
                #if True:
                try:
                    outputs = model.generate(
                        **inputs,
                        do_sample=True if args.temperature > 1e-5 else False,
                        temperature=args.temperature,
                        repetition_penalty=args.repetition_penalty,
                        max_new_tokens=args.max_new_tokens,
                        part_length = [parts,o_parts,golden],
                        attention_factor = factor_num,
                        filter_factor = filter_factor,
                    )
                    output_ids=outputs[0]
                    all_loss = outputs[1]
                    # all_predict = [tokenizer.decode(torch.tensor(t), skip_special_tokens=True, spaces_between_special_tokens=False ) for t in outputs[1]]
                    if model.config.is_encoder_decoder:
                        output_ids = output_ids[0]
                    else:
                        output_ids = output_ids[0][len(inputs["input_ids"][0]):]
                    outputs = tokenizer.decode(
                        output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
                    )
                    if task == 'CommonsenseQA':
                        print("{}-{}".format(tem[1]['answerKey'], outputs))
                    elif task in ['GSM8K','syn_counta_5']:
                        print(outputs)
                    else:
                        cans = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K"]
                        print("{}-{}".format(cans[tem[1][2]],outputs))
                    cur['response'].append(outputs)
                except Exception as e:
                    print("Error:",e)
                    pass
        except Exception as e:
            print("Error2:",e)
            cur = {"response": [], "raw": tem}
        with open(dir_name + "/raw_data/{}.json".format(idx), "w") as f:
            json.dump(cur, f)
            f.close()
        with open(dir_name + "/hidden/{}.json".format(idx), "w") as f:
            try:
                hidden_cur = {'all_loss': all_loss, 'predict': outputs.strip(), 'label':all_labels[-1]}
            except:
                hidden_cur = {}
            json.dump(hidden_cur, f)
            f.close()



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    add_model_args(parser)
    parser.add_argument("--temperature", type=float, default=0.1)
    parser.add_argument("--demos_num", type=int, default=4)
    parser.add_argument("--batch_num", type=int, default=4)
    parser.add_argument("--filter_factor", type=int, default=0.8)
    parser.add_argument("--n_num", type=int, default=1)
    parser.add_argument("--subset", type=int, default=0)
    parser.add_argument("--demos_seed", type=int, default=0)
    parser.add_argument("--repetition_penalty", type=float, default=1.0)
    parser.add_argument("--max-new-tokens", type=int, default=1)
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--task", type=str, default='SST5')
    args = parser.parse_args()
    model, tokenizer = load_model(
        args.model_path,
        device=args.device,
        num_gpus=args.num_gpus,
        max_gpu_memory=args.max_gpu_memory,
        load_8bit=args.load_8bit,
        cpu_offloading=args.cpu_offloading,
        revision=args.revision,
        debug=args.debug,
    )
    args.model_path=args.model_path.split("/")[1]
    demo_map = {'CommonsenseQA':[1,2,4,8,16,32,64,128],'PIQA':[1,2,4,8,16,32,64,96],'ARCE':[1,2,4,8,16,32,64,108],'GSM8K':[1,2,4,8,16,32,64,80],'syn_counta_5':[1,2,4,8,16,32,64,128,256,448]}
    batch_map = {'CommonsenseQA':[32],'PIQA':[2],'ARCE':[12],'GSM8K':[8],'syn_counta_5':[100]}
    factor_map = {'CommonsenseQA':[0.9],'PIQA':[0.9],'ARCE':[0.9],'GSM8K':[0.9],'syn_counta_5':[0.9]}
    tasks = ['CommonsenseQA','PIQA','ARCE','syn_counta_5','GSM8K']
    for task in tasks:
        if task == 'GSM8K':args.max_new_tokens = 250
        else: args.max_new_tokens = 1
        for filter_factor in factor_map[task]:
            for demo_num in demo_map[task]:
                for batch_num in batch_map[task]:
                        if batch_num>demo_num:continue
                        for i in range(5):
                            args.demos_seed = i
                            args.task = task
                            args.filter_factor = filter_factor
                            args.demos_num = demo_num
                            args.batch_num = batch_num
                            main(args,model,tokenizer)
