""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet)
"""

import argparse
import logging

import numpy as np
import torch

from transformers import (
    GPT2LMHeadModel,
    GPT2Tokenizer,
)

import sys

import json

sys.path.append('..')

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

MAX_LENGTH = int(10000)  # Hardcoded max length to avoid infinite loop


def set_seed(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)


def adjust_length_to_model(length, max_sequence_length):
    if length < 0 and max_sequence_length > 0:
        length = max_sequence_length
    elif 0 < max_sequence_length < length:
        length = max_sequence_length  # No generation bigger than model size
    elif length < 0:
        length = MAX_LENGTH  # avoid infinite loop
    return length


class GPT2Generator:
    def __init__(self, model_path, stop_token=None, max_len=512, cuda=1, seed=2021):
        self.device = torch.device("cuda" if torch.cuda.is_available() and cuda is not None else "cpu")
        self.n_gpu = 0 if cuda is not None else torch.cuda.device_count()
        self.seed = seed
        self.tokenizer = GPT2Tokenizer.from_pretrained(model_path)
        self.model = GPT2LMHeadModel.from_pretrained(model_path)
        self.model.to(self.device)
        self.length = adjust_length_to_model(max_len, max_sequence_length=self.model.config.max_position_embeddings)
        self.stop_token = stop_token

    def set_seed(self):
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        if self.n_gpu > 0:
            torch.cuda.manual_seed_all(self.seed)

    def generate(self, prompt):
        encoded_prompt = self.tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
        encoded_prompt = encoded_prompt.to(self.device)
        if encoded_prompt.size()[-1] == 0:
            input_ids = None
        else:
            input_ids = encoded_prompt

        output_sequences = self.model.generate(
            input_ids=input_ids,
            max_length=self.length + len(encoded_prompt[0]),
            temperature=1.0,
            top_p=0.9,
            do_sample=True,
        )

        generated_sequences = []

        for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
            generated_sequence = generated_sequence.tolist()

            # Decode text
            text = self.tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)

            # Remove all text after the stop token
            text = text[: text.find(self.stop_token) if self.stop_token else None]

            # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
            total_sequence = (
                    text[len(self.tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)):]
            )

            generated_sequences.append(total_sequence)

        return generated_sequences

if __name__ == "__main__":
    generator = GPT2Generator('../gpt2_en_ckpt_origin')

    prompt = '''The next hour was a blur for the new recruits, they were ferried unceremoniously from the office with a quick salute from Sir Hugh towards an unrecognisable airfield on the outskirts of London. There, the group was met by the charismatic Artur Boruc, a Polish national pilot who was going to get the group infiltrated into the area and prepared for the attack. The small cargo plane flew across the sea with ease, the fine clear sky allowing easy flying until they were above the Polish shore. Twenty-Five minutes out. Artur shouted from the cockpit before a deafening thunder clap began to ring out repeated and rhythmically. "Ah!" Artur screamed as the plane banked hard, throwing the passengers across the cramped fuselage.  Metal began to tear through the thin wings outside the small windows and pinged heavily off the underside of the plane, a quiet arrival in Poland wasn't going to be an option anymore.'''
    result = {}
    for i in range(1000):
        a = generator.generate(prompt)
        result[i] = a
    with open("gpt_result_1000_noprompt.json", 'w') as f:
        json.dump(result, f, indent=4, separators=[',', ':'])


