# -*- coding: utf-8 -*-

import os
import pickle
import random
import stat
import string
import sys
import tempfile
from zipfile import ZipFile

from ..common.logger import LOGGER, open_file
from ..common.utils import query_yes_no

try:
    from pip.utils import get_installed_distributions
    from pip import main as pip_main
except ModuleNotFoundError:
    from pip._internal.utils.misc import get_installed_distributions
    from pip._internal import main as pip_main


__ASSETS__ = ['magic_load.py', 'magic_load.sh', 'magic_entry.py']

PACKAGE_CACHE = None

PYTHON_LOAD_PATH = os.path.join(os.path.dirname(__file__), 'magic_load.py')
SHELL_LOAD_PATH = os.path.join(os.path.dirname(__file__), 'magic_load.sh')

SHELL_HEADER = open(SHELL_LOAD_PATH).read().format(
    ''.join(random.choices(string.ascii_uppercase, k=8)),  # random separator
    open(PYTHON_LOAD_PATH).read()  # python load
).encode()

ENTRY_FILE_CONTENT = '''# -*- coding: utf-8 -*-
from {}.magic_entry import main

if __name__ == '__main__':
    main()
'''


def _is_inside_venv():
    return hasattr(sys, 'real_prefix') \
        or (hasattr(sys, 'base_prefix') and sys.base_prefix != sys.prefix)


def _build_package_index():
    file2package = {}
    packages = {}
    for dist in get_installed_distributions():
        packages[dist.project_name] = dist.version
        # RECORDs should be part of .dist-info metadatas
        if dist.has_metadata('RECORD'):
            lines = dist.get_metadata_lines('RECORD')
            paths = [l.split(',')[0] for l in lines]
            paths_absolute = [os.path.normpath(os.path.join(dist.location, p)) for p in paths]
        # Otherwise use pip's log for .egg-info's
        elif dist.has_metadata('installed-files.txt'):
            paths = dist.get_metadata_lines('installed-files.txt')
            paths_absolute = [os.path.normpath(os.path.join(dist.egg_info, p)) for p in paths]
        else:
            paths = []
            paths_absolute = []
        for path in paths_absolute:
            file2package[path] = dist
    return file2package, packages


def get_packages_cache():
    global PACKAGE_CACHE
    if PACKAGE_CACHE is None:
        PACKAGE_CACHE = _build_package_index()
    return PACKAGE_CACHE


def clear_package_cache():
    global PACKAGE_CACHE
    PACKAGE_CACHE = None


def get_package_version(project_name):
    return get_packages_cache()[1].get(project_name)


def get_file_package(path):
    return get_packages_cache()[0].get(path)


def get_module_path(module):
    return getattr(module, '__file__', None)


def get_module_package(module):
    path = get_module_path(module)
    if path is not None:
        return get_file_package(path)


def get_module_version(module, dist=None):
    if dist is None:
        dist = get_module_package(module)

    version = dist.version if dist is not None else None
    if version is None:
        for name in ('version', '__version__', 'VERSION', '__VERSION__', 'get_version'):
            version = getattr(module, name, None)
            if callable(version):
                try:
                    version = version()
                except Exception:
                    version = None
            if version is not None:
                break
    return version


def get_module_infos(base_dir):
    base_dir = os.path.normpath(os.path.abspath(base_dir))
    packages = {}
    user_modules = {}
    for module in list(sys.modules.values()):
        name = module.__name__
        module_file = get_module_path(module)
        if module_file is None:
            continue
        module_file = os.path.normpath(os.path.abspath(module_file))
        if module_file.startswith(base_dir):
            model_rel_path = os.path.relpath(module_file, base_dir)

            external_files = getattr(module, '__ASSETS__', None)
            if external_files is not None:
                if any(os.path.isabs(external_file) for external_file in external_files):
                    raise Exception('__ASSETS__ should contain only relative paths')

            user_modules[name] = model_rel_path, module_file, external_files

        dist = get_file_package(module_file)
        if dist is not None:
            packages[name.split('.')[0]] = dist.project_name, get_module_version(module, dist=dist)

    user_modules['__this_module__'] = user_modules['__main__']
    del user_modules['__main__']
    return packages, user_modules


def install_given_packages(packages):
    missing_packages = {}
    for module_name, (package_name, required_version) in packages.items():
        installed_version = get_package_version(package_name)
        if installed_version is None or required_version != installed_version:
            info = missing_packages.get(package_name)
            if info is None:
                info = missing_packages[package_name] = (installed_version, required_version, [])
            info[-1].append(module_name)

    if missing_packages:
        args = ['install']
        if not _is_inside_venv():
            args.append('--user')

        LOGGER.warning('Following packages are mismatched:')
        for package_name, (installed_version,
                           required_version,
                           module_names) in missing_packages.items():
            installed_version = installed_version or 'MISSING'
            LOGGER.warning('  %s: %s is installed, but %s is required. %s',
                           package_name, installed_version, required_version, module_names)

            args.append(f'{package_name}=={required_version}' if required_version else package_name)

        if query_yes_no('Install?', default='no') == 'y':
            pip_main(args)


def collect_user_modules(output_path, entry_class, entry_point, base_dir=None):
    if base_dir is None:
        this_module = sys.modules.get('__main__')
        if this_module is not None:
            base_dir = os.path.dirname(os.path.abspath(get_module_path(this_module)))
    packages, modules = get_module_infos(base_dir)

    with ZipFile(output_path, 'w') as code_zip:
        code_zip.writestr('__HEAD__', pickle.dumps({
            'packages': packages,
            'packer': __name__,
            'entry_point': entry_point,
            'entry_class': entry_class,
        }))
        code_zip.writestr('__entry_point__.py', ENTRY_FILE_CONTENT.format(__package__))
        for module_name, (_, module_file, external_files) in modules.items():
            module_arc_file = module_name.replace('.', os.path.sep)
            if module_file.endswith('__init__.py'):
                module_arc_file = os.path.join(module_arc_file, '__init__.py')
            elif not module_arc_file.endswith('.py'):
                module_arc_file += '.py'

            code_zip.write(module_file, module_arc_file)
            if external_files is not None:
                module_dir = os.path.dirname(module_file)
                module_arc_dir = os.path.dirname(module_arc_file)
                for file in external_files:
                    arcname = os.path.join(module_arc_dir, file)
                    external_file = os.path.join(module_dir, file)
                    code_zip.write(external_file, arcname)


def pack_to_executable(data_bytes, output_path, entry_class, entry_point, base_dir=None):
    if not isinstance(entry_class, str):  # convert class to path
        module = entry_class.__module__
        if module is None or module == str.__module__:
            entry_class = entry_class.__name__
        else:
            entry_class = module + '.' + entry_class.__name__

    zip_file = tempfile.NamedTemporaryFile()
    collect_user_modules(zip_file.name, entry_class, entry_point, base_dir=base_dir)

    code_bytes = open_file(zip_file.name, 'rb').read()

    shell_bytes = SHELL_HEADER % (len(code_bytes), len(data_bytes))
    with open_file(output_path, 'wb') as fp:
        fp.write(shell_bytes)
        fp.write(code_bytes)
        fp.write(data_bytes)
    st = os.stat(output_path)
    os.chmod(output_path, st.st_mode | stat.S_IEXEC)
