This example illustrates the DeepCoral method from [1] on a simple image classification task.
# Author: Théo Gnassounou # # License: BSD 3-Clause # sphinx_gallery_thumbnail_number = 4Load the image datasets
dataset = load_mnist_usps(n_classes=2, n_samples=0.5, return_dataset=True) X, y, sample_domain = dataset.pack_train(as_sources=["mnist"], as_targets=["usps"]) X_test, y_test, sample_domain_test = dataset.pack_test(as_targets=["usps"])
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz Failed to download (trying next): HTTP Error 403: Forbidden Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./datasets/MNIST/raw/train-images-idx3-ubyte.gz 0%| | 0.00/9.91M [00:00<?, ?B/s] 100%|██████████| 9.91M/9.91M [00:00<00:00, 144MB/s] Extracting ./datasets/MNIST/raw/train-images-idx3-ubyte.gz to ./datasets/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz Failed to download (trying next): HTTP Error 403: Forbidden Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./datasets/MNIST/raw/train-labels-idx1-ubyte.gz 0%| | 0.00/28.9k [00:00<?, ?B/s] 100%|██████████| 28.9k/28.9k [00:00<00:00, 16.1MB/s] Extracting ./datasets/MNIST/raw/train-labels-idx1-ubyte.gz to ./datasets/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz Failed to download (trying next): HTTP Error 403: Forbidden Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./datasets/MNIST/raw/t10k-images-idx3-ubyte.gz 0%| | 0.00/1.65M [00:00<?, ?B/s] 100%|██████████| 1.65M/1.65M [00:00<00:00, 42.0MB/s] Extracting ./datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to ./datasets/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz Failed to download (trying next): HTTP Error 403: Forbidden Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz 0%| | 0.00/4.54k [00:00<?, ?B/s] 100%|██████████| 4.54k/4.54k [00:00<00:00, 14.5MB/s] Extracting ./datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./datasets/MNIST/raw /home/circleci/project/skada/datasets/_mnist_usps.py:72: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). mnist_target = torch.tensor(mnist_dataset.targets) Downloading https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2 to ./datasets/usps.t.bz2 0%| | 0.00/1.83M [00:00<?, ?B/s] 2%|▏ | 32.8k/1.83M [00:00<00:11, 150kB/s] 4%|▎ | 65.5k/1.83M [00:00<00:11, 151kB/s] 7%|▋ | 131k/1.83M [00:00<00:07, 218kB/s] 11%|█ | 197k/1.83M [00:00<00:06, 243kB/s] 18%|█▊ | 328k/1.83M [00:01<00:04, 368kB/s] 25%|██▌ | 459k/1.83M [00:01<00:03, 435kB/s] 34%|███▍ | 623k/1.83M [00:01<00:02, 524kB/s] 43%|████▎ | 786k/1.83M [00:01<00:01, 581kB/s] 54%|█████▎ | 983k/1.83M [00:02<00:01, 659kB/s] 63%|██████▎ | 1.15M/1.83M [00:02<00:01, 676kB/s] 73%|███████▎ | 1.34M/1.83M [00:02<00:00, 728kB/s] 84%|████████▍ | 1.54M/1.83M [00:02<00:00, 765kB/s] 95%|█████████▍| 1.74M/1.83M [00:02<00:00, 793kB/s] 100%|██████████| 1.83M/1.83M [00:03<00:00, 609kB/s]Train a classic model
epoch train_loss dur ------- ------------ ------ 1 1.4900 9.5993 2 0.3168 8.1940 3 0.1119 7.3040 4 0.0610 6.7043 5 0.0458 8.6105 0.8906752411575563Train a DeepCoral model
model = DeepCoral( MNISTtoUSPSNet(), layer_name="fc1", batch_size=128, max_epochs=5, train_split=False, reg=1, lr=1e-2, ) model.fit(X, y, sample_domain=sample_domain) model.score(X_test, y_test, sample_domain=sample_domain_test)
epoch train_loss dur ------- ------------ ------- 1 1.6827 30.1814 2 0.4499 17.5877 3 0.1586 14.8993 4 0.0891 13.2113 5 0.0644 12.6976 0.8938906752411575
Total running time of the script: (2 minutes 22.076 seconds)
Gallery generated by Sphinx-Gallery
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4