import stanfordnlp
from nltk.tokenize import sent_tokenize
import json
import time
import sys
from nltk.stem.wordnet import WordNetLemmatizer
# from polyjuice import Polyjuice
from polyjuice.polyjuice_wrapper import Polyjuice
import ilm.tokenize_util
from tqdm import tqdm
import re
import torch
import random 
random.seed(42)
torch.manual_seed(42)

# stanfordnlp.download('en')   # This downloads the English models for the neural pipeline
nlp = stanfordnlp.Pipeline() # This sets up a default neural pipeline in English


# for sent in doc.sentences:
#     # print('---')
#     # print(sent.words_string())
#     # print('===')
#     # print(sent.tokens_string())
#     for dep_edge in sent.dependencies:
#         # print(dep_edge[2].text, dep_edge[0].index, dep_edge[1])
#         print('dep_edge', dep_edge)
#         if dep_edge[0].index == '0':
#             assert dep_edge[1] == 'root'
#             print('root=', dep_edge[2].text)

def _find_answer_in_context(answer_text: str, context: str):
    """Finds all instances of the `answer_text` in the context passage.
    
    Returns a list of (start index, end index) tuples.
    """
    context_spans = [
        (m.start(), m.end())
        for m in re.finditer(re.escape(answer_text.lower()), context.lower())
    ]
    return context_spans


mark=""

def reconstruct_negation(sent, root_tok, root_idx):
    # reconstruction
    words = [dep_edge[2].text for dep_edge in sent.dependencies]
    org_words = list(words)
    org_words[root_idx] = mark+'%s'%org_words[root_idx]+mark
    # reconstruct new sentence for modality
    if root_tok[2].xpos in ['VB', 'VBP', 'VBZ']:
        if root_tok[2].text in ['is', 'am', 'are']:
            words[root_idx] = mark + '%s'%(words[root_idx] + ' not') + mark
        elif root_tok[2].xpos == 'VBZ':
            words[root_idx] = mark + '%s'%('does not ' + WordNetLemmatizer().lemmatize(words[root_idx],'v')) + mark
        elif root_tok[2].xpos == 'VBP':
            words[root_idx] = mark + '%s'%('do not ' + WordNetLemmatizer().lemmatize(words[root_idx],'v')) + mark
        elif words[root_idx-1] in ['can', 'could', 'dare', 'do', 'have', 
                                'may', 'might', 'must', 'need', 'ought', 
                                'shall', 'should', 'will', 'would']: # auxiliary verbs
            words[root_idx] = mark + '%s'%('not ' + WordNetLemmatizer().lemmatize(words[root_idx],'v')) + mark
        else:  # present tense or base verb
            words[root_idx] = mark + '%s'%('do not ' + WordNetLemmatizer().lemmatize(words[root_idx],'v')) + mark
    elif root_tok[2].xpos in ['VBD']:
        if root_tok[2].text in ['was', 'were']:
            words[root_idx] = mark + '%s'%(words[root_idx] + ' not') + mark
        else:
            words[root_idx] = mark + '%s'%('did not ' + WordNetLemmatizer().lemmatize(words[root_idx],'v')) + mark
    elif root_tok[2].xpos in ['VBN', 'VBG']:  # if p.p.
        if words[root_idx-1] in ['was', 'were', 'is', 'are', 'am']: # if be verbs
            words[root_idx-1] = mark + words[root_idx-1] + ' not'
            words[root_idx] = words[root_idx]  + mark
        elif words[root_idx-1] in ['have', 'had', 'has']: # if be verbs
            words[root_idx-1] = mark + words[root_idx-1] + ' not' 
            words[root_idx] = words[root_idx]  + mark
        elif words[root_idx-1] in ['be']:  # 
            if words[root_idx-2] in ['n\'t', 'not']:
                pass
            else:
                words[root_idx-1] = mark + 'not ' + words[root_idx-1]
                words[root_idx] = words[root_idx]  + mark
        elif words[root_idx-1] == 'been':
            if words[root_idx-2] in ['n\'t', 'not']:
                pass
            elif words[root_idx-2] in ['have', 'had', 'has']:
                words[root_idx-2] = mark + words[root_idx-2] + ' not' 
                words[root_idx] = words[root_idx]  + mark
            elif words[root_idx-3] in ['have', 'had', 'has']:
                words[root_idx-3] = mark + words[root_idx-3] + ' not'  
                words[root_idx] = words[root_idx]  + mark
            else:
                sent.print_dependencies()
                print('1')
                return '', ''
                raise NotImplementedError
        elif root_idx >=2 and words[root_idx-2] in ['was', 'were', 'is', 'are', 'am']: # if be verbs, skipped
            words[root_idx-2] = mark + words[root_idx-2] + ' not'
            words[root_idx] = words[root_idx] + mark
        elif root_idx >= 3 and words[root_idx-3] in ['was', 'were', 'is', 'are', 'am']:
            words[root_idx-3] = mark + words[root_idx-3] + ' not' 
            words[root_idx] = words[root_idx] + mark
        else:
            sent.print_dependencies()
            print('2')
            return '', ''
            raise NotImplementedError
    else:
        sent.print_dependencies()
        print('3')
        raise NotImplementedErrore

    return ' '.join(words), ' '.join(org_words)

def reconstruct_modality(sent, root_tok, root_idx):
    # reconstruction
    words = [dep_edge[2].text for dep_edge in sent.dependencies]
    # reconstruct new sentence for modality
    if root_tok[2].xpos in ['VB', 'VBP', 'VBZ']:
        if root_tok[2].text in ['is', 'am', 'are']:
            words[root_idx] = 'may be'
        else:
            words[root_idx] = 'may ' + WordNetLemmatizer().lemmatize(words[root_idx],'v') 
    elif root_tok[2].xpos in ['VBD']:
        if root_tok[2].text in ['was', 'were']:
            words[root_idx] = 'might be'
        else:
            words[root_idx] = 'might ' + WordNetLemmatizer().lemmatize(words[root_idx],'v') 
    elif root_tok[2].xpos in ['VBN', 'VBG']:  # if p.p.
        if words[root_idx-1] in ['was', 'were', 'is', 'are', 'am']: # if be verbs
            words[root_idx-1] = 'might be'
        elif words[root_idx-1] in ['have', 'had', 'has']: # if be verbs
            words[root_idx-1] = 'might ' + WordNetLemmatizer().lemmatize(words[root_idx-1],'v') 
        elif words[root_idx-1] in ['be']:  # 
            if words[root_idx-2] in ['n\'t', 'not']:
                words[root_idx-3] = 'might'
                words[root_idx-2] = 'not'
                words[root_idx-1] = 'be'
            else:
                words[root_idx-2] = 'might'
        elif words[root_idx-1] == 'been':
            if words[root_idx-2] in ['n\'t', 'not']:
                words[root_idx-3] = 'might'
                words[root_idx-2] = 'not have'
                words[root_idx-1] = 'been'
            elif words[root_idx-2] in ['have', 'had', 'has']:
                words[root_idx-2] = 'might ' + WordNetLemmatizer().lemmatize(words[root_idx-2],'v') 
            elif words[root_idx-3] in ['have', 'had', 'has']:
                words[root_idx-3] = 'might ' + WordNetLemmatizer().lemmatize(words[root_idx-2],'v') 
            else:
                sent.print_dependencies()
                # print('111111')
                return '', ''
                raise NotImplementedError
        elif root_idx >=2 and words[root_idx-2] in ['was', 'were', 'is', 'are', 'am']: # if be verbs, skipped
            words[root_idx-2] = 'might'
            words[root_idx] = 'be ' + words[root_idx]
            # words[root_idx] = 'be ' + WordNetLemmatizer().lemmatize(words[root_idx],'v')
        elif root_idx >= 3 and words[root_idx-3] in ['was', 'were', 'is', 'are', 'am']:
            words[root_idx-3] = 'might'
            words[root_idx] = 'be ' + words[root_idx]
            # words[root_idx] = 'be ' + WordNetLemmatizer().lemmatize(words[root_idx],'v') 
        else:
            sent.print_dependencies()
            # print('22222')
            return '', ''
            raise NotImplementedError
    else:
        sent.print_dependencies()
        # print('333333')
        raise NotImplementedError

    # print(' '.join(words))
    return ' '.join(words), ''


# need to be fixed
def reconstruct_future(sent, root_tok, root_idx):
    # reconstruction
    words = [dep_edge[2].text for dep_edge in sent.dependencies]
    # reconstruct new sentence for modality
    if root_tok[2].xpos in ['VB', 'VBP', 'VBZ']:
        if root_tok[2].text in ['is', 'am', 'are']:
            words[root_idx] = 'will be'
        elif words[root_idx-1] in ['could', 'dare', 'do', 'have', 
                                'may', 'might', 'must', 'need', 'ought', 
                                'shall', 'should', 'will', 'would']: # auxiliary verbs
            words[root_idx] = 'will be'
        else:
            words[root_idx] = 'will ' + WordNetLemmatizer().lemmatize(words[root_idx],'v') 
    elif root_tok[2].xpos in ['VBD']:
        if root_tok[2].text in ['was', 'were']:
            words[root_idx] = 'will be'
        else:
            words[root_idx] = 'will ' + WordNetLemmatizer().lemmatize(words[root_idx],'v') 
    elif root_tok[2].xpos in ['VBN', 'VBG']:  # if p.p.
        if words[root_idx-1] in ['was', 'were', 'is', 'are', 'am']: # if be verbs
            words[root_idx-1] = 'will be'
        elif words[root_idx-1] in ['have', 'had', 'has']: # if be verbs
            words[root_idx-1] = 'will ' + WordNetLemmatizer().lemmatize(words[root_idx-1],'v') 
        elif words[root_idx-1] in ['be']:  # 
            if words[root_idx-2] in ['n\'t', 'not']:
                words[root_idx-3] = 'will'
                words[root_idx-2] = 'not'
                words[root_idx-1] = 'be'
            else:
                words[root_idx-2] = 'will'
        elif words[root_idx-1] == 'been':
            if words[root_idx-2] in ['n\'t', 'not']:
                words[root_idx-3] = 'will'
                words[root_idx-2] = 'not have'
                words[root_idx-1] = 'been'
            elif words[root_idx-2] in ['have', 'had', 'has']:
                words[root_idx-2] = 'will ' + WordNetLemmatizer().lemmatize(words[root_idx-2],'v') 
            elif words[root_idx-3] in ['have', 'had', 'has']:
                words[root_idx-3] = 'will ' + WordNetLemmatizer().lemmatize(words[root_idx-2],'v') 
            else:
                sent.print_dependencies()
                # print('1')
                return '', ''
                raise NotImplementedError
        elif root_idx >=2 and words[root_idx-2] in ['was', 'were', 'is', 'are', 'am']: # if be verbs, skipped
            # print('foundfound')
            words[root_idx-2] = 'will'
            words[root_idx] = 'be ' + words[root_idx]
            # words[root_idx] = 'be ' + WordNetLemmatizer().lemmatize(words[root_idx],'v') 
        elif root_idx >= 3 and words[root_idx-3] in ['was', 'were', 'is', 'are', 'am']:
            # print('foundfound')
            words[root_idx-3] = 'will'
            words[root_idx] = 'be ' + words[root_idx]
            # words[root_idx] = 'be ' + WordNetLemmatizer().lemmatize(words[root_idx],'v') 
        else:
            sent.print_dependencies()
            # print('2')
            return '', ''
            raise NotImplementedError
    else:
        sent.print_dependencies()
        # print('3')
        raise NotImplementedErrore

    return ' '.join(words), ''

def reconstruct_future_old():
    # # reconstruct new sentence for future tense
    # if root_tok[2].text in ['was', 'were', 'is', 'are', 'am']
    #     words[root_idx] = 'will be'
    # elif root_tok[2].text == 'be':   # would not happen
    #     words[root_idx] = 'will be'
    # elif words[root_idx-1] in ['be', 'can', 'could', 'dare', 'do', 'have', 
    #                             'may', 'might', 'must', 'need', 'ought', 
    #                             'shall', 'should', 'will', 'would']: # auxiliary verbs
    #     print()
    # elif root_tok[2].xpos in ['VBN', 'VBG']:  # if p.p. or ving
    #     if words[root_idx-1] in ['was', 'were', 'is', 'are', 'am']: # if be verbs

    #     elif words[root_idx-1] in ['be', 'been']:  # 
    #         if words[root_idx-2] in ['should', 'could', 'might', 'can', 'may', 'would', 'shall', 'must', 'have']:
    #             be, can, could, dare, eed, ought
    #         elif words[root_idx-2] == 'n\'t':
    #             words[root_idx-3] = 'will'
    #             words[root_idx-2] = 'not'
    #             words[root_idx-1] = 'be'
    #         else:
    #             words[root_idx-1] == 'will be'
    #     else: #all very normal
    pass



MASK_CLS = 'ilm.mask.hierarchical.MaskHierarchical'
MODEL_DIR = '/tmp/ilm/models/sto_ilm'

def prepare_tokenizer():
    class IlmUtils(object):
        """docstring for IlmUtils"""
        def __init__(self, tokenizer, additional_ids_to_tokens, additional_tokens_to_ids):
            super(IlmUtils, self).__init__()
            self.tokenizer = tokenizer
            self.additional_ids_to_tokens = additional_ids_to_tokens
            self.additional_tokens_to_ids = additional_tokens_to_ids
            
    # Prepare tokenizer
    import os
    import pickle

    import ilm.tokenize_util

    tokenizer = ilm.tokenize_util.Tokenizer.GPT2
    with open(os.path.join(MODEL_DIR, 'additional_ids_to_tokens.pkl'), 'rb') as f:
        additional_ids_to_tokens = pickle.load(f)
    additional_tokens_to_ids = {v:k for k, v in additional_ids_to_tokens.items()}
    try:
        ilm.tokenize_util.update_tokenizer(additional_ids_to_tokens, tokenizer)
    except ValueError:
        print('Already updated')
    # print(additional_tokens_to_ids)
    return IlmUtils(tokenizer, additional_ids_to_tokens, additional_tokens_to_ids)


def load_model():
    # Load model

    import torch
    from transformers import GPT2LMHeadModel

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = GPT2LMHeadModel.from_pretrained(MODEL_DIR)
    model.eval()
    _ = model.to(device)
    return model


def reconstruct_ilm(sent, model, ilmutils, root_idx):
    

    words = [dep_edge[2].text for dep_edge in sent.dependencies]
    org_sent = ' '.join(words)

    words[root_idx] = '_'
    context = ' '.join(words)
    if root_idx == 0:
        context = ' ' + context
    # print(context)

    # Create context
    context_ids = ilm.tokenize_util.encode(context, ilmutils.tokenizer)

    # Replace blanks with appropriate tokens from left to right
    _blank_id = ilm.tokenize_util.encode(' _', ilmutils.tokenizer)[0]
    try:
        context_ids[context_ids.index(_blank_id)] = ilmutils.additional_tokens_to_ids['<|infill_word|>']
    except:
        print('exceptexcept')
        return '', ''
    # print(ilm.tokenize_util.decode(context_ids, ilmutils.tokenizer))
 

    from ilm.infer import infill_with_ilm

    generated = infill_with_ilm(
        model,
        ilmutils.additional_tokens_to_ids,
        context_ids,
        num_infills=10)
    g = generated[0]
    print(generated)
    # print('-' * 80)
    # print(ilm.tokenize_util.decode(g, ilmutils.tokenizer))
    out = ilm.tokenize_util.decode(g, ilmutils.tokenizer)
    i = 1
    while out == org_sent:
        # print('samesamesame')
        out = ilm.tokenize_util.decode(generated[i], ilmutils.tokenizer)
        i += 1
        if i == 10:
            # print('NNN...')
            return None, ''
    return out, ''


def reconstruct_polyjuice(sent, root_idx, control_code, pj):
    # instantiate a wrapper

    # 
    words = [dep_edge[2].text for dep_edge in sent.dependencies]
    org_sent = ' '.join(words)

    words[root_idx] = '[BLANK]'
    context = ' '.join(words)
    print('0000000')
    print(org_sent)
    print(context)

    perturbations = pj.perturb(
    orig_sent=org_sent,
    # can specify where to put the blank. Otherwise, it's automatically selected.
    # Can be a list or a single sentence.
    blanked_sent=context,
    # can also specify the ctrl code (a list or a single code.)
    # The code should be from 'resemantic', 'restructure', 'negation', 'insert', 'lexical', 'shuffle', 'quantifier', 'delete'.
    ctrl_code=control_code,
    # ctrl_code=negation,
    # Customzie perplexity score. 
    perplex_thred=15,
    # number of perturbations to return
    num_perturbations=2,
    # the function also takes in additional arguments for huggingface generators.
    num_beams=3
    )
    print('pert', perturbations)
    if not perturbations:
        return None, '' 
    else:
        return perturbations[0], ''

# 2711 data points survive this algorithm
def modfiy_passages(data, method, num_context, model=None, ilmutils=None):
    print('start')
    count = 0
    mod_data = []
    old_passages = []
    total_gold, perturbed_gold = 0, 0
    pj = None 
    if method == 'poly-negation':
        pj = Polyjuice(model_path="uw-hai/polyjuice", is_cuda=True)


    for i, inst in tqdm(enumerate(data)):
        print('%d/%d'%(count, i))

        found_verb_inst = False
        inst_gold, inst_perturb = 0, 0
        for c in range(len(inst['ctxs'])):
            found_verb_context = False
            passage = inst['ctxs'][c]['text']
            sents = sent_tokenize(passage)
            new_sent_list = []
            old_sent_list = []
            
            context_gold, context_perturb = 0, 0
            for sent in sents:
                new_sent = sent
                old_sent = sent
                found_verb = False
                sent_gold, sent_perturb = False, False

                for ans in inst['answers']: # for any possible answers
                    # check if the sentence contain answer span
                    # if sent.find(ans) != -1:
                    if _find_answer_in_context(ans, sent):
                        doc = nlp(sent)     # feed the sentence into pipeline
                        sent_gold = True

                        # need to reconstruct new sentence from doc
                        if len(doc.sentences[0].tokens) <=3:  # if the first sent too short
                            # probably something wrong
                            if len(doc.sentences) == 1:
                                continue
                            if len(doc.sentences[1].tokens) > 3:
                                indices = [dep_edge[0].index for dep_edge in doc.sentences[1].dependencies]
                                root_idx = indices.index('0')  # found the root index!!
                                root_tok = doc.sentences[1].dependencies[root_idx]
                                if root_tok[2].upos == 'VERB':
                                    found_verb = True

                                    ##############################
                                    # reconstruction happens here
                                    print('reconstruction')
                                    if method == 'ilm':
                                        recon_sent, old_sent = reconstruct_ilm(doc.sentences[1], model, ilmutils, root_idx)
                                    elif method == 'modality':
                                        recon_sent, old_sent = reconstruct_modality(doc.sentences[1], root_tok, root_idx)
                                    elif method == 'future':
                                        recon_sent, old_sent = reconstruct_future(doc.sentences[1], root_tok, root_idx)
                                    elif method == 'negation':
                                        recon_sent, old_sent = reconstruct_negation(doc.sentences[1], root_tok, root_idx)
                                    elif method[:4] == 'poly':
                                        control_code = method.split('-')[1]
                                        recon_sent, old_sent = reconstruct_polyjuice(doc.sentences[1], root_idx, control_code, pj)
                                    if not recon_sent:
                                        found_verb = False
                                        continue

                                    if len(doc.sentences) > 2:
                                        other_strings = [' '.join([l.text for l in s.tokens]) for s in doc.sentences[2:]]
                                    else:
                                        other_strings = []
                                    new_sent = ' '.join([l.text for l in doc.sentences[0].tokens]) + ' ' + recon_sent + ' '.join(other_strings)
                                    old_sent = ' '.join([l.text for l in doc.sentences[0].tokens]) + ' ' + old_sent + ' '.join(other_strings)
                                    print('new: %s'%new_sent)
                                    print('old: %s'%old_sent)

                                    
                                    # reconstruction ends here
                                    ##############################

                                # else:
                                #     print(inst)
                                #     print('ssssecond')
                                #     doc.sentences[1].print_dependencies()
                                #     for tok in doc.sentences[1].dependencies:
                                #         if tok[0].index == '%d'%root_idx:
                                #             print('can be root:', tok[2].index, tok[2].text)
                                #             if tok[2].upos =='VERB':
                                #                 print('really possible!!')
                                #                 print('really possible!!')

                                # print('root_idx', root_idx)

                        else:
                            indices = [dep_edge[0].index for dep_edge in doc.sentences[0].dependencies]
                            root_idx = indices.index('0')  # found the root index!!
                            root_tok = doc.sentences[0].dependencies[root_idx]
                            if root_tok[2].upos == 'VERB':
                                found_verb = True

                                ##############################
                                # reconstruction happens here
                                if method == 'ilm':
                                    recon_sent, old_sent = reconstruct_ilm(doc.sentences[0], model, ilmutils, root_idx)
                                elif method == 'modality':
                                    recon_sent, old_sent = reconstruct_modality(doc.sentences[0], root_tok, root_idx)
                                elif method == 'future':
                                    recon_sent, old_sent = reconstruct_future(doc.sentences[0], root_tok, root_idx)
                                elif method == 'negation':
                                    recon_sent, old_sent = reconstruct_negation(doc.sentences[0], root_tok, root_idx)
                                elif method[:4] == 'poly':
                                    control_code = method.split('-')[1]
                                    recon_sent, old_sent = reconstruct_polyjuice(doc.sentences[0], root_idx, control_code, pj)
                                if not recon_sent:
                                    found_verb = False
                                    continue
                                if len(doc.sentences) > 1:
                                    other_strings = [' '.join([l.text for l in s.tokens]) for s in doc.sentences[1:]]
                                else:
                                    other_strings = []                            
                                new_sent = recon_sent + ' ' + ' '.join(other_strings)
                                old_sent = old_sent + ' ' + ' '.join(other_strings)
                                print('new: %s'%new_sent)
                                print('old: %s'%old_sent)

                                # reconstruction ends here
                                ##############################
                            # else:
                            #     print(inst)
                            #     print('fffffirst')
                            #     doc.sentences[0].print_dependencies()
                            #     for tok in doc.sentences[0].dependencies:
                            #         if tok[0].index == '%d'%root_idx:
                            #             print('can be root:', tok[2].index, tok[2].text)
                            #             if tok[2].upos =='VERB':
                            #                 print('really possible!!')
                            #                 print('really possible!!')

                        # perturbed sentence
                        if found_verb:
                            sent_perturb = True
                            found_verb_context = True
                            break

                old_sent_list.append(old_sent)
                new_sent_list.append(new_sent)

                if sent_gold:
                    context_gold += 1
                if sent_perturb:
                    context_perturb += 1


            inst_gold += context_gold
            inst_perturb += context_perturb

            if found_verb_context:
                new_passage = ' '.join(new_sent_list)
                old_passage = ' '.join(old_sent_list)
                old_passages.append(old_passage)

                # modify original passage
                data[i]['ctxs'][c]['text'] = new_passage

                print('*'*50)
                print('*'*50)
                print('passage', passage)
                print('new_passage', data[i]['ctxs'][c])

                found_verb_inst = True
            
            # if 1 context, only do 1 time (else substitute all passages)
            if num_context == 1:
                break



        if found_verb_inst and (inst_gold == inst_perturb):
        # if found_verb_inst:
            count += 1
            mod_data.append(data[i])
            total_gold += inst_gold
            perturbed_gold += inst_perturb
            print('count:%d inst: %d inst_gold:%d inst_perturb: %d'%(count, i, inst_gold, inst_perturb))
            

    print('total_gold, perturbed_gold')
    print(total_gold, perturbed_gold)


# VB  Verb, base form
# VBD Verb, past tense
# VBG Verb, gerund or present participle
# VBN Verb, past participle
# VBP Verb, non-3rd person singular present
# VBZ Verb, 3rd person singular present

    return mod_data, old_passages


def main(method, num_context):
    print(method)

    model = None
    if method == 'ilm':
        ilmutils = prepare_tokenizer()
        model = load_model()
        data = json.load(open(sys.argv[2]))
        mod_data, old_passages = modfiy_passages(data, method, num_context, model, ilmutils)
        fw = open(sys.argv[3], 'w')
        fw.write(json.dumps(mod_data, indent=4))
        fw.close()
        # fw = open(sys.argv[4], 'w')
        # for l in old_passages:
        #     fw.write(l+'\n')
    elif method in ['modality', 'future', 'negation', 'poly-negation']:
        data = json.load(open(sys.argv[2]))
        mod_data, old_passages = modfiy_passages(data, method, num_context)
        fw = open(sys.argv[3], 'w')
        fw.write(json.dumps(mod_data, indent=4))
        fw.close()
        # fw = open(sys.argv[4], 'w')
        # for l in old_passages:
        #     fw.write(l+'\n')


    # elif method == 'future':
    #     data = json.load(open(sys.argv[2]))
    #     mod_data, old_passages = modfiy_passages(data, method)
    #     fw = open(sys.argv[3], 'w')
    #     fw.write(json.dumps(mod_data, indent=4))
    #     fw.close()

    #     fw = open(sys.argv[4], 'w')
    #     for l in old_passages:
    #         fw.write(l+'\n')
    # elif method == 'negation':
    #     pass

if __name__ == '__main__':
    method=sys.argv[1]
    num_context = int(sys.argv[5])
    main(method, num_context)

    # survive modality -> 2633
    # survive ilm -> 2712
    # to use ilm -> data/open_domain_data$ source ilm/ilm-env/bin/activate
