library(tidyverse)
library(fs)
library(glue)
library(patchwork)
library(rstatix)
library(lme4)
library(lmerTest)
library(ggeffects)
library(emmeans)

theme_set(
  theme_bw(base_size = 15, base_family = "Times") +
    theme(
      legend.position = "top",
      axis.text = element_text(color = "black"),
      # panel.grid = element_blank() # no gridlines
    )
)

LEVELS = c("gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl", "opt-125m", "opt-350m", "opt-1.3b", "opt-2.7b", "opt-6.7b", "llama2-7b-hf","Llama-2-13b-hf", "Mistral-7B-v0.1")
MODEL_LIST = c("gpt2-xl", "opt-6.7b", "llama2-7b-hf", "Llama-2-13b-hf", "Mistral-7B-v0.1")
MODEL_NAMES = c("GPT-2 XL", "OPT-6.7b", "Llama-2-7b", "Llama-2-13b", "Mistral-7b")


load_predictions <- function(comps="comps-instructions", model_list=MODEL_LIST, alt=FALSE) {
  results <- dir_ls(glue("data/results/{comps}/predictions/"), regexp = "*.csv") %>%
    map_df(read_csv, .id = "model") %>%
    mutate(
      model = str_remove(model, ".csv"),
      model = str_remove(model, "data/results/(comps\\-qa|comps)(-instructions)?/predictions/"),
      model = str_remove(model, "(_home_shared_|facebook_|mistralai_|meta-llama_)"),
      model = str_remove(model, "_predictions"),
      model = factor(model, levels = LEVELS),
      correct = str_replace(correct, "first", "First"),
      correct = str_replace(correct,"second", "Recent"),
      correct = str_replace(correct,"alt", "Alt"),
      heuristic = str_replace(heuristic,"first", "First"),
      heuristic = str_replace(heuristic,"second", "Recent"),
      heuristic = str_replace(heuristic,"alt", "Alt"),
      heuristic = str_replace(heuristic, "none", "0-shot"),
      heuristic = factor(heuristic, levels = c("0-shot", "First", "Recent", "Alt-First", "Alt-Recent"))
    ) %>%
    filter(!is.na(model)) %>%
    filter(model %in% model_list) %>%
    mutate(
      model = factor(model, levels=model_list, labels=MODEL_NAMES)
    )
  
  return(results)
}

comps <- bind_rows(
  load_predictions("comps-qa-instructions"),
  load_predictions("comps-qa") %>% mutate(instruction_type = "none")
) %>%
  mutate(
    # instruction_present = case_when(
    #   instruction_type == "none" ~ 0.5,
    #   TRUE ~ -0.5
    # ),
    instruction_present = factor(instruction_type == "none"),
    instruction_type = factor(instruction_type),
    icl = factor(prompt_length > 0)
    # icl = case_when(
    #   prompt_length > 0 ~ 0.5,
    #   TRUE ~ -0.5
    # ),
    # icl = contr.sum(factor(icl))
  )

model_data <- comps %>% 
  # filter(prompt_length == 6) %>%
  filter(model == "Mistral-7b")

fit1 <- lme4::glmer(prediction ~ icl * instruction_present + (1 | instruction_type) + (1|idx), data = model_data, family = binomial(link = "logit"))

fit1_icl_main <- lme4::glmer(prediction ~ icl + icl:instruction_present + (1 | instruction_type) + (1|idx), data = model_data, family = binomial(link = "logit"))

fit1_instructions_main <- lme4::glmer(prediction ~ instruction_present + icl:instruction_present + (1 | instruction_type) + (1|idx), data = model_data, family = binomial(link = "logit"))

fit2 <- lme4::glmer(prediction ~ instruction_present + (1 | instruction_type) + (1|idx), data = model_data, family = binomial(link = "logit"))

fit3 <- lme4::glmer(prediction ~ icl + (1 | instruction_type) + (1|idx), data = model_data, family = binomial(link = "logit"))

fit4 <- lme4::glmer(prediction ~ icl + instruction_present + (1 | instruction_type) + (1|idx), data = model_data, family = binomial(link = "logit"))

summary(fit1)
summary(fit1_icl_main)
summary(fit1_instructions_main)
summary(fit2)
summary(fit3)
summary(fit4)

anova(fit1, fit2)

anova(fit1, fit1_icl_main)

anova(fit1, fit1_instructions_main)

anova(fit1, fit3)

anova(fit1, fit4)

emmip(fit1, ~ instruction_present * icl)

ggpredict(fit1, c("instruction_present", "icl")) %>%
  plot() +
  theme_bw(base_size = 15, base_family = "Times")

# do t-test on icl results when instructions present

comps %>%
  group_by(model, instruction_type) %>%
  summarize(acc = mean(prediction==1)) %>% 
  plot()
