import collections
import heapq
import operator
import sys

import tqdm

import skemb.xyprobs
import skemb.sgpattern


class SGPatternMiner:
    def __init__(self,
                 sequences,
                 labels,
                 str_representations=None,
                 skipgram_k=5,
                 min_support=10,
                 max_iterations=10,
                 pattern_buffer_size=6,
                 xy_score_alpha=.8,
                 min_xy_score_diff=.05):
        self.sequences = list(sequences)
        self.str_representations = list(str_representations)
        self.labels = labels
        self.skipgram_k = skipgram_k
        self.min_support = min_support
        self.max_iterations = max_iterations
        self.logfile = None
        self.msgfile = sys.stderr
        self.pattern_buffer_size = pattern_buffer_size
        self.xy_score_alpha = xy_score_alpha
        self.min_xy_score_diff = min_xy_score_diff

    def __attribute_is_valid(self, attr):
        if self.__single_attributes_probs.x_count[attr] < self.min_support:
            return False
        return True

    def run(self):
        self.__calc_single_attributes_probs()
        selected_patterns = set()
        last_selection = set()
        last_selection_scores = dict()
        patterns_buffer = collections.deque([], self.pattern_buffer_size)
        for it in range(self.max_iterations):
            self.__it = it
            self.__msg(f'Running iteration #{it}')
            # Generate new patterns: if this is the first iteration, then we
            # gather every attribute in of the items of the sequences.
            # Otherwise, we get the last selection and extend the patterns in
            # there.
            self.__msg(f'it#{it}: Finding new patterns')
            if it == 0:
                new_patterns = self.__patterns_for_first_iteration()
                new_probs = self.__calc_probs(new_patterns)
                prev_map = dict()
            else:
                new_patterns, new_probs, prev_map = self.__extend_patterns(
                    last_selection,
                    selected_patterns,
                    patterns_buffer,
                )

            self.__msg(f'it#{it}: Found {len(new_patterns)} new patterns')

            # Filter new patterns
            self.__msg(f'it#{it}: Filtering new patterns')
            filtered = set()
            for pattern in new_patterns:
                if new_probs.x_count[pattern] < self.min_support:
                    continue

                # Accept only patterns that improve previous ones
                if pattern in prev_map:
                    for prev_pattern in prev_map[pattern]:
                        for label, prev_score in last_selection_scores[prev_pattern].items():
                            diff = new_probs.xy_score(pattern, label) - prev_score
                            if diff >= self.min_xy_score_diff:
                                break
                        else:
                            continue
                        break
                    else:
                        continue

                filtered.add(pattern)


            new_patterns = set(filtered)
            self.__msg(f'it#{it}: {len(new_patterns)} new patterns after filtering')

            # Push new patterns to the buffer
            patterns_buffer.append((new_patterns, new_probs))

            # Select patterns
            self.__msg(f'it#{it}: Performing pattern selection')
            last_selection, last_selection_scores = self.__select_patterns(patterns_buffer)
            self.__msg(f'it#{it}: Selected {len(last_selection)} patterns')
            selected_patterns.update(last_selection)

            # Clear selected patterns from the buffer
            for patterns, _ in patterns_buffer:
                patterns.difference_update(last_selection)

        selected_patterns = self.__final_selection(selected_patterns)
        self.__msg(f'Final selection has {len(selected_patterns)} patterns')

        return skemb.sgpattern.SGPatterns(self.skipgram_k, selected_patterns)

    def __final_selection(self, patterns):
        self.__msg('Calculating probabilities for final selection of patterns')
        probs = self.__calc_probs(patterns)

        self.__log('')
        self.__log('FINAL PATTERN SELECTION')
        self.__log('=======================')

        finder = skemb.sgpattern.SGPatterns(self.skipgram_k, patterns)

        final_selection = set()

        str_representations = self.str_representations or self.sequences
        for seq, label, str_representation in zip(tqdm.tqdm(self.sequences), self.labels, str_representations):
            patterns = set(pattern for _, pattern in finder.search(seq))
            patterns = tuple(
                p for p in patterns
                if probs.x_info_gain(p) >= 0.5
            )
            candidates = ((p, probs.xy_score(p, label)) for p in patterns)
            if candidates:
                selected = heapq.nlargest(5, candidates, key=operator.itemgetter(1))
            else:
                selected = []

            final_selection.update(pattern for pattern, _ in selected)

            # Logging
            self.__log(f'>>> {str(str_representation)}')
            self.__log(f'    label={label}')
            for selected_idx, (pattern, score) in enumerate(selected):
                self.__log(f'    selected[{selected_idx}]={selected[selected_idx]} {score:.4f}')

        self.__log('')
        self.__log('RESULT')
        self.__log('======')
        for y in probs.y_count:
            xs = set(probs.y_xs[y]) & final_selection
            for x in sorted(xs, key=lambda x: -probs.xy_score(x, y)):
                line = (
                    f'sc={probs.xy_score(x, y):.4f}',
                    f'py_cx={probs.py(y, cond_x=x):.4f}',
                    f'px_cy={probs.px(x, cond_y=y):.4f}',
                    f'x_ig={probs.x_info_gain(x):.4f}',
                    f'py={probs.py(y):.4f}',
                    f'px={probs.py(y):.4f}',
                    f'y={y}',
                    f'x={x}',
                    f'ny={probs.y_count[y]}',
                    f'nx={probs.x_count[x]}',
                )
                self.__log(' '.join(line))

        return final_selection

    def __patterns_for_first_iteration(self):
        new_patterns = set()
        for seq in self.sequences:
            attrs = set(attr for attrset in seq for attr in attrset)
            for attr in attrs:
                if not self.__attribute_is_valid(attr):
                    continue
                new_pattern = (frozenset({attr}),)
                new_patterns.add(new_pattern)
        return new_patterns

    def __extend_patterns(self, patterns, selected_patterns, patterns_buffer):
        finder = skemb.sgpattern.SGPatterns(self.skipgram_k, patterns)
        new_probs = skemb.xyprobs.XYProbs(default_xy_score_alpha=self.xy_score_alpha)

        new_patterns = set()
        prev_map = collections.defaultdict(set)
        for seq, label in zip(tqdm.tqdm(self.sequences), self.labels):
            seq_new_patterns = set()
            for indices, pattern in finder.search(seq):
                # Find new attributes in the extremes
                for attr in seq[indices[0]] - pattern[0]:
                    if not self.__attribute_is_valid(attr):
                        continue
                    new_pattern = (pattern[0] | {attr},) + pattern[1:]
                    seq_new_patterns.add(new_pattern)
                    prev_map[new_pattern].add(pattern)

                for attr in seq[indices[-1]] - pattern[-1]:
                    if not self.__attribute_is_valid(attr):
                        continue
                    new_pattern = pattern[:-1] + (pattern[-1] | {attr},)
                    seq_new_patterns.add(new_pattern)
                    prev_map[new_pattern].add(pattern)

                max_dist = self.skipgram_k + 1

                # Find new extremes to the left
                for i in range(max(0, indices[0] - max_dist), indices[0]):
                    for attr in seq[i]:
                        if not self.__attribute_is_valid(attr):
                            continue
                        new_pattern = (frozenset({attr}),) + pattern
                        seq_new_patterns.add(new_pattern)
                        prev_map[new_pattern].add(pattern)

                # Find new extremes to the right
                for i in range(indices[-1] + 1, min(len(seq), indices[-1] + 1 + max_dist)):
                    for attr in seq[i]:
                        if not self.__attribute_is_valid(attr):
                            continue
                        new_pattern = pattern + (frozenset({attr}),)
                        seq_new_patterns.add(new_pattern)
                        prev_map[new_pattern].add(pattern)

            #   Make sure we remove patterns already processed
            seq_new_patterns.difference_update(selected_patterns)
            for patterns_from_buffer, _ in patterns_buffer:
                seq_new_patterns.difference_update(patterns_from_buffer)

            new_patterns.update(seq_new_patterns)

            xs = seq_new_patterns
            ys = {label}
            new_probs.update(xs, ys)

        return new_patterns, new_probs, prev_map

    def __calc_probs(self, patterns):
        probs = skemb.xyprobs.XYProbs(default_xy_score_alpha=self.xy_score_alpha)
        finder = skemb.sgpattern.SGPatterns(self.skipgram_k, patterns)

        for seq, label in zip(tqdm.tqdm(self.sequences), self.labels):
            xs = set(pattern for _, pattern in finder.search(seq))
            ys = {label}
            probs.update(xs, ys)

        return probs

    def __select_patterns(self, patterns_buffer):
        selected_patterns = set()
        selection_scores = collections.defaultdict(dict)

        finders = [
            skemb.sgpattern.SGPatterns(self.skipgram_k, patterns)
            for patterns, _ in patterns_buffer
        ]

        log_title = f'PATTERN SELECTION (it: #{self.__it})'
        self.__log(log_title)
        self.__log(f'=' * len(log_title))

        str_representations = self.str_representations or self.sequences

        for seq, label, str_representation in zip(tqdm.tqdm(self.sequences), self.labels, str_representations):
            # Select best pattern
            candidates = dict()
            candidate_probs = dict()
            for finder, (_, probs) in zip(finders, patterns_buffer):
                patterns = set(pattern for _, pattern in finder.search(seq))
                for pattern in patterns:
                    candidates[pattern] = probs.xy_score(pattern, label)
                    candidate_probs[pattern] = probs

            if candidates:
                selected, score = max(candidates.items(), key=operator.itemgetter(1))
                selected_patterns.add(selected)
                selection_scores[selected][label] = score
            else:
                selected = None

            # Logging
            self.__log(f'>>> {str(str_representation)}')
            self.__log(f'    label={label}')
            self.__log(f'    selected={selected}')
            if self.logfile:
                l = sorted(candidates.items(), key=operator.itemgetter(1), reverse=True)
                top_5_candidates = l[:5]
                for cand_idx, (cand, score) in enumerate(top_5_candidates):
                    cand_probs = candidate_probs[cand]
                    line = (
                        f'cand[{cand_idx}]={cand}',
                        f'score={score:.4f}',
                        f'nx={cand_probs.x_count[cand]}',
                    )
                    self.__log(f'    {" ".join(line)}')
            self.__log('')

        return selected_patterns, selection_scores

    def __calc_single_attributes_probs(self):
        probs = skemb.xyprobs.XYProbs(default_xy_score_alpha=self.xy_score_alpha)

        for seq, label in zip(self.sequences, self.labels):
            xs = set(attr for attrset in seq for attr in attrset)
            ys = {label}
            probs.update(xs, ys)

        self.__single_attributes_probs = probs
        self.__log_single_attributes_probs()

    def __log_single_attributes_probs(self):
        if not self.logfile:
            return

        self.__log('SINGLE ATTRIBUTES PROBS')
        probs = self.__single_attributes_probs
        for y in probs.y_count:
            for x in sorted(probs.y_xs[y], key=lambda x: -probs.xy_score(x, y)):
                line = (
                    f'SAP',
                    f'sc={probs.xy_score(x, y):.4f}',
                    f'py_cx={probs.py(y, cond_x=x):.4f}',
                    f'px_cy={probs.px(x, cond_y=y):.4f}',
                    f'x_ig={probs.x_info_gain(x):.4f}',
                    f'py={probs.py(y):.4f}',
                    f'px={probs.py(y):.4f}',
                    f'y={y}',
                    f'x={x}',
                    f'ny={probs.y_count[y]}',
                    f'nx={probs.x_count[x]}',
                )
                self.__log(' '.join(line))


    def __log(self, *k , **kw):
        if not self.logfile:
            return
        kw['file'] = self.logfile
        print(*k, **kw)

    def __msg(self, *k , **kw):
        if not self.msgfile:
            return
        kw['file'] = self.msgfile
        print(*k, **kw)
