import random, os , torch , datetime, torch ,json,tqdm,time,inspect,argparse,re
import numpy as np
os.chdir("/ruletaker/")
from typing import Any, Dict, List, cast
from random import sample
from torch.utils.data import DataLoader,RandomSampler,SequentialSampler , TensorDataset
from torch.nn import CrossEntropyLoss
from transformers import AdamW,get_linear_schedule_with_warmup
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from copy import deepcopy

def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

def format_time(elapsed):
    elapsed_rounded = int(round((elapsed)))
    return str(datetime.timedelta(seconds=elapsed_rounded))
    
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)    

def create_data_set(data,samplenum,batch_size):
    model_arch = 'roberta-large'
    counter=0
    with open(data, 'r') as json_file:
        json_list = list(json_file)

    contexts=[]
    questions=[]
    labels=[]

    rule_set=set()
    all_rules=[]
    for json_str in json_list:

        result = json.loads(json_str)
        for q in result["questions"]:

            contexts.append(result["context"])
            questions.append(q["text"])
            labels.append(q["label"])
            counter+=1
        if counter>samplenum:
            break

    print("size of the context is {} and the size of the questions is {} and the size of the labels is {}".format(len(contexts),len(questions),len(labels)))
    tokenizer = AutoTokenizer.from_pretrained(model_arch)
    # tokenize training data
    train_input_ids = []
    train_attention_masks = []
    for c, h in tqdm.tqdm(zip(contexts, questions)):
        encoded = tokenizer.encode_plus(c,h,max_length=384,truncation=True,return_tensors='pt',padding='max_length')

        train_input_ids.append(encoded['input_ids'])
        train_attention_masks.append(encoded['attention_mask'])

    train_input_ids = torch.cat(train_input_ids, dim=0)
    train_attention_masks = torch.cat(train_attention_masks, dim=0)
    train_labels = torch.LongTensor(labels)

    train_dataset = TensorDataset(train_input_ids, train_attention_masks, train_labels)
    train_dataloader = DataLoader(dataset=train_dataset,sampler=RandomSampler(train_dataset),batch_size=batch_size)
    return train_dataloader

def or_proba(rules_probabilities=[],prof=""):
    probas=[]
    for k in prof.split("OR"):
        prob_hyp=100
        for j in [int(i[-1])-1 for i in re.findall(r"rule\d",k)]:
            prob_hyp*=rules_probabilities[j]/100
        probas.append(prob_hyp)
    from problog.program import PrologString
    from problog import get_evaluatable
    prologstring=""
    for num,i in enumerate(probas):
        prologstring+="a"+str(num)+" . \n"
        prologstring+=str(i/100)+"::b :- a"+str(num)+" . \n"
    prologstring+="query(b). \n"
    p = PrologString(prologstring)

    #print(prologstring)
    return list(get_evaluatable().create_from(p).evaluate().items())[0][1]

def convert_to_adverb(x):

    if x>95:
        return "always",100
    if x>85:
        return "usually",90
    if x>75:
        return "normally",80
    if x>60:
        return "often",65
    if x>40:
        return "sometimes",50
    if x>20:
        return "occasionally",30
    if x>7:
        return "seldom",15
    return "never",0
    

def create_data_set_proba(data,data_meta,samplenum,batch_size,adverb,fake,depth_train,mustrule):
    seed_everything(42)
    os.environ['TOKENIZERS_PARALLELISM'] = "true"
    model_arch = 'roberta-large'
    counter=0
    
    
    import json,random,re
    with open(data, 'r') as json_file:
        json_list = list(json_file)
    with open(data_meta, 'r') as json_file:
        json_list_meta = list(json_file)

    contexts=[]
    questions=[]
    labels=[]
    probas=[]
    depths=[]
    or_used_in_example_list=[]
    rule_proba,prev_connect,contexts_PD,questions_PD,probas_PD=[],[],[],[],[]
    meaw,no_meaw=0,0
    for json_str,meta_str in zip(json_list,json_list_meta):
        
        result = json.loads(json_str)
        result_meta=json.loads(meta_str)
        
        
        

        for i in range(1,20):
            try:
                rrrr=result_meta['questions']["Q"+str(i)]
            except:
                continue
            c=result["context"]
            if " not " in result_meta['questions']["Q"+str(i)]['question']:
                continue
            if int(result_meta['questions']["Q"+str(i)]["QDep"])>5:# or int(result_meta['questions']["Q"+str(i)]["QDep"])==0:
                continue
            #print(result_meta['questions']["Q"+str(i)]['question'])
            help_proba=0
            forced_answer=random.randint(0,1)
            prob_hyp=-1
            rule_used=False
            while((forced_answer and prob_hyp<=50) or (not forced_answer and prob_hyp>=50) or prob_hyp==-1):
                prob_hyp=100
                rules_probabilities=[]
                for i_ppp in range(len(result_meta['rules'])):
                    prob=random.gauss(60+help_proba+10*int(result_meta['questions']["Q"+str(i)]["QDep"]),60)
                    if prob>100:
                        prob=100
                    if prob<0:
                        prob=0
                    _,prob=convert_to_adverb(int(prob*100//100))
                    rules_probabilities.append(prob)
                or_used_in_example=False
                if "OR" in result_meta['questions']["Q"+str(i)]["proofs"]:
                    prob_hyp=or_proba(rules_probabilities=rules_probabilities,prof=result_meta['questions']["Q"+str(i)]["proofs"])*100
                    rule_used=True
                    if prob_hyp==100:
                        pass
                    or_used_in_example=True
                else:
                    for j in [int(i_[-1])-1 for i_ in re.findall(r"rule\d",result_meta['questions']["Q"+str(i)]["proofs"])]:
                        rule_used=True
                        prob_hyp*=rules_probabilities[j]/100
                if (forced_answer and prob_hyp<50):
                    help_proba+=10
                if (not forced_answer and prob_hyp>50):
                    help_proba-=10
                if int(result_meta['questions']["Q"+str(i)]["QDep"])==0:
                    break
                #print(forced_answer,prob_hyp)

                #if not rule_used:
                #    break
                #print(forced_answer,prob_hyp,result_meta['questions']["Q"+str(i)]["QDep"],help_proba,"OR" in result_meta['questions']["Q"+str(i)]["proofs"])
            #if not rule_used:
                #continue
            
            if prob_hyp>50 and rule_used:
                meaw+=1
            elif rule_used:
                no_meaw+=1
            if " not " in result_meta['questions']["Q"+str(i)]['question']:
                prob_hyp=100-prob_hyp
            #print("probability=",prob_hyp)
            c=result["context"]
            for num,j in enumerate(result_meta['rules'].keys()):
                adverb_converted,_=convert_to_adverb(int(rules_probabilities[num]))
                if not adverb:
                    c=c.replace(result_meta['rules'][j]["text"],"With a probability of "+str(int(rules_probabilities[num]))+".00 percent , "+result_meta['rules'][j]["text"])
                else:
                    #print(int(rules_probabilities[num]),convert_to_adverb(int(rules_probabilities[num])))
                    c=c.replace(result_meta['rules'][j]["text"],adverb_converted+" , "+result_meta['rules'][j]["text"])
            #print(c)
            if fake:
                c=result["context"]
            #if (mustrule and rule_used) or not mustrule:
            #print(c)
            contexts.append(c)
            depths.append(result_meta['questions']["Q"+str(i)]["QDep"])
            questions.append(result_meta['questions']["Q"+str(i)]['question'])
            if prob_hyp>50:
                labels.append(True)
            else:
                labels.append(False)
            probas.append(prob_hyp/100)
            or_used_in_example_list.append(or_used_in_example)
            counter+=1
            #print(c,result_meta['questions']["Q"+str(i)]['question'],prob_hyp/100,result_meta['questions']["Q"+str(i)]["QDep"])
            
            
            
        p_=result_meta['allProofs']
        #print(p_)

        probs_dict={}
        cur_index=-1
        for i in range(1,len( p_.split("@") ) ):
            if cur_index==15:
                    break
            for j in p_.split("@")[i][2:].split("]")[:-1]:
                if "OR" in j.split(".")[1]:
                    continue
                prob_hyp=100
                r_s=[int(i[-1])-1 for i in re.findall(r"rule\d",j.split(".")[1])]
                if len(r_s)<1:
                    continue
                if cur_index==15:
                    break
                cur_index+=1
                ignore_this=False
                for j_ in r_s:
                    try:
                        prob_hyp*=rules_probabilities[j_]/100
                    except:
                        ignore_this=True
                        continue
                if ignore_this:
                    continue
                
                if r_s:
                    rule_proba.append(rules_probabilities[r_s[-1]])
                else:
                    rule_proba.append(-1)
                probs_dict[j.split(".")[1].strip("[()")]=cur_index
                index=j.split(".")[1].rfind("rule",0,j.split(".")[1].rfind("rule")-1)
                if index==-1:
                    index=j.split(".")[1].rfind("rule",0,j.split(".")[1].rfind("->")-1)-7
                if index==-1:
                    prev_connect.append(-1)
                else:
                    index+=6
                    final_name=j.split(".")[1][0:index]+")"
                    #print(final_name.strip("[()"))
                    #print(probs_dict)
                    #print("*"*10)
                    if final_name.strip("[()") in probs_dict:
                        
                        
                        prev_connect.append(probs_dict[final_name.strip("[()")])
                    else:
                        prev_connect.append(-1)
                #print("-->",c)
                contexts_PD.append(c)
                questions_PD.append(j.split(".")[0])
                probas_PD.append(prob_hyp)
                
        #print(cur_index)
        if not cur_index==-1:
            for _i_ in range(cur_index+1,16):
                contexts_PD.append(contexts_PD[-1])
                questions_PD.append(questions_PD[-1])
                probas_PD.append(probas_PD[-1])
                prev_connect.append(-1)
                rule_proba.append(-1)
            
            if counter>samplenum:
                break
                
    ######### make small batch
    contexts_PD_batch_4,questions_PD_batch_4,probas_PD_batch_4,prev_connect_batch_4,rule_proba_batch_4=[],[],[],[],[]
    if not depth_train==0:
        even=True
        for i in range(0,len(contexts_PD),16):
            for j in range(0,16):
                cur_node=15-j+i
                if not prev_connect[cur_node]==-1:

                    prev_node=i+prev_connect[cur_node]

                    contexts_PD_batch_4.append(contexts_PD[prev_node])
                    questions_PD_batch_4.append(questions_PD[prev_node])
                    probas_PD_batch_4.append(probas_PD[prev_node])
                    prev_connect_batch_4.append(-1)
                    rule_proba_batch_4.append(rule_proba[prev_node])

                    contexts_PD_batch_4.append(contexts_PD[cur_node])
                    questions_PD_batch_4.append(questions_PD[cur_node])
                    probas_PD_batch_4.append(probas_PD[cur_node])
                    if even:
                        prev_connect_batch_4.append(0)
                    else:
                        prev_connect_batch_4.append(2)
                    even=not even
                    rule_proba_batch_4.append(rule_proba[cur_node])
        if not len(contexts_PD_batch_4)%4==0:
            contexts_PD_batch_4.pop()
            contexts_PD_batch_4.pop()

            questions_PD_batch_4.pop()
            questions_PD_batch_4.pop()

            probas_PD_batch_4.pop()
            probas_PD_batch_4.pop()

            prev_connect_batch_4.pop()
            prev_connect_batch_4.pop()

            rule_proba_batch_4.pop()
            rule_proba_batch_4.pop()
        #print(contexts_PD_batch_4)
        #print(questions_PD_batch_4)
        #print(probas_PD_batch_4)
        #print(prev_connect_batch_4)
        #print(rule_proba_batch_4)
    contexts_PD,questions_PD,probas_PD,prev_connect,rule_proba=contexts_PD_batch_4\
    ,questions_PD_batch_4,probas_PD_batch_4,prev_connect_batch_4,rule_proba_batch_4
    #########
                
    #print(probas_PD)
    #print(prev_connect)
    #print(rule_proba)
    print("size of the context is {} and the size of the questions is {} and the size of the labels is {}".format(len(contexts),len(questions),len(labels)))
    print("number of true",sum(labels),meaw,no_meaw)
    tokenizer = AutoTokenizer.from_pretrained(model_arch)
    # tokenize training data
    train_input_ids = []
    train_attention_masks = []
    for c, h in tqdm.tqdm(zip(contexts, questions)):
        encoded = tokenizer.encode_plus(c,h,max_length=384,truncation=True,return_tensors='pt',padding='max_length')

        train_input_ids.append(encoded['input_ids'])
        train_attention_masks.append(encoded['attention_mask'])

    train_input_ids = torch.cat(train_input_ids, dim=0)
    train_attention_masks = torch.cat(train_attention_masks, dim=0)
    train_labels = torch.LongTensor(labels)
    train_proba = torch.FloatTensor(probas)
    depths = torch.LongTensor(depths)
    or_used_in_example_list = torch.LongTensor(or_used_in_example_list)

    train_dataset = TensorDataset(train_input_ids, train_attention_masks, train_labels ,train_proba,depths,or_used_in_example_list)
    train_dataloader = DataLoader(dataset=train_dataset,sampler=RandomSampler(train_dataset),batch_size=batch_size)
    
    print("size of the context is {} and the size of the questions is {} and the size of the labels is".format(len(contexts_PD),len(questions_PD)))
    # tokenize training data
    train_input_ids_PD = []
    train_attention_masks_PD = []
    for c, h in tqdm.tqdm(zip(contexts_PD, questions_PD)):
        encoded = tokenizer.encode_plus(c,h,max_length=384,truncation=True,return_tensors='pt',padding='max_length')

        train_input_ids_PD.append(encoded['input_ids'])
        train_attention_masks_PD.append(encoded['attention_mask'])
    try:
        train_input_ids_PD = torch.cat(train_input_ids_PD, dim=0)
        train_attention_masks_PD = torch.cat(train_attention_masks_PD, dim=0)
    except:
        train_input_ids_PD=torch.LongTensor(train_input_ids_PD)
        train_attention_masks_PD=torch.LongTensor(train_attention_masks_PD)
    train_proba_PD = torch.FloatTensor(probas_PD)
    
    prev_connect = torch.LongTensor(prev_connect)
    rule_proba = torch.FloatTensor(rule_proba)

    train_dataset_PD = TensorDataset(train_input_ids_PD, train_attention_masks_PD, train_proba_PD ,prev_connect,rule_proba)
    train_dataloader_PD = DataLoader(dataset=train_dataset_PD,sampler=SequentialSampler(train_dataset_PD),batch_size=batch_size)
    
    dist_of_depths=[[0,0] for i in range(6)]
    for i in train_dataset:
        dist_of_depths[i[4].item()][i[2].item()]+=1
    print(dist_of_depths)
    return train_dataloader,train_dataloader_PD