from Watermark.WatermarkingFn import WatermarkingFn
import numpy as np

def compute_scaling_factor(N):
    half_N = int(N/2)
    tile = np.tile(np.arange(N),(half_N,1))*np.arange(1,half_N+1)[:,None]*2*np.pi/N
    scaling_factor = 1/np.std(np.concatenate((np.cos(tile[:-1]), np.sin(tile[:-1]))), axis = 1)
    return scaling_factor

# Precomputed scaling factor for llama tokenizer
N = 32000
# scaling_factor = compute_scaling_factor(N)

class WatermarkingFnFourier(WatermarkingFn):
    def __init__(self, id = 0, k_p = 1, N = 32000, kappa = 1):
        super().__init__(id = id, k_p = k_p, N = N, kappa = kappa)

        freq = self.k_p
        assert (freq > 0) and (freq < self.N), f"k_p must be 0<k_p<{self.N}, value provided is {freq}"

        half_N = int(self.N/2)
        if freq <= half_N:
            self.phi = np.cos(np.arange(self.N)/self.N*2*np.pi*freq)
        else:
            freq -= half_N
            self.phi = np.sin(np.arange(self.N)/self.N*2*np.pi*freq)
        self.phi *= self.kappa

        self.scaling_factor = np.ones((self.N//2 - 1)*2)

        # if self.N != N:
        #     self.scaling_factor = compute_scaling_factor(N)
        # else:
        #     self.scaling_factor = scaling_factor

    def q(self, bins):
        if bins.ndim == 1:
            bins = bins[None,:]
        fft_res = np.fft.rfft(bins / bins.sum(axis=1))[:,1:-1]
        fft_res = np.concatenate((np.real(fft_res), np.imag(fft_res)), axis=1)
        fft_res *= self.scaling_factor

        return fft_res