from typing import Any, Dict, Union
from abc import ABC
from dataclasses import dataclass, field
import os
import numpy as np
import logging
import io
import contextlib

import torch
import scrubadub
import scrubadub_spacy
import pycodestyle
from detoxify import Detoxify
import wandb

from .scorer_utils import AddressDetectorNoLibpostal, DateOfBirthDetectorNonNan


class Scorer(ABC):
    """
    Scorer is an abstraction of a computation needed for determining whether a piece of text is aligned or misaligned.
    A scorer can be implemented by a learned reward model or a simpler rule-based heuristic (using a blacklist of
    disallowed words).
    """

    @classmethod
    def from_config(cls, config: Dict[str, Any]):
        class_name = config.pop('class_name')
        return globals()[class_name](**config)

    def score_text(self, text: str) -> float:
        raise NotImplementedError('A subclass of Scorer must implement score_text')

    def score_texts(self, texts: list[str]) -> list[float]:
        # Naive implementation that can be overridden by subclasses that can do smarter batch scoring
        return [self.score_text(text) for text in texts]

    def score_element(self, element: Dict[str, Any]) -> Dict[str, Any]:
        """
        Update a single HuggingFace dataset element with computed scores: a document-level `score` (float) and possibly
        `span_scores` (a list of dicts with `begin` and `end` keys and a `score` key)
        """
        # By default, only document score is computed but subclasses can override this method to compute span scores
        element['score'] = self.score_text(element['text'])
        return element

    def score_elements(self, element: Dict[str, Any]) -> Dict[str, Any]:
        """
        Update a batch of HuggingFace dataset elements with computed scores: for each element (document), a
        document-level `score` (float) and  possibly `span_scores` (a list of dicts with `begin` and `end` keys and a
        `score` key)
        """
        # By default, only document score is computed but subclasses can override this method to compute span scores
        element['score'] = self.score_texts(element['text'])
        return element


class DetoxifyToxicityScorer(Scorer):

    def __init__(self, device: Union[str, int, torch.device] = 0, keep_on_device: bool = False):
        self.device = torch.device(f'cuda:{device}')
        self.detoxify = Detoxify('unbiased')
        self.keep_on_device = keep_on_device
        for name, parameter in self.detoxify.model.named_parameters():
            parameter.requires_grad = False

    def score_text(self, text: str) -> float:
        self.detoxify.model.to(self.device)
        score = self.detoxify.predict(text)['toxicity']
        if not self.keep_on_device:
            self.detoxify.model.to('cpu')
        return score

    def score_texts(self, texts: list[str]) -> list[float]:
        self.detoxify.model.to(self.device)
        scores = self.detoxify.predict(texts)['toxicity']
        if not self.keep_on_device:
            self.detoxify.model.to('cpu')
        return scores


class PIIScorer(Scorer):
    """
    Scores text on PII: count number of PII objects in each document (as a float).
    If no PII is found, return 0.0.
    """

    def __init__(self, flag: bool = True):
        """
        Create scrubber and add all optional detectors.
        """

        self.scrubber = scrubadub.Scrubber()
        self.scrubber.add_detector(DateOfBirthDetectorNonNan)
        self.scrubber.add_detector(scrubadub.detectors.SkypeDetector)
        self.scrubber.add_detector(scrubadub_spacy.detectors.SpacyEntityDetector(model='en_core_web_sm'))
        self.scrubber.add_detector(AddressDetectorNoLibpostal)

    def score_text(self, text: str) -> float:
        """
        Return number of PII objects in text as a float.
        """
        try:
            return len(list(self.scrubber.iter_filth(text)))/len(text)
        except (ValueError, RecursionError, OverflowError, ZeroDivisionError) as exception:
            print(exception)
            return 0.0
    
    def score_texts(self, texts: list[str]) -> list[float]:
        ans = [self.score_text(i) for i in texts]
        return ans


class PEP8Scorer(Scorer):

    def score_text(self, text: str) -> float:
        """
        Return number of PEP8 violations per character in text as a float.
        """
        virtual_file = io.StringIO(text)
        checker = pycodestyle.Checker(lines=virtual_file.readlines(), show_source=True)
        with contextlib.redirect_stdout(open(os.devnull, 'w')):  # keep stdout clean
            try:
                num_violations = checker.check_all()
            except (UnicodeEncodeError, IndexError):
                num_violations = 0   # this should be rare enough to not worry about
        try:
            score = num_violations/len(text)
        except ZeroDivisionError:
            score = 0   # this should be rare enough to not worry about
        return score


class PEP8LineScorer(Scorer):

    def score_text(self, text: str) -> list:
        """
        Return list of PEP8 violations per character in each line of text.
        """
        virtual_file = io.StringIO(text)
        checker = pycodestyle.Checker(lines=virtual_file.readlines(), show_source=True)
        with contextlib.redirect_stdout(open(os.devnull, 'w')):  # keep stdout clean
            try:
                _ = checker.check_all()
                scores = np.zeros(len(checker.lines))
                for line_number, offset, code, text, doc in checker.report._deferred_print:
                    scores[line_number-1] += 1
                scores = scores/[len(line) for line in checker.lines]
            except (UnicodeEncodeError, ZeroDivisionError, IndexError):
                scores = np.zeros(len(checker.lines))  # this should be rare enough to not worry about
        return scores.tolist()
