import argparse
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
from Watermark.WatermarkingFnFourier import WatermarkingFnFourier
from Watermark.WatermarkingFnSquare import WatermarkingFnSquare
from Watermark.WatermarkerBase import Watermarker

prompt = (
    "Paraphrase the user provided text while preserving semantic similarity. "
    "Do not include any other sentences in the response, such as explanations of the paraphrasing. "
    "Do not summarize."
)
pre_paraphrased = "Here is a paraphrased version of the text while preserving the semantic similarity:\n\n"

def watermark_and_evaluate(T_o):
    print(f"\nUsing ID: \033[91m{args.id}\033[0m and k_p: \033[91m{args.k_p}\033[0m to watermark original text T_o:\n{T_o}\n")

    # Generate watermarked text
    paraphrasing_prompt = tokenizer.apply_chat_template(
        [
            {"role":"system", "content":prompt},
            {"role":"user", "content":T_o},
        ], tokenize=False, add_generation_prompt = True) + f"{pre_paraphrased}\n\n"
    watermarked = watermarker.generate(paraphrasing_prompt)

    print(f"\nWatermarked text T_w:\n{watermarked}\n")

    # Verify on original text
    res = watermarker.verify(T_o)[0]
    q = res[k_p-1]
    print(f"Verification score of T_o: \033[93m{q:.4f}\033[0m")

    # Verify on watermarked text
    res = watermarker.verify(watermarked)[0]
    q = res[k_p-1]
    print(f"Verification score of T_w: \033[92m{q:.4f}\033[0m")

    # Extract from watermarked text
    print(f"\nExtracted k_p from T_w   : \033[96m{np.argmax(res)+1}\033[0m")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='generate text watermarked with a key')
    parser.add_argument('--id',default=0,type=int,
            help='id: unique ID')
    parser.add_argument('--kappa',default=6,type=float,
            help='kappa: watermarking strength')
    parser.add_argument('--k_p', default=1, type=int,
            help="k_p: Perturbation key")
    parser.add_argument('--outdir',default="outdir",type=str,
            help='out directory')
    parser.add_argument('--model', default='meta-llama/Llama-2-13b-chat-hf', type=str,
            help="model")
    parser.add_argument('--T_o', default='In a magical forest, a gentle deer named Lily discovered a hidden forest filled with colorful flowers that glowed in the moonlight. Every night, her animal friends would join her there, dancing joyfully under the twinkling stars.',
            type=str,
            help="original_text")

    args = parser.parse_args()

    id = args.id
    kappa = args.kappa
    k_p = args.k_p
    model_name_or_path = args.model
    T_o = args.T_o

    # Initialize tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    model = AutoModelForCausalLM.from_pretrained(
        model_name_or_path, 
        torch_dtype=torch.float16,
        load_in_8bit=True,
        device_map=0,
        )

    watermarker = Watermarker(model, tokenizer, id, kappa, k_p, n_gram = 2, watermarkingFnClass=WatermarkingFnFourier)
    # watermarker = Watermarker(model, tokenizer, id, kappa, k_p, n_gram = 2, watermarkingFnClass=WatermarkingFnSquare)

    watermark_and_evaluate(T_o)