predict()
produces a factor and multi_predict()
is character:
library(tidymodels) #> Registered S3 method overwritten by 'xts': #> method from #> as.zoo.xts zoo #> ── Attaching packages ───────────────────────────────────────────────────────────────────────────────────────── tidymodels 0.0.3 ── #> ✔ broom 0.5.2 ✔ purrr 0.3.3 #> ✔ dials 0.0.3.9001 ✔ recipes 0.1.7.9001 #> ✔ dplyr 0.8.3 ✔ rsample 0.0.5 #> ✔ ggplot2 3.2.1 ✔ tibble 2.1.3 #> ✔ infer 0.5.0 ✔ yardstick 0.0.4 #> ✔ parsnip 0.0.3.9001 #> ── Conflicts ──────────────────────────────────────────────────────────────────────────────────────────── tidymodels_conflicts() ── #> ✖ purrr::discard() masks scales::discard() #> ✖ dplyr::filter() masks stats::filter() #> ✖ dplyr::lag() masks stats::lag() #> ✖ ggplot2::margin() masks dials::margin() #> ✖ dials::offset() masks stats::offset() #> ✖ recipes::step() masks stats::step() library(tune) library(glmnet) #> Loading required package: Matrix #> #> Attaching package: 'Matrix' #> The following objects are masked from 'package:tidyr': #> #> expand, pack, unpack #> Loaded glmnet 3.0 library(mlbench) data("Satellite") mod <- multinom_reg() %>% set_engine("glmnet") fit <- mod %>% fit(classes ~ ., data = Satellite[-(1:10),]) predict(fit, new_data = Satellite[1:10, -37], penalty = .01) #> # A tibble: 10 x 1 #> .pred_class #> <fct> #> 1 grey soil #> 2 grey soil #> 3 grey soil #> 4 grey soil #> 5 grey soil #> 6 grey soil #> 7 grey soil #> 8 grey soil #> 9 damp grey soil #> 10 damp grey soil multi_predict(fit, new_data = Satellite[1:10, -37], penalty = c(.1, 1))$.pred[[1]] #> # A tibble: 2 x 2 #> .pred_class penalty #> <chr> <dbl> #> 1 grey soil 0.1 #> 2 red soil 1
Created on 2019-10-23 by the reprex package (v0.3.0)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4