Source code for tissue_purifier.genex.gene_utils

import torch
import numpy
from scanpy import AnnData
from typing import NamedTuple, Union, List, Any, Optional
from sklearn.model_selection import StratifiedShuffleSplit, ShuffleSplit
from tissue_purifier.utils.validation_util import SmartPca

def _make_labels(y: Union[torch.Tensor, numpy.ndarray, List[Any]]) -> (torch.Tensor, dict):
    def _to_numpy(_y):
        if isinstance(_y, numpy.ndarray):
            return _y
        elif isinstance(_y, list):
            return numpy.array(_y)
        elif isinstance(_y, torch.Tensor):
            return _y.detach().cpu().numpy()

    y_np = _to_numpy(y)
    unique_labels = numpy.unique(y_np)
    y_to_ids = dict(zip(unique_labels, numpy.arange(y_np.shape[0])))
    labels = torch.tensor([y_to_ids[tmp] for tmp in y_np])
    return labels, y_to_ids

[docs]class GeneDataset(NamedTuple): """ Container for organizing the gene expression data """ # order is important. Do not change #: float tensor with the covariates of shape (n, k) covariates: torch.Tensor #: long tensor with the cell_type_ids of shape (n) cell_type_ids: torch.Tensor #: long tensor with the count data of shape (n, g) counts: torch.Tensor #: number of cell types k_cell_types: int #: dictionary with mapping from unique_cell_type to cell_type_ids cell_type_mapping: dict #: list of the gene names gene_names: List[str]
[docs] def describe(self): """ Method which described the content and the GeneDataset. """ for k, v in zip(self._fields, self): if isinstance(v, int) or isinstance(v, dict): print("{} ---> {}".format(k.ljust(20), v)) elif isinstance(v, list): print("{} ---> list of length {}".format(k.ljust(20), len(v))) else: print("{} ---> {}".format(k.ljust(20), v.shape))
[docs]def make_gene_dataset_from_anndata( anndata: AnnData, cell_type_key: str, covariate_key: str, preprocess_strategy: str = 'raw', apply_pca: bool = False, n_components: Union[int, float] = 0.9) -> GeneDataset: """ Convert a anndata object into a GeneDataset object which can be used for gene regression. Args: anndata: AnnData object with the raw counts stored in anndata.X cell_type_key: key corresponding to the cell type, i.e. cell_types = anndata.obs[cell_type_key] covariate_key: key corresponding to the covariate, i.e. covariates = anndata.obsm[covariate_key] preprocess_strategy: either 'center', 'z_score' or 'raw'. It describes how to preprocess the covariates. 'raw' (default) means no preprocessing. apply_pca: if True, we compute the pca of the covariates. This operation happens after the preprocessing. n_components: Used only if :attr:`apply_pca` == True. If integer specifies the dimensionality of the data after PCA. If float in (0, 1) it auto selects the dimensionality so that the explained variance is at least that value. Returns: GeneDataset: a GeneDataset object """ def _to_torch(_x): if isinstance(_x, torch.Tensor): _y = _x.detach().cpu() elif isinstance(_x, numpy.ndarray): _y = torch.from_numpy(_x) else: raise Exception("Error. Expected torch.Tensor or numpy.array. Received {}".format(type(_x))) assert torch.all(torch.isfinite(_y)), "Error. Tensor is not finite {}".format(_y) return _y assert preprocess_strategy in {'center', 'z_score', 'raw'}, \ 'Preprocess strategy must be either "center", "z_score" or "raw"' assert (isinstance(n_components, int) and n_components >= 1) or \ (isinstance(n_components, float) and 0.0 < n_components < 1.0), \ "n_components must be a positive integer or a float in (0.0, 1.0)" assert cell_type_key in set(anndata.obs.keys()), "Cell_type_key is not present in anndata.obs.keys()" assert covariate_key in set(anndata.obsm.keys()), "Covariate_key is not present in anndata.obsm.keys()" cell_types = list(anndata.obs[cell_type_key].values) cell_type_ids_n, mapping_dict = _make_labels(cell_types) counts_ng = torch.tensor(anndata.X.toarray()).long() covariates_nl_raw = torch.tensor(anndata.obsm[covariate_key]) if not torch.all(torch.isfinite(covariates_nl_raw)): mask = torch.isfinite(covariates_nl_raw) raise ValueError("there are n={} not finite covariates in anndata file".format(torch.sum(~mask))) assert len(counts_ng.shape) == 2 assert len(covariates_nl_raw.shape) == 2 assert len(cell_type_ids_n.shape) == 1 assert counts_ng.shape[0] == covariates_nl_raw.shape[0] == cell_type_ids_n.shape[0] k_cell_types = cell_type_ids_n.max().item() + 1 # +1 b/c ids start from zero if apply_pca: new_covariate = SmartPca(preprocess_strategy=preprocess_strategy).fit_transform(data=covariates_nl_raw, n_components=n_components) elif preprocess_strategy == 'z_score': std, mean = torch.std_mean(covariates_nl_raw, dim=-2, unbiased=True, keepdim=True) mask = (std == 0.0) std[mask] = 1.0 new_covariate = (covariates_nl_raw - mean) / std elif preprocess_strategy == 'center': mean = torch.mean(covariates_nl_raw, dim=-2, keepdim=True) new_covariate = covariates_nl_raw - mean else: new_covariate = covariates_nl_raw return GeneDataset( cell_type_ids=_to_torch(cell_type_ids_n), covariates=_to_torch(new_covariate), counts=_to_torch(counts_ng), k_cell_types=k_cell_types, cell_type_mapping=mapping_dict, gene_names=list(anndata.var_names), )
[docs]def train_test_val_split( data: Union[List[torch.Tensor], List[numpy.ndarray], GeneDataset], train_size: float = 0.8, test_size: float = 0.15, val_size: float = 0.05, n_splits: int = 1, random_state: int = None, stratify: bool = True): """ Utility function used to split the data into train/test/val. Args: data: the data to split into train/test/val train_size: the relative size of the train dataset test_size: the relative size of the test dataset val_size: the relative size of the val dataset n_splits: how many times to split the data random_state: specify the random state for reproducibility stratify: If true the tran/test/val are stratified so that they contain approximately the same number of example from each class. If data is a list of arrays the 2nd array is assumed to represent the class. If data is a GeneDataset the class is the cell_type. Returns: tuple: yields multiple splits of the data. Example: >>> for train, test, val in train_test_val_split(data=[x,y,z]): >>> x_train, y_train, z_train = train >>> x_test, y_test, z_test = test >>> x_val, y_val, z_val = val >>> ... do something ... Example: >>> for train, test, val in train_test_val_split(data=GeneDataset): >>> assert isinstance(train, GeneDataset) >>> assert isinstance(test, GeneDataset) >>> assert isinstance(val, GeneDataset) >>> ... do something ... """ if train_size <= 0: raise ValueError("Train_size must be > 0") if test_size <= 0: raise ValueError("Test_size must be > 0") if val_size <= 0: raise ValueError("Val_size must be > 0") if isinstance(data, List): arrays = data elif isinstance(data, GeneDataset): # same order as in the definition of GeneDataset NamedTuple arrays = [data.covariates, data.cell_type_ids, data.counts] else: raise ValueError("data must be a list or a GeneDataset") dims_actual = [a.shape[0] for a in arrays] dims_expected = [dims_actual[0]] * len(dims_actual) assert all(a == b for a, b in zip(dims_actual, dims_expected)), \ "Error. All leading dimensions should be the same" # Normalize the train/test/val sizes norm0 = train_size + test_size + val_size train_size_norm0 = train_size / norm0 test_and_val_size_norm0 = (test_size + val_size) / norm0 norm1 = test_size + val_size test_size_norm1 = test_size / norm1 val_size_norm1 = val_size / norm1 if stratify: sss0 = StratifiedShuffleSplit(n_splits=n_splits, train_size=train_size_norm0, test_size=test_and_val_size_norm0, random_state=random_state) sss1 = StratifiedShuffleSplit(n_splits=1, train_size=val_size_norm1, test_size=test_size_norm1, random_state=random_state) else: sss0 = ShuffleSplit(n_splits=n_splits, train_size=train_size_norm0, test_size=test_and_val_size_norm0, random_state=random_state) sss1 = ShuffleSplit(n_splits=1, train_size=val_size_norm1, test_size=test_size_norm1, random_state=random_state) # Part common to both stratified and not stratified trains, tests, vals = None, None, None for index_train, index_test_and_val in sss0.split(*arrays[:2]): trains = [a[index_train] for a in arrays] test_and_val = [a[index_test_and_val] for a in arrays] for index_val, index_test in sss1.split(*test_and_val[:2]): tests = [a[index_test] for a in test_and_val] vals = [a[index_val] for a in test_and_val] if isinstance(data, GeneDataset): # copy the k_cell_types and cell_type_mapping in the train/test/val trains += [data.k_cell_types, data.cell_type_mapping, data.gene_names] tests += [data.k_cell_types, data.cell_type_mapping, data.gene_names] vals += [data.k_cell_types, data.cell_type_mapping, data.gene_names] yield GeneDataset._make(trains), GeneDataset._make(tests), GeneDataset._make(vals) else: yield trains, tests, vals