import numpy as np
import time

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

import transformers

import sklearn
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_recall_fscore_support

import argparse

parser = argparse.ArgumentParser()

parser.add_argument('--input_dir',default=None,type=str)
parser.add_argument('--output_dir',default=None,type=str)

args = parser.parse_args()

if torch.cuda.is_available():    
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

ent_biobert_trn = np.load(args.input_dir+'.npy')

def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
torch.manual_seed(123)
model = nn.Sequential(nn.Linear(768,1024))
model.apply(init_weights)
model.cuda()

def dimension_expander_ent(input_seqs,model):
    output = []
    for seqs in input_seqs:
        output_ex = []
        for seq in seqs:
            output_ex.append(model(torch.tensor(seq).cuda()).detach().cpu().numpy())
        output_ex = np.array(output_ex)
        output.append(output_ex)
    return np.array(output)
def dimension_expander(input_seqs,model):
    output = []
    for seq in input_seqs:
        output.append(model(torch.tensor(seq).cuda()).detach().cpu().numpy())
    return np.array(output)
ent_biobert_trn_exp = (dimension_expander(ent_biobert_trn,model))

np.save(args.output_dir,ent_biobert_trn_exp)

print(ent_biobert_trn_exp.shape)