Gene regression

This set of classes/functions can be used to predict the gene expression profile based on:

  1. the cell-type and

  2. cell-specific covariates (such as location in space and local micro-environment).

The intended use case is to test whether the semantic features extracted by one the Self-Supervised Lerning (ssl) algorithm contain biologically relevant information.

We treat each cell-type separately. For each cell, the gene counts are modelled as a multinomial distribution:

\[c \sim \frac{N!}{n_1!n_2!\dots n_g!} p_1^{n_1} p_2^{n_2} \dots p_g^{n_g}\]

where \(N=\sum_{g=1}^G n_g\) is the total number of counts in the cell (sometimes referred as the total UMI count) and \(\sum_{g=1}^G p_g=1\) are the probabilities of measuring each gene. When \(N\) is large and \(p_i\) are small the counts for each gene can be approximated by a Poisson distribution with rate \(r_i = N p_i\). Therefore the counts for cell \(n\) and gene \(g\) are modelled as:

\[c_{ng} \sim \text{Poi}( r_{ng} = N_n \, p_{ng})\]

To account for noise and the presence of L (cell-specific) covariates, we model the probability as:

\[\log p_{ng} = \left( \beta_g^0 + \sum_l \beta_{lg} X_{nl} \right) + \epsilon_g\]

where \(\beta_g^0\) is a gene-specific intercepts, \(X_{nl}\) are the cell covariates and \(\epsilon_g \sim N(0,\sigma_g)\) is a noise term representing the gene-specific over-dispersion.

We recap the dimension of the variable involved in the full model (for K different cell-types):

  1. \(X_{nl}\) is a fixed covariate matrix of shape \(N \times L\) (i.e. cells by covariates)

  2. \(N_n\) is a fixed vector of shape \(N\) with the (observed) total counts in a cell.

  3. \(\beta_{kg}^0\) is the intercepts of the regression of shape \(K \times G\) (i.e. cell-types by genes)

  4. \(\beta_{klg}\) are the regression coefficients of shape \(K \times L \times G\) (i.e. cell-types by covariates by genes)

  5. \(\sigma_{kg}\) are the gene over-dispersion of shape \(K \times G\) (i.e. cell-types by genes)

Typical values are \(N \sim 10^3, G \sim 10^3, K\sim 10, L\sim 50\). The goal of the inference is to determine \(\beta^0, \beta\) and \(\sigma\). We enforce a penalty (either L1 or a L2) on the regression coefficients \(\beta\) to encourage them to be small. We put a flat prior on \(\sigma\) which is allowed to vary in a small (predefined) range. There is no prior on \(\beta^0\). Overall the model has two hyper-parameters (the strength of the regularization on \(\beta\) and the allowed range for \(\sigma\)) which are determined by cross-validation.

See notebook3 for an example.

class GeneDataset(covariates: torch.Tensor, cell_type_ids: torch.Tensor, counts: torch.Tensor, k_cell_types: int, cell_type_mapping: dict, gene_names: List[str])[source]

Container for organizing the gene expression data

cell_type_ids: torch.Tensor

long tensor with the cell_type_ids of shape (n)

cell_type_mapping: dict

dictionary with mapping from unique_cell_type to cell_type_ids

counts: torch.Tensor

long tensor with the count data of shape (n, g)

covariates: torch.Tensor

float tensor with the covariates of shape (n, k)

describe()[source]

Method which described the content and the GeneDataset.

gene_names: List[str]

list of the gene names

k_cell_types: int

number of cell types

make_gene_dataset_from_anndata(anndata, cell_type_key, covariate_key, preprocess_strategy='raw', apply_pca=False, n_components=0.9)[source]

Convert a anndata object into a GeneDataset object which can be used for gene regression.

Parameters
  • anndata (AnnData) – AnnData object with the raw counts stored in anndata.X

  • cell_type_key (str) – key corresponding to the cell type, i.e. cell_types = anndata.obs[cell_type_key]

  • covariate_key (str) – key corresponding to the covariate, i.e. covariates = anndata.obsm[covariate_key]

  • preprocess_strategy (str) – either ‘center’, ‘z_score’ or ‘raw’. It describes how to preprocess the covariates. ‘raw’ (default) means no preprocessing.

  • apply_pca (bool) – if True, we compute the pca of the covariates. This operation happens after the preprocessing.

  • n_components (Union[int, float]) – Used only if apply_pca == True. If integer specifies the dimensionality of the data after PCA. If float in (0, 1) it auto selects the dimensionality so that the explained variance is at least that value.

Returns

GeneDataset – a GeneDataset object

Return type

GeneDataset

train_test_val_split(data, train_size=0.8, test_size=0.15, val_size=0.05, n_splits=1, random_state=None, stratify=True)[source]

Utility function used to split the data into train/test/val.

Parameters
  • data (Union[List[Tensor], List[ndarray], GeneDataset]) – the data to split into train/test/val

  • train_size (float) – the relative size of the train dataset

  • test_size (float) – the relative size of the test dataset

  • val_size (float) – the relative size of the val dataset

  • n_splits (int) – how many times to split the data

  • random_state (Optional[int]) – specify the random state for reproducibility

  • stratify (bool) – If true the tran/test/val are stratified so that they contain approximately the same number of example from each class. If data is a list of arrays the 2nd array is assumed to represent the class. If data is a GeneDataset the class is the cell_type.

Returns

tuple – yields multiple splits of the data.

Example

>>> for train, test, val in train_test_val_split(data=[x,y,z]):
>>>       x_train, y_train, z_train = train
>>>       x_test, y_test, z_test = test
>>>       x_val, y_val, z_val = val
>>>       ... do something ...

Example

>>> for train, test, val in train_test_val_split(data=GeneDataset):
>>>       assert isinstance(train, GeneDataset)
>>>       assert isinstance(test, GeneDataset)
>>>       assert isinstance(val, GeneDataset)
>>>       ... do something ...
plot_gene_hist(cell_types_n, value1_ng, value2_ng=None, bins=20)[source]

Plot the per cell-type histogram. If value2_ng is defined the two histogram are interlieved.

Parameters
  • cell_types_n – tensor of shape N with the cell type labels (with K distinct values)

  • value1_ng – the first quantity to whose histogram is computed lot of shape (N,G)

  • value2_ng – the second quantity to plot of shape (N,G) (optional)

  • bins – number of bins in the histogram

Returns

fig – A figure with G rows and K columns where K is the number of distinct cell types.

Return type

Figure

class GeneRegression[source]

Given the cell-type labels and some covariates the model predicts the gene expression. The counts are modelled as a LogNormalPoisson process. See documentation for more details.

configure_optimizer(optimizer_type='adam', lr=0.005, betas=(0.9, 0.999), momentum=0.9, alpha=0.99)[source]

Configure the optimizer to use.

Parameters
  • optimizer_type (str) – Either ‘adam’ (default), ‘sgd’ or ‘rmsprop’

  • lr (float) – learning rate

  • betas (Tuple[float, float]) – betas for ‘adam’ optimizer. Ignored if optimizer_type is not ‘adam’.

  • momentum (float) – momentum for ‘sgd’ optimizer. Ignored if optimizer_type is not ‘sgd’.

  • alpha (float) – alpha for ‘rmsprop’ optimizer. Ignored if optimizer_type is not ‘rmsprop’.

extend_train(dataset, n_steps=2500, print_frequency=50)[source]

Utility methods which calls train() with the same parameter just used effectively extending the training.

Parameters
  • dataset (GeneDataset) – Dataset to train the model on

  • n_steps (int) – number of training step

  • print_frequency (int) – how frequently to print loss to screen

extend_train_and_test(train_dataset, test_dataset, test_num_samples=10, train_steps=2500, train_print_frequency=50)[source]

Utility method which sequentially calls the methods extend_train() and predict().

Parameters
  • train_dataset (GeneDataset) – Dataset to train the model on

  • test_dataset (GeneDataset) – Dataset to run the prediction on

  • test_num_samples (int) – how many random samples to draw from the predictive distribution

  • train_steps (int) – number of training step

  • train_print_frequency (int) – how frequently to print loss to screen during training

Returns

metrics – See predict().

Return type

(DataFrame, DataFrame)

get_params()[source]
Returns

df – dataframe with the fitted parameters.

Note

This method can be used in combination with \(load_ckpt\) to inspect the fitted parameters of a previous run.

Examples

>>> gr = GeneRegression()
>>> gr.load_ckpt(filename="my_old_ckpt.pt")
>>> df_beta0, df_beta, df_eps = gr.get_params()
>>> df_beta0.head()
Return type

(DataFrame, DataFrame, DataFrame)

load_ckpt(filename, map_location=None)[source]

Load the full state of the model and optimizer from disk. Use it in pair with save_ckpt().

property optimizer: pyro.optim.PyroOptim

The optimizer associated with this model.

Return type

PyroOptim

predict(dataset, num_samples=10, subsample_size_cells=None, subsample_size_genes=None)

Use the parameters currently in the param_store to run the prediction and report some metrics. If you want to run the prediction based on a different set of parameters you need to call load_ckpt() first.

The Q metric is \(Q = E\left[|X_{i,g} - Y_{i,g}|\right]\) where X is the (observed) data and Y is a sample from the predicted posterior and (i,g) indicates cell and genes respectively.

The log_score metric is \(\text{log_score} = \log p_\text{posterior}\left(X_\text{data}\right)\)

Parameters
  • dataset (GeneDataset) – the dataset to run the prediction on

  • num_samples (int) – how many random samples to draw from the predictive distribution

  • subsample_size_cells (Optional[int]) – if not None (defaults) the prediction are made in chunks to avoid memory issue

  • subsample_size_genes (Optional[int]) – if not None (defaults) the prediction are made in chunks to avoid memory issue

Returns
  • df_metric – For each cell_type and gene we report the Q and log_score metrics

  • df_counts – For each cell and gene we report the observed counts and a single sample from the posterior

Return type

(DataFrame, DataFrame)

static remove_params(beta0=False, beta=False, eps=False)[source]

Selectively remove parameters from param_store.

Parameters
  • beta0 (bool) – If True (defaults is False) remove beta0 of shape \((N, G)\) from the param_store

  • beta (bool) – If True (defaults is False) remove beta of shape \((N, L, G)\) from the param_store

  • eps (bool) – If True (defaults is False) remove eps of shape \((N, G)\) from the param_store

Note

This is useful in combination with load_ckpt() and train(). For example you might have fitted a model with l1 covariate and wanting to try a different model with l2 covariate. You can load the previous ckpt and remove beta while keeping beta0 and eps (which do not depend on the number of covariate)

Example

>>> gr.load_ckpt("ckpt_with_l1_covariate.pt")
>>> gr.remove_from_param_store(beta=True)
>>> gr.train(dataset=dataset_with_l2_covariate, initialization_type="pretrained")
save_ckpt(filename)[source]

Save the full state of the model and optimizer to disk. Use it in pair with load_ckpt().

Note

Pyro saves unconstrained parameters and the constrain transformation. This means that if you manually “look inside” the ckpt you will see strange values. To get the actual value of the fitted parameter use the get_params() method.

show_loss(figsize=(4, 4), logx=False, logy=False, ax=None)[source]

Show the loss history. Useful for checking if the training has converged.

Parameters
  • figsize (Tuple[float, float]) – the size of the image. Used only if ax=None

  • logx (bool) – if True the x_axis is shown in logarithmic scale

  • logy (bool) – if True the x_axis is shown in logarithmic scale

  • ax – The axes object to draw the plot onto. If None (defaults) creates a new figure.

train(dataset, n_steps=2500, print_frequency=50, use_covariates=True, l1_regularization_strength=0.1, l2_regularization_strength=None, eps_range=(0.001, 1.0), subsample_size_cells=None, subsample_size_genes=None, initialization_type='scratch', **kargs)[source]

Train the model. The trained parameter are stored in the pyro.param_store and can be accessed via get_params().

Parameters
  • dataset (GeneDataset) – Dataset to train the model on

  • n_steps (int) – number of training step

  • print_frequency (int) – how frequently to print loss to screen

  • use_covariates (bool) – if true, use covariates, if false use cell type information only

  • l1_regularization_strength (float) – controls the strength of the L1 regularization on the regression coefficients. If None there is no L1 regularization.

  • l2_regularization_strength (Optional[float]) – controls the strength of the L2 regularization on the regression coefficients. If None there is no L2 regularization.

  • eps_range (Tuple[float, float]) – range of the possible values of the gene-specific noise. Must the a strictly positive range.

  • subsample_size_genes (Optional[int]) – for large dataset, the minibatch can be created using a subset of genes.

  • subsample_size_cells (Optional[int]) – for large dataset, the minibatch can be created using a subset of cells.

  • initialization_type (str) – Either “scratch”, “pretrained” or “resume”. If “resume” both the model and optimizer state are kept and training restart from where it was left off. If “pretrained” the model state is kept but the optimizer state is erased. If “scratch” (default) both the model and optimizer state are erased (i.e. simulation start from scratch).

  • kargs – unused parameters

Note

If you get an out-of-memory error try to tune the subsample_size_cells and subsample_size_genes.

train_and_test(train_dataset, test_dataset, test_num_samples=10, train_steps=2500, train_print_frequency=50, use_covariates=True, l1_regularization_strength=0.1, l2_regularization_strength=None, eps_range=(0.001, 1.0), subsample_size_cells=None, subsample_size_genes=None, initialization_type='scratch')[source]

Utility method which sequentially calls the methods train() and predict().

Parameters
  • train_dataset (GeneDataset) – Dataset to train the model on

  • test_dataset (GeneDataset) – Dataset to run the prediction on

  • test_num_samples (int) – how many random samples to draw from the predictive distribution

  • train_steps (int) – number of training step

  • train_print_frequency (int) – how frequently to print loss to screen during training

  • use_covariates (bool) – if true, use covariates, if false use cell type information only

  • l1_regularization_strength (float) – controls the strength of the L1 regularization on the regression coefficients. If None there is no L1 regularization.

  • l2_regularization_strength (Optional[float]) – controls the strength of the L2 regularization on the regression coefficients. If None there is no L2 regularization.

  • eps_range (Tuple[float, float]) – range of the possible values of the gene-specific noise. Must the a strictly positive range.

  • subsample_size_genes (Optional[int]) – for large dataset, the minibatch can be created using a subset of genes.

  • subsample_size_cells (Optional[int]) – for large dataset, the minibatch can be created using a subset of cells.

  • initialization_type (str) – Either “scratch”, “pretrained” or “resume”. If “resume” both the model and optimizer state are kept and training restart from where it was left off. If “pretrained” the model state is kept but the optimizer state is erased. If “scratch” (default) both the model and optimizer state are erased (i.e. simulation start from scratch).

Returns

metrics – See predict().

Return type

(DataFrame, DataFrame)

class LogNormalPoisson(n_trials, log_rate, noise_scale, *, num_quad_points=8, validate_args=None)[source]

A Poisson distribution with rate: \(r = N \times \exp\left[ \log \mu + \epsilon \right]\) where noise is normally distributed with mean zero and variance sigma, i.e. \(\epsilon \sim N(0, \sigma)\).

See Mingyuan for discussion of the nice properties of the LogNormalPoisson model.

__init__(n_trials, log_rate, noise_scale, *, num_quad_points=8, validate_args=None)[source]
Parameters
  • n_trials (Tensor) – non-negative number of Poisson trials, i.e. N.

  • log_rate (Tensor) – the log_rate of a single trial, i.e. \(\log \mu\).

  • noise_scale (Tensor) – controls the level of the injected noise, i.e. \(\sigma\).

  • num_quad_points – number of quadrature points used to compute the (approximate) log_prob. Defaults to 8.