import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Sequence, Tuple
from transformers import DataCollatorForSeq2Seq
from transformers.generation.configuration_utils import GenerationConfig

@dataclass
class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
    r"""
    Data collator for pairwise data.
    """

    def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
        padded_labels = []
        for feature, (prompt_len, answer_len) in zip(batch, positions):
            if self.tokenizer.padding_side == "left":
                start, end = feature.size(0) - answer_len, feature.size(0)
            else:
                start, end = prompt_len, prompt_len + answer_len
            padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
            padded_tensor[start:end] = feature[start:end]
            padded_labels.append(padded_tensor)
        return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory

    def online_sampling(self, features, feature_keys):
        """
        sampling function, features {chosen_1, chosen_2, ..., rejected}
        """
        print("sampling...")
        prompts = []
        prompt_max_length = 0
        for feature in features:
            prompt_length = len(feature["prompt_ids"])
            prompt_max_length = max(prompt_max_length, prompt_length)
            prompts.append(
                {
                    "input_ids": feature["prompt_ids"],
                    "attention_mask": [1] * prompt_length
                }
            )
        # left padding
        prompt_features = self.sample_tokenizer.pad(
            prompts,
            padding=True,
            return_tensors=self.return_tensors,
            pad_to_multiple_of=self.pad_to_multiple_of,
        )
        
        sampled_ids_with_padding = self.sampling(
            prompt_features["input_ids"],
            prompt_features["attention_mask"],
            prompt_max_length,
            need_unload=False
        )

        assert sampled_ids_with_padding.size(0) == prompt_features["input_ids"].size(0)
        prompt_len = prompt_features["input_ids"].size(1)
        concatenated_features = []
        label_positions = []
        for key in feature_keys:
            for i, feature in enumerate(features):
                if "chosen_id" in key:
                    # get before eos token
                    sampled_ids = sampled_ids_with_padding[i, prompt_len:]
                    # valid_sample_len = (sampled_ids != self.tokenizer.pad_token_id).sum(dim=-1)
                    # sampled_ids = sampled_ids[:valid_sample_len]
                    valid_sample_len = (sampled_ids != self.tokenizer.pad_token_id).sum(dim=-1)
                    sampled_ids = sampled_ids[:valid_sample_len].numpy().tolist()

                    if valid_sample_len <= 1:
                        # sampled_ids = self.tokenizer.encode(" ")[-1]   
                        print(f"Sampling Empty Sentences: {self.step}, using original label")
                        sampled_ids = feature[key]

                    # barrier: print sample
                    print_answers = True
                    if print_answers:
                        print(f"----------step: {self.steps}----------------------------")
                        print("prompt: ", self.tokenizer.decode(feature["prompt_ids"]))
                        print("sample: ", self.tokenizer.decode(sampled_ids))
                        print("chosen: ", self.tokenizer.decode(feature[key]))
                        print("rejected: ", self.tokenizer.decode(feature["rejected_ids"]))
                        print("")
                    
                    raw_prompt_len, answer_len = len(feature["prompt_ids"]), len(sampled_ids)
                    concatenated_features.append({
                        "input_ids": feature["prompt_ids"] + sampled_ids,
                        "attention_mask": [1] * (raw_prompt_len + answer_len)
                    })
                    label_positions.append((raw_prompt_len, answer_len))
                else:
                    raw_prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
                    concatenated_features.append({
                        "input_ids": feature["prompt_ids"] + feature[key],
                        "attention_mask": [1] * (raw_prompt_len + answer_len)
                    })
                    label_positions.append((raw_prompt_len, answer_len))
        return concatenated_features, label_positions
    
    def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        r"""
        Pads batched data to the longest sequence in the batch.

        We generate 2 * n examples where the first n examples represent chosen examples and
        the last n examples represent rejected examples.
        """
        # sample chosen_ids (self.steps > ) and (self.steps % 320 ==0)
        feature_keys = [x for x in list(features[0].keys()) if x != 'prompt_ids'] 
        feature_keys = sorted(feature_keys)
        if (self.scheduler is not None) and self.scheduler.need_sampling(self.steps): # 8: accumulation steps 
            concatenated_features, label_positions = self.online_sampling(features, feature_keys) 
        else:
            concatenated_features = []
            label_positions = []
            for key in feature_keys: # [chosen_id_1, chosen_id_2, ..., rejected_ids] 不能有prompt_id!!!
                for i, feature in enumerate(features):
                    raw_prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
                    concatenated_features.append({
                        "input_ids": feature["prompt_ids"] + feature[key],
                        "attention_mask": [1] * (raw_prompt_len + answer_len)
                    })
                    label_positions.append((raw_prompt_len, answer_len))

        batch = self._pad_concatenated_features(concatenated_features)
        # we pad label after sampling 
        batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
        self.steps += 1
        return batch

    def _pad_concatenated_features(self, concatenated_features) -> Dict[str, torch.Tensor]:
        return self.tokenizer.pad(
            concatenated_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=self.return_tensors,
        )
    
    def sampling(self, prompt_ids, attention_mask, prompt_max_length, need_unload, **kwargs):
        # sampling response and replace chosen
        with torch.no_grad():
            """
            generate_kwargs = dict(
                max_length=self.tokenizer.model_max_length - prompt_max_length,
                num_beams=5,
                num_beam_groups=5,
                diversity_penalty=1.0,
                early_stopping=True,
                num_return_sequences=1,
                output_scores=False,
                use_cache=True,
                return_dict_in_generate=False
            )
            """
            # to device
            # TODO: stage 3 need to synced_gpus
            device = self.accelerator.device
            prompt_ids = prompt_ids.to(device)
            attention_mask = attention_mask.to(device)
            # generate_config = GenerationConfig(**dict(
            #     max_new_tokens=self.tokenizer.model_max_length - prompt_max_length,
            #     num_beams=4,
            #     num_beam_groups=4,
            #     diversity_penalty=1.0,
            #     early_stopping=True,
            #     do_sample=False,
            #     output_scores=False,
            #     use_cache=False,
            #     return_dict_in_generate=False,
            #     num_return_sequences=1,
            #     pad_token_id=self.tokenizer.pad_token_id,
            # ))
            generate_config = GenerationConfig(**dict(
                max_new_tokens=self.tokenizer.model_max_length - prompt_max_length,
                temperature=1.0,
                top_p=0.95,
                top_k=50,             # Top-K sampling
                do_sample=True,
                output_scores=False,
                use_cache=False,
                return_dict_in_generate=False,
                num_return_sequences=1,
                pad_token_id=self.tokenizer.pad_token_id,
            ))
            # pad_token_id=self.tokenizer.pad_token_id,
            if need_unload:
                with self.accelerator.unwrap_model(self.model).disable_adapter():
                    outputs = self.model.generate(
                                input_ids=prompt_ids,
                                attention_mask=attention_mask,
                                pad_token_id=self.tokenizer.pad_token_id,
                                generation_config = generate_config,
                            )
            else:
                outputs = self.model.generate(
                    input_ids=prompt_ids,
                    attention_mask=attention_mask,
                    pad_token_id=self.tokenizer.pad_token_id,
                    generation_config = generate_config,
                )
        # return generated sequence
        return outputs.cpu()