import os
import argparse
import numpy as np
from operator import itemgetter
from transformers import AutoTokenizer, AutoModelForCausalLM
from DirectShare.models_for_share.original.modeling_llama import LlamaForCausalLM as LlamaForCausalLM_O

def getHeadIndex(share_rate,similarity_values,match_indexs):
    dict_cor={}
    num_head=len(similarity_values[0])
    num_layer=len(similarity_values)+1 # layer 0 is not considered before
    for l_index, layer in enumerate(similarity_values):
        for h_index, head in enumerate(layer):
            dict_cor['Layer '+ str(num_layer-1-l_index) +' Head '+ str(h_index)]= head
    
    print('Sorting the similarity values...')        
    sorted_dict = dict(sorted(dict_cor.items(), key=itemgetter(1)))
    
    print('Choosing which heads will share the weight...')
    shared_head=int(share_rate*num_layer*num_head)
    replace_heads = list(sorted_dict.keys())[-shared_head:]
    print(replace_heads)

    replace_head_indexs=[[] for _ in range(num_layer)]
    cmp_head_indexs=[[] for _ in range(num_layer)]
    cmp_layer_indexs=[[] for _ in range(num_layer)]
    for head_name in replace_heads:
        layer_index=int(head_name.split('Layer ')[1].split(" Head ")[0])
        head_index=int(head_name.split(" Head ")[1])
        replace_head_indexs[layer_index].append(head_index)
    for index_i in range(num_layer):
        head_indexs = replace_head_indexs[index_i]
        for head_index in head_indexs:
            # print('Layer '+ str(index_i) +' Head '+ str(head_index))
            match_index = match_indexs[num_layer-index_i-1][head_index]
            # print(match_index) # match_head_index, match_layer_index
            cmp_head_indexs[index_i].append(match_index[0])
            cmp_layer_indexs[index_i].append(match_index[1]) 
    
    return replace_head_indexs, cmp_head_indexs, cmp_layer_indexs


def replaceLLaMA(replaced_model, model_ref, replace_head_indexs, cmp_head_indexs, cmp_layer_indexs, model_name):
    """
    model.model.layer[i].self_attn.W_pack: suitable for LLM structure like baichuan2
    """
    
    for i, block in enumerate(replaced_model.model.layers):
        head_indexs = replace_head_indexs[i]
        if len(head_indexs)!=0:
            for j,head_index in enumerate(head_indexs):
                # original block
                l_attention=block.self_attn
                # the block used to replace
                cmp_attention=model_ref.model.layers[cmp_layer_indexs[i][j]].self_attn
                
                split_size = l_attention.hidden_size
                head_dim = l_attention.head_dim
                start_index = head_index*head_dim
                cmp_start_index = cmp_head_indexs[i][j]*head_dim
                
                if "llama" in model_name.lower():
                    # query
                    replaced_model.model.layers[i].self_attn.q_proj.weight.data[start_index:start_index+head_dim,:]=cmp_attention.q_proj.weight.data[cmp_start_index:cmp_start_index+head_dim,:]
                    # key
                    replaced_model.model.layers[i].self_attn.k_proj.weight.data[start_index:start_index+head_dim,:]=cmp_attention.k_proj.weight.data[cmp_start_index:cmp_start_index+head_dim,:]
                    # value
                    replaced_model.model.layers[i].self_attn.v_proj.weight.data[start_index:start_index+head_dim,:]=cmp_attention.v_proj.weight.data[cmp_start_index:cmp_start_index+head_dim,:]
                    # output projection
                    replaced_model.model.layers[i].self_attn.o_proj.weight.data[:,start_index:start_index+head_dim]=cmp_attention.o_proj.weight.data[:,cmp_start_index:cmp_start_index+head_dim]
                elif "baichuan" in model_name.lower():
                    # query
                    replaced_model.model.layers[i].self_attn.W_pack.weight.data[start_index:start_index+head_dim,:]=cmp_attention.W_pack.weight.data[cmp_start_index:cmp_start_index+head_dim,:]
                    # key
                    replaced_model.model.layers[i].self_attn.W_pack.weight.data[split_size+start_index:split_size+start_index+head_dim,:]=cmp_attention.W_pack.weight.data[split_size+cmp_start_index:split_size+cmp_start_index+head_dim,:]
                    # value
                    replaced_model.model.layers[i].self_attn.W_pack.weight.data[split_size*2+start_index:split_size*2+start_index+head_dim,:]=cmp_attention.W_pack.weight.data[split_size*2+cmp_start_index:split_size*2+cmp_start_index+head_dim,:]
    
    return replaced_model

def replaceSeparate(replaced_model, model_ref, replace_head_indexs, cmp_head_indexs, cmp_layer_indexs, model_name, module_name):
    for i, block in enumerate(replaced_model.model.layers):
        head_indexs = replace_head_indexs[i]
        if len(head_indexs)!=0:
            for j,head_index in enumerate(head_indexs):
                # original block
                l_attention=block.self_attn
                # the block used to replace
                cmp_attention=model_ref.model.layers[cmp_layer_indexs[i][j]].self_attn
                
                split_size = l_attention.hidden_size
                head_dim = l_attention.head_dim
                start_index = head_index*head_dim
                cmp_start_index = cmp_head_indexs[i][j]*head_dim
                
                if "llama" in model_name.lower():
                    if module_name == "q":
                        replaced_model.model.layers[i].self_attn.q_proj.weight.data[start_index:start_index+head_dim,:]=cmp_attention.q_proj.weight.data[cmp_start_index:cmp_start_index+head_dim,:]
                    elif module_name == "k":
                        replaced_model.model.layers[i].self_attn.k_proj.weight.data[start_index:start_index+head_dim,:]=cmp_attention.k_proj.weight.data[cmp_start_index:cmp_start_index+head_dim,:]
                    elif module_name == "v":
                        replaced_model.model.layers[i].self_attn.v_proj.weight.data[start_index:start_index+head_dim,:]=cmp_attention.v_proj.weight.data[cmp_start_index:cmp_start_index+head_dim,:]
                elif "baichuan" in model_name.lower():
                    if module_name == "q":
                        replaced_model.model.layers[i].self_attn.W_pack.weight.data[start_index:start_index+head_dim,:]=cmp_attention.W_pack.weight.data[cmp_start_index:cmp_start_index+head_dim,:]
                    elif module_name == "k":
                        replaced_model.model.layers[i].self_attn.W_pack.weight.data[split_size+start_index:split_size+start_index+head_dim,:]=cmp_attention.W_pack.weight.data[split_size+cmp_start_index:split_size+cmp_start_index+head_dim,:]
                    elif module_name == "v":
                        replaced_model.model.layers[i].self_attn.W_pack.weight.data[split_size*2+start_index:split_size*2+start_index+head_dim,:]=cmp_attention.W_pack.weight.data[split_size*2+cmp_start_index:split_size*2+cmp_start_index+head_dim,:]
    
    return replaced_model

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default="llama-2-7b-hf", help='name of model')
    parser.add_argument('--model_path', type=str, default="llama-2-7b-hf", help='path of model ckpt')
    parser.add_argument('--share_rate', type=str, default="0.3", help='the ratio for weight sharing')
    parser.add_argument('--match_index_path', type=str, default="saved_npy/llama-2-7b-hf-qk/match_qk_indexs_llama-2-7b-hf_cos.npy", help='about candidate attention head pairs')
    parser.add_argument('--match_value_path', type=str, default="saved_npy/llama-2-7b-hf-qk/match_qk_values_llama-2-7b-hf_cos.npy", help='about matching function values (cosine similarity)')
    parser.add_argument('--output_folder', type=str, default="shared_models/", help='path of saved model')
    args = parser.parse_args()
    
    model_ref = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code = True)
    model_to_replace = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code = True) 
    tokenizer =  AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
    
    match_indexs = np.array(np.load(args.match_index_path))
    similarity_values = np.array(np.load(args.match_value_path))
    replace_head_indexs, cmp_head_indexs, cmp_layer_indexs = getHeadIndex(args.share_rate,similarity_values,match_indexs)
    replaced_model = replaceLLaMA(model_to_replace, model_ref,replace_head_indexs, cmp_head_indexs, cmp_layer_indexs, args.model_name)
    
    if not os.path.exists(args.output_folder): 
        os.makedirs(args.output_folder)
    replaced_model.save_pretrained(args.output_folder)
    tokenizer.save_pretrained(args.output_folder)
    