import pandas as pd
import numpy as np
import json
from tqdm import tqdm
from openai import OpenAI
import re
from datetime import datetime

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--exp_id', type=str, default='0')
args = parser.parse_args()

client = OpenAI(
  api_key="",
)

summary_merged_event_df = pd.read_csv('./data/future_data.csv')

def format_date(date):
    return datetime.strptime(str(int(date)), '%Y%m%d%H%M%S').strftime('%Y-%m-%d')


def determine_relationship(score):
    if pd.isna(score):
        return "Neutral"
    elif float(score) >= 0.75:
        return "Conflict"
    elif float(score) <= 0.25:
        return "Cooperation"
    else:
        return "Neutral"


def get_unique_recent_summaries(df, current_index, lookback=15):
    unique_summaries = []
    seen_summaries = set()

    for i in range(current_index - 1, -1, -1):
        summary = df.at[i, 'Summary']
        if summary not in seen_summaries:
            unique_summaries.append(df.iloc[i])
            seen_summaries.add(summary)
        if len(unique_summaries) == lookback:
            break

    if len(unique_summaries) < lookback:
        print(f"Warning: Only found {len(unique_summaries)} unique summaries for index {current_index}")

    return pd.DataFrame(unique_summaries[::-1])


def format_prompt(X, Y):
    events_list = []
    for summary in X['Summary'].unique():
        related_rows = X[X['Summary'] == summary]
        if sum(related_rows['RE_Evidence_for_prediction'] == 'No') == len(related_rows):
            continue

        important_countries = related_rows.iloc[0]['CE_Final_Important_Countries']
        date_added = format_date(related_rows.iloc[0]['CE_DATEADDED'])

        relationships = [
            f"{row['RE_country1']}-{row['RE_country2']}: {determine_relationship(row['RE_average_score'])}"
            for _, row in related_rows.iterrows() if (row['RE_Evidence_for_prediction'] == 'Yes')
        ]
        relationships_str = ', '.join(relationships)

        event = f"[NEWS ID] {related_rows.iloc[0]['GlobalEventID']} [DATE] {date_added} [News] {summary} [Important Countries in News] {important_countries} [Relationships of Pairs of Important Countries in News] {relationships_str}"
        events_list.append(event)

    events = " [SEP] ".join(events_list)

    prompt = (
        f"Question: Predict the nature of the relationship on the specified date between the two countries listed below.\n"
        f"Country 1: {Y['RE_country1']}\n"
        f"Country 2: {Y['RE_country2']}\n"
        f"Date: {format_date(Y['CE_DATEADDED'])}\n"
        "\nQuestion Background: "
        f"\nWe have retrieved the following information (summarized news articles) for this question:\n{events}\n"
        f"Also, historical and recent interactions, combined with the prior probability distribution of 0.38:0.24:0.38 for conflict, cooperation, and neutral respectively, should guide your prediction.\n"
        "\nResolution Criteria: The prediction must choose one and only one of the following exact classifications for the relationship between the two countries: 'Conflict', 'Cooperation', or 'Neutral'. "
        "\nInstructions:\n"
        "1. Considering the provided context summaries, discuss whether there are reasons to predict a conflict. If so, what are the reasons and which NewsIDs are considered?\n"
        "2. Considering the provided context summaries, discuss whether there are reasons to predict cooperation. If so, what are the reasons and which NewsIDs are considered?\n"
        "3. Considering the provided context summaries, discuss whether there are reasons to predict neutral. If so, what are the reasons and which NewsIDs are considered?\n"
        "4. Using only the article content, assign probabilities to each possible relationship outcome (conflict, cooperation, neutral) so that the total equals 1 (e.g., conflict: 0.7, cooperate: 0.1, neutral: 0.2).\n"
        "5. Now combine article content, prior knowledge, and your background general knowledge to assign probabilities as in instruction 4.\n"
        "6. Check whether you properly consider the retrieved information (i.e., Question Background).\n"
        "7. Finally, choose one answer. Select from 'conflict', 'cooperation', or 'neutral'. If conflict, answer *0*; if cooperation, answer *1*; if neutral, answer *2*. Your answer must include asterisks around the number; any other form will be incorrect.\n"
        "8. If you predicted 'conflict' or 'cooperation' in step 7, briefly describe a potential specific event that might occur between the two countries on the given date, considering the previous documents and the future date. The description can be somewhat speculative. Use one sentence.\n"

        "\nFollow this format exactly to ensure proper parsing and then answer:\n"
        "1. Conflict Reasons: {{ Insert your thoughts }}\n"
        "2. Cooperation Reasons: {{ Insert your thoughts }}\n"
        "3. Neutral Reasons: {{ Insert your thoughts }}\n"
        "4. Article Content Probabilities: {{ Insert your answer }}\n"
        "5. Combined Probabilities: {{ Insert your answer }}\n"
        "6. Consideration Check: {{ Insert your thoughts }}\n"
        "7. Final Prediction: {{ Insert your answer }}\n"
        "8. Event Description: {{ Insert your answer }}"
    )
    return prompt


def prepare_unique_summaries(df, lookback=15):
    unique_summaries_map = {}

    for current_index in tqdm(range(15, len(df)), desc="Preparing unique summaries"):
        unique_summaries = []
        seen_summaries = set()

        for i in range(current_index - 1, -1, -1):
            summary = df.at[i, 'Summary']
            global_event_id = df.at[i, 'GlobalEventID']
            if global_event_id != df.at[current_index, 'GlobalEventID']:
                if summary not in seen_summaries:
                    event_rows = df[df['GlobalEventID'] == global_event_id]
                    unique_summaries.extend(event_rows.to_dict('records'))
                    seen_summaries.add(summary)
            if len(seen_summaries) >= lookback:
                break

        if len(seen_summaries) < lookback:
            print(f"Warning: Only found {len(seen_summaries)} unique summaries for index {current_index}")

        unique_summaries_map[current_index] = pd.DataFrame(unique_summaries)

    return unique_summaries_map


def parse_final_prediction(text):
    matches = re.findall(r'\*(\d+)\*', text)
    if matches:
        return int(matches[-1])
    return -1

def save_intermediate_results(predictions_dict, responses_dict, filename_prefix, mid):
    str_predictions_dict = {str(key): value for key, value in predictions_dict.items()}
    str_responses_dict = {str(key): value for key, value in responses_dict.items()}

    results = {
        "predictions": str_predictions_dict,
        "responses": str_responses_dict
    }

    if mid:
        with open(f"{filename_prefix}_results_mid.json", 'w') as f:
            json.dump(results, f, ensure_ascii=False, indent=4)
        print(f"Saved intermediate results to {filename_prefix}_results_mid.json")
    else:
        with open(f"{filename_prefix}_results_final.json", 'w') as f:
            json.dump(results, f, ensure_ascii=False, indent=4)
        print(f"Saved intermediate results to {filename_prefix}_results_final.json")

def predict_relationship(df, unique_summaries_map, filename_prefix, save_interval=100):
    predictions_dict = {}
    responses_dict = {}
    prompt_dict = {}

    for i in tqdm(range(15, len(df))):
        X = unique_summaries_map[i]
        Y = df.iloc[i]
        prompt = format_prompt(X, Y)

        completion = client.chat.completions.create(
            model="gpt-4o",
            messages=[{"role": "user", "content": prompt}],
            max_tokens=900,
            temperature=1
        )

        full_response_text = completion.choices[0].message.content.strip()
        predicted_label = parse_final_prediction(full_response_text)

        global_event_id = str(Y['Unique_ID'])
        predictions_dict[global_event_id] = predicted_label
        responses_dict[global_event_id] = full_response_text
        prompt_dict[global_event_id] = prompt

        if (i % save_interval) == 0 and i != 0:
            save_intermediate_results(predictions_dict, responses_dict, filename_prefix, True)
        print(prompt)
        print(f"Completed {i + 1}/{len(df)}: {predicted_label}")
        print(full_response_text)

    with open(f"{filename_prefix}_prompt_final.json", 'w') as f:
        json.dump(prompt_dict, f, ensure_ascii=False, indent=4)

    save_intermediate_results(predictions_dict, responses_dict, filename_prefix, False)
    return predictions_dict, responses_dict

summary_merged_event_df = summary_merged_event_df.sort_values(by='CE_DATEADDED').reset_index(drop=True)
two_weeks_df = summary_merged_event_df[(2024051100000 <= summary_merged_event_df['CE_DATEADDED']) & (summary_merged_event_df['CE_DATEADDED'] < 20240529000000)].reset_index(drop=True)
unique_summaries_map = prepare_unique_summaries(two_weeks_df)
two_weeks_df = two_weeks_df.sort_values(by='CE_DATEADDED').reset_index(drop=True)
predictions, responses = predict_relationship(two_weeks_df, unique_summaries_map, filename_prefix=f"gpt4_future", save_interval=100)