import os
import shutil
import tempfile
import time

import git
from findfile import find_files, find_dir
from termcolor import colored


class DatasetItem(list):
    def __init__(self, dataset_name, dataset_items=None):
        super().__init__()
        if os.path.exists(dataset_name):
            # print('Construct DatasetItem from {}, assign dataset_name={}...'.format(dataset_name, os.path.basename(dataset_name)))
            # Normalizing the dataset's name (or path) to not end with a '/' or '\'
            while dataset_name and dataset_name[-1] in ['/', '\\']:
                dataset_name = dataset_name[:-1]

        # Naming the dataset with the normalized folder name only
        self.dataset_name = os.path.basename(dataset_name)

        # Creating the list of items if it does not exist
        if not dataset_items:
            dataset_items = dataset_name

        if not isinstance(dataset_items, list):
            self.append(dataset_items)
        else:
            for d in dataset_items:
                self.append(d)
        self.name = self.dataset_name


class ABSADatasetList(list):
    # SemEval
    Laptop14 = DatasetItem('Laptop14', 'Laptop14')
    Restaurant14 = DatasetItem('Restaurant14', 'Restaurant14')

    Restaurant15 = DatasetItem('Restaurant15', 'Restaurant15')
    Restaurant16 = DatasetItem('Restaurant16', 'Restaurant16')

    # Twitter
    ACL_Twitter = DatasetItem('Twitter', 'Twitter')

    MAMS = DatasetItem('MAMS', 'MAMS')

    def __init__(self):
        dataset_list = [
            self.Laptop14, self.Restaurant14, self.Restaurant15, self.Restaurant16,
            self.ACL_Twitter, self.MAMS
        ]
        super().__init__(dataset_list)


filter_key_words = ['.py', '.md', 'readme', 'log', 'result', 'zip', '.state_dict', '.model', '.png', 'acc_', 'f1_']


def detect_dataset(dataset_path, task='apc', load_aug=False):
    if not isinstance(dataset_path, DatasetItem):
        dataset_path = DatasetItem(dataset_path)
    dataset_file = {'train': [], 'test': [], 'valid': []}

    search_path = ''
    d = ''
    for d in dataset_path:
        if not os.path.exists(d) or hasattr(ABSADatasetList, d):
            print('Loading {} dataset '.format(d))
            search_path = find_dir(os.getcwd(), [d, task, 'dataset'], exclude_key=['infer', 'test.'] + filter_key_words, disable_alert=True)
            if load_aug:
                dataset_file['train'] += find_files(search_path, [d, 'train', task], exclude_key=['.inference', 'test.'] + filter_key_words)
                dataset_file['test'] += find_files(search_path, [d, 'test', task], exclude_key=['.inference', 'train.'] + filter_key_words)
                dataset_file['valid'] += find_files(search_path, [d, 'valid', task], exclude_key=['.inference', 'train.'] + filter_key_words)
                dataset_file['valid'] += find_files(search_path, [d, 'dev', task], exclude_key=['.inference', 'train.'] + filter_key_words)
            else:
                dataset_file['train'] += find_files(search_path, [d, 'train', task], exclude_key=['.inference', 'test.'] + filter_key_words + ['.ignore'])
                dataset_file['test'] += find_files(search_path, [d, 'test', task], exclude_key=['.inference', 'train.'] + filter_key_words + ['.ignore'])
                dataset_file['valid'] += find_files(search_path, [d, 'valid', task], exclude_key=['.inference', 'train.'] + filter_key_words + ['.ignore'])
                dataset_file['valid'] += find_files(search_path, [d, 'dev', task], exclude_key=['.inference', 'train.'] + filter_key_words + ['.ignore'])

        else:
            print('Try to load {} dataset from local')
            if load_aug:
                dataset_file['train'] += find_files(d, ['train', task], exclude_key=['.inference', 'test.'] + filter_key_words)
                dataset_file['test'] += find_files(d, ['test', task], exclude_key=['.inference', 'train.'] + filter_key_words)
                dataset_file['valid'] += find_files(d, ['valid', task], exclude_key=['.inference', 'train.'] + filter_key_words)
                dataset_file['valid'] += find_files(d, ['dev', task], exclude_key=['.inference', 'train.'] + filter_key_words)
            else:
                dataset_file['train'] += find_files(d, ['train', task], exclude_key=['.inference', 'test.'] + filter_key_words + ['.ignore'])
                dataset_file['test'] += find_files(d, ['test', task], exclude_key=['.inference', 'train.'] + filter_key_words + ['.ignore'])
                dataset_file['valid'] += find_files(d, ['valid', task], exclude_key=['.inference', 'train.'] + filter_key_words + ['.ignore'])
                dataset_file['valid'] += find_files(d, ['dev', task], exclude_key=['.inference', 'train.'] + filter_key_words + ['.ignore'])

    if len(dataset_file['train']) == 0:
        if os.path.isdir(d) or os.path.isdir(search_path):
            print('No train set found from: {}, detected files: {}'.format(dataset_path, ', '.join(os.listdir(d) + os.listdir(search_path))))
        raise RuntimeError(
            'Fail to locate dataset: {}'.format(dataset_path)
        )

    return dataset_file
