import datasets
import tqdm
from huggingface_hub.hf_api import HfFolder
from transformers import AutoModelForCausalLM, AutoTokenizer
import config
from data_loader import load_test_data
import pickle
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
from transformers import StoppingCriteria, StoppingCriteriaList
import os

def main():

    print(config.params)
    print("DPO FORCE")

    dpo_model_name = config.params['dpo_final_output_dir']
    dpo_results_dir = config.params['dpo_results_dir']

    path = '/'.join(dpo_results_dir.split('/')[:-1])
    if not os.path.exists(path):
        os.makedirs(path)

    cache_dir = config.params['cache_dir']
    hf_token = config.params['hf_read_token']
    HfFolder.save_token(hf_token)


    tokenizer_name = config.params['sft_tokenizer_name']

    def prompt_convert(ex):
        new_prompts = []
        for i, p in enumerate(ex['prompt']):
            new_prompt = f"### Term: {p.title()}\n### Mnemonic: {p.title()} sounds like"
            new_prompts.append(new_prompt)
        ex['prompt'] = new_prompts
        return ex

    ds_test = load_test_data().map(prompt_convert, batched=True)
    inf_prompts = ds_test['prompt']


    class StoppingCriteriaSub(StoppingCriteria):

        def __init__(self, stop_tokens = [], prompt_len = 0):
            super().__init__()
            self.prompt_len = prompt_len
            self.stop_tokens = stop_tokens

        def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
            sublist = self.stop_tokens
            input_ids = input_ids[0].tolist()
            seq_in_gen = sublist in [input_ids[i:len(sublist)+i] for i in range(self.prompt_len, len(input_ids))]
            return seq_in_gen


    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, cache_dir=cache_dir)
    dpo_model = AutoModelForCausalLM.from_pretrained(dpo_model_name,
                                                    load_in_8bit=config.params['load_in_8bit'],
                                                    load_in_4bit=config.params['load_in_4bit'],
                                                    cache_dir=cache_dir,
                                                    device_map="auto")
    f = dpo_results_dir + '_force'
    if os.path.isfile(f):
        with open(f, 'rb') as handle:
            outputs = pickle.load(handle) 
    else:
        outputs = []

    stop_token = '\n'
    for idx in tqdm.tqdm(range(len(inf_prompts))):

        prompt = inf_prompts[idx]

        input_ids = tokenizer.encode(prompt, return_tensors="pt")
        stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(tokenizer(stop_token).input_ids[2:], prompt_len=input_ids.shape[1])])
        input_ids = input_ids.to('cuda')
        out = dpo_model.generate(input_ids, max_new_tokens=128, do_sample=False, stopping_criteria=stopping_criteria).to('cpu').detach()
        out = out[:, input_ids.shape[1]:]
        out = tokenizer.batch_decode(out)
        outputs.append(out)

        with open(dpo_results_dir + '_force', 'wb') as handle:
            pickle.dump(outputs, handle, protocol=pickle.HIGHEST_PROTOCOL)

    with open(dpo_results_dir + '_force', 'wb') as handle:
        pickle.dump(outputs, handle, protocol=pickle.HIGHEST_PROTOCOL)

if __name__ == '__main__':
    main()