package fractaldragon.assessment;

import fractaldragon.assessment.HierarchicalTopicAgreementXtra.TopicMatch;
import fractaldragon.topicanalysis.HierarchicalModelLevelStore;
import fractaldragon.util.HungarianAlgorithm;
import fractaldragon.util.IdKeyCountsMap;
import fractaldragon.util.Tree;
import gnu.trove.iterator.TIntIntIterator;
import gnu.trove.map.TIntIntMap;
import gnu.trove.map.TIntObjectMap;
import gnu.trove.map.hash.TIntObjectHashMap;
import java.lang.reflect.Array;
import java.util.Arrays;
import java.util.ListIterator;
import static fractaldragon.util.Information.jsDivergence;

/**
 * Match topics considering hierarchy - revision of earlier TopicAgreement.
 *
 * Establish the correspondence between topics from different analyses 
 * based on topic compositions, hierarchy and weighting, and provide
 * a global measure of divergence between topic structures. 
 * 
 * 
 * @author 
 * 
 */
public class HierarchicalTopicAgreementXtra {
    
    static final int UNDEFINED = -1;
    
    int maxNumTypes;
    int maxNumTopics;
    int numTopicsl;
    int numTopicsr;
    int[] zCountsl;
    int[] zCountsr;
    int[] zmCountsl;
    int[] zmCountsr;
    IdKeyCountsMap typeTopicCountsl;
    IdKeyCountsMap typeTopicCountsr;

    
    double censorLimit = 0.0;
    
    double[][] sims;
    double[][] wtSims;
    Tree<Integer> tøl, tør;
    Tree<Integer>[] reverseTøl, reverseTør;
    int[] correspondence;
    double[] correspondingSims;
    public int[] getCorrespondence() { return correspondence; }
    public double[] getCorrespondingSims() { return correspondingSims; }
    
    // Similarity and divergence results that can be fetched.
    // [0] matching number of topics, [1] matching total topic count, [2] total topic count.
    int[] topicCounts; 
    public int[] getTopicCounts() { return topicCounts; }
    // [0] matching avg similarity, [1] matching weighted avg similarity, [2] weighted avg similarity.
    double[] avgSimilarities;   
    public double[] getAvgSimilarities() { return avgSimilarities; }
    // [0] JSD flat, [2] JSD hierarchical, [4] JSD strict hierarchical. 
    double[] divergences;   
    public double[] getDivergences() { return divergences; }
    
    // Store is final in sense that I will not replace topicMatchMap.
    // Use this for memoization of intermediate results of best subtree.
    final TIntObjectMap<TopicMatch> topicMatchMap = new TIntObjectHashMap<>();
    TopicMatch flatTopicMatch;

    
    public HierarchicalTopicAgreementXtra() { }
    
    public void setModels(HierarchicalModelLevelStore ml,  HierarchicalModelLevelStore mr) {
        assert ml.numTypes() ==  mr.numTypes();    
        this.maxNumTypes = ml.numTypes();
        numTopicsl = ml.maxTopic()+1;
        numTopicsr = mr.maxTopic()+1;
        this.maxNumTopics = Math.max( numTopicsl, numTopicsr);
        setTø(ml.tø(), mr.tø());
        this.zCountsl = Arrays.copyOf(ml.n(), numTopicsl);
        this.zCountsr = Arrays.copyOf(mr.n(), numTopicsr);
        this.zmCountsl = Arrays.copyOf(ml.m(), numTopicsl);
        this.zmCountsr = Arrays.copyOf(mr.m(), numTopicsr);
        this.typeTopicCountsl = ml.idKeyCounts();
        this.typeTopicCountsr = mr.idKeyCounts();
    }
    
        
    final public void setCensorLimit( double limit) { censorLimit = limit; }
    
    final public void setTø(Tree<Integer> tøl, Tree<Integer> tør) {
        this.tøl = tøl;
        this.tør = tør;
        this.reverseTøl = getReverseLookup(tøl, maxNumTopics);
        this.reverseTør = getReverseLookup(tør, maxNumTopics);
    }
    
    public static class TopicMatch {
        public double measure = 0.0;  // Calculated measure of match.
        public int[] match;   // Matching topics in pairs (z, z').
    }

    
    public TopicMatch getOptimalFlatMatch() {
        // Match based just on similarity without taking into account hierarchy.
        // Uses Hungarian match algorithm to determine match over all topics.
        // If topic hierarchy is not obeyed, then we could have high similarity
        // and high JSD, but low hJSD and shJSD.
        
        int[] zlvec = getFlatTopics(tøl);
        int[] zrvec = getFlatTopics(tør);
        
        // Prepare cost array for Hungarian algorithm.
        // Simplified flat version.
        double[][] costs = new double[zlvec.length][zrvec.length];
        for (int i = 0; i < zlvec.length; i++) {
            for (int j = 0; j < zrvec.length; j++) {
                int zl = zlvec[i];
                int zr = zrvec[j];
                costs[i][j] = wtSims[zl][zr] == 0.0 ? 0.0 : -wtSims[zl][zr];
            }
        }
        
        // Get the optimal match.
        int[] map = new HungarianAlgorithm(costs).execute();

        // Convert map for TopicMatch match.
        // Calculate TopicMatch measure.
        int[] match = new int[map.length*2];
        double measure = 0.0;
        int pos = 0;
        for (int i = 0; i < map.length; i++) {
            if (map[i]>= 0) { 
                double wt = wtSims[zlvec[i]][zrvec[map[i]]];
                // Exclude zero wts as they are invalid here.
                if (wt > 0.0) {
                    match[pos] = zlvec[i];
                    pos += 1;
                    match[pos] = zrvec[map[i]];
                    pos += 1;
                    measure += wtSims[zlvec[i]][zrvec[map[i]]];
                }
            }
        }
        if (pos < match.length) {
            match = Arrays.copyOf(match, pos);
        }
        
        // Set in measure.
        TopicMatch topicMatch = new TopicMatch();
        topicMatch.match = match;
        topicMatch.measure = measure;
        
        // Save the topic match for retrieval.
        flatTopicMatch = topicMatch;
        return topicMatch;        
    }

    

    /**
     * Retrieves the optimal mapping and stores similarity results as well.
     * 
     * @return Mapping from index of left model to topic of right model.
     */
    public int[] retrieveOptimalFlatMapping() {
        // Enter at tl, tr and construct the topic correspondence.
        correspondence = new int[numTopicsl];
        correspondingSims = new double[numTopicsl];
        Arrays.fill(correspondence, UNDEFINED);
     
        int[] match = flatTopicMatch.match;
        for (int i = 0; i < match.length; i+=2) {
            correspondence[match[i]] = match[i+1]; 
        }
        for (int i = 0; i < numTopicsl; i++) {
            int j = correspondence[i];
            if (j >= 0) {
                correspondingSims[i] = sims[i][j];
            }
        }
        return correspondence;
    }
    
    public TopicMatch getOptimalTreeMatch() {
        return getOptimalSubtreeMatch(0, 0);   
    }
    
    TopicMatch getOptimalSubtreeMatch(int zlPar, int zrPar) {
        // Assumption that root will not be leaf.
        
        // Get subtree match from memoization.  
        // If not found then calculate match.
        TopicMatch topicMatch = subtreeMatches(zlPar, zrPar);
        if (topicMatch != null) { return topicMatch; }
        
        topicMatch = calculateTopicMatchHungarian(zlPar, zrPar); 
        
        // Memoize the match.
        subtreeMatches(zlPar, zrPar, topicMatch);
        
        //System.out.println("*Measure: "+topicMatch.measure
        //    +", parents: ("+zlPar+","+zrPar+"), wtSim: "+wtSims[zlPar][zrPar]
        //    +", match: "+Arrays.toString(topicMatch.match));
        
        return topicMatch;
    }
    
    TopicMatch calculateTopicMatchHungarian(int zlPar, int zrPar) {
        // Require: topic costs are already determined for zlvec and zrvec. 
        // Require: parent similarity weight cannot be zero except for zero root.
        double parentWtSim = wtSims[zlPar][zrPar];
        if (parentWtSim == 0.0 && zlPar != 0 && zrPar != 0) {
            //System.out.println("Parent wt for["+zlPar+","+zrPar+"] == 0.0; ignored!");
            TopicMatch topicMatch = new TopicMatch();
            topicMatch.match = null;
            topicMatch.measure = 0.0; 
            return topicMatch;
        }
        
        int[] zlvec = getChildTopicsl(zlPar);
        int[] zrvec = getChildTopicsr(zrPar);
        // Prepare cost array for Hungarian algorithm.
        // Generalized to recursively invoke optimal match 
        // when topic pairs contain leaf or are not memoized.
        double[][] costs = new double[zlvec.length][zrvec.length];
        for (int i = 0; i < zlvec.length; i++) {
            for (int j = 0; j < zrvec.length; j++) {
                int zl = zlvec[i];
                int zr = zrvec[j];
                if (isLeafl(zl) || isLeafr(zr)) {
                    // Child is a leaf. Get weight directly from wtSims.
                    costs[i][j] = wtSims[zl][zr] == 0.0 ? 0.0 : -wtSims[zl][zr];
                } else {
                    // Get weight from subtree matches.
                    costs[i][j] = -getOptimalSubtreeMatch(zl, zr).measure;        
                }
            }
        }
        // Get the optimal match.
        int[] map = new HungarianAlgorithm(costs).execute();
        
        // Convert map for TopicMatch match.
        // Calculate TopicMatch measure.
        int[] match = new int[map.length*2];
        double measure = parentWtSim; // Include parent weight in subtree weight.
        int pos = 0;
        for (int i = 0; i < map.length; i++) {
            if (map[i]>= 0) { 
                // Measure includes subtrees up to this level.
                TopicMatch topicMatch = subtreeMatches(zlvec[i],zrvec[map[i]]);
                double wt = topicMatch != null ? topicMatch.measure 
                        : wtSims[zlvec[i]][zrvec[map[i]]];
                if (wt > 0.0) {
                    match[pos] = zlvec[i];
                    pos += 1;
                    match[pos] = zrvec[map[i]];
                    pos += 1;
                    measure += wt;
                }
            }
        }
        if (pos < match.length) {
            match = Arrays.copyOf(match, pos);
        }
        // Set in measure.
        TopicMatch topicMatch = new TopicMatch();
        topicMatch.match = match;
        topicMatch.measure = measure;
        
        return topicMatch;
    }
                
    /**
     * Retrieves the optimal mapping and stores similarity results as well.
     * 
     * @return Mapping from index of left model to topic of right model.
     */
    public int[] retrieveOptimalMapping() {
        // Enter at tl, tr and construct the topic correspondence.
        correspondence = new int[numTopicsl];
        correspondingSims = new double[numTopicsl];
        Arrays.fill(correspondence, UNDEFINED);
     
        retrieveOptimalSubtreeMapping(0, 0);
        for (int i = 0; i < numTopicsl; i++) {
            int j = correspondence[i];
            if (j >= 0) {
                correspondingSims[i] = sims[i][j];
            }
        }
        return correspondence;
    }
    
    private void retrieveOptimalSubtreeMapping( int tl, int tr) {
        TopicMatch topicMatch;
        
        correspondence[tl] = tr;
        
        int key = tl << 16 | tr;
        topicMatch = topicMatchMap.get(key);
        if (topicMatch == null || topicMatch.match == null) { return; }
        
        int[] match = topicMatch.match;
        //System.out.println("Measure["+tl+","+tr+"]="+topicMatch.measure
        //    +"; match="+Arrays.toString(match));
        for (int i = 0; i < match.length; i+=2) {
            retrieveOptimalSubtreeMapping( match[i], match[i+1]);
        }        
    }
    
    private void subtreeMatches(int zlPar, int zrPar, TopicMatch topicMatch) {
        // store topicMatch indexed by zlPar, zrPar
        int key = zlPar << 16 | zrPar;
        topicMatchMap.put(key, topicMatch);
    }
    
    private TopicMatch subtreeMatches(int zlPar, int zrPar) {
        // return topicMatch indexed by zlPar, zrPar.
        // return null if not found.
        int key = zlPar << 16 | zrPar;
        return topicMatchMap.get(key);
    }
    
    private boolean isLeafl(int z) {
        return reverseTøl[z].isLeaf();
    }
    private boolean isLeafr(int z) {
        return reverseTør[z].isLeaf();        
    }
        
    private int[] getChildTopicsl(int par) {
        Tree<Integer> parTree = reverseTøl[par];
        return getChildTopics(parTree);
    }
    
    private int[] getChildTopicsr(int par) {
        Tree<Integer> parTree = reverseTør[par];
        return getChildTopics(parTree);
    }
    
    private int[] getChildTopics(Tree<Integer> parTree) {
        int numChildren = parTree.getNumChildren();
        if (numChildren == 0) { return null; }
        
        int[] childTopics = new int[numChildren];
        ListIterator<Tree<Integer>> childIter = parTree.childrenIterator();
        for (int i = 0; i < numChildren; i++) {
            childTopics[i] = childIter.next().getValue();
        }
        return childTopics;
    }
    
    private int[] getFlatTopics(Tree<Integer> tree) {
        // Wasteful processing tree twice; once for count and then to fill vector.
        // But using an array list of Integer would also be wasteful.
        int[] vec = new int[tree.getCount()];
        int i = 0;
        for (Tree<Integer> node : tree) {
            vec[i] = node.getValue();
            i++;
        }        
        return vec;
    }

        
    // Weight similarities as identity of according to frequencies.
    void weightSimilarities() {
        // Identity weighting. 
        // Reference the weighted similarities from similarities.
        //wtSims = sims; // Copy similarities.
        wtSims = new double[sims.length][sims[0].length];
        for (int i = 0; i < sims.length; i++) {
            for (int j = 0; j < sims[0].length; j++) {
                wtSims[i][j] = sims[i][j] < censorLimit ? 0.0 : sims[i][j];
            }
        }
    }
    
    public enum WTFUNS { IDENTITY, RATIO, SQRTRATIO, LOG1PRATIO; }
    
    public void weightSimilarities( WTFUNS wtFun) {
        if (wtFun == WTFUNS.IDENTITY) {
            weightSimilarities();
        } else {
            weightSimilarities( wtFun, zmCountsl, zmCountsr);     
        }
    }
    
    public void weightSimilarities( WTFUNS wtFun,  int[] ml, int[] mr) {
        if (wtFun == WTFUNS.IDENTITY) {
            weightSimilarities();
        } else {
            // Will use some version of the ratio.
            assert ml.length == sims.length;
            assert mr.length == sims[0].length;
            
            double møl = sumCounts(ml);
            double mør = sumCounts(mr);      
            double[] rl = new double[ml.length];
            double[] rr = new double[mr.length];
            for (int i = 0; i < ml.length; i++) { rl[i] = ml[i]/møl; }
            for (int i = 0; i < mr.length; i++) { rr[i] = mr[i]/mør; }
            
            wtSims = new double[ml.length][mr.length];
            
            for (int i = 0; i < ml.length; i++) {
                for (int j = 0; j < mr.length; j++) {
                    if (sims[i][j] < censorLimit) {
                        wtSims[i][j] = 0.0;
                    } else {
                        switch (wtFun) {
                            case SQRTRATIO:
                                wtSims[i][j] = sims[i][j] * Math.sqrt(rl[i]*rr[j]);
                                break;
                            case LOG1PRATIO:
                                wtSims[i][j] = sims[i][j] * Math.log1p(rl[i]*rr[j]);
                                break;
                            case RATIO:
                            default:
                                wtSims[i][j] = sims[i][j] * rl[i]*rr[j];
                        }
                    }
                }
            }
        }
    }
    
    
    public void calculateCosineSimilarities() {
        IdKeyCountsMap topicIdCountsl = new IdKeyCountsMap(numTopicsl, maxNumTypes);
        typeTopicCountsl.transposeIdKeyCounts(topicIdCountsl);
        IdKeyCountsMap topicIdCountsr = new IdKeyCountsMap(numTopicsr, maxNumTypes);
        typeTopicCountsr.transposeIdKeyCounts(topicIdCountsr);
        
        calculateCosineSimilarities(topicIdCountsl, topicIdCountsr);
    }
    
    void calculateCosineSimilarities(
      IdKeyCountsMap topicIdCountsl, IdKeyCountsMap topicIdCountsr) {
        
        sims = new double[numTopicsl][numTopicsr];

        // Calculate the similarities between topics across corpora.
        for (int zl = 0; zl < numTopicsl; zl++) {
            TIntIntMap idCounts0 = topicIdCountsl.keyCounts(zl);
            if (idCounts0.isEmpty()) { continue; }
            
            int[] counts0 = fillCountsFromMap(idCounts0);
            
            for (int zr = 0; zr < numTopicsr; zr++) {
                TIntIntMap idCounts1 = topicIdCountsr.keyCounts(zr);
                if (idCounts1.isEmpty()) { continue; }
                
                // Redundant, but will avoid having to manage vector store.
                int[] counts1 = fillCountsFromMap(idCounts1);
            
                sims[zl][zr] = cosineSimilarity(counts0, counts1);                   
            }
        }
    }

    int[] calculateTopicCounts(IdKeyCountsMap topicIdCounts) {
        return calculateTopicCounts(topicIdCounts, maxNumTypes);
    }
    
    int[] fillCountsFromMap(TIntIntMap countsMap) {
        return fillCountsFromMap(countsMap, maxNumTypes);
    }
    
    // Service modules
    
    static int[] calculateTopicCounts(IdKeyCountsMap topicIdCounts, int numTypes) {
        // Calculate the topic counts and totals up front.
        int[] n = new int[topicIdCounts.numIds()];
        int numTopics = topicIdCounts.numIds();
        for (int z = 0; z < numTopics; z++) {
            TIntIntMap idCounts = topicIdCounts.keyCounts(z);
            if (idCounts.isEmpty()) { continue; }
            
            int[] counts = fillCountsFromMap(idCounts, numTypes);
            n[z] = sumCounts(counts);
        }
        return n;
    }
    
    
    static int sumCounts(int[] counts) {
        int sum = 0;
        for (int count : counts) {
            sum += count;
        }
        return sum;
    }
    
    static int[] fillCountsFromMap(TIntIntMap countsMap, int numTypes) {
        int[] counts = new int[numTypes];
        TIntIntIterator iter = countsMap.iterator();
        while (iter.hasNext()) {
            iter.advance();
            counts[iter.key()] = iter.value();
        }
        return counts;
    }
        
        
    static double cosineSimilarity(int[] counts0, int[] counts1) {
        double dotProduct = 0.0;
        double magnitude0 = 0.0;
        double magnitude1 = 0.0;
        
        for(int i = 0; i < counts0.length; i++) {
            double count0 = counts0[i];
            double count1 = counts1[i];
            dotProduct += count0*count1;
            magnitude0 += count0*count0;
            magnitude1 += count1*count1;
        }
        return dotProduct/(Math.sqrt(magnitude0)*Math.sqrt(magnitude1));
    }

    
    static private Tree<Integer>[] getReverseLookup(Tree<Integer> tø, int size) {
        Tree<Integer>[] reverseTø = 
              (Tree<Integer>[])Array.newInstance(tø.getClass(), size);
        // Build reverse lookup table for tree nodes.
        for (Tree<Integer> node : tø) {
            reverseTø[node.getValue()] = node;
        }
        return reverseTø;
    }

    /**** Static service methods
     * @param corpus
     * @param maxNumTopics
     * @return  
     *****/
    
    
    public void reportCorrespondences() {
        System.out.println("Correspondence model left index to model right:");
        // Build string of correspondences for human reading.
        String mapping = "";
        for (int l = 0; l < correspondence.length; l++){
            if (correspondence[l]>= 0) {
                mapping += l+"=>"+correspondence[l]+", ";
            }
        }
        System.out.println(mapping);        
    }
    
    public void reportSimilaritiesAverage() {
        System.out.println("\nAverage similarities.");
        calculateSimilaritiesAverage();
        
        System.out.println("Matching number topics: "+topicCounts[0]
            +"; avg similarity:"+avgSimilarities[0]);
        System.out.println("Matching total count:"+topicCounts[1]
            +"; matching wt avg similarity:"+avgSimilarities[1]);
        System.out.println("Total count:"+topicCounts[2]
            +"; total wt avg similarity:"+avgSimilarities[2]);
    }
    
    public void calculateSimilaritiesAverage() {
        // Calculate weighted and unweighted similarity scores.
        // Uses matchingSimilarities, matchingTopics, ave(zCountsl,zCountsr)
        
        topicCounts = new int[3];
        avgSimilarities = new double[3];
        
        double sim = 0.0;
        double wtSim = 0.0;
        int topicCount = 0;
        double totalCount = 0;
        
        // We could exlude topics with sim < censor here.
        for (int z = 0; z < numTopicsl; z++) {
            if (correspondence[z] >= 0) {
                double wt = (zCountsl[z]+zCountsr[correspondence[z]])/2.0;
                if (wt > 0) {
                    // If no frequency weight, then no topic.
                    topicCount++;
                    totalCount += wt;
                    sim += correspondingSims[z];
                    wtSim += correspondingSims[z]*wt;  
                }
            }
        }
        topicCounts[0] = topicCount;
        topicCounts[1] = (int)Math.round(totalCount);
        avgSimilarities[0] = sim/topicCount;
        avgSimilarities[1] = wtSim/totalCount;
        
        totalCount = (sumCounts(zCountsl) + sumCounts(zCountsr))/2.0;
        topicCounts[2] = (int)Math.round(totalCount);
        avgSimilarities[2] = wtSim/totalCount;
    }

    // Calculate Jensen-Shannon divergence between corpus0 and corpus1 topic 
    // distribution using the topic correspondence and topic counts. 
    // Non-matching topics are from distribution P are maintained as separate
    // from non-matching topics from distribution Q so that each contributes 
    // the the total JSD as (P_nm + Q_nm)/2. Note that divergence is reported
    // on a log2 basis.
    public enum JSDPolicy { Flat, Hierarchical, StrictHierarchical; }
    
    public void reportJSDivergences() {
        System.out.println("Hierarchical divergence.");
        divergences = new double[6];
        
        double[] jsd;
        jsd = calculateTopicJSDivergence(JSDPolicy.Flat);
        divergences[0] = jsd[0]; //divergences[1] = jsd[1];
        System.out.println("Flat: JSD = "+jsd[0]); //+"; without Root JSD = "+jsd[1]); 
        jsd = calculateTopicJSDivergence(JSDPolicy.Hierarchical);
        divergences[2] = jsd[0]; //divergences[3] = jsd[1];
        System.out.println("Hierarchical: JSD = "+jsd[0]); //+"; without Root JSD = "+jsd[1]);   
        jsd = calculateTopicJSDivergence(JSDPolicy.StrictHierarchical);
        divergences[4] = jsd[0]; //divergences[5] = jsd[1];
        System.out.println("Strict Hierarchical: JSD = "+jsd[0]); //+"; without Root JSD = "+jsd[1]); 
    }

    public double[] calculateTopicJSDivergence(JSDPolicy policy) {
        // Calculate JSD over topic frequencies for corresponding topics.
        // Idea is so simple!
        //
        // Algorithm:
        // 1. Determine mapping for corpora 0 and 1. 
        //    1.1 numTopics = max(numTopics1, numTopics2)
        //    1.2 Corpus 0 topics mapped to 1; undefined summed in numTopics position. 
        //    1.3 Corpus 1 topics as given; unmapped stay in place.
        // 2. Calculate Jensen-Shannon divergence.
        double[] jsd = new double[2];
        
        // Map the corpora topic counts.
        int numTopics = Math.max(numTopicsl, numTopicsr);
        // numTopics position is for unassigned topics on l and
        // numTopics+1 position for unassigned topics on r.
        int[] zMapCountsl = new int[numTopics+2];
        int[] zMapCountsr = new int[numTopics+2];
        long unassignedr = sumCounts(zCountsr);
        
        for (int kl = 0; kl < numTopicsl; kl++) {
            int kr = correspondence[kl];
            if (kr == UNDEFINED) {
                // Undefined either because not used; or no match.
                // If not used there is zero added count.
                zMapCountsl[numTopics] += zCountsl[kl];
            } else {
                // For 1-to-1 mapping = OK; if onto mapping += OK.
                switch (policy) {
                    case Flat:
                        zMapCountsl[kr] += zCountsl[kl];
                        zMapCountsr[kr] += zCountsr[kr];
                        unassignedr -= zCountsr[kr];
                        break;
                    case Hierarchical:
                        if (matchOnHierarchy(kl, kr)) {
                            zMapCountsl[kr] += zCountsl[kl];                              
                            zMapCountsr[kr] += zCountsr[kr];
                            unassignedr -= zCountsr[kr];
                        } else {
                            zMapCountsl[numTopics] += zCountsl[kl];                    
                        }
                        break;
                    case StrictHierarchical:
                        if (strictMatchOnHierarchy(kl, kr)) {
                            zMapCountsl[kr] += zCountsl[kl];                              
                            zMapCountsr[kr] += zCountsr[kr];
                            unassignedr -= zCountsr[kr];
                        } else {
                            zMapCountsl[numTopics] += zCountsl[kl];                    
                        }
                        break;      
                }
            }
        }
        // zMapCountsl has topic frequencies aligned with those for zCountsr.
        // Plus an accumulation of non-matching frequencies in numTopics position.
        // Topics without correspondence have zero frequency.
        // So l frequencies are accmulated and correspond to r frequency of 0.
        // And r frequencies stay as individuals and correspond to l frequencies of 0.
        // But that doesn't penalize much for misses from the r as small discrepancies 
        // don't add much to the divergence. So we should add up the misses from the right
        // and put into their own frequency bucket.
        
        //int[] zMapCountsr = Arrays.copyOf(zCountsr, numTopics+1);
        assert unassignedr >= 0;
        zMapCountsr[numTopics+1] = (int)unassignedr;
        // Unassigned counts for r are accumulated in numTopics+1 position.
        //System.out.println("Unassigned l="+zMapCountsl[numTopics]+", r="+zMapCountsr[numTopics+1]);
        //System.out.println("l="+Arrays.toString(zMapCountsl));
        //System.out.println("r="+Arrays.toString(zMapCountsr));
        

        jsd[0] = jsDivergence(zMapCountsl, zMapCountsr); 
        // Calculate adjusted JSD ignoring the contribution from root.
        // Rationale is that the root should be always right,
        // so this focuses on the harder stuff.
        if (correspondence[0] == 0) {
            zMapCountsl[0] = 0;
            zMapCountsr[0] = 0;           
            jsd[1] = jsDivergence(zMapCountsl, zMapCountsr);
        } else {
            jsd[1] = jsd[0];
        }
        
        return Arrays.copyOf(jsd, jsd.length);
    }
    
    /* calculateFlatTopicJSDivergence
        // Algorithm:
        // 1. Determine mapping for corpora 0 and 1. 
        //    1.1 numTopics = max(numTopics1, numTopics2)
        //    1.2 Corpus 0 topics mapped to 1; undefined summed in numTopics position. 
        //    1.3 Corpus 1 topics as given; unmapped stay in place.
        // 2. Calculate Jensen-Shannon divergence.
    
        calculateHierarchicalTopicJSDivergence
        // Algorithm:
        // 1. Determine mapping for corpora 0 and 1. 
        //    1.1 numTopics = max(numTopics1, numTopics2)
        //    1.2 Corpus 0 topics mapped to 1 where the hierarchy corresponds; 
        //        undefined and incorrectly mapped are summed in numTopics position.
        //        So the significant change from flat is with incorrectly mapped.
        //        Only the parent  is checked for a match so this gives credit for
        //        a partial match on hierarchy even if the next level up is incorrect.
        //        More demanding would be to insist the entire path has to match.
        //    1.3 Corpus 1 topics as given; unmapped stay in place.
        // 2. Calculate Jensen-Shannon divergence.
    */
    
    boolean matchOnHierarchy(int kl, int kr) {
        // Account for hierarchy; test corpus 0 versus corpus 1.
        // Simple test only to parent level.
        boolean match;
        
        int klPar = kl == 0 ? UNDEFINED : reverseTøl[kl].getParentValue();
        int krPar = kr == 0 ? UNDEFINED : reverseTør[kr].getParentValue();
        if (klPar == UNDEFINED && krPar == UNDEFINED) {
            match = true;
        } else if (klPar == UNDEFINED) {
            match = false;
        } else if (correspondence[klPar] == krPar) {
            match = true;
        } else {
            match = false;
        }
        return match;
    }

    boolean strictMatchOnHierarchy(int kl, int kr) {
        // Account for hierarchy; test corpus 0 versus corpus 1.
        // Test parents mapped the same up to the root.
        // Just a recursive call when current topics match.
        boolean match;
        
        int klPar = kl == 0 ? UNDEFINED : reverseTøl[kl].getParentValue();
        int krPar = kr == 0 ? UNDEFINED : reverseTør[kr].getParentValue();
        if (klPar == UNDEFINED && krPar == UNDEFINED) {
            match = true; // Root node match.
        } else if (klPar == UNDEFINED || krPar == UNDEFINED) {
            match = false; // Root node mismatch.
        } else if (correspondence[klPar] == krPar) {
            //match = true;
            match = strictMatchOnHierarchy(klPar, krPar);
        } else {
            match = false;
        }
        return match;
    }
 
}
