Source code for tissue_purifier.genex.gene_visualization

import numpy
import torch
import matplotlib.pyplot as plt


[docs]def plot_gene_hist(cell_types_n, value1_ng, value2_ng=None, bins=20) -> plt.Figure: """ Plot the per cell-type histogram. If :attr:`value2_ng` is defined the two histogram are interlieved. Args: cell_types_n: tensor of shape N with the cell type labels (with K distinct values) value1_ng: the first quantity to whose histogram is computed lot of shape (N,G) value2_ng: the second quantity to plot of shape (N,G) (optional) bins: number of bins in the histogram Returns: fig: A figure with G rows and K columns where K is the number of distinct cell types. """ def _to_torch(_x): if isinstance(_x, torch.Tensor): return _x elif isinstance(_x, numpy.ndarray): return torch.tensor(_x) else: raise Exception("Expected torch.tensor or numpy.ndarray. Received {0}".format(type(_x))) # TODO: use seaborn instead # fig, ax = plt.subplots(figsize=(12, 6)) # _ = seaborn.histplot(data=df, x="eps", hue="cell_type", bins=200, ax=ax, multiple="dodge") raise NotImplementedError value2_ng = value1_ng if value2_ng is None else value2_ng value1_torch_ng = _to_torch(value1_ng) value2_torch_ng = value1_torch_ng if value2_ng is None else _to_torch(value1_ng) assert len(cell_types_n.shape) == 1 assert len(value1_ng.shape) >= 2 assert cell_types_n.shape[0] == value1_ng.shape[-2] assert value1_torch_ng.shape == value2_torch_ng.shape assert value1_torch_ng.dtype == value2_torch_ng.dtype ctypes = torch.unique(cell_types_n) genes = value1_ng.shape[-1] nrows = genes ncols = len(ctypes) fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(4 * ncols, 4 * nrows)) for r in range(genes): for c, c_type in enumerate(ctypes): mask = (cell_types_n == c_type) if r == 0: _ = axes[r, c].set_title("cell_type {0}".format(c_type)) if value2_ng is None: v1 = value1_torch_ng[mask][r] if v1.dtype == torch.long: y1 = torch.bincount(v1) x1 = torch.arange(y1.shape[0]) barWidth = 0.9 * (x1[1] - x1[0]) _ = axes[r, c].bar(x1, y1, width=barWidth) else: y1, x1 = numpy.histogram(v1, bins=bins, density=True) barWidth = 0.9 * (x1[1] - x1[0]) _ = axes[r, c].bar(x1[:-1], y1, width=barWidth) else: v1 = value1_torch_ng[mask][r] v2 = value2_torch_ng[mask][r] a1, b1 = v1.min(), v1.max() a2, b2 = v2.min(), v2.max() a = min(a1.item(), a2.item()) b = max(b1.item(), b2.item(), 1) myrange = (a, b+1) if v1.dtype == torch.long: y1 = torch.bincount(v1, minlength=int(myrange[1])) y2 = torch.bincount(v2, minlength=int(myrange[1])) x = torch.arange(y1.shape[0]) barWidth = 0.4 * (x[1] - x[0]) _ = axes[r, c].bar(x, y1, width=barWidth) _ = axes[r, c].bar(x + barWidth, y2, width=barWidth) else: y1, x1 = numpy.histogram(v1, range=myrange, bins=bins, density=True) y2, x2 = numpy.histogram(v2, range=myrange, bins=bins, density=True) barWidth = 0.4 * (x1[1] - x1[0]) _ = axes[r, c].bar(x1[:-1], y1, width=barWidth) _ = axes[r, c].bar(x2[:-1] + barWidth, y2, width=barWidth) plt.close() return fig