import os

from project_root import ROOT_DIR
from utils.bart_score import BARTScorer
from baselines.MetricClass import MetricClass



class BARTScore(MetricClass):
    name = 'BARTSCORE'

    def __init__(self, batch_size=8, lang='en', *args, **kwargs):
        self.bart_scorer = BARTScorer(device='cuda:0', checkpoint='facebook/bart-large-cnn')
        self.batch_size = batch_size


    def __call__(self, gt, hyp):
        return self.bart_scorer.score(hyp, gt, batch_size=self.batch_size)



if __name__ == '__main__':
    b = BARTScore()

    print(sum(p.numel() for p in b.bart_scorer.model.parameters()))
    print(b(["A test sentence", "Sentence B"],["So Cummings was told that these units must be preserved in their entirety.", "Satz B"]))
