import json,csv
import requests, time
import openai
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from requests.auth import HTTPBasicAuth

class evaluation_toxicity():
    def __init__(self, tokenizer='gpt2'):
        self.perspective_api_key = 'api_key'
        self.perspective_url = ('https://commentanalyzer.googleapis.com/v1alpha1/comments:analyze' +
               '?key=' + self.perspective_api_key)
        openai.api_key='api_key'
        self.tokenizer=GPT2Tokenizer.from_pretrained(tokenizer)
        self.tokenizer.pad_token = self.tokenizer.eos_token

    def eval_toxicity(self, textlist, sleep=False, return_response=False):
        avg_toxicity=0.0
        response_list=[]
        if len(textlist)>60:
            sleep=True
        for content in textlist:
            data_dict = {
                'comment': {'text': content},
                'languages': ['en'],
                'requestedAttributes': {'TOXICITY': {}}
            }
            response = requests.post(url=self.perspective_url, data=json.dumps(data_dict))
            if return_response:
                response_list.append(json.loads(response.content))
            avg_toxicity+=json.loads(response.content)['attributeScores']['TOXICITY']['summaryScore']['value']
            if sleep:
                time.sleep(1)
        avg_toxicity/=len(textlist)
        if return_response:
            return avg_toxicity, response_list
        else:
            return avg_toxicity

    def eval_perplexity_gpt2(self, infilepathorlist, outfilepath=None, model_type='gpt2', eval_type='complete', device=0):
        tokenizer=GPT2Tokenizer.from_pretrained(model_type)
        model=GPT2LMHeadModel.from_pretrained(model_type)
        model.eval()
        if torch.cuda.is_available():
            model.to(device)
        device=model.device
        tokenizer.pad_token = tokenizer.eos_token
        if outfilepath is not None:
            writer=csv.writer(open(outfilepath,'w'))
            writer.writerow(['id', eval_type+'_ppl',eval_type+'_avg_ppl'])
        else:
            writer=None
        all_ppl=[]
        inputlist=[]
        ids=[]
        assert(eval_type=='complete')
        if type(infilepathorlist)==str:
            reader=csv.DictReader(open(infilepathorlist,'r'))
            for row in reader:
                if 'id' in row:
                    ids.append(row['id'])
                else:
                    ids.append(0)
                if 'generated' in row:
                    completion_list = json.loads(row['generated'])
                elif 'generation' in row:
                    completion_list = json.loads(row['generation'])
                inputlist.append(completion_list)
        else:
            assert(type(infilepathorlist)==list)
            inputlist=infilepathorlist
        assert(len(ids)==len(inputlist))
        with torch.no_grad():
            for text_list_i, text_list in enumerate(inputlist):
                avg_ppl=[]
                for text in text_list:
                    tokens = tokenizer.encode(text.strip())
                    if tokens[0] != tokenizer.eos_token_id:
                        tokens = [tokenizer.eos_token_id] + tokens
                    input_ids = torch.tensor(tokens).unsqueeze(0).to(device)
                    outputs = model(input_ids, labels=input_ids, return_dict=True)
                    loss = outputs.loss
                    orignial_ppl = torch.exp(loss.detach().cpu()).item()
                    all_ppl.append(orignial_ppl)
                    avg_ppl.append(orignial_ppl)
                if writer:
                    writer.writerow([ids[text_list_i],
                                     json.dumps(avg_ppl), np.mean(avg_ppl)])
                else:
                    print(np.mean(avg_ppl))
        print(np.mean(all_ppl))
        return np.mean(all_ppl)

    def eval_perplexity_dgpt(self, infilepathorlist, outfilepath=None, model_type='microsoft/DialoGPT-large',
                             eval_type='complete', device=0):
        tokenizer=AutoTokenizer.from_pretrained(model_type)
        model=AutoModelForCausalLM.from_pretrained(model_type)
        model.eval()
        if torch.cuda.is_available():
            model.to(device)
        device=model.device
        tokenizer.pad_token = tokenizer.eos_token
        if outfilepath is not None:
            writer=csv.writer(open(outfilepath,'w'))
            writer.writerow(['id', eval_type+'_ppl',eval_type+'_avg_ppl'])
        else:
            writer=None
        all_ppl=[]
        inputlist=[]
        ids=[]
        assert(eval_type=='complete')
        if type(infilepathorlist)==str:
            reader=csv.DictReader(open(infilepathorlist,'r'))
            for row in reader:
                if 'id' in row:
                    ids.append(row['id'])
                else:
                    ids.append(0)
                if 'generated' in row:
                    completion_list = json.loads(row['generated'])
                elif 'generation' in row:
                    completion_list = json.loads(row['generation'])
                inputlist.append(completion_list)
        else:
            assert(type(infilepathorlist)==list)
            inputlist=infilepathorlist
        assert(len(ids)==len(inputlist))
        with torch.no_grad():
            for text_list_i, text_list in enumerate(inputlist):
                avg_ppl=[]
                for text in text_list:
                    tokens = tokenizer.encode(text.strip())
                    if tokens[0] != tokenizer.eos_token_id:
                        tokens = [tokenizer.eos_token_id] + tokens
                    input_ids = torch.tensor(tokens).unsqueeze(0).to(device)
                    outputs = model(input_ids, labels=input_ids, return_dict=True)
                    loss = outputs.loss
                    orignial_ppl = torch.exp(loss.detach().cpu()).item()
                    all_ppl.append(orignial_ppl)
                    avg_ppl.append(orignial_ppl)
                if writer:
                    writer.writerow([ids[text_list_i],
                                     json.dumps(avg_ppl), np.mean(avg_ppl)])
                else:
                    print(np.mean(avg_ppl))
        print(np.mean(all_ppl))
        return np.mean(all_ppl)

    def eval_toxicity_fileorlist(self, infilepathorlist, outfilepath=None, sleep=False):
        all_toxicity = []
        inputlist = []
        ids = []
        if outfilepath is not None:
            writer = csv.writer(open(outfilepath, 'w'))
            writer.writerow(['id',  'toxicity', 'avg_toxicity'])
        else:
            writer = None
        if type(infilepathorlist) == str:
            reader = csv.DictReader(open(infilepathorlist, 'r'))
            for row in reader:
                if 'id' in row:
                    ids.append(row['id'])
                else:
                    ids.append(0)
                if 'generated' in row:
                    completion_list = json.loads(row['generated'])
                elif 'generation' in row:
                    completion_list = json.loads(row['generation'])
                inputlist.append(completion_list)
        else:
            assert (type(infilepathorlist) == list)
            inputlist = infilepathorlist
        if len(inputlist) > 60:
            sleep = True
        for inputlist_i, textlist in enumerate(inputlist):
            avg_toxicity = []
            for content in textlist:
                data_dict = {
                    'comment': {'text': content},
                    'languages': ['en'],
                    'requestedAttributes': {'TOXICITY': {}}
                }
                response = requests.post(url=self.perspective_url, data=json.dumps(data_dict))
                try:
                    avg_toxicity.append(
                        json.loads(response.content)['attributeScores']['TOXICITY']['summaryScore']['value'])
                except:
                    print('error exists, skipped: no. ', inputlist_i,' testing example.')
                if sleep:
                    time.sleep(0.2)
            if writer:
                writer.writerow([ids[inputlist_i], json.dumps(avg_toxicity), np.mean(avg_toxicity)])
            else:
                print(np.mean(avg_toxicity))
            all_toxicity += avg_toxicity
        print(np.mean(all_toxicity))
        return np.mean(all_toxicity)


class evaluation_stance():
    def __init__(self):
        self.api_key='api_key'
        self.api_password='api_psw'
        self.url='https://cloud-api.gate.ac.uk/process/stance-classification'
        self.headers={'Content-Type':'text/plain'}
    def eval_stance_fileorlist(self, infilepathorlist1, infilepathorlist2, outfilepath=None, sleep=True):
        all_sup_stance= []
        inputlist1 = []
        inputlist2 = []
        ids = []
        if outfilepath is not None:
            writer = csv.writer(open(outfilepath, 'w'))
            writer.writerow(['id','stance',
                             'sup_stance_score', 'deny_stance_score', 'comment_stance_score', 'query_stance_score',
                             'avg_sup_score', 'avg_deny_score', 'avg_comment_score','avg_query_score'])
        else:
            writer = None
        if type(infilepathorlist1) == str:
            reader = csv.DictReader(open(infilepathorlist1, 'r'))
            for row in reader:
                if 'id' in row:
                    ids.append(row['id'])
                else:
                    ids.append(0)
                completion_list = json.loads(row['utt_list'])
                inputlist1.append(completion_list[-1])
        else:
            assert (type(infilepathorlist1) == list)
            inputlist1 = infilepathorlist1
            ids=[0]*len(inputlist1)
        if type(infilepathorlist2) == str:
            reader = csv.DictReader(open(infilepathorlist2, 'r'))
            for row in reader:
                if 'generated' in row:
                    completion_list = json.loads(row['generated'])
                elif 'generation' in row:
                    completion_list = json.loads(row['generation'])
                inputlist2.append(completion_list)
        else:
            assert (type(infilepathorlist2) == list)
            inputlist2 = infilepathorlist2
        assert(len(inputlist1)==len(inputlist2))

        for inputlist_i, text1 in enumerate(inputlist1):
            avg_sup_score = []
            avg_comment_score=[]
            avg_deny_score=[]
            avg_query_score=[]
            for text2 in inputlist2[inputlist_i]:
                data_dict_1 = {
                    'text': text1,
                    'id_str': "1"
                }
                data_dict_2={
                    'text': text2,
                    'id_str': "2",
                    'in_reply_to_status_id_str':"1"
                }
                response = requests.post(url=self.url, data=json.dumps(data_dict_1)+json.dumps(data_dict_2),
                                         headers=self.headers, auth=HTTPBasicAuth(self.api_key, self.api_password))
                try:
                    response = response.json()['entities']['TweetStance'][0]
                    avg_sup_score.append(response['support_score'])
                    avg_comment_score.append(response['comment_score'])
                    avg_deny_score.append(response['deny_score'])
                    avg_query_score.append(response['query_score'])
                except:
                    print(response)
                    print('error exists, skipped: no. ', inputlist_i, ' testing example.')
                if sleep:
                    time.sleep(0.5)
            if writer:
                writer.writerow([ids[inputlist_i], response['stance_class'],
                                 json.dumps(avg_sup_score), json.dumps(avg_deny_score), json.dumps(avg_comment_score),
                                 json.dumps(avg_query_score),
                                 np.mean(avg_sup_score), np.mean(avg_deny_score), np.mean(avg_comment_score),
                                 np.mean(avg_query_score)])
            else:
                print(np.mean(avg_sup_score))
            all_sup_stance += avg_sup_score
        print(np.mean(all_sup_stance))
        return np.mean(all_sup_stance)
