import pandas as pd
import random
import numpy as np
import torch
from torch.utils.data import IterableDataset, DataLoader
import math

class Base_Batchfier(IterableDataset):
    def __init__(self, batch_size: int = 32, seq_len=512, minlen=50, maxlen: int = 4096,
                 criteria: str = 'lens',
                 padding_index=70000, epoch_shuffle=False,device='cuda'):
        super(Base_Batchfier).__init__()
        self.maxlen = maxlen
        self.minlen = minlen
        self.size = batch_size
        self.criteria = criteria
        self.seq_len = seq_len
        self.padding_index = padding_index
        self.epoch_shuffle = epoch_shuffle
        self.device = device
        # self.size = len(self.df) / num_buckets

    def truncate_small(self, df, criteria='lens'):
        lens = np.array(df[criteria])
        indices = np.nonzero((lens < self.minlen).astype(np.int64))[0]
        return df.drop(indices)

    def truncate_large(self, texts, lens):
        new_texts = []
        new_lens = []
        for i in range(len(texts)):
            text = texts[i]
            if len(text) > self.maxlen:
                new_texts.append(text[:self.maxlen])
                new_lens.append(self.maxlen)
            else:
                remainder = len(text) % self.seq_len
                l = lens[i]
                if remainder and remainder < 10:
                    text = text[:-remainder]
                    l = l - remainder
                new_texts.append(text)
                new_lens.append(l)
        return new_texts, new_lens

    def shuffle(self, df, num_buckets):
        dfs = []
        for bucket in range(num_buckets):
            new_df = df.iloc[bucket * self.size: (bucket + 1) * self.size]
            dfs.append(new_df)
        random.shuffle(dfs)
        df = pd.concat(dfs)
        return df

    def sort(self, df, criteria):
        return df.sort_values(criteria).reset_index(drop=True)



class Lyrics_Batchfier(Base_Batchfier):
    def __init__(self, filelist: list, batch_size: int = 32, seq_len=512, minlen=50, maxlen: int = 4096,
                 criteria: str = 'lens',
                 padding_index=70000, epoch_shuffle=False, device='cuda'):
        super(Lyrics_Batchfier,self).__init__(batch_size,seq_len,minlen,maxlen,criteria,
                                              padding_index,epoch_shuffle,device)
        self.filelist = filelist

    def iterator(self):
        for filename in self.filelist:
            cur_df = pd.read_pickle(filename)
            if self.epoch_shuffle:
                cur_df = self.truncate_small(cur_df)
                num_buckets = len(cur_df) // self.size + (len(cur_df) % self.size !=0)
                cur_df = self.sort(cur_df,self.criteria)
                cur_df = self.shuffle(cur_df, num_buckets)
            cur_pos = 0
            while cur_pos < len(cur_df):
                cur_batch = cur_df.iloc[cur_pos :cur_pos+self.size]
                cur_pos += self.size
                texts = cur_batch['texts'].to_list()
                lens = cur_batch['lens'].to_list()
                texts, lens = self.truncate_large(texts,lens)
                maxlen = max(lens)
                n_chunk = maxlen // self.seq_len + (maxlen % self.seq_len !=0)
                for chunk in range(n_chunk):
                    for i in range(len(texts)):
                        text = texts[i][chunk * self.seq_len: (chunk + 1) * self.seq_len]
                        text_len = max(min(lens[i], self.seq_len),0)
                        lens[i]-= self.seq_len
                        yield text, text_len

    def __iter__(self):
        return self.iterator()


    def __len__(self):
        ## hard coded should be fixed
        return 30000

    def collate(self, batch):
        texts = [torch.Tensor(item[0]).long() for item in batch]
        lens = torch.Tensor([item[1] for item in batch]).long()
        texts = torch.nn.utils.rnn.pad_sequence(texts, batch_first=True, padding_value=self.padding_index)
        return texts.to(self.device), lens.to(self.device), texts[:,1:].to(self.device)



class LyricsSampleBatchfier:
    def __init__(self, filelist: list, batch_size: int = 32,
                 n_samples=10000, prefix_len=5, token_len=100, device='cuda'):
        self.filelist = filelist
        self.batch_size = batch_size
        self.device = device
        self.n_samples = n_samples
        self.prefix_len = prefix_len
        self.token_len = token_len

    def __iter__(self):
        cur_df = pd.read_pickle(self.filelist[0])
        tot_generated = 0
        cur = 0
        textlen = self.token_len + self.prefix_len
        while tot_generated < self.n_samples:
            tot_generated += self.batch_size
            batch = cur_df.iloc[cur:cur + self.batch_size]
            cur += self.batch_size
            batch_text = [text[:textlen] for text in batch['texts'].tolist() if len(text) > textlen]
            short = self.batch_size - len(batch_text)
            while short != 0:
                if cur >= len(cur_df):
                    tot_generated = self.n_samples
                    break
                added = cur_df.iloc[cur:cur + short]
                added = [text[:textlen] for text in added['texts'].tolist() if len(text) > textlen]
                cur += short
                batch_text += added
                short = self.batch_size - len(batch_text)
            x = torch.LongTensor(batch_text).to(self.device)
            yield x, torch.LongTensor([self.prefix_len] * self.batch_size).to(self.device), x[:, 1:]




class BpttIterator:
    def __init__(self, dataset, batch_size, bptt_len, device='cuda', **kwargs):
        self.bptt_len = bptt_len
        self.dataset = dataset
        self.size = batch_size
        self.device = device
        self.iterations = 0
        self.data = self.prepair_dataset(dataset)

    def prepair_dataset(self, text):
        remainder= len(text) % self.size
        if remainder:
            text = text[:-remainder]
        data = np.array(text).reshape((self.size,-1))
        return data

    def __len__(self):
        return math.ceil((len(self.dataset) / self.size - 1)
                         / self.bptt_len)

    def __iter__(self):
        cur = 0
        data = self.data
        while cur < self.data.shape[1]:
            self.iterations += 1
            batch_text = data[:,cur:cur + self.bptt_len]
            x = torch.from_numpy(batch_text).to(self.device)
            cur+=self.bptt_len
            yield x, torch.LongTensor([batch_text.shape[1]]*self.size).to(self.device), x[:,1:]


class SamplingIterator:
    def __init__(self, dataset, batch_size, prefix_len, generate_len, device='cuda', **kwargs):
        self.prefix_len = prefix_len
        self.generate_len = generate_len
        self.pg_len = prefix_len + generate_len
        self.dataset = dataset
        self.size = batch_size
        self.device = device
        self.iterations = 0
        self.data = self.prepair_dataset(dataset)

    def prepair_dataset(self, text):
        remainder= len(text) % self.size
        if remainder:
            text = text[:-remainder]
        data = np.array(text).reshape((self.size,-1))
        return data

    def __iter__(self):
        cur = 0
        data = self.data
        while cur < self.data.shape[1]:
            self.iterations += 1
            batch_text = data[:,cur:cur + self.pg_len]
            x = torch.from_numpy(batch_text).to(self.device)
            cur+=self.pg_len
            yield x, torch.LongTensor([batch_text.shape[1]]*self.size).to(self.device), x[:,1:]

