from modeling_extensions import model_getter
from modeling_extensions.modeling_adapters import AdapterBertModel, AdapterRobertaModel, AdapterXLMRobertaModel
from modeling_extensions.modeling_adapters import BottleneckAdapterBertConfig, BottleneckAdapterRobertaConfig, BottleneckAdapterXLMRobertaConfig
from transformers import BertConfig, RobertaConfig, XLMRobertaConfig
import helper
import config as c
import random
from transformers import BertModel, RobertaModel, XLMRobertaModel
import numpy as np
import os
import data_provider

def mean_center(mat):
        mean = np.mean(mat, axis = 0)
        return mat - mean

def cka(x, y):
    x = mean_center(x)
    y = mean_center(y) 

    nom =  np.linalg.norm(np.matmul(np.transpose(y), x))
    denom = np.linalg.norm(np.matmul(np.transpose(x), x)) * np.linalg.norm(np.matmul(np.transpose(y), y))
    return (nom * nom) / denom

def get_representations(sentences):
    tokenizer = model_getter.get_tokenizer()
    if c.original_transformer.startswith("bert-"):
        if c.adapter:
            config = BottleneckAdapterBertConfig.from_pretrained(c.pretrained_transformer)
            config.output_hidden_states = True
            transformer = AdapterBertModel.from_pretrained(c.pretrained_transformer, config = config) # , output_hidden_states=True    
        else:
            config = BertConfig.from_pretrained(c.pretrained_transformer)    
            config.output_hidden_states = True
            transformer = BertModel.from_pretrained(c.pretrained_transformer, config = config)
    elif c.original_transformer.startswith("roberta-"):
        if c.adapter:
            config = BottleneckAdapterRobertaConfig.from_pretrained(c.pretrained_transformer)   
            config.output_hidden_states = True
            transformer = AdapterRobertaModel.from_pretrained(c.pretrained_transformer, config = config)
        else:
            config = RobertaConfig.from_pretrained(c.pretrained_transformer)   
            config.output_hidden_states = True
            transformer = RobertaModel.from_pretrained(c.pretrained_transformer, config = config)
    elif c.original_transformer.startswith("xlm-roberta-"):
        if c.adapter:
            config = BottleneckAdapterXLMRobertaConfig.from_pretrained(c.pretrained_transformer)   
            config.output_hidden_states = True
            transformer = AdapterXLMRobertaModel.from_pretrained(c.pretrained_transformer, config = config)
        else:
            config = XLMRobertaConfig.from_pretrained(c.pretrained_transformer)   
            config.output_hidden_states = True
            transformer = XLMRobertaModel.from_pretrained(c.pretrained_transformer, config = config)
    
    transformer.cuda()
    is_roberta = "roberta" in c.original_transformer

    vectors = {}
    vectors["cls"] = []
    for i in range(1, 13):
        vectors[str(i)] = []

    cnt = 0
    for s in sentences:
        cnt += 1
        if cnt % 100 == 0:
            print(cnt)
        
        input = tokenizer.encode_plus(s, add_special_tokens=c.add_special_tokens, max_length=c.max_length, return_tensors='pt')
        if is_roberta:
            outputs = transformer(input_ids = input['input_ids'].cuda(), attention_mask = input['attention_mask'].cuda())
        else:
            outputs = transformer(input_ids = input['input_ids'].cuda(), attention_mask = input['attention_mask'].cuda(), token_type_ids = input['token_type_ids'].cuda())

        vectors["cls"].append(outputs[1].detach().cpu().numpy()[0])
        for i in range(1, 13):
            vecs = outputs[2][i].detach().cpu().numpy()[0][1:-1]
            vectors[str(i)].append(np.mean(vecs, axis = 0))

    for k in vectors:    
        vectors[k] = np.array(vectors[k])
    return vectors

eval = c.topology_eval
if eval:
    for layer in range(1, 13):
        vecs1 = np.load(c.topology_first_variant_path + "/layer_" + str(layer) + ".npy")
        vecs2 = np.load(c.topology_second_variant_path + "/layer_" + str(layer) + ".npy")
        score = cka(vecs1, vecs2)
        print(round(score, 2))
else:
    sentences, _, _, _ = data_provider.load_ud_treebank(c.in_file, None, c.max_word_len)
    #random.shuffle(sentences)
    sentences = [" ".join(s) for s in sentences]

    vectors = get_representations(sentences)

    for i in range(1, 13):
        np.save(os.path.join(c.outpath, "layer_" + str(i) + ".npy"), vectors[str(i)])



        