A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://pytorch-lightning.readthedocs.io/en/stable/extensions/callbacks.html below:

Callback — PyTorch Lightning 2.5.3 documentation

Callback

Callbacks allow you to add arbitrary self-contained programs to your training. At specific points during the flow of execution (hooks), the Callback interface allows you to design programs that encapsulate a full set of functionality. It de-couples functionality that does not need to be in the lightning module and can be shared across projects.

Lightning has a callback system to execute them when needed. Callbacks should capture NON-ESSENTIAL logic that is NOT required for your lightning module to run.

A complete list of Callback hooks can be found in Callback.

An overall Lightning system should have:

  1. Trainer for all engineering

  2. LightningModule for all research code.

  3. Callbacks for non-essential code.

Example:

from lightning.pytorch.callbacks import Callback


class MyPrintingCallback(Callback):
    def on_train_start(self, trainer, pl_module):
        print("Training is starting")

    def on_train_end(self, trainer, pl_module):
        print("Training is ending")


trainer = Trainer(callbacks=[MyPrintingCallback()])

We successfully extended functionality without polluting our super clean lightning module research code.

You can do pretty much anything with callbacks.

Built-in Callbacks

Lightning has a few built-in callbacks.

Note

For a richer collection of callbacks, check out our bolts library.

Save Callback state

Some callbacks require internal state in order to function properly. You can optionally choose to persist your callback’s state as part of model checkpoint files using state_dict() and load_state_dict(). Note that the returned state must be able to be pickled.

When your callback is meant to be used only as a singleton callback then implementing the above two hooks is enough to persist state effectively. However, if passing multiple instances of the callback to the Trainer is supported, then the callback must define a state_key property in order for Lightning to be able to distinguish the different states when loading the callback state. This concept is best illustrated by the following example.

class Counter(Callback):
    def __init__(self, what="epochs", verbose=True):
        self.what = what
        self.verbose = verbose
        self.state = {"epochs": 0, "batches": 0}

    @property
    def state_key(self) -> str:
        # note: we do not include `verbose` here on purpose
        return f"Counter[what={self.what}]"

    def on_train_epoch_end(self, *args, **kwargs):
        if self.what == "epochs":
            self.state["epochs"] += 1

    def on_train_batch_end(self, *args, **kwargs):
        if self.what == "batches":
            self.state["batches"] += 1

    def load_state_dict(self, state_dict):
        self.state.update(state_dict)

    def state_dict(self):
        return self.state.copy()


# two callbacks of the same type are being used
trainer = Trainer(callbacks=[Counter(what="epochs"), Counter(what="batches")])

A Lightning checkpoint from this Trainer with the two stateful callbacks will include the following information:

{
    "state_dict": ...,
    "callbacks": {
        "Counter{'what': 'batches'}": {"batches": 32, "epochs": 0},
        "Counter{'what': 'epochs'}": {"batches": 0, "epochs": 2},
        ...
    }
}

The implementation of a state_key is essential here. If it were missing, Lightning would not be able to disambiguate the state for these two callbacks, and state_key by default only defines the class name as the key, e.g., here Counter.

Best Practices

The following are best practices when using/designing callbacks.

  1. Callbacks should be isolated in their functionality.

  2. Your callback should not rely on the behavior of other callbacks in order to work properly.

  3. Do not manually call methods from the callback.

  4. Directly calling methods (eg. on_validation_end) is strongly discouraged.

  5. Whenever possible, your callbacks should not depend on the order in which they are executed.

Entry Points

Lightning supports registering Trainer callbacks directly through Entry Points. Entry points allow an arbitrary package to include callbacks that the Lightning Trainer can automatically use, without you having to add them to the Trainer manually. This is useful in production environments where it is common to provide specialized monitoring and logging callbacks globally for every application.

Here is a callback factory function that returns two special callbacks:

def my_custom_callbacks_factory():
    return [MyCallback1(), MyCallback2()]

If we make this factories.py file into an installable package, we can define an entry point for this factory function. Here is a minimal example of the setup.py file for the package my-package:

from setuptools import setup

setup(
    name="my-package",
    version="0.0.1",
    install_requires=["lightning"],
    entry_points={
        "lightning.pytorch.callbacks_factory": [
            # The format here must be [any name]=[module path]:[function name]
            "monitor_callbacks=factories:my_custom_callbacks_factory"
        ]
    },
)

The group name for the entry points is lightning.pytorch.callbacks_factory and it contains a list of strings that specify where to find the function within the package.

Now, if you pip install -e . this package, it will register the my_custom_callbacks_factory function and Lightning will automatically call it to collect the callbacks whenever you run the Trainer!

To unregister the factory, simply uninstall the package with pip uninstall “my-package”.

Callback API

Here is the full API of methods available in the Callback base class.

The Callback class is the base for all the callbacks in Lightning just like the LightningModule is the base for all models. It defines a public interface that each callback implementation must follow, the key ones are:

Properties state_key
Callback.state_key

Identifier for the state of the callback.

Used to store and retrieve a callback’s state from the checkpoint dictionary by checkpoint["callbacks"][state_key]. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.

Hooks setup
Callback.setup(trainer, pl_module, stage)[source]

Called when fit, validate, test, predict, or tune begins.

Return type:

None

teardown
Callback.teardown(trainer, pl_module, stage)[source]

Called when fit, validate, test, predict, or tune ends.

Return type:

None

on_fit_start
Callback.on_fit_start(trainer, pl_module)[source]

Called when fit begins.

Return type:

None

on_fit_end
Callback.on_fit_end(trainer, pl_module)[source]

Called when fit ends.

Return type:

None

on_sanity_check_start
Callback.on_sanity_check_start(trainer, pl_module)[source]

Called when the validation sanity check starts.

Return type:

None

on_sanity_check_end
Callback.on_sanity_check_end(trainer, pl_module)[source]

Called when the validation sanity check ends.

Return type:

None

on_train_batch_start
Callback.on_train_batch_start(trainer, pl_module, batch, batch_idx)[source]

Called when the train batch begins.

Return type:

None

on_train_batch_end
Callback.on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]

Called when the train batch ends. :rtype: None

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

on_train_epoch_start
Callback.on_train_epoch_start(trainer, pl_module)[source]

Called when the train epoch begins.

Return type:

None

on_train_epoch_end
Callback.on_train_epoch_end(trainer, pl_module)[source]

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the lightning.pytorch.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss


class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()
Return type:

None

on_validation_epoch_start
Callback.on_validation_epoch_start(trainer, pl_module)[source]

Called when the val epoch begins.

Return type:

None

on_validation_epoch_end
Callback.on_validation_epoch_end(trainer, pl_module)[source]

Called when the val epoch ends.

Return type:

None

on_test_epoch_start
Callback.on_test_epoch_start(trainer, pl_module)[source]

Called when the test epoch begins.

Return type:

None

on_test_epoch_end
Callback.on_test_epoch_end(trainer, pl_module)[source]

Called when the test epoch ends.

Return type:

None

on_predict_epoch_start
Callback.on_predict_epoch_start(trainer, pl_module)[source]

Called when the predict epoch begins.

Return type:

None

on_predict_epoch_end
Callback.on_predict_epoch_end(trainer, pl_module)[source]

Called when the predict epoch ends.

Return type:

None

on_validation_batch_start
Callback.on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[source]

Called when the validation batch begins.

Return type:

None

on_validation_batch_end
Callback.on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]

Called when the validation batch ends.

Return type:

None

on_test_batch_start
Callback.on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[source]

Called when the test batch begins.

Return type:

None

on_test_batch_end
Callback.on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]

Called when the test batch ends.

Return type:

None

on_predict_batch_start
Callback.on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[source]

Called when the predict batch begins.

Return type:

None

on_predict_batch_end
Callback.on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]

Called when the predict batch ends.

Return type:

None

on_train_start
Callback.on_train_start(trainer, pl_module)[source]

Called when the train begins.

Return type:

None

on_train_end
Callback.on_train_end(trainer, pl_module)[source]

Called when the train ends.

Return type:

None

on_validation_start
Callback.on_validation_start(trainer, pl_module)[source]

Called when the validation loop begins.

Return type:

None

on_validation_end
Callback.on_validation_end(trainer, pl_module)[source]

Called when the validation loop ends.

Return type:

None

on_test_start
Callback.on_test_start(trainer, pl_module)[source]

Called when the test begins.

Return type:

None

on_test_end
Callback.on_test_end(trainer, pl_module)[source]

Called when the test ends.

Return type:

None

on_predict_start
Callback.on_predict_start(trainer, pl_module)[source]

Called when the predict begins.

Return type:

None

on_predict_end
Callback.on_predict_end(trainer, pl_module)[source]

Called when predict ends.

Return type:

None

on_exception
Callback.on_exception(trainer, pl_module, exception)[source]

Called when any trainer execution is interrupted by an exception.

Return type:

None

state_dict
Callback.state_dict()[source]

Called when saving a checkpoint, implement to generate callback’s state_dict.

Return type:

dict[str, Any]

Returns:

A dictionary containing callback state.

on_save_checkpoint
Callback.on_save_checkpoint(trainer, pl_module, checkpoint)[source]

Called when saving a checkpoint to give you a chance to store anything else you might want to save.

Parameters:
Return type:

None

load_state_dict
Callback.load_state_dict(state_dict)[source]

Called when loading a checkpoint, implement to reload callback state given callback’s state_dict.

Parameters:

state_dict (dict[str, Any]) – the callback state returned by state_dict.

Return type:

None

on_load_checkpoint
Callback.on_load_checkpoint(trainer, pl_module, checkpoint)[source]

Called when loading a model checkpoint, use to reload state.

Parameters:
Return type:

None

on_before_backward
Callback.on_before_backward(trainer, pl_module, loss)[source]

Called before loss.backward().

Return type:

None

on_after_backward
Callback.on_after_backward(trainer, pl_module)[source]

Called after loss.backward() and before optimizers are stepped.

Return type:

None

on_before_optimizer_step
Callback.on_before_optimizer_step(trainer, pl_module, optimizer)[source]

Called before optimizer.step().

Return type:

None

on_before_zero_grad
Callback.on_before_zero_grad(trainer, pl_module, optimizer)[source]

Called before optimizer.zero_grad().

Return type:

None


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4