from transformers import Trainer, TrainingArguments, TrainerCallback


class ContrastiveTrainer(Trainer):

    def compute_loss(self, model, inputs, return_outputs=False):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs["labels"]
            contras_labels = inputs["contras_label"]
        else:
            labels = None

        if labels is None:
            inputs_1 = {
                "input_ids": inputs["input_ids"],
                "attention_mask": inputs["attention_mask"],
                "token_type_ids": inputs["token_type_ids"],
                "labels": inputs["labels"],
            }
            outputs = model(**inputs_1)

            inputs_2 = {
                "input_ids": inputs["conts_input_ids"],
                "attention_mask": inputs["conts_attention_mask"],
                "token_type_ids": inputs["conts_token_type_ids"],
                "labels": inputs["contras_label"],
            }
            outputs_2 = model(**inputs_2)
        else:
            inputs_1 = {
                "input_ids": inputs["input_ids"],
                "attention_mask": inputs["attention_mask"],
                "token_type_ids": inputs["token_type_ids"],
            }
            outputs = model(**inputs_1)

            inputs_2 = {
                "input_ids": inputs["conts_input_ids"],
                "attention_mask": inputs["conts_attention_mask"],
                "token_type_ids": inputs["conts_token_type_ids"],
            }
            outputs_2 = model(**inputs_2)

        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            loss = self.label_smoother(outputs, labels)
            loss_2 = self.label_smoother(outputs_2, contras_labels)
            loss = loss + loss_2
        else:
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
            loss_2 = outputs_2["loss"] if isinstance(outputs_2, dict) else outputs_2[0]
            loss = loss + loss_2

        return (loss, outputs) if return_outputs else loss



