+22
-17
lines changedFilter options
+22
-17
lines changed Original file line number Diff line number Diff line change
@@ -33,7 +33,7 @@ utils::globalVariables(
33
33
'lab', 'original', 'predicted_label', 'prediction', 'value', 'type',
34
34
"neighbors", ".submodels", "has_submodel", "max_neighbor", "max_penalty",
35
35
"max_terms", "max_tree", "name", "num_terms", "penalty", "trees",
36
-
"sub_neighbors")
36
+
"sub_neighbors", ".pred_class")
37
37
)
38
38
39
39
# nocov end
Original file line number Diff line number Diff line change
@@ -12,8 +12,10 @@
12
12
#' @param ... Optional arguments to pass to `predict.model_fit(type = "raw")`
13
13
#' such as `type`.
14
14
#' @return A tibble with the same number of rows as the data being predicted.
15
-
#' Mostly likely, there is a list-column named `.pred` that is a tibble with
16
-
#' multiple rows per sub-model.
15
+
#' There is a list-column named `.pred` that contains tibbles with
16
+
#' multiple rows per sub-model. Note that, within the tibbles, the column names
17
+
#' follow the usual standard based on prediction `type` (i.e. `.pred_class` for
18
+
#' `type = "class"` and so on).
17
19
#' @export
18
20
multi_predict <- function(object, ...) {
19
21
if (inherits(object$fit, "try-error")) {
Original file line number Diff line number Diff line change
@@ -404,7 +404,7 @@ xgb_by_tree <- function(tree, object, new_data, type, ...) {
404
404
} else {
405
405
if (type == "class") {
406
406
pred <- object$spec$method$pred$class$post(pred, object)
407
-
pred <- tibble(.pred = factor(pred, levels = object$lvl))
407
+
pred <- tibble(.pred_class = factor(pred, levels = object$lvl))
408
408
} else {
409
409
pred <- object$spec$method$pred$prob$post(pred, object)
410
410
pred <- as_tibble(pred)
@@ -503,7 +503,7 @@ C50_by_tree <- function(tree, object, new_data, type, ...) {
503
503
504
504
# switch based on prediction type
505
505
if (type == "class") {
506
-
pred <- tibble(.pred = factor(pred, levels = object$lvl))
506
+
pred <- tibble(.pred_class = factor(pred, levels = object$lvl))
507
507
} else {
508
508
pred <- as_tibble(pred)
509
509
names(pred) <- paste0(".pred_", names(pred))
Original file line number Diff line number Diff line change
@@ -309,7 +309,7 @@ multi_predict._lognet <-
309
309
if (is.null(type))
310
310
type <- "class"
311
311
if (!(type %in% c("class", "prob", "link", "raw"))) {
312
-
stop ("`type` should be either 'class', 'link', 'raw', or 'prob'.", call. = FALSE)
312
+
stop("`type` should be either 'class', 'link', 'raw', or 'prob'.", call. = FALSE)
313
313
}
314
314
if (type == "prob")
315
315
dots$type <- "response"
@@ -321,12 +321,12 @@ multi_predict._lognet <-
321
321
param_key <- tibble(group = colnames(pred), penalty = penalty)
322
322
pred <- as_tibble(pred)
323
323
pred$.row <- 1:nrow(pred)
324
-
pred <- gather(pred, group, .pred, -.row)
324
+
pred <- gather(pred, group, .pred_class, -.row)
325
325
if (dots$type == "class") {
326
-
pred[[".pred"]] <- factor(pred[[".pred"]], levels = object$lvl)
326
+
pred[[".pred_class"]] <- factor(pred[[".pred_class"]], levels = object$lvl)
327
327
} else {
328
328
if (dots$type == "response") {
329
-
pred[[".pred2"]] <- 1 - pred[[".pred"]]
329
+
pred[[".pred2"]] <- 1 - pred[[".pred_class"]]
330
330
names(pred) <- c(".row", "group", paste0(".pred_", rev(object$lvl)))
331
331
pred <- pred[, c(".row", "group", paste0(".pred_", object$lvl))]
332
332
}
@@ -371,3 +371,4 @@ predict_raw._lognet <- function(object, new_data, opts = list(), ...) {
371
371
object$spec <- eval_args(object$spec)
372
372
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
373
373
}
374
+
Original file line number Diff line number Diff line change
@@ -290,7 +290,7 @@ multi_predict._multnet <-
290
290
pred <-
291
291
tibble(
292
292
.row = rep(1:nrow(new_data), length(penalty)),
293
-
.pred = as.vector(pred),
293
+
.pred_class = as.vector(pred),
294
294
penalty = rep(penalty, each = nrow(new_data))
295
295
)
296
296
}
Original file line number Diff line number Diff line change
@@ -119,7 +119,7 @@ test_that('glmnet prediction, mulitiple lambda', {
119
119
mult_pred$rows <- rep(1:7, 2)
120
120
mult_pred <- mult_pred[order(mult_pred$rows, mult_pred$penalty), ]
121
121
mult_pred <- mult_pred[, c("penalty", "values")]
122
-
names(mult_pred) <- c("penalty", ".pred")
122
+
names(mult_pred) <- c("penalty", ".pred_class")
123
123
mult_pred <- tibble::as_tibble(mult_pred)
124
124
125
125
expect_equal(
@@ -148,7 +148,7 @@ test_that('glmnet prediction, mulitiple lambda', {
148
148
form_pred$rows <- rep(1:7, 2)
149
149
form_pred <- form_pred[order(form_pred$rows, form_pred$penalty), ]
150
150
form_pred <- form_pred[, c("penalty", "values")]
151
-
names(form_pred) <- c("penalty", ".pred")
151
+
names(form_pred) <- c("penalty", ".pred_class")
152
152
form_pred <- tibble::as_tibble(form_pred)
153
153
154
154
expect_equal(
@@ -180,7 +180,7 @@ test_that('glmnet prediction, no lambda', {
180
180
mult_pred$rows <- rep(1:7, 2)
181
181
mult_pred <- mult_pred[order(mult_pred$rows, mult_pred$penalty), ]
182
182
mult_pred <- mult_pred[, c("penalty", "values")]
183
-
names(mult_pred) <- c("penalty", ".pred")
183
+
names(mult_pred) <- c("penalty", ".pred_class")
184
184
mult_pred <- tibble::as_tibble(mult_pred)
185
185
186
186
expect_equal(mult_pred, multi_predict(xy_fit, lending_club[1:7, num_pred]) %>% unnest())
@@ -206,7 +206,7 @@ test_that('glmnet prediction, no lambda', {
206
206
form_pred$rows <- rep(1:7, 2)
207
207
form_pred <- form_pred[order(form_pred$rows, form_pred$penalty), ]
208
208
form_pred <- form_pred[, c("penalty", "values")]
209
-
names(form_pred) <- c("penalty", ".pred")
209
+
names(form_pred) <- c("penalty", ".pred_class")
210
210
form_pred <- tibble::as_tibble(form_pred)
211
211
212
212
expect_equal(
Original file line number Diff line number Diff line change
@@ -123,7 +123,7 @@ test_that('glmnet probabilities, mulitiple lambda', {
123
123
124
124
mult_class <- names(mult_probs)[apply(mult_probs, 1, which.max)]
125
125
mult_class <- tibble(
126
-
.pred = mult_class,
126
+
.pred_class = mult_class,
127
127
penalty = rep(lams, each = 3),
128
128
row = rep(1:3, 2)
129
129
)
You can’t perform that action at this time.
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4