# Demo code, full tempalte will be available in the public version.

import torch
from transformers import T5Tokenizer, T5Config, T5ForConditionalGeneration, T5TokenizerFast


T5_PATH = 't5-large' # "t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

t5_tokenizer = T5TokenizerFast.from_pretrained('t5-small')
t5_config = T5Config.from_pretrained(T5_PATH)
t5_mlm = T5ForConditionalGeneration.from_pretrained(T5_PATH, config=t5_config).to(DEVICE)


# Input text


What_color = "What color is the floor of that area? The color of floor of that area is [mask] . " \
	   "What color is the chair the cat is on? The color of the chair the cat on is [mask] . " \
	   "What color is the man's shorts? The color of the man's shorts is [mask] . " \
	   "What color is the letter on the sign? The color of the letter on the sign is [mask] . " \
	   "What color is the statue near the building? <extra_id_0> ."

text = What_color

encoded = t5_tokenizer.encode_plus(text, add_special_tokens=True, return_tensors='pt')
input_ids = encoded['input_ids'].to(DEVICE)

outputs = t5_mlm.generate(input_ids=input_ids,
						  num_beams=20, num_return_sequences=10,
						  max_length=30)

_0_index = text.index('<extra_id_0>')
_result_prefix = text[:_0_index]
_result_suffix = text[_0_index+12:]  # 12 is the length of <extra_id_0>

def _filter(output, end_token='<extra_id_1>'):
	# The first token is <unk> (inidex at 0) and the second token is <extra_id_0> (indexed at 32099)
	_txt = t5_tokenizer.decode(output[2:], skip_special_tokens=False, clean_up_tokenization_spaces=False)
	if end_token in _txt:
		_end_token_index = _txt.index(end_token)
		return _result_prefix + _txt[:_end_token_index] + _result_suffix
	else:
		return _result_prefix + _txt + _result_suffix

def _filter_2(output, end_token='<extra_id_1>'):
	# The first token is <unk> (inidex at 0) and the second token is <extra_id_0> (indexed at 32099)
	_txt = t5_tokenizer.decode(output[2:], skip_special_tokens=False, clean_up_tokenization_spaces=False)
	if end_token in _txt:
		_end_token_index = _txt.index(end_token)
		return _txt[:_end_token_index]
	else:
		return _txt

results = list(map(_filter, outputs))
mask_infilling = list(map(_filter_2, outputs))

print(text)
print(mask_infilling)
print(max(mask_infilling,key=mask_infilling.count))
