from typing import Dict, Tuple
from torch.nn import functional as F
import torch
from argparse import ArgumentParser
from pytorch_lightning.utilities.distributed import sync_ddp_if_available # wrapper around torch.distributed.all_reduce
from neptune.new.types import File
from ._ssl_base_model import SslModelBase
from ._resnet_backbone import (
make_vae_decoder_backbone_from_scratch,
make_vae_decoder_backbone_from_resnet,
make_vae_encoder_backbone_from_scratch,
make_vae_encoder_backbone_from_resnet,
)
from tissue_purifier.models._optim_scheduler import LARS, linear_warmup_and_cosine_protocol
from tissue_purifier.plots.plot_images import show_batch
class ConvolutionalVae(torch.nn.Module):
def __init__(self,
backbone_type: str,
in_size: int,
in_channels: int,
latent_dim: int,
hidden_dims: Tuple[int] = (32, 64, 128, 256, 512),
decoder_activation: torch.nn.Module = torch.nn.Identity(),
) -> None:
super(ConvolutionalVae, self).__init__()
assert (in_size % 32) == 0, "The input size must be a multiple of 32. Received {0}".format(in_size)
assert backbone_type in ('vanilla', 'resnet18', 'resnet34', 'resnet50'), \
"Invalid vae_type. Received {0}".format(backbone_type)
x_fake = torch.zeros((2, in_channels, in_size, in_size))
# encoder
print("making encoder", backbone_type)
self.latent_dim = latent_dim
if backbone_type == 'vanilla':
self.encoder_backbone = make_vae_encoder_backbone_from_scratch(
in_channels=in_channels,
hidden_dims=hidden_dims
)
elif backbone_type.startswith("resnet"):
self.encoder_backbone = make_vae_encoder_backbone_from_resnet(
in_channels=in_channels,
resnet_type=backbone_type
)
else:
raise Exception("Invalid vae_type. Received {0}".format(backbone_type))
x_latent = self.encoder_backbone(x_fake)
small_ch = x_latent.shape[-3]
self.small_size = x_latent.shape[-1]
self.fc_mu = torch.nn.Linear(small_ch * self.small_size * self.small_size, latent_dim)
self.fc_var = torch.nn.Linear(small_ch * self.small_size * self.small_size, latent_dim)
# Decoder
self.decoder_input = torch.nn.Linear(latent_dim, small_ch * self.small_size * self.small_size)
z_to_decode = torch.zeros((2, small_ch, self.small_size, self.small_size))
if backbone_type == 'vanilla':
tmp_list = list(hidden_dims)
tmp_list.reverse()
reverse_hidden_dims = tuple(tmp_list)
self.decoder_backbone = make_vae_decoder_backbone_from_scratch(hidden_dims=reverse_hidden_dims)
elif backbone_type.startswith("resnet"):
self.decoder_backbone = make_vae_decoder_backbone_from_resnet(resnet_type=backbone_type)
else:
raise Exception("Invalid vae_type. Received {0}".format(backbone_type))
x_tmp = self.decoder_backbone(z_to_decode)
ch_tmp = x_tmp.shape[-3]
last_hidden_ch = min(ch_tmp, 64)
self.final_layer = torch.nn.Sequential(
torch.nn.ConvTranspose2d(
in_channels=ch_tmp,
out_channels=last_hidden_ch,
kernel_size=(3, 3),
stride=(2, 2),
padding=1,
output_padding=1),
torch.nn.BatchNorm2d(last_hidden_ch),
torch.nn.LeakyReLU(),
torch.nn.Conv2d(
in_channels=last_hidden_ch,
out_channels=in_channels,
kernel_size=(3, 3),
padding=1),
decoder_activation,
)
# make sure the VAE reproduce the correct shape
dict_vae = self.forward(x_fake)
assert dict_vae['x_rec'].shape == x_fake.shape
def encode(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
Args:
x: (Tensor) [B x C x H x W]
verbose: bool
Returns:
mu, log_var (Tensors) [B x latent_dim]
"""
result = self.encoder_backbone(x)
result = torch.flatten(result, start_dim=1)
mu = self.fc_mu(result)
log_var = self.fc_var(result)
return {'mu': mu, 'log_var': log_var}
def decode(self, z: torch.Tensor) -> torch.Tensor:
"""
Maps the given latent codes
onto the image space.
Args:
z: (Tensor) [B x D]
verbose: bool
Returns:
(Tensor) [B x C x H x W]
"""
result = self.decoder_input(z)
result = result.view(z.shape[0], -1, self.small_size, self.small_size)
result = self.decoder_backbone(result)
x_rec = self.final_layer(result)
return x_rec
@staticmethod
def reparameterize(mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
dict_encoder = self.encode(x)
z = self.reparameterize(mu=dict_encoder['mu'], logvar=dict_encoder['log_var'])
x_rec = self.decode(z)
return {'x_rec': x_rec, 'x_in': x, 'mu': dict_encoder['mu'], 'log_var': dict_encoder['log_var']}
[docs]class VaeModel(SslModelBase):
"""
Convolutional Variational Auto Encoders (VAE) with dynamically adjusted hyper-parameter :math:`\\beta`.
The loss function is a weighted sum of the reconstruction (MSE) and regularization (KL):
:math:`\\text{loss} = \\beta \\times \\text{KL} + (1-\\beta) \\times \\text{MSE}`
We view this problem as a Multi-Objective Optimization (minimizing MSE *and* KL) and we dynamically
adjust :math:`\\beta \\in (.0, 1.0)` taking inspiration from the ideas in
`Multi-Task Learning as Multi-Objective Optimization
<https://proceedings.neurips.cc/paper/2018/file/432aca3a1e345e339f35a30c8f65edce-Paper.pdf>`_
Note:
Depending on the data pre-processing step, the input images might be mostly zeros with few "spots" in them.
In this case, the VAE might collapse to a local minimum in the loss function corresponding to identically
zero reconstruction. You might try to solve this collapse by using a larger latent dimension or by implementing
a KL-free variation of the VAE, such as `VQ-VAE <https://arxiv.org/pdf/1711.00937v2.pdf>`_.
However, spotty images are inherently difficult for
(unstructured) VAE which are required to retain (in their latent embedding) precise information about the
location of each individual spots.
"""
def __init__(
self,
# architecture
backbone_type: str,
global_size: int,
image_in_ch: int,
latent_dim: int,
encoder_hidden_dims: Tuple[int],
decoder_output_activation: str,
# optimizer
optimizer_type: str,
beta_vae_init: float,
momentum_beta_vae: float,
# scheduler
warm_up_epochs: int,
warm_down_epochs: int,
max_epochs: int,
min_learning_rate: float,
max_learning_rate: float,
min_weight_decay: float,
max_weight_decay: float,
# gradient clipping (these parameters are defined)
gradient_clip_val: float = 0.0,
gradient_clip_algorithm: str = 'value',
# validation
val_iomin_threshold: float = 0.0,
**kwargs,
):
"""
Args:
backbone_type: Either 'vanilla', 'resnet18', 'resnet34' or 'resnet50'
global_size: Size in pixel of the input image. Must be a multiple of 32.
image_in_ch: number of channels in the input images, used to adjust the first
convolution filter in the backbone
latent_dim: number of latent dimensions of the embeddings
encoder_hidden_dims: Dimension of the hidden layers. Used only in :attr:`backbone_type` == 'vanilla'.
decoder_output_activation: The non-linearity used to produce the reconstructed image.
In most cases "identity" (default) will work just fine. This is true even when he pixel values are
strictly positive and a "softplus" or "sigmoid" activations could be used.
optimizer_type: Either 'adamw', 'sgd', 'adam' or 'rmsprop'.
beta_vae_init: Initial value for :math:`\\beta` (the coefficient multiplying the KL divergence in the loss)
It should be in (0.0, 1.0). The reconstruction error in the loss is multiplied by :math:`(1-\\beta)`.
momentum_beta_vae: momentum for the Exponential Moving Average which updates the value of :math:`\\beta`.
It should be in (0.0, 1.0).
warm_up_epochs: epochs during which to linearly increase learning rate (at the beginning of training)
warm_down_epochs: epochs during which to anneal learning rate with cosine protocoll (at the end of training)
max_epochs: total number of epochs
min_learning_rate: minimum learning rate (at the very beginning and end of training)
max_learning_rate: maximum learning rate (after linear ramp)
min_weight_decay: minimum weight decay (during the entirety of the linear ramp)
max_weight_decay: maximum weight decay (reached at the end of training)
gradient_clip_algorithm: Either "norm" or "value". The algorithm to use for gradient clipping.
gradient_clip_val: Clip the gradients to this value. If 0 no clipping
val_iomin_threshold: during validation, only patches with Intersection Over MinArea < IoMin_threshold
are used. Should be in [0.0, 1.0). If 0 only strictly non-overlapping patches are allowed.
"""
super(VaeModel, self).__init__(val_iomin_threshold=val_iomin_threshold)
# Important: This property activates manual optimization.
self.automatic_optimization = False
# Next two lines will make checkpointing much simpler. Always keep them as-is
self.save_hyperparameters() # all hyperparameters are saved to the checkpoint
self.neptune_run_id = None # if from scratch neptune_experiment_is is None
# to make sure that you load the input images only once
self.already_loaded_input_val_images = False
# architecture
if decoder_output_activation == 'identity':
output_activation = torch.nn.Identity()
elif decoder_output_activation == 'relu':
output_activation = torch.nn.ReLU()
elif decoder_output_activation == 'tanh':
output_activation = torch.nn.Tanh()
elif decoder_output_activation == 'softplus':
output_activation = torch.nn.Softplus()
elif decoder_output_activation == "sigmoid":
output_activation = torch.nn.Sigmoid()
else:
raise Exception("invalid decoder_output_activation. Received {0}".format(decoder_output_activation))
self.image_size = global_size
self.vae = ConvolutionalVae(
backbone_type=backbone_type,
in_size=self.image_size,
in_channels=image_in_ch,
latent_dim=latent_dim,
hidden_dims=tuple(encoder_hidden_dims),
decoder_activation=output_activation
)
# stuff to do gradient clipping internally
self.gradient_clip_val = gradient_clip_val
self.gradient_clip_algorithm = gradient_clip_algorithm
# stuff to keep the gradients and adjust beta_vae
self.loss_type = None
self.grad_due_to_kl = None
self.grad_due_to_mse = None
self.grad_old = None
assert 0.0 < beta_vae_init < 1.0, \
"Error. beta_vae_init should be in (0,1). Received {0}".format(beta_vae_init)
self.register_buffer("beta_vae", float(beta_vae_init) * torch.ones(1, requires_grad=False).float())
self.momentum_beta_vae = momentum_beta_vae
# optimizer
self.optimizer_type = optimizer_type
# scheduler
assert warm_up_epochs + warm_down_epochs <= max_epochs
self.learning_rate_fn = linear_warmup_and_cosine_protocol(
f_values=(min_learning_rate, max_learning_rate, min_learning_rate),
x_milestones=(0, warm_up_epochs, max_epochs - warm_down_epochs, max_epochs))
self.weight_decay_fn = linear_warmup_and_cosine_protocol(
f_values=(min_weight_decay, min_weight_decay, max_weight_decay),
x_milestones=(0, warm_up_epochs, max_epochs - warm_down_epochs, max_epochs))
[docs] @classmethod
def add_specific_args(cls, parent_parser):
"""
Utility functions which add parameters to argparse to simplify setting up a CLI
Example:
>>> import sys
>>> import argparse
>>> parser = argparse.ArgumentParser(add_help=False, conflict_handler='resolve')
>>> parser = VaeModel.add_specific_args(parser)
>>> args = parser.parse_args(sys.argv[1:])
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False, conflict_handler='resolve')
# validation
parser.add_argument("--val_iomin_threshold", type=float, default=0.0,
help="during validation, only patches with IoMinArea < IoMin_threshold are used "
"in the kn-classifier and kn-regressor.")
# this model has manual optimization therefore it has to handle clipping internally.
parser.add_argument("--gradient_clip_val", type=float, default=0.5,
help="Clip the gradients to this value. If 0 no clipping")
parser.add_argument("--gradient_clip_algorithm", type=str, default="value", choices=["norm", "value"],
help="Algorithm to use for gradient clipping.")
# architecture
parser.add_argument("--backbone_type", type=str, default="resnet18",
choices=["vanilla", "resnet18", "resnet34", "resnet50"],
help="The backbone architecture of the VAE")
parser.add_argument("--global_size", type=int, default=64,
help="size in pixel of the input image. Must be a multiple of 32")
parser.add_argument("--image_in_ch", type=int, default=3, help="number of channels of the input image")
parser.add_argument("--latent_dim", type=int, default=128, help="number of latent dimensions")
parser.add_argument("--encoder_hidden_dims", type=int, nargs='*', default=[32, 64, 128, 256, 512],
help="dimension of the hidden layers. Used only in backbone_type='vanilla'.")
parser.add_argument("--decoder_output_activation", type=str, default="identity",
choices=["sigmoid", "identity", "tanh", "softplus", "relu"],
help="The non-linearity used to produce the reconstructed image.")
# optimizer
parser.add_argument("--optimizer_type", type=str, default='adam', help="optimizer type",
choices=['adamw', 'sgd', 'adam', 'rmsprop'])
# Parameters to update the beta (i.e. the balancing between MSE and KL)
parser.add_argument('--beta_vae_init', type=float, default=0.1,
help="Initial value for beta (coefficient in front of KL). Should be in (0.0, 1.0)")
parser.add_argument('--momentum_beta_vae', type=float, default=0.999,
help="momentum for the EMA which updates the value of beta")
# scheduler
parser.add_argument("--warm_up_epochs", default=10, type=int,
help="Number of epochs for the linear learning-rate warm up.")
parser.add_argument("--warm_down_epochs", default=100, type=int,
help="Number of epochs for the cosine decay.")
parser.add_argument("--max_epochs", type=int, default=300, help="maximum number of training epochs")
parser.add_argument('--min_learning_rate', type=float, default=1e-5,
help="Target LR at the end of cosine protocol (smallest LR used during training).")
parser.add_argument("--max_learning_rate", type=float, default=5e-4,
help="learning rate at the end of linear ramp (largest LR used during training).")
parser.add_argument('--min_weight_decay', type=float, default=0.0,
help="Minimum value of the weight decay. It is used during the linear ramp.")
parser.add_argument('--max_weight_decay', type=float, default=0.0,
help="Maximum Value of the weight decay. It is reached at the end of cosine protocol.")
return parser
[docs] @classmethod
def get_default_params(cls) -> dict:
"""
Get the default configuration parameters for this model
Example:
>>> config = VaeModel.get_default_params()
>>> my_barlow = VaeModel(**config)
"""
parser = ArgumentParser()
parser = cls.add_specific_args(parser)
args = parser.parse_args(args=[])
return args.__dict__
@staticmethod
def compute_losses(x_in, x_rec, mu, log_var):
# compute both kl and derivative of kl w.r.t. mu and log_var
assert len(mu.shape) == 2
batch_size = mu.shape[0]
kl_loss = 0.5 * (mu ** 2 + log_var.exp() - log_var - 1.0).sum() / batch_size
mse_loss = F.mse_loss(x_in, x_rec, reduction='mean')
return {
'mse_loss': mse_loss,
'kl_loss': kl_loss,
}
def head_and_backbone_embeddings_step(self, x):
# this generates both head and backbone embeddings
# return mu twice so that it is interpreted as backbone and head features
mu = self(x)
return mu, mu
def forward(self, x) -> torch.Tensor:
# this is the stuff that returns the embeddings "
dict_encoder = self.vae.encode(x)
return dict_encoder['mu']
def training_step(self, batch, batch_idx):
with torch.no_grad():
# Update the optimizer parameters
opt: torch.optim.Optimizer = self.optimizers()
lr = self.learning_rate_fn(self.current_epoch)
wd = self.weight_decay_fn(self.current_epoch)
for i, param_group in enumerate(opt.param_groups):
param_group["lr"] = lr
if i == 0: # only the first group is regularized
param_group["weight_decay"] = wd
else:
param_group["weight_decay"] = 0.0
# this is data augmentation
list_imgs = batch[0]
batch_size = len(list_imgs)
img_in = self.trsfm_train_global(list_imgs)
assert img_in.shape[-1] == self.image_size, \
"img.shape {0} vs image_size {1}".format(img_in.shape[-1], self.image_size)
# does the encoding-decoding
dict_vae = self.vae(img_in)
loss_dict = self.compute_losses(
x_in=dict_vae['x_in'],
x_rec=dict_vae['x_rec'],
mu=dict_vae['mu'],
log_var=dict_vae['log_var']
)
assert torch.all(torch.isfinite(dict_vae['mu'])), "In training step. mu in NOT finite"
assert torch.all(torch.isfinite(dict_vae['x_rec'])), "In training step. x_rec is NOT finite"
# Manual optimization
opt.zero_grad()
loss_kl = self.beta_vae * loss_dict["kl_loss"]
loss_mse = (1.0 - self.beta_vae) * loss_dict["mse_loss"]
if batch_idx == 0:
# two backward passes to collect the two gradients separately
self.manual_backward(loss_kl, retain_graph=True)
grad_due_to_kl_tmp = self.__get_grad_from_last_layer_of_encoder__()
self.manual_backward(loss_mse, retain_graph=False)
grad_tot_tmp = self.__get_grad_from_last_layer_of_encoder__()
grad_due_to_mse_tmp = grad_tot_tmp - grad_due_to_kl_tmp
else:
# a single backward pass
self.manual_backward(loss_kl + loss_mse)
grad_due_to_kl_tmp, grad_due_to_mse_tmp = None, None
self.clip_gradients(
opt,
gradient_clip_val=self.gradient_clip_val,
gradient_clip_algorithm=self.gradient_clip_algorithm
)
opt.step()
# end manual optimization
with torch.no_grad():
self.log('weight_decay', wd,
on_step=False, on_epoch=True, rank_zero_only=True, batch_size=1)
self.log('learning_rate', lr,
on_step=False, on_epoch=True, rank_zero_only=True, batch_size=1)
# Use the 75% quantile, i.e. we are requiring that 75% of the pixel are reconstructed better than rec_target
mse_for_constraint = torch.quantile(
input=(dict_vae['x_in'] - dict_vae['x_rec']).pow(2).sum(dim=-3),
q=0.75
)
self.log('train_loss', loss_kl + loss_mse,
on_step=False, on_epoch=True, rank_zero_only=True, batch_size=batch_size)
self.log('train_mse_loss', loss_dict['mse_loss'],
on_step=False, on_epoch=True, rank_zero_only=True, batch_size=batch_size)
self.log('train_kl_loss', loss_dict['kl_loss'],
on_step=False, on_epoch=True, rank_zero_only=True, batch_size=batch_size)
self.log('mse_for_constraint', mse_for_constraint,
on_step=False, on_epoch=True, rank_zero_only=True, batch_size=batch_size)
# batch_size
self.log('batch_size_per_gpu_train', float(len(list_imgs)),
on_step=False, on_epoch=True, rank_zero_only=True, batch_size=1)
# update the beta_vae if necessary
if grad_due_to_mse_tmp is not None and grad_due_to_kl_tmp is not None:
grad_due_to_mse = sync_ddp_if_available(grad_due_to_mse_tmp, group=None, reduce_op='mean')
grad_due_to_kl = sync_ddp_if_available(grad_due_to_kl_tmp, group=None, reduce_op='mean')
c11 = torch.dot(grad_due_to_kl, grad_due_to_kl) / self.beta_vae**2
c22 = torch.dot(grad_due_to_mse, grad_due_to_mse) / (1.0 - self.beta_vae)**2
c12 = torch.dot(grad_due_to_kl, grad_due_to_mse) / (self.beta_vae * (1.0 - self.beta_vae))
method = 0
if method == 0:
# find beta in (0,1) which minimizes: || beta * grad_kl + (1-beta) * grad_mse ||^2
# see paper: "Multi-Task Learning as Multi-Objective Optimization"
# This is the close form solution
ideal_beta_vae = ((c22 - c12) / (c11 + c22 - 2 * c12)).clamp(min=0.0, max=1.0)
elif method == 1:
# find beta in (0,1) which makes the two gradient equal size, i.e.:
# set: beta * sqrt(c11) = (1 - beta) * sqrt(c22)
# leads to: beta = sqrt(c22) / (sqrt(c11) + sqrt(c22))
ideal_beta_vae = (c22.sqrt() / (c11.sqrt() + c22.sqrt())).clamp(min=0.0, max=1.0)
else:
raise Exception("Method can only be 0 or 1. Received {0}".format(method))
# update beta using a slow Exponential Moving Average (EMA)
self.__update_beta_vae__(ideal_beta=ideal_beta_vae, beta_momentum=self.momentum_beta_vae)
self.log('beta/c11', c11, on_step=False, on_epoch=True, rank_zero_only=True, batch_size=1)
self.log('beta/c12', c12, on_step=False, on_epoch=True, rank_zero_only=True, batch_size=1)
self.log('beta/c22', c22, on_step=False, on_epoch=True, rank_zero_only=True, batch_size=1)
self.log('beta/beta_vae', self.beta_vae,
on_step=False, on_epoch=True, rank_zero_only=True, batch_size=1)
self.log('beta/ideal_beta_vae', ideal_beta_vae,
on_step=False, on_epoch=True, rank_zero_only=True, batch_size=1)
def __get_grad_from_last_layer_of_encoder__(self) -> torch.Tensor:
grad = torch.cat((
self.vae.fc_mu.bias.grad.detach().clone().flatten(),
self.vae.fc_mu.weight.grad.detach().clone().flatten(),
self.vae.fc_var.bias.grad.detach().clone().flatten(),
self.vae.fc_var.weight.grad.detach().clone().flatten()
), dim=0)
return grad
def __update_beta_vae__(self, ideal_beta, beta_momentum):
# update only if the suggested beta is finite
if ideal_beta.isfinite():
tmp = beta_momentum * self.beta_vae + (1.0 - beta_momentum) * ideal_beta
self.beta_vae = tmp.clamp(min=1.0E-5, max=1.0 - 1.0E-5)
def validation_step(self, batch, batch_idx, dataloader_idx: int = -1):
# Log an example of the reconstructed images
if self.global_rank == 0 and batch_idx == 0:
list_imgs = batch[0]
img_in = self.trsfm_test(list_imgs)
dict_vae = self.vae(img_in)
img_out = dict_vae['x_rec'].clone().detach().float() # make sure this is in full precision for plotting
one_ch_tmp_out_plot = show_batch(img_out[0].unsqueeze(dim=-3), n_col=5,
title="output, epoch={0}".format(self.current_epoch))
self.logger.run["rec/output_imgs/one_ch"].log(File.as_image(one_ch_tmp_out_plot))
all_ch_tmp_out_plot = show_batch(img_out[:10], n_col=5,
title="output, epoch={0}".format(self.current_epoch))
self.logger.run["rec/output_imgs/all_ch"].log(File.as_image(all_ch_tmp_out_plot))
if not self.already_loaded_input_val_images:
img_in_tmp = img_in.clone().detach().float() # make sure this is in full precision for plotting
one_ch_tmp_in_plot = show_batch(img_in_tmp[0].unsqueeze(dim=-3), n_col=5,
title="input, epoch={0}".format(self.current_epoch))
self.logger.run["rec/input_imgs/one_ch"].log(File.as_image(one_ch_tmp_in_plot))
all_ch_tmp_in_plot = show_batch(img_in_tmp[:10], n_col=5,
title="input, epoch={0}".format(self.current_epoch))
self.logger.run["rec/input_imgs/all_ch"].log(File.as_image(all_ch_tmp_in_plot))
self.already_loaded_input_val_images = True
# call the super.validation_step
return super(VaeModel, self).validation_step(batch, batch_idx, dataloader_idx)
def configure_optimizers(self):
# the learning_rate and weight_decay are very large. They are just placeholder.
# The real value will be set by the scheduler
regularized = []
not_regularized = []
for name, param in self.vae.named_parameters():
if not param.requires_grad:
continue
# we do not regularize biases nor Norm parameters
if name.endswith(".bias") or len(param.shape) == 1:
not_regularized.append(param)
else:
regularized.append(param)
arg_for_optimizer = [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.0}]
# The real lr will be set in the training step
# The weight_decay for the regularized group will be set in the training step
if self.optimizer_type == 'adam':
return torch.optim.Adam(arg_for_optimizer, betas=(0.9, 0.999), lr=0.0)
elif self.optimizer_type == 'sgd':
return torch.optim.SGD(arg_for_optimizer, momentum=0.9, lr=0.0)
elif self.optimizer_type == 'rmsprop':
return torch.optim.RMSprop(arg_for_optimizer, alpha=0.99, lr=0.0)
elif self.optimizer_type == 'lars':
# for convnet with large batch_size
return LARS(arg_for_optimizer, momentum=0.9, lr=0.0)
else:
# do adamw
raise Exception("optimizer is misspecified")