A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://mmengine.readthedocs.io/en/latest/advanced_tutorials/initialize.html below:

Weight initialization — mmengine 0.10.7 documentation

Weight initialization

Usually, we’ll customize our module based on nn.Module, which is implemented by Native PyTorch. Also, torch.nn.init could help us initialize the parameters of the model easily. To simplify the process of model construction and initialization, MMEngine designed the BaseModule to help us define and initialize the model from config easily.

Initialize the model from config

The core function of BaseModule is that it could help us to initialize the model from config. Subclasses inherited from BaseModule could define the init_cfg in the __init__ function, and we can choose the method of initialization by configuring init_cfg.

Currently, we support the following initialization methods:

Initializer Registered name Function ConstantInit Constant Initialize the weight and bias with a constant, commonly used for Convolution XavierInit Xavier Initialize the weight by Xavier initialization, and initialize the bias with a constant NormalInit Normal Initialize the weight by normal distribution, and initialize the bias with a constant TruncNormalInit TruncNormal Initialize the weight by truncated normal distribution, and initialize the bias with a constant, commonly used for Transformer UniformInit Uniform Initialize the weight by uniform distribution, and initialize the bias with a constant, commonly used for convolution KaimingInit Kaiming Initialize the weight by Kaiming initialization, and initialize the bias with a constant. Commonly used for convolution Caffe2XavierInit Caffe2Xavier Xavier initialization in Caffe2, and Kaiming initialization in PyTorh with "fan_in" and "normal" mode. Commonly used for convolution PretrainedInit Pretrained Initialize the model with the pretrained model Initialize the model with pretrained model

Defining the ToyNet as below:

import torch
import torch.nn as nn

from mmengine.model import BaseModule


class ToyNet(BaseModule):

    def __init__(self, init_cfg=None):
        super().__init__(init_cfg)
        self.conv1 = nn.Linear(1, 1)


# Save the checkpoint.
toy_net = ToyNet()
torch.save(toy_net.state_dict(), './pretrained.pth')
pretrained = './pretrained.pth'

toy_net = ToyNet(init_cfg=dict(type='Pretrained', checkpoint=pretrained))

and then we can configure the init_cfg to make it load the pretrained model by calling initi_weights() after its construction.

# Initialize the model with the saved checkpoint.
toy_net.init_weights()
08/19 16:50:24 - mmengine - INFO - load model from: ./pretrained.pth
08/19 16:50:24 - mmengine - INFO - local loads checkpoint from path: ./pretrained.pth

If init_cfg is a dict, type means a kind of initializer registered in WEIGHT_INITIALIZERS. The Pretrained means PretrainedInit, which could help us to load the target checkpoint. All initializers have the same mapping relationship like Pretrained -> PretrainedInit, which strips the suffix Init of the class name. The checkpoint argument of PretrainedInit means the path of the checkpoint. It could be a local path or a URL.

Note

PretrainedInit has a higher priority than any other initializer. The loaded pretrained weights will overwrite the previous initialized weights.

Commonly used initialization methods

Similarly, we could use the Kaiming initialization just like Pretrained initializer. For example, we could make init_cfg=dict(type='Kaiming', layer='Conv2d') to initialize all Conv2d module with Kaiming initialization.

Sometimes we need to initialize the model with different initialization methods for different modules. For example, we could initialize the Conv2d module with Kaiming initialization and initialize the Linear module with Xavier initialization. We could make init_cfg=dict(type='Kaiming', layer='Conv2d'):

import torch.nn as nn

from mmengine.model import BaseModule


class ToyNet(BaseModule):

    def __init__(self, init_cfg=None):
        super().__init__(init_cfg)
        self.linear = nn.Linear(1, 1)
        self.conv = nn.Conv2d(1, 1, 1)


# Apply `Kaiming` initialization to `Conv2d` module and `Xavier` initialization to `Linear` module.
toy_net = ToyNet(
    init_cfg=[
        dict(type='Kaiming', layer='Conv2d'),
        dict(type='Xavier', layer='Linear')
    ], )
toy_net.init_weights()
08/19 16:50:24 - mmengine - INFO -
linear.weight - torch.Size([1, 1]):
XavierInit: gain=1, distribution=normal, bias=0

08/19 16:50:24 - mmengine - INFO -
linear.bias - torch.Size([1]):
XavierInit: gain=1, distribution=normal, bias=0

08/19 16:50:24 - mmengine - INFO -
conv.weight - torch.Size([1, 1, 1, 1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0

08/19 16:50:24 - mmengine - INFO -
conv.bias - torch.Size([1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0

layer could also be a list, each element of which means a type of applied module.

# Apply Kaiming initialization to `Conv2d` and `Linear` module.
toy_net = ToyNet(init_cfg=[dict(type='Kaiming', layer=['Conv2d', 'Linear'])], )
toy_net.init_weights()
08/19 16:50:24 - mmengine - INFO -
linear.weight - torch.Size([1, 1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0

08/19 16:50:24 - mmengine - INFO -
linear.bias - torch.Size([1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0

08/19 16:50:24 - mmengine - INFO -
conv.weight - torch.Size([1, 1, 1, 1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0

08/19 16:50:24 - mmengine - INFO -
conv.bias - torch.Size([1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0
More fine-grained initialization

Sometimes we need to initialize the same type of module with different types of initialization. For example, we’ve defined conv1 and conv2 submodules, and we want to initialize the conv1 with Kaiming initialization and conv2 with Xavier initialization. We could configure the init_cfg with override:

import torch.nn as nn

from mmengine.model import BaseModule


class ToyNet(BaseModule):

    def __init__(self, init_cfg=None):
        super().__init__(init_cfg)
        self.conv1 = nn.Conv2d(1, 1, 1)
        self.conv2 = nn.Conv2d(1, 1, 1)


# Apllly `Kaiming` initialization to `conv1` and `Xavier` initialization to `conv2`.
toy_net = ToyNet(
    init_cfg=[
        dict(
            type='Kaiming',
            layer=['Conv2d'],
            override=dict(name='conv2', type='Xavier')),
    ], )
toy_net.init_weights()
08/19 16:50:24 - mmengine - INFO -
conv1.weight - torch.Size([1, 1, 1, 1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0

08/19 16:50:24 - mmengine - INFO -
conv1.bias - torch.Size([1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0

08/19 16:50:24 - mmengine - INFO -
conv2.weight - torch.Size([1, 1, 1, 1]):
XavierInit: gain=1, distribution=normal, bias=0

08/19 16:50:24 - mmengine - INFO -
conv2.bias - torch.Size([1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0

override could be understood as an nested init_cfg, which could also be a list or dict, and we should also set “type” for it. The difference is that we must set name in override to specify the applied scope for submodule. As the example above, we set name='conv2' to specify that the Xavier initialization is applied to all submodules of toy_net.conv2.

Customize the initialization method

Although the init_cfg could control the initialization method for different modules, we would have to register a new initialization method to WEIGHT_INITIALIZERS if we want to customize initialization process. It is not convenient right? Actually, we could also override the init_weights method to customize the initialization process.

Assuming we’ve defined the following modules:

ToyNet.init_weights will call init_weights of all submodules sequentially.

import torch
import torch.nn as nn

from mmengine.model import BaseModule


class ToyConv(nn.Module):

    def __init__(self):
        super().__init__()
        self.custom_weight = nn.Parameter(torch.empty(1, 1, 1, 1))
        self.custom_bias = nn.Parameter(torch.empty(1))

    def init_weights(self):
        with torch.no_grad():
            self.custom_weight = self.custom_weight.fill_(1)
            self.custom_bias = self.custom_bias.fill_(0)


class ToyNet(BaseModule):

    def __init__(self, init_cfg=None):
        super().__init__(init_cfg)
        self.conv1 = nn.Conv2d(1, 1, 1)
        self.conv2 = nn.Conv2d(1, 1, 1)
        self.custom_conv = ToyConv()


toy_net = ToyNet(
    init_cfg=[
        dict(
            type='Kaiming',
            layer=['Conv2d'],
            override=dict(name='conv2', type='Xavier'))
    ])

toy_net.init_weights()
08/19 16:50:24 - mmengine - INFO -
conv1.weight - torch.Size([1, 1, 1, 1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0

08/19 16:50:24 - mmengine - INFO -
conv1.bias - torch.Size([1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0

08/19 16:50:24 - mmengine - INFO -
conv2.weight - torch.Size([1, 1, 1, 1]):
XavierInit: gain=1, distribution=normal, bias=0

08/19 16:50:24 - mmengine - INFO -
conv2.bias - torch.Size([1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0

08/19 16:50:24 - mmengine - INFO -
custom_conv.custom_weight - torch.Size([1, 1, 1, 1]):
Initialized by user-defined `init_weights` in ToyConv

08/19 16:50:24 - mmengine - INFO -
custom_conv.custom_bias - torch.Size([1]):
Initialized by user-defined `init_weights` in ToyConv
Conclusion

1. Configure init_cfg to initialize model

2. Customize init_weights

Note

Ininitailize module with function

As mentioned in prior section, we could customize our initialization in init_weights. To make it more convenient to initialize modules, MMEngine provides a series of module initialization functions to initialize the whole module based on torch.nn.init. For example, we want to initialize the weights of the convolutional layer with normal distribution and initialize the bias of the convolutional layer with a constant. The implementation of torch.nn.init is as follows:

from torch.nn.init import normal_, constant_
import torch.nn as nn

model = nn.Conv2d(1, 1, 1)
normal_(model.weight, mean=0, std=0.01)
constant_(model.bias, val=0)
Parameter containing:
tensor([0.], requires_grad=True)

The above process is actually a standard process for initializing a convolutional module with normal distribution, so MMEngine simplifies this by implementing a series of common module initialization functions. Compared with torch.nn.init, the module initialization functions could accept the convolution module directly:

from mmengine.model import normal_init

normal_init(model, mean=0, std=0.01, bias=0)

Similarly, we could also use Kaiming initialization and Xavier initialization:

from mmengine.model import kaiming_init, xavier_init

kaiming_init(model)
xavier_init(model)

Currently, MMEngine provide the following initialization function:

Initialization function Function constant_init Initialize the weight and bias with a constant, commonly used for Convolution xavier_init Initialize the weight by Xavier initialization, and initialize the bias with a constant normal_init Initialize the weight by normal distribution, and initialize the bias with a constant trunc_normal_init Initialize the weight by truncated normal distribution, and initialize the bias with a constant, commonly used for Transformer uniform_init Initialize the weight by uniform distribution, and initialize the bias with a constant, commonly used for convolution kaiming_init Initialize the weight by Kaiming initialization, and initialize the bias with a constant. Commonly used for convolution caffe2_xavier_init Xavier initialization in Caffe2, and Kaiming initialization in PyTorh with "fan_in" and "normal" mode. Commonly used for convolution bias_init_with_prob Initialize the bias with the probability

RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4