
(* #load "str.cmxa";; *)
(* #load "scripts/trees.ml";; *)

open Str
open String
open Trees


type 'a binary_tree = Leaf of 'a * string * string | Branch of 'a * 'a binary_tree * 'a binary_tree
;;


(* formatting exception *)
exception WrongFormat of string
;;


let rec tuptree_of_tree = function
    ChildList(f,End) -> tuptree_of_tree f
  | ChildList(Term(sP),ChildList(t0,ChildList(t1,End))) when string_match (regexp "\\(.*\\):\\(.*\\){\\(.*\\)}\\^\\([LR]\\),\\([0-9]\\)") sP 0 ->
      let lP = matched_group 1 sP in
      let cP = matched_group 2 sP in
      let hP = matched_group 3 sP in
      let uP = matched_group 4 sP in
      let dP = matched_group 5 sP in
      Branch ( (lP,cP,hP,lowercase uP,dP), tuptree_of_tree t0, tuptree_of_tree t1 )
  | ChildList(Term(sP),ChildList(Term(s),End)) when string_match (regexp "\\(.*\\):\\(.*\\){\\(.*\\)}\\^\\([LR]\\),\\([0-9]\\)") sP 0 ->
      let lP = matched_group 1 sP in
      let cP = matched_group 2 sP in
      let hP = matched_group 3 sP in
      let uP = matched_group 4 sP in
      let dP = matched_group 5 sP in
      if string_match (regexp "\\(.*\\)#\\(.*\\)") s 0 then
        let p = matched_group 1 s in
        let x = matched_group 2 s in
        Leaf ( (lP,cP,hP,lowercase uP,dP), p, x )
      else raise (WrongFormat "tuptree_of_tree 2")
  | ChildList(Term(sP),r) -> raise (WrongFormat sP)
  | ChildList(_,_) -> raise (WrongFormat "no label")
  | Term(s) -> raise (WrongFormat ("term "^s))
  | _ -> raise (WrongFormat "tuptree_of_tree")
;;


let first_of_bintree = function
    Branch(f,_,_) -> f
  | Leaf  (f,_,_) -> f
;;


let imax = (int_of_string (Sys.argv.(1)))-1 ;;
let hA = Hashtbl.create 1000;;
let hM = Hashtbl.create 1000;;
let hL = Hashtbl.create 1000;;
let hH = Hashtbl.create 1000;;


let default_find h k = if Hashtbl.mem h k then Hashtbl.find h k else 0.0
;;


let rec instree_of_tuptree = function
    Branch((lP,cP,hP,uP,dP),t0,t1) ->
      let (l0,c0,h0,u0,d0) = first_of_bintree t0 in
      let (l1,c1,h1,u1,d1) = first_of_bintree t1 in
      let ins0 = instree_of_tuptree t0 in
      let ins1 = instree_of_tuptree t1 in
      let eP = Array.make (imax+1) 0.0 in
      for iP = 0 to imax do
        (* sum out left child *)
        let pr0 = ref 0.0 in
        for i0 = 0 to imax do
          pr0 := !pr0 +.  (default_find hL (u0,d0,l0,iP,i0)) *. (Array.get (first_of_bintree ins0) i0)
        done;
        (* sum out right child *)
        let pr1 = ref 0.0 in
        for i1 = 0 to imax do
          pr1 := !pr1 +.  (default_find hL (u1,d1,l1,iP,i1)) *. (Array.get (first_of_bintree ins1) i1)
        done;
        Array.set eP iP ((default_find hM (uP,dP,lP,cP,iP,l0,c0,l1,c1)) *. !pr0 *. !pr1)
      done;
      Branch ( eP, ins0, ins1 )
  | Leaf((l,c,h,u,d),p,x) ->
      let e = Array.make (imax+1) 0.0 in
      for i = 0 to imax do
        Array.set e i (default_find hH (c,i,h))
      done;
      Leaf ( e, "", "" )
;;


let hATnum   = Hashtbl.create 1000;;
let hATdenom = Hashtbl.create 1000;;
let hAT      = Hashtbl.create 1000;;

let rec get_a_transp = function
    (aP, Branch((lP,cP,hP,uP,dP),t0,t1), Branch(eP,te0,te1)) ->
      let (l0,c0,h0,u0,d0) = first_of_bintree t0 in
      let (l1,c1,h1,u1,d1) = first_of_bintree t1 in
      let tot = ref 0.0 in
      for iP = 0 to imax do
        tot := !tot +. ((Array.get aP iP) *. (Array.get eP iP))
      done;

      (* explore left child *)
      let a0 = Array.make (imax+1) 0.0 in
      for iP = 0 to imax do
        (* sum out right child *)
        let pr1 = ref 0.0 in
        for i1 = 0 to imax do
(*          pr1 := !pr1 +.  (default_find hL (l1,iP,i1)) *. (Array.get (first_of_bintree te1) i1)*)
          pr1 := !pr1 +.  (default_find hL (u1,d1,l1,iP,i1)) *. (Array.get (first_of_bintree te1) i1)
        done;
        (* create distrib for left child *)
        for i0 = 0 to imax do
          Array.set a0 i0 ( (Array.get a0 i0) +.
(*                            ((Array.get aP iP) *. (default_find hM (uP,dP,lP,cP,iP,l0,c0,l1,c1)) *. !pr1 *. (default_find hL (l0,iP,i0)) *. (Array.get (first_of_bintree te0) i0)) ); *)
                            ((Array.get aP iP) *. (default_find hM (uP,dP,lP,cP,iP,l0,c0,l1,c1)) *. !pr1 *. (default_find hL (u0,d0,l0,iP,i0)) *. (Array.get (first_of_bintree te0) i0)) );
        done;
      done;
      get_a_transp (a0, t0, te0);

      (* explore right child *)
      let a1 = Array.make (imax+1) 0.0 in
      for iP = 0 to imax do
        (* sum out left child *)
        let pr0 = ref 0.0 in
        for i0 = 0 to imax do
(*          pr0 := !pr0 +.  (default_find hL (l0,iP,i0)) *. (Array.get (first_of_bintree te0) i0)*)
          pr0 := !pr0 +.  (default_find hL (u0,d0,l0,iP,i0)) *. (Array.get (first_of_bintree te0) i0)
        done;
        (* create distrib for right child *)
        for i1 = 0 to imax do
          Array.set a1 i1 ( (Array.get a1 i1) +.
(*                            ((Array.get aP iP) *. (default_find hM (uP,dP,lP,cP,iP,l0,c0,l1,c1)) *. !pr0 *. (default_find hL (l1,iP,i1)) *. (Array.get (first_of_bintree te1) i1)) );*)
                            ((Array.get aP iP) *. (default_find hM (uP,dP,lP,cP,iP,l0,c0,l1,c1)) *. !pr0 *. (default_find hL (u1,d1,l1,iP,i1)) *. (Array.get (first_of_bintree te1) i1)) );
        done;
      done;
      get_a_transp (a1, t1, te1);

      (* update AT num and denom *)
      for iP = 0 to imax do
        (* sum out left child *)
        let pr0 = ref 0.0 in
        for i0 = 0 to imax do
(*          pr0 := !pr0 +.  (default_find hL (l0,iP,i0)) *. (Array.get (first_of_bintree te0) i0)*)
          pr0 := !pr0 +.  (default_find hL (u0,d0,l0,iP,i0)) *. (Array.get (first_of_bintree te0) i0)
        done;
        (* sum out right child *)
        let pr1 = ref 0.0 in
        for i1 = 0 to imax do
(*          pr1 := !pr1 +.  (default_find hL (l1,iP,i1)) *. (Array.get (first_of_bintree te1) i1)*)
          pr1 := !pr1 +.  (default_find hL (u1,d1,l1,iP,i1)) *. (Array.get (first_of_bintree te1) i1)
        done;
        Hashtbl.replace hATnum   (uP,dP,lP,cP,iP) ( (default_find hATnum   (uP,dP,lP,cP,iP)) +.
                                                    if !tot = 0.0 then 0.0 else
                                                    ((Array.get aP iP) *. (Array.get eP iP) /. !tot) );
        Hashtbl.replace hATdenom ()               ( (default_find hATdenom ()) +.
                                                    if !tot = 0.0 then 0.0 else
                                                    ((Array.get aP iP) *. (Array.get eP iP) /. !tot) )
      done;

  | (aP, Leaf((lP,cP,hP,uP,dP),p,x), Leaf(eP,"","")) ->
      let tot = ref 0.0 in
      for iP = 0 to imax do
        tot := !tot +. ((Array.get aP iP) *. (Array.get eP iP))
      done;

      (* update AT num and denom *)
      for iP = 0 to imax do
        Hashtbl.replace hATnum   (uP,dP,lP,cP,iP) ( (default_find hATnum   (uP,dP,lP,cP,iP)) +.
                                                    if !tot = 0.0 then 0.0 else
                                                    ((Array.get aP iP) *. (Array.get eP iP) /. !tot) );
        Hashtbl.replace hATdenom ()               ( (default_find hATdenom ()) +.
                                                    if !tot = 0.0 then 0.0 else
                                                    ((Array.get aP iP) *. (Array.get eP iP) /. !tot) )
      done;

  | _ -> raise (WrongFormat "estep")
;;


let ltrees = ref [];;


(* read loop *)
try
  while true do
    let s = input_line stdin in
    if (string_match (regexp "^A : \\(.*\\):\\(.*\\){e\\(.*\\)} = \\(.*\\)") s 0) then
      let lP = matched_group 1 s in
      let cP = matched_group 2 s in
      let iP = matched_group 3 s in
      let pr = matched_group 4 s in
      Hashtbl.replace hA (lP,cP,(int_of_string iP)) (float_of_string pr)
    else if (string_match (regexp "^M \\(.*\\) \\(.*\\) \\(.*\\):\\(.*\\){e\\(.*\\)} : \\(.*\\):\\(.*\\) \\(.*\\):\\(.*\\) = \\(.*\\)") s 0) then
      let uP = matched_group 1 s in
      let dP = matched_group 2 s in
      let lP = matched_group 3 s in
      let cP = matched_group 4 s in
      let iP = matched_group 5 s in
      let l0 = matched_group 6 s in
      let c0 = matched_group 7 s in
      let l1 = matched_group 8 s in
      let c1 = matched_group 9 s in
      let pr = matched_group 10 s in
      Hashtbl.replace hM (lowercase uP,dP,lP,cP,(int_of_string iP),l0,c0,l1,c1) (float_of_string pr)
(*    else if (string_match (regexp "^L \\([^ ]*\\) e\\([^ ]*\\) : e\\([^ ]*\\) = \\([^ ]*\\)") s 0) then*)
(*      let lP = matched_group 1 s in*)
(*      let iP = matched_group 2 s in*)
(*      let iC = matched_group 3 s in*)
(*      let pr = matched_group 4 s in*)
(*      Hashtbl.replace hL (lP,(int_of_string iP),(int_of_string iC)) (float_of_string pr)*)
    else if (string_match (regexp "^L \\([^ ]*\\) \\([^ ]*\\) \\([^ ]*\\) e\\([^ ]*\\) : e\\([^ ]*\\) = \\([^ ]*\\)") s 0) then
      let uC = matched_group 1 s in
      let dC = matched_group 2 s in
      let lP = matched_group 3 s in
      let iP = matched_group 4 s in
      let iC = matched_group 5 s in
      let pr = matched_group 6 s in
      Hashtbl.replace hL (uC,dC,lP,(int_of_string iP),(int_of_string iC)) (float_of_string pr)
    else if (string_match (regexp "^H \\([^ ]*\\) e\\([^ ]*\\) : \\([^ ]*\\) = \\([^ ]*\\)") s 0) then
      let c  = matched_group 1 s in
      let i  = matched_group 2 s in
      let h  = matched_group 3 s in
      let pr = matched_group 4 s in
      Hashtbl.replace hH (c,(int_of_string i),h) (float_of_string pr)
    else (
      (*prerr_endline ("assuming tree: "^s);*)
      let r,t = tree_of_string s in
      ltrees := (tuptree_of_tree t) :: !ltrees
    )
  done;
  None
with
  End_of_file -> None
;;


let aP = Array.make (imax+1) 0.0;;


prerr_endline ("now processing...");;


(* calc num and denom *)
List.iter (function t ->
  (*print_endline ("tree...");*)
  let ins = instree_of_tuptree t in
  let (lP,cP,hP,uP,dP) = first_of_bintree t in
  for iP = 0 to imax do
    (Array.set aP iP (default_find hA (lP,cP,iP)))
  done;
  get_a_transp (aP, t, ins);
) !ltrees
;;


(* calc model *)
Hashtbl.iter (fun (uP,dP,lP,cP,iP) pr ->
  let prDenom = (default_find hATdenom ()) in
  Hashtbl.add hAT (uP,dP,lP,cP,iP) (if (prDenom = 0.0) then 0.0 else (pr /. prDenom));
) hATnum;


(* write loop *)
Hashtbl.iter (fun (uP,dP,lP,cP,iP) pr ->  
  print_endline ("G : "^uP^" "^dP^" "^lP^":"^cP^"{e"^(string_of_int iP)^"} = "^(string_of_float pr))
) hAT
;;
