import os
import numpy as np
import tensorflow as tf
import nltk
from nltk.tokenize import sent_tokenize

def text_examples_to_tfdataset(examples, 
                               tokenizer,
                               max_length=128,
                               ):
    
    def gen():
        for example in examples:
          inputs = tokenizer.encode_plus(example["text"], max_length=max_length, truncation=True, padding="max_length" if max_length else "do_not_pad")

          label = example['label']

          yield ({'input_ids': inputs['input_ids'],
                 'token_type_ids': inputs['token_type_ids'],
                 'attention_mask': inputs['attention_mask']},
                 label)

    label_type = tf.int32

    return tf.data.Dataset.from_generator(gen,
        ({'input_ids': tf.int32,
         'token_type_ids': tf.int32,
         'attention_mask': tf.int32},
         label_type),
        ({'input_ids': tf.TensorShape([None]),
         'token_type_ids': tf.TensorShape([None]),
         'attention_mask': tf.TensorShape([None])},
         tf.TensorShape([])))

def content_examples_to_tfdataset(examples, 
                               tokenizer,
                               max_length=128,
                               ):
    
    def gen():
        for example in examples:
          inputs = tokenizer.encode_plus(example["content"], max_length=max_length, truncation=True, padding="max_length" if max_length else "do_not_pad")

          label = example['label']

          yield ({'input_ids': inputs['input_ids'],
                 'token_type_ids': inputs['token_type_ids'],
                 'attention_mask': inputs['attention_mask']},
                 label)

    label_type = tf.int32

    return tf.data.Dataset.from_generator(gen,
        ({'input_ids': tf.int32,
         'token_type_ids': tf.int32,
         'attention_mask': tf.int32},
         label_type),
        ({'input_ids': tf.TensorShape([None]),
         'token_type_ids': tf.TensorShape([None]),
         'attention_mask': tf.TensorShape([None])},
         tf.TensorShape([])))


convertors = {
    "ag_news": text_examples_to_tfdataset,
    "imdb": text_examples_to_tfdataset,
    "dbpedia_14": content_examples_to_tfdataset,
}


def calc_flops_per_length_bert(n, num_labels):
  classifier_flops = 1537
  c = 1183490 - classifier_flops
  b = 170216500
  a = 37872
  return a*(n**2) + b*n + c + num_labels * classifier_flops


def calc_flops_per_length_bert_w_lr(layers_n, num_labels):
  def layer_flop_calculator(n):
    flops = 3156 * n**2 + 14233554 * n
    return flops
  
  classifier_flops = 1537
  flops = 1183537 + (num_labels - 1) * classifier_flops + 6152 * layers_n[0]
  for n in layers_n:
    flops += layer_flop_calculator(n)
  
  return flops


# Evaluation for our model
class ModelwlrEvaluator:
  def __init__(self, model, eval_dataset, num_labels, metrics):
    self.model = model
    self.num_labels = num_labels
    self.metrics = metrics
    self.eval_dataset = eval_dataset
    self.eval_labels: np.ndarray = None

    for batch in eval_dataset:
      if self.eval_labels is None:
        self.eval_labels = batch[1].numpy()
      else:
        self.eval_labels = np.append(self.eval_labels, batch[1].numpy(), axis=0)

  def evaluate(self, lr_mode):
    eval_preds = []
    speedups = []
    for example in self.eval_dataset:
      output = self.model(example[0], lr_mode=lr_mode, output_hidden_states=True)
      eval_preds.append(tf.math.argmax(output.logits, axis=1))

      # Calculate flpos
      hidden_states = output.hidden_states
      num_tokens = [hidden_states[i].shape[1] for i in range(len(hidden_states))]
      flops_albert_wlr = calc_flops_per_length_albert_w_lr(num_tokens, self.num_labels)
      flops_albert = calc_flops_per_length_albert(num_tokens[0], self.num_labels)
      speedups.append(flops_albert / flops_albert_wlr)
    
    speedup = np.mean(speedups)
    eval_scores = {}
    for metric_name, metric_func in self.metrics.items():
      eval_scores[metric_name] = metric_func(self.eval_labels, eval_preds)

    # Print results
    str_result = ""
    for key, value in eval_scores.items():
      str_result = str_result + " - {0}: {1:.4f}".format(key, value)
    str_result = str_result + " - Speedup: {0:.4f}".format(speedup)
    print(str_result)
  
  def evaluate_all_layers(self, lr_mode):
    eval_preds = []
    for example in self.eval_dataset:
      output = self.model(example[0], lr_mode=lr_mode, multi_classifier_mode=True)
      logits = tf.stack(output.logits)
      eval_preds.append(tf.math.argmax(logits, axis=2))
    
    eval_preds = tf.stack(eval_preds)
    eval_preds = tf.transpose(eval_preds, perm=[1, 0, 2])
    eval_scores = {}
    for metric_name, metric_func in self.metrics.items():
      eval_scores[metric_name] = [metric_func(self.eval_labels, eval_preds[layer]) for layer in range(self.model.albert.config.num_hidden_layers)]

    # Print results
    for layer in range(self.model.albert.config.num_hidden_layers):
      str_result = "Layer {}".format(layer+1)
      for key, value in eval_scores.items():
        str_result = str_result + " - {0}: {1:.4f}".format(key, value[layer])
      print(str_result)
    
    return list(eval_scores.items()[0])[0] if len(eval_scores) > 1 else list(eval_scores.items())[0] # (metric_name, performance array)


# Callback Full Checkpoint
class ModelCheckpoint(tf.keras.callbacks.Callback):
  def __init__(self, datasets, metrics, save_model, saved_model_path):
    super(ModelCheckpoint, self).__init__()
    self.metrics = metrics
    self.save_model = save_model
    self.saved_model_path = saved_model_path
    self.datasets: dict = datasets
    self.eval_labels: np.ndarray = None
    self.test_labels: np.ndarray = None
    self.all_scores = []

  def evaluate(self):
    scores = {}
    for split_name, dataset in self.datasets.items():
      y_preds = []
      y_true = []
      for batch in dataset:
        output = self.model(batch[0])
        y_preds.extend(tf.math.argmax(output.logits, axis=1).numpy())
        y_true.extend(batch[1].numpy())

      scores[split_name] = {}
      for metric_name, metric_func in self.metrics.items():
        scores[split_name][metric_name] = metric_func(y_true, y_preds)
    return scores

  def on_epoch_end(self, epoch, logs=None):
    scores = self.evaluate()
    for split_name, split_scores in scores.items():
      str_result = split_name + ": "
      for metric_name, metric_value in split_scores.items():
        str_result = str_result + " - {0}: {1:.4f}".format(metric_name, metric_value)
      print(str_result)

    if self.save_model:
      self.model.save_weights(self.saved_model_path+"_"+str(epoch+1)+".h5") 
    
    self.all_scores.append(scores)


# Callback Full Checkpoint
class ModelCheckpoint_wlr(tf.keras.callbacks.Callback):
  def __init__(self, datasets, metrics, save_model, saved_model_path, logger=None, strategy=None, PHI=None):
    super(ModelCheckpoint_wlr, self).__init__()
    self.metrics = metrics
    self.save_model = save_model
    self.saved_model_path = saved_model_path
    self.datasets: dict = datasets
    self.eval_labels: np.ndarray = None
    self.test_labels: np.ndarray = None
    self.strategy = strategy
    self.logger = logger
    self.PHI = PHI
    if PHI:
      self.best_combined_loss = 1e+10

  @tf.function
  def eval_wo_sals(self, inputs, labels):
    outputs = self.model(inputs, lr_mode=False, output_explanations=True)
    y_preds = tf.argmax(outputs.logits, axis=-1, output_type=tf.int32)
    return (y_preds, tf.transpose(tf.stack(outputs.explanations), perm=[1, 0, 2]), outputs.logits)

  @tf.function
  def distributed_eval(self, inputs):
    per_replica_outputs = self.strategy.run(self.eval_wo_sals, args=(inputs))
    return (
      self.strategy.gather(per_replica_outputs[0], axis=0),
      self.strategy.gather(per_replica_outputs[1], axis=0),
      self.strategy.gather(per_replica_outputs[2], axis=0),
      self.strategy.gather(inputs[1], axis=0)
    )

  def evaluate(self, inf_lambda=True):
    # loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
    if inf_lambda:
      _lambda = self.model.bert.encoder._lambda.numpy()
      self.model.bert.encoder._lambda.assign(1e+20)

    scores = {}
    for split_name, dataset in self.datasets.items():
      y_preds = []
      losses = []
      y_true = []
      lengths = np.empty((0, 13), dtype=np.int32)
      for batch in dataset:
        if self.strategy is not None:
          outputs = self.distributed_eval(batch)
          preds = outputs[0]
          exps = outputs[1]
          logits = outputs[2]
          labels = outputs[3]
        else:
          outputs = self.eval_wo_sals(batch[0], batch[1])
          preds = outputs[0]
          exps = outputs[1]
          logits = outputs[2]
          labels = batch[1]
        y_preds.extend(preds.numpy())
        y_true.extend(labels.numpy())
        # losses.extend(loss_fn(labels, logits).numpy())
        last_layer_length = np.sum(exps[:, -1] > self.model.bert.encoder.ETA[-1], axis=-1)
        lengths = np.concatenate([lengths, np.concatenate([np.sum(exps > 1e-6, axis=2), np.expand_dims(last_layer_length, axis=-1)], axis=1)])

      scores[split_name] = {}
      first_metric = None
      for metric_name, metric_func in self.metrics.items():
        if first_metric is None:
          first_metric = metric_name
        scores[split_name][metric_name] = metric_func(y_true, y_preds)
      
      if self.PHI:
        length_loss = np.sum(lengths, axis=-1)
        length_loss = np.mean(length_loss, axis=0)
        scores[split_name]["LEN. LOSS"] = length_loss
        scores[split_name]["1-Perf"] = 1 - scores[split_name][first_metric]
        combined_loss = scores[split_name]["1-Perf"] + self.PHI * length_loss
        scores[split_name]["COMBINED Metric"] = combined_loss
      
      flops_bert = 0
      flops_lr = 0
      for l in lengths:
        flops_bert += calc_flops_per_length_bert(l[0], np.max(y_true) + 1)
        flops_lr += calc_flops_per_length_bert_w_lr(l[1:], np.max(y_true) + 1)
      scores[split_name]["COUNT"] = len(lengths)
      scores[split_name]["SPEEDUP (TOTAL)"] = flops_bert / flops_lr
        
    if inf_lambda:
      self.model.bert.encoder._lambda.assign(_lambda)
    return scores

  def predict(self, inf_lambda=True):
    # loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
    if inf_lambda:
      _lambda = self.model.bert.encoder._lambda.numpy()
      self.model.bert.encoder._lambda.assign(1e+20)

    return_outputs = {}
    for split_name, dataset in self.datasets.items():
      y_preds = []
      lengths = np.empty((0, 13), dtype=np.int32)
      for batch in dataset:
        if self.strategy is not None:
          outputs = self.distributed_eval(batch)
          preds = outputs[0]
          exps = outputs[1]
        else:
          outputs = self.eval_wo_sals(batch[0], batch[1])
          preds = outputs[0]
          exps = outputs[1]
        y_preds.extend(preds.numpy())
        last_layer_length = np.sum(exps[:, -1] > self.model.bert.encoder.ETA[-1], axis=-1)
        lengths = np.concatenate([lengths, np.concatenate([np.sum(exps > 1e-6, axis=2), np.expand_dims(last_layer_length, axis=-1)], axis=1)])

      return_outputs[split_name] = {"preds": y_preds}
      
      flops_bert = 0
      flops_lr = 0
      for l in lengths:
        flops_bert += calc_flops_per_length_bert(l[0], np.max(y_preds) + 1)
        flops_lr += calc_flops_per_length_bert_w_lr(l[1:], np.max(y_preds) + 1)
      return_outputs[split_name]["COUNT"] = len(lengths)
      return_outputs[split_name]["SPEEDUP (TOTAL)"] = flops_bert / flops_lr
        
    if inf_lambda:
      self.model.bert.encoder._lambda.assign(_lambda)
    return return_outputs

  def on_epoch_end(self, epoch, logs=None, inf_lambda=False):
    scores = self.evaluate(inf_lambda=inf_lambda)
    for split_name, split_scores in scores.items():
      str_result = split_name + ": "
      for metric_name, metric_value in split_scores.items():
        str_result = str_result + " - {0}: {1:.4f}".format(metric_name, metric_value)
      print(str_result, flush=True)
      if self.logger:
        self.logger.info(str_result)
    
    print(self.model.bert.encoder.ETA.numpy())
    if self.logger:
      self.logger.info(f"Confidence Ratios: {self.model.bert.encoder.ETA.numpy()}")

    if self.PHI:
      dev_combined_loss = scores[list(scores.keys())[0]]["COMBINED Metric"]
      if dev_combined_loss < self.best_combined_loss and inf_lambda:
          print("Combined metric improved from:", self.best_combined_loss, dev_combined_loss, flush=True)
          self.logger.info(f"Combined metric improved from: {self.best_combined_loss}, {dev_combined_loss}")
          self.best_combined_loss = dev_combined_loss

          if self.save_model:
            self.model.save_weights(self.saved_model_path+"_lr.h5") 


# Callback dual evaluation
# class DualModelCheckpoint(tf.keras.callbacks.Callback):
#   def __init__(self, eval_dataset, test_dataset, metrics, save_model, saved_model_path):
#     super(DualModelCheckpoint, self).__init__()
#     self.metrics = metrics
#     self.save_model = save_model
#     self.saved_model_path = saved_model_path
#     self.eval_dataset = eval_dataset
#     self.test_dataset = test_dataset
#     self.eval_labels: np.ndarray = None
#     self.test_labels: np.ndarray = None

#     for batch in eval_dataset:
#       if self.eval_labels is None:
#         self.eval_labels = batch[1].numpy()
#       else:
#         self.eval_labels = np.append(self.eval_labels, batch[1].numpy(), axis=0)

#     for batch in test_dataset:
#       if self.test_labels is None:
#         self.test_labels = batch[1].numpy()
#       else:
#         self.test_labels = np.append(self.test_labels, batch[1].numpy(), axis=0)

#   def evaluate(self):
#     eval_preds = {'validation': [], 'test': []}
#     for example in self.eval_dataset:
#       output = self.model(example[0])
#       eval_preds['validation'].extend(tf.math.argmax(output.logits, axis=1).numpy())
#     for example in self.test_dataset:
#       output = self.model(example[0])
#       eval_preds['test'].extend(tf.math.argmax(output.logits, axis=1).numpy())
    
#     eval_scores = {'validation': {}, 'test': {}}
#     for metric_name, metric_func in self.metrics.items():
#       eval_scores['validation'][metric_name] = metric_func(self.eval_labels, eval_preds['validation'])
#       eval_scores['test'][metric_name] = metric_func(self.test_labels, eval_preds['test'])
#     return eval_scores

#   def on_epoch_end(self, epoch, logs=None):
#     eval_scores = self.evaluate()
#     str_result = "validation: "
#     for key, value in eval_scores['validation'].items():
#       str_result = str_result + " - {0}: {1:.4f}".format(key, value)
#     print(str_result)

#     str_result = "test: "
#     for key, value in eval_scores['test'].items():
#       str_result = str_result + " - {0}: {1:.4f}".format(key, value)
#     print(str_result)

#     if self.save_model:
#       self.model.save_weights(self.saved_model_path+"_"+str(epoch+1)+".h5") 


# Callback for bert
# class SingleModelCheckpoint(tf.keras.callbacks.Callback):
#   def __init__(self, eval_dataset, metrics, save_model=False, saved_model_path=""):
#     super(SingleModelCheckpoint, self).__init__()
#     self.metrics = metrics
#     self.eval_dataset = eval_dataset
#     self.save_model = save_model
#     self.saved_model_path = saved_model_path
#     self.eval_labels: np.ndarray = None

#     for batch in eval_dataset:
#       if self.eval_labels is None:
#         self.eval_labels = batch[1].numpy()
#       else:
#         self.eval_labels = np.append(self.eval_labels, batch[1].numpy(), axis=0)

#   def evaluate(self):
#     eval_preds = []
#     for example in self.eval_dataset:
#       output = self.model(example[0])
#       eval_preds.extend(tf.math.argmax(output.logits, axis=1).numpy())

#     eval_scores = {}
#     for metric_name, metric_func in self.metrics.items():
#       eval_scores[metric_name] = metric_func(self.eval_labels, eval_preds)
#     return eval_scores

#   # def on_batch_end(self, batch, logs=None):
#   #   if batch % 10 == 0:
#   #     os.system("sensors >> templog.txt")

#   def on_epoch_end(self, epoch, logs):
#     eval_scores = self.evaluate()
#     str_result = ""
#     for key, value in eval_scores.items():
#       str_result = str_result + " - {0}: {1:.4f}".format(key, value)
#     print(str_result)

#     if self.save_model:
#       self.model.save_weights(self.saved_model_path+"_"+str(epoch+1)+".h5")
