This example illustrates the use of deep DA methods in Skada. on a simple image classification task.
# Author: Théo Gnassounou # # License: BSD 3-Clause # sphinx_gallery_thumbnail_number = 4Load the image datasets
dataset = load_mnist_usps(n_classes=2, n_samples=0.5, return_dataset=True) X, y, sample_domain = dataset.pack_train(as_sources=["mnist"], as_targets=["usps"]) X_test, y_test, sample_domain_test = dataset.pack_test(as_targets=["usps"])
/home/circleci/project/skada/datasets/_mnist_usps.py:72: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). mnist_target = torch.tensor(mnist_dataset.targets)Training parameters Training with skorch
model = DeepCoral( MNISTtoUSPSNet(), layer_name="fc1", batch_size=batch_size, max_epochs=max_epochs, train_split=False, reg=reg, lr=lr, device=device, ) model.fit(X, y, sample_domain=sample_domain)
epoch train_loss dur ------- ------------ ------- 1 2.2764 10.0830 2 2.1936 9.4029 <class 'skada.deep.base.DomainAwareNet'>[initialized]( module_=DomainAwareModule( (base_module_): MNISTtoUSPSNet( (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1)) (relu1): ReLU() (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1)) (relu2): ReLU() (dropout1): Dropout(p=0.25, inplace=False) (dropout2): Dropout(p=0.5, inplace=False) (fc1): Linear(in_features=9216, out_features=128, bias=True) (relu3): ReLU() (fc2): Linear(in_features=128, out_features=10, bias=True) (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) ), )Training with skorch with dataset
X_dict = {"X": torch.tensor(X), "sample_domain": torch.tensor(sample_domain)} # TODO create a dataset also without skorch dataset = Dataset(X_dict, torch.tensor(y)) model = DeepCoral( MNISTtoUSPSNet(), layer_name="fc1", batch_size=batch_size, max_epochs=max_epochs, train_split=False, reg=reg, lr=lr, device=device, ) model.fit(dataset, y=None, sample_domain=None)
epoch train_loss dur ------- ------------ ------ 1 2.2715 8.8677 2 2.1959 9.2985 <class 'skada.deep.base.DomainAwareNet'>[initialized]( module_=DomainAwareModule( (base_module_): MNISTtoUSPSNet( (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1)) (relu1): ReLU() (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1)) (relu2): ReLU() (dropout1): Dropout(p=0.25, inplace=False) (dropout2): Dropout(p=0.5, inplace=False) (fc1): Linear(in_features=9216, out_features=128, bias=True) (relu3): ReLU() (fc2): Linear(in_features=128, out_features=10, bias=True) (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) ), )Training with torch
model = DomainAwareModule(MNISTtoUSPSNet(), layer_name="fc1").to(device) optimizer = torch.optim.Adam(model.parameters(), lr=lr) dataloader = DomainBalancedDataLoader(dataset, batch_size=batch_size) loss_fn = DomainAwareCriterion(torch.nn.CrossEntropyLoss(), DeepCoralLoss(), reg=reg) # Training loop for epoch in range(max_epochs): model.train() running_loss = 0.0 iter = 0 for inputs, labels in dataloader: inputs, labels = inputs, labels.to(device) # Zero the gradients optimizer.zero_grad() # Forward pass outputs = model(**inputs, is_fit=True) loss = loss_fn(outputs, labels) # Backward pass and optimization loss.backward() optimizer.step() running_loss += loss.item() iter += 1 print("Loss:", running_loss / iter)
Loss: 0.8874167464673519 Loss: 0.18708933144807816
Total running time of the script: (1 minutes 1.895 seconds)
Gallery generated by Sphinx-Gallery
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4