Lookup and instantiate classes with style.
from class_resolver import ClassResolver from dataclasses import dataclass class Base: pass @dataclass class A(Base): name: str @dataclass class B(Base): name: str # Index resolver = ClassResolver([A, B], base=Base) # Lookup assert A == resolver.lookup('A') # Instantiate with a dictionary assert A(name='hi') == resolver.make('A', {'name': 'hi'}) # Instantiate with kwargs assert A(name='hi') == resolver.make('A', name='hi') # A pre-instantiated class will simply be passed through assert A(name='hi') == resolver.make(A(name='hi'))🤖 Writing Extensible Machine Learning Models with
class-resolver
Assume you've implemented a simple multi-layer perceptron in PyTorch:
from itertools import chain from more_itertools import pairwise from torch import nn class MLP(nn.Sequential): def __init__(self, dims: list[int]): super().__init__(chain.from_iterable( ( nn.Linear(in_features, out_features), nn.ReLU(), ) for in_features, out_features in pairwise(dims) ))
This MLP uses a hard-coded rectified linear unit as the non-linear activation function between layers. We can generalize this MLP to use a variety of non-linear activation functions by adding an argument to its __init__()
function like in:
from itertools import chain from more_itertools import pairwise from torch import nn class MLP(nn.Sequential): def __init__(self, dims: list[int], activation: str = "relu"): if activation == "relu": activation = nn.ReLU() elif activation == "tanh": activation = nn.Tanh() elif activation == "hardtanh": activation = nn.Hardtanh() else: raise KeyError(f"Unsupported activation: {activation}") super().__init__(chain.from_iterable( ( nn.Linear(in_features, out_features), activation, ) for in_features, out_features in pairwise(dims) ))
The first issue with this implementation is it relies on a hard-coded set of conditional statements and is therefore hard to extend. It can be improved by using a dictionary lookup:
from itertools import chain from more_itertools import pairwise from torch import nn activation_lookup: dict[str, nn.Module] = { "relu": nn.ReLU(), "tanh": nn.Tanh(), "hardtanh": nn.Hardtanh(), } class MLP(nn.Sequential): def __init__(self, dims: list[int], activation: str = "relu"): activation = activation_lookup[activation] super().__init__(chain.from_iterable( ( nn.Linear(in_features, out_features), activation, ) for in_features, out_features in pairwise(dims) ))
This approach is rigid because it requires pre-instantiation of the activations. If we needed to vary the arguments to the nn.HardTanh
class, the previous approach wouldn't work. We can change the implementation to lookup on the class before instantiation then optionally pass some arguments:
from itertools import chain from more_itertools import pairwise from torch import nn activation_lookup: dict[str, type[nn.Module]] = { "relu": nn.ReLU, "tanh": nn.Tanh, "hardtanh": nn.Hardtanh, } class MLP(nn.Sequential): def __init__( self, dims: list[int], activation: str = "relu", activation_kwargs: None | dict[str, any] = None, ): activation_cls = activation_lookup[activation] activation = activation_cls(**(activation_kwargs or {})) super().__init__(chain.from_iterable( ( nn.Linear(in_features, out_features), activation, ) for in_features, out_features in pairwise(dims) ))
This is pretty good, but it still has a few issues:
activation_lookup
dictionary,activation
keywordEnter the class_resolver
package, which takes care of all of these things using the following:
from itertools import chain from class_resolver import ClassResolver, Hint from more_itertools import pairwise from torch import nn activation_resolver = ClassResolver( [nn.ReLU, nn.Tanh, nn.Hardtanh], base=nn.Module, default=nn.ReLU, ) class MLP(nn.Sequential): def __init__( self, dims: list[int], activation: Hint[nn.Module] = None, # Hint = Union[None, str, nn.Module, type[nn.Module]] activation_kwargs: None | dict[str, any] = None, ): super().__init__(chain.from_iterable( ( nn.Linear(in_features, out_features), activation_resolver.make(activation, activation_kwargs), ) for in_features, out_features in pairwise(dims) ))
Because this is such a common pattern, we've made it available through contrib module in class_resolver.contrib.torch
:
from itertools import chain from class_resolver import Hint from class_resolver.contrib.torch import activation_resolver from more_itertools import pairwise from torch import nn class MLP(nn.Sequential): def __init__( self, dims: list[int], activation: Hint[nn.Module] = None, activation_kwargs: None | dict[str, any] = None, ): super().__init__(chain.from_iterable( ( nn.Linear(in_features, out_features), activation_resolver.make(activation, activation_kwargs), ) for in_features, out_features in pairwise(dims) ))
Now, you can instantiate the MLP with any of the following:
MLP(dims=[10, 200, 40]) # uses default, which is ReLU MLP(dims=[10, 200, 40], activation="relu") # uses lowercase MLP(dims=[10, 200, 40], activation="ReLU") # uses stylized MLP(dims=[10, 200, 40], activation=nn.ReLU) # uses class MLP(dims=[10, 200, 40], activation=nn.ReLU()) # uses instance MLP(dims=[10, 200, 40], activation="hardtanh", activation_kwargs={"min_val": 0.0, "max_value": 6.0}) # uses kwargs MLP(dims=[10, 200, 40], activation=nn.HardTanh, activation_kwargs={"min_val": 0.0, "max_value": 6.0}) # uses kwargs MLP(dims=[10, 200, 40], activation=nn.HardTanh(0.0, 6.0)) # uses instance
In practice, it makes sense to stick to using the strings in combination with hyper-parameter optimization libraries like Optuna.
The most recent release can be installed from PyPI with uv:
$ uv pip install class_resolver
or with pip:
$ python3 -m pip install class_resolver
The most recent code and data can be installed directly from GitHub with uv:
$ uv pip install git+https://github.com/cthoyt/class-resolver.git
or with pip:
$ python3 -m pip install git+https://github.com/cthoyt/class-resolver.git
Contributions, whether filing an issue, making a pull request, or forking, are appreciated. See CONTRIBUTING.md for more information on getting involved.
The code in this package is licensed under the MIT License.
This package was created with @audreyfeldroy's cookiecutter package using @cthoyt's cookiecutter-snekpack template.
See developer instructionsThe final section of the README is for if you want to get involved by making a code contribution.
To install in development mode, use the following:
$ git clone git+https://github.com/cthoyt/class-resolver.git $ cd class-resolver $ uv pip install -e .
Alternatively, install using pip:
$ python3 -m pip install -e .Updating Package Boilerplate
This project uses cruft
to keep boilerplate (i.e., configuration, contribution guidelines, documentation configuration) up-to-date with the upstream cookiecutter package. Install cruft with either uv tool install cruft
or python3 -m pip install cruft
then run:
More info on Cruft's update command is available here.
After cloning the repository and installing tox
with uv tool install tox --with tox-uv
or python3 -m pip install tox tox-uv
, the unit tests in the tests/
folder can be run reproducibly with:
Additionally, these tests are automatically re-run with each commit in a GitHub Action.
📖 Building the DocumentationThe documentation can be built locally using the following:
$ git clone git+https://github.com/cthoyt/class-resolver.git $ cd class-resolver $ tox -e docs $ open docs/build/html/index.html
The documentation automatically installs the package as well as the docs
extra specified in the pyproject.toml
. sphinx
plugins like texext
can be added there. Additionally, they need to be added to the extensions
list in docs/source/conf.py
.
The documentation can be deployed to ReadTheDocs using this guide. The .readthedocs.yml
YAML file contains all the configuration you'll need. You can also set up continuous integration on GitHub to check not only that Sphinx can build the documentation in an isolated environment (i.e., with tox -e docs-test
) but also that ReadTheDocs can build it too.
Zenodo is a long-term archival system that assigns a DOI to each release of your package.
After these steps, you're ready to go! After you make "release" on GitHub (steps for this are below), you can navigate to https://zenodo.org/account/settings/github/repository/cthoyt/class-resolver to see the DOI for the release and link to the Zenodo record for it.
Registering with the Python Package Index (PyPI)You only have to do the following steps once.
You have to do the following steps once per machine.
$ uv tool install keyring $ keyring set https://upload.pypi.org/legacy/ __token__ $ keyring set https://test.pypi.org/legacy/ __token__
Note that this deprecates previous workflows using .pypirc
.
After installing the package in development mode and installing tox
with uv tool install tox --with tox-uv
or python3 -m pip install tox tox-uv
, run the following from the console:
This script does the following:
pyproject.toml
, CITATION.cff
, src/class_resolver/version.py
, and docs/source/conf.py
to not have the -dev
suffixuv build
uv publish
.tox -e bumpversion -- minor
after.This will trigger Zenodo to assign a DOI to your release as well.
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4