def nt_escape(sym):
    """Escape a string to make it a proper nonterminal symbol."""
    sym = sym.replace(',', 'COMMA')
    sym = sym.replace('[', 'LB')
    sym = sym.replace(']', 'RB')
    return sym

class Sym(object):
    """terminal or nonterminal (possibly with index)

    >>> a = Sym('a')
    >>> b = Sym('b')
    >>> c = Sym('a')
    >>> X = Sym('[X,1]')
    >>> Y = Sym('[Y,1]')
    >>> Z = Sym('[X,2]')
    >>> str(a)
    'a'
    >>> a == b
    False
    >>> a == c
    True
    >>> print(X)
    [X]
    >>> X == Y
    False
    >>> X == Z
    True
    """
    def __init__(self, s=None):
        if s:
            self.fromstr(s)
        else:
            self.sym = None # symbol
            self.isvar = None # true if this token is a nonterminal

    def fromstr(self, s):
        """return an idx if there is one"""
        idx = None
        if s.startswith('[') and s.endswith(']'):
            self.isvar = True
            s = s[1:-1].split(',')
            if len(s) == 2:
                idx = int(s[1]) # idx, optional for nonterminal
            self.sym = s[0].strip() # symbol
        else:
            self.sym = s
            self.isvar = False
        return idx

    def tostr(self, idx):
        """Write as an indexed nonterminal, like [X,1]"""
        assert self.isvar, 'only nonterminals can have indices'
        return '[%s,%s]' % (self.sym, idx)

    def __str__(self):
        if self.isvar:
            return '[%s]' % self.sym
        else:
            return self.sym

    def __eq__(self, other):
        return self.sym == other.sym and self.isvar == other.isvar

    def __hash__(self):
        return hash( (self.sym, self.isvar) )

    def is_virtual(self):
        """Return true if this is a virtual item."""
        return self.sym.startswith('V_')

# save only one copy of each symbol
# TODO: ugly
symbol_table = {}

# there are unary, binary, and lexical rules
# other forms rules are currently not supported
# A -> B C , C B
# B -> B , B
# A -> a b c ... , d e f ...
#TODO: sanity check
class Rule(object):
    """SCFG rule

    >>> r = Rule()
    >>> r.fromstr('[X] ||| aa [A,1] cc [A,2] dd [B] [C] ||| xx [C] [B] [A,2] [A,1] ||| 1.0 0 0 0')
    >>> str(r)
    '[X] ||| aa [A,4] cc [A,3] dd [B,2] [C,1] ||| xx [C,1] [B,2] [A,3] [A,4] ||| 1.0 0.0 0.0 0.0'
    >>> r.arity
    4
    >>> r.rewrite([['A1', 'A1'], ['A2'], ['B', 'B'], ['C']])
    ['aa', 'C', 'cc', 'B', 'B', 'dd', 'A2', 'A1', 'A1']
    """

    def __init__(self):
        self.lhs = None  # lhs nonterminal
        self.f = []  # foreign phrase
        self.e = []  # english phrase
        # this is weighted sum of self.fcosts
        # precomputed to be used in decoding
        self.cost = 0
        self.hcost = 0
        # list of feature values read from input
        self.feats = []
        # list of feature values computed by stateless feature functions
        # these are real costs used in decoding.
        # this is different from feature fields read from input because
        # feature funtions may not pick all the features in input file,
        # may use log probabilities, etc.
        self.fcosts = []
        self.cost = 0
        # e2f[i] = j means the i'th nonterminal on the e side is linked to
        # the j'th nonterminal on the f side. nonterminal permutation 
        # information is saved only here, the idx assigned to nonterminals
        # are only for input/output purposes.
        self.e2f = []
        self.grammar = None

    def rank_cost(self):
        return self.cost + self.hcost

    def fromstr(self, line):
        s = line.split('|||')
        assert len(s) == 4, 'fewer fields than expected'
        lhs, e, f, probs = s
        self.lhs = Sym(lhs.strip())

        self.f = []
        f_var_indices = []
        f_var2idx = {}
        for i, x in enumerate(f.split()):
            sym = Sym()
            idx = sym.fromstr(x)
            self.f.append(sym)
            if sym.isvar:
                f_var2idx[(sym.sym, idx)] = len(f_var_indices)
                f_var_indices.append(i)

        self.e = []
        e_var_indices = []
        self.e2f = []  # maps e var indices to f var indices
        for i, x in enumerate(e.split()):
            sym = Sym()
            idx = sym.fromstr(x)
            self.e.append(sym)
            if sym.isvar:
                self.e2f.append(f_var2idx[(sym.sym, idx)])
                e_var_indices.append(i)

        assert len(f_var_indices) == len(e_var_indices), line

        self.feats = [float(x) for x in probs.split()]
        self.arity = sum(x.isvar for x in self.e)

        # save memory
        self.f = [symbol_table.setdefault(w, w) for w in self.f]
        self.e = [symbol_table.setdefault(w, w) for w in self.e]

    def rewrite(self, vars):
        """'vars' are lists of target side symbols. rewrite variables with lists
        of symbols in 'vars' and return the target side list of symbols after
        rewriting"""
        assert len(vars) == self.arity
        result = []
        e_var_idx = 0
        for sym in self.e:
            if sym.isvar:
                f_var_idx = self.e2f[e_var_idx]
                result += vars[f_var_idx]
                e_var_idx += 1
            else:
                result.append(str(sym))  # convert Sym to Python string
        return result
    
    def __str__(self):
        # f side always straight
        f_strs = []
        i = 1
        for sym in self.f:
            if sym.isvar:
                f_strs.append(sym.tostr(i))
                i += 1
            else:
                f_strs.append(str(sym))
        # e side permuted
        e_strs = []
        i = 0
        for sym in self.e:
            if sym.isvar:
                e_strs.append(sym.tostr(self.e2f[i] + 1))
                i += 1
            else:
                e_strs.append(str(sym))
        result= '%s ||| %s ||| %s ||| %s' % \
                (self.lhs, 
                 ' '.join(e_strs), 
                 ' '.join(f_strs),
                 ' '.join([str(f) for f in self.feats]))
        return result

    def __lt__(self, other):
        return self.rank_cost() < other.rank_cost()

if __name__ == '__main__':
    import doctest
    doctest.testmod()
