import os
import random
import json
import torch
import time
import copy
import numpy as np
import math
from torch import nn
from tqdm import tqdm
from transformers import AutoTokenizer, LlamaTokenizer
from collections import defaultdict
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.cluster import KMeans
from MetaICL.metaicl.model import MetaICLModel

from utils import codex_execution


def get_instance_length(input_text,output_text,tokenizer):
    return len(tokenizer(input_text)['input_ids']),len(tokenizer(output_text)['input_ids'])

def prepro_sentence_pair_single(ids1, ids2, max_length,
                                bos_token_id, eos_token_id,
                                allow_truncation=False):

    #if bos_token_id is not None:
    #    ids1 = [bos_token_id] + ids1
    #if eos_token_id is not None:
    #    ids2 = ids2 + [eos_token_id]
    if allow_truncation and len(ids1)+len(ids2) > max_length:
        ids1 = ids1[len(ids1)+len(ids2)-max_length:] # len = max_length-len(ids2)
        assert len(ids1)+len(ids2)==max_length

    n_mask = max_length-len(ids1)-len(ids2)
    assert n_mask>=0, (max_length, len(ids1), len(ids2))
    input_ids = ids1+ids2+[0 for _ in range(n_mask)]
    attention_mask = [1 for _ in ids1+ids2] + [0 for _ in range(n_mask)]
    token_type_ids = [0 for _ in ids1] + [1 for _ in ids2] + [0 for _ in range(n_mask)]
    return input_ids, attention_mask, token_type_ids

def _prepro_each_datapoint(dp, is_first=True, is_training=False, for_unlabeled_demonstrations=False,
                            add_newlines=True, tokenizer_gpt=None):
    dp = dp.copy()
    if add_newlines:
        no_input = dp["input"]==""
    if not is_first:
        if no_input:
            dp["input"] = "\n\n" + dp["input"]
        else:
            dp["input"] = "\n\n\n" + dp["input"]
    input_tokens = tokenizer_gpt(dp["input"])["input_ids"]
    if "options" in dp:
        dp["options"] = ["\n" + opt for opt in dp["options"]]

    if for_unlabeled_demonstrations:
        # U_tr(x + 伪y)
        option_tokens = [tokenizer_gpt(option)["input_ids"] for option in dp["options"]]
        return [input_tokens+option_tokens_ for option_tokens_ in option_tokens]
        # only_x_input_ids = tokenizer_gpt(dp["input"] + '\n\n')["input_ids"]
        # return only_x_input_ids

    dp["output"] = "\n" + dp["output"]

    assert len(dp["options"])>=2, dp
    assert dp["output"] in dp["options"]
    option_tokens = [tokenizer_gpt(option)["input_ids"] for option in dp["options"]]
    input_tokens = [input_tokens for _ in option_tokens]
    output_tokens = option_tokens
    option_tokens = [dp["options"].index(dp["output"])]
    return input_tokens, output_tokens, option_tokens

def prompt_retrieval(train_embs,test_embs,train_examples,eval_examples,return_string,format_example,
                     maximum_input_len,args, label_map,prompt_identifier='prompts',single_context_example_len=None, pseudo_indices=None):
    cos = nn.CosineSimilarity(dim=1, eps=1e-6)
    eval_example_num = len(eval_examples)
    bar = tqdm(range(eval_example_num), desc="Retrieve examples from annotated pool")
    if 'llama' in args.model_cache_dir:
        tokenizer = LlamaTokenizer.from_pretrained(args.model_cache_dir)
    else:
        tokenizer = AutoTokenizer.from_pretrained(args.model_cache_dir)
    prompt_cache_dir = os.path.join(args.output_dir,prompt_identifier)
    if not os.path.isdir(prompt_cache_dir):
        os.makedirs(prompt_cache_dir, exist_ok=True)
    for test_id, one_test_instance in enumerate(eval_examples):
        one_test_instance_input_text,one_test_instance_output_text = format_example(example=one_test_instance,args=args,
                                                                                    label_map=label_map)
        cur_prompt_string_len = get_instance_length(one_test_instance_input_text,one_test_instance_output_text,tokenizer)[0]
        if args.prompt_retrieval_method=='similar':
            test_e_reshape = test_embs[test_id].reshape(1, -1)
            scores = cos(test_e_reshape, train_embs).numpy()
            sorted_indices = np.argsort(scores)
        elif args.prompt_retrieval_method=='random':
            sorted_indices = np.random.permutation(range(eval_example_num))
        else:
            raise ValueError(f"The prompt retrieval method {args.prompt_retrieval_method} is not supported")
        selected_indices = []
        num_indices = len(sorted_indices)
        # for idx in range(num_indices - 1, -1, -1):
        #     if sorted_indices[idx] in pseudo_indices:
        #         continue
        #     if args.prompt_retrieval_method=='similar' and scores[sorted_indices[idx]]==1:
        #         continue
        #     cur_example_input_text,cur_example_output_text = format_example(example=train_examples[sorted_indices[idx]],
        #                                                                     args=args,label_map=label_map)
        #     cur_len = sum(get_instance_length(cur_example_input_text, cur_example_output_text,tokenizer=tokenizer))
        #     if single_context_example_len is not None and cur_len>single_context_example_len:
        #         continue
        #     if cur_prompt_string_len > maximum_input_len // 2:
        #         break
        #     cur_prompt_string_len += cur_len
        #     selected_indices.append(idx)

        # for idx in range(num_indices - 1, -1, -1):
        #     if sorted_indices[idx] not in pseudo_indices:
        #         continue
        #     if args.prompt_retrieval_method=='similar' and scores[sorted_indices[idx]]==1:
        #         continue
        #     cur_example_input_text,cur_example_output_text = format_example(example=train_examples[sorted_indices[idx]],
        #                                                                     args=args,label_map=label_map)
        #     cur_len = sum(get_instance_length(cur_example_input_text, cur_example_output_text,tokenizer=tokenizer))
        #     if single_context_example_len is not None and cur_len>single_context_example_len:
        #         continue
        #     cur_prompt_string_len += cur_len
        #     if cur_prompt_string_len > maximum_input_len:
        #         break
        #     selected_indices.append(idx)

        for idx in range(num_indices - 1, -1, -1):
            if args.prompt_retrieval_method=='similar' and scores[sorted_indices[idx]]==1:
                continue
            cur_example_input_text,cur_example_output_text = format_example(example=train_examples[sorted_indices[idx]],
                                                                            args=args,label_map=label_map)
            cur_len = sum(get_instance_length(cur_example_input_text, cur_example_output_text,tokenizer=tokenizer))
            if single_context_example_len is not None and cur_len>single_context_example_len:
                continue
            cur_prompt_string_len += cur_len
            if cur_prompt_string_len > maximum_input_len:
                break
            selected_indices.append(idx)

        one_test_emb = test_embs[test_id]
        indices_scores = []
        for idx in selected_indices:
            indices_scores.append(
                [idx, cos(train_embs[sorted_indices[idx]].reshape(1, -1), one_test_emb.reshape(1, -1)).item()])
        indices_scores = sorted(indices_scores, key=lambda x: x[1], reverse=True)
        new_selected_indices = [x[0] for x in indices_scores]
        # if args.prompt_retrieval_method in ['similar']:
        #     assert new_selected_indices == selected_indices, f"new_selected_indices={new_selected_indices}, " \
        #                                                      f"selected_indices={selected_indices}"
        selected_indices = new_selected_indices

        select_num = len(selected_indices)
        second_phase_selected_indices = []
        if return_string:
            cur_train_data = ''
        else:
            cur_train_data = []
        for idx in range(select_num - 1, -1, -1):
            cur_input_text, cur_output_text = format_example(
                example=train_examples[sorted_indices[selected_indices[idx]]],
                args=args, label_map=label_map)
            if return_string:
                cur_train_data += f'{cur_input_text}{cur_output_text}\n\n'
            else:
                if args.task_name in ['hellaswag', 'commonsense_qa', 'piqa', 'cosmos_qa', 'copa']:
                    cur_train_data.append({
                        'input': cur_input_text,
                        'output': cur_output_text,
                        'options': train_examples[sorted_indices[selected_indices[idx]]]['endings']
                    })
                else:
                    cur_train_data.append({
                        'input': cur_input_text,
                        'output': cur_output_text
                    })
            second_phase_selected_indices.append([sorted_indices[selected_indices[idx]].item()])
        if return_string:
            cur_train_data += format_example(
                example=one_test_instance,
                args=args, label_map=label_map)[0]
        # print(f'{len(second_phase_selected_indices)} examples in context')
        with open(os.path.join(prompt_cache_dir,f"{one_test_instance['id']}.json"),'w') as f:
            json.dump([[test_id, second_phase_selected_indices, one_test_instance['label']],
                       cur_train_data,
                       one_test_instance
                       ], f, indent=4)
        bar.update(1)

def np_softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()

def in_dev_prompt_retrieval(train_embs,test_embs,train_examples,eval_examples,return_string,format_example,
                     maximum_input_len,args, label_map,prompt_identifier='prompts',single_context_example_len=None, inference_model=None, tokenizer_gpt=None):

    prompt_cache_dir = os.path.join(args.output_dir,prompt_identifier)
    if not os.path.isdir(prompt_cache_dir):
        os.makedirs(prompt_cache_dir, exist_ok=True)

    if not args.task_name in ['hellaswag','xsum','nq','commonsense_qa','piqa','cosmos_qa', 'copa']:
        all_labels = []
        label_to_digit = {}
        for k, v in label_map.items():
            all_labels.append(v)
            label_to_digit[v] = k
    
    for i, dp in enumerate(train_examples):
        dp['input'] = format_example(dp, label_map=label_map, args=args)[0]
        if args.task_name in ['hellaswag', 'commonsense_qa', 'piqa', 'cosmos_qa', 'copa']:
            dp['options'] = dp['endings']
        else:
            dp['options'] = all_labels
    for i, dp in enumerate(eval_examples):
        dp['input'], dp['output'] = format_example(dp, label_map=label_map, args=args)
        if args.task_name in ['hellaswag', 'commonsense_qa', 'piqa', 'cosmos_qa', 'copa']:
            dp['options'] = dp['endings']
        else:
            dp['options'] = all_labels

    
    if isinstance(inference_model, MetaICLModel):
        upper_bound_inference_model = inference_model.model
    else:
        upper_bound_inference_model = inference_model
    
    demonstrations = []
    for i, dp in enumerate(train_examples):
        input_ = _prepro_each_datapoint(
            dp, is_first=True, for_unlabeled_demonstrations=True,
            add_newlines=True, tokenizer_gpt=tokenizer_gpt)
        demonstrations.extend(input_) # U_tr
        # demonstrations.append(input_)
    assert len(demonstrations) == len(dp['options']) * len(train_examples) # U_tr
    # assert len(demonstrations) == len(train_examples)


    M_sample = args.M_sample
    demonstrations_freqs = np.ones(len(demonstrations)) # init freqs

    bos_token_id = tokenizer_gpt.bos_token_id
    eos_token_id = tokenizer_gpt.eos_token_id

    bar = tqdm(range(len(eval_examples) * M_sample), desc=f"Retrieve oracle examples for dev set, eval: {len(eval_examples)}, demo: {M_sample}")
    all_selected_indices = []
    for dp_idx, dp in enumerate(eval_examples):

        inputs, outputs, answer = _prepro_each_datapoint(
            dp, is_first=False, for_unlabeled_demonstrations=False, add_newlines=True, tokenizer_gpt=tokenizer_gpt)
        
        one_test_instance_input_text,one_test_instance_output_text = format_example(example=dp,args=args,
                                                                            label_map=label_map)
        cur_prompt_string_len = get_instance_length(one_test_instance_input_text,one_test_instance_output_text,tokenizer_gpt)[0]
        all_losses = []

        norm1 = demonstrations_freqs / np.linalg.norm(demonstrations_freqs,ord=1)
        sample_indices = np.random.choice(range(len(demonstrations)), M_sample, p=norm1) # select demos
        # sample_indices = np.random.choice(range(len(demonstrations)), M_sample, p=np_softmax(norm1)) # select demos
        sample_demonstrations = [demonstrations[sample_index] for sample_index in sample_indices]
        for demonstration in sample_demonstrations:

            # input_ids, attention_mask, token_type_ids = [], [], []

            cur_max_output_len = max(map(lambda x: len(x), outputs))
            input_len = len(inputs[0])
            example_losses = []
            for inputs_, outputs_ in zip(inputs, outputs):

                inputs_ = demonstration + inputs_

                encoded = prepro_sentence_pair_single(
                    inputs_, outputs_, min(len(demonstration)+input_len+cur_max_output_len+10, maximum_input_len), bos_token_id, eos_token_id,
                    allow_truncation=True)

                # input_ids.append(encoded[0])
                # attention_mask.append(encoded[1])
                # token_type_ids.append(encoded[2])
            
                inference_inputs = dict(input_ids=torch.LongTensor(encoded[0]).cuda(),
                attention_mask=torch.LongTensor(encoded[1]).cuda(),
                token_type_ids=torch.LongTensor(encoded[2]).cuda())

                with torch.no_grad():
                    lm_outputs = upper_bound_inference_model(**inference_inputs)
                    logits = lm_outputs.logits[..., :-1, :].contiguous()
                    labels = inference_inputs['input_ids'][..., 1:].contiguous()
                    label_mask = inference_inputs['token_type_ids'][..., 1:].contiguous()

                    loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
                    losses = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) # [batch_size, length]

                    losses = losses * label_mask
                    losses = torch.sum(losses) / torch.sum(label_mask)
                    example_losses.append(losses.item())
            if np.argmin(example_losses) != dp['label']: ##error case
                all_losses.append(1000+example_losses[dp['label']])
            else:
                all_losses.append(example_losses[dp['label']])
            
            # inputs_, outputs_ = inputs[dp['label']], outputs[dp['label']]
            # inputs_ = demonstration + inputs_
            # encoded = prepro_sentence_pair_single(
            #     inputs_, outputs_, len(demonstration)+input_len+cur_max_output_len, bos_token_id, eos_token_id,
            #     allow_truncation=True)
            # inference_inputs = dict(input_ids=torch.LongTensor(encoded[0]).cuda(),
            # attention_mask=torch.LongTensor(encoded[1]).cuda(),
            # token_type_ids=torch.LongTensor(encoded[2]).cuda())

            # with torch.no_grad():
            #     lm_outputs = upper_bound_inference_model(**inference_inputs)
            #     logits = lm_outputs.logits[..., :-1, :].contiguous()
            #     labels = inference_inputs['input_ids'][1:].contiguous()
            #     label_mask = inference_inputs['token_type_ids'][1:].contiguous()

            #     loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
            #     losses = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) # [batch_size, length]

            #     losses = losses * label_mask
            #     losses = torch.sum(losses) / torch.sum(label_mask)

            # all_losses.append(losses.item())
            bar.update(1)
        assert len(all_losses) == len(sample_indices)
        oracle_indices = [sample_indices[oracle_index] for oracle_index in np.argsort(all_losses)]

        oracle_examples_indices = np.array(oracle_indices) // len(dp['options']) # U_tr
        oracle_examples_label_indices = np.array(oracle_indices) % len(dp['options'])

        # oracle_examples_indices = np.array(oracle_indices)
        # oracle_indices = np.array(oracle_indices)


        losses_list = np.sort(all_losses).tolist()
        if losses_list[0] > 1000:
            continue
        selected_indices = oracle_indices[:2]
        
        
        if dp_idx == 0:
            print('selected_indices: ', selected_indices, 'oracle_examples_indices:', list(oracle_examples_indices), 'oracle_examples_label_indices:', list(oracle_examples_label_indices))
        
        # select_num = len(selected_indices)
        # cur_train_data = []
        # for idx in range(select_num - 1, -1, -1):
        #     cur_input_text, cur_output_text = train_examples[selected_indices[idx]]['input'], train_examples[selected_indices[idx]]['output']
        #     if args.task_name=='hellaswag':
        #         cur_train_data.append({
        #             'input': cur_input_text,
        #             'output': cur_output_text,
        #             'options': train_examples[selected_indices[idx]]['endings']
        #         })
        #     else:
        #         cur_train_data.append({
        #             'input': cur_input_text,
        #             'output': cur_output_text
        #         })
        for a_i, oracle_index in enumerate(oracle_indices[:len(selected_indices)]):
            demonstrations_freqs[oracle_index] += 30 * (len(selected_indices) - a_i)

        all_selected_indices.extend([int(_) for _ in selected_indices])
        # with open(os.path.join(prompt_cache_dir,f"{dp['id']}.json"),'w') as f:
        #     json.dump([[int(dp_idx), [[int(_)] for _ in selected_indices], dp['label']],
        #                 cur_train_data,
        #                 dp
        #                 ], f, indent=4)
    return list(set(all_selected_indices)), demonstrations_freqs

def in_dev_multi_prompt_retrieval(train_embs,test_embs,train_examples,eval_examples,return_string,format_example,
                     maximum_input_len,args, label_map,prompt_identifier='prompts',single_context_example_len=None, inference_model=None, tokenizer_gpt=None):

    prompt_cache_dir = os.path.join(args.output_dir,prompt_identifier)
    if not os.path.isdir(prompt_cache_dir):
        os.makedirs(prompt_cache_dir, exist_ok=True)

    if not args.task_name in ['hellaswag','xsum','nq','commonsense_qa','piqa','cosmos_qa', 'copa']:
        all_labels = []
        label_to_digit = {}
        for k, v in label_map.items():
            all_labels.append(v)
            label_to_digit[v] = k
    
    for i, dp in enumerate(train_examples):
        dp['input'] = format_example(dp, label_map=label_map, args=args)[0]
        if args.task_name in ['hellaswag', 'commonsense_qa', 'piqa', 'cosmos_qa', 'copa']:
            dp['options'] = dp['endings']
        else:
            dp['options'] = all_labels
    for i, dp in enumerate(eval_examples):
        dp['input'], dp['output'] = format_example(dp, label_map=label_map, args=args)
        if args.task_name in ['hellaswag', 'commonsense_qa', 'piqa', 'cosmos_qa', 'copa']:
            dp['options'] = dp['endings']
        else:
            dp['options'] = all_labels

    
    if isinstance(inference_model, MetaICLModel):
        upper_bound_inference_model = inference_model.model
    else:
        upper_bound_inference_model = inference_model
    
    demonstrations = []
    for i, dp in enumerate(train_examples):
        input_ = _prepro_each_datapoint(
            dp, is_first=True, for_unlabeled_demonstrations=True,
            add_newlines=True, tokenizer_gpt=tokenizer_gpt)
        demonstrations.extend(input_) # U_tr
        # demonstrations.append(input_)
    nonfirst_demonstrations = []
    for i, dp in enumerate(train_examples):
        input_ = _prepro_each_datapoint(
            dp, is_first=False, for_unlabeled_demonstrations=True,
            add_newlines=True, tokenizer_gpt=tokenizer_gpt)
        nonfirst_demonstrations.extend(input_) # U_tr
        # demonstrations.append(input_)
    assert len(demonstrations) == len(dp['options']) * len(train_examples) # U_tr
    # assert len(demonstrations) == len(train_examples)


    M_sample = args.M_sample

    bos_token_id = tokenizer_gpt.bos_token_id
    eos_token_id = tokenizer_gpt.eos_token_id
    cos = nn.CosineSimilarity(dim=1, eps=1e-6)

    bar = tqdm(range(len(eval_examples) * M_sample), desc=f"Calculate acc examples for dev set, eval: {len(eval_examples)}, demo: {M_sample}")
    all_selected_indices = []

    candidate_indices_set = []
    while True:
        if not args.task_name in ['hellaswag','xsum','nq','commonsense_qa','piqa','cosmos_qa', 'copa']:
            per_label_num = 10 // len(dp['options'])
            label_num = len(dp['options'])
            per_label_indices = [[demo_index for demo_index in range(i,len(demonstrations),label_num)] for i in range(label_num)]
            assert len(per_label_indices) == label_num
            indices = []
            for per_label_indices_ in per_label_indices:
                indices.extend(random.sample(per_label_indices_, per_label_num))
            assert len(indices) == per_label_num * label_num
            random.shuffle(indices)
        else:
            indices = random.sample(range(len(demonstrations)), 10)
        new_combo = tuple(indices)
        if new_combo not in candidate_indices_set:
            candidate_indices_set.append(new_combo)
        if len(candidate_indices_set) >= M_sample:
            break
    
    all_M_sample_acc = []
    for step_i in range(M_sample):
        sample_indices = np.array(candidate_indices_set[step_i])
        sample_embeds_indices = np.array(sample_indices // len(dp['options']))

        correct = 0
        total = 0
        for dp_idx, dp in enumerate(eval_examples):

            inputs, outputs, answer = _prepro_each_datapoint(
                dp, is_first=False, for_unlabeled_demonstrations=False, add_newlines=True, tokenizer_gpt=tokenizer_gpt)
            
            one_test_instance_input_text,one_test_instance_output_text = format_example(example=dp,args=args,
                                                                                label_map=label_map)
            cur_prompt_string_len = get_instance_length(one_test_instance_input_text,one_test_instance_output_text,tokenizer_gpt)[0]

            cur_max_output_len = max(map(lambda x: len(x), outputs))
            input_len = len(inputs[0])
            test_e_reshape = test_embs[dp_idx].reshape(1, -1)
            scores = cos(torch.tensor(test_e_reshape), torch.tensor(train_embs[sample_embeds_indices])).numpy()
            demo_sorted_indices = np.argsort(-scores)
            demonstration = []
            for demo_i, demo_idx in enumerate(demo_sorted_indices):
                if demo_i == len(demo_sorted_indices) - 1:
                    added_demo = demonstrations[sample_indices[demo_idx]]
                else:
                    added_demo = nonfirst_demonstrations[sample_indices[demo_idx]]
                if len(demonstration) + len(added_demo) + input_len + cur_max_output_len > maximum_input_len:
                    break
                demonstration = added_demo + demonstration
            example_losses = []
            for inputs_, outputs_ in zip(inputs, outputs):
                inputs_ = demonstration + inputs_

                encoded = prepro_sentence_pair_single(
                    inputs_, outputs_, len(demonstration)+input_len+cur_max_output_len, bos_token_id, eos_token_id,
                    allow_truncation=True)

                # input_ids.append(encoded[0])
                # attention_mask.append(encoded[1])
                # token_type_ids.append(encoded[2])
            
                inference_inputs = dict(input_ids=torch.LongTensor(encoded[0]).cuda(),
                attention_mask=torch.LongTensor(encoded[1]).cuda(),
                token_type_ids=torch.LongTensor(encoded[2]).cuda())

                with torch.no_grad():
                    lm_outputs = upper_bound_inference_model(**inference_inputs)
                    logits = lm_outputs.logits[..., :-1, :].contiguous()
                    labels = inference_inputs['input_ids'][..., 1:].contiguous()
                    label_mask = inference_inputs['token_type_ids'][..., 1:].contiguous()

                    loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
                    losses = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) # [batch_size, length]

                    losses = losses * label_mask
                    losses = torch.sum(losses) / torch.sum(label_mask)
                    example_losses.append(losses.item())
            if np.argmin(example_losses) != dp['label']: ##error case
                total+=1
            else:
                total+=1
                correct+=1

            bar.update(1)

        in_dev_acc = correct * 1.0 / total
        print('in_dev_acc: ', in_dev_acc, 'oracle_examples_indices:', list(sample_indices // len(dp['options'])), 'oracle_examples_label_indices:', list(sample_indices % len(dp['options'])))

        all_M_sample_acc.append(in_dev_acc)

        # oracle_examples_indices = np.array(oracle_indices)
        # oracle_indices = np.array(oracle_indices)

        # selected_indices = []
        # for idx in range(len(oracle_examples_indices)):
        #     cur_example_input_text,cur_example_output_text = format_example(example=train_examples[oracle_examples_indices[idx]],
        #                                                                     args=args,label_map=label_map)
        #     cur_len = sum(get_instance_length(cur_example_input_text, cur_example_output_text,tokenizer=tokenizer_gpt))
        #     cur_prompt_string_len += cur_len
        #     if cur_prompt_string_len > 1024:
        #         break
        #     if losses_list[idx] > 1000:
        #         break
        #     selected_indices.append(oracle_examples_indices[idx])
        
        
        # if dp_idx == 0:
        #     print('selected_indices: ', selected_indices, 'oracle_examples_indices:', list(oracle_examples_indices), 'oracle_examples_label_indices:', list(oracle_examples_label_indices))


        # all_selected_indices.extend([int(_) for _ in selected_indices])
        # with open(os.path.join(prompt_cache_dir,f"{dp['id']}.json"),'w') as f:
        #     json.dump([[int(dp_idx), [[int(_)] for _ in selected_indices], dp['label']],
        #                 cur_train_data,
        #                 dp
        #                 ], f, indent=4)
    acc_sort_indices = np.argsort(-np.array(all_M_sample_acc))

    for acc_index in acc_sort_indices:
        all_selected_indices.extend(list(candidate_indices_set[acc_index]))
    return all_selected_indices

def in_dev_acc_prompt_retrieval(train_embs,test_embs,train_examples,eval_examples,return_string,format_example,
                     maximum_input_len,args, label_map,prompt_identifier='prompts',single_context_example_len=None, inference_model=None, tokenizer_gpt=None, inference_data_module=None):

    prompt_cache_dir = os.path.join(args.output_dir,prompt_identifier)
    if not os.path.isdir(prompt_cache_dir):
        os.makedirs(prompt_cache_dir, exist_ok=True)

    if not args.task_name in ['hellaswag','xsum','nq','commonsense_qa','piqa','cosmos_qa', 'copa']:
        all_labels = []
        label_to_digit = {}
        for k, v in label_map.items():
            all_labels.append(v)
            label_to_digit[v] = k
    
    for i, dp in enumerate(train_examples):
        dp['input'] = format_example(dp, label_map=label_map, args=args)[0]
        if args.task_name in ['hellaswag', 'commonsense_qa', 'piqa', 'cosmos_qa', 'copa']:
            dp['options'] = dp['endings']
        else:
            dp['options'] = all_labels
    for i, dp in enumerate(eval_examples):
        dp['input'], dp['output'] = format_example(dp, label_map=label_map, args=args)
        if args.task_name in ['hellaswag', 'commonsense_qa', 'piqa', 'cosmos_qa', 'copa']:
            dp['options'] = dp['endings']
        else:
            dp['options'] = all_labels


    if isinstance(inference_model, MetaICLModel):
        upper_bound_inference_model = inference_model.model
        # upper_bound_inference_model = inference_model
    else:
        upper_bound_inference_model = inference_model
    
    demonstrations = []
    for i, dp in enumerate(train_examples):
        input_ = _prepro_each_datapoint(
            dp, is_first=True, for_unlabeled_demonstrations=True,
            add_newlines=True, tokenizer_gpt=tokenizer_gpt)
        demonstrations.extend(input_) # U_tr
        # demonstrations.append(input_)
    nonfirst_demonstrations = []
    for i, dp in enumerate(train_examples):
        input_ = _prepro_each_datapoint(
            dp, is_first=False, for_unlabeled_demonstrations=True,
            add_newlines=True, tokenizer_gpt=tokenizer_gpt)
        nonfirst_demonstrations.extend(input_) # U_tr
    assert len(demonstrations) == len(dp['options']) * len(train_examples) # U_tr
    # assert len(demonstrations) == len(train_examples)


    M_sample = args.M_sample
    K_shot_size = args.K_shot_size

    bos_token_id = tokenizer_gpt.bos_token_id
    eos_token_id = tokenizer_gpt.eos_token_id
    cos = nn.CosineSimilarity(dim=1, eps=1e-6)

    bar = tqdm(range(len(eval_examples) * M_sample), desc=f"Calculate acc examples for dev set, eval: {len(eval_examples)}, demo: {M_sample}")
    all_selected_indices = []

    candidate_indices_set = []
    while True:
        indices = random.sample(range(len(demonstrations)), K_shot_size)
        new_combo = tuple(indices)
        if new_combo not in candidate_indices_set:
            candidate_indices_set.append(new_combo)
        if len(candidate_indices_set) >= M_sample:
            break
    candidate_indices_set = [list(_) for _ in candidate_indices_set]

    # if not args.task_name in ['hellaswag','xsum','nq','commonsense_qa','piqa','cosmos_qa', 'copa']:
    #     per_label_num = M_sample // len(dp['options'])
    #     label_num = len(dp['options'])
    #     per_label_indices = [[demo_index for demo_index in range(i,len(demonstrations),label_num)] for i in range(label_num)]
    #     assert len(per_label_indices) == label_num
    #     for per_label_indices_ in per_label_indices:
    #         candidate_indices_set.extend(random.sample(per_label_indices_, per_label_num))
    #     assert len(candidate_indices_set) == per_label_num * label_num
    #     if len(candidate_indices_set) < M_sample:
    #         while True:
    #             if len(candidate_indices_set) == M_sample:
    #                 break
    #             selected_candidate_index = random.choice(range(len(demonstrations)))
    #             if selected_candidate_index not in candidate_indices_set:
    #                 candidate_indices_set.append(selected_candidate_index)
    #     random.shuffle(candidate_indices_set)
    # else:
    #     candidate_indices_set = random.sample(range(len(demonstrations)), M_sample)

    assert len(candidate_indices_set) == M_sample
    
    print('candidate_indices_set is: ', candidate_indices_set, '. length is: ', len(candidate_indices_set))
    all_M_sample_acc = []
    for step_i in range(M_sample):
        sample_indices = np.array(candidate_indices_set[step_i])
        sample_embeds_indices = np.array(sample_indices // len(dp['options']))
        correct = 0
        total = 0
        for dp_idx, dp in enumerate(eval_examples):
            test_e_reshape = test_embs[dp_idx].reshape(1, -1)
            scores = cos(torch.tensor(test_e_reshape), torch.tensor(train_embs[sample_embeds_indices])).numpy()
            demo_sorted_indices = np.argsort(-scores)
            demonstration = []
            for demo_i, demo_idx in enumerate(demo_sorted_indices):
                if demo_i == len(demo_sorted_indices) - 1:
                    added_demo = demonstrations[sample_indices[demo_idx]]
                else:
                    added_demo = nonfirst_demonstrations[sample_indices[demo_idx]]
                demonstration = added_demo + demonstration
            inputs, outputs, answer = _prepro_each_datapoint(
                dp, is_first=False, for_unlabeled_demonstrations=False, add_newlines=True, tokenizer_gpt=tokenizer_gpt)
            
            one_test_instance_input_text,one_test_instance_output_text = format_example(example=dp,args=args,
                                                                                label_map=label_map)
            cur_prompt_string_len = get_instance_length(one_test_instance_input_text,one_test_instance_output_text,tokenizer_gpt)[0]

            cur_max_output_len = max(map(lambda x: len(x), outputs))
            input_len = len(inputs[0])
            example_losses = []
            for inputs_, outputs_ in zip(inputs, outputs):
                inputs_ = demonstration + inputs_
                encoded = prepro_sentence_pair_single(
                    inputs_, outputs_, len(demonstration)+input_len+cur_max_output_len+10, bos_token_id, eos_token_id,
                    allow_truncation=True)

                # input_ids.append(encoded[0])
                # attention_mask.append(encoded[1])
                # token_type_ids.append(encoded[2])
            
                inference_inputs = dict(input_ids=torch.LongTensor(encoded[0]).cuda(),
                attention_mask=torch.LongTensor(encoded[1]).cuda(),
                token_type_ids=torch.LongTensor(encoded[2]).cuda())

                with torch.no_grad():
                    lm_outputs = upper_bound_inference_model(**inference_inputs)
                    logits = lm_outputs.logits[..., :-1, :].contiguous()
                    labels = inference_inputs['input_ids'][..., 1:].contiguous()
                    label_mask = inference_inputs['token_type_ids'][..., 1:].contiguous()

                    loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
                    losses = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) # [batch_size, length]
                    losses = losses * label_mask
                    losses = torch.sum(losses) / torch.sum(label_mask)
                    example_losses.append(losses.item())
            if np.argmin(example_losses) != dp['label']: ##error case
                total+=1
            else:
                total+=1
                correct+=1

            bar.update(1)

        in_dev_acc = correct * 1.0 / total
        print('in_dev_acc: ', in_dev_acc, 'candidata_indices: ', candidate_indices_set[step_i])

        all_M_sample_acc.append(in_dev_acc)

        # oracle_examples_indices = np.array(oracle_indices)
        # oracle_indices = np.array(oracle_indices)

        # selected_indices = []
        # for idx in range(len(oracle_examples_indices)):
        #     cur_example_input_text,cur_example_output_text = format_example(example=train_examples[oracle_examples_indices[idx]],
        #                                                                     args=args,label_map=label_map)
        #     cur_len = sum(get_instance_length(cur_example_input_text, cur_example_output_text,tokenizer=tokenizer_gpt))
        #     cur_prompt_string_len += cur_len
        #     if cur_prompt_string_len > 1024:
        #         break
        #     if losses_list[idx] > 1000:
        #         break
        #     selected_indices.append(oracle_examples_indices[idx])
        
        
        # if dp_idx == 0:
        #     print('selected_indices: ', selected_indices, 'oracle_examples_indices:', list(oracle_examples_indices), 'oracle_examples_label_indices:', list(oracle_examples_label_indices))


        # all_selected_indices.extend([int(_) for _ in selected_indices])
        # with open(os.path.join(prompt_cache_dir,f"{dp['id']}.json"),'w') as f:
        #     json.dump([[int(dp_idx), [[int(_)] for _ in selected_indices], dp['label']],
        #                 cur_train_data,
        #                 dp
        #                 ], f, indent=4)
    acc_sort_indices = np.argsort(-np.array(all_M_sample_acc))

    # for acc_index in acc_sort_indices:
    #     all_selected_indices.append(candidate_indices_set[acc_index])
    # return all_selected_indices
    for acc_index in acc_sort_indices:
        all_selected_indices.extend(list(candidate_indices_set[acc_index]))
    return all_selected_indices

def fast_votek(embeddings,select_num,k,vote_file=None):
    n = len(embeddings)
    if vote_file is not None and os.path.isfile(vote_file):
        with open(vote_file) as f:
            vote_stat = json.load(f)
    else:
        bar = tqdm(range(n),desc=f'voting')
        vote_stat = defaultdict(list)
        for i in range(n):
            cur_emb = embeddings[i].reshape(1, -1)
            cur_scores = np.sum(cosine_similarity(embeddings, cur_emb), axis=1)
            sorted_indices = np.argsort(cur_scores).tolist()[-k-1:-1]
            for idx in sorted_indices:
                if idx!=i:
                    vote_stat[idx].append(i)
            bar.update(1)
        if vote_file is not None:
            with open(vote_file,'w') as f:
                json.dump(vote_stat,f)
    votes = sorted(vote_stat.items(),key=lambda x:len(x[1]),reverse=True)
    selected_indices = []
    selected_times = defaultdict(int)
    while len(selected_indices)<select_num:
        cur_scores = defaultdict(int)
        for idx,candidates in votes:
            if idx in selected_indices:
                cur_scores[idx] = -100
                continue
            for one_support in candidates:
                if not one_support in selected_indices:
                    cur_scores[idx] += 10 ** (-selected_times[one_support])
        cur_selected_idx = max(cur_scores.items(),key=lambda x:x[1])[0]
        selected_indices.append(int(cur_selected_idx))
        for idx_support in vote_stat[cur_selected_idx]:
            selected_times[idx_support] += 1
    return selected_indices

def iterative_selection(train_embs,test_embs,train_examples,test_examples,return_string,format_example,maximum_input_len,
                        label_map,single_context_example_len,inference_model,inference_data_module,tokenizer_gpt,args):
    if args.selective_annotation_method=='least_confidence':
        selected_indices = random.sample(range(len(train_examples)),args.batch_size)
    elif args.selective_annotation_method=='votek':
        selected_indices = fast_votek(embeddings=train_embs,select_num=args.batch_size,k=150,
                                      vote_file=os.path.join(args.output_dir,'votek_cache.json'))
    else:
        raise ValueError(f'iterative selection does not support {args.selective_annotation_method}')
    if not args.task_name in ['hellaswag','xsum','nq','commonsense_qa','piqa','cosmos_qa', 'copa']:
        all_labels = []
        label_to_digit = {}
        for k, v in label_map.items():
            all_labels.append(v)
            label_to_digit[v] = k
    batch_count = 0
    device = torch.device('cuda')
    while len(selected_indices)<args.annotation_size:
        batch_count += 1
        cur_annotated_examples = [train_examples[idx] for idx in selected_indices]
        prompt_retrieval(train_embs=train_embs[selected_indices],
                         test_embs=test_embs,
                         train_examples=cur_annotated_examples,
                         eval_examples=test_examples,
                         return_string=return_string,
                         format_example=format_example,
                         maximum_input_len=maximum_input_len,
                         args=args,label_map=label_map,
                         prompt_identifier=f'prompts_{batch_count}',
                         single_context_example_len=single_context_example_len)

        candidate_prompt_files = os.listdir(os.path.join(args.output_dir,f'prompts_{batch_count}'))
        prompt_files = [f for f in candidate_prompt_files if f.endswith('.json')]
        assert len(prompt_files) == len(test_examples), f"len(prompt_files)={len(prompt_files)}," \
                                                                  f"len(processed_eval_examples)={len(test_examples)}"
        output_dir = os.path.join(args.output_dir,f'results_iteration_{batch_count}')
        prompt_dir = os.path.join(args.output_dir,f'prompts_{batch_count}')
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir, exist_ok=True)
        count = 0
        execution_count = 0
        if args.model_key is not None:
            model_keys = args.model_key.split('##')
        running_flag = True
        while running_flag:
            running_flag = False
            count += 1
            bar = tqdm(range(len(prompt_files)), desc=f"  prediction iteration {batch_count}")
            for file in prompt_files:
                bar.update(1)
                if not os.path.isfile(os.path.join(output_dir,file)):
                    running_flag = True

                    if args.task_name in ['hellaswag', 'commonsense_qa', 'piqa', 'cosmos_qa', 'copa']:
                        with open(os.path.join(prompt_dir, file)) as f:
                            one_test_example = json.load(f)
                        cur_train_data = one_test_example[1]
                        cur_input = {'input': format_example(one_test_example[2], label_map=label_map, args=args)[0],
                                     'options': one_test_example[2]['endings']}
                        inference_data_module.k = len(cur_train_data)
                        inference_data_module.tensorize(cur_train_data, [cur_input])
                        prediction = inference_model.do_predict(inference_data_module, require_loss=True)[0]
                        with open(f"{output_dir}/{file}", 'w') as f:
                            json.dump(prediction, f)
                    elif args.task_name=='xsum':
                        with open(os.path.join(prompt_dir, file)) as f:
                            one_test_example = json.load(f)
                        context = one_test_example[1]
                        input_ids = tokenizer_gpt(context, return_tensors="pt").input_ids
                        input_ids = input_ids[:, :1900]
                        input_len = input_ids.shape[1]
                        input_ids = input_ids.to(device)
                        # print(input_ids.shape)
                        # print(os.path.join(prompt_dir,file))
                        gen_tokens = inference_model.generate(input_ids, do_sample=False, temperature=0.7,
                                                              max_length=input_len + 64,
                                                              output_scores=True, return_dict_in_generate=True)
                        generated_text = tokenizer_gpt.batch_decode(gen_tokens.sequences.view(-1, 1))  #
                        stop = ['--', '\n', ';', '#']
                        stop_index = len(generated_text)
                        for i, c in enumerate(generated_text):
                            if i > input_len and c.strip(' ') in stop:
                                stop_index = i
                                break
                        prediction = [' '.join(generated_text[input_len:stop_index]),
                                      sum(gen_tokens.probs[:stop_index - input_len])]
                        with open(f"{output_dir}/{file}", 'w') as f:
                            json.dump(prediction, f)
                    elif args.task_name=='nq':
                        cur_key = model_keys[execution_count % len(model_keys)]
                        execution_count += 1
                        try:
                            codex_execution(key=cur_key,output_path=os.path.join(output_dir,file),
                                            prompt_path=os.path.join(prompt_dir, file))
                        except Exception as e:
                            print(e)
                            time.sleep(3)
                    else:
                        with open(os.path.join(prompt_dir, file)) as f:
                            one_test_example = json.load(f)
                        cur_train_data = one_test_example[1]
                        for idx in range(len(cur_train_data)):
                            cur_train_data[idx]['options'] = all_labels
                        cur_input = format_example(one_test_example[2],label_map=label_map,args=args)[0]
                        inference_data_module.k = len(cur_train_data)
                        inference_data_module.tensorize(cur_train_data, [cur_input], options=all_labels)
                        prediction = inference_model.do_predict(inference_data_module, require_loss=True)[0]
                        with open(f"{output_dir}/{file}", 'w') as f:
                            json.dump(prediction, f)


        idx_scores = {}
        n = len(test_examples)
        for idx in range(n):
            if idx in selected_indices:
                if args.task_name in ['xsum','nq']:
                    idx_scores[idx] = float('inf')
                else:
                    idx_scores[idx] = float('-inf')
                continue
            with open(f"{output_dir}/{idx}.json") as f:
                one_pred = json.load(f)
                if args.task_name in ['nq']:
                    idx_scores[idx] = sum(one_pred['choices'][0]["logprobs"]["token_logprobs"]) / len(
                        one_pred['choices'][0]["logprobs"]["token_logprobs"])
                else:
                    idx_scores[idx] = one_pred[1]
        if args.task_name in ['xsum','nq']:
            sorted_scores = sorted(idx_scores.items(), key=lambda x: x[1])
        else:
            sorted_scores = sorted(idx_scores.items(), key=lambda x:x[1],reverse=True)
        sorted_scores_len = len(sorted_scores)
        if args.selective_annotation_method=='least_confidence':
            cur_selected = []
            cur_select_num = min(args.batch_size, args.annotation_size - len(selected_indices))
            for sorted_scores_iter in range(sorted_scores_len):
                if len(cur_selected)>=cur_select_num:
                    break
                if not sorted_scores[sorted_scores_iter][0] in selected_indices:
                    cur_selected.append(sorted_scores[sorted_scores_iter][0])
            selected_indices += cur_selected
        else:
            with open(os.path.join(args.output_dir,'votek_cache.json')) as f:
                vote_stat = json.load(f)
            selected_times = defaultdict(int)
            select_num_1 = args.annotation_size - len(selected_indices)
            inter = int(len(train_examples) * 0.9 / select_num_1)
            for prev_idx in selected_indices:
                for idx_support in vote_stat[str(prev_idx)]:
                    selected_times[idx_support] += 1
            count_t = 0
            while len(selected_indices) < args.annotation_size and count_t * inter < sorted_scores_len:
                cur_scores = defaultdict(int)
                for idx, _ in sorted_scores[count_t * inter:(count_t + 1) * inter]:
                    if not str(idx) in vote_stat:
                        cur_scores[idx] = 0
                        continue
                    candidates = vote_stat[str(idx)]
                    if idx in selected_indices:
                        cur_scores[idx] = -100
                        continue
                    for one_support in candidates:
                        if not one_support in selected_indices:
                            cur_scores[idx] += 10 ** (-selected_times[one_support])
                cur_selected_idx = max(cur_scores.items(), key=lambda x: x[1])[0]
                selected_indices.append(cur_selected_idx)
                if cur_selected_idx in vote_stat:
                    for idx_support in vote_stat[cur_selected_idx]:
                        selected_times[idx_support] += 1
                count_t += 1
            if len(selected_indices) < args.annotation_size:
                unselected_indices = []
                for unselected_i in range(len(train_examples)):
                    if not unselected_i in selected_indices:
                        unselected_indices.append(unselected_i)
                selected_indices += random.sample(unselected_indices, args.annotation_size - len(selected_indices))
                print(f"{args.annotation_size - len(selected_indices)} examples are randomly selected")
    return selected_indices

def get_dpp_kernel(embed, perplexities, args):
    assert embed.shape[0] == len(perplexities)
    # normalize first
    embed = embed / np.linalg.norm(embed)
    perplexities = np.reciprocal(np.array(perplexities)) # 求倒数
    perplexities = perplexities / np.linalg.norm(perplexities)

    # to prevent overflow error
    perplexities -= perplexities.max()

    # to balance relevance and diversity
    perplexities = np.exp(perplexities / (2 * args.scale_factor))

    # to make kernel-matrix non-negative
    sim_matrix = np.matmul(embed, embed.T)
    sim_matrix = (sim_matrix + 1) / 2

    kernel_matrix = perplexities[:, None] * sim_matrix * perplexities[None]
    return perplexities, kernel_matrix

def get_vanilla_dpp_kernel(embed, args):
    # normalize first
    embed = embed / np.linalg.norm(embed)

    # to make kernel-matrix non-negative
    sim_matrix = np.matmul(embed, embed.T)
    sim_matrix = (sim_matrix + 1) / 2

    return sim_matrix

def fast_map_dpp(kernel_matrix, max_length):
    """
    fast implementation of the greedy algorithm
    reference: https://github.com/laming-chen/fast-map-dpp/blob/master/dpp_test.py
    paper: Fast Greedy MAP Inference for Determinantal Point Process to Improve Recommendation Diversity
    """
    item_size = kernel_matrix.shape[0]
    cis = np.zeros((max_length, item_size))
    di2s = np.copy(np.diag(kernel_matrix))
    selected_items = list()
    selected_item = np.argmax(di2s)
    selected_items.append(int(selected_item))
    while len(selected_items) < max_length:
        k = len(selected_items) - 1
        ci_optimal = cis[:k, selected_item]
        di_optimal = math.sqrt(di2s[selected_item])
        elements = kernel_matrix[selected_item, :]
        eis = (elements - np.dot(ci_optimal, cis[:k, :])) / di_optimal
        cis[k, :] = eis
        di2s -= np.square(eis)
        di2s[selected_item] = -np.inf
        selected_item = np.argmax(di2s)
        if di2s[selected_item] < 1e-10:
            break
        selected_items.append(int(selected_item))

    return selected_items

def dpp_selection(embeddings, train_examples, select_num, raw_diversity_num, inference_model, label_map, tokenizer_gpt, format_example, args):

    assert raw_diversity_num > select_num
    corpus_embeddings = copy.deepcopy(np.array(embeddings))

    nearest_points = []
    # cos_sim_matrix = cosine_similarity(corpus_embeddings, corpus_embeddings)
    # cos_scores = np.sum(cos_sim_matrix, axis=1)
    # first_id = np.argmax(cos_scores)
    first_id = random.choice(range(len(corpus_embeddings)))
    nearest_points.append(first_id)
    selected_representations = corpus_embeddings[first_id].reshape(1, -1)
    for count in range(raw_diversity_num - 1):
        scores = np.sum(cosine_similarity(corpus_embeddings, selected_representations), axis=1)
        for i in nearest_points:
            scores[i] = float('inf')
        min_idx = np.argmin(scores)
        selected_representations = np.concatenate((selected_representations,
                                                corpus_embeddings[min_idx].reshape(1, -1)), axis=0)
        nearest_points.append(min_idx.item())
    
    assert len(corpus_embeddings) == len(train_examples)
    assert len(nearest_points) == raw_diversity_num

    kernel_matrix = get_vanilla_dpp_kernel(corpus_embeddings[nearest_points], args=args)
    selected_indices = fast_map_dpp(kernel_matrix, select_num)
    selected_indices = [int(nearest_points[selected_indice]) for selected_indice in selected_indices]
    if len(selected_indices) < select_num:
        unselected_indices = [i for i in range(len(train_examples)) if i not in selected_indices]
        selected_indices = selected_indices + random.sample(unselected_indices, select_num - len(selected_indices))

    return selected_indices

def lm_dpp_selection(embeddings, train_examples, select_num, raw_diversity_num, inference_model, label_map, tokenizer_gpt, format_example, args):

    assert raw_diversity_num > select_num
    corpus_embeddings = copy.deepcopy(np.array(embeddings))


    # # Perform kmean clustering
    # clustering_model = KMeans(n_clusters=raw_diversity_num).fit(corpus_embeddings)
    # cluster_assignment = clustering_model.labels_
    # cluster_centers = clustering_model.cluster_centers_

    # nearest_points = []
    # for i, center in enumerate(cluster_centers):
    #     cluster_indices = np.where(cluster_assignment == i)[0]
    #     distances = euclidean_distances([center], [corpus_embeddings[j] for j in cluster_indices])
    #     closest_index = np.argmin(distances)
    #     nearest_points.append(cluster_indices[closest_index])
    
    # print(f'Kmeans Clustering Done...      Cluseter nums:{len(cluster_centers)} ...')

    # Perfrom diversity selection

    # Diversity
    nearest_points = []
    # cos_sim_matrix = cosine_similarity(corpus_embeddings, corpus_embeddings)
    # cos_scores = np.sum(cos_sim_matrix, axis=1)
    # first_id = np.argmax(cos_scores)
    first_id = random.choice(range(len(corpus_embeddings)))
    nearest_points.append(first_id)
    selected_representations = corpus_embeddings[first_id].reshape(1, -1)
    for count in range(raw_diversity_num - 1):
        scores = np.sum(cosine_similarity(corpus_embeddings, selected_representations), axis=1)
        for i in nearest_points:
            scores[i] = float('inf')
        min_idx = np.argmin(scores)
        selected_representations = np.concatenate((selected_representations,
                                                corpus_embeddings[min_idx].reshape(1, -1)), axis=0)
        nearest_points.append(min_idx.item())
    
    assert len(corpus_embeddings) == len(train_examples)
    assert len(nearest_points) == raw_diversity_num

    # Vanilla LM_DPP

    # nearest_points = list(range(len(train_examples)))

    if isinstance(inference_model, MetaICLModel):
        dpp_inference_model = inference_model.model
    else:
        dpp_inference_model = inference_model

    raw_diversity_examples = [train_examples[nearest_point] for nearest_point in nearest_points]
    raw_diversity_sentences = [format_example(raw_diversity_example, label_map=label_map)[0] for raw_diversity_example in raw_diversity_examples]

    tokenizer_gpt.truncation_side='right'
    perplexities = []
    with torch.no_grad():
        for sentence in tqdm(raw_diversity_sentences, desc='Calculate perplexity for Candidates'):
            if 'llama' in args.model_cache_dir:
                inputs = tokenizer_gpt.encode(sentence, add_special_tokens=True, return_tensors="pt", truncation=True, max_length=dpp_inference_model.config.max_position_embeddings).cuda()
            else:
                inputs = tokenizer_gpt.encode(sentence, add_special_tokens=True, return_tensors="pt", truncation=True, max_length=dpp_inference_model.config.n_positions).cuda()
            outputs = dpp_inference_model(inputs, labels=inputs)
            loss = outputs.loss
            perplexity = torch.exp(loss)
            perplexities.append(perplexity.item())
    
    perplexities, kernel_matrix = get_dpp_kernel(corpus_embeddings[nearest_points], perplexities, args=args)

    selected_indices = fast_map_dpp(kernel_matrix, select_num)
    selected_indices = [int(nearest_points[selected_indice]) for selected_indice in selected_indices]
    if len(selected_indices) < select_num:
        unselected_indices = [i for i in range(len(train_examples)) if i not in selected_indices]
        selected_indices = selected_indices + random.sample(unselected_indices, select_num - len(selected_indices))

    return selected_indices

def lm_dpp_selection_in_domain_dev(embeddings, train_examples, select_num, raw_diversity_num, inference_model, label_map, tokenizer_gpt, format_example, inference_data_module, args):

    assert raw_diversity_num > select_num
    corpus_embeddings = copy.deepcopy(np.array(embeddings))


    # # Perform kmean clustering
    # clustering_model = KMeans(n_clusters=raw_diversity_num).fit(corpus_embeddings)
    # cluster_assignment = clustering_model.labels_
    # cluster_centers = clustering_model.cluster_centers_

    # nearest_points = []
    # for i, center in enumerate(cluster_centers):
    #     cluster_indices = np.where(cluster_assignment == i)[0]
    #     distances = euclidean_distances([center], [corpus_embeddings[j] for j in cluster_indices])
    #     closest_index = np.argmin(distances)
    #     nearest_points.append(cluster_indices[closest_index])
    
    # print(f'Kmeans Clustering Done...      Cluseter nums:{len(cluster_centers)} ...')

    # Perfrom diversity selection

    # Diversity
    nearest_points = []
    first_id = random.choice(range(len(corpus_embeddings)))
    nearest_points.append(first_id)
    selected_representations = corpus_embeddings[first_id].reshape(1, -1)
    for count in range(raw_diversity_num - 1):
        scores = np.sum(cosine_similarity(corpus_embeddings, selected_representations), axis=1)
        for i in nearest_points:
            scores[i] = float('inf')
        min_idx = np.argmin(scores)
        selected_representations = np.concatenate((selected_representations,
                                                corpus_embeddings[min_idx].reshape(1, -1)), axis=0)
        nearest_points.append(min_idx.item())
    
    assert len(corpus_embeddings) == len(train_examples)
    assert len(nearest_points) == raw_diversity_num

    # Vanilla LM_DPP

    # nearest_points = list(range(len(train_examples)))

    in_dev_rate = args.in_dev_rate
    in_dev_num = int(args.in_dev_rate*select_num)

    data_model_select_num = args.annotation_size - in_dev_num

    if isinstance(inference_model, MetaICLModel):
        dpp_inference_model = inference_model.model
    else:
        dpp_inference_model = inference_model

    raw_diversity_examples = [train_examples[nearest_point] for nearest_point in nearest_points]
    raw_diversity_sentences = [format_example(raw_diversity_example, label_map=label_map)[0] for raw_diversity_example in raw_diversity_examples]

    tokenizer_gpt.truncation_side='right'
    perplexities = []
    with torch.no_grad():
        for sentence in tqdm(raw_diversity_sentences, desc='Step1(select in-dev set): Calculate perplexity for Candidates'):
            if 'llama' in args.model_cache_dir:
                inputs = tokenizer_gpt.encode(sentence, add_special_tokens=True, return_tensors="pt", truncation=True, max_length=dpp_inference_model.config.max_position_embeddings).cuda()
            else:
                inputs = tokenizer_gpt.encode(sentence, add_special_tokens=True, return_tensors="pt", truncation=True, max_length=dpp_inference_model.config.n_positions).cuda()
            outputs = dpp_inference_model(inputs, labels=inputs)
            loss = outputs.loss
            perplexity = torch.exp(loss)
            perplexities.append(perplexity.item())
    
    perplexities, kernel_matrix = get_dpp_kernel(corpus_embeddings[nearest_points], perplexities, args=args)
    
    selected_indices = fast_map_dpp(kernel_matrix, in_dev_num)
    in_dev_indices = [int(nearest_points[selected_indice]) for selected_indice in selected_indices]
    in_dev_examples = [train_examples[in_dev_index] for in_dev_index in in_dev_indices]
    
    # in_dev_multi_retrieval
    # remain_examples_indices = [all_index for all_index in range(len(train_examples)) if all_index not in in_dev_indices]
    # remain_train_examples = [train_examples[all_index] for all_index in remain_examples_indices]
    # assert len(in_dev_examples) + len(remain_train_examples) == len(train_examples)
    # in_dev_selected_indices = in_dev_multi_prompt_retrieval(train_embs=corpus_embeddings[remain_examples_indices],test_embs=corpus_embeddings[in_dev_indices],train_examples=remain_train_examples,
    #                                                                eval_examples=in_dev_examples,return_string=False,format_example=format_example,
    #                                                                maximum_input_len=1000,single_context_example_len=250,label_map=label_map,args=args,inference_model=inference_model,tokenizer_gpt=tokenizer_gpt)
    # if args.task_name in ['hellaswag', 'commonsense_qa', 'piqa', 'cosmos_qa', 'copa']:
    #     label_set_num = len(train_examples[0]['endings'])
    # else:
    #     label_set_num = len(label_map)
    # if len(in_dev_selected_indices) >= data_model_select_num:
    #     data_model_indices = np.array(in_dev_selected_indices[:data_model_select_num])
    #     data_model_labels_indices = (data_model_indices % label_set_num).tolist()
    #     data_model_indices = [remain_examples_indices[data_model_index] for data_model_index in (data_model_indices // label_set_num).tolist()]

    # in_dev_acc_retrieval
    # remain_examples_indices = [nearest_point for nearest_point in nearest_points if nearest_point not in in_dev_indices]
    # remain_train_examples = [train_examples[all_index] for all_index in remain_examples_indices]
    # assert len(in_dev_examples) + len(remain_train_examples) == len(nearest_points)
    # in_dev_selected_indices = in_dev_acc_prompt_retrieval(train_embs=corpus_embeddings[remain_examples_indices],test_embs=corpus_embeddings[in_dev_indices],train_examples=remain_train_examples,
    #                                                                eval_examples=in_dev_examples,return_string=False,format_example=format_example,
    #                                                                maximum_input_len=1000,single_context_example_len=250,label_map=label_map,args=args,inference_model=inference_model,inference_data_module=inference_data_module,tokenizer_gpt=tokenizer_gpt)
    # if args.task_name in ['hellaswag', 'commonsense_qa', 'piqa', 'cosmos_qa', 'copa']:
    #     label_set_num = len(train_examples[0]['endings'])
    # else:
    #     label_set_num = len(label_map)

    # if len(in_dev_selected_indices) >= data_model_select_num:
    #     data_model_indices = np.array(in_dev_selected_indices[:data_model_select_num])
    #     data_model_labels_indices = (data_model_indices % label_set_num).tolist()
    #     data_model_indices = [remain_examples_indices[data_model_index] for data_model_index in (data_model_indices // label_set_num).tolist()]


    # in_dev_one_retrieval
    remain_examples_indices = [all_index for all_index in range(len(train_examples)) if all_index not in in_dev_indices]
    remain_train_examples = [train_examples[all_index] for all_index in remain_examples_indices]
    assert len(in_dev_examples) + len(remain_train_examples) == len(train_examples)
    in_dev_selected_indices, demonstrations_freqs = in_dev_prompt_retrieval(train_embs=corpus_embeddings[remain_examples_indices],test_embs=corpus_embeddings[in_dev_indices],train_examples=remain_train_examples,
                                                                   eval_examples=in_dev_examples,return_string=False,format_example=format_example,
                                                                   maximum_input_len=1000,single_context_example_len=250,label_map=label_map,args=args, inference_model=inference_model, tokenizer_gpt=tokenizer_gpt)

    if args.task_name in ['hellaswag', 'commonsense_qa', 'piqa', 'cosmos_qa', 'copa']:
        label_set_num = len(train_examples[0]['endings'])
    else:
        label_set_num = len(label_map)
    assert len(demonstrations_freqs) == len(remain_train_examples) * label_set_num # U_tr
    remain_train_examples_freqs = demonstrations_freqs
    if len(in_dev_selected_indices) >= data_model_select_num:
        data_model_indices = np.argsort(-remain_train_examples_freqs)[:data_model_select_num]
        data_model_labels_indices = (data_model_indices % label_set_num).tolist()
        data_model_indices = [remain_examples_indices[data_model_index] for data_model_index in (data_model_indices // label_set_num).tolist()]
    else:
        unselected_indices = []
        already_selected_indices = [remain_examples_indices[data_model_index] for data_model_index in (np.array(in_dev_selected_indices) // label_set_num).tolist()]
        for unselected_i in range(len(train_examples)):
            if not (unselected_i in already_selected_indices+in_dev_indices):
                unselected_indices.append(unselected_i)
        data_model_indices = already_selected_indices
        data_model_labels_indices = (np.array(in_dev_selected_indices)  % label_set_num).tolist()

        in_dev_indices = in_dev_indices + random.sample(unselected_indices, args.annotation_size - len(in_dev_indices+data_model_indices))
        print(f"{args.annotation_size - len(in_dev_indices+list(data_model_indices))} examples are randomly selected")

    assert len(data_model_indices+in_dev_indices) == args.annotation_size

    if data_model_labels_indices is not None:
        data_model_example_label_indices = [(int(x),int(y)) for x,y in zip(data_model_indices, data_model_labels_indices)]
        in_dev_indices = [int(_) for _ in in_dev_indices]
        return data_model_example_label_indices + in_dev_indices
    return [int(_) for _ in data_model_indices+in_dev_indices]

def selective_annotation(args,**kwargs):
    if args.selective_annotation_method=='random':
        train_examples = kwargs['train_examples']
        selected_indices = random.sample(range(len(train_examples)),args.annotation_size)
    elif args.selective_annotation_method=='diversity':
        embeddings = kwargs['embeddings']
        selected_indices = []
        first_id = random.choice(range(len(embeddings)))
        selected_indices.append(first_id)
        selected_representations = embeddings[first_id].reshape(1, -1)
        for count in range(args.annotation_size - 1):
            scores = np.sum(cosine_similarity(embeddings, selected_representations), axis=1)
            for i in selected_indices:
                scores[i] = float('inf')
            min_idx = np.argmin(scores)
            selected_representations = torch.cat((selected_representations,
                                                  embeddings[min_idx].reshape(1, -1)), 0)
            selected_indices.append(min_idx.item())
    elif args.selective_annotation_method=='fast_votek':
        selected_indices = fast_votek(embeddings=kwargs['embeddings'],select_num=args.annotation_size,k=150,
                                      vote_file=os.path.join(args.output_dir,'nearest_neighbors.json'))
    elif args.selective_annotation_method=='mfl':
        embeds = kwargs['embeddings']
        N, D = embeds.shape
        norm_embeds = embeds / embeds.norm(dim=1, keepdim=True)
        cosine = torch.einsum('nd,md->nm', norm_embeds, norm_embeds)
        selected = torch.zeros(N, dtype=torch.bool)
        max_similarity = torch.zeros(N) - 1
        for k in tqdm(range(args.annotation_size)):
            marginal_gain = torch.relu(cosine - max_similarity).sum(dim=1) * (1 - selected.float())
            node = torch.argmax(marginal_gain)
            selected[node] = True
            max_similarity = torch.max(max_similarity, cosine[node])
        selected_indices = torch.nonzero(selected).squeeze().tolist()
    elif args.selective_annotation_method=='kmeans':
        # Perform kmean clustering
        clustering_model = KMeans(n_clusters=args.annotation_size, init='k-means++').fit(np.array(kwargs['embeddings']))
        cluster_assignment = clustering_model.labels_
        cluster_centers = clustering_model.cluster_centers_

        clustered_indices = [[] for i in range(args.annotation_size)]
        for sentence_id, cluster_id in enumerate(cluster_assignment):
            clustered_indices[cluster_id].append(sentence_id)
        print(f'Kmeans Clustering Done...      Cluseter nums:{len(cluster_centers)} ...')
        selected_indices = [random.choice(clustered_indices[i]) for i in range(args.annotation_size)]
    elif args.selective_annotation_method=='perplexity':
        raw_sentences = [kwargs['format_example'](raw_example, label_map=kwargs['label_map'])[0] for raw_example in kwargs['train_examples']]
        tokenizer_gpt=kwargs['tokenizer_gpt']
        tokenizer_gpt.truncation_side='right'
        inference_model=kwargs['inference_model']
        if isinstance(inference_model, MetaICLModel):
            inference_model = inference_model.model
        else:
            inference_model = inference_model
        perplexities = []
        with torch.no_grad():
            for sentence in tqdm(raw_sentences, desc='Calculate perplexity for trainset', total=len(kwargs['train_examples'])):
                if 'llama' in args.model_cache_dir:
                    inputs = tokenizer_gpt.encode(sentence, add_special_tokens=True, return_tensors="pt", truncation=True, max_length=dpp_inference_model.config.max_position_embeddings).cuda()
                else:
                    inputs = tokenizer_gpt.encode(sentence, add_special_tokens=True, return_tensors="pt", truncation=True, max_length=dpp_inference_model.config.n_positions).cuda()
                outputs = inference_model(inputs, labels=inputs)
                loss = outputs.loss
                perplexity = torch.exp(loss)
                perplexities.append(perplexity.item())
        selected_indices = list(np.argsort(np.array(perplexities)))[:args.annotation_size]
        selected_indices = [int(_) for _ in selected_indices]
    elif args.selective_annotation_method in ['votek','least_confidence']:
        selected_indices = iterative_selection(train_embs=kwargs['embeddings'],
                                               test_embs=kwargs['embeddings'],
                                               train_examples=kwargs['train_examples'],
                                               test_examples=kwargs['train_examples'],
                                               return_string=kwargs['return_string'],
                                               format_example=kwargs['format_example'],
                                               maximum_input_len=kwargs['maximum_input_len'],
                                               label_map=kwargs['label_map'],
                                               single_context_example_len=kwargs['single_context_example_len'],
                                               inference_model=kwargs['inference_model'],
                                               inference_data_module=kwargs['inference_data_module'],
                                               tokenizer_gpt=kwargs['tokenizer_gpt'],
                                               args=args)
    elif args.selective_annotation_method == 'lm_dpp':
        selected_indices = lm_dpp_selection(embeddings=kwargs['embeddings'],
            train_examples=kwargs['train_examples'],
            select_num=args.annotation_size,
            raw_diversity_num=args.raw_diversity_num,
            inference_model=kwargs['inference_model'],
            label_map=kwargs['label_map'],
            tokenizer_gpt=kwargs['tokenizer_gpt'],
            format_example=kwargs['format_example'],
            args=args
        )
    elif args.selective_annotation_method == 'dpp':
        selected_indices = dpp_selection(embeddings=kwargs['embeddings'],
            train_examples=kwargs['train_examples'],
            select_num=args.annotation_size,
            raw_diversity_num=args.raw_diversity_num,
            inference_model=kwargs['inference_model'],
            label_map=kwargs['label_map'],
            tokenizer_gpt=kwargs['tokenizer_gpt'],
            format_example=kwargs['format_example'],
            args=args
        )
    elif args.selective_annotation_method == 'lm_dpp_in_domain_dev':
        selected_indices = lm_dpp_selection_in_domain_dev(embeddings=kwargs['embeddings'],
            train_examples=kwargs['train_examples'],
            select_num=args.annotation_size,
            raw_diversity_num=args.raw_diversity_num,
            inference_model=kwargs['inference_model'],
            label_map=kwargs['label_map'],
            tokenizer_gpt=kwargs['tokenizer_gpt'],
            format_example=kwargs['format_example'],
            inference_data_module=kwargs['inference_data_module'],
            args=args
        )
    else:
        raise ValueError(f'The selective annotation method {args.selective_annotation_method} is not supported')
    return selected_indices

def get_instance_length(input_text,output_text,tokenizer):
    return len(tokenizer(input_text)['input_ids']),len(tokenizer(output_text)['input_ids'])

