import os
import copy
import torch
from torch.nn.utils.rnn import pad_sequence
import random
import numpy as np
from collections import deque
from datasets import load_dataset, load_metric
from transformers import AutoTokenizer, T5Model
import logging

logger = logging.getLogger(__name__)

class PersonaChatPreprocessor:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        
    def preprocess_function(self, examples):
        "Only works with batched=True"
        outputs = []
        
        conv_indices = []
        prev_idx = 0
        for i, idx in enumerate(examples['conv_id']):
            if prev_idx != idx:
                conv_indices.append(i-1)
            prev_idx = idx
       
        results = []
        for idx in conv_indices:
            dialogs = examples['history'][idx]
            persona = examples['personality'][idx]
            
            result = {
                "id": idx,
                "persona": persona,
                "tokenized_persona": self.tokenize_persona(persona),
                "dialogs": dialogs,
                "tokenized_dialogs": self.tokenize_dialogs(dialogs),
            }
            results.append(result)
    
        return {"data": results}
    
    def tokenize_dialogs(self, dialogs):
        outputs = []
        for i, dial in enumerate(dialogs):
            if i % 2 == 0:
                dial = self.tokenizer.encode("A: " + dial + "\n\n", add_special_tokens=False)
            else:
                dial = self.tokenizer.encode("B: " + dial+ "\n\n", add_special_tokens=False)
            outputs.append(dial)
        return outputs
    
    def tokenize_persona(self, persona):
        persona = "persona: " + " ".join(persona) + "\n\n"
        persona = self.tokenizer.encode(persona, add_special_tokens=False)
        return persona

class SeriesQueueAgent:
    def __init__(self, tokenizer, max_seq_length = 512):
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
        # shared queue
        self.iterator = None
        self.role_A = tokenizer.encode("A", add_special_tokens=False)[0]
        
    def get(self, queue):
        if self.iterator is None:
            if len(queue) == 0:
                # if_empty, (reset, result)
                return True, (True, None)
            else:
                example = queue.popleft()
                self.iterator = iter(self.iter_example(example))
                
        output = next(self.iterator, None)
        if output is None:
            self.iterator = None
            return self.get(queue)
        else:
            return output
            
    def iter_example(self, example):
        history = copy.copy(example["tokenized_persona"])
        reset = True
        
        for item in example["tokenized_dialogs"]:
            if item[0] != self.role_A:
                result = {"history": [self.tokenizer.bos_token_id] + history[-self.max_seq_length-2:] + [self.tokenizer.eos_token_id],
                          "labels": [self.tokenizer.bos_token_id] + item}
                yield False, (reset, result)
                reset = False

            history += item
            

class SeriesQueueDataLoader:
    def __init__(self, dataset, batch_size, tokenizer, max_seq_length, shuffle=False, collate_fn=None, enable_length=True, output_if_empty=True):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.collate_fn = collate_fn
        self.enable_length = enable_length
        self.output_if_empty = output_if_empty
        
        self._saved_length = None
        
        self.agents = [SeriesQueueAgent(tokenizer, max_seq_length) for i in range(self.batch_size)]
        
    def __iter__(self):
        queue = deque() # full of datas
        indices = np.arange(len(self.dataset))
        
        if self.shuffle:
            np.random.shuffle(indices)
        
        for idx in indices:
            queue.append(self.dataset[idx])
        
        count = 0
        while True:
            batch = [self.agents[i].get(queue) for i in range(self.batch_size)]
            if_empty, batch = zip(*batch)
                        
            if (isinstance(self._saved_length, int) and count >= self._saved_length) or all(if_empty) is True:
                break
                        
            if self.collate_fn is not None:
                batch = self.collate_fn((if_empty, batch))
                yield batch
            else:
                yield if_empty, batch
            
            count += 1
    
    def __len__(self):
        if not self.enable_length:
            raise NotImplementedError
            
        if self.shuffle is False:
            count = 0
            for batch in self:
                count += 1
            return count
        else:
            if self._saved_length is None:
                count = 0
                for batch in self:
                    count += 1
                    # avoid over batching
                self._saved_length = (count // self.batch_size - 2) * self.batch_size
            return self._saved_length

class DataLoaderHelper:
    def __init__(self, config) -> None:
        self.config = config
        self.tokenizer = AutoTokenizer.from_pretrained(self.config.task.tokenizer)
        processor = PersonaChatPreprocessor(self.tokenizer)
        raw_dataset = load_dataset("bavard/personachat_truecased")
        self.tokenized_dataset = raw_dataset.map(processor.preprocess_function, batched=True, remove_columns=raw_dataset["train"].column_names)

    def eval_collate_fn(self, batch):
        # assert num_batch_chunks == 1
        if_empty, batch = batch
        resets, batch = zip(*batch)
        batch = list(batch)
        # deal with None
        for i in range(len(batch)):
            if batch[i] is None:
                batch[i] = {"history": [], "labels":[]}
                #logger.warn("Detect empty input!")

        input_ids = [torch.LongTensor(item["history"]) for item in batch]
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = torch.BoolTensor(input_ids != self.tokenizer.pad_token_id)
        labels = [torch.LongTensor(item["labels"]) for item in batch]
        labels = pad_sequence(labels, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        decoder_input_ids = labels[:, :-1].clone()
        labels = labels[:, 1:].clone()
        labels[labels==self.tokenizer.pad_token_id] = -100
        decoder_attention_mask = torch.BoolTensor(decoder_input_ids != self.tokenizer.pad_token_id)

        output = {
            "reset": torch.BoolTensor(resets),
            "encoder_input_ids": input_ids,
            "encoder_attention_mask": attention_mask,
            "decoder_input_ids": decoder_input_ids,
            "decoder_attention_mask": decoder_attention_mask,
            "target": labels
        }
        return output

    def train_collate_fn(self, batch):
        return [self.eval_collate_fn(batch)]

    def train_loader_fn(self):
        dataset = self.tokenized_dataset['train']['data']
        # use portion of the data
        indices = np.arange(len(dataset))
        np.random.seed(42)
        np.random.shuffle(indices)
        split = int(len(indices) * self.config.task.train_portion)
        dataset = [dataset[idx] for idx in indices[:split]]

        dataloader = SeriesQueueDataLoader(dataset, self.config.training.batch_size, self.tokenizer, self.config.task.max_seq_length, 
                                            collate_fn=self.train_collate_fn, shuffle=True)
        print("Loading dataloader length: ", len(dataloader))
        return dataloader

    def valid_loader_fn(self):
        dataset = self.tokenized_dataset['validation']['data']
        dataloader = SeriesQueueDataLoader(dataset, self.config.training.evaluation.batch_size, self.tokenizer, self.config.task.max_seq_length, 
                                            collate_fn=self.eval_collate_fn, shuffle=False)
        print("Loading dataloader length: ", len(dataloader))
        return dataloader