"""Input/Output utility functions for data preprocessing."""


import os
import json
import random
from bs4 import BeautifulSoup
from openpyxl import load_workbook

from processing.base.crawler import Table
from processing.base.structure import table_to_struc
from processing.annotation.read import read_table_anno



# %% data load
def load_tables(html_path, verbose=True):
    """Load the basic information of all tables."""
    table_dict = {}

    html_files = os.listdir(html_path)
    for file_name in html_files:
        file_path = os.path.join(html_path, file_name)
        with open(file_path, 'r', encoding='utf-8') as f:
            soup = BeautifulSoup(f.read(), 'html.parser')
        table_id = int(file_name.split('.')[0])
        table = Table(table_id, None, None)
        table.set_html(soup)
        table_dict[table_id] = table
    
    if verbose == True:
        print(f"Collected {len(table_dict)} Table Candidates.")
    return table_dict


# %% data dump

def calc_thresholds(portion, epsilon=1e-6):
    assert len(portion) == 3
    denominator = (sum(portion) + epsilon)
    thres1 = portion[0] / denominator
    thres2 = (portion[0] + portion[1]) / denominator
    print(f'[io_utils >> calc_thres] has portion {portion} and thres [{thres1:.2f}, {thres2:.2f}]')
    return thres1, thres2

def split_dataset_by_key(dataset, key, portion=[0.8, 0.1, 0.1], shuffle=True, verbose=True):
    """Split a dataset of dict-type samples with a specified key.
    'key' currently takes: 'sub_sent_id', 'table_id', 'url_id'.
    """
    n = len(dataset)
    # group samples based on their the 'key'
    key_samples_dict = dict()
    for sample in dataset:
        kid = sample[key]
        if kid not in key_samples_dict:
            key_samples_dict[kid] = []
        key_samples_dict[kid].append(sample)
    print(f'[split] got {len(key_samples_dict)} categories based on key {key}')
    
    nmax_train = int(n * portion[0])
    nmax_valid = int(n * portion[1])
    nmax_test = n - nmax_train - nmax_valid

    trainset, validset, testset = [], [], []
    ntrain, nvalid, ntest = 0, 0, 0

    thres1, thres2 = calc_thresholds(portion)

    for kid, ksamples in key_samples_dict.items():
        prob = random.random()
        if prob < thres1:    # train
            trainset.extend(ksamples)
            ntrain += len(ksamples)
            # modify the thresholds if any dataset is full
            if ntrain >= nmax_train:
                portion[0] = 0.0
                thres1, thres2 = calc_thresholds(portion)
        elif prob < thres2:  # valid
            validset.extend(ksamples)
            nvalid += len(ksamples)
            if nvalid >= nmax_valid:
                portion[1] = 0.0
                thres1, thres2 = calc_thresholds(portion)
        else:                # test
            testset.extend(ksamples)
            ntest += len(ksamples)
            if ntest >= nmax_test:
                portion[2] = 0.0
                thres1, thres2 = calc_thresholds(portion)
       
    if verbose == True:
        print(f'[io_utils >> split_dataset] results {len(trainset)} / {len(validset)} / {len(testset)} samples.')  
    return trainset, validset, testset


def get_table_page_id(dataset, pair_path):
    """From the table-url-pairs xlsx, get the page url of table-id. """
    wb = load_workbook(filename=pair_path)
    ws = wb['Sheet']
    table_id_column_name = ws.cell(1, 2).value
    assert table_id_column_name.lower() == 'table id', f"unexpected table-id column name: {table_id_column_name}"
    url_column_name = ws.cell(1, 4).value
    assert url_column_name.lower() == 'url', f"unexpected url column name: {url_column_name}"
    
    next_url_id = 0
    url_id_map = {}
    tid_uid_dict = {}
    for irow in range(2, ws.max_row):
        # read 'table id' and 'url' annotations  
        table_id = ws.cell(irow, 2).value       # int?
        url = ws.cell(irow, 4).value.strip()         # str

        # update 'url-id' with 'url' contents
        if url not in url_id_map: 
            url_id_map[url] = next_url_id
            next_url_id += 1
        url_id = url_id_map[url]

        # save table-id and url-id record
        tid_uid_dict[int(table_id)] = url_id
    print(f'[io_utils >> get_table_page_id] reads {len(url_id_map)} urls of {len(tid_uid_dict)} tables.')
    
    # @USER TODO: how many tables & sub-sentences are there in a page?

    n = len(dataset)
    for i in range(n):
        tid = dataset[i]['table_id']
        if tid not in tid_uid_dict:
            print(f'dataset tid [{tid}] of type [{type(tid)}]')
            continue
        uid = tid_uid_dict[tid]
        dataset[i].update({ 'url_id': uid, })
    return dataset



def get_table_domain_id(dataset, pair_path, column_index=5):
    """From the table-url-pairs xlsx, get the table domain of table-id. """
    wb = load_workbook(filename=pair_path)
    ws = wb['Sheet']
    table_id_column_name = ws.cell(1, 2).value
    assert table_id_column_name.lower() == 'table id', \
        f"unexpected table-id column name: {table_id_column_name}"
    domain_column_name = ws.cell(1, column_index).value
    assert domain_column_name.lower() == 'domain', \
        f"unexpected url column name: {domain_column_name}"
    
    next_domain_id = 1    # index '0' used by 'nsf' tables
    domain_id_map = {}     # {'Agriculture and food': 0, str: int}
    tid_did_dict = {}      # {0: 0, int: int}
    for irow in range(2, ws.max_row):
        # read 'table id' (int) and 'domain' (text) annotations  
        table_id = ws.cell(irow, 2).value                      # int?
        domain = ws.cell(irow, column_index).value.strip()     # str
        # update 'domain-id' with 'domain' contents
        if domain not in domain_id_map: 
            domain_id_map[domain] = next_domain_id
            next_domain_id += 1
        # get the domain from updated/not-updated domain map
        domain_id = domain_id_map[domain]

        # save table-id and url-id record
        tid_did_dict[int(table_id)] = domain_id
    print(f'[io_utils >> get_table_domain_id] reads {len(domain_id_map)} domain of {len(tid_did_dict)} tables.')
    print(f'got domains: {list(domain_id_map.keys())}')
    
    n = len(dataset)
    for i in range(n):
        tid = dataset[i]['table_id']
        if (tid >= 10000): continue 
        if (tid not in tid_did_dict):
            print(f'dataset tid [{tid}] of type [{type(tid)}]')
            continue
        domain = tid_did_dict[tid]
        dataset[i].update({ 'domain_id': domain, })
    return dataset



def split_dataset(dataset, pair_path, method, portion=[0.8, 0.1, 0.1], shuffle=True, verbose=True):
    """Split the dataset based on the given method."""

    n = len(dataset)
    if verbose == True: print(f"collect {n} samples in the dataset.")
    if shuffle == True: random.shuffle(dataset)

    if method == 'subsent':
        nbunch1, nbunch2 = int(n * portion[0]), int(n * portion[1])
        trainset = dataset[: nbunch1]
        validset = dataset[nbunch1: n - nbunch2]
        testset = dataset[n - nbunch2: ]
    elif method == 'table':
        trainset, validset, testset = split_dataset_by_key(dataset, key='table_id', 
            portion=portion, shuffle=shuffle, verbose=verbose)
    elif method == 'page':
        dataset = get_table_page_id(dataset, pair_path=pair_path)
        trainset, validset, testset = split_dataset_by_key(dataset, key='url_id', 
            portion=portion, shuffle=shuffle, verbose=verbose)
    elif method == 'domain':     # use only the canada for train/validation/test
        dataset = get_table_domain_id(dataset, pair_path=pair_path)
        cnd_dataset = [sample for sample in dataset if (sample['domain_id'] != 0)]
        nsf_dataset = [sample for sample in dataset if (sample['domain_id'] == 0)]
        print(f'[split] by [domain] use canada {len(cnd_dataset)} and dump nsf {len(nsf_dataset)}')
        trainset, validset, testset = split_dataset_by_key(cnd_dataset, key='domain_id', 
            portion=portion, shuffle=shuffle, verbose=verbose)
    elif method == 'website':    # 'canada' as train&validation, 'nsf' as test
        dataset = get_table_domain_id(dataset, pair_path=pair_path)
        cnd_dataset = [sample for sample in dataset if (sample['domain_id'] != 0)]
        nsf_dataset = [sample for sample in dataset if (sample['domain_id'] == 0)]
        print(f'[split] by website has canada {len(cnd_dataset)} and nsf {len(nsf_dataset)}')
        trainset, validset, _ = split_dataset_by_key(cnd_dataset, key='domain_id', 
            portion=[0.89, 0.11, 0.0], shuffle=shuffle, verbose=verbose)
        testset = nsf_dataset
        print(f'[split] by website has train {len(trainset)}, valid {len(validset)}, test {len(testset)}.')
    else:
        print(f'[io_utils >> split_dataset] gets unexpected method {method}')
        return
    return trainset, validset, testset


def dump_dataset(
    dataset, pair_path, 
    data_dir, sample_dir=None, 
    names=['train.json', 'valid.json', 'test.json'], 
    method='table', 
    shuffle=True, verbose=True
):
    """Split train/valid/testset, output sample files if subdir specified.
    method: choices = ['subsent', 'table']
    """

    trainset, validset, testset = split_dataset(dataset, pair_path, method, 
        shuffle=shuffle, verbose=verbose)

    train_name, valid_name, test_name = names
    train_path = os.path.join(data_dir, train_name)
    with open(train_path, 'w') as fw1:
        for sample in trainset:
            line = json.dumps(sample)
            fw1.write(line + '\n')

    valid_path = os.path.join(data_dir, valid_name)
    with open(valid_path, 'w') as fw2:
        for sample in validset:
            line = json.dumps(sample)
            fw2.write(line + '\n')

    test_path = os.path.join(data_dir, test_name)
    with open(test_path, 'w') as fw3:
        for sample in testset:
            line = json.dumps(sample)
            fw3.write(line + '\n')

    if sample_dir is not None:
        sample_train_path = os.path.join(sample_dir, train_name)
        with open(sample_train_path, 'w') as fw1:
            for sample in trainset[:10]:
                line = json.dumps(sample)
                fw1.write(line + '\n')
        sample_valid_path = os.path.join(sample_dir, valid_name)
        with open(sample_valid_path, 'w') as fw2:
            for sample in validset[:10]:
                line = json.dumps(sample)
                fw2.write(line + '\n')
        sample_test_path = os.path.join(sample_dir, test_name)
        with open(sample_test_path, 'w') as fw3:
            for sample in testset[:10]:
                line = json.dumps(sample)
                fw3.write(line + '\n')      

    if verbose == True:
        print(f"[io_utils >> dump_dataset] writes {len(trainset)} train, {len(validset)} valid, {len(testset)} test samples.")
        if sample_dir is not None:
            print(f"[io_utils >> dump_dataset] writes 10 resp samples for test.")
    return




# %% annotation

def load_anno_nsf(anno_path, operator):
    """Read sentence annotations from the refined xlsx file."""
    if not os.path.exists(anno_path): return
    sentences = read_table_anno(anno_path, operator)
    if len(sentences) == 0: return 
    return sentences


def load_anno_and_struc(anno_path, table_id, table, operator):
    """Try read the sentence annotations and html structure."""

    # read and parse annotations
    if not os.path.exists(anno_path): 
        # print(f'cannot find the annotation file named [{anno_path}] .. bye..')
        return

    sentences = read_table_anno(anno_path, operator)
    if len(sentences) == 0:
        # print(f'cannot parse any valid sentences from table [{table_id}] .. bye..')
        return 
    
    # read html and parse structure
    structure = table_to_struc(table)

    return structure, sentences


def get_sub_sent_id(desc_sent_id, isub):
    return f'{desc_sent_id}-{isub}'

def pairing(structure, sentences):
    """Iteratively yield struc-subsent pairs, for OpCheck inputs."""
    for sent in sentences:    
        # {'start_row', 'end_row', 'table_desc_sent', 'table_desc_sent_id', 'sub_samples'}
        desc_sent_id = sent['table_desc_sent_id']
        sub_samples = sent['sub_samples']

        for isub, subdict in enumerate(sub_samples): 
            if subdict is not None:
            # {'start_row', 'sub_sent', 'key_part', 'key_index', 'schema_link', 'answer_cells', 'answer', 'aggregation'}
                sub_sent_id = get_sub_sent_id(desc_sent_id, isub)
                yield structure, subdict, sub_sent_id

