import argparse
from types import SimpleNamespace

import os
import sys

import torch
import torch.nn as nn
from transformers import (AutoTokenizer,
                          AutoModelForSeq2SeqLM)

sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))

from utils.training_utils import PDICT, logger

parser = argparse.ArgumentParser()
torch.set_num_threads(1)


def return_model(args: SimpleNamespace):
    if args.precision == 16:
        dtype = torch.float16
    else:
        dtype = torch.float32

    assert args.model_type in PDICT  # 'Model type defining error'
    model_args = PDICT[args.model_type]['model_args']
    tokenizer = AutoTokenizer.from_pretrained(model_args)
    core_model = AutoModelForSeq2SeqLM.from_pretrained(model_args)

    core_model = TaskModel(core_model, args)
    return core_model, tokenizer


class TaskModel(nn.Module):
    def __init__(self, core_model, args):
        super().__init__()
        self.model = core_model
        self.args = args

    def set_external(self, **kwargs):
        for key in kwargs:
            setattr(self, key, kwargs[key])

    def forward_seq2seq(self, **kwargs):
        out = self.model(
            input_ids=kwargs['src_ids'],
            attention_mask=kwargs['src_attention_mask'],
            decoder_input_ids=kwargs['decoder_input_ids'],
            decoder_attention_mask=kwargs['decoder_attention_mask'],
            labels=kwargs['labels'],
            return_dict=True)
        return out, out['loss']

    def forward(self, **kwargs):
        outputs = {}
        o, l = self.forward_seq2seq(**kwargs)
        outputs['out'] = o
        outputs['loss'] = l
        return outputs

if __name__ == '__main__':
    print(':)')
