import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Transformer


class ValueModel(nn.Module):
    def __init__(self, game, num_encoder_layers, num_decoder_layers, device):
        super().__init__()
        self.device = device

        self.transformer = Transformer(game.memory_size + 5, nhead=5, num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers)

        self.valueHead = nn.Sequential(
            nn.Dropout(.05),
            nn.Linear(game.memory_size + 5, game.memory_size + 5),
            nn.ReLU(),
            nn.Dropout(.05),
            nn.Linear(game.memory_size + 5, 1),
            nn.Tanh()
        )

        self.to(device)

    def forward(self, x):
        x = self.transformer(x, x)
        value = self.valueHead(x)
        return value