function [net, codeBook, obj] = FOFE_RL_DNN_LM(D, para) 
%
%%%%usage: [net, codeBook, obj] = FOFE_RL_DNN_LM(D, para) 
%
%%D
%       D.train_sent_ind
%       D.train_sent_len
%       D.train_sent_ind
%       D.train_sent_len
%       D.train_sent_ind
%       D.train_sent_len
%       D.codeBook
%       D.codeBookSize
%para:
%       para.lrate : learning rate, 0.4
%       para.batchsize : minibatch size, 200
%       para.vecSize: 200
%       para.n_words: 1
%       para.NNSize: net architecture, [para.vecSize*para.n_words 400 400 D.codeBookSize]
%
% Output:
%       net:   trained neural network
%       obj:   preplexity

%############ Init
nLayers = length(para.NNsize);
for iLayer = 1: nLayers-1
  %%weights and bias initlized
  range = 0.5*sqrt(6/(para.NNsize(iLayer) + para.NNsize(iLayer+1)));
  net.w{iLayer} = (gpuArray.rand(para.NNsize(iLayer),para.NNsize(iLayer+1))*2-1) * range;
  net.b{iLayer} = gpuArray.zeros(1,para.NNsize(iLayer+1));
end
%%%%CPU to GPU
net.codeBook   = gpuArray(D.codeBook);
clear D.codeBook;
%%%%function matrix
forgetting_matrix = gpuArray.zeros(2000,2000);
for i = 1 : 2000
    for j = i : -1 : 1
        forgetting_matrix(i,j) = para.alpha^(i-j);
    end
end
%############# Train
totalSent = length(D.train_sent_len);
SHIFT_VECTOR = (0:para.batchSize-1) * para.NNsize(iLayer+1);
SHIFT_VECTOR = SHIFT_VECTOR';

halfStep = 1;
halfFlag = 0;
curEpoch = 1;
%load(sprintf('./tmp/epoch%d.mat',curEpoch));
%for curEpoch = 1 : 10
 while( halfStep < 7)
    tic
    curEpoch = curEpoch + 1;
    if(halfFlag)
       para.lrate = para.lrate/2;
       halfStep = halfStep + 1;
    end
    %shuffle
    randNdx  = randperm(totalSent);
    sent_st  = 1;
    while(sent_st <= totalSent)
        M      = gpuArray.zeros(para.batchSize, para.batchSize);
        %M      = [];
        row_st = 1;
        x      = [];
        t      = [];
        cur_batchSize = 0;
        while(cur_batchSize < para.batchSize && sent_st<=totalSent)
            cur_len = D.train_sent_len(randNdx(sent_st)) - para.n_words;
            cur_batchSize = cur_batchSize + cur_len;
            if(cur_batchSize > para.batchSize)
                cur_len = cur_len - (cur_batchSize - para.batchSize);
            end
            tmp_M = forgetting_matrix(1:cur_len,1:cur_len);
            M(row_st:row_st+cur_len-1,row_st:row_st+cur_len-1) = tmp_M;
            row_st = row_st + cur_len;
            %M= blkdiag(M, tmp_M);
            tmp_x = zeros(cur_len, para.n_words);
            for j = 1 : para.n_words
                tmp_x(:,j)= D.train_sent_ind{randNdx(sent_st)}(j:cur_len+j-1)';
            end
            x = [x;tmp_x];
            %%target
            ind = D.train_sent_ind{randNdx(sent_st)}(para.n_words+1: cur_len+para.n_words)';
            t= [t; ind];
            sent_st = sent_st + 1;
        end
        if(cur_batchSize > para.batchSize && D.train_sent_len(randNdx(sent_st-1))<para.batchSize)
            sent_st = sent_st - 1;
        end
        if(cur_batchSize>=para.batchSize)   
           [net] = DoBackProp_GPU(M, x, t, net, para, SHIFT_VECTOR);  
        end
    end
    %%%%training set perplexity
    %obj.trainPPL(curEpoch) = Evaluate_obj(net, forgetting_matrix,  D.train_sent_ind, D.train_sent_len, para);
    %%%%%%%%%%%%%%%valid set perplexity
    obj.validPPL(curEpoch) = Evaluate_obj(net, forgetting_matrix,  D.valid_sent_ind, D.valid_sent_len, para);
    %%%%%%%%%%%%%%% test set perplexity
    obj.testPPL(curEpoch)  = Evaluate_obj(net, forgetting_matrix,  D.test_sent_ind, D.test_sent_len, para);
    trainTime = toc;
    fprintf('curEpoch=%d, lrate=%f,  Time = %.2fs, testPPL=%.2f\n', curEpoch, para.lrate, trainTime, obj.testPPL(curEpoch));
    %%%%
    %para.lrate = para.lrate * 0.6;
    if(~halfFlag && curEpoch>1 && obj.validPPL(curEpoch) > obj.validPPL(curEpoch-1))
       curEpoch = curEpoch - 1;
       load(sprintf('./tmp/epoch%d.mat',curEpoch));
       halfFlag = 1;
    end
    if isnan(obj.validPPL(curEpoch))
       curEpoch = curEpoch - 1;
       load(sprintf('./tmp/epoch%d.mat',curEpoch));
       halfFlag = 1;
    end
    save(sprintf('./tmp/epoch%d.mat',curEpoch),'net','obj','para');
 end;
    codeBook = gather(net.codeBook);
    %Best_trainPPL = min(obj.trainPPL);
    %Best_validPPL = min(obj.validPPL);
    Best_testPPL = min(obj.testPPL);
    fprintf('The best perplexity: test= %.2f\n', Best_testPPL);
end
