import os, sys

from project_root import ROOT_DIR
from baselines.MetricClass import MetricClass

from comet import download_model, load_from_checkpoint

from huggingface_hub import snapshot_download, login

class XComet(MetricClass):
    name = 'XComet'

    def __init__(self, access_token, batch_size=8, cache_dir="/work/<REDACTED>/.cache", *args, **kwargs):
        #login(token = access_token)
        #os.environ['TRANSFORMERS_CACHE'] = cache_dir
        #self.model_path = snapshot_download(repo_id="Unbabel/XCOMET-XXL", cache_dir=cache_dir)
        #print(self.model_path)
        #self.model = load_from_checkpoint(os.path.join(*[self.model_path, "checkpoints", "model.ckpt"]))
        self.model = load_from_checkpoint(os.path.join(*["/work/<REDACTED>/.cache/models--Unbabel--XCOMET-XXL/snapshots/bad20b47daa64c41a8b29f3d3016be75baf0d7b4", "checkpoints", "model.ckpt"]))
        self.batch_size = batch_size


    def __call__(self, gt, hyp):
        data = [{"src": g, "mt": h} for g, h in zip (gt, hyp)]
        model_output = self.model.predict(data, batch_size=self.batch_size, gpus=1)

        return model_output.scores



if __name__ == '__main__':
    b = XComet(sys.argv[1])

    print(sum(p.numel() for p in b.model.parameters()))
    print(b(["A test sentence", "Sentence B"],["So Cummings was told that these units must be preserved in their entirety.", "Satz B"]))
