import numpy as np
from numpy import linalg as LA
import torch
import torch.nn as nn

# Prune each layer with specific percentage [Not used]
# def prune_by_percentile(percent, model, mask):
#     # Calculate percentile value
#     state_dict = model.state_dict()
#     for name in mask.keys():

#         # We do not prune bias term 
#         tensor = state_dict[name].data.cpu().numpy()
#         alive = tensor[np.nonzero(tensor)] # flattened array of nonzero values
#         percentile_value = np.percentile(abs(alive), percent)

#         # Convert Tensors to numpy and calculate
#         new_mask = np.where(abs(tensor) < percentile_value, 0, mask[name])
        
#         # Apply new weight and mask
#         mask[name] = new_mask
        
# Prune weight globally
def prune_by_percentile_global(percent, model, mask):
    # Calculate percentile value
    state_dict = model.state_dict()
    
    # 1. Get Alive list (for)
    
    alive_lst = []
    for name in mask.keys():
        # We do not prune bias term 
        tensor = state_dict[name].data.cpu().numpy()
        alive = tensor[np.nonzero(tensor)] # flattened array of nonzero values
        alive_lst.append(alive)
        
    # 2. Fine percentile value
    alive = np.concatenate(alive_lst)
    percentile_value = np.percentile(abs(alive),percent)
    
    # 3. Create new mask (for)
    for name in mask.keys():
        tensor = state_dict[name].data.cpu().numpy()
        new_mask = np.where(abs(tensor) < percentile_value, 0, mask[name])
        mask[name] = new_mask
        
# Globaly Pruning Neurons with weight magnitude 
def prune_by_percentile_node_global(percent, model, mask):
    # Calculate percentile value
    state_dict = model.state_dict()
    # 1. Calculate L2 Norm of weights and stored it in dict (for)
    # 2. Calculate precentile_value
    # 3. Create New mask (for)
    l2norm = {}
    tmpLst = []
    for name in mask.keys():
        if 'upProject' in name:
          # We do not prune bias term 
          tensor = state_dict[name].data.cpu().numpy()
          alive = LA.norm(tensor,2,0)
          l2norm[name] = alive
          tmpLst.append(alive[np.nonzero(alive)])

    percentile_value = np.percentile(np.concatenate(tmpLst), percent)

    for name in l2norm.keys():
        # Convert Tensors to numpy and calculate
        mask[name][:,np.where(l2norm[name] < percentile_value)]=0
        # down Proj
        down_name = name[:name.find('upProject')]+'downProject.weight'
        mask[down_name][np.where(l2norm[name] < percentile_value),:]=0        
        
        
# Reinitialize the weight and bias, Mask the weight with mask
def original_initialization(mask, initial_state_dict, model):
    # print("",torch.sum(model.state_dict()==initial_state_dict).item() == )
    model.load_state_dict(initial_state_dict)
    for name, param in model.named_parameters():
        if "weight" in name and param.requires_grad and name in mask: 
            weight_dev = param.device
            param.data = torch.from_numpy(mask[name] * initial_state_dict[name].cpu().numpy()).to(weight_dev)
#ANCHOR Print table of zeros and non-zeros count
def print_nonzeros(model):
    nonzero = total = 0
    for name, p in model.named_parameters():
      if 'weight' in name and p.requires_grad and 'dense' in name:
        tensor = p.data.cpu().numpy()
        nz_count = np.count_nonzero(tensor)
        total_params = tensor.size
        nonzero += nz_count
        total += total_params
        #print(f'{name:20} | nonzeros = {nz_count:7} / {total_params:7} ({100 * nz_count / total_params:6.2f}%) | total_pruned = {total_params - nz_count :7} | shape = {tensor.shape}')
    print(f'alive: {nonzero}, pruned : {total - nonzero}, total: {total}, Compression rate : {total/nonzero:10.2f}x  ({100 * (total-nonzero) / total:6.2f}% pruned)')
    return round((nonzero/total)*100,1)
# Function to make an empty mask of the same size as the model
def make_mask(model):
    # Get the number of layer that needed mask, and create a list `mask` for them
    mask = {}
    # Get the size of prune layer and create initial mask for them
    for name, param in model.named_parameters(): 
        if 'weight' in name and param.requires_grad and 'dense' in name:
            tensor = param.data.cpu().numpy()
            mask[name] = np.ones_like(tensor)
    return mask
# Get Persentage of prune
def persentage_nonzeros(model):
    nonzero = total = 0
    for name, p in model.named_parameters():
      if p.requires_grad and 'dense' in name:
        tensor = p.data.cpu().numpy()
        nz_count = np.count_nonzero(tensor)
        total_params = tensor.size
        nonzero += nz_count
        total += total_params
    return round((nonzero/total)*100,1)



# Prune Adapter

# Get the L1 of the adapter
def getAdapterL1(adapter):
    weightdown = adapter.downProject.weight.cpu()
    weightup = adapter.upProject.weight.cpu()
    result = torch.sum(torch.abs(weightdown)) + torch.sum(torch.abs(weightup))
    return result.item()


# Prune adapter layer according to the provided prune_lst
# return a sorted dictionary list {"l1",'layer','up_down'}
def adapterL1(layerList, prune_lst):
  dicList = []
  for i in range(len(layerList)):
    # if pruned skip
    # if not pruned get its sum of weight L1
    layer1 = layerList[i].attention.output.dense.adapter 
    layer2 = layerList[i].output.dense.adapter 
    if not prune_lst[i][0]:
      # Get L1
      l1 = getAdapterL1(layer1)
      # Save dic {layer, up_down, l1}
      dic1 = {'layer':i, "up_down":0, "l1":l1}
      dicList.append(dic1)
    if not prune_lst[i][1]:
      # Get L1
      l1 = getAdapterL1(layer2)
      # Save dic {layer, up_down, l1}
      dic2 = {'layer':i, "up_down":1, "l1":l1}
      dicList.append(dic2)
  dicList.sort(key=lambda x: x['l1'])
  return dicList


# Update the Prune Lst
'''
persent: Pruned persentage
'''
def updatePruneLst(persent, dic, prune_lst):
  num = int(persent*len(dic))
  if num == 0:
    print("finished pruning")
    return False
  tmp = dic[:int(num)]
  for item in tmp:
    layer = item['layer']
    up_down = item['up_down']
    prune_lst[layer][up_down] = True
  print("Pruned layer:",num)
  return True


# Prune adapter layer according to the provided prune_lst
def pruneAdapter(layerList,prune_lst):
  for i in range(len(layerList)):
    front_prune = prune_lst[i][0]
    back_prune = prune_lst[i][1]
    if front_prune:
      layerList[i].attention.output.dense.adapter = nn.Identity()
    if back_prune:
      layerList[i].output.dense.adapter = nn.Identity()
    
'''
Example:
dic = adapterL1(model.model.bert.encoder.layer,prune_lst)
updatePruneLst(0.2, dic, prune_lst)
pruneAdapter(model.model.bert.encoder.layer, prune_lst)
'''
