library(tidyverse)
library(fs)

theme_set(
  theme_bw(base_size = 15, base_family = "Times") +
    theme(
      legend.position = "top",
      axis.text = element_text(color = "black"),
      # panel.grid = element_blank()
    )
)

instresults <- dir_ls("data/results/comps-qa-instructions/", regexp = "*.csv") %>%
  map_df(read_csv, .id = "model") %>%
  # replace_na(list(heuristic = "0-shot")) %>%
  mutate(
    model = str_remove(model, ".csv"),
    model = str_remove(model, "data/results/(comps\\-qa|comps)(-instructions)?/"),
    model = str_remove(model, "(_home_shared_|facebook_|mistralai_|meta-llama_)"),
    model = factor(
      model, 
      levels = c("gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl", "opt-125m", "opt-350m", "opt-1.3b", "opt-2.7b", "opt-6.7b", "llama2-7b-hf","Llama-2-13b-hf", "Mistral-7B-v0.1")
    ),
    # correct = case_when(
    #   correct == "second" ~ "Recent",
    #   TRUE ~ "First"
    # ),
    # heuristic = case_when(
    #   heuristic == "second" ~ "Recent",
    #   TRUE ~ "First"
    # ),
    # match = case_when(
    #   correct == heuristic ~ "Matched",
    #   TRUE ~ "Mismatched"
    # ),
    correct = str_replace(correct, "first", "First"),
    correct = str_replace(correct,"second", "Recent"),
    correct = str_replace(correct,"alt", "Alt"),
    heuristic = str_replace(heuristic,"first", "First"),
    heuristic = str_replace(heuristic,"second", "Recent"),
    heuristic = str_replace(heuristic,"alt", "Alt"),
    heuristic = str_replace(heuristic, "none", "0-shot"),
    heuristic = factor(heuristic, levels = c("0-shot", "First", "Recent", "Alt-First", "Alt-Recent"))
  ) %>%
  filter(!is.na(model))

instresults %>% count(model)

inst_zero_shot_base <- instresults %>% filter(heuristic == '0-shot')

inst_zero_shot <- bind_rows(
  inst_zero_shot_base %>% mutate(heuristic = "First"),
  inst_zero_shot_base %>% mutate(heuristic = "Alt-First"),
  inst_zero_shot_base %>% mutate(heuristic = "Recent"),
  inst_zero_shot_base %>% mutate(heuristic = "Alt-Recent")
) %>%
  mutate(
    heuristic = factor(heuristic, levels = c("0-shot", "First", "Recent", "Alt-First", "Alt-Recent"))
  )

instresults_full <- bind_rows(
  instresults %>%
    filter(heuristic != "0-shot"),
  inst_zero_shot
) %>%
  group_by(model, prompt_domain, prompt_id, heuristic, prompt_length, correct) %>%
  nest() %>%
  mutate(
    best = map(data, function(x) {
      x %>% arrange(-accuracy) %>% slice(1)
    })
  ) %>% select(-data) %>% unnest(best) %>%
  ungroup()

best_instructions <- bind_rows(
  instresults %>%
    filter(heuristic != "0-shot"),
  inst_zero_shot
) %>% filter(!str_detect(heuristic, "Alt")) %>%
  group_by(model, instruction_type) %>% 
  summarize(acc = mean(accuracy)) %>% 
  filter(acc == max(acc)) %>%
  ungroup()

instresults_full <- bind_rows(
  instresults %>%
    filter(heuristic != "0-shot"),
  inst_zero_shot
) %>% inner_join(best_instructions)

# instresults_alt <- bind_rows(
#   instresults %>%
#     filter(heuristic != "0-shot"),
#   inst_zero_shot
# ) %>%
#   mutate(
#     match = case_when(
#       correct == heuristic ~ "Matched",
#       TRUE ~ "Mismatched"
#     ),
#   ) %>%
#   filter(!str_detect(heuristic, "Alt")) %>%
#   group_by(model, prompt_domain, prompt_id, prompt_length, correct, match) %>%
#   nest() %>%
#   mutate(
#     best = map(data, function(x) {
#       x %>% arrange(-accuracy) %>% slice(1)
#     })
#   ) %>% select(-data) %>% unnest(best) %>%
#   ungroup()

# bind_rows(
#   instresults %>%
#     filter(heuristic != "0-shot"),
#   inst_zero_shot
# ) %>% 
#   filter(str_detect(model, "Llama-2-13b-hf")) %>%
#   group_by(instruction_type, heuristic, prompt_length, correct) %>%
#   summarize(
#     ste = 1.96 * plotrix::std.error(accuracy),
#     accuracy = mean(accuracy)
#   ) %>%
#   ggplot(aes(prompt_length, accuracy, color = correct, shape = correct)) +
#   geom_point(size = 2.5) +
#   geom_line(linewidth = 0.8) +
#   geom_errorbar(aes(ymin = accuracy-ste, ymax=accuracy+ste), width = 0.3) +
#   geom_hline(yintercept = 0.5, linetype="dashed") +
#   scale_x_continuous(breaks = scales::pretty_breaks(7)) +
#   scale_y_continuous(limits = c(0.0, 1.0)) +
#   scale_color_manual(values = c("#7570b3", "#d95f02")) +
#   ggh4x::facet_grid2(heuristic ~ instruction_type, scales = "free_x", independent = "x") +
#   labs(
#     x = "Number of Exemplars (0 = Zero-shot)",
#     y = "Accuracy",
#     color = "Test ground-truth",
#     fill = "Test ground-truth",
#     shape = "Test ground-truth",
#   )

plot_data <- instresults_full %>%
  filter(model %in% c("gpt2-xl", "opt-6.7b", "llama2-7b-hf", "Llama-2-13b-hf", "Mistral-7B-v0.1")) %>%
  filter(!str_detect(heuristic, "Alt")) %>%
  group_by(model, heuristic, correct, prompt_length) %>%
  summarize(
    ste = 1.96 * plotrix::std.error(accuracy),
    accuracy = mean(accuracy)
  )

plot_data %>% 
  ggplot(aes(prompt_length, accuracy, color = correct, fill = correct, shape = correct)) +
  geom_point(size = 2.5) +
  geom_line(linewidth = 0.8) +
  geom_ribbon(aes(ymin = accuracy-ste, ymax=accuracy+ste), alpha = 0.4, color = NA) +
  # geom_errorbar(aes(ymin = accuracy-ste, ymax=accuracy+ste), width = 0.3) +
  geom_hline(yintercept = 0.5, linetype="dashed") +
  scale_x_continuous(breaks = scales::pretty_breaks(7)) +
  scale_y_continuous(limits = c(0.0, 1.0)) +
  scale_color_manual(values = c("#7570b3", "#d95f02"), aesthetics = c("color", "fill")) +
  ggh4x::facet_grid2(heuristic ~ model, scales = "free_x", independent = "x") +
  labs(
    x = "Number of Exemplars (0 = Zero-shot)",
    y = "Accuracy",
    color = "Test ground-truth",
    fill = "Test ground-truth",
    shape = "Test ground-truth",
  )

## match mismatch

plot_data_alt <- instresults_full %>%
  filter(model %in% c("gpt2-xl", "opt-6.7b", "llama2-7b-hf", "Llama-2-13b-hf", "Mistral-7B-v0.1")) %>%
  filter(!str_detect(heuristic, "Alt")) %>%
  mutate(
    match = case_when(
      heuristic == correct ~ "Match",
      TRUE ~ "Mismatch"
    )
  ) %>%
  group_by(model, match, prompt_length) %>%
  summarize(
    ste = 1.96 * plotrix::std.error(accuracy),
    accuracy = mean(accuracy)
  )

avg_accuracy <- instresults_full %>%
  filter(model %in% c("gpt2-xl", "opt-6.7b", "llama2-7b-hf", "Llama-2-13b-hf", "Mistral-7B-v0.1")) %>%
  filter(!str_detect(heuristic, "Alt")) %>%
  group_by(model, prompt_length) %>%
  summarize(
    ste = 1.96 * plotrix::std.error(accuracy),
    accuracy = mean(accuracy)
  )

bind_rows(
  plot_data_alt,
  avg_accuracy %>% mutate(match = "Avg")
) %>%
  mutate(
    match = factor(
      match, levels = c("Match", "Mismatch", "Avg"), 
      labels = c("Heuristics work", "Heuristics don't work", "Avg Performance")
    ),
    ste = case_when(
      prompt_length == 0 ~ NA,
      TRUE ~ ste
    )
  ) %>%
# plot_data_alt %>% 
  ggplot(aes(prompt_length, accuracy, color = match, fill=match, shape = match)) +
  geom_point(size = 2.5) +
  geom_line(linewidth = 0.8) +
  # geom_point(data=avg_accuracy, aes(prompt_length, accuracy, color, ))
  geom_ribbon(aes(ymin = accuracy-ste, ymax=accuracy+ste), alpha = 0.4, color = NA) +
  # geom_errorbar(aes(ymin = accuracy-ste, ymax=accuracy+ste), width = 0.3) +
  geom_hline(yintercept = 0.5, linetype="dashed") +
  scale_x_continuous(breaks = scales::pretty_breaks(7)) +
  scale_y_continuous(limits = c(0.0, 1.0)) +
  # scale_color_brewer(palette = "Dark2", direction = -1) +
  scale_color_manual(values = c("#7570b3", "#d95f02", "#4a4a4a"), aesthetics = c("color", "fill")) +
  facet_wrap(~ model, nrow = 1) +
  theme(
    legend.title = element_blank()
  ) +
  labs(
    x = "Number of Exemplars (0 = Zero-shot)",
    y = "Accuracy",
    # color = "Train-test Heuristics Setup",
    # fill = "Train-test Heuristics Setup",
    # shape = "Train-test Heuristics Setup",
  )

instresults_full %>%
  filter(model %in% c("gpt2-xl", "opt-6.7b", "llama2-7b-hf", "Llama-2-13b-hf", "Mistral-7B-v0.1")) %>%
  filter(!str_detect(heuristic, "Alt")) %>% View()

fake_data_ideal <- tibble(
  model = 
)

plot_data %>% 
  filter(model == "gpt2-xl") %>%
  mutate(accuracy = accuracy + 10,ste = ste+100, model = "MODEL") %>%
  ggplot(aes(prompt_length, accuracy, color = correct, fill = correct, shape = correct)) +
  geom_point(size = 2.5) +
  geom_line(linewidth = 0.8) +
  # geom_ribbon(aes(ymin = accuracy-ste, ymax=accuracy+ste), alpha = 0.4, color = NA) +
  geom_errorbar(aes(ymin = accuracy-ste, ymax=accuracy+ste), width = 0.3) +
  geom_hline(yintercept = 0.5, linetype="dashed") +
  scale_x_continuous(breaks = scales::pretty_breaks(7)) +
  scale_y_continuous(limits = c(0.0, 1.0)) +
  scale_color_manual(values = c("#7570b3", "#d95f02")) +
  ggh4x::facet_grid2(heuristic ~ model, scales = "free_x", independent = "x") +
  labs(
    x = "Number of Exemplars (0 = Zero-shot)",
    y = "Accuracy",
    color = "Test ground-truth",
    fill = "Test ground-truth",
    shape = "Test ground-truth",
  )


plot_data %>%
  group_by(instruction_type, heuristic, correct, prompt_length) %>%
  summarize(
    ste = 1.96 * plotrix::std.error(accuracy),
    accuracy = mean(accuracy)
  ) %>%
  ungroup() %>%
  ggplot(aes(prompt_length, accuracy, color = correct, shape = correct)) +
  geom_point(size = 3) +
  geom_errorbar(aes(ymin = accuracy-ste, ymax=accuracy+ste), width = 0.3) +
  geom_line() +
  scale_x_continuous(breaks = scales::pretty_breaks(7)) +
  scale_y_continuous(limits = c(0.0, 1.0)) +
  # facet_wrap(~instruction_type, nrow=1)
  ggh4x::facet_grid2(heuristic ~ instruction_type, scales = "free_x", independent = "x") +
  scale_color_manual(values = c("#7570b3", "#d95f02")) +
  labs(
    x = "Number of Exemplars (0 = Zero-shot)",
    y = "Accuracy",
    color = "Correct answer in Test",
    shape = "Correct answer in Test"
  )


# matched vs mismatched
plot_data %>%
  group_by(instruction_type, match, correct, prompt_length) %>%
  summarize(
    ste = 1.96 * plotrix::std.error(accuracy),
    accuracy = mean(accuracy)
  ) %>%
  ungroup() %>%
  ggplot(aes(prompt_length, accuracy, color = correct, shape = correct, linetype=match)) +
  # geom_point(size = 3) +
  geom_errorbar(aes(ymin = accuracy-ste, ymax=accuracy+ste), width = 0.3, linetype="solid") +
  geom_line(linewidth=0.7) +
  scale_x_continuous(breaks = scales::pretty_breaks(7)) +
  scale_y_continuous(limits = c(0.0, 1.0)) +
  facet_wrap(~instruction_type, nrow=1) +
  # ggh4x::facet_grid2(match ~ instruction_type, scales = "free_x", independent = "x") +
  scale_color_manual(values = c("#7570b3", "#d95f02")) +
  labs(
    x = "Number of Exemplars (0 = Zero-shot)",
    y = "Accuracy",
    color = "Correct answer in Test",
    shape = "Correct answer in Test"
  )


instresults_full %>%
  filter(!str_detect(heuristic, "alt")) %>%
  group_by(instruction_type, heuristic, correct, prompt_length) %>%
  summarize(
    ste = 1.96 * plotrix::std.error(accuracy),
    accuracy = mean(accuracy)
  ) %>%
  select(-ste) %>%
  pivot_wider(names_from = correct, values_from = accuracy) %>%
  mutate(
    total = first + second
  ) %>%
  View()
