import pandas as pd

data = pd.read_csv('') #labelled dataset with the 


probing_data = data[data['annotation_type'] == 'Probing']

# ProbingQuestionID
probing_data['probingQuestionID'] = probing_data.groupby('group_id').cumcount() + 1
probing_data['probingQuestionID'] = probing_data['group_id'].astype(str) + '_' + probing_data['probingQuestionID'].astype(str)

# Update the original dataframe
data.loc[probing_data.index, 'probingQuestionID'] = probing_data['probingQuestionID']



def get_previous_messages(data, group_id, message_id):
    # Filter data for the specified group and where the origin is not 4. 4 is Mariah
    group_data = data[(data['group_id'] == group_id) & (data['origin'] != 4)].reset_index(drop=True)
    if group_data.empty:
        return "No previous messages found."

    # Find the index of the specified message_id in the filtered data
    current_index = group_data[group_data['message_id'] == message_id].index.min()
    if pd.isna(current_index):
        return "Current message not found."

    # Calculate the starting index 
    start_index = max(0, current_index - 22)

    # Get previous messages 
    previous_messages = group_data.iloc[start_index:current_index]

    #
    message_history = [
        f"Participant {row['origin']}: {row['original_text']}"
        for _, row in previous_messages.iterrows()
    ]

    # Join the messages with new lines and return the result
    return "\n".join(message_history)

data['prev_utterance_history'] = data.apply(
    lambda row: get_previous_messages(data, row['group_id'], row['message_id']) if row['annotation_type'] == 'Probing' else None,
    axis=1
)
data['probing_utterance'] = data.apply(
    lambda row: row['original_text'] if row['annotation_type'] == 'Probing' else None, 
    axis=1
)
data.to_csv('wtd_with_probingID.csv')


##1
probing_data = data[data['annotation_type'] == 'Probing']

# Create columns for the message IDs of the causal counterparts
probing_data['causal_counterpart1_msgID'] = None
probing_data['causal_counterpart2_msgID'] = None
probing_data['causal_counterpart3_msgID'] = None

data['causal_counterpart1_msgID'] = None
data['causal_counterpart2_msgID'] = None
data['causal_counterpart3_msgID'] = None


#  text cleaning and partial matching
def find_message_id_fuzzy(group, index, text):

    max_index = max(0, index - 22)  # don't go out of bounds
    relevant_group = group[max_index:index]
    
    cleaned_text = re.sub(r"Participant \d+:\s*", "", text).lower().strip()
    for idx, row in relevant_group.iterrows():
        row_text_cleaned = re.sub(r"Participant \d+:\s*", "", row['original_text']).lower().strip()
        if cleaned_text in row_text_cleaned:
            return row['message_id']
    
    
    # Using FuzzyWuzzy 
    best_score = 0
    best_id = None
    for idx, row in relevant_group.iterrows():
        row_text_cleaned = re.sub(r"Participant \d+:\s*", "", row['original_text']).lower().strip()
        score = fuzz.partial_ratio(cleaned_text, row_text_cleaned)
        if score > best_score:
            best_score = score
            best_id = row['message_id']
   # print(best_id)
    return best_id

for index, row in probing_data.iterrows():
    group_data = data[data['group_id'] == row['group_id']]
    # Fetch the current index in the main dataframe
    main_index = group_data.index.get_loc(index)
    if pd.notna(row['causal_counterpart_1']):
        probing_data.at[index, 'causal_counterpart1_msgID'] = find_message_id_fuzzy(group_data, main_index, row['causal_counterpart_1'])
        #print('1', find_message_id_fuzzy(group_data, main_index, row['causal_counterpart_1']))
    if pd.notna(row['causal_counterpart_2']):
        probing_data.at[index, 'causal_counterpart2_msgID'] = find_message_id_fuzzy(group_data, main_index, row['causal_counterpart_2'])
    if pd.notna(row['causal_counterpart_3']):
       probing_data.at[index, 'causal_counterpart3_msgID'] = find_message_id_fuzzy(group_data, main_index, row['causal_counterpart_3'])


# Save the updated data back to a new CSV file
updated_file_path = 'updated_wtd_with_probingID.csv'
data.to_csv(updated_file_path, index=False)

##2
np.random.seed(42)

# Get unique groups
unique_groups = data_deduplicated['group_id'].unique()


# Shuffle 
np.random.shuffle(unique_groups)

train_groups = unique_groups[:7]  # 7 groups for training
dev_groups = unique_groups[7:8]   # 1 group for development
test_groups = unique_groups[8:]   # 2 groups for testing


# Map each group_id in the dataframe to the corresponding set
data_deduplicated['set'] = data_deduplicated['group_id'].apply(
    lambda x: 'Train' if x in train_groups else ('Dev' if x in dev_groups else 'Test'))

# Verify the split
split_counts = data_deduplicated['set'].value_counts()
split_counts, total_groups

data_deduplicated.to_csv('final.csv')
print('dist' ,data_deduplicated['set'].value_counts())
probing_data_corrected = data_deduplicated[data_deduplicated['probingQuestionID'].notnull()]
print(probing_data_corrected['group_id'].value_counts())
probing_dict_corrected = {row['probingQuestionID']: row.to_dict() for index, row in probing_data_corrected.iterrows()}

# Display a sample from the corrected dictionary to verify structure
sample_key_corrected = next(iter(probing_dict_corrected))
probing_dict_corrected[sample_key_corrected], len(probing_dict_corrected)

import pickle

with open('final.pkl', 'wb') as file:
    pickle.dump(probing_dict_corrected, file)

print(data_deduplicated['group_id'].value_counts())


##3
def create_lists(file_path):
    # Load the dataset
    data = pd.read_csv(file_path)

    # Filter rows where 'annotation_type' is 'Probing' and has at least one causal counterpart
    probing_data = data[
        (data['annotation_type'] == 'Probing') & 
        (data['causal_counterpart1_msgID'].notna() | 
         data['causal_counterpart2_msgID'].notna() | 
         data['causal_counterpart3_msgID'].notna())
    ]

    # Initialize lists for storing samples
    positive_tuples = []
    negative_samples = []

    # Generate positive samples
    causal_columns = ['causal_counterpart1_msgID', 'causal_counterpart2_msgID', 'causal_counterpart3_msgID']
    for _, row in probing_data.iterrows():
        question_id = row['probingQuestionID']
        has_causal = False

        for col in causal_columns:
            if pd.notna(row[col]):
                positive_tuples.append((question_id, row[col]))
                has_causal = True

    # Generate negative samples considering message order
    for _, row in probing_data.iterrows():
        probing_id = row['probingQuestionID']
        group_id = row['group_id']
        #split = row['split']
        probing_index = data.index[data['message_id'] == row['message_id']][0]  # Index of the probing message
        causal_ids = {row[col] for col in causal_columns if pd.notna(row[col])}

        # All messages in the same group but not causal counterparts and before the probing message
        group_messages = data[(data['group_id'] == group_id) & (data.index < probing_index) & (~data['message_id'].isin(causal_ids))]
        if len(group_messages) >= 3:
            selected_ids = sample(group_messages['message_id'].tolist(), 3)
            negative_samples.extend([(probing_id, msg_id) for msg_id in selected_ids])
        else:
            selected_ids = group_messages['message_id'].tolist()
            negative_samples.extend([(probing_id, msg_id) for msg_id in selected_ids])

    # Save both sets of samples to files
    pd.DataFrame(positive_tuples, columns=['probingQuestionID', 'message_id']).to_csv('positive_samples.csv', index=False)
    pd.DataFrame(negative_samples, columns=['probingQuestionID', 'message_id']).to_csv('negative_samples.csv', index=False)

file_path = 'final.csv'  # Path to your CSV file
create_lists(file_path)