import itertools
from collections import defaultdict
from typing import Dict, List

import pandas as pd

from align.config.app_config import DatasetName
from align.utils import PROJECT_ROOT, load_jsonl_as_dataframe, save_as_json


def format_to_lima(dataset_path: DatasetName):
    if dataset_path == DatasetName.lima:
        return load_jsonl_as_dataframe(data_file_path=PROJECT_ROOT / "data/LIMA/train.jsonl")
    elif dataset_path == DatasetName.oasst1:
        _format_oasst_to_lima()
    else:
        raise ValueError("Unkown Dataset!")


def _format_oasst_to_lima():
    for partition in ["train", "val"]:
        conversations_in_lima_format = []
        data_file_path = PROJECT_ROOT / f"data/oasst1/oasst1_{partition}.jsonl"
        data = pd.read_json(path_or_buf=data_file_path, lines=True)
        split_conversation_indices = data[data.parent_id.isnull()].index.to_list()
        data = data.to_dict("records")
        for start_idx, end_idx in window(split_conversation_indices):
            conversation_tree = [utterance for utterance in data[start_idx:end_idx]]
            forest_paths = __get_all_tree_paths(conversation_tree)
            conversations_in_lima_format.extend(_get_conversations(forest_paths=forest_paths))
        conversations_in_lima_format = [
            {"conversations": conversation, "source": "oasst"} for conversation in conversations_in_lima_format
        ]
        save_as_json(
            entries=conversations_in_lima_format,
            out_path=data_file_path.parent / (data_file_path.stem + "_in_lima_format.jsonl"),
        )


def window(seq, n=2):
    "Returns a sliding window (of width n) over data from the iterable"
    "   s -> (s0,s1,...s[n-1]), (s1,s2,...,sn), ...                   "
    it = iter(seq)
    result = tuple(itertools.islice(it, n))
    if len(result) == n:
        yield result
    for elem in it:
        result = result[1:] + (elem,)
        yield result


def __get_all_tree_paths(conversation_tree: List[Dict]) -> List[List[Dict]]:
    conversation_tree_lut = {utterance["message_id"]: utterance for utterance in conversation_tree}
    message_ids = set(conversation_tree_lut.keys())
    parent_ids = set([utterance["parent_id"] for utterance in conversation_tree if utterance["parent_id"] is not None])
    leaf_ids = message_ids.difference(parent_ids)
    conversations = defaultdict(list)
    for idx, leaf_id in enumerate(leaf_ids):
        leaf = conversation_tree_lut[leaf_id]
        conversations[idx].append(leaf)
        parent_id = leaf["parent_id"]
        while parent_id is not None:
            parent = conversation_tree_lut[parent_id]
            conversations[idx].append(parent)
            parent_id = parent["parent_id"]
        conversations[idx] = list(reversed(conversations[idx]))
        # Note: the last conversation must not be of the assistant!
    return list(conversations.values())


def _get_conversations(forest_paths) -> List[List[str]]:
    return [[node["text"] for node in path] for path in forest_paths]
