library(tidyverse)
library(politeness)

#######################
#' Load Data
#' 
#' This section loads data that cannot be shared. model_data contains a uniqe participant identifier column, 
#' outcome data, treatment indicator, and demographic controls. politeness_receiver_data and ps_data_output 
#' contain the politeness cores created as described in Section 4 of the paper. The former contains the scores
#' from the R package politeness, the latter the cores from the Python library convokit. 
#' 
#' At the bottom of this file, we provide resisdualised versions of the Y, T, and M variables where the linear 
#' projection of X (the control variabels) have been removed. These data permit replication of the regression
#' results and mediation sensitivity without the identifying control variables. 
#' 
#' This analysis takes less than 15 minutes to run on a standard MacBook Pro laptop.
#' 
#######################

setwd("..")
source("analysis_to_model_data.R")
setwd("mediation_paper")



politeness_receiver_data <- read_csv("../gt_input/politeness_receiver_data.csv")
ps_data_output <- read_csv("../convokit/convokit_user_scores.csv")

model_data <- model_data %>% 
  left_join(politeness_receiver_data %>% rename(App_install_code=receiver)) %>% 
  left_join(ps_data_output %>% rename(App_install_code=receiver)) 

model_data <- model_data %>% 
  mutate(politeness_index_mean = polite_index_all/count,
         ps_convokit_mean = ps_convokit/count
  ) %>% 
  mutate_at(vars(starts_with("polite_index"),projection,ps_sum,ps_convokit),function(x) x/sd(x,na.rm = T))

#final dataset
mediation_data <- model_data %>% 
  filter(displayCondition == "ROBUST" | fully_compliant == 1) %>% #keep control and completes only
  mutate(mediator = ifelse(treatment_assigned == 1,polite_index,NA), #set politeness index as mediator
         mediator = mediator/sd(mediator,na.rm=T),
         depolarization_index_winsor = depolarization_index_winsor/sd(depolarization_index_winsor),
         immigration = topic == "immigration")

demogs <- c("gender","age","college","white","regionMidwest","regionSouth","regionWest")

mediation_data %>% 
  filter(treatment_assigned == 1) %>% 
  mutate(topic = ifelse(topic == "immigration","Immigration","Gun Control")) %>% 
  ggplot(aes(x=polite_index)) + geom_histogram(bins = 50) + facet_wrap(~ topic,ncol=1) + 
  theme_bw()

mediation_data %>% 
  filter(treatment_assigned == 1) %>% 
  mutate(Topic = ifelse(topic == "immigration","Immigration","Gun Control")) %>% 
  ggplot(aes(x=polite_index,y=depolarization_index_change,color = Topic)) + 
  geom_point() + geom_smooth(method = "lm",se = F) +
  theme_bw() + 
  labs(x = latex2exp::TeX("Partner's Politeness ($\\sigma$)"),
       y = latex2exp::TeX("Depolarization ($\\sigma$)")) + 
  theme(legend.position = "bottom")
ggsave("output/depolarization_polieness_by_toipc.png",width=8,height=7)

###############################
### preliminary regressions ###
###############################

cvars <- c(demogs)

y.model.treat <- lm(str_c("depolarization_index_winsor ~ mediator + democrat + ",str_c(cvars,collapse = " + ")),
                    data = mediation_data %>% filter(treatment_assigned==1))
y.model.treat.dem <- lm(str_c("depolarization_index_winsor ~ mediator + democrat + ",str_c(cvars,collapse = " + ")),
                        data = mediation_data %>% filter(treatment_assigned==1,democrat == 1))
y.model.treat.rep <- lm(str_c("depolarization_index_winsor ~ mediator + democrat + ",str_c(cvars,collapse = " + ")),
                        data = mediation_data %>% filter(treatment_assigned==1,democrat == 0))

stargazer::stargazer(y.model.treat,y.model.treat.dem,y.model.treat.rep,dep.var.labels = "Depolarization Index",
                     column.labels = c("All","Dem","Rep"),type="text",
                     star.cutoffs = c(0.05,0.01,0.001),
                     omit = c(str_c(demogs,collapse = "|")),
                     omit.stat = c("rsq","adj.rsq","ser","f"),
                     covariate.labels = c("Politeness","Democrat"),
                     dep.var.labels.include = F,omit.table.layout = "lon",
                     no.space = T,column.sep.width = "2.5pt"
)

# mediator model for imputation
mhat.mod <- lm(mediator ~ democrat + gender + age + college + white + regionMidwest + regionSouth + regionWest,
               data = mediation_data %>% filter(treatment_assigned==1))
mediation_data$mediator_mhat <- case_when(mediation_data$treatment_assigned==1 ~ mediation_data$mediator,
                                          mediation_data$treatment_assigned==0 ~ predict(mhat.mod,newdata = mediation_data))


m.model.all <- lm(str_c("mediator ~ democrat + ",str_c(cvars,collapse = " + ")),
                  data = mediation_data)
y.model.treat <- lm(str_c("depolarization_index_winsor ~ treatment_assigned + mediator_mhat + democrat + ",str_c(cvars,collapse = " + ")),
                    data = mediation_data %>% filter(treatment_assigned==1))
y.model.noM <- lm(str_c("depolarization_index_winsor ~ treatment_assigned + democrat + ",str_c(cvars,collapse = " + ")),
                  data = mediation_data)
y.model.all <- lm(str_c("depolarization_index_winsor ~ treatment_assigned + mediator_mhat + democrat + ",str_c(cvars,collapse = " + ")),
                  data = mediation_data)

stargazer::stargazer(y.model.treat,y.model.noM,y.model.all,type="text",
                     no.space = T,column.sep.width = "2.5pt",
                     dep.var.labels.include = F,
                     #column.labels = c("T Only","Total Effect","")
                     omit = c(str_c(demogs,collapse = "|")),
                     covariate.labels = c("Treatment","Politeness","Democrat"),
                     omit.stat = c("f","ser","rsq","adj.rsq"),
                     star.cutoffs = c(0.05,0.01,0.001),digits=3)

# upper bound on beta
beta_ub <- y.model.noM$coefficients["treatment_assigned"]/y.model.treat$coefficients["mediator_mhat"]

#function to extract mediation effects for a given beta
get_mediation_effect <- function(control_delta,data=mediation_data,print=F,seed = NULL,
                                 mediation_package = TRUE,sims=2000){
  
  if(!is.null(seed)){
    set.seed(seed)
  }
  
  
  # subtract beta from mediator values in control
  mediation_data$mediator_mhat[mediation_data$treatment_assigned == 0] <- mediation_data$mediator_mhat[mediation_data$treatment_assigned == 0] - control_delta
  
  #estimate models
  y.model <- lm(depolarization_index_winsor ~ treatment_assigned + mediator_mhat + democrat + gender + age + college + white + regionMidwest + regionSouth + regionWest,
                data = mediation_data)
  m.model <- lm(mediator_mhat ~ treatment_assigned + democrat + gender + age + college + white + regionMidwest + regionSouth + regionWest,
                data = mediation_data)
  
  if(mediation_package){
    set.seed(NULL) #ensure seed is not set for bootstraping
    med.results <- mediation::mediate(m.model,y.model,treat = "treatment_assigned",mediator = "mediator_mhat",
                                      boot = TRUE,sims = sims)
  }
  
  if(print){
    stargazer::stargazer(y.model,m.model,star.cutoffs = c(0.05,0.01,0.001),type = "text")
    med.results %>% summary
  }
  
  #d0 - indirect effect, z0- direct effect
  return(list(ide=med.results$d0,ide_ci=med.results$d0.ci,
              de = med.results$z0,de_ci=med.results$z0.ci,
              tot = med.results$tau.coef,tot_ci=med.results$tau.ci,
              control_delta = control_delta,
              med.results = med.results))
}


# small range to detect first significant effect
effect_grid_small_range = seq(0,0.15,length.out = 16)
list_results_small_range <- parallel::mclapply(effect_grid_small_range, get_mediation_effect,
                                               seed=NULL,mc.cores = 4,sims=2000)
saveRDS(list_results_small_range,"output/list_results_small_range.rds")

# large range to make plot
effect_grid_large_range = seq(0,beta_ub,length.out = 16)
list_results_large_range <- parallel::mclapply(effect_grid_large_range, get_mediation_effect,
                                               seed=NULL,mc.cores = 4,sims=2000)
saveRDS(list_results_large_range,"output/list_results_large_range.rds")

#combine results into a dataframe
results <- map_df(c(list_results_small_range,list_results_large_range),function(l){
  bind_rows(tibble(Effect = "Indirect",est = l$ide,lb=l$ide_ci[1],ub=l$ide_ci[2]),
            tibble(Effect = "Direct",  est = l$de,lb=l$de_ci[1],ub=l$de_ci[2]),
            tibble(Effect = "Total",est = l$tot,lb=l$tot_ci[1],ub=l$tot_ci[2])) %>% 
    mutate(delta = l$control_delta)
})

# figure in paper 
results %>% 
  mutate(pannel = ifelse(Effect == "Total","Total Effect","Effect Decomposition")) %>% 
  ggplot(aes(x=delta,y=est,ymin=lb,ymax=ub)) + 
  geom_line(aes(color=Effect)) + #geom_point(aes(color=Effect)) + 
  geom_ribbon(aes(fill=Effect),alpha=.3) + 
  geom_hline(yintercept = 0,color = "black") + 
  facet_wrap(~ pannel) + 
  labs(x=latex2exp::TeX("Effect of Treatment on Mediator ($\\beta$)"),y="Causal Effect on Outcome") + #add $\beta$ to x label
  scale_color_brewer(palette = "Set1") + 
  scale_fill_brewer(palette = "Set1") + 
  theme_bw() + 
  theme(legend.position = "bottom")
ggsave("output/effect_decomp.png",width = 8,height = 7)

results %>% filter(Effect == "Indirect")

########################################
### Sharable data are generated here ###
########################################

X <- mediation_data %>% select(all_of(c("democrat",demogs))) %>% as.matrix() %>% 
  cbind(.,matrix(1,nrow = nrow(mediation_data),ncol = 1))
A_X <- diag(1,nrow=nrow(X)) - X %*% solve(t(X) %*% X) %*% t(X)

# sharable data
Y.share <- A_X %*% mediation_data$depolarization_index_winsor
M.share <- A_X %*% mediation_data$mediator_mhat
T.share <- A_X %*% mediation_data$treatment_assigned

# example replications
lm(Y.share ~ T.share + M.share) %>% summary
lm(Y.share ~ T.share) %>% summary

save(Y.share,M.share,T.share,file = "output/replication_data.Rdata")


