LuxonisTrain
is a user-friendly tool designed to streamline the training of deep learning models, especially for edge devices. Built on top of PyTorch Lightning
, it simplifies the process of training, testing, and exporting models with minimal coding required.
YAML
configuration file.Warning
The project is in a beta state and might be unstable or contain bugs - please report any feedback.
Get started with LuxonisTrain
in just a few steps:
Install LuxonisTrain
pip install luxonis-train
This will create the luxonis_train
executable in your PATH
.
Use the provided configs/detection_light_model.yaml
configuration file
You can download the file by executing the following command:
wget https://raw.githubusercontent.com/luxonis/luxonis-train/main/configs/detection_light_model.yaml
Find a suitable dataset for your task
We will use a sample COCO dataset from RoboFlow
in this example.
Start training
luxonis_train train \ --config detection_light_model.yaml \ loader.params.dataset_dir "roboflow://team-roboflow/coco-128/2/coco"
Monitor progress with TensorBoard
tensorboard --logdir output/tensorboard_logs
Open the provided URL in your browser to visualize the training progress
Note
For hands-on examples of how to prepare data with LuxonisML
and train AI models using LuxonisTrain
, check out this guide.
LuxonisTrain
requires Python 3.10 or higher. We recommend using a virtual environment to manage dependencies.
Install via pip
:
pip install luxonis-train
This will also install the luxonis_train
CLI. For more information on how to use it, see CLI Usage.
You can use LuxonisTrain
either from the command line or via the Python API. We will demonstrate both ways in the following sections.
The CLI is the most straightforward way how to use LuxonisTrain
. The CLI provides several commands for training, testing, tuning, exporting and more.
Available commands:
train
- Start the training processtest
- Test the model on a specific dataset viewinfer
- Run inference on a dataset, image directory, or a video file.export
- Export the model to either ONNX
or BLOB
format that can be run on edge devicesarchive
- Create an NN Archive
file that can be used with our DepthAI
API (coming soon)tune
- Tune the hyperparameters of the model for better performanceinspect
- Inspect the dataset you are using and visualize the annotationsannotate
- Annotate a directory using the modelβs predictions and generate a new LDF.To get help on any command:
luxonis_train <command> --help
Specific usage examples can be found in the respective sections below.
LuxonisTrain
uses YAML
configuration files to define the training pipeline. Here's a breakdown of the key sections:
model: name: model_name # Use a predefined detection model instead of defining # the model architecture manually predefined_model: name: DetectionModel params: variant: light # Download and parse the coco dataset from RoboFlow. # Save it internally as `coco_test` dataset for future reference. loader: params: dataset_name: coco_test dataset_dir: "roboflow://team-roboflow/coco-128/2/coco" trainer: batch_size: 8 epochs: 200 n_workers: 8 validation_interval: 10 preprocessing: train_image_size: [384, 384] # Uses the imagenet normalization by default normalize: active: true # Augmentations are powered by Albumentations augmentations: - name: Defocus - name: Sharpen - name: Flip callbacks: - name: ExportOnTrainEnd - name: ArchiveOnTrainEnd - name: TestOnTrainEnd optimizer: name: SGD params: lr: 0.02 scheduler: name: ConstantLRπ Configuration Reference
For a complete reference of all available configuration options, see our Configuration Documentation.
Tip
We provide a set of predefined configuration files for common computer vision tasks in the configs
directory. These are great starting points that you can customize for your specific needs.
LuxonisTrain
supports several ways of loading data:
LuxonisDataset
formatThe easiest way to load data is to use a directory with the dataset in one of the supported formats.
Supported formats:
COCO
- We support COCO JSON format in two variants:
Pascal VOC XML
YOLO Darknet TXT
YOLOv4 PyTorch TXT
MT YOLOv6
CreateML JSON
TensorFlow Object Detection CSV
Classification Directory
- A directory with subdirectories for each class
dataset_dir/
βββ train/
β βββ class1/
β β βββ img1.jpg
β β βββ img2.jpg
β β βββ ...
β βββ class2/
β βββ ...
βββ valid/
βββ test/
Segmentation Mask Directory
- A directory with images and corresponding masks.
dataset_dir/
βββ train/
β βββ img1.jpg
β βββ img1_mask.png
β βββ ...
β βββ _classes.csv
βββ valid/
βββ test/
The masks are stored as grayscale PNG
images where each pixel value corresponds to a class. The mapping from pixel values to classes is defined in the _classes.csv
file.
Pixel Value, Class
0, background
1, class1
2, class2
3, class3
dataset_dir
parameter in the configuration file to point to the dataset directory.The dataset_dir
can be one of the following:
"data"
directory in the current working directorys3://bucket/path/to/directory
fo AWS S3gs://buclet/path/to/directory
for Google Cloud Storageroboflow://workspace/project/version/format
for RoboFlow
workspace
- name of the workspace the dataset belongs toproject
- name of the project the dataset belongs toversion
- version of the datasetformat
- one of coco
, darknet
, voc
, yolov4pytorch
, mt-yolov6
, createml
, tensorflow
, folder
, png-mask-semantic
roboflow://team-roboflow/coco-128/2/coco
Example:
loader: params: dataset_name: "coco_test" dataset_dir: "roboflow://team-roboflow/coco-128/2/coco"
LuxonisDataset
is our custom dataset format designed for easy and efficient dataset management. To learn more about how to create a dataset in this format from scratch, see the Luxonis ML repository.
To use the LuxonisDataset
as a source of the data, specify the following in the config file:
loader: params: # name of the dataset dataset_name: "dataset_name" # one of local (default), s3, gcs bucket_storage: "local"
Tip
To inspect the loader output, use the luxonis_train inspect
command:
luxonis_train inspect --config configs/detection_light_model.yaml
The inspect
command is currently only available in the CLI
For additional information about the shapes of Luxonis ML data that the loader returns, please refer to the Loaders README.
Once your configuration file and dataset are ready, start the training process.
CLI:
luxonis_train train --config configs/detection_light_model.yaml
Tip
To change a configuration parameter from the command line, use the following syntax:
luxonis_train train \ --config configs/detection_light_model.yaml \ loader.params.dataset_dir "roboflow://team-roboflow/coco-128/2/coco"
Python API:
from luxonis_train import LuxonisModel model = LuxonisModel( "configs/detection_light_model.yaml", {"loader.params.dataset_dir": "roboflow://team-roboflow/coco-128/2/coco"} ) model.train()
Expected Output:
INFO Using predefined model: `DetectionModel`
INFO Main metric: `MeanAveragePrecision`
INFO GPU available: True (cuda), used: True
INFO TPU available: False, using: 0 TPU cores
INFO HPU available: False, using: 0 HPUs
...
INFO Training finished
INFO Checkpoints saved in: output/1-coral-wren
Monitoring with TensorBoard
:
If not explicitly disabled, the training process will be monitored by TensorBoard
. To start the TensorBoard
server, run:
tensorboard --logdir output/tensorboard_logs
Open the provided URL to visualize training metrics.
Evaluate your trained model on a specific dataset view (train
, val
, or test
).
CLI:
luxonis_train test --config configs/detection_light_model.yaml \ --view val \ --weights path/to/checkpoint.ckpt
Python API:
from luxonis_train import LuxonisModel model = LuxonisModel("configs/detection_light_model.yaml") model.test(weights="path/to/checkpoint.ckpt")
The testing process can be started automatically at the end of the training by using the TestOnTrainEnd
callback. To learn more about callbacks, see Callbacks.
Run inference on images, datasets, or videos.
CLI:
luxonis_train infer --config configs/detection_light_model.yaml \ --view val \ --weights path/to/checkpoint.ckpt
luxonis_train infer --config configs/detection_light_model.yaml \ --weights path/to/checkpoint.ckpt \ --source-path path/to/video.mp4
luxonis_train infer --config configs/detection_light_model.yaml \ --weights path/to/checkpoint.ckpt \ --source-path path/to/images \ --save-dir path/to/save_directory
Python API:
from luxonis_train import LuxonisModel model = LuxonisModel("configs/detection_light_model.yaml") # infer on a dataset view model.infer(weights="path/to/checkpoint.ckpt", view="val") # infer on a video file model.infer(weights="path/to/checkpoint.ckpt", source_path="path/to/video.mp4") # infer on an image directory and save the results model.infer( weights="path/to/checkpoint.ckpt", source_path="path/to/images", save_dir="path/to/save_directory", )
Export your trained models to formats suitable for deployment on edge devices.
Supported formats:
To configure the exporter, you can specify the exporter section in the config file.
You can see an example export configuration here.
CLI:
luxonis_train export --config configs/example_export.yaml --weights path/to/weights.ckpt
Python API:
from luxonis_train import LuxonisModel model = LuxonisModel("configs/example_export.yaml") model.export(weights="path/to/weights.ckpt")
Model export can be run automatically at the end of the training by using the ExportOnTrainEnd
callback.
The exported models are saved in the export directory within your output
folder.
Create an NN Archive
file for easy deployment with the DepthAI
API.
The archive contains the exported model together with all the metadata needed for running the model.
CLI:
luxonis_train archive \ --config configs/detection_light_model.yaml \ --weights path/to/checkpoint.ckpt
Python API:
from luxonis_train import LuxonisModel model = LuxonisModel("configs/detection_light_model.yaml") model.archive(weights="path/to/checkpoint.ckpt")
The archive can be created automatically at the end of the training by using the ArchiveOnTrainEnd
callback.
Optimize your model's performance using hyperparameter tuning powered by Optuna
.
Configuration:
Include a tuner
section in your configuration file.
tuner: study_name: det_study n_trials: 10 storage: backend: sqlite params: trainer.optimizer.name_categorical: ["Adam", "SGD"] trainer.optimizer.params.lr_float: [0.0001, 0.001] trainer.batch_size_int: [4, 16, 4]
CLI:
luxonis_train tune --config configs/example_tuning.yaml
Python API:
from luxonis_train import LuxonisModel model = LuxonisModel("configs/example_tuning.yaml") model.tune()
LuxonisTrain
is highly modular, allowing you to customize various components:
Creating Custom Components:
Implement custom components by subclassing the respective base classes and/or registering them. Registered components can be referenced in the config file. Custom components need to inherit from their respective base classes:
BaseLoaderTorch
BaseNode
BaseLoss
BaseMetric
BaseVisualizer
lightning.pytorch.callbacks.Callback
, requires manual registration to the CALLBACKS
registrytorch.optim.Optimizer
, requires manual registration to the OPTIMIZERS
registrytorch.optim.lr_scheduler.LRScheduler
, requires manual registration to the SCHEDULERS
registryBaseTrainingStrategy
Examples:
Custom Callback:
import lightning.pytorch as pl from luxonis_train import LuxonisLightningModule from luxonis_train.registry import CALLBACKS @CALLBACKS.register() class CustomCallback(pl.Callback): def __init__(self, message: str, **kwargs): super().__init__(**kwargs) self.message = message # Will be called at the end of each training epoch. # Consult the PyTorch Lightning documentation for more callback methods. def on_train_epoch_end( self, trainer: pl.Trainer, pl_module: LuxonisLightningModule, ) -> None: print(self.message)
Custom Loss:
from torch import Tensor from luxonis_train import BaseLoss, Tasks # Subclasses of `BaseNode`, `BaseLoss`, `BaseMetric` # and `BaseVisualizer` are registered automatically. class CustomLoss(BaseLoss): supported_tasks = [Tasks.CLASSIFICATION, Tasks.SEGMENTATION] def __init__(self, smoothing: float, **kwargs): super().__init__(**kwargs) self.smoothing = smoothing def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: # Implement the actual loss logic here value = predictions.sum() * self.smoothing return value.abs()
For additional examples of creating custom components, please refer to the examples section.
Using custom components in the configuration file:
model: nodes: - name: SegmentationHead losses: - name: CustomLoss params: smoothing: 0.0001 trainer: callbacks: - name: CustomCallback params: lr: "Hello from the custom callback!"
Note
Files containing the custom components must be sourced before the training script is run. To do that in CLI, you can use the --source
argument:
luxonis_train --source custom_components.py train --config config.yaml
Python API:
You have to import the custom components before creating the LuxonisModel
instance.
from custom_components import * from luxonis_train import LuxonisModel model = LuxonisModel("config.yaml") model.train()
For more information on how to define custom components, consult the respective in-source documentation.
π Tutorials and ExamplesWe are actively working on providing examples and tutorials for different parts of the library which will help you to start more easily. The tutorials can be found here and will be updated regularly.
When using cloud services, avoid hard-coding credentials or placing them directly in your configuration files. Instead:
.env
file and load it securely, ensuring it's excluded from version control.Supported Cloud Services:
AWS_ACCESS_KEY_ID
AWS_SECRET_ACCESS_KEY
AWS_S3_ENDPOINT_URL
GOOGLE_APPLICATION_CREDENTIALS
ROBOFLOW_API_KEY
For logging and tracking, we support:
MLFLOW_S3_BUCKET
MLFLOW_S3_ENDPOINT_URL
MLFLOW_TRACKING_URI
WANDB_API_KEY
For remote database storage, we support:
POSTGRES_PASSWORD
POSTGRES_HOST
POSTGRES_PORT
POSTGRES_DB
We welcome contributions! Please read our Contribution Guide to get started. Whether it's reporting bugs, improving documentation, or adding new features, your help is appreciated.
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4