from typing import List, Any, Dict
import torch
from torch.nn import functional as F
import math
from argparse import ArgumentTypeError as ArgparseArgumentTypeError
from argparse import ArgumentParser
from ._resnet_backbone import make_resnet_backbone
from ._ssl_base_model import SslModelBase
from tissue_purifier.models._optim_scheduler import LARS, linear_warmup_and_cosine_protocol

def smart_bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
        raise ArgparseArgumentTypeError('Boolean value expected.')

def dino_loss(output_t: torch.Tensor,
              output_s: torch.Tensor,
              ncrops_t: int,
              ncrops_s: int,
              temp_t: torch.Tensor,
              temp_s: torch.Tensor,
              center_t: torch.Tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor):
        total_loss, prob_t, prob_s
    all_logit_s = output_s / temp_s
    all_prob_s = F.softmax(all_logit_s, dim=-1)
    all_log_prob_s = F.log_softmax(all_logit_s, dim=-1)
    prob_s = all_prob_s.chunk(ncrops_s)
    log_prob_s = all_log_prob_s.chunk(ncrops_s)

    all_logit_t = (output_t - center_t).detach() / temp_t
    all_prob_t = F.softmax(all_logit_t, dim=-1)
    all_log_prob_t = F.log_softmax(all_logit_t, dim=-1)
    prob_t = all_prob_t.chunk(ncrops_t)
    log_prob_t = all_log_prob_t.chunk(ncrops_t)

    total_loss = 0
    n_loss_terms = 0
    for iq, (q, log_q) in enumerate(zip(prob_t, log_prob_t)):
        for ip, (p, log_p) in enumerate(zip(prob_s, log_prob_s)):
            if ip == iq:
                # we skip cases where student and teacher operate on the same view
            # TODO: There is inconsistency in paper vs code. SYmmetric or non-symmetric version of KL
            # loss = -0.5 * (p * log_q + q * log_p).sum(dim=-1)  # shape: BATCH_SIZE
            loss = -(q * log_p).sum(dim=-1)  # shape: BATCH_SIZE
            total_loss += loss.mean()
            n_loss_terms += 1
    total_loss /= n_loss_terms
    return total_loss, all_prob_t, all_prob_s

def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        print("WARNINGS: The mean is more than 2 std from [a, b] in nn.init.trunc_normal_. \
              The distribution of values may be incorrect.")

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor

def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    # type: (Tensor, float, float, float, float) -> Tensor
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)

class DINOHead(torch.nn.Module):
    def __init__(self, in_dim: int, hidden_dim: List[int], out_dim: int, use_bn=False, norm_last_layer=True):
        assert len(hidden_dim) >= 1
        sizes = [in_dim] + hidden_dim + [out_dim]

        layers = []
        for i in range(len(sizes) - 2):
            if use_bn:
                layers.append(torch.nn.Linear(sizes[i], sizes[i + 1], bias=False))
                layers.append(torch.nn.BatchNorm1d(sizes[i + 1]))
                layers.append(torch.nn.Linear(sizes[i], sizes[i + 1], bias=True))
        self.mlp = torch.nn.Sequential(*layers)
        self.last_layer = torch.nn.utils.weight_norm(torch.nn.Linear(sizes[-2], sizes[-1], bias=False))
        if norm_last_layer:
            self.last_layer.weight_g.requires_grad = False

    def _init_weights(self, m):
        if isinstance(m, torch.nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, torch.nn.Linear) and m.bias is not None:
                torch.nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.mlp(x)
        x = F.normalize(x, dim=-1, p=2)
        x = self.last_layer(x)
        return x

class MultiResolutionNet(torch.nn.Module):
    Net that can act on both a single torch.Tensor or a list of torch.Tensors
    with possibly different spatiual resolutions.
    def __init__(self,
                 backbone_type: str,
                 backbone_in_ch: int,
                 head_hidden_chs: List[int],
                 head_out_ch: int,
                 head_use_bn: bool):

        self.backbone = make_resnet_backbone(
        in_tmp = torch.zeros((1, backbone_in_ch, 32, 32))
        out_tmp = self.backbone(in_tmp)
        head_ch_in = out_tmp.shape[1]

        self.head = DINOHead(

    def init_projection(
            ch_in: int,
            ch_out: int,
            ch_hidden: List[int]=None):

        sizes = [ch_in] + ch_hidden + [ch_out]
        layers = []
        for i in range(len(sizes) - 2):
            layers.append(torch.nn.Linear(sizes[i], sizes[i + 1], bias=False))
            layers.append(torch.nn.BatchNorm1d(sizes[i + 1]))
        layers.append(torch.nn.Linear(sizes[-2], sizes[-1], bias=False))
        return torch.nn.Sequential(*layers)

    def forward(self, x):
        """ x is either a torch.Tensor or a list of torch.Tensor of possibly different resolutions.
            1. concatenate tensor with same resolution and run the backbone
            2. concatenate the features
            3. run the head just one on the concatenated features
        # convert to list
        if not isinstance(x, list):
            output = self.backbone(x)  # output is of size (b, c)
            idx_crops = torch.cumsum(torch.unique_consecutive(
                torch.tensor([inp.shape[-1] for inp in x]),
                return_counts=True)[1], dim=0)

            start_idx, output = 0, None
            for end_idx in idx_crops:
                _out = self.backbone([start_idx: end_idx], dim=0))

                # accumulate outputs
                if output is None:
                    output = _out
                    output =, _out), dim=0)

                start_idx = end_idx

        # Run the head forward on the concatenated features.
        return self.head(output), output

[docs]class DinoModel(SslModelBase): """ DINO self supervised learning model. Inspired by the `DINO official implementation <>`_ and this `DINO pytorch-lightning reimplementation <\ loss%20function/2021/08/01/dino-self-supervised-vision-transformers.html>`_ """ def __init__( self, # architecture image_in_ch: int, backbone_type: str, head_hidden_chs: List[int], head_use_bn: bool, head_out_ch: int, # teacher centering center_momentum: float, # optimizer optimizer_type: str, # scheduler warm_up_epochs: int, warm_down_epochs: int, max_epochs: int, min_learning_rate: float, max_learning_rate: float, min_weight_decay: float, max_weight_decay: float, # validation val_iomin_threshold: float = 0.0, # temperatures set_temperature_using_ipr_init: bool = False, ipr_teacher_init: float = 40.0, ipr_student_init: float = 80.0, temperature_teacher_init: float = 0.04, temperature_student_init: float = 0.1, # the teacher's parameters update slowly and then stop param_momentum_init: float = 0.996, param_momentum_final: float = 0.996, param_momentum_epochs_end: int = 1000, **kwargs, ): """ Args: backbone_type: Either 'resnet18', 'resnet34' or 'resnet50' image_in_ch: number of channels in the input images, used to adjust the first convolution filter in the backbone head_hidden_chs: List of integers with the size of the hidden layers of the projection head head_use_bn: if True (defaults) uses BatchNormalization in the projection head head_out_ch: output dimension of the projection head center_momentum: the teacher network does centering and sharpening. The center_momentum which is a float in (0.0, 1.0) described the momentum of the centering, i.e.: .. :math: \\text{center} = c \\times \\text{center} + (1-c) \\times \\text{center}_\\text{empirical} optimizer_type: Either 'adamw', 'lars', 'sgd', 'adam' or 'rmsprop' warm_up_epochs: epochs during which to linearly increase learning rate (at the beginning of training) warm_down_epochs: epochs during which to anneal learning rate with cosine protocoll (at the end of training) max_epochs: total number of epochs min_learning_rate: minimum learning rate (at the very beginning and end of training) max_learning_rate: maximum learning rate (after linear ramp) min_weight_decay: minimum weight decay (during the entirety of the linear ramp) max_weight_decay: maximum weight decay (reached at the end of training) val_iomin_threshold: during validation, only patches with Intersection Over MinArea < IoMin_threshold are used. Should be in [0.0, 1.0). If 0 only strictly non-overlapping patches are allowed. set_temperature_using_ipr_init: if True the user specifies the initial InverseParticipationRatio (IPR) for the teacher and student network and the temperature of the softmax is computed to match the desired initial IPR. ipr_teacher_init: Used only if :attr:`set_temperature_using_ipr_init` == True. The desired initial IPR for the teacher. ipr_student_init: Used only if :attr:`set_temperature_using_ipr_init` == True. The desired initial IPR for the student. temperature_teacher_init: Used only if :attr:`set_temperature_using_ipr_init` == False. The temperature for the softmax of the teacher. temperature_student_init: Used only if :attr:`set_temperature_using_ipr_init` == False. The temperature for the softmax of the student. param_momentum_init: the teacher params are updated using an Exponential Moving Average (EMA). This parameter controls the initial momentum of the EMA param_momentum_final: This parameter controls the final momentum of the EMA param_momentum_epochs_end: The teacher parameters are not updated after this many epochs """ super(DinoModel, self).__init__(val_iomin_threshold=val_iomin_threshold) # Next two lines will make checkpointing much simpler. Always keep them as-is self.save_hyperparameters() # all hyperparameters are saved to the checkpoint self.neptune_run_id = None # if from scratch neptune_experiment_is is None # architecture self.student = MultiResolutionNet( backbone_type=backbone_type, backbone_in_ch=image_in_ch, head_hidden_chs=head_hidden_chs, head_out_ch=head_out_ch, head_use_bn=head_use_bn, ) # this is creating a separate teacher object with the same weights as the student object self.teacher = MultiResolutionNet( backbone_type=backbone_type, backbone_in_ch=image_in_ch, head_hidden_chs=head_hidden_chs, head_out_ch=head_out_ch, head_use_bn=head_use_bn, ) self.teacher.load_state_dict(self.student.state_dict()) for p in self.teacher.parameters(): p.requires_grad = False self.dim_out = head_out_ch self.register_buffer("center_teacher", torch.zeros(self.dim_out, requires_grad=False).float()) self.register_buffer("population_t", torch.zeros(self.dim_out, requires_grad=False).float()) self.register_buffer("population_s", torch.zeros(self.dim_out, requires_grad=False).float()) # teacher parameters self.set_temperature_using_ipr_init = set_temperature_using_ipr_init if self.set_temperature_using_ipr_init: self.ipr_teacher_init = float(ipr_teacher_init) self.ipr_student_init = float(ipr_student_init) self.register_buffer("student_temperature", float(1.0) * torch.ones(1, requires_grad=False).float()) self.register_buffer("teacher_temperature", float(1.0) * torch.ones(1, requires_grad=False).float()) else: self.ipr_teacher_init = -1.0 self.ipr_student_init = -1.0 self.register_buffer("student_temperature", float(temperature_student_init) * torch.ones(1, requires_grad=False).float()) self.register_buffer("teacher_temperature", float(temperature_teacher_init) * torch.ones(1, requires_grad=False).float()) self.teacher_center_momentum = float(center_momentum) self.teacher_parameter_momentum_fn = linear_warmup_and_cosine_protocol( f_values=(param_momentum_init, param_momentum_init, param_momentum_final), x_milestones=(0, 0, 0, param_momentum_epochs_end)) # optimizer self.optimizer_type = optimizer_type # scheduler assert warm_up_epochs + warm_down_epochs <= max_epochs self.learning_rate_fn = linear_warmup_and_cosine_protocol( f_values=(min_learning_rate, max_learning_rate, min_learning_rate), x_milestones=(0, warm_up_epochs, max_epochs - warm_down_epochs, max_epochs)) self.weight_decay_fn = linear_warmup_and_cosine_protocol( f_values=(min_weight_decay, min_weight_decay, max_weight_decay), x_milestones=(0, warm_up_epochs, max_epochs - warm_down_epochs, max_epochs))
[docs] @classmethod def add_specific_args(cls, parent_parser): """ Utility functions which add parameters to argparse to simplify setting up a CLI Example: >>> import sys >>> import argparse >>> parser = argparse.ArgumentParser(add_help=False, conflict_handler='resolve') >>> parser = DinoModel.add_specific_args(parser) >>> args = parser.parse_args(sys.argv[1:]) """ parser = ArgumentParser(parents=[parent_parser], add_help=False, conflict_handler='resolve') # validation parser.add_argument("--val_iomin_threshold", type=float, default=0.0, help="during validation, only patches with IoMinArea < IoMin_threshold are used " "in the kn-classifier and kn-regressor.") # architecture parser.add_argument("--image_in_ch", type=int, default=3, help="number of channels in the input images") parser.add_argument("--backbone_type", type=str, default="resnet34", help="backbone type", choices=['resnet18', 'resnet34', 'resnet50']) parser.add_argument("--head_hidden_chs", type=int, nargs='+', default=[256, 512], help="head hidden channels") parser.add_argument("--head_out_ch", type=int, default=512, help="head output channels") parser.add_argument("--head_use_bn", type=smart_bool, default=True, help="use batch normalization layers in the DINOHead") # optimizer parser.add_argument("--optimizer_type", type=str, default='adam', help="optimizer type", choices=['adamw', 'lars', 'sgd', 'adam', 'rmsprop']) # Updating of the teacher network parser.add_argument("--param_momentum_init", type=float, default=0.996, help="Teacher parameters are updated with EMA starting from this value") parser.add_argument("--param_momentum_final", type=float, default=0.996, help="Teacher parameters are updated with EMA ending at this value") parser.add_argument("--param_momentum_epochs_end", type=int, default=10000, help="The teacher parameters momentum reach its final value after this many epochs") parser.add_argument("--center_momentum", type=float, default=0.9, help="momentum for updating the teacher softmax") # temperatures parser.add_argument("--set_temperature_using_ipr_init", type=smart_bool, default=True, help="If true the student and teacher's temperatures are fixed so that IPR \ is equal to the ipr_init") parser.add_argument("--temperature_teacher_init", type=float, default=0.04, help="Initial value of the temperature in the softmax of the teacher. \ Ignored if set_temperature_using_ipr_init = True") parser.add_argument("--temperature_student_init", type=float, default=0.1, help="Initial value of the temperature in the softmax of the student.") parser.add_argument("--ipr_teacher_init", type=float, default=40.0, help="The desired value of the initial IPR for the teacher. \ Ignored if set_temperature_using_ipr_init = True") parser.add_argument("--ipr_student_init", type=float, default=80.0, help="The desired value of the initial IPR for the student. \ Ignored if set_temperature_using_ipr_init = True") # scheduler parser.add_argument("--max_epochs", type=int, default=1000, help="maximum number of training epochs") parser.add_argument("--warm_up_epochs", default=100, type=int, help="Number of epochs for the linear learning-rate warm up.") parser.add_argument("--warm_down_epochs", default=500, type=int, help="Number of epochs for the cosine decay.") parser.add_argument('--min_learning_rate', type=float, default=1e-5, help="Target LR at the end of cosine protocol (smallest LR used during training).") parser.add_argument("--max_learning_rate", type=float, default=5e-4, help="learning rate at the end of linear ramp (largest LR used during training).") parser.add_argument('--min_weight_decay', type=float, default=0.04, help="Minimum value of the weight decay. It is used during the linear ramp.") parser.add_argument('--max_weight_decay', type=float, default=0.4, help="Maximum Value of the weight decay. It is reached at the end of cosine protocol.") return parser
[docs] @classmethod def get_default_params(cls) -> dict: """ Get the default configuration parameters for this model Example: >>> config = DinoModel.get_default_params() >>> my_barlow = DinoModel(**config) """ parser = ArgumentParser() parser = DinoModel.add_specific_args(parser) args = parser.parse_args(args=[]) return args.__dict__
def forward(self, x) -> torch.Tensor: # this is the stuff that will generate the backbone embeddings z, y = self.teacher(x) return y def head_and_backbone_embeddings_step(self, x): # this generates both head and backbone embeddings z, y = self.teacher(x) return z, y def training_step(self, batch, batch_idx) -> torch.Tensor: with torch.no_grad(): # Update the optimizer parameters lr = self.learning_rate_fn(self.current_epoch) wd = self.weight_decay_fn(self.current_epoch) for i, param_group in enumerate(self.optimizers().param_groups): param_group["lr"] = lr if i == 0: # only the first group is regularized param_group["weight_decay"] = wd else: param_group["weight_decay"] = 0.0 # this is data augmentation list_imgs, list_labels, list_metadata = batch # create the global and local crops list_of_minibatches = [] for n in range(self.n_global_crops): list_of_minibatches.append(self.trsfm_train_global(list_imgs)) for n in range(self.n_local_crops): list_of_minibatches.append(self.trsfm_train_local(list_imgs)) # forward for teacher is inside the no-grad context z_t, y_t = self.head_and_backbone_embeddings_step(list_of_minibatches[:self.n_global_crops]) # forward for student is outside the no-grad context z_s, y_s = self.student(list_of_minibatches) with torch.no_grad(): if self.global_step == 0 and self.set_temperature_using_ipr_init: self.__update_temperatures__( output_t=z_t, output_s=z_s, ipr_t=self.ipr_teacher_init, ipr_s=self.ipr_student_init) # loss loss: torch.Tensor prob_t: torch.Tensor prob_s: torch.Tensor loss, prob_t, prob_s = dino_loss( output_t=z_t, output_s=z_s, ncrops_t=self.n_global_crops, ncrops_s=self.n_global_crops+self.n_local_crops, temp_s=self.student_temperature, temp_t=self.teacher_temperature, center_t=self.center_teacher, ) # gather stuff from all GPUs and update the teacher on all GPUs identically with torch.no_grad(): # compute sample entropy tmp_entropy_sample_t = self.__normalized_stable_entropy__(prob_t).mean() # shape: [] tmp_entropy_sample_s = self.__normalized_stable_entropy__(prob_s).mean() # shape: [] # compute sample ipr tmp_ipr_sample_t = self.__inverse_participation_ratio__(prob_t).mean() # shape: [] tmp_ipr_sample_s = self.__inverse_participation_ratio__(prob_s).mean() # shape: [] # compute the pdf of the population tmp_population_t = prob_t.sum(dim=0) # sum over samples --> shape: latent tmp_population_s = prob_s.sum(dim=0) # sum over samples --> shape: latent # compute the teacher center tmp_empirical_center_teacher = z_t.mean(dim=0) # sum over samples --> shape: latent # compute the local batch_size tmp_local_batch_size = torch.tensor(len(list_imgs), device=self.device, dtype=torch.float) # do the all_gather operations together tmp = [tmp_entropy_sample_t, tmp_entropy_sample_s, tmp_ipr_sample_t, tmp_ipr_sample_s, tmp_population_t, tmp_population_s, tmp_empirical_center_teacher, tmp_local_batch_size] world_tmp = self.all_gather(tmp, sync_grads=False) w_ent_t, w_ent_s, w_ipr_t, w_ipr_s, w_pop_t, w_pop_s, w_center_t, w_batch_size = world_tmp # update the population over a mini-batch if len(w_pop_s.shape) == 1 + len(tmp_population_s.shape): self.population_s.add_(w_pop_s.sum(dim=0)) # shape: latent self.population_t.add_(w_pop_t.sum(dim=0)) # shape: latent empirical_center_teacher = w_center_t.mean(dim=0) # shape: latent else: self.population_s.add_(w_pop_s) # shape: latent self.population_t.add_(w_pop_t) # shape: latent empirical_center_teacher = w_center_t # shape: latent # update teacher parameter and center center_momentum = self.teacher_center_momentum param_momentum = self.teacher_parameter_momentum_fn(self.current_epoch) self.__update_teacher_param__(p_momentum=param_momentum) self.__update_teacher_center__(empirical_center_teacher=empirical_center_teacher, c_momentum=center_momentum) # Finally I log interesting stuff self.log('train_loss', loss, on_step=False, on_epoch=True, rank_zero_only=True, batch_size=1) self.log('train_loss', loss, on_step=False, on_epoch=True, rank_zero_only=True, batch_size=1) self.log('weight_decay', wd, on_step=False, on_epoch=True, rank_zero_only=True, batch_size=1) self.log('learning_rate', lr, on_step=False, on_epoch=True, rank_zero_only=True, batch_size=1) self.log('teacher_param_momentum', param_momentum, on_step=False, on_epoch=True, rank_zero_only=True, batch_size=1) self.log('teacher_center_momentum', center_momentum, on_step=False, on_epoch=True, rank_zero_only=True, batch_size=1) self.log('teacher_temperature', self.teacher_temperature, on_step=False, on_epoch=True, rank_zero_only=True, batch_size=1) self.log('student_temperature', self.student_temperature, on_step=False, on_epoch=True, rank_zero_only=True, batch_size=1) self.log('ipr_studnet_init', self.ipr_student_init, on_step=False, on_epoch=True, rank_zero_only=True, batch_size=1) self.log('ipr_teacher_init', self.ipr_teacher_init, on_step=False, on_epoch=True, rank_zero_only=True, batch_size=1) # These are correct both if all_gather add an extra dimension or not (depends if trainer(strategy='dpp')). self.log('batch_size_per_gpu_train', w_batch_size.mean(), on_step=False, on_epoch=True, rank_zero_only=True) self.log('batch_size_total_train', w_batch_size.sum(), on_step=False, on_epoch=True, rank_zero_only=True) self.log('entropy_sample_t', w_ent_t.mean(), on_step=False, on_epoch=True, rank_zero_only=True) self.log('entropy_sample_s', w_ent_s.mean(), on_step=False, on_epoch=True, rank_zero_only=True) self.log('ipr_sample_t', w_ipr_t.mean(), on_step=False, on_epoch=True, rank_zero_only=True) self.log('ipr_sample_s', w_ipr_s.mean(), on_step=False, on_epoch=True, rank_zero_only=True) return loss def on_train_epoch_start(self) -> None: self.population_t.fill_ = 0.0 self.population_s.fill_ = 0.0 def on_train_epoch_end(self, unused=None) -> None: assert self.population_t.shape == self.population_s.shape == torch.Size([self.dim_out]), \ "This should be 1D vector of size: {0}. Received {0}".format(self.dim_out, self.population_t.shape) tmp_t = self.population_t / torch.sum(self.population_t) # shape: latent tmp_s = self.population_s / torch.sum(self.population_s) # shape: latent # compute population entropy and population_ipr entropy_population_t = self.__normalized_stable_entropy__(tmp_t) entropy_population_s = self.__normalized_stable_entropy__(tmp_s) ipr_population_t = self.__inverse_participation_ratio__(tmp_t) ipr_population_s = self.__inverse_participation_ratio__(tmp_s) self.log('entropy_population_t', entropy_population_t, on_step=False, on_epoch=True, rank_zero_only=True, batch_size=1) self.log('entropy_population_s', entropy_population_s, on_step=False, on_epoch=True, rank_zero_only=True, batch_size=1) self.log('ipr_population_t', ipr_population_t, on_step=False, on_epoch=True, rank_zero_only=True, batch_size=1) self.log('ipr_population_s', ipr_population_s, on_step=False, on_epoch=True, rank_zero_only=True, batch_size=1) def __update_temperatures__(self, output_t, output_s, ipr_t, ipr_s): ideal_temp_teacher_tmp = self.__find_ideal_temperature__( output=output_t, tau_init=torch.ones_like(self.teacher_temperature), desired_ipr=ipr_t) ideal_temp_student_tmp = self.__find_ideal_temperature__( output=output_s, tau_init=torch.ones_like(self.student_temperature), desired_ipr=ipr_s) w_ideal_temp_teacher, w_ideal_temp_student = self.all_gather( [ideal_temp_teacher_tmp, ideal_temp_student_tmp], sync_grads=False) if len(w_ideal_temp_student.shape) == 1 + len(ideal_temp_student_tmp.shape): ideal_temp_teacher = w_ideal_temp_teacher.mean(dim=0).clamp(min=1.0E-5, max=1.0E5) ideal_temp_student = w_ideal_temp_student.mean(dim=0).clamp(min=1.0E-5, max=1.0E5) else: ideal_temp_teacher = w_ideal_temp_teacher.clamp(min=1.0E-5, max=1.0E5) ideal_temp_student = w_ideal_temp_student.clamp(min=1.0E-5, max=1.0E5) assert ideal_temp_teacher.shape == ideal_temp_student.shape == torch.Size([1]), \ "Expected torch.Size([1]). Received {0}, {1}".format(ideal_temp_teacher.shape, ideal_temp_student.shape) assert ideal_temp_teacher.isfinite() and ideal_temp_student.isfinite(), \ "Ideal temperature is not finite. Received {0}, {1}".format(ideal_temp_teacher, ideal_temp_student) assert ideal_temp_teacher < ideal_temp_student, \ "Error. teacher temperature {0} must be SMALLER than student temperature {1}.".format(ideal_temp_teacher, ideal_temp_student) self.teacher_temperature = ideal_temp_teacher self.student_temperature = ideal_temp_student @staticmethod def __find_ideal_temperature__( output: torch.Tensor, tau_init: torch.Tensor, desired_ipr: float, f_eps: float = 1E-3, eps: float = 1E-4) -> torch.Tensor: """ Find the temperature that makes the IPR = desired_ipr. Solve this problem using bisection search. Args: output: the raw output of the student network of shape (batch, latent_dim) tau_init: initial guess for the temperature to use. A good guess make search conclude earlier. desired_ipr: float larger than 1.01. The desired IPR for a sample. IPR=1.0 means one-hot probabilities. Temperature tends to 0.0 if IPR tends to 1.0. f_eps: algorithm terminates if (IPR - desired_ipr).abs() < f_eps eps: algorithm terminates if the (tau_max - tau_min) < eps Return: tau, a torch.Tensor with the temperature value which makes the IPR close to the desired value. """ def f(_temperature: torch.Tensor) -> torch.Tensor: """ This is the function I want to find the zero of. """ _logit = output / _temperature _p = F.softmax(_logit, dim=-1) _ipr_sample = 1.0 / (_p * _p).sum(dim=-1) tmp = _ipr_sample.mean(dim=0) - desired_ipr assert torch.isfinite(tmp), \ "Error. The function I want to find the zero of is not finite {0}".format(tmp) return tmp assert isinstance(desired_ipr, float) and desired_ipr > 1.01, \ "Desired IPR must be a float > 1.01. Received {0}".format(desired_ipr) with torch.no_grad(): n_max = 20 n = 0 a = b = tau_init fa = fb = f(a) if fb > 0: # ipr too high -> decrease temperature while (fb > 0) and (n < n_max): n = n + 1 fa = fb a = b b = b * 0.5 fb = f(b) else: # ipr too low -> increase temperature while (fb < 0) and (n < n_max): n = n + 1 fa = fb a = b b = b * 2.0 fb = f(b) if fa * fb < 0.0 and torch.isfinite(a) and torch.isfinite(b): # I have bracketed a zero. Now I can do bisection algorithm n_max = 20 n = 0 fc = fb while abs(a - b) > eps and fc.abs() > f_eps and n < n_max: c = 0.5 * (a + b) fc = f(c) n = n + 1 if fa * fc < 0: b = c fb = fc elif fb * fc < 0: a = c fa = fc else: raise Exception("something wrong happened: a,b,c=({0},{1},{2}) \ and fa,fb,fc=({3},{4},{5})".format(a, b, c, fa, fb, fc)) return 0.5 * (a + b) else: print("WARNING. dynamic temperature routine did not find a finite solution.") return torch.tensor([float('inf')], device=output.device, dtype=torch.float) @staticmethod def __inverse_participation_ratio__(x): ipr = 1.0 / (x * x).sum(dim=-1) return ipr @staticmethod def __normalized_stable_entropy__(x): x_logx = x * x.log() tmp = torch.where(torch.isfinite(x_logx), x_logx, torch.zeros_like(x_logx)) return -torch.sum(tmp, dim=-1) / math.log(float(tmp.shape[-1])) def __update_teacher_param__(self, p_momentum: float): if p_momentum < 1.0: for param_s, param_t in zip(self.student.parameters(), self.teacher.parameters()): - p_momentum) * def __update_teacher_center__(self, empirical_center_teacher: torch.Tensor, c_momentum: float): assert empirical_center_teacher.shape == torch.Size([self.dim_out]), \ "Received {0}. Expected torch.Size([{1}])".format(empirical_center_teacher.shape, self.dim_out) self.center_teacher = c_momentum * self.center_teacher + (1.0-c_momentum) * empirical_center_teacher def configure_optimizers(self) -> torch.optim.Optimizer: regularized = [] not_regularized = [] for name, param in self.student.named_parameters(): if not param.requires_grad: continue # we do not regularize biases nor Norm parameters if name.endswith(".bias") or len(param.shape) == 1: not_regularized.append(param) else: regularized.append(param) arg_for_optimizer = [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.0}] # The real lr will be set in the training step # The weight_decay for the regularized group will be set in the training step if self.optimizer_type == 'adam': return torch.optim.Adam(arg_for_optimizer, betas=(0.9, 0.999), lr=0.0) elif self.optimizer_type == 'sgd': return torch.optim.SGD(arg_for_optimizer, momentum=0.9, lr=0.0) elif self.optimizer_type == 'rmsprop': return torch.optim.RMSprop(arg_for_optimizer, alpha=0.99, lr=0.0) elif self.optimizer_type == 'lars': # for convnet with large batch_size return LARS(arg_for_optimizer, momentum=0.9, lr=0.0) else: # do adamw raise Exception("optimizer is misspecified")