library(tidyverse)

data <- read_csv("data/plot-candidate-results.csv") %>%
  mutate(
    model = factor(
      model, 
      levels = c("gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl", "opt-125m", "opt-350m", "opt-1.3b", "opt-2.7b", "opt-6.7b", "llama2-7b-hf", "Mistral-7B-v0.1")
    )
  ) %>%
  mutate(
    match = case_when(
      heuristic == correct ~ "Matched",
      TRUE ~ "Mismatched"
    )
  )

prepare_data <- function(data, matched=FALSE) {
  data <- ungroup(data)
  prepared <- bind_rows(
    data %>% mutate(heuristic = "first", correct="first"),
    data %>% mutate(heuristic = "first", correct="recent"),
    data %>% mutate(heuristic = "recent", correct="first"),
    data %>% mutate(heuristic = "recent", correct="recent"),
  )

  if(matched) {
    prepared <- prepared %>%
      mutate(
        match = case_when(
          heuristic == correct ~ "Matched",
          TRUE ~ "Mismatched"
        )
      ) %>%
      select(-heuristic)
  }
  return(prepared)
}

prepare_data(averaged)

averaged <- data %>%
  group_by(model, prompt_length) %>%
  summarize(
    ste = 1.96 * plotrix::std.error(accuracy),
    accuracy = mean(accuracy)
  )

heuristic_results <- data %>%
  group_by(model, heuristic, correct, prompt_length) %>%
  summarize(
    ste = 1.96 * plotrix::std.error(accuracy),
    accuracy = mean(accuracy)
  )

matched_results <- data %>%
  group_by(model, match, correct, prompt_length) %>%
  summarize(
    ste = 1.96 * plotrix::std.error(accuracy),
    accuracy = mean(accuracy)
  )

matched_results


bind_rows(
  matched_results,
  averaged %>% ungroup() %>% mutate(match = "Average", correct = "Average")
) %>%
  mutate(
    match = factor(match, levels = c("Matched", "Mismatched", "Average")),
    correct = str_to_title(correct),
    correct = factor(correct, levels = c("First", "Recent", "Average")),
    color = case_when(
      correct == "First" ~ "#7570b3",
      correct == "Recent" ~ "#d95f02",
      TRUE ~ "black"
    ),
    shape = case_when(
      correct == "First" ~ 15,
      correct == "Recent" ~ 17,
      TRUE ~ 16
    ),
    linetype = case_when(
      match == "Matched" ~ "dashed",
      match == "Mismatched" ~ "dotted",
      TRUE ~ "solid"
    )
  ) %>%
  ggplot(aes(prompt_length, accuracy, shape = correct, color = correct, fill=correct, linetype=match)) +
  geom_point(size = 2.5) +
  geom_line(linewidth=0.7) +
  geom_ribbon(aes(ymin = accuracy-ste, ymax=accuracy+ste), alpha = 0.1, color=NA) +
  # geom_errorbar(aes(ymin = accuracy-ste, ymax=accuracy+ste), width=0.3, linetype="solid") +
  facet_wrap(~ model, nrow=1) +
  scale_shape_manual(values = c(15, 17, 16)) +
  scale_linetype_manual(values = c("dashed", "dotted", "solid")) +
  scale_color_manual(values = c("#7570b3", "#d95f02", "black"), aesthetics = c("fill", "color")) +
  labs(
    x = "Number of Exemplars (0 = Zero-shot)",
    y = "Accuracy",
    color = "Test Ground Truth",
    fill = "Test Ground Truth",
    shape = "Test Ground Truth",
  )



