import numpy as np
import tensorflow as tf
import pickle
from os import path

class cwe_term:

    #Constructor

    def __init__(self, term, contexts, model, tokenizer, save_dir, layers = None, hidden_state_target = None):

        self.term = term
        self.first_subtoken_cwes = {}
        self.last_subtoken_cwes = {}
        self.mean_subtoken_cwes = {}
        self.max_subtoken_cwes = {}
        self.concat_subtoken_cwes = {}
        self.subtoken_cwe_dict = {'First':  self.first_subtoken_cwes, 'Last': self.last_subtoken_cwes, 'Mean': self.mean_subtoken_cwes, 'Max': self.max_subtoken_cwes, "Concat": self.concat_subtoken_cwes}
        self.contexts = contexts
        self.extract_representations(model, tokenizer, layers, hidden_state_target)

        self.save_term(save_dir)
        print(f'Configured term {self.term}-{self.model_type} and saved to {save_dir}')

    #Methods for obtaining CWE vectors
    
    def extract_representations(self, model, tokenizer, layers = None, hidden_state_target = None):

        if not layers:
            layers = [i for i in range(0, 13)]

        if not hidden_state_target:
            hidden_state_target = -1

        self.model_type = model.name

        self.encoding = tokenizer.encode(self.term, add_special_tokens = False, add_prefix_space = True)
        self.num_subtokens = len(self.encoding)

        self.cwe_by_layer = {i: [] for i in layers}

        self.surrounding_cwe_by_layer = {i: [] for i in layers}

        for context in self.contexts:

            encoded_context = tokenizer.encode(context)
            positions = []

            if len(self.encoding) == 1:
                positions = [encoded_context.index(self.encoding[0])]
            
            else:
                for i in range(len(encoded_context)):
                    if self.encoding[0] == encoded_context[i] and self.encoding[1:] == encoded_context[i+1:i+len(self.encoding)]:
                        positions = [j for j in range(i, i + len(self.encoding))]
            
            inputs = tokenizer(context, return_tensors = 'tf')
            output_ = model(inputs)           
            np.squeeze(output_)

            for layer in layers:

                embeddings = []

                for position in positions:
                    target_embedding = output_[hidden_state_target][layer][0][position]
                    embeddings.append(target_embedding)
                                    
                self.cwe_by_layer[layer].append(embeddings)

    #Creates subtoken cwe dictionary
    def form_subtoken_cwes(self, subtoken_type = 'Mean'):
        
        self.subtoken_cwe_map = {}

        if len(self.encoding) == 1:

            for layer, cwes in self.cwe_by_layer.items():
                self.first_subtoken_cwes[layer] = [i[0] for i in cwes]
                self.last_subtoken_cwes[layer] = [i[0] for i in cwes]
                self.mean_subtoken_cwes[layer] = [i[0] for i in cwes]
                self.max_subtoken_cwes[layer] = [i[0] for i in cwes]
                self.concat_subtoken_cwes[layer] = [i[0] for i in cwes]

            self.subtoken_cwe_map['First'] = self.first_subtoken_cwes
            self.subtoken_cwe_map['Last'] = self.last_subtoken_cwes
            self.subtoken_cwe_map['Mean'] = self.mean_subtoken_cwes
            self.subtoken_cwe_map['Max'] = self.max_subtoken_cwes
            self.subtoken_cwe_map['Concat'] = self.concat_subtoken_cwes

        else:
            for layer, cwes in self.cwe_by_layer.items():

                if subtoken_type == 'First':
                    first_cwe_list = [cwe[0] for cwe in cwes]
                    self.first_subtoken_cwes[layer] = first_cwe_list

                if subtoken_type == 'Last':
                    last_cwe_list = [cwe[-1] for cwe in cwes]
                    self.last_subtoken_cwes[layer] = last_cwe_list

                if subtoken_type == 'Mean':
                    mean_cwe_list = []
                    for cwe in cwes:
                        cwe_arr = np.array([i for i in cwe])
                        mean_cwe = np.mean(cwe_arr, axis = 0)
                        tensor_cwe = tf.convert_to_tensor(mean_cwe)
                        mean_cwe_list.append(tensor_cwe)
                    self.mean_subtoken_cwes[layer] = mean_cwe_list

                if subtoken_type == 'Max':
                    max_cwe_list = []
                    for cwe in cwes:
                        cwe_arr = np.array([i for i in cwe])
                        max_cwe = np.nanmax(cwe_arr, axis = 0)
                        tensor_cwe = tf.convert_to_tensor(max_cwe)
                        max_cwe_list.append(tensor_cwe)
                    self.max_subtoken_cwes[layer] = max_cwe_list

                if subtoken_type == 'Concat':
                    concat_cwe_list = []
                    for cwe in cwes:
                        concat_cwe = np.array(cwe[0])
                        if len(cwe) > 1:
                            for subtoken in cwe[1:]:
                                concat_cwe = np.append(concat_cwe, subtoken)
                        concat_tensor = tf.convert_to_tensor(concat_cwe)
                        concat_cwe_list.append(concat_tensor)
                    self.concat_subtoken_cwes[layer] = concat_cwe_list

            if subtoken_type == 'First':
                self.subtoken_cwe_map['First'] = self.first_subtoken_cwes
            if subtoken_type == 'Last':
                self.subtoken_cwe_map['Last'] = self.last_subtoken_cwes
            if subtoken_type == 'Mean':
                self.subtoken_cwe_map['Mean'] = self.mean_subtoken_cwes
            if subtoken_type == 'Max':
                self.subtoken_cwe_map['Max'] = self.max_subtoken_cwes
            if subtoken_type == 'Concat':
                self.subtoken_cwe_map['Concat'] = self.concat_subtoken_cwes

    def save_term(self, directory):

        with open(path.join(directory, f'{self.term}_{self.model_type}-object.pkl'), 'wb') as pickle_writer:
            pickle.dump(self, pickle_writer)