Source code for tissue_purifier.data.datamodule

from __future__ import annotations
from argparse import ArgumentParser
from argparse import Action as ArgparseAction
import numpy
import os.path
from pytorch_lightning import LightningDataModule
from anndata import read_h5ad
from typing import Dict, Callable, Optional, Tuple, List, Iterable, Any
import torch
import torchvision
from os import cpu_count
from scanpy import AnnData

from tissue_purifier.models.patch_analyzer import SpatialAutocorrelation
from .sparse_image import SparseImage
from .transforms import (
    DropoutSparseTensor,
    SparseToDense,
    TransformForList,
    Rasterize,
    RandomHFlip,
    RandomVFlip,
    RandomStraightCut,
    RandomGlobalIntensity,
    DropChannel,
    # LargestSquareCrop,
    # ToRgb,
)
from .dataset import (
    CropperDataset,
    DataLoaderWithLoad,
    CollateFnListTuple,
    MetadataCropperDataset,
    # CropperDenseTensor,
    CropperSparseTensor,
    CropperTensor,
)


# SparseTensor can not be used in dataloader using num_workers > 0.
# See https://github.com/pytorch/pytorch/issues/20248
# Therefore I put the dataset in GPU and use num_workers = 0.


class ParseDict(ArgparseAction):
    """ Make argparse able to parse a dictionary from command line """
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, dict())
        for value in values:
            key, value = value.split('=')
            getattr(namespace, self.dest)[key] = value


[docs]class SslDM(LightningDataModule): """ Base class to inherit from to make a DataModule which can be used with any Self Supervised Learning framework """ @classmethod def get_default_params(cls) -> dict: # Get the default parameters to instantiate an object parser = ArgumentParser() parser = cls.add_specific_args(parser) args = parser.parse_args(args=[]) return args.__dict__ @classmethod def add_specific_args(cls, parent_parser): # Utility functions which add parameters to argparse to simplify setting up a CLI raise NotImplementedError def get_metadata_to_regress(self, metadata) -> Dict[str, float]: # Extract one or more quantities to regress from the metadata """ raise NotImplementedError def get_metadata_to_classify(self, metadata) -> Dict[str, int]: # Extract one or more quantities to classify from the metadata """ raise NotImplementedError @property def ch_in(self) -> int: # How many channels will be present in the images returned by the train/test/val dataloaders? raise NotImplementedError @property def local_size(self) -> int: # Size in pixel of the local crops (used only for Dino) raise NotImplementedError @property def global_size(self) -> int: # Size in pixel of the global crops raise NotImplementedError @property def n_local_crops(self) -> int: # Number of local crops for each image to use for training (used only for Dino) raise NotImplementedError @property def n_global_crops(self) -> int: # Number of global crops for each image to use for training (used only for Dino) raise NotImplementedError @property def cropper_test(self) -> CropperTensor: # Cropper to be used at test time. This specify the cropping strategy to use at test time. raise NotImplementedError @property def trsfm_test(self) -> Callable: # Transformation to be applied at test time. This specify the data-augmentation at test time. raise NotImplementedError @property def cropper_train(self) -> CropperTensor: # Cropper to be used at train time. This specify the cropping strategy to use at train time. raise NotImplementedError @property def trsfm_train_local(self) -> Callable: # Local Transformation to be applied at train time. This specify the data augmentation for the local crops. # Used by Dino only. raise NotImplementedError @property def trsfm_train_global(self) -> Callable: # Global Transformation to be applied at train time. This specify the data augmentation for the global crops. raise NotImplementedError def prepare_data(self): # Use this to download and prepare the data. # These operations will be done only once in distributed settings. # For example, one GPU might be used to prepare data and write the results to disk so that the other # GPUs can read the pre-process data. raise NotImplementedError def setup(self, stage: Optional[str] = None) -> None: # Called on every GPU at the beginning of fit (train + validate), validate, test, and predict. # This is a good place to set the internal state, i.e. self.something = something_else raise NotImplementedError def train_dataloader(self) -> DataLoaderWithLoad: # Returns the train dataloader. raise NotImplementedError def val_dataloader(self) -> List[DataLoaderWithLoad]: # Returns the validation dataloader. raise NotImplementedError def test_dataloader(self) -> List[DataLoaderWithLoad]: # Returns the test dataloader. raise NotImplementedError def predict_dataloader(self) -> List[DataLoaderWithLoad]: # Returns the predict dataloader. raise NotImplementedError
[docs]class SparseSslDM(SslDM): """ Datamodule for sparse Images with the parameter for the transform (i.e. data augmentation) specified. If you are inheriting from this class then you only have to overwrite: 'prepara_data', 'setup', 'get_metadata_to_classify' and 'get_metadata_to_regress'. """ def __init__(self, global_size: int = 96, local_size: int = 64, n_local_crops: int = 2, n_global_crops: int = 2, global_scale: Tuple[float] = (0.8, 1.0), local_scale: Tuple[float] = (0.5, 0.8), global_intensity: Tuple[float, float] = (0.8, 1.2), n_element_min_for_crop: int = 200, drop_spot_probs: Tuple[float] = (0.1, 0.2, 0.3), rasterize_sigmas: Tuple[float] = (1.0, 1.5), occlusion_fraction: Tuple[float, float] = (0.1, 0.3), drop_channel_prob: float = 0.0, drop_channel_relative_freq: Iterable[float] = None, n_crops_for_tissue_test: int = 50, n_crops_for_tissue_train: int = 50, batch_size_per_gpu: int = 64, **kargs): """ Args: global_size: size in pixel of the global crops local_size: size in pixel of the local crops n_local_crops: number of global crops n_global_crops: number of local crops global_scale: in RandomResizedCrop the scale of global crops will be drawn uniformly between these values local_scale: in RandomResizedCrop the scale of global crops will be drawn uniformly between these values global_intensity: all channels will be multiplied by a number in this range n_element_min_for_crop: minimum number of beads/cell in a crop drop_spot_probs: Probability of dropping out spots (in sparse image). Should be > 0.0 rasterize_sigmas: Possible values of the sigma of the gaussian kernel used for rasterization. occlusion_fraction: Fraction of the sample which is occluded is drawn uniformly between these values drop_channel_prob: Probability that a channel will be set to zero, drop_channel_relative_freq: Relative probability of each channel to be set to zero. If None (default) all channels are equally likely to be set to zero. n_crops_for_tissue_test: The number of crops in each validation epoch will be :math:`n_{tissue} \\times \\text{n_crops_for_tissue_test}` n_crops_for_tissue_train: The number of crops in each training epoch will be :math:`n_{tissue} \\times \\text{n_crops_for_tissue_train}` batch_size_per_gpu: batch size for EACH GPUs. """ super(SparseSslDM, self).__init__() # params for overwriting the abstract property self._global_size = global_size self._local_size = local_size self._n_global_crops = n_global_crops self._n_local_crops = n_local_crops # specify the transform self._global_scale = global_scale self._local_scale = local_scale self._global_intensity = global_intensity self._drop_spot_probs = drop_spot_probs self._rasterize_sigmas = rasterize_sigmas self._occlusion_fraction = occlusion_fraction self._drop_channel_prob = drop_channel_prob self._drop_channel_relative_freq = drop_channel_relative_freq self._n_element_min_for_crop = n_element_min_for_crop self._n_crops_for_tissue_test = n_crops_for_tissue_test self._n_crops_for_tissue_train = n_crops_for_tissue_train # batch_size self._batch_size_per_gpu = batch_size_per_gpu self._dataset_train: CropperDataset = None self._dataset_test: CropperDataset = None
[docs] @classmethod def add_specific_args(cls, parent_parser) -> ArgumentParser: """ Utility functions which add parameters to argparse to simplify setting up a CLI Example: >>> import sys >>> import argparse >>> parser = argparse.ArgumentParser(add_help=False, conflict_handler='resolve') >>> parser = SslDM.add_specific_args(parser) >>> args = parser.parse_args(sys.argv[1:]) """ parser = ArgumentParser(parents=[parent_parser], add_help=False, conflict_handler='resolve') parser.add_argument("--global_size", type=int, default=96, help="size in pixel of the global crops") parser.add_argument("--local_size", type=int, default=64, help="size in pixel of the local crops") parser.add_argument("--n_global_crops", type=int, default=2, help="number of global crops") parser.add_argument("--n_local_crops", type=int, default=2, help="number of local crops") parser.add_argument("--global_scale", type=float, nargs=2, default=[0.8, 1.0], help="in RandomResizedCrop the scale of global crops will be drawn uniformly \ between these values") parser.add_argument("--local_scale", type=float, nargs=2, default=[0.5, 0.8], help="in RandomResizedCrop the scale of local crops will be drawn uniformly \ between these values") parser.add_argument("--global_intensity", type=float, nargs=2, default=[0.8, 1.2], help="All channels will be multiplied by a value within this range") parser.add_argument("--n_element_min_for_crop", type=int, default=200, help="minimum number of beads/cell in a crop") parser.add_argument("--drop_spot_probs", type=float, nargs='*', default=[0.1, 0.2, 0.3], help="Probability of dropping out spots in the sparse image. Should be in (0.0, 1.0). \ If a tuple is given. A random value for the tuple is chosen.") parser.add_argument("--rasterize_sigmas", type=float, nargs='*', default=[1.0, 1.5], help="Possible values of the sigma of the gaussian kernel used for rasterization") parser.add_argument("--occlusion_fraction", type=float, nargs=2, default=[0.1, 0.3], help="Fraction of the sample which is occluded is drawn uniformly between these values.") parser.add_argument("--drop_channel_prob", type=float, default=0.2, help="Probability that a channel in the image will be set to zero.") parser.add_argument("--drop_channel_relative_freq", type=float, nargs='*', default=None, help="Relative probability of each channel to be set to zero. \ If None, all channels have the same probability of being zero") parser.add_argument("--n_crops_for_tissue_train", type=int, default=50, help="The number of crops in each training epoch will be: n_tissue * n_crops. \ Set small for rapid prototyping") parser.add_argument("--n_crops_for_tissue_test", type=int, default=50, help="The number of crops in each test epoch will be: n_tissue * n_crops. \ Set small for rapid prototyping") parser.add_argument("--batch_size_per_gpu", type=int, default=64, help="Batch size for EACH GPUs. Set small for rapid prototyping. \ The total batch_size will increase linearly with the number of GPUs.") return parser
@property def global_size(self) -> int: """ Size in pixel of the global crops. This specify the size of the patch processed by the ssl model. """ return self._global_size @property def local_size(self) -> int: """ Size in pixel of the local crops (used only for Dino). This specify the size of the patch processed by the ssl model. """ return self._local_size @property def n_global_crops(self) -> int: """ Number of global crops for each image to use for training (used only for Dino). """ return self._n_global_crops @property def n_local_crops(self) -> int: """ Number of local crops for each image to use for training (used only for Dino). """ return self._n_local_crops @property def cropper_test(self) -> CropperSparseTensor: """ Cropper to be used at test time. This specify the cropping strategy to use at test time. """ return CropperSparseTensor( strategy='random', crop_size=self._global_size, n_element_min=self._n_element_min_for_crop, n_crops=self._n_crops_for_tissue_test, random_order=True, ) @property def cropper_train(self) -> CropperSparseTensor: """ Cropper to be used at train time. This specify the cropping strategy to use at train time. """ return CropperSparseTensor( strategy='random', crop_size=int(self._global_size * 1.5), n_element_min=int(self._n_element_min_for_crop * 1.5 * 1.5), n_crops=self._n_crops_for_tissue_train, random_order=True, ) @property def trsfm_test(self) -> TransformForList: """ Transformation to be applied at test time. This specify the data-augmentation at test time. """ return TransformForList( transform_before_stack=torchvision.transforms.Compose([ DropoutSparseTensor(p=0.5, dropout_rate=self._drop_spot_probs), SparseToDense(), Rasterize(sigmas=self._rasterize_sigmas, normalize=False), RandomVFlip(p=0.5), RandomHFlip(p=0.5), RandomGlobalIntensity(f_min=self._global_intensity[0], f_max=self._global_intensity[1]) ]), transform_after_stack=torchvision.transforms.CenterCrop(size=self.global_size), ) @property def trsfm_train_global(self) -> TransformForList: """ Global Transformation to be applied at train time. This specify the data augmentation for the global crops. """ return TransformForList( transform_before_stack=torchvision.transforms.Compose([ DropoutSparseTensor(p=0.5, dropout_rate=self._drop_spot_probs), SparseToDense(), RandomGlobalIntensity(f_min=self._global_intensity[0], f_max=self._global_intensity[1]) ]), transform_after_stack=torchvision.transforms.Compose([ Rasterize(sigmas=self._rasterize_sigmas, normalize=False), torchvision.transforms.RandomRotation( degrees=(-180.0, 180.0), interpolation=torchvision.transforms.InterpolationMode.BILINEAR, expand=False, fill=0.0), torchvision.transforms.CenterCrop(size=self._global_size), RandomVFlip(p=0.5), RandomHFlip(p=0.5), torchvision.transforms.RandomResizedCrop( size=(self._global_size, self._global_size), scale=self._global_scale, ratio=(0.95, 1.05), interpolation=torchvision.transforms.InterpolationMode.BILINEAR), RandomStraightCut(p=0.5, occlusion_fraction=self._occlusion_fraction), DropChannel(p=self._drop_channel_prob, relative_frequency=self._drop_channel_relative_freq), ]) ) @property def trsfm_train_local(self) -> TransformForList: """ Local Transformation to be applied at train time. This specify the data augmentation for the local crops. Used by Dino only. """ return TransformForList( transform_before_stack=torchvision.transforms.Compose([ DropoutSparseTensor(p=0.5, dropout_rate=self._drop_spot_probs), SparseToDense(), RandomGlobalIntensity(f_min=self._global_intensity[0], f_max=self._global_intensity[1]) ]), transform_after_stack=torchvision.transforms.Compose([ Rasterize(sigmas=self._rasterize_sigmas, normalize=False), torchvision.transforms.RandomRotation( degrees=(-180.0, 180.0), interpolation=torchvision.transforms.InterpolationMode.BILINEAR, expand=False, fill=0.0), torchvision.transforms.CenterCrop(size=self.global_size), RandomVFlip(p=0.5), RandomHFlip(p=0.5), torchvision.transforms.RandomResizedCrop( size=(self._local_size, self._local_size), scale=self._local_scale, ratio=(0.95, 1.05), interpolation=torchvision.transforms.InterpolationMode.BILINEAR), RandomStraightCut(p=0.5, occlusion_fraction=self._occlusion_fraction), DropChannel(p=self._drop_channel_prob, relative_frequency=self._drop_channel_relative_freq), ]) ) def train_dataloader(self) -> DataLoaderWithLoad: try: device = self.trainer._model.device except AttributeError: device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") # print("Inside train_dataloader", device) assert isinstance(self._dataset_train, CropperDataset) if self._dataset_train.n_crops_per_tissue is None: batch_size_dataloader = self._batch_size_per_gpu else: batch_size_dataloader = max(1, int(self._batch_size_per_gpu // self._dataset_train.n_crops_per_tissue)) dataloader_train = DataLoaderWithLoad( # move the dataset to GPU so that the cropping happens there dataset=self._dataset_train.to(device), # each sample generate n_crops therefore reduce batch_size batch_size=batch_size_dataloader, collate_fn=CollateFnListTuple(), # problem if this is larger than 0, see https://github.com/pytorch/pytorch/issues/20248 num_workers=0, # in the train dataloader, I DO shuffle and drop the last partial_batch shuffle=True, drop_last=True, ) return dataloader_train def val_dataloader(self) -> List[DataLoaderWithLoad]: # the same as test try: device = self.trainer._model.device except AttributeError: device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") assert isinstance(self._dataset_test, CropperDataset) if self._dataset_test.n_crops_per_tissue is None: batch_size_dataloader = self._batch_size_per_gpu else: batch_size_dataloader = max(1, int(self._batch_size_per_gpu // self._dataset_train.n_crops_per_tissue)) assert isinstance(self._dataset_test, CropperDataset) test_dataloader = DataLoaderWithLoad( # move the dataset to GPU so that the cropping happens there dataset=self._dataset_test.to(device), # each sample generate n_crops therefore reduce batch_size batch_size=batch_size_dataloader, collate_fn=CollateFnListTuple(), # problem if num_workers > 0, see https://github.com/pytorch/pytorch/issues/20248 num_workers=0, # in the test dataloader, I do NOT shuffle and do not drop the last partial_batch shuffle=False, drop_last=False, ) return [test_dataloader] def test_dataloader(self) -> List[DataLoaderWithLoad]: return self.val_dataloader() def predict_dataloader(self) -> List[DataLoaderWithLoad]: return self.val_dataloader() def prepare_data(self): raise NotImplementedError def get_metadata_to_classify(self, metadata) -> Dict[str, int]: raise NotImplementedError def get_metadata_to_regress(self, metadata) -> Dict[str, float]: raise NotImplementedError def setup(self, stage: Optional[str] = None) -> None: self._dataset_train = None self._dataset_test = None raise NotImplementedError
[docs]class AnndataFolderDM(SparseSslDM): """ Create a Datamodule ready for Self-supervised learning starting from a folder full of anndata files in .h5ad format. """ def __init__(self, data_folder: str, pixel_size: float, x_key: str, y_key: str, category_key: str, categories_to_channels: Dict[Any, int], metadata_to_classify: Callable, metadata_to_regress: Callable, num_workers: int, gpus: int, n_neighbours_moran: int, **kargs): """ Args: data_folder: path to folder with the anndata in h5ad format pixel_size: size of the pixel (used to convert raw_coordinates to pixel_coordinates) x_key: key associated with the x_coordinate in the AnnData object y_key: key associated with the y_coordinate in the AnnData object category_key: key associated with the the categorical values (cell_types or gene_identities) in the AnnData object categories_to_channels: dictionary with the mapping from categorical values to channels in the image. The values must be non-negative integers metadata_to_classify: callable which defines the values to classify during training metadata_to_regress: callable which defines the values to regress during training num_workers: number of worker to load data. Meaningful only if dataset is on disk. Set to zero if data in memory gpus: number of gpus to use for training. n_neighbours_moran: number of neighbours used to compute Moran's I score kargs: all these parameters will be passed to :class:`SparseSslDM` """ assert isinstance(categories_to_channels, dict) and len(categories_to_channels.keys()) >= 1, \ "Error. Specify a valid categories_to_channels mapping. Received {}".format(categories_to_channels) set_chs = set(categories_to_channels.values()) set_chs_should_be = set([i for i in range(max(set_chs)+1)]) assert set_chs == set_chs_should_be, \ "The values of the categories_to_channels must be integers starting at zero. Received {}".format(set_chs) self._data_folder = data_folder self._pixel_size = pixel_size self._x_key = x_key self._y_key = y_key self._category_key = category_key self._categories_to_channels = categories_to_channels self._metadata_to_regress = metadata_to_regress self._metadata_to_classify = metadata_to_classify self._num_workers = cpu_count() if num_workers is None else num_workers self._gpus = torch.cuda.device_count() if gpus is None else gpus self._n_neighbours_moran = n_neighbours_moran # Callable on dataset self.compute_moran = SpatialAutocorrelation( modality='moran', n_neighbours=self._n_neighbours_moran, neigh_correct=False) # list of all the files used to create the dataset self._all_filenames = None super(AnndataFolderDM, self).__init__(**kargs)
[docs] @classmethod def add_specific_args(cls, parent_parser) -> ArgumentParser: """ Utility functions which add parameters to argparse to simplify setting up a CLI Example: >>> import sys >>> import argparse >>> parser = argparse.ArgumentParser(add_help=False, conflict_handler='resolve') >>> parser = AnndataFolderDM.add_specific_args(parser) >>> args = parser.parse_args(sys.argv[1:]) """ parser_from_super = super().add_specific_args(parent_parser) parser = ArgumentParser(parents=[parser_from_super], add_help=False, conflict_handler='resolve') parser.add_argument("--data_folder", type=str, default="./", help="directory where to find the anndata in h5ad format") parser.add_argument("--pixel_size", type=float, default=4.0, help="size of the pixel (used to convert raw_coordinates to pixel_coordinates)") parser.add_argument("--x_key", type=str, default="x", help="key associated with the x_coordinate in the AnnData object") parser.add_argument("--y_key", type=str, default="y", help="key associated with the y_coordinate in the AnnData object") parser.add_argument("--category_key", type=str, default="cell_type", help="key associated with the the categorical values (cell_types or gene_identities) \ in the AnnData object") parser.add_argument("--categories_to_channels", nargs='*', action=ParseDict, help="dictionary in the form 'foo'=1 'bar'=2 to define \ how the categorical values are mapped to the different channels in the image") parser.add_argument("--metadata_to_classify", default=None, help="callable which defines the values to classify during training") parser.add_argument("--metadata_to_regress", default=None, help="callable which defines the values to regress during training") parser.add_argument("--num_workers", default=cpu_count(), type=int, help="number of worker to load data. Meaningful only if dataset is on disk. \ Set to zero if data in memory") parser.add_argument("--gpus", default=torch.cuda.device_count(), type=int, help="number of gpus to use for training.") parser.add_argument("--n_neighbours_moran", type=int, default=6, help="number of neighbours used to compute Moran's I score") return parser
@property def ch_in(self) -> int: """ How many channels will be present in the images returned by the train/test/val dataloaders? """ return numpy.max(list(self._categories_to_channels.values())) + 1
[docs] def anndata_to_sparseimage(self, anndata: AnnData): """ Converts a anndata object to :class:`SparseImage`. """ return SparseImage.from_anndata( anndata=anndata, x_key=self._x_key, y_key=self._y_key, category_key=self._category_key, pixel_size=self._pixel_size, categories_to_channels=self._categories_to_channels, padding=10)
def prepare_data(self): # create train_dataset_random and write to file all_metadatas = [] all_sparse_images = [] all_labels = [] for filename in os.listdir(self._data_folder): f = os.path.join(self._data_folder, filename) # checking if it is a file if os.path.isfile(f) and filename.endswith('h5ad'): print("reading file {}".format(f)) anndata = read_h5ad(filename=f) anndata.X = None # set the count matrix to None sp_img = self.anndata_to_sparseimage(anndata=anndata).cpu() all_sparse_images.append(sp_img) metadata = MetadataCropperDataset(f_name=filename, loc_x=0.0, loc_y=0.0, moran=-99) all_metadatas.append(metadata) all_labels.append(filename) self._all_filenames: list = all_labels torch.save((all_sparse_images, all_labels, all_metadatas), os.path.join(self._data_folder, "train_dataset.pt")) print("saved the file", os.path.join(self._data_folder, "train_dataset.pt")) # create test_dataset_random and write to file all_names = [metadata.f_name for metadata in all_metadatas] if torch.cuda.is_available(): all_sparse_images = [sp_img.cuda() for sp_img in all_sparse_images] test_imgs, test_labels, test_metadatas = [], [], [] for sp_img, label, fname in zip(all_sparse_images, all_labels, all_names): sps_tmp, loc_x_tmp, loc_y_tmp = self.cropper_test(sp_img, n_crops=self._n_crops_for_tissue_test) labels = [label] * len(sps_tmp) morans = [self.compute_moran(sparse_tensor).max().item() for sparse_tensor in sps_tmp] metadatas = [MetadataCropperDataset(f_name=fname, loc_x=loc_x, loc_y=loc_y, moran=moran) for loc_x, loc_y, moran in zip(loc_x_tmp, loc_y_tmp, morans)] test_imgs += [sp_img.cpu() for sp_img in sps_tmp] test_labels += labels test_metadatas += metadatas torch.save((test_imgs, test_labels, test_metadatas), os.path.join(self._data_folder, "test_dataset.pt")) print("saved the file", os.path.join(self._data_folder, "test_dataset.pt"))
[docs] def get_metadata_to_classify(self, metadata) -> Dict[str, int]: """ Extract one or more quantities to classify from the metadata """ if self._metadata_to_classify is None: return {"tissue_label": self._all_filenames.index(metadata.f_name)} else: return self._metadata_to_classify(metadata)
[docs] def get_metadata_to_regress(self, metadata) -> Dict[str, float]: """ Extract one or more quantities to regress from the metadata """ if self._metadata_to_regress is None: return { "moran": float(metadata.moran), "loc_x": float(metadata.loc_x), } else: return self._metadata_to_regress(metadata)
def setup(self, stage: Optional[str] = None) -> None: list_imgs, list_labels, list_metadata = torch.load(os.path.join(self._data_folder, "train_dataset.pt")) print("read the file {}".format(os.path.join(self._data_folder, "train_dataset.pt"))) list_imgs = [img.coalesce().cpu() for img in list_imgs] self._dataset_train = CropperDataset( imgs=list_imgs, labels=list_labels, metadatas=list_metadata, cropper=self.cropper_train, ) print("created train_dataset device = {0}, length = {1}".format(self._dataset_train.imgs[0].device, self._dataset_train.__len__())) list_imgs, list_labels, list_metadata = torch.load(os.path.join(self._data_folder, "test_dataset.pt")) print("read the file {}".format(os.path.join(self._data_folder, "test_dataset.pt"))) list_imgs = [img.coalesce().cpu() for img in list_imgs] self._dataset_test = CropperDataset( imgs=list_imgs, labels=list_labels, metadatas=list_metadata, cropper=None, ) print("created test_dataset device = {0}, length = {1}".format(self._dataset_test.imgs[0].device, self._dataset_test.__len__()))