NNpredict function
A function to produce predictions from a trained network
NNpredict( net, param, newdata, newtruth = NULL, freq = 1000, record = FALSE, plot = FALSE )
net |
an object of class network, see ?network |
param |
vector of trained parameters from the network, see ?train |
newdata |
input data to be predicted, a list of vectors (i.e. ragged array) |
newtruth |
the truth, a list of vectors to compare with output from the feed-forward network |
freq |
frequency to print progress updates to the console, default is every 1000th training point |
record |
logical, whether to record details of the prediction. Default is FALSE |
plot |
locical, whether to produce diagnostic plots. Default is FALSE |
if record is FALSE, the output of the neural network is returned. Otherwise a list of objects is returned including: rec, the predicted probabilities; err, the L1 error between truth and prediction; pred, the predicted categories based on maximum probability; pred_MC, the predicted categories based on maximum probability; truth, the object newtruth, turned into an integer class number
Ian Goodfellow, Yoshua Bengio, Aaron Courville, Francis Bach. Deep Learning. (2016)
Terrence J. Sejnowski. The Deep Learning Revolution (The MIT Press). (2018)
Neural Networks YouTube playlist by 3brown1blue: https://www.youtube.com/playlist?list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi
http://neuralnetworksanddeeplearning.com/
NNpredict.regression, network, train, backprop_evaluate, MLP_net, backpropagation_MLP, logistic, ReLU, smoothReLU, ident, softmax, Qloss, multinomial, NNgrad_test, weights2list, bias2list, biasInit, memInit, gradInit, addGrad, nnetpar, nbiaspar, addList, no_regularisation, L1_regularisation, L2_regularisation
# Example in context: download_mnist("mnist.RData") # only need to download once load("mnist.RData") # loads objects train_set, truth, test_set and test_truth net <- network( dims = c(784,16,16,10), activ=list(ReLU(),ReLU(),softmax())) netwts <- train(dat=train_set, truth=truth, net=net, eps=0.001, tol=0.8, # normally would use a higher tol here e.g. 0.95 loss=multinomial(), batchsize=100) pred <- NNpredict( net=net, param=netwts$opt, newdata=test_set, newtruth=test_truth, record=TRUE, plot=TRUE) # Example 2 N <- 1000 d <- matrix(rnorm(5*N),ncol=5) fun <- function(x){ lp <- 2*x[2] pr <- exp(lp) / (1 + exp(lp)) ret <- c(0,0) ret[1+rbinom(1,1,pr)] <- 1 return(ret) } d <- lapply(1:N,function(i){return(d[i,])}) truth <- lapply(d,fun) net <- network( dims = c(5,10,2), activ=list(ReLU(),softmax())) netwts <- train( dat=d, truth=truth, net=net, eps=0.01, tol=100, # run for 100 iterations batchsize=10, # note this is not enough loss=multinomial(), # for convergence stopping="maxit") pred <- NNpredict( net=net, param=netwts$opt, newdata=d, newtruth=truth, record=TRUE, plot=TRUE)
Please choose more modern alternatives, such as Google Chrome or Mozilla Firefox.