import os
import time
from datetime import datetime, timedelta
from typing import List

import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from transformers import BertModel, BertTokenizer, AdamW, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import classification_report
from collections import defaultdict

from bert_text_classifier.preprocess import preprocess
from bert_text_classifier.dataset import load_coda_darkweb_texts

RANDOM_SEED = 50
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

CUDA_DEVICE_NO = 0
MODEL_SAVE_PATH = 'bert_text_classifier/saved_models/'


class TextClassifier:
    def __init__(self, class_names, device, bert_model_name, bert_tokenizer,
                 max_seq_len, preproc_type, experiment_name='textcat'):

        self.preproc_type = preproc_type
        self.experiment_name = experiment_name

        model = BertTextClassifierModel(len(class_names), bert_model_name)
        model = model.to(device)

        self.tokenizer = bert_tokenizer
        self.max_seq_len = max_seq_len

        self.model = model
        self.class_names = class_names
        self.device = device
        self.bert_model_name = bert_model_name

        # To be filled when training data is loaded
        self.scheduler = None
        self.epochs = None
        self.optimizer = None
        self.loss_fn = None

    def load_trained_model(self, bert_model_name, model_path):
        model = BertTextClassifierModel(len(self.class_names), bert_model_name)
        model.load_state_dict(torch.load(model_path))
        self.model = model.to(self.device)

    def config_training(self, num_training_examples, epochs, learning_rate):
        total_steps = num_training_examples * epochs
        self.epochs = epochs
        self.optimizer = AdamW(self.model.parameters(), lr=learning_rate, correct_bias=False)
        self.loss_fn = nn.CrossEntropyLoss().to(self.device)

        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=0,
            num_training_steps=total_steps
        )

    def train(self, train_data_loader, val_data_loader=None, save_best_model=True,
              model_suffix=''):
        experiment_name = self.experiment_name
        epochs = self.epochs
        model = self.model
        bert_model_name = self.bert_model_name

        history = defaultdict(list)
        best_accuracy = 0
        best_epochs = 0
        best_model_file_path = ''

        start_training_time = time.time()

        for epoch in range(epochs):
            start_epoch_time = time.time()

            print(f'Epoch {epoch + 1}/{epochs}')
            print('-' * 10)
            train_acc, train_loss = self.train_epoch(train_data_loader)
            print(f'Train loss {train_loss:.4f} accuracy {train_acc:.2%}')

            # --------------------------------------------------------
            #     Early stopping using validation set (for epochs)
            # --------------------------------------------------------
            if val_data_loader:
                val_acc, val_loss = self.eval_model(val_data_loader)

                print(f'  Val loss {val_loss:.4f} accuracy {val_acc:.2%}')

                history['train_acc'].append(train_acc)
                history['train_loss'].append(train_loss)
                history['val_acc'].append(val_acc)
                history['val_loss'].append(val_loss)

                # If the best result is updated on the validation set, save this model as a file
                if val_acc > best_accuracy:
                    print(f'Higher than the best val accuracy so far ({best_accuracy:.2%})')
                    if save_best_model:
                        model_filename \
                            = f'{experiment_name}_model_state_{bert_model_name}_{self.preproc_type}_len{self.max_seq_len}_{val_acc.item():.2%}.bin'

                        if model_suffix:
                            model_filename = model_filename.replace('.bin', f'_{model_suffix}.bin')

                        os.makedirs(MODEL_SAVE_PATH, exist_ok=True)
                        best_model_file_path = os.path.join(MODEL_SAVE_PATH, model_filename)
                        print(f'Saving the best model ({val_acc:.2%}): {best_model_file_path}')
                        torch.save(model.state_dict(), best_model_file_path)

                    best_accuracy = val_acc
                    best_epochs = epoch + 1

            print(f'Elapsed epoch time: {time.time() - start_epoch_time:.1f} sec')
            print()

        training_sec = time.time() - start_training_time
        times = str(timedelta(seconds=training_sec)).split(".")
        times = times[0]

        num_training_examples = len(train_data_loader.dataset)

        print(f'Elapsed time for training loop: {training_sec:.1f} sec ({times})')
        print(f'Batch size:', train_data_loader.batch_size)
        print(f'BERT model:', bert_model_name)
        print(f'max_seq_len:', self.max_seq_len)
        print(f'# training examples:', num_training_examples)
        print(f'Total # epochs:', epochs)

        if val_data_loader:
            print(f'Best # epochs:', best_epochs)
            return best_accuracy.item(), history, best_model_file_path, best_epochs
        else:
            # =======================================================
            #        Save the trained model (no early stopping)
            # =======================================================
            model_filename = f'{experiment_name}_model_state_{bert_model_name}_{self.preproc_type}' \
                             f'_len{self.max_seq_len}_ep{epochs}_trained_on_{num_training_examples}.bin'

            if model_suffix:
                model_filename = model_filename.replace('.bin', f'_{model_suffix}.bin')

            os.makedirs(MODEL_SAVE_PATH, exist_ok=True)
            model_file_path = os.path.join(MODEL_SAVE_PATH, model_filename)
            print(f'Saving the trained model after {epochs} epochs: {model_file_path}')
            torch.save(model.state_dict(), model_file_path)

            return history, model_file_path

    def train_epoch(self, data_loader):
        model = self.model.train()  # put the model into training mode
        loss_fn = self.loss_fn
        optimizer = self.optimizer
        device = self.device
        scheduler = self.scheduler

        losses = []
        num_correct_predictions = 0

        num_train_examples = len(data_loader.dataset)

        for batch in data_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            targets = batch["targets"].to(device)
            # forward computation
            logits = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )

            # _, preds = torch.max(logits, dim=1)
            max_out = torch.max(logits, dim=1)  # softmax (dim=1: first dimension)
            preds = max_out.indices  # list of indices (index for each example)
            num_correct_predictions += torch.sum(preds == targets)

            loss = loss_fn(logits, targets)  # cross-entropy loss
            losses.append(loss.item())
            loss.backward()  # Do the backpropagation step

            # ------- Backpropagation steps (for fine-tuning) ------
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()  # Need this for every example batch
            optimizer.zero_grad()

        ratio_correct_predictions = num_correct_predictions.double() / num_train_examples
        mean_loss = np.mean(losses)
        return ratio_correct_predictions, mean_loss

    def eval_model(self, data_loader, get_predictions=False):
        model = self.model.eval()
        loss_fn = self.loss_fn
        device = self.device

        losses = []
        correct_predictions = 0

        raw_texts = []
        predictions = []
        prediction_probs = []
        real_values = []

        num_examples_used = 0

        num_examples = len(data_loader.dataset)
        # assert num_examples == len(data_loader.dataset)

        # Disable the gradient function so torch is a bit faster
        with torch.no_grad():
            for d in data_loader:
                input_ids = d["input_ids"].to(device)
                attention_mask = d["attention_mask"].to(device)
                targets = d["targets"].to(device)
                logits = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask
                )
                max_out = torch.max(logits, dim=1)
                preds = max_out.indices
                loss = loss_fn(logits, targets)
                correct_predictions += torch.sum(preds == targets)
                losses.append(loss.item())

                pred_probs = F.softmax(logits, dim=1)

                if get_predictions:
                    texts = d["raw_text"]
                    raw_texts.extend(texts)
                    predictions.extend(preds)
                    prediction_probs.extend(pred_probs)
                    real_values.extend(targets)

                num_examples_used += len(input_ids)

        assert num_examples_used == num_examples

        accuracy = correct_predictions.double() / num_examples
        mean_loss = np.mean(losses)

        if get_predictions:

            # [!] list of tensors => single tensor로 변환
            #     Basically adding a new dimension to our tensor
            predictions = torch.stack(predictions).cpu()
            prediction_probs = torch.stack(prediction_probs).cpu()
            real_values = torch.stack(real_values).cpu()

            return accuracy, mean_loss, raw_texts, predictions, prediction_probs, real_values
        else:
            return accuracy, mean_loss

    def predict(self, texts: List[str]):
        device = self.device
        model = self.model.eval()
        preproc_type = self.preproc_type

        if preproc_type == 'all_id_removed':
            preprocessed_texts = [preprocess(text, remove_identifiers=True, spacy_preproc=False) for text in texts]
        elif preproc_type == 'all_id_masked_preprocessed':
            preprocessed_texts = [preprocess(text) for text in texts]
        elif preproc_type == '':
            preprocessed_texts = texts
        else:
            raise ValueError('Invalid preprocessing type: ' + preproc_type)

        dataset = TextClassificationTokenizedDataset(
            raw_texts=preprocessed_texts,  # Text
            tokenizer=self.tokenizer,
            max_len=self.max_seq_len
        )

        # Set batch
        data_loader = DataLoader(
            dataset=dataset,  # Inheriting PyTorch's Dataset ...
            batch_size=1,
            num_workers=0  # zero => disable multi-preprocessing
        )

        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            logits = model(input_ids, attention_mask)
            max_out = torch.max(logits, dim=1)
            prediction = max_out.indices
            prediction_prob = F.softmax(logits, dim=1)[0].cpu()[prediction].item()

            yield prediction_prob, prediction

        return

USE_ONLY_LINEAR_ON_CLASSIFICATION_LAYER = True
#USE_ONLY_LINEAR_ON_CLASSIFICATION_LAYER = False

class BertTextClassifierModel(nn.Module):
    def __init__(self, n_classes, bert_model_name, freeze_bert=False):
        super(BertTextClassifierModel, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)

        # Apply to BERT's "pooled output" (classification layer)
        if USE_ONLY_LINEAR_ON_CLASSIFICATION_LAYER:
            self.drop = nn.Dropout(p=0.3)
            self.out = nn.Linear(self.bert.config.hidden_size, n_classes)  # = dense = fully-connected layer
            self.softmax = nn.Softmax(dim=1)
        else:
            D_in = self.bert.config.hidden_size
            H = 50
            D_out = n_classes

            # Instantiate an one-layer feed-forward classifier
            self.classifier = nn.Sequential(
                nn.Linear(D_in, H),
                nn.ReLU(),
                #nn.Dropout(0.5),
                nn.Linear(H, D_out)
            )

        # Freeze the BERT model
        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False

    def forward(self, input_ids, attention_mask):
        # _, pooled_output = self.bert(
        bert_output = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        if USE_ONLY_LINEAR_ON_CLASSIFICATION_LAYER:
            # [NOTE] HuggingFace https://huggingface.co/transformers/main_classes/output.html
            #   Last layer hidden-state of the first token of the sequence (classification token)
            #   further processed by a Linear layer and a Tanh activation function
            #   The Linear layer weights are trained from the next sentence prediction (classification)
            #   objective during pretraining.
            output = bert_output.pooler_output
            output = self.drop(output)  # dropout for regularization
            output = self.out(output)  # hidden_state => n classes ("out" = fully connected layer)
            #return self.softmax(output)
            return output
        else:
            # Extract the last hidden state of the token `[CLS]` for classification task
            last_hidden_state_cls = bert_output[0][:, 0, :]

            # Feed input to classifier to compute logits
            logits = self.classifier(last_hidden_state_cls)

            return logits

# Inheriting PyTorch's Dataset class so as to use it for PyTorch training later
class TextClassificationTokenizedDataset(Dataset):
    def __init__(self, raw_texts: np.array,
                 tokenizer,
                 max_len: int,
                 targets: np.array=np.empty(shape=(0, 0))):
        self.raw_texts = raw_texts
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.targets = targets

    def __len__(self):
        return len(self.raw_texts)

    def __getitem__(self, item):
        raw_text = str(self.raw_texts[item])
        encoding = self.tokenizer.encode_plus(
            raw_text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            #pad_to_max_length=True,  # Deprecated
            padding='max_length', truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )

        data = {
            'raw_text': raw_text,  # type: list[str]
            'input_ids': encoding['input_ids'].flatten(),  # type: torch.Tensor
            'attention_mask': encoding['attention_mask'].flatten(),  # type: torch.Tensor
        }

        if self.targets.size > 0:
            target = self.targets[item]
            data['targets'] = torch.tensor(target, dtype=torch.long)  # type: torch.Tensor

        return data


def split_dataset(X, y, split_ratio: str):

    if split_ratio.count(':') != 2:
        raise ValueError(f'Invalid notation for split ratio: "{split_ratio}". (Needs to be like "8:1:1")')

    # e.g.) "9:0.5:0.5" => [9.0, 0.5, 0.5]
    split_ratios = list(map(float, split_ratio.split(':')))
    # e.g.) [9.0, 0.5, 0.5] => 10.0
    total = sum(split_ratios)
    #train_ratio = split_ratios[0] / total
    # e.g.) 0.5/10.0 => 0.05
    val_ratio = split_ratios[1] / total
    # e.g.) 0.5/10.0 => 0.05
    test_ratio = split_ratios[2] / total
    # e.g.) 0.05+0.05 => 0.1  (val + test = 10%)
    non_train_ratio = val_ratio + test_ratio

    print('@Total examples:', X.shape[0])
    print('@Non-train ratio:', val_ratio + test_ratio)

    # Use all as training data ...
    if non_train_ratio == 0.0:
        X_train = X
        y_train = y
        X_val, X_test, y_val, y_test = None, None, None, None
        print('   ==> Use all data as training data!')

    else:
        # e.g.) 0.05/(0.05+0.05) => 0.5 (50% of val+test = test)
        test_to_val_ratio = test_ratio / (val_ratio + test_ratio)

        print('@test/non-train ratio:', test_ratio / (val_ratio + test_ratio))

        # Split the whole data into train & non-train
        X_train, X_non_train, y_train, y_non_train = train_test_split(
            X, y,
            test_size=non_train_ratio,
            random_state=RANDOM_SEED,  # shuffle!
            stratify=y,
        )

        # No validation data... Use all non-train part as test data
        if test_to_val_ratio == 1.0:
            X_val, y_val = None, None
            X_test, y_test = X_non_train, y_non_train
        else:
            # split val+test into val & test (by specifying relative test ratio)
            X_val, X_test, y_val, y_test = train_test_split(
                X_non_train, y_non_train,
                test_size=test_to_val_ratio,
                random_state=RANDOM_SEED,  # shuffle!
                stratify=y_non_train
            )

    return X_train, X_val, X_test, y_train, y_val, y_test


def create_loader(raw_texts, targets, bert_tokenizer, max_seq_len, batch_size):
    dataset = TextClassificationTokenizedDataset(
        raw_texts=raw_texts.to_numpy(),
        targets=targets.to_numpy(),
        tokenizer=bert_tokenizer,
        max_len=max_seq_len
    )

    data_loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        num_workers=4
    )

    return data_loader


def print_data_split_info(data_type, num_examples, raw_texts, labels, extra='', class_names=None):
    if extra:
        print(f'\n======= Dataframe <{data_type}> | # texts: {num_examples} | {extra} ==============')
    else:
        print(f'\n======= Dataframe <{data_type}> | # texts: {num_examples} ==============')

    print('------- input texts -------')
    print(raw_texts)
    print('--------- labels ----------')
    print(labels)

    unique_labels, counts = np.unique(labels, return_counts=True)
    total_count = np.sum(counts)
    assert total_count == num_examples

    sorted_indices = np.lexsort((unique_labels, counts))  # Sort by counts, then by unique_labels

    print('----- Label distribution (Ratio) -----')

    for i in sorted_indices[::-1]:
        label = unique_labels[i]
        if class_names:
            label = class_names[label]
        count = counts[i]
        print(f'    {label}: {count} ({count/total_count:.1%})')

    print()


def iter_data_loaders_for_each_split(dataframe, input_text_col_name, output_label_col_name,
                                     class_names, bert_tokenizer, max_seq_len, batch_size, split_ratio,
                                     k_fold=0, early_stopping=True):
    if k_fold > 0:
        assert isinstance(k_fold, int)
        cross_val_loader_iterator \
            = create_cross_validate_loaders(dataframe, input_text_col_name, output_label_col_name,
                                            bert_tokenizer, max_seq_len, batch_size, k_fold,
                                            early_stopping=early_stopping, class_names=class_names)

    else:
        data_loaders \
            = create_split_loaders(dataframe, input_text_col_name, output_label_col_name,
                                   bert_tokenizer, max_seq_len, batch_size, split_ratio,
                                   class_names=class_names)
        cross_val_loader_iterator = [data_loaders]

    for data_loaders in cross_val_loader_iterator:

        train_data_loader, val_data_loader, test_data_loader = data_loaders
        train_sample_batch = next(iter(train_data_loader))

        print('\nEach batch:')
        print('  * train:', train_sample_batch['input_ids'].shape)
        print('   - keys:', train_sample_batch.keys())

        if val_data_loader:
            val_sample_batch = next(iter(val_data_loader))
            print('  *   val:', val_sample_batch['input_ids'].shape)
            print('   - keys:', val_sample_batch.keys())

        if test_data_loader:
            test_sample_batch = next(iter(test_data_loader))
            print('  *  test:', test_sample_batch['input_ids'].shape)
            print('   - keys:', test_sample_batch.keys())
            print()

        yield train_data_loader, val_data_loader, test_data_loader


def create_cross_validate_loaders(dataframe, input_text_col_name, output_label_col_name,
                                  bert_tokenizer, max_seq_len, batch_size, k_fold,
                                  early_stopping=True, class_names=None):
    total_num_examples = dataframe.shape[0]

    raw_texts = dataframe[input_text_col_name]
    targets = dataframe[output_label_col_name]

    #skf = StratifiedKFold(n_splits=k_fold)
    skf = StratifiedKFold(n_splits=k_fold, shuffle=True, random_state=RANDOM_SEED)
    #skf.get_n_splits(raw_texts, targets)

    # For each fold
    for fold_idx, (train_iloc_index, test_iloc_index) in enumerate(skf.split(raw_texts, targets)):
        print(f'---------------' * 4)
        print(f'      FOLD: {fold_idx+1}/{k_fold}  (Total {total_num_examples} examples)')
        print(f'---------------' * 4)
        train_raw_texts, test_raw_texts = raw_texts.iloc[train_iloc_index], raw_texts.iloc[test_iloc_index]
        train_targets, test_targets = targets.iloc[train_iloc_index], targets.iloc[test_iloc_index]

        #print('X_test[0]:', X_test.iloc[0])
        #print('y_test[0]:', self.class_names[y_test.iloc[0]])
        #print()

        # ------------------------------------------------------
        #   Split training set into training & validation sets
        # ------------------------------------------------------
        if early_stopping:
            skf_for_val = StratifiedKFold(n_splits=k_fold, shuffle=True, random_state=RANDOM_SEED)
            train_iloc_index, val_iloc_index = next(skf_for_val.split(train_raw_texts, train_targets))

            train_raw_texts, val_raw_texts = train_raw_texts.iloc[train_iloc_index], train_raw_texts.iloc[val_iloc_index]
            train_targets, val_targets = train_targets.iloc[train_iloc_index], train_targets.iloc[val_iloc_index]

            split_data = [('Train', train_raw_texts, train_targets),
                          ('Val', val_raw_texts, val_targets),
                          ('Test', test_raw_texts, test_targets)]

            train_index, val_index, test_index = train_raw_texts.index, val_raw_texts.index, test_raw_texts.index
            assert len(train_index) + len(val_index) + len(test_index) == total_num_examples
            assert len(np.intersect1d(train_index, test_index, assume_unique=True)) == 0
            assert len(np.intersect1d(train_index, val_index, assume_unique=True)) == 0
            assert len(np.intersect1d(val_index, test_index, assume_unique=True)) == 0

        else:
            split_data = [('Train', train_raw_texts, train_targets),
                          ('Test', test_raw_texts, test_targets)]

            train_index, val_index, test_index = train_raw_texts.index, [], test_raw_texts.index
            assert len(train_index) + len(val_index) + len(test_index) == total_num_examples
            assert len(np.intersect1d(train_index, test_index, assume_unique=True)) == 0

        for data_name, split_raw_texts, split_targets in split_data:
            print_data_split_info(data_name, split_raw_texts.shape[0], split_raw_texts, split_targets,
                                  extra=f'FOLD: {fold_idx+1}/{k_fold}', class_names=class_names)

        train_data_loader = create_loader(train_raw_texts, train_targets, bert_tokenizer, max_seq_len, batch_size)
        test_data_loader = create_loader(test_raw_texts, test_targets, bert_tokenizer, max_seq_len, batch_size)

        if early_stopping:
            val_data_loader = create_loader(val_raw_texts, val_targets, bert_tokenizer, max_seq_len, batch_size)
        else:
            val_data_loader = None

        yield train_data_loader, val_data_loader, test_data_loader


def create_split_loaders(dataframe, input_text_col_name, output_label_col_name,
                         bert_tokenizer, max_seq_len, batch_size, split_ratio,
                         class_names=None):
    raw_texts = dataframe[input_text_col_name]
    targets = dataframe[output_label_col_name]

    train_raw_texts, val_raw_texts, test_raw_texts, \
        train_targets, val_targets, test_targets = split_dataset(raw_texts, targets, split_ratio)

    train_data_loader = create_loader(train_raw_texts, train_targets, bert_tokenizer, max_seq_len, batch_size)

    print_data_split_info('TRAIN', train_raw_texts.shape[0], train_raw_texts, train_targets, class_names=class_names)

    val_exists = type(val_raw_texts) != type(None)
    test_exists = type(test_raw_texts) != type(None)

    if val_exists:
        print_data_split_info('VAL', val_raw_texts.shape[0], val_raw_texts, val_targets, class_names=class_names)
        val_data_loader = create_loader(val_raw_texts, val_targets, bert_tokenizer, max_seq_len, batch_size)
    else:
        print('\n**** No validation dataset specified ****')
        val_data_loader = None

    if test_exists:
        print_data_split_info('TEST', test_raw_texts.shape[0], test_raw_texts, test_targets, class_names=class_names)
        test_data_loader = create_loader(test_raw_texts, test_targets, bert_tokenizer, max_seq_len, batch_size)
    else:
        print('\n**** No test dataset specified ****')
        test_data_loader = None

    # No overlap between train/val/test datasets (indices)
    if test_exists:
        assert len(np.intersect1d(train_raw_texts.index, test_raw_texts.index, assume_unique=True)) == 0

    if val_exists:
        assert len(np.intersect1d(train_raw_texts.index, val_raw_texts.index, assume_unique=True)) == 0
        if test_raw_texts:
            assert len(np.intersect1d(val_raw_texts.index, test_raw_texts.index, assume_unique=True)) == 0
            assert train_raw_texts.shape[0] + val_raw_texts.shape[0] + test_raw_texts.shape[0] == dataframe.shape[0]
    elif test_exists:
        assert train_raw_texts.shape[0] + test_raw_texts.shape[0] == dataframe.shape[0]

    return train_data_loader, val_data_loader, test_data_loader


def demo_predict():

    trained_model_path = 'path/to/trained_model'
    input_data_path = 'path/to/benchmark_dataset'

    coda_class_names = ['Arms', 'Crypto', 'Drugs', 'Electronic', 'Financial', 'Gambling',
                        'Hacking', 'Others', 'Porn', 'Violence']

    # ===================================================
    #      DO NOT CHANGE THE ORDER OF CLASS NAMES
    # ===================================================
    class_names = coda_class_names

    # =================== Configuration ============================
    preproc_type = 'norm_removed'
    device = torch.device(f'cuda:{CUDA_DEVICE_NO}' if torch.cuda.is_available() else "cpu")
    bert_model_name = 'bert-base-uncased'
    max_seq_len = 256
    print('Class names:', ' '.join(class_names))
    # ==============================================================

    # ================ Loading the classifier ======================
    print('Loading the tokenizer ...')
    bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
    print('Loading the classifier ...')
    bert_classifier = TextClassifier(class_names, device, bert_model_name, bert_tokenizer, max_seq_len, preproc_type)
    print('Loading the model ...\n  => ' + trained_model_path)
    bert_model_path = os.path.join(MODEL_SAVE_PATH, trained_model_path)
    bert_classifier.load_trained_model(bert_model_name, bert_model_path)
    # ==============================================================

    from bert_text_classifier.dataset import load_darkweb_extra_benchmark_data, DUTA_TO_CODA_MAP
    from bert_text_classifier.util_text import compress_whitespace_in_text

    print('Input data:', input_data_path)

    forum_data = [(text, category, forum_name) for text, category, forum_name
                  in load_darkweb_extra_benchmark_data(input_data_path)]

    forum_texts, forum_categories, forum_names = zip(*forum_data)

    true_labels = []
    pred_labels = []

    for i, (pred_prob, pred_label) in enumerate(bert_classifier.predict(forum_texts)):
        pred_label = pred_label.item()
        pred_coda_category = class_names[pred_label]

        _, true_coda_category, forum_name = forum_data[i]

        true_label = coda_class_names.index(true_coda_category)
        true_labels.append(true_label)
        pred_labels.append(pred_label)

        print('-----------------------------------------------')
        print(f'#{i+1} | {forum_name} | {true_coda_category}')
        print(f'{compress_whitespace_in_text(forum_texts[i])[:100]}')
        print(f'   ** {pred_coda_category} | prob={pred_prob:.2f} **')

    forum_name_match_dict = defaultdict(list)

    # ==================================
    #      Per-website performance
    # ==================================
    forum_name_cat_pairs = []
    for i, true_label in enumerate(true_labels):
        pred_label = pred_labels[i]
        forum_name = forum_names[i]
        forum_category = forum_categories[i]
        forum_name_cat = f'{forum_category}-{forum_name}'
        forum_name_cat_pairs.append(forum_name_cat)

        if pred_label == true_label:
            forum_name_match_dict[forum_name_cat].append(1)
        else:
            forum_name_match_dict[forum_name_cat].append(0)

    for forum_name_cat in sorted(set(forum_name_cat_pairs)):
        total = len(forum_name_match_dict[forum_name_cat])
        match = sum(forum_name_match_dict[forum_name_cat])
        if total > 0:
            ratio = match / total
            print(f'{forum_name_cat:15}: {match}/{total} ({ratio:.1%})')
        else:
            print(f'{forum_name_cat:15}: (No example)')

    # ==================================
    #         Total performance
    # ==================================
    report = classification_report(true_labels, pred_labels, target_names=coda_class_names,
                                   labels=list(range(len(coda_class_names))), digits=4)
    print(report)
    print('Model:', trained_model_path)
    print('Test data:', input_data_path)
    print()


def train_bert_classifier(data_path, text_type):

    df, input_text_col_name, output_label_col_name, class_names \
        = load_coda_darkweb_texts(data_path=data_path, text_type=text_type)

    device = torch.device(f"cuda:{CUDA_DEVICE_NO}" if torch.cuda.is_available() else "cpu")

    bert_model_name = 'bert-base-uncased'
    max_seq_len = 256
    batch_size = 32
    learning_rate = 2e-5
    epochs = 10
    k_fold = 0
    early_stopping=False
    split_ratio = '7:0:3'  # "train:val:test"

    start_time_str = datetime.now().strftime("%Y/%m/%d %H:%M:%S")

    print('@Time:', start_time_str)
    print('@data_path:', data_path)
    print('@preproc_type:', text_type)
    print('@RANDOM_SEED:', RANDOM_SEED)
    print('@class_names:', class_names)
    print('@bert_model_name:', bert_model_name)
    print('@max_seq_len:', max_seq_len)
    print('@batch_size:', batch_size)
    print('@epochs:', epochs)

    if k_fold > 0:
        print('@Cross-validation:', f'{k_fold}-fold')
    else:
        print('@split_ratio:', split_ratio)

    print('@early_stopping:', early_stopping)

    bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)

    # Create training/validation/test data loader pairs for each data split
    # (n splits for n-fold cross-validation)
    data_loader_iterator \
        = iter_data_loaders_for_each_split(
            dataframe=df,
            input_text_col_name=input_text_col_name,
            output_label_col_name=output_label_col_name,
            class_names=class_names,
            bert_tokenizer=bert_tokenizer,
            max_seq_len=max_seq_len,
            batch_size=batch_size,
            split_ratio=split_ratio,
            k_fold=k_fold,
            early_stopping=early_stopping,
        )

    # Aggregate all test labels across cross-validation splits (true/predictions)
    y_pred_test_all, y_true_test_all = [], []
    test_acc_all = []
    best_epochs_all = []

    start_training_time = time.time()

    for nth_split, data_loaders in enumerate(data_loader_iterator):
        train_data_loader, val_data_loader, test_data_loader = data_loaders

        if not ((early_stopping and val_data_loader) or (not early_stopping and not val_data_loader)):
            print('[Error] Specify a validation set to use early stopping.')
            exit(1)

        # ==============================================
        #     Initialize a BERT classification model
        # ==============================================
        classifier = TextClassifier(class_names, device, bert_model_name, bert_tokenizer,
                                    max_seq_len, text_type, experiment_name=dataset_name)
        classifier.config_training(len(train_data_loader), epochs, learning_rate)

        print()
        print(f'========' * 5)
        if k_fold > 0:
            print(' ' * 14 + f'Experiment: Fold #{nth_split+1}')
        else:
            print(' ' * 14 + 'Experiment')
        print(f'========' * 5)

        #print('Continue for debugging!'); continue

        if early_stopping:
            # Train (with early stopping based on validation set)
            best_val_accuracy, history, model_path, best_epochs \
                = classifier.train(train_data_loader, val_data_loader, model_suffix=f'fold{nth_split+1},{k_fold}')
            best_epochs_all.append(best_epochs)
            print(f'Best validation accuracy: {best_val_accuracy:.2%}')
            print('Loading the best model for evaluation:', model_path)
        else:
            # Train (without early stopping)
            if k_fold > 0:
                model_suffix = f'fold{nth_split},{k_fold}'
            else:
                model_suffix = ''

            history, model_path \
                = classifier.train(train_data_loader, model_suffix=model_suffix)
            print('Loading the pretrained model for evaluation:', model_path)

        # =====================================
        #           Test (evaluation)
        # =====================================
        if test_data_loader:
            # So how good is our model on predicting sentiment?
            # Let’s start by calculating the accuracy on the test data:
            print(f'\nEvaluating on the test data ...')

            classifier.load_trained_model(bert_model_name, model_path)

            #test_acc, _ = classifier.eval_model()
            test_acc, mean_loss, y_input_texts, y_pred_test, y_pred_probs, y_true_test \
                = classifier.eval_model(test_data_loader, get_predictions=True)
            print(f'Test accuracy: {test_acc.item():.2%}  (see below)')
            test_acc_all.append(test_acc.item())

            # ==========================================
            #       Prediction & Confusion matrix
            # ==========================================
            print()
            print(classification_report(y_true_test, y_pred_test, target_names=class_names, digits=4))

            # Aggregate all test results obtained from each split
            if k_fold > 0:
                y_pred_test_all.extend(y_pred_test)
                y_true_test_all.extend(y_true_test)

    total_training_time = time.time() - start_training_time
    times = str(timedelta(seconds=total_training_time)).split(".")
    times = times[0]

    if k_fold > 0:
        print(f'=============================================')
        print(f'  Result of {k_fold}-fold cross-validation')
        print(f'=============================================')
        print('Execution time:', start_time_str)
        print(f'Total time for cross-validation: {total_training_time:.1f} seconds ({times})')
        if early_stopping:
            print('Best epochs:', best_epochs_all)
        print()
        for i, test_acc in enumerate(test_acc_all):
            print(f'Test accuracy for split {i+1}: {test_acc:.2%}')
        print()
        print('Micro-averaged results:')
        print(classification_report(y_true_test_all, y_pred_test_all, target_names=class_names, digits=4))


def demo_train():
    data_path = 'path/to/coda'

    #text_type = 'min_id_masked'
    #text_type = 'all_id_masked'
    text_type = 'all_id_masked_preprocessed'

    train_bert_classifier(data_path, text_type)


if __name__ == '__main__':
    #demo_predict()
    demo_train()

