import numpy as np
import re
import pandas as pd
from tqdm import tqdm
import nltk
from unidecode import unidecode
import os

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils import shuffle
from sklearn.utils.class_weight import compute_class_weight

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import BertTokenizer, BertModel, AlbertTokenizerFast, BertTokenizerFast, AutoModel, AdamW

data = pd.read_csv("./tamil_final_train.csv")
data = pd.DataFrame(data)
data = data.sample(frac=1).reset_index(drop=True)
# data = data[:-20000]

X_val = data['text'][-800:]
y_val = data['tag'][-800:]

X_train = data['text'][:-800]
y_train = data['tag'][:-800]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bert_tokenizer = AlbertTokenizerFast.from_pretrained('./indic-bert-v1/')

def data_prep(data, batch):
    X_val = data['text'][-800:]
    y_val = data['tag'][-800:]

    X_train = data['text'][:-800]
    y_train = data['tag'][:-800]

    train_tokens = bert_tokenizer.batch_encode_plus(
        X_train.tolist(),
        max_length=512,
        padding="longest",
        truncation=True
    )

    val_tokens = bert_tokenizer.batch_encode_plus(
        X_val.tolist(),
        max_length=512,
        padding="longest",
        truncation=True
    )

    train_seq = torch.tensor(train_tokens['input_ids'])
    train_mask = torch.tensor(train_tokens['attention_mask'])
    train_y = torch.tensor(y_train.tolist())

    val_seq = torch.tensor(val_tokens['input_ids'])
    val_mask = torch.tensor(val_tokens['attention_mask'])
    val_y = torch.tensor(y_val.tolist())

    batch_size = batch

    train_data = TensorDataset(train_seq, train_mask, train_y)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

    val_data = TensorDataset(val_seq, val_mask, val_y)
    val_sampler = SequentialSampler(val_data)
    val_dataloader = DataLoader(val_data, sampler = val_sampler, batch_size=batch_size)

    return y_train, train_data, train_sampler, train_dataloader, val_data, val_sampler, val_dataloader

# for param in bert.parameters():
#     param.requires_grad = False

class BERT_FT(nn.Module):
    def __init__(self, bert):
      super(BERT_FT, self).__init__()
      self.bert = bert 
      self.dropout = nn.Dropout(0.1)
      self.relu =  nn.ReLU()
      self.fc1 = nn.Linear(768,512)
      self.fc2 = nn.Linear(512,2)
      self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, sent_id, mask):
      _, cls_hs = self.bert(sent_id, attention_mask=mask)
      x = self.fc1(cls_hs)
      x = self.relu(x)
      x = self.dropout(x)
      x = self.fc2(x)
      x = self.softmax(x)
      return x

def train(model, train_dataloader, optimizer, cross_entropy):
    model.train()
    total_loss, total_accuracy = 0, 0
    total_preds=[]
    for step,batch in enumerate(train_dataloader):
        if step % 50 == 0 and not step == 0:
            print('  Batch {:>5,}  of  {:>5,}.'.format(step, len(train_dataloader)), flush=True)
        batch = [r.to(device) for r in batch]
        sent_id, mask, labels = batch
        model.zero_grad()        
        preds = model(sent_id, mask)
        loss = cross_entropy(preds, labels)
        total_loss = total_loss + loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        preds=preds.detach().cpu().numpy()
        total_preds.append(preds)
    avg_loss = total_loss / len(train_dataloader)
    total_preds  = np.concatenate(total_preds, axis=0)
    return avg_loss, total_preds

def evaluate(model, val_dataloader):
    print("\nEvaluating...", flush=True)
    model.eval()
    total_loss, total_accuracy = 0, 0
    total_preds = []
    for step,batch in enumerate(val_dataloader):
        if step % 50 == 0 and not step == 0:
            print('  Batch {:>5,}  of  {:>5,}.'.format(step, len(val_dataloader)), flush=True)
        batch = [t.to(device) for t in batch]
        sent_id, mask, labels = batch
        with torch.no_grad():
            preds = model(sent_id, mask)
            loss = cross_entropy(preds,labels)
            total_loss = total_loss + loss.item()
            preds = preds.detach().cpu().numpy()
            total_preds.append(preds)
    avg_loss = total_loss / len(val_dataloader) 
    total_preds  = np.concatenate(total_preds, axis=0)
    return avg_loss, total_preds

#do = input("train(1) or test(2)")
do = 1

if int(do)==1:
    for m in range(11):
        # del model
        # del optimizer
        # del class_weights
        # del weights
        # del cross_entropy
        bert = AutoModel.from_pretrained('./indic-bert-v1/')
        model = BERT_FT(bert)
        model = model.to(device)
        optimizer = AdamW(model.parameters(), lr=1e-5)
        # torch.save(model.state_dict(), 'start.pt')
        epochs = 5
        print("model " + str(m) + " training", flush=True)
        best_valid_loss = float('inf')
        train_losses=[]
        valid_losses=[]
        data = data.sample(frac=1).reset_index(drop=True)
        y_train, train_data, train_sampler, train_dataloader, val_data, val_sampler, val_dataloader = data_prep(data, 20)
        class_weights = compute_class_weight('balanced', np.unique(y_train), y_train)
        weights= torch.tensor(class_weights,dtype=torch.float)
        weights = weights.to(device)
        cross_entropy  = nn.NLLLoss(weight=weights)
        for epoch in range(epochs):
            print('\n Epoch {:} / {:}'.format(epoch + 1, epochs), flush=True)
            train_loss, _ = train(model, train_dataloader, optimizer, cross_entropy)
            valid_loss, _ = evaluate(model, val_dataloader)
            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                torch.save(model.state_dict(), './models/tam/' + str(m) + 'tanil_saved_weights.pt')
            train_losses.append(train_loss)
            valid_losses.append(valid_loss)
            print(f'\nTraining Loss: {train_loss:.3f}', flush=True)
            print(f'Validation Loss: {valid_loss:.3f}', flush=True)
        print("model " + str(m) + " trained", flush=True)
        model = model.cpu()
        del bert
        del model
        del optimizer
        del class_weights
        del weights
        del cross_entropy