from collections import defaultdict

from common import cyk_spans
from rule import Sym, Rule, nt_escape

class Node(object):
    def init_leaf(self, i, sym, rule, binarizer):
        self.sym = sym
        self.tag = sym.sym
        self.i = i  # start of span
        self.j = i + 1  # end of span
        self.rule = rule
        self.binarizer = binarizer
        self.is_leaf = True
        self.lexical = not self.sym.isvar  # lexical nodes cover only terminals

        if self.sym.isvar:
            self.min = self.max = self.rule.f2e[self.rule.sym_idx2nt_idx[i]]

        # init measures
        self.binarizable = 0  # 0 for True and 1 for False
        self.linear = 0
        self.scope = 0
        self.exp = 0  # expected number of edges to be built
                      # for the binarization tree rooted at this node
        self.new_nt = 0

    def init_nonleaf(self, lnode, rnode):
        self.tag = '%s+%s' % (nt_escape(lnode.tag),
                              nt_escape(rnode.tag))
        self.sym = Sym('[V_%s]' % self.tag)
        self.lnode = lnode
        self.rnode = rnode
        self.i, self.j = self.lnode.i, self.rnode.j
        assert self.lnode.rule is self.rnode.rule
        self.rule = self.lnode.rule
        assert self.lnode.binarizer is self.rnode.binarizer
        self.binarizer = self.lnode.binarizer 
        self.is_leaf = False
        self.lexical = lnode.lexical and rnode.lexical

        # compute measures
        self.compute_binarizable()
        self.compute_linear()
        self.compute_scope()
        self.compute_exp()
        self.compute_new_nts()

    def leaf(self):
        return self.is_leaf

    def __str__(self):
        if self.leaf():
            return '%s' % self.sym
        else:
            return '(%s %s)' % (self.lnode, self.rnode)

    def info(self):
        return "[%s,%s] %s (%s, %s, %s, %s)" % (
            self.i,
            self.j,
            str(self),
            self.binarizable,
            self.linear,
            self.scope,
            self.exp)

    def __lt__(self, other):
        self_cost = ()
        other_cost = ()
        for c in self.binarizer.cost_function.split():
            if c == 'b':
                self_cost += (self.binarizable,)
                other_cost += (other.binarizable,)
            elif c == 'l':
                self_cost += (self.linear,)
                other_cost += (other.linear,)
            elif c == 'e':
                self_cost += (self.exp,)
                other_cost += (other.exp,)
            elif c == 'n':
                self_cost += (self.new_nt,)
                other_cost += (other.new_nt,)
            elif c == '-n':
                self_cost += (-self.new_nt,)
                other_cost += (-other.new_nt,)
            else:
                assert False, 'illegal cost function: %s' % c
        return self_cost < other_cost

    def iter_nodes(self):
        """Post-order iteration."""
        if not self.leaf():
            for n in self.lnode.iter_nodes():
                yield n
            for n in self.rnode.iter_nodes():
                yield n
        yield self

    def monotonic(self):
        """Return True if only straight (inverted) rules are present."""
        # TODO: wrong
        if self.leaf():
            return True
        b = self.inverted
        for n in self.iter_nodes():
            if not n.leaf():
                if n.inverted != b:
                    return False
        return True

    # ----------- begin of methods class users usually do not need------------

    def compute_binarizable(self):
        """A node is binarizable if and only if:
        1. both of its children are binarizable
        2. its children cover consecutive spans of nonterminals on the e side

        Also, unaligned (lexical) nodes are always binarizable."""
        if self.lnode.binarizable == 0 and self.rnode.binarizable == 0:
            if self.lexical:
                self.binarizable = 0
                self.inverted = False
            elif self.lnode.lexical:
                self.min, self.max = self.rnode.min, self.rnode.max
                self.binarizable = 0
                self.inverted = False
            elif self.rnode.lexical:
                self.min, self.max = self.lnode.min, self.lnode.max
                self.binarizable = 0
                self.inverted = False
            else:
                if self.lnode.max == self.rnode.min - 1:
                    self.min = self.lnode.min
                    self.max = self.rnode.max
                    self.binarizable = 0
                    self.inverted = False
                elif self.rnode.max == self.lnode.min - 1:
                    self.min = self.rnode.min
                    self.max = self.lnode.max
                    self.binarizable = 0
                    self.inverted = True
                else:
                    self.binarizable = 1
        else:
            self.binarizable = 1

    def compute_linear(self):
        """A node is linear if and only if it combines a linear node and
        a leaf node."""
        if (self.lnode.linear == 0 and self.rnode.leaf()) or \
           (self.rnode.linear == 0 and self.lnode.leaf()):
            self.linear = 0
        else:
            self.linear = 1

    def compute_scope(self):
        """Double-counting scope computation."""
        span = [sym.isvar for sym in self.rule.f[self.i:self.j]]
        scope = sum(l and r for l, r in zip([True] + span, span + [True]))
        self.scope = self.lnode.scope + self.rnode.scope + scope

    def compute_exp(self):
        """Compute expected number of edges to be built following this path.
        """
        p = self.binarizer.node_prob[(self.i, self.j)]
        self.exp = self.lnode.exp + self.rnode.exp + p

    def compute_new_nts(self):
        """New nonterminals generated by this path."""
        self.new_nt = self.lnode.new_nt + self.rnode.new_nt
        if self.sym not in self.binarizer.binarized_nt_count:
            self.new_nt += 1

class SCFGBinarizer(object):
    def __init__(self, freq_file, early_attach, cost_function):
        """The binarizer counts freq_file for source side word prob estimates.
        """
        self.early_attach = early_attach
        self.cost_function = cost_function
        count = {}
        f = open(freq_file)
        for line in f:
            for word in line.split():
                if word in count:
                    count[word] += 1
                else:
                    count[word] = 1
        f.close()
        total = sum(c for c in count.values())
        self.word_prob = {}
        for w, c in count.items():
            self.word_prob[w] = float(c)/total

        # stats
        self.processed = defaultdict(int)
        self.binarizable = defaultdict(int)
        self.linear = defaultdict(int)
        self.monotonic = defaultdict(int)
        self.max_nt_count = 0  # max number of rhs nts in any rule

        self.original_nt_count = {}
        self.binarized_nt_count = {}
        self.lbinarized_nt_count = defaultdict(int)
        self.rbinarized_nt_count = defaultdict(int)
        self.lbinarized_nt_count_binarizable_only = defaultdict(int)
        self.rbinarized_nt_count_binarizable_only = defaultdict(int)
                
    def binarize(self, rule):
        self.count_nts(rule, self.original_nt_count)
        self.count_left_right_binarization_nts(rule,
                                               self.lbinarized_nt_count,
                                               self.rbinarized_nt_count)
        rhs_nts = len(rule.e2f)
        if rhs_nts > self.max_nt_count:
            self.max_nt_count = rhs_nts
        self.processed[rhs_nts] += 1
        if len(rule.f) <= 2:
            rules = [rule]
            self.binarizable[rhs_nts] += 1
            self.linear[rhs_nts] += 1
            self.monotonic[rhs_nts] += 1
        else:
            self.prepare_rule(rule)
            self.precompute_node_probs(rule)
            N = len(rule.f)
            chart = [[None for j in range(N+1)] for i in range(N+1)]
            # initialization
            for i in range(N):
                node = Node()
                node.init_leaf(i, rule.f[i], rule, self)
                # print(node.info())
                chart[i][i+1] = node
            # recursion
            for i, j in cyk_spans(N):
                # print(i, j)
                for k in range(i+1, j):
                    # print(k)
                    node = Node()
                    node.init_nonleaf(chart[i][k], chart[k][j])
                    # print(node.info())
                    old_node = chart[i][j]
                    if old_node is None or node < old_node:
                        # print('better')
                        chart[i][j] = node
            root = chart[0][N]
            if root.binarizable == 1:
                rules = []
            else:
                self.binarizable[rhs_nts] += 1
                if root.linear == 0:
                    self.linear[rhs_nts] += 1
                root.sym = rule.lhs
                self.postprocess(root)
                rules = self.generate_rules(root)
        if rules != []:  # binarizable
            self.count_left_right_binarization_nts(
                rule,
                self.lbinarized_nt_count_binarizable_only,
                self.rbinarized_nt_count_binarizable_only)
        for r in rules:
            self.count_nts(r, self.binarized_nt_count)
        return rules

    # ----------- begin of methods class users usually do not need------------

    def prepare_rule(self, rule):
        """Precompute two mappings used in binarization:
        f2e[i] = j means the i'th nonterminal on f side is mapped to the j'th
        on the e side
        sym_idx2nt_idx[i] = j means the i'th sym on f side is the j'th
        nonterminal on the f side"""
        rule.f2e = [0] * len(rule.e2f)
        for i, j in enumerate(rule.e2f):
            rule.f2e[j] = i
        rule.sym_idx2nt_idx = [None] * len(rule.f)
        nt_idx = 0
        for i, sym in enumerate(rule.f):
            if sym.isvar:
                rule.sym_idx2nt_idx[i] = nt_idx
                nt_idx += 1

    def precompute_node_probs(self, rule):
        """Precompute the prob for each node, which is the product of probs of
        words covered by the span of this node"""
        self.node_prob = {}
        for i, j in cyk_spans(len(rule.f)):
            # print(i,j)
            p = 1
            for sym in rule.f[i:j]:
                if sym.isvar:
                    # assign nonterminals prob 1
                    sym_prob = 1
                else:
                    # if an unknown word show up in rule, assign prob 0
                    sym_prob = 0.001
                    #sym_prob = self.word_prob.get(sym.sym, 0)
                # print('sym: ', sym_prob)
                p *= sym_prob
            # print(p)
            self.node_prob[(i,j)] = p

    def postprocess(self, root):
        """In post-order tranversal, mark each node with a e_span attribute,
        which is the e side span covered by the node. e_span is None for nodes
        that cover no nonterminals."""
        rule = root.rule
        # coverage vector to mark covered symbols on e side
        self.coverage = [False] * len(rule.e)
        for node in root.iter_nodes():
            if node.leaf():
                # initialization: terminal nodes get e_span None, nonterminal
                # nodes get a e_span that corresponds to the e side idx of the
                # nonterminal
                if node.lexical:
                    node.e_span = None
                else:
                    e_nt_idx = node.min
                    nt_idx = 0
                    # find the e side idx
                    for i, sym in enumerate(rule.e):
                        if sym.isvar:
                            if nt_idx == e_nt_idx:
                                node.e_span = (i, i+1)
                                break
                            else:
                                nt_idx += 1
                    self.coverage[i] = True
            else:
                # recursion: the e_span of a node is the concatenation of the
                # e_spans of its children. note that there can be a gap between
                # the children e_spans, which means e side terminals are
                # in between. after the concatenation, the e_span is expanded
                # to include adjacent unaligned terminals
                if node.lnode.e_span is None and node.rnode.e_span is None:
                    node.e_span = None
                elif node.lnode.e_span is None:
                    node.e_span = node.rnode.e_span
                    if self.early_attach:
                        node.e_span = self.expand_e_span(node)
                elif node.rnode.e_span is None:
                    node.e_span = node.lnode.e_span
                    if self.early_attach:
                        node.e_span = self.expand_e_span(node)
                else:
                    node.e_span = (min(node.lnode.e_span[0],
                                       node.rnode.e_span[0]),
                                   max(node.lnode.e_span[1],
                                       node.rnode.e_span[1]))
                    if self.early_attach:
                        node.e_span = self.expand_e_span(node)
        # pure lexical rules will result a root e_span of None. assign all
        # words in e to top level rule.
        if root.e_span is None and len(rule.e) > 0:
            root.e_span = (0, len(rule.e))
        # if we are not doing early target terminal attaching, give the whole
        # target span to root
        if not self.early_attach:
            root.e_span = (0, len(rule.e))

    def expand_e_span(self, node):
        """Expands a e_span of a node to include adjacent uncovered terminals.
        """
        i, j = node.e_span
        rule = node.rule
        while i-1 >= 0 and not rule.e[i-1].isvar and not self.coverage[i-1]:
            self.coverage[i-1] = True
            i -= 1
        while j < len(rule.e) and not rule.e[j].isvar and not self.coverage[j]:
            self.coverage[j] = True
            j += 1
        return (i, j)

    def generate_rules(self, root):
        """Generate rules from each non-leaf binarization tree node."""
        rules = []
        for node in root.iter_nodes():
            if node.leaf():
                pass
            else:
                rule = self.make_rule(node)
                # for top rule, this is the model score
                # for virtual rule, this is the heuristic score
                rule.feats = node.rule.feats
                # if node is root:
                #     rule.feats = node.rule.feats
                # else:
                #     rule.feats = [1]
                rules.append(rule)
        return rules
                  
    def make_rule(self, node):
        """Make a binarized rule from a binarization tree node.

        The f side symbols are generated by simply reading off children
        symbols. The e side symbols are generated by taking the symbols on the
        e side covered by e_span, and replace e side symbols within children's
        e_spans with the children's symbols."""

        rule = Rule()
        rule.lhs = node.sym
        rule.f = [node.lnode.sym, node.rnode.sym]
        if node.e_span is None:
            # empty e side for nodes without e_span
            if node.lnode.sym.isvar:
                v1 = [node.lnode.sym]
            else:
                v1 = []
            if node.rnode.sym.isvar:
                v2 = [node.rnode.sym]
            else:
                v2 = []
            rule.e = v1 + v2
            if len(rule.e) == 1:
                rule.e2f = [0]
            elif len(rule.e) == 2:
                rule.e2f = [0, 1]
        elif node.lnode.e_span is None or node.rnode.e_span is None:
            if node.lnode.e_span is None and node.rnode.e_span is None:
                e1 = node.rule.e[node.e_span[0]:node.e_span[1]]
                e2 = []
            elif node.lnode.e_span is None:
                e1 = node.rule.e[node.e_span[0]:node.rnode.e_span[0]]
                e2 = node.rule.e[node.rnode.e_span[1]:node.e_span[1]]
            elif node.rnode.e_span is None:
                e1 = node.rule.e[node.e_span[0]:node.lnode.e_span[0]]
                e2 = node.rule.e[node.lnode.e_span[1]:node.e_span[1]]
            else:
                assert False, 'impossible execution path'
            # if the f side consists of a terminal and a nonterminal,
            # the terminal should not appear on the e side
            if node.lnode.sym.isvar:
                v1 = [node.lnode.sym]
            else:
                v1 = []
            if node.rnode.sym.isvar:
                v2 = [node.rnode.sym]
            else:
                v2 = []
            rule.e2f = [0, 1]
            rule.e = e1 + v1 + v2 + e2
        else:
            if node.inverted:
                e1 = node.rule.e[node.e_span[0]:node.rnode.e_span[0]]
                e2 = node.rule.e[node.rnode.e_span[1]:node.lnode.e_span[0]]
                e3 = node.rule.e[node.lnode.e_span[1]:node.e_span[1]]
                v1 = [node.rnode.sym]
                v2 = [node.lnode.sym]
                rule.e2f = [1, 0]
            else:
                e1 = node.rule.e[node.e_span[0]:node.lnode.e_span[0]]
                e2 = node.rule.e[node.lnode.e_span[1]:node.rnode.e_span[0]]
                e3 = node.rule.e[node.rnode.e_span[1]:node.e_span[1]]
                v1 = [node.lnode.sym]
                v2 = [node.rnode.sym]
                rule.e2f = [0, 1]
            rule.e = e1 + v1 + e2 + v2 + e3
        return rule

    def count_nts(self, rule, counter):
        """Count appearances of nonterminals in a rule."""
        for sym in [rule.lhs] + rule.f + rule.e:
            if sym.isvar:
                if sym in counter:
                    counter[sym] += 1
                else:
                    counter[sym] = 1

    def count_left_right_binarization_nts(self, rule, lcounter, rcounter):
        # left
        for sym in [rule.lhs] + rule.f[:2]:
            if sym.isvar:
                lcounter[sym] += 1
        for i in range(2, len(rule.f)):
            sym = tuple(rule.f[:i])
            lcounter[sym] += 1
        # right
        for sym in [rule.lhs] + rule.f[-2:]:   
            if sym.isvar:
                rcounter[sym] += 1
        for j in range(1, len(rule.f)-1):
            sym = tuple(rule.f[j:])
            rcounter[sym] += 1
                
    def stats(self):
        w = 18  # width of column
        result = '-- Binarizer Stats --\n'
        result += 'NTCount before binarization: %s\n' % \
                len(self.original_nt_count)
        result += 'NTCount after binarization: %s\n' % \
                len(self.binarized_nt_count)
        result += 'NTCount left binarization: %s\n' % \
                len(self.lbinarized_nt_count)
        result += 'NTCount right binarization: %s\n' % \
                len(self.rbinarized_nt_count)
        result += 'NTCount left binarization (binarizable rules only): %s\n' % \
                len(self.lbinarized_nt_count_binarizable_only)
        result += 'NTCount right binarization (binarizable rules only): %s\n' % \
                len(self.rbinarized_nt_count_binarizable_only)
        result += 'RHS_NTcount'.rjust(w) + \
                  'Processed'.rjust(w) + \
                  'Binarizable'.rjust(w) + \
                  'Linear'.rjust(w) + \
                  'Monotonic'.rjust(w) + \
                  '\n'
        for n in range(self.max_nt_count + 1):
            result += str(n).rjust(w) + \
                      str(self.processed[n]).rjust(w) + \
                      str(self.binarizable[n]).rjust(w) + \
                      str(self.linear[n]).rjust(w) + \
                      str(self.monotonic[n]).rjust(w) + \
                      '\n'
        result += '-'*(w*5) + '\n'
        result += 'Total'.rjust(w) + \
                  str(sum(self.processed[n]
                          for n in range(self.max_nt_count + 1))).rjust(w) + \
                  str(sum(self.binarizable[n]
                          for n in range(self.max_nt_count + 1))).rjust(w) + \
                  str(sum(self.linear[n]
                          for n in range(self.max_nt_count + 1))).rjust(w) + \
                  str(sum(self.monotonic[n]
                          for n in range(self.max_nt_count + 1))).rjust(w) + \
                  '\n'
        return result
