import os.path
import re
import os

os.environ['HF_HOME'] = '<REDACTED>-PrEx/cache'

import pandas as pd
import torch

from transformers import pipeline
import argparse
from tqdm import tqdm
from torch.utils.data import Dataset

from iterator.utils.model_dict import load_from_catalogue
from project_root import join_with_root
from vllm import LLM, SamplingParams
#import mii


class ListDataset(Dataset):
    def __init__(self, p):
        self.prompt_dataset = p

    def __len__(self):
        return len(self.prompt_dataset)

    def __getitem__(self, i):
        return self.prompt_dataset[i]


if __name__ == '__main__':
    os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'

    parser = argparse.ArgumentParser(description='Pass allowed models via command line.')
    parser.add_argument('--model', help='List of models to be allowed')
    parser.add_argument('--task', help='the language pair to consider')
    parser.add_argument('--fr', help='Starting sample')
    parser.add_argument('--to', help='Ending sample')
    parser.add_argument('--max_tokens', help='Max tokens')
    parser.add_argument('--prompt_df_path', help='File with the prompts for the current model')
    parser.add_argument('--out_dir', help='output directory', default="<REDACTED>-PrEx/outputs/raw")
    parser.add_argument('--mode', help='standard, vllm, or deepspeed. vllm is the default', default="vllm")
    args = parser.parse_args()

    raw_df = pd.read_json(args.prompt_df_path)
    raw_samples = len(raw_df)
    raw_df = raw_df[raw_df["task"] == args.task]
    df = raw_df[int(args.fr):int(args.to)]

    del raw_df


    prompt_list = df["prompts"].tolist()
    prompt_list_unpacked = [d["base_prompt"]["prompt"] for e in prompt_list for d in e]
    
    if args.mode == "standard":
        prompt_dataset = ListDataset(prompt_list_unpacked)
        model, tokenizer, user, assistant = load_from_catalogue(args.model)
        pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=int(args.max_tokens),
                   pad_token_id=tokenizer.eos_token_id)
        model_results = [i for i in tqdm(pipe(prompt_dataset), desc="Progress: ")]
        processed_results = [m[0]["generated_text"][len(prompt_list_unpacked[i]):] for i, m in enumerate(model_results)]

    elif args.mode == "vllm":
        sampling_params = SamplingParams(temperature=0, max_tokens=int(args.max_tokens))
        if "gptq" in args.model.lower():
            if args.task == "summarization":
                l = 1600
            else:
                l = 4096
            llm = LLM(model=args.model,quantization="gptq", max_model_len=l)
        else:
            llm = LLM(model=args.model)
        outputs = llm.generate(prompt_list_unpacked, sampling_params)

        # Capture the parts of the output that are not part of the original prompt. Also, fillup empty cells with a
        # string "Missing"
        #
        processed_results = [m.outputs[0].text for i, m in enumerate(outputs)]
        
        processed_results += ['MISSING'] * (len(prompt_list_unpacked) - len(processed_results))
    
    #elif args.mode == "deepspeed":
    #    pipe = mii.pipeline(args.model)
    #    response = pipe(prompt_list_unpacked, max_new_tokens=int(args.max_tokens))
    #    print(response)

    new_res = [processed_results[x:x + len(prompt_list[0])] for x in
               range(0, len(processed_results), len(prompt_list[0]))]

    df["generated_text"] = new_res

    torch.cuda.empty_cache()

    df.to_json(
        os.path.join(args.out_dir, f"slurm_pool_{args.model.replace('/', '_')}_{args.fr}_{args.to}_of_{raw_samples}_{args.mode}_{args.task}.json"),
        orient="records", force_ascii=False)
