import copy
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn import svm
from sklearn.metrics import accuracy_score
from sklearn.utils import shuffle
# from sklearn.utils.testing import ignore_warnings
from sklearn.exceptions import ConvergenceWarning
from statistics import mean, stdev
import math
import numpy as np
import gc
from warnings import filterwarnings
filterwarnings('ignore')
from sklearn.linear_model import LogisticRegression

# @ignore_warnings(category=ConvergenceWarning)
def train_lr(   train_x,
                train_y,
                random_state,
                ):

    # clf = make_pipeline(StandardScaler(), svm.SVC(random_state = random_state, max_iter=100))
    clf = make_pipeline(StandardScaler(), LogisticRegression(random_state = random_state, max_iter=100))
    clf.fit(train_x, train_y)

    return clf


def evaluate_svm(   clf,
                    test_x,
                    test_y,
                    ):
    
    test_y_pred = clf.predict(test_x)
    acc = accuracy_score(test_y, test_y_pred)
    return acc


def train_eval_lr(  train_x, train_y,
                    test_x, test_y,
                    num_seeds = 3,
                    ):

    reg_acc_list = []
    for random_state in range(num_seeds):
        clf_reg = train_lr(train_x, train_y, random_state)
        reg_acc_list.append(evaluate_svm(clf_reg, test_x, test_y))

    if num_seeds > 1:
        deviation = stdev(reg_acc_list)
        if deviation > 0.01:
            print(f"big stdev! {reg_acc_list}")

    return mean(reg_acc_list)