import torch.nn as nn
from collections import OrderedDict
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
import torch
class Adapter(nn.Module):
    def __init__(self, hid_size, size):
        super(Adapter, self).__init__()
        self.downProject = nn.Linear(hid_size, size, bias=True)
        self.activation = nn.GELU()
        self.upProject = nn.Linear(size, hid_size, bias=True)

        self.downProject.apply(self.init_weight)
        self.activation.apply(self.init_weight)
        self.upProject.apply(self.init_weight)
        
    def forward(self, x):
        downProject = self.downProject(x)
        downProject = self.activation(downProject)
        upProject = self.upProject(downProject)
        
        return upProject + x
    def init_weight(self, m):
        if type(m) == nn.Linear:
            std = 1e-2
            nn.init.normal_(m.weight,std=std)
            m.weight.data.masked_fill_(m.weight.data>2*std,2*std)
            m.weight.data.masked_fill_(m.weight.data<-2*std,-2*std)
            m.bias.data.fill_(0.01)



# Only Adapter do the training
def activateAdapter(model):
    for name, param in model.named_parameters():
        if 'adapter' in name or 'classifier' in name:
            param.requires_grad = True
            # print(name,param.requires_grad)
        else:
            param.requires_grad = False

# Add adapter to the model
# layerList: Transformers encoder list
# hid_size: size of hidden layer
def setAdapter(layerList, hid_size, adapter_size):
    for layer in layerList:
        layer.attention.output.dense = nn.Sequential(OrderedDict([
            ('dense',layer.attention.output.dense),
            ('adapter',Adapter(hid_size, adapter_size))])
        )
        layer.output.dense = nn.Sequential(OrderedDict([
            ('dense',layer.output.dense),
            ('adapter',Adapter(hid_size, adapter_size))])
        )
        
        # layer.attention.output.dense.to(device)
        # layer.output.dense.to(device)
        
# Add adapter to the model
# layerList: Transformers encoder list
# hid_size: size of hidden layer        
# size_lst: 12*2 list, specify the size of adapters
def setAdapterByLst(layerList, hid_size, size_lst):
    i = 0
    for layer in layerList:
        tmp = size_lst[i][0]
        if tmp!= 0:
            layer.attention.output.dense = nn.Sequential(OrderedDict([
                ('dense',layer.attention.output.dense),
                ('adapter',Adapter(hid_size, tmp))])
            )
        tmp = size_lst[i][1]
        if tmp!=0:
            layer.output.dense = nn.Sequential(OrderedDict([
                ('dense',layer.output.dense),
                ('adapter',Adapter(hid_size, tmp))])
            )
        i += 1
# Adapter Model
'''
    WARNING: Only support bert for now, due to activateAdapter function
'''
class AdapterModel(torch.nn.Module):
    def __init__(self, model_path, num_labels, adapter_size=0, size_lst = None):
        super().__init__()
        self.model = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=num_labels,mirror='tuna')
        if size_lst is None:
            setAdapter(self.model.bert.encoder.layer,768,adapter_size)
        else:
            setAdapterByLst(self.model.bert.encoder.layer,768,size_lst)
        activateAdapter(self.model.bert)
    def forward(self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,):
        output = self.model(input_ids,attention_mask,token_type_ids,position_ids,head_mask,inputs_embeds,labels,output_attentions,output_hidden_states,return_dict)
        return output