import os
os.chdir("rulebert/RuleBert/")
from utils import create_data_set,flat_accuracy,seed_everything,format_time,create_data_set_proba
from typing import Any, Dict, List, cast
import torch,json
from random import sample
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.nn import CrossEntropyLoss
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup
import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from copy import deepcopy
import time
import inspect
import argparse

seed_everything(42)

parser = argparse.ArgumentParser(description='Run Logical Rulebert')
parser.add_argument('--cuda', dest='cuda_number', default=0, help='cuda number to train the models on',type=int)
parser.add_argument('--chain', dest='chain_number', default=3, help='TODO',type=int)
parser.add_argument('--alpha', dest='alpha', default=3, help='TODO',type=float)
parser.add_argument('--batch', dest='batch_size', default=16, help='batch size for neural network training',type=int)
parser.add_argument('--epoch', dest='cur_epoch', default=6, help='number of epochs you want your model to train on',type=int)
parser.add_argument('--lr', dest='learning_rate', default=1e-5, help='learning rate of the adamW optimiser',type=float)
parser.add_argument('--pd', dest='primaldual', default=False, help='whether or not to use primaldual constriant learning',type=bool)
parser.add_argument('--race', dest='race', default=False, help='whether or not to use primaldual constriant learning',type=bool)
parser.add_argument('--samplenum', dest='samplenum', default=1000000000, help='number of samples to train the model on',type=int)
parser.add_argument('--model_name', dest='model_name', default="modellogical", help='TODO',type=str)
parser.add_argument('--context', dest='context', default=False, help='TODO',type=bool)
parser.add_argument('--adverb', dest='adverb', default=False, help='TODO',type=bool)
parser.add_argument('--fake', dest='fake', default=False, help='TODO',type=bool)
parser.add_argument('--mustrule', dest='mustrule', default=False, help='TODO',type=bool)
args = parser.parse_args()

cude_number=args.cuda_number
chain_number=args.chain_number
batch_size = args.batch_size
epochs = args.cur_epoch
lr = args.learning_rate
apply_PD=args.primaldual
samplenum=args.samplenum

if args.context:
    use_context=True
else:
    use_context=False
out_put_file=open(args.model_name+"testi.txt","w")

def retrieve_name(var):
    callers_local_vars = inspect.currentframe().f_back.f_locals.items()
    return [var_name for var_name, var_val in callers_local_vars if var_val is var]
for i in [cude_number,chain_number,batch_size,epochs,lr,apply_PD,samplenum,out_put_file,use_context,args.race]:
    print(retrieve_name(i),i)
    print(retrieve_name(i),i,file=out_put_file)
    out_put_file.flush()

include_first=False
data_dir = "ruletaker/rule-reasoning-dataset-V2020.2.5.0/original/depth-5/"
if args.race:
    model_arch = 'BASERACE'
else:
    model_arch = "rulebert/RuleBert/0newbase3PD4"
    
max_length = 512
eps = 1e-6
weight_decay =  0.1
warmup_ratio = 0.06
verbose = True
time_step_size = 100

device = torch.device("cuda:"+str(cude_number)) if torch.cuda.is_available() else torch.device('cpu')
print("DEBUG",device)

# read the data
train_dataloader,train_dataloader_PD=create_data_set_proba(data_dir+"test.jsonl",data_dir+"meta-test.jsonl",samplenum,batch_size,args.adverb,False,chain_number,True)

# Load model
print("DEBUG",device)

for model_name_test in [args.model_name]:
    for epoch_i in range(3,5):

        model_arch="ruletaker/"+str(model_name_test)+str(epoch_i)
        try:
            model = AutoModelForSequenceClassification.from_pretrained(model_arch, num_labels=2)
            model = model.to(device)
        except:
            continue
        model.eval()
        for name , param in list(model.named_parameters())[:-36]:
            param.requires_grad = False
        total_steps = len(train_dataloader) * epochs
       
        loss_fct = CrossEntropyLoss(reduction='none')
        training_stats = []
        total_t0 = time.time()
        from torch import nn
        softm = nn.Softmax(dim=1)
        total_step_loss=0
        train_dataloader_PD_iter=iter(train_dataloader_PD)
        alpha=args.alpha
    # ========================================
    #               Training
    # ========================================
        print('Model name:',str(model_name_test)+str(epoch_i))
        print('Model name:',str(model_name_test)+str(epoch_i),file=out_put_file)
        t0 = time.time()
        total_train_loss = 0.0


        t_PD,ac_PD1,ac_PD10,ac_PD25=0,0,0,0
        w_t_PD,w_ac_PD1,w_ac_PD10,w_ac_PD25=0,0,0,0
        
        ac_,t_=[0,0,0,0,0,0,0],[[0,0] for or_i in range(6)]
        bac_PD,bt_PD=[0,0,0,0,0,0,0],[0,0,0,0,0,0,0]
        bac_,bt_=[[0,0] for or_i in range(6)],[[0,0] for or_i in range(6)]
        ac25,ac10,ac1=[[0,0] for or_i in range(6)],[[0,0] for or_i in range(6)],[[0,0] for or_i in range(6)]
        for step, batch in enumerate(train_dataloader):     

            b_input_ids = batch[0].to(device)
            b_input_mask = batch[1].to(device)
            b_labels = batch[2].to(device)
            b_weights = batch[3].to(device)
            b_depths = batch[4].to(device)
            or_used = batch[5].to(device)

            with torch.no_grad():
                o = model(b_input_ids, attention_mask=b_input_mask, labels=b_labels)

            logits = o.logits
            for hh_,yy_,dd_,or_i in zip(softm(logits),b_weights,b_depths,or_used):

                if (hh_[1]>0.5 and yy_>0.5) or (hh_[1]<0.5 and yy_<0.5):
                    bac_[dd_][or_i]+=1
                if abs(hh_[1]-yy_)<0.25:
                    ac25[dd_][or_i]+=1
                if abs(hh_[1]-yy_)<0.10:
                    ac10[dd_][or_i]+=1
                if abs(hh_[1]-yy_)<0.01:
                    ac1[dd_][or_i]+=1
                t_[dd_][or_i]+=1
            loss = torch.mean(loss_fct(logits.view(-1, 2), b_labels.view(-1)) * b_weights)

            total_train_loss += loss.item()
            total_step_loss += loss.item()

        for pd_batch_number in train_dataloader_PD_iter:

            b_input_ids = pd_batch_number[0].to(device)
            b_input_mask = pd_batch_number[1].to(device)
            train_proba_PD = pd_batch_number[2].to(device)
            prev_connect = pd_batch_number[3].to(device)
            rule_proba = pd_batch_number[4].to(device)
            with torch.no_grad():
                o = model(b_input_ids, attention_mask=b_input_mask)
            logits = softm(o.logits)

            loss2=torch.Tensor([0]).to(device)
            for ii in range(len(prev_connect)):
                if not prev_connect[ii]==-1:
                    v=torch.abs(logits[ii][1]-logits[ii-prev_connect[ii]][1]*rule_proba[ii]/100).to(device)
                    #print(v,v<0.2)
                    if v<0.01:
                        ac_PD1+=1
                    if v<0.25:
                        ac_PD10+=1    
                    if v<0.25:
                        ac_PD25+=1
                    t_PD+=1

            for ii in range(len(prev_connect)):
                if not prev_connect[ii]==-1:
                    v=torch.abs(logits[ii][0]-logits[ii-prev_connect[ii]][0]*rule_proba[ii]/100).to(device)
                    #print(v,v<0.2)
                    if v<0.01:
                        w_ac_PD1+=1
                    if v<0.25:
                        w_ac_PD10+=1    
                    if v<0.25:
                        w_ac_PD25+=1
                    w_t_PD+=1
                        



        avg_train_loss = total_train_loss / len(train_dataloader)
        training_time = format_time(time.time() - t0)
        if verbose:
            print("total")
            print("ac25",(sum([sum(i) for i in ac25])+1)/(sum([sum(i) for i in t_])+1),(sum([i[0] for i in ac25])+1)/(sum([i[0] for i in t_])+1),(sum([i[1] for i in ac25])+1)/(sum([i[1] for i in t_])+1))
            print("ac10",(sum([sum(i) for i in ac10])+1)/(sum([sum(i) for i in t_])+1),(sum([i[0] for i in ac10])+1)/(sum([i[0] for i in t_])+1),(sum([i[1] for i in ac10])+1)/(sum([i[1] for i in t_])+1))
            print("ac1",(sum([sum(i) for i in ac1])+1)/(sum([sum(i) for i in t_])+1),(sum([i[0] for i in ac1])+1)/(sum([i[0] for i in t_])+1),(sum([i[1] for i in ac1])+1)/(sum([i[1] for i in t_])+1))
            print("bac",(sum([sum(i) for i in bac_])+1)/(sum([sum(i) for i in t_])+1),(sum([i[0] for i in bac_])+1)/(sum([i[0] for i in t_])+1),(sum([i[1] for i in bac_])+1)/(sum([i[1] for i in t_])+1))
            for dd_ in range(0,6,1):
                print("depth: ",dd_)
                print("ac25",(ac25[dd_][0]+ac25[dd_][1]+1)/(t_[dd_][0]+t_[dd_][1]+1),(ac25[dd_][0]+1)/(t_[dd_][0]+1),(ac25[dd_][1]+1)/(t_[dd_][1]+1))
                print("ac10",(ac10[dd_][0]+ac10[dd_][1]+1)/(t_[dd_][0]+t_[dd_][1]+1),(ac10[dd_][0]+1)/(t_[dd_][0]+1),(ac10[dd_][1]+1)/(t_[dd_][1]+1))
                print("ac1",(ac1[dd_][0]+ac1[dd_][1]+1)/(t_[dd_][0]+t_[dd_][1]+1),(ac1[dd_][0]+1)/(t_[dd_][0]+1),(ac1[dd_][1]+1)/(t_[dd_][1]+1))
                print("bac",(bac_[dd_][0]+bac_[dd_][1]+1)/(t_[dd_][0]+t_[dd_][1]+1),(bac_[dd_][0]+1)/(t_[dd_][0]+1),(bac_[dd_][1]+1)/(t_[dd_][1]+1))
            print(" ")
            print("PD1 AC: ",(ac_PD1+1)/(t_PD+1))
            print("PD10 AC: ",(ac_PD10+1)/(t_PD+1))
            print("PD25 AC: ",(ac_PD25+1)/(t_PD+1))
            
            print("W PD1 AC: ",(w_ac_PD1+1)/(t_PD+1))
            print("W PD10 AC: ",(w_ac_PD10+1)/(t_PD+1))
            print("W PD25 AC: ",(w_ac_PD25+1)/(t_PD+1))
            
            print("total")
            print("ac25",(sum([sum(i) for i in ac25])+1)/(sum([sum(i) for i in t_])+1),(sum([i[0] for i in ac25])+1)/(sum([i[0] for i in t_])+1),(sum([i[1] for i in ac25])+1)/(sum([i[1] for i in t_])+1),file=out_put_file)
            print("ac10",(sum([sum(i) for i in ac10])+1)/(sum([sum(i) for i in t_])+1),(sum([i[0] for i in ac10])+1)/(sum([i[0] for i in t_])+1),(sum([i[1] for i in ac10])+1)/(sum([i[1] for i in t_])+1),file=out_put_file)
            print("ac1",(sum([sum(i) for i in ac1])+1)/(sum([sum(i) for i in t_])+1),(sum([i[0] for i in ac1])+1)/(sum([i[0] for i in t_])+1),(sum([i[1] for i in ac1])+1)/(sum([i[1] for i in t_])+1),file=out_put_file)
            print("bac",(sum([sum(i) for i in bac_])+1)/(sum([sum(i) for i in t_])+1),(sum([i[0] for i in bac_])+1)/(sum([i[0] for i in t_])+1),(sum([i[1] for i in bac_])+1)/(sum([i[1] for i in t_])+1),file=out_put_file)
            for dd_ in range(0,6,1):
                print("depth: ",dd_,file=out_put_file)
                print("ac25",(ac25[dd_][0]+ac25[dd_][1]+1)/(t_[dd_][0]+t_[dd_][1]+1),(ac25[dd_][0]+1)/(t_[dd_][0]+1),(ac25[dd_][1]+1)/(t_[dd_][1]+1),file=out_put_file)
                print("ac10",(ac10[dd_][0]+ac10[dd_][1]+1)/(t_[dd_][0]+t_[dd_][1]+1),(ac10[dd_][0]+1)/(t_[dd_][0]+1),(ac10[dd_][1]+1)/(t_[dd_][1]+1),file=out_put_file)
                print("ac1",(ac1[dd_][0]+ac1[dd_][1]+1)/(t_[dd_][0]+t_[dd_][1]+1),(ac1[dd_][0]+1)/(t_[dd_][0]+1),(ac1[dd_][1]+1)/(t_[dd_][1]+1),file=out_put_file)
                print("bac",(bac_[dd_][0]+bac_[dd_][1]+1)/(t_[dd_][0]+t_[dd_][1]+1),(bac_[dd_][0]+1)/(t_[dd_][0]+1),(bac_[dd_][1]+1)/(t_[dd_][1]+1),file=out_put_file)
            print(" ")
            
            print("PD1 AC: ",(ac_PD1+1)/(t_PD+1),file=out_put_file)
            print("PD10 AC: ",(ac_PD10+1)/(t_PD+1),file=out_put_file)
            print("PD25 AC: ",(ac_PD25+1)/(t_PD+1),file=out_put_file)
            
            print("W PD1 AC: ",(w_ac_PD1+1)/(w_t_PD+1),file=out_put_file)
            print("W PD10 AC: ",(w_ac_PD10+1)/(w_t_PD+1),file=out_put_file)
            print("W PD25 AC: ",(w_ac_PD25+1)/(w_t_PD+1),file=out_put_file)
                
            print("")
            print("  Average training loss: {0:.2f}".format(avg_train_loss))
            print("  Training epcoh took: {:}".format(training_time))

            print("  Average training loss: {0:.2f}".format(avg_train_loss),file=out_put_file)
            print("  Training epcoh took: {:}".format(training_time),file=out_put_file)
            out_put_file.flush()
