import demjson
import re

from torch.utils.data import DataLoader,Dataset
import torch.nn as nn
import torch.nn.functional as F
from skimage import io,transform
import torch
from torch import nn
from torch import optim
from torch.autograd import Variable
import transformers as tfs
import math
import random
import os
import sys
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import cross_val_score
import warnings
#import preprocessor
import json
import tqdm
import createdata
import tree2seq 
from nltk.translate.bleu_score import sentence_bleu,SmoothingFunction

from createdata import sememeDataset
from transformers import BertModel, BertTokenizer
from model import TreePred

torch.cuda._initialized = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def active_bytes():
    stats = torch.cuda.memory_stats()
    current_active_byte =  stats['active_bytes.all.current']
    return current_active_byte


def testFirst(model, testSet = "/data3/private/yyn/structPred/datas/testSetForward.pkl",beamWidth = 1):
    testData = torch.load(testSet)
    shunxu = list(range(len(testData)))
    random.seed(2077)
    #torch.manual_seed(2077)
    random.shuffle(shunxu)
    model = model.eval()
    acc1,acc2,acc3,acc4 = 0.0,0.0,0.0,0.0
    total1,total2 = 0.0, 0.0
    with torch.no_grad():
        bar = tqdm.tqdm(range(1000))
        
        for id in bar:
            torch.cuda.empty_cache() 
            cont = testData[shunxu[id]]
            #print(cont)

            ans = [2083]+[model.sememe2id[i] for i in cont[2][0]][1:] + [2084]
            ans = [str(i) for i in ans]
            out = model.predict([cont[1][0]],maxLen=15,beamWidth = beamWidth)
            #ls = out.argmax(dim=1).to("cpu").numpy().tolist()
            ls = [str(i) for i in out]

            #print(ls)
            #print([ans])
            scores,_,_ = tree2seq.computeStrict(ls,ans)
            
            ans.remove("2083")
            ans.remove("2084")
            ls.remove("2083")
            ls.remove("2084")
            score1 = sentence_bleu([ans], ls, smoothing_function=SmoothingFunction().method1)
            
            #print(score)
            acc1 += score1
            acc2 += scores[0]
            acc4 += scores[2]
            total1 += 1
            acc3 += scores[1]
            total2 += 1
                
            if total1 != 0 and total2 != 0:
                bar.set_description(f"{format(acc1 / total1, '.4f')}, {format(acc2/total2, '.4f')}, {format(acc3/total1, '.4f')}, {format(acc4/total1, '.4f')}")

        #print(f"score: {acc / total}")
        return [acc1 / total1, acc2 / total2, acc3/total2, acc4 / total2]
    

    
def myCollate(batch):
    bns = []
    texts,dicts, bias, depth= [],[], [], []
    for cont in batch:
        bns.append(cont[0])
        texts.append(cont[1][0])
        dicts.append(cont[2][0])
        #bias.append(cont[3][0])
        #depth.append(cont[4][0])
        #assert(len(cont[4][0]) == len(depth[0]))
        #assert(len(cont[3][0]) == len(bias[0]))
    #try:
        #bias = torch.tensor([t.numpy() for t in bias],dtype=torch.long).squeeze(1)
        #depth = torch.tensor([t.numpy() for t in depth] ,dtype=torch.long).squeeze(1)
    #except:
        # first len != last len: jump in the train
    #    pass
    return texts,dicts,bias,depth

            

def train(epoch = 601, batchSize = 32, trainSet = "/data3/private/yyn/structPred/datas/trainSetForward.pkl", testSet = "/data3/private/yyn/structPred/datas/testSetForward.pkl",modelName="base",lr = 1e-5):
    
    
    trainData = torch.load(trainSet)
    
    trainLoader=DataLoader(dataset=trainData,batch_size=batchSize,shuffle=False,num_workers=0,collate_fn = myCollate)
        
    model = TreePred(wordDim = 768, maxPosEmbed = 50,sememeDim = 768, head=8,attentionLayer = 8,hiddenDim = 128, pretrained = 1, TUPE=0, MASK=1, seq = False, depthMethod='depth', biasMethod='distance')
    model.to(device)
    
    '''loadname = "base_tupe_noBias_090"
    stateTest = torch.load(f"/data3/private/yyn/structPred/models/"+loadname+"/modelWeight.pt",map_location=torch.device('cpu'))
    model.load_state_dict(stateTest['net'],strict=False)
    model.loadBert(f"/data3/private/yyn/structPred/models/"+loadname+"/")
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=stateTest['optimizer'][0]['lr'],betas=stateTest['optimizer'][0]['betas'],eps=stateTest['optimizer'][0]['eps'],weight_decay=stateTest['optimizer'][0]['weight_decay'])'''
    
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    for epochId in range(0,epoch):
        model.to(device)
        model.device = device
        lens = len(trainLoader)
        bar = tqdm.tqdm(enumerate(trainLoader), total=lens)
        model = model.train()
        for id, (texts,dicts,bias,depth) in bar:
            if len(dicts[0]) != len(dicts[-1]):
                # make sure treeLen equals in a batch
                continue
            out,dictsSeqId = model(texts,dicts)
            # out : [batchSize, treeLen - 1, sememeCount]

            out = out.view(-1,out.shape[2])

            dictsSeqId = torch.tensor(dictsSeqId,dtype=torch.long)[:,1:].to(device)
            dictsSeqId = dictsSeqId.view(-1)


            loss = criterion(out,dictsSeqId)
            #loss = criterion(out,target)
            
            bar.set_description(f"[#{epochId+1}]loss: {format(float(loss),'.6f')},cuda: {active_bytes()}")
        
            
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            out.to("cpu")
            dictsSeqId.to("cpu")
            torch.cuda.empty_cache() 
        if (epochId % 10 == 0):
            # test model
            #model.to('cpu')
            #model.device = torch.device('cpu')
            results = testFirst(model,testSet)
            f = open("info_"+modelName+".txt",'a',encoding='utf-8')
            f.write(f"{epochId},   {format(results[0],'.5f')},  {format(results[1],'.5f')}, {format(results[2],'.5f')}, {format(results[3],'.5f')}\n")
            f.close()
            
            #save model
            model.to("cpu")
            model.eval()
            if not os.path.exists(f"/data3/private/yyn/structPred/models/{modelName}_{format(epochId,'03d')}/"):
                os.mkdir(f"/data3/private/yyn/structPred/models/{modelName}_{format(epochId,'03d')}/")
            model.saveBert(f"/data3/private/yyn/structPred/models/{modelName}_{format(epochId,'03d')}/")
            state = {'net':model.state_dict(), 'optimizer':optimizer.state_dict()['param_groups'], 'epoch':epochId}
            torch.save(state, f"/data3/private/yyn/structPred/models/{modelName}_{format(epochId,'03d')}/modelWeight.pt")
            
            # load and test again
            '''modelT = TreePred(wordDim = 768, maxPosEmbed = 50,sememeDim = 768,attentionLayer = 8,hiddenDim = 128, pretrained = 1)
            stateTest = torch.load(f"/data3/private/yyn/structPred/models/base_beam_{epochId}/modelWeight.pt",map_location=torch.device('cpu'))
            modelT.load_state_dict(stateTest['net'])
            modelT.loadBert(f"/data3/private/yyn/structPred/models/base_beam_{epochId}/")
            modelT.device = torch.device("cpu")
            modelT.to('cpu')
            modelT.eval()
            result1,result2 = testFirst(modelT,testSet)
            f = open(infoName,'a',encoding='utf-8')
            f.write(f"{epochId}:   bleu:{result1},  strict:{result2}\n")
            f.close()
            exit()'''
    
    
if __name__ == "__main__":
    print(f"begin training {sys.argv[1]}")
    train(modelName=sys.argv[1])
    #demo()