import torch, re, csv
from tqdm.notebook import tqdm

from transformers import BertTokenizer
from torch.utils.data import TensorDataset

from transformers import BertForSequenceClassification

import preprocessor as p
p.set_options(p.OPT.URL)
import demoji
demoji.download_codes() 
import pandas as pd
from sklearn import preprocessing
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

def decontracted(phrase):
    # specific
    phrase = re.sub(r"won\'t", "will not", phrase)
    phrase = re.sub(r"can\'t", "can not", phrase)

    # general
    phrase = re.sub(r"n\'t", " not", phrase)
    phrase = re.sub(r"\'re", " are", phrase)
    phrase = re.sub(r"\'s", " is", phrase)
    phrase = re.sub(r"\'d", " would", phrase)
    phrase = re.sub(r"\'ll", " will", phrase)
    phrase = re.sub(r"\'t", " not", phrase)
    phrase = re.sub(r"\'ve", " have", phrase)
    phrase = re.sub(r"\'m", " am", phrase)
    return phrase

def getData(file, tag='train'):
    tsv_file = open(file)
    df = pd.read_csv(tsv_file, delimiter="\t", header=None)
    if(tag=='train'):
      df = df.drop([2], axis=1)
      df = df.rename(columns={0: "text", 1: "label"})
    else:
      df = df.rename(columns={0: "text"})
    return df

def preprocessTweet(row):
    text = row['text']
    text = p.clean(text)
    text = demoji.replace_with_desc(text, sep = "")
    text = text.lower().replace("#", "").replace("@", "")
    text = decontracted(text)
    return text

trainFilename = "english_hope_train.csv"
validFilename = "english_hope_dev.csv"
testFilename = "english_hope_test.csv"
trainDF = getData(trainFilename)
validDF = getData(validFilename)
testDF = getData(testFilename, tag='test')

trainDF['preprocessedText'] = trainDF.apply(preprocessTweet, axis=1)
validDF['preprocessedText'] = validDF.apply(preprocessTweet, axis=1)
testDF['preprocessedText'] = testDF.apply(preprocessTweet, axis=1)

print("Train Data Shape: ",trainDF.shape)
print("Validation Data Shape: ",validDF.shape)
print("Test Data Shape: ",testDF.shape)

print(trainDF['label'].value_counts())
print(validDF['label'].value_counts())
possible_labels = trainDF['label'].unique()

labelEncoder = preprocessing.LabelEncoder()
labelEncoder.fit(['Non_hope_speech', 'Hope_speech', 'not-English'])
trainDF['numericalLabels'] = labelEncoder.transform(trainDF['label'])
validDF['numericalLabels'] = labelEncoder.transform(validDF['label'])

totalTrainData = pd.concat([trainDF, validDF])
print("Total Train Data Shape: ", totalTrainData.shape)

tokenizer = BertTokenizer.from_pretrained('bert-base-cased', 
                                          do_lower_case=True)

encoded_data_train = tokenizer.batch_encode_plus(
    totalTrainData.preprocessedText.values, 
    add_special_tokens=True, 
    return_attention_mask=True, 
    pad_to_max_length=True, 
    max_length=256, 
    truncation = True,
    return_tensors='pt'

)

encoded_data_val = tokenizer.batch_encode_plus(
    validDF.preprocessedText.values, 
    add_special_tokens=True, 
    return_attention_mask=True, 
    pad_to_max_length=True, 
    max_length=256, 
    truncation = True,
    return_tensors='pt'
)

encoded_data_test = tokenizer.batch_encode_plus(
    testDF.preprocessedText.values, 
    add_special_tokens=True, 
    return_attention_mask=True, 
    pad_to_max_length=True, 
    max_length=256, 
    truncation = True,
    return_tensors='pt'
)


input_ids_train = encoded_data_train['input_ids']
attention_masks_train = encoded_data_train['attention_mask']
labels_train = torch.tensor(totalTrainData.numericalLabels.values)

input_ids_val = encoded_data_val['input_ids']
attention_masks_val = encoded_data_val['attention_mask']
labels_val = torch.tensor(validDF.numericalLabels.values)

input_ids_test = encoded_data_test['input_ids']
attention_masks_test = encoded_data_test['attention_mask']

dataset_train = TensorDataset(input_ids_train, attention_masks_train, labels_train)
dataset_val = TensorDataset(input_ids_val, attention_masks_val, labels_val)
dataset_test = TensorDataset(input_ids_test, attention_masks_test)

len(dataset_train), len(dataset_val), len(dataset_test)

freeze_bert=False
model = BertForSequenceClassification.from_pretrained("bert-base-cased",
                                                      num_labels=3)
model.bert.requires_grad = not freeze_bert

from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

batch_size = 16

dataloader_train = DataLoader(dataset_train, 
                              sampler=RandomSampler(dataset_train), 
                              batch_size=batch_size)

dataloader_validation = DataLoader(dataset_val, 
                                   sampler=SequentialSampler(dataset_val), 
                                   batch_size=batch_size)

dataloader_test = DataLoader(dataset_test, 
                                   sampler=SequentialSampler(dataset_test), 
                                   batch_size=batch_size)

from transformers import AdamW, get_linear_schedule_with_warmup
import torch.optim as optim

optimizer = AdamW(model.parameters(),
                  lr=2e-5)

epochs = 5

scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps=0,
                                            num_training_steps=len(dataloader_train)*epochs)

from sklearn.metrics import f1_score

def f1_score_func(preds, labels):
    preds = np.argmax(preds, axis=1).flatten()
    labels = labels.flatten()
    print(classification_report(labels, preds, digits=4))
    print(confusion_matrix(labels, preds))
    return f1_score(labels, preds, average='weighted')

def accuracy_per_class(preds, labels):
    label_dict_inverse = {v: k for k, v in label_dict.items()}
    
    preds_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()

    for label in np.unique(labels_flat):
        y_preds = preds_flat[labels_flat==label]
        y_true = labels_flat[labels_flat==label]
        print(f'Class: {label_dict_inverse[label]}')
        print(f'Accuracy: {len(y_preds[y_preds==label])}/{len(y_true)}\n')

import random
import numpy as np

seed_val = 17
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

def evaluate(dataloader_val):

    model.eval()
    
    loss_val_total = 0
    predictions, true_vals = [], []
    
    for batch in dataloader_val:
        
        batch = tuple(b.to(device) for b in batch)
        
        inputs = {'input_ids':      batch[0],
                  'attention_mask': batch[1],
                  'labels':         batch[2],
                 }

        with torch.no_grad():        
            outputs = model(**inputs)
                
        loss = outputs[0]
        logits = outputs[1]
        loss_val_total += loss.item()

        logits = logits.detach().cpu().numpy()
        label_ids = inputs['labels'].cpu().numpy()
        predictions.append(logits)
        true_vals.append(label_ids)
    
    loss_val_avg = loss_val_total/len(dataloader_val) 
    
    predictions = np.concatenate(predictions, axis=0)
    true_vals = np.concatenate(true_vals, axis=0)
            
    return loss_val_avg, predictions, true_vals

def test(dataloader_test):

    model.eval()

    predictions = []
    
    for batch in dataloader_test:
        
        batch = tuple(b.to(device) for b in batch)
        
        inputs = {'input_ids':      batch[0],
                  'attention_mask': batch[1],
                 }

        with torch.no_grad():        
            outputs = model(**inputs)
                
        #loss = outputs[0]
        logits = outputs[0]

        logits = logits.detach().cpu().numpy()
        predictions.append(logits)
    
    predictions = np.concatenate(predictions, axis=0)
            
    return predictions

for epoch in tqdm(range(1, epochs+1)):
    
    model.train()
    
    loss_train_total = 0

    progress_bar = tqdm(dataloader_train, desc='Epoch {:1d}'.format(epoch), leave=False, disable=False)
    for batch in progress_bar:

        model.zero_grad()
        
        batch = tuple(b.to(device) for b in batch)
        
        inputs = {'input_ids':      batch[0],
                  'attention_mask': batch[1],
                  'labels':         batch[2],
                 }       

        outputs = model(**inputs)
        
        loss = outputs[0]
        loss_train_total += loss.item()
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        scheduler.step()
        
        progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item()/len(batch))})
         
        
        
    tqdm.write(f'\nEpoch {epoch}')
    
    loss_train_avg = loss_train_total/len(dataloader_train)            
    tqdm.write(f'Training loss: {loss_train_avg}')
    
    '''val_loss, predictions, true_vals = evaluate(dataloader_validation)
    val_f1 = f1_score_func(predictions, true_vals)
    tqdm.write(f'Validation loss: {val_loss}')
    tqdm.write(f'F1 Score (Weighted): {val_f1}')'''
    print("test epoch running")
    test_predictions = test(dataloader_test)
    labels = test_predictions.argmax(axis=1)
    labels = list(labelEncoder.inverse_transform(labels))
    header = "epoch"+str(epoch)
    testDF[header]=labels
    #torch.save(model.state_dict(), f'/content/drive/MyDrive/Colab Notebooks/EALC/finetuned_BERT_epoch_{epoch}.model')

temp=1
for i in range(1, epochs+1):
	outputs = []
	for i,j in zip(testDF['text'], testDF['epoch'+str(i)]):
		count = 1
		row = ["en_sent_"+str(count), i , j]
		#print(row)
		count+=1
		outputs.append(row)

	print(i,temp)
	with open('output'+str(temp)+'.tsv', 'wt') as out_file:
		tsv_writer = csv.writer(out_file, delimiter='\t')
		for j in outputs:
		    tsv_writer.writerow(j)
	temp+=1
