import torch
import torch.nn as nn
from tqdm import tqdm
from torch.autograd import Variable
from batch_generator import *
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import time
from itertools import chain
from custom_rnn import *

# Hyper Parameters
embedding_size = 300
filter_size = 100
num_classes = 2
batch_size = 128
iter_size = 128
iter_num = batch_size // iter_size
num_epochs = 2
init_learning_rate = [0.001, 0.001, 0.001]
device=torch.device('cuda:0')

# train_filename='imdb_tr.pkl'
# test_filename='imdb_te.pkl'
# dic_filename='imdb_dic.pkl'

def get_maxlen(df_name):
    df = load_file(df_name)
    ls = df['len'].tolist()
    ls = [sum(l) for l in ls]
    return max(ls)

train_filename='data/text_classification/yelp_p/allen/train.pkl'
test_filename='data/text_classification/yelp_p/allen/test.pkl'
dic_filename='data/text_classification/yelp_p/allen/dictionary.pkl'
# glove_filename = None
glove_filename= 'data/text_classification/yelp_p/allen/glove.pkl'

dic = load_file(dic_filename)
trdf = load_file(train_filename)
tedf = load_file(test_filename)
tr = BucketedDataIterator(trdf,dic, 20,char_embedding=False, max_len=1000)
te = BucketedDataIterator(tedf,dic,char_embedding=False,oov_index=tr.oov_index, max_len=1000)
vocab_size = tr.oov_index + 1
# maxlen = get_maxlen(train_filename)


def adjust_learning_rate(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return optimizer


# RNN Model (Many-to-One)
class RNN(nn.Module):
    def __init__(self, embed_size, num_classes, conv_channels=100):
        super(RNN, self).__init__()
        self.embed_size = embed_size
        self.conv_channels = conv_channels
        self.embed = Word_embedding(embed_size,vocab_size,glove_filename,dic)
        self.Ks = [1, 3, 5]
        self.convs1 = nn.ModuleList(
            [nn.Conv2d(1, self.conv_channels, (K,embed_size)) for K in self.Ks])
        for conv in self.convs1:
            nn.init.kaiming_normal_(conv.weight.data)

        # self.conv13 = nn.Conv2d(1, self.conv_channels, (3, self.hidden_size*self.num_cell))
        # self.conv14 = nn.Conv2d(1, self.conv_channels, (4, self.hidden_size*self.num_cell))
        # self.conv15 = nn.Conv2d(1, self.conv_channels, (5, self.hidden_size*self.num_cell))

        self.fc = nn.Linear(len(self.Ks) * self.conv_channels, num_classes)
        nn.init.kaiming_normal_(self.fc.weight)


    def conv_and_pool(self, x, conv):
        x = F.relu(conv(x)).squeeze(3)  # (batch, n_kernel, W)
        x = F.max_pool1d(x, x.size(2)).squeeze(2)
        return x

    def get_filters(self):
        print('get filters')
        filters = self.convs1[0].weight.squeeze()
        return filters

    def forward(self, x, seq_lengths):
        length = x.size(1)
        batch_size = x.size(0)
        x = self.embed(x)
        length = x.size(1)
        outs = x

        if length < self.Ks[-1]:
            zeros = torch.zeros(batch_size, self.Ks[-1], self.embed_size).to(device)  # for contextualized
            zeros[:, :length, :] = x
            outs = zeros
        outs = outs.unsqueeze(1)


        # x1 = self.conv_and_pool(outs,self.conv13)
        # x2 = self.conv_and_pool(outs,self.conv14)
        # x3 = self.conv_and_pool(outs,self.conv15)
        # x = torch.cat((x1, x2, x3), 1)

        convolutioned = [F.relu(conv(outs)).squeeze(3) for conv in self.convs1]
        pooled = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in convolutioned]
        last_tensor = torch.cat(pooled, 1)

        # Decode hidden state of last time step
        out = self.fc(last_tensor)
        return out


# Train the Model
max_accuraries = []
for lr in init_learning_rate:
    tr.epochs = 0
    te.epochs = 0
    current_epochs = 0
    learning_rate = lr
    rnn = RNN(embedding_size, num_classes,conv_channels=filter_size)
    rnn.to(device)
    print(count_parameters(rnn)-count_parameters(rnn.embed))
    # Loss and Optimizer
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)
    while tr.epochs < num_epochs:
        t = time.time()
        adjust_learning_rate(optimizer, learning_rate)
        cnt = 0
        while current_epochs == tr.epochs:
            cnt += 1
            pbar = tqdm(range(100))
            loss_eval = 0
            total = 0
            correct = 0
            mse = 0
            for i in pbar:
                optimizer.zero_grad()
                for _ in range(iter_num):
                # training_process
                    x, lengths, labels = tr.next_batch(iter_size)
                    x = torch.from_numpy(x).long().to(device)
                    labels = torch.from_numpy(labels).to(device)
                    lengths = torch.from_numpy(np.array(lengths)).to(device)
                    outputs = rnn(x,lengths)
                    _, predicted = torch.max(outputs.data, 1)
                    loss = criterion(outputs, labels).to(device)
                    loss_eval += loss.data.item()
                    loss.backward()
                    total += labels.size(0)
                    correct += (predicted == labels).sum()
                    mse += ((predicted - labels) ** 2).sum()

                torch.nn.utils.clip_grad_norm_(rnn.parameters(), 5.0)
                optimizer.step()
                pbar.set_description("accuracy : %f mse : %f loss : %f epoch : %d iter : %d" % (
                    100 * float(correct) / float(total), float(mse) / float(total),
                    loss_eval / ((float(i) + 1) * iter_num), (current_epochs + 1),
                    cnt))
                if current_epochs < tr.epochs:
                    pbar.close()
                    break
        learning_rate = learning_rate * 0.1
        i = 0
        total = 0
        correct = 0
        mse = 0
        pbar = tqdm()
        while current_epochs == te.epochs:
            for _ in range(iter_num):
                x, lengths, labels = te.next_batch(iter_size)
                x = torch.from_numpy(x).long().to(device)
                lengths = torch.from_numpy(np.array(lengths)).to(device)
                labels = torch.from_numpy(labels).to(device)
                outputs = rnn(x,lengths)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum()

                mse += ((predicted - labels) ** 2).sum()
            pbar.set_description(
                "test accuracy : %f test MSE : %f time spent %f" % (
                100 * float(correct) / float(total), float(mse) / float(total), time.time() - t))

        pbar.close()
        acc = 100 * float(correct) / float(total)
        current_epochs += 1
        # print('Test Accuracy %f%%, time spent %f ' %(100*float(correct) / float(total), time.time()-t))
        t = time.time()
# for i in range(num_layers):
#     cell = rnn.lstm.get_cell(i)
#     print(cell.weight_ar.data)
