import os
import pandas as pd
from tqdm import tqdm
import pickle
import tiktoken
import sys

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, f1_score, classification_report

import openai
from tenacity import (
    retry,
    stop_after_attempt,
    wait_random_exponential,
)

import argparse
from load_data import convert_data, combine_codes

encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
MAX_TOKEN = 4096
PROMPT_LENGTH = 70

parser = argparse.ArgumentParser()
parser.add_argument('--num_turns', type=int, default=4)
parser.add_argument('--output_label', type=str, default='positive')
parser.add_argument('--binarize_label', action='store_true')
parser.add_argument('--conv', action='store_true')
parser.add_argument('--code', action='store_true')
parser.add_argument('--feature', action='store_true')
parser.add_argument('--summary', action='store_true')
parser.add_argument('--objective', action='store_true')
args = parser.parse_args()

NUM_TURNS = args.num_turns
OUTPUT_LABEL = args.output_label
IS_BINARY = args.binarize_label
USE_CONV = args.conv
USE_CODE = args.code
USE_FEATURE = args.feature
USE_SUMMARY = args.summary
USE_OBJECTIVE = args.objective

# checkpoint = "roberta-base"
checkpoint = "distilbert-base-uncased"
SEP_TOKEN = ' <SEP> ' if 'distilbert' in checkpoint else ' </s></s> '

# Load your API key from an environment variable or secret management service
openai.api_key = os.getenv("OPENAI_API_KEY")

_system_msg = "You are a helpful assistant to help me understand the chat conversation between HelpSeeker and Counselor. Briefly answer questions about the conversation.\n\n"
_question = "Would the help seeker have felt more positive after the conversation? Answer '0' if they would not feel more positive at all, and answer '1' otherwise.\n"

@retry(wait=wait_random_exponential(min=5, max=60), stop=stop_after_attempt(6))
def completion_with_backoff(**kwargs):
    return openai.ChatCompletion.create(**kwargs)

def truncate_text(text, fixed_text):
    fixed_text_length = len(encoder.encode(fixed_text))
    if len(text.split(SEP_TOKEN)) < MAX_TOKEN-PROMPT_LENGTH-fixed_text_length:
        return text

    _text_turns = text.split(SEP_TOKEN)
    _n_turns = len(_text_turns)
    _discard_idx = int(_n_turns/2)
    if _text_turns[_discard_idx].startswith('Counselor'):
        _discard_idx -= 1
    _discard_indices = [_discard_idx, _discard_idx+1]
    _remaining_turns = [_text_turns[idx] for idx in range(_n_turns) if idx not in _discard_indices]
    
    return truncate_text(SEP_TOKEN.join(_remaining_turns))

def process_text(conv, feature):
    if USE_CONV:
        conv = conv.replace('___HelpSeeker___','HelpSeeker: ').replace('___Counselor___','Counselor: ')
        conv = truncate_text(conv, feature)
        text = conv.replace(SEP_TOKEN,'\n') + '\nSummary of the conversation:\n' + feature
    else:
        text = feature
    return text

def identify_testset(df, train_percent, eval_percent):
    all_ids = sorted(df['id'].tolist())
    test_ids = []
    for i in range(10):
        _train, _eval_test = train_test_split(all_ids, test_size=1-train_percent, random_state=i)
        _eval, _test = train_test_split(_eval_test, test_size=1-eval_percent, random_state=i)
        test_ids += _test
    test_ids = list(set(test_ids))
    return df[df['id'].isin(test_ids)]

def retrieve_data(train_percent, eval_percent):
    df_transcripts = pd.read_csv(os.getcwd()+'/data/extended_transcripts_allOutcome.tsv', sep='\t')
    df_codes_annotated = pd.read_csv(os.getcwd()+'/data/strategy_indicator.tsv', sep='\t')
    df_codes_labeled = pd.read_csv(os.getcwd()+'/data/extended_transcripts_allOutcome_strategy_indicator.tsv', sep='\t')
    df_codes = combine_codes(df_codes_annotated, df_codes_labeled)
    df_features = pd.read_csv(os.getcwd()+'/data/chatgpt_convToFeature.tsv', sep='\t')
    df_summary = pd.read_csv(os.getcwd()+'/data/chatgpt_convToSummary.tsv', sep='\t')
    df_ObjSummary = pd.read_csv(os.getcwd()+'/data/chatgpt_convToObjectiveSummary.tsv', sep='\t')

    df = convert_data(df_transcripts, df_codes, df_features, df_summary, df_ObjSummary, IS_BINARY, NUM_TURNS, OUTPUT_LABEL, SEP_TOKEN, USE_CONV, USE_CODE, USE_FEATURE, USE_SUMMARY, USE_OBJECTIVE)
    test_df = identify_testset(df, train_percent, eval_percent)

    datadict = {'id':[], 'text':[], 'label':[]}
    for idx, row in test_df.iterrows():
        datadict['id'].append(row['id'])
        datadict['text'].append(process_text(row['text'], row['feature']))
        datadict['label'].append(row['label'])

    return df, pd.DataFrame(data=datadict)

def cast_elem(elem):
    """
    1, 0, ... -> as is
    '1', '0', ... -> int(elem)
    1.0, 0.0, ... -> int(elem)
    '1.0', '0.0', ... -> int(float(elem))
    """
    if isinstance(elem, int):
        return elem
    elif isinstance(elem, float):
        return int(elem)
    elif isinstance(elem, str):
        if elem.isdigit():
            return int(elem)
        else:
            if elem.replace('.','').isdigit():
                return int(float(elem))
            else:
                if '0' in elem:
                    print(f'converting the answer {elem} to Negative')
                    return 0
                else:
                    print(f'converting the answer {elem} to non-negative')
                    return 1

def print_eval(original_df, answer_df, train_percent, eval_percent, model_name, IS_CUSTOM_MODEL):
    fname = os.getcwd()+'/console_outputs/llm_outputs/'+model_name+'.txt'
    sys.stdout = open(fname, 'w')
    all_ids = sorted(original_df['id'].tolist())
    preds, tgts = [], []
    print(f"\n{model_name}")
    for i in range(10):
        _train, _eval_test = train_test_split(all_ids, test_size=1-train_percent, random_state=i)
        _eval, _test = train_test_split(_eval_test, test_size=1-eval_percent, random_state=i)
        _curr_fold_df = answer_df[answer_df['id'].isin(_test)]

        pred = _curr_fold_df['answer'].tolist()
        pred = [cast_elem(elem) for elem in pred]
        
        tgt = _curr_fold_df['label'].tolist()
        if IS_CUSTOM_MODEL:
            tgt = [0 if int(elem[1]) == 1 else 1 for elem in tgt]
        else:
            tgt = [0 if elem == 0 else 1 for elem in tgt]

        preds += pred
        tgts += tgt

        print('[Fold %d] Test f1: %.6f'%(i, f1_score(y_true=tgt, y_pred=pred, average='macro')))
    
    overall_f1 = f1_score(y_true=tgts, y_pred=preds, average='macro')
    print('[Overall] Test f1: %.6f'%(overall_f1))

    if IS_BINARY:
        print(classification_report(tgts, preds, target_names=['Negative', 'Neut and Posi']))
    else:
        print(classification_report(tgts, preds, target_names=['Not at all', 'A little', 'A lot']))

    return

binary_indicator_text = 'binary' if IS_BINARY else 'multi'
input_stream_text = ''
if USE_CONV:
    input_stream_text += 'Conv'
if USE_CODE:
    input_stream_text += 'Code'
if USE_FEATURE:
    input_stream_text += 'Feat'
if USE_SUMMARY:
    input_stream_text += 'Summary'
if USE_OBJECTIVE:
    input_stream_text += "Objective"

if USE_CONV and (USE_FEATURE or USE_SUMMARY):
    IS_CUSTOM_MODEL = True
else:
    IS_CUSTOM_MODEL = False

train_percent, eval_percent = 0.6, 0.5
original_df, test_df = retrieve_data(train_percent, eval_percent)

model_name = f"{input_stream_text}-ChatGPT-{OUTPUT_LABEL}-{binary_indicator_text}"

results = {'id':[], 'label':[], 'answer':[]}
try:
    for idx, row in tqdm(test_df.iterrows()):
        message_list = [
            {"role": "system", "content": _system_msg},
            {"role": "user", "content": row['text'] + "\n\n" + _question}
        ]

        response = completion_with_backoff(model="gpt-3.5-turbo", messages=message_list)

        results['answer'].append(response['choices'][0]['message']['content'])
        results['id'].append(row['id'])
        results['label'].append(row['label'])
            
    answer_df = pd.DataFrame(data=results)
    answer_df.to_csv(os.getcwd()+'/data/'+model_name+'.tsv', sep='\t', index=False)
    answer_df = pd.read_csv(os.getcwd()+'/data/'+model_name+'.tsv', sep='\t')

    print_eval(original_df, answer_df, train_percent, eval_percent, model_name, IS_CUSTOM_MODEL)

except Exception as e:
    print(e)
    answer_df = pd.DataFrame(data=results)
    answer_df.to_csv(os.getcwd()+'/data/'+model_name+'_partial.tsv', sep='\t', index=False)
    print("Partial ChatGPT answers saved.")