When trying to predict probabilities from a multinomial logistic regression with the 'glmnet' engine, I get the following error:
options(stringsAsFactors = F) library(tidyverse) library(magrittr) #> #> Attaching package: 'magrittr' #> The following object is masked from 'package:purrr': #> #> set_names #> The following object is masked from 'package:tidyr': #> #> extract library(tidymodels) #> ── Attaching packages ────────────────────────────────────────────────────────────────────────────────────────────────── tidymodels 0.0.2 ── #> ✔ broom 0.5.2 ✔ recipes 0.1.6 #> ✔ dials 0.0.2 ✔ rsample 0.0.5 #> ✔ infer 0.4.0.1 ✔ yardstick 0.0.3 #> ✔ parsnip 0.0.3.1 #> ── Conflicts ───────────────────────────────────────────────────────────────────────────────────────────────────── tidymodels_conflicts() ── #> ✖ scales::discard() masks purrr::discard() #> ✖ magrittr::extract() masks tidyr::extract() #> ✖ dplyr::filter() masks stats::filter() #> ✖ recipes::fixed() masks stringr::fixed() #> ✖ dplyr::lag() masks stats::lag() #> ✖ magrittr::set_names() masks purrr::set_names() #> ✖ yardstick::spec() masks readr::spec() #> ✖ recipes::step() masks stats::step() library(glmnet) #> Loading required package: Matrix #> #> Attaching package: 'Matrix' #> The following object is masked from 'package:tidyr': #> #> expand #> Loading required package: foreach #> #> Attaching package: 'foreach' #> The following objects are masked from 'package:purrr': #> #> accumulate, when #> Loaded glmnet 2.0-18 # create a toy dataset n_obs = 100 n_feats = 200 mat = matrix(NA, nrow = n_obs, ncol = n_feats, dimnames = list(paste0("observation_", seq_len(n_obs)), paste0("feature_", seq_len(n_feats)))) mat[] = rnorm(length(c(mat))) # create labels labels = runif(n = n_obs, min = 0, max = 3) %>% floor() %>% factor() # get optimized penalty with cv.glmnet penalty = cv.glmnet(mat, labels, nfolds = 3, family = 'multinomial') %>% extract2("lambda.1se") # create classifier clf = logistic_reg(mixture = 1, penalty = penalty, mode = 'classification') %>% set_engine('glmnet', family = 'multinomial') # fit models in cross-validation x = as.data.frame(mat) %>% mutate(label = labels) cv = vfold_cv(x, v = 3, strata = 'label') folded = cv %>% mutate( recipes = splits %>% map(~ prepper(., recipe = recipe(.$data, label ~ .))), test_data = splits %>% map(analysis), fits = map2( recipes, test_data, ~ fit( clf, label ~ ., data = bake(object = .x, new_data = .y) ) ) ) # predict on the left-out data retrieve_predictions = function(split, recipe, model) { test = bake(recipe, assessment(split)) tbl = tibble( true = test$label, pred = predict(model, test)$.pred_class, prob = predict(model, test, type = 'prob')) %>% # convert prob from nested df to columns cbind(.$prob) %>% select(-prob) return(tbl) } predictions = folded %>% mutate( pred = list( splits, recipes, fits ) %>% pmap(retrieve_predictions) ) #> Error in attr(x, "names") <- as.character(value): 'names' attribute [3] must be the same length as the vector [2]
Created on 2019-08-21 by the reprex package (v0.3.0)
The issue is caused by this line:
predict(model, test, type = 'prob')
Running traceback()
on that line alone gives the following:
10: `names<-.tbl_df`(`*tmp*`, value = value)
9: `names<-`(`*tmp*`, value = value)
8: `colnames<-`(`*tmp*`, value = object$lvl)
7: object$spec$method$pred$prob$post(res, object)
6: predict_classprob.model_fit(object, new_data = new_data, ...)
5: predict_classprob._multnet(object = object, new_data = new_data,
...)
4: predict_classprob(object = object, new_data = new_data, ...)
3: predict.model_fit(object = object, new_data = new_data, type = type,
opts = opts)
2: predict._multnet(model, test, type = "prob")
1: predict(model, test, type = "prob")
(As a side note, I am having trouble interpreting the output returned by predict(model, test)$.pred_class
. In most tidymodels predict functions, this is an object with length equal to that of test$label
, but in this example, it's three times as long - why is this?)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4