A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://scikit-adaptation.github.io/auto_examples/validation/plot_gridsearch_for_da.html below:

Website Navigation


Using GridSearchCV with skada — SKADA : Scikit Adaptation

Using GridSearchCV with skada

This example illustrates the use of DA scorer such as ImportanceWeightedScorer with GridSearchCV.

We first create a shifted dataset. Then we prepare the pipeline including a base estimator doing the classification and the DA estimator. We use ShuffleSplit as cross-validation strategy.

import warnings

import matplotlib.pyplot as plt
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.model_selection import GridSearchCV, ShuffleSplit
from sklearn.svm import SVC

from skada import EntropicOTMapping
from skada.datasets import make_shifted_datasets
from skada.metrics import PredictionEntropyScorer

warnings.filterwarnings("ignore")

RANDOM_SEED = 42
dataset = make_shifted_datasets(
    n_samples_source=30,
    n_samples_target=20,
    shift="concept_drift",
    label="binary",
    noise=0.4,
    random_state=RANDOM_SEED,
    return_dataset=True,
)
X, y, sample_domain = dataset.pack_train(as_sources=["s"], as_targets=["t"])
X_target, y_target, _ = dataset.pack_test(as_targets=["t"])

estimator = EntropicOTMapping(base_estimator=SVC(probability=True))
cv = ShuffleSplit(n_splits=5, test_size=0.3, random_state=RANDOM_SEED)

We want to perform a grid search to find the best regularization parameter for the DA estimator. The DA pipeline can directly be used in GridSearchCV. We use the PredictionEntropyScorer to evaluate the performance of the DA estimator during the grid search.

reg_e = [0.01, 0.03, 0.05, 0.08, 0.1]

grid_search = GridSearchCV(
    estimator,
    {"entropicotmappingadapter__reg_e": reg_e},
    cv=cv,
    scoring=PredictionEntropyScorer(),
)

grid_search.fit(X, y, sample_domain=sample_domain)

best_reg_e = grid_search.best_params_["entropicotmappingadapter__reg_e"]
print(f"Best regularization parameter: {best_reg_e}")
Best regularization parameter: 0.08

Plot the results

DecisionBoundaryDisplay.from_estimator(
    grid_search.best_estimator_,
    X_target,
    alpha=0.8,
    eps=0.5,
    response_method="predict",
)

# Plot the target points
plt.scatter(
    X_target[:, 0],
    X_target[:, 1],
    c=y_target,
    alpha=0.5,
)
plt.show()

Total running time of the script: (0 minutes 11.356 seconds)

Gallery generated by Sphinx-Gallery


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4