import math
import numpy as np

from einops import rearrange, repeat

import torch
from torch import einsum
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import pytorch_lightning as pl

from utils import corrcoef, rmse, mae, smape

## modules ##

class TemporalAttention(nn.Module):
    def __init__(self, in_dim, hidden_dim=128):
        super().__init__()
        self.project = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1, bias=False)
        )

    def forward(self, z):
        w = self.project(z)
        beta = torch.softmax(w, dim=1)
        return (beta * z).sum(1)

def scaled_dot_product_attention(query, key, value, mask=None):
    attn_scores = query.bmm(key.transpose(1,2))
    if mask is not None: 
        attn_scores = attn_scores.masked_fill(mask==0, -1e9)
    scale = query.size(-1) ** 0.5
    softmax = F.softmax(attn_scores/scale, dim=-1).squeeze(0)
    return softmax.bmm(value)

class SingleAttentionLayer(nn.Module):
    def __init__(self, input_dim, k_dim, v_dim):
        super().__init__()
        self.input_dim = input_dim
        self.k_dim = k_dim
        self.v_dim = v_dim

        self.q = nn.Linear(self.input_dim, self.k_dim)
        self.k = nn.Linear(self.input_dim, self.k_dim)
        self.v = nn.Linear(self.input_dim, self.v_dim)

    def forward(self, q_seq, k_seq, v_seq, mask=None):
        return scaled_dot_product_attention(self.q(q_seq), self.k(k_seq), self.v(v_seq), mask)

class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, num_heads, input_dim, k_dim, v_dim):
        super().__init__()
        self.heads = nn.ModuleList(
            [SingleAttentionLayer(input_dim, k_dim, v_dim) for _ in range(num_heads)]
        )
        self.linear = nn.Linear(num_heads*v_dim, input_dim) 

    def forward(self, q_seq, k_seq, v_seq, mask=None):
        concat_heads = torch.cat([h(q_seq, k_seq, v_seq, mask) for h in self.heads], dim=-1)
        out = self.linear(concat_heads)
        return out

class FeedForward(nn.Module):
    def __init__(self, input_dim, ff_dim):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, ff_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(ff_dim, input_dim)

    def forward(self, x, mask=None):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

class Residual(nn.Module):
    def __init__(self, sublayer, dim, dropout=0.2):
        super().__init__()
        self.sublayer = sublayer
        self.norm = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, *tensors, mask=None):
        return self.norm(tensors[-1] + self.dropout(self.sublayer(*tensors, mask)))

class PositionalEncoder(nn.Module): 
    def __init__(self, input_dim, seq_len):
        super().__init__()
        self.input_dim = input_dim

        pe = torch.zeros(seq_len, input_dim)
        for pos in range(seq_len):
            for i in range(0, input_dim, 2):
                pe[pos, i] = math.sin(pos/(10000**((2*i)/input_dim)))
                pe[pos, i+1] = math.cos(pos/(10000**((2*(i+1))/input_dim)))
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x * math.sqrt(self.input_dim)
        seq_len = x.size(1)
        x = x + self.pe[:, :seq_len]
        return x

class TransformerEncoderLayer(nn.Module):
    def __init__(self, input_dim, num_heads, ff_dim, dropout):
        super().__init__()
        k_dim = v_dim = input_dim // num_heads

        self.attention = Residual(
            MultiHeadAttentionLayer(num_heads, input_dim, k_dim, v_dim),
            dim = input_dim,
            dropout = dropout,
        )

        self.feed_forward = Residual(
            FeedForward(input_dim, ff_dim),
            dim = input_dim,
            dropout = dropout
        )

    def forward(self, src, mask=None):
        src = self.attention(src, src, src, mask=mask)
        return self.feed_forward(src)

class TransformerEncoder(nn.Module):
    def __init__(self, num_layers, input_dim, num_heads, ff_dim, seq_len, dropout):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(input_dim, num_heads, ff_dim, dropout) for _ in range(num_layers)
        ])

        self.pe = PositionalEncoder(input_dim, seq_len)

    def forward(self, src, mask=None):
        seq_len, input_dim = src.size(1), src.size(2)
        src = self.pe(src) 

        for layer in self.layers:
            src = layer(src, mask)

        return src

class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return x * F.gelu(gates)

class LatentEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()

        self.dense_q = nn.Linear(hidden_dim, hidden_dim, bias = False)
        self.dense_k = nn.Linear(input_dim, hidden_dim, bias = False)
        self.dense_v = nn.Linear(input_dim, hidden_dim, bias = False)
        self.scale = hidden_dim**-0.5 #1/sqrt
        self.reset_parameters()

    def reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_normal_(p, gain=0.0001)

    def forward(self, latent_q, k, v):
        batch_size = k.size(0)
        latent_q_ = repeat(latent_q, 'n d -> b n d', b = batch_size)
        enc_q = self.dense_q(latent_q_) # repeat on the batch/num_nodes dimension
        enc_k = self.dense_k(k)
        enc_v = self.dense_v(v)
        attn_scores = einsum('b i d, b j d -> b i j', enc_q, enc_k) * self.scale # 
        attn_wts = attn_scores.softmax(dim = -1)
        h = einsum('b i j, b j d -> b i d', attn_wts, enc_v)
        out = h
        return out

class LatentDecoder(nn.Module):
    def __init__(self, orig_dim, input_dim, hidden_dim, activation=F.relu):
        super().__init__()

        self.orig_dense_q = nn.Linear(orig_dim, hidden_dim, bias = False)
        self.dense_k = nn.Linear(input_dim, hidden_dim, bias = False) 
        self.dense_v = nn.Linear(input_dim, hidden_dim, bias = False)
        self.scale = hidden_dim**-0.5 

        self.reset_parameters()

    def reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_normal_(p, gain=0.0001)

    def forward(self, orig_q, k, v):
        enc_orig_q = self.orig_dense_q(orig_q)
        enc_k = self.dense_k(k)
        enc_v = self.dense_v(v)
        attn_scores = einsum('b i d, b j d -> b i j', enc_orig_q, enc_k) * self.scale 
        attn_wts = attn_scores.softmax(dim = -1)
        h = einsum('b i j, b j d -> b i d', attn_wts, enc_v)
        out = h
        return out

class LatentCrossAttn(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()

        self.cross_q = nn.Linear(input_dim, hidden_dim, bias = False)
        self.dense_k = nn.Linear(input_dim, hidden_dim, bias = False)
        self.dense_v = nn.Linear(input_dim, hidden_dim, bias = False)
        self.scale = hidden_dim**-0.5 #1/sqrt
        self.reset_parameters()

    def reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_normal_(p, gain=0.0001)

    def forward(self, cross_q, k, v):
        enc_cross_q = self.cross_q(cross_q)
        enc_k = self.dense_k(k)
        enc_v = self.dense_v(v)

        attn_scores = einsum('b i d, b j d -> b i j', enc_cross_q, enc_k) * self.scale
        attn_wts = attn_scores.softmax(dim = -1)
        h = einsum('b i j, b j d -> b i d', attn_wts, enc_v)
        out = h
        return out

class GlobalLocalGuidedCrossAttn(nn.Module):
    def __init__(self, global_dim, hidden_dim):
        super().__init__()

        self.QLocal = nn.Linear(hidden_dim,hidden_dim)
        self.KGlobal = nn.Linear(global_dim, hidden_dim)
        self.VGlobal = nn.Linear(global_dim, hidden_dim)
        self.gl_weight = nn.Parameter(torch.zeros(size=(hidden_dim,hidden_dim)))
        self.scale = hidden_dim**-0.5

    def reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_normal_(p, gain=0.0001)

    def forward(self, global_feat, local_feat):
        enc_seq_list = []
        for step in range(len(global_feat)):
            step_global = global_feat[step].to(local_feat.device) 

            q = F.gelu(self.QLocal(local_feat[:,step,:]))
            k = F.gelu(self.KGlobal(step_global.float()).squeeze(0))
            v = F.gelu(self.VGlobal(step_global.float()).squeeze(0))

            score_i = torch.matmul(torch.matmul(k,self.gl_weight), q.T)*self.scale
            softmax_weights = F.softmax(score_i, dim=0)
            t_j = torch.matmul(softmax_weights.T, v)
            enc_seq_list.append(t_j)
        h = torch.stack(enc_seq_list, dim=1)
        out = h
        return out

# for text sequences
class IntraDayTemporalEncoder(nn.Module): 
    def __init__(self, feat_in_dim, hidden_dim, time_in_dim, time_out_dim=8, func='sin'):
        super().__init__()
        self.func = func
        self.fc_periodic = nn.Linear(time_in_dim, time_out_dim, bias=False)
        self.fc_sigmoid = nn.Linear(time_in_dim, time_out_dim, bias=False)

        self.project = nn.Linear(feat_in_dim + 2*time_out_dim, hidden_dim, bias=False)

        if func == 'sin':
            self.f = torch.sin
        else:
            self.f = torch.cos

    def forward(self, x, time_features):

        periodic = self.f(self.fc_periodic(time_features))
        sigmoid = nn.Sigmoid()(self.fc_sigmoid(time_features))

        time_vec = torch.cat([periodic, sigmoid], 2)
        out = torch.cat([x, time_vec], dim=2)
        out = self.project(out)
        return out

# for numerical and text sequences that are daily
class InterDayTemporalEncoder(nn.Module): 
    def __init__(self, feat_in_dim, hidden_dim, time_in_dim, time_out_dim, func='sin'):
        super().__init__()
        self.func = func
        self.fc_periodic = nn.Linear(time_in_dim, time_out_dim, bias=False)
        self.fc_sigmoid = nn.Linear(time_in_dim, time_out_dim, bias=False)

        self.project = nn.Linear(feat_in_dim + 2*time_out_dim, hidden_dim, bias=False)

        if func == 'sin':
            self.f = torch.sin
        else:
            self.f = torch.cos

    def forward(self, x, time_features):

        periodic = self.f(self.fc_periodic(time_features))
        sigmoid = nn.Sigmoid()(self.fc_sigmoid(time_features))
        time_vec = torch.cat([periodic, sigmoid], 1)
        exp_time_vec = time_vec.repeat(x.shape[0],1,1)
        out = torch.cat([x, exp_time_vec], dim=2)
        out = self.project(out)
        return out

class TemporalEncoder(nn.Module): 
    def __init__(self, feat_in_dim, hidden_dim, time_in_dim, time_out_dim, func='sin'):
        super().__init__()
        self.func = func
        self.fc_periodic = nn.Linear(time_in_dim, time_out_dim, bias=False)
        self.fc_sigmoid = nn.Linear(time_in_dim, time_out_dim, bias=False)

        self.project = nn.Linear(feat_in_dim + 2*time_out_dim, hidden_dim, bias=False)

        if func == 'sin':
            self.f = torch.sin
        else:
            self.f = torch.cos

    def forward(self, x, time_features):

        periodic = self.f(self.fc_periodic(time_features))
        sigmoid = nn.Sigmoid()(self.fc_sigmoid(time_features))

        time_vec = torch.cat([periodic, sigmoid], 1)
        exp_time_vec = time_vec.repeat(x.shape[0],1,1)
        out = torch.cat([x, exp_time_vec], dim=2)
        out = self.project(out)
        return out

class TemporalTransformerEncoder(nn.Module):
    def __init__(self, num_layers, feat_in_dim, time_in_dim, time_out_dim, func, num_heads, ff_dim, seq_len, dropout):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(feat_in_dim, num_heads, ff_dim, dropout) for _ in range(num_layers)
        ])

        self.te = TemporalEncoder(feat_in_dim, feat_in_dim, time_in_dim, time_out_dim, func=func)

    def forward(self, src, time_features):
        seq_len, input_dim = src.size(1), src.size(2)
        src = self.te(src, time_features) # The positional encoder takes in the features, and generates a positional vector to be added
        for layer in self.layers:
            src = layer(src)

        return src

################
## main model ##
################

class GAME(nn.Module):
    def __init__(self, num_nodes, adj_mat, num_latents, num_dim, txt_dim, time_in_dim, time_out_dim, hidden_dim, output_dim,
                num_layers, window, dropout, num_heads, device):
        super().__init__()

        self.window = window
        self.num_nodes = num_nodes
        self.adj_mat = adj_mat
        self.num_latents = num_latents

        self.local_num_te = InterDayTemporalEncoder(num_dim, hidden_dim, time_in_dim, time_out_dim)
        self.local_txt_te = IntraDayTemporalEncoder(txt_dim, hidden_dim, time_in_dim+1, time_out_dim)

        self.num_seq_encoder = LatentEncoder(input_dim=hidden_dim, hidden_dim=hidden_dim)
        self.txt_seq_encoder = LatentEncoder(input_dim=hidden_dim, hidden_dim=hidden_dim)

        self.latents = nn.Parameter(torch.randn(num_latents, hidden_dim))

        self.num_crossattn = LatentCrossAttn(input_dim=hidden_dim, hidden_dim=hidden_dim)
        self.txt_crossattn = LatentCrossAttn(input_dim=hidden_dim, hidden_dim=hidden_dim)

        self.combine_enc = nn.Linear(hidden_dim*2, hidden_dim, bias=False)

        self.key_fc = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.query_fc = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.inner_weight = nn.Parameter(torch.zeros(size=(num_latents,hidden_dim,hidden_dim)))

        self.num_seq_decoder = LatentDecoder(orig_dim=hidden_dim, input_dim=hidden_dim, hidden_dim=hidden_dim)
        self.txt_seq_decoder = LatentDecoder(orig_dim=hidden_dim, input_dim=hidden_dim, hidden_dim=hidden_dim)

        self.globalattn = GlobalLocalGuidedCrossAttn(global_dim=txt_dim, hidden_dim=hidden_dim)

        self.num_tfms = TransformerEncoder(num_layers=num_layers, input_dim=hidden_dim, 
                                            num_heads=num_heads, ff_dim=hidden_dim*2, seq_len=window, dropout=dropout)

        self.global_tfms = TransformerEncoder(num_layers=num_layers, input_dim=hidden_dim, 
                                            num_heads=num_heads, ff_dim=hidden_dim*2, seq_len=window, dropout=dropout)
        
        self.txt_tfms = TransformerEncoder(num_layers=num_layers, input_dim=hidden_dim, 
                                            num_heads=num_heads, ff_dim=hidden_dim*2, seq_len=window*10, dropout=dropout)


        self.num_attfuse = TemporalAttention(in_dim=hidden_dim, hidden_dim=hidden_dim)
        self.global_attfuse = TemporalAttention(in_dim=hidden_dim, hidden_dim=hidden_dim)
        self.txt_attfuse = TemporalAttention(in_dim=hidden_dim, hidden_dim=hidden_dim)

        self.all_attfuse = TemporalAttention(in_dim=hidden_dim, hidden_dim=hidden_dim)

        self.combine_dec = nn.Linear(hidden_dim*3, hidden_dim, bias=False)
        
        self.fc_forecast = nn.Linear((hidden_dim)+window, output_dim, bias=False)

        self.fc_cov = nn.Linear((hidden_dim)+window, hidden_dim, bias=False)

    def reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_normal_(p, gain=0.00001)      
    
    def forward(self, num_features, time_features, local_news, local_time, global_news):

        enc_num_features = self.local_num_te(num_features, time_features)
 
        local_num_features = self.num_seq_encoder(self.latents.to(num_features.device), enc_num_features, enc_num_features) 

        enc_txt_features = self.local_txt_te(local_news, local_time)

        local_txt_features = self.txt_seq_encoder(self.latents.to(num_features.device), enc_txt_features, enc_txt_features)

        cross_num_features = self.num_crossattn(local_txt_features, local_num_features, local_num_features)
        cross_txt_features = self.txt_crossattn(local_num_features, local_txt_features, local_txt_features)

        feat_encoded = F.leaky_relu(self.combine_enc(torch.cat([cross_num_features, cross_txt_features], dim=2)))

        key = self.key_fc(feat_encoded).transpose(1,0)
        query = self.query_fc(feat_encoded).transpose(1,0)
        scale  = query.size(-1) ** 0.5

        att_weights = F.tanh(torch.bmm(torch.bmm(query[:,:,:], self.inner_weight), key[:,:,:].transpose(2,1))/scale)
        adj_mat_mask = torch.where(self.adj_mat>0.0, torch.tensor(0.0), torch.tensor(-1.0e9)).to(num_features.device)

        encoded_seq = torch.bmm(F.softmax(att_weights + adj_mat_mask, 2), feat_encoded.transpose(1,0)).transpose(1,0)

        num_decoded_seq = self.num_seq_decoder(enc_num_features, encoded_seq, encoded_seq) 
        txt_decoded_seq = self.txt_seq_decoder(enc_txt_features, encoded_seq, encoded_seq) 
        global_txt_seq = self.globalattn(global_news, num_decoded_seq)

        num_out = self.num_tfms(num_decoded_seq)
        txt_out = self.txt_tfms(txt_decoded_seq)
        global_out = self.global_tfms(global_txt_seq)

        num_out = self.num_attfuse(num_out)
        txt_out = self.txt_attfuse(txt_out)
        global_out = self.global_attfuse(global_out)

        stacked_out = torch.stack([num_out, txt_out, global_out], dim=1) 
        out = self.all_attfuse(stacked_out)

        concat_out = torch.cat([out, num_features[:,:,0]], dim=1)
        embed = concat_out
        preds_moments = self.fc_forecast(concat_out)

        key_cov = self.key_fc(self.fc_cov(concat_out))
        query_cov = self.query_fc(self.fc_cov(concat_out))
        preds_corr = F.tanh(torch.matmul(query_cov, key_cov.T)/scale)

        return (preds_moments, preds_corr), embed

class MainModel_GAME_Moments(pl.LightningModule):

    def __init__(self, model, criterion, n_epochs, learning_rate):
        super().__init__()
        self.learning_rate = learning_rate
        self.n_epochs = n_epochs
        self.criterion = criterion
        self.model = model
        self.save_hyperparameters()

    def training_step(self, batch, batch_idx):
        numerical_features, targets, prices, time_features, global_news, local_news, local_time = batch
        numerical_features = numerical_features.float().reshape(numerical_features.shape[0]*numerical_features.shape[1], numerical_features.shape[2], numerical_features.shape[3])
        time_features = time_features.float().reshape(time_features.shape[0]*time_features.shape[1], time_features.shape[2])
        labels = targets.float().reshape(targets.shape[0]*targets.shape[1], targets.shape[2])
        prices = prices.float().reshape(prices.shape[0]*prices.shape[1], prices.shape[2])
        local_news = local_news.float().view(local_news.shape[0]*local_news.shape[1],
                                        local_news.shape[2]*local_news.shape[3],local_news.shape[4])
        local_time = local_time.float().view(local_time.shape[0]*local_time.shape[1],
                                            local_time.shape[2]*local_time.shape[3], local_time.shape[4])
        inputs = (numerical_features, time_features, local_news, local_time, global_news)

        preds, _ = self(inputs)
        loss, _, _ = self.compute_training_loss(preds, labels)
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        numerical_features, targets, prices, time_features, global_news, local_news, local_time = batch
        numerical_features = numerical_features.float().reshape(numerical_features.shape[0]*numerical_features.shape[1], numerical_features.shape[2], numerical_features.shape[3])
        time_features = time_features.float().reshape(time_features.shape[0]*time_features.shape[1], time_features.shape[2])
        labels = targets.float().reshape(targets.shape[0]*targets.shape[1], targets.shape[2])
        prices = prices.float().reshape(prices.shape[0]*prices.shape[1], prices.shape[2])
        local_news = local_news.float().view(local_news.shape[0]*local_news.shape[1],
                                        local_news.shape[2]*local_news.shape[3],local_news.shape[4])
        local_time = local_time.float().view(local_time.shape[0]*local_time.shape[1],
                                            local_time.shape[2]*local_time.shape[3], local_time.shape[4])
        inputs = (numerical_features, time_features, local_news, local_time, global_news)

        preds, _ = self(inputs)
        loss, labels, preds_ = self.compute_training_loss(preds, labels)
        label_mean, label_std, label_corr = labels
        pred_mean, pred_std, pred_corr = preds_

        mean_rmse_loss = rmse(label_mean, pred_mean)
        self.log('mean_rmse', mean_rmse_loss, prog_bar=True)
        mean_mae_loss = mae(label_mean, pred_mean)
        self.log('mean_mae', mean_mae_loss)
        mean_smape_loss = smape(label_mean, pred_mean)
        self.log('mean_smape', mean_smape_loss)

        vol_rmse_loss = rmse(label_std, pred_std)
        self.log('vol_rmse', vol_rmse_loss, prog_bar=True)
        vol_mae_loss = mae(label_std, pred_std)
        self.log('vol_mae', vol_mae_loss)
        vol_smape_loss = smape(label_std, pred_std)
        self.log('vol_smape', vol_smape_loss)

        corr_rmse_loss = rmse(label_corr, pred_corr)
        self.log('corr_rmse', corr_rmse_loss, prog_bar=True)
        corr_mae_loss = mae(label_corr, pred_corr)
        self.log('corr_mae', corr_mae_loss)
        corr_smape_loss = smape(label_corr, pred_corr)
        self.log('corr_smape', corr_smape_loss)

        rmse_loss = mean_rmse_loss + vol_rmse_loss + corr_rmse_loss
        self.log('val_loss', rmse_loss)

    def test_step(self, batch, batch_idx):
        numerical_features, targets, prices, time_features, global_news, local_news, local_time = batch
        numerical_features = numerical_features.float().reshape(numerical_features.shape[0]*numerical_features.shape[1], numerical_features.shape[2], numerical_features.shape[3])
        time_features = time_features.float().reshape(time_features.shape[0]*time_features.shape[1], time_features.shape[2])
        labels = targets.float().reshape(targets.shape[0]*targets.shape[1], targets.shape[2])
        prices = prices.float().reshape(prices.shape[0]*prices.shape[1], prices.shape[2])
        local_news = local_news.float().view(local_news.shape[0]*local_news.shape[1],
                                        local_news.shape[2]*local_news.shape[3],local_news.shape[4])
        local_time = local_time.float().view(local_time.shape[0]*local_time.shape[1],
                                            local_time.shape[2]*local_time.shape[3], local_time.shape[4])
        inputs = (numerical_features, time_features, local_news, local_time, global_news)
        
        preds, _ = self(inputs)
        loss, labels, preds_ = self.compute_training_loss(preds, labels)
        label_mean, label_std, label_corr = labels
        pred_mean, pred_std, pred_corr = preds_

        mean_rmse_loss = rmse(label_mean, pred_mean)
        self.log('mean_rmse', mean_rmse_loss, prog_bar=True)
        mean_mae_loss = mae(label_mean, pred_mean)
        self.log('mean_mae', mean_mae_loss)
        mean_smape_loss = smape(label_mean, pred_mean)
        self.log('mean_smape', mean_smape_loss)

        vol_rmse_loss = rmse(label_std, pred_std)
        self.log('vol_rmse', vol_rmse_loss, prog_bar=True)
        vol_mae_loss = mae(label_std, pred_std)
        self.log('vol_mae', vol_mae_loss)
        vol_smape_loss = smape(label_std, pred_std)
        self.log('vol_smape', vol_smape_loss)

        corr_rmse_loss = rmse(label_corr, pred_corr)
        self.log('corr_rmse', corr_rmse_loss, prog_bar=True)
        corr_mae_loss = mae(label_corr, pred_corr)
        self.log('corr_mae', corr_mae_loss)
        corr_smape_loss = smape(label_corr, pred_corr)
        self.log('corr_smape', corr_smape_loss)

        rmse_loss = mean_rmse_loss + vol_rmse_loss + corr_rmse_loss
        self.log('test_loss', rmse_loss)

    def forward(self, inputs):
        numerical_features, time_features, local_news, local_time, global_news = inputs
        preds, embed = self.model(numerical_features, time_features, local_news, local_time, global_news)
        return preds, embed

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.n_epochs)
        return [optimizer], [scheduler]

    def compute_training_loss(self, preds, labels):
        forecast_moments, forecast_corr = preds
        label_mean = labels.mean(1).unsqueeze(-1)
        label_std = labels.std(1).unsqueeze(-1)
        label_corr = corrcoef(labels)

        pred_mean = forecast_moments[:,0].unsqueeze(1)
        pred_std = forecast_moments[:,1].unsqueeze(1)

        loss_mean = self.criterion(pred_mean, label_mean)
        loss_std = self.criterion(pred_std, label_std)
        loss_corr = self.criterion(forecast_corr, label_corr)

        loss = loss_mean + loss_std + loss_corr
        labels = (label_mean, label_std, label_corr)
        preds = (pred_mean, pred_std, forecast_corr)

        return loss, labels, preds
