import logging
logger = logging.getLogger('__main__')

class PTBEvaluator:

    def __init__(self, names, vocab):
        self.eval_set = {}
        for name in names:
            self.eval_set[name] = {}
        self.ignore_tag = set()
        # self.ignore_tag.add(vocab.get_token_index('``', 'tag'))
        # self.ignore_tag.add(vocab.get_token_index('\'\'', 'tag'))
        # self.ignore_tag.add(vocab.get_token_index(':', 'tag'))
        # self.ignore_tag.add(vocab.get_token_index(',', 'tag'))
        # self.ignore_tag.add(vocab.get_token_index('.', 'tag'))


    def add_pred(self, name, pred):
        self.eval_set[name]['pred'].append(pred)

    def add_truth(self, name, truth):
        self.eval_set[name]['truth'].append(truth)

    def evaluation(self, name):
        if 'Best_UAS' not in self.eval_set[name]:
            self.eval_set[name]['Best_UAS'] = 0
            self.eval_set[name]['Best_LAS'] = 0
        total_token = UA = LA = cnt = 0
        for truth, pred in zip(self.eval_set[name]['truth'], self.eval_set[name]['pred']):
            for truth_head, pred_head, truth_rel, pred_rel, tag in zip(truth['head'], pred['head'], truth['rel'], pred['rel'], truth['tag']):
                if (truth_head or truth_rel) and tag not in self.ignore_tag:
                    total_token += 1
                    if truth_head == pred_head:
                        UA += 1
                        if truth_rel == pred_rel:
                            LA += 1
                if (truth_head or truth_rel): cnt += 1

        UAS = UA*1.0/total_token
        LAS = LA*1.0/total_token
        logger.info("%s: UAS=%f, LAS=%f" % (name, UAS, LAS))
        if UAS > self.eval_set[name]['Best_UAS'] or (UAS+LAS > self.eval_set[name]['Best_UAS']+self.eval_set[name]['Best_LAS']):
            self.eval_set[name]['Best_UAS'] = UAS
            self.eval_set[name]['Best_LAS'] = LAS
            return True
        return False

    def print_best_result(self, name):
        logger.info('Best Results: UAS: %s, LAS: %s' % (self.eval_set[name]['Best_UAS'], self.eval_set[name]['Best_LAS']))

    def clear(self, name):
        self.eval_set[name]['pred'] = []
        self.eval_set[name]['truth'] = []
