from nltk.corpus import stopwords
from sklearn.feature_extraction.text import TfidfTransformer, CountVectorizer
from sklearn.model_selection import GridSearchCV
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import LinearSVC
from sklearn.linear_model import LogisticRegression
from sklearn import metrics
from sklearn.pipeline import Pipeline
from sklearn.model_selection import PredefinedSplit
from dataloaders.ecthr_dataset import ECtHRDataset
from dataloaders.fscs_dataset import FSCSDataset
from dataloaders.spc_dataset import SPCDataset
import numpy as np
import logging
import os
import argparse
import pandas as pd


def eval_by_group(dataset: pd.DataFrame, attribute: str):
    print(f'-' * 100)
    print(f'RESULTS PER GROUP IN {attribute.upper()}')
    f1s = []
    for val in dataset[attribute].unique():
        group_dataset = dataset[dataset[attribute] == val]
        y_true = group_dataset.labels.tolist()
        y_pred = group_dataset.predictions.tolist()
        f1 = metrics.f1_score(y_true, y_pred, average="macro")
        print(f'Macro-F1 ({val}): {f1 * 100:.1f}')
        f1s.append(f1)
    print(f'Macro-F1 (GROUP): {np.mean(f1s) * 100:.1f} +/- {np.std(f1s) * 100:.1f}')
    print(f'-' * 100)


def main():
    parser = argparse.ArgumentParser()
    # Required arguments
    parser.add_argument('--dataset',  default='fscs')
    parser.add_argument('--attributes', default=['region', 'language', 'legal_area'])
    parser.add_argument('--task_type', default='multi_class')
    parser.add_argument('--text_limit', default=-1)
    parser.add_argument('--n_classes', default=100)
    config = parser.parse_args()
    if not os.path.exists(f'logs/{config.dataset}'):
        if not os.path.exists(f'logs'):
            os.mkdir(f'logs')
        os.mkdir(f'logs/{config.dataset}')
    handlers = [logging.FileHandler(f'logs/{config.dataset}/svm.txt'), logging.StreamHandler()]
    logging.basicConfig(handlers=handlers, level=logging.INFO)

    def add_zero_class(labels):
        augmented_labels = np.zeros((len(labels), len(labels[0]) + 1), dtype=np.int32)
        augmented_labels[:, :-1] = labels
        augmented_labels[:, -1] = (np.sum(labels, axis=1) == 0).astype('int32')
        return augmented_labels

    if config.dataset == 'fscs':
        dataset = FSCSDataset(version="1.0", group_by_fields=config.attributes, root_dir='../data/datasets')
        stop_words = stopwords.words('german') + stopwords.words('french') + stopwords.words('italian')
    elif config.dataset == 'ecthr':
        dataset = ECtHRDataset(version="1.0", group_by_fields=config.attributes, root_dir='../data/datasets')
        stop_words = stopwords.words('english')
    elif config.dataset == 'spc':
        dataset = SPCDataset(version="1.0", group_by_fields=config.attributes, root_dir='../data/datasets')
        stop_words = []
    else:
        raise RuntimeError(f'The dataset `{config.datase}` is not recognised.')

    # Subsets
    datasets = {}
    datasets['train'] = dataset.data_df[dataset.data_df['data_type'] == dataset.split_dict['train']]
    datasets['validation'] = dataset.data_df[dataset.data_df['data_type'] == dataset.split_dict['val']]
    datasets['test'] = dataset.data_df[dataset.data_df['data_type'] == dataset.split_dict['test']]

    if config.task_type == 'multi_label':
        classifier = OneVsRestClassifier(LogisticRegression(random_state=12, max_iter=50000, solver='liblinear'))
        parameters = {
            'vect__max_features': [40000, 60000],
            'clf__estimator__C': [1, 10, 20],
            # 'clf__estimator__loss': ['hinge', 'squared_hinge']
        }
    else:
        classifier = LogisticRegression(random_state=12, max_iter=50000, solver='liblinear')
        parameters = {
            'vect__max_features': [40000, 60000],
            'clf__C': [1, 10, 20],
        }

    # Init Pipeline (TF-IDF, SVM)
    text_clf = Pipeline([('vect', CountVectorizer(stop_words=stop_words,
                                                  ngram_range=(1, 3), min_df=5)),
                         ('tfidf', TfidfTransformer()),
                         ('clf', classifier),
                         ])

    # Fixate Validation Split
    split_index = [-1] * len(datasets['train']) + [0] * len(datasets['validation'])
    val_split = PredefinedSplit(test_fold=split_index)
    gs_clf = GridSearchCV(text_clf, parameters, cv=val_split, n_jobs=16, verbose=4)

    # Pre-process inputs, outputs
    x_train = datasets['train'].text.tolist() + \
              datasets['validation'].text.tolist()
    y_train = datasets['train'].y.tolist() + datasets['validation'].y.tolist()

    # Train classifier
    gs_clf = gs_clf.fit(x_train, y_train)

    # Print best hyper-parameters
    print('Best Parameters:')
    for param_name in sorted(parameters.keys()):
        print("%s: %r" % (param_name, gs_clf.best_params_[param_name]))

    # Report results
    print('VALIDATION RESULTS:')
    y_pred = gs_clf.predict(datasets['validation'].text.tolist())
    y_true = datasets['validation'].y.tolist()
    if config.task_type == 'multi_label':
        y_true = add_zero_class(y_true)
        y_pred = add_zero_class(y_pred)

    print(f'Micro-F1: {metrics.f1_score(y_true, y_pred, average="micro")*100:.1f}')
    print(f'Macro-F1: {metrics.f1_score(y_true, y_pred, average="macro")*100:.1f}')

    print('TEST RESULTS:')
    y_pred = gs_clf.predict(datasets['test'].text.tolist())
    y_true = datasets['test'].y.tolist()
    if config.task_type == 'multi_label':
        y_true = add_zero_class(y_true)
        y_pred = add_zero_class(y_pred)
    print(f'Micro-F1: {metrics.f1_score(y_true, y_pred, average="micro")*100:.1f}')
    print(f'Macro-F1: {metrics.f1_score(y_true, y_pred, average="macro")*100:.1f}')

    datasets['test']['labels'] = list(y_true)
    datasets['test']['predictions'] = list(y_pred)
    for attribute in config.attributes:
        eval_by_group(datasets['test'], attribute)


if __name__ == '__main__':
    main()