import os
import sys
import zipfile
from findfile import find_file, find_cwd_file
from LSA.core.apc.prediction.sentiment_classifier import SentimentClassifier
from LSA.utils.absa_utils import get_device


def unzip_checkpoint(zip_path):
    try:
        print('Find zipped checkpoint: {}, unzipping...'.format(zip_path))
        sys.stdout.flush()
        with zipfile.ZipFile(zip_path, 'r') as z:
            z.extractall(zip_path.replace('.zip', ''))
        print('Done.')
    except zipfile.BadZipfile:
        print('Unzip failed'.format(zip_path))
    return zip_path.replace('.zip', '')


class CheckpointManager:
    pass


class APCCheckpointManager(CheckpointManager):
    @staticmethod
    def get_sentiment_classifier(checkpoint: str = None,
                                 sentiment_map: dict = None,
                                 auto_device=True,
                                 eval_batch_size=128):
        """

        :param checkpoint: zipped checkpoint name, or checkpoint path or checkpoint name queried from Google Drive
        :param sentiment_map: label to text index map (deprecated and has no effect)
        This param is for someone wants to load a checkpoint not registered in PyABSA
        :param auto_device: True or False, otherwise 'cuda', 'cpu' works
        :param eval_batch_size: eval batch_size in modeling

        :return:
        """
        checkpoint_config = None
        # find ckpt from zip file
        if checkpoint.endswith('.zip'):
            # find ckpt from unzipped file
            checkpoint_config = find_cwd_file([checkpoint.strip('.zip'), '.config'])
            if not checkpoint_config:
                # unzip ckpt
                checkpoint = unzip_checkpoint(find_cwd_file(checkpoint))
            checkpoint_config = find_cwd_file([checkpoint, '.config'])
        # find ckpt from a path
        if not checkpoint_config and os.path.exists(checkpoint):
            checkpoint_config = find_file(checkpoint, ['.config'])
        # use "checkpoint" as a keyword to search ckpt
        if not checkpoint_config:
            checkpoint_config = find_cwd_file([checkpoint, '.config'])

        # obtain ckpt location if it is found in the local envs
        if checkpoint_config:
            checkpoint = os.path.dirname(checkpoint_config)

        sent_classifier = SentimentClassifier(checkpoint, sentiment_map=sentiment_map, eval_batch_size=eval_batch_size)
        device, device_name = get_device(auto_device)
        sent_classifier.to(device)
        return sent_classifier
