--- title: "Advanced flashier" output: rmarkdown::html_vignette: toc: true vignette: > %\VignetteIndexEntry{Advanced flashier} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r setup, include=FALSE} knitr::opts_chunk$set(echo = TRUE, collapse = TRUE, comment = "#>", fig.width = 6, fig.height = 6, warning = FALSE) ``` In addition to `flashier` and `ggplot2`, we will make use of package `cowplot` for arranging plots into grids and package `dplyr` for wrangling fitting progress data: ```{r load_packages, message=FALSE} library(flashier) library(ggplot2) library(cowplot) library(dplyr) ``` ## Pipeable interface In addition to the main `flash()` function, `flashier` provides a collection of pipeable `flash_xxx()` functions. Any fit produced via `flash()` can be rewritten using these functions, which make the order of fitting operations more explicit and also provide many more options for customization. For example, the following are equivalent: ```{r adv_interface} # # Basic interface (not run): # fit_backfit <- flash( # gtex, # greedy_Kmax = 5, # var_type = 2, # backfit = TRUE, # verbose = 0 # ) # Pipeable interface: t_backfit <- system.time( fit_backfit <- flash_init(gtex, var_type = 2) |> flash_set_verbose(verbose = 0) |> flash_greedy(Kmax = 5) |> flash_backfit() |> flash_nullcheck() ) ``` Function `flash_init()` sets up the `flash` object and handles global parameter `var_type`; `flash_set_verbose()` manages the output that will be printed to console; and `flash_greedy()`, `flash_backfit()`, and `flash_nullcheck()` perform the greedy fit, backfit, and nullcheck described in the introductory vignette. In many scenarios, the functionality provided by `flash()` will be sufficient. Other scenarios will require the additional flexibility afforded by the pipeable interface, either to achieve specific goals or to produce multiple alternative fits via non-default settings. In particular, successfully fitting very large datasets is often a trial-and-error process of tinkering and customization. Below, we describe some tasks made possible by the pipeable interface that we have found particularly useful in our own applications. All examples use the same `gtex` dataset that was used in the introductory vignette. ## Customizing the order of operations Since the pipeable interface modularizes operations, it is possible to perform multiple intermediary backfits and nullchecks. Often, re-arranging the order of operations will produce a different fit, either for better or for worse: ```{r cust_order_ops} # Pipeable interface: fit_multiple_backfits <- flash_init(gtex, var_type = 2) |> flash_set_verbose(verbose = 0) |> flash_greedy(Kmax = 3) |> flash_backfit() |> flash_nullcheck() |> flash_greedy(Kmax = 2) |> flash_backfit() |> flash_nullcheck() c(one_bf_elbo = fit_backfit$elbo, two_bf_elbo = fit_multiple_backfits$elbo) ``` Here, we do not obtain an improvement in ELBO, but results will vary from one scenario to the next. ## An alternative backfitting method Function `flash()` always only uses the greedy algorithm as initialization for the backfitting algorithm. By using function `flash_factors_init()` within a pipeline, we can instead initialize factor/loadings pairs all at once via `svd()` (or any other method) and then subsequently backfit. As argument, `flash_factors_init()` takes a list of two matrices (interpreted as $L$ and $F$) or an "SVD-like object" (that is, a list containing fields `u`, `d`, and `v`), so that the output of function `svd()` can be passed in directly: ```{r init.factors} fit_alternative_backfit <- flash_init(gtex, var_type = 2) |> flash_set_verbose(verbose = 0) |> flash_factors_init(svd(gtex, nu = 5, nv = 5)) |> flash_backfit(verbose = 0) c(bf_elbo = fit_backfit$elbo, alt_bf_elbo = fit_alternative_backfit$elbo) ``` Again, we fail to obtain an improvement in ELBO in this particular scenario. ## Troubleshooting: turning off extrapolation To accelerate backfits, `flashier` uses an "extrapolation" technique inspired by Ang and Gillis (2018); for details, see Willwerscheid (2021). While this can dramatically improve run time for large datasets, it can be finicky and occasionally results in errors that are difficult to track down. When odd errors are generated, we recommend turning off extrapolation as a first troubleshooting step. To do so, set `extrapolate = FALSE` in the call to `flash_backfit`: ```{r no.extrap} t_no_extrapolate <- system.time( fit_no_extrapolate <- flash_init(gtex, var_type = 2) |> flash_set_verbose(verbose = 0) |> flash_greedy(Kmax = 5) |> flash_backfit(extrapolate = FALSE) |> flash_nullcheck() ) c(extrapolate_elbo = fit_backfit$elbo, no_extrapolate_elbo = fit_no_extrapolate$elbo) ``` Here, `flashier` appears to find the same solution with and without extrapolation, but without the benefit of extrapolation there is a large increase in run time: ```{r no.extrap.time} c(t_extrapolate = t_backfit[3], t_no_extrapolate = t_no_extrapolate[3]) ``` ## Adding an intercept In many applications, it is useful to add an "intercept" term to account for, say, differences in mean values from row to row. Here, an intercept might be used to capture effects that are exactly equally shared across all tissues (similar to what is captured by the first factor from previous fits, except that factor values are constrained to be exactly equal rather than estimated as approximately so): ```{r intercept} fit_with_intercept <- flash_init(gtex, var_type = 2) |> flash_set_verbose(verbose = 0) |> flash_add_intercept(rowwise = FALSE) |> flash_greedy(Kmax = 4) |> flash_backfit() |> flash_nullcheck() p1 <- plot( fit_backfit, pm_which = "factors", pm_colors = gtex_colors, plot_type = "bar" ) + ggtitle("No intercept") p2 <- plot( fit_with_intercept, pm_which = "factors", pm_colors = gtex_colors, plot_type = "bar" ) + ggtitle("With intercept") plot_grid(p1, p2, nrow = 2) ``` In essence, `flash_add_intercept()` is a convenience function that initializes the values of the factor or loadings at one via `flash_factors_init()` and then fixes those values using `flash_factors_fix()`. Thus we could achieve the above fit as follows: ```{r fix.mean, eval = FALSE} ones <- matrix(1, nrow = ncol(gtex), ncol = 1) init_loadings <- matrix(rowMeans(gtex), ncol = 1) fit_with_intercept <- flash_init(gtex, var_type = 2) |> flash_set_verbose(0) |> flash_factors_init(list(init_loadings, ones)) |> flash_factors_fix(kset = 1, which_dim = "factors") |> flash_greedy(Kmax = 4) |> flash_backfit() ``` ## Fixed sparsity patterns Many more options are possible using `flash_factors_init()` in conjunction with `flash_factors_fix()`. For example, after adding an intercept, we could explicitly add a brain-specific factor by constraining the values for non-brain tissues to be zero and allowing values for brain tissues to be estimated: ```{r fixed.sprs} is_brain <- grepl("Brain", colnames(gtex)) init_loadings <- rowMeans(gtex[, is_brain]) - rowMeans(gtex[, !is_brain]) fit_fixed_pattern <- flash_init(gtex, var_type = 2) |> flash_set_verbose(0) |> flash_add_intercept(rowwise = FALSE) |> flash_factors_init(list(matrix(init_loadings, ncol = 1), matrix(is_brain, ncol = 1))) |> flash_factors_fix(kset = 2, which_dim = "factors", fixed_idx = !is_brain) |> flash_greedy(3) |> flash_backfit() plot( fit_fixed_pattern, pm_which = "factors", pm_colors = gtex_colors, plot_type = "bar" ) ``` ## Modifying the convergence criterion and verbose output By default, the greedy and backfitting algorithms terminate when the variational lower bound on the log likelihood (ELBO) increases by no more than $np \sqrt{\epsilon}$ from one iteration to the next (where $\epsilon$ denotes machine epsilon). The convergence criterion can be changed using function `flash_set_conv_crit()`. If we were primarily interested in factor values rather than loadings, then we might like to terminate when their normalized absolute values no longer change by more than, say, .001. To confirm that this criterion is respected, we also modify the output printed to console using function `flash_set_verbose()`: ```{r conv.crit} gtex_conv_crit <- flash_init(gtex, var_type = 2) |> flash_set_conv_crit(fn = flash_conv_crit_max_chg_F, tol = .001) |> flash_set_verbose( fns = c(flash_verbose_elbo, flash_verbose_max_chg_F), colnames = c("ELBO", "Max.Chg.Factors"), colwidths = c(18, 18) ) |> flash_greedy(Kmax = 3) |> flash_backfit() ``` Note that `flash_set_conv_crit()` and `flash_set_verbose()` both take functions as arguments. Several functions `flash_conv_crit_xxx()` are provided as alternative convergence criteria, and similar functions `flash_verbose_xxx()` simplify the customization of verbose output. As we demonstrate in the following section, it is also possible to write functions from scratch. ## Writing custom convergence criteria and verbose output functions using flash_fit objects Custom functions for `flash_set_conv_crit()` and `flash_set_verbose()` require working with `flash_fit` objects, which are much less friendly than their `flash` counterparts. To ease use, `flashier` provides a number of accessor functions `flash_fit_get_xxx()` as well as methods `fitted()`, `residuals()`, and `ldf()` (see the documentation in `?flash_fit` for a full list of helper functions). Any custom function must take three parameters as input: `curr` (the current `flash_fit` object); `prev` (the `flash_fit` object from the previous iteration); and `k` (which gives the index of the factor currently being optimized by `flash_backfit()` when `extrapolate = FALSE`; if extrapolation has not been turned off, then `k` can safely be ignored). For example, let's say that we would like to monitor the sparsity of factors 2-5 (which we define as the mixture weight of the point mass $\pi_0$ in estimates of priors $g_f$) over the course of a backfit. We use the following custom functions: ```{r custom} verbose_sparsity <- function(new, old, k, f_idx) { g <- flash_fit_get_g(new, n = 2) # setting n = 2 gets g_f (n = 1 would get g_\ell) pi0 <- g[[f_idx]]$pi[1] # return point mass weight return(formatC(pi0, format = "f", digits = 3)) } verbose_sprs2 <- function(new, old, k) verbose_sparsity(new, old, k, 2) verbose_sprs3 <- function(new, old, k) verbose_sparsity(new, old, k, 3) verbose_sprs4 <- function(new, old, k) verbose_sparsity(new, old, k, 4) verbose_sprs5 <- function(new, old, k) verbose_sparsity(new, old, k, 5) fit_monitor_sparsity <- flash_init(gtex, var_type = 2) |> flash_set_verbose(0) |> flash_greedy(Kmax = 5) |> flash_set_verbose( verbose = 3, fns = c(flash_verbose_elbo, verbose_sprs2, verbose_sprs3, verbose_sprs4, verbose_sprs5), colnames = c("ELBO", paste0("Sparsity (", 2:5, ")")), colwidths = rep(14, 5) ) |> flash_backfit() ``` ## Writing custom EBNM functions Custom EBNM functions may also be created when specialized prior families are required. Often it is sufficient to use the helper function `flash_ebnm` to pass non-default arguments to function `ebnm()` in package `ebnm`. For example, we might choose to put a normal prior on the first factor with mode to be estimated (since we do not expect the first factor to be sparse): ```{r normal.est.mode} fit_flash_ebnm <- flash_init(gtex, var_type = 2) |> flash_set_verbose(0) |> flash_greedy(ebnm_fn = flash_ebnm(prior_family = "normal", mode = "estimate")) |> flash_greedy(Kmax = 4, ebnm_fn = ebnm_point_normal) fit_flash_ebnm$F_ghat[[1]] ``` For cases where `flash_ebnm()` is not sufficient, completely custom functions can also be created; for details, see the documentation in `?flash_ebnm`. To ensure that the return object is correctly formatted, we recommend calling into function `ebnm()` (with, perhaps, `fix_g = TRUE`) before returning. For the sake of illustration, we create an EBNM function where the prior family $\mathcal{G}$ is the family of two-component distributions where one component is a pointmass at zero and the other is a normal distribution (not necessarily centered at zero): ```{r custom.ebnm} ebnm_custom <- function(x, s, g_init, fix_g, output) { if (fix_g) { ebnm_res <- ebnm_ash( x, s, g_init = g_init, fix_g = TRUE, output = output, mixcompdist = "normal" ) } else { # Parameters are: # 1. mean of normal component # 2. sd of normal component neg_llik <- function(par) { g <- ashr::normalmix(c(0.5, 0.5), c(0, par[1]), c(0, par[2])) ebnm_res <- ebnm_ash( x, s, g_init = g, fix_g = FALSE, mixcompdist = "normal" ) return(-ebnm_res$log_likelihood) } # Run optim to get mean and sd of normal component: opt_res <- optim( par = c(0, 1), # Initial values fn = neg_llik, method = "L-BFGS-B", lower = c(-Inf, 0.01), upper = c(Inf, Inf) ) # Now re-run ash to get mixture weights: opt_par <- opt_res$par g <- ashr::normalmix(c(0.5, 0.5), c(0, opt_par[1]), c(0, opt_par[2])) ebnm_res <- ebnm_ash( x, s, g_init = g, fix_g = FALSE, output = output, mixcompdist = "normal" ) } return(ebnm_res) } fit_custom <- flash_init(gtex, var_type = 2) |> flash_set_verbose(0) |> flash_greedy( Kmax = 2, ebnm_fn = c(ebnm_point_normal, ebnm_custom) ) fit_custom$F_ghat ``` ## A recipe for plotting fitting progress Setting `verbose = -1` outputs a single tab-delimited table of values that makes it straightforward to analyze fitting progress. The code below backfits with and without extrapolation and then compares the per-iteration ELBO for each fit. (Since `sink` does not play well with R Markdown, this code is not evaluated.) ```{r plot.history, eval = FALSE} sink("zz.tsv") tmp <- flash_init(gtex, var_type = 2) |> flash_set_verbose(-1) |> flash_factors_init(svd(gtex, nu = 5, nv = 5)) |> flash_backfit() progress_extrapolate <- read.delim("zz.tsv") sink() sink("zz.tsv") tmp <- flash_init(gtex, var_type = 2) |> flash_set_verbose(-1) |> flash_factors_init(svd(gtex, nu = 5, nv = 5)) |> flash_backfit(extrapolate = FALSE) progress_no_extrapolate <- read.delim("zz.tsv") sink() rm(tmp) file.remove("zz.tsv") progress_extrapolate <- progress_extrapolate |> mutate(Extrapolate = TRUE) |> select(Iter, ELBO, Extrapolate) progress_no_extrapolate <- progress_no_extrapolate |> group_by(Iter) |> summarize(ELBO = max(ELBO, na.rm = TRUE)) |> ungroup() |> mutate(Extrapolate = FALSE) tib <- progress_extrapolate |> bind_rows(progress_no_extrapolate) |> mutate(Iter = as.numeric(Iter), ELBO = as.numeric(ELBO)) ggplot(tib, aes(x = Iter, y = ELBO, col = Extrapolate)) + geom_line() + theme_minimal() ``` ## Session information The following R version and packages were used to generate this vignette: ```{r} sessionInfo() ```