import sys
from datasets import load_from_disk, load_dataset, Features, Value, Sequence, ClassLabel
import evaluate
from transformers import AutoModelForTokenClassification, Trainer, TrainingArguments
from transformers import DataCollatorForTokenClassification
from transformers import AutoTokenizer
import numpy as np
from torch import nn, max as torch_max, IntTensor
from train_on_dataset import compute_metrics

SPECIAL_TOKEN = -100

def test_eval_on_sentence(best_model, input_text, tokenizer):
    #input_text = "The text on which I test"
    input_text_tokenized = tokenizer.encode(input_text,
                                            truncation=True,
                                            padding=True,
                                            return_tensors="pt")
    prediction = best_model(input_text_tokenized)
    prediction_logits = prediction[0]
    sigmoid = nn.Sigmoid()
    probs = sigmoid(prediction_logits.squeeze().cpu())
    predictions = np.zeros(probs.shape)
    max_prob_indices = [t.item() for t in list(torch_max(probs, dim=1)[1])]
    # predictions[np.where(probs >= 0.5)] = 1
    # turn predicted id's into actual label names
    # predicted_labels = [best_model.id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]
    predicted_labels = [best_model.config.id2label[idx] for idx in max_prob_indices]
    print(predicted_labels)


def convert_predictions(predictions, model_config):
    txt_labels = []
    for p in list(predictions):
        txt_labels.append([model_config.id2label[idx] for idx in p if idx != -100])
    return txt_labels



if __name__ == "__main__":
    model_path = sys.argv[1]
    dataset_path = sys.argv[2]
    output_path = sys.argv[3]
    best_model = AutoModelForTokenClassification.from_pretrained(model_path) 
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
    dataset = load_from_disk(dataset_path)

    training_args = TrainingArguments(
        disable_tqdm=True,
        do_train=False,
        do_eval=False,
        do_predict=True,
        output_dir=output_path,
        eval_accumulation_steps=100
    )

    trainer = Trainer(
        model=best_model,
        args=training_args,
        eval_dataset=dataset['validation'],
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        tokenizer=tokenizer,
    )

    predictions = trainer.predict(dataset['test'])
    txt_predictions = convert_predictions(predictions.label_ids,best_model.config)
    print(predictions.metrics)
    with open(output_path + 'predictions.txt', 'w') as f:
        for p in txt_predictions:
            f.write(str(p) + '\n')

