import collections
import itertools
import math
import pickle


class XYProbs:
    def __init__(self, default_xy_score_alpha=.5):
        self.total = 0
        self.x_count = collections.Counter()
        self.y_count = collections.Counter()
        self.xy_count = collections.Counter()
        self.y_xs = {}
        self.x_ys = {}
        self.default_xy_score_alpha = default_xy_score_alpha

    def update(self, xs, ys):
        xs = set(xs)
        ys = set(ys)
        self.total += 1
        self.x_count.update(xs)
        self.y_count.update(ys)
        self.xy_count.update(itertools.product(xs, ys))

        for y in ys:
            if y not in self.y_xs:
                self.y_xs[y] = set()
            self.y_xs[y].update(xs)

        for x in xs:
            if x not in self.x_ys:
                self.x_ys[x] = set()
            self.x_ys[x].update(ys)

    def py(self, y, cond_x=None, cond_not_x=None):
        if cond_x is None and cond_not_x is None:
            return self.y_count[y] / self.total
        if cond_x is not None:
            return self.xy_count[cond_x, y] / self.x_count[cond_x]
        if cond_not_x is not None:
            x = cond_not_x
            return (self.py(y) - self.pxy(x, y)) / self.pnx(x)

    def pny(self, y, cond_x=None, cond_not_x=None):
        return 1 - self.py(y, cond_x, cond_not_x)

    def px(self, x, cond_y=None, cond_not_y=None):
        if cond_y is None and cond_not_y is None:
            return self.x_count[x] / self.total
        if cond_y is not None:
            return self.xy_count[x, cond_y] / self.y_count[cond_y]
        if cond_not_y is not None:
            y = cond_not_y
            return (self.px(x) - self.pxy(x, y)) / self.pny(y)

    def pnx(self, x, cond_y=None, cond_not_y=None):
        return 1 - self.px(x, cond_y, cond_not_y)

    def pxy(self, x, y):
        return self.xy_count[x, y] / self.total

    def xy_score(self, x, y, alpha=None):
        if alpha is None:
            alpha = self.default_xy_score_alpha
        y_term = self.py(y, cond_x=x)**alpha
        x_term = self.px(x, cond_y=y)**(1 - alpha)
        return y_term * x_term

    def x_info_gain(self, x, y=None):
        r = 0
        if y is None:
            for y_i in self.y_count:
                cond_x = self.py(y_i, cond_x=x)
                if cond_x == 0:
                    continue
                r += cond_x * math.log(cond_x / self.py(y_i))
        else:
            p_y_cond_x = self.py(y, cond_x=x)
            if p_y_cond_x != 0:
                r += p_y_cond_x * math.log(p_y_cond_x / self.py(y))

            p_not_y_cond_x = 1 - p_y_cond_x
            p_not_y = 1 - self.py(y)
            if p_not_y_cond_x != 0:
                r += p_not_y_cond_x * math.log(p_not_y_cond_x / p_not_y)

        return r

    def save(self, path):
        with open(path, 'wb') as f:
            pickle.dump(self, f)


def load(path):
    with open(path, 'rb') as f:
        return pickle.load(f)
