import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
import math

class SelfSupervisedModel(nn.Module):
    def __init__(self, num_labels=3, num_feature=4096, max_len=10000, num_filters=128, filter_sizes=[2,3],compute_dtype=None):
        super(SelfSupervisedModel, self).__init__()
        self.num_labels = num_labels
        self.num_feature = num_feature
        self.num_filters = num_filters
        self.filter_sizes = filter_sizes
        self.compute_dtype = compute_dtype

        self.convs = nn.ModuleList([nn.Conv2d(1, self.num_filters, (k, self.num_feature)) for k in self.filter_sizes])
        self.dropout = nn.Dropout(0.1)
        #self.classifier = nn.Linear(self.num_feature, num_labels)
        self.classifier = nn.Linear(self.num_filters * len(self.filter_sizes), num_labels)

        # positional encoding
        # compute the positional encodings once in log space
        pe = torch.zeros(max_len, num_feature)
        position = torch.arange(0, max_len).unsqueeze(1)

        div_term = torch.exp(torch.arange(0, num_feature, 2) * -(math.log(10000.0) / num_feature))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

        # chosen
        self.transform_chosen = nn.Linear(self.num_feature, self.num_feature)
        self.classifier_chosen = nn.Linear(self.num_feature, self.num_labels)
        # reject
        self.transform_reject = nn.Linear(self.num_feature, self.num_feature)
        self.classifier_reject = nn.Linear(self.num_feature, self.num_labels)

    def transform_and_pool(self, x, x_type=None):
        # [1, seq_len, hidden] -> [1, seq_len, hidden] -> [1, 1, hidden]
        if x_type == 'chosen':
            x = self.transform_chosen(x)
        elif x_type == 'reject':
            x = self.transform_reject(x)
        else:
            raise ValueError(
                f"Unknown x_type:{x_type}. Should be one of ['chosen', 'reject']"
            )
        # [1, seq_len, hidden] -> [1, 1, hidden]
        x, _ = torch.max(x, dim=1, keepdim=True)
        return x


    def conv_and_pool(self, x, conv):
        x = F.relu(conv(x)).squeeze(3)
        x = F.max_pool1d(x, x.size(2)).squeeze(2)
        return x

    def forward_bak(self, x):
        
        x = x.unsqueeze(0)

        # add positional embedding
        x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)

        # Add custom layers
        sequence_output = self.dropout(x) # outputs[0] = last hidden state

        # sum and mean all tokens in each batch
        token_nums = sequence_output.shape[1]
        sequence_output = sequence_output.sum(1) / float(token_nums)

        logits = self.classifier(sequence_output) # calculate loss

        return logits

    def forward(self, x, x_type=None):
        # print(x)
        # print(x.dtype)
        # print(x.shape)
        if self.compute_dtype == torch.bfloat16:
            x = x.bfloat16()
        elif self.compute_dtype == torch.float16:
            x = x.half()
        x = x.unsqueeze(0)
        # print("===================================")
        # print(x.dtype)
        # print(x.shape)
        # print(self)

        self.pe = self.pe.to(x.device)
        
        # add positional embedding
        # print(f'x.shape:{x.shape}')
        x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False) # [1, seq_len, hidden]

        # CNN version
        # x = x.unsqueeze(1)# [1, 1, seq_len, hidden]
        # # print(f'x.shape:{x.shape}')
        # out = torch.cat([self.conv_and_pool(x, conv) for conv in self.convs], 1)
        # # print(f'out:{out.shape}')
        # # Add custom layers
        # sequence_output = self.dropout(out) # outputs[0] = last hidden state
        # # print(f'sequence_output.shape:{sequence_output.shape}')
        # logits = self.classifier(sequence_output) # calculate loss [1, num_labels]
        # # print(f'logits.shape:{logits.shape}')

        # MLP version 
        out = self.transform_and_pool(x,x_type)
        sequence_output = self.dropout(out)
        # print(f'sequence_output.shape:{sequence_output.shape}')
        if x_type == 'chosen':
            logits = self.classifier_chosen(sequence_output.squeeze(1))
        elif x_type == 'reject':
            logits = self.classifier_reject(sequence_output.squeeze(1))
        else:
            raise ValueError(
                f"Unknown x_type:{x_type}. Should be one of ['chosen', 'reject']"
            )
        # logits = self.classifier_2(sequence_output.squeeze(1))
        # print(f'logits.shape:{logits.shape}')
        return logits.to(torch.float32)