import os
import json
import random
from tqdm import tqdm
from datasets import load_dataset
import datasets
import pandas as pd

maven_label_map = {0: 'Achieve', 1: 'Action', 2: 'Adducing', 3: 'Agree_or_refuse_to_act', 4: 'Aiming', 5: 'Arranging', 6: 'Arrest', 7: 'Arriving', 8: 'Assistance', 9: 'Attack', 10: 'Award', 11: 'Bearing_arms', 12: 'Becoming', 13: 'Becoming_a_member', 14: 'Being_in_operation', 15: 'Besieging', 16: 'Bodily_harm', 17: 'Body_movement', 18: 'Breathing', 19: 'Bringing', 20: 'Building', 21: 'Carry_goods', 22: 'Catastrophe', 23: 'Causation', 24: 'Cause_change_of_position_on_a_scale', 25: 'Cause_change_of_strength', 26: 'Cause_to_amalgamate', 27: 'Cause_to_be_included', 28: 'Cause_to_make_progress', 29: 'Change', 30: 'Change_event_time', 31: 'Change_of_leadership', 32: 'Change_sentiment', 33: 'Change_tool', 34: 'Check', 35: 'Choosing', 36: 'Collaboration', 37: 'Come_together', 38: 'Coming_to_be', 39: 'Coming_to_believe', 40: 'Commerce_buy', 41: 'Commerce_pay', 42: 'Commerce_sell', 43: 'Commitment', 44: 'Committing_crime', 45: 'Communication', 46: 'Competition', 47: 'Confronting_problem', 48: 'Connect', 49: 'Conquering', 50: 'Containing', 51: 'Control', 52: 'Convincing', 53: 'Cost', 54: 'Create_artwork', 55: 'Creating', 56: 'Criminal_investigation', 57: 'Cure', 58: 'Damaging', 59: 'Death', 60: 'Deciding', 61: 'Defending', 62: 'Departing', 63: 'Destroying', 64: 'Dispersal', 65: 'Earnings_and_losses', 66: 'Education_teaching', 67: 'Emergency', 68: 'Employment', 69: 'Emptying', 70: 'Escaping', 71: 'Exchange', 72: 'Expansion', 73: 'Expend_resource', 74: 'Expressing_publicly', 75: 'Extradition', 76: 'Filling', 77: 'Forming_relationships', 78: 'GetReady', 79: 'Getting', 80: 'GiveUp', 81: 'Giving', 82: 'Having_or_lacking_access', 83: 'Hiding_objects', 84: 'Hindering', 85: 'Hold', 86: 'Hostile_encounter', 87: 'Imposing_obligation', 88: 'Incident', 89: 'Influence', 90: 'Ingestion', 91: 'Institutionalization', 92: 'Judgment_communication', 93: 'Justifying', 94: 'Kidnapping', 95: 'Killing', 96: 'Know', 97: 'Labeling', 98: 'Legality', 99: 'Legal_rulings', 100: 'Lighting', 101: 'Limiting', 102: 'Manufacturing', 103: 'Military_operation', 104: 'Motion', 105: 'Motion_directional', 106: 'Name_conferral', 107: 'Openness', 108: 'Participation', 109: 'Patrolling', 110: 'Perception_active', 111: 'Placing', 112: 'Practice', 113: 'Presence', 114: 'Preserving', 115: 'Preventing_or_letting', 116: 'Prison', 117: 'Process_end', 118: 'Process_start', 119: 'Protest', 120: 'Publishing', 121: 'Quarreling', 122: 'Ratification', 123: 'Receiving', 124: 'Recording', 125: 'Recovering', 126: 'Reforming_a_system', 127: 'Releasing', 128: 'Removing', 129: 'Renting', 130: 'Reporting', 131: 'Request', 132: 'Rescuing', 133: 'Research', 134: 'Resolve_problem', 135: 'Response', 136: 'Reveal_secret', 137: 'Revenge', 138: 'Rewards_and_punishments', 139: 'Risk', 140: 'Rite', 141: 'Robbery', 142: 'Scouring', 143: 'Scrutiny', 144: 'Self_motion', 145: 'Sending', 146: 'Sign_agreement', 147: 'Social_event', 148: 'Statement', 149: 'Submitting_documents', 150: 'Supply', 151: 'Supporting', 152: 'Surrendering', 153: 'Surrounding', 154: 'Suspicion', 155: 'Telling', 156: 'Temporary_stay', 157: 'Terrorism', 158: 'Testing', 159: 'Theft', 160: 'Traveling', 161: 'Use_firearm', 162: 'Using', 163: 'Violence', 164: 'Vocalizations', 165: 'Warning', 166: 'Wearing', 167: 'Writing'}
clinc_label_map = {0: 'pto_request_status', 1: 'schedule_meeting', 2: 'what_song', 3: 'plug_type', 4: 'do_you_have_pets', 5: 'reminder_update', 6: 'ingredient_substitution', 7: 'bill_due', 8: 'credit_limit_change', 9: 'vaccines', 10: 'spelling', 11: 'routing', 12: 'interest_rate', 13: 'are_you_a_bot', 14: 'make_call', 15: 'timer', 16: 'whisper_mode', 17: 'international_fees', 18: 'apr', 19: 'todo_list_update', 20: 'taxes', 21: 'pay_bill', 22: 'calendar_update', 23: 'last_maintenance', 24: 'change_speed', 25: 'shopping_list', 26: 'thank_you', 27: 'damaged_card', 28: 'change_language', 29: 'meaning_of_life', 30: 'calendar', 31: 'change_ai_name', 32: 'nutrition_info', 33: 'current_location', 34: 'ingredients_list', 35: 'sync_device', 36: 'repeat', 37: 'measurement_conversion', 38: 'how_old_are_you', 39: 'mpg', 40: 'definition', 41: 'direct_deposit', 42: 'balance', 43: 'new_card', 44: 'income', 45: 'translate', 46: 'travel_alert', 47: 'pto_request', 48: 'what_can_i_ask_you', 49: 'report_lost_card', 50: 'goodbye', 51: 'accept_reservations', 52: 'credit_limit', 53: 'meeting_schedule', 54: 'change_user_name', 55: 'fun_fact', 56: 'date', 57: 'alarm', 58: 'pin_change', 59: 'restaurant_reservation', 60: 'improve_credit_score', 61: 'weather', 62: 'text', 63: 'directions', 64: 'shopping_list_update', 65: 'lost_luggage', 66: 'oil_change_when', 67: 'book_hotel', 68: 'confirm_reservation', 69: 'next_holiday', 70: 'restaurant_reviews', 71: 'uber', 72: 'gas', 73: 'reminder', 74: 'oil_change_how', 75: 'international_visa', 76: 'pto_balance', 77: 'transfer', 78: 'book_flight', 79: 'cancel', 80: 'restaurant_suggestion', 81: 'tell_joke', 82: 'credit_score', 83: 'find_phone', 84: 'pto_used', 85: 'todo_list', 86: 'w2', 87: 'travel_suggestion', 88: 'schedule_maintenance', 89: 'change_accent', 90: 'distance', 91: 'spending_history', 92: 'cancel_reservation', 93: 'yes', 94: 'order', 95: 'traffic', 96: 'payday', 97: 'play_music', 98: 'expiration_date', 99: 'rewards_balance', 100: 'flight_status', 101: 'update_playlist', 102: 'food_last', 103: 'how_busy', 104: 'insurance_change', 105: 'report_fraud', 106: 'cook_time', 107: 'what_is_your_name', 108: 'recipe', 109: 'travel_notification', 110: 'reset_settings', 111: 'meal_suggestion', 112: 'car_rental', 113: 'tire_pressure', 114: 'who_made_you', 115: 'timezone', 116: 'next_song', 117: 'exchange_rate', 118: 'maybe', 119: 'no', 120: 'smart_home', 121: 'freeze_account', 122: 'calories', 123: 'who_do_you_work_for', 124: 'card_declined', 125: 'where_are_you_from', 126: 'calculator', 127: 'change_volume', 128: 'order_status', 129: 'rollover_401k', 130: 'min_payment', 131: 'redeem_rewards', 132: 'oos', 133: 'transactions', 134: 'account_blocked', 135: 'jump_start', 136: 'bill_balance', 137: 'share_location', 138: 'flip_coin', 139: 'time', 140: 'carry_on', 141: 'user_name', 142: 'roll_dice', 143: 'what_are_your_hobbies', 144: 'tire_change', 145: 'application_status', 146: 'greeting', 147: 'order_checks', 148: 'gas_type', 149: 'insurance', 150: 'replacement_card_duration'}
semeval_label_map = {0: 'Cause-Effect', 1: 'Component-Whole', 2: 'Content-Container', 3: 'Entity-Destination', 4: 'Entity-Origin', 5: 'Instrument-Agency', 6: 'Member-Collection', 7: 'Message-Topic', 8: 'Product-Producer'}


def read_json(path: str):
    examples = []
    with open(path, 'r', encoding='utf-8') as fin:
        for line in fin:
            if line.strip() != '':
                json_line = json.loads(line.strip())
                if 'label' in json_line and type(json_line['label']) is list and len(json_line['label']) == 0:
                    continue
                examples.append(json_line)
    return examples

def pre_process_maven(example):
    example['text'] = example['text'].replace('what is the event type of ', '').replace('?','')
    example['retrieve_text'] = example['text'][:example['text'].index('in')].strip()
    
    example['label'] = list(maven_label_map.values()).index(example['labels']['standard'][0]['name'])

    return example

def pre_process_clinc(example):
    example['label'] = list(clinc_label_map.values()).index(example['labels']['151class'][0]['name'])
    return example

def pre_process_semeval(example):
    example['text'] = example['text'].replace('?','').replace('What is the', 'The')
    head_entity, tail_entity = example['text'].split('The relation between')[1].strip().split(' and ')
    example['retrieve_text'] = example['text'][example['text'].index(head_entity): example['text'].index(tail_entity) + len(tail_entity)]

    example['label'] = list(semeval_label_map.values()).index(example['labels']['standard'][0]['name'])
    return example

def load_customize_dataset(path: str, args=None):
    if args.task_name == 'maven':
        examples = read_json(path)
        dataset = datasets.Dataset.from_pandas(pd.DataFrame(data=examples))
        dataset = dataset.map(pre_process_maven).remove_columns(['trace_id', 'delimiter', 'labels'])
    elif args.task_name == 'clinc':
        examples = read_json(path)
        dataset = datasets.Dataset.from_pandas(pd.DataFrame(data=examples))
        dataset = dataset.map(pre_process_clinc).remove_columns(['trace_id', 'delimiter', 'labels'])
    elif args.task_name == 'semeval':
        examples = read_json(path)
        dataset = datasets.Dataset.from_pandas(pd.DataFrame(data=examples))
        dataset = dataset.map(pre_process_semeval).remove_columns(['trace_id', 'delimiter', 'labels'])
    elif args.task_name == 'conllner':
        examples = read_json(path)
        dataset = datasets.Dataset.from_pandas(pd.DataFrame(data=examples))
        dataset = dataset.map(pre_process_conllner).remove_columns(['trace_id'])
    elif args.task_name == 'fewnerd':
        examples = read_json(path)
        dataset = datasets.Dataset.from_pandas(pd.DataFrame(data=examples))
        dataset = dataset.map(pre_process_fewnerd).remove_columns(['trace_id'])
    else:
        raise NotImplementedError
    return dataset

def format_nq_dataset(sample):
    question = sample['question']['text']
    context = sample['document']['tokens']['token']
    is_html = sample['document']['tokens']['is_html']
    long_answers = sample['annotations']['long_answer']
    short_answers = sample['annotations']['short_answers']

    context_string =  " ".join([context[i] for i in range(len(context)) if not is_html[i]])

    # 0 - No ; 1 - Yes
    for answer in sample['annotations']['yes_no_answer']:
        if answer == 0 or answer == 1:
            return {"question": question, "short": ["no" if answer == 0 else "yes"], "long": [], "category": "no" if answer == 0 else "yes"}

    short_targets = []
    for s in short_answers:
        short_targets.extend(s['text'])
    short_targets = list(set(short_targets))

    long_targets = []
    for s in long_answers:
        if s['start_token'] == -1:
            continue
        answer = context[s['start_token']: s['end_token']]
        html = is_html[s['start_token']: s['end_token']]
        new_answer = " ".join([answer[i] for i in range(len(answer)) if not html[i]])
        if new_answer not in long_targets:
            long_targets.append(new_answer)

    category = "other" if len(short_targets) > 0 else "null"

    return {"question": question, "short": short_targets, "long": long_targets, "category": category}

def process_mnli_examples(examples):
    processed_examples = []
    idx = 0
    for raw_data in tqdm(examples,desc='process mnli examples'):
        processed_examples.append({
            'id': idx,
            'label': raw_data['label'],
            'premise': raw_data['premise'],
            'hypothesis': raw_data['hypothesis'],
        })
        idx += 1
    return processed_examples

def process_qnli_examples(examples):
    processed_examples = []
    idx = 0
    for raw_data in tqdm(examples,desc='process qnli examples'):
        processed_examples.append({
            'id': idx,
            'label': raw_data['label'],
            'premise': raw_data['text2'],
            'hypothesis': raw_data['text1'],
        })
        idx += 1
    return processed_examples

def process_subj_examples(examples):
    processed_examples = []
    idx = 0
    for raw_data in tqdm(examples,desc='process subj examples'):
        processed_examples.append({
            'id': idx,
            'label': raw_data['label'],
            'text': raw_data['text'],
        })
        idx += 1
    return processed_examples

def process_scicite_examples(examples):
    processed_examples = []
    idx = 0
    for raw_data in tqdm(examples,desc='process scicite examples'):
        processed_examples.append({
            'id': idx,
            'label': raw_data['label'],
            'text': raw_data['string'],
        })
        idx += 1
    return processed_examples

def process_maven_examples(examples):
    processed_examples = []
    idx = 0
    for raw_data in tqdm(examples,desc='process maven examples'):
        processed_examples.append({
            'id': idx,
            'label': raw_data['label'],
            'text': raw_data['text']
        })
        idx += 1
    return processed_examples

def process_clinc_examples(examples):
    processed_examples = []
    idx = 0
    for raw_data in tqdm(examples,desc='process clinc examples'):
        processed_examples.append({
            'id': idx,
            'label': raw_data['label'],
            'text': raw_data['text']
        })
        idx += 1
    return processed_examples

def process_semeval_examples(examples):
    processed_examples = []
    idx = 0
    for raw_data in tqdm(examples,desc='process semeval examples'):
        processed_examples.append({
            'id': idx,
            'label': raw_data['label'],
            'text': raw_data['text']
        })
        idx += 1
    return processed_examples

def process_rte_examples(examples):
    processed_examples = []
    idx = 0
    for raw_data in tqdm(examples,desc='process rte examples'):
        processed_examples.append({
            'id': idx,
            'label': raw_data['label'],
            'sentence1': raw_data['sentence1'],
            'sentence2': raw_data['sentence2'],
        })
        idx += 1
    return processed_examples

def process_sst5_examples(examples):
    processed_examples = []
    idx = 0
    for raw_data in tqdm(examples,desc='process sst5 examples'):
        processed_examples.append({
            'id': idx,
            'label': raw_data['label'],
            'text': raw_data['text'],
        })
        idx += 1
    return processed_examples

def process_yelp_examples(examples):
    processed_examples = []
    idx = 0
    for raw_data in tqdm(examples,desc='process yelp examples'):
        processed_examples.append({
            'id': idx,
            'label': raw_data['label'],
            'text': raw_data['text'],
        })
        idx += 1
    return processed_examples

def process_sst2_examples(examples):
    processed_examples = []
    idx = 0
    for raw_data in tqdm(examples,desc='process sst2 examples'):
        processed_examples.append({
            'id': idx,
            'label': raw_data['label'],
            'sentence': raw_data['sentence'],
        })
        idx += 1
    return processed_examples

def process_mrpc_examples(examples):
    processed_examples = []
    idx = 0
    for raw_data in tqdm(examples,desc='process mrpc examples'):
        processed_examples.append({
            'id': idx,
            'label': raw_data['label'],
            'sentence1': raw_data['sentence1'],
            'sentence2': raw_data['sentence2'],
        })
        idx += 1
    return processed_examples

def process_boolq_examples(examples):
    processed_examples = []
    idx = 0
    for raw_data in tqdm(examples,desc='process boolq examples'):
        processed_examples.append({
            'id': idx,
            'label': 0 if raw_data['answer'] else 1,
            'question': raw_data['question'],
            'passage': raw_data['passage'],
        })
        idx += 1
    return processed_examples

def process_snli_examples(examples):
    processed_examples = []
    idx = 0
    for raw_data in tqdm(examples,desc='process snli examples'):
        if raw_data['label'] == -1:
            continue
        processed_examples.append({
            'id': idx,
            'label': raw_data['label'],
            'premise': raw_data['premise'],
            'hypothesis': raw_data['hypothesis'],
        })
        idx += 1
    return processed_examples

def process_dbpedia_examples(examples):
    processed_examples = []
    idx = 0
    for raw_data in tqdm(examples,desc='process dbpedia_14 examples'):
        processed_examples.append({
            'id': idx,
            'label': raw_data['label'],
            'title': raw_data['title'],
            'content': raw_data['content'],
        })
        idx += 1
    return processed_examples

def process_trec_examples(examples):
    processed_examples = []
    idx = 0
    for raw_data in tqdm(examples,desc='process trec examples'):
        processed_examples.append({
            'id': idx,
            'label': raw_data['coarse_label'],
            'text': raw_data['text'],
        })
        idx += 1
    return processed_examples

def process_ag_news_examples(examples):
    processed_examples = []
    idx = 0
    for raw_data in tqdm(examples,desc='process ag_news examples'):
        processed_examples.append({
            'id': idx,
            'label': raw_data['label'],
            'text': raw_data['text'],
        })
        idx += 1
    return processed_examples

def process_hellaswag_examples(examples):
    processed_examples = []
    idx = 0
    for raw_data in tqdm(examples,desc='process hellaswag examples'):
        processed_examples.append({
            'id': idx,
            'ctx_a': raw_data['ctx_a'],
            'ctx_b': raw_data['ctx_b'],
            'ctx':raw_data['ctx'],
            'endings':raw_data['endings'],
            'label':int(raw_data['label']),
            'activity_label':raw_data['activity_label']
        })
        idx += 1
    return processed_examples

def process_copa_examples(examples):
    processed_examples = []
    idx = 0
    for raw_data in tqdm(examples,desc='process copa examples'):
        processed_examples.append({
            'id': idx,
            'ctx':raw_data['p'],
            'ask': raw_data['asks-for'],
            'endings':[raw_data['a1'], raw_data['a2']],
            'label':int(raw_data['most-plausible-alternative'])-1,
        })
        idx += 1
    return processed_examples

def process_cosmos_qa_examples(examples):
    processed_examples = []
    idx = 0
    for raw_data in tqdm(examples,desc='process cosmos_qa examples'):
        processed_examples.append({
            'id': idx,
            'question': raw_data['question'],
            'context': raw_data['context'],
            'endings': [raw_data['answer0'], raw_data['answer1'], raw_data['answer2'], raw_data['answer3']],
            'label': int(raw_data['label']),
        })
        idx += 1
    return processed_examples

def process_piqa_examples(examples):
    processed_examples = []
    idx = 0
    for raw_data in tqdm(examples,desc='process piqa examples'):
        processed_examples.append({
            'id': idx,
            'question': raw_data['goal'],
            'endings': [raw_data['sol1'].strip(), raw_data['sol2'].strip()],
            'label': int(raw_data['label']),
        })
        idx += 1
    return processed_examples

def process_commonsense_qa_examples(examples):
    processed_examples = []
    idx = 0
    for raw_data in tqdm(examples,desc='process commonsense_qa examples'):
        processed_examples.append({
            'id': idx,
            'question': raw_data['question'],
            'concept': raw_data['question_concept'],
            'endings': raw_data['choices']['text'],
            'label':['A', 'B', 'C', 'D', 'E'].index(raw_data['answerKey']),
        })
        idx += 1
    return processed_examples

def process_xsum_examples(examples):
    processed_examples = []
    for i,e in enumerate(examples):
        processed_examples.append({
            'id':i,
            'document':e["document"],
            'summary':e["summary"],
            'label':e["summary"],
        })
    return processed_examples

def process_nq_examples(examples):
    processed_examples = []
    for idx,e in enumerate(examples):
        processed_examples.append({
            'id':idx,
            'question':e['question'],
            'short_targets':e['short'],
            'category':e['category'],
            'long': e['long'],
            'label':e['short'],
        })
    return processed_examples

def get_task(args):
    task_name = args.task_name
    data_cache_dir = args.data_cache_dir
    if task_name=='mnli':
        if os.path.isfile(os.path.join(args.output_dir,f'train_examples_seed_{args.seed}.json')) and \
            os.path.isfile(os.path.join(args.output_dir,f'eval_examples_seed_{args.seed}.json')):
            print('use cached examples')
            with open(os.path.join(args.output_dir,f'train_examples_seed_{args.seed}.json')) as f:
                total_train_examples = json.load(f)
            with open(os.path.join(args.output_dir,f'eval_examples_seed_{args.seed}.json')) as f:
                total_eval_examples = json.load(f)
        else:
            mnli_datasets = load_dataset(path=os.path.join(data_cache_dir, 'glue', 'mnli'))
            total_train_examples = [e for e in mnli_datasets['train']]
            total_train_examples = random.sample(total_train_examples, 10000) # 10000
            total_train_examples = process_mnli_examples(total_train_examples)
            total_eval_examples = [e for e in mnli_datasets['validation']]
            total_eval_examples = random.sample(total_eval_examples, 4000) # 4000
            total_eval_examples = process_mnli_examples(total_eval_examples)
            with open(os.path.join(args.output_dir,f'train_examples_seed_{args.seed}.json'),'w') as f:
                json.dump(total_train_examples,f,indent=4)
            with open(os.path.join(args.output_dir,f'eval_examples_seed_{args.seed}.json'),'w') as f:
                json.dump(total_eval_examples,f,indent=4)
        if args.debug:
            args.annotation_size = 10
            args.batch_size = 1
            total_train_examples = total_train_examples[:50]
            total_eval_examples = total_eval_examples[:5]
        def format_example(example,label_map,**kwargs):
            return f"{example['premise']}. Based on that information, is the claim {example['hypothesis']} \"True\", " \
               f"\"False\", or \"Inconclusive\"?\nanswer:", f"{label_map[example['pseudo_label']]}" if 'pseudo_label' in example else f"{label_map[example['label']]}"

        all_train_text_to_encode = ["{}. Based on that information, is the claim {} \"True\", \"False\", or \"Inconclusive\"?"
                                        .format(raw_item["premise"], raw_item["hypothesis"]) for raw_item in total_train_examples]
        all_eval_text_to_encode = ["{}. Based on that information, is the claim {} \"True\", \"False\", or \"Inconclusive\"?"
                                        .format(raw_item["premise"], raw_item["hypothesis"]) for raw_item in total_eval_examples]
        label_map = {0:"True",1:"Inconclusive",2:"False"}
    elif task_name=='snli':
        if os.path.isfile(os.path.join(args.output_dir,f'train_examples_seed_{args.seed}.json')) and \
            os.path.isfile(os.path.join(args.output_dir,f'eval_examples_seed_{args.seed}.json')):
            print('use cached examples')
            with open(os.path.join(args.output_dir,f'train_examples_seed_{args.seed}.json')) as f:
                total_train_examples = json.load(f)
            with open(os.path.join(args.output_dir,f'eval_examples_seed_{args.seed}.json')) as f:
                total_eval_examples = json.load(f)
        else:
            snli_datasets = load_dataset(path=os.path.join(data_cache_dir, 'snli'))
            total_train_examples = [e for e in snli_datasets['train']]
            total_train_examples = random.sample(total_train_examples, 10000) # 10000
            total_train_examples = process_snli_examples(total_train_examples)
            total_eval_examples = [e for e in snli_datasets['validation']]
            total_eval_examples = random.sample(total_eval_examples, 4000) # 4000
            total_eval_examples = process_snli_examples(total_eval_examples)
            with open(os.path.join(args.output_dir,f'train_examples_seed_{args.seed}.json'),'w') as f:
                json.dump(total_train_examples,f,indent=4)
            with open(os.path.join(args.output_dir,f'eval_examples_seed_{args.seed}.json'),'w') as f:
                json.dump(total_eval_examples,f,indent=4)
        if args.debug:
            args.annotation_size = 10
            args.batch_size = 1
            total_train_examples = total_train_examples[:50]
            total_eval_examples = total_eval_examples[:5]
        def format_example(example,label_map,**kwargs):
            return f"{example['premise']}. Based on that information, is the claim {example['hypothesis']} \"entailment\", " \
               f"\"neutral\", or \"contradiction\"?\nanswer:", f"{label_map[example['pseudo_label']]}" if 'pseudo_label' in example else f"{label_map[example['label']]}"

        all_train_text_to_encode = ["{}. Based on that information, is the claim {} \"entailment\", \"neutral\", or \"contradiction\"?"
                                        .format(raw_item["premise"], raw_item["hypothesis"]) for raw_item in total_train_examples]
        all_eval_text_to_encode = ["{}. Based on that information, is the claim {} \"entailment\", \"neutral\", or \"contradiction\"?"
                                        .format(raw_item["premise"], raw_item["hypothesis"]) for raw_item in total_eval_examples]
        label_map = {0:"entailment",1:"neutral",2:"contradiction"}
    elif task_name=='qnli':
        if os.path.isfile(os.path.join(args.output_dir,f'train_examples_seed_{args.seed}.json')) and \
            os.path.isfile(os.path.join(args.output_dir,f'eval_examples_seed_{args.seed}.json')):
            print('use cached examples')
            with open(os.path.join(args.output_dir,f'train_examples_seed_{args.seed}.json')) as f:
                total_train_examples = json.load(f)
            with open(os.path.join(args.output_dir,f'eval_examples_seed_{args.seed}.json')) as f:
                total_eval_examples = json.load(f)
        else:
            qnli_datasets = load_dataset(path=os.path.join(data_cache_dir, 'qnli'))
            total_train_examples = [e for e in qnli_datasets['train']]
            total_train_examples = random.sample(total_train_examples, 10000) # 10000
            total_train_examples = process_qnli_examples(total_train_examples)
            total_eval_examples = [e for e in qnli_datasets['validation']]
            total_eval_examples = random.sample(total_eval_examples, 4000) # 4000
            total_eval_examples = process_qnli_examples(total_eval_examples)
            with open(os.path.join(args.output_dir,f'train_examples_seed_{args.seed}.json'),'w') as f:
                json.dump(total_train_examples,f,indent=4)
            with open(os.path.join(args.output_dir,f'eval_examples_seed_{args.seed}.json'),'w') as f:
                json.dump(total_eval_examples,f,indent=4)
        if args.debug:
            args.annotation_size = 10
            args.batch_size = 1
            total_train_examples = total_train_examples[:50]
            total_eval_examples = total_eval_examples[:5]
        def format_example(example,label_map,**kwargs):
            return f"{example['premise']}. Based on that information, is the claim {example['hypothesis']} \"entailment\", " \
               f"or \"contradiction\"?\nanswer:", f"{label_map[example['pseudo_label']]}" if 'pseudo_label' in example else f"{label_map[example['label']]}"

        all_train_text_to_encode = ["{}. Based on that information, is the claim {} \"entailment\", or \"contradiction\"?"
                                        .format(raw_item["premise"], raw_item["hypothesis"]) for raw_item in total_train_examples]
        all_eval_text_to_encode = ["{}. Based on that information, is the claim {} \"entailment\", or \"contradiction\"?"
                                        .format(raw_item["premise"], raw_item["hypothesis"]) for raw_item in total_eval_examples]
        label_map = {0:"entailment",1:"contradiction"}
    elif task_name=='boolq':
        if os.path.isfile(os.path.join(args.output_dir,f'train_examples_seed_{args.seed}.json')) and \
            os.path.isfile(os.path.join(args.output_dir,f'eval_examples_seed_{args.seed}.json')):
            print('use cached examples')
            with open(os.path.join(args.output_dir,f'train_examples_seed_{args.seed}.json')) as f:
                total_train_examples = json.load(f)
            with open(os.path.join(args.output_dir,f'eval_examples_seed_{args.seed}.json')) as f:
                total_eval_examples = json.load(f)
        else:
            boolq_datasets = load_dataset(path=os.path.join(data_cache_dir, 'boolq'))
            total_train_examples = [e for e in boolq_datasets['train']]
            # total_train_examples = random.sample(total_train_examples, 10000) # None
            total_train_examples = process_boolq_examples(total_train_examples)
            total_eval_examples = [e for e in boolq_datasets['validation']]
            # total_eval_examples = random.sample(total_eval_examples, 4000) # 4000
            total_eval_examples = process_boolq_examples(total_eval_examples)
            with open(os.path.join(args.output_dir,f'train_examples_seed_{args.seed}.json'),'w') as f:
                json.dump(total_train_examples,f,indent=4)
            with open(os.path.join(args.output_dir,f'eval_examples_seed_{args.seed}.json'),'w') as f:
                json.dump(total_eval_examples,f,indent=4)
        if args.debug:
            args.annotation_size = 10
            args.batch_size = 1
            total_train_examples = total_train_examples[:50]
            total_eval_examples = total_eval_examples[:5]
        def format_example(example,label_map,**kwargs):
            return f"{example['passage']}. Based on that information, is the claim {example['question']} \"True\", " \
               f"or \"False\"?\nanswer:", f"{label_map[example['pseudo_label']]}" if 'pseudo_label' in example else f"{label_map[example['label']]}"

        all_train_text_to_encode = ["{}. Based on that information, is the claim {} \"True\", or \"False\"?"
                                        .format(raw_item["passage"], raw_item["question"]) for raw_item in total_train_examples]
        all_eval_text_to_encode = ["{}. Based on that information, is the claim {} \"True\", or \"False\"?"
                                        .format(raw_item["passage"], raw_item["question"]) for raw_item in total_eval_examples]
        label_map = {0:"True",1:"False"}
    elif task_name=='sst2':
        if os.path.isfile(os.path.join(args.output_dir,f'train_examples_seed_{args.seed}.json')) and \
            os.path.isfile(os.path.join(args.output_dir,f'eval_examples_seed_{args.seed}.json')):
            print('use cached examples')
            with open(os.path.join(args.output_dir,f'train_examples_seed_{args.seed}.json')) as f:
                total_train_examples = json.load(f)
            with open(os.path.join(args.output_dir,f'eval_examples_seed_{args.seed}.json')) as f:
                total_eval_examples = json.load(f)
        else:
            sst2_datasets = load_dataset(path=os.path.join(data_cache_dir, 'sst2'))
            total_train_examples = [e for e in sst2_datasets['train']]
            total_train_examples = random.sample(total_train_examples, 10000) # 10000
            total_train_examples = process_sst2_examples(total_train_examples)
            total_eval_examples = [e for e in sst2_datasets['validation']]
            total_eval_examples = process_sst2_examples(total_eval_examples)
            with open(os.path.join(args.output_dir,f'train_examples_seed_{args.seed}.json'),'w') as f:
                json.dump(total_train_examples,f,indent=4)
            with open(os.path.join(args.output_dir,f'eval_examples_seed_{args.seed}.json'),'w') as f:
                json.dump(total_eval_examples,f,indent=4)
        if args.debug:
            args.annotation_size = 10
            args.batch_size = 1
            total_train_examples = total_train_examples[:50]
            total_eval_examples = total_eval_examples[:5]
        def format_example(example,label_map,**kwargs):
            return f"How do you feel about the following sentence?\n{example['sentence']}\nanswer:",\
                   f"{label_map[example['pseudo_label']]}" if 'pseudo_label' in example else f"{label_map[example['label']]}"

        all_train_text_to_encode = [raw_item["sentence"] for raw_item in total_train_examples]
        all_eval_text_to_encode = [raw_item["sentence"] for raw_item in total_eval_examples]
        label_map = {0:"negative", 1:"positive"}
    elif task_name=='rte':
        if os.path.isfile(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) and \
                os.path.isfile(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')):
            print('use cached examples')
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) as f:
                total_train_examples = json.load(f)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')) as f:
                total_eval_examples = json.load(f)
        else:
            rte_datasets = load_dataset(path=os.path.join(data_cache_dir, 'glue', 'rte'))
            total_train_examples = [e for e in rte_datasets['train']]
            # total_train_examples = random.sample(total_train_examples, 250) # delete
            total_train_examples = process_rte_examples(total_train_examples)
            total_eval_examples = [e for e in rte_datasets['validation']]
            # total_eval_examples = random.sample(total_eval_examples, 256)
            total_eval_examples = process_rte_examples(total_eval_examples)
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_train_examples, f, indent=4)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_eval_examples, f, indent=4)
        if args.debug:
            args.annotation_size = 10
            args.batch_size = 1
            total_train_examples = total_train_examples[:50]
            total_eval_examples = total_eval_examples[:5]
        # def format_example(example,label_map,**kwargs):
        #     return f"{example['sentence1']}.\nquestion: {example['sentence2']}. True or False?\nanswer:",\
        #            f"{label_map[example['pseudo_label']]}" if 'pseudo_label' in example else f"{label_map[example['label']]}"

        # all_train_text_to_encode = ["{}.\nquestion: {}".format(raw_item["sentence1"], raw_item["sentence2"])
        #                             for raw_item in total_train_examples]
        # all_eval_text_to_encode = ["{}.\nquestion: {}".format(raw_item["sentence1"], raw_item["sentence2"])
        #                             for raw_item in total_eval_examples]
        # label_map = {0:"True",1:"False"}
        def format_example(example,label_map,**kwargs):
            return f"{example['sentence1']}. Based on that information, is the claim {example['sentence2']} \"entailment\", " \
               f"or \"contradiction\"?\nanswer:", f"{label_map[example['pseudo_label']]}" if 'pseudo_label' in example else f"{label_map[example['label']]}"

        all_train_text_to_encode = ["{}. Based on that information, is the claim {} \"entailment\", or \"contradiction\"?"
                                        .format(raw_item["sentence1"], raw_item["sentence2"]) for raw_item in total_train_examples]
        all_eval_text_to_encode = ["{}. Based on that information, is the claim {} \"entailment\", or \"contradiction\"?"
                                        .format(raw_item["sentence1"], raw_item["sentence2"]) for raw_item in total_eval_examples]
        label_map = {0:"entailment",1:"contradiction"}
    elif task_name=='sst5':
        if os.path.isfile(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) and \
                os.path.isfile(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')):
            print('use cached examples')
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) as f:
                total_train_examples = json.load(f)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')) as f:
                total_eval_examples = json.load(f)
        else:
            sst5_datasets = load_dataset(path=os.path.join(data_cache_dir, 'sst5'))
            total_train_examples = [e for e in sst5_datasets['train']]
            # total_train_examples = random.sample(total_train_examples, 1000) # None
            total_train_examples = process_sst5_examples(total_train_examples)
            total_eval_examples = [e for e in sst5_datasets['test']]
            # total_eval_examples = random.sample(total_eval_examples, 256) # None
            total_eval_examples = process_sst5_examples(total_eval_examples)
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_train_examples, f, indent=4)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_eval_examples, f, indent=4)
        if args.debug:
            args.annotation_size = 10
            args.batch_size = 1
            total_train_examples = total_train_examples[:50]
            total_eval_examples = total_eval_examples[:5]
        def format_example(example,label_map,**kwargs):
            return f"How do you feel about the following sentence?\n{example['text']}\nanswer:",\
                   f"{label_map[example['pseudo_label']]}" if 'pseudo_label' in example else f"{label_map[example['label']]}"

        all_train_text_to_encode = [raw_item["text"] for raw_item in total_train_examples]
        all_eval_text_to_encode = [raw_item["text"] for raw_item in total_eval_examples]
        label_map = {0:"very negative",1:"negative",2:"neutral",3:"positive",4:"very positive"}
    elif task_name=='yelp':
        if os.path.isfile(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) and \
                os.path.isfile(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')):
            print('use cached examples')
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) as f:
                total_train_examples = json.load(f)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')) as f:
                total_eval_examples = json.load(f)
        else:
            yelp_datasets = load_dataset(path=os.path.join(data_cache_dir, 'yelp'))
            total_train_examples = [e for e in yelp_datasets['train']]
            total_train_examples = random.sample(total_train_examples, 10000) # None
            total_train_examples = process_yelp_examples(total_train_examples)
            total_eval_examples = [e for e in yelp_datasets['test']]
            total_eval_examples = random.sample(total_eval_examples, 4000) # None
            total_eval_examples = process_yelp_examples(total_eval_examples)
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_train_examples, f, indent=4)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_eval_examples, f, indent=4)
        if args.debug:
            args.annotation_size = 10
            args.batch_size = 1
            total_train_examples = total_train_examples[:50]
            total_eval_examples = total_eval_examples[:5]
        def format_example(example,label_map,**kwargs):
            return f"How do you feel about the following sentence?\n{example['text']}\nanswer:",\
                   f"{label_map[example['pseudo_label']]}" if 'pseudo_label' in example else f"{label_map[example['label']]}"

        all_train_text_to_encode = [raw_item["text"] for raw_item in total_train_examples]
        all_eval_text_to_encode = [raw_item["text"] for raw_item in total_eval_examples]
        label_map = {0:"terrible",1:"bad",2:"okay",3:"good",4:"great"}
    elif task_name=='mrpc':
        if os.path.isfile(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) and \
                os.path.isfile(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')):
            print('use cached examples')
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) as f:
                total_train_examples = json.load(f)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')) as f:
                total_eval_examples = json.load(f)
        else:
            mrpc_datasets = load_dataset(path=os.path.join(data_cache_dir, 'glue', 'mrpc'))
            total_train_examples = [e for e in mrpc_datasets['train']]
            # total_train_examples = random.sample(total_train_examples, 1000) # None
            total_train_examples = process_mrpc_examples(total_train_examples)
            total_eval_examples = [e for e in mrpc_datasets['validation']]
            # total_eval_examples = random.sample(total_eval_examples, 256) # None
            total_eval_examples = process_mrpc_examples(total_eval_examples)
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_train_examples, f, indent=4)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_eval_examples, f, indent=4)
        if args.debug:
            args.annotation_size = 10
            args.batch_size = 1
            total_train_examples = total_train_examples[:50]
            total_eval_examples = total_eval_examples[:5]
        def format_example(example,label_map,**kwargs):
            return f"Are the following two sentences 'equivalent' or 'not equivalent'?\n" \
                   f"{example['sentence1']}.\n{example['sentence2']}.\nanswer:",\
                   f"{label_map[example['pseudo_label']]}" if 'pseudo_label' in example else f"{label_map[example['label']]}"

        all_train_text_to_encode = ["{}.\n{}".format(raw_item["sentence1"], raw_item["sentence2"])
                                    for raw_item in total_train_examples]
        all_eval_text_to_encode = ["{}.\n{}".format(raw_item["sentence1"], raw_item["sentence2"])
                                   for raw_item in total_eval_examples]
        label_map = {0:"not equivalent",1:"equivalent"}
    elif task_name=='dbpedia_14':
        if os.path.isfile(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) and \
                os.path.isfile(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')):
            print('use cached examples')
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) as f:
                total_train_examples = json.load(f)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')) as f:
                total_eval_examples = json.load(f)
        else:
            dbpedia_datasets = load_dataset(path=os.path.join(data_cache_dir, 'd_bpedia14'))
            total_train_examples = [e for e in dbpedia_datasets['train']]
            total_train_examples = random.sample(total_train_examples, 10000) # 10000
            total_train_examples = process_dbpedia_examples(total_train_examples)
            total_eval_examples = [e for e in dbpedia_datasets['test']]
            total_eval_examples = random.sample(total_eval_examples, 4000) # 4000
            total_eval_examples = process_dbpedia_examples(total_eval_examples)
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_train_examples, f, indent=4)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_eval_examples, f, indent=4)
        if args.debug:
            args.annotation_size = 10
            args.batch_size = 1
            total_train_examples = total_train_examples[:50]
            total_eval_examples = total_eval_examples[:5]
        def format_example(example,label_map,**kwargs):
            return f"title: {example['title']}; content: {example['content']}",\
                   f"{label_map[example['pseudo_label']]}" if 'pseudo_label' in example else f"{label_map[example['label']]}"

        all_train_text_to_encode = ["title: {} ; content: {}".format(raw_item["title"], raw_item["content"])
                                    for raw_item in total_train_examples]
        all_eval_text_to_encode = ["title: {} ; content: {}".format(raw_item["title"], raw_item["content"])
                                   for raw_item in total_eval_examples]
        label_map = {0: "company",1: "educational institution",2: "artist",3: "athlete",4: "office holder",
            5: "mean of transportation",6: "building",7: "natural place",8: "village",9: "animal",10: "plant",
            11: "album",12: "film",13: "written work"}
    elif task_name=='ag_news':
        if os.path.isfile(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) and \
                os.path.isfile(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')):
            print('use cached examples')
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) as f:
                total_train_examples = json.load(f)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')) as f:
                total_eval_examples = json.load(f)
        else:
            ag_news_datasets = load_dataset(path=os.path.join(data_cache_dir, 'ag_news'))
            total_train_examples = [e for e in ag_news_datasets['train']]
            total_train_examples = random.sample(total_train_examples, 10000) # 10000
            total_train_examples = process_ag_news_examples(total_train_examples)
            total_eval_examples = [e for e in ag_news_datasets['test']]
            total_eval_examples = random.sample(total_eval_examples, 4000) # 4000
            total_eval_examples = process_ag_news_examples(total_eval_examples)
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_train_examples, f, indent=4)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_eval_examples, f, indent=4)
        if args.debug:
            args.annotation_size = 10
            args.batch_size = 1
            total_train_examples = total_train_examples[:50]
            total_eval_examples = total_eval_examples[:5]
        def format_example(example,label_map,**kwargs):
            return f"content: {example['text']}",\
                   f"{label_map[example['pseudo_label']]}" if 'pseudo_label' in example else f"{label_map[example['label']]}"

        all_train_text_to_encode = ["content: {}".format(raw_item["text"])
                                    for raw_item in total_train_examples]
        all_eval_text_to_encode = ["content: {}".format(raw_item["text"])
                                   for raw_item in total_eval_examples]
        label_map = {0: "World",1: "Sports",2: "Business",3: "Sci/Tech"}
    elif task_name=='trec':
        if os.path.isfile(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) and \
                os.path.isfile(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')):
            print('use cached examples')
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) as f:
                total_train_examples = json.load(f)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')) as f:
                total_eval_examples = json.load(f)
        else:
            trec_datasets = load_dataset(path=os.path.join(data_cache_dir, 'trec'))
            total_train_examples = [e for e in trec_datasets['train']]
            # total_train_examples = random.sample(total_train_examples, 10000) # 10000
            total_train_examples = process_trec_examples(total_train_examples)
            total_eval_examples = [e for e in trec_datasets['test']]
            # total_eval_examples = random.sample(total_eval_examples, 4000) # 4000
            total_eval_examples = process_trec_examples(total_eval_examples)
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_train_examples, f, indent=4)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_eval_examples, f, indent=4)
        if args.debug:
            args.annotation_size = 10
            args.batch_size = 1
            total_train_examples = total_train_examples[:50]
            total_eval_examples = total_eval_examples[:5]
        def format_example(example,label_map,**kwargs):
            return f"content: {example['text']}",\
                   f"{label_map[example['pseudo_label']]}" if 'pseudo_label' in example else f"{label_map[example['label']]}"

        all_train_text_to_encode = ["content: {}".format(raw_item["text"])
                                    for raw_item in total_train_examples]
        all_eval_text_to_encode = ["content: {}".format(raw_item["text"])
                                   for raw_item in total_eval_examples]
        label_map = {0: "abbreviation",1: "entity",2: "description and abstract concept",3: "human being",4:"location",5:"numeric value"}
    elif task_name=='subj':
        if os.path.isfile(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) and \
                os.path.isfile(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')):
            print('use cached examples')
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) as f:
                total_train_examples = json.load(f)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')) as f:
                total_eval_examples = json.load(f)
        else:
            subj_datasets = load_dataset(path=os.path.join(data_cache_dir, 'subj'))
            total_train_examples = [e for e in subj_datasets['train']]
            # total_train_examples = random.sample(total_train_examples, 10000) # 10000
            total_train_examples = process_subj_examples(total_train_examples)
            total_eval_examples = [e for e in subj_datasets['test']]
            # total_eval_examples = random.sample(total_eval_examples, 4000) # 4000
            total_eval_examples = process_subj_examples(total_eval_examples)
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_train_examples, f, indent=4)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_eval_examples, f, indent=4)
        if args.debug:
            args.annotation_size = 10
            args.batch_size = 1
            total_train_examples = total_train_examples[:50]
            total_eval_examples = total_eval_examples[:5]
        def format_example(example,label_map,**kwargs):
            return f"Input: {example['text']} \nType:",\
                   f"{label_map[example['pseudo_label']]}" if 'pseudo_label' in example else f"{label_map[example['label']]}"

        all_train_text_to_encode = ["Input: {} \nType:".format(raw_item["text"])
                                    for raw_item in total_train_examples]
        all_eval_text_to_encode = ["Input: {} \nType:".format(raw_item["text"])
                                   for raw_item in total_eval_examples]
        label_map = {0: "objective",1: "subjective"}
    elif task_name=='scicite':
        if os.path.isfile(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) and \
                os.path.isfile(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')):
            print('use cached examples')
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) as f:
                total_train_examples = json.load(f)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')) as f:
                total_eval_examples = json.load(f)
        else:
            scicite_datasets = load_dataset(path=os.path.join(data_cache_dir, 'scicite'))
            total_train_examples = [e for e in scicite_datasets['train']]
            # total_train_examples = random.sample(total_train_examples, 10000) # 10000
            total_train_examples = process_scicite_examples(total_train_examples)
            total_eval_examples = [e for e in scicite_datasets['test']]
            # total_eval_examples = random.sample(total_eval_examples, 4000) # 4000
            total_eval_examples = process_scicite_examples(total_eval_examples)
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_train_examples, f, indent=4)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_eval_examples, f, indent=4)
        if args.debug:
            args.annotation_size = 10
            args.batch_size = 1
            total_train_examples = total_train_examples[:50]
            total_eval_examples = total_eval_examples[:5]
        def format_example(example,label_map,**kwargs):
            return f"Is the following citation from a scientific paper describing a \"method\", a \"result\", or \"background\"?\n{example['text']}\nanswer:",\
                   f"{label_map[example['pseudo_label']]}" if 'pseudo_label' in example else f"{label_map[example['label']]}"

        all_train_text_to_encode = ["{}".format(raw_item["text"])
                                    for raw_item in total_train_examples]
        all_eval_text_to_encode = ["{}".format(raw_item["text"])
                                   for raw_item in total_eval_examples]
        label_map = {0: "method",1: "background", 2:"result"}
    elif task_name=='maven':
        if os.path.isfile(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) and \
                os.path.isfile(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')):
            print('use cached examples')
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) as f:
                total_train_examples = json.load(f)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')) as f:
                total_eval_examples = json.load(f)
        else:
            dev_ds_path = f"{data_cache_dir}/{args.task_name}/en_cls-MAVEN-test.json"
            train_ds_path = f"{data_cache_dir}/{args.task_name}/en_cls-MAVEN-train.json"
            maven_train_datasets = load_customize_dataset(train_ds_path, args=args)
            maven_dev_datasets = load_customize_dataset(dev_ds_path, args=args)
            total_train_examples = [e for e in maven_train_datasets]
            total_train_examples = random.sample(total_train_examples, 3000) # 10000
            total_train_examples = process_maven_examples(total_train_examples)
            total_eval_examples = [e for e in maven_dev_datasets]
            total_eval_examples = random.sample(total_eval_examples, 256) # 350
            total_eval_examples = process_maven_examples(total_eval_examples)
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_train_examples, f, indent=4)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_eval_examples, f, indent=4)
        if args.debug:
            args.annotation_size = 10
            args.batch_size = 1
            total_train_examples = total_train_examples[:50]
            total_eval_examples = total_eval_examples[:5]
        def format_example(example,label_map,**kwargs):
            return f"{example['text']} It is about",\
                   f"{label_map[example['pseudo_label']]}" if 'pseudo_label' in example else f"{label_map[example['label']]}"

        all_train_text_to_encode = [f"{raw_item['text']} It is about" for raw_item in total_train_examples]
        all_eval_text_to_encode = [f"{raw_item['text']} It is about" for raw_item in total_eval_examples]
        label_map = maven_label_map
    elif task_name=='clinc':
        if os.path.isfile(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) and \
                os.path.isfile(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')):
            print('use cached examples')
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) as f:
                total_train_examples = json.load(f)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')) as f:
                total_eval_examples = json.load(f)
        else:
            dev_ds_path = f"{data_cache_dir}/{args.task_name}/en_clinc-test.json"
            train_ds_path = f"{data_cache_dir}/{args.task_name}/en_clinc-train.json"
            clinc_train_datasets = load_customize_dataset(train_ds_path, args=args)
            clinc_dev_datasets = load_customize_dataset(dev_ds_path, args=args)
            total_train_examples = [e for e in clinc_train_datasets]
            total_train_examples = random.sample(total_train_examples, 3000) # 10000
            total_train_examples = process_clinc_examples(total_train_examples)
            total_eval_examples = [e for e in clinc_dev_datasets]
            total_eval_examples = random.sample(total_eval_examples, 256) # 350
            total_eval_examples = process_clinc_examples(total_eval_examples)
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_train_examples, f, indent=4)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_eval_examples, f, indent=4)
        if args.debug:
            args.annotation_size = 10
            args.batch_size = 1
            total_train_examples = total_train_examples[:50]
            total_eval_examples = total_eval_examples[:5]
        def format_example(example,label_map,**kwargs):
            return f"{example['text']} It is about",\
                   f"{label_map[example['pseudo_label']]}" if 'pseudo_label' in example else f"{label_map[example['label']]}"

        all_train_text_to_encode = [f"{raw_item['text']} It is about" for raw_item in total_train_examples]
        all_eval_text_to_encode = [f"{raw_item['text']} It is about" for raw_item in total_eval_examples]
        label_map = clinc_label_map
    elif task_name=='semeval':
        if os.path.isfile(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) and \
                os.path.isfile(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')):
            print('use cached examples')
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) as f:
                total_train_examples = json.load(f)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')) as f:
                total_eval_examples = json.load(f)
        else:
            dev_ds_path = f"{data_cache_dir}/{args.task_name}/en_cls-Semeval-test.json"
            train_ds_path = f"{data_cache_dir}/{args.task_name}/en_cls-Semeval-train.json"
            semeval_train_datasets = load_customize_dataset(train_ds_path, args=args)
            semeval_dev_datasets = load_customize_dataset(dev_ds_path, args=args)
            total_train_examples = [e for e in semeval_train_datasets]
            total_train_examples = random.sample(total_train_examples, 3000) # None
            total_train_examples = process_semeval_examples(total_train_examples)
            total_eval_examples = [e for e in semeval_dev_datasets]
            total_eval_examples = random.sample(total_eval_examples, 256) # None
            total_eval_examples = process_semeval_examples(total_eval_examples)
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_train_examples, f, indent=4)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_eval_examples, f, indent=4)
        if args.debug:
            args.annotation_size = 10
            args.batch_size = 1
            total_train_examples = total_train_examples[:50]
            total_eval_examples = total_eval_examples[:5]
        def format_example(example,label_map,**kwargs):
            return f"{example['text']} is",\
                   f"{label_map[example['pseudo_label']]}" if 'pseudo_label' in example else f"{label_map[example['label']]}"

        all_train_text_to_encode = [f"{raw_item['text']} is" for raw_item in total_train_examples]
        all_eval_text_to_encode = [f"{raw_item['text']} is" for raw_item in total_eval_examples]
        label_map = semeval_label_map
    elif task_name=='hellaswag':
        if os.path.isfile(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) and \
                os.path.isfile(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')):
            print('use cached examples')
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) as f:
                total_train_examples = json.load(f)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')) as f:
                total_eval_examples = json.load(f)
        else:
            hellaswag_datasets = load_dataset(path=os.path.join(data_cache_dir, 'hellaswag'))
            total_train_examples = [e for e in hellaswag_datasets['train']]
            total_train_examples = random.sample(total_train_examples, 10000) # 10000
            total_train_examples = process_hellaswag_examples(total_train_examples)
            total_eval_examples = [e for e in hellaswag_datasets['validation']]
            total_eval_examples = random.sample(total_eval_examples, 4000) # 4000
            total_eval_examples = process_hellaswag_examples(total_eval_examples)
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_train_examples, f, indent=4)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_eval_examples, f, indent=4)
        if args.debug:
            args.annotation_size = 10
            args.batch_size = 1
            total_train_examples = total_train_examples[:50]
            total_eval_examples = total_eval_examples[:5]
        def format_example(example,label_map,**kwargs):
            return f"The topic is {example['activity_label']}. {example['ctx_a']} " \
                   f"{example['ctx_b']} ",f"{example['endings'][example['pseudo_label']]}" if 'pseudo_label' in example else f"{example['endings'][example['label']]}"

        all_train_text_to_encode = [f"The topic is {raw_item['activity_label']}. {raw_item['ctx_a']} {raw_item['ctx_b']} | " \
                                  f"{raw_item['endings'][0]} | " \
                                  f"{raw_item['endings'][1]} | " \
                                  f"{raw_item['endings'][2]} | " \
                                  f"{raw_item['endings'][3]}" for raw_item in total_train_examples]
        all_eval_text_to_encode = [f"The topic is {raw_item['activity_label']}. {raw_item['ctx_a']} {raw_item['ctx_b']} | " \
                                  f"{raw_item['endings'][0]} | " \
                                  f"{raw_item['endings'][1]} | " \
                                  f"{raw_item['endings'][2]} | " \
                                  f"{raw_item['endings'][3]}" for raw_item in total_eval_examples]
        label_map = None
    elif task_name=='copa':
        if os.path.isfile(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) and \
                os.path.isfile(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')):
            print('use cached examples')
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) as f:
                total_train_examples = json.load(f)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')) as f:
                total_eval_examples = json.load(f)
        else:
            copa_datasets = load_dataset(path=os.path.join(data_cache_dir, 'copa'))
            total_train_examples = [e for e in copa_datasets['validation']]
            # total_train_examples = random.sample(total_train_examples, 10000) # 10000
            total_train_examples = process_copa_examples(total_train_examples)
            total_eval_examples = [e for e in copa_datasets['test']]
            # total_eval_examples = random.sample(total_eval_examples, 10) # 4000
            total_eval_examples = process_copa_examples(total_eval_examples)
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_train_examples, f, indent=4)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_eval_examples, f, indent=4)
        if args.debug:
            args.annotation_size = 10
            args.batch_size = 1
            total_train_examples = total_train_examples[:50]
            total_eval_examples = total_eval_examples[:5]
        def format_example(example,label_map,**kwargs):
            return f"{example['ctx']}. What was the {example['ask']} of this? " \
                   ,f"{example['endings'][example['pseudo_label']]}" if 'pseudo_label' in example else f"{example['endings'][example['label']]}"

        all_train_text_to_encode = [f"{raw_item['ctx']}. What was the {raw_item['ask']} of this? | " \
                                  f"{raw_item['endings'][0]} | " \
                                  f"{raw_item['endings'][1]}" for raw_item in total_train_examples]
        all_eval_text_to_encode = [f"{raw_item['ctx']}. What was the {raw_item['ask']} of this? | " \
                                  f"{raw_item['endings'][0]} | " \
                                  f"{raw_item['endings'][1]}" for raw_item in total_eval_examples]
        label_map = None
    elif task_name=='commonsense_qa':
        if os.path.isfile(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) and \
                os.path.isfile(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')):
            print('use cached examples')
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) as f:
                total_train_examples = json.load(f)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')) as f:
                total_eval_examples = json.load(f)
        else:
            csqa_datasets = load_dataset(path=os.path.join(data_cache_dir, 'commonsense_qa'))
            total_train_examples = [e for e in csqa_datasets['train']]
            # total_train_examples = random.sample(total_train_examples, 10000) # 10000
            total_train_examples = process_commonsense_qa_examples(total_train_examples)
            total_eval_examples = [e for e in csqa_datasets['validation']]
            # total_eval_examples = random.sample(total_eval_examples, 4000) # 4000
            total_eval_examples = process_commonsense_qa_examples(total_eval_examples)
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_train_examples, f, indent=4)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_eval_examples, f, indent=4)
        if args.debug:
            args.annotation_size = 10
            args.batch_size = 1
            total_train_examples = total_train_examples[:50]
            total_eval_examples = total_eval_examples[:5]
        def format_example(example,label_map,**kwargs):
            return f"The topic is {example['concept']}. {example['question']} " \
                   ,f"{example['endings'][example['pseudo_label']]}" if 'pseudo_label' in example else f"{example['endings'][example['label']]}"

        all_train_text_to_encode = [f"The topic is {raw_item['concept']}. {raw_item['question']} | " \
                                  f"{raw_item['endings'][0]} | " \
                                  f"{raw_item['endings'][1]} | " \
                                  f"{raw_item['endings'][2]} | " \
                                  f"{raw_item['endings'][3]} | " \
                                  f"{raw_item['endings'][4]}" for raw_item in total_train_examples]
        all_eval_text_to_encode = [f"The topic is {raw_item['concept']}. {raw_item['question']} | " \
                                  f"{raw_item['endings'][0]} | " \
                                  f"{raw_item['endings'][1]} | " \
                                  f"{raw_item['endings'][2]} | " \
                                  f"{raw_item['endings'][3]} | " \
                                  f"{raw_item['endings'][4]}" for raw_item in total_eval_examples]
        label_map = None
    elif task_name=='cosmos_qa':
        if os.path.isfile(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) and \
                os.path.isfile(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')):
            print('use cached examples')
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) as f:
                total_train_examples = json.load(f)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')) as f:
                total_eval_examples = json.load(f)
        else:
            cosmos_qa_datasets = load_dataset(path=os.path.join(data_cache_dir, 'cosmos_qa'))
            total_train_examples = [e for e in cosmos_qa_datasets['train']]
            total_train_examples = random.sample(total_train_examples, 10000) # 10000
            total_train_examples = process_cosmos_qa_examples(total_train_examples)
            total_eval_examples = [e for e in cosmos_qa_datasets['validation']]
            # total_eval_examples = random.sample(total_eval_examples, 4000) # 4000
            total_eval_examples = process_cosmos_qa_examples(total_eval_examples)
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_train_examples, f, indent=4)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_eval_examples, f, indent=4)
        if args.debug:
            args.annotation_size = 10
            args.batch_size = 1
            total_train_examples = total_train_examples[:50]
            total_eval_examples = total_eval_examples[:5]
        def format_example(example,label_map,**kwargs):
            return f"{example['context']}. {example['question']} " \
                   ,f"{example['endings'][example['pseudo_label']]}" if 'pseudo_label' in example else f"{example['endings'][example['label']]}"

        all_train_text_to_encode = [f"{raw_item['context']}. {raw_item['question']} | " \
                                  f"{raw_item['endings'][0]} | " \
                                  f"{raw_item['endings'][1]} | " \
                                  f"{raw_item['endings'][2]} | " \
                                  f"{raw_item['endings'][3]}" for raw_item in total_train_examples]
        all_eval_text_to_encode = [f"{raw_item['context']}. {raw_item['question']} | " \
                                  f"{raw_item['endings'][0]} | " \
                                  f"{raw_item['endings'][1]} | " \
                                  f"{raw_item['endings'][2]} | " \
                                  f"{raw_item['endings'][3]}" for raw_item in total_eval_examples]
        label_map = None
    elif task_name=='piqa':
        if os.path.isfile(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) and \
                os.path.isfile(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')):
            print('use cached examples')
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) as f:
                total_train_examples = json.load(f)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')) as f:
                total_eval_examples = json.load(f)
        else:
            piqa_datasets = load_dataset(path=os.path.join(data_cache_dir, 'piqa'))
            total_train_examples = [e for e in piqa_datasets['train']]
            total_train_examples = random.sample(total_train_examples, 10000) # 10000
            total_train_examples = process_piqa_examples(total_train_examples)
            total_eval_examples = [e for e in piqa_datasets['validation']]
            # total_eval_examples = random.sample(total_eval_examples, 4000) # 4000
            total_eval_examples = process_piqa_examples(total_eval_examples)
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_train_examples, f, indent=4)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_eval_examples, f, indent=4)
        if args.debug:
            args.annotation_size = 10
            args.batch_size = 1
            total_train_examples = total_train_examples[:50]
            total_eval_examples = total_eval_examples[:5]
        def format_example(example,label_map,**kwargs):
            return f"{example['question']} " \
                   ,f"{example['endings'][example['pseudo_label']]}" if 'pseudo_label' in example else f"{example['endings'][example['label']]}"

        all_train_text_to_encode = [f"{raw_item['question']}. | " \
                                  f"{raw_item['endings'][0]} | " \
                                  f"{raw_item['endings'][1]}" for raw_item in total_train_examples]
        all_eval_text_to_encode = [f"{raw_item['question']}. | " \
                                  f"{raw_item['endings'][0]} | " \
                                  f"{raw_item['endings'][1]}" for raw_item in total_eval_examples]
        label_map = None
    elif task_name == 'xsum':
        if os.path.isfile(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) and \
                os.path.isfile(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')):
            print('use cached examples')
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) as f:
                total_train_examples = json.load(f)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')) as f:
                total_eval_examples = json.load(f)
        else:
            xsum_dataset = load_dataset(path=os.path.join(data_cache_dir, 'xsum'))
            total_train_examples = [e for e in xsum_dataset['train']]
            total_train_examples = random.sample(total_train_examples, 3000) # 10000
            total_train_examples = process_xsum_examples(total_train_examples)
            total_eval_examples = [e for e in xsum_dataset['test']]
            total_eval_examples = random.sample(total_eval_examples, 256) # 256
            total_eval_examples = process_xsum_examples(total_eval_examples)
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_train_examples, f, indent=4)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_eval_examples, f, indent=4)
        if args.debug:
            args.annotation_size = 10
            args.batch_size = 1
            total_train_examples = total_train_examples[:50]
            total_eval_examples = total_eval_examples[:5]
        def format_example(example,label_map,**kwargs):
            return f"write a short summary:\n{example['document']}\nTL;DR:",f"{example['summary']}"

        all_train_text_to_encode = [raw_item['document']
                                    for raw_item in total_train_examples]
        all_eval_text_to_encode = [raw_item['document']
                                   for raw_item in total_eval_examples]
        label_map = None
    elif task_name == 'nq':
        if os.path.isfile(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) and \
                os.path.isfile(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')):
            print('use cached examples')
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json')) as f:
                total_train_examples = json.load(f)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json')) as f:
                total_eval_examples = json.load(f)
        else:
            nq_dataset = load_dataset('natural_questions', cache_dir=data_cache_dir)
            first_sub_sample_indices = random.sample(range(len(nq_dataset['train'])), 12000)
            train_data = nq_dataset['train'].select(first_sub_sample_indices).map(format_nq_dataset)
            total_train_examples = train_data.remove_columns(["annotations", "document", "id"]).filter(
                lambda x: x['category'] != "null")
            total_train_examples = [e for e in total_train_examples]
            total_train_examples = random.sample(total_train_examples, 3000)
            total_train_examples = process_nq_examples(total_train_examples)
            total_eval_examples = nq_dataset['validation'].map(format_nq_dataset).remove_columns(
                ["annotations", "document", "id"]).filter(lambda x: x['category'] != "null")
            total_eval_examples = [e for e in total_eval_examples]
            total_eval_examples = random.sample(total_eval_examples, 256)
            total_eval_examples = process_nq_examples(total_eval_examples)
            with open(os.path.join(args.output_dir, f'train_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_train_examples, f, indent=4)
            with open(os.path.join(args.output_dir, f'eval_examples_seed_{args.seed}.json'), 'w') as f:
                json.dump(total_eval_examples, f, indent=4)
        if args.debug:
            args.annotation_size = 10
            args.batch_size = 1
            total_train_examples = total_train_examples[:50]
            total_eval_examples = total_eval_examples[:5]

        def format_example(example, label_map, **kwargs):
            if example['category'] in ['yes', 'no']:
                return f"Write an answer: {example['question']}\nclass", f"{example['category']}"
            assert example['category'] == 'other', example['category']
            assert len(example['short_targets']) > 0, f"{example['short_targets']}"
            return f"Write an answer: {example['question']}\n{example['category']} ", f"{example['short_targets'][0]}"

        all_train_text_to_encode = [raw_item['question']
                                    for raw_item in total_train_examples]
        all_eval_text_to_encode = [raw_item['question']
                                   for raw_item in total_eval_examples]
        label_map = None
    else:
        raise ValueError(f"{args.task_name} is not supported")
    return total_train_examples,total_eval_examples,all_train_text_to_encode,\
           all_eval_text_to_encode,format_example,label_map
