"""Statistics about the level of top & left header depth."""

import os
import json
import argparse

from processing.io_utils import load_tables
from processing.io_utils import load_anno_and_struc, pairing, load_anno_nsf
from processing.link import get_raw_linked_cells
from processing.base.structure_nsf import process


# %% general
def get_tree_depth(node):
    max_child_depth = 0
    for child_node in node.child_cells:
        child_depth = get_tree_depth(child_node)
        if child_depth > max_child_depth:
            max_child_depth = child_depth
    return max_child_depth + 1

def get_words(text):
    words = text.split()
    words = [w for word in words for w in word.strip().split('\t')]
    words = [w for word in words for w in word.strip().split('\n')]
    return words

def get_word_num(text):
    words = get_words(text)
    return len(words)

# %% StatCan

def stat_can_sample(table_id, table, args):
    """Create an instance of sample input given a table."""

    # read and parse annotation
    anno_path = os.path.join(args.can_dir, f'{table_id}.xlsx')
    struc_and_anno = load_anno_and_struc(anno_path, table_id, table, operator=None)
    if struc_and_anno is None: return
    structure, sentences = struc_and_anno

    # link and serialize each t-s pair
    for structure, subdict, sub_sent_id in pairing(structure, sentences):
        linked_cells, _ = get_raw_linked_cells(structure, subdict)
        top_depths = [get_tree_depth(root) for root in structure['top_header']['virtual_root']]
        top_depths.append(0)
        top_depth = max(top_depths)
        left_depths = [get_tree_depth(root) for root in structure['left_header']['virtual_root']]
        left_depths.append(0)
        left_depth = max(left_depths)

        # format the sample
        sample = {
            'table_id': table_id, 
            'sub_sent_id': sub_sent_id, 
            'top_depth': top_depth, 
            'left_depth': left_depth, 
            'linked_num': len(subdict['schema_link']), 
            'word_num': get_word_num(subdict['sub_sent']),
        }
        yield sample

def stat_canada_annotations():
    # load tables 
    table_dict = load_tables(args.html_dir)
    # sort tables (w.r.t. id) for better stats
    sorted_table_list = [(table_id, table) for table_id, table in table_dict.items()]
    sorted_table_list = sorted(sorted_table_list, key=lambda x: x[0])
    
    statset = []  
    for table_id, table in sorted_table_list:
        for sample in stat_can_sample(table_id, table, args):
            if sample is None: break
            statset.append(sample)
            if len(statset) % 100 == 0:
                print(f'canada-stats size: {len(statset)}')
    return statset


# %% NSF

def stat_nsf_sample(nsf_table_id, nsf_page_id, nsf_anno_path, nsf_structure, args):
    """Create an instance of nsf table."""
    nsf_sentences = load_anno_nsf(anno_path=nsf_anno_path, operator=None)
    if nsf_sentences is None: return

    # link and serialize each t-s pair
    for structure, subdict, sub_sent_id in pairing(nsf_structure, nsf_sentences):
        linked_cells, data_paths = get_raw_linked_cells(structure, subdict)

        # format the sample
        sample = {
            'table_id': nsf_table_id, 
            'sub_sent_id': sub_sent_id, 
            'top_depth': structure['num_top_header_rows'], 
            'left_depth': structure['num_left_header_cols'], 
            'linked_num': len(subdict['schema_link']), 
            'word_num': get_word_num(subdict['sub_sent']),
            # 'page_id': nsf_page_id, 'domain_id': 0, 
        }
        yield sample

def stat_nsf_annotations():
    statnsf = []

    global_table_id = 0
    num_error_case = 0
    page_dirnames = sorted(os.listdir(args.raw_nsf_dir))
    for page_dirname in sorted(page_dirnames):
        if page_dirname.endswith('N'): continue
        page_dir = os.path.join(args.raw_nsf_dir, page_dirname)
        table_names = os.listdir(page_dir)
        for tname in table_names:
            if tname.split('.')[0].endswith('N') or tname.startswith('~') or ('fig' in tname): continue

            try:
                nsf_structure, nsf_anno_path = process(
                    file_path=os.path.join(page_dir, tname), 
                    file_name=tname, 
                    page_id=page_dirname, 
                    table_id=global_table_id, 
                    new_nsf_dir=args.nsf_anno_dir, 
                )
                for nsf_sample in stat_nsf_sample(
                    nsf_table_id=10000+global_table_id, 
                    nsf_page_id=int(page_dirname), 
                    nsf_anno_path=nsf_anno_path, 
                    nsf_structure=nsf_structure, 
                    args=args
                ):
                    statnsf.append(nsf_sample)
                global_table_id += 1
            except: 
                num_error_case += 1
                print(f"error processing {tname} on page {page_dir}.")
    return statnsf



# %% split 

def get_stat_key(sample_stat):
    # return f"link{sample_stat['linked_num']}"
    return f"t{sample_stat['top_depth']}l{sample_stat['left_depth']}"


def split_testset_by_hierarchy(test_path, stats_path, output_dir):
    """Split lines of test samples according to statistics info."""

    # load the statistics
    all_stats = []
    with open(stats_path, 'r') as fr:
        for line in fr:
            sample_stats = json.loads(line.strip())
            all_stats.append(sample_stats)

    # load the test set
    testset = []
    with open(test_path, 'r') as fr:
        for line in fr:
            test_sample = json.loads(line.strip())
            testset.append(test_sample)
    
    test_dict = {}
    for sample in testset:

        sample_stat = None
        # find sample statistics
        for stat in all_stats:
            if stat['table_id'] == sample['table_id'] and stat['sub_sent_id'] == sample['sub_sent_id']:
                sample_stat = stat
                break
        if sample_stat is None: continue

        key = get_stat_key(sample_stat)
        if key not in test_dict:
            test_dict[key] = []
            print(f"got new [{key}]")
        test_dict[key].append(sample)
    
    for key, ksamples in test_dict.items():
        with open(os.path.join(output_dir, f'{key}.json'), 'w') as fw:
            for ksam in ksamples:
                kline = json.dumps(ksam)
                fw.write(kline + '\n')
        print(f'wwrote {len(ksamples)} samples into {key}.json')
    


# %% count sentence per table

def count_can_sample(table_id, table, args):
    """Create an instance of sample input given a table."""
    anno_path = os.path.join(args.can_dir, f'{table_id}.xlsx')
    struc_and_anno = load_anno_and_struc(anno_path, table_id, table, operator=None)
    if struc_and_anno is None: return
    structure, sentences = struc_and_anno
    return len(sentences)

def count_canada_annotations():
    # load tables 
    table_dict = load_tables(args.html_dir)
    # sort tables (w.r.t. id) for better stats
    sorted_table_list = [(table_id, table) for table_id, table in table_dict.items()]
    sorted_table_list = sorted(sorted_table_list, key=lambda x: x[0])
    
    statset = []  
    for table_id, table in sorted_table_list:
        count = count_can_sample(table_id, table, args)
        if count is None: continue
        statset.append(count)
        if len(statset) % 100 == 0:
            print(f'canada-stats size: {len(statset)}')
    print(f'#num sentence per table (can): {sum(statset)/len(statset):.2f}')
    return statset


def count_nsf_sample(nsf_table_id, nsf_page_id, nsf_anno_path, nsf_structure, args):
    """Create an instance of nsf table."""
    p = os.path.join('../data/', nsf_anno_path)
    nsf_sentences = load_anno_nsf(anno_path=p, operator=None)
    if nsf_sentences is None: return
    return len(nsf_sentences)

def count_nsf_annotations():
    statnsf = []

    global_table_id = 0
    num_error_case = 0
    page_dirnames = sorted(os.listdir(args.raw_nsf_dir))
    for page_dirname in sorted(page_dirnames):
        if page_dirname.endswith('N'): continue
        page_dir = os.path.join(args.raw_nsf_dir, page_dirname)
        table_names = os.listdir(page_dir)
        for tname in table_names:
            if tname.split('.')[0].endswith('N') or tname.startswith('~') or ('fig' in tname): continue

            try:
                nsf_structure, nsf_anno_path = process(
                    file_path=os.path.join(page_dir, tname), 
                    file_name=tname, 
                    page_id=page_dirname, 
                    table_id=global_table_id, 
                    new_nsf_dir=args.nsf_anno_dir, 
                )
                nsf_sample = count_nsf_sample(
                    nsf_table_id=10000+global_table_id, 
                    nsf_page_id=int(page_dirname), 
                    nsf_anno_path=nsf_anno_path, 
                    nsf_structure=nsf_structure, 
                    args=args
                )
                if nsf_sample is None: continue
                statnsf.append(nsf_sample)
                global_table_id += 1
            except: 
                num_error_case += 1
                print(f"error processing {tname} on page {page_dir}.")
    print(f'#num sent per table (nsf): {sum(statnsf)/len(statnsf):.2f}')
    return statnsf


# %% main

def main(args):
    # can_stats = stat_canada_annotations()
    # nsf_stats = stat_nsf_annotations()
    # all_stats = can_stats + nsf_stats
    # word_nums = [s['word_num'] for s in all_stats]
    # print(f'average word num per sentence: {sum(word_nums)/len(word_nums):.2f}')
    
    # with open(args.stats_path, 'w') as fw:
    #     for sample_stats in all_stats:
    #         line = json.dumps(sample_stats)
    #         fw.write(line + '\n')
    
    # split_testset_by_hierarchy(
    #     test_path=os.path.join(args.dataset_dir, 'test.json'), 
    #     stats_path=args.stats_path, 
    #     output_dir=args.dataset_dir, 
    # )

    nsf_stats = count_nsf_annotations()
    can_stats = count_canada_annotations()
    all_stats = can_stats + nsf_stats
    print(f'#num sent per table (all): {sum(all_stats)/len(all_stats):.2f}')
    print('bye :)')

    


# %% agreement, all table files from StatCan

from processing.link import get_anno_linked_cells


def linearize_cell_coords(irow, icol):
    return 128 * irow + icol


def get_agree_sample(table_id, table, anno_path):
    """Create an instance of sample input given a table."""

    # read and parse annotation
    struc_and_anno = load_anno_and_struc(anno_path, table_id, table, operator=None)
    if struc_and_anno is None: return {'sub_sent': '', 'question': '', 'cells': []}
    structure, sentences = struc_and_anno

    scat, qcat = [], []
    cells = set()
    for structure, subdict, sub_sent_id in pairing(structure, sentences):
        linked_cells = get_anno_linked_cells(structure, subdict)
        for key, dval in linked_cells.items():    
            for (irow, icol), c in dval.items():
                cpos = linearize_cell_coords(irow, icol)
                cells.add(cpos)
        scat.append( subdict['sub_sent'] )
        qcat.append( subdict['question'] )
    
    scat = [s for s in scat if s is not None]
    qcat = [q for q in qcat if q is not None]
    return {
        'cells': sorted(list(cells)), 
        'sub_sent': '\t'.join(scat), 
        'question': '\t'.join(qcat), 
    }


def calc_sample_kapas(anno_samples, key='cells'):
    instances = [s[key] for s in anno_samples]
    instances = [ins for ins in instances if len(ins) > 0]
    return


def calc_sample_bleu(anno_samples, key='sub_sent'):
    """Key using 'sub_sent' or 'question'. """
    from experiment.utils.metrics import bleu_scorer

    sents = [s[key] for s in anno_samples]
    sents = [s for s in sents if len(s) > 0]
    l = len(sents)

    pairs = [
        (sents[i], sents[j]) 
        for i in range(l) for j in range(l)
    ]
    fake_predictions = [get_words(p[0]) for p in pairs]
    fake_references = [[get_words(p[1])] for p in pairs]
    if len(fake_predictions) == 0 or len(fake_references) == 0: return
    bleu_dict = bleu_scorer.compute(
        predictions=fake_predictions, 
        references=fake_references, 
    )
    print(f"BLEU-4: {bleu_dict['bleu']:.4f}")
    return bleu_dict['bleu']



def get_agree_dataset():
    # load tables 
    table_dict = load_tables(args.html_dir)
    # sort tables (w.r.t. id) for better stats
    sorted_table_list = [(table_id, table) for table_id, table in table_dict.items()]
    sorted_table_list = sorted(sorted_table_list, key=lambda x: x[0])
    
    agreeset = {} 
    all_sbleu, all_qbleu = [], []
    for table_id, table in sorted_table_list:
        anno_paths = [os.path.join(d, f'{table_id}.xlsx') for d in args.anno_dirs]
        anno_paths = [ap for ap in anno_paths if os.path.exists(ap)]
        if len(anno_paths) == 0: continue
        anno_samples = [get_agree_sample(table_id, table, anno_path) for anno_path in anno_paths]
        agreeset[table_id] = anno_samples
        print(f'[agree-set] has {len(anno_samples)} samples about table {table_id}')

        sbleu = calc_sample_bleu(anno_samples, key='sub_sent')
        if sbleu is not None: all_sbleu.append(sbleu)
        qbleu = calc_sample_bleu(anno_samples, key='question')
        if qbleu is not None: all_qbleu.append(qbleu)

        calc_sample_kapas(anno_samples, key='cells')
    
    print(f'average sentence bleu: {sum(all_sbleu)/len(all_sbleu):.3f}')
    print(f'average question bleu: {sum(all_qbleu)/len(all_qbleu):.3f}')
    return agreeset




def calc_agreement(args):
    """Calculate the agreement scores.
    All tables from the StatCan website. 

    1. for entity and quantity linking, use Fleiss' Kappa.
    2. for sentence and questions, use BLEU.
    """
    args.all_anno_dir = '../data/agreement/'
    args.anno_dirs = os.listdir(args.all_anno_dir)   # list of directories for annotators' folders
    print(f'[anno-dirs] {args.anno_dirs}')
    args.anno_dirs = [os.path.join(args.all_anno_dir, d) for d in args.anno_dirs]

    # collect the statcan dataset
    agree_dict = get_agree_dataset()

    # get all entity linkings

    # get all quantity linkings

    # get all sentences

    # get all questions



# %% major pipeline

if __name__ == "__main__":

    parser = argparse.ArgumentParser()

    parser.add_argument('--dataset_subdir', type=str, default='../data')
    parser.add_argument('--html_subdir', type=str, default='canada_html')
    parser.add_argument('--can_anno_dir', type=str, default='canada_annotations')
    parser.add_argument('--raw_nsf_subdir', type=str, default='raw_nsf')
    parser.add_argument('--nsf_anno_dir', type=str, default='new_nsf')

    parser.add_argument('--output_subpath', type=str, default='stats.json')

    parser.add_argument('--dataset_dir', type=str, default='../data/table-split/raw/concat/both0513')
    
    args = parser.parse_args()

    args.html_dir = os.path.join(args.dataset_subdir, args.html_subdir)
    args.raw_nsf_dir = os.path.join(args.dataset_subdir, args.raw_nsf_subdir)
    args.can_dir = os.path.join(args.dataset_subdir, args.can_anno_dir)
    args.nsf_dir = os.path.join(args.dataset_subdir, args.nsf_anno_dir)

    args.stats_path = os.path.join(args.dataset_subdir, args.output_subpath)

    # main(args)
    calc_agreement(args)
