oldw <- getOption("warn")
options(warn = -1)

setwd(dirname(rstudioapi::getActiveDocumentContext()$path))
getwd()

source("libraries_functions.R")

## features used in the NP system---

no_extra_parsing_features_minimal <- c("ID", "reftype", "type", "first_mention", "same_sent_ante", "word_dist_quantile",
                                     "sent_dist_quantile", "same_prev_mention", "mention_order_cat")

## Data for non_parsed features---

np <- webnlg_only_features %>% 
  select(c(refex,entity,all_of(no_extra_parsing_features_minimal))) %>% 
  filter(reftype != "empty")%>% 
  mutate_if(is.character , replace_na, replace = "first_mention") %>% 
  mutate_if(., is.character, as.factor) %>% 
  mutate_if(is.numeric , replace_na, replace = -1)

## prepare unique concatenated entity, type, feature info for subsequent generations--

ent_ref_form_feat_6 <- np
ent_ref_form_feat_5 <- np[1:(length(np)-1)]
ent_ref_form_feat_4 <- np[1:(length(np)-2)]
ent_ref_form_feat_3 <- np[1:(length(np)-3)]
ent_ref_form_feat_2 <- np[1:(length(np)-4)]
ent_ref_form_feat_1 <- np[1:(length(np)-5)]

ent_ref_form_feat_6 <- unity_formation(ent_ref_form_feat_6)

ent_ref_form_feat_5 <- unity_formation(ent_ref_form_feat_5)

ent_ref_form_feat_4 <- unity_formation(ent_ref_form_feat_4)

ent_ref_form_feat_3 <- unity_formation(ent_ref_form_feat_3)

ent_ref_form_feat_2 <- unity_formation(ent_ref_form_feat_2)

ent_ref_form_feat_1 <- unity_formation(ent_ref_form_feat_1)

name_value_pairs_6 <- ent_ref_form_feat_6 %>% 
  select(unique_entity_reftype_feat, refex)
  
name_value_pairs_5 <- ent_ref_form_feat_5 %>% 
  select(unique_entity_reftype_feat, refex)

name_value_pairs_4 <- ent_ref_form_feat_4 %>% 
  select(unique_entity_reftype_feat, refex)

name_value_pairs_3 <- ent_ref_form_feat_3 %>% 
  select(unique_entity_reftype_feat, refex)

name_value_pairs_2 <- ent_ref_form_feat_2 %>% 
  select(unique_entity_reftype_feat, refex)

name_value_pairs_1 <- ent_ref_form_feat_1 %>% 
  select(unique_entity_reftype_feat, refex)

entity_test <- np %>% 
  filter(type=="test") %>% 
  select(entity)

refex_test <- np %>% 
  filter(type=="test") %>% 
  select(refex)

## create  train dev test sets of the NP dataframe ------------------

np <- np %>% 
  select(-c(refex,entity))

np_train_id <- np %>% filter(type=="train") %>% select(ID)
np_test_id <- np %>% filter(type=="test") %>% select(ID)
np_dev_id <- np %>% filter(type=="dev") %>% select(ID)


np_train <- np %>% filter(type=="train") %>% select(-c(ID,type))
np_test <- np %>% filter(type=="test") %>% select(-c(ID,type))
np_dev <- np %>% filter(type=="dev") %>% select(-c(ID,type))

np_train_label <- as.integer(np_train$reftype)-1
np_dev_label <- as.integer(np_dev$reftype)-1
np_test_label <- as.integer(np_test$reftype)-1

target= c(1)

np_train_pool <- catboost.load_pool(data=np_train[,-target], label = np_train_label)

np_dev_pool <- catboost.load_pool(data=np_dev[,-target], label = np_dev_label)

np_test_pool <- catboost.load_pool(data=np_test[,-target], label = np_test_label)

## Running the NP model -----------

set.seed(42)
np_model <- catboost.train(np_train_pool, np_dev_pool, catboost_fit_params_best)

## get feature importance (this was used for ordering the features) ---

np_varimp <- catboost.get_feature_importance(model = np_model, pool= np_train_pool) %>% 
  as.data.frame() %>% 
  rownames_to_column(.) %>% 
  rename(importance= "V1") %>% 
  arrange(desc(importance))

catboost.get_feature_importance(np_model, np_dev_pool) %>% 
  as.data.frame() %>% 
  setNames("Importance") %>% 
  rownames_to_column("Feature") %>% 
  ggplot(aes(x = reorder(Feature, Importance), y = Importance)) +
  geom_bar(stat = "identity") +
  coord_flip()

## NP model predictions -----------

np_pred <- catboost.predict(np_model, np_test_pool, prediction_type = 'Probability')

np_predictions <- catboost_predictions(np_test_label,np_pred)  

np_info <- acc_table(np_predictions)
np_info

saveRDS(np_info, file="webnlg_np_results/webnlg_np_stat.rds")

np_pred_with_feat <- np_predictions %>% cbind(np_test[,-1])

saveRDS(np_pred_with_feat, file= "webnlg_np_results/webnlg_np_pred_with_feat.rds")

## NP model generation-------------------

np_generation <- np_pred_with_feat %>% 
  .[,c(4,6:11)] %>% 
  unite("concat_feat", 2:7, remove = TRUE, sep='-') %>%
  cbind(entity_test) %>% 
  mutate(seen = ifelse(entity %in% ent_ref_form_feat_6$entity, "seen", "unseen")) %>%
  unite("unique_entity_reftype_feat", c(entity,class_prediction,concat_feat), remove= FALSE, sep='-') %>% 
  #mutate(unique_entity_reftype_feat_0 = str_before_nth(unique_entity_reftype_feat, '-',2)) %>% 
  mutate(unique_entity_reftype_feat_1 = str_before_nth(unique_entity_reftype_feat, '-',3)) %>% 
  mutate(unique_entity_reftype_feat_2 = str_before_nth(unique_entity_reftype_feat, '-',4)) %>%
  mutate(unique_entity_reftype_feat_3 = str_before_nth(unique_entity_reftype_feat, '-',5)) %>%
  mutate(unique_entity_reftype_feat_4 = str_before_nth(unique_entity_reftype_feat, '-',6)) %>%
  mutate(unique_entity_reftype_feat_5 = str_before_nth(unique_entity_reftype_feat, '-',7))%>% 
  left_join(.,name_value_pairs_6, by="unique_entity_reftype_feat") %>% 
  left_join(.,name_value_pairs_5, by=c("unique_entity_reftype_feat_5" = "unique_entity_reftype_feat")) %>% 
  left_join(.,name_value_pairs_4, by=c("unique_entity_reftype_feat_4" = "unique_entity_reftype_feat")) %>% 
  left_join(.,name_value_pairs_3, by=c("unique_entity_reftype_feat_3" = "unique_entity_reftype_feat")) %>% 
  left_join(.,name_value_pairs_2, by=c("unique_entity_reftype_feat_2" = "unique_entity_reftype_feat")) %>% 
  left_join(.,name_value_pairs_1, by=c("unique_entity_reftype_feat_1" = "unique_entity_reftype_feat")) %>% 
  #left_join(.,name_value_pairs_0, by=c("unique_entity_reftype_feat_0" = "unique_entity_reftype_feat")) %>%
  mutate_all(funs(type.convert(as.character(replace(., .=='NULL', NA)))))%>% 
  mutate(generated_text = do.call(coalesce, .[grepl('refex', names(.))]) %>% as.character()) %>% #this code choose the first non-missing column that its name starts with refex
  mutate(generated_text = ifelse(is.na(generated_text), gsub("_"," ",entity),generated_text)) %>% 
  mutate(generated_text = ifelse(grepl("c\\(",generated_text), str_extract(generated_text,'".*"'),generated_text)) %>% 
  mutate(generated_text = gsub('", "', ' ', generated_text)) %>% 
  mutate(generated_text = gsub('"','', generated_text)) %>% 
  mutate(ID = np_test_id) %>% 
  select(ID, class_prediction, entity, seen, generated_text) %>% 
  mutate(original = refex_test$refex)


saveRDS(np_generation, file="webnlg_np_results/webnlg_np_generation.rds")

write_delim(as.data.frame(np_generation$generated_text), delim = '', path="webnlg_np_results/webnlg_np.txt", col_names = FALSE)


toremove <- grep("np_generation*", ls(),  
                 invert = TRUE,  
                 value = TRUE) 
rm(list = c(toremove, "toremove"))

