Source code for tissue_purifier.plots.plot_images

from typing import Optional, Tuple, List, Union

import numpy
import torch
import matplotlib
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
from matplotlib.cm import ScalarMappable
from torchvision.transforms.functional import to_pil_image
from torchvision.utils import make_grid
from torchvision.transforms import CenterCrop as trsfm_center_crop
from torchvision.transforms import Compose as trsfm_compose


def _get_color_tensor(_cmap, _ch):
    cm = plt.get_cmap(_cmap, _ch)
    x = numpy.linspace(0.0, 1.0, _ch)
    colors_np = cm(x)
    color = torch.Tensor(colors_np)[:, :3]
    assert color.shape[0] == _ch
    return color


def _minmax_scale_tensor(tensor, in_range: Tuple[float, float] = None):
    """ Clamp tensor in_range and transform to (0,1) range """
    if in_range is None:
        in_range_min, in_range_max = torch.min(tensor), torch.max(tensor)
    else:
        in_range_min, in_range_max = in_range

    dist = in_range_max - in_range_min
    scale = 1.0 if dist == 0.0 else 1.0 / dist
    return tensor.clamp(min=in_range_min,
                        max=in_range_max).add_(other=in_range_min,
                                               alpha=-1.0).mul_(other=scale).clamp_(min=0.0, max=1.0)


[docs]def pad_and_crop_and_stack(x: List[torch.Tensor], pad_value: float = 0.0) -> torch.Tensor: """ Takes a list of tensor and returns a single batched_tensor. It is useful for visualization. Args: x: a list of tensor with the same channel dimension but possibly different width and heigth pad_value: float, the value used in padding the images. Defaults to padding with black colors Returns: tensor: A single batch tensor of shape :math:`(B, c, \\text{width}_\\text{max}, \\text{height}_\\text{max})` """ widths = [tmp.shape[-2] for tmp in x] heigths = [tmp.shape[-1] for tmp in x] w_min = min(widths) w_max = max(widths) h_min = min(heigths) h_max = max(heigths) pad_w = w_max - w_min pad_h = h_max - h_min padder = torch.nn.ConstantPad2d((pad_w, pad_w, pad_h, pad_h), value=pad_value) cropper = trsfm_center_crop(size=(w_max, h_max)) transform = trsfm_compose([padder, cropper]) imgs_batched = torch.stack([transform(tmp) for tmp in x], dim=0) return imgs_batched
[docs]def show_batch( tensor: torch.Tensor, cmap: str = None, n_col: int = 4, n_padding: int = 10, title: str = None, pad_value: int = 1, normalize: bool = True, normalize_range: Tuple[float, float] = None, figsize: Tuple[float, float] = None) -> plt.Figure: """ Visualize a torch tensor of shape: :math:`(*, \\text{ch}, \\text{width}, \\text{height})`. It works for any number of leading dimensions. Args: tensor: the torch.Tensor to plot cmap: the color map to use. If None, it defaults to 'gray' for 1 channel images, RGB for 3 channels images and 'tab20' for images with more that 3 channels. n_col: int, number of columns in the image grid n_padding: int, padding between images in the grid title: str, the tile on the image pad_value: float, pad_value normalize: bool, if tru normalize the tensor in normalize_range normalize_range: tuple, if not specified it is set to (min_image, max_image) figsize: size of the figure """ assert len(tensor.shape) >= 4 # *, ch, width, height tensor = tensor.flatten(end_dim=-4) # -1, ch, width, height ch = tensor.shape[-3] if ch > 3: cmap = 'tab20' if cmap is None else cmap colors = _get_color_tensor(cmap, ch) images = torch.einsum('...cwh,cj->...jwh', tensor, colors.to(device=tensor.device).float()) elif ch == 3: images = tensor else: images = tensor[..., :1, :, :] images = images.cpu().to(dtype=torch.float32) # upgrade to full precision if working in half precision. n_images = images.shape[-4] n_row = int(numpy.ceil(float(n_images) / n_col)) # Always normalize the image in (0,1) either using min_max of tensor or normalize_range grid = make_grid(images, n_col, n_padding, normalize=normalize, value_range=normalize_range, scale_each=False, pad_value=pad_value) figsize = (4 * n_col, 4 * n_row) if figsize is None else figsize fig = plt.figure(figsize=figsize) plt.imshow(grid.detach().permute(1, 2, 0).squeeze(-1).numpy()) # plt.axis("off") if isinstance(title, str): plt.title(title) fig.tight_layout() plt.close(fig) return fig
def _show_raw_one_channel( tensor: torch.Tensor, ax: "matplotlib.axes.Axes", cmap: str, in_range: Union[str, Tuple[float, float]] = 'image', title: Optional[str] = None): # normalization if in_range == 'image': tensor = _minmax_scale_tensor(tensor, in_range=(torch.min(tensor).item(), torch.max(tensor).item())) else: tensor = _minmax_scale_tensor(tensor, in_range=in_range) _ = ax.imshow(numpy.asarray(to_pil_image(tensor)), cmap=cmap) if title is not None: ax.set_title(title) def _show_raw_all_channels( tensor: torch.Tensor, ax: "matplotlib.axes.Axes", title: Optional[str] = None): assert len(tensor.shape) == 3 # normalization tensor = _minmax_scale_tensor(tensor, in_range=(torch.min(tensor).item(), torch.max(tensor).item())) _ = ax.imshow(numpy.asarray(to_pil_image(tensor))) if title is not None: ax.set_title(title)
[docs]def show_raw_one_channel( data: Union[torch.Tensor, List[torch.Tensor]], n_col: int = 4, cmap: str = None, in_range: Union[str, Tuple[float, float]] = 'image', scale_each: bool = True, figsize: Tuple[float, float] = None, titles: List[str] = None, sup_title: str = None, show_axis: bool = True) -> plt.Figure: """ Visualize a torch tensor of shape :math:`(*, \\text{width}, \\text{height})`. or a list of tensor of shape :math:`(\\text{width}, \\text{height})`. Each leading dimension is shown separately. Args: data: A torch.tensor of shape :math:`(*, \\text{width}, \\text{height})` or list of tensor of shape :math:`(\\text{width}, \\text{height})`. n_col: number of columns to plot the data cmap: the matplotlib color map to use. If None use 'gray' colormap. in_range: Either a tuple specifying a a min and max value (for clamping) or a string 'image'. If 'image' the min and max value are computed form the image itself. Value are clamped in_range and then transformed to range (0.0, 1.0) for visualization. scale_each: bool if true each leading dimension is scaled by itself. It has effect only if in_range = 'image' figsize: Optional, the tuple with the width and height of the rendered figure titles: list with the titles for each small image sup_title: str, the title for the entire image show_axis: If True (defaults) show the axis. Returns: fig: A figure with `(*)` panels. Each panel is a rendering of a tensor of shape :math:`(\\text{width}, \\text{height})` """ if isinstance(data, list): tmp = [len(tensor.shape) == 2 for tensor in data] assert all(tmp) n_max = len(data) data = [tmp.float() for tmp in data] if in_range == 'image' and not scale_each: mins = [torch.min(tensor) for tensor in data] maxs = [torch.max(tensor) for tensor in data] in_range = min(mins), max(maxs) elif isinstance(data, torch.Tensor): data = data.flatten(end_dim=-3).float() # shape: (*, w, h) n_max = data.shape[0] if in_range == 'image' and not scale_each: in_range = torch.min(data), torch.max(data) else: raise Exception("Expected Union[tensor, List[tensor]]. Received {0}".format(type(data))) assert titles is None or (isinstance(titles, list) and len(titles) == n_max) ncols = min(n_col, n_max) nrows = int(numpy.ceil(n_max / ncols)) figsize = (4*ncols, 4*nrows) if figsize is None else figsize if nrows == 1: fig, ax = plt.subplots(ncols=ncols, figsize=figsize) else: fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=figsize) n = 0 for r in range(nrows): for c in range(ncols): if nrows == 1 and ncols == 1: ax_curr = ax elif nrows == 1: ax_curr = ax[c] else: ax_curr = ax[r, c] if n < n_max: _show_raw_one_channel( tensor=data[n], ax=ax_curr, cmap='gray' if cmap is None else cmap, in_range=in_range, title=None if titles is None else titles[n], ) n += 1 if show_axis: ax_curr.set_axis_on() else: ax_curr.set_axis_off() else: ax_curr.set_axis_off() if sup_title: fig.suptitle(sup_title) plt.close(fig) return fig
[docs]def show_raw_all_channels( data: Union[torch.Tensor, List[torch.Tensor]], n_col: int = 4, cmap: str = None, figsize: Tuple[float, float] = None, titles: List[str] = None, sup_title: str = None, show_colorbar: Optional[bool] = None, legend_colorbar: List[str] = None, show_axis: bool = True) -> plt.Figure: """ Visualize a torch tensor of shape: :math:`(*, \\text{ch}, \\text{width}, \\text{height})` or a list of tensors of shape :math:`(\\text{ch}, \\text{width}, \\text{height})`. Args: data: A torch.tensor of shape :math:`(*, \\text{ch}, \\text{width}, \\text{height})` or a list of tensor of shape :math:`(\\text{ch}, \\text{width}, \\text{height})` n_col: number of columns to plot the data cmap: the matplotlib color map to use. Defaults to RBG (if data has 3 channels) or 'tab20' otherwise. figsize: Optional, the tuple with the width and height of the rendered figure titles: list with the titles for each panel sup_title: the title for the entire image show_colorbar: bool, if yes show the color bar legend_colorbar: legend for the colorbar show_axis: If True (default) show the axis Returns: fig: A figure with `*` panels each panel is a rendering of a tensors of shape :math:`(\\text{ch}, \\text{width}, \\text{height})` """ if isinstance(data, torch.Tensor) and len(data.shape) == 3: data = data.unsqueeze(dim=0) if isinstance(data, torch.Tensor): data = data.detach().clone().cpu().float() elif isinstance(data, list): data = [tmp.detach().clone().cpu().float() for tmp in data] else: raise Exception("Expected either a tensor or a list of tensors. Received {0}".format(type(data))) # check the images have all the same channels chs = [data[n].shape[-3] for n in range(0, len(data))] check = [ch == chs[0] for ch in chs] assert all(check), "The images have different number of channels {0}".format(chs) # extract the channels and the colors to use ch = chs[0] assert legend_colorbar is None or len(legend_colorbar) == ch if cmap is None: if ch == 3: colors = torch.eye(3) else: colors = _get_color_tensor('tab20', ch) else: colors = _get_color_tensor(cmap, ch) colors = colors.to(data[0].device).float() if isinstance(data, torch.Tensor): imgs = torch.einsum('...cwh,cj->...jwh', data, colors).detach().clone().cpu() else: imgs = [torch.einsum('...cwh,cj->...jwh', img, colors).detach().clone().cpu() for img in data] # set the canvas n_max = imgs.shape[0] if isinstance(data, torch.Tensor) else len(data) ncols = min(n_col, n_max) nrows = int(numpy.ceil(n_max / ncols)) figsize = (4*ncols, 4*nrows) if figsize is None else figsize assert titles is None or (isinstance(titles, list) and len(titles) == n_max) if nrows == 1: fig, axes = plt.subplots(ncols=ncols, figsize=figsize) else: fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=figsize) n = 0 for r in range(nrows): for c in range(ncols): if nrows == 1 and ncols == 1: ax_curr = axes elif nrows == 1: ax_curr = axes[c] else: ax_curr = axes[r, c] if n < n_max: _show_raw_all_channels( tensor=imgs[n], ax=ax_curr, title=None if titles is None else titles[n], ) n += 1 if show_axis: ax_curr.set_axis_on() else: ax_curr.set_axis_off() else: ax_curr.set_axis_off() if sup_title: fig.suptitle(sup_title) show_colorbar = legend_colorbar is not None if show_colorbar is None else show_colorbar if show_colorbar: discrete_cmp = ListedColormap(colors.numpy()) normalizer = matplotlib.colors.BoundaryNorm( boundaries=numpy.linspace(-0.5, ch - 0.5, ch + 1), ncolors=ch, clip=True) scalar_mappable = matplotlib.cm.ScalarMappable(norm=normalizer, cmap=discrete_cmp) cbar = fig.colorbar(scalar_mappable, ticks=numpy.arange(ch), ax=axes) if legend_colorbar is None: legend_colorbar = numpy.arange(ch).tolist() cbar.ax.set_yticklabels(legend_colorbar) plt.close(fig) return fig