#!/usr/bin/env python
# coding: utf-8


import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ModelCheckpoint
import tensorflow_hub as hub
import tokenization
from sklearn.model_selection import train_test_split
from nltk.tokenize import WordPunctTokenizer 
from nltk.stem import WordNetLemmatizer 
import json
import re
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from wordsegment import load, segment
import nltk
nltk.download('wordnet')
import seaborn as sns


def bert_encode(texts, tokenizer, max_len=512):
    all_tokens = []
    all_masks = []
    all_segments = []
    
    for text in texts:
        text = tokenizer.tokenize(text)
            
        text = text[:max_len-2]
        input_sequence = ["[CLS]"] + text + ["[SEP]"]
        pad_len = max_len - len(input_sequence)
        
        tokens = tokenizer.convert_tokens_to_ids(input_sequence)
        tokens += [0] * pad_len
        pad_masks = [1] * len(input_sequence) + [0] * pad_len
        segment_ids = [0] * max_len
        
        all_tokens.append(tokens)
        all_masks.append(pad_masks)
        all_segments.append(segment_ids)
    
    return np.array(all_tokens), np.array(all_masks), np.array(all_segments)


def build_model(bert_layer, max_len=512):
    input_word_ids = Input(shape=(max_len,), dtype=tf.int32, name="input_word_ids")
    input_mask = Input(shape=(max_len,), dtype=tf.int32, name="input_mask")
    segment_ids = Input(shape=(max_len,), dtype=tf.int32, name="segment_ids")

    _, sequence_output = bert_layer([input_word_ids, input_mask, segment_ids])
    clf_output = sequence_output[:, 0, :]
    out = Dense(1, activation='sigmoid')(clf_output)
    
    model = Model(inputs=[input_word_ids, input_mask, segment_ids], outputs=out)
    model.compile(Adam(lr=2e-6), loss='binary_crossentropy', metrics=['accuracy']) 
    return model



module_url = "https://tfhub.dev/tensorflow/bert_en_uncased_L-24_H-1024_A-16/1"
bert_layer = hub.KerasLayer(module_url, trainable=True)



train=pd.read_excel("all.xlsx")
test = pd.read_excel("test.xlsx")

from sklearn.preprocessing import LabelEncoder
labelencoder = LabelEncoder()

train['Types_Category'] = labelencoder.fit_transform(train['category'])


LE = LabelEncoder()

train['label1'] = LE.fit_transform(train['category'])
train['Types_Category'].value_counts()[:20]


tok = WordPunctTokenizer()
lemmatizer = WordNetLemmatizer()
nltk.download('stopwords')

from nltk.corpus import stopwords
stopword = stopwords.words('english')


def helper(a):
    k=segment(a)
    return ' '.join(k)
load()


def tweet_cleaner(text): 
    wc=[]
    newString = str(text)                 #encode to ascii
    newString=re.sub(r'@[A-Za-z0-9]+?','',newString)                                #removing user mentions
    letters_only = re.sub("[^a-zA-Z]", " ", newString)                             #Fetching out only ascii characters
    letters_onl = re.sub('(www|http)\S+', '', letters_only)                        #removing links
    lower_case = letters_onl.lower()                                               #converting everything to lowercase
    words = tok.tokenize(lower_case)                                               #tokenize and join together to remove white spaces between words
    rs = [word for word in words if word not in stopword]                           #remove stopwords
    long_words=[]
    for i in rs:
        if len(i)>3:                                                 #removing short words
            long_words.append(lemmatizer.lemmatize(i))                 #converting words to lemma
    return (" ".join(long_words)).strip()  
train['text']=train['text'].apply(lambda x: tweet_cleaner(x))


sns.countplot(x="category", data=train)
X_train, X_test, y_train, y_test = train_test_split(train['text'], train['category'], test_size=0.33, random_state=42)


vocab_file = bert_layer.resolved_object.vocab_file.asset_path.numpy()
do_lower_case = bert_layer.resolved_object.do_lower_case.numpy()
tokenizer = tokenization.FullTokenizer(vocab_file, do_lower_case)


train_input = bert_encode(train.text.values, tokenizer, max_len=120)
test_input = bert_encode(test.text.values, tokenizer, max_len=120)
train_labels = train.label1.values

model = build_model(bert_layer, max_len=120)
model.summary()


from multi_imbalance.resampling.soup import SOUP
mdo = SOUP(maj_int_min={
        'maj': ['Not_offensive'],
        'min': ['Offensive_Targeted_Insult_Other']
    })
X_train_res, y_train_res = mdo.fit_resample(train_input.toarray(), np.array(train_labels))


from multi_imbalance.utils.plot import plot_visual_comparision_datasets
plot_visual_comparision_datasets(train_input.toarray(), train_labels, X_train_res, y_train_res, 'Raw Data', 'After processing')


print(train_input)
print(train_labels)

print(len(train_input),len(train_labels))

train_history = model.fit(
    X_train_res, y_train_res,
    validation_split=0.2,
    epochs=1,
    batch_size=1)
model.save('bert_to_extract_model.h5')


score = train_history.evaluate(test_input, y_test,batch_size=256, verbose=1)
print('Test accuracy:', score[1])


preds = train_history.predict(X_test)

from sklearn.metrics import classification_report,confusion_matrix, accuracy_score
print(classification_report(np.argmax(y_test,axis=1),np.argmax(preds,axis=1)))





