
import argparse
import glob
import logging as log
import os
import random
import time
import torch.nn.functional as F

import numpy as np
import torch
from eval_utils import f1_score, precision_score, recall_score, classification_report, macro_score
from utils import gen_knn_mix_batch
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
import pickle
from transformers import *
from read_data import *

from tensorboardX import SummaryWriter

from bert_models import BertModel4Mix

logger = log.getLogger(__name__)

use_cuda = torch.cuda.is_available()
#CUDA_VISIBLE_DEVICES=6,7
#os.environ["CUDA_VISIBLE_DEVICES"] = "6,7"
MODEL_CLASSES = {"bert": (BertConfig, BertForTokenClassification, BertTokenizer)}

parser = argparse.ArgumentParser(description='PyTorch BaseNER')
parser.add_argument("--data-dir", default = './data', type = str, required = True)
parser.add_argument("--model-type", default = 'bert', type = str)
parser.add_argument("--model-name", default = 'bert-base-multilingual-cased', type = str)
parser.add_argument("--output-dir", default = './german_eval', type = str)
parser.add_argument('--gpu', default='0,1,2,3', type=str, help='id(s) for CUDA_VISIBLE_DEVICES')
parser.add_argument('--train-examples', default = -1, type = int)

parser.add_argument("--labels", default = "", type = str)
parser.add_argument('--config-name', default = '', type = str)
parser.add_argument("--tokenizer-name", default = '', type = str)
parser.add_argument("--max-seq-length", default = 128, type = int)

parser.add_argument("--do-train", action="store_true", help="Whether to run training.")
parser.add_argument("--do-eval", action="store_true", help="Whether to run eval on the dev set.")
parser.add_argument("--do-predict", action="store_true", help="Whether to run predictions on the test set.")
parser.add_argument("--evaluate-during-training", action="store_true", help="Whether to run evaluation during training at each logging step.")
parser.add_argument("--do-lower-case", action="store_true", help="Set this flag if you are using an uncased model.")

parser.add_argument("--batch-size", default = 16, type = int)
parser.add_argument('--eval-batch-size', default = 128, type = int)

parser.add_argument("--gradient-accumulation-steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.")

parser.add_argument("--learning-rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--weight-decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam-epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--max-grad-norm", default=1.0, type=float, help="Max gradient norm.")

parser.add_argument("--num-train-epochs", default=20, type=float, help="Total number of training epochs to perform.")
parser.add_argument("--max-steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
parser.add_argument('--warmup-steps', default = 0, type = int,  help="Linear warmup over warmup_steps.")

parser.add_argument('--logging-steps', default = 150, type = int, help="Log every X updates steps.")
parser.add_argument("--save-steps", type=int, default=0, help="Save checkpoint every X updates steps.")
parser.add_argument("--eval-all-checkpoints", action="store_true", help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
parser.add_argument("--overwrite-output-dir", action="store_true", help="Overwrite the content of the output directory")

parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")

parser.add_argument("--pad-subtoken-with-real-label", action="store_true", help="give real label to the padded token instead of `-100` ")
parser.add_argument("--subtoken-label-type",default='real', type=str,help="[real|repeat|O] three ways to do pad subtoken with real label. [real] give the subtoken a real label e.g., B -> B I. [repeat] simply repeat the label e.g., B -> B B. [O] give it a O label. B -> B O")


parser.add_argument("--eval-pad-subtoken-with-first-subtoken-only", action="store_true", help="only works when --pad-subtoken-with-real-label is true, in this mode, we only test the prediction of the first subtoken of each word (if the word could be tokenized into multiple subtoken)")
parser.add_argument("--label-sep-cls", action="store_true", help="label [SEP] [CLS] with special labels, but not [PAD]") 



parser.add_argument("--log-file", default = "results.csv", type = str,help="the file to store resutls")

parser.add_argument("--optimizer", default = "adam", type = str,help='optimizer')
parser.add_argument('--special-label-weight', default=0, type=float, help='the special_label_weight in training . default 0')



# intra-mix




args = parser.parse_args()

os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.device = device
args.n_gpu = torch.cuda.device_count()
print("gpu num: ", args.n_gpu)

best_f1 = 0





def set_seed(args):
    logger.info("random seed %s", args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if len(args.gpu) > 0:
        torch.cuda.manual_seed_all(args.seed)

def count_rule_based_aug(args, tokenizer, labels, pad_token_label_id, mode,  
              omit_sep_cls_token=False,
              pad_subtoken_with_real_label=False):

    examples = read_examples_from_file_excel("data/conll2003",mode)  #examples = read_examples_from_file_excel("data/conll2003",mode)
    augexamples=[]
    totalexamplelist=copy.deepcopy(examples)
    examples_for_rule_based_aug_training=copy.deepcopy(examples)
    
    if mode =='train':

    
        
        
        import pickle 
        

        #Count all eligible examples
        num_of_aug_examples_to_generate_for_aug_type_1=10000000000000000000000000000
        num_of_aug_examples_to_generate_for_aug_type_3=10000000000000000000000000000
        num_of_aug_examples_to_generate_for_aug_type_4=10000000000000000000000000000
        augruniter=0
       
        num_of_aug_examples_generated=0
        

         
        #Pattern 1
    
        data4 = pd.read_excel ('data/conll2003/Sports.xlsx') 
        sportnames=list(data4['Names'])
        data = pd.read_excel ('data/conll2003/ORG Entity Phrases.xlsx') 
        df=data['Pattern1']
        placement=list(data['Placement'])
        phrasedomains=list(data['Domain'])
        compatibility=list(data['Compatibility'])
        labelsdf=list(data['Labels'])
        phrases=[]
        labels2=[]
        dfindex=0
        exampleindex=0
        while dfindex< len(df.index):
            phrases.append(df[dfindex].split(","))
            dfindex=dfindex+1

        for label in labelsdf:
            labels2.append(label.replace(" ", "").split("|"))


        examplesguids=[]
        selectedwords=[]
        exampleguidsalreadyaugmented=[]
        iccands=[]
        iccandbool=False

        i=0
        selectedword=''
        pattern1orgtransitioncounter=0
    

        for ex in examples:
            #if len(examplesguids)==400:
            #    break
            e=random.choice(examples)
            for word,label in zip(e.words,e.labels):
                if 'B-PER' in label or 'B-LOC' in label and 'B-MISC' not in label:
                    iccandbool=True
                    selectedword=word            
                
                elif iccandbool==True and label=='O' and i==1:
                    
                    if e.guid not in examplesguids:

                        examplesguids.append(e.guid)
                        selectedwords.append(selectedword)
                

                    break 
                
                if iccandbool==True:                
                    i=i+1        
            pattern1orgtransitioncounter=pattern1orgtransitioncounter+1
            iccandbool=False
            i=0


        #Code for putting data into examples instead of cand list of tuples
        index=0
        
        for ex in examples:
            
            if ex.guid in examplesguids and ex.guid not in exampleguidsalreadyaugmented:
            
                for guid2,selword in zip(examplesguids,selectedwords):
                    if ex.guid==guid2:
                        
                    

                        randomindex=random.randint(0,len(phrases)-1)
                    
                        
                        if placement[randomindex]=='before':
                            currentexample=copy.deepcopy(examples[index])
                            if "B-LOC" == currentexample.labels[currentexample.words.index(selword)]:
                                if currentexample.words[currentexample.words.index(selword)-1].lower()=="in":
                                    if currentexample.words.index(selword)-1==0:
                                        currentexample.words[currentexample.words.index(selword)-1]="At"
                                    else:
                                        currentexample.words[currentexample.words.index(selword)-1]="at"
                                        o=0
                                    o=0
                            if currentexample.words.index(selword)==0:
                                currentexample.words.insert(currentexample.words.index(selword),"The")
                                currentexample.labels.insert(currentexample.words.index(selword)-1,"O")
                            else:
                                currentexample.words.insert(currentexample.words.index(selword),"the")
                                currentexample.labels.insert(currentexample.words.index(selword)-1,"O")
                            
                        
                            #Sport phrase code
                            if  "<x>" in phrases[randomindex]:
                                exwords = [word.lower() for word in currentexample.words]
                                insports=False
                                
                                for sport in sportnames:                               
                                    if sport.lower() in exwords :
                                        insports=True
                                        sportphrase1=phrases[randomindex]
                                        sportphrase=copy.deepcopy(sportphrase1)
                                        pindex=sportphrase.index("<x>")
                                        
                                        sportphrase[pindex]=sport
                                        
                                    
                                        break
                                if insports==False:
                                    randomsportindex=random.randint(0,len(sportnames)-1) 
                                    sportphrase1=phrases[randomindex]
                                    sportphrase=copy.deepcopy(sportphrase1)
                                    pindex=phrases[randomindex].index("<x>")
                                    sportphrase[pindex]=sportnames[randomsportindex]

                                for word,label in zip(sportphrase,labels2[randomindex]):
                                

                                    currentexample.labels[currentexample.words.index(selword)]='I-ORG'
                                    currentexample.words.insert(currentexample.words.index(selword),word)
                                    currentexample.labels.insert(currentexample.words.index(selword),label)
                                firstwordindex=currentexample.words.index(selword)-len(phrases[randomindex])
                                currentexample.labels[firstwordindex]='B-ORG'
                                r=0
                            #Regular insertion 
                            else: 
                            
                                for word,label in zip(phrases[randomindex],labels2[randomindex]):
                                    

                                    currentexample.labels[currentexample.words.index(selword)]='I-ORG'
                                    currentexample.words.insert(currentexample.words.index(selword),word)
                                    currentexample.labels.insert(currentexample.words.index(selword),label)
                                firstwordindex=currentexample.words.index(selword)-len(phrases[randomindex])
                                currentexample.labels[firstwordindex]='B-ORG'
                            
                                

                        elif placement[randomindex]=='after':
                            currentexample=copy.deepcopy(examples[index])
                            if "B-LOC" == currentexample.labels[currentexample.words.index(selword)]:
                                if currentexample.words[currentexample.words.index(selword)-1].lower()=="in":
                                    if currentexample.words.index(selword)-1==0:
                                        currentexample.words[currentexample.words.index(selword)-1]="At"
                                    else:
                                        currentexample.words[currentexample.words.index(selword)-1]="at"
                            if currentexample.words.index(selword)==0:
                                currentexample.words.insert(currentexample.words.index(selword),"The")
                                currentexample.labels.insert(currentexample.words.index(selword)-1,"O")
                            else:
                                currentexample.words.insert(currentexample.words.index(selword),"the")
                                currentexample.labels.insert(currentexample.words.index(selword)-1,"O")


                            #Sport phrase code
                            if  "<x>" in phrases[randomindex]:
                                exwords = [word.lower() for word in currentexample.words]
                                insports=False
                                
                                for sport in sportnames:                               
                                    if sport.lower() in exwords :
                                        insports=True
                                        sportphrase1=phrases[randomindex]
                                        sportphrase=copy.deepcopy(sportphrase1)
                                        pindex=sportphrase.index("<x>")
                                        
                                        sportphrase[pindex]=sport
                            
                                    
                                        break
                                if insports==False:
                                    randomsportindex=random.randint(0,len(sportnames)-1) 
                                    sportphrase1=phrases[randomindex]
                                    sportphrase=copy.deepcopy(sportphrase1)
                                    pindex=phrases[randomindex].index("<x>")
                                    sportphrase[pindex]=sportnames[randomsportindex]

                                for word,label in zip(sportphrase,labels2[randomindex]):
                                

                                    currentexample.labels[currentexample.words.index(selword)]='I-ORG'
                                    currentexample.words.insert(currentexample.words.index(selword),word)
                                    currentexample.labels.insert(currentexample.words.index(selword),label)
                                firstwordindex=currentexample.words.index(selword)-len(phrases[randomindex])
                                currentexample.labels[firstwordindex]='B-ORG'
                                

                            #Regular insertion 
                            else:                        
                    
                            
                                #Start inserting words
                                k=1
                                for word,label in zip(phrases[randomindex],labels2[randomindex]):
                                    currentexample.words.insert(currentexample.words.index(selword)+k,word)
                                    currentexample.labels.insert(currentexample.words.index(selword)+k,label)
                                    currentexample.labels[currentexample.words.index(selword)]='B-ORG'
                                    k=k+1
                        augexamples.append(currentexample)
                        exampleguidsalreadyaugmented.append(ex.guid)
                        num_of_aug_examples_generated=num_of_aug_examples_generated+1
                        break
                        
            if num_of_aug_examples_generated>=num_of_aug_examples_to_generate_for_aug_type_1:
                num_of_aug_examples_generated=0
                break
                        

            index=index+1
            

        
        dataloc = pd.read_excel ('data/conll2003/Transition To LOC.xlsx') 
        dfloc=list(dataloc['Pattern1'])
        dataloc2 = pd.read_excel ('data/conll2003/LOC Context Phrases.xlsx') 
        phrasesdfloc2=list(dataloc2['Pattern2'])
        labelsloc2=list(dataloc2['EntityType'])
        
        loclabels=[]
        phrasesloc=[]
        phrasesloc2=[]
        for locword in phrasesdfloc2:
            phrasesloc2.append(locword.split("|"))

        for label in labelsloc2:
            loclabels.append(label.replace(" ", "").split("|"))     
        
        for loc in dfloc:
            phrasesloc.append(loc.split(","))

        placementloc=list(dataloc['Placement'])

        labels3=[]
        labelset=[]
        for phrase, place in zip(phrasesloc, placementloc):
            if place=='before':
                first=True
                for word in phrase:
                    if first==True:
                        labelset.append('B-ORG')
                    else:
                        labelset.append('I-ORG')
            elif place=='after':
                
                for word in phrase:
                    labelset.append('I-ORG')
            labels3.append(labelset)
            labelset=[]
        index=0
        toloccands=[]
        iccandbool=False
        toloccounter=0
        selectedword=''
        pattern1loctransitioncounter=0
        #while  pattern1loctransitioncounter <30:
            
        for e in examples:
            if toloccounter==70:
                break

            if e.guid not in exampleguidsalreadyaugmented:
                for phrase,placementlociter in zip(phrasesloc,placementloc):
                    
                    if(set(phrase).issubset(set( e.words))):
                        
                        if placementlociter=='after':
                            
                            entityindex=e.words.index(phrase[-1])+1

                            if e.labels[entityindex]=='I-ORG':
                                #match e guid with example guid
                                index=0
                                
                                
                                for ex in examples:
                                    if ex.guid==e.guid:
                                        currentexample=copy.deepcopy(examples[index])
                                        
                                        currentexample.labels[entityindex]='B-LOC'
                                        k=1
                                        
                                        for word in phrase:
                                            
                                            currentexample.words.pop(entityindex-k)
                                            currentexample.labels.pop(entityindex-k)
                                            
                                            k=k+1
                                        
                                        if currentexample.labels[currentexample.labels.index('B-LOC')+1]=='I-ORG':
                                            currentexample.labels[currentexample.labels.index('B-LOC')+1]='I-LOC'
                                        
                                        #pattern 2
                                        randomindex=random.randint(0,len(phrasesdfloc2)-1)
                                        selword=currentexample.labels.index('B-LOC')

                                        k=1
                                        
                                        for word,label in zip(phrasesloc2[randomindex],loclabels[randomindex]):
                                            currentexample.words.insert(selword+k,word)
                                            currentexample.labels.insert(selword+k,label)
                                            
                                            k=k+1

                                        augexamples.append(currentexample)
                                        #num_of_aug_examples_generated=num_of_aug_examples_generated+1
                                        exampleguidsalreadyaugmented.append(ex.guid)
                                        toloccounter=toloccounter+1
                                        break
                                        
                                        
                                    index=index+1
                            else:
                                break

                        elif placementlociter=='before':
                            entityindex=e.words.index(phrase[0])-1
                            if e.labels[entityindex]=='B-ORG':
                                #match e guid with example guid
                                index=0
                                
                                
                                for ex in examples:
                                    if ex.guid==e.guid:
                                        currentexample=copy.deepcopy(examples[index])
                                        
                                        currentexample.labels[entityindex]='B-LOC'
                                        
                                        for word in phrase:
                                            
                                            currentexample.words.pop(entityindex+1)
                                            currentexample.labels.pop(entityindex+1)
                                        
                                        if currentexample.labels[currentexample.labels.index('B-LOC')+1]=='I-ORG':
                                            currentexample.labels[currentexample.labels.index('B-LOC')+1]='I-LOC'
                                        
                                        #pattern 2
                                        randomindex=random.randint(0,len(phrasesdfloc2)-1)
                                        selword=currentexample.labels.index('B-LOC')

                                        k=1
                                        
                                        for word,label in zip(phrasesloc2[randomindex],loclabels[randomindex]):
                                            currentexample.words.insert(selword+k,word)
                                            currentexample.labels.insert(selword+k,label)
                                            
                                            k=k+1

                                        augexamples.append(currentexample)
                                        
                                        exampleguidsalreadyaugmented.append(ex.guid)
                                        #num_of_aug_examples_generated=num_of_aug_examples_generated+1
                                        toloccounter=toloccounter+1
                                        break
                                        
                                        
                                    index=index+1

                            else:
                                break


                      

                    if e.guid in exampleguidsalreadyaugmented:
                        break
                        
           

                            
    
        dataPER = pd.read_excel ('data/conll2003/PER Context Phrases.xlsx') 
        dataPER2 = pd.read_excel ('data/conll2003/PER Headline Phrases.xlsx') 
        dataORG = pd.read_excel ('data/conll2003/ORG Context Phrases.xlsx') 
        dfPER2=list(dataPER2['Pattern2'])
        labelsPER2=list(dataPER2['EntityType'])
        compatibility2=list(dataORG['Compatibility'])

        dataNames = pd.read_excel ('data/conll2003/Names.xlsx') 
        dfNames=list(dataNames['Names'])
        dfPER=list(dataPER['Pattern2'])
        perphrasedomains=list(dataPER['Domain'])
        orgphrasedomains=list(dataORG['Domain'])
        dfORG=list(dataORG['Pattern2'])
        labelsPER=list(dataPER['EntityType'])
        labelsORG=list(dataORG['EntityType'])
        perphrases=[] 
        orgphrases=[]
        perlabels=[]  
        orglabels=[]
        perphrases2=[] 
        perlabels2=[] 


        for per2 in dfPER2:
            perphrases2.append(per2.split("|"))

        for perlabel2 in labelsPER2:
            perlabels2.append(perlabel2.replace(" ", "").split("|"))

        for per in dfPER:
            perphrases.append(per.split("|"))
        
        for org in dfORG:
            orgphrases.append(org.split("|"))

        for label in labelsPER:
            perlabels.append(label.replace(" ", "").split("|"))
        for label in labelsORG:
            orglabels.append(label.replace(" ", "").split("|"))

        percands=[]
        orgcands=[]
        iccandbool=False
        i=0
        selectedword=''
        numofexamples=0
        pertransitioncounter=0
        orgtransitioncounter=0
        originallabel=''
        examplesguids=[]
        selectedwords=[]
        #while  pertransitioncounter <30:
        for ex in examples:
            #if len(examplesguids)==400:
            #    break
            e=random.choice(examples)
            if  e.guid not in exampleguidsalreadyaugmented:

                for word,label in zip(e.words,e.labels):
                                
                    nextwordinddex=e.words.index(word)+1
                    if nextwordinddex<len(e.words):
                        w=e.words[nextwordinddex]
                    else:
                        w=''

                    if 'B-ORG' in label or 'B-LOC' in label  and 'B-MISC' not in label and ')' not in w and '.' not in word:
                        iccandbool=True
                        selectedword=word
                        originallabel=label
                    elif iccandbool==True and label=='O' and i==1:
                        if e.guid not in examplesguids:
                        
                            examplesguids.append(e.guid)
                            selectedwords.append(selectedword)

                        break 
                    
                    if iccandbool==True:                
                        i=i+1        
                
                iccandbool=False
                i=0
                pertransitioncounter=pertransitioncounter+1
        index=0
        for ex in examples:
                
                if ex.guid in examplesguids and ex.guid not in exampleguidsalreadyaugmented:
                
                    for guid2,selword in zip(examplesguids,selectedwords):
                        if ex.guid==guid2:
                                            
                            
                            currentexample=copy.deepcopy(examples[index])
                            numbercount = sum(entry.isdigit() for entry in currentexample.words)
                            if len(currentexample.words)<4 and 'AT' not in currentexample.words or numbercount>3:
                                randomindex=random.randint(0,len(dfPER2)-1)
                                k=0
                                for word,label in zip(perphrases2[randomindex],perlabels2[randomindex]):
                                    currentexample.words.insert(0+k,word)
                                    currentexample.labels.insert(0+k,label)
                                    currentexample.labels[currentexample.words.index(selword)]='B-PER'
                                    k=k+1
                            else:
                                randomindex=random.randint(0,len(dfPER)-1)        
                                k=1
                                for word,label in zip(perphrases[randomindex],perlabels[randomindex]):
                                    currentexample.words.insert(currentexample.words.index(selword)+k,word)
                                    currentexample.labels.insert(currentexample.words.index(selword)+k,label)
                                    currentexample.labels[currentexample.words.index(selword)]='B-PER'
                                    k=k+1

                            #Pattern 1
                            randomindex3=random.randint(0,len(dfNames)-1)
                            currentexample.words.insert(currentexample.words.index(selword)+1,dfNames[randomindex3])
                            currentexample.labels.insert(currentexample.words.index(selword)+1,'I-PER')
                            augexamples.append(currentexample)
                            exampleguidsalreadyaugmented.append(currentexample.guid)
                            num_of_aug_examples_generated=num_of_aug_examples_generated+1
                            
                            break

                index=index+1
                if num_of_aug_examples_generated>=num_of_aug_examples_to_generate_for_aug_type_3:
                    num_of_aug_examples_generated=0
                    break
        examplesguids=[]    
        selectedwords=[]

        
            
        
        for ex in examples:
 

            e=random.choice(examples)
            if  e.guid not in exampleguidsalreadyaugmented:
                for word,label in zip(e.words,e.labels):
                    if 'B-PER' in label or 'B-LOC' in label  and 'B-MISC' not in label:
                        iccandbool=True
                        selectedword=word
                        originallabel=label
                    elif iccandbool==True and label=='O' and i==1:
                    
                        if e.guid not in examplesguids:
                            examplesguids.append(e.guid)
                            selectedwords.append(selectedword)

                        break 
                    
                    if iccandbool==True:                
                        i=i+1        
                
                iccandbool=False
                i=0
                orgtransitioncounter=orgtransitioncounter+1

        j=0
        index=0
        for ex in examples:
                
                if ex.guid in examplesguids and ex.guid not in exampleguidsalreadyaugmented:
                    
                    for guid2,selword in zip(examplesguids,selectedwords):
                        if ex.guid==guid2:
                        

                            
                            currentexample=copy.deepcopy(examples[index])
                                                        
                            randomindex2=random.randint(0,len(phrases)-1)
                            randomindex=random.randint(0,len(orgphrases)-1)

                                    
                            k=1
                            for word,label in zip(orgphrases[randomindex],orglabels[randomindex]):
                                currentexample.words.insert(currentexample.words.index(selword)+k,word)
                                currentexample.labels.insert(currentexample.words.index(selword)+k,label)
                                currentexample.labels[currentexample.words.index(selword)]='B-ORG'
                                k=k+1
                      
                            
                            if placement[randomindex2]=='before':

                                if "B-LOC" == currentexample.labels[currentexample.words.index(selword)]:
                                    if currentexample.words[currentexample.words.index(selword)-1].lower()=="in":
                                        if currentexample.words.index(selword)-1==0:
                                            currentexample.words[currentexample.words.index(selword)-1]="At"
                                        else:
                                            currentexample.words[currentexample.words.index(selword)-1]="at"
                                if currentexample.words.index(selword)==0:
                                    currentexample.words.insert(currentexample.words.index(selword),"The")
                                    currentexample.labels.insert(currentexample.words.index(selword)-1,"O")
                                else:
                                    currentexample.words.insert(currentexample.words.index(selword),"the")
                                    currentexample.labels.insert(currentexample.words.index(selword)-1,"O")



                                
                                #Sport phrase code
                                if  "<x>" in phrases[randomindex2]:
                                    exwords = [word.lower() for word in currentexample.words]
                                    insports=False
                                    
                                    for sport in sportnames:                               
                                        if sport.lower() in exwords :
                                            insports=True
                                            sportphrase1=phrases[randomindex2]
                                            sportphrase=copy.deepcopy(sportphrase1)
                                            pindex=sportphrase.index("<x>")
                                            
                                            sportphrase[pindex]=sport
                                    
                                        
                                            break
                                    if insports==False:
                                        randomsportindex=random.randint(0,len(sportnames)-1) 
                                        sportphrase1=phrases[randomindex2]
                                        sportphrase=copy.deepcopy(sportphrase1)
                                        pindex=phrases[randomindex2].index("<x>")
                                        sportphrase[pindex]=sportnames[randomsportindex]

                                    for word,label in zip(sportphrase,labels2[randomindex2]):
                                    

                                        currentexample.labels[currentexample.words.index(selword)]='I-ORG'
                                        currentexample.words.insert(currentexample.words.index(selword),word)
                                        currentexample.labels.insert(currentexample.words.index(selword),label)
                                    firstwordindex=currentexample.words.index(selword)-len(phrases[randomindex2])
                                    currentexample.labels[firstwordindex]='B-ORG'
                                    

                                #Regular insertion 
                                else:
                                        
                                
                                    for word,label in zip(phrases[randomindex2],labels2[randomindex2]):
                                        currentexample.labels[currentexample.words.index(selword)]='I-ORG'
                                        currentexample.words.insert(currentexample.words.index(selword),word)
                                        currentexample.labels.insert(currentexample.words.index(selword),label)
                                    firstwordindex=currentexample.words.index(selword)-len(phrases[randomindex2])
                                    currentexample.labels[firstwordindex]='B-ORG'


                            elif placement[randomindex2]=='after':
                                if "B-LOC" == currentexample.labels[currentexample.words.index(selword)]:
                                    if currentexample.words[currentexample.words.index(selword)-1].lower()=="in":
                                        if currentexample.words.index(selword)-1==0:
                                            currentexample.words[currentexample.words.index(selword)-1]="At"
                                        else:
                                            currentexample.words[currentexample.words.index(selword)-1]="at"
                                if currentexample.words.index(selword)==0:
                                    currentexample.words.insert(currentexample.words.index(selword),"The")
                                    currentexample.labels.insert(currentexample.words.index(selword)-1,"O")
                                else:
                                    currentexample.words.insert(currentexample.words.index(selword),"the")
                                    currentexample.labels.insert(currentexample.words.index(selword)-1,"O")


                                
            
                                #Sport phrase code
                                if  "<x>" in phrases[randomindex2]:
                                    exwords = [word.lower() for word in currentexample.words]
                                    insports=False
                            
                                    for sport in sportnames:                               
                                        if sport.lower() in exwords :
                                            insports=True
                                            sportphrase1=phrases[randomindex2]
                                            sportphrase=copy.deepcopy(sportphrase1)
                                            pindex=sportphrase.index("<x>")
                                            
                                            sportphrase[pindex]=sport
                                        
                                        
                                            break
                                    if insports==False:
                                        randomsportindex=random.randint(0,len(sportnames)-1) 
                                        sportphrase1=phrases[randomindex2]
                                        sportphrase=copy.deepcopy(sportphrase1)
                                        pindex=phrases[randomindex2].index("<x>")
                                        sportphrase[pindex]=sportnames[randomsportindex]

                                    for word,label in zip(sportphrase,labels2[randomindex2]):
                                    

                                        currentexample.labels[currentexample.words.index(selword)]='I-ORG'
                                        currentexample.words.insert(currentexample.words.index(selword),word)
                                        currentexample.labels.insert(currentexample.words.index(selword),label)
                                    firstwordindex=currentexample.words.index(selword)-len(phrases[randomindex2])
                                    currentexample.labels[firstwordindex]='B-ORG'
                                    

                                #Regular insertion 
                                else:
                        
                                    k=1
                                    for word,label in zip(phrases[randomindex2],labels2[randomindex2]):
                                        currentexample.words.insert(currentexample.words.index(selword)+k,word)
                                        currentexample.labels.insert(currentexample.words.index(selword)+k,label)
                                        currentexample.labels[currentexample.words.index(selword)]='B-ORG'
                                        k=k+1
                                        
                            exampleguidsalreadyaugmented.append(currentexample.guid)
                            augexamples.append(currentexample)
                            num_of_aug_examples_generated=num_of_aug_examples_generated+1
                            

                            break

                index=index+1
                
                if num_of_aug_examples_generated>=num_of_aug_examples_to_generate_for_aug_type_4:
                    num_of_aug_examples_generated=0
                    break
        logger.info("Number of augmented examples: %s", len(augexamples))
        print("Number of augmented examples",len(augexamples))
        file_name="zeroshotaugexamples100percent_no_held_out_phrases.pkl"

        file_path = os.path.join(args.data_dir, file_name)
        
        open_file = open(file_path, "wb")
        pickle.dump(augexamples, open_file)
        open_file.close()
        return len(augexamples)

def read_data_rule_based_aug(args,amount, percentagename,tokenizer, labels, pad_token_label_id, mode,  
              omit_sep_cls_token=False,
              pad_subtoken_with_real_label=False):

    examples = read_examples_from_file_excel("data/conll2003",mode)  #examples = read_examples_from_file_excel("data/conll2003",mode)
    augexamples=[]
    totalexamplelist=copy.deepcopy(examples)
    examples_for_rule_based_aug_training=copy.deepcopy(examples)
    
    if mode =='train':

    
        
        
        import pickle 
        

        
       

        num_of_aug_examples_to_generate_minus_loc_transition=amount-70
        num_of_aug_examples_to_generate=num_of_aug_examples_to_generate_minus_loc_transition/3
        num_of_aug_examples_to_generate=round(num_of_aug_examples_to_generate)

        num_of_aug_examples_to_generate_for_aug_type_1=num_of_aug_examples_to_generate
        num_of_aug_examples_to_generate_for_aug_type_3=num_of_aug_examples_to_generate
        num_of_aug_examples_to_generate_for_aug_type_4=num_of_aug_examples_to_generate

        augruniter=0
       
        num_of_aug_examples_generated=0
        

         
        #Pattern 1
    
        data4 = pd.read_excel ('data/conll2003/Sports.xlsx') 
        sportnames=list(data4['Names'])
        data = pd.read_excel ('data/conll2003/ORG Entity Phrases.xlsx') 
        df=data['Pattern1']
        placement=list(data['Placement'])
        phrasedomains=list(data['Domain'])
        compatibility=list(data['Compatibility'])
        labelsdf=list(data['Labels'])
        phrases=[]
        labels2=[]
        dfindex=0
        exampleindex=0
        while dfindex< len(df.index):
            phrases.append(df[dfindex].split(","))
            dfindex=dfindex+1

        for label in labelsdf:
            labels2.append(label.replace(" ", "").split("|"))


        examplesguids=[]
        selectedwords=[]
        exampleguidsalreadyaugmented=[]
        iccands=[]
        iccandbool=False

        i=0
        selectedword=''
        pattern1orgtransitioncounter=0
    

        for ex in examples:
            #if len(examplesguids)==400:
            #    break
            e=random.choice(examples)
            for word,label in zip(e.words,e.labels):
                if 'B-PER' in label or 'B-LOC' in label and 'B-MISC' not in label:
                    iccandbool=True
                    selectedword=word            
                
                elif iccandbool==True and label=='O' and i==1:
                    
                    if e.guid not in examplesguids:

                        examplesguids.append(e.guid)
                        selectedwords.append(selectedword)
                

                    break 
                
                if iccandbool==True:                
                    i=i+1        
            pattern1orgtransitioncounter=pattern1orgtransitioncounter+1
            iccandbool=False
            i=0


        #Code for putting data into examples instead of cand list of tuples
        index=0
        
        for ex in examples:
            
            if ex.guid in examplesguids and ex.guid not in exampleguidsalreadyaugmented:
            
                for guid2,selword in zip(examplesguids,selectedwords):
                    if ex.guid==guid2:
                        
                    

                        randomindex=random.randint(0,len(phrases)-1)
                    
                        
                        if placement[randomindex]=='before':
                            currentexample=copy.deepcopy(examples[index])
                            if "B-LOC" == currentexample.labels[currentexample.words.index(selword)]:
                                if currentexample.words[currentexample.words.index(selword)-1].lower()=="in":
                                    if currentexample.words.index(selword)-1==0:
                                        currentexample.words[currentexample.words.index(selword)-1]="At"
                                    else:
                                        currentexample.words[currentexample.words.index(selword)-1]="at"
                                        o=0
                                    o=0
                            if currentexample.words.index(selword)==0:
                                currentexample.words.insert(currentexample.words.index(selword),"The")
                                currentexample.labels.insert(currentexample.words.index(selword)-1,"O")
                            else:
                                currentexample.words.insert(currentexample.words.index(selword),"the")
                                currentexample.labels.insert(currentexample.words.index(selword)-1,"O")
                            
                        
                            #Sport phrase code
                            if  "<x>" in phrases[randomindex]:
                                exwords = [word.lower() for word in currentexample.words]
                                insports=False
                                
                                for sport in sportnames:                               
                                    if sport.lower() in exwords :
                                        insports=True
                                        sportphrase1=phrases[randomindex]
                                        sportphrase=copy.deepcopy(sportphrase1)
                                        pindex=sportphrase.index("<x>")
                                        
                                        sportphrase[pindex]=sport
                                        
                                    
                                        break
                                if insports==False:
                                    randomsportindex=random.randint(0,len(sportnames)-1) 
                                    sportphrase1=phrases[randomindex]
                                    sportphrase=copy.deepcopy(sportphrase1)
                                    pindex=phrases[randomindex].index("<x>")
                                    sportphrase[pindex]=sportnames[randomsportindex]

                                for word,label in zip(sportphrase,labels2[randomindex]):
                                

                                    currentexample.labels[currentexample.words.index(selword)]='I-ORG'
                                    currentexample.words.insert(currentexample.words.index(selword),word)
                                    currentexample.labels.insert(currentexample.words.index(selword),label)
                                firstwordindex=currentexample.words.index(selword)-len(phrases[randomindex])
                                currentexample.labels[firstwordindex]='B-ORG'
                                r=0
                            #Regular insertion 
                            else: 
                            
                                for word,label in zip(phrases[randomindex],labels2[randomindex]):
                                    

                                    currentexample.labels[currentexample.words.index(selword)]='I-ORG'
                                    currentexample.words.insert(currentexample.words.index(selword),word)
                                    currentexample.labels.insert(currentexample.words.index(selword),label)
                                firstwordindex=currentexample.words.index(selword)-len(phrases[randomindex])
                                currentexample.labels[firstwordindex]='B-ORG'
                            
                                

                        elif placement[randomindex]=='after':
                            currentexample=copy.deepcopy(examples[index])
                            if "B-LOC" == currentexample.labels[currentexample.words.index(selword)]:
                                if currentexample.words[currentexample.words.index(selword)-1].lower()=="in":
                                    if currentexample.words.index(selword)-1==0:
                                        currentexample.words[currentexample.words.index(selword)-1]="At"
                                    else:
                                        currentexample.words[currentexample.words.index(selword)-1]="at"
                            if currentexample.words.index(selword)==0:
                                currentexample.words.insert(currentexample.words.index(selword),"The")
                                currentexample.labels.insert(currentexample.words.index(selword)-1,"O")
                            else:
                                currentexample.words.insert(currentexample.words.index(selword),"the")
                                currentexample.labels.insert(currentexample.words.index(selword)-1,"O")


                            #Sport phrase code
                            if  "<x>" in phrases[randomindex]:
                                exwords = [word.lower() for word in currentexample.words]
                                insports=False
                                
                                for sport in sportnames:                               
                                    if sport.lower() in exwords :
                                        insports=True
                                        sportphrase1=phrases[randomindex]
                                        sportphrase=copy.deepcopy(sportphrase1)
                                        pindex=sportphrase.index("<x>")
                                        
                                        sportphrase[pindex]=sport
                            
                                    
                                        break
                                if insports==False:
                                    randomsportindex=random.randint(0,len(sportnames)-1) 
                                    sportphrase1=phrases[randomindex]
                                    sportphrase=copy.deepcopy(sportphrase1)
                                    pindex=phrases[randomindex].index("<x>")
                                    sportphrase[pindex]=sportnames[randomsportindex]

                                for word,label in zip(sportphrase,labels2[randomindex]):
                                

                                    currentexample.labels[currentexample.words.index(selword)]='I-ORG'
                                    currentexample.words.insert(currentexample.words.index(selword),word)
                                    currentexample.labels.insert(currentexample.words.index(selword),label)
                                firstwordindex=currentexample.words.index(selword)-len(phrases[randomindex])
                                currentexample.labels[firstwordindex]='B-ORG'
                                

                            #Regular insertion 
                            else:                        
                    
                            
                                #Start inserting words
                                k=1
                                for word,label in zip(phrases[randomindex],labels2[randomindex]):
                                    currentexample.words.insert(currentexample.words.index(selword)+k,word)
                                    currentexample.labels.insert(currentexample.words.index(selword)+k,label)
                                    currentexample.labels[currentexample.words.index(selword)]='B-ORG'
                                    k=k+1
                        augexamples.append(currentexample)
                        exampleguidsalreadyaugmented.append(ex.guid)
                        if num_of_aug_examples_generated>=num_of_aug_examples_to_generate_for_aug_type_1:
                            num_of_aug_examples_generated=0
                            break
                        num_of_aug_examples_generated=num_of_aug_examples_generated+1
                        break
                        
            if num_of_aug_examples_generated>=num_of_aug_examples_to_generate_for_aug_type_1:
                num_of_aug_examples_generated=0
                break
                        

            index=index+1
            
   
        
        dataloc = pd.read_excel ('data/conll2003/Transition To LOC.xlsx') 
        dfloc=list(dataloc['Pattern1'])
        dataloc2 = pd.read_excel ('data/conll2003/LOC Context Phrases.xlsx') 
        phrasesdfloc2=list(dataloc2['Pattern2'])
        labelsloc2=list(dataloc2['EntityType'])
        
        loclabels=[]
        phrasesloc=[]
        phrasesloc2=[]
        for locword in phrasesdfloc2:
            phrasesloc2.append(locword.split("|"))

        for label in labelsloc2:
            loclabels.append(label.replace(" ", "").split("|"))     
        
        for loc in dfloc:
            phrasesloc.append(loc.split(","))

        placementloc=list(dataloc['Placement'])

        labels3=[]
        labelset=[]
        for phrase, place in zip(phrasesloc, placementloc):
            if place=='before':
                first=True
                for word in phrase:
                    if first==True:
                        labelset.append('B-ORG')
                    else:
                        labelset.append('I-ORG')
            elif place=='after':
                
                for word in phrase:
                    labelset.append('I-ORG')
            labels3.append(labelset)
            labelset=[]
        index=0
        toloccands=[]
        iccandbool=False
        toloccounter=0
        selectedword=''
        pattern1loctransitioncounter=0

            
        for e in examples:
            if toloccounter==70:
                break

            if e.guid not in exampleguidsalreadyaugmented:
                for phrase,placementlociter in zip(phrasesloc,placementloc):
                    
                    if(set(phrase).issubset(set( e.words))):
                        
                        if placementlociter=='after':
                            
                            entityindex=e.words.index(phrase[-1])+1

                            if e.labels[entityindex]=='I-ORG':
                                #match e guid with example guid
                                index=0
                                
                                
                                for ex in examples:
                                    if ex.guid==e.guid:
                                        currentexample=copy.deepcopy(examples[index])
                                        
                                        currentexample.labels[entityindex]='B-LOC'
                                        k=1
                                        
                                        for word in phrase:
                                            
                                            currentexample.words.pop(entityindex-k)
                                            currentexample.labels.pop(entityindex-k)
                                            
                                            k=k+1
                                        
                                        if currentexample.labels[currentexample.labels.index('B-LOC')+1]=='I-ORG':
                                            currentexample.labels[currentexample.labels.index('B-LOC')+1]='I-LOC'
                                        
                                        #pattern 2
                                        randomindex=random.randint(0,len(phrasesdfloc2)-1)
                                        selword=currentexample.labels.index('B-LOC')

                                        k=1
                                        
                                        for word,label in zip(phrasesloc2[randomindex],loclabels[randomindex]):
                                            currentexample.words.insert(selword+k,word)
                                            currentexample.labels.insert(selword+k,label)
                                            
                                            k=k+1

                                        augexamples.append(currentexample)
                                        #num_of_aug_examples_generated=num_of_aug_examples_generated+1
                                        exampleguidsalreadyaugmented.append(ex.guid)
                                        toloccounter=toloccounter+1
                                        break
                                        
                                        
                                    index=index+1
                            else:
                                break

                        elif placementlociter=='before':
                            entityindex=e.words.index(phrase[0])-1
                            if e.labels[entityindex]=='B-ORG':
                                #match e guid with example guid
                                index=0
                                
                                
                                for ex in examples:
                                    if ex.guid==e.guid:
                                        currentexample=copy.deepcopy(examples[index])
                                        
                                        currentexample.labels[entityindex]='B-LOC'
                                        
                                        for word in phrase:
                                            
                                            currentexample.words.pop(entityindex+1)
                                            currentexample.labels.pop(entityindex+1)
                                        
                                        if currentexample.labels[currentexample.labels.index('B-LOC')+1]=='I-ORG':
                                            currentexample.labels[currentexample.labels.index('B-LOC')+1]='I-LOC'
                                        
                                        #pattern 2
                                        randomindex=random.randint(0,len(phrasesdfloc2)-1)
                                        selword=currentexample.labels.index('B-LOC')

                                        k=1
                                        
                                        for word,label in zip(phrasesloc2[randomindex],loclabels[randomindex]):
                                            currentexample.words.insert(selword+k,word)
                                            currentexample.labels.insert(selword+k,label)
                                            
                                            k=k+1

                                        augexamples.append(currentexample)
                                        
                                        exampleguidsalreadyaugmented.append(ex.guid)
                                        #num_of_aug_examples_generated=num_of_aug_examples_generated+1
                                        toloccounter=toloccounter+1
                                        break
                                        
                                        
                                    index=index+1

                            else:
                                break


                    

                    if e.guid in exampleguidsalreadyaugmented:
                        break
                        
          

                            
        
  
        dataPER = pd.read_excel ('data/conll2003/PER Context Phrases.xlsx') 
        dataPER2 = pd.read_excel ('data/conll2003/PER Headline Phrases.xlsx') 
        dataORG = pd.read_excel ('data/conll2003/ORG Context Phrases.xlsx') 
        dfPER2=list(dataPER2['Pattern2'])
        labelsPER2=list(dataPER2['EntityType'])
        compatibility2=list(dataORG['Compatibility'])

        dataNames = pd.read_excel ('data/conll2003/Names.xlsx') 
        dfNames=list(dataNames['Names'])
        dfPER=list(dataPER['Pattern2'])
        perphrasedomains=list(dataPER['Domain'])
        orgphrasedomains=list(dataORG['Domain'])
        dfORG=list(dataORG['Pattern2'])
        labelsPER=list(dataPER['EntityType'])
        labelsORG=list(dataORG['EntityType'])
        perphrases=[] 
        orgphrases=[]
        perlabels=[]  
        orglabels=[]
        perphrases2=[] 
        perlabels2=[] 


        for per2 in dfPER2:
            perphrases2.append(per2.split("|"))

        for perlabel2 in labelsPER2:
            perlabels2.append(perlabel2.replace(" ", "").split("|"))

        for per in dfPER:
            perphrases.append(per.split("|"))
        
        for org in dfORG:
            orgphrases.append(org.split("|"))

        for label in labelsPER:
            perlabels.append(label.replace(" ", "").split("|"))
        for label in labelsORG:
            orglabels.append(label.replace(" ", "").split("|"))

        percands=[]
        orgcands=[]
        iccandbool=False
        i=0
        selectedword=''
        numofexamples=0
        pertransitioncounter=0
        orgtransitioncounter=0
        originallabel=''
        examplesguids=[]
        selectedwords=[]
        #while  pertransitioncounter <30:
        for ex in examples:
            #if len(examplesguids)==400:
            #    break
            e=random.choice(examples)
            if  e.guid not in exampleguidsalreadyaugmented:

                for word,label in zip(e.words,e.labels):
                                
                    nextwordinddex=e.words.index(word)+1
                    if nextwordinddex<len(e.words):
                        w=e.words[nextwordinddex]
                    else:
                        w=''

                    if 'B-ORG' in label or 'B-LOC' in label  and 'B-MISC' not in label and ')' not in w and '.' not in word:
                        iccandbool=True
                        selectedword=word
                        originallabel=label
                    elif iccandbool==True and label=='O' and i==1:
                        if e.guid not in examplesguids:
                        
                            examplesguids.append(e.guid)
                            selectedwords.append(selectedword)

                        break 
                    
                    if iccandbool==True:                
                        i=i+1        
                
                iccandbool=False
                i=0
                pertransitioncounter=pertransitioncounter+1
        index=0
        for ex in examples:
                
                if ex.guid in examplesguids and ex.guid not in exampleguidsalreadyaugmented:
                
                    for guid2,selword in zip(examplesguids,selectedwords):
                        if ex.guid==guid2:
                                            
                            
                            currentexample=copy.deepcopy(examples[index])
                            numbercount = sum(entry.isdigit() for entry in currentexample.words)
                            if len(currentexample.words)<4 and 'AT' not in currentexample.words or numbercount>3:
                                randomindex=random.randint(0,len(dfPER2)-1)
                                k=0
                                for word,label in zip(perphrases2[randomindex],perlabels2[randomindex]):
                                    currentexample.words.insert(0+k,word)
                                    currentexample.labels.insert(0+k,label)
                                    currentexample.labels[currentexample.words.index(selword)]='B-PER'
                                    k=k+1
                            else:
                                randomindex=random.randint(0,len(dfPER)-1)        
                                k=1
                                for word,label in zip(perphrases[randomindex],perlabels[randomindex]):
                                    currentexample.words.insert(currentexample.words.index(selword)+k,word)
                                    currentexample.labels.insert(currentexample.words.index(selword)+k,label)
                                    currentexample.labels[currentexample.words.index(selword)]='B-PER'
                                    k=k+1

                            #Pattern 1
                            randomindex3=random.randint(0,len(dfNames)-1)
                            currentexample.words.insert(currentexample.words.index(selword)+1,dfNames[randomindex3])
                            currentexample.labels.insert(currentexample.words.index(selword)+1,'I-PER')
                            augexamples.append(currentexample)
                            exampleguidsalreadyaugmented.append(currentexample.guid)
                            if num_of_aug_examples_generated>=num_of_aug_examples_to_generate_for_aug_type_3:
                                num_of_aug_examples_generated=0
                                break
                            num_of_aug_examples_generated=num_of_aug_examples_generated+1
                            
                            break

                index=index+1
                if num_of_aug_examples_generated>=num_of_aug_examples_to_generate_for_aug_type_3:
                    num_of_aug_examples_generated=0
                    break
        examplesguids=[]    
        selectedwords=[]

        #if num_of_aug_examples_generated<num_of_aug_examples_to_generate:
            
        
        for ex in examples:
            #if len(examplesguids)==400:
            #     break
            e=random.choice(examples)
            if  e.guid not in exampleguidsalreadyaugmented:
                for word,label in zip(e.words,e.labels):
                    if 'B-PER' in label or 'B-LOC' in label  and 'B-MISC' not in label:
                        iccandbool=True
                        selectedword=word
                        originallabel=label
                    elif iccandbool==True and label=='O' and i==1:
                    
                        if e.guid not in examplesguids:
                            examplesguids.append(e.guid)
                            selectedwords.append(selectedword)

                        break 
                    
                    if iccandbool==True:                
                        i=i+1        
                
                iccandbool=False
                i=0
                orgtransitioncounter=orgtransitioncounter+1

        j=0
        index=0
        for ex in examples:
                
                if ex.guid in examplesguids and ex.guid not in exampleguidsalreadyaugmented:
                    
                    for guid2,selword in zip(examplesguids,selectedwords):
                        if ex.guid==guid2:
                        

                            
                            currentexample=copy.deepcopy(examples[index])
                                                        
                            randomindex2=random.randint(0,len(phrases)-1)
                            randomindex=random.randint(0,len(orgphrases)-1)

                                    
                            k=1
                            for word,label in zip(orgphrases[randomindex],orglabels[randomindex]):
                                currentexample.words.insert(currentexample.words.index(selword)+k,word)
                                currentexample.labels.insert(currentexample.words.index(selword)+k,label)
                                currentexample.labels[currentexample.words.index(selword)]='B-ORG'
                                k=k+1
                    
                            
                            if placement[randomindex2]=='before':

                                if "B-LOC" == currentexample.labels[currentexample.words.index(selword)]:
                                    if currentexample.words[currentexample.words.index(selword)-1].lower()=="in":
                                        if currentexample.words.index(selword)-1==0:
                                            currentexample.words[currentexample.words.index(selword)-1]="At"
                                        else:
                                            currentexample.words[currentexample.words.index(selword)-1]="at"
                                if currentexample.words.index(selword)==0:
                                    currentexample.words.insert(currentexample.words.index(selword),"The")
                                    currentexample.labels.insert(currentexample.words.index(selword)-1,"O")
                                else:
                                    currentexample.words.insert(currentexample.words.index(selword),"the")
                                    currentexample.labels.insert(currentexample.words.index(selword)-1,"O")



                                
                                #Sport phrase code
                                if  "<x>" in phrases[randomindex2]:
                                    exwords = [word.lower() for word in currentexample.words]
                                    insports=False
                                    
                                    for sport in sportnames:                               
                                        if sport.lower() in exwords :
                                            insports=True
                                            sportphrase1=phrases[randomindex2]
                                            sportphrase=copy.deepcopy(sportphrase1)
                                            pindex=sportphrase.index("<x>")
                                            
                                            sportphrase[pindex]=sport
                                    
                                        
                                            break
                                    if insports==False:
                                        randomsportindex=random.randint(0,len(sportnames)-1) 
                                        sportphrase1=phrases[randomindex2]
                                        sportphrase=copy.deepcopy(sportphrase1)
                                        pindex=phrases[randomindex2].index("<x>")
                                        sportphrase[pindex]=sportnames[randomsportindex]

                                    for word,label in zip(sportphrase,labels2[randomindex2]):
                                    

                                        currentexample.labels[currentexample.words.index(selword)]='I-ORG'
                                        currentexample.words.insert(currentexample.words.index(selword),word)
                                        currentexample.labels.insert(currentexample.words.index(selword),label)
                                    firstwordindex=currentexample.words.index(selword)-len(phrases[randomindex2])
                                    currentexample.labels[firstwordindex]='B-ORG'
                                    

                                #Regular insertion 
                                else:
                                        
                                
                                    for word,label in zip(phrases[randomindex2],labels2[randomindex2]):
                                        currentexample.labels[currentexample.words.index(selword)]='I-ORG'
                                        currentexample.words.insert(currentexample.words.index(selword),word)
                                        currentexample.labels.insert(currentexample.words.index(selword),label)
                                    firstwordindex=currentexample.words.index(selword)-len(phrases[randomindex2])
                                    currentexample.labels[firstwordindex]='B-ORG'


                            elif placement[randomindex2]=='after':
                                if "B-LOC" == currentexample.labels[currentexample.words.index(selword)]:
                                    if currentexample.words[currentexample.words.index(selword)-1].lower()=="in":
                                        if currentexample.words.index(selword)-1==0:
                                            currentexample.words[currentexample.words.index(selword)-1]="At"
                                        else:
                                            currentexample.words[currentexample.words.index(selword)-1]="at"
                                if currentexample.words.index(selword)==0:
                                    currentexample.words.insert(currentexample.words.index(selword),"The")
                                    currentexample.labels.insert(currentexample.words.index(selword)-1,"O")
                                else:
                                    currentexample.words.insert(currentexample.words.index(selword),"the")
                                    currentexample.labels.insert(currentexample.words.index(selword)-1,"O")


                                
            
                                #Sport phrase code
                                if  "<x>" in phrases[randomindex2]:
                                    exwords = [word.lower() for word in currentexample.words]
                                    insports=False
                            
                                    for sport in sportnames:                               
                                        if sport.lower() in exwords :
                                            insports=True
                                            sportphrase1=phrases[randomindex2]
                                            sportphrase=copy.deepcopy(sportphrase1)
                                            pindex=sportphrase.index("<x>")
                                            
                                            sportphrase[pindex]=sport
                                        
                                        
                                            break
                                    if insports==False:
                                        randomsportindex=random.randint(0,len(sportnames)-1) 
                                        sportphrase1=phrases[randomindex2]
                                        sportphrase=copy.deepcopy(sportphrase1)
                                        pindex=phrases[randomindex2].index("<x>")
                                        sportphrase[pindex]=sportnames[randomsportindex]

                                    for word,label in zip(sportphrase,labels2[randomindex2]):
                                    

                                        currentexample.labels[currentexample.words.index(selword)]='I-ORG'
                                        currentexample.words.insert(currentexample.words.index(selword),word)
                                        currentexample.labels.insert(currentexample.words.index(selword),label)
                                    firstwordindex=currentexample.words.index(selword)-len(phrases[randomindex2])
                                    currentexample.labels[firstwordindex]='B-ORG'
                                    

                                #Regular insertion 
                                else:
                        
                                    k=1
                                    for word,label in zip(phrases[randomindex2],labels2[randomindex2]):
                                        currentexample.words.insert(currentexample.words.index(selword)+k,word)
                                        currentexample.labels.insert(currentexample.words.index(selword)+k,label)
                                        currentexample.labels[currentexample.words.index(selword)]='B-ORG'
                                        k=k+1
                                        
                            exampleguidsalreadyaugmented.append(currentexample.guid)
                            augexamples.append(currentexample)
                            if num_of_aug_examples_generated>=num_of_aug_examples_to_generate_for_aug_type_4:
                                num_of_aug_examples_generated=0
                                break
                            num_of_aug_examples_generated=num_of_aug_examples_generated+1
                            

                            break

                index=index+1
                
                if num_of_aug_examples_generated>=num_of_aug_examples_to_generate_for_aug_type_4:
                    num_of_aug_examples_generated=0
                    break


        logger.info("Number of augmented examples: %s", len(augexamples))
        file_name=percentagename+".pkl"
        open_file = open(file_name, "wb")
        pickle.dump(augexamples, open_file)
        open_file.close()
     

def generate_aug_percentage_categories(args, tokenizer, labels, pad_token_label_id, mode,  
              omit_sep_cls_token=False,
              pad_subtoken_with_real_label=False):   
    countaug=count_rule_based_aug(args, tokenizer, labels, pad_token_label_id, mode = 'train', pad_subtoken_with_real_label=args.pad_subtoken_with_real_label)
         
    
    countaug5percent=countaug*0.05
    file_name = "zeroshotaugexamples5percent_no_held_out_phrases"
    read_data_rule_based_aug(args, countaug5percent,file_name,tokenizer, labels, pad_token_label_id, mode = 'train', pad_subtoken_with_real_label=args.pad_subtoken_with_real_label) 

    countaug10percent=countaug*0.1
    file_name = "zeroshotaugexamples10percent_no_held_out_phrases"
    read_data_rule_based_aug(args, countaug10percent,file_name,tokenizer, labels, pad_token_label_id, mode = 'train', pad_subtoken_with_real_label=args.pad_subtoken_with_real_label) 

    countaug30percent=countaug*0.3
    file_name = "zeroshotaugexamples30percent_no_held_out_phrases"
    read_data_rule_based_aug(args, countaug30percent,file_name,tokenizer, labels, pad_token_label_id, mode = 'train', pad_subtoken_with_real_label=args.pad_subtoken_with_real_label)

    countaug50percent=countaug*0.5
    file_name = "zeroshotaugexamples50percent_no_held_out_phrases"
    read_data_rule_based_aug(args, countaug50percent,file_name,tokenizer, labels, pad_token_label_id, mode = 'train', pad_subtoken_with_real_label=args.pad_subtoken_with_real_label)

        
   
    
    
    
def linear_rampup(current, rampup_length=args.num_train_epochs):
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current / rampup_length, 0.0, 1.0)
        return float(current)


def train(args,train_dataset, eval_dataset, test_dataset_regular,test_dataset_challenging, model, tokenizer, labels, pad_token_label_id):
    
    global best_f1
    tb_writer = SummaryWriter()
    print('tb_writer.logdir',tb_writer.logdir)
    
    train_dataloader = DataLoader(train_dataset, batch_size = args.batch_size, shuffle = True)
    labeled_dataloader = train_dataloader


    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]


    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)
    if args.optimizer=='adam':
        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    elif args.optimizer=='sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
        
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )    
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info(
        "  Total train batch size (w. parallel, accumulation) = %d",
        args.batch_size
        * args.gradient_accumulation_steps),
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0

    tr_loss, logging_loss = 0.0, 0.0

    
    #eval_f1 = []
    test_f1 = []
    test_f1_regular = []
    test_f1_challenging = []
    model.zero_grad()

    train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc='Epoch')
    set_seed(args)



    for epoch in train_iterator:

        #Make sure batch tuple indeces match ones in train
        print("Epoch #",epoch)
        epoch_iterator = tqdm(train_dataloader, desc="Iteration")        
        for step, batch in enumerate(epoch_iterator):
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue
            model.train()
            
            batch = tuple(t.to(args.device) for t in batch)

            #inputs_c = {"input_ids": batch[0].to(args.device),"attention_mask": batch[1].to(args.device),'subtoken_ids':batch[4].to(args.device)}
            #target_c=batch[3].to(args.device) 

            b_input_ids = batch[0].to(device)
            b_input_mask = batch[1].to(device)
            b_labels = batch[3].to(device)         
                                    
            model.zero_grad()
          
            result = model(b_input_ids, 
                        token_type_ids=None, 
                        attention_mask=b_input_mask, 
                        labels=b_labels)
            
            loss = result[0]
           
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            
            loss.backward()
            tr_loss += loss.item()

            if (step + 1) % args.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                optimizer.step()
                scheduler.step()  
                model.zero_grad()
                global_step += 1

                if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if (args.evaluate_during_training):
                        
                        results, _ = evaluate(args, model, tokenizer, labels, pad_token_label_id, eval_dataset, parallel = False, mode="dev", prefix = str(global_step))
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value, global_step)
                        logger.info("Model name: %s", args.output_dir)
                        logger.info("Epoch is %s", epoch)
                        if results['f1'] >= best_f1:
                            best_f1 = results['f1']
                            results, _ = evaluate(args, model, tokenizer, labels, pad_token_label_id, test_dataset_regular, parallel = False, mode="test", prefix = str(global_step))
                            test_f1_regular.append(results['f1'])
                            results, _ = evaluate(args,model, tokenizer, labels, pad_token_label_id, test_dataset_challenging, parallel = False, mode="test", prefix = str(global_step))
                            test_f1_challenging.append(results['f1'])
                            
                            
                            output_dir = os.path.join(args.output_dir, "best")
                            #epochinfo=os.path.join(output_dir, "bestepochinfo.txt")
                            #epochoutput = open(epochinfo, 'a')
                            #print("Best Epoch #",epoch, file = epochoutput)
                            #epochoutput.close()
                            if not os.path.exists(output_dir):
                                os.makedirs(output_dir)
                            logger.info("Saving best model to %s", output_dir)
                            logger.info("Epochs trained is %s", epochs_trained)
                            logger.info("Epoch is %s", epoch)
                            model_to_save = (
                                model.module if hasattr(model, "module") else model)  
                            model_to_save.save_pretrained(output_dir)
                            tokenizer.save_pretrained(output_dir)
                            torch.save(args, os.path.join(output_dir, "training_args.bin"))
                            logger.info("Saving model checkpoint to %s", output_dir)

                            torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                            torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                            logger.info("Saving optimizer and scheduler states to %s", output_dir)
                    
                    tb_writer.add_scalar("lr", scheduler.get_last_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
                    logging_loss = tr_loss

                    logger.info("logging train info!!!")
                    logger.info("*")



             
            
        # eval and save the best model based on dev set after each epoch
        if (args.evaluate_during_training):
            
            results, _ = evaluate(args, model, tokenizer, labels, pad_token_label_id, eval_dataset, parallel = False, mode="dev", prefix = str(global_step))
            for key, value in results.items():
                tb_writer.add_scalar("eval_{}".format(key), value, global_step)

            
            if results['f1'] >= best_f1:
                best_f1 = results['f1']
                results, _ = evaluate(args, model, tokenizer, labels, pad_token_label_id, test_dataset_regular, parallel = False, mode="test", prefix = str(global_step))
                test_f1_regular.append(results['f1'])
                results, _ = evaluate(args,model, tokenizer, labels, pad_token_label_id, test_dataset_challenging, parallel = False, mode="test", prefix = str(global_step))
                test_f1_challenging.append(results['f1'])
                
                output_dir = os.path.join(args.output_dir, "best")
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                logger.info("Epoch #", epoch)
                logger.info("Saving best model to %s", output_dir)
                logger.info("Epochs trained is %s", epochs_trained)
                logger.info("Epoch is %s", epoch)
                model_to_save = (model.module if hasattr(model, "module") else model)  
                model_to_save.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)
                torch.save(args, os.path.join(output_dir, "training_args.bin"))
                logger.info("Saving model checkpoint to %s", output_dir)

                torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                logger.info("Saving optimizer and scheduler states to %s", output_dir)

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break
        logger.info("Epoch is %s", epoch)
    args.tb_writer_logdir=tb_writer.logdir
    tb_writer.close()
    return global_step, tr_loss / global_step, test_f1_regular ,test_f1_challenging


def main():
    global best_f1
    if (os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir):
        raise ValueError( "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
    
    logger.setLevel(log.INFO)
    formatter = log.Formatter("%(asctime)s - %(levelname)s - %(name)s -   %(message)s", datefmt="%m/%d/%Y %H:%M:%S")
    
    if not os.path.exists(args.output_dir):
            os.makedirs(args.output_dir)
            
            
    fh = log.FileHandler(args.output_dir  +'/' + str(args.train_examples)+'-' + 'log.txt')
    fh.setLevel(log.INFO)
    fh.setFormatter(formatter)

    ch = log.StreamHandler()
    ch.setLevel(log.INFO)
    ch.setFormatter(formatter)

    logger.addHandler(ch)
    logger.addHandler(fh)
    
    logger.info("------NEW RUN-----")

    logger.info("device: %s, n_gpu: %s", args.device, args.n_gpu)

    set_seed(args)

    labels = get_labels(args.labels)
    num_labels = len(labels)
    args.num_labels=num_labels

    pad_token_label_id = CrossEntropyLoss().ignore_index





    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]

  
    tokenizer = tokenizer_class.from_pretrained(args.model_name, do_lower_case=args.do_lower_case)

    
    config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name,
        num_labels=num_labels,
    )
    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name if args.tokenizer_name else args.model_name,
        do_lower_case=args.do_lower_case,
    )


    logger.info("Training/evaluation parameters %s", args)
    
    #test_dataset = read_data(args, tokenizer, labels, pad_token_label_id, mode = 'test',
    #                                 pad_subtoken_with_real_label=args.pad_subtoken_with_real_label)
    generate_aug_percentage_categories(args, tokenizer, labels, pad_token_label_id, mode = 'train', pad_subtoken_with_real_label=args.pad_subtoken_with_real_label)        

    




main()

    