library(quanteda)
library(topicmodels)
library(texteffect)
library(tidyverse)

#########################################################################################################
#' 
#' This code loads the full conversation data, then creates the document-term matrix, then runs the
#' LDA and sIBP analysis. Due to confidentiality agreements and the IRB, we cannot share the raw
#' conversation data or the unlabeled document-term matrix. For transparency, we include all of the code 
#' here. 
#' 
#' This analysis takes about an hour to run on a standard MacBook Pro laptop.
#'  
#########################################################################################################

setwd("..")
source("analysis_to_model_data.R")
setwd("mediation_paper")

scripts <- read_csv("../firebase_transcripts.csv")

demogs <- c("immigration","democrat","gender","age","college","white","regionMidwest","regionSouth","regionWest")

# create dataset of participants, Y, X, and T along with full text of the messages sent to them
scripts_completes <- left_join(model_data %>% filter(fully_compliant==1) %>% mutate(receiver=App_install_code,immigration=1*(topic=="immigration")) %>% 
                                 dplyr::select(receiver,depolarization_index_winsor,all_of(c(demogs))),
                               scripts %>% mutate(receiver = ifelse(sender != user1,user2,user1))) %>% 
  group_by(receiver) %>% 
  summarise(across(all_of(c("depolarization_index_winsor",demogs)),mean),
            messages = str_c(message,collapse = " "))


scripts.dfm <- dfm(scripts_completes$messages,tolower = T,remove = stopwords(source = "smart"),remove_punct=TRUE)

scripts.dfm <- scripts.dfm %>% 
  dfm_trim(min_termfreq = 9,termfreq_type = "count")
scripts.topicmodel <- convert(scripts.dfm,to = "topicmodels")
ks <- seq(4,12,2)

lda_fits <- parallel::mclapply(ks,function(k,x) LDA(x,k,method = "VEM",control = list(seed=1:4,nstart=4)),x=scripts.topicmodel)
saveRDS(lda_fits,"output/lda_fits.rds")
#lda_fits <- readRDS("output/lda_fits.rds")

lda_lms <- lapply(lda_fits,function(topic.mod){
  lda.post <- posterior(topic.mod)
  lda.topics <- (lda.post$topics ) %>% as.data.frame()
  colnames(lda.topics) <- str_c("T",1:topic.mod@k)
  
  lm(depolarization_index_winsor ~ .,
     data = scripts_completes %>% bind_cols(lda.topics) %>% dplyr::select(-receiver,-messages)) 
})

baseline_lm <- lm(depolarization_index_winsor ~ .,
  data = scripts_completes %>% dplyr::select(-receiver,-messages))

joint_test <- lapply(lda_lms,function(mod) anova(mod,baseline_lm))

tidy_lda_lm <- function(mod){
  result <- broom::tidy(mod,conf.int = T)
  k <- sum(str_detect(result$term,"^T[0-9]"))
  result %>% mutate(K=k)
}

lda_coef_data <- map_df(lda_lms,tidy_lda_lm) %>% 
  filter(term %>% str_detect("^T[0-9]"),!is.na(estimate)) %>% 
  filter(K %in% c(4,6,8,10,12))
lda_coef_data %>% 
  ggplot(aes(y=estimate,x=factor(K),ymin=conf.low,ymax=conf.high,color=factor(K))) + 
  geom_errorbar(width=1,position = position_dodge2(width = 0.9)) + 
  geom_point(position = position_dodge2(width = 1)) + 
  geom_hline(yintercept = 0) + 
  scale_color_brewer(palette = "Set1") + 
  labs(x="Number of LDA Topics",y="Effect of Topic on Depolarization") + 
  theme_bw() + 
  theme(legend.position = "none")
#ggsave("output/lda_effects.png",width = 6,height = 6)


# examine histogram of p-values relative to null distribution
lda_coef_data %>% select(p.value) %>% unlist %>% hist(breaks=seq(0,1,length.out = 11))
runif(35) %>% hist(breaks=seq(0,1,length.out = 11))

# print regression results
lda_lms %>% stargazer::stargazer(type = "text",
                                 omit = str_c(demogs,collapse = "|"),
                                 omit.stat = c("rsq","adj.rsq","f","ser"),
                                 column.labels = str_c("K=",ks),star.cutoffs = c(0.10,0.05,0.01))

# examine top words from most significant topics
lda_coef_data %>% arrange(-abs(statistic))

terms(lda_fits[[1]],10)
terms(lda_fits[[2]],10)
terms(lda_fits[[4]],10)

sig_words <- cbind(terms(lda_fits[[4]],10)[,7:9],
                       terms(lda_fits[[2]],10)[,5])

sig_words %>% t %>% xtable::xtable() %>% xtable::print.xtable(include.rownames = F)

####################
### sIBP Process ###
####################

library(texteffect)

# create data

set.seed(1)
train.ind <- sample(1:nrow(scripts_completes), size = 0.5*nrow(scripts_completes), replace = FALSE)

Y <- scripts_completes$depolarization_index_winsor
X <- scripts.dfm %>% dfm_sort(decreasing = T,margin = "features") %>% convert(to="data.frame") %>% {.[,-1]}
#drop features that do not appear in training set
X <- X[,!(colSums(X[train.ind,])==0)]

# Search sIBP for several parameter configurations; fit each to the training set
ks_ibp <- c(4,6,8,10,12)

# Return optimal sibp fit using exclusivity metric. 
# Saves memory over sibp_param_search by only keeping best fit 
find_opt_sibp <- function(X,Y,K,alphas,sigmasq.ns,iters,train.ind,seed){
  
  print(str_c("Initializing alpha = ",alphas[1],", sigma = ",sigmasq.ns[1]))
  sibp.opt <- sibp(X,Y,K,alphas[1],sigmasq.ns[1],train.ind = train.ind,silent = TRUE)
  exclu.opt <- sibp_exclusivity(sibp.opt,X,num.words = 20)
  
  for (alpha in alphas) {
    for (sigma in sigmasq.ns) {
      print(str_c("Testing alpha = ",alpha,", sigma = ",sigma))
      
      for(i in 1:iters){
        if(i == 1 & alpha == alphas[1] & sigma == sigmasq.ns[1]) next
        
        set.seed(seed+i)
        test.sibp <- sibp(X,Y,K,alpha,sigma,train.ind = train.ind,silent = TRUE)
        test.exclu <- sibp_exclusivity(test.sibp,X,num.words = 20)
        
        if(test.exclu>exclu.opt){
          exclu.opt <- test.exclu
          sibp.opt <- test.sibp
        }
      }
      
    }
  }
  
  return(sibp.opt)
}

# if memory is not a bottleneck, this will work
sibp.searches <- parallel::mclapply(ks_ibp,function(k) sibp_param_search(X, Y, K = k, alphas = c(12), sigmasq.ns = c(0.8),#alphas = c(6,9,12), sigmasq.ns = c(0.8,1),
                                                             iters = 1, train.ind = train.ind,seed = 0515),mc.cores = 4)
#saveRDS(sibp.searches,"output/sibp_searches3.rds")
sibp.searches12 <- sibp_param_search(X, Y, K = 12, alphas = c(12), sigmasq.ns = c(0.8),#alphas = c(6,9,12), sigmasq.ns = c(0.8),
                  iters = 1, train.ind = train.ind,seed = 0515)
opt12 <- sibp(X,Y,K = 12,alpha = 12,sigmasq.n = 0.8,train.ind = train.ind)
#saveRDS(opt12,"output/opt12_12_08.rds")

lapply(sibp.searches,function(s) sibp_rank_runs(s,X=X,num.words=10))

optimal_fits <- list(sibp.searches[[1]][["12"]][["0.8"]][[1]],
                     sibp.searches[[2]][["12"]][["0.8"]][[1]],
                     sibp.searches[[3]][["12"]][["0.8"]][[1]],
                     sibp.searches[[4]][["12"]][["0.8"]][[1]],
                     sibp.searches[[5]][["12"]][["0.8"]][[1]])


for(k in ks_ibp){
  search <- sibp_param_search(X, Y, K = k, alphas = c(12), sigmasq.ns = c(0.8),#alphas = c(6,9,12), sigmasq.ns = c(0.8,1),
                    iters = 1, train.ind = train.ind,seed = 0515)
  saveRDS(search,str_c("output/sibp_runs/sibp_k",k,".rds"))
}

# if memory is a bottleneck, this will work

sibp.searches <- lapply(ks_ibp,function(k)  find_opt_sibp(X,Y,K=k,alphas = c(10,12,15,20),sigmasq.ns = c(0.8,1),
                                                          iters = 1,train.ind = train.ind,seed = 0515))
# optimal_fits <- sibp.searches

# compute training set regressions
get_sibp_lm <- function(sibp.fit,Y,X){
  
  #print top words
  print(sibp_top_words(sibp.fit, colnames(X), 10, verbose = TRUE))
  
  #collect data
  set.seed(0)
  Z.test <- infer_Z(sibp.fit,X) 
  Y.test <- (Y[sibp.fit$test.ind] - sibp.fit$meanY)/sibp.fit$sdY
  Z.hard <- apply(Z.test, 2, function(z) sapply(z, function(zi) ifelse(zi >= 0.5, 1, 0))) %>% as.data.frame()
  
  colnames(Z.test) <- str_c("T",1:ncol(Z.test))
  Z.test <- Z.test %>% as.data.frame()
  colnames(Z.hard) <- str_c("T",1:ncol(Z.hard))
  #print treatment prevelance
  print(colMeans(Z.hard))
  
  fit <- lm(Y.test ~ .,data = bind_cols(Z.hard,select(scripts_completes[sibp.fit$test.ind,],all_of(demogs))))
  broom::tidy(fit,conf.int=T) %>% mutate(K=ncol(Z.hard))
}

sibp.TEs <- lapply(optimal_fits,function(f) get_sibp_lm(f,Y=Y,X=X))
sibp.TEs <- bind_rows(sibp.TEs)
sibp.TEs %>% arrange(-statistic)

# plot results
sibp.TEs %>% filter(term %>% str_detect("^T[0-9]"),!is.na(estimate)) %>% 
  filter(K %in% c(4,6,8,10,12)) %>% 
  mutate(across(all_of(c("estimate","conf.low","conf.high")),function(x) x*optimal_fits[[1]]$sdY)) %>% 
  ggplot(aes(y=estimate,x=factor(K),ymin=conf.low,ymax=conf.high,color=factor(K))) + 
  geom_errorbar(width=1,position = position_dodge2(width = 0.9)) + 
  geom_point(position = position_dodge2(width = 1)) + 
  geom_hline(yintercept = 0) + 
  scale_color_brewer(palette = "Set1") + 
  labs(x="Number of sIBP Treatments",y="Effect of Treatment on Depolarization") + 
  theme_bw() + 
  theme(legend.position = "none")
ggsave("output/sibp_effects.png",width = 6,height = 6)


# combined graphic
sibp.plot <- sibp.TEs %>% filter(term %>% str_detect("^T[0-9]"),!is.na(estimate)) %>% 
  filter(K %in% c(4,6,8,10,12)) %>% 
  mutate(across(all_of(c("estimate","conf.low","conf.high")),function(x) x*optimal_fits[[1]]$sdY)) %>% 
  ggplot(aes(y=estimate,x=factor(K),ymin=conf.low,ymax=conf.high,color=factor(K))) + 
  geom_errorbar(width=1,position = position_dodge2(width = 0.9)) + 
  geom_point(position = position_dodge2(width = 1)) + 
  geom_hline(yintercept = 0) + 
  scale_color_brewer(palette = "Set1") + 
  labs(x="Number of sIBP Treatments",y="Effect of Treatment\non Depolarization") + 
  theme_bw() + 
  theme(legend.position = "none")

lda.plot <- lda_coef_data %>% 
  ggplot(aes(y=estimate,x=factor(K),ymin=conf.low,ymax=conf.high,color=factor(K))) + 
  geom_errorbar(width=1,position = position_dodge2(width = 0.9)) + 
  geom_point(position = position_dodge2(width = 1)) + 
  geom_hline(yintercept = 0) + 
  scale_color_brewer(palette = "Set1") + 
  labs(x="Number of LDA Topics",y="Effect of Topic\non Depolarization") + 
  theme_bw() + 
  theme(legend.position = "none")

gridExtra::grid.arrange(lda.plot,sibp.plot,nrow=2)
ggsave("output/combined_effects.png",plot = gridExtra::grid.arrange(lda.plot,sibp.plot,nrow=2),
       width = 6,height = 6)
