oldw <- getOption("warn")
options(warn = -1)

setwd(dirname(rstudioapi::getActiveDocumentContext()$path))
getwd()

source("libraries_functions.R")


## features used in the informed model -----------------

parsing_features_minimal <- c("ID", "reftype", "type" , "dep_rel", "first_mention" ,"same_sent_ante",
                              "prev_gm", "word_dist_quantile", "entity_type", "gender")

##Data for non_parsed features

pars <- webnlg_only_features %>% 
  select(c(refex,entity,all_of(parsing_features_minimal))) %>% 
  filter(reftype != "empty")%>% 
  mutate_if(is.character , replace_na, replace = "first_mention") %>% 
  mutate(gender= as.character(gender)) %>% 
  mutate(gender= ifelse(gender=="first_mention", "neutral,",
                        ifelse(gender=="neutral", "neutral,", gender))) %>% 
  mutate(gender = as.factor(gender)) %>% 
  mutate(entity_type= str_before_first(as.character(entity_type),",")) %>% 
  mutate(gender= str_before_first(as.character(gender),",")) %>% 
  mutate_if(., is.character, as.factor) %>% 
  mutate_if(is.numeric , replace_na, replace = -1)


## prepare unique concatenated entity, type, feature info for subsequent generations
pars_ent_ref_form_feat_5 <- pars[1:(length(pars)-2)]
pars_ent_ref_form_feat_4 <- pars[1:(length(pars)-3)]
pars_ent_ref_form_feat_3 <- pars[1:(length(pars)-4)]
pars_ent_ref_form_feat_2 <- pars[1:(length(pars)-5)]
pars_ent_ref_form_feat_1 <- pars[1:(length(pars)-6)]


pars_ent_ref_form_feat_5 <- unity_formation(pars_ent_ref_form_feat_5)

pars_ent_ref_form_feat_4 <- unity_formation(pars_ent_ref_form_feat_4)

pars_ent_ref_form_feat_3 <- unity_formation(pars_ent_ref_form_feat_3)

pars_ent_ref_form_feat_2 <- unity_formation(pars_ent_ref_form_feat_2)

pars_ent_ref_form_feat_1 <- unity_formation(pars_ent_ref_form_feat_1)


pars_name_value_pairs_5 <- pars_ent_ref_form_feat_5 %>% 
  select(unique_entity_reftype_feat, refex)

pars_name_value_pairs_4 <- pars_ent_ref_form_feat_4 %>% 
  select(unique_entity_reftype_feat, refex)

pars_name_value_pairs_3 <- pars_ent_ref_form_feat_3 %>% 
  select(unique_entity_reftype_feat, refex)

pars_name_value_pairs_2 <- pars_ent_ref_form_feat_2 %>% 
  select(unique_entity_reftype_feat, refex)

pars_name_value_pairs_1 <- pars_ent_ref_form_feat_1 %>% 
  select(unique_entity_reftype_feat, refex)


pars_entity_test <- pars %>% 
  filter(type=="test") %>% 
  select(entity)


pars_refex_test <- pars %>% 
  filter(type=="test") %>% 
  select(refex)

## create  train dev test sets of the pars dataframe ------------------

pars <- pars %>% 
  select(-c(refex,entity))

pars_train_id <- pars %>% filter(type=="train") %>% select(ID)
pars_test_id <- pars %>% filter(type=="test") %>% select(ID)
pars_dev_id <- pars %>% filter(type=="dev") %>% select(ID)


pars_train <- pars %>% filter(type=="train") %>% select(-c(ID,type))
pars_test <- pars %>% filter(type=="test") %>% select(-c(ID,type))
pars_dev <- pars %>% filter(type=="dev") %>% select(-c(ID,type))

pars_train_label <- as.integer(pars_train$reftype)-1
pars_dev_label <- as.integer(pars_dev$reftype)-1
pars_test_label <- as.integer(pars_test$reftype)-1

target= c(1)

pars_train_pool <- catboost.load_pool(data=pars_train[,-target], label = pars_train_label)

pars_dev_pool <- catboost.load_pool(data=pars_dev[,-target], label = pars_dev_label)

pars_test_pool <- catboost.load_pool(data=pars_test[,-target], label = pars_test_label)

## Running the pars model -----------

set.seed(42)
pars_model <- catboost.train(pars_train_pool, pars_dev_pool, catboost_fit_params_best)

## get feature importance (this was used for ordering the features) ---

pars_varimp <- catboost.get_feature_importance(model = pars_model, pool= pars_train_pool) %>% 
  as.data.frame() %>% 
  rownames_to_column(.) %>% 
  rename(importance= "V1") %>% 
  arrange(desc(importance))

catboost.get_feature_importance(pars_model, pars_dev_pool) %>% 
  as.data.frame() %>% 
  setNames("Importance") %>% 
  rownames_to_column("Feature") %>% 
  ggplot(aes(x = reorder(Feature, Importance), y = Importance)) +
  geom_bar(stat = "identity") +
  coord_flip()

## pars model predictions -----------

pars_pred <- catboost.predict(pars_model, pars_test_pool, prediction_type = 'Probability')

pars_predictions <- catboost_predictions(pars_test_label,pars_pred)  

pars_info <- acc_table(pars_predictions)
pars_info

saveRDS(pars_info, file="webnlg_pars_results/webnlg_pars_stat.rds")

pars_pred_with_feat <- pars_predictions %>% cbind(pars_test[,-1])

saveRDS(pars_pred_with_feat, file= "webnlg_pars_results//pars_pred_with_feat.rds")

## NP model generation ----------


pars_generation <- pars_pred_with_feat %>% 
  .[,c(4,6:12)]%>% 
  unite("concat_feat", 2:6, remove = FALSE, sep='-')%>%
  cbind(pars_entity_test)%>% 
  mutate(seen = ifelse(entity %in% pars_ent_ref_form_feat_5$entity, "seen", "unseen")) %>%
  unite("unique_entity_reftype_feat", c(entity,class_prediction,concat_feat), remove= FALSE, sep='-') %>% 
  mutate(unique_entity_reftype_feat_1 = str_before_nth(unique_entity_reftype_feat, '-',3)) %>% 
  mutate(unique_entity_reftype_feat_2 = str_before_nth(unique_entity_reftype_feat, '-',4)) %>%
  mutate(unique_entity_reftype_feat_3 = str_before_nth(unique_entity_reftype_feat, '-',5)) %>%
  mutate(unique_entity_reftype_feat_4 = str_before_nth(unique_entity_reftype_feat, '-',6)) %>%
  left_join(.,pars_name_value_pairs_5, by="unique_entity_reftype_feat") %>% 
  left_join(.,pars_name_value_pairs_4, by=c("unique_entity_reftype_feat_4" = "unique_entity_reftype_feat")) %>% 
  left_join(.,pars_name_value_pairs_3, by=c("unique_entity_reftype_feat_3" = "unique_entity_reftype_feat")) %>% 
  left_join(.,pars_name_value_pairs_2, by=c("unique_entity_reftype_feat_2" = "unique_entity_reftype_feat")) %>% 
  left_join(.,pars_name_value_pairs_1, by=c("unique_entity_reftype_feat_1" = "unique_entity_reftype_feat")) %>% 
  mutate_all(funs(type.convert(as.character(replace(., .=='NULL', NA)))))%>% 
  mutate(generated_text = do.call(coalesce, .[grepl('refex', names(.))]) %>% as.character()) %>% #this code choose the first non-missing column that its name starts with refex
  mutate(generated_text = ifelse(is.na(generated_text), gsub("_"," ",entity),generated_text))%>% 
  mutate(generated_text = ifelse(grepl("c\\(",generated_text), str_extract(generated_text,'".*"'),generated_text)) %>% 
  mutate(generated_text = gsub('", "', ' ', generated_text)) %>% 
  mutate(generated_text = gsub('"','', generated_text))%>% 
  mutate(ID = pars_test_id) %>% 
  mutate(original_text= pars_refex_test$refex) %>%
  mutate(generated_text = ifelse(class_prediction=="pronoun" & gender=="male,",
                                 ifelse(dep_rel =="nsubj","he",
                                        ifelse(dep_rel =="oth","he",
                                               ifelse(dep_rel =="poss","his","him"))),generated_text)) %>% 
  mutate(generated_text = ifelse(class_prediction=="pronoun" & gender=="female,",
                                 ifelse(dep_rel =="nsubj","she",
                                        ifelse(dep_rel =="oth","she","her")),generated_text)) %>% 
  mutate(generated_text = ifelse(class_prediction=="pronoun" & gender=="neutral," ,
                                 ifelse(dep_rel =="poss","its","it"),generated_text)) %>% 
  select(ID, class_prediction, entity, seen, generated_text,original_text)


saveRDS(pars_generation, file="webnlg_pars_results/pars_generation.rds")

write_delim(as.data.frame(pars_generation$generated_text), delim = '', path="webnlg_pars_results/webnlg_pars.txt", col_names = FALSE)


toremove <- grep("pars_generation*", ls(),  
                 invert = TRUE,  
                 value = TRUE) 
rm(list = c(toremove, "toremove"))
