import torch
import torch.nn as nn
from src.utils.config import DEVICE, UTC

class TimeLSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.W_all = nn.Linear(hidden_size, hidden_size * 4)
        self.U_all = nn.Linear(input_size, hidden_size * 4)
        self.W_d = nn.Linear(hidden_size, hidden_size)

    def forward(self, inputs, timestamps, reverse=False):
        b, seq, embed = inputs.size()
        h = torch.zeros(b, self.hidden_size, requires_grad=False).to(DEVICE)
        c = torch.zeros(b, self.hidden_size, requires_grad=False).to(DEVICE)

        #h = h.cuda()
        #c = c.cuda()
        outputs = []
        for s in range(seq):
            c_s1 = torch.tanh(self.W_d(c)) #short mem
            c_s2 = c_s1 * timestamps[:, s:s + 1].expand_as(c_s1) #adjusted short term
            c_l = c - c_s1 #long mem
            c_adj = c_l + c_s2 #adjusted prev mem
            outs = self.W_all(h) + self.U_all(inputs[:, s])
            f, i, o, c_tmp = torch.chunk(outs, 4, 1)
            f = torch.sigmoid(f)
            i = torch.sigmoid(i)
            o = torch.sigmoid(o)
            c_tmp = torch.sigmoid(c_tmp)
            c = f * c_adj + i * c_tmp
            h = o * torch.tanh(c)
            outputs.append(h)
        if reverse:
            outputs.reverse()
        outputs = torch.stack(outputs, 1)
        return outputs

class RTimeLSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.W_all = nn.Linear(hidden_size, hidden_size * 4)
        #print(input_size)
        self.U_all = nn.Linear(input_size, hidden_size * 4)
        self.W_d = nn.Linear(hidden_size, hidden_size)

    def forward(self, inputs, timestamps, reach_weights, reverse=False):
        b, seq, embed = inputs.size()
        h = torch.zeros(b, self.hidden_size, requires_grad=False).to(DEVICE)
        c = torch.zeros(b, self.hidden_size, requires_grad=False).to(DEVICE)

        outputs = []
        for s in range(seq):
            c_s1 = torch.tanh(self.W_d(c)) #short mem
            c_s2 = c_s1 * timestamps[:, s:s + 1].expand_as(c_s1) * reach_weights[:, s:s + 1].expand_as(c_s1) #adjusted short term
            c_l = c - c_s1 #long mem
            c_adj = c_l + c_s2 #adjusted prev mem
            outs = self.W_all(h) + self.U_all(inputs[:, s])
            f, i, o, c_tmp = torch.chunk(outs, 4, 1)
            f = torch.sigmoid(f)
            i = torch.sigmoid(i)
            o = torch.sigmoid(o)
            c_tmp = torch.sigmoid(c_tmp)
            c = f * c_adj + i * c_tmp
            h = o * torch.tanh(c)
            outputs.append(h)
        if reverse:
            outputs.reverse()
        outputs = torch.stack(outputs, 1)
        return outputs
