import torch
import logging
import argparse
import sys

from transformers import pipeline
import pandas as pd
from transformers.pipelines.pt_utils import KeyDataset
from tqdm.auto import tqdm
from datasets import Dataset

logger = logging.getLogger(__name__)

def pipeline_evaluate(type, model, data, text_column, text_pair_column=None, batch_size=32, device_num=-1, **kwargs):
    if text_pair_column is None:
      data = data[text_column].astype(str).to_list()
    else:
      data = [{'text': text, 'text_pair': text_pair} for text, text_pair in zip(data[text_column].astype(str), data[text_pair_column].astype(str))]
    data = KeyDataset(Dataset.from_dict({'input': data}), 'input') 
    pipe = pipeline(type, model=model, device=device_num)
    return pipe(data, batch_size=batch_size, **kwargs)

def main():
    parser = argparse.ArgumentParser(description='Evaluate With Classifier')
    parser.add_argument('--input_file', type=str, default=None)
    parser.add_argument('--text_column', type=str, default=None)
    parser.add_argument('--num_to_evaluate', type=int, default=None)
    parser.add_argument('--clf', type=str, default=None)
    parser.add_argument('--output_file', type=str, default=None)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--log_level', type=str, default='INFO')
    parser.add_argument('--max_seq_length', type=int, default=256)
    args = parser.parse_args()


    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    logger.setLevel(args.log_level)

    logger.info(f"Evaluation parameters {args}")


    df = pd.read_csv(args.input_file)
    if args.num_to_evaluate is not None:
        df = df.sample(args.num_to_evaluate)

    device_num = 0 if torch.cuda.is_available() else -1
    device = torch.device('cuda:{}'.format(device_num) if torch.cuda.is_available() else 'cpu')

    logger.info(f"Evaluating with classifier {args.clf}")
    
    results = pipeline_evaluate(
        type="text-classification",
        model=args.clf,
        device_num=device_num,
        data=df,
        text_column=args.text_column,
        batch_size=args.batch_size,
        return_all_scores=False,
        max_length=args.max_seq_length,
    )

    classes = []
    scores = []
    for result in tqdm(results, desc='classifing', total=len(df)):
        scores.append(result['score'])
        classes.append(result['label'])
    
    pd.DataFrame({'score': scores, 'class': classes}).to_csv(args.output_file, index=False)
    logger.info(f"Saved results to {args.output_file}")
    max_gpu_memory = torch.cuda.max_memory_allocated(device=None)
    max_gpu_memory = f"{max_gpu_memory/float(1<<30):,.2f} GiB"
    logger.info(f"Max GPU memory used: {max_gpu_memory}")

if __name__ == '__main__':
    main()