import os
import random
import time
import torch
import torch.nn as nn
import numpy as np

from itertools import accumulate
from math import ceil

from pruner.utils.runs import Run
from pruner.utils.utils import print_message

from pruner.evaluation.metrics import Metrics
from pruner.modeling.inference import ModelInference

def evaluate(args):
    inference = ModelInference(args.pruner, amp=args.amp)
    data = args.data
    # data: List[Tuple(str, ndarray(float64)))] = The list of tuples of (passage, label)

    metrics = Metrics(recall_depths={10, 20, 24, 30, 40, 50, 60}, total_passages=len(data))

    with torch.no_grad():
        
        n_samples = 0
        
        # Batch processing
        for offset in range(0, len(data), args.bsize):
        
            endpos = min(offset + args.bsize, len(data))
            
            batch = data[offset:endpos]
            passages, labels = zip(*batch)

            scores, _ = inference.scoreFromText(docs=passages) 
            # scores: List[ List[float] ] = scores for each token

            #?@ debugging
            # print(f'passages[0]={passages[0]}')
            # print(f'passages[1]={passages[1]}')
            # print(f'labels[0]={labels[0]}')
            # print(f'labels[1]={labels[1]}')
            # print(f'scores[0]={scores[0]}')
            # print(f'scores[1]={scores[1]}')
            # exit()

            for score, label in zip(scores, labels):

                metrics.add(prediction=score, gold_tokens=set(np.nonzero(label)[0]))
                n_samples += 1

                # #?@ debugging
                # metrics.print_metrics(n_samples)
                # print("\n\n")

    print('\n\n')
    assert n_samples == len(data)
    print_message("#> checkpoint['batch'] =", args.checkpoint['batch'])
    metrics.output_final_metrics(os.path.join(Run.path, 'pruning.metrics'), n_samples, len(data))
    print('\n\n')
