Use one of the following examples after installing the Python package to get started:
CatBoostClassifierimport numpy as np
from catboost import CatBoostClassifier, Pool
train_data = np.random.randint(0,
100,
size=(100, 10))
train_labels = np.random.randint(0,
2,
size=(100))
test_data = catboost_pool = Pool(train_data,
train_labels)
model = CatBoostClassifier(iterations=2,
depth=2,
learning_rate=1,
loss_function='Logloss',
verbose=True)
model.fit(train_data, train_labels)
preds_class = model.predict(test_data)
preds_proba = model.predict_proba(test_data)
print("class = ", preds_class)
print("proba = ", preds_proba)
CatBoostRegressor
import numpy as np
from catboost import Pool, CatBoostRegressor
train_data = np.random.randint(0,
100,
size=(100, 10))
train_label = np.random.randint(0,
1000,
size=(100))
test_data = np.random.randint(0,
100,
size=(50, 10))
train_pool = Pool(train_data,
train_label,
cat_features=[0,2,5])
test_pool = Pool(test_data,
cat_features=[0,2,5])
model = CatBoostRegressor(iterations=2,
depth=2,
learning_rate=1,
loss_function='RMSE')
model.fit(train_pool)
preds = model.predict(test_pool)
print(preds)
CatBoost
Datasets can be read from input files. For example, the Pool class offers this functionality.
import numpy as np
from catboost import CatBoost, Pool
train_data = np.random.randint(0,
100,
size=(100, 10))
train_labels = np.random.randint(0,
2,
size=(100))
test_data = np.random.randint(0,
100,
size=(50, 10))
train_pool = Pool(train_data,
train_labels)
test_pool = Pool(test_data)
param = {'iterations':5}
model = CatBoost(param)
model.fit(train_pool)
preds_class = model.predict(test_pool, prediction_type='Class')
preds_proba = model.predict(test_pool, prediction_type='Probability')
preds_raw_vals = model.predict(test_pool, prediction_type='RawFormulaVal')
print("Class", preds_class)
print("Proba", preds_proba)
print("Raw", preds_raw_vals)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4