import argparse
import json
import random
import sys
import logging
import numpy as np

import configparser
import openai

from typing import Any, List, Dict
from tqdm import tqdm
from nltk.tokenize.treebank import TreebankWordDetokenizer
from nltk import word_tokenize
from nltk.corpus import stopwords

from negative_generation.utils import set_logger, read_dataset, save_results

parser = argparse.ArgumentParser(description="Process arguments for generating negatives by gpt3")
parser.add_argument("--example-dataset-path", type=str, default="./data/dailydialog++/train/train.json", help="dataset path for sampling examples in prompts")
parser.add_argument("--target-dataset-path", type=str, default="./data/dailydialog/train/train_augmentation.json", help="dataset path for sampling targets in prompts")
parser.add_argument("--result-path", type=str, default="./data/dailydialog/train/train_augmented_20000.json", help="result dataset path for saving")
parser.add_argument("--type", type=str, default="dts", help="task instruction type in prompts for generating negatives")

# Arguments for ablation studies
parser.add_argument("--num-of-example", type=int, default=2, help="set number of examples for constructing prompts (e.g. 2 means 2-shot)")
parser.add_argument("--sampling-ratio", type=float, default=1, help="sampling ratio of example dataset")
parser.add_argument("--is-reuse", action="store_true", help="if this argument is given, reuse the generated sample")
parser.add_argument("--is-pos", action="store_true", help="if this argument is given, use the positive response for construction prompts")

# Arguments for test run
parser.add_argument("--is-test", action="store_true", help="if this argument is given, run the test mode (test generation)")


def constructing_meta_prompts(target_context: Dict[str, Any]) -> str:
    """
    Constructing prompt corresponding to meta-prompt type

    :param target_context: target context to generate negative responses
    :return: meta-type prompt
    """
    prompts = 'Now A and B are talking.\n'
    next_speaker = None
    
    for i, utt in enumerate(target_context["context"]):
        if i==0 or i%2==0:
            turn_utt = "A: " + utt + "\n"
            next_speaker = "B"
        else:
            turn_utt = "B: " + utt + "\n"
            next_speaker = "A"
        prompts += turn_utt
    
    narration = f"Suddenly, {next_speaker} makes an awkward response. \
        The response appears to be okay at first glance, but it's irrelevant to the conversation.\n{next_speaker}:"
    
    prompts += narration

    return prompts


def constructing_dts_prompts(
    example_list: List[Dict[str, Any]],
    target_context: Dict[str, Any],
    is_pos: bool
) -> str:
    """
    Constructing prompt corresponding to direct task specification type

    :param example_list: dialogue examples to construct prompt
    :param target_context: target context to generate negative responses
    :param is_pos: whether to contain the positive response
    :return: direct task specification type prompt
    """
    prompts = ''
    next_speaker = None

    for sample in example_list:
        prompts += '###\nDialogue context:\n"""\n'

        for i, utt in enumerate(sample["context"]):
            if i==0 or i%2==0:
                turn_utt = "A: " + utt + "\n"
            else:
                turn_utt = "B: " + utt + "\n"
            prompts += turn_utt
        
        if is_pos:
            prompts += '"""\nRelevant response: '
            prompts += sample["positive_responses"][random.randint(0,4)]
            prompts += "\nCreate five irrelevant responses containing keywords of the given dialogue context:\n"
        
        else:
            prompts+='"""\nCreate five irrelevant responses containing keywords of the given dialogue context:\n'

        for i, neg in enumerate(sample["adversarial_negative_responses"]):
            # numbering start from 1.
            num_neg = str(i+1) + ". " + neg + "\n"
            prompts+= num_neg
        
        prompts += "\n"
    prompts += '###\nDialogue context:\n"""\n'

    for i, utt in enumerate(target_context["context"]):
        if i==0 or i%2==0:
            turn_utt = "A: " + utt + "\n"
            next_speaker = "B"
        else:
            turn_utt = "B: " + utt + "\n"
            next_speaker = "A"
        prompts += turn_utt

    if is_pos:
        prompts += '"""\nRelevant response: '
        prompts += target_context["positive_responses"][random.randint(0,4)]
    else:
        prompts += '"""'

    if len(example_list) == 0:
        added_prompts = f"\nCreate an {next_speaker}'s irrelevant response containing keywords of the given dialogue context:\n{next_speaker}:"
        prompts += added_prompts
    else:
        prompts += "\nCreate five irrelevant responses containing keywords of the given dialogue context:\n1."

    return prompts


def generating_negatives_by_inferencing_gpt3(
    type: str,
    prompt: str,
    stop_seq: List[str],
    max_tokens: int
):
    """
    Generating negative responses by inferencing gpt3 API

    :param type: prompt type
    :param prompt: target prompt
    :param stop_seq: stop sequence token list
    :return: generated negative responses sample
    """
    config = configparser.ConfigParser()
    config.read('config.ini')

    openai.organization = config['OPENAI_INFO']['OPENAI_ORGANIZATION']
    openai.api_key = config['OPENAI_INFO']['OPENAI_API_KEY']

    generated_negatives = openai.Completion.create(
        engine="davinci",
        prompt=prompt,
        temperature=0.8,
        max_tokens=max_tokens,
        top_p=1,
        frequency_penalty=0.4,
        presence_penalty=0.4,
        stop=stop_seq
    )
    generated_negatives = generated_negatives["choices"][0]["text"]
    
    if type == "dts" and "###" in stop_seq:
        generated_negatives = "1." + generated_negatives
        generated_negatives = generated_negatives.split("\n")

    return generated_negatives


def multi_negatives_generation(
    logger: logging.Logger,
    type: str,
    example_dataset: List[Any],
    target_dataset: List[Any],
    result_path: str,
    num_of_example: int,
    sampling_ratio: float,
    is_pos: bool,
    is_reuse: bool,
    stop_word_list: List[str]
):
    """
    Set of negative responses generation pipeline

    :param logger: logger object
    :param type: prompt type
    :param example_dataset: dataset for sampling examples in prompts
    :param target_dataset: dataset for sampling targets in prompts
    :param result_path: result dataset path for saving
    :param num_of_example: set number of examples for constructing prompts (e.g. 2 means 2-shot)
    :param sampling_ratio: sampling ratio of example dataset
    :param is_pos: whether to use the positive response for construction prompts
    :param is_reuse: whether to reuse the generated sample
    :param stop_word_list: stop word list for filtering
    """
    gpt3_neg_data = []
    attempt = 0
    stop_seqs: List[str] = []
    prompts: str = ''
    max_tokens: int = 50

    target_idx = [i for i in range(len(target_dataset))]
    assert len(target_idx) == len(target_dataset)
    
    sample_set = example_dataset[: int(len(example_dataset) * sampling_ratio)+1]

    logger.info(f"total_target: {len(target_dataset)}")
    logger.info(f"sampling_ratio: {len(sample_set)/len(example_dataset)}")
    
    for i in tqdm(target_idx):
        # excluding target context, randomly select N few-shot (num_of_example)
        while True:
            rand_idx = np.random.choice(len(sample_set), num_of_example)
            if i not in rand_idx:
                break
        
        sample_list = [sample_set[j] for j in list(rand_idx)]
        target_context = target_dataset[i]
        target_negs = []

        if type=="dts":
            # constructing direct task specification type prompt
            if num_of_example == 0:
                # arguments setup for zero-shot setting
                stop_seqs = ["\n", "A:", "B:"]
                max_tokens = 50
            else:
                stop_seqs = ["###"]
                max_tokens = 200
            prompts = constructing_dts_prompts(sample_list, target_context, is_pos=is_pos)
        
        elif type == "meta":
            # constructing meta-type prompt
            stop_seqs = ["\n", "A:", "B:"]
            max_tokens = 50
            prompts = constructing_meta_prompts(target_context)
        
        while True:
            attempt += 1
            try:
                generated_negatives = generating_negatives_by_inferencing_gpt3(type, prompts, stop_seqs, max_tokens)

                # [Few-shot setting] filtering criteria for generated negatives
                if "###" in stop_seqs:
                    for j, neg in enumerate(generated_negatives):
                        temp_status = False
                        pivot = str(j+1)+". "

                        # filtering incorrectly generated negatives
                        if "A." in neg or "B." in neg or "__" in neg or len(neg) < 5:
                            break
                        
                        # refining generated negatives
                        if pivot in neg:
                            neg = neg.replace(pivot, "")
                            neg = neg.strip()
                            temp_status = True
                        if ('"' == neg[0] and '"' == neg[-1]) or ("'" == neg[0] and "'" == neg[-1]):
                            neg = neg[1:-1]
                        
                        # filtering too generic negatives
                        tokenized_neg = word_tokenize(neg)
                        if len(set([e.lower() for e in tokenized_neg])-set(stop_word_list))==0:
                            break
                        
                        if temp_status:
                            target_negs.append(neg)
                    
                    if len(target_negs) == 5 and len(target_negs) == len(set(target_negs)):
                        break

                    target_negs = []

                # [Zero-shot setting] filtering criteria for generated negatives
                elif "\n" in stop_seqs:
                    neg = generated_negatives.strip()
                    tokenized_neg = word_tokenize(neg)

                    if  "__" in neg or len(neg) < 5:
                        continue
                    if len(set([e.lower() for e in tokenized_neg])-set(stop_word_list))==0:
                        continue

                    if "A:" in neg:
                        neg = neg.replace("A:", "")
                    elif "B:" in neg:
                        neg = neg.replace("B:", "")
                    
                    if ('"' == neg[0] and '"' == neg[-1]) or ("'" == neg[0] and "'" == neg[-1]):
                        neg = neg[1:-1]

                    if neg in target_negs:
                        continue
                    
                    target_negs.append(neg)
                    
                    if len(target_negs) == 5:
                        break
            
            except openai.error.APIError:
                logger.info(f"error prompts: {prompts}")

        data = {
            "id":str(i),
            "context":target_context["context"],
            "positive_responses":target_context["positive_responses"],
            "gpt3_negative_responses":target_negs,
            "random_negative_responses":target_context["random_negative_responses"],
            "sample_idxs": [data["id"] for data in sample_list]
        }
        
        if is_reuse:
            added_sample = {
                "id": str(i),
                "context": target_context["context"],
                "positive_responses": target_context["positive_responses"],
                "adversarial_negative_responses": target_negs,
                "random_negative_responses": target_context["random_negative_responses"]
            }
            sample_set.append(added_sample)
        
        gpt3_neg_data.append(data)

        if len(gpt3_neg_data)%500 == 0:
            logger.info(f"saving intermediate results with {len(gpt3_neg_data)} lengths")
            save_results(result_path, gpt3_neg_data)

    logger.info(f"total number of inferences: {attempt}")
    save_results(result_path, gpt3_neg_data)


def test_negative_generation(
    logger: logging.Logger,
    type: str,
    example_dataset: List[Any],
    target_dataset: List[Any],
    num_of_example: int,
    is_pos: bool,
    target_idx: int
):
    """
    generating negative responses for testing

    :param logger: logger object
    :param type: prompt type
    :param example_dataset: dataset for sampling examples in prompts
    :param target_dataset: dataset for sampling targets in prompts
    :param num_of_example: set number of examples for constructing prompts (e.g. 2 means 2-shot)
    :param is_pos: whether to use the positive response for construction prompts
    :param target_idx: target index of target dataset
    """
    rand_idx = np.random.choice(len(example_dataset), num_of_example)
    sample_list = [example_dataset[i] for i in list(rand_idx)]
    target_context = target_dataset[target_idx]
    
    prompt: str = ""
    stop_seqs: List[str] = []
    max_tokens: int = 50

    if type=="dts":
        if num_of_example == 0:
            stop_seqs = ["\n", "A:", "B:"]
            max_tokens = 50
        else:
            stop_seqs = ["###"]
            max_tokens = 200
        prompt = constructing_dts_prompts(sample_list, target_context, is_pos=is_pos)
    elif type=="meta":
        stop_seqs = ["\n", "A:", "B:"]
        max_tokens = 50
        prompt = constructing_meta_prompts(target_context)
    
    logger.info(f"test prompt is\n {prompt}")

    generated_negatives = generating_negatives_by_inferencing_gpt3(type, prompt, stop_seqs, max_tokens)
    
    logger.info(f"generated responses are\n{generated_negatives}")


def main():
    args = parser.parse_args()

    # set logger
    logger = set_logger("generating-negatives")

    stop_word_list = stopwords.words('english') + ['.',',','!','?','"',"'",'yes','no']
    example_dataset = read_dataset(args.example_dataset_path)
    target_dataset = read_dataset(args.target_dataset_path)[10000:20000]

    if args.is_test:
        # test mode
        test_negative_generation(
            logger,
            args.type, 
            example_dataset,
            target_dataset,
            args.num_of_example,
            args.is_pos,
            target_idx=20
        )
    else:
        multi_negatives_generation(
            logger,
            args.type,
            example_dataset,
            target_dataset,
            args.num_of_example,
            args.result_path,
            args.sampling_ratio,
            args.is_pos,
            args.is_reuse,
            stop_word_list
        )


if __name__ == "__main__":
    sys.exit(main())