A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://scikit-adaptation.github.io/auto_examples/methods/plot_monge_alignment_da.html below:

Website Navigation


Multi-domain Linear Monge Alignment — SKADA : Scikit Adaptation

Multi-domain Linear Monge Alignment

This example illustrates the use of the MultiLinearMongeAlignmentAdapter

# Author: Remi Flamary
#
# License: BSD 3-Clause
# sphinx_gallery_thumbnail_number = 4
Generate concept drift classification dataset and plot it

We generate a simple 2D concept drift dataset.

X, y, sample_domain = make_shifted_datasets(
    n_samples_source=20,
    n_samples_target=20,
    shift="concept_drift",
    noise=0.2,
    label="multiclass",
    random_state=42,
)


Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain)


plt.figure(5, (10, 5))
plt.subplot(1, 2, 1)
plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Source")
plt.title("Source data")
ax = plt.axis()

plt.subplot(1, 2, 2)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Target")
plt.title("Target data")
plt.axis(ax)
(np.float64(-2.3676169789533272), np.float64(4.53817259426296), np.float64(-1.7964573229829714), np.float64(4.336863129384494))
Train a classifier on source data

We train a simple SVC classifier on the source domain and evaluate its performance on the source and target domain. Performance is much lower on the target domain due to the shift. We also plot the decision boundary

clf = MultiLinearMongeAlignmentAdapter()
clf.fit(X, sample_domain=sample_domain)

X_adapt = clf.transform(X, sample_domain=sample_domain, allow_source=True)


plt.figure(5, (10, 3))
plt.subplot(1, 3, 1)
plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Source")
plt.title("Source data")
ax = plt.axis()

plt.subplot(1, 3, 2)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Target")
plt.title("Target data")
plt.axis(ax)

plt.subplot(1, 3, 3)
plt.scatter(
    X_adapt[sample_domain >= 0, 0],
    X_adapt[sample_domain >= 0, 1],
    c=y[sample_domain >= 0],
    marker="o",
    cmap="tab10",
    vmax=9,
    label="Source",
    alpha=0.5,
)
plt.scatter(
    X_adapt[sample_domain < 0, 0],
    X_adapt[sample_domain < 0, 1],
    c=y[sample_domain < 0],
    marker="x",
    cmap="tab10",
    vmax=9,
    label="Target",
    alpha=1,
)
plt.legend()
plt.title("Adapted data")
Text(0.5, 1.0, 'Adapted data')
Train a classifier on adapted data
Average accuracy on all domains: 0.9875
def get_multidomain_data(
    n_samples_source=100,
    n_samples_target=100,
    noise=0.1,
    random_state=None,
    n_sources=3,
    n_targets=2,
):
    np.random.seed(random_state)
    X, y, sample_domain = make_shifted_datasets(
        n_samples_source=n_samples_source,
        n_samples_target=n_samples_target,
        noise=noise,
        shift="concept_drift",
        label="multiclass",
        random_state=random_state,
    )
    for ns in range(n_sources - 1):
        Xi, yi, sample_domaini = make_shifted_datasets(
            n_samples_source=n_samples_source,
            n_samples_target=n_samples_target,
            noise=noise,
            shift="concept_drift",
            label="multiclass",
            random_state=random_state + ns,
            mean=np.random.randn(2),
            sigma=np.random.rand(2) * 0.5 + 0.5,
        )
        Xs, Xt, ys, yt = source_target_split(Xi, yi, sample_domain=sample_domaini)
        X = np.vstack([X, Xt])
        y = np.hstack([y, yt])
        sample_domain = np.hstack([sample_domain, np.ones(Xt.shape[0]) * (ns + 2)])

    for nt in range(n_targets - 1):
        Xi, yi, sample_domaini = make_shifted_datasets(
            n_samples_source=n_samples_source,
            n_samples_target=n_samples_target,
            noise=noise,
            shift="concept_drift",
            label="multiclass",
            random_state=random_state + nt + 42,
            mean=np.random.randn(2),
            sigma=np.random.rand(2) * 0.5 + 0.5,
        )
        Xs, Xt, ys, yt = source_target_split(Xi, yi, sample_domain=sample_domaini)
        X = np.vstack([X, Xt])
        y = np.hstack([y, yt])
        sample_domain = np.hstack([sample_domain, -np.ones(Xt.shape[0]) * (nt + 1)])

    return X, y, sample_domain


X, y, sample_domain = get_multidomain_data(
    n_samples_source=50,
    n_samples_target=50,
    noise=0.1,
    random_state=43,
    n_sources=3,
    n_targets=2,
)

Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain)


plt.figure(5, (10, 5))
plt.subplot(1, 2, 1)
plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Source")
plt.title("Source data")
ax = plt.axis()

plt.subplot(1, 2, 2)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Target")
plt.title("Target domains")
plt.axis(ax)
(np.float64(-2.310098338155625), np.float64(4.756925382279493), np.float64(-2.1443686989830857), np.float64(4.464886123797522))
clf = MultiLinearMongeAlignmentAdapter()
clf.fit(X, sample_domain=sample_domain)

X_adapt = clf.transform(X, sample_domain=sample_domain, allow_source=True)


plt.figure(5, (10, 3))
plt.subplot(1, 3, 1)
plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Source")
plt.title("Source data")
ax = plt.axis()

plt.subplot(1, 3, 2)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Target")
plt.title("Target data")
plt.axis(ax)

plt.subplot(1, 3, 3)
plt.scatter(
    X_adapt[sample_domain >= 0, 0],
    X_adapt[sample_domain >= 0, 1],
    c=y[sample_domain >= 0],
    marker="o",
    cmap="tab10",
    vmax=9,
    label="Source",
    alpha=0.5,
)
plt.scatter(
    X_adapt[sample_domain < 0, 0],
    X_adapt[sample_domain < 0, 1],
    c=y[sample_domain < 0],
    marker="x",
    cmap="tab10",
    vmax=9,
    label="Target",
    alpha=1,
)
plt.legend()
plt.axis(ax)
plt.title("Adapted data")
Text(0.5, 1.0, 'Adapted data')
Train a classifier on adapted data
Average accuracy on all domains: 1.0

Total running time of the script: (0 minutes 0.856 seconds)

Gallery generated by Sphinx-Gallery


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4