import pandas as pd
import pickle
import pandas as pd
from fuzzywuzzy import fuzz
##1
#This uses the pickle file to get the csv 
with open('', 'rb') as file:
    data = pickle.load(file)


data_for_df = []
for key, value in data.items():
      # Add the dictionary key as a new field
    entry = value 
    entry['probingQuestionID'] = key
    key = key.split('_')[0]
    entry['group_id'] = key  # Add the dictionary key as a new field
    data_for_df.append(entry)


df = pd.DataFrame(data_for_df)


df.to_csv('causal_counterpart.csv', index=False)



pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)

#Merge both 
causal_counterpart_df = pd.read_csv('causal_counterpart.csv')
deli_data_complete_df = pd.read_csv('deli_data_complete.csv')

# Cleaning 
causal_counterpart_df['cleaned_probing_utterance'] = causal_counterpart_df['probing_utterance'].astype(str).str.split(':').str[-1].str.strip()

# Merge
merged_df = deli_data_complete_df.merge(causal_counterpart_df, left_on=['group_id', 'original_text'], right_on=['group_id', 'cleaned_probing_utterance'], how='left')

# Get unmatched rows
unmatched_rows = causal_counterpart_df[~causal_counterpart_df.index.isin(merged_df.index)]

if not unmatched_rows.empty:
  unmatched_df = pd.DataFrame(unmatched_rows)
  print("First 5 unmatched rows:")
  print(unmatched_df[['group_id', 'probing_utterance', 'cleaned_probing_utterance']].head().to_markdown(index=False, numalign="left", stralign="left"))
else:
  print("No unmatched rows found.")

# Write merged dataframe to csv
merged_df.to_csv('deli_data_complete_matched.csv', index=False)

def extract_text_after_colon(text):
    if ':' in text:
        return text.split(':', 1)[1].strip()
    return text
# Read the CSV file into a DataFrame
df = pd.read_csv('deli_data_complete_matched.csv')


#This will match the causal_counter_part with the messageID
# Create new columns for causal counterpart message IDs
df['causal_counterpart_1_msgID'] = None
df['causal_counterpart_2_msgID'] = None
df['causal_counterpart_3_msgID'] = None

# List to store instances where message IDs are not found for all causal counterparts
missing_message_ids = []
instances_with_missing_ids = 0
# Iterate through the rows of the DataFrame
for index, row in df.iterrows():
    group_id = row['group_id']
    probing_utterance = row['probing_utterance']
    causal_counterparts = [row['causal_counterpart_1'], row['causal_counterpart_2'], row['causal_counterpart_3']]
    causal_counterpart_msgIDs = []
    all_have_counterparts = True
    # Iterate through the causal counterparts
    for causal_counterpart in causal_counterparts:
        max_similarity = 0
        best_match_msg_id = None
        match_found = False
        if pd.notna(causal_counterpart):  # Check if causal counterpart is not NaN
            # Filter rows with the same group ID
            group_df = df[(df['group_id'] == group_id) & (df.index < index)]
            causal_counterpart = extract_text_after_colon(causal_counterpart)
            # Iterate through the rows in the group
            for _, group_row in group_df.iterrows():
                if pd.isna(group_row['original_text']):
                    continue

                similarity_score = fuzz.token_sort_ratio(causal_counterpart, group_row['original_text'])
               
                if similarity_score > max_similarity:
                    max_similarity = similarity_score
                    best_match_msg_id = group_row['message_id']
            
            causal_counterpart_msgIDs.append(best_match_msg_id)
        else:
            causal_counterpart_msgIDs.append(None)  # Append None if causal counterpart is NaN

    
    # Assign the message IDs to the respective columns
    df.at[index, 'causal_counterpart_1_msgID'] = causal_counterpart_msgIDs[0]
    df.at[index, 'causal_counterpart_2_msgID'] = causal_counterpart_msgIDs[1]
    df.at[index, 'causal_counterpart_3_msgID'] = causal_counterpart_msgIDs[2]

    # Check if any message ID is missing
    if any(pd.isna(msg_id) for msg_id in causal_counterpart_msgIDs) and all_have_counterparts:
        instances_with_missing_ids += 1
        missing_message_ids.append((group_id, probing_utterance))

import numpy as np
df[['causal_counterpart_1_msgID', 'causal_counterpart_2_msgID', 'causal_counterpart_3_msgID']] = df[['causal_counterpart_1_msgID', 'causal_counterpart_2_msgID', 'causal_counterpart_3_msgID']].replace('-1', np.nan)
df.to_csv('final_deli_data_complete_matched_with_msgID.csv', index=False)
##2

import numpy as np

# Ensure reproducibility
np.random.seed(42)

# Get unique groups
unique_groups = data_deduplicated['group_id'].unique()

# Shuffle
np.random.shuffle(unique_groups)

# Split groups into train, dev, and test
train_groups = unique_groups[200:]  # 300 groups for training
dev_groups = unique_groups[100:200]  # 100 groups for development
test_groups = unique_groups[:100]  # 100 groups for testing

# Map each group_id in the dataframe to the corresponding set
data_deduplicated['set'] = data_deduplicated['group_id'].apply(
    lambda x: 'Train' if x in train_groups else ('Dev' if x in dev_groups else 'Test'))

# Verify the split
split_counts = data_deduplicated['set'].value_counts()
split_counts, total_groups
data_deduplicated = data_deduplicated.rename(columns={"prev_uttterance_history": "prev_utterance_history"})
data_deduplicated.to_csv('final.csv')


import pickle

with open('final.pkl', 'wb') as file:
    pickle.dump(probing_dict_corrected, file)

print(data_deduplicated['group_id'].value_counts())

##3

def create_lists(file_path):
    # Load the dataset
    data = pd.read_csv(file_path)

    # Filter rows where 'annotation_type' is 'Probing' and has at least one causal counterpart
    probing_data = data[
        (data['annotation_type'] == 'Probing') & 
        (data['causal_counterpart_1_msgID'].notna() | 
         data['causal_counterpart_2_msgID'].notna() | 
         data['causal_counterpart_3_msgID'].notna())
    ]

    # Initialize lists for storing samples
    positive_tuples = []
    negative_samples = []

    # Generate positive samples
    causal_columns = ['causal_counterpart_1_msgID', 'causal_counterpart_2_msgID', 'causal_counterpart_3_msgID']
    for _, row in probing_data.iterrows():
        question_id = row['probingQuestionID']
        has_causal = False

        for col in causal_columns:
            if pd.notna(row[col]):
                positive_tuples.append((question_id, row[col]))
                has_causal = True

    # Generate negative samples considering message order
    for _, row in probing_data.iterrows():
        probing_id = row['probingQuestionID']
        group_id = row['group_id']
        #split = row['split']
        probing_index = data.index[data['message_id'] == row['message_id']][0]  # Index of the probing message
        causal_ids = {row[col] for col in causal_columns if pd.notna(row[col])}

        # All messages in the same group but not causal counterparts and before the probing message
        group_messages = data[(data['group_id'] == group_id) & (data.index < probing_index) & (~data['message_id'].isin(causal_ids))]
        if len(group_messages) >= 3:
            selected_ids = sample(group_messages['message_id'].tolist(), 3)
            negative_samples.extend([(probing_id, msg_id) for msg_id in selected_ids])
        else:
            selected_ids = group_messages['message_id'].tolist()
            negative_samples.extend([(probing_id, msg_id) for msg_id in selected_ids])

    # Save both sets of samples to files
    pd.DataFrame(positive_tuples, columns=['probingQuestionID', 'message_id']).to_csv('positive_samples.csv', index=False)
    pd.DataFrame(negative_samples, columns=['probingQuestionID', 'message_id']).to_csv('negative_samples.csv', index=False)

file_path = 'final.csv'  # Path to your CSV file
create_lists(file_path)