import random
from dataloaders import get_dataset
from nltk.corpus import stopwords
from sklearn.feature_extraction.text import TfidfTransformer, CountVectorizer
from sklearn.model_selection import GridSearchCV
from sklearn.linear_model import SGDClassifier
from sklearn import metrics
import numpy as np

import logging
logging.basicConfig(filename='linear_logs_v3.txt', level=logging.INFO)
DATASET = 'fscs'
GROUP_FIELD = 'language'
N_GROUPS = 3
dataset = get_dataset(DATASET, group_by_fields=['language'], root_dir='data/datasets')
train_dataset = dataset.data_df[dataset.data_df['data_type'] == dataset.split_dict['train']]
train_x, train_y = train_dataset['text'].tolist(), train_dataset['label'].tolist()
val_dataset = dataset.data_df[dataset.data_df['data_type'] == dataset.split_dict['val']]
val_x, val_y = val_dataset['text'].tolist(), val_dataset['label'].tolist()
test_dataset = dataset.data_df[dataset.data_df['data_type'] == dataset.split_dict['test']]
test_x, test_y = test_dataset['text'].tolist(), test_dataset['label'].tolist()
scores = {}
scores['macro'] = {i: [] for i in range(N_GROUPS)}
scores['group_avg'] = []
scores['ds'] = []

for seed in range(1, 2):
        from sklearn.pipeline import Pipeline
        text_clf = Pipeline([('vect', CountVectorizer(stop_words=None, ngram_range=(1, 3))),
                             ('tfidf', TfidfTransformer()),
                             ('clf', SGDClassifier(early_stopping=True, learning_rate='adaptive', tol=1e-4,
                                                   eta0=1e-4, validation_fraction=0.1, max_iter=10000,
                                                   random_state=seed)),
                             ])

        parameters = {
            'vect__max_features': [20000, 35000, 50000, 60000],
            'clf__loss': ('hinge', 'log', 'squared_hinge')
        }

        gs_clf = GridSearchCV(text_clf, parameters, cv=None, n_jobs=-1, verbose=4)
        gs_clf = gs_clf.fit(train_x, train_y)

        for param_name in sorted(parameters.keys()):
            logging.info("%s: %r" % (param_name, gs_clf.best_params_[param_name]))

        logging.info('VALIDATION RESULTS:')
        preds = gs_clf.predict(val_x)
        logging.info(f'Micro-F1: {metrics.f1_score(val_y, preds, average="micro")*100:.1f}')
        logging.info(f'Macro-F1: {metrics.f1_score(val_y, preds, average="macro")*100:.1f}')
        logging.info('TEST RESULTS:')
        preds = gs_clf.predict(test_x)
        logging.info(f'Micro-F1: {metrics.f1_score(test_y, preds, average="micro")*100:.1f}')
        logging.info(f'Macro-F1: {metrics.f1_score(test_y, preds, average="macro")*100:.1f}')

        # scores['macro']['all'].append(metrics.f1_score(test_y, preds, average="macro"))

        for i in range(N_GROUPS):
            group_set = dataset.data_df[dataset.data_df[GROUP_FIELD] == i]
            group_x, group_y = group_set['text'].tolist(), group_set['label'].tolist()
            preds = gs_clf.predict(group_x)
            scores['macro'][i].append(metrics.f1_score(group_y, preds, average="macro"))
            logging.info('-'*100)
            logging.info(f'Macro-F1 ({i}): {scores["macro"][i][-1]*100:.1f}')
        logging.info(f'Group-F1: {np.mean([scores["macro"][i][-1] for i in range(N_GROUPS)]) * 100:.1f}')
        logging.info(f'Group-D : {np.std([scores["macro"][i][-1] for i in range(N_GROUPS)]) * 100:.1f}')
        logging.info('-' * 100)
