#!/usr/bin/env python3

import argparse
from collections import OrderedDict
import sys

import torch

try:
    from transformers import T5ForConditionalGeneration
except ModuleNotFoundError:
    print('transformers are not officially in the fairseq requirements, please install it on its own:'
          'pip install transformers==2.7.0') 



def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--model', '-m',
        type=str,
        required=True,
        choices=['t5-small', 't5-base', 't5-large', 't5-3b', 't5-11b', 'file'])
    parser.add_argument(
        '--path', '-p',
        type=str,
        required=False)
    parser.add_argument(
        '--output', '-o',
        type=str,
        required=False)
    return parser.parse_args()


def main(args):
    model_type = args.model

    if model_type == 'file':
        loaded_model = torch.load(args.path)
    else:
        loaded_model = T5ForConditionalGeneration.from_pretrained(model_type).state_dict()

    output_model_state_dict = OrderedDict()

    for name in loaded_model:
        if name.startswith('encoder'):
            output_model_state_dict[name.replace('encoder', 'encoder.t5_stack')] = loaded_model[name]
        elif name.startswith('decoder'):
            output_model_state_dict[name.replace('decoder', 'decoder.t5_stack')] = loaded_model[name]
        elif name.startswith('lm_head'):
            output_model_state_dict[name.replace('lm_head', 'decoder.lm_head')] = loaded_model[name]
        else:
            output_model_state_dict[name] = loaded_model[name]

    output_model = {}
    output_model['model'] = output_model_state_dict
    output_model['args'] = argparse.Namespace

    with open(f'{model_type}.pt', 'wb') as ffile:
        torch.save(output_model, ffile)


if __name__ == '__main__':
    args = parse_args()
    main(args)
