import numpy as np
import re
import pandas as pd
from tqdm import tqdm
import nltk
from unidecode import unidecode
from langdetect import detect
from alphabet_detector import AlphabetDetector
from indic_transliteration import sanscript
from indic_transliteration.sanscript import transliterate

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
from sklearn.utils import shuffle
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import BertTokenizer, RobertaTokenizer, AlbertTokenizer, BertModel, RobertaModel, AlbertModel, RobertaTokenizerFast, BertTokenizerFast, AutoModel, AdamW, BertForSequenceClassification

import http.client
import json

import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, LSTM, Embedding, Input, Flatten

from keras import backend as K
from keras import utils

init_df = pd.read_csv("tamil_final_train.csv")
init_df = init_df.sample(frac=1).reset_index(drop=True)

indices = np.arange(init_df.count()[0])
x_indices = indices
y_indices = init_df['tag']
train_indices, val_indices, ylol1, ylol2 = train_test_split(x_indices, y_indices, train_size = 0.95, stratify=y_indices)
print(train_indices, val_indices, flush=True)
df_train = init_df.iloc[train_indices]
df_val = init_df.iloc[val_indices]
print(df_train.count(), df_val.count(), flush=True)

df_test = pd.read_csv("tamil_final_test.csv")

def bert_embeddings(X_input, model_dir):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = AlbertModel.from_pretrained(model_dir, output_hidden_states = True, )
    tokenizer = AlbertTokenizer.from_pretrained('./models/finetune/tamil_indic')
    model = model.to(device)
    model.eval()
    indexed_tokens = []
    segment_ids = []
    embeddings = []
    sentence_embeddings = []
    # sec_embeddings = []
    dd = X_input
    for i, line in enumerate(tqdm(dd)):
        tokens_tensor= tokenizer.convert_tokens_to_ids(tokenizer.tokenize("CLS " + line + " SEP"))
        if len(tokens_tensor)>512:
            tokens_tensor = tokens_tensor[:511] + [tokens_tensor[-1]]
        tokens_tensor = torch.tensor([tokens_tensor])
        segments_tensor = torch.tensor([[i]*tokens_tensor.shape[1]])
        print(tokens_tensor.shape, segments_tensor.shape)
        # if tokens_tensor.shape[1]>512:
            # continue
        tokens_tensor = tokens_tensor.to(device)
        segments_tensor = segments_tensor.to(device)
        with torch.no_grad():
            outputs = model(tokens_tensor, segments_tensor)
            hidden_states = outputs[2]
        sentence_embeddings.append(torch.mean(hidden_states[-2][0], dim=0).cpu().numpy())
    embeds = np.array(sentence_embeddings)
    # np.save('roberta_eng_emb.npy', embeds)
    return embeds

# train_embeds = bert_embeddings(df_train['text'], "./models/finetune/tamil_indic")
# np.save('ttrain_indic.npy', train_embeds)
# test_embeds = bert_embeddings(df_test['text'], "./models/finetune/tamil_indic")
# np.save('ttest_indic.npy', test_embeds)

train_embeds = np.load('ttrain_indic.npy')
test_embeds = np.load('ttest_indic.npy')

def recall_m(y_true, y_pred):
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
        recall = true_positives / (possible_positives + K.epsilon())
        return recall

def precision_m(y_true, y_pred):
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
        precision = true_positives / (predicted_positives + K.epsilon())
        return precision

def f1_m(y_true, y_pred):
        precision = precision_m(y_true, y_pred)
        recall = recall_m(y_true, y_pred)
        return 2*((precision*recall)/(precision+recall+K.epsilon()))

def train(X_train, y_train, X_test, y_test, ClassifierList):

    # X=WordEmbeddings[:count]
    # y=np.array(Dataframe['hate'][:count]).astype('int')

    for n, Classifier in enumerate(ClassifierList):
        print(f'{n+1} of {len(ClassifierList)} of methods attempting')
        # method={}
        if Classifier=="DNN":
            X_train = np.reshape(X_train, (X_train.shape[0], 1, X_train.shape[1]))
            X_test = np.reshape(X_test, (X_test.shape[0], 1, X_test.shape[1]))
            y_train = utils.to_categorical(y_train, 2)
            y_test = utils.to_categorical(y_test, 2)
            model = Sequential()
            model.add(LSTM(100, input_shape=(1, X_train.shape[2])))
            model.add(Dense(64, activation='relu'))
            model.add(Dense(2, activation='softmax'))
            model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy', f1_m])
            # print(model.summary())
            history = model.fit(X_train, y_train, epochs=3, batch_size=8, validation_split=0.1)
            y_pred = model.predict(X_test, batch_size=8)
            prediction = np.argmax(y_pred, axis=1)
            y_test = np.argmax(y_test, axis=1)
        else:
            print(X_train.shape, y_train.shape)
            Classifier.fit(X_train, y_train)
            prediction = Classifier.predict(X_test)

        tn, fp, fn, tp = confusion_matrix(y_test, prediction).ravel()
        print()
        print(f"{str(Classifier).split('(')[0]} Confusion Matrix:")
        print(f"True Negatives: {tn}")
        print(f"False Positives: {fp}")
        print(f"False Negatives: {fn}")
        print(f"True Positives: {tp}")
        print()

        report = classification_report(y_test, prediction, target_names=['Predict 0', 'Predict 1'], output_dict=True)
        class_table = pd.DataFrame(report).transpose()
        print(class_table)

classifier_lst = [LogisticRegression(), RandomForestClassifier(), SVC(), "DNN"]
print(df_train.count)
train(train_embeds, df_train['tag'], test_embeds, df_test['tag'], classifier_lst)
