import json
import logging
import traceback
from ast import literal_eval
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union

import click
import pandas as pd
from tqdm import tqdm

PROJECT_ROOT = Path(__file__).parent.parent


def save_as_text(entries: List[str], out_path: Union[str, Path]):
    if isinstance(out_path, str):
        out_path = Path(out_path)
    with out_path.open(mode="w", encoding="utf-8") as f:
        logging.info("Storing data as txt...")
        for text in entries:
            f.write(text + "\n")


def save_as_json(entries: Union[List[Dict[Any, Any]], Dict[Any, Any], pd.DataFrame], out_path: Union[str, Path]):
    if isinstance(out_path, str):
        out_path = Path(out_path)

    if isinstance(entries, list) or isinstance(entries, pd.DataFrame):
        is_jsonl = out_path.name.endswith(".jsonl")
    elif isinstance(entries, dict):
        is_jsonl = False
    else:
        raise ValueError(f"Unknown type of data to store as JSON(L) to {out_path}")

    with out_path.open(mode="w", encoding="utf-8") as f:
        if is_jsonl:
            logging.info("Storing data as JSONL...")
            if isinstance(entries, pd.DataFrame):
                entries.to_json(f, lines=True, index=False, orient="records")
            for entry in entries:
                json_str = json.dumps(entry, ensure_ascii=False)
                f.write(json_str + "\n")
        else:
            logging.info("Storing data as JSON...")
            json.dump(entries, f, ensure_ascii=False)
    logging.info(f"Stored data in: {out_path}")


def load_json_as_dataframe(data_file_path: Path) -> pd.DataFrame:
    assert ".json" == data_file_path.suffix, f"File {data_file_path} does not have JSON filename extention!"
    logging.info(f"Read data {data_file_path}...")
    data = pd.read_json(path_or_buf=data_file_path.__str__())
    return data


def load_jsonl_as_dataframe(data_file_path: Path) -> pd.DataFrame:
    assert ".jsonl" == data_file_path.suffix, f"File {data_file_path} does not have JSONL filename extention!"
    logging.info(f"Read data {data_file_path}...")
    data = pd.read_json(path_or_buf=data_file_path.__str__(), orient="records", lines=True)
    return data


def load_json(path: Union[str, Path]) -> Dict[Any, Any]:
    path = Path(path)
    logging.info(f"Loading {path}...")
    with path.open("r") as f:
        data = json.load(f)
    return data


def load_jsonl(path: Union[str, Path]) -> List[Any]:
    path = Path(path)
    entries = []
    with path.open("r") as f:
        for i, line in enumerate(f):
            try:
                entries.append(json.loads(line.rstrip()))
            except json.JSONDecodeError:
                logging.info(f"Decoding error. Skip line {i}")
                logging.error(traceback.print_exc())
    return entries


def stream_jsonl_store_err_lines(path: Union[str, Path]) -> Any:
    num_decoding_error = 0
    err_lines = []
    pbar = tqdm(stream_text_file(path))
    for line in pbar:
        try:
            yield json.loads(line.rstrip())
        except json.JSONDecodeError:
            num_decoding_error += 1
            pbar.set_postfix(dict(num_decoding_error=num_decoding_error))
            err_lines.append(line)
    save_as_text(err_lines, path.parent / (path.stem + "_erroneous_lines.txt"))


def stream_text_file(path: Union[str, Path]):
    path = Path(path)
    logging.info(f"Streaming {path}...")
    with path.open("r") as f:
        for line in f:
            yield line


def lima_csv_to_jsonl(file_path: Path):
    assert ".csv" == file_path.suffix, f"File {file_path} does not have CSV filename extention!"
    data = pd.read_csv(filepath_or_buffer=file_path)
    out_file_path = file_path.with_suffix(".jsonl")
    _abort_if_user_forbids_override(out_file_path)
    data["comment"] = data.comment.fillna(value="")
    data["highlight"] = data.highlight.fillna(value=0).astype("Int64")
    data["conversations"] = data.conversations.apply(lambda x: literal_eval(x))
    save_as_json(entries=data.to_dict(orient="records"), out_path=out_file_path)


def lima_val_csv_to_jsonl(file_path: Path):
    assert ".csv" == file_path.suffix, f"File {file_path} does not have CSV filename extention!"
    data = pd.read_csv(filepath_or_buffer=file_path)
    data["conversations"] = [[t, a] for t, a in zip(data.title.to_list(), data.answer.to_list())]
    data = data[["source", "conversations"]]
    out_file_path = file_path.with_suffix(".jsonl")
    save_as_json(entries=data.to_dict(orient="records"), out_path=out_file_path)


def jsonl_to_csv(file_path: Path):
    data = load_jsonl_as_dataframe(filepath_or_buffer=file_path)
    out_file_path = file_path.with_suffix(".csv")
    _abort_if_user_forbids_override(out_file_path)
    data.to_csv(path_or_buf=out_file_path, index=False)


def _abort_if_user_forbids_override(out_file_path: Path):
    if out_file_path.exists():
        do_overwrite = click.confirm(f"Do you want to overwrite {out_file_path}?", default=True)
        if not do_overwrite:
            logging.info("Do not overwrite file. Abort.")
            exit()


def split_into_train_val_test(data: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    train = data.sample(frac=0.8, random_state=200)
    test_val = data.drop(train.index)
    val = test_val.sample(frac=0.5, random_state=200)
    test = test_val.drop(val.index)
    return train, val, test
