This example illustrates the Optimal Transport deep DA method from on a simple image classification task.
# Author: Théo Gnassounou # # License: BSD 3-Clause # sphinx_gallery_thumbnail_number = 4Load the image datasets
dataset = load_mnist_usps(n_classes=2, n_samples=0.5, return_dataset=True) X, y, sample_domain = dataset.pack_train(as_sources=["mnist"], as_targets=["usps"]) X_test, y_test, sample_domain_test = dataset.pack_test(as_targets=["usps"])
/home/circleci/project/skada/datasets/_mnist_usps.py:72: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). mnist_target = torch.tensor(mnist_dataset.targets)Train a classic model
epoch train_loss dur ------- ------------ ------ 1 1.6460 6.1998 2 0.4327 6.4993 3 0.1498 6.6010 4 0.0746 12.6988 5 0.0511 13.5012 0.8938906752411575Train a DeepJDOT model
model = DeepJDOT( MNISTtoUSPSNet(), layer_name="fc1", batch_size=128, max_epochs=5, train_split=False, reg_dist=0.1, reg_cl=0.01, lr=1e-2, ) model.fit(X, y, sample_domain=sample_domain) model.score(X_test, y_test, sample_domain=sample_domain_test)
epoch train_loss dur ------- ------------ ------- 1 2.1794 38.5859 2 1.3292 13.6001 3 0.8222 10.7061 4 0.6452 10.4921 5 0.5480 10.5015 0.9389067524115756
Total running time of the script: (2 minutes 14.978 seconds)
Gallery generated by Sphinx-Gallery
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4