A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://github.com/tidymodels/parsnip/commit/43c15db377ea9ef27483ff209f6bd0e98cb830d2 below:

fixed multi_predict column names · tidymodels/parsnip@43c15db · GitHub

File tree Expand file treeCollapse file tree 8 files changed

+22

-17

lines changed

Filter options

Expand file treeCollapse file tree 8 files changed

+22

-17

lines changed Original file line number Diff line number Diff line change

@@ -33,7 +33,7 @@ utils::globalVariables(

33 33

'lab', 'original', 'predicted_label', 'prediction', 'value', 'type',

34 34

"neighbors", ".submodels", "has_submodel", "max_neighbor", "max_penalty",

35 35

"max_terms", "max_tree", "name", "num_terms", "penalty", "trees",

36 -

"sub_neighbors")

36 +

"sub_neighbors", ".pred_class")

37 37

)

38 38 39 39

# nocov end

Original file line number Diff line number Diff line change

@@ -12,8 +12,10 @@

12 12

#' @param ... Optional arguments to pass to `predict.model_fit(type = "raw")`

13 13

#' such as `type`.

14 14

#' @return A tibble with the same number of rows as the data being predicted.

15 -

#' Mostly likely, there is a list-column named `.pred` that is a tibble with

16 -

#' multiple rows per sub-model.

15 +

#' There is a list-column named `.pred` that contains tibbles with

16 +

#' multiple rows per sub-model. Note that, within the tibbles, the column names

17 +

#' follow the usual standard based on prediction `type` (i.e. `.pred_class` for

18 +

#' `type = "class"` and so on).

17 19

#' @export

18 20

multi_predict <- function(object, ...) {

19 21

if (inherits(object$fit, "try-error")) {

Original file line number Diff line number Diff line change

@@ -404,7 +404,7 @@ xgb_by_tree <- function(tree, object, new_data, type, ...) {

404 404

} else {

405 405

if (type == "class") {

406 406

pred <- object$spec$method$pred$class$post(pred, object)

407 -

pred <- tibble(.pred = factor(pred, levels = object$lvl))

407 +

pred <- tibble(.pred_class = factor(pred, levels = object$lvl))

408 408

} else {

409 409

pred <- object$spec$method$pred$prob$post(pred, object)

410 410

pred <- as_tibble(pred)

@@ -503,7 +503,7 @@ C50_by_tree <- function(tree, object, new_data, type, ...) {

503 503 504 504

# switch based on prediction type

505 505

if (type == "class") {

506 -

pred <- tibble(.pred = factor(pred, levels = object$lvl))

506 +

pred <- tibble(.pred_class = factor(pred, levels = object$lvl))

507 507

} else {

508 508

pred <- as_tibble(pred)

509 509

names(pred) <- paste0(".pred_", names(pred))

Original file line number Diff line number Diff line change

@@ -309,7 +309,7 @@ multi_predict._lognet <-

309 309

if (is.null(type))

310 310

type <- "class"

311 311

if (!(type %in% c("class", "prob", "link", "raw"))) {

312 -

stop ("`type` should be either 'class', 'link', 'raw', or 'prob'.", call. = FALSE)

312 +

stop("`type` should be either 'class', 'link', 'raw', or 'prob'.", call. = FALSE)

313 313

}

314 314

if (type == "prob")

315 315

dots$type <- "response"

@@ -321,12 +321,12 @@ multi_predict._lognet <-

321 321

param_key <- tibble(group = colnames(pred), penalty = penalty)

322 322

pred <- as_tibble(pred)

323 323

pred$.row <- 1:nrow(pred)

324 -

pred <- gather(pred, group, .pred, -.row)

324 +

pred <- gather(pred, group, .pred_class, -.row)

325 325

if (dots$type == "class") {

326 -

pred[[".pred"]] <- factor(pred[[".pred"]], levels = object$lvl)

326 +

pred[[".pred_class"]] <- factor(pred[[".pred_class"]], levels = object$lvl)

327 327

} else {

328 328

if (dots$type == "response") {

329 -

pred[[".pred2"]] <- 1 - pred[[".pred"]]

329 +

pred[[".pred2"]] <- 1 - pred[[".pred_class"]]

330 330

names(pred) <- c(".row", "group", paste0(".pred_", rev(object$lvl)))

331 331

pred <- pred[, c(".row", "group", paste0(".pred_", object$lvl))]

332 332

}

@@ -371,3 +371,4 @@ predict_raw._lognet <- function(object, new_data, opts = list(), ...) {

371 371

object$spec <- eval_args(object$spec)

372 372

predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)

373 373

}

374 + Original file line number Diff line number Diff line change

@@ -290,7 +290,7 @@ multi_predict._multnet <-

290 290

pred <-

291 291

tibble(

292 292

.row = rep(1:nrow(new_data), length(penalty)),

293 -

.pred = as.vector(pred),

293 +

.pred_class = as.vector(pred),

294 294

penalty = rep(penalty, each = nrow(new_data))

295 295

)

296 296

}

Original file line number Diff line number Diff line change

@@ -119,7 +119,7 @@ test_that('glmnet prediction, mulitiple lambda', {

119 119

mult_pred$rows <- rep(1:7, 2)

120 120

mult_pred <- mult_pred[order(mult_pred$rows, mult_pred$penalty), ]

121 121

mult_pred <- mult_pred[, c("penalty", "values")]

122 -

names(mult_pred) <- c("penalty", ".pred")

122 +

names(mult_pred) <- c("penalty", ".pred_class")

123 123

mult_pred <- tibble::as_tibble(mult_pred)

124 124 125 125

expect_equal(

@@ -148,7 +148,7 @@ test_that('glmnet prediction, mulitiple lambda', {

148 148

form_pred$rows <- rep(1:7, 2)

149 149

form_pred <- form_pred[order(form_pred$rows, form_pred$penalty), ]

150 150

form_pred <- form_pred[, c("penalty", "values")]

151 -

names(form_pred) <- c("penalty", ".pred")

151 +

names(form_pred) <- c("penalty", ".pred_class")

152 152

form_pred <- tibble::as_tibble(form_pred)

153 153 154 154

expect_equal(

@@ -180,7 +180,7 @@ test_that('glmnet prediction, no lambda', {

180 180

mult_pred$rows <- rep(1:7, 2)

181 181

mult_pred <- mult_pred[order(mult_pred$rows, mult_pred$penalty), ]

182 182

mult_pred <- mult_pred[, c("penalty", "values")]

183 -

names(mult_pred) <- c("penalty", ".pred")

183 +

names(mult_pred) <- c("penalty", ".pred_class")

184 184

mult_pred <- tibble::as_tibble(mult_pred)

185 185 186 186

expect_equal(mult_pred, multi_predict(xy_fit, lending_club[1:7, num_pred]) %>% unnest())

@@ -206,7 +206,7 @@ test_that('glmnet prediction, no lambda', {

206 206

form_pred$rows <- rep(1:7, 2)

207 207

form_pred <- form_pred[order(form_pred$rows, form_pred$penalty), ]

208 208

form_pred <- form_pred[, c("penalty", "values")]

209 -

names(form_pred) <- c("penalty", ".pred")

209 +

names(form_pred) <- c("penalty", ".pred_class")

210 210

form_pred <- tibble::as_tibble(form_pred)

211 211 212 212

expect_equal(

Original file line number Diff line number Diff line change

@@ -123,7 +123,7 @@ test_that('glmnet probabilities, mulitiple lambda', {

123 123 124 124

mult_class <- names(mult_probs)[apply(mult_probs, 1, which.max)]

125 125

mult_class <- tibble(

126 -

.pred = mult_class,

126 +

.pred_class = mult_class,

127 127

penalty = rep(lams, each = 3),

128 128

row = rep(1:3, 2)

129 129

)

You can’t perform that action at this time.


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4