import heapq
import math
import operator
import os
import pathlib
import pickle
import sys

import numpy as np
import joblib
import tqdm

import skemb.sgpattern
import skemb.tokenizer
import skemb.xyprobs


class Embedding:
    def set_params(*k, **kw):
        pass

    def train(self, texts, sequences, labels):
        raise NotImplementedError()

    def save(self, path):
        raise NotImplementedError()

    def load(self, path):
        raise NotImplementedError()

    def extract(self, text, seq):
        raise NotImplementedError()

    def logging_enabled(self):
        return bool(getattr(self, 'logfile', None))

    def log(self, *k, **kw):
        logfile = getattr(self, 'logfile', None)
        if not logfile:
            return
        kw['file'] = logfile
        print(*k, **kw)


class SGPatternEmbedding(Embedding):
    def train(self, texts, sequences, labels, patterns):
        raise NotImplementedError()


class BinaryBagOfPatterns(SGPatternEmbedding):
    def train(self, texts, sequences, labels, patterns):
        self.__patterns = patterns
        self.__indices = {k: i for i, k in enumerate(self.__patterns)}

    def save(self, path):
        path = pathlib.Path(path)
        path.mkdir()

        (path / 'type').write_text('BinaryBagOfPatterns')

        self.__patterns.save(path / 'patterns.pickle')
        with open(path / 'indices.pickle', 'wb') as f:
            pickle.dump(self.__indices, f)

    def load(self, path):
        path = pathlib.Path(path)
        patterns = skemb.sgpattern.load(path / 'patterns.pickle')
        with open(path / 'indices.pickle', 'rb') as f:
            self.__indices = pickle.load(f)

    def extract(self, text, seq):
        matched_patterns = set(pattern for _, pattern in self.__patterns.search(seq))
        r = np.zeros(len(self.__patterns))
        for p in matched_patterns:
            r[self.__indices[p]] = 1
        return r


class PatternsProbabilities(SGPatternEmbedding):
    def __msg(self, *k, **kw):
        kw['file'] = sys.stderr
        print(*k, **kw)

    def set_params(self, alpha=.5, min_support=10, dimension=100, multiclass=None):
        self.__alpha = alpha
        self.__min_support = min_support
        self.__dimension = dimension
        self.__multiclass = multiclass

    def train(self, texts, sequences, labels, patterns):
        self.__msg('Matching patterns')
        def search(seq):
            return set(p for _, p in patterns.search(seq))
        X = joblib.Parallel(
            n_jobs=os.cpu_count(),
            verbose=10,
            batch_size=math.ceil(len(sequences) / os.cpu_count()),
            pre_dispatch='all',
        )(joblib.delayed(search)(seq) for seq in sequences)

        self.__msg('Calculating label probabilities')
        probs = skemb.xyprobs.XYProbs()
        for xs, label in zip(tqdm.tqdm(X), labels):
            probs.update(xs, {label})

        if self.__multiclass is None:
            self.__multiclass = len(probs.y_count) > 2

        if self.__multiclass:
            candidates = set(probs.x_count)
            selected = set()
            while candidates and len(selected) < self.__dimension:
                for y in probs.y_count:
                    d = self.__dimension // len(probs.y_count)
                    if len(selected) + d > self.__dimension:
                        d = self.__dimension - len(selected)
                    pairs = [(p, probs.x_info_gain(p, y=y)) for p in candidates]
                    selected_for_y = heapq.nlargest(d, pairs, key=operator.itemgetter(1))
                    selected_for_y = [p for p, _ in selected_for_y]
                    selected.update(selected_for_y)
                    candidates.difference_update(selected_for_y)
        else:
            candidates = [(p, probs.x_info_gain(p)) for p in probs.x_count]
            selected = heapq.nlargest(self.__dimension, candidates, key=operator.itemgetter(1))
            selected = [p for p, _ in selected]
        self.__patterns = skemb.sgpattern.SGPatterns(patterns.k, selected)
        self.__indices = {k: i for i, k in enumerate(self.__patterns)}

        if self.logging_enabled():
            self.log('SELECTED PATTERNS')
            for p in selected:
                self.log(f'{p}')

        sequences = list(sequences)
        probs = skemb.xyprobs.XYProbs()

        self.__msg('Calculating patterns probabilities')
        for seq in tqdm.tqdm(sequences):
            ys = set(p for _, p in self.__patterns.search(seq))
            xs = set(attr for attrset in seq for attr in attrset)
            probs.update(xs, ys)

        self.__probs = probs
        self.__generate_indexed_probs()

    def extract(self, text, seq, return_terms=False):
        matched_patterns = set(pattern for _, pattern in self.__patterns.search(seq))
        attrs = set(attr for attrset in seq for attr in attrset)
        attrs = [attr for attr in attrs if self.__probs.x_count[attr] > self.__min_support]
        attr_indices = [self.__attr_indices[attr] for attr in attrs]

        u1 = np.zeros(len(self.__patterns))
        u2 = np.zeros(len(self.__patterns))

        for pattern, i in self.__indices.items():
            w1 = 1 if pattern in matched_patterns else 0
            w2 = 0
            #for attr in attrs:
            for attr_idx in attr_indices:
                w2 += self.__indexed_prob_y_cond_x[i, attr_idx]
            if len(attr_indices):
                w2 /= len(attr_indices)

            u1[i] = w1
            u2[i] = w2

        if return_terms:
            return u1, u2, self.__alpha
        else:
            return self.__alpha * u1 + (1 - self.__alpha) * u2

    def save(self, path):
        path = pathlib.Path(path)
        path.mkdir()

        (path / 'type').write_text('PatternsProbabilities')

        self.__patterns.save(path / 'patterns.pickle')
        self.__probs.save(path / 'probs.pickle')

        with open(path / 'indices.pickle', 'wb') as f:
            pickle.dump(self.__indices, f)

        params = (self.__alpha, self.__min_support)
        with open(path / 'params.pickle', 'wb') as f:
            pickle.dump(params, f)

    def load(self, path):
        path = pathlib.Path(path)
        self.__patterns = skemb.sgpattern.load(path / 'patterns.pickle')
        self.__probs = skemb.xyprobs.load(path / 'probs.pickle')

        with open(path / 'indices.pickle', 'rb') as f:
            self.__indices = pickle.load(f)

        with open(path / 'params.pickle', 'rb') as f:
            self.__alpha, self.__min_support = pickle.load(f)

        self.__generate_indexed_probs()

    def __generate_indexed_probs(self):
        self.__attr_indices = {}
        for attr in self.__probs.x_count:
            self.__attr_indices[attr] = len(self.__attr_indices)

        yx_shape = len(self.__indices), len(self.__attr_indices)
        self.__indexed_prob_y_cond_x = np.zeros(yx_shape)

        for pattern, pattern_idx in self.__indices.items():
            for attr, attr_idx in self.__attr_indices.items():
                p = self.__probs.py(pattern, cond_x=attr)
                self.__indexed_prob_y_cond_x[pattern_idx, attr_idx] = p


class BagOfWords(Embedding):
    def train(self, texts, sequences, labels):
        all_tokens = set(tk for seq in sequences for tk in seq)
        self.__indices = {
            tk: i for i, tk in enumerate(all_tokens)
        }

    def save(self, path):
        path = pathlib.Path(path)
        path.mkdir()
        with open(path / 'indices.pickle', 'wb') as f:
            pickle.dump(self.__indices, f)

    def load(self, path):
        path = pathlib.Path(path)
        with open(path / 'indices.pickle', 'rb') as f:
            self.__indices = pickle.load(f)

    def extract(self, text, seq):
        tokens = set(seq)
        r = np.zeros(len(self.__indices))
        for tk in tokens:
            if tk in self.__indices:
                r[self.__indices[tk]] = 1
        return r


class Tok1BagOfWords(BagOfWords):
    def train(self, texts, sequences, labels):
        sequences = [
            skemb.tokenizer.tok1(text) for text in texts
        ]
        super().train(texts, sequences, labels)

    def save(self, path):
        super().save(path)
        (path / 'type').write_text('Tok1BagOfWords')


class Tok2BagOfWords(BagOfWords):
    def train(self, texts, sequences, labels):
        sequences = [
            skemb.tokenizer.tok2(text) for text in texts
        ]
        super().train(texts, sequences, labels)

    def save(self, path):
        super().save(path)
        (path / 'type').write_text('Tok2BagOfWords')


def load(path):
    path = pathlib.Path(path)
    model_type = (path / 'type').read_text()

    if model_type == 'BinaryBagOfPatterns':
        model = BinaryBagOfPatterns()
        model.load(path)
    elif model_type == 'PatternsProbabilities':
        model = PatternsProbabilities()
        model.load(path)
    elif model_type == 'Tok1BagOfWords':
        model = Tok1BagOfWords()
        model.load(path)
    elif model_type == 'Tok2BagOfWords':
        model = Tok2BagOfWords()
        model.load(path)
    else:
        raise Exception(f'unknown model type: {model_type}')

    return model
