import argparse import os import tempfile from typing import Dict import numpy as np from matplotlib import pyplot as plt from sklearn.datasets import load_breast_cancer from sklearn.model_selection import train_test_split import xgboost as xgb class Plotting(xgb.callback.TrainingCallback): """Plot evaluation result during training. Only for demonstration purpose as it's quite slow to draw using matplotlib. """ def __init__(self, rounds: int) -> None: self.fig = plt.figure() self.ax = self.fig.add_subplot(111) self.rounds = rounds self.lines: Dict[str, plt.Line2D] = {} self.fig.show() self.x = np.linspace(0, self.rounds, self.rounds) plt.ion() def _get_key(self, data: str, metric: str) -> str: return f"{data}-{metric}" def after_iteration( self, model: xgb.Booster, epoch: int, evals_log: Dict[str, dict] ) -> bool: """Update the plot.""" if not self.lines: for data, metric in evals_log.items(): for metric_name, log in metric.items(): key = self._get_key(data, metric_name) expanded = log + [0] * (self.rounds - len(log)) (self.lines[key],) = self.ax.plot(self.x, expanded, label=key) self.ax.legend() else: # https://pythonspot.com/matplotlib-update-plot/ for data, metric in evals_log.items(): for metric_name, log in metric.items(): key = self._get_key(data, metric_name) expanded = log + [0] * (self.rounds - len(log)) self.lines[key].set_ydata(expanded) self.fig.canvas.draw() # False to indicate training should not stop. return False def custom_callback() -> None: """Demo for defining a custom callback function that plots evaluation result during training.""" X, y = load_breast_cancer(return_X_y=True) X_train, X_valid, y_train, y_valid = train_test_split(X, y, random_state=0) D_train = xgb.DMatrix(X_train, y_train) D_valid = xgb.DMatrix(X_valid, y_valid) num_boost_round = 100 plotting = Plotting(num_boost_round) # Pass it to the `callbacks` parameter as a list. xgb.train( { "objective": "binary:logistic", "eval_metric": ["error", "rmse"], "tree_method": "hist", "device": "cuda", }, D_train, evals=[(D_train, "Train"), (D_valid, "Valid")], num_boost_round=num_boost_round, callbacks=[plotting], ) def check_point_callback() -> None: """Demo for using the checkpoint callback. Custom logic for handling output is usually required and users are encouraged to define their own callback for checkpointing operations. The builtin one can be used as a starting point. """ # Only for demo, set a larger value (like 100) in practice as checkpointing is quite # slow. rounds = 2 def check(as_pickle: bool) -> None: for i in range(0, 10, rounds): if i == 0: continue if as_pickle: path = os.path.join(tmpdir, "model_" + str(i) + ".pkl") else: path = os.path.join( tmpdir, f"model_{i}.{xgb.callback.TrainingCheckPoint.default_format}", ) assert os.path.exists(path) X, y = load_breast_cancer(return_X_y=True) m = xgb.DMatrix(X, y) # Check point to a temporary directory for demo with tempfile.TemporaryDirectory() as tmpdir: # Use callback class from xgboost.callback # Feel free to subclass/customize it to suit your need. check_point = xgb.callback.TrainingCheckPoint( directory=tmpdir, interval=rounds, name="model" ) xgb.train( {"objective": "binary:logistic"}, m, num_boost_round=10, verbose_eval=False, callbacks=[check_point], ) check(False) # This version of checkpoint saves everything including parameters and # model. See: doc/tutorials/saving_model.rst check_point = xgb.callback.TrainingCheckPoint( directory=tmpdir, interval=rounds, as_pickle=True, name="model" ) xgb.train( {"objective": "binary:logistic"}, m, num_boost_round=10, verbose_eval=False, callbacks=[check_point], ) check(True) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--plot", default=1, type=int) args = parser.parse_args() check_point_callback() if args.plot: custom_callback()
Gallery generated by Sphinx-Gallery
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4