from typing import Optional, Dict, Iterable, List, Tuple

import transformers


import torch
from torch.nn.utils.stateless import _reparametrize_module

import Levenshtein
from torch.utils.data import DataLoader

from config_evaluator import Lazy

import torchopt

import tqdm

from logger import Logger


def compute_loss_for_input(model, params, inputs):
    # maml https://github.com/pytorch/benchmark/blob/a6843e77414a22fc150a660f6f27c524169863f4/torchbenchmark/models/functorch_maml_omniglot/__init__.py
    r = torch.func.functional_call(model, params, args=(), kwargs=inputs)
    return r.loss


def to_param_name(s):
  return s.replace(".", "_")

def dict_grad_update(all_params, trainable_names, model, batch, step_size, create_graph = True):
    tr = [(k,v) for k,v in all_params.items() if k in trainable_names]
    gradients = torch.autograd.grad(compute_loss_for_input(model, all_params, batch), inputs=[v for (k,v) in tr],
                                    create_graph=create_graph)
    if step_size is None:
      # use learned learning rates
      updated_params = {k : v - torch.exp(all_params["lr_"+to_param_name(k)]) * grad for grad, (k,v) in zip(gradients, tr)}
    else:
      updated_params = {k : v - step_size * grad for grad, (k,v) in zip(gradients, tr)}
    d = dict(all_params)
    d.update(updated_params)
    return d

from collections import deque

class MovingAvg:
    # Naive implementation
    def __init__(self, window_length):
        self.l = deque()
        self.window_length = window_length

    def extend(self, l):
        last = None
        for x in l:
            last = self.append(x)
        return last

    def get(self):
        if len(self.l) == 0:
            return 0
        return sum(self.l) / len(self.l)

    def append(self, x):
        if len(self.l) == self.window_length:
            self.l.popleft()
        self.l.append(x)
        return sum(self.l) / len(self.l)


def evaluate_on(model, tokenizer, dataloader):
  correct, total, edit_dist, per = 0,0,0,0
  model.eval()
  for test_batch in dataloader:
    test_batch = {k: v.to(model.device) for k,v in test_batch.items()}
    test_batch_inputs = dict(test_batch)
    del test_batch_inputs["labels"]
    r = tokenizer.batch_decode(model.generate(**test_batch_inputs, max_new_tokens=test_batch["labels"].shape[1]+2,
                                              early_stopping="never", num_beams=1, no_repeat_ngram_size=0), skip_special_tokens=True)
    gold = tokenizer.batch_decode(100*(test_batch["labels"] == -100) + test_batch["labels"], skip_special_tokens=True) # replace -100 by 0
    for p, g in zip(r, gold):
      print(p, "\t|\t", g)
    correct += sum( [x == y for x,y in zip(r, gold)])
    total += len(gold)
    edit_dist += sum( Levenshtein.distance(x,y) for x,y in zip(r, gold))
    per += sum(Levenshtein.distance(x,y)/max(1, len(y)) for x,y in zip(r, gold))
  return correct/total, edit_dist/total, per/total


def meta_training_load(model, save_dir):
    model.load_adapter(save_dir+"/task_adapter")
    model.load_adapter(save_dir+"/meta_adapter_1")
    model.load_adapter(save_dir+"/meta_adapter_2")
    model.active_adapters = transformers.adapters.composition.Stack("meta_adapter_1", "task_adapter",
                                                                                    "meta_adapter_2")
    return model                                                                                    

#################
from math import ceil
def get_device_map(n_layers, devices):
    """Returns a dictionary of layers distributed evenly across all devices."""
    layers = list(range(n_layers))
    n_blocks = int(ceil(n_layers / len(devices)))
    layers_list = list(layers[i : i + n_blocks] for i in range(0, n_layers, n_blocks))

    return dict(zip(devices, layers_list))
    
def hack_t5_parallelize(model):
   model.encoder.parallelize(get_device_map(len(model.encoder.block), range(torch.cuda.device_count())))
   model.decoder.parallelize(get_device_map(len(model.decoder.block), range(torch.cuda.device_count())))
   model.lm_head = model.lm_head.to(model.decoder.first_device)
   model.model_parallel = True

   return model
