import os

import sys
import argparse
import math
import torch
import json
from tqdm import tqdm
from typing import Optional

from dschat.utils.model.model_utils import create_critic_model
from dschat.utils.utils import to_device, load_hf_tokenizer
from deepspeed import get_accelerator


def parse_args():
    # Lists of valid choices for dataset_model_names and datasets
    dataset_model_names = ["mistral-7B", "wizardmath-7B", "llama2-7B", "mistral-sft", "llama2-13B"]
    datasets = ["GSM8K", "MATH", "CSQA", "STQA"]

    parser = argparse.ArgumentParser(
        description="Eval the finetued reward model")
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        required=True
    )
    parser.add_argument(
        "--num_padding_at_beginning",
        type=int,
        default=1
    )
    parser.add_argument(
        "--add_eot_token",
        action='store_true',
        help="Add <|endoftext|> as additional special token to tokenizer")
    parser.add_argument(
        '--dataset',
        type=str,
        help='The dataset to use.',
        default="GSM8K",
        choices=datasets  # Enforce choice constraint
    )
    parser.add_argument(
        '--dataset_model',
        type=str,
        help='The dataset model to use.',
        default="mistral-7B",
        choices=dataset_model_names  # Enforce choice constraint
    )
    args = parser.parse_args()
    return args


def load_stuff(model_name_or_path, num_padding_at_beginning,
               additional_special_tokens):

    tokenizer = load_hf_tokenizer(model_name_or_path,
                                  fast_tokenizer=True,
                                  add_special_tokens=additional_special_tokens)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    model = create_critic_model(model_name_or_path,
                                tokenizer,
                                None,
                                num_padding_at_beginning,
                                dropout=0.)
    model_ckpt_path = os.path.join(model_name_or_path, 'pytorch_model.bin')
    state_dict = torch.load(model_ckpt_path)
    model.load_state_dict(state_dict)
    return model, tokenizer


def get_batch(step, tokenizer):
    chosen_sentence = step
    chosen_token = tokenizer(chosen_sentence,
        max_length=1024,
        truncation=True,
        return_tensors="pt")
    batch = {}
    batch["input_ids"] = chosen_token["input_ids"]
    batch["attention_mask"] = chosen_token["attention_mask"]
    return batch


def run_new(model_name_or_path, dataset, dataset_model):

    model_name = os.path.basename(model_name_or_path)

    path = f""
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)

    rst_dir = f""
    os.makedirs(rst_dir, exist_ok=True)
    rst_path = os.path.join(rst_dir, f"{dataset_model}.json")


    device = torch.device(get_accelerator().device_name())
    rm_model, tokenizer = load_stuff(args.model_name_or_path,
                                     args.num_padding_at_beginning,
                                     additional_special_tokens=None)
    rm_model.to(device)
    rm_model.eval()

    step_tag = None
    step_tag_id = tokenizer.encode(f"{step_tag}")[-1]

    rst = []
    for entry in tqdm(data, desc='Processing '):
        correct_ans = str(entry['answer'])
        candidates = entry['candidates']
        cand_scores = []
        for candidate in candidates:
            input_for_prm = entry['question'] + '\n'
            solution = candidate['steps']

            for step in solution.split('\n'):
                input_for_prm += step.strip() + step_tag + '\n'

            with torch.no_grad():
                batch = get_batch(input_for_prm, tokenizer)
                batch = to_device(batch, device)
                outputs = rm_model.forward_value(
                    **batch, return_value_only = True
                )
                input_id = batch['input_ids'][0]
                value = outputs[0]
                step_scores = value[input_id == step_tag_id].tolist()
            cand_ans = candidate['answer']

            cand_scores.append({
                'answer': cand_ans,
                'step_scores': step_scores
            })
        rst.append(
            {
                "answer": correct_ans,
                "candidates": cand_scores
            }
        )
    with open(rst_path, 'w', encoding='utf-8') as f:
        json.dump(rst, f, ensure_ascii=False, indent = 4)



if __name__ == "__main__":
    args = parse_args()
    run_new(args.model_name_or_path,
             dataset=args.dataset,
             dataset_model=args.dataset_model)

