from dataclasses import dataclass
from transformers.utils import PaddingStrategy
# from transformers.tokenization_utils_base import PreTrainedTokenizerBase
import torch
from typing import *
import json
from data_collator import MyDataCollatorWithPadding
from transformers import CodeGenTokenizer, AutoTokenizer, PreTrainedTokenizerBase


class MyDataset(torch.utils.data.Dataset):
    def __init__(self, data_path: str, max_tokens: int, task_type: str, tokenizer: AutoTokenizer):
        assert task_type in [
            "no-steps", "ground-truth-guided", "self-guided", "steps-genetate"]
        self.data_path = data_path
        self.max_tokens = max_tokens
        self.tokenizer = tokenizer
        self.task_type = task_type
        self.samples: List[Dict] = self._initialize()
        # debug
        # for i, sample in enumerate(self.samples):
        #     print(f"sample{i} pid: {sample['pid']}")

    def _initialize(self):
        all_samples: List[Dict] = []
        with open(self.data_path, 'r', encoding='utf-8') as f:
            data: List[Dict] = json.load(f)
        print(f"Loading data {self.data_path} ...")
        for item in data:
            sample = {}
            sample["pid"] = item["pid"]
            sample["nl"] = item["nl"]
            sample["input_format"] = item["input_format"]
            sample["input"] = item["test_case"][0]["input"]
            sample["output_format"] = item["output_format"]
            sample["output"] = item["test_case"][0]["output"]
            sample["ans"] = item["code"][0]["code"]

            sample["steps"] = "\n".join(item["step"])
            all_samples.append(sample)
        print("Finish loading data!")
        return all_samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        if self.task_type == "no-steps":
            prompt = f'Problem description:\n{sample["nl"]}\nInput format:\n{sample["input_format"]}\n' \
                f'Output format:\n{sample["output_format"]}\nExamples:\n' \
                f'Input>>\n{sample["input"]}\nOutput>>\n{sample["output"]}\nAnswer:\n'
            ans = sample["ans"]
        elif self.task_type == "ground-truth-guided":
            prompt = f'Problem description:\n{sample["nl"]}\nInput format:\n{sample["input_format"]}\n' \
                f'Output format:\n{sample["output_format"]}\nExamples:\n' \
                f'Input>>\n{sample["input"]}\nOutput>>\n{sample["output"]}\nAnswer:\n{sample["steps"]}\nBelow is the code:\n'
            ans = sample['ans']
        elif self.task_type == "self-guided":
            prompt = f'Problem description:\n{sample["nl"]}\nInput format:\n{sample["input_format"]}\n' \
                f'Output format:\n{sample["output_format"]}\nExamples:\n' \
                f'Input>>\n{sample["input"]}\nOutput>>\n{sample["output"]}\nAnswer:\n'
            ans = f"{sample['steps']}\nBelow is the code:\n{sample['ans']}"
        elif self.task_type == "steps-genetate":
            prompt = f'Problem description:\n{sample["nl"]}\nInput format:\n{sample["input_format"]}\n' \
                f'Output format:\n{sample["output_format"]}\nExamples:\n' \
                f'Input>>\n{sample["input"]}\nOutput>>\n{sample["output"]}\nAnswer:\n'
            ans = sample['steps']

        # ******************
        # prompt = "hello world def hello"
        # ans = "yes, hi world Hi!"
        # ******************
        ans = ans + "<|endoftext|>"

        prompt_token_ids = self.tokenizer.encode(
            prompt, add_special_tokens=False, verbose=False)
        ans_token_ids = self.tokenizer.encode(
            ans, add_special_tokens=False, verbose=False)

        if len(prompt_token_ids) >= self.max_tokens:
            ans_token_ids = []
            prompt_token_ids = prompt_token_ids[:self.max_tokens]
        elif len(prompt_token_ids) + len(ans_token_ids) > self.max_tokens:
            input_len = len(prompt_token_ids) + len(ans_token_ids)
            surplus_len = input_len - self.max_tokens  # Excess length
            ans_token_ids = ans_token_ids[:len(ans_token_ids) - surplus_len]

        input_ids = (torch.LongTensor(prompt_token_ids),
                     torch.LongTensor(ans_token_ids))   # Combine into a tuple for further processing in DataCollator

        res = {"input_ids": input_ids}
        return res
