# remotes::install_github("rpkgs/gg.layers")
library(tidyverse)
library(gg.layers)
library(scales)
library(glue)
library(latex2exp)
library(cowplot)

Data

files_bias = c(
  "glove-wiki-gigaword-300"="results/bias_glove-wiki-gigaword-300_FEMALE_MALE.csv",
  "word2vec-google-news-300"="results/bias_word2vec-google-news-300_FEMALE_MALE.csv"
)
dfs = list()
for (n in names(files_bias)) {
  dfs[[n]] = read_csv(files_bias[n], show_col_types=F)
}
# modify frequency bins if needed
add_frequency_bins = function(df) {
  log_freq = log10(df[["reverse_freq_rank"]])
  max_value = max(log_freq)
  cuts = c(0, seq(1.5, max_value, .5), max_value)
  # print(cuts)
  df = df %>% mutate(bins = cut(log_freq, cuts, include.lowest=T))
  return(df)
}

for (n in names(dfs)) {
  dfs[[n]] = add_frequency_bins(dfs[[n]])
}

Explore

# words in each frequency bin
table(dfs[[1]]$bins, useNA = "always")
## 
##   [0,1.5]   (1.5,2]   (2,2.5]   (2.5,3]   (3,3.5]   (3.5,4]   (4,4.5]   (4.5,5] 
##        29        67       214       682      2155      6838     21621     68378 
##   (5,5.5] (5.5,5.6]      <NA> 
##    216227     83773         0
table(dfs[[2]]$bins, useNA = "always")
## 
##  [0,1.5]  (1.5,2]  (2,2.5]  (2.5,3]  (3,3.5]  (3.5,4]  (4,4.5]  (4.5,5] 
##       29       66      215      682     2155     6838    21621    68378 
##  (5,5.5]  (5.5,6] (6,6.48]     <NA> 
##   216227   683773  2000000        0

Plots

Boxplots

# boxplots
boxplots_plt = function(df, xlab=NULL, ylab="Female bias", title=NULL,
                        subtitle=NULL, effect_sizes=F) {
  # bin labels
  labs = levels(df[["bins"]])
  labs = str_replace(labs, ",", "},10^{")
  labs = str_replace(labs, "([\\[\\(])", "\\1$10^{")
  labs = str_replace(labs, "([\\]\\)])", "}$\\1")
  labs = str_replace(labs, "(\\])$", r"(\\])")
  labs = str_replace(labs, "^(\\[)", r"(\\[)")
  labs = lapply(sprintf(r'(%s)', labs), TeX)
  labs = unlist(labs)
  # plot
  p = ggplot(df, aes(x=bins, y=bias)) +
    gg.layers::geom_boxplot2(
      width.errorbar=0.2, fill="lightblue", color="black") +
    stat_summary(fun="mean", color="navy") +
    geom_hline(yintercept=0, color="black", linetype="dashed") +
    # facet_wrap(vars(model), scales="free", ncol=1) +
    labs(x=xlab, y=ylab, title=title, subtitle=subtitle) +
    scale_x_discrete(labels=labs, guide=guide_axis(angle=35)) +
    theme_minimal() +
    theme(
      axis.title.x=element_text(size=15), axis.title.y=element_text(size=16),
      axis.text=element_text(size=14), strip.text=element_text(size=15),
      plot.subtitle=element_text(size=16)
      ) +
    NULL
  if (effect_sizes == T) {
    df_effect_sizes = df %>%
      group_by(bins) %>%
      summarise(mean_bias = mean(bias), sd_bias = sd(bias)) %>%
      mutate(ef = mean_bias / sd_bias)
    y_limits = layer_scales(p)$y$get_limits()
    y_adj = (y_limits[2] - y_limits[1]) * 0.07
    y_text = y_limits[2] + y_adj
    p = p +
      geom_text(
        data=df_effect_sizes,
        aes(x=bins, label=round(ef, 2)), y=y_text, color="navy", size=5) +
      lims(y = c(NA, y_text)) +
      NULL
  }
  return(p)
}
last_name = names(dfs)[length(names(dfs))]

plot_list = list()
for (n in names(dfs)) {
  x_label = NULL
  if (n == last_name) x_label = "Frequency rank"
  p = boxplots_plt(dfs[[n]], xlab=x_label, ylab=glue("Female bias\n({n})"),
    effect_sizes=T)
  plot_list[[n]] = p
  print(p)
}
## Warning: Removed 985 rows containing non-finite values (stat_boxplot).
## Warning: Removed 985 rows containing non-finite values (stat_summary).
## Warning: Removed 10 rows containing missing values (geom_segment).

## Warning: Removed 24639 rows containing non-finite values (stat_boxplot).
## Warning: Removed 24639 rows containing non-finite values (stat_summary).
## Warning: Removed 11 rows containing missing values (geom_segment).

outname = glue("results/plots/boxplots_pretrained.png")
grid = plot_grid(plotlist=plot_list, ncol=1)
## Warning: Removed 985 rows containing non-finite values (stat_boxplot).
## Warning: Removed 985 rows containing non-finite values (stat_summary).
## Warning: Removed 10 rows containing missing values (geom_segment).
## Warning: Removed 24639 rows containing non-finite values (stat_boxplot).
## Warning: Removed 24639 rows containing non-finite values (stat_summary).
## Warning: Removed 11 rows containing missing values (geom_segment).
save_plot(outname, grid, base_height=8.5, base_width=7, dpi=300)

# NOTE we dont use facet_wrap because it is hard to use it with geom_text
# # one DF with all models:
# df_full = dfs %>% bind_rows(.id="model")
# 
# # One plot for all models with facet_wrap:
# p = boxplots_plt(df_full)
# outname = glue("results/plots/boxplots_pretrained.png")
# ggsave(outname, p, width=8, height=2*5, dpi=300)
# print(p)