import torch

from time import time



class MultitaskRerankingDatasetTrainer:

    def __init__(self, mode, tokenizer, texts, scored_summaries, labels, args):
        self.mode = mode
        self.tokenizer = tokenizer
        self.texts = texts
        self.scored_summaries = scored_summaries
        self.labels = labels
        self.args = args

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, item):
        text = self.texts[item]
        scored_summaries = self.scored_summaries[item]
        summary_candidates = scored_summaries[0]
        summary_scores = scored_summaries[1]
        for i in range(len(summary_scores)):
            # re-adjust BERTScore
            if min(summary_scores[i]) > 0.0 and max(summary_scores[i]) < 1.0:
                #print("transforming BERTScore...")
                for j in range(len(summary_scores[i])):
                    summary_scores[i][j] *= 100
            # re-adjust BARTScore
            elif min(summary_scores[i]) > -10.0 and max(summary_scores[i]) < 0.0:
                #print("transforming BARTScore...")
                for j in range(len(summary_scores[i])):
                    summary_scores[i][j] *= 30
        if self.args.encode_begin_end:
            text_inputs = self.tokenizer(text, return_tensors="pt")
            if len(text_inputs["input_ids"][0]) < self.args.max_length:
                text_inputs = self.tokenizer(text, return_tensors="pt", max_length = self.args.max_length, padding = "max_length")
                text_inputs["input_ids"] = text_inputs["input_ids"][:, :self.args.max_length]
                text_inputs["attention_mask"] = text_inputs["attention_mask"][:, :self.args.max_length]
            else:
                thresh_begin = int(0.67 * self.args.max_length)
                thresh_end = self.args.max_length - thresh_begin
                text_inputs["input_ids"] = torch.cat((text_inputs["input_ids"][:, :thresh_begin], text_inputs["input_ids"][:, -thresh_end:]), 1)
                text_inputs["attention_mask"] = torch.cat((text_inputs["attention_mask"][:, :thresh_begin], text_inputs["attention_mask"][:, -thresh_end:]), 1)
        else:
            text_inputs = self.tokenizer(text, return_tensors="pt", max_length=self.args.max_length, padding='max_length')
            text_inputs["input_ids"] = text_inputs["input_ids"][:, :self.args.max_length]
            text_inputs["attention_mask"] = text_inputs["attention_mask"][:, :self.args.max_length]
        text_inputs_tail = self.tokenizer(text, return_tensors="pt")
        text_inputs_tail["input_ids"] = text_inputs_tail["input_ids"][:, -self.args.max_length:]
        text_inputs_tail["attention_mask"] = text_inputs_tail["attention_mask"][:, -self.args.max_length:]

        summary_candidates_inputs = self.tokenizer(summary_candidates, return_tensors="pt", truncation=True, max_length=self.args.max_summary_length, padding='max_length')
        summary_candidates_inputs["input_ids"] = summary_candidates_inputs["input_ids"][:,:self.args.max_summary_length]
        summary_candidates_inputs["attention_mask"] = summary_candidates_inputs["attention_mask"][:,:self.args.max_summary_length]

        if self.args.pack_text_summaries:
            text_and_summaries = []
            full_text_inputs = self.tokenizer(text, return_tensors="pt")
            for i in range(len(summary_candidates)):
                summary_candidate_inputs = self.tokenizer(summary_candidates[i], return_tensors="pt")
                n_tokens_summary_i = len(summary_candidate_inputs["input_ids"][0])
                n_tokens_text_i_max = self.args.max_length + self.args.max_summary_length - 3 - n_tokens_summary_i
                text_and_summary_i = self.tokenizer.decode(full_text_inputs["input_ids"][0][:n_tokens_text_i_max], skip_special_tokens=True) + " " + self.args.sep_symbol + " " + summary_candidates[i]
                text_and_summaries.append(text_and_summary_i)
        else:
            text_and_summaries = [self.tokenizer.decode(text_inputs["input_ids"][0], skip_special_tokens=True) + " " + self.args.sep_symbol + " " \
                              + self.tokenizer.decode(summary_candidates_inputs["input_ids"][i], skip_special_tokens=True) for i in range(len(summary_candidates_inputs["input_ids"]))]
        text_and_summaries_inputs = self.tokenizer(text_and_summaries, return_tensors="pt", truncation=True, max_length=self.args.max_length + self.args.max_summary_length, padding='max_length')
        text_and_summaries_inputs["input_ids"] = text_and_summaries_inputs["input_ids"][:, :(self.args.max_length + self.args.max_summary_length)]
        text_and_summaries_inputs["attention_mask"] = text_and_summaries_inputs["attention_mask"][:, :(self.args.max_length + self.args.max_summary_length)]

        text_tail_and_summaries = [self.tokenizer.decode(text_inputs_tail["input_ids"][0], skip_special_tokens=True) + " " + self.args.sep_symbol + " " \
                              + self.tokenizer.decode(summary_candidates_inputs["input_ids"][i], skip_special_tokens=True) for i in range(len(summary_candidates_inputs["input_ids"]))]
        text_tail_and_summaries_inputs = self.tokenizer(text_tail_and_summaries, return_tensors="pt", truncation=True, max_length=self.args.max_length + self.args.max_summary_length, padding='max_length')
        text_tail_and_summaries_inputs["input_ids"] = text_tail_and_summaries_inputs["input_ids"][:, :(self.args.max_length + self.args.max_summary_length)]
        text_tail_and_summaries_inputs["attention_mask"] = text_tail_and_summaries_inputs["attention_mask"][:, :(self.args.max_length + self.args.max_summary_length)]

        text_inputs = self.tokenizer(text, return_tensors="pt", max_length = self.args.max_length + self.args.max_summary_length, padding="max_length")
        text_inputs["input_ids"] = text_inputs["input_ids"][:, :(self.args.max_length + self.args.max_summary_length)]
        text_inputs["attention_mask"] = text_inputs["attention_mask"][:, :(self.args.max_length + self.args.max_summary_length)]

        scores = torch.cat([torch.tensor(summary_scores[i]).unsqueeze(0) for i in range(len(summary_scores))], 0)
        labels = torch.max(scores, dim = 1)[0]
        mode = torch.tensor([1])
        if self.mode != "train":
            mode = torch.tensor([0])

        batch = {
            "mode": mode,
            "text_input_ids": text_inputs["input_ids"],
            "text_attn_mask": text_inputs["attention_mask"],
            "cand_input_ids": summary_candidates_inputs["input_ids"],
            "cand_attn_mask": summary_candidates_inputs["attention_mask"],
            "text_and_summaries_input_ids": text_and_summaries_inputs["input_ids"],
            "text_and_summaries_attn_mask": text_and_summaries_inputs["attention_mask"],
            "text_tail_and_summaries_input_ids": text_tail_and_summaries_inputs["input_ids"],
            "text_tail_and_summaries_attn_mask": text_tail_and_summaries_inputs["attention_mask"],
            "scores": scores,
            "labels": labels
        }

        return batch

