import os
import math
import torch
import numpy as np
import pickle as pkl
from src.utils import prompt_direct_inferring, prompt_direct_inferring_twice, prompt_direct_inferring_few_shot, prompt_direct_inferring_few_shot_agnews, prompt_direct_inferring_masked, prompt_for_aspect_inferring
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader
import random
from datasets import load_dataset, load_from_disk

class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.data_length = 0

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)


class MyDataLoader_agnews:
    def __init__(self, config):
        self.config = config
        config.preprocessor = Preprocessor_agnews(config)
            
        if 'llama' in config.model_path:
            self.tokenizer = AutoTokenizer.from_pretrained(config.model_path)
            # self.tokenizer.pad_token = "[PAD]"
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.padding_side = "left"
            print('self.tokenizer.pad_token_id: ', self.tokenizer.pad_token_id)
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(config.model_path)
            self.tokenizer.padding_side = "left"

        self.add_phrase = self.config.add_phrase
        self.few_shot_example_indices = self.config.few_shot_example_indices

    def worker_init(self, worked_id):
        worker_seed = torch.initial_seed() % 2 ** 32
        np.random.seed(worker_seed)
        random.seed(worker_seed)

    def get_data_few_shot(self):
        cfg = self.config
        path = os.path.join(self.config.preprocessed_dir,
                            '{}_{}_{}.pkl'.format(cfg.data_name, cfg.model_size, cfg.model_path).replace('/', '-'))
        if os.path.exists(path):
            self.data = pkl.load(open(path, 'rb'))
        else:
            self.data = self.config.preprocessor.forward()
            pkl.dump(self.data, open(path, 'wb'))

        train_data, test_data = self.data[:2]

        self.config.word_dict = self.data[-1]
        
        self.few_shot_examples = {'sent': [], 'answer': []}

        random_indices = list(range(len(train_data)))
        random.shuffle(random_indices)
        self.few_shot_example_indices = []

        answer_l = {0: 'world', 1: 'sports', 2: 'business', 3: 'sci/tech'}

        n_world = 0
        n_sports = 0 
        n_business = 0
        n_scitech = 0

        for ind in random_indices:
            if train_data[ind][1] == 0:
                self.few_shot_example_indices.append(ind)
                n_world += 1
            if n_world == 1:
                break

        for ind in random_indices:
            if train_data[ind][1] == 1:
                self.few_shot_example_indices.append(ind)
                n_sports += 1
            if n_sports == 1:
                break

        for ind in random_indices:
            if train_data[ind][1] == 2:
                self.few_shot_example_indices.append(ind)
                n_business += 1
            if n_business == 1:
                break

        for ind in random_indices:
            if train_data[ind][1] == 3:
                self.few_shot_example_indices.append(ind)
                n_scitech += 1
            if n_scitech == 1:
                break

        print('='*77)
        print('self.few_show_example_indices: ', self.few_shot_example_indices)
        
        for index in self.few_shot_example_indices:
            self.few_shot_examples['sent'].append(train_data[index][0])
            self.few_shot_examples['answer'].append(answer_l[train_data[index][1]])

        print('='*77)
        print('Examples used: ')
        print('-'*77)
        for i in range(len(self.few_shot_examples['sent'])):
            print(f'Example {i}')
            print("Sent: ", self.few_shot_examples['sent'][i])
            print("Answer: ", self.few_shot_examples['answer'][i])
            print('-'*77)
            
        load_data = lambda dataset: DataLoader(MyDataset(dataset), num_workers=0, worker_init_fn=self.worker_init, \
                                               shuffle=self.config.shuffle, batch_size=self.config.batch_size,
                                               collate_fn=self.collate_fn_few_shot)
        train_loader, test_loader = map(load_data, [train_data, test_data])
        train_loader.data_length, test_loader.data_length = math.ceil(
            len(train_data) / self.config.batch_size), \
            math.ceil(len(test_data) / self.config.batch_size)

        res = [train_loader, test_loader]

        return res, self.config

    def collate_fn_few_shot(self, data):
        input_tokens, input_labels= zip(*data)
        
        if self.config.reasoning == 'prompt_few_shot':
            new_tokens = []
            show_prompt_once = True
            for i, line in enumerate(input_tokens):
                line = ' '.join(line.split()[:self.config.max_length - 25])
                if self.config.zero_shot == True:
                    prompt = prompt_direct_inferring_few_shot_agnews(line, self.config.prompt_file, few_shot_examples = self.few_shot_examples,
                                                              show_prompt = show_prompt_once)
                    show_prompt_once = False
                else:
                    _, prompt = prompt_direct_inferring_masked(line, input_targets[i])
                new_tokens.append(prompt)

            batch_input = self.tokenizer.batch_encode_plus(new_tokens, padding=True, return_tensors='pt',
                                                           max_length=self.config.max_length)
            batch_input = batch_input.data
            labels = [self.config.label_list[int(w)] for w in input_labels]
            batch_output = self.tokenizer.batch_encode_plus(labels, max_length=3, padding=True,
                                                            return_tensors="pt").data

            res = {
                'input_ids': batch_input['input_ids'],
                'input_masks': batch_input['attention_mask'],
                'output_ids': batch_output['input_ids'],
                'output_masks': batch_output['attention_mask'],
                'input_labels': torch.tensor(input_labels),
            }
            res = {k: v.to(self.config.device) for k, v in res.items()}
            return res

class Preprocessor_agnews:
    def __init__(self, config):
        self.config = config

    def read_file(self):
        dataset = load_dataset('ag_news', split=['train', 'test'])
        train_data = dataset[0]
        test_data = dataset[1]
        
        return train_data, test_data

    def transformer2indices(self, cur_data, max_id=-1):
        res = []
        for i in range(len(cur_data['text'])):
            if i % 100 == 0:
                print('i: ', i)

            if i == max_id:
                break
            text = cur_data['text'][i]
            label = cur_data['label'][i]
            res.append([text, label])
        return res

    def forward(self):
        modes = 'train test'.split()
        dataset = self.read_file()
        res = []
        for i, mode in enumerate(modes):
            if i == 0:
                data = self.transformer2indices(dataset[i], max_id = 2000)
            else:
                data = self.transformer2indices(dataset[i])
            res.append(data)
        return res
