Become an expert in R — Interactive courses, Cheat Sheets, certificates and more!
Get Started for Free

predict.nn

Neural network prediction


Description

Prediction of artificial neural network of class nn, produced by neuralnet().

Usage

## S3 method for class 'nn'
predict(object, newdata, rep = 1, all.units = FALSE, ...)

Arguments

object

Neural network of class nn.

newdata

New data of class data.frame or matrix.

rep

Integer indicating the neural network's repetition which should be used.

all.units

Return output for all units instead of final output only.

...

further arguments passed to or from other methods.

Value

Matrix of predictions. Each column represents one output unit. If all.units=TRUE, a list of matrices with output for each unit.

Author(s)

Marvin N. Wright

Examples

library(neuralnet)

# Split data
train_idx <- sample(nrow(iris), 2/3 * nrow(iris))
iris_train <- iris[train_idx, ]
iris_test <- iris[-train_idx, ]

# Binary classification
nn <- neuralnet(Species == "setosa" ~ Petal.Length + Petal.Width, iris_train, linear.output = FALSE)
pred <- predict(nn, iris_test)
table(iris_test$Species == "setosa", pred[, 1] > 0.5)

# Multiclass classification
nn <- neuralnet((Species == "setosa") + (Species == "versicolor") + (Species == "virginica")
                 ~ Petal.Length + Petal.Width, iris_train, linear.output = FALSE)
pred <- predict(nn, iris_test)
table(iris_test$Species, apply(pred, 1, which.max))

neuralnet

Training of Neural Networks

v1.44.2
GPL (>= 2)
Authors
Stefan Fritsch [aut], Frauke Guenther [aut], Marvin N. Wright [aut, cre], Marc Suling [ctb], Sebastian M. Mueller [ctb]
Initial release
2019-02-07

We don't support your browser anymore

Please choose more modern alternatives, such as Google Chrome or Mozilla Firefox.