function [net,codeBook, obj] = RL_DNN_LM(D,para) 
%
%%%%usage: [net, obj] = RL_DNN_LM(D,para) 
%
%D:
%       D.train_total_id
%       D.train_keep_id
%       D.valid_total_id
%       D.valid_keep_id
%       D.test_total_id
%       D.test_keep_id
%       D.codeBook: 10001 * vecSize
%       D.codeBookSize
%para:
%       para.lrate : learning rate, 0.4
%       para.batchsize : minibatch size, 100
%       para.n_words: n-gram, (tri-gram: 2)
%       para.NNSize: net architecture, [200 400 400 10001]
%
% Output:
%       net:   trained neural network
%       obj:   preplexity

%############ Init
nLayers = length(para.NNsize);
for iLayer = 1: nLayers-1
  %%weights and bias initlized
  range = 0.5*sqrt(6/(para.NNsize(iLayer) + para.NNsize(iLayer+1)));
  net.w{iLayer} = (gpuArray.rand(para.NNsize(iLayer),para.NNsize(iLayer+1))*2-1) * range;
  net.b{iLayer} = gpuArray.zeros(1,para.NNsize(iLayer+1));
end
%%%%CPU to GPU
net.codeBook   = gpuArray(D.codeBook);
clear D.codeBook
%############# Train
trainSize = size(D.train_keep_id,1);
nBatches = floor( trainSize / para.batchSize);
SHIFT_VECTOR = (0:para.batchSize-1) * para.NNsize(iLayer+1);
SHIFT_VECTOR = SHIFT_VECTOR';
halfStep = 1;
halfFlag = 0;
curEpoch = 0;
while(halfStep < 7)
%for curEpoch = 1 : 10
    tic
    curEpoch = curEpoch + 1;
     if(halfFlag)
         para.lrate = para.lrate/2;
         halfStep = halfStep + 1;
     end
    %shuffle
    randNdx = randperm(trainSize);
    %%training
    x=zeros(para.batchSize,para.n_words);
    for batch = 1 : nBatches
        index = randNdx(1+(batch-1)* para.batchSize : batch*para.batchSize);
        for j = 1 : para.n_words
            x(:,j) = D.train_total_id(D.train_keep_id(index)+j-1);
        end
        t = D.train_total_id(D.train_keep_id(index)+para.n_words);
        [net] = DoBackProp_GPU(x, t, net,  para, SHIFT_VECTOR);  
    end
    %%%%training set perplexity
    %obj.trainPPL(curEpoch) = Evaluate_obj(net, D.train_total_id, D.train_keep_id, para);
    %%%%%%%%%%%%%%%valid set perplexity
    obj.validPPL(curEpoch) = Evaluate_obj(net, D.valid_total_id, D.valid_keep_id, para);
    %%%%%%%%%%%%%%% test set perplexity
    obj.testPPL(curEpoch)  = Evaluate_obj(net, D.test_total_id, D.test_keep_id, para);
    trainTime = toc;
   fprintf('curEpoch=%d, lrate=%f,  Time = %.2fs, validPPL=%.2f, testPPL=%.2f\n', curEpoch, para.lrate, trainTime,obj.validPPL(curEpoch), obj.testPPL(curEpoch));
   %%%%
    if(~halfFlag && curEpoch>1 && obj.validPPL(curEpoch) > obj.validPPL(curEpoch-1))
       curEpoch = curEpoch - 1;
       load(sprintf('./tmp/epoch%d.mat',curEpoch));
       halfFlag = 1;
    end
    if isnan(obj.validPPL(curEpoch))
       curEpoch = curEpoch - 1;
       load(sprintf('./tmp/epoch%d.mat',curEpoch));
       halfFlag = 1;
    end
   % para.lrate = para.lrate * 0.6;
    save(sprintf('./tmp/epoch%d.mat',curEpoch),'net','obj','para');
end
    codeBook = gather(net.codeBook);
    %Best_trainPPL = min(obj.trainPPL);
    %Best_validPPL = min(obj.validPPL);
    Best_testPPL = min(obj.testPPL);
    fprintf('The best perplexity: test= %.2f\n', Best_testPPL);
end
