import pickle

from compute_vocabulary_overlap import compute_vocabulary_overlap, get_vocabulary_set, compute_jaccard_index
from hfdataset import HFDataset
from dataset_processing_info import dataset_info_dict
from utils.path_utils import get_output_path
import os
import time
from datetime import datetime
from tqdm import tqdm
import shutil
import json
import numpy as np
from config import VOCAB_OVERLAP_DIR

target_dataset_names = [
    'imdb_plain_text',
    'tweet_eval_emotion',
    'tweet_eval_sentiment',
    'llm-book__JGLUE_JSTS',
    'google_wellformed_query_default',
    'paws-x_en',
    'md_gender_bias_convai2_inferred',
    'google__civil_comments_default'
]
source_dataset_names = list(dataset_info_dict.keys())

delete_cache = False
num_target_samples = 1000
num_source_samples_list = [10000]
stream_datasets = True

seed_range = [None]

failed_sources = []

for num_source_samples in tqdm(num_source_samples_list):
    for seed in seed_range:
        print(f'Seed {seed}')
        # results_dict = {'Sources': dataset_names}

        for target_dataset_name in target_dataset_names:
            # target_results = []

            target_output_dir = get_output_path(VOCAB_OVERLAP_DIR,
                                                num_train_samples=num_target_samples,
                                                num_source_samples=num_source_samples,
                                                seed=seed,
                                                target_name=target_dataset_name)


            start_time = time.time()

            target_dataset = HFDataset(target_dataset_name,
                                       split='train',
                                       max_num_examples=num_target_samples,
                                       seed=seed,
                                       streaming=stream_datasets)

            target_vocabulary_set = get_vocabulary_set(target_dataset)

            for source_dataset_name in tqdm(source_dataset_names):

                output_dir = get_output_path(target_output_dir, source_name=source_dataset_name)
                if os.path.isfile(os.path.join(output_dir, 'metric.npy')):
                    continue

                try:
                    source_dataset = HFDataset(source_dataset_name,
                                               split='train',
                                               max_num_examples=num_source_samples,
                                               streaming=stream_datasets)
                # except:
                #     # if stream_datasets:
                #     try:
                #         source_dataset = HFDataset(source_dataset_name,
                #                                    split='train',
                #                                    max_num_examples=num_source_samples,
                #                                    streaming=False)
                #     except Exception as e:
                #         print(source_dataset_name)
                #         print(e)
                #         failed_sources.append(source_dataset_name)
                #         continue

                    source_vocabulary_set = get_vocabulary_set(source_dataset)

                    vocabulary_overlap = compute_jaccard_index(target_vocabulary_set, source_vocabulary_set)
                    # overlap_score = compute_vocabulary_overlap(target_dataset, source_dataset)
                    # print(f'Overlap of {target_dataset_name} and {source_dataset_name}: {round(overlap_score, 2)}')
                    # target_results.append(vocabulary_overlap)
                    os.makedirs(output_dir, exist_ok=True)
                    np.save(os.path.join(output_dir, 'metric.npy'), vocabulary_overlap)
                except Exception as e:
                    print(source_dataset_name)
                    print(str(e))
                    failed_sources.append(source_dataset_name)

            time_elapsed = time.time() - start_time

            # results_dict[target_dataset_name] = target_results

            timer_dict = {
                'timestamp': datetime.now().strftime("%m/%d/%Y, %H:%M:%S"),
                'elapsed': time_elapsed,
                'num_sources': len(source_dataset_names)
            }

            with open(os.path.join(target_output_dir, 'timer.json'), 'w') as f:
                json.dump(timer_dict, f)

    with open(os.path.join(VOCAB_OVERLAP_DIR, 'failed_sources.pkl'), 'wb') as f:
        pickle.dump(failed_sources, f)
