import numpy as np
import re
import pandas as pd
from tqdm import tqdm
import nltk
from unidecode import unidecode

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import BertTokenizer, BertModel, BertTokenizerFast, AutoModel, AdamW

def remove_links(x):
    return re.sub(r"((http|ftp|https):\/\/)?[-a-zA-Z0-9@:%._\+~#=]{2,256}\.[a-z]{2,6}\b([-a-zA-Z0-9@:%_\+.~#?&//=]*)", "", x)

def preproc(text):
    text = re.sub('\\t+', '', text)
    text = re.sub('\d+\. ', '', text)
    ends = re.compile('[?!.]')
    text = re.sub(ends, " . ", text)
    emoji_pattern = re.compile("["
        u"\U0001F600-\U0001F64F"  # emoticons
        u"\U0001F300-\U0001F5FF"  # symbols & pictographs
        u"\U0001F680-\U0001F6FF"  # transport & map symbols
        u"\U0001F1E0-\U0001F1FF"  # flags (iOS)
                           "]+", flags=re.UNICODE)
    text = emoji_pattern.sub(r'', text)
    final = [i for i in text.split() if i[0]!='@']
    text = ' '.join(final)
    text = re.sub(r"\s+", " ", text)
    return text

hope_train = pd.read_csv("./hope_train.csv", header=None)
# print(hope_train[0])
data = pd.DataFrame(hope_train)
data.columns = ['text', 'tag']
data.drop(data[data['tag']=='not-English'].index, inplace=True)
data['text'] = data['text'].apply(lambda x: remove_links(x))
data['text'] = data['text'].apply(lambda x: preproc(x))
data['tag'] = data['tag'].apply(lambda x: 1 if (x=='Hope_speech') else 0)
print(data['text'].iloc[1])
# data.head()

X = data['text']
y = data['tag']
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size = 0.8, stratify=y, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X, y, test_size = 0.5, stratify=y, random_state=42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bert_tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
# bert = AutoModel.from_pretrained('bert-base-uncased')

train_tokens = bert_tokenizer.batch_encode_plus(
    X_train.tolist(),
    max_length=512,
    padding="longest",
    truncation=True
)

test_tokens = bert_tokenizer.batch_encode_plus(
    X_test.tolist(),
    max_length=512,
    padding="longest",
    truncation=True
)

val_tokens = bert_tokenizer.batch_encode_plus(
    X_val.tolist(),
    max_length=512,
    padding="longest",
    truncation=True
)

train_seq = torch.tensor(train_tokens['input_ids'])
train_mask = torch.tensor(train_tokens['attention_mask'])
train_y = torch.tensor(y_train.tolist())

val_seq = torch.tensor(val_tokens['input_ids'])
val_mask = torch.tensor(val_tokens['attention_mask'])
val_y = torch.tensor(y_val.tolist())

test_seq = torch.tensor(test_tokens['input_ids'])
test_mask = torch.tensor(test_tokens['attention_mask'])
test_y = torch.tensor(y_test.tolist())

batch_size = 25

train_data = TensorDataset(train_seq, train_mask, train_y)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

val_data = TensorDataset(val_seq, val_mask, val_y)
val_sampler = SequentialSampler(val_data)
val_dataloader = DataLoader(val_data, sampler = val_sampler, batch_size=batch_size)

test_data = TensorDataset(test_seq, test_mask, test_y)
test_sampler = RandomSampler(test_data)
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)

# class BERT_FT(nn.Module):
    # def __init__(self, bert):
    #   super(BERT_FT, self).__init__()
    #   self.bert = bert 
    #   self.dropout = nn.Dropout(0.1)
    #   self.relu =  nn.ReLU()
    #   self.fc1 = nn.Linear(768,512)
    #   self.fc2 = nn.Linear(512,2)
    #   self.softmax = nn.LogSoftmax(dim=1)
# 
    # def forward(self, sent_id, mask):
    #   _, cls_hs = self.bert(sent_id, attention_mask=mask)
    #   x = self.fc1(cls_hs)
    #   x = self.relu(x)
    #   x = self.dropout(x)
    #   x = self.fc2(x)
    #   x = self.softmax(x)
    #   return x

model = BertForSequenceClassification.from_pretrained('bert-base-uncased',num_labels=2)
model = model.to(device)

lr = 2e-5
max_grad_norm = 1.0
num_total_steps = 1000
num_warmup_steps = 100
warmup_proportion = float(num_warmup_steps) / float(num_total_steps)  # 0.1


### In PyTorch-Transformers, optimizer and schedules are splitted and instantiated like this:
optimizer = AdamW(model.parameters(), lr=lr, correct_bias=False)  # To reproduce BertAdam specific behavior set correct_bias=False
# scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_total=num_total_steps)  # PyTorch scheduler


# Store our loss and accuracy for plotting
train_loss_set = []


epochs = 5

# trange is a tqdm wrapper around the normal python range
for epoch in tqdm(range(epochs)):
    # Training
    # Set our model to training mode (as opposed to evaluation mode)
    model.train()

    # Tracking variables
    tr_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0

    # Train the data for one epoch
    for i, batch in enumerate(train_dataloader):
      # Add batch to GPU
      batch = tuple(t.to(device) for t in batch)
      # Unpack the inputs from our dataloader
      b_input_ids, b_input_mask, b_labels = batch
      # Forward pass
      outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
      loss = outputs[0]
      train_loss_set.append(loss.item())    
      # Backward pass
      loss.backward()
      # Update parameters and take a step using the computed gradient
      optimizer.step()
    #   scheduler.step()
      optimizer.zero_grad()
      if (i) % 50 == 0:
        print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, epochs, i+1, 100, loss.item()))
    torch.save(model.state_dict(), str(epoch)+'sw.pt')


# optimizer = AdamW(model.parameters(), lr=1e-5)
# 
# class_weights = compute_class_weight('balanced', np.unique(y_train), y_train)
# weights= torch.tensor(class_weights,dtype=torch.float)
# weights = weights.to(device)
# 
# cross_entropy  = nn.NLLLoss(weight=weights) 
# 
# model.load_state_dict(torch.load("saved_weights.pt"))
# 
# total_preds = []
# 
# for step, batch in enumerate(test_dataloader):
    # if step % 50 == 0 and not step == 0:
        # print('  Batch {:>5,}  of  {:>5,}.'.format(step, len(test_dataloader)))
    # batch = [r.to(device) for r in batch]
    # sent_id, mask, labels = batch
    # preds = model(sent_id, mask)
    # preds=preds.detach().cpu().numpy()
    # total_preds.append(preds)
# total_preds = np.concatenate(total_preds, axis = 0)
# 
# preds = np.argmax(total_preds, axis = 1)
# print(classification_report(test_y, preds))