import os
import json
import random
import os
import json
import random
import torch
import datasets
from collections import Counter
from src.dataset.super_glue import AutoTask as AutoSuperGLUETask
from src.dataset.glue_tasks import AutoTask as AutoGLUETask


def load_lora_dict_by_task(args, task_name):
    """Load Golden Lora Checkpoint

    Args:
        task_name (str): task name
    """
    # if ':' in task_name:    task_name = task_name.split(':')[0]
    # task_name = task_name.replace('/', '_')
    # TODO: flan-t5-large
    # lora_dict = torch.load(f'{args.lora_path}/flan_t5_large-{task_name}/adapter_model.bin', map_location='cpu')
    # TODO: flan-t5-base/small
    # lora_dict = torch.load(f'{args.lora_path}/flan_t5_base-{task_name}.json/adapter_model.bin', map_location='cpu')
    # TODO: t5-v11-large-lm-adapt
    # lora_dict = torch.load(f'{args.lora_path}/{task_name}.json/adapter_model.bin', map_location='cpu')
    # TODO: t5-large
    # lora_dict = torch.load(f'{args.lora_path}/t5_large-glue_{task_name}/adapter_model.bin', map_location='cpu')
    #TODO: SNI
    # lora_dict = torch.load(f'{args.lora_path}/{task_name}/adapter_model.bin', map_location='cpu')
    # TODO: t5-v11-large-lm-adapt + P3
    lora_dict = torch.load(f'{args.lora_path}/{task_name}/adapter_model.bin', map_location='cpu')
    return lora_dict

def get_dataset_sizes(raw_datasets, split='train', task_key='task_name'):
    # sorted by the appear order
    data_counter = Counter(raw_datasets[split][task_key])
    all_task_names = raw_datasets[split][task_key]
    order_task_names = []
    for task_name in all_task_names:
        if task_name not in order_task_names:
            order_task_names.append(task_name)
    dataset_sizes = [data_counter[task_name] for task_name in order_task_names]
    return dataset_sizes

def load_bbh_dataset(args, n_demonstrations=0, seed=42):
    sub_dirs = os.listdir(args.dataset_name)

    if n_demonstrations > 0:
        demo_datasets, raw_datasets = {}, {}
        # n seeds used in our experiments
        for sub_dir in sub_dirs:
            # construct the few-shot examples for lorahub learning
            example_inputs, examples_outputs = [], []
            example_file_path = os.path.join(args.dataset_name, sub_dir, "example.jsonl")
            for line in open(example_file_path, "r", encoding="utf-8"):
                example = json.loads(line)
                example_inputs.append(example["context"])
                examples_outputs.append(example["completion"])
                '''format'''
                # context, completion = example["context"], example["completion"]
                # if '(A)' in context:
                #     import re
                #     patterns = re.findall(r'\([A-Z]\)', example['context'])
                #     for pt in patterns:
                #         context = context.replace(pt, f'{pt[1]}.')
                #     completion = completion[1]
                # example_inputs.append(context)
                # examples_outputs.append(completion)

            # random select 5 examples for each task
            random.seed(seed)
            shuffled_set = list(zip(example_inputs, examples_outputs))
            random.shuffle(shuffled_set)
            example_inputs, examples_outputs = zip(*shuffled_set)
            # take the first 5 examples
            example_inputs, examples_outputs = example_inputs[:n_demonstrations], examples_outputs[:n_demonstrations]

            demo_dataset = datasets.Dataset.from_dict({
                'inputs': example_inputs,
                'targets': examples_outputs,
            })
            demo_datasets[sub_dir] = demo_dataset

            # load the zero-shot examples for evaluation
            test_file_path = os.path.join(args.dataset_name, sub_dir, "zero_shot.jsonl")
            task_inputs, task_outputs, task_names = [], [], []
            for line in open(test_file_path, "r", encoding="utf-8"):
                example = json.loads(line)
                task_inputs.append(example["context"])
                task_outputs.append(example["completion"])
                task_names.append(sub_dir)
                '''format'''
                # context, completion = example["context"], example["completion"]
                # if '(A)' in context:
                #     import re
                #     patterns = re.findall(r'\([A-Z]\)', example['context'])
                #     for pt in patterns:
                #         context = context.replace(pt, f'{pt[1]}.')
                #     completion = completion[1]
                # task_inputs.append(context)
                # task_outputs.append(completion)

            raw_datasets[sub_dir] = datasets.Dataset.from_dict({
                'inputs': task_inputs,
                'targets': task_outputs,
                'task_name': task_names
            })
        demo_datasets = datasets.DatasetDict(demo_datasets)
        raw_datasets = datasets.DatasetDict(raw_datasets)
        return demo_datasets, raw_datasets

    raw_datasets = []
    for sub_dir in sub_dirs:
        test_file_path = os.path.join(args.dataset_name, sub_dir, "few_shot.jsonl")
        task_inputs, task_outputs, task_names = [], [], []
        for line in open(test_file_path, "r", encoding="utf-8"):
            example = json.loads(line)
            task_inputs.append(example["context"])
            task_outputs.append(example["completion"])
            task_names.append(sub_dir)

        raw_datasets.append(datasets.Dataset.from_dict({
            'inputs': task_inputs,
            'targets': task_outputs,
            'task_name': task_names
        }))
    raw_datasets = datasets.DatasetDict({
                    'validation': datasets.concatenate_datasets(raw_datasets)
                })
    return raw_datasets

def load_fs_super_glue_dataset(args, n_demonstrations=0, seed=42):
    sub_dirs = os.listdir(args.dataset_name)
    sub_dirs = ['cb']
    print(sub_dirs)
    demo_datasets, raw_datasets = {}, {}
    for sub_dir in sub_dirs:
        # load n-shot data
        demo_dataset = load_json_by_line(os.path.join(args.dataset_name, sub_dir, "{}_shot.json".format(n_demonstrations)))
        eval_dataset =  load_json_by_line(os.path.join(args.dataset_name, sub_dir, "validation.json"))

        # process
        demo_dataset = AutoSuperGLUETask.get(task_name=sub_dir).get_dataset(demo_dataset, add_prefix=True)
        eval_dataset = AutoSuperGLUETask.get(task_name=sub_dir).get_dataset(eval_dataset, add_prefix=True)

        print('print eval_dataset')
        print(eval_dataset[0])
        raw_datasets[sub_dir] = eval_dataset
        demo_datasets[sub_dir] = demo_dataset
    return demo_datasets, raw_datasets

def load_fs_glue_dataset(args, n_demonstrations=0, seed=42):
    sub_dirs = os.listdir(args.dataset_name)
    print(sub_dirs)
    demo_datasets, raw_datasets = {}, {}
    for sub_dir in sub_dirs:
        train_dataset = AutoGLUETask.get(sub_dir, seed=seed).get_dataset(split="train",
                                n_obs=-1, add_prefix=True)
        # TODO: sample n-shot data
        # rnd_demo_idx = random.sample(range(len(train_dataset)), n_demonstrations)
        # demo_dataset = train_dataset.select(rnd_demo_idx)
        # TODO: auto demo
        demo_dataset = datasets.load_dataset("json", data_files=os.path.join(f'data/glue_auto_demonstration', f'{sub_dir}_train.json'))['train']
        # TODO: average all train data
        # demo_dataset = train_dataset

        # read eval dataset
        eval_dataset = AutoGLUETask.get(sub_dir, seed=seed).get_dataset(split="test",
                                n_obs=-1, add_prefix=True)
        print('print eval_dataset')
        print(eval_dataset[0])
        raw_datasets[sub_dir] = eval_dataset
        demo_datasets[sub_dir] = demo_dataset
    return demo_datasets, raw_datasets

def get_sni_all_labels(dataset, decoded_labels):
    all_labels = []
    for instance in dataset['Instance']:
        all_labels.append(instance['output'])
    all_resort_labels = []
    for label in decoded_labels:
        for all_label in all_labels:
            if label in all_label:
                all_resort_labels.append(all_label)
                break
    return all_resort_labels

def save_json(content, path, indent=4, **json_dump_kwargs):
    with open(path, "w") as f:
        json.dump(content, f, indent=indent, **json_dump_kwargs)

def load_json(path):
    with open(path) as f:
        return json.load(f)

def load_txt(path):
    data = []
    with open(path, 'r') as f:
        for line in f.readlines():
            data.append(line.strip())
    return data

def load_json_by_line(path):
    data = []
    with open(path, 'r') as f:
        for line in f.readlines():
            data.append(json.loads(line))
    return data

def save_json_file(json_dict, outfile_name, output_dir):
    """
    Saves the given dictionary as a json file to output_dir and also
    the given bucket if given.
    """
    save_json(json_dict, os.path.join(output_dir, outfile_name))

def save_predictions(output_dir, predictions, labels, task_names):
    assert len(predictions) == len(labels), (len(predictions), len(labels))

    output_test_preds_file = os.path.join(output_dir, "test_preds_seq2seq.txt")
    with open(output_test_preds_file, 'w') as writer:
        for i in range(len(predictions)):
            writer.write("\t\t".join([task_names[i], predictions[i], labels[i]]) + '\n')