import itertools
import pathlib
import re


def read_dataset(path):
    path = pathlib.Path(path)
    suffix = path.suffix.lower()
    if suffix == '.xlsx':
        yield from read_xlsx(path)
    else:
        yield from read_txt(path)


def read_xlsx(path, label_field=None, text_field=None):
    import pandas as pd

    df = pd.read_excel(path)
    df.set_index(df.columns[0], inplace=True)

    if not label_field:
        label_field = df.columns[0]

    if not text_field:
        text_field = df.columns[1]

    for idx, row in df.iterrows():
        yield row[label_field], row[text_field]


def read_txt(path, strict=True):
    line_re = re.compile(r'(\S+)\s+(.*)')
    with open(path) as f:
        for line, lineno in zip(f, itertools.count(1)):
            line = line.strip()
            m = line_re.match(line)
            if not m and strict:
                raise Exception(f'invalid line found in {path} at line {lineno}')
            label = m.group(1)
            text = m.group(2)
            yield label, text
