import numpy as np
import re
import pandas as pd
from tqdm import tqdm
import nltk
from unidecode import unidecode
from langdetect import detect
from alphabet_detector import AlphabetDetector
from indic_transliteration import sanscript
from indic_transliteration.sanscript import transliterate

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
from sklearn.utils import shuffle

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import AlbertTokenizerFast, AlbertModel, BertTokenizer, RobertaTokenizer, BertModel, RobertaTokenizerFast, BertTokenizerFast, AutoModel, AdamW, BertForSequenceClassification

import http.client
import json

def remove_links(x):
    return re.sub(r"((http|ftp|https):\/\/)?[-a-zA-Z0-9@:%._\+~#=]{2,256}\.[a-z]{2,6}\b([-a-zA-Z0-9@:%_\+.~#?&//=]*)", "", x)

def preproc(inp):
    inp = inp.lower()
    inp = re.sub('\\t+', '', inp)
    inp = re.sub('\d+\. ', '', inp)
    inp = re.sub('\d+\. ', '', inp)
    ends = re.compile('[?!.]')
    inp = re.sub(ends, " . ", inp)
    # print(type(inp))
    final = []
    for i in inp.split():
        emoji_pattern = re.compile("["
            u"\U0001F600-\U0001F64F"  # emoticons
            u"\U0001F300-\U0001F5FF"  # symbols & pictographs
            u"\U0001F680-\U0001F6FF"  # transport & map symbols
            u"\U0001F1E0-\U0001F1FF"  # flags (iOS)
            u"\U0001F1F2-\U0001F1F4"  # Macau flag
            u"\U0001F1E6-\U0001F1FF"  # flags
            u"\U0001F600-\U0001F64F"
            u"\U00002702-\U000027B0"
            u"\U000024C2-\U0001F251"
            u"\U0001f900-\U0001f999"
            u"\U0001F1F2"
            u"\U0001F1F4"
            u"\U0001F620"
            u"\u200d"
            u"\u2640-\u2642"
            "]+", flags=re.UNICODE)
        i = emoji_pattern.sub(r'', i)
        if len(i) and i[0]!='@':
            final.append(i)
    # final = [i for i in inp.split() if i[0]!='@']
    text = ' '.join(final)
    text = re.sub(r"\s+", " ", text)
    # print(text)
    return text

def tokenize_txt(txt):
    stop_words = set(stopwords.words('english'))
    #only take words or numbers in
    tokenizer = RegexpTokenizer(r'[a-zA-Z]+', gaps=False)
    tokens = tokenizer.tokenize(txt)
    tokens = [word.lower() for word in tokens if word.lower() not in stop_words]
    return tokens

def lemmatization(data):
    lemmatizer = WordNetLemmatizer()
    lemmatized_text = [lemmatizer.lemmatize(word) for word in data]
    return ' '.join(lemmatized_text)

malayalam = 'ml-t-i0-und'
tamil = 'ta-t-i0-und'
telugu = 'te-t-i0-und'
ad = AlphabetDetector()

def request(input, itc):
    conn = http.client.HTTPSConnection('inputtools.google.com')
    conn.request('GET', '/request?text=' + input + '&itc=' + itc + '&num=1&cp=0&cs=1&ie=utf-8&oe=utf-8&app=test')
    res = conn.getresponse()
    return res

def driver(text, itc):
    output = ''
    input = text.split()
    for i in input:
        flag = ad.only_alphabet_chars(i, "LATIN")
        if flag:
            curr = transliterate(i, sanscript.HK, sanscript.MALAYALAM)
            output = output + ' ' + curr
            # res = request(input = i, itc = itc)
            # curr = str(res.read(), encoding = 'utf-8')[14+4+len(input):-31]
            # if len(curr)==0 and i!='.':
                # curr = transliterate(i, sanscript.HK, sanscript.MALAYALAM)
            # if len(curr)==0:
                # curr = i
            # output = output + ' ' + curr
        else:
            output = output + ' ' + i
    # global count
    # count+=1
    # print(count)
    # print(text, output)
    return output

hope_train = pd.read_csv("./malayalam_train.csv", header=None)
data_train = pd.DataFrame(hope_train)
data_train.columns = ['text', 'tag']
data_train.drop(data_train[data_train['tag']=='not-English'].index, inplace=True)
data_train['text'] = data_train['text'].apply(lambda x: remove_links(x))
data_train['text'] = data_train['text'].apply(lambda x: preproc(x))
data_train['text'] = data_train['text'].apply(lambda x: driver(x, malayalam))
data_train['tag'] = data_train['tag'].apply(lambda x: 1 if (x=='Hope_speech') else 0)
print("Train preprocessed", flush=True)

hope_train = pd.read_csv("./malayalam_dev.csv", header=None)
data_dev = pd.DataFrame(hope_train)
data_dev.columns = ['text', 'tag']
data_dev.drop(data_dev[data_dev['tag']=='not-English'].index, inplace=True)
data_dev['text'] = data_dev['text'].apply(lambda x: remove_links(x))
data_dev['text'] = data_dev['text'].apply(lambda x: preproc(x))
data_dev['text'] = data_dev['text'].apply(lambda x: driver(x, malayalam))
data_dev['tag'] = data_dev['tag'].apply(lambda x: 1 if (x=='Hope_speech') else 0)
print("Val preprocessed", flush=True)

data = data_train.append(data_dev, ignore_index=True, sort=False)
data = shuffle(data)

X_val = data['text'][-1000:]
y_val = data['tag'][-1000:]

X = data['text'][:-1000]
y = data['tag'][:-1000]
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size = 0.8, stratify=y, random_state=42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bert_tokenizer = AlbertTokenizerFast.from_pretrained('./indic-bert-v1')
bert = AutoModel.from_pretrained('./indic-bert-v1')

def data_prep(data, batch, test=0):
    X_val = data['text'][-1000:]
    y_val = data['tag'][-1000:]

    X = data['text'][:-1000]
    y = data['tag'][:-1000]
    X_train, X_test, y_train, y_test = train_test_split(X, y, train_size = 0.8, stratify=y, random_state=42)

    train_tokens = bert_tokenizer.batch_encode_plus(
        X_train.tolist(),
        max_length=512,
        padding="longest",
        truncation=True
    )

    test_tokens = bert_tokenizer.batch_encode_plus(
        X_test.tolist(),
        max_length=512,
        padding="longest",
        truncation=True
    )

    val_tokens = bert_tokenizer.batch_encode_plus(
        X_val.tolist(),
        max_length=512,
        padding="longest",
        truncation=True
    )

    train_seq = torch.tensor(train_tokens['input_ids'])
    train_mask = torch.tensor(train_tokens['attention_mask'])
    train_y = torch.tensor(y_train.tolist())

    val_seq = torch.tensor(val_tokens['input_ids'])
    val_mask = torch.tensor(val_tokens['attention_mask'])
    val_y = torch.tensor(y_val.tolist())

    test_seq = torch.tensor(test_tokens['input_ids'])
    test_mask = torch.tensor(test_tokens['attention_mask'])
    test_y = torch.tensor(y_test.tolist())

    batch_size = batch

    train_data = TensorDataset(train_seq, train_mask, train_y)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

    test_data = TensorDataset(test_seq, test_mask, test_y)
    test_sampler = SequentialSampler(test_data)
    test_dataloader = DataLoader(test_data, sampler = test_sampler, batch_size=batch_size)

    val_data = TensorDataset(val_seq, val_mask, val_y)
    val_sampler = SequentialSampler(val_data)
    val_dataloader = DataLoader(val_data, sampler = val_sampler, batch_size=batch_size)

    if test==0:
        return train_data, train_sampler, train_dataloader, test_data, test_sampler, test_dataloader, val_data, val_sampler, val_dataloader
    else:
        return test_y, test_dataloader

class BERT_FT(nn.Module):
    def __init__(self, bert):
      super(BERT_FT, self).__init__()
      self.bert = bert 
      self.dropout = nn.Dropout(0.1)
      self.relu =  nn.ReLU()
      self.fc1 = nn.Linear(768,512)
      self.fc2 = nn.Linear(512,2)
      self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, sent_id, mask):
      _, cls_hs = self.bert(sent_id, attention_mask=mask)
      x = self.fc1(cls_hs)
      x = self.relu(x)
      x = self.dropout(x)
      x = self.fc2(x)
      x = self.softmax(x)
      return x

model = BERT_FT(bert)
model = model.to(device)

optimizer = AdamW(model.parameters(), lr=1e-5)

class_weights = compute_class_weight('balanced', np.unique(y_train), y_train)
weights= torch.tensor(class_weights,dtype=torch.float)
weights = weights.to(device)

cross_entropy  = nn.NLLLoss(weight=weights) 

epochs = 10

def train(train_dataloader):
    model.train()
    total_loss, total_accuracy = 0, 0
    total_preds=[]
    for step,batch in enumerate(train_dataloader):
        if step % 50 == 0 and not step == 0:
            print('  Batch {:>5,}  of  {:>5,}.'.format(step, len(train_dataloader)), flush=True)
        batch = [r.to(device) for r in batch]
        sent_id, mask, labels = batch
        model.zero_grad()        
        preds = model(sent_id, mask)
        loss = cross_entropy(preds, labels)
        total_loss = total_loss + loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        preds=preds.detach().cpu().numpy()
        total_preds.append(preds)
    avg_loss = total_loss / len(train_dataloader)
    total_preds  = np.concatenate(total_preds, axis=0)
    return avg_loss, total_preds

def evaluate(val_dataloader):
    print("\nEvaluating...", flush=True)
    model.eval()
    total_loss, total_accuracy = 0, 0
    total_preds = []
    for step,batch in enumerate(val_dataloader):
        if step % 50 == 0 and not step == 0:
            elapsed = format_time(time.time() - t0)
            print('  Batch {:>5,}  of  {:>5,}.'.format(step, len(val_dataloader)), flush=True)
        batch = [t.to(device) for t in batch]
        sent_id, mask, labels = batch
        with torch.no_grad():
            preds = model(sent_id, mask)
            loss = cross_entropy(preds,labels)
            total_loss = total_loss + loss.item()
            preds = preds.detach().cpu().numpy()
            total_preds.append(preds)
    avg_loss = total_loss / len(val_dataloader) 
    total_preds  = np.concatenate(total_preds, axis=0)
    return avg_loss, total_preds

#do = input("train(1) or test(2)")
do = 1

if int(do)==1:
    for m in range(11):
        if m!=0:
            model = BERT_FT(bert)
            model = model.to(device)
            
            optimizer = AdamW(model.parameters(), lr=1e-5)
            
            class_weights = compute_class_weight('balanced', np.unique(y_train), y_train)
            weights= torch.tensor(class_weights,dtype=torch.float)
            weights = weights.to(device)
            
            cross_entropy  = nn.NLLLoss(weight=weights)
            
            epochs = 10
        print("model " + str(m) + " training", flush=True)
        best_valid_loss = float('inf')
        train_losses=[]
        valid_losses=[]
        data = shuffle(data)
        train_data, train_sampler, train_dataloader, test_data, test_sampler, test_dataloader, val_data, val_sampler, val_dataloader = data_prep(data, 20)
        for epoch in range(epochs):
            print('\n Epoch {:} / {:}'.format(epoch + 1, epochs), flush=True)
            train_loss, _ = train(train_dataloader)
            valid_loss, _ = evaluate(val_dataloader)
            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                torch.save(model.state_dict(), './models/mlm/' + str(m) + 'mlm_saved_weights.pt')
            train_losses.append(train_loss)
            valid_losses.append(valid_loss)
            print(f'\nTraining Loss: {train_loss:.3f}', flush=True)
            print(f'Validation Loss: {valid_loss:.3f}', flush=True)
        print("model " + str(m) + " trained", flush=True)

else:
    # data = shuffle(data)
    # test_y, test_dataloader = data_prep(data, 20, 1)
    all_preds = []
    for m in range(11):
        data = shuffle(data)
        test_y, test_dataloader = data_prep(data, 20, 1)
        model.load_state_dict(torch.load('./models/mlm/' + str(m) + 'mlm_saved_weights.pt'))
        total_preds = []
        for step, batch in enumerate(test_dataloader):
            if step % 50 == 0 and not step == 0:
                print(' Batch {:>5,} of {:>5,}.'.format(step, len(test_dataloader)), flush=True)
            batch = [r.to(device) for r in batch]
            sent_id, mask, labels = batch
            preds = model(sent_id, mask)
            preds = preds.detach().cpu().numpy()
            total_preds.append(preds)
        total_preds = np.concatenate(total_preds, axis = 0)
        preds = np.argmax(total_preds, axis = 1)
        all_preds.append(preds)
    for preds in all_preds:
        print(classification_report(test_y, preds), flush=True)