import random
import os
import json
from copy import deepcopy
from collections import defaultdict


ANNOTATORS = ['1','2', '3', '4', '5', '6','7','8', '9', '10', '11']


FORMAL_FOLDERS = [
    '../../gyafc/eval/chatgpt_fixed_prompt/2024-06-03_20-29-02', # GPT-4
    '../../gyafc/eval/chatgpt_fixed_prompt/2024-06-03_20-58-27', # GPT-3.5
    '../../gyafc/eval/sft_v2_outputs/2024-05-28-12_07_50', # TinyStyle
    '../../gyafc/eval/paraguide/2024-05-28_00-38-51_200.0', # ParaGuide
    '../../gyafc/eval/mix_match/2024-05-29_11-52-31', # Ham
    
   
]

INFORMAL_FOLDERS = [
    '../../gyafc/eval/chatgpt_fixed_prompt/2024-06-03_20-29-02', # GPT-4
    '../../gyafc/eval/chatgpt_fixed_prompt/2024-06-03_20-58-27', # GPT-3.5
    '../../gyafc/eval/sft_v2_outputs/2024-05-28-12_07_50', # TinyStyle
    '../../gyafc/eval/paraguide/2024-05-28_00-39-50_200.0', # paraguide
    '../../gyafc/eval/mix_match/2024-05-29_11-52-51', #ham
]

ANNOT_PER = 3
N = 75
NUM_EXAMPLES = 10

assert ANNOT_PER <= len(ANNOTATORS)

def standardize_inference_data(data):

    if 'original_text' not in data:
        data['original_text'] = data['source_text']

    result = data['output']

    while isinstance(result, (dict, list)):
        while isinstance(result, dict):
            # import pdb; pdb.set_trace()
            keys = list(result.keys())
            if len(keys) > 1:
                assert 'text' in keys
                result = result['text']
            else:
                result = result[keys[0]]
        while isinstance(result, list):
            result = result[0]
    data['output'] = [result]

def get_inference_file(target, folder):

    option1 = os.path.join(folder, f'{target}.jsonl')
    option2 = os.path.join(folder, f'to_{target}.jsonl')
    
    if os.path.exists(option1):
        return option1

    elif os.path.exists(option2):
        return option2

    else:
        raise FileNotFoundError(f'Could not find {target} in {folder}')

def load_data(file_path):
    data = []
    with open(file_path, 'r') as f:
        for i, line in enumerate(f):
            result = json.loads(line)
            standardize_inference_data(result)
            result['data_path'] = file_path
            result['data_index'] = i
            data.append(result)
    return data

def validate_same_inputs(json_lists):
    all_original_texts = []
    for json_list in json_lists:
        all_original_texts.append([json['original_text'] for json in json_list])
    
    for i in range(len(all_original_texts)):
        assert all_original_texts[i] == all_original_texts[0]


FORMAL_DATA = [load_data(get_inference_file('formal', folder)) for folder in FORMAL_FOLDERS]
INFORMAL_DATA = [load_data(get_inference_file('informal', folder)) for folder in INFORMAL_FOLDERS]

validate_same_inputs(FORMAL_DATA)
validate_same_inputs(INFORMAL_DATA)


informal_examples = [FORMAL_DATA[0][i]['original_text'] for i in range(NUM_EXAMPLES)]
formal_examples = [INFORMAL_DATA[0][i]['original_text'] for i in range(NUM_EXAMPLES)]

# exclude the examples 

FORMAL_DATA = [x[NUM_EXAMPLES:] for x in FORMAL_DATA]
INFORMAL_DATA = [x[NUM_EXAMPLES:] for x in INFORMAL_DATA]

# save examples
with open('informal_examples.txt', 'w') as f:
    for example in informal_examples:
        f.write(example + '\n')

with open('formal_examples.txt', 'w') as f:
    for example in formal_examples:
        f.write(example + '\n')


random.seed(42)

total_formal = len(FORMAL_DATA[0])
total_informal = len(INFORMAL_DATA[0])

print('Total formal: ', total_formal)
print('Total informal: ', total_informal)

formal_indices = list(range(total_formal))
informal_indices = list(range(total_informal))

random.shuffle(formal_indices)
random.shuffle(informal_indices)

formal_indices = formal_indices[:N]
informal_indices = informal_indices[:N]

# save indices
with open('formal_indices.json', 'w') as f:
    json.dump(formal_indices, f)

with open('informal_indices.json', 'w') as f:
    json.dump(informal_indices, f)

FORMAL_DATA_TO_ANNOTATE = []
for dataset in FORMAL_DATA:
    for i in formal_indices:
        FORMAL_DATA_TO_ANNOTATE.append(dataset[i])

INFORMAL_DATA_TO_ANNOTATE = []
for dataset in INFORMAL_DATA:
    for i in informal_indices:
        INFORMAL_DATA_TO_ANNOTATE.append(dataset[i])

# import pdb; pdb.set_trace()


ALL_DATA = FORMAL_DATA_TO_ANNOTATE + INFORMAL_DATA_TO_ANNOTATE
random.shuffle(ALL_DATA)
random.shuffle(ANNOTATORS)

print('Total data to annotate: ', len(ALL_DATA))


data_copied_with_assignments = []

counter = 0
for example in ALL_DATA:
    for k in range(ANNOT_PER):
        data_copy = deepcopy(example)
        data_copy['annotator'] = ANNOTATORS[counter % len(ANNOTATORS)]
        data_copy['annotation_id'] = counter
        data_copied_with_assignments.append(data_copy)
        counter +=1 

check_number_per_annotator = defaultdict(int)

for example in data_copied_with_assignments:
    check_number_per_annotator[example['annotator']] += 1
print(check_number_per_annotator)

with open('all_data_to_annotate.jsonl', 'w') as f:
    for example in data_copied_with_assignments:
        f.write(json.dumps(example) + '\n')

os.makedirs('TODO', exist_ok=True)

for annotator in ANNOTATORS:

    with open('TODO/'+annotator+'_to_annotate.tsv', 'w+') as in_:
        in_.write('\t'.join(['id','Reference', '', 'Output Text', '', 'Meaning Preserved (0=no, 1=yes)', 'Well-formedness (by internet standards) (0=badly-formed, 1=well-formed)', 'Formality (0=informal, 1=formal)'])+'\n')
        for example in data_copied_with_assignments:
            if example['annotator'] == annotator:
                decoded =  example['output'][0]

                in_.write('\t'.join([str(example['annotation_id']), example['original_text'], '', decoded, '', '', '']) + '\n')





