A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://github.com/tidymodels/parsnip/issues/174 below:

predict() on a mlp with nnet double names the output with `.pred_` · Issue #174 · tidymodels/parsnip · GitHub

This problem is similar to an already closed issue(#107) but with mlp using nnet.


library(tidymodels)
#> -- Attaching packages ------------------------------------------------- tidymodels 0.0.2 --
#> v broom     0.5.1       v purrr     0.3.2  
#> v dials     0.0.2       v recipes   0.1.5  
#> v dplyr     0.8.0.1     v rsample   0.0.4  
#> v ggplot2   3.1.0       v tibble    2.1.1  
#> v infer     0.4.0       v yardstick 0.0.3  
#> v parsnip   0.0.2
#> -- Conflicts ---------------------------------------------------- tidymodels_conflicts() --
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter()  masks stats::filter()
#> x dplyr::lag()     masks stats::lag()
#> x recipes::step()  masks stats::step()
data(credit_data)

set.seed(7075)
data_split <- initial_split(credit_data, strata = "Status", p = 0.75)

credit_train <- training(data_split)
credit_test  <- testing(data_split)
credit_rec <- 
  recipe(Status ~ ., data = credit_train) %>%
  step_knnimpute(Home, Job, Marital, Income, Assets, Debt) %>%
  step_dummy(all_nominal(), -Status) %>%
  step_center(all_predictors()) %>%
  step_scale(all_predictors()) %>%
  prep(training = credit_train, retain = TRUE)

test_normalized <- bake(credit_rec, new_data = credit_test, all_predictors())

set.seed(57974)
nnet_fit <-set_engine(mlp("classification",hidden_units =10),"nnet") %>%
  fit(Status ~ ., data = juice(credit_rec))

glm_fit <- set_engine(logistic_reg(),"glm") %>% 
  fit(Status ~ ., data = juice(credit_rec))

#Issue with predict on nnet
glimpse(predict(nnet_fit, new_data = test_normalized, type = "prob"))
#> Observations: 1,113
#> Variables: 2
#> $ .pred_.pred_bad  <dbl> 0.5608545, 0.7023505, 0.3303682, 0.4221877, 0...
#> $ .pred_.pred_good <dbl> 0.4391455, 0.2976495, 0.6696318, 0.5778123, 0...

#Normal with predict on glm (No issue)
glimpse(predict(glm_fit, new_data = test_normalized, type = "prob"))
#> Observations: 1,113
#> Variables: 2
#> $ .pred_bad  <dbl> 0.04675355, 0.94317298, 0.24316454, 0.06970005, 0.0...
#> $ .pred_good <dbl> 0.95324645, 0.05682702, 0.75683546, 0.93029995, 0.9...

RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4