import torch
import torch.nn as nn
from tqdm import tqdm
from torch.autograd import Variable
from batch_generator import *
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import time
from itertools import chain
from custom_rnn import *

# Hyper Parameters
embedding_size = 300
context_size = 600
filter_size = 100
num_classes = 10
batch_size = 128
iter_size = 64
iter_num = batch_size // iter_size
num_epochs = 2
init_learning_rate = [0.001, 0.001, 0.001]
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device="cuda"
# device=torch.device('cuda:1')
# train_filename='imdb_tr.pkl'
# test_filename='imdb_te.pkl'
# dic_filename='imdb_dic.pkl'

def get_maxlen(df_name):
    df = load_file(df_name)
    ls = df['len'].tolist()
    ls = [sum(l) for l in ls]
    return max(ls)

train_filename='data/text_classification/yahoo/allen/train.pkl'
test_filename='data/text_classification/yahoo/allen/test.pkl'
dic_filename='data/text_classification/yahoo/allen/dictionary.pkl'
glove_filename = None
glove_filename= 'data/text_classification/yahoo/allen/glove.pkl'

dic = load_file(dic_filename)
trdf = load_file(train_filename)
tedf = load_file(test_filename)
tr = BucketedDataIterator(trdf,dic, 20,char_embedding=False, max_len=1000)
te = BucketedDataIterator(tedf,dic,char_embedding=False,oov_index=tr.oov_index, max_len=1000)
vocab_size = tr.oov_index + 1
# maxlen = get_maxlen(train_filename)


def adjust_learning_rate(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return optimizer

class Dynamic_Fc(nn.Module):
    def __init__(self, context_size,input_size, num_class):
        super(Dynamic_Fc, self).__init__()
        self.context_size = context_size
        self.num_class = num_class
        self.input_size = input_size
        # self.bn_c = nn.BatchNorm1d(context_size)
        # self.bn_w = nn.BatchNorm1d(input_size*num_class)
        self.std = math.sqrt(2 / (input_size*num_class))
        self.fc = nn.Linear(context_size,input_size*num_class)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_normal_(self.fc.weight)
        nn.init.constant_(self.fc.bias,val=0)

    def forward(self,context,x):
        # context = self.bn_c(context)
        fc_weight = self.fc(context) # [b,i*c]
        # fc_weight = self.bn_w(fc_weight) * self.std
        mean = fc_weight.mean(-1, keepdim=True)
        std = fc_weight.std(-1, keepdim=True)
        fc_weight = ((fc_weight - mean) / std) * self.std #* math.sqrt(2)
        fc_weight = fc_weight.view(-1,self.input_size,self.num_class)

        out = torch.bmm(x.unsqueeze(1),fc_weight).squeeze(1)
        return out


class Dynamic_Conv(nn.Module):
    def __init__(self, context_size, infilter_size, outfilter_size, kernel_size, hash_size = 20, rank=5):
        super(Dynamic_Conv, self).__init__()
        self.context_size = context_size
        self.infilter_size = infilter_size
        self.outfilter_size = outfilter_size
        self.kernel_size = kernel_size
        self.padding_size = int((kernel_size - 1) / 2)
        conv_dims = outfilter_size * infilter_size * kernel_size
        # self.bn_o = nn.BatchNorm1d(outfilter_size)
        # self.bn_c = nn.BatchNorm1d(context_size)
        # self.bn_w = nn.BatchNorm1d(conv_dims)
        # self.gamma = Parameter(torch.ones(layer_size))
        # self.beta = Parameter(torch.zeros(layer_size))
        latent_size = 100
        self.latent_size = latent_size
        self.rank = rank
        self.std_conv = math.sqrt(2 / (outfilter_size + infilter_size * kernel_size))
        self.hash_f = np.random.randint(0, hash_size, size=(outfilter_size, rank))
        self.hash_b = nn.Embedding(hash_size, infilter_size * kernel_size)
        # self.std_conv = math.sqrt(1 / (infilter_size * kernel_size))
        # self.convs = nn.Sequential(nn.Linear(context_size, context_size), nn.LeakyReLU(),
        #                            nn.Linear(context_size, conv_dims))
        # self.convs = nn.Sequential(nn.Linear(context_size, conv_dims))
        self.convs = nn.Sequential(nn.Linear(context_size, outfilter_size * rank))
        # self.generation = nn.Sequential(nn.Linear(rank, infilter_size * kernel_size))
        # self.bias = nn.Sequential(nn.Linear(context_size, outfilter_size))
        # self.conv_left = nn.Sequential(nn.Linear(context_size, infilter_size * kernel_size*rank))
        # self.conv_right = nn.Sequential(nn.Linear(context_size, outfilter_size *rank))
        # self.latent = nn.Sequential(nn.Linear(context_size, outfilter_size * latent_size),nn.ReLU())
        # self.convs = nn.Sequential(nn.Linear(latent_size, infilter_size*kernel_size))
        # self.new_context = nn.Sequential(nn.Linear(conv_dims, context_size), nn.ReLU())
        self.ln = nn.LayerNorm(outfilter_size)
        # self.ln_filter = nn.LayerNorm(infilter_size * kernel_size)
        self.reset_parameters()

    def reset_parameters(self):
        # nn.init.xavier_normal_(self.hash_b.weight)
        nn.init.kaiming_normal_(self.hash_b.weight,nonlinearity='linear')
        for i in self.convs:
            if isinstance(i, nn.Linear):
                nn.init.kaiming_normal_(i.weight)
                # nn.init.xavier_normal_(i.weight)
                nn.init.constant_(i.bias, val=0)
        # for i in self.bias:
        #     if isinstance(i, nn.Linear):
        #         # nn.init.kaiming_normal_(i.weight,nonlinearity='linear')
        #         nn.init.xavier_normal_(i.weight, gain=math.sqrt(2))
        #         nn.init.constant_(i.bias, val=0)
        # for i in self.generation:
        #     if isinstance(i, nn.Linear):
        #         # nn.init.kaiming_normal_(i.weight,nonlinearity='linear')
        #         nn.init.xavier_normal_(i.weight)
        #         nn.init.constant_(i.bias, val=0)
        # for i in self.conv_left:
        #     if isinstance(i, nn.Linear):
        #         # nn.init.kaiming_normal_(i.weight,nonlinearity='linear')
        #         nn.init.xavier_normal_(i.weight)
        #         nn.init.constant_(i.bias, val=0)
        # for i in self.conv_right:
        #     if isinstance(i, nn.Linear):
        #         # nn.init.kaiming_normal_(i.weight,nonlinearity='linear')
        #         nn.init.xavier_normal_(i.weight)
        #         nn.init.constant_(i.bias, val=0)
        # for i in self.latent:
        #     if isinstance(i, nn.Linear):
        #         # nn.init.kaiming_normal_(i.weight)
        #         nn.init.xavier_normal_(i.weight, gain=math.sqrt(2))
        #         nn.init.constant_(i.bias, val=0)
        # for i in self.new_context:
        #     if isinstance(i, nn.Linear):
        #         nn.init.kaiming_normal_(i.weight,nonlinearity='relu')
        #         # nn.init.xavier_normal_(i.weight, gain=math.sqrt(2))
        #         nn.init.constant_(i.bias, val=0)

    def get_filters(self, context):
        # conv_weight = self.convs(context)
        # # mean = conv_weight.mean(-1, keepdim=True)
        # # std = conv_weight.std(-1, keepdim=True)
        # # conv_weight = ((conv_weight - mean) / std) * self.std_conv * math.sqrt(2)
        # conv_weight = conv_weight.view(context.size(0), self.infilter_size * self.kernel_size, -1)

        conv_weight = self.convs(context).view(context.size(0), -1, self.rank)  # [batch, out, rank]
        conv_weight = F.softmax(conv_weight, -1)
        hashed = self.hash_b(torch.from_numpy(self.hash_f).to(device))  # [out, rank, in*k]
        conv_weight = torch.bmm(conv_weight.transpose(0, 1), hashed)  # [ out, batch, in*k]
        conv_weight = conv_weight.permute(1, 2, 0)
        return conv_weight

    def forward(self, context, x):
        """
        :param x: [batch, lens, channel]
        :return: context [batch,context_size], x [batch,lens,channel]
        """
        device = x.device
        batch_size = x.size(0)
        # context = self.bn_c(context)
        # print(x.size(),self.context_size,self.weight_size)
        # temp = F.unfold(x,kernel_size=self.kernel_size,padding=int((self.kernel_size - 1) / 2))
        temp = x.unfold(1, self.kernel_size, 1).contiguous()
        if self.padding_size:
            padding = torch.zeros(batch_size, self.padding_size, self.infilter_size, self.kernel_size).to(device)
            temp = torch.cat([padding, temp, padding], 1)
        temp = temp.view(temp.size(0), temp.size(1), -1)  # [batch, len, in*k]

        # latent = self.latent(context) # [batch,out*latent]
        # latent = latent.view(latent.size(0),self.outfilter_size,self.latent_size) #[batch,out,latent]
        # conv_weight = self.convs(latent) #[batch,out,in*kernel_size]
        # conv_weight = conv_weight.transpose(1,2).contiguous()
        # new_context = self.new_context(conv_weight.view(conv_weight.size(0),-1))
        # mean = conv_weight.view(conv_weight.size(0),-1).mean(-1, keepdim=False)
        # std = conv_weight.view(conv_weight.size(0),-1).std(-1, keepdim=False)
        # conv_weight = ((conv_weight - mean[:,None,None]) / std[:,None,None]) * self.std_conv

        # conv_weight = self.convs(context)
        # conv_weight = self.generation(conv_weight.view(batch_size, -1, self.rank)).view(batch_size, -1)
        # conv_weight = self.convs(context).view(batch_size,-1,self.rank) # [batch, out, rank]
        # conv_weight = F.softmax(conv_weight,-1)
        # conv_weight = self.generation(conv_weight).view(batch_size, self.infilter_size * self.kernel_size, -1)

        conv_weight = self.convs(context).view(batch_size, -1, self.rank)  # [batch, out, rank]
        conv_weight = F.softmax(conv_weight, -1)
        hashed = self.hash_b(torch.from_numpy(self.hash_f).to(device))  # [out, rank, in*k]
        conv_weight = torch.bmm(conv_weight.transpose(0, 1), hashed)  # [ out, batch, in*k]
        conv_weight = conv_weight.permute(1, 2, 0)


        # conv_left = self.conv_left(context).view(batch_size,-1,self.rank)
        # conv_right = self.conv_right(context).view(batch_size,self.rank,-1)
        # conv_weight = torch.bmm(conv_left,conv_right)
        # conv_weight = self.ln_filter(conv_weight.transpose(1,2)).transpose(1,2)
        # mean = conv_weight.mean(-1, keepdim=True)
        # std = conv_weight.std(-1, keepdim=True)
        # conv_weight = ((conv_weight - mean) / std) * self.std_conv #*math.sqrt(2)
        # bias = self.bias(context)

        # print(conv_weight.size())
        # bias = self.bias(context)

        # conv_weight = self.convs(context)
        # # conv_weight = conv_weight.view(batch_size, self.infilter_size * self.kernel_size, -1)
        # mean = conv_weight.mean(-1, keepdim=True)
        # std = conv_weight.std(-1, keepdim=True)
        # conv_weight = ((conv_weight - mean) / std) * self.std_conv *math.sqrt(2)
        # # new_context = self.new_context(conv_weight)
        # conv_weight = conv_weight.view(batch_size, self.infilter_size * self.kernel_size, -1)  # [batch, in*k, out]

        # conv_weight = self.convs(context)
        # conv_weight = self.bn_w(conv_weight) * self.std_conv
        # new_context = self.new_context(conv_weight)
        # conv_weight = conv_weight.view(batch_size, self.infilter_size * self.kernel_size, -1)  # [batch, in*k, out]

        # i = i.view(i.size(0),-1).transpose(0,1)
        # out = temp @ i
        # print(temp.size(),conv_weight.size())
        # print(temp.size(),conv_weight.size())

        out = torch.bmm(temp, conv_weight) #+ bias.unsqueeze(1)
        out = self.ln(out)

        # batch,length,hidden = out.size()
        # out = self.bn_o(out.view(batch*length,hidden)).view(batch,length,hidden)

        out = F.relu(out)
        return torch.zeros(1,1,1),out


class RNN(nn.Module):
    def __init__(self, embed_size,context_size, num_classes, filter_size=100,use_lstm=True):
        super(RNN, self).__init__()
        print('use lstm is',use_lstm)
        self.embed_size = embed_size
        self.filter_size = filter_size
        self.context_size = context_size
        self.embed = Word_embedding(embed_size,vocab_size,glove_filename,dic)
        # self.bn_e = nn.BatchNorm1d(embed_size)
        self.Ks = [3, 4, 5]
        self.use_lstm = use_lstm
        # self.temp = nn.ParameterList([Parameter(torch.Tensor(conv_channels,embed_size,K)) for K in self.Ks])
        # self.temp = nn.ModuleList([nn.Conv1d(embed_size,conv_channels,K) for K in self.Ks])
        self.convs = nn.ModuleList()
        if use_lstm:
            self.latent_context = nn.GRU(embed_size, self.context_size //2, 1, batch_first=True, bidirectional=True)
            # self.latent_trans = nn.Sequential(nn.Linear(context_size,context_size),nn.Tanh())
            self.self_att_vec = Parameter(torch.Tensor(self.context_size))
            self.att_vec_g = Parameter(torch.Tensor([1.0]))
        else:
            self.latent_context = nn.Sequential(nn.Linear(embed_size, self.context_size), nn.ReLU())
        for i in self.Ks:
            self.convs.append(Dynamic_Conv(self.context_size,self.embed_size,self.filter_size,i))
        # self.fc = Dynamic_Fc(self.context_size,len(self.Ks)*self.filter_size,num_classes)
        self.fc = nn.Linear(len(self.Ks)*self.filter_size,num_classes)
        self.reset_parameters()


    def reset_parameters(self):
        if self.use_lstm:
            # for i in self.latent_trans:
            #     if isinstance(i,nn.Linear):
            #         nn.init.xavier_normal_(i.weight,gain=5/3)
            #         nn.init.constant_(i.bias,val=0)
            nn.init.normal_(self.self_att_vec.data, std=1.0 / math.sqrt(self.context_size))
            stdv = math.sqrt(2.0 / (self.embed_size*2))
            stdv_1 = math.sqrt(2.0 / (self.embed_size*3))
            for name, param in self.latent_context.named_parameters():
                if 'bias' in name:
                    nn.init.constant_(param, val=0)
                    # if 'l0' in name:
                    #     nn.init.normal_(param, 0, stdv)
                    # else:
                    #     nn.init.normal_(param, 0, stdv_1)
                elif 'weight' in name:
                    if 'l0' in name:
                        # nn.init.xavier_normal_(param)
                        nn.init.normal_(param, 0, stdv)
                    else:
                        nn.init.normal_(param, 0, stdv_1)
        # nn.init.xavier_normal_(self.fc.weight)
        nn.init.kaiming_normal_(self.fc.weight,nonlinearity='linear')
        nn.init.zeros_(self.fc.bias)

    def contexualized_filters(self, x, seq_lengths):
        x = self.embed(x)
        batch_size, length, hidden_size = x.size()
        packed_input = pack_padded_sequence(x, seq_lengths, batch_first=True)
        packed_output, _ = self.latent_context(packed_input, None)
        out_rnn, _ = pad_packed_sequence(packed_output, batch_first=True)
        # out_rnn = self.latent_trans(out_rnn)
        att_vec = self.att_vec_g * (self.self_att_vec / torch.sqrt(torch.sum(self.self_att_vec ** 2)))
        # att_vec = self.self_att_vec
        att_vec = att_vec.unsqueeze(0).unsqueeze(0)
        logits = torch.sum(torch.mul(out_rnn, att_vec), -1)
        att = softmax_with_len(logits, seq_lengths).unsqueeze(-1)  # [b, l]
        latent = torch.sum(out_rnn * att, 1)
        filters = self.convs[0].get_filters(latent).squeeze()
        return filters

    def forward(self, x, seq_lengths):
        x = self.embed(x)
        batch_size, length, hidden_size = x.size()
        # x = self.bn_e(x.view(batch_size * length, hidden_size)).view(batch_size, length, hidden_size)
        if self.use_lstm:
            packed_input = pack_padded_sequence(x, seq_lengths, batch_first=True)
            packed_output, _ = self.latent_context(packed_input, None)
            out_rnn, _ = pad_packed_sequence(packed_output, batch_first=True)
            # out_rnn = self.latent_trans(out_rnn)
            att_vec = self.att_vec_g * (self.self_att_vec / torch.sqrt(torch.sum(self.self_att_vec ** 2)))
            # att_vec = self.self_att_vec
            att_vec = att_vec.unsqueeze(0).unsqueeze(0)
            logits = torch.sum(torch.mul(out_rnn, att_vec), -1)
            att = softmax_with_len(logits, seq_lengths).unsqueeze(-1)  # [b, l]
            latent = torch.sum(out_rnn * att, 1)
        else:
            latent = self.latent_context(x.mean(dim=1))
        if length < self.Ks[-1]:
            zeros = torch.zeros(batch_size, self.Ks[-1], self.embed_size).to(device)  # for contextualized
            zeros[:, :length, :] = x
            x = zeros
        convolutions= []
        contexts = []
        for i in self.convs:
            context, convolutioned = i(latent,x)
            contexts.append(context)
            convolutions.append(convolutioned)

        pooled = [F.max_pool1d(i.transpose(1,2), i.size(1)).squeeze(2) for i in convolutions]
        context = torch.stack(contexts,-1).sum(dim=-1)
        last_tensor = torch.cat(pooled, 1)
        # out = self.fc(context,last_tensor)
        out = self.fc(last_tensor)

        # Decode hidden state of last time step
        # out = self.fc(last_tensor)
        return out

def save_filters(model,batch_generator):
    pos_cnt = 0
    neg_cnt = 0
    for i in range(100):
        # training_process
        x, lengths, labels = tr.next_batch(iter_size)
        x = torch.from_numpy(x).long().to(device)
        labels = torch.from_numpy(labels).to(device)
        lengths = torch.from_numpy(np.array(lengths)).to(device)
        # Forward + Backward + Optimize
        filters = rnn.contexualized_filters(x, lengths)
        for idx,label in enumerate(labels):
            if label == 0:
                save_file(os.path.join(r'filters','pos'+str(pos_cnt)+'.pkl'), filters[idx].transpose(0,1).data.to('cpu').numpy())
                pos_cnt+=1
            else:
                save_file(os.path.join(r'filters', 'neg' + str(neg_cnt) + '.pkl'), filters[idx].transpose(0,1).data.to('cpu').numpy())
                neg_cnt +=1

# Train the Model
max_accuraries = []
for lr in init_learning_rate:
    tr.epochs = 0
    te.epochs = 0
    current_epochs = 0
    learning_rate = lr
    rnn = RNN(embedding_size,context_size, num_classes,filter_size)
    rnn.to(device)
    print(count_parameters(rnn)-count_parameters(rnn.embed))
    # Loss and Optimizer
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)
    while tr.epochs < num_epochs:
        t = time.time()
        adjust_learning_rate(optimizer, learning_rate)
        cnt = 0
        rnn.train()
        while current_epochs == tr.epochs:
            cnt += 1
            pbar = tqdm(range(100))
            loss_eval = 0
            total = 0
            correct = 0
            mse = 0
            for i in pbar:
                optimizer.zero_grad()
                for _ in range(iter_num):
                    # training_process
                    x, lengths, labels = tr.next_batch(iter_size)
                    x = torch.from_numpy(x).long().to(device)
                    labels = torch.from_numpy(labels).to(device)
                    lengths = torch.from_numpy(np.array(lengths)).to(device)
                    # Forward + Backward + Optimize
                    outputs = rnn(x,lengths)
                    _, predicted = torch.max(outputs.data, 1)
                    loss = criterion(outputs, labels).to(device)
                    loss_eval += loss.data.item()
                    loss.backward()
                    total += labels.size(0)
                    correct += (predicted == labels).sum()
                    mse += ((predicted - labels) ** 2).sum()

                torch.nn.utils.clip_grad_norm_(rnn.parameters(), 5.0)
                optimizer.step()
                pbar.set_description("accuracy : %f mse : %f loss : %f epoch : %d iter : %d" % (
                    100 * float(correct) / float(total), float(mse) / float(total),
                    loss_eval / ((float(i) + 1) * iter_num), (current_epochs + 1),
                    cnt))
                if current_epochs < tr.epochs:
                    pbar.close()
                    break
        learning_rate = learning_rate * 0.1
        i = 0
        total = 0
        correct = 0
        mse = 0
        pbar = tqdm()
        rnn.eval()
        while current_epochs == te.epochs:
            for _ in range(iter_num):
                x, lengths, labels = te.next_batch(iter_size)
                x = torch.from_numpy(x).long().to(device)
                lengths = torch.from_numpy(np.array(lengths)).to(device)
                labels = torch.from_numpy(labels).to(device)
                outputs = rnn(x,lengths)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum()
                mse += ((predicted - labels) ** 2).sum()
            pbar.set_description(
                "test accuracy : %f test MSE : %f time spent %f" % (
                100 * float(correct) / float(total), float(mse) / float(total), time.time() - t))
        pbar.close()
        pbar.close()
        acc = 100 * float(correct) / float(total)
        current_epochs += 1
        # print('Test Accuracy %f%%, time spent %f ' %(100*float(correct) / float(total), time.time()-t))
        t = time.time()
# for i in range(num_layers):
#     cell = rnn.lstm.get_cell(i)
#     print(cell.weight_ar.data)
