import os
import random
import torch

from tqdm import tqdm

from eval_attribute import load_internal_formality_model

def run_formality_model(*, model, tokenizer, texts, label_map, target, device='cuda'):
    optimizing_label_index = label_map[target]
    scores = []
    for text in tqdm(texts):
        inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
        inputs.to(device)
        outputs = model(**inputs)
        logits = outputs.logits
        probs = torch.softmax(logits, dim=-1)
        scores.append(probs[0, optimizing_label_index].item())

    return scores

    
if __name__ == '__main__':

    path_to_data = '../data/GYAFC_Corpus 2/Entertainment_Music/'
    shard = 'tune'
    max = 128
    random.seed(42)

    filter_thresh = 0.95

    model, tokenizer, label_map = load_internal_formality_model()
    model.to('cuda')
    model.eval()



    for type in ['formal', 'informal']:
        full_path = os.path.join(path_to_data, shard, type)

        with open(full_path) as f:
            lines = [x.strip() for x in f.readlines()]

        random.shuffle(lines)

        with open(full_path + f'_exemplar_sample_random.{max}', 'w') as f:
            f.writelines([x + '\n' for x in lines[:max]])

        
        scores = run_formality_model(model=model, tokenizer=tokenizer, texts=lines, label_map=label_map, target=type)
        scored_texts = list(zip(lines, scores))

        filtered_texts = [x[0] for x in scored_texts if x[1] > filter_thresh][:max]

        with open(full_path + f'_exemplar_sample_filtered_{filter_thresh}.{max}', 'w') as f:
            f.writelines([x + '\n' for x in filtered_texts])
        




        
        