import re
import spacy
from statistics import mean

from tqdm import tqdm
tqdm.pandas(desc='Processing text')


class TextPreprocessor():
    def __init__(self, spacy_model):
        self.nlp = spacy.load(spacy_model)

    def preprocess(self, series, lowercase=True, remove_punct=True, 
                   remove_num=True, remove_stop=True, lemmatize=True):
        return (series.progress_apply(lambda text: self.preprocess_text(text, lowercase, remove_punct, remove_num, remove_stop, lemmatize)))

    def preprocess_text(self, text, lowercase, remove_punct,
                        remove_num, remove_stop, lemmatize):
        if lowercase:
            text = self._lowercase(text)
        doc = self.nlp(text)
        if remove_punct:
            doc = self._remove_punctuation(doc)
        if remove_num:
            doc = self._remove_numbers(doc)
        if remove_stop:
            doc = self._remove_stop_words(doc)
        if lemmatize:
            text = self._lemmatize(doc)
        else:
            text = self._get_text(doc)
        return text

    def _lowercase(self, text):
        return text.lower()
    
    def _remove_punctuation(self, doc):
        return [t for t in doc if not t.is_punct]
    
    def _remove_numbers(self, doc):
        return [t for t in doc if not (t.is_digit or t.like_num or re.match('.*\d+', t.text))]

    def _remove_stop_words(self, doc):
        return [t for t in doc if not t.is_stop]

    def _lemmatize(self, doc):
        return ' '.join([t.lemma_ for t in doc])

    def _get_text(self, doc):
        return ' '.join([t.text for t in doc])



class Evaluator:
    def __init__(self):
        pass

    def compute_mean_score(self, func, all_ground_truths, all_results, at=None):
        return mean([func(truths, res, at=at) for truths, res in zip(all_ground_truths, all_results)])

    def precision(self, ground_truths, results, at=None):
        at = len(results) if at is None else at
        relevances = [1 if d in ground_truths else 0 for d in results[:at]]
        return sum(relevances)/len(results[:at])

    def recall(self, ground_truths, results, at=None):
        at = len(results) if at is None else at
        relevances = [1 if d in ground_truths else 0 for d in results[:at]]
        return sum(relevances)/len(ground_truths)

    def fscore(self, ground_truths, results):
        p = self.precision(ground_truths, results)
        r = self.recall(ground_truths, results)
        return (2*p*r)/(p+r) if (p != 0.0 or r != 0.0) else 0.0

    def reciprocal_rank(self, ground_truths, results, at=None):
        at = len(results) if at is None else at
        return max([1/(i+1) if d in ground_truths else 0.0 for i, d in enumerate(results[:at])])

    def average_precision(self, ground_truths, results, at=None):
        at = len(results) if at is None else at
        p_at = [self.precision(ground_truths, results, at=i+1) if d in ground_truths else 0 for i, d in enumerate(results[:at])]
        return sum(p_at)/len(ground_truths)
