# set directory -----------------------------------------------------------

oldw <- getOption("warn")
options(warn = -1)

setwd(dirname(rstudioapi::getActiveDocumentContext()$path))
getwd()

# Directory and libraries -------------------------------------------------
library(tidyverse)
library(strex)
library(data.table)
library(summarytools)
library(descr)

library(mlr3)
library(mlr3learners)
library(DALEXtra)
library(xgboost)
library(mlr)
require(gridExtra)
library(cvms)
library(janitor)
library(yardstick)
library(caret)
require(Matrix)


#***********************************
# set directory
#***********************************

setwd(dirname(rstudioapi::getActiveDocumentContext()$path))
getwd()

#***********************************
# libraries
#***********************************

library(jsonlite)
#library(reticulate)
library(rlist)

#for oversampling
library(ROSE)
library(scutr)


# read zh data -------------------------------------------------------

# train <- fromJSON("zh/zh/train.json")
# 
# test <- fromJSON("zh/zh/test.json")
# 
# dev <- fromJSON("zh/zh/dev.json")

# train <- write_rds(train,"zh/zh/train.rds")
# test <- write_rds(test,"zh/zh/test.rds")

train <- readRDS("zh/zh/train.rds")
test <- readRDS("zh/zh/test.rds")

#***********************************
# functions
#***********************************

predictions <- function(model,test){ predict(model, test, type="prob") %>%
    as.data.frame() %>%
    mutate(class_prediction = colnames(.)[max.col(.)])%>%
    mutate(class_prediction= as.factor(class_prediction)) %>%
    mutate(original_value = test$reftype)
}


perclass_all <- function(dt){
  modelname = deparse(substitute(dt))
  cm = as.matrix(table(Actual = dt$original_value, Predicted = dt$class_prediction))
  n = sum(cm) # number of instances
  nc = nrow(cm) # number of classes
  diag = diag(cm) # number of correctly classified instances per class 
  rowsums = apply(cm, 1, sum) # number of instances per class
  colsums = apply(cm, 2, sum) # number of predictions per class
  p = rowsums / n # distribution of instances over the actual classes
  q = colsums / n # distribution of instances over the predicted classes
  accuracy = sum(diag) / n 
  precision = round(diag / colsums, digits = 5)
  recall = round(diag / rowsums, digits = 5) 
  f1 = round(2 * precision * recall / (precision + recall), digits = 5) 
  macroPrecision = round(mean(precision), digits=5)
  macroRecall = round(mean(recall), digits = 5)
  macroF1 = round(mean(f1), digits = 5)
  dt = data.frame(modelname,diag,rowsums, colsums,p,q, precision, recall, f1, macroPrecision, macroRecall, macroF1)
}

#dt = data.frame(modelname,diag,rowsums, colsums,p,q, precision, recall, f1, macroPrecision, macroRecall, macroF1, avgAccuracy)

#***********************************
# xgboost model function
#***********************************
ctrl=trainControl(method = 'cv',
                  number = 5,
                  verboseIter = T,
                  #savePredictions = 'final',
                  savePredictions=TRUE,
                  classProbs = T)


tune_grid <- expand.grid(nrounds = 1000,
                         max_depth = 5,
                         eta = 0.05,
                         gamma = 0.01,
                         colsample_bytree = 0.75,
                         min_child_weight = 0,
                         subsample = 0.5)


xgboost<- function(train,test,model){
  set.seed(200)
  xgboost <- caret::train(reftype ~., data = train, method = "xgbTree",
                              trControl=ctrl,
                              tuneGrid = tune_grid,
                              tuneLength = 10)
  pred <- predictions(xgboost,test) %>% as.data.frame()
  write_delim(pred, path=paste("zh_bal/results/", model,'_pred.txt',sep=''), delim = "\t")
  # write_rds(pred,path=paste("zh/results/", model,'_pred.rds',sep=''))
  stat <- perclass_all(pred)
  write_delim(stat, path=paste("zh_bal/results/", model,'_perclass.txt',sep=''), delim = "\t")
  # write_rds(stat,path=paste("zh/results/", model,'_perclass.rds',sep=''))
  cm <- caret::confusionMatrix(data = pred$class_prediction, reference = pred$original_value, mode = "everything")
  overall <- print(cm$overall) %>% as.data.frame() %>% rownames_to_column(.) %>% rename(model= '.')
  write_delim(overall, path=paste("zh_bal/results/", model,'_overall.txt',sep=''), delim = "\t")
  # write_rds(overall,path=paste("zh/results/", model,'_overall.rds',sep=''))
  confmat <- as.data.frame.matrix(cm$table)
  write_delim(confmat, path=paste("zh_bal/results/", model,'_confmat.txt',sep=''), delim = "\t")
  set.seed(123)
  truth_predicted <- data.frame(
  obs = pred$original_value,
  pred = pred$class_prediction)
  truth_predicted$obs <- as.factor(truth_predicted$obs)
  truth_predicted$pred <- as.factor(truth_predicted$pred)
  cm_plt <- conf_mat(truth_predicted, obs, pred)
  p <- autoplot(cm_plt, type = "heatmap") +
  scale_fill_gradient(low="#D6EAF8",high = "#2E86C1")
  ggsave(filename=paste("zh_bal/results/", model,'_confmat.jpg',sep=''), plot=p)
}


#***********************
# 2 way
#***********************************

colnames(train)
two_way <- c("2_way","disstat","senstat","syn","focus",
               "distant","intref","locsal") 

two_way_train <- train %>% 
  select(all_of(two_way)) %>% 
  mutate_if(.,is.character,as.factor) %>% 
  rename(reftype="2_way") %>% 
  mutate(reftype= ifelse(reftype=="0","zero","one") %>% as.factor) %>% 
  mutate_if(.,is.numeric, as.factor)

over_two_way_train <- ovun.sample(reftype ~ ., data = two_way_train, method = "over")$data

two_way_test <- test %>% 
  select(all_of(two_way)) %>% 
  mutate_if(.,is.character,as.factor) %>% 
  rename(reftype="2_way") %>% 
  mutate(reftype= ifelse(reftype=="0","zero","one") %>% as.factor) %>% 
  mutate_if(.,is.numeric, as.factor)


xgboost(over_two_way_train,two_way_test,"zh_two_way")


#***********************
# 3 way
#***********************************
three_way <- c("3_way","disstat","senstat","syn","focus",
             "distant","intref","locsal") 


three_way_train <- train %>% 
  select(all_of(three_way)) %>% 
  mutate_if(.,is.character,as.factor) %>% 
  rename(reftype="3_way") %>% 
  mutate(reftype= ifelse(reftype=="0","zero",
                         ifelse(reftype=="1","one","two")) %>% as.factor)  %>% 
  mutate_if(.,is.numeric, as.factor)



three_way_test <- test %>% 
  select(all_of(three_way)) %>% 
  mutate_if(.,is.character,as.factor) %>% 
  rename(reftype="3_way") %>% 
  mutate(reftype= ifelse(reftype=="0","zero",
                         ifelse(reftype=="1","one","two")) %>% as.factor) %>% 
  mutate_if(.,is.numeric, as.factor)

xgboost(three_way_train,three_way_test,"zh_three_way")


#***********************
# 4 way
#***********************************
four_way <- c("4_way","disstat","senstat","syn","focus",
               "distant","intref","locsal") 

four_way_train <- train %>% 
  select(all_of(four_way)) %>% 
  mutate_if(.,is.character,as.factor) %>% 
  rename(reftype="4_way") %>% 
  mutate(reftype= ifelse(reftype=="0","zero",
                         ifelse(reftype=="1","one",
                                ifelse(reftype=="2","two","three"))) %>% as.factor) %>% 
  mutate_if(.,is.numeric, as.factor)


four_way_test <- test %>% 
  select(all_of(four_way)) %>% 
  mutate_if(.,is.character,as.factor) %>% 
  rename(reftype="4_way") %>% 
  mutate(reftype= ifelse(reftype=="0","zero",
                         ifelse(reftype=="1","one",
                                ifelse(reftype=="2","two","three"))) %>% as.factor) %>% 
  mutate_if(.,is.numeric, as.factor)

xgboost(four_way_train,four_way_test,"zh_four_way")


#***********************
# 5 way
#***********************************
five_way <- c("5_way","disstat","senstat","syn","focus",
              "distant","intref","locsal") 

five_way_train <- train %>% 
  select(all_of(five_way)) %>% 
  mutate_if(.,is.character,as.factor) %>% 
  rename(reftype="5_way") %>% 
  mutate(reftype= ifelse(reftype=="0","zero",
                         ifelse(reftype=="1","one",
                                ifelse(reftype=="2","two",
                                       ifelse(reftype=="3","three","four")))) %>% as.factor) %>% 
  mutate_if(.,is.numeric, as.factor)

set.seed(234)
five_way_train_up <- upSample(x = five_way_train[, -1],
                               y = five_way_train$reftype) %>%
  rename(reftype=Class) %>%
  select(8,1:7)


five_way_test <- test %>% 
  select(all_of(five_way)) %>% 
  mutate_if(.,is.character,as.factor) %>% 
  rename(reftype="5_way") %>% 
  mutate(reftype= ifelse(reftype=="0","zero",
                         ifelse(reftype=="1","one",
                                ifelse(reftype=="2","two",
                                       ifelse(reftype=="3","three","four")))) %>% as.factor) %>% 
  mutate_if(.,is.numeric, as.factor)

xgboost(five_way_train_up,five_way_test,"zh_five_way")
