import torch
import torch.nn as nn
from tqdm import tqdm
from torch.autograd import Variable
from batch_generator import *
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import time
from itertools import chain
from custom_rnn import *

# Hyper Parameters
embedding_size = 300
affine_size = 150
maxlen= 300
channel_size = 75
compress_size = 150
kernel_size = 3
layer_size = 5
num_classes = 10
batch_size = 128
iter_size = 128
iter_num = batch_size // iter_size
num_epochs = 2
init_learning_rate = [0.001, 0.001, 0.001]
device=torch.device('cuda:0')

# train_filename='imdb_tr.pkl'
# test_filename='imdb_te.pkl'
# dic_filename='imdb_dic.pkl'

def get_maxlen(df_name):
    df = load_file(df_name)
    ls = df['len'].tolist()
    ls = [sum(l) for l in ls]
    return max(ls)

train_filename='data/text_classification/yahoo/allen/train.pkl'
test_filename='data/text_classification/yahoo/allen/test.pkl'
dic_filename='data/text_classification/yahoo/allen/dictionary.pkl'
# glove_filename = None
glove_filename= 'data/text_classification/yahoo/allen/glove.pkl'

dic = load_file(dic_filename)
trdf = load_file(train_filename)
tedf = load_file(test_filename)
tr = BucketedDataIterator(trdf,dic, 20,char_embedding=False, max_len=maxlen)
te = BucketedDataIterator(tedf,dic,char_embedding=False,oov_index=tr.oov_index, max_len=maxlen)
vocab_size = tr.oov_index + 1
# maxlen = get_maxlen(train_filename)


def adjust_learning_rate(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return optimizer


class DenseNet(nn.Module):
    def __init__(self, word_emb_dim, affine_filter, filter_size, kernel_size,layer_size):
        super(DenseNet, self).__init__()
        self.padding = int((kernel_size - 1) / 2)
        self.layer_size = layer_size
        self.affine = nn.Sequential(nn.Conv1d(word_emb_dim, affine_filter, kernel_size=1, padding=0, bias=False),
                                    nn.LeakyReLU(inplace=True))
        self.convs = nn.ModuleList()
        for i in range(layer_size):
            self.convs.append(nn.Sequential(nn.Conv1d(affine_filter + i*filter_size, filter_size,
                                                      kernel_size=kernel_size, padding=self.padding, bias=False),
                                   nn.LeakyReLU(inplace=True)))

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_normal_(self.affine[0].weight.data)
        for i in self.convs:
            for j in i:
                if isinstance(j,nn.Conv1d):
                    nn.init.kaiming_normal_(j.weight.data)

    def forward(self, embed):
        output = self.affine(embed)
        outputs = []
        for i in range(self.layer_size):
            output_ = self.convs[i](output)
            outputs.append(output_)
            output = torch.cat([output, output_], dim=1)
        return outputs, output


# RNN Model (Many-to-One)
class DSA(nn.Module):
    def __init__(self, embed_size,affine_size, filter_size,compress_size, layer_size,kernel_size, num_classes,maxlen,att=False):
        super(DSA, self).__init__()
        print('att is',att)
        self.embed_size = embed_size
        self.num_classes = num_classes
        self.filter_size = filter_size
        self.affine_size = affine_size
        self.layer_size = layer_size
        self.kernel_size = kernel_size
        self.compress_size = compress_size
        self.maxlen = maxlen
        self.att = att
        self.embed = Word_embedding(embed_size,vocab_size,glove_filename,dic)
        # self.ln = nn.LayerNorm(embed_size)
        self.densenet = DenseNet(embed_size,affine_size,filter_size,kernel_size,layer_size)
        if self.att:
            self.filter_conv = nn.Sequential(nn.Conv1d((layer_size)*filter_size,layer_size,kernel_size=1,groups=layer_size),
                                             nn.LeakyReLU())
            # self.filter_att = nn.Sequential(nn.Linear(layer_size,layer_size),nn.LeakyReLU(),
            #                                 nn.Linear(layer_size,layer_size))
            self.filter_att = nn.Sequential(nn.Linear(layer_size*filter_size, filter_size), nn.LeakyReLU(),
                                            nn.Linear(filter_size,layer_size))
        else:
            self.compress = nn.Sequential(
                nn.Conv1d(affine_size + filter_size * layer_size, filter_size, kernel_size=1, bias=False),
                nn.LeakyReLU())

        self.fc = nn.Sequential(nn.Linear(maxlen*filter_size,filter_size),
                                nn.LeakyReLU(),
                                nn.Linear(filter_size,num_classes))
        # self.fc1 = nn.Sequential(nn.Linear(filter_size, filter_size),nn.LeakyReLU(),nn.Linear(filter_size, 1))
        # self.fc2 = nn.Sequential(nn.Linear(maxlen,num_classes))
        self.reset_parameters()

        # self.fc = nn.Sequential(nn.Linear(max_len * channel_size // 2, num_classes))
        # self.att_vec = self.att_vec = Parameter(torch.Tensor(channel_size // 2))
        # self.att_g = Parameter(torch.Tensor([1.0]))
        # self.fc = nn.Sequential(nn.Linear(channel_size // 2, channel_size // 2), nn.ReLU(),
        #                         nn.Linear(channel_size // 2,num_classes))
        # self.final_conv = nn.Sequential(nn.Conv1d(channel_size // 2,100,kernel_size=3))
        # self.fc = nn.Linear(100,num_classes)

    def reset_parameters(self):
        # nn.init.kaiming_normal_(self.filter_conv.weight, nonlinearity='relu')
        # nn.init.xavier_normal_(self.filter_conv.weight, gain=math.sqrt(2))
        # nn.init.constant_(self.filter_conv.bias, val=0)
        # for i in self.final_conv:
        #     if isinstance(i,nn.Linear):
        #         nn.init.xavier_normal_(self.fc.weight,gain=math.sqrt(2))
        #         nn.init.constant_(self.fc.bias,val=0)

        if self.att:
            for i in self.filter_conv:
                if isinstance(i,nn.Conv1d):
                    nn.init.kaiming_normal_(i.weight)
                    # nn.init.xavier_normal_(i.weight,gain=math.sqrt(2))
                    nn.init.constant_(i.bias,val=0)
            for i in self.filter_att:
                if isinstance(i,nn.Linear):
                    nn.init.kaiming_normal_(i.weight)
                    # nn.init.xavier_normal_(i.weight,gain=math.sqrt(2))
                    nn.init.constant_(i.bias,val=0)
        else:
            nn.init.kaiming_normal_(self.compress[0].weight)
        for i in self.fc:
            if isinstance(i,nn.Linear):
                nn.init.kaiming_normal_(i.weight)
                # nn.init.xavier_normal_(i.weight,gain=math.sqrt(2))
                nn.init.constant_(i.bias,val=0)
        # for i in self.fc2:
        #     if isinstance(i,nn.Linear):
        #         nn.init.kaiming_normal_(i.weight)
        #         # nn.init.xavier_normal_(i.weight,gain=math.sqrt(2))
        #         nn.init.constant_(i.bias,val=0)
        # nn.init.normal_(self.self_att_vec.data, std=1.0 / math.sqrt(self.channel_size//2))
        # nn.init.xavier_normal_(self.fc.weight)
        # nn.init.constant_(self.fc.bias,val=0)

    def forward(self, x, seq_lengths):
        batch_size = x.size(0)
        x = self.embed(x)
        # x = self.ln(x)
        length = x.size(1)
        x = x.transpose(1,2)
        if length < self.maxlen:
            zeros = torch.zeros(batch_size, self.embed_size, self.maxlen).to(device)  # for contextualized
            zeros[:, : ,:length] = x
            x = zeros
        masks = mask_lengths(seq_lengths,self.maxlen)
        dense_outs, dense_out = self.densenet(x) # [[b,chnnel,len], layers

        # logits = dense_out.sum(dim=1) #[b,len,layer]

        # attention
        # logits = F.relu(self.filter_conv(torch.cat(dense_out,1)).transpose(1,2).contiguous()) #[b,len,layer+1]
        # logits = self.filter_att(logits.view(-1,self.layer_size+1))        #[batch * len, layer+1]
        # att = F.softmax(logits,dim=-1).view(batch_size,-1,self.layer_size+1) #[batch,len,layer+1]
        # out = (torch.stack(dense_out, -1) * att.unsqueeze(1)).sum(-1)  # [b,c,l]

        # mean
        # dense_out = torch.stack(dense_out, -1)
        # out = dense_out.mean(-1) # [batch, channel,len ]
        if self.att:
            # logits = self.filter_conv(torch.cat(dense_outs,1))
            # logits = self.filter_att(logits.transpose(1,2).contiguous().view(-1,self.layer_size))
            logits = self.filter_att(torch.cat(dense_outs,1).transpose(1,2).contiguous())
            att = F.softmax(logits,dim=-1).view(batch_size,-1,self.layer_size) #[b,length,layer]
            out = torch.stack(dense_outs,-1) * att.unsqueeze(1)
            out = out.sum(dim=-1)
        else:
            out = self.compress(dense_out)

        out *= masks[:, None, :] #[batch,channel ,len]
        # out = self.final_conv(out) #[b,class,l]
        # out = F.max_pool1d(out,out.size(-1)).squeeze()
        # out = self.fc(out)
        
        out = self.fc(out.view(batch_size,-1))
        # out = self.fc1(out.transpose(1,2)).squeeze(-1)
        # out = self.fc2(out)
        # out = out.transpose(1,2)
        # att_vec = self.att_g * (self.att_vec / torch.sqrt(torch.sum(self.att_vec ** 2)))
        # # # att_vec = self.self_att_vec
        # att_vec = att_vec.unsqueeze(0).unsqueeze(0)
        # logits = torch.sum(torch.mul(out, att_vec), -1)
        # att = softmax_with_len(logits, seq_lengths,self.max_len).unsqueeze(-1)
        # out = torch.sum(out * att, 1)
        # out *= masks[:, None, :]
        # out = self.fc(out)

        return out


# Train the Model
max_accuraries = []
for lr in init_learning_rate:
    tr.epochs = 0
    te.epochs = 0
    current_epochs = 0
    learning_rate = lr
    rnn = DSA(embedding_size,affine_size,channel_size,compress_size,layer_size,kernel_size,num_classes,maxlen)
    rnn.to(device)
    print(count_parameters(rnn)-count_parameters(rnn.embed))
    # Loss and Optimizer
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)
    while tr.epochs < num_epochs:
        t = time.time()
        adjust_learning_rate(optimizer, learning_rate)
        cnt = 0
        while current_epochs == tr.epochs:
            cnt += 1
            pbar = tqdm(range(100))
            loss_eval = 0
            total = 0
            correct = 0
            for i in pbar:

                optimizer.zero_grad()
                for _ in range(iter_num):
                # training_process
                    x, lengths, labels = tr.next_batch(iter_size)
                    x = torch.from_numpy(x).long().to(device)
                    labels = torch.from_numpy(labels).to(device)
                    lengths = torch.from_numpy(np.array(lengths)).to(device)
                    # Forward + Backward + Optimize
                    outputs = rnn(x,lengths)
                    _, predicted = torch.max(outputs.data, 1)
                    loss = criterion(outputs, labels).to(device)
                    loss_eval += loss.data.item()
                    loss.backward()
                    total += labels.size(0)
                    correct += (predicted == labels).sum()

                torch.nn.utils.clip_grad_norm_(rnn.parameters(), 5.0)
                optimizer.step()
                pbar.set_description("accuracy : %f loss : %f epoch : %d iter : %d" % (
                100 * float(correct) / float(total), loss_eval / (float(i) + 1), (current_epochs + 1), cnt))
                if current_epochs < tr.epochs:
                    pbar.close()
                    break
        learning_rate = learning_rate * 0.1
        i = 0
        total = 0
        correct = 0
        pbar = tqdm()
        while current_epochs == te.epochs:
            for _ in range(iter_num):
                x, lengths, labels = te.next_batch(iter_size)
                x = torch.from_numpy(x).long().to(device)
                lengths = torch.from_numpy(np.array(lengths)).to(device)
                labels = torch.from_numpy(labels).to(device)
                outputs = rnn(x,lengths)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum()
            pbar.set_description(
                "test accuracy : %f time spent %f" % (100 * float(correct) / float(total), time.time() - t))
        pbar.close()
        acc = 100 * float(correct) / float(total)

        current_epochs += 1
        # print('Test Accuracy %f%%, time spent %f ' %(100*float(correct) / float(total), time.time()-t))
        t = time.time()
# for i in range(num_layers):
#     cell = rnn.lstm.get_cell(i)
#     print(cell.weight_ar.data)
