Gene regression¶
This set of classes/functions can be used to predict the gene expression profile based on:
the cell-type and
cell-specific covariates (such as location in space and local micro-environment).
The intended use case is to test whether the semantic features extracted by one the Self-Supervised Lerning (ssl) algorithm contain biologically relevant information.
We treat each cell-type separately. For each cell, the gene counts are modelled as a multinomial distribution:
where \(N=\sum_{g=1}^G n_g\) is the total number of counts in the cell (sometimes referred as the total UMI count) and \(\sum_{g=1}^G p_g=1\) are the probabilities of measuring each gene. When \(N\) is large and \(p_i\) are small the counts for each gene can be approximated by a Poisson distribution with rate \(r_i = N p_i\). Therefore the counts for cell \(n\) and gene \(g\) are modelled as:
To account for noise and the presence of L (cell-specific) covariates, we model the probability as:
where \(\beta_g^0\) is a gene-specific intercepts, \(X_{nl}\) are the cell covariates and \(\epsilon_g \sim N(0,\sigma_g)\) is a noise term representing the gene-specific over-dispersion.
We recap the dimension of the variable involved in the full model (for K different cell-types):
\(X_{nl}\) is a fixed covariate matrix of shape \(N \times L\) (i.e. cells by covariates)
\(N_n\) is a fixed vector of shape \(N\) with the (observed) total counts in a cell.
\(\beta_{kg}^0\) is the intercepts of the regression of shape \(K \times G\) (i.e. cell-types by genes)
\(\beta_{klg}\) are the regression coefficients of shape \(K \times L \times G\) (i.e. cell-types by covariates by genes)
\(\sigma_{kg}\) are the gene over-dispersion of shape \(K \times G\) (i.e. cell-types by genes)
Typical values are \(N \sim 10^3, G \sim 10^3, K\sim 10, L\sim 50\). The goal of the inference is to determine \(\beta^0, \beta\) and \(\sigma\). We enforce a penalty (either L1 or a L2) on the regression coefficients \(\beta\) to encourage them to be small. We put a flat prior on \(\sigma\) which is allowed to vary in a small (predefined) range. There is no prior on \(\beta^0\). Overall the model has two hyper-parameters (the strength of the regularization on \(\beta\) and the allowed range for \(\sigma\)) which are determined by cross-validation.
See notebook3 for an example.
- class GeneDataset(covariates: torch.Tensor, cell_type_ids: torch.Tensor, counts: torch.Tensor, k_cell_types: int, cell_type_mapping: dict, gene_names: List[str])[source]¶
Container for organizing the gene expression data
- cell_type_ids: torch.Tensor¶
long tensor with the cell_type_ids of shape (n)
- cell_type_mapping: dict¶
dictionary with mapping from unique_cell_type to cell_type_ids
- counts: torch.Tensor¶
long tensor with the count data of shape (n, g)
- covariates: torch.Tensor¶
float tensor with the covariates of shape (n, k)
- gene_names: List[str]¶
list of the gene names
- k_cell_types: int¶
number of cell types
- make_gene_dataset_from_anndata(anndata, cell_type_key, covariate_key, preprocess_strategy='raw', apply_pca=False, n_components=0.9)[source]¶
Convert a anndata object into a GeneDataset object which can be used for gene regression.
- Parameters
anndata (
AnnData
) – AnnData object with the raw counts stored in anndata.Xcell_type_key (
str
) – key corresponding to the cell type, i.e. cell_types = anndata.obs[cell_type_key]covariate_key (
str
) – key corresponding to the covariate, i.e. covariates = anndata.obsm[covariate_key]preprocess_strategy (
str
) – either ‘center’, ‘z_score’ or ‘raw’. It describes how to preprocess the covariates. ‘raw’ (default) means no preprocessing.apply_pca (
bool
) – if True, we compute the pca of the covariates. This operation happens after the preprocessing.n_components (
Union
[int
,float
]) – Used only ifapply_pca
== True. If integer specifies the dimensionality of the data after PCA. If float in (0, 1) it auto selects the dimensionality so that the explained variance is at least that value.
- Returns
GeneDataset – a GeneDataset object
- Return type
- train_test_val_split(data, train_size=0.8, test_size=0.15, val_size=0.05, n_splits=1, random_state=None, stratify=True)[source]¶
Utility function used to split the data into train/test/val.
- Parameters
data (
Union
[List
[Tensor
],List
[ndarray
],GeneDataset
]) – the data to split into train/test/valtrain_size (
float
) – the relative size of the train datasettest_size (
float
) – the relative size of the test datasetval_size (
float
) – the relative size of the val datasetn_splits (
int
) – how many times to split the datarandom_state (
Optional
[int
]) – specify the random state for reproducibilitystratify (
bool
) – If true the tran/test/val are stratified so that they contain approximately the same number of example from each class. If data is a list of arrays the 2nd array is assumed to represent the class. If data is a GeneDataset the class is the cell_type.
- Returns
tuple – yields multiple splits of the data.
Example
>>> for train, test, val in train_test_val_split(data=[x,y,z]): >>> x_train, y_train, z_train = train >>> x_test, y_test, z_test = test >>> x_val, y_val, z_val = val >>> ... do something ...
Example
>>> for train, test, val in train_test_val_split(data=GeneDataset): >>> assert isinstance(train, GeneDataset) >>> assert isinstance(test, GeneDataset) >>> assert isinstance(val, GeneDataset) >>> ... do something ...
- plot_gene_hist(cell_types_n, value1_ng, value2_ng=None, bins=20)[source]¶
Plot the per cell-type histogram. If
value2_ng
is defined the two histogram are interlieved.- Parameters
cell_types_n – tensor of shape N with the cell type labels (with K distinct values)
value1_ng – the first quantity to whose histogram is computed lot of shape (N,G)
value2_ng – the second quantity to plot of shape (N,G) (optional)
bins – number of bins in the histogram
- Returns
fig – A figure with G rows and K columns where K is the number of distinct cell types.
- Return type
Figure
- class GeneRegression[source]¶
Given the cell-type labels and some covariates the model predicts the gene expression. The counts are modelled as a LogNormalPoisson process. See documentation for more details.
- configure_optimizer(optimizer_type='adam', lr=0.005, betas=(0.9, 0.999), momentum=0.9, alpha=0.99)[source]¶
Configure the optimizer to use.
- Parameters
optimizer_type (
str
) – Either ‘adam’ (default), ‘sgd’ or ‘rmsprop’lr (
float
) – learning ratebetas (
Tuple
[float
,float
]) – betas for ‘adam’ optimizer. Ignored ifoptimizer_type
is not ‘adam’.momentum (
float
) – momentum for ‘sgd’ optimizer. Ignored ifoptimizer_type
is not ‘sgd’.alpha (
float
) – alpha for ‘rmsprop’ optimizer. Ignored ifoptimizer_type
is not ‘rmsprop’.
- extend_train(dataset, n_steps=2500, print_frequency=50)[source]¶
Utility methods which calls
train()
with the same parameter just used effectively extending the training.- Parameters
dataset (
GeneDataset
) – Dataset to train the model onn_steps (
int
) – number of training stepprint_frequency (
int
) – how frequently to print loss to screen
- extend_train_and_test(train_dataset, test_dataset, test_num_samples=10, train_steps=2500, train_print_frequency=50)[source]¶
Utility method which sequentially calls the methods
extend_train()
andpredict()
.- Parameters
train_dataset (
GeneDataset
) – Dataset to train the model ontest_dataset (
GeneDataset
) – Dataset to run the prediction ontest_num_samples (
int
) – how many random samples to draw from the predictive distributiontrain_steps (
int
) – number of training steptrain_print_frequency (
int
) – how frequently to print loss to screen during training
- Returns
metrics – See
predict()
.- Return type
(
DataFrame
,DataFrame
)
- get_params()[source]¶
- Returns
df – dataframe with the fitted parameters.
Note
This method can be used in combination with \(load_ckpt\) to inspect the fitted parameters of a previous run.
Examples
>>> gr = GeneRegression() >>> gr.load_ckpt(filename="my_old_ckpt.pt") >>> df_beta0, df_beta, df_eps = gr.get_params() >>> df_beta0.head()
- Return type
(
DataFrame
,DataFrame
,DataFrame
)
- load_ckpt(filename, map_location=None)[source]¶
Load the full state of the model and optimizer from disk. Use it in pair with
save_ckpt()
.
- property optimizer: pyro.optim.PyroOptim¶
The optimizer associated with this model.
- Return type
PyroOptim
- predict(dataset, num_samples=10, subsample_size_cells=None, subsample_size_genes=None)¶
Use the parameters currently in the param_store to run the prediction and report some metrics. If you want to run the prediction based on a different set of parameters you need to call
load_ckpt()
first.The Q metric is \(Q = E\left[|X_{i,g} - Y_{i,g}|\right]\) where X is the (observed) data and Y is a sample from the predicted posterior and (i,g) indicates cell and genes respectively.
The log_score metric is \(\text{log_score} = \log p_\text{posterior}\left(X_\text{data}\right)\)
- Parameters
dataset (
GeneDataset
) – the dataset to run the prediction onnum_samples (
int
) – how many random samples to draw from the predictive distributionsubsample_size_cells (
Optional
[int
]) – if not None (defaults) the prediction are made in chunks to avoid memory issuesubsample_size_genes (
Optional
[int
]) – if not None (defaults) the prediction are made in chunks to avoid memory issue
- Returns
df_metric – For each cell_type and gene we report the Q and log_score metrics
df_counts – For each cell and gene we report the observed counts and a single sample from the posterior
- Return type
(
DataFrame
,DataFrame
)
- static remove_params(beta0=False, beta=False, eps=False)[source]¶
Selectively remove parameters from param_store.
- Parameters
beta0 (
bool
) – If True (defaults is False) removebeta0
of shape \((N, G)\) from the param_storebeta (
bool
) – If True (defaults is False) removebeta
of shape \((N, L, G)\) from the param_storeeps (
bool
) – If True (defaults is False) removeeps
of shape \((N, G)\) from the param_store
Note
This is useful in combination with
load_ckpt()
andtrain()
. For example you might have fitted a model with l1 covariate and wanting to try a different model with l2 covariate. You can load the previous ckpt and removebeta
while keepingbeta0
andeps
(which do not depend on the number of covariate)Example
>>> gr.load_ckpt("ckpt_with_l1_covariate.pt") >>> gr.remove_from_param_store(beta=True) >>> gr.train(dataset=dataset_with_l2_covariate, initialization_type="pretrained")
- save_ckpt(filename)[source]¶
Save the full state of the model and optimizer to disk. Use it in pair with
load_ckpt()
.Note
Pyro saves unconstrained parameters and the constrain transformation. This means that if you manually “look inside” the ckpt you will see strange values. To get the actual value of the fitted parameter use the
get_params()
method.
- show_loss(figsize=(4, 4), logx=False, logy=False, ax=None)[source]¶
Show the loss history. Useful for checking if the training has converged.
- Parameters
figsize (
Tuple
[float
,float
]) – the size of the image. Used only if ax=Nonelogx (
bool
) – if True the x_axis is shown in logarithmic scalelogy (
bool
) – if True the x_axis is shown in logarithmic scaleax – The axes object to draw the plot onto. If None (defaults) creates a new figure.
- train(dataset, n_steps=2500, print_frequency=50, use_covariates=True, l1_regularization_strength=0.1, l2_regularization_strength=None, eps_range=(0.001, 1.0), subsample_size_cells=None, subsample_size_genes=None, initialization_type='scratch', **kargs)[source]¶
Train the model. The trained parameter are stored in the pyro.param_store and can be accessed via
get_params()
.- Parameters
dataset (
GeneDataset
) – Dataset to train the model onn_steps (
int
) – number of training stepprint_frequency (
int
) – how frequently to print loss to screenuse_covariates (
bool
) – if true, use covariates, if false use cell type information onlyl1_regularization_strength (
float
) – controls the strength of the L1 regularization on the regression coefficients. If None there is no L1 regularization.l2_regularization_strength (
Optional
[float
]) – controls the strength of the L2 regularization on the regression coefficients. If None there is no L2 regularization.eps_range (
Tuple
[float
,float
]) – range of the possible values of the gene-specific noise. Must the a strictly positive range.subsample_size_genes (
Optional
[int
]) – for large dataset, the minibatch can be created using a subset of genes.subsample_size_cells (
Optional
[int
]) – for large dataset, the minibatch can be created using a subset of cells.initialization_type (
str
) – Either “scratch”, “pretrained” or “resume”. If “resume” both the model and optimizer state are kept and training restart from where it was left off. If “pretrained” the model state is kept but the optimizer state is erased. If “scratch” (default) both the model and optimizer state are erased (i.e. simulation start from scratch).kargs – unused parameters
Note
If you get an out-of-memory error try to tune the
subsample_size_cells
andsubsample_size_genes
.
- train_and_test(train_dataset, test_dataset, test_num_samples=10, train_steps=2500, train_print_frequency=50, use_covariates=True, l1_regularization_strength=0.1, l2_regularization_strength=None, eps_range=(0.001, 1.0), subsample_size_cells=None, subsample_size_genes=None, initialization_type='scratch')[source]¶
Utility method which sequentially calls the methods
train()
andpredict()
.- Parameters
train_dataset (
GeneDataset
) – Dataset to train the model ontest_dataset (
GeneDataset
) – Dataset to run the prediction ontest_num_samples (
int
) – how many random samples to draw from the predictive distributiontrain_steps (
int
) – number of training steptrain_print_frequency (
int
) – how frequently to print loss to screen during traininguse_covariates (
bool
) – if true, use covariates, if false use cell type information onlyl1_regularization_strength (
float
) – controls the strength of the L1 regularization on the regression coefficients. If None there is no L1 regularization.l2_regularization_strength (
Optional
[float
]) – controls the strength of the L2 regularization on the regression coefficients. If None there is no L2 regularization.eps_range (
Tuple
[float
,float
]) – range of the possible values of the gene-specific noise. Must the a strictly positive range.subsample_size_genes (
Optional
[int
]) – for large dataset, the minibatch can be created using a subset of genes.subsample_size_cells (
Optional
[int
]) – for large dataset, the minibatch can be created using a subset of cells.initialization_type (
str
) – Either “scratch”, “pretrained” or “resume”. If “resume” both the model and optimizer state are kept and training restart from where it was left off. If “pretrained” the model state is kept but the optimizer state is erased. If “scratch” (default) both the model and optimizer state are erased (i.e. simulation start from scratch).
- Returns
metrics – See
predict()
.- Return type
(
DataFrame
,DataFrame
)
- class LogNormalPoisson(n_trials, log_rate, noise_scale, *, num_quad_points=8, validate_args=None)[source]¶
A Poisson distribution with rate: \(r = N \times \exp\left[ \log \mu + \epsilon \right]\) where noise is normally distributed with mean zero and variance sigma, i.e. \(\epsilon \sim N(0, \sigma)\).
See Mingyuan for discussion of the nice properties of the LogNormalPoisson model.
- __init__(n_trials, log_rate, noise_scale, *, num_quad_points=8, validate_args=None)[source]¶
- Parameters
n_trials (
Tensor
) – non-negative number of Poisson trials, i.e. N.log_rate (
Tensor
) – the log_rate of a single trial, i.e. \(\log \mu\).noise_scale (
Tensor
) – controls the level of the injected noise, i.e. \(\sigma\).num_quad_points – number of quadrature points used to compute the (approximate) log_prob. Defaults to 8.