import itertools
import os
import pathlib
import subprocess
from typing import NoReturn, List, Tuple, Union, Any

import logging
import IPython

logger = logging.getLogger(__name__)


class OMCSBase:
    def __init__(self) -> NoReturn:
        self.cache_path: pathlib.Path = pathlib.Path(os.path.expanduser('~/mowgli-cache'))
        self.cn_web_paths = [
            'https://s3.amazonaws.com/conceptnet/downloads/2018/omcs-sentences-more.txt',
            'https://s3.amazonaws.com/conceptnet/downloads/2018/omcs-sentences-free.txt',
        ]
        self._check_cn_data_exists()

    def _check_cn_data_exists(self):
        for p in self.cn_web_paths:
            file_name = p.split('/')[-1]
            if not (self.cache_path / file_name).exists():
                try:
                    subprocess.run(args=['wget', f'--directory-prefix={self.cache_path}', p], check=True)
                except subprocess.CalledProcessError:
                    logger.error(f'Failed to download {p}')

    @staticmethod
    def _lookup_in_file(words: List[str], path: pathlib.Path, keep_score: bool = False, add_words: bool = False) \
            -> List[Union[str, Tuple[str, float]]]:
        cmd_words = " && ".join([f'/{w.lower()}/' for w in words])
        # command = f'awk \'{cmd_words}\' {path}'
        out = subprocess.run(['awk', cmd_words, path],
                             check=True, stdout=subprocess.PIPE,
                             stderr=subprocess.PIPE)

        if keep_score:
            keep_func = lambda m: (m[1], float(m[-1]))
        else:
            keep_func = lambda m: (m[1],)

        pre = (words,) if add_words else ()

        return [(*pre, *keep_func(m)) for m in
                filter(
                    lambda m: m[4] == "en",  # only english sentences
                    map(
                        lambda m: m.split('\t'),
                        filter(len,  # only valid sentences
                               out.stdout.decode("utf-8").split('\n')
                               )
                    )
                )]


def flatten_2D_list(in_list: List[List[Any]]) -> Tuple[List[Any], List[int]]:
    l_size: List[int] = list(map(len, in_list))
    l_flat: List[Any] = list(itertools.chain.from_iterable(in_list))

    return l_flat, l_size

def unflatten_list(in_list: List[Any], sizes: List[int]) -> List[List[Any]]:
    out_list:List[List[Any]] = []

    for s in sizes:
        pass

    return out_list