from typing import IO, Any, List, Tuple, Dict, Union
import json


def BIO_file_reader(f_handle: IO, existing_title: bool = False) -> List[Tuple[Any]]:
    sentence = []
    title = 'Non'
    for line in f_handle:
        line = line.strip()
        if line == '':
            if len(sentence) != 0:
                if existing_title is False:
                    yield sentence
                else:
                    yield (sentence, title)
            sentence = []
        elif line[:7] == '<title>':
            title = line[7:].strip()
        else:
            items = line.split()

            # if len(items) == 1:
            #     items = (items[0], 'O')
            # if items[1][0] == 'M' or items[1][0] == 'E':
            #     items[1] = 'I' + items[1][1:]
            # if items[1][0] == 'S':
            #     items[1] = 'B' + items[1][1:]

            sentence.append((items[0], items[1]))

    if len(sentence) != 0:
        if existing_title is False:
            yield sentence
        else:
            yield (sentence, title)


def BIO_file_printer(f_handle: IO, sentence: List[Tuple[Any]]) -> None:
    for items in sentence:
        f_handle.write('\t'.join(items) + '\n')
    f_handle.write('\n')


def feature_reader(f_handle: IO) -> List[Dict[str, Union[float, int]]]:
    sentence = []
    for line in f_handle:
        line = line.strip()
        if line == '':
            if len(sentence) != 0:
                yield sentence
            sentence = []
        else:
            sentence.append(json.loads(line))

    if len(sentence) != 0:
        yield sentence


def split_dev_from_train(input_file: str, out_path: str, split_ratio: float) -> None:
    import random

    with open(input_file, 'r', encoding='utf-8') as f, open(out_path + 'train.txt', 'w', encoding='utf-8') as train, open(out_path + 'dev.txt', 'w', encoding='utf-8') as dev:
        total_cnt = 0
        train_cnt = 0
        dev_cnt = 0
        for sentence in BIO_file_reader(f):
            total_cnt += 1
            if random.random() < split_ratio:
                dev_cnt += 1
                BIO_file_printer(dev, sentence)
            else:
                train_cnt += 1
                BIO_file_printer(train, sentence)

        print(f'total num: {total_cnt}, train sentence num: {train_cnt}, dev sentence num: {dev_cnt}')


def split_from_book(input_file: str, out_path: str, book_map: Dict, prefix: str = 'test') -> None:
    outs = {}
    for k, v in book_map.items():
        outs[k] = open(out_path + v + f'/{prefix}.txt', 'w', encoding='utf-8')

    with open(input_file, 'r', encoding='utf-8') as f:
        for (sentence, title) in BIO_file_reader(f, existing_title=True):
            BIO_file_printer(outs[title], sentence)

    for o in outs.values():
        o.close()


def generate_label_vocab(input_file: str, out_file: str) -> None:
    with open(input_file, 'r', encoding='utf-8') as f, open(out_file, 'w', encoding='utf-8') as o:
        label_dict = {}
        index = 0
        for sentence in BIO_file_reader(f):
            for items in sentence:
                if items[1] != "O" and items[1] not in label_dict:
                    label_dict[items[1]] = (index, 0)
                    index += 1
        
        json.dump(label_dict, o, indent=4, ensure_ascii=False)


def split_by_finegrain(input_file: str, out_path: str, fine_grain: List[str]) -> None:
    outs = {}
    for name in fine_grain:
        outs[name] = open(out_path + f'/test_{name}.bio', 'w', encoding='utf-8')

    with open(input_file, 'r', encoding='utf-8') as f:
        for (sentence, title) in BIO_file_reader(f, existing_title=True):
            BIO_file_printer(outs[title], sentence)

    for o in outs.values():
        o.close()

