A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://scikit-adaptation.github.io/auto_examples/validation/plot_cross_validation_for_da.html below:

Website Navigation


Visualizing cross-validation behavior in skada — SKADA : Scikit Adaptation

Visualizing cross-validation behavior in skada

This example illustrates the use of DA cross-validation object such as DomainShuffleSplit.

Let's prepare the imports:

# Author: Yanis Lalou
#
# License: BSD 3-Clause
# sphinx_gallery_thumbnail_number = 1

First we generate a dataset with 4 different domains. The domains are drawn from 4 different distributions: 2 source and 2 target distributions. The target distributions are shifted versions of the source distributions. Thus we will have a domain adaptation problem with 2 source domains and 2 target domains.

dataset = make_shifted_datasets(
    n_samples_source=3,
    n_samples_target=2,
    shift="concept_drift",
    label="binary",
    noise=0.4,
    random_state=RANDOM_SEED,
    return_dataset=True,
)

dataset2 = make_shifted_datasets(
    n_samples_source=3,
    n_samples_target=2,
    shift="concept_drift",
    label="binary",
    noise=0.4,
    random_state=RANDOM_SEED + 1,
    return_dataset=True,
)
dataset.merge(dataset2, names_mapping={"s": "s2", "t": "t2"})

X, y, sample_domain = dataset.pack_train(as_sources=["s", "s2"], as_targets=["t", "t2"])
_, target_labels, _ = dataset.pack(
    as_sources=["s", "s2"], as_targets=["t", "t2"], train=False
)

# Sort by sample_domain first then by target_labels
indx_sort = np.lexsort((target_labels, sample_domain))
X = X[indx_sort]
y = y[indx_sort]
target_labels = target_labels[indx_sort]
sample_domain = sample_domain[indx_sort]


# For Lodo methods
X_lodo, y_lodo, sample_domain_lodo = dataset.pack_lodo()

indx_sort = np.lexsort((y_lodo, sample_domain_lodo))
X_lodo = X_lodo[indx_sort]
y_lodo = y_lodo[indx_sort]
sample_domain_lodo = sample_domain_lodo[indx_sort]

We define functions to visualize the behavior of each cross-validation object. The number of splits is set to 4 (2 for the lodo method). For each split, we visualize the indices selected for the training set (in blue) and the test set (in orange).

# Code source: scikit-learn documentation
# Modified for documentation by Yanis Lalou
# License: BSD 3 clause
def plot_cv_indices(cv, X, y, sample_domain, ax, n_splits, lw=10):
    """Create a sample plot for indices of a cross-validation object."""
    # Generate the training/testing visualizations for each CV split
    cv_args = {"X": X, "y": y, "sample_domain": sample_domain}

    for ii, (tr, tt) in enumerate(cv.split(**cv_args)):
        # Fill in indices with the training/test sample_domain
        indices = np.array([np.nan] * len(X))
        indices[tt] = 1
        indices[tr] = 0

        # Visualize the results
        ax.scatter(
            [i / 2 for i in range(1, len(indices) * 2 + 1, 2)],
            [ii + 0.5] * len(indices),
            c=indices,
            marker="_",
            lw=lw,
            cmap=cmap_cv,
            vmin=-0.2,
            vmax=1.2,
        )

    # Plot the data classes and sample_domain at the end
    ax.scatter(
        [i / 2 for i in range(1, len(indices) * 2 + 1, 2)],
        [ii + 1.5] * len(X),
        c=y,
        marker="_",
        lw=lw,
        cmap=cmap_data,
        vmin=-1.2,
        vmax=0.2,
    )

    ax.scatter(
        [i / 2 for i in range(1, len(indices) * 2 + 1, 2)],
        [ii + 2.5] * len(X),
        c=sample_domain,
        marker="_",
        lw=lw,
        cmap=cmap_domain,
        vmin=-3.2,
        vmax=3.2,
    )

    # Formatting
    yticklabels = list(range(n_splits)) + ["class", "sample_domain"]
    ax.set(
        yticks=np.arange(n_splits + 2) + 0.5,
        yticklabels=yticklabels,
        ylim=[n_splits + 2.2, -0.2],
        xlim=[0, len(X)],
    )
    ax.set_title(f"{type(cv).__name__}", fontsize=15)
    return ax


def plot_lodo_indices(cv, X, y, sample_domain, ax, lw=10):
    """Create a sample plot for indices of a cross-validation object."""
    # Generate the training/testing visualizations for each CV split
    cv_args = {"X": X, "y": y, "sample_domain": sample_domain}

    for ii, (tr, tt) in enumerate(cv.split(**cv_args)):
        # Fill in indices with the training/test sample_domain
        indices = np.array([np.nan] * len(X))
        indices[tt] = 1
        indices[tr] = 0

        # Visualize the results
        ax.scatter(
            [i / 2 for i in range(1, len(indices) * 2 + 1, 2)],
            [ii + 0.5] * len(indices),
            c=indices,
            marker="_",
            lw=lw,
            cmap=cmap_cv,
            vmin=-0.2,
            vmax=1.2,
            s=1.8,
        )

    # Plot the data classes and sample_domain at the end
    ax.scatter(
        [i / 2 for i in range(1, len(indices) * 2 + 1, 2)],
        [ii + 1.5] * len(X),
        c=y,
        marker="_",
        lw=lw,
        cmap=cmap_data,
        vmin=-1.2,
        vmax=0.2,
    )

    ax.scatter(
        [i / 2 for i in range(1, len(indices) * 2 + 1, 2)],
        [ii + 2.5] * len(X),
        c=sample_domain,
        marker="_",
        lw=lw,
        cmap=cmap_domain,
        vmin=-3.2,
        vmax=3.2,
    )

    # Formatting
    yticklabels = list(range(n_splits)) + ["class", "sample_domain"]
    ax.set(
        yticks=np.arange(n_splits + 2) + 0.5,
        yticklabels=yticklabels,
        ylim=[n_splits + 2.2, -0.2],
        xlim=[0, len(X)],
    )
    ax.set_title(f"{type(cv).__name__}", fontsize=15)
    return ax


def plot_st_shuffle_indices(cv, X, y, target_labels, sample_domain, ax, n_splits, lw):
    """Create a sample plot for indices of a cross-validation object."""
    for n, labels in enumerate([y, target_labels]):
        # Generate the training/testing visualizations for each CV split
        cv_args = {"X": X, "y": labels, "sample_domain": sample_domain}
        for ii, (tr, tt) in enumerate(cv.split(**cv_args)):
            # Fill in indices with the training/test sample_domain
            indices = np.array([np.nan] * len(X))
            indices[tt] = 1
            indices[tr] = 0

            # Visualize the results
            ax[n].scatter(
                [i / 2 for i in range(1, len(indices) * 2 + 1, 2)],
                [ii + 0.5] * len(indices),
                c=indices,
                marker="_",
                lw=lw,
                cmap=cmap_cv,
                vmin=-0.2,
                vmax=1.2,
            )

        # Plot the data classes and sample_domain at the end
        ax[n].scatter(
            [i / 2 for i in range(1, len(indices) * 2 + 1, 2)],
            [ii + 1.5] * len(X),
            c=labels,
            marker="_",
            lw=lw,
            cmap=cmap_data,
            vmin=-1.2,
            vmax=0.2,
        )

        ax[n].scatter(
            [i / 2 for i in range(1, len(indices) * 2 + 1, 2)],
            [ii + 2.5] * len(X),
            c=sample_domain,
            marker="_",
            lw=lw,
            cmap=cmap_domain,
            vmin=-3.2,
            vmax=3.2,
        )

        # Formatting
        yticklabels = list(range(n_splits)) + ["class", "sample_domain"]
        ax[n].set(
            yticks=np.arange(n_splits + 2) + 0.5,
            yticklabels=yticklabels,
            ylim=[n_splits + 2.2, -0.2],
            xlim=[0, len(X)],
        )

    return ax

The following plot illustrates the behavior of SourceTargetShuffleSplit. The left plot shows the indices of the training and testing sets for each split and with the datased packed with pack_train() (the target domains labels are masked (=-1)). While the right plot shows the indices of the training and testing sets for each split and with the datased packed with pack_test().

cvs = [SourceTargetShuffleSplit]
for cv in cvs:
    fig, ax = plt.subplots(1, 2, figsize=(7, 3), sharey=True)
    fig.suptitle(f"{cv.__name__}", fontsize=15)
    plot_st_shuffle_indices(
        cv(n_splits), X, y, target_labels, sample_domain, ax, n_splits, 10
    )

    fig.legend(
        [Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))],
        ["Testing set", "Training set"],
        loc="center right",
    )
    fig.text(0.48, 0.01, "Sample index", ha="center")
    fig.text(0.001, 0.5, "CV iteration", va="center", rotation="vertical")

    # Make the legend fit
    plt.tight_layout()
    fig.subplots_adjust(right=0.7)

The following plot illustrates the behavior of LeaveOneDomainOut. The plot shows the indices of the training and testing sets for each split and which domain is used as the target domain for each split.

cvs = [LeaveOneDomainOut]
for cv in cvs:
    fig, ax = plt.subplots(figsize=(6, 3))
    plot_lodo_indices(cv(n_splits_lodo), X_lodo, y_lodo, sample_domain_lodo, ax)

    fig.legend(
        [Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))],
        ["Testing set", "Training set"],
        loc="center right",
    )
    fig.text(0.48, 0.01, "Sample index", ha="center")
    fig.text(0.001, 0.5, "CV iteration", va="center", rotation="vertical")

    # Make the legend fit
    plt.tight_layout()
    fig.subplots_adjust(right=0.7)

Now let's see how the other cross-validation objects behave on our dataset.

cvs = [
    DomainShuffleSplit,
    StratifiedDomainShuffleSplit,
]

for cv in cvs:
    fig, ax = plt.subplots(figsize=(6, 3))
    plot_cv_indices(cv(n_splits), X, y, sample_domain, ax, n_splits)

    fig.legend(
        [Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))],
        ["Testing set", "Training set"],
        loc="center right",
    )
    fig.text(0.48, 0.01, "Sample index", ha="center")
    fig.text(0.001, 0.5, "CV iteration", va="center", rotation="vertical")

    # Make the legend fit
    plt.tight_layout()
    fig.subplots_adjust(right=0.7)
As we can see each splitter has a very different behavior:

Total running time of the script: (0 minutes 0.539 seconds)

Gallery generated by Sphinx-Gallery


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4