import argparse
import numpy as np
import pandas as pd
from transformers import AutoTokenizer
from IPython import embed
import matplotlib.pyplot as plt
import seaborn as sns
from keras.preprocessing.text import Tokenizer
from keras_preprocessing.sequence import pad_sequences
from keras.models import Sequential
from keras.layers import Dense, Embedding, LSTM, SpatialDropout1D
from keras.callbacks import EarlyStopping
from keras.layers import Dropout
import keras
import re
from nltk.corpus import stopwords
from nltk import word_tokenize
from sklearn.metrics import classification_report
import warnings

warnings.filterwarnings("ignore")

dataset_to_max_length = {
    "imdb": 512,
    "dbpedia": 512,
    "ag_news": 64,
}

dataset_to_num_labels = {"imdb": 2, "dbpedia": 9, "ag_news": 4}

STOPWORDS = set(stopwords.words("english"))
REPLACE_BY_SPACE_RE = re.compile("[/(){}\[\]\|@,;]")
BAD_SYMBOLS_RE = re.compile("[^0-9a-z #+_]")

import cufflinks
from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"

cufflinks.go_offline()
cufflinks.set_config_file(world_readable=True, theme="pearl")
import os


def get_model(X, MAX_NB_WORDS, EMBEDDING_DIM, num_labels):
    model = Sequential()
    model.add(Embedding(MAX_NB_WORDS, EMBEDDING_DIM, input_length=X.shape[1]))
    model.add(SpatialDropout1D(0.2))
    model.add(LSTM(100, dropout=0.2, recurrent_dropout=0.2))
    model.add(Dense(num_labels, activation="softmax"))
    model.compile(
        loss="categorical_crossentropy",
        optimizer=keras.optimizers.Adam(learning_rate=0.05),
        metrics=["accuracy"],
    )
    print(model.summary())

    return model


# def clean_text(text):
#     """
#     text: a string

#     return: modified initial string
#     """
#     text = text.lower()  # lowercase text
#     text = REPLACE_BY_SPACE_RE.sub(
#         " ", text
#     )  # replace REPLACE_BY_SPACE_RE symbols by space in text. substitute the matched string in REPLACE_BY_SPACE_RE with space.
#     text = BAD_SYMBOLS_RE.sub(
#         "", text
#     )  # remove symbols which are in BAD_SYMBOLS_RE from text. substitute the matched string in BAD_SYMBOLS_RE with nothing.
#     text = text.replace("x", "")
#     #    text = re.sub(r'\W+', '', text)
#     text = " ".join(
#         word for word in text.split() if word not in STOPWORDS
#     )  # remove stopwors from text
#     return text


import keras.backend as K


def size(model):  # Compute number of params in a model (the actual number of floats)
    return sum([np.prod(K.get_value(w).shape) for w in model.trainable_weights])


def load_dataset(tokenizer, dataset):
    train_df = pd.read_csv(
        os.path.join("../datasets", f"{dataset}_dataset", "train.csv")
    )
    # train_df["text"] = train_df["text"].apply(clean_text)
    MAX_NB_WORDS = 50000
    MAX_SEQUENCE_LENGTH = dataset_to_max_length[dataset]
    EMBEDDING_DIM = 300

    # tokenizer = Tokenizer(
    #     num_words=MAX_NB_WORDS, filters='!"#$%&()*+,-./:;<=>?@[\]^_`{|}~', lower=True
    # )
    # tokenizer.fit_on_texts(train_df["text"].values)
    # word_index = tokenizer.word_index
    # print("Found %s unique tokens." % len(word_index))

    def tokenize_df(df):
        X = tokenizer(
            df["text"].tolist(),
            padding="max_length",
            truncation=True,
            max_length=MAX_SEQUENCE_LENGTH,
            return_tensors="np",
        )["input_ids"]

        Y = pd.get_dummies(df["label"]).values
        return X, Y

    X_train, Y_train = tokenize_df(train_df)

    test_files = {
        file[: file.find(".")]: os.path.join("../datasets", f"{dataset}_dataset", file)
        for file in os.listdir(os.path.join("../datasets", f"{dataset}_dataset"))
        if (file.startswith("test") or file.startswith("adv"))
    }
    test_dfs = {
        key: pd.read_csv(value)
        for key, value in test_files.items()
        for key, value in test_files.items()
    }

    X_test = {}
    Y_test = {}
    for key, df in test_dfs.items():
        X_test[key], Y_test[key] = tokenize_df(df)

    return X_train, Y_train, X_test, Y_test, tokenizer.vocab_size, EMBEDDING_DIM


def main(args):
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    X_train, Y_train, X_test, Y_test, MAX_NB_WORDS, EMBEDDING_DIM = load_dataset(
        tokenizer=tokenizer, dataset=args.dataset
    )

    model = get_model(
        X_train, MAX_NB_WORDS, EMBEDDING_DIM, dataset_to_num_labels[args.dataset]
    )
    print("Number of params:", size(model))
    cp_callback = keras.callbacks.ModelCheckpoint(
        filepath=args.model_dir, save_weights_only=True, save_best_only=True, verbose=1
    )

    if args.mode == "train":
        print("Training the model...")
        history = model.fit(
            X_train,
            Y_train,
            epochs=args.num_epochs,
            batch_size=args.batch_size,
            validation_split=0.1,
            callbacks=[
                EarlyStopping(monitor="val_loss", patience=3, min_delta=0.0001),
                cp_callback,
            ],
        )
    else:
        model.load_weights(args.model_dir)

    for key in X_test.keys():
        if not (key.startswith("test_") or key.startswith("adv_")):
            continue
        print(f"Results for {key}:")
        print("Evaluating the model on the test set...")
        all_predictions = model.predict(X_test[key])
        print("The classification report for the model is:")
        print(
            classification_report(
                np.argmax(Y_test[key], axis=1),
                np.argmax(all_predictions, axis=1),
                digits=3,
            )
        )
        print("-" * 100)
        print("\n\n")
        from IPython import embed

        embed()


if __name__ == "__main__":
    argparser = argparse.ArgumentParser()

    argparser.add_argument(
        "--dataset",
        type=str,
        required=True,
    )
    argparser.add_argument(
        "--model_dir",
        type=str,
        required=True,
    )

    argparser.add_argument(
        "--batch_size",
        type=int,
        required=True,
    )

    argparser.add_argument("--num_epochs", type=int, default=3)
    argparser.add_argument(
        "--mode",
        type=str,
        required=True,
    )

    args = argparser.parse_args()

    main(args)
