A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://mmengine.readthedocs.io/en/latest/common_usage/save_gpu_memory.html below:

Save Memory on GPU — mmengine 0.10.7 documentation

Save Memory on GPU

Memory capacity is critical in deep learning training and inference and determines whether the model can run successfully. Common memory saving approaches include:

MMEngine now supports gradient accumulation and large model training FSDP techniques, and the usages are described as follows.

Gradient Accumulation

The configuration can be written in this way:

optim_wrapper_cfg = dict(
    type='OptimWrapper',
    optimizer=dict(type='SGD', lr=0.001, momentum=0.9),
    # update every four times
    accumulative_counts=4)

The full example working with Runner is as follows.

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from mmengine.runner import Runner
from mmengine.model import BaseModel

train_dataset = [(torch.ones(1, 1), torch.ones(1, 1))] * 50
train_dataloader = DataLoader(train_dataset, batch_size=2)


class ToyModel(BaseModel):
    def __init__(self) -> None:
        super().__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, img, label, mode):
        feat = self.linear(img)
        loss1 = (feat - label).pow(2)
        loss2 = (feat - label).abs()
        return dict(loss1=loss1, loss2=loss2)


runner = Runner(
    model=ToyModel(),
    work_dir='tmp_dir',
    train_dataloader=train_dataloader,
    train_cfg=dict(by_epoch=True, max_epochs=1),
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01),
                       accumulative_counts=4)
)
runner.train()
Gradient Checkpointing

Note

Starting from MMEngine v0.9.0, gradient checkpointing is supported. For performance comparisons, you can click on #1319. If you encounter any issues during usage, feel free to provide feedback in #1319.

You can simply enable gradient checkpointing by configuring activation_checkpointing in the Runner’s cfg parameters.

Let’s take Get Started in 15 Minutes as an example:

cfg = dict(
    activation_checkpointing=['resnet.layer1', 'resnet.layer2', 'resnet.layer3']
)
runner = Runner(
    model=MMResNet50(),
    work_dir='./work_dir',
    train_dataloader=train_dataloader,
    optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
    train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
    val_dataloader=val_dataloader,
    val_cfg=dict(),
    val_evaluator=dict(type=Accuracy),
    launcher=args.launcher,
    cfg=cfg,
)
runner.train()
Large Model Training

Warning

If you have the requirement to train large models, we recommend reading Training Big Models.

FSDP is officially supported from PyTorch 1.11. The config can be written in this way:

# located in cfg file
model_wrapper_cfg=dict(type='MMFullyShardedDataParallel', cpu_offload=True)

The full example working with Runner is as follows.

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from mmengine.runner import Runner
from mmengine.model import BaseModel

train_dataset = [(torch.ones(1, 1), torch.ones(1, 1))] * 50
train_dataloader = DataLoader(train_dataset, batch_size=2)


class ToyModel(BaseModel):
    def __init__(self) -> None:
        super().__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, img, label, mode):
        feat = self.linear(img)
        loss1 = (feat - label).pow(2)
        loss2 = (feat - label).abs()
        return dict(loss1=loss1, loss2=loss2)


runner = Runner(
    model=ToyModel(),
    work_dir='tmp_dir',
    train_dataloader=train_dataloader,
    train_cfg=dict(by_epoch=True, max_epochs=1),
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01)),
    cfg=dict(model_wrapper_cfg=dict(type='MMFullyShardedDataParallel', cpu_offload=True))
)
runner.train()

Please be noted that FSDP works only in distributed training environments.


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4