library(tidyverse)
library(fs)
library(glue)
library(patchwork)
library(rstatix)
library(ggtext)

theme_set(
  theme_bw(base_size = 15, base_family = "Times") +
    theme(
      legend.position = "top",
      axis.text = element_text(color = "black"),
      # panel.grid = element_blank() # no gridlines
    )
)

LEVELS = c("gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl", "opt-125m", "opt-350m", "opt-1.3b", "opt-2.7b", "opt-6.7b", "llama2-7b-hf","Llama-2-13b-hf", "Mistral-7B-v0.1")
MODEL_LIST = c("gpt2-xl", "opt-6.7b", "llama2-7b-hf", "Llama-2-13b-hf", "Mistral-7B-v0.1")
MODEL_NAMES = c("GPT-2 XL", "OPT-6.7b", "Llama-2-7b", "Llama-2-13b", "Mistral-7b")

load_instruction_results <- function(comps="comps-instructions", alt=FALSE) {
  results <- dir_ls(glue("data/results/{comps}/"), regexp = "*.csv") %>%
    map_df(read_csv, .id = "model") %>%
    mutate(
      model = str_remove(model, ".csv"),
      model = str_remove(model, "data/results/(comps\\-qa|comps)(-instructions)?/"),
      model = str_remove(model, "(_home_shared_|facebook_|mistralai_|meta-llama_)"),
      model = factor(model, levels = LEVELS),
      correct = str_replace(correct, "first", "First"),
      correct = str_replace(correct,"second", "Recent"),
      correct = str_replace(correct,"alt", "Alt"),
      heuristic = str_replace(heuristic,"first", "First"),
      heuristic = str_replace(heuristic,"second", "Recent"),
      heuristic = str_replace(heuristic,"alt", "Alt"),
      heuristic = str_replace(heuristic, "none", "0-shot"),
      heuristic = factor(heuristic, levels = c("0-shot", "First", "Recent", "Alt-First", "Alt-Recent"))
    ) %>%
    filter(!is.na(model))
  
  zero_shot_base <- results %>% filter(heuristic == '0-shot')
  
  zero_shot <- bind_rows(
    zero_shot_base %>% mutate(heuristic = "First"),
    zero_shot_base %>% mutate(heuristic = "Alt-First"),
    zero_shot_base %>% mutate(heuristic = "Recent"),
    zero_shot_base %>% mutate(heuristic = "Alt-Recent")
  ) %>%
    mutate(
      heuristic = factor(heuristic, levels = c("0-shot", "First", "Recent", "Alt-First", "Alt-Recent"))
    )
  
  results_full <- bind_rows(
    results %>%
      filter(heuristic != "0-shot"),
    zero_shot
  )
  
  if(alt==FALSE) {
    results_full <- results_full %>%
      filter(!str_detect(heuristic, "Alt"))
  }
  
  return(results_full)
}

comps_qa_inst <- load_instruction_results("comps-instructions")

comps_qa_inst %>%
  group_by(model, instruction_type, prompt_length, match = heuristic == correct) %>%
  summarize(
    ste = 1.96 * plotrix::std.error(accuracy),
    acc=mean(accuracy)
  ) %>%
  # select(-ste) %>%
  # pivot_wider(names_from = instruction_type, values_from = acc) %>% 
  # View() %>%
  # filter(model == "Mistral-7B-v0.1") %>%
  mutate(match = case_when(match == TRUE ~ "Heuristic works", TRUE ~ "Heuristic doesn't work")) %>%
  ggplot(aes(prompt_length, acc, color = instruction_type, shape = instruction_type, fill = instruction_type)) +
  geom_point(size = 3) +
  geom_line() +
  geom_ribbon(aes(ymin = acc-ste, ymax=acc+ste), alpha = 0.2, color = NA) +
  # facet_wrap(~match) +
  facet_grid(model~match) +
  scale_x_continuous(breaks = scales::pretty_breaks(7)) +
  # scale_y_continuous(limits = c(0.5, 1.0)) +
  scale_color_brewer(palette = "Dark2", direction = -1, aesthetics = c("color", "fill")) +
  scale_shape_manual(values = c(16, 17, 18, 15))
  

load_results <- function(comps="comps", alt=FALSE) {
  instructions = FALSE
  if(str_detect(comps, "instruction")) {
    instructions = TRUE
  }
  
  results <- dir_ls(glue("data/results/{comps}/"), regexp = "*.csv") %>%
    map_df(read_csv, .id = "model") %>%
    mutate(
      model = str_remove(model, ".csv"),
      model = str_remove(model, "data/results/(comps\\-qa|comps)(-instructions)?/"),
      model = str_remove(model, "(_home_shared_|facebook_|mistralai_|meta-llama_)"),
      model = factor(model, levels = LEVELS),
      correct = str_replace(correct, "first", "First"),
      correct = str_replace(correct,"second", "Recent"),
      correct = str_replace(correct,"alt", "Alt"),
      heuristic = str_replace(heuristic,"first", "First"),
      heuristic = str_replace(heuristic,"second", "Recent"),
      heuristic = str_replace(heuristic,"alt", "Alt"),
      heuristic = str_replace(heuristic, "none", "0-shot"),
      heuristic = factor(heuristic, levels = c("0-shot", "First", "Recent", "Alt-First", "Alt-Recent"))
    ) %>%
    filter(!is.na(model))
  
  zero_shot_base <- results %>% filter(heuristic == '0-shot')
  
  zero_shot <- bind_rows(
    zero_shot_base %>% mutate(heuristic = "First"),
    zero_shot_base %>% mutate(heuristic = "Alt-First"),
    zero_shot_base %>% mutate(heuristic = "Recent"),
    zero_shot_base %>% mutate(heuristic = "Alt-Recent")
  ) %>%
    mutate(
      heuristic = factor(heuristic, levels = c("0-shot", "First", "Recent", "Alt-First", "Alt-Recent"))
    )
  
  results_full <- bind_rows(
    results %>%
      filter(heuristic != "0-shot"),
    zero_shot
  )
  
  # results_full <- results %>%
  #   group_by(model, prompt_domain, prompt_id, heuristic, prompt_length, correct) %>%
  #   nest() %>%
  #   mutate(
  #     best = map(data, function(x) {
  #       x %>% arrange(-accuracy) %>% slice(1)
  #     })
  #   ) %>% select(-data) %>% unnest(best) %>%
  #   ungroup()
  
  if(instructions) {
    best_instructions <- bind_rows(
      results %>%
        filter(heuristic != "0-shot"),
      zero_shot
    ) %>% 
      group_by(model, instruction_type) %>% 
      summarize(acc = mean(accuracy)) %>% 
      filter(acc == max(acc)) %>%
      ungroup()
    
    results_full <- bind_rows(
      results %>%
        filter(heuristic != "0-shot"),
      zero_shot
    ) %>% inner_join(best_instructions)
  }
  
  if(alt==FALSE) {
    results_full <- results_full %>%
      filter(!str_detect(heuristic, "Alt"))
  }
  
  return(results_full)
}

comps <- load_results("comps-qa-instructions")

comps_qa_inst <- load_instruction_results("comps-instructions")

paired_t_test <- comps_qa_inst %>% 
  mutate(
    match = heuristic==correct,
    instruction_type = case_when(
      instruction_type == "inst1" ~ "Detailed-1",
      instruction_type == "inst2" ~ "Detailed-2",
      TRUE ~ instruction_type
    )
  ) %>% 
  group_by(model, match) %>%
  pairwise_t_test(
    accuracy ~ instruction_type, paired = TRUE, 
    p.adjust.method = "bonferroni"
  ) %>%
  filter(model %in% MODEL_LIST) %>%
  mutate(
    model = factor(model, levels=MODEL_LIST, labels=MODEL_NAMES)
  )

paired_t_test %>%
  select(model, match, group1, group2, t = statistic, p.adj) %>%
  janitor::clean_names() %>%
  mutate(
    p_adj = case_when(
      p_adj < 0.001 ~ "< 0.001",
      TRUE ~ p_adj
    )
  )
# accuracy ~ instruction + prompt_length + match + (1|item_id)

comps %>%
  mutate(
    match = case_when(
      heuristic == correct ~ "Match",
      TRUE ~ "Mismatch"
    )
  ) %>%
  group_by(model, match, heuristic, correct, prompt_length) %>%
  summarize(
    ste = 1.96 * plotrix::std.error(accuracy),
    accuracy = mean(accuracy)
  ) %>% 
  filter(model %in% MODEL_LIST) %>%
  ggplot(aes(prompt_length, accuracy, color=match, shape=match)) +
  geom_point() + 
  geom_line() +
  facet_grid(heuristic ~ model)

comps <- get_detailed_results("comps-qa")

# comps %>% 

# get_detailed_results("comps")
  

get_detailed_results <- function(comps="comps", model_list=MODEL_LIST, alt=FALSE) {
  results_all <- load_results(comps, alt)
  plot_data <- results_all %>%
    filter(model %in% model_list) %>%
    mutate(
      match = case_when(
        heuristic == correct ~ "Match",
        TRUE ~ "Mismatch"
      )
    ) %>%
    group_by(model, match, heuristic, correct, prompt_length) %>%
    summarize(
      ste = 1.96 * plotrix::std.error(accuracy),
      accuracy = mean(accuracy)
    ) %>% 
    ungroup() %>%
    filter(model %in% MODEL_LIST)
  
  avg_accuracy <- results_all %>%
    filter(model %in% model_list) %>%
    group_by(model, heuristic, prompt_length) %>%
    summarize(
      ste = 1.96 * plotrix::std.error(accuracy),
      accuracy = mean(accuracy)
    )
  
  bind_rows(
    plot_data,
    avg_accuracy %>% mutate(match = "Avg")
  ) %>%
    mutate(
      match = factor(
        match, levels = c("Match", "Mismatch", "Avg"), 
        labels = c("Heuristics work", "Heuristics don't work", "Avg Performance")
      ),
      ste = case_when(
        prompt_length == 0 ~ NA,
        TRUE ~ ste
      ),
      model = factor(model, levels=model_list, labels=MODEL_NAMES),
      heuristic = factor(heuristic, levels = c("First", "Recent"), labels = c("<span style='font-size: 9pt; color:#3D85C6'>FIRST-CORRECT</span>", "<span style='font-size: 9pt;color:#F1C232'>RECENT-CORRECT</span>"))
    )
}


create_detailed_plot <- function(comps="comps", model_list=MODEL_LIST, legend_position = "top", alt=FALSE) {
  plot_data <- get_detailed_results(comps, model_list, alt)
  plot_data %>% 
    ggplot(aes(prompt_length, accuracy, color = match, fill=match, shape = match, linetype=match)) +
    geom_point(size = 2.5) +
    geom_line(linewidth = 0.8) +
    # geom_point(data=avg_accuracy, aes(prompt_length, accuracy, color, ))
    geom_ribbon(aes(ymin = accuracy-ste, ymax=accuracy+ste), alpha = 0.4, color = NA) +
    # geom_errorbar(aes(ymin = accuracy-ste, ymax=accuracy+ste), width = 0.3) +
    geom_hline(yintercept = 0.5, linetype="dashed") +
    scale_x_continuous(breaks = scales::pretty_breaks(7)) +
    scale_y_continuous(limits = c(0.0, 1.0)) +
    # scale_color_brewer(palette = "Dark2", direction = -1) +
    scale_color_manual(values = c("#d95f02", "#7570b3", "#4a4a4a"), aesthetics = c("color", "fill")) +
    scale_linetype_manual(values=c("dotted", "dashed", "solid")) +
    guides(
      color = guide_legend(keywidth = 2),
      fill = guide_legend(keywidth = 2),
      shape = guide_legend(keywidth = 2),
      linetype = guide_legend(keywidth = 2)
    ) +
    facet_grid(heuristic ~ model) +
    theme(
      legend.title = element_blank(),
      legend.position = legend_position,
      strip.text.y = element_markdown(size = 11),
      strip.background.y = element_rect(fill = "white")
    ) +
    labs(
      x = "Number of Exemplars (0 = Zero-shot)",
      y = "Accuracy (+95% CI)",
    )
}

create_detailed_plot("comps")

get_matched_results <- function(comps="comps", model_list=MODEL_LIST, alt=FALSE) {
  instructions = FALSE
  if(str_detect(comps, "instruction")) {
    instructions = TRUE
  }
  
  results_full <- load_results(comps, alt)
  
  plot_data <- results_full %>%
    filter(model %in% model_list) %>%
    mutate(
      match = case_when(
        heuristic == correct ~ "Match",
        TRUE ~ "Mismatch"
      )
    ) %>%
    group_by(model, match, prompt_length) %>%
    summarize(
      ste = 1.96 * plotrix::std.error(accuracy),
      accuracy = mean(accuracy)
    )
  
  avg_accuracy <- results_full %>%
    filter(model %in% model_list) %>%
    group_by(model, prompt_length) %>%
    summarize(
      ste = 1.96 * plotrix::std.error(accuracy),
      accuracy = mean(accuracy)
    )
  
  bind_rows(
    plot_data,
    avg_accuracy %>% mutate(match = "Avg")
  ) %>%
    mutate(
      match = factor(
        match, levels = c("Match", "Mismatch", "Avg"), 
        labels = c("Heuristics work", "Heuristics don't work", "Avg Performance")
      ),
      ste = case_when(
        prompt_length == 0 ~ NA,
        TRUE ~ ste
      ),
      model = factor(model, levels=model_list, labels=MODEL_NAMES)
    )
}

create_plot <- function(comps="comps", model_list=MODEL_LIST, legend_position = "top", alt=FALSE) {
  plot_data = get_matched_results(comps, model_list, alt)
  plot_data %>%
    ggplot(aes(prompt_length, accuracy, color = match, fill=match, shape = match, linetype = match)) +
    geom_point(size = 2.5) +
    geom_line(linewidth = 0.8) +
    # geom_point(data=avg_accuracy, aes(prompt_length, accuracy, color, ))
    geom_ribbon(aes(ymin = accuracy-ste, ymax=accuracy+ste), alpha = 0.4, color = NA) +
    # geom_errorbar(aes(ymin = accuracy-ste, ymax=accuracy+ste), width = 0.3) +
    geom_hline(yintercept = 0.5, linetype="dashed") +
    scale_x_continuous(breaks = scales::pretty_breaks(7)) +
    scale_y_continuous(limits = c(0.0, 1.0)) +
    # scale_color_brewer(palette = "Dark2", direction = -1) +
    scale_color_manual(values = c("#d95f02", "#7570b3", "#4a4a4a"), aesthetics = c("color", "fill")) +
    scale_linetype_manual(values=c("dotted", "dashed", "solid")) +
    guides(
      color = guide_legend(keywidth = 2),
      fill = guide_legend(keywidth = 2),
      shape = guide_legend(keywidth = 2),
      linetype = guide_legend(keywidth = 2)
    ) +
    facet_wrap(~ model, nrow = 1) +
    theme(
      legend.title = element_blank(),
      legend.position = legend_position
    ) +
    labs(
      x = "Number of Exemplars (0 = Zero-shot)",
      y = "Accuracy (+95% CI)",
    )
}

generate_plots <- function() {
  p <- create_plot("comps")
  ggsave("analysis/figures/comps-combined.pdf", p, height=3.5, width=12, dpi=300, device=cairo_pdf)
  
  p <- create_detailed_plot("comps")
  ggsave("analysis/figures/comps-detailed.pdf", p, height=4.25, width=8.75, dpi=300, device=cairo_pdf)
  
  p <- create_plot("comps-instructions")
  ggsave("analysis/figures/comps-instructions-combined.pdf", p, height=3.5, width=12, dpi=300, device=cairo_pdf)
  
  p <- create_detailed_plot("comps-instructions")
  ggsave("analysis/figures/comps-instructions-detailed.pdf", p, height=4.25, width=8.75, dpi=300, device=cairo_pdf)
  
  p <- create_plot("comps-qa")
  ggsave("analysis/figures/comps-qa-combined.pdf", p, height=3.5, width=12, dpi=300, device=cairo_pdf)
  
  p <- create_detailed_plot("comps-qa")
  ggsave("analysis/figures/comps-qa-detailed.pdf", p, height=4.25, width=8.75, dpi=300, device=cairo_pdf)
  
  p <- create_plot("comps-qa-instructions")
  ggsave("analysis/figures/comps-qa-instruction-combined.pdf", p, height=3.5, width=12, dpi=300, device=cairo_pdf)
  
  p <- create_detailed_plot("comps-qa-instructions")
  ggsave("analysis/figures/comps-qa-instruction-detailed.pdf", p, height=4.25, width=8.75, dpi=300, device=cairo_pdf)
}

generate_plots()
create_plot("comps")

p1 <- create_plot("comps")
p2 <- create_plot("comps-instructions", legend_position = "None")
p3 <- create_plot("comps-qa", legend_position = "None")
p4 <- create_plot("comps-qa-instructions", legend_position="None")

p5 <- p1 / p2 / p3 / p4


p5
# ggsave("analysis/figures/comps-qa-instruction-combined.pdf", p, height=3.5, width=12, dpi=300, device=cairo_pdf)

p1 <- create_detailed_plot("comps")
p2 <- create_detailed_plot("comps-instructions", legend_position = "None")
p3 <- create_detailed_plot("comps-qa", legend_position = "None")
p4 <- create_detailed_plot("comps-qa-instructions", legend_position = "None")

p5 <- p1 / p2 / p3 / p4
p5

ggsave("analysis/figures/main-results-stacked-detailed.pdf", p5, height=12, width=7, dpi=300, device=cairo_pdf)

matched_results <- bind_rows(
  get_matched_results("comps") %>% mutate(dataset="COMPS", instructions = "no_instructions"),
  get_matched_results("comps-instructions") %>% mutate(dataset="COMPS", instructions = "instructions"),
  get_matched_results("comps-qa") %>% mutate(dataset="COMPS-QA", instructions = "no_instructions"),
  get_matched_results("comps-qa-instructions") %>% mutate(dataset="COMPS-QA", instructions = "instructions")
) %>%
  select(-ste) %>%
  pivot_wider(names_from = instructions, values_from = accuracy)




combined <- bind_rows(
  # get_matched_results("comps") %>% mutate(dataset="COMPS"),
  # get_matched_results("comps-instructions") %>% mutate(dataset="COMPS w/ Inst"),
  get_matched_results("comps-qa") %>% mutate(dataset="COMPS-QA"),
  get_matched_results("comps-qa-instructions") %>% mutate(dataset="COMPS-QA w/ Inst")
)

combined %>%
  ggplot(aes(prompt_length, accuracy, color = match, fill=match, shape = match, linetype=match)) +
  geom_point(size = 2.5) +
  geom_line(linewidth = 0.8) +
  # geom_point(data=avg_accuracy, aes(prompt_length, accuracy, color, ))
  geom_ribbon(aes(ymin = accuracy-ste, ymax=accuracy+ste), alpha = 0.4, color = NA) +
  # geom_errorbar(aes(ymin = accuracy-ste, ymax=accuracy+ste), width = 0.3) +
  geom_hline(yintercept = 0.5, linetype="dashed") +
  scale_x_continuous(breaks = scales::pretty_breaks(7)) +
  scale_y_continuous(limits = c(0.0, 1.0)) +
  # scale_color_brewer(palette = "Dark2", direction = -1) +
  scale_color_manual(values = c("#d95f02", "#7570b3", "#4a4a4a"), aesthetics = c("color", "fill")) +
  scale_linetype_manual(values=c("dotted", "dashed", "solid")) +
  guides(
    color = guide_legend(keywidth = 2),
    fill = guide_legend(keywidth = 2),
    shape = guide_legend(keywidth = 2),
    linetype = guide_legend(keywidth = 2)
  ) +
  facet_grid(dataset ~ model) +
  theme(
    legend.title = element_blank()
  ) +
  labs(
    x = "Number of Exemplars (0 = Zero-shot)",
    y = "Accuracy (+95% CI)",
  )

ggsave("analysis/figures/main-results.pdf", height=8, width=9.8, dpi=300, device=cairo_pdf)


combined_detailed <- bind_rows(
  get_detailed_results("comps") %>% mutate(dataset="COMPS"),
  get_detailed_results("comps-instructions") %>% mutate(dataset="COMPS w/ Inst"),
  get_detailed_results("comps-qa") %>% mutate(dataset="COMPS-QA"),
  get_detailed_results("comps-qa-instructions") %>% mutate(dataset="COMPS-QA w/ Inst")
)

combined_detailed %>%
  ggplot(aes(prompt_length, accuracy, color = match, fill = match, shape = match, linetype=match)) +
  geom_point(size = 2.5) +
  geom_line(linewidth = 0.8) +
  # geom_point(data=avg_accuracy, aes(prompt_length, accuracy, color, ))
  geom_ribbon(aes(ymin = accuracy-ste, ymax=accuracy+ste), alpha = 0.4, color = NA) +
  # geom_errorbar(aes(ymin = accuracy-ste, ymax=accuracy+ste), width = 0.3) +
  geom_hline(yintercept = 0.5, linetype="dashed") +
  scale_x_continuous(breaks = scales::pretty_breaks(7)) +
  scale_y_continuous(limits = c(0.0, 1.0)) +
  # scale_color_brewer(palette = "Dark2", direction = -1) +
  scale_color_manual(values = c("#d95f02", "#7570b3", "#4a4a4a"), aesthetics = c("color", "fill")) +
  scale_linetype_manual(values=c("dotted", "dashed", "solid")) +
  guides(
    color = guide_legend(keywidth = 2),
    fill = guide_legend(keywidth = 2),
    shape = guide_legend(keywidth = 2),
    linetype = guide_legend(keywidth = 2)
  ) +
  facet_grid(model ~ dataset + heuristic) +
  theme(
    legend.title = element_blank(),
    legend.position = "top"
  ) +
  labs(
    x = "Number of Exemplars (0 = Zero-shot)",
    y = "Accuracy (+95% CI)",
  )

ggsave("analysis/figures/main-results-detailed.pdf", height=9.5, width=14.25, dpi=300, device=cairo_pdf)
