import os
import json
import pickle
from tqdm import tqdm
import openai
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer, util
import torch
import pandas as pd
from typing import List, Optional, Dict, Tuple, Union

from data_fns import featurise_data, prep_newspaper_data, prep_sotu_data, clean_entity


def check_special_tokens_exist(text, stoks):
    """
    This function checks if special tokens exist in the text

    :param text: (str) text to check
    :param stoks: (dict) dictionary of special tokens
    :returns: (bool) True if special tokens exist, False otherwise
    """
    token_exists=[tok in text for tok in stoks.values()]
    return all(token_exists)
        
        
    

def get_completion_from_messages(
        text: str,
        model: str,
        openai_key: str,
        openai_prompt: str = None,
        openai_params: Dict = None
) -> Tuple[str, int]:
    """
    This function takes text of an article and send it to OpenAI API as user input and
    collects the content of the API response and the total number of tokens used

    :param text: (str) user input to API
    :param model: (str) name of the model to use (see "https://platform.openai.com/docs/models")
    :param openai_key: (str) OpenAI API key
    :param openai_topic: (str) topic to use for the API prompt
    :param openai_prompt: (str) prompt to use for the API
    :param openai_params: (dict) parameters to use for the API
    :returns: tuple ((str, int)): response text from API and number of tokens used
    """

 
    sys_prompt = openai_prompt

    os.environ["OPENAI_API_KEY"] = openai_key

    if openai.__version__ >= "1.0.0":
        client = openai.OpenAI(
            api_key=openai_key,
            timeout=openai_params["request_timeout"] if "request_timeout" in openai_params else 10
        )
        response = client.chat.completions.create(
            model=model,
            messages=[
                {
                    "role": "system",
                    "content": sys_prompt
                },
                {
                    "role": "user",
                    "content": text
                }
            ],
            temperature=openai_params["temperature"] if "temperature" in openai_params else 0,
            # max_tokens=openai_params["max_tokens"] if "max_tokens" in openai_params else 1,
            top_p=openai_params["top_p"] if "top_p" in openai_params else 0,
            frequency_penalty=openai_params["frequency_penalty"] if "frequency_penalty" in openai_params else 0,
            presence_penalty=openai_params["presence_penalty"] if "presence_penalty" in openai_params else 0,
        )
        return response.choices[0].message.content, response.usage.total_tokens
    else:
        # this supports the deprecated version of the OpenAI API
        openai.api_key = openai_key
        response = openai.ChatCompletion.create(
            model=model,
            messages=[
                {
                    "role": "system",
                    "content": sys_prompt
                },
                {
                    "role": "user",
                    "content": text
                }
            ],
            temperature=openai_params["temperature"] if "temperature" in openai_params else 0,
            # max_tokens=openai_params["max_tokens"] if "max_tokens" in openai_params else 1,
            top_p=openai_params["top_p"] if "top_p" in openai_params else 0,
            frequency_penalty=openai_params["frequency_penalty"] if "frequency_penalty" in openai_params else 0,
            presence_penalty=openai_params["presence_penalty"] if "presence_penalty" in openai_params else 0,
            request_timeout=openai_params["request_timeout"] if "request_timeout" in openai_params else 10,
        )
        return response.choices[0].message["content"], response.usage["total_tokens"]



def clean_ocr_errors_gpt(article,stoks,
                         prompt="Return the entire text taken from an old newspaper article by cleaning OCR errors and artifacts. Return only the article text as a string. Don't add any new words or context. Simply act like a spellchecker - but do clean up proper nouns as needed.  Retain special tokens like '</s>', '[M]', '[/M]' - don't clean those up.",
                         retries=0):
    """
    This function takes text of an article and send it to OpenAI API as user input and
    collects the content of the API response and the total number of tokens used

    :param article: (str) user input to API
    :returns: tuple ((str, int)): response text from API and number of tokens used
    """
    retries=retries+1
    if retries>2:
        print("Exceeded maximum retries. Returning the original article")
        return "Original"
    ##Clean up the article a bit
    ##replace \n with space
    article=article.replace("\n", " ")
    ##replace multiple spaces with single space
    article=" ".join(article.split())
    ##remove any trailing spaces
    article=article.strip()
    
    ##Get completion
    completion, tokens = get_completion_from_messages(
        text=article,
        model="gpt-4-turbo-2024-04-09",
        openai_key=os.environ["OPENAI_API_KEY"],
        openai_prompt=prompt,
        openai_params={
            "temperature": 0,
            "top_p": 1,
            "frequency_penalty": 0,
            "presence_penalty": 0,
            "request_timeout": 100,
        }
    )


    ##check if special tokens exist
    if not check_special_tokens_exist(completion, stoks):
        print("Special tokens not found in the completion. Re-prompting")
        print("Original article: ", article)
        print("Original completion: ", completion)
        completion=clean_ocr_errors_gpt(article, stoks, 
                                        prompt=prompt,
                                        retries=retries)
    print("Final completion: ", completion)
    
    return completion

if __name__ == '__main__':



    # trained_model_path = '/mnt/data01/entity/trained_models/cgis_model_ent_mark_incontext_90' # not finetuned
    # trained_model_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/entity_split_newspaper_wiki_coref_disamb_more_incontext' # finetuned
    trained_model_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model_ent_mark_incontext' # coref model
    # trained_model_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/asymmetric_disambiguation_full_100' # assym wiki only
    # trained_model_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/asymm_newspapers_0.3530040607711171_64_5_0.9285210198462236/' # assym wiki + news
    
    # model= SentenceTransformer(trained_model_path)
    # print(model.max_seq_length)
    # exit()

    stoks={'men_start': "[M]", 'men_end': "[/M]", "sep": '</s>'}
    
    instance_types_to_keep=set(['human','human Biblical figure', 'mythical character', 'religious character', 'historical character', 
                            'supernatural being',
                            'fictional character', 'television character', 'fictional human', 'literary character', 
                            'film character', 'animated character', 'musical theatre character', 'theatrical character'])

    model= SentenceTransformer(trained_model_path)
    ds = prep_sotu_data(
            dataset_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/sotu_reformatted_wp_ids.json',
            model=model,
            special_tokens=stoks,
            featurisation='ent_mark',
            disamb_or_coref='disamb',
            date_featurisation='prepend_1',
            override_max_seq_length=256,
            keep_entity_types=['PER'],
        )
    
    text_list=[d['text'] for d in ds]
    
    ##Clean up the OCR errors
    cleaned_text_list=[]
    for text in tqdm(text_list[628:]):
        cleaned_text=clean_ocr_errors_gpt(text, stoks)
        cleaned_text_list.append(cleaned_text)
        ##Write the intermediate cleaned text to a file
        with open('/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/sotu_disamb_format_wp_ids_cleaned_intermediate_2_gpt4.json', 'w') as f:
            json.dump(cleaned_text_list, f)
    
    # ##Replace the text in the dataset
    # for i, d in enumerate(ds):
    #     d['text']=cleaned_text_list[i]
    
    # ##Save the cleaned dataset
    # with open('/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/sotu_disamb_format_wp_ids_cleaned.json', 'w') as f:
    #     json.dump(ds, f)