library(tidyverse)
library(fs)

theme_set(
  theme_bw(base_size = 15, base_family = "Times") +
    theme(
      legend.position = "top",
      axis.text = element_text(color = "black"),
      panel.grid = element_blank()
    )
)

results <- dir_ls("data/results/comps-qa/", regexp = "*.csv") %>%
  map_df(read_csv, .id = "model") %>%
  # replace_na(list(heuristic = "0-shot")) %>%
  mutate(
    model = str_remove(model, ".csv"),
    model = str_remove(model, "data/results/(comps\\-qa|comps)/"),
    model = str_remove(model, "(_home_shared_|facebook_|mistralai_|meta-llama_)"),
    model = factor(
      model, 
      levels = c("gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl", "gpt2-xl2", "opt-125m", "opt-350m", "opt-1.3b", "opt-1.3b-2", "opt-2.7b", "opt-6.7b", "opt-6.7b-2", "llama2-7b-hf", "llama2-7b-hf-2", "Llama-2-13b-hf", "Mistral-7B-v0.1")
    ),
    heuristic = str_replace(heuristic, "none", "0-shot"),
    heuristic = factor(heuristic, levels = c("0-shot", "first", "second", "alt-first", "alt-second"))
  )

zero_shot_base <- results %>% filter(heuristic == '0-shot')

zero_shot <- bind_rows(
  zero_shot_base %>% mutate(heuristic = "first"),
  zero_shot_base %>% mutate(heuristic = "alt-first"),
  zero_shot_base %>% mutate(heuristic = "second"),
  zero_shot_base %>% mutate(heuristic = "alt-second")
) %>%
  mutate(
    heuristic = factor(heuristic, levels = c("0-shot", "first", "second", "alt-first", "alt-second"))
  )

results_full <- bind_rows(
  results %>%
    filter(heuristic != "0-shot"),
  zero_shot
)

results_full %>% 
  filter(!str_detect(heuristic, "alt")) %>%
  filter(model %in% c("gpt2-xl", "gpt2-xl2", "opt-1.3b", "opt-1.3b-2", "opt-2.7b", "opt-6.7b", "opt-6.7b-2", "llama2-7b-hf", "llama2-7b-hf-2", "Llama-2-13b-hf", "Mistral-7B-v0.1")) %>%
  mutate(
    correct = case_when(
      correct == "second" ~ "Recent",
      TRUE ~ "First"
    ),
    heuristic = case_when(
      heuristic == "second" ~ "Recent",
      TRUE ~ "First"
    )
  ) %>% 
  # write_csv("data/plot-candidate-results.csv")
  group_by(model, heuristic, prompt_length, correct) %>%
  summarize(
    ste = 1.96 * plotrix::std.error(accuracy),
    std = sd(accuracy),
    accuracy = mean(accuracy)
  ) %>% 
  ggplot(aes(prompt_length, accuracy, color = correct, fill = correct, shape = correct)) +
  geom_point(size = 3) +
  geom_line(linewidth = 0.8) +
  # geom_ribbon(aes(ymin = accuracy-ste, ymax=accuracy+ste), alpha = 0.4, color = NA) +
  geom_errorbar(aes(ymin = accuracy-ste, ymax=accuracy+ste), width = 0.3) +
  geom_hline(yintercept = 0.5, linetype="dashed") +
  scale_x_continuous(breaks = scales::pretty_breaks(7)) +
  scale_color_manual(values = c("#7570b3", "#d95f02")) +
  ggh4x::facet_grid2(heuristic ~ model, scales = "free_x", independent = "x") +
  labs(
    x = "Number of Prompts (0 = Zero-shot)",
    y = "Accuracy",
    color = "Test Ground Truth",
    fill = "Test Ground Truth",
    shape = "Test Ground Truth",
  )

# Matched vs Mis-matched
results_full %>%
  filter(!str_detect(heuristic, "alt")) %>%
  filter(model %in% c("gpt2-xl", "gpt2-xl2", "opt-1.3b", "opt-1.3b-2", "opt-2.7b", "opt-6.7b", "opt-6.7b-2", "llama2-7b-hf", "llama2-7b-hf-2", "Mistral-7B-v0.1")) %>%
  mutate(
    correct = case_when(
      correct == "second" ~ "recent",
      TRUE ~ "first"
    ),
    heuristic = case_when(
      heuristic == "second" ~ "recent",
      TRUE ~ "first"
    ),
    match = case_when(
      correct == heuristic ~ "Matched",
      TRUE ~ "Mismatched"
    )
  ) %>%
  group_by(model, match, correct, prompt_length) %>%
  summarize(
    ste = 1.96 * plotrix::std.error(accuracy),
    std = sd(accuracy),
    accuracy = mean(accuracy)
  ) %>%
  ggplot(aes(prompt_length, accuracy, color = correct, shape = correct)) +
  geom_point(size = 3) +
  geom_line(linewidth = 0.8) +
  geom_errorbar(aes(ymin = accuracy-ste, ymax=accuracy+ste), width = 0.3) +
  geom_hline(yintercept = 0.5, linetype="dashed") +
  scale_x_continuous(breaks = scales::pretty_breaks(7)) +
  scale_color_manual(values = c("#7570b3", "#d95f02")) +
  ggh4x::facet_grid2(match ~ model, scales = "free_x", independent = "x") +
  labs(
    x = "Number of Exemplars (0 = Zero-shot)",
    y = "Accuracy",
    color = "Test Ground Truth",
    shape = "Test Ground Truth",
  )


results_full %>%
  filter(!str_detect(heuristic, "alt")) %>%
  filter(model %in% c("gpt2-xl", "gpt2-xl2", "opt-1.3b", "opt-1.3b-2", "opt-2.7b", "opt-6.7b", "opt-6.7b-2", "llama2-7b-hf", "llama2-7b-hf-2", "Mistral-7B-v0.1")) %>%
  mutate(
    correct = case_when(
      correct == "second" ~ "recent",
      TRUE ~ "first"
    ),
    heuristic = case_when(
      heuristic == "second" ~ "recent",
      TRUE ~ "first"
    )
  ) %>%
  group_by(model, heuristic, prompt_length, correct) %>%
  summarize(
    ste = 1.96 * plotrix::std.error(accuracy),
    std = sd(accuracy),
    accuracy = mean(accuracy)
  ) %>% 
  ggplot(aes(prompt_length, accuracy, color = paste(heuristic, "-", correct), shape = paste(heuristic, "-", correct), linetype=(heuristic!=correct))) +
  geom_point(size = 3) +
  geom_line(linewidth = 0.8) +
  # geom_ribbon(aes(ymin = accuracy-ste, ymax=accuracy+ste), alpha = 0.4, color = NA) +
  geom_errorbar(aes(ymin = accuracy-ste, ymax=accuracy+ste), width = 0.3) +
  geom_hline(yintercept = 0.5, linetype="dashed") +
  scale_x_continuous(breaks = scales::pretty_breaks(7)) +
  scale_color_manual(values = c("red", "pink", "lightblue", "blue")) +
  facet_wrap(~model, nrow=1) +
  # ggh4x::facet_grid2(~model, scales = "free_x", independent = "x") +
  labs(
    x = "Number of Prompts (0 = Zero-shot)",
    y = "Accuracy",
    color = "Heuristic Predicts",
    fill = "Heuristic Predicts",
    shape = "Heuristic Predicts",
  )

results %>%
  filter(heuristic == "0-shot") %>%
  ggplot(aes(model, accuracy, color = correct, shape = correct)) +
  geom_point(size = 2)

overall_accuracy <- results %>%
  group_by(model, prompt_domain, prompt_id, heuristic) %>%
  summarize(accuracy = mean(accuracy)) %>%
  ungroup() %>%
  group_by(model, heuristic) %>%
  summarize(
    ste = 1.96 * plotrix::std.error(accuracy),
    std = sd(accuracy),
    accuracy = mean(accuracy)
  ) 

overall_accuracy %>%
  ggplot(aes(heuristic, accuracy, group = model, shape = model, color = model)) +
  geom_point(size = 3) +
  # geom_errorbar(aes(ymin = accuracy - std, ymax = accuracy + std)) +
  geom_line()

overall_accuracy %>%
  ggplot(aes(heuristic, accuracy, color = model, fill = model)) +
  geom_col(width = 0.8) +
  geom_errorbar(aes(ymin = accuracy - ste, ymax = accuracy + ste), color = "black", width = 0.2) +
  facet_wrap(~model, nrow = 2) +
  # scale_color_brewer(palette = "Dark2") +
  theme_bw(base_size = 15) +
  theme(
    legend.position = "top"
  )

results %>%
  filter(!str_detect(heuristic, "alt")) %>% 
  mutate(
    match = case_when(
      heuristic == "0-shot" ~ "0-shot",
      heuristic == correct ~ "match",
      heuristic != correct ~ "mismatch"
    )
  ) %>%
  group_by(model, match) %>%
  summarize(
    ste = 1.96 * plotrix::std.error(accuracy),
    std = sd(accuracy),
    accuracy = mean(accuracy)
  ) %>% View()

results %>%
  filter(!str_detect(heuristic, "alt")) %>%
  group_by(model, heuristic, correct) %>%
  summarize(
    ste = 1.96 * plotrix::std.error(accuracy),
    std = sd(accuracy),
    accuracy = mean(accuracy)
  ) %>% 
  ggplot(aes(heuristic, accuracy, shape = correct, color = correct, group = correct)) +
  geom_point(size = 2.4) +
  geom_errorbar(aes(ymin = accuracy - ste, ymax = accuracy + ste), width = 0.2) +
  geom_line(size = 0.8) +
  facet_wrap(~model, nrow = 2) +
  # scale_color_brewer(palette = "Dark2") +
  theme_bw(base_size = 17, base_family = "Times") +
  theme(
    panel.grid = element_blank(),
    legend.position = "top"
  )




