import transformers
from transformers import BertModel, BertTokenizer, AlbertTokenizer, AlbertModel, AdamW, get_linear_schedule_with_warmup
import torch

import numpy as np
import pandas as pd
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from collections import defaultdict
from textwrap import wrap

from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

import logging
logging.basicConfig(level=logging.ERROR)
from langdetect import detect

df_test = pd.read_csv("mal_final_test.csv")

data_count = df_test.count()[0]

PRE_TRAINED_MODEL_NAME = './indic-bert-v1'
tokenizer = AlbertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)

sample_txt = "ച്രfത്സ് ഥനിക് ഓര് ഗയ് അയ കുത്തി ഓന്ദവഥിരികതേ"
encoding = tokenizer.encode_plus(
  sample_txt,
  truncation=True,
  max_length=32,
  add_special_tokens=True, # Add '[CLS]' and '[SEP]'
  return_token_type_ids=False,
  pad_to_max_length=True,
  return_attention_mask=True,
  return_tensors='pt',  # Return PyTorch tensors
)

encoding.keys()

MAX_LEN = 150

class GPReviewDataset(Dataset):

  def __init__(self, reviews, targets, tokenizer, max_len):
    self.reviews = reviews
    self.targets = targets
    self.tokenizer = tokenizer
    self.max_len = max_len
  
  def __len__(self):
    return len(self.reviews)
  
  def __getitem__(self, item):
    review = str(self.reviews[item])
    target = self.targets[item]

    encoding = self.tokenizer.encode_plus(
      review,
      truncation=True,
      add_special_tokens=True,
      max_length=self.max_len,
      return_token_type_ids=False,
      pad_to_max_length=True,
      return_attention_mask=True,
      return_tensors='pt',
    )

    return {
      'review_text': review,
      'input_ids': encoding['input_ids'].flatten(),
      'attention_mask': encoding['attention_mask'].flatten(),
      'targets': torch.tensor(target, dtype=torch.long)
    }

def create_data_loader(df, tokenizer, max_len, batch_size):
  ds = GPReviewDataset(
    reviews=df.text.to_numpy(),
    targets=df.tag.to_numpy(),
    tokenizer=tokenizer,
    max_len=max_len
  )

  return DataLoader(
    ds,
    batch_size=batch_size,
    num_workers=4
  )

BATCH_SIZE = 16

test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)

data = next(iter(test_data_loader))
data.keys()

bert_model = AlbertModel.from_pretrained('./indic-bert-v1')

class SentimentClassifier(nn.Module):

  def __init__(self, n_classes):
    super(SentimentClassifier, self).__init__()
    self.bert = AlbertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
    self.drop = nn.Dropout(p=0.3)
    self.out = nn.Linear(self.bert.config.hidden_size, n_classes)
  
  def forward(self, input_ids, attention_mask):
    _, pooled_output = self.bert(
      input_ids=input_ids,
      attention_mask=attention_mask
    )
    output = self.drop(pooled_output)
    return self.out(output)

model = SentimentClassifier(3)
model = model.to(device)

input_ids = data['input_ids'].to(device)
attention_mask = data['attention_mask'].to(device)

F.softmax(model(input_ids, attention_mask), dim=1)

def get_predictions(model, data_loader):
  model = model.eval()
  
  review_texts = []
  predictions = []
  prediction_probs = []
  real_values = []

  with torch.no_grad():
    for d in data_loader:

      texts = d["review_text"]
      input_ids = d["input_ids"].to(device)
      attention_mask = d["attention_mask"].to(device)
      targets = d["targets"].to(device)

      outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask
      )
      _, preds = torch.max(outputs, dim=1)

      probs = F.softmax(outputs, dim=1)

      review_texts.extend(texts)
      predictions.extend(preds)
      prediction_probs.extend(probs)
      real_values.extend(targets)

  predictions = torch.stack(predictions).cpu()
  prediction_probs = torch.stack(prediction_probs).cpu()
  real_values = torch.stack(real_values).cpu()
  return review_texts, predictions, prediction_probs, real_values

all_preds = torch.zeros(data_count)
some_preds = torch.zeros(data_count)
real_preds = torch.tensor([])

print(data_count)

for m in range(11):
  model.load_state_dict(torch.load('./models/malayalam/' + str(m) + "best_model_state.bin"))

  y_review_texts, y_pred, y_pred_probs, y_test = get_predictions(
    model,
    test_data_loader
  )

  # print(y_pred)
  all_preds+= y_pred

  if m<7:
    some_preds+=y_pred

  if not m:
    real_preds = y_test

  print(classification_report(real_preds, y_pred, target_names=["no-hope","hope"]), flush=True)
  # if len(real_preds) and not torch.all(torch.eq(real_preds, y_test)):
    # print("UNEQUAL")

some_preds = some_preds//4
all_preds = all_preds//6
print(classification_report(real_preds, some_preds, target_names=["no-hope","hope"]), flush=True)
print(classification_report(real_preds, all_preds, target_names=["no-hope","hope"]), flush=True)
# print(classification_report(y_test, y_pred, target_names=["no-hope","hope"]), flush=True)