import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader

## for comuting moments and losses ##
# metrics adapted from https://gist.github.com/bshishov/5dc237f59f019b26145648e2124ca1c9

EPSILON = 1e-3

def cov(tensor, rowvar=True, bias=False):
    """Estimate a covariance matrix (np.cov)"""
    tensor = tensor if rowvar else tensor.transpose(-1, -2)
    tensor = tensor - torch.mean(tensor, dim=-1, keepdim=True)
    factor = 1 / (tensor.shape[-1] - int(not bool(bias)))
    return factor * tensor @ tensor.transpose(-1, -2).conj()

def corrcoef(tensor, rowvar=True):
    """Get Pearson product-moment correlation coefficients (np.corrcoef)"""
    covariance = cov(tensor, rowvar=rowvar)
    variance = covariance.diagonal(0, -1, -2)
    if variance.is_complex():
        variance = variance.real
    stddev = variance.sqrt() + EPSILON
    correlation = covariance
    correlation /= stddev.unsqueeze(-1)
    correlation /= stddev.unsqueeze(-2)
    if correlation.is_complex():
        correlation.real.clip_(-1, 1)
        correlation.imag.clip_(-1, 1)
    else:
        correlation.clip_(-1, 1)
    return correlation

def _error(actual, predicted):
    """ Simple error """
    return actual - predicted

def _percentage_error(actual, predicted):
    """
    Percentage error

    Note: result is NOT multiplied by 100
    """
    return _error(actual, predicted) / (actual + EPSILON)

def mse(actual, predicted):
    """ Mean Squared Error """
    return torch.mean(torch.square(_error(actual, predicted)))

def rmse(actual, predicted):
    """ Root Mean Squared Error """
    return torch.sqrt(mse(actual, predicted))

def mae(actual, predicted):
    """ Mean Absolute Error """
    return torch.mean(torch.abs(_error(actual, predicted)))

def smape(actual, predicted):
    """
    Symmetric Mean Absolute Percentage Error

    Note: result is NOT multiplied by 100
    """

    smape_ = (2.0 * torch.abs(actual - predicted) / ((torch.abs(actual) + torch.abs(predicted)))).nan_to_num(nan=0.0, posinf=0.0, neginf=0.0)
    smape = torch.mean(smape_)
    # return torch.mean(2.0 * torch.abs(actual - predicted) / ((torch.abs(actual) + torch.abs(predicted)) + EPSILON))
    return smape

class CustomDataset(Dataset):
    def __init__(self, num_dict, target, price, global_news, local_news, local_news_time, date_index, index_list, keys, window=5, horizon=5):

        self.window = window
        self.horizon = horizon
        self.numerical_array = np.stack([num_dict[k].reset_index().iloc[:,1:].values for k in keys],axis=-1)
        self.target_array = target.iloc[:,1:].values
        self.price_array = price.iloc[:,1:].values
        
        self.global_news = global_news

        self.local_text = local_news
        self.local_news_time = local_news_time

        self.time_feat = pd.DataFrame(date_index).copy()
        self.time_feat['day'] = self.time_feat.date.dt.dayofweek
        self.time_feat['week'] = self.time_feat.date.dt.isocalendar().week 
        self.time_feat['month'] = self.time_feat.date.dt.month
        self.time_array = self.time_feat.iloc[:,1:].astype(int).values
        self.date_index = date_index
        self.window = window
        self.horizon = horizon
        self.index_list = index_list

    def __len__(self):
        return len(self.index_list)

    def __getitem__(self, index):

        idx = self.index_list[index]
        selected_date = self.date_index[idx]
        window_dates = self.date_index[idx - self.window: idx]
        horizon_dates = self.date_index[idx: idx + self.horizon]
        all_dates = self.date_index[idx - self.window: idx + self.horizon]

        numerical_features = np.transpose(self.numerical_array[window_dates.index.values], (1,0,2))
        targets = np.transpose(self.target_array[horizon_dates.index.values])[:,:,np.newaxis]
        prices = np.transpose(self.price_array[horizon_dates.index.values])[:,:,np.newaxis]

        time_features = self.time_array[window_dates.index.values]

        window_start = window_dates.index.values[0] 
        window_end = window_dates.index.values[-1] + 1
        global_news = self.global_news[window_start:window_end]

        local_news = self.local_text[window_start:window_end].transpose(1,0,2,3)

        local_time = self.local_news_time[window_start:window_end].transpose(1,0,2)[:,:,:,np.newaxis]

        local_time_features = time_features[np.newaxis,:,:].repeat(local_time.shape[0], 0)

        local_time_features = local_time_features[:,:,np.newaxis,:].repeat(local_time.shape[2], 2)

        local_time = np.concatenate((local_time, local_time_features), axis=-1)

        return numerical_features, targets, prices, time_features, global_news, local_news, local_time

    def subsequent_mask(self, seq_len):
        attn_shape = (1, seq_len, seq_len)
        subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
        return torch.from_numpy(subsequent_mask) == 0
