import json
from pathlib import Path

import torch
import typer
from baseline_model import (BaselineModelMultiply
                            )
from datasets import Dataset
from transformers import (AutoTokenizer,
                          AutoConfig
                          )
from utils.data import (get_data
                        )
from utils.utils import (run_training,
                         run_test
                         )

app = typer.Typer()


@app.command()
def train_model(
        model_name_or_path: str = typer.Option(...),
        tokenizer_path: str = typer.Option(...),
        data_folder: str = typer.Option(...),
        data_meta: str = typer.Option(...),
        batch_size: int = typer.Option(64, help="choose batch_size"),
        max_length: int = typer.Option(64, help="maximum length of sequence of tokens"),
        number_of_epochs: int = typer.Option(10, help="number of epochs to train model"),
        learning_rate: float = typer.Option(2e-5, help="learning_rate for optimizer"),
        device: str = typer.Option('cuda', help="use cuda or not"),
        do_train: bool = typer.Option(True, "--train/--no_train"),
        do_test: bool = typer.Option(True, "--test/--no_test"),
        save_path: str = typer.Option(None, help="where to save/find model tuned checkpoint"),
        save_test_metrics_file: str = typer.Option(None, help="where to save test metrics"),
        logging_file: str = typer.Option(
            'log.txt', help="file to log training metrics")
):
    # data
    data = {split: get_data(str(Path(data_folder, split + '.csv')))
            for split in ['train', 'dev', 'test']}
    data['valid'] = data.pop('dev')
    data['infer'] = data.pop('test')

    intent2descriptions = json.load(
        Path(data_meta, 'intent_descriptions.json').open('r'))
    intent2actions = json.load(Path(data_meta, 'actions.json').open('r'))
    intent2concepts = json.load(Path(data_meta, 'concepts.json').open('r'))

    if Path(data_folder, 'intent2idx.json').exists():
        intent2idx = json.load(Path(data_folder, 'intent2idx.json').open('r'))
    else:
        intent2idx = {
            split_name: {intent: intent_id for intent_id,
                                               intent in enumerate(sorted(list(set(d['labels']))))}
            for split_name, d in data.items()
        }
        json.dump(intent2idx, Path(data_folder, 'intent2idx.json').open('w'))

    if Path(data_folder, 'idx2intent.json').exists():
        idx2intent = json.load(Path(data_folder, 'idx2intent.json').open('r'))
    else:
        idx2intent = {
            i: {idx: intent for intent, idx in j.items()}
            for i, j in intent2idx.items()
        }
        json.dump(idx2intent, Path(data_folder, 'idx2intent.json').open('w'))

    if Path(data_folder, 'action2idx.json').exists():
        action2idx = json.load(Path(data_folder, 'action2idx.json').open('r'))
    else:
        action2idx = {
            split_name: {action: action_id for action_id, action in
                         enumerate(sorted(set([intent2actions[i] for i in intent2idx[split_name].keys()])))}
            for split_name in data.keys()
        }
        json.dump(action2idx, Path(data_folder, 'action2idx.json').open('w'))

    if Path(data_folder, 'idx2action.json').exists():
        idx2action = json.load(Path(data_folder, 'idx2action.json').open('r'))
    else:
        idx2action = {
            i: {idx: action for action, idx in j.items()}
            for i, j in action2idx.items()
        }
        json.dump(idx2action, Path(data_folder, 'idx2action.json').open('w'))

    if Path(data_folder, 'concept2idx.json').exists():
        concept2idx = json.load(
            Path(data_folder, 'concept2idx.json').open('r'))
    else:
        concept2idx = {
            split_name: {concept: concept_id for concept_id, concept in
                         enumerate(sorted(set([intent2concepts[i] for i in intent2idx[split_name].keys()])))}
            for split_name in data.keys()
        }
        json.dump(concept2idx, Path(data_folder, 'concept2idx.json').open('w'))

    if Path(data_folder, 'idx2concept.json').exists():
        idx2concept = json.load(
            Path(data_folder, 'idx2concept.json').open('r'))
    else:
        idx2concept = {
            i: {idx: concept for concept, idx in j.items()}
            for i, j in concept2idx.items()
        }
        json.dump(idx2concept, Path(data_folder, 'idx2concept.json').open('w'))

    new_intent_names = set(
        intent2idx['infer'].keys()) - set(intent2idx['train'].keys())
    new_intent_idx = [intent2idx['infer'][i] for i in new_intent_names]

    data_dict = {
        split_name:
            {'text': d['text'],
             'labels': [intent2idx[split_name][i] for i in d['labels']],
             'labels_actions': [action2idx[split_name][intent2actions[i]] for i in d['labels']],
             'labels_concepts': [concept2idx[split_name][intent2concepts[i]] for i in d['labels']],
             'descriptions': [intent2descriptions[i] for i in sorted(intent2idx[split_name].keys())],
             'actions': [intent2actions[i] for i in sorted(intent2idx[split_name].keys())],
             'concepts': [intent2concepts[i] for i in sorted(intent2idx[split_name].keys())],
             'names': sorted(intent2idx[split_name].keys()),
             }
        for split_name, d in data.items()
    }

    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

    def preprocess_function_list(examples, max_length=64):
        result = tokenizer(examples, padding='max_length',
                           max_length=max_length, truncation=True, return_tensors='pt')
        return result

    parameters = dict()
    for split_name, d in data_dict.items():
        parameters[split_name] = dict()
        for name in ['descriptions', 'actions', 'concepts', 'names']:
            parameters[split_name][name] = preprocess_function_list(
                data_dict[split_name][name], 10)
        parameters[split_name]['descriptions_names'] = preprocess_function_list([' '.join
                                                                                 ([name,
                                                                                   intent2descriptions[intent],
                                                                                   ])
                                                                                 for intent in d['names']], 20)
        parameters[split_name]['descriptions_names_concepts_actions'] = preprocess_function_list([' '.join
                                                                                                  ([name,
                                                                                                    intent2descriptions[
                                                                                                        intent],
                                                                                                    intent2actions[
                                                                                                        intent],
                                                                                                    intent2concepts[
                                                                                                        intent],
                                                                                                    ])
                                                                                                  for intent in
                                                                                                  d['names']], 40)
    datasets = dict()
    for split_name, d in data_dict.items():
        datasets[split_name] = dict()
        for name in ['labels', 'labels_actions', 'labels_concepts']:
            datasets[split_name][name] = data_dict[split_name][name]
        datasets[split_name].update(
            preprocess_function_list(data_dict[split_name]['text']))

    datasets_ = {i: Dataset.from_dict(j) for i, j in datasets.items()}
    for j in datasets_.values():
        j.set_format(type='torch', columns=list(
            datasets_['train'].features.keys()), output_all_columns=True)
    dataloaders = {i: torch.utils.data.DataLoader(
        j, batch_size=batch_size) for i, j in datasets_.items()}

    print('data ready')
    # model
    device = torch.device(device)
    config = AutoConfig.from_pretrained(model_name_or_path)
    model = BaselineModelMultiply(
        model_name_or_path=model_name_or_path, config=config).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    print('model ready')

    # train validate test

    if do_train:
        run_training(model, optimizer, parameters, train_dataloader=dataloaders['train'],
                     valid_dataloader=dataloaders['valid'],
                     number_of_epochs=number_of_epochs,
                     logging_file=logging_file, save_path=save_path,
                     new_intents=new_intent_idx,
                     predict_action=predict_action,
                     predict_concept=predict_concept,
                     metric_learning=metric_learning,
                     )
    if do_test:
        model.load_state_dict(torch.load(save_path)['model_state_dict'])
        run_test(model, parameters, test_dataloader=dataloaders['infer'],
                 logging_file=logging_file, save_test_metrics_file=save_test_metrics_file,
                 new_intents=new_intent_idx,
                 predict_action=predict_action,
                 predict_concept=predict_concept,
                 metric_learning=metric_learning,
                 )


if __name__ == "__main__":
    app()
