from sys import version
import torch
from tqdm import tqdm, trange
import torch.nn as nn
import os
import torch.nn.functional as F
import argparse
from torch.utils.tensorboard import SummaryWriter

class LayerNormApproximation(nn.Module):
    def __init__(self, hidden_size):
        super(LayerNormApproximation, self).__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))

    def forward(self, x):
        return self.weight * x + self.bias






def train(args, model):
    writer = SummaryWriter()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    model.to(device)
    iter_bar = tqdm(range(args.num_iter))
    for t in iter_bar:
        input_tensor = torch.randn(128, 128, 128, device=device) * args.scale
        label  = nn.Softmax(dim=-1)(input_tensor)

        pred = model(input_tensor, dim=-1)

        loss = nn.MSELoss()(pred.view(-1), label.view(-1))
        # print(loss)
        writer.add_scalar("loss", loss, t)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        iter_bar.set_description("loss: %.4f" % loss.item())
    return model

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--num_iter", type=int, default=10000)
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--scale", type=float, default=1.0)
    parser.add_argument("--log_dir", type=str, default="logs")
    parser.add_argument("--output_path", type=str, default="outputs")
    args = parser.parse_args()

    model = SoftmaxApproximation()
    state_dict = torch.load("/home/LAB/chenty/workspace/personal/MSRA_project/encryption/unilm2-he/reciprocal_model.pt", map_location="cpu")
    model.reciprocal.load_state_dict(state_dict)
    model = train(args, model)
    model = model.cpu()
    torch.save(model.reciprocal.state_dict(), args.output_path)

    print("save to {}".format(args.output_path))

if __name__ == "__main__":
    main()



