#
 #     MILIE: Modular & Iterative Multilingual Open Information Extraction
 #
 #
 #
 #     Authors: Deleted for purposes of anonymity
 #
 #     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 #
 # The software and its source code contain valuable trade secrets and shall be maintained in
 # confidence and treated as confidential information. The software may only be used for
 # evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 # license agreement or nondisclosure agreement with the proprietor of the software.
 # Any unauthorized publication, transfer to third parties, or duplication of the object or
 # source code---either totally or in part---is strictly prohibited.
 #
 #     Copyright (c) 2021 Proprietor: Deleted for purposes of anonymity
 #     All Rights Reserved.
 #
 # THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR
 # IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY
 # AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT
 # DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION.
 #
 # NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 # IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE
 # LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 # FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 # OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 # ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 # TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 # THE POSSIBILITY OF SUCH DAMAGES.
 #
 # For purposes of anonymity, the identity of the proprietor is not given herewith.
 # The identity of the proprietor will be given once the review of the
 # conference submission is completed.
 #
 # THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 #

import sys
import logging
import os
import math
from collections import defaultdict
import pdb
from tqdm import tqdm
import numpy as np
import scipy, spacy
from copy import deepcopy
import json, pickle
#from milie.evals.oie_conj import get_conj_split_objs, create_conj_obj_data
import torch
from torch.utils.data import DataLoader, SequentialSampler

from .util import write_list_to_file, compute_softmax, plot_attention

LOGGER = logging.getLogger(__name__)


def get_predictor(milie_args, output_losses=False):
    """
    Factory for returning various predictors

    :param milie_args: an instance of milieArguments
    :return: an instance of :py:class:`~milie.predict.GreedyPredictor` or a subclass
    """
    
    predictor = IterativeMultiHead(milie_args)
    
    return predictor


class GreedyPredictor(object):
    """
    Predicts the most likely token for each [MASK] until the first [SEP] is predicted in one step.
    Also support classification on the [CLS] token and token classification.
    """
    def __init__(self, milie_args, output_losses=False):
        """
        Initializes the Predictor using an instance of milieArguments.

        - If milie_args.plus_classify_sequence > 0, prediction on the [CLS] token is run
        - If milie_args.plus_classify_tokens > 0, predicton on each token is run
        - If milie_args.plus_generation > 0, generate words from masks
        - If milie_args.output_attentions, then the probability of the attention matrices are returned

        :param milie_args: instance of milieArguments
        """
        # if MetaHandler is used even in prediction, these numbers should be merged.
        self.plus_classify_sequence = milie_args.plus_classify_sequence
        self.plus_classify_tokens = milie_args.plus_classify_tokens
        self.plus_generation = milie_args.plus_generation
        self.output_attentions = milie_args.output_attentions
        self.output_losses = output_losses
        self.losses = defaultdict(list)
        self.plot_attention_indices = [0, 1, 2]  # milie_args.plot_attention_indices
        self.plot_attention_layers = range(12)  # milie_args.plot_attention_layers

    def get_model_output(self, model, batch, device):
        """
        For a set of inputs, get the output the model produces.

        :param model: the model
        :param batch: minibatch of features
        :param device: where to run the computation, e.g. gpu
        :return: a tuple, where an element of the tuple is none if the corresponding model does not predict this type

                - the minibatch of ids for generation
                - the minibatch of ids for classification on [CLS] token
                - the minibatch of ids for classification on the tokens
                - the attention probability matrices
                  list of torch.FloatTensor (one for each layer) of shape
                  (batch_size, num_heads, sequence_length, sequence_length):
                  Attentions weights after the attention softmax,
                  used to compute the weighted average in the self-attention heads.
        """
        input_ids, input_mask, segment_ids, gen_label_ids, \
            classify_id_cls, classify_id_tokens, _ = batch

        # expected tensor shapes
        # need to check masking.py, especially if you pass an empty list there instead of [-1] with the desired shape.
        #input_ids.size() == (batch_size, max_sequence_length)
        #input_mask.size() == (batch_size, max_sequence_length)
        #segment_ids.size() == (batch_size, max_sequence_length)
        #gen_label_ids.size() == (batch_size, num_gen_heads, max_sequence_length)
        #classify_id_cls.size() == (batch_size, num_cls_heads, max_sequence_length)
        #classify_id_tokens.size() == (batch_size, num_tok_heads, max_sequence_length)

        batch_gen_logits = [None] * self.plus_generation # if self.plus_generation == 0, then empty list
        batch_cls_logits = [None] * self.plus_classify_sequence # if self.plus_classify_sequence == 0, then empty list
        batch_tokens_logits = [None] * self.plus_classify_tokens # if self.plus_classify_tokens == 0, then empty list
        hidden_states = None
        attention_probs = None

        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)
        segment_ids = segment_ids.to(device)
        if self.output_losses:
            gen_label_ids = gen_label_ids.to(device)
            classify_id_cls = classify_id_cls.to(device)
            classify_id_tokens = classify_id_tokens.to(device)

        # TODO: if we have models other than VariableHeadsNSP, the call here needs to be handled
        # differently
        with torch.no_grad():
            if self.output_losses:
                outputs = model(input_ids, attention_mask=input_mask, token_type_ids=segment_ids,
                                labels_tok=classify_id_tokens, masked_lm_labels=gen_label_ids,
                                labels_cls=classify_id_cls)
            else:
                outputs = model(input_ids, attention_mask=input_mask, token_type_ids=segment_ids)

        counter = 0
        if self.output_losses:
            assert len(outputs) >= (1 + self.plus_generation + \
                                    self.plus_classify_sequence + self.plus_classify_tokens + 4)
            # take mean across different devices
            loss = outputs[0].mean() #if n_gpu > 1 else outputs[0]
            self.losses['total_loss'].append(loss.detach().cpu().item())
            counter += 1

        for i in range(self.plus_generation):  # if self.plus_generation == 0, then skipped
            batch_gen_logits[i] = outputs[counter]
            counter += 1
        for i in range(self.plus_classify_sequence):  # if self.plus_classify_sequence == 0, then skipped
            batch_cls_logits[i] = outputs[counter]
            counter += 1
        for i in range(self.plus_classify_tokens):  # if self.plus_classify_tokens == 0, then skipped
            batch_tokens_logits[i] = outputs[counter]
            counter += 1

        if self.output_losses:
            for j, score in enumerate(['gen_loss', 'cls_loss', 'tok_loss', 'perplexity']):
                # take mean across different devices
                avg_score = outputs[counter + j].mean() #if n_gpu > 1 else outputs[counter + j]
                self.losses[score].append(avg_score.detach().cpu().item())

        if self.output_attentions:
            # outputs[-1] is a list(tuple?) of length == num_layers
            attention_probs = [att.detach().cpu().numpy() for att in outputs[-1]]
        return batch_gen_logits, batch_cls_logits, batch_tokens_logits, attention_probs

    @staticmethod
    def predict_classification(logits):
        """
        Given a data set, index for current example and logits for classification on [CLS],
        output the most likely classification label

        :param logits: logits for the [CLS] token with dimension data_handler.num_labels_cls
        :return: the most likely class label
        """
        max_classify_index = np.argmax(logits).astype(np.int64)  # TODO need to modify for regression on [CLS]
        return max_classify_index







class IterativeMultiHead(GreedyPredictor):
    """
    Iteratively chooses heads and predicts from each token classification head.
    After each prediction the sentence is fed-back to the model. This can be quite expensive.
    Prediction options are:
    1. Fixed Prediction: (s,v,o or subject, verb, object)
    2. Random Prediction with Correction:
    3. lowest entropy with correction
    Inherits from GreedyPredictor for the functions:
    get_model_output, predict_token_classification and predict_classification
    """

    def __init__(self, milie_args):
        """
                Initializes the predictor.

                - If milie_args.plus_classify_sequence, prediction on the [CLS] token is run
                - If milie_args.plus_classify_tokens, predicton on each token is run
                - If milie_args.plus_generation, generate words from masks in Part B

                milie_args.predict is the prediction strategy, see above for the options of this predictor.

                :param milie_args: instance of milieArguments
                """
        super().__init__(milie_args)
        self.predict = milie_args.predict
        self.max_seq_len = milie_args.max_seq_length
        self.output_losses = False
        self.pattern = None

        #self.pattern_combo = dict()
        self.oracle = None
        if os.path.exists('/home/bkotnis/local/data/milie/models/oracle.pkl'):
            logging.info("Oracle found.")
            with open('/home/bkotnis/local/data/milie/models/oracle.pkl','rb') as f:
                self.oracle = pickle.load(f)
        self.nlp = spacy.load('en_core_web_lg', disable=['ner', 'lemmatizer'])

    def predict_dataset(self, data_handler, model, device):
        """
               Given a data set (via data_handler), run predictions,
               for options see description of the class.

               :param data_handler: an instance or a subclass instance of :py:class:`~milie.dataset_handlers.dataset_bitext.BitextHandler`
               :param model: the model that should predict
               :param device: the device to move the computation to
               :return: a tuple of:
                        - a list of token classifications (both Part A & Part B,
                          what is actually needed, can be decided by the dataset handler in
                          arrange_token_classify_output()
                        - the prediction order (relevant for maximum probability or minimum entropy)
               """
        all_results_gen = []
        all_results_classify = []
        all_results_classify_tokens = dict()
        all_prediction_order = None
        self._pattern_cache = dict()
        for batch in tqdm(data_handler.eval_dataloader, desc="Evaluating"):
            # iteratively get output until generation for everything is done
            predicted_triples = self.iteratively_operate_on_batch(data_handler, model, batch, device)
            all_results_classify_tokens.update(predicted_triples)

        return all_results_gen, all_results_classify, all_results_classify_tokens, \
               [], all_prediction_order

    def add_recall_triples(self, old_triples, recall_triples):
        trp_set = set(old_triples)
        recall_triples = set(recall_triples)
        return list(trp_set.union(recall_triples))

    def iteratively_operate_on_batch(self, data_handler, model, batch, device):
        """
        Given a batch, iteratively calls the model until all instances are done generating.

        :param data_handler: an instance or a subclass instance of :py:class:`~milie.dataset_handlers.dataset_bitext.BitextHandler`
        :param example_indices: keeps track of which overall example we are operating on
        :param model: the model that should predict
        :param input_ids: minibatch of input ids (see corresponding subclass instance of
                :py:class:`~milie.masking.Masking`)
        :param segment_ids: minibatch of segment ids (see corresponding subclass instance of
                :py:class:`~milie.masking.Masking`)
        :param input_mask: minibatch of input masks (see corresponding subclass instance of
                :py:class:`~milie.masking.Masking`)
        :param device: the device to move the computation to
        :return: a tuple of

                 - a list of all generated texts for this batch
                 - the prediciton order of the invidual texts
                   (only applicable to 'max_probability' and 'min_entropy')
                 - attention
        """

        #if self.oracle is not None:
            #logging.info("Using Oracle")

        #patterns = [(2,1,0,3)]
        #patterns = [(0, 1, 2,3), (0,2,1,3), (1, 0, 2,3), (1,2,0,3), (2, 1, 0,3), (2, 0, 1,3)]
        recall_pat = [(1, 2, 0 ,3),(1,0,2,3), (0,1,2,3),(0,2,1,3),(2, 1, 0,3), (2, 0, 1,3)]
        self._pattern_cache = dict()
        all_prediction_order, all_triples = [], dict()
        input_ids, input_mask, segment_ids, gen_label_ids, \
        classify_id_cls, classify_id_tokens, example_indices = batch
        _, _, batch_token_logits, _ = self.get_model_output(model, batch, device)

        batch_token_logits = np.asarray([x.detach().cpu().numpy() for x in batch_token_logits])
        example_indices=example_indices.tolist()
        for i, ex_id in enumerate(example_indices):
            batch_i = [input_ids[i], input_mask[i], segment_ids[i],
                     gen_label_ids[i], classify_id_cls[i], classify_id_tokens[i], ex_id]
            token_logits = batch_token_logits[:, i, :, :]

            if self.oracle is not None:
                pat = self.oracle[i]
                patterns = [(int(pat[0]),int(pat[1]),int(pat[2]),3)]
            example = deepcopy(data_handler.examples[ex_id])
            #example.pred_elem = None
            #example.pred_map = None
            batch_cp_i = deepcopy(batch_i)
            #triples = self.extract_triples(patterns, model, data_handler,example, token_logits, batch_cp_i, 0.0, device)
            #triples = self.break_objects(triples, example, token_logits, batch_cp_i, model, data_handler, device,
            #                   break_trp=False)

            triples = []
            i = len(recall_pat)
            while True:
                recall_triples = self.extract_triples(recall_pat, model, data_handler, example, token_logits, batch_cp_i, i/len(recall_pat), device)
                recall_triples = self.break_objects(recall_triples, example,token_logits, batch_cp_i,model,data_handler,device, break_trp=False)
                triples = self.add_recall_triples(triples, recall_triples)
                i-=1
                if i<=1:
                    break

            triples = [list(trp) for trp in triples]
            triples = [trp + [1.0] for trp in triples]
            sent = data_handler.examples[ex_id].sentence
            all_triples[sent] = triples
            self._pattern_cache = dict()

        return all_triples



    def break_objects(self, triples, example,token_logits, batch,model, data_handler, device, break_trp=True):
        new_triples = []
        for trp in triples:
            if not isinstance(trp[-1], dict):
                new_triples.append(trp)
                continue
            tokens = trp[-1]['tok']
            new_triples.append(trp[:-1])
            if break_trp:
                line = create_conj_obj_data(tokens)
                if line.strip() == '':
                    continue
                triples = get_conj_split_objs(line,self.nlp)
                for elem in triples:
                    s,r,o = elem
                    batch_cp = deepcopy(batch)
                    example.pred_elem = {'0':[s]}
                    example.pred_map = {'0-2':{s: [o]}}
                    rel_triples = self.predict_fixed(
                        model, data_handler, example, token_logits, batch_cp, device, pattern=(0, 2, 1, 3), nary=False)
                    rel_triples = [x[0:3] for x in rel_triples]
                    new_triples.extend(rel_triples)
                    example.pred_elem = None
                    example.pred_map = None
                    #new_triples.extend(triples)
        return [tuple(x) for x in new_triples]


    def create_rel_dict(self, triples):
        rel_object = defaultdict(set)
        rel_subject = defaultdict(set)
        for trp in triples:
            rel_object[trp[:2]].update(set(trp[2:]))
            rel_subject[trp[:2]].add(trp[0])
        return rel_object, rel_subject

    def extract_triples(self, patterns, model, data_handler, example, token_logits, batch, threshold, device):
        def merge_triples(all_triples):
            num_patterns = len(patterns)
            trp_freq = defaultdict(int)
            token_dict = dict()
            for trp in all_triples:
                if isinstance(trp[-1], dict):
                    o_trp = trp[:-1]
                    token_dict[tuple(o_trp)] = trp[-1]
                else:
                    o_trp = trp
                o_trp = tuple(o_trp)
                trp_freq[o_trp] +=1
            all_triples = list(set([tuple(list(trp)) for trp,freq in trp_freq.items() if freq/num_patterns >= threshold]))
            for count,trp in enumerate(all_triples):
                if trp in token_dict:
                    tokens = token_dict[trp]
                    trp = list(trp)
                    all_triples[count] = trp + [tokens]
            return all_triples

        mrg_trps = []
        for pat in patterns:
            if pat not in self._pattern_cache:
                trps = self.predict_fixed(
                        model, data_handler, example, token_logits, batch, device, pattern=pat)
                self._pattern_cache[pat] = trps
            trps = self._pattern_cache[pat]
            mrg_trps.extend(trps)
        triples = merge_triples(mrg_trps)
        return triples


    def forward_pass(self, model, data_handler, input_ids, input_mask, seg_ids, example, batch, head_id, device):
        # create new batch
        #new_batch = [b for b in batch]
        input_ids = torch.tensor(input_ids[:]).to(device)
        input_mask = input_mask.to(device)
        seg_ids = torch.tensor(seg_ids[:]).to(device)
        if len(input_ids)!=1:
            input_ids = input_ids.unsqueeze(0)
            input_mask = input_mask.unsqueeze(0)
            seg_ids = seg_ids.unsqueeze(0)
        new_batch = [b for b in batch]
        new_batch[0], new_batch[1] , new_batch[2] = input_ids, input_mask, seg_ids
        # Step 2: predict verb and reform the sentence for that verb
        _,_, batch_tokens_logits, _ = self.get_model_output(model, new_batch, device)
        classify_tokens = torch.argmax(batch_tokens_logits[head_id], dim=2)[0].tolist()
        #classify_cls = torch.argmax(batch_cls_logits[0], dim=1).tolist()
        preds = data_handler.arrange_token_classify_output(example,
                                                           classify_tokens, input_ids.tolist())
        return preds


    def predict_fixed(self, model,data_handler, example, token_logits, batch, device, pattern=(0,1,2), nary=True):
        '''
        for each head, predict the token and feed it back in the model
        :param data_handler:
        :param token_logits:
        :param example_index:
        :param input_id:
        :param input_mask:
        :param segment_id:
        :return:
        '''

        def insert_tags(input_ids, input_mask, substr, tags):
            try:
                idx = input_ids.index(substr)
            except:
                try:
                    idx = input_ids.index(substr[0])
                except:
                    return [],[]
            tagged_input = input_ids[:idx] + tags + substr + tags + input_ids[idx + len(substr):]
            tagged_input = tagged_input[:self.max_seq_len]
            if 102 in tagged_input:
                idx = tagged_input.index(102)
                input_mask[:idx + 2] = 1
            else:
                input_mask[:] = 1
            return tagged_input, input_mask

        # get subject, predicate and object location
        subj_pos, pred_pos, obj_pos, arg_pos = -1,-1,-1,-1
        for count, code in enumerate(pattern):
            if code==0:
                subj_pos = count
            elif code ==1:
                pred_pos = count
            elif code == 2:
                obj_pos = count
            elif code == 3:
                arg_pos = count
            else:
                raise RuntimeError('Code not found. Subj=0, Pred=1, Obj=2, Args=3')
        tokens = ''
        #Step 1: predict level 1 elems and reform the sentence
        classify_tokens = np.argmax(token_logits[pattern[0]],axis=1).tolist()
        input_ids, input_mask, segment_ids, _, _,_, example_indices = batch
        lev1_elems, lev1_ids, _ = data_handler.arrange_token_classify_output(example,
                                                   classify_tokens, input_ids.tolist())
        tag_map = {0: data_handler.subj_tags, 1: data_handler.pred_tags, 2: data_handler.obj_tags}
        pattern_tags = [tag_map[x] for x in pattern[:3]]
        if example.pred_elem is not None and str(pattern[0]) in example.pred_elem:
            #lev1_elems, lev1_ids = [],[]
            pred_elems = example.pred_elem[str(pattern[0])]
            assert isinstance(pred_elems, list)
            if len(pred_elems) > 0:
                lev1_elems, lev1_ids = [], []
            pred_ids = [data_handler._sent_ids(elem) for elem in pred_elems]
            for elem, id in zip(pred_elems,pred_ids):
                if elem not in lev1_elems:
                    lev1_elems.append(elem)
                    lev1_ids.append(id)
        triples = []
        # Step 2: predict predicates and reform the sentence
        for lev1, lev1_id in zip(lev1_elems, lev1_ids):
            lev1_input_id, lev1_mask = insert_tags(input_ids.tolist(), input_mask, lev1_id, pattern_tags[0])
            if len(lev1_input_id)<=0:
                continue
            lev1_segment_ids = data_handler.get_segment_ids(example, lev1_input_id)
            lev2_elems, lev2_ids,_ = self.forward_pass(model, data_handler,
                                                     lev1_input_id, lev1_mask,lev1_segment_ids, example_indices, batch, pattern[1],
                                                     device)
            # Step 3: predict lev3 for the lev2 elem and obtain the triple
            # Add from pred_map if present
            if example.pred_map is not None:
                key = '-'.join([str(pattern[0]), str(pattern[1])])
                if key in example.pred_map:
                    #lev2_elems, lev2_ids = [], []
                    pred_map = example.pred_map[key]
                    pred_elems = pred_map.get(lev1,[])
                    if len(pred_elems)>0:
                        lev2_elems, lev2_ids = [], []

                    pred_ids = [data_handler._sent_ids(elem) for elem in pred_elems]
                    for elem, id in zip(pred_elems, pred_ids):
                        if elem not in lev2_elems:
                            lev2_elems.append(elem)
                            lev2_ids.append(id)

            for lev2,lev2_id in zip(lev2_elems,lev2_ids):
                lev2_input_id, lev2_mask = insert_tags(lev1_input_id, input_mask, lev2_id, pattern_tags[1])
                if len(lev2_input_id) <= 0:
                    continue
                lev2_segment_ids = data_handler.get_segment_ids(example, lev2_input_id)
                lev3_elems, lev3_ids,_ = self.forward_pass(model, data_handler,
                                                  lev2_input_id, lev2_mask,lev2_segment_ids, example_indices, batch, pattern[2], device)
                if len(lev3_elems) <= 0:
                    trp = [str(lev1).replace('[UNK]', '').strip(), str(lev2).replace('[UNK]', '').strip(), '']
                    triples.append([trp[subj_pos], trp[pred_pos], trp[obj_pos]])
                    continue

                for lev3, lev3_id in zip(lev3_elems, lev3_ids):
                    lev3_input_id, lev3_mask = insert_tags(lev2_input_id, input_mask, lev3_id, pattern_tags[2])
                    if len(lev3_input_id) <= 0:
                        trp = [str(lev1).replace('[UNK]', '').strip(), str(lev2).replace('[UNK]', '').strip(), str(lev3).replace('[UNK]', '').strip()]
                        triples.append([trp[subj_pos], trp[pred_pos], trp[obj_pos]])
                        continue
                    lev3_segment_ids = data_handler.get_segment_ids(example, lev3_input_id)
                    args,_,tokens = self.forward_pass(model, data_handler,
                                                             lev3_input_id, lev3_mask, lev3_segment_ids,
                                                             example_indices, batch, pattern[3], device)
                    #if cls_3[0]==0:
                        #continue
                    trp = [str(lev1).replace('[UNK]', '').strip(), str(lev2).replace('[UNK]', '').strip(),
                           str(lev3).replace('[UNK]', '').strip()]
                    trp = [trp[subj_pos], trp[pred_pos], trp[obj_pos]]
                    if len(args)>0 and nary:
                        args = [str(x).replace('[UNK]', '').strip() for x in args]
                        trp += args
                    trp = trp + [{'tok': tokens}]
                    triples.append(trp)

        def nary2binary(triples):
            binary_triples = []
            for trp in triples:
                if len(trp) <= 4:
                    binary_triples.append(trp)
                else:
                    parts = trp
                    s,r,o = parts[0:3]
                    binary_triples.append((s, r, o))
                    for arg in parts[3:]:
                        if isinstance(arg,dict):
                            continue
                        batch_cp = deepcopy(batch)
                        example.pred_elem = {'0':[s]}
                        example.pred_map = {'0-2':{s: [arg]}}
                        rel_triples = self.predict_fixed(
                            model, data_handler, example, token_logits, batch_cp, device, pattern=(0,2,1,3), nary=False)
                        binary_triples.extend(rel_triples)
                        example.pred_elem = None
                        example.pred_map = None
            return binary_triples
        filt_triples = []
        for trp in triples:
            num_blanks = len([1 for x in trp if x==''])
            #if self.apos_mistakes(trp):
                #continue
            if num_blanks<1:
                filt_triples.append(trp)
        binary_triples = nary2binary(filt_triples)
        return binary_triples
        #return filt_triples

    def apos_mistakes(self, trp):
        for elem in trp:
            if " 't" in elem or " 'd" in elem or " ' " in elem or " '' "  in elem or "' " in elem:
                return True
        return False


