from typing import List, Any, Dict
import torch
from torch.nn import functional as F
from argparse import ArgumentParser
from ._resnet_backbone import make_resnet_backbone
from ._ssl_base_model import SslModelBase
from tissue_purifier.models._optim_scheduler import LARS, linear_warmup_and_cosine_protocol
class NTXentLoss(torch.nn.Module):
""" Very smart implementation of contrastive loss """
def __init__(self,
temperature: float = 0.5):
super().__init__()
self.temperature = temperature
self.cross_entropy = torch.nn.CrossEntropyLoss(reduction='mean')
self.eps = 1e-8
def forward(self,
out0: torch.Tensor,
out1: torch.Tensor):
""" Forward pass through Contrastive Cross-Entropy Loss.
Args:
out0: representation for the first set of transformed images. Shape: (batch_size, embedding_size)
out1: representation for the second set of transformed images. Shape: (batch_size, embedding_size)
Returns:
Contrastive Cross Entropy Loss value.
Example:
>>> batch, latent_dim = 3, 1028
>>> out_1 = torch.randn((batch, latent_dim))
>>> out_2 = out1 + 0.1 # this mimic a very good encoding where pair images have close embeddings
>>> ntx_loss = NTXentLoss()
>>> my_loss = ntx_loss(out_1, out_2)
"""
# normalize the output to length 1
out0 = F.normalize(out0, p=2, dim=1) # shape: batch_size, latent_dim
out1 = F.normalize(out1, p=2, dim=1) # shape: batch_size, latent_dim
batch_size = out0.shape[0]
# the logits are the similarity matrix divided by the temperature
output = torch.cat((out0, out1), dim=0) # shape: 2*batch_size, latent_dim
logits = (output @ output.t()) / self.temperature # shape: 2*batch_size, 2*batch_size
# We need to remove the similarities of samples to themselves
mask_diag = torch.eye(2*batch_size, dtype=torch.bool, device=out0.device)
logits = logits[~mask_diag].view(2*batch_size, -1) # shape: 2*batch_size, 2*batch_size - 1
# The labels point from a sample in out_i to its equivalent in out_(1-i)
target = torch.arange(batch_size, device=out0.device, dtype=torch.long) # shape: batch_size
target = torch.cat([target + batch_size - 1, target]) # shape: 2*batch_size
# compute loss
loss = self.cross_entropy(logits, target) # shape: after reduction is a scalar
return loss
[docs]class SimclrModel(SslModelBase):
"""
Simclr self supervised learning model.
Inspired by the `Simclr official implementation <https://github.com/google-research/simclr>`_
and this `Simclr pytorch lightning reimplementation <https://github.com/PyTorchLightning/lightning-bolts/\
blob/0.1.0/pl_bolts/models/self_supervised/simclr/simclr_module.py#L47-L281>`_
"""
def __init__(
self,
# architecture
backbone_type: str,
image_in_ch: int,
head_hidden_chs: List[int],
head_out_ch: int,
# optimizer
optimizer_type: str,
# scheduler
warm_up_epochs: int,
warm_down_epochs: int,
max_epochs: int,
min_learning_rate: float,
max_learning_rate: float,
min_weight_decay: float,
max_weight_decay: float,
# validation
val_iomin_threshold: float = 0.0,
**kwargs,
):
"""
Args:
backbone_type: Either 'resnet18', 'resnet34' or 'resnet50'
image_in_ch: number of channels in the input images, used to adjust the first
convolution filter in the backbone
head_hidden_chs: List of integers with the size of the hidden layers of the projection head
head_out_ch: output dimension of the projection head
optimizer_type: Either 'adamw', 'lars', 'sgd', 'adam' or 'rmsprop'
warm_up_epochs: epochs during which to linearly increase learning rate (at the beginning of training)
warm_down_epochs: epochs during which to anneal learning rate with cosine protocoll (at the end of training)
max_epochs: total number of epochs
min_learning_rate: minimum learning rate (at the very beginning and end of training)
max_learning_rate: maximum learning rate (after linear ramp)
min_weight_decay: minimum weight decay (during the entirety of the linear ramp)
max_weight_decay: maximum weight decay (reached at the end of training)
val_iomin_threshold: during validation, only patches with Intersection Over MinArea < IoMin_threshold
are used. Should be in [0.0, 1.0). If 0 only strictly non-overlapping patches are allowed.
"""
super(SimclrModel, self).__init__(val_iomin_threshold=val_iomin_threshold)
# Next two lines will make checkpointing much simpler
self.save_hyperparameters() # all hyperparameters are saved to the checkpoint
self.neptune_run_id = None # if from scratch neptune_experiment_is is None
# architecture
self.backbone = make_resnet_backbone(
backbone_in_ch=image_in_ch,
backbone_type=backbone_type)
tmp_in = torch.zeros((1, image_in_ch, 64, 64))
tmp_out = self.backbone(tmp_in)
backbone_ch_out = tmp_out.shape[1]
self.projection = self.init_projection(
ch_in=backbone_ch_out,
ch_hidden=head_hidden_chs,
ch_out=head_out_ch)
# loss
self.nt_xent_loss = NTXentLoss()
# optimizer
self.optimizer_type = optimizer_type
# scheduler
assert warm_up_epochs + warm_down_epochs <= max_epochs
self.learning_rate_fn = linear_warmup_and_cosine_protocol(
f_values=(min_learning_rate, max_learning_rate, min_learning_rate),
x_milestones=(0, warm_up_epochs, max_epochs - warm_down_epochs, max_epochs))
self.weight_decay_fn = linear_warmup_and_cosine_protocol(
f_values=(min_weight_decay, min_weight_decay, max_weight_decay),
x_milestones=(0, warm_up_epochs, max_epochs - warm_down_epochs, max_epochs))
@staticmethod
def init_projection(
ch_in: int,
ch_out: int,
ch_hidden: List[int] = None):
sizes = [ch_in] + ch_hidden + [ch_out]
layers = []
for i in range(len(sizes) - 2):
layers.append(torch.nn.Linear(sizes[i], sizes[i + 1], bias=False))
layers.append(torch.nn.BatchNorm1d(sizes[i + 1]))
layers.append(torch.nn.ReLU(inplace=True))
layers.append(torch.nn.Linear(sizes[-2], sizes[-1], bias=False))
return torch.nn.Sequential(*layers)
[docs] @classmethod
def add_specific_args(cls, parent_parser):
"""
Utility functions which add parameters to argparse to simplify setting up a CLI
Example:
>>> import sys
>>> import argparse
>>> parser = argparse.ArgumentParser(add_help=False, conflict_handler='resolve')
>>> parser = SimclrModel.add_specific_args(parser)
>>> args = parser.parse_args(sys.argv[1:])
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False, conflict_handler='resolve')
# validation
parser.add_argument("--val_iomin_threshold", type=float, default=0.0,
help="during validation, only patches with IoMinArea < IoMin_threshold are used "
"in the kn-classifier and kn-regressor.")
# architecture
parser.add_argument("--image_in_ch", type=int, default=3, help="number of channels in the input images")
parser.add_argument("--backbone_type", type=str, default="resnet34", help="backbone type",
choices=['resnet18', 'resnet34', 'resnet50'])
parser.add_argument("--head_hidden_chs", type=int, nargs='+', default=[128, 256],
help="List of integers. Hidden channels in projection head.")
parser.add_argument("--head_out_ch", type=int, default=128, help="head output channels")
# optimizer
parser.add_argument("--optimizer_type", type=str, default='adam', help="optimizer type",
choices=['adamw', 'lars', 'sgd', 'adam', 'rmsprop'])
# scheduler
parser.add_argument("--max_epochs", default=1000, type=int,
help="Total number of epochs in training.")
parser.add_argument("--warm_up_epochs", default=100, type=int,
help="Number of epochs for the linear learning-rate warm up.")
parser.add_argument("--warm_down_epochs", default=500, type=int,
help="Number of epochs for the cosine decay.")
parser.add_argument('--min_learning_rate', type=float, default=1e-5,
help="Target LR at the end of cosine protocol (smallest LR used during training).")
parser.add_argument("--max_learning_rate", type=float, default=5e-4,
help="learning rate at the end of linear ramp (largest LR used during training).")
parser.add_argument('--min_weight_decay', type=float, default=0.04,
help="Minimum value of the weight decay. It is used during the linear ramp.")
parser.add_argument('--max_weight_decay', type=float, default=0.4,
help="Maximum Value of the weight decay. It is reached at the end of cosine protocol.")
return parser
[docs] @classmethod
def get_default_params(cls) -> dict:
"""
Get the default configuration parameters for this model
Example:
>>> config = SimclrModel.get_default_params()
>>> my_barlow = SimclrModel(**config)
"""
parser = ArgumentParser()
parser = SimclrModel.add_specific_args(parser)
args = parser.parse_args(args=[])
return args.__dict__
def forward(self, x):
# this is the stuff that will generate the backbone embeddings
y = self.backbone(x) # shape (batch, ch)
return y
def head_and_backbone_embeddings_step(self, x):
# this generates both head and backbone embeddings
y = self(x) # shape: (batch, ch)
z = self.projection(y) # shape: (batch, latent)
return z, y
def training_step(self, batch, batch_idx) -> dict:
# this is data augmentation
with torch.no_grad():
list_imgs, list_labels, list_metadata = batch
img1 = self.trsfm_train_global(list_imgs)
img2 = self.trsfm_train_global(list_imgs)
z1, y1 = self.head_and_backbone_embeddings_step(img1)
z2, y2 = self.head_and_backbone_embeddings_step(img2)
world_z1 = self.all_gather(z1, sync_grads=True).flatten(end_dim=-2) # shape: world*batch_size, latent_dim
world_z2 = self.all_gather(z2, sync_grads=True).flatten(end_dim=-2) # shape: world*batch_size, latent_dim
loss = self.nt_xent_loss(world_z1, world_z2)
# Update the optimizer parameters and log stuff
with torch.no_grad():
lr = self.learning_rate_fn(self.current_epoch)
wd = self.weight_decay_fn(self.current_epoch)
for i, pg in enumerate(self.optimizers().param_groups):
pg["lr"] = lr
if i == 0: # only the first group is regularized
pg["weight_decay"] = wd
else:
pg["weight_decay"] = 0.0
# Finally I log interesting stuff
batch_size_total = float(world_z1.shape[0])
batch_size_per_gpu = float(z1.shape[0])
self.log('train_loss', loss, on_step=False, on_epoch=True, rank_zero_only=True, batch_size=1)
self.log('weight_decay', wd, on_step=False, on_epoch=True, rank_zero_only=True, batch_size=1)
self.log('learning_rate', lr, on_step=False, on_epoch=True, rank_zero_only=True, batch_size=1)
self.log('batch_size_per_gpu_train', batch_size_per_gpu, on_step=False, on_epoch=True, rank_zero_only=True)
self.log('batch_size_total_train', batch_size_total, on_step=False, on_epoch=True, rank_zero_only=True)
return loss
def configure_optimizers(self) -> torch.optim.Optimizer:
regularized = []
not_regularized = []
for name, param in list(self.backbone.named_parameters()) + list(self.projection.named_parameters()):
if not param.requires_grad:
continue
# we do not regularize biases nor Norm parameters
if name.endswith(".bias") or len(param.shape) == 1:
not_regularized.append(param)
else:
regularized.append(param)
arg_for_optimizer = [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.0}]
# The real lr will be set in the training step
# The weight_decay for the regularized group will be set in the training step
if self.optimizer_type == 'adam':
return torch.optim.Adam(arg_for_optimizer, betas=(0.9, 0.999), lr=0.0)
elif self.optimizer_type == 'sgd':
return torch.optim.SGD(arg_for_optimizer, momentum=0.9, lr=0.0)
elif self.optimizer_type == 'rmsprop':
return torch.optim.RMSprop(arg_for_optimizer, alpha=0.99, lr=0.0)
elif self.optimizer_type == 'lars':
# for convnet with large batch_size
return LARS(arg_for_optimizer, momentum=0.9, lr=0.0)
else:
# do adamw
raise Exception("optimizer is misspecified")