import pandas as pd
import sys
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
import torch
from datetime import datetime
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

import json
import re

class ValueResonanceClassifier:
    '''
    ValueResonanceClassifier classifier:
        - 
    '''
    def __init__(self,model_path,model_type='transformers',tokenizer='roberta-large-mnli'):
        if model_type=='transformers':    
            self.model_config = \
                    AutoConfig.from_pretrained(f"{model_path}config.json")
            self.model = \
                    AutoModelForSequenceClassification.from_pretrained(f"{model_path}pytorch_model.bin",
                        config=self.model_config)
            self.tokenizer=AutoTokenizer.from_pretrained(tokenizer)
        elif model_type=='torch':
            self.model = AutoModelForSequenceClassification.from_pretrained(tokenizer, num_labels=3)
            self.model.load_state_dict(torch.load(f"{model_path}pytorch_model.pt"))
            # self.model=torch.load(f"{model_path}pytorch_model.pt")
            self.tokenizer=AutoTokenizer.from_pretrained(tokenizer)
        elif model_type=='rte':
            self.model = \
                    AutoModelForSequenceClassification.from_pretrained(tokenizer)
            self.tokenizer=AutoTokenizer.from_pretrained(tokenizer)

    def ensure_clean(self, text):
        return re.sub("\s\s+", " ", str(text))

    def score_entailment(self, premise, hypothesis):
        '''
        Scores entailment of premise/hypothesis pair. 
        To use a transformers model type like roberta-large-mnli (which we use for training), set argument:
            - model_type = 'transformers'
            
        If you are using a transformer model and have already loaded the model/tokenizer be sure to specify the
            tokenizer in the arguments, otherwise you can just put the transformers model name in the model 
            argument
            
        If you want the full allenlp json for the entailment scoring instead of just the final label set
            `return_full_json=True`
        '''
        premise=self.ensure_clean(premise)
        hypothesis=self.ensure_clean(hypothesis)
        text = f"{premise}</s></s>{hypothesis}."
        inputs = self.tokenizer(text, return_tensors="pt")
        with torch.no_grad():
            logits = self.model(**inputs).logits
            predicted_class_id = logits.argmax().item()
            result = self.model.config.id2label[predicted_class_id].lower()
        return result
    
    def score_entailment_logits(self, premise, hypothesis):
        '''
        Scores entailment of premise/hypothesis pair. 
        To use a transformers model type like roberta-large-mnli (which we use for training), set argument:
            - model_type = 'transformers'
            
        If you are using a transformer model and have already loaded the model/tokenizer be sure to specify the
            tokenizer in the arguments, otherwise you can just put the transformers model name in the model 
            argument
            
        If you want the full allenlp json for the entailment scoring instead of just the final label set
            `return_full_json=True`
        '''
        premise=self.ensure_clean(premise)
        hypothesis=self.ensure_clean(hypothesis)
        text = f"{premise}</s></s>{hypothesis}."
        inputs = self.tokenizer(text, return_tensors="pt")
        with torch.no_grad():
            logits = self.model(**inputs).logits
            probs = logits.softmax(dim=-1).detach().cpu().flatten().numpy().tolist()
            predicted_class_id = logits.argmax().item()
        return {'pred' : predicted_class_id, 'logits' : logits, 'probs' : probs}

    def score_entailments(self, premises, hypotheses):
        '''
        Essentially executes `score_entailment` over multiple hypotheses
        Arguments are essentially the same except premises and hypotheses must be a list of hypotheses
        
        Output pandas dataframe has hypotheses as the columns and premises as the rows
        '''
        results_dict={}
        for premise in premises:
            hypothesis_results=[]
            for hypothesis in hypotheses:
                hypothesis_result=self.score_entailment(premise, hypothesis)
                hypothesis_results.append(hypothesis_result)
            results_dict[premise]=hypothesis_results

        results_df=pd.DataFrame(results_dict, index=hypotheses)
        return results_df

    def predict(self, premises, hypotheses, do_eval=False, true_labels=[]):
        if len(premises) != len(hypotheses):
            print(f'You must have an equal length of premises and hypotheses. Currently n-premises ({len(premises)}) and\
            n-hypotheses ({len(hypotheses)})')
            sys.exit()
        elif (do_eval==True) and (len(true_labels) != len(premises)):
            print(f'If you wish to evaluate results, you must have an equal length of true labels and and \
            premises. Currently n-premises ({len(premises)}) and\
            n-true_labels ({len(true_labels)})')
            sys.exit()
        elif (do_eval==True) and (len(true_labels) != len(hypotheses)):
            print(f'If you wish to evaluate results, you must have an equal length of true labels and and \
            hypotheses. Currently n-hypotheses ({len(hypotheses)}) and\
            n-true_labels ({len(true_labels)})')
            sys.exit()
        if (do_eval==True) and (len(true_labels)==0):
            print('In order to do evaluation you must provide a list of true labels')

        scores=[]
        for i in range(len(premises)):
            score=self.score_entailment(premises[i],hypotheses[i])
            scores.append(score)

        if len(true_labels)==len(premises):
            results_df=pd.DataFrame({'premise':premises,'hypothesis':hypotheses,'prediction':scores,'label':true_labels})
            if do_eval==True:
                averaging='weighted'
                y_pred=results_df.loc[:,'prediction'].values
                y_true=results_df.loc[:,'label'].values
                if type(y_pred[0]) != type(y_true[0]):
                    if type(y_pred[0]) == str:
                        y_pred_new=[]
                        for val in y_pred:
                            if val == 'contradiction':
                                y_pred_new.append(0)
                            elif val == 'neutral':
                                y_pred_new.append(1)
                            elif val == 'entailment':
                                y_pred_new.append(2)
                        y_pred=y_pred_new
                accuracy_scores=pd.DataFrame({'accuracy':[accuracy_score(y_true, y_pred)],
                                    'precision':[precision_score(y_true, y_pred,average=averaging)],
                                    'recall':[recall_score(y_true, y_pred,average=averaging)],
                                    'F1':[f1_score(y_true, y_pred,average=averaging)]})
                
                return results_df, accuracy_scores
            else:
                return results_df
        else:
            results_df=pd.DataFrame({'premise':premises,'hypothesis':hypotheses,'prediction':scores})
            return results_df