from transformers import MarianMTModel, MarianTokenizer
import torch

def translate(texts, model, tokenizer, language="fr"):
    # Prepare the text data into appropriate format for the model
    def template(
        text): return f"{text}" if language == "en" else f">>{language}<< {text}"
    src_texts = [template(text) for text in texts]

    # Tokenize the texts
    encoded = tokenizer.prepare_seq2seq_batch(src_texts)
    # input_ids = tokenizer.encode(src_texts)
    for k, v in encoded.items():
        encoded[k] = torch.tensor(v)
    encoded = encoded.to('cuda')  

    # Generate translation using model
    translated = model.generate(**encoded)

    # Convert the generated tokens indices back into text
    translated_texts = tokenizer.batch_decode(
        translated, skip_special_tokens=True)

    return translated_texts


def back_translate(texts, target_model, target_tokenizer, en_model, en_tokenizer, source_lang="en", target_lang="fr"):
    # Translate from source to target language
    fr_texts = translate(texts, target_model, target_tokenizer,
                         language=target_lang)

    # Translate from target language back to source language
    back_translated_texts = translate(fr_texts, en_model, en_tokenizer,
                                      language=source_lang)

    return back_translated_texts

# en_texts = ['This is so cool', 'I hated the food', 'They were very helpful']
# # en_texts = 'I love you.'
# aug_texts = back_translate(en_texts, target_model, target_tokenizer, en_model, en_tokenizer, source_lang="en", target_lang="fr")
# print(aug_texts)
