
import re
import sys
from copy import deepcopy
from model import Model, CondModel

########################################
#
#  calc_variants function
#
########################################

# define '%'-variants of a category
def calc_variants ( ghi ):
    U = [ghi]
    for av in re.findall('-[^-]*',ghi):
        a = re.search('(-[a-z]*)(?!$)',av).group(1)
        for uhihim in U[:]:
            uhihim1 = uhihim.replace(av,a+'%')
            U.append(uhihim)
    return U

def calc_variants ( ghi, uhim ):
    U = [ghi]
    for av in re.findall('-[^-]*',ghi):
        a = re.search('(-[a-z]*)(?!$)',av).group(1)
        if re.search(a+'%',uhim) is not None:
            for uhihim in U[:]:
                uhihim1 = uhihim.replace(av,a+'%')
                uhihim1 = re.sub(r'(-[^- |]+)\1',r'\1',uhihim1)
                U.append(uhihim1)
    return U


########################################
#
#  calc_ahihim function
#
########################################

def calc_ahihim1 ( uhihim, uhim1 ):
    ahihim1 = ""
    uhihim1 = uhihim
    for a in re.findall('-[a-z]*(?=%)',uhihim):
        av = ''.join(re.findall('('+a+'[^-]*)*',uhim1))
        ahihim1 += av
        uhihim1 = re.sub(a+'%',av,uhihim1)
    return ( ahihim1 if ahihim1!='' else '-' , uhihim1 )


########################################
#
#  main
#
########################################

########## read rule counts

# init relevant models
Gh0_giv_BL_D_Gh = CondModel('G0_L')
Gh0_giv_BR_D_Uh = CondModel('G0_R')
Gh1_giv_BL_D_Gh = CondModel('G1_L')
Gh1_giv_BR_D_Gh = CondModel('G1_R')
Uhi_giv_BL_D_Ghi = CondModel('U_L')
Uhim_giv_BR_D_Ghim = CondModel('U_R')
Gh0_Uh1_giv_BL_D_Uh_Gh = CondModel('GU_L')
Gh0_Uh1_giv_BR_D_Uh_Gh = CondModel('GU_R')
Pc = CondModel('Pc')
Pw = CondModel('Pw')
P  = Model('P')
W  = Model('W')

# read in GG model and obtain relevant models
for s in sys.stdin:
    m = re.search('Pg (.*)\|.*\^.,. : (.*) = (.*)',s)
    if m is not None:
        Pc[m.group(1)][m.group(2)] = float(m.group(3))
    m = re.search('Pw (.*) : (.*) = (.*)',s)
    if m is not None:
        Pw[m.group(1)][m.group(2)] = float(m.group(3))
        P[m.group(2)]             += float(m.group(3))
        W[m.group(1)]             += float(m.group(3))
    s = re.sub('G : (.*\^L,1) =','GG REST|REST^R,0 : \\1 REST|REST^R,0 =',s)
    m = re.search('GG (.*)\|(.*)\^(.),(.) : (.*)\|(.*)\^.,. (.*)\|(.*)\^.,. = (.*)',s)
    if m is not None:
        (uh,gh, b,sd, uh0,gh0, uh1,gh1, ct) = m.groups()
        d = int(sd)
        if b=='L' or b=='l':
            Gh0_giv_BL_D_Gh[d,gh][gh0] += float(ct)
            #print b, d, uh, gh0, float(ct)
            Gh1_giv_BL_D_Gh[d,gh][gh1] += float(ct)
            Uhi_giv_BL_D_Ghi[d,gh0][uh0] += float(ct)
            Gh0_Uh1_giv_BL_D_Uh_Gh[d,uh,gh][gh0,uh1] += float(ct)
        else:
            Gh0_giv_BR_D_Uh[d+1,uh][gh0] += float(ct)
            #print b, d+1, uh, gh0, float(ct)
            Gh1_giv_BR_D_Gh[d,gh][gh1] += float(ct)
            Uhi_giv_BL_D_Ghi[d+1,gh0][uh0] += float(ct)
            Gh0_Uh1_giv_BR_D_Uh_Gh[d,uh,gh][gh0,uh1] += float(ct)
        Uhim_giv_BR_D_Ghim[d,gh1][uh1] += float(ct)

# normalize models
Gh0_giv_BL_D_Gh.normalize()
Gh0_giv_BR_D_Uh.normalize()
Gh1_giv_BL_D_Gh.normalize()
Gh1_giv_BR_D_Gh.normalize()
Gh0_Uh1_giv_BL_D_Uh_Gh.normalize()
Gh0_Uh1_giv_BR_D_Uh_Gh.normalize()
Uhi_giv_BL_D_Ghi.normalize()
Uhim_giv_BR_D_Ghim.normalize()
Pc.normalize()
Pw.normalize()
P.normalize()
W.normalize()
Pc.write()
Pw.write()
P.write()
W.write()

sys.stderr.write('1\n')


########## obtain intermediate models

# define iteration constant
K=20

# obtain expected counts for unbounded left descendants
Ghi_giv_D_Uh_prev = CondModel('Grl*_k-1')
Ghi_giv_D_Uh_curr = CondModel('Grl*_k')
Ghi_giv_D_Uh      = CondModel('Grl*')
for d,gh in Gh0_giv_BR_D_Uh:
    for gh0 in Gh0_giv_BR_D_Uh[d,gh]:
        Ghi_giv_D_Uh_curr[d,gh][gh0] += Gh0_giv_BR_D_Uh[d,gh][gh0]
        Ghi_giv_D_Uh     [d,gh][gh0] += Gh0_giv_BR_D_Uh[d,gh][gh0]
    for k in range(1,K+1):
        Ghi_giv_D_Uh_prev = deepcopy(Ghi_giv_D_Uh_curr)
        Ghi_giv_D_Uh_curr.clear()
        for ghi in Ghi_giv_D_Uh_prev.get((d,gh)):
            for ghi0 in Gh0_giv_BL_D_Gh.get((d,ghi)):
                pr = Ghi_giv_D_Uh_prev.get((d,gh)).get(ghi) * Gh0_giv_BL_D_Gh[d,ghi][ghi0]
                if pr > 0.0:
                    Ghi_giv_D_Uh_curr[d,gh][ghi0] += pr
                    Ghi_giv_D_Uh     [d,gh][ghi0] += pr

sys.stderr.write('2\n')


# obtain expected counts for unbounded right descendants
Ghim_giv_D_Ghi_prev = CondModel('Glr*_k-1')
Ghim_giv_D_Ghi_curr = CondModel('Glr*_k')
Ghim_giv_D_Ghi      = CondModel('Glr*')
for d,gh in Gh1_giv_BL_D_Gh:
    for gh1 in Gh1_giv_BL_D_Gh[d,gh]:
        Ghim_giv_D_Ghi_curr[d,gh][gh1] += Gh1_giv_BL_D_Gh[d,gh][gh1]
        Ghim_giv_D_Ghi     [d,gh][gh1] += Gh1_giv_BL_D_Gh[d,gh][gh1]
    for k in range(1,K+1):
        Ghim_giv_D_Ghi_prev = deepcopy(Ghim_giv_D_Ghi_curr)
        Ghim_giv_D_Ghi_curr.clear()
        for ghi in Ghim_giv_D_Ghi_prev.get((d,gh)):
            for ghi1 in Gh1_giv_BR_D_Gh.get((d,ghi)):
                pr = Ghim_giv_D_Ghi_prev.get((d,gh)).get(ghi) * Gh1_giv_BR_D_Gh[d,ghi][ghi1]
                if pr > 0.0:
                    Ghim_giv_D_Ghi_curr[d,gh][ghi1] += pr
                    Ghim_giv_D_Ghi     [d,gh][ghi1] += pr

sys.stderr.write('3\n')


########## obtain agf models

# obtain expansion model
Ge = CondModel('Ge')
for d,uh in Ghi_giv_D_Uh:
    for ghi in Ghi_giv_D_Uh[d,uh]:
        if not re.search('(^[A-Z]|_)',ghi):
            Ge[d,uh][ghi] += Ghi_giv_D_Uh[d,uh][ghi]

Ge.normalize()
Ge.write()

sys.stderr.write('4\n')


# obtain reduction model
Fr = CondModel('F')
for d,uh in Ghi_giv_D_Uh:
    for ghi in Ghi_giv_D_Uh[d,uh]:
        pr = Gh0_giv_BR_D_Uh[d,uh][ghi]
        if pr > 0.0:
            Fr[d,uh,ghi]['1,'+ghi+',-'] += pr
        pr = (Ghi_giv_D_Uh[d,uh][ghi] - Gh0_giv_BR_D_Uh[d,uh][ghi])
        if pr > 0.0:
            Fr[d,uh,ghi]['0,'+ghi+',-'] += pr

Fr.normalize()
Fr.write()

sys.stderr.write('5\n')


# obtain active transition models
Uhi_giv_D_Uh_Ghi0      = CondModel('Gtaa')
Uhi1_giv_D_Uh_Uhi_Ghi0 = CondModel('Gtaw')
for d,uh in Ghi_giv_D_Uh:
    for ghi in Ghi_giv_D_Uh[d,uh]:
        for uhi in Uhi_giv_BL_D_Ghi[d,ghi]:
            for ghi0,uhi1 in Gh0_Uh1_giv_BL_D_Uh_Gh[d,uhi,ghi]:
                pr = Ghi_giv_D_Uh[d,uh][ghi] * Uhi_giv_BL_D_Ghi[d,ghi][uhi] * Gh0_Uh1_giv_BL_D_Uh_Gh[d,uhi,ghi][ghi0,uhi1]
                (ahihi1,uhihi1) = calc_ahihim1(uhi,uhi1)
                if pr > 0.0:
                    Uhi_giv_D_Uh_Ghi0[d,uh,ghi0][uhihi1] += pr
                    Uhi1_giv_D_Uh_Uhi_Ghi0[d,uh,uhihi1,ghi0][uhi1] += pr

Uhi_giv_D_Uh_Ghi0.normalize()
Uhi1_giv_D_Uh_Uhi_Ghi0.normalize()
Uhi_giv_D_Uh_Ghi0.write()
Uhi1_giv_D_Uh_Uhi_Ghi0.write()

sys.stderr.write('6\n')


# obtain awaited transition models
count = 0
tot = len(Ghi_giv_D_Uh)
Ahihim1_giv_D_Uh_Uhihim_Uhim_Ghim0 = CondModel('A')
Uhihim1_giv_D_Uhihim_Ahihim1       = CondModel('Gtwa')
Uhim1_giv_D_Uhim_Ghim0_Ahihim1     = CondModel('Gtww')
for d,uh in Ghi_giv_D_Uh:
    count += 1
    sys.stderr.write(str(count)+'/'+str(tot)+'\n')
    for ghi in Ghi_giv_D_Uh[d,uh]:
        for ghim in Ghim_giv_D_Ghi[d,ghi]:
            for uhim in Uhim_giv_BR_D_Ghim[d,ghim]:
                for ghim0,uhim1 in Gh0_Uh1_giv_BR_D_Uh_Gh[d,uhim,ghim]:
                    for uhihim in calc_variants(ghi,uhim):
                        if re.search('%',uhihim) is not None:
                            (ahihim1,uhihim1) = calc_ahihim1(uhihim,uhim1)
                            pr = ( Ghi_giv_D_Uh[d,uh][ghi] *
                                   Ghim_giv_D_Ghi[d,ghi][ghim] *
                                   Uhim_giv_BR_D_Ghim[d,ghim][uhim] *
                                   Gh0_Uh1_giv_BR_D_Uh_Gh[d,uhim,ghim][ghim0,uhim1] )
                            if pr > 0.0:
                                Ahihim1_giv_D_Uh_Uhihim_Uhim_Ghim0[d,uh,uhihim,uhim,ghim0]['0,-,'+ahihim1] += pr
                        else: (ahihim1,uhihim1) = ('-',uhihim)
                        if pr > 0.0:
                            Uhihim1_giv_D_Uhihim_Ahihim1[d,uhihim,ahihim1][uhihim1] += pr
                            Uhim1_giv_D_Uhim_Ghim0_Ahihim1[d,uhim,ghim0,ahihim1][uhim1] += pr

Ahihim1_giv_D_Uh_Uhihim_Uhim_Ghim0.normalize()
Uhihim1_giv_D_Uhihim_Ahihim1.normalize()
Uhim1_giv_D_Uhim_Ghim0_Ahihim1.normalize()
Ahihim1_giv_D_Uh_Uhihim_Uhim_Ghim0.write()
Uhihim1_giv_D_Uhihim_Ahihim1.write()
Uhim1_giv_D_Uhim_Ghim0_Ahihim1.write()

