This example illustrates the use of the MultiLinearMongeAlignmentAdapter
# Author: Remi Flamary # # License: BSD 3-Clause # sphinx_gallery_thumbnail_number = 4Generate concept drift classification dataset and plot it
We generate a simple 2D concept drift dataset.
X, y, sample_domain = make_shifted_datasets( n_samples_source=20, n_samples_target=20, shift="concept_drift", noise=0.2, label="multiclass", random_state=42, ) Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain) plt.figure(5, (10, 5)) plt.subplot(1, 2, 1) plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Source") plt.title("Source data") ax = plt.axis() plt.subplot(1, 2, 2) plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Target") plt.title("Target data") plt.axis(ax)
(np.float64(-2.3676169789533272), np.float64(4.53817259426296), np.float64(-1.7964573229829714), np.float64(4.336863129384494))Train a classifier on source data
We train a simple SVC classifier on the source domain and evaluate its performance on the source and target domain. Performance is much lower on the target domain due to the shift. We also plot the decision boundary
clf = MultiLinearMongeAlignmentAdapter() clf.fit(X, sample_domain=sample_domain) X_adapt = clf.transform(X, sample_domain=sample_domain, allow_source=True) plt.figure(5, (10, 3)) plt.subplot(1, 3, 1) plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Source") plt.title("Source data") ax = plt.axis() plt.subplot(1, 3, 2) plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Target") plt.title("Target data") plt.axis(ax) plt.subplot(1, 3, 3) plt.scatter( X_adapt[sample_domain >= 0, 0], X_adapt[sample_domain >= 0, 1], c=y[sample_domain >= 0], marker="o", cmap="tab10", vmax=9, label="Source", alpha=0.5, ) plt.scatter( X_adapt[sample_domain < 0, 0], X_adapt[sample_domain < 0, 1], c=y[sample_domain < 0], marker="x", cmap="tab10", vmax=9, label="Target", alpha=1, ) plt.legend() plt.title("Adapted data")
Text(0.5, 1.0, 'Adapted data')Train a classifier on adapted data
Average accuracy on all domains: 0.9875
def get_multidomain_data( n_samples_source=100, n_samples_target=100, noise=0.1, random_state=None, n_sources=3, n_targets=2, ): np.random.seed(random_state) X, y, sample_domain = make_shifted_datasets( n_samples_source=n_samples_source, n_samples_target=n_samples_target, noise=noise, shift="concept_drift", label="multiclass", random_state=random_state, ) for ns in range(n_sources - 1): Xi, yi, sample_domaini = make_shifted_datasets( n_samples_source=n_samples_source, n_samples_target=n_samples_target, noise=noise, shift="concept_drift", label="multiclass", random_state=random_state + ns, mean=np.random.randn(2), sigma=np.random.rand(2) * 0.5 + 0.5, ) Xs, Xt, ys, yt = source_target_split(Xi, yi, sample_domain=sample_domaini) X = np.vstack([X, Xt]) y = np.hstack([y, yt]) sample_domain = np.hstack([sample_domain, np.ones(Xt.shape[0]) * (ns + 2)]) for nt in range(n_targets - 1): Xi, yi, sample_domaini = make_shifted_datasets( n_samples_source=n_samples_source, n_samples_target=n_samples_target, noise=noise, shift="concept_drift", label="multiclass", random_state=random_state + nt + 42, mean=np.random.randn(2), sigma=np.random.rand(2) * 0.5 + 0.5, ) Xs, Xt, ys, yt = source_target_split(Xi, yi, sample_domain=sample_domaini) X = np.vstack([X, Xt]) y = np.hstack([y, yt]) sample_domain = np.hstack([sample_domain, -np.ones(Xt.shape[0]) * (nt + 1)]) return X, y, sample_domain X, y, sample_domain = get_multidomain_data( n_samples_source=50, n_samples_target=50, noise=0.1, random_state=43, n_sources=3, n_targets=2, ) Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain) plt.figure(5, (10, 5)) plt.subplot(1, 2, 1) plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Source") plt.title("Source data") ax = plt.axis() plt.subplot(1, 2, 2) plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Target") plt.title("Target domains") plt.axis(ax)
(np.float64(-2.310098338155625), np.float64(4.756925382279493), np.float64(-2.1443686989830857), np.float64(4.464886123797522))
clf = MultiLinearMongeAlignmentAdapter() clf.fit(X, sample_domain=sample_domain) X_adapt = clf.transform(X, sample_domain=sample_domain, allow_source=True) plt.figure(5, (10, 3)) plt.subplot(1, 3, 1) plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Source") plt.title("Source data") ax = plt.axis() plt.subplot(1, 3, 2) plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Target") plt.title("Target data") plt.axis(ax) plt.subplot(1, 3, 3) plt.scatter( X_adapt[sample_domain >= 0, 0], X_adapt[sample_domain >= 0, 1], c=y[sample_domain >= 0], marker="o", cmap="tab10", vmax=9, label="Source", alpha=0.5, ) plt.scatter( X_adapt[sample_domain < 0, 0], X_adapt[sample_domain < 0, 1], c=y[sample_domain < 0], marker="x", cmap="tab10", vmax=9, label="Target", alpha=1, ) plt.legend() plt.axis(ax) plt.title("Adapted data")
Text(0.5, 1.0, 'Adapted data')Train a classifier on adapted data
Average accuracy on all domains: 1.0
Total running time of the script: (0 minutes 0.856 seconds)
Gallery generated by Sphinx-Gallery
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4