from transformers import pipeline
import json
from tqdm import tqdm
import os
from collections import Counter

test_files = os.listdir("final_results")
hate_classifier = pipeline("text-classification", model="badmatr11x/distilroberta-base-offensive-hateful-speech-text-multiclassification", device=0)
tokenizer_kwargs = {'padding':True,'truncation':True,'max_length':512}
label_str = "answer_1"
for file in test_files:
    with open(os.path.join("final_results" ,file), "r") as input_file:
        data = json.load(input_file)
    
    # For testing only use first 100 answers
    data_subset = data
    texts = [el["prediction"] for el in data_subset]
    questions = [el["questions"] for el in data_subset]
#     texts = [el[label_str] for el in data_subset]
#     questions = [el["text"] + " " + el["context"] for el in data_subset]
    predictions = []
    for idx, el in tqdm(enumerate(texts), total=len(texts)):
        response = hate_classifier(el, **tokenizer_kwargs)
        prediction = response[0]["label"]
        if prediction == "HATE-SPEECH" or prediction == "OFFENSIVE-LANGUAGE":
            new_result = {
                "idx": idx,
                "questions": questions[idx],
                "prediction": el,
                "reason": prediction
            }
            predictions.append(new_result)
#     with open(f"./hate_classification/hate_classifier_{label_str}.json", "w") as output_file:
#         json.dump(predictions, output_file)
    with open(f"./hate_classification/hate_classifier_{file.split('.')[0]}.json", "w") as output_file:
        json.dump(predictions, output_file)

# file = "complete_qa_final_filtered.json"
# with open(os.path.join("data" ,file), "r") as input_file:
#         data = json.load(input_file)
# texts = [el["text"] for el in data]
# contexts = [el["context"] for el in data]
# answers = [el["answer_1"] for el in data]
# complete = [el["text"] + " " + el["context"] + " " + el["answer_1"] for el in data]

# complete_words = [len(el.split(" ")) for el in complete]
# print(f"Average complete words: {sum(complete_words) / len(complete_words)}")

# complete_lengths = [len(el) for el in complete]
# print(f"Average complete length: {sum(complete_lengths) / len(complete_lengths)}")

# complete_tokens = [len(el) / 3.76 for el in complete]
# print(f"Average complete tokens: {sum(complete_tokens) / len(complete_tokens)}")

# with open(os.path.join("data" ,"complete_qa_final_filtered.json"), "r") as input_file:
#         data = json.load(input_file)
# with open(os.path.join("hate_classification" ,"hate_classifier_answer_2.json"), "r") as input_file:
#         classified = json.load(input_file)
# classified_hate = [el["idx"] for el in classified if el["reason"] == "HATE-SPEECH"]
# classified_offensive = [el["idx"] for el in classified if el["reason"] == "OFFENSIVE-LANGUAGE"]

# result_hate = []
# for el in classified_hate:
#         result_hate.append(data[el]["subreddit"])
# print("Hate")
# print(Counter(result_hate))
# result_offensive = []
# for el in classified_offensive:
#         result_offensive.append(data[el]["subreddit"])
# print("Offense")
# print(Counter(result_offensive))