A training helper for PyTorch.
Runner object can be built from config by runner = Runner.from_cfg(cfg)
where the cfg
usually contains training, validation, and test-related configurations to build corresponding components. We usually use the same config to launch training, testing, and validation tasks. However, only some of these components are necessary at the same time, e.g., testing a model does not need training or validation-related components.
To avoid repeatedly modifying config, the construction of Runner
adopts lazy initialization to only initialize components when they are going to be used. Therefore, the model is always initialized at the beginning, and training, validation, and, testing related components are only initialized when calling runner.train()
, runner.val()
, and runner.test()
, respectively.
Warning
This is an experimental feature, and its interface is subject to change.
model (torch.nn.Module
or dict) – The model to be run. It can be a dict used for build a model.
work_dir (str) –
experiment_name (str | None) –
train_dataloader (DataLoader | Dict | None) –
optim_wrapper (OptimWrapper | Dict | None) –
param_scheduler (_ParamScheduler | Dict | List | None) –
train_cfg (Dict | None) –
val_dataloader (DataLoader | Dict | None) –
val_cfg (Dict | None) –
test_dataloader (DataLoader | Dict | None) –
test_cfg (Dict | None) –
strategy (BaseStrategy | Dict | None) –
auto_scale_lr (Dict | None) –
load_from (str | None) –
launcher (str | None) –
env_cfg (Dict) –
log_processor (Dict | None) –
log_level (str) –
visualizer (Visualizer | Dict | None) –
default_scope (str | None) –
randomness (Dict) –
cfg (Config) –
The logs will be saved in the subdirectory of work_dir named timestamp
. Defaults to ‘work_dir’.
specified, timestamp will be used as experiment_name
. Defaults to None.
a dict to build a dataloader. If None
is given, it means skipping training steps. Defaults to None. See build_dataloader()
for more details.
Computing gradient of model parameters. If specified, train_dataloader
should also be specified. If automatic mixed precision or gradient accmulation training is required. The type of optim_wrapper
should be AmpOptimizerWrapper. See build_optim_wrapper()
for examples. Defaults to None.
Parameter scheduler for updating optimizer parameters. If specified, optimizer
should also be specified. Defaults to None. See build_param_scheduler()
for examples.
not provide “type” key, it should contain “by_epoch” to decide which type of training loop EpochBasedTrainLoop
or IterBasedTrainLoop
should be used. If train_cfg
specified, train_dataloader
should also be specified. Defaults to None. See build_train_loop()
for more details.
a dict to build a dataloader. If None
is given, it means skipping validation steps. Defaults to None. See build_dataloader()
for more details.
used for computing metrics for validation. It can be a dict or a list of dict to build a evaluator. If specified, val_dataloader
should also be specified. Defaults to None.
not provide “type” key, ValLoop
will be used by default. If val_cfg
specified, val_dataloader
should also be specified. If ValLoop
is built with fp16=True`, runner.val()
will be performed under fp16 precision.
a dict to build a dataloader. If None
is given, it means skipping test steps. Defaults to None. See build_dataloader()
for more details. Defaults to None. See build_val_loop()
for more details.
object used for computing metrics for test steps. It can be a dict or a list of dict to build a evaluator. If specified, test_dataloader
should also be specified. Defaults to None.
not provide “type” key, TestLoop
will be used by default. If test_cfg
specified, test_dataloader
should also be specified. If ValLoop
is built with fp16=True`, runner.val()
will be performed under fp16 precision. Defaults to None. See build_test_loop()
for more details.
to build a strategy. Defaults to None. If not specified, the strategy will be inferred automatically.
automatically. It includes base_batch_size
and enable
. base_batch_size
is the batch size that the optimizer lr is based on. enable
is the switch to turn on and off the feature.
execute default actions like updating model parameters and saving checkpoints. Default hooks are OptimizerHook
, IterTimerHook
, LoggerHook
, ParamSchedulerHook
and CheckpointHook
. Defaults to None. See register_default_hooks()
for more details.
custom actions like visualizing images processed by pipeline. Defaults to None.
BaseDataPreprocessor
. If the model
argument is a dict and doesn’t contain the key data_preprocessor
, set the argument as the data_preprocessor
of the model
dict. Defaults to None.
Defaults to None.
resume
is True and load_from
is None, automatically to find latest checkpoint from work_dir
. If not found, resuming does nothing.
launchers are ‘pytorch’, ‘mpi’, ‘slurm’ and ‘none’. If ‘none’ is provided, non-distributed environment will be launched. If launcher is None, the launcher will be inferred according some specified environments. Defaults to None.
dict(dist_cfg=dict(backend=’nccl’)).
None.
Defaults to ‘INFO’.
dict build Visualizer object. Defaults to None. If not specified, default config will be used.
Defaults to “mmengine”.
as possible like seed and deterministic. Defaults to dict(seed=None)
. If seed is None, a random number will be generated and it will be broadcasted to all other processes if in distributed environment. If cudnn_benchmark
is True
in env_cfg
but deterministic
is True
in randomness
, the value of torch.backends.cudnn.benchmark
will be False
finally.
torch.compile
.
Defaults to False.
Config
, optional): Full config.
Defaults to None.
Note
Since PyTorch 2.0.0, you can enable torch.compile
by passing in compile = True. If you want to control compile options, you can pass a dict, e.g. cfg.compile = dict(backend='eager')
. Refer to PyTorch API Documentation for more valid options.
Examples
>>> from mmengine.runner import Runner >>> cfg = dict( >>> model=dict(type='ToyModel'), >>> work_dir='path/of/work_dir', >>> train_dataloader=dict( >>> dataset=dict(type='ToyDataset'), >>> sampler=dict(type='DefaultSampler', shuffle=True), >>> batch_size=1, >>> num_workers=0), >>> val_dataloader=dict( >>> dataset=dict(type='ToyDataset'), >>> sampler=dict(type='DefaultSampler', shuffle=False), >>> batch_size=1, >>> num_workers=0), >>> test_dataloader=dict( >>> dataset=dict(type='ToyDataset'), >>> sampler=dict(type='DefaultSampler', shuffle=False), >>> batch_size=1, >>> num_workers=0), >>> auto_scale_lr=dict(base_batch_size=16, enable=False), >>> optim_wrapper=dict(type='OptimizerWrapper', optimizer=dict( >>> type='SGD', lr=0.01)), >>> param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]), >>> val_evaluator=dict(type='ToyEvaluator'), >>> test_evaluator=dict(type='ToyEvaluator'), >>> train_cfg=dict(by_epoch=True, max_epochs=3, val_interval=1), >>> val_cfg=dict(), >>> test_cfg=dict(), >>> custom_hooks=[], >>> default_hooks=dict( >>> timer=dict(type='IterTimerHook'), >>> checkpoint=dict(type='CheckpointHook', interval=1), >>> logger=dict(type='LoggerHook'), >>> optimizer=dict(type='OptimizerHook', grad_clip=False), >>> param_scheduler=dict(type='ParamSchedulerHook')), >>> launcher='none', >>> env_cfg=dict(dist_cfg=dict(backend='nccl')), >>> log_processor=dict(window_size=20), >>> visualizer=dict(type='Visualizer', >>> vis_backends=[dict(type='LocalVisBackend', >>> save_dir='temp_dir')]) >>> ) >>> runner = Runner.from_cfg(cfg) >>> runner.train() >>> runner.test()
Build dataloader.
The method builds three components:
Dataset
Sampler
Dataloader
An example of dataloader
:
dataloader = dict( dataset=dict(type='ToyDataset'), sampler=dict(type='DefaultSampler', shuffle=True), batch_size=1, num_workers=9 )
dataloader (DataLoader or dict) – A Dataloader object or a dict to build Dataloader object. If dataloader
is a Dataloader object, just returns itself.
seed (int, optional) – Random seed. Defaults to None.
diff_rank_seed (bool) – Whether or not set different seeds to different ranks. If True, the seed passed to sampler is set to None, in order to synchronize the seeds used in samplers across different ranks. Defaults to False.
DataLoader build from dataloader_cfg
.
Dataloader
Build evaluator.
Examples of evaluator
:
# evaluator could be a built Evaluator instance evaluator = Evaluator(metrics=[ToyMetric()]) # evaluator can also be a list of dict evaluator = [ dict(type='ToyMetric1'), dict(type='ToyEvaluator2') ] # evaluator can also be a list of built metric evaluator = [ToyMetric1(), ToyMetric2()] # evaluator can also be a dict with key metrics evaluator = dict(metrics=ToyMetric()) # metric is a list evaluator = dict(metrics=[ToyMetric()])
Build test log_processor.
Examples of log_processor
:
# LogProcessor will be used log_processor = dict()
# custom log_processor log_processor = dict(type=’CustomLogProcessor’)
log_processor (LogProcessor or dict) – A log processor or a dict
processor (to build log processor. If log_processor is a log) –
object –
itself. (just returns) –
Log processor object build from log_processor_cfg
.
Build a global asscessable MessageHub.
message_hub (dict, optional) – A dict to build MessageHub object. If not specified, default config will be used to build MessageHub object. Defaults to None.
A MessageHub object build from message_hub
.
Build a strategy.
A strategy object.
Build test loop.
Examples of loop
:
# `TestLoop` will be used loop = dict() # custom test loop loop = dict(type='CustomTestLoop')
Build training loop.
Examples of loop
:
# `EpochBasedTrainLoop` will be used loop = dict(by_epoch=True, max_epochs=3) # `IterBasedTrainLoop` will be used loop = dict(by_epoch=False, max_epochs=3) # custom training loop loop = dict(type='CustomTrainLoop', max_epochs=3)
Build validation loop.
Examples of loop
:
# ValLoop will be used loop = dict()
# custom validation loop loop = dict(type=’CustomValLoop’)
Build a global asscessable Visualizer.
visualizer (Visualizer or dict, optional) – A Visualizer object or a dict to build Visualizer object. If visualizer
is a Visualizer object, just returns itself. If not specified, default config will be used to build Visualizer object. Defaults to None.
A Visualizer object build from visualizer
.
Call all hooks.
fn_name (str) – The function name in each hook to be called, such as “before_train_epoch”.
**kwargs – Keyword arguments passed to hook.
None
Whether cudnn to select deterministic algorithms.
Whether current environment is distributed.
Dump config to work_dir.
None
Current epoch.
Name of experiment.
Build a runner from config.
cfg (ConfigType) – A config used for building runner. Keys of cfg
can see __init__()
.
A runner build from cfg
.
A list of registered hooks.
List[Hook
]
Current iteration.
Load checkpoint from given filename
.
filename (str) – Accept local filepath, URL, torchvision://xxx
, open-mmlab://xxx
.
map_location (str or callable) – A string or a callable function to specifying how to remap storage locations. Defaults to ‘cpu’.
strict (bool) – strict (bool): Whether to allow different params for the model and checkpoint.
revise_keys (list) – A list of customized keywords to modify the state_dict in checkpoint. Each item is a (pattern, replacement) pair of the regular expression operations. Defaults to strip the prefix ‘module.’ by [(r’^module.’, ‘’)].
Load or resume checkpoint.
Total epochs to train model.
Total iterations to train model.
Name of the model, usually the module class name.
Rank of current process.
Register custom hooks into hook list.
Register default hooks into hook list.
hooks
will be registered into runner to execute some default actions like updating model parameters or saving checkpoints.
Default hooks and their priorities:
Hooks
Priority
RuntimeInfoHook
VERY_HIGH (10)
IterTimerHook
NORMAL (50)
DistSamplerSeedHook
NORMAL (50)
LoggerHook
BELOW_NORMAL (60)
ParamSchedulerHook
LOW (70)
CheckpointHook
VERY_LOW (90)
If hooks
is None, above hooks will be registered by default:
default_hooks = dict( runtime_info=dict(type='RuntimeInfoHook'), timer=dict(type='IterTimerHook'), sampler_seed=dict(type='DistSamplerSeedHook'), logger=dict(type='LoggerHook'), param_scheduler=dict(type='ParamSchedulerHook'), checkpoint=dict(type='CheckpointHook', interval=1), )
If not None, hooks
will be merged into default_hooks
. If there are None value in default_hooks, the corresponding item will be popped from default_hooks
:
The final registered default hooks will be RuntimeInfoHook
, DistSamplerSeedHook
, LoggerHook
, ParamSchedulerHook
and CheckpointHook
.
Register a hook into the hook list.
The hook will be inserted into a priority queue, with the specified priority (See Priority
for details of priorities). For hooks with the same priority, they will be triggered in the same order as they are registered.
Priority of hook will be decided with the following priority:
priority
argument. If priority
is given, it will be priority of hook.
If hook
argument is a dict and priority
in it, the priority will be the value of hook['priority']
.
If hook
argument is a dict but priority
not in it or hook
is an instance of hook
, the priority will be hook.priority
.
hook (Hook
or dict) – The hook to be registered.
priority (int or str or Priority
, optional) – Hook priority. Lower value means higher priority.
None
Register default hooks and custom hooks into hook list.
default_hooks (dict[str, dict] or dict[str, Hook], optional) – Hooks to execute default actions like updating model parameters and saving checkpoints. Defaults to None.
custom_hooks (list[dict] or list[Hook], optional) – Hooks to execute custom actions like visualizing images processed by pipeline. Defaults to None.
None
Resume model from checkpoint.
filename (str) – Accept local filepath, URL, torchvision://xxx
, open-mmlab://xxx
.
resume_optimizer (bool) – Whether to resume optimizer state. Defaults to True.
resume_param_scheduler (bool) – Whether to resume param scheduler state. Defaults to True.
map_location (str or callable) – A string or a callable function to specifying how to remap storage locations. Defaults to ‘default’.
None
Save checkpoints.
CheckpointHook
invokes this method to save checkpoints periodically.
out_dir (str) – The directory that checkpoints are saved.
filename (str) – The checkpoint filename.
file_client_args (dict, optional) – Arguments to instantiate a FileClient. See mmengine.fileio.FileClient
for details. Defaults to None. It will be deprecated in future. Please use backend_args instead.
save_optimizer (bool) – Whether to save the optimizer to the checkpoint. Defaults to True.
save_param_scheduler (bool) – Whether to save the param_scheduler to the checkpoint. Defaults to True.
meta (dict, optional) – The meta information to be saved in the checkpoint. Defaults to None.
by_epoch (bool) – Whether the scheduled momentum is updated by epochs. Defaults to True.
backend_args (dict, optional) – Arguments to instantiate the prefix of uri corresponding backend. Defaults to None.
A number to set random modules.
Launch test.
A dict of metrics on testing set.
The data loader for testing.
An evaluator for testing.
Evaluator
A loop to run testing.
Timestamp when creating experiment.
Launch training.
The model after training.
nn.Module
The data loader for training.
A loop to run training.
Launch validation.
A dict of metrics on validation set.
The epoch/iteration to start running validation during training.
The data loader for validation.
An evaluator for validation.
Evaluator
Interval to run validation during training.
A loop to run validation.
The working directory to save checkpoints and logs.
Number of processes participating in the job.
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4