import argparse
import csv
import collections
import contextlib
import pathlib
import re
import statistics as st
import sys

parser = argparse.ArgumentParser()
parser.add_argument('--workspaces', nargs='+', type=pathlib.Path, required=True)
parser.add_argument('--log-glob', default='classification-log.txt')
parser.add_argument('--log-re', type=re.compile)
parser.add_argument('--output', '-o')


def msg(*k, **kw):
    kw['file'] = sys.stderr
    print(*k, **kw)


def find_logs(path, pattern, is_re=False):
    if is_re:
        yield from (
            c for c in path.iterdir() if c.is_file() and pattern.match(c.name)
        )
    else:
        yield from (
            c for c in path.glob(pattern) if c.is_file()
        )

log_line_re = re.compile(r'([^:]+):\s*([^\s]+)')
expected_metrics = [
    'Accuracy',
    'MCC',
    'SC',
    'BH',
]
def parse_log(path):
    r = {}
    for line in path.read_text().splitlines():
        m = log_line_re.match(line)
        if not m:
            continue
        r[m.group(1)] = float(m.group(2))

    missing = set(expected_metrics) - set(r)
    if missing:
        parser.error('{path}: missing the following metrics: {", ".join(missing)}')

    return r

def run(args):
    if args.log_re:
        is_re = True
        pattern = args.log_re
    else:
        is_re = False
        pattern = args.log_glob

    with contextlib.ExitStack() as exit_stack:
        if args.output:
            csvfile = open(args.output, 'w', newline='')
            exit_stack.enter_context(csvfile)
        else:
            csvfile = sys.stdout

        fieldnames = ['workspace']
        for k in expected_metrics:
            fieldnames += [f'{k}-avg', f'{k}-stdev']

        csv_writer = csv.DictWriter(csvfile, fieldnames)
        csv_writer.writeheader()
        number_of_log_files = -1
        for path in args.workspaces:
            log_files = list(find_logs(path, pattern, is_re))

            if not len(log_files):
                parser.error(f'no log files found for {path}')

            if number_of_log_files == -1:
                number_of_log_files = len(log_files)
            elif number_of_log_files != len(log_files):
                parser.error(f'inconsistent number of log files found for {path}: {len(log_files)}. It should be {number_of_log_files}')

            metrics = collections.defaultdict(list)
            for log_file in log_files:
                m = parse_log(log_file)
                for k, v in m.items():
                    metrics[k].append(v)

            row = {'workspace': path}
            for k in metrics:
                if number_of_log_files == 1:
                    row[f'{k}-avg'] = metrics[k][0]
                    row[f'{k}-stdev'] = 0
                else:
                    row[f'{k}-avg'] = st.mean(metrics[k])
                    row[f'{k}-stdev'] = st.stdev(metrics[k])
            csv_writer.writerow(row)

if __name__ == '__main__':
    args = parser.parse_args()
    run(args)
