Become an expert in R — Interactive courses, Cheat Sheets, certificates and more!
Get Started for Free

NNpredict

NNpredict function


Description

A function to produce predictions from a trained network

Usage

NNpredict(
  net,
  param,
  newdata,
  newtruth = NULL,
  freq = 1000,
  record = FALSE,
  plot = FALSE
)

Arguments

net

an object of class network, see ?network

param

vector of trained parameters from the network, see ?train

newdata

input data to be predicted, a list of vectors (i.e. ragged array)

newtruth

the truth, a list of vectors to compare with output from the feed-forward network

freq

frequency to print progress updates to the console, default is every 1000th training point

record

logical, whether to record details of the prediction. Default is FALSE

plot

locical, whether to produce diagnostic plots. Default is FALSE

Value

if record is FALSE, the output of the neural network is returned. Otherwise a list of objects is returned including: rec, the predicted probabilities; err, the L1 error between truth and prediction; pred, the predicted categories based on maximum probability; pred_MC, the predicted categories based on maximum probability; truth, the object newtruth, turned into an integer class number

References

  1. Ian Goodfellow, Yoshua Bengio, Aaron Courville, Francis Bach. Deep Learning. (2016)

  2. Terrence J. Sejnowski. The Deep Learning Revolution (The MIT Press). (2018)

  3. Neural Networks YouTube playlist by 3brown1blue: https://www.youtube.com/playlist?list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi

  4. http://neuralnetworksanddeeplearning.com/

See Also

Examples

# Example in context:


download_mnist("mnist.RData") # only need to download once
load("mnist.RData") # loads objects train_set, truth, test_set and test_truth

net <- network( dims = c(784,16,16,10),
                activ=list(ReLU(),ReLU(),softmax()))

netwts <- train(dat=train_set,
                truth=truth,
                net=net,
                eps=0.001,
                tol=0.8, # normally would use a higher tol here e.g. 0.95
                loss=multinomial(),
                batchsize=100)

pred <- NNpredict(  net=net,
                    param=netwts$opt,
                    newdata=test_set,
                    newtruth=test_truth,
                    record=TRUE,
                    plot=TRUE)


# Example 2

N <- 1000
d <- matrix(rnorm(5*N),ncol=5)

fun <- function(x){
    lp <- 2*x[2]
    pr <- exp(lp) / (1 + exp(lp))
    ret <- c(0,0)
    ret[1+rbinom(1,1,pr)] <- 1
    return(ret)
}

d <- lapply(1:N,function(i){return(d[i,])})

truth <- lapply(d,fun)

net <- network( dims = c(5,10,2),
                activ=list(ReLU(),softmax()))

netwts <- train( dat=d,
                 truth=truth,
                 net=net,
                 eps=0.01,
                 tol=100,            # run for 100 iterations
                 batchsize=10,       # note this is not enough
                 loss=multinomial(), # for convergence
                 stopping="maxit")

pred <- NNpredict(  net=net,
                    param=netwts$opt,
                    newdata=d,
                    newtruth=truth,
                    record=TRUE,
                    plot=TRUE)

deepNN

Deep Learning

v1.0
GPL-3
Authors
Benjamin Taylor [aut, cre]
Initial release

We don't support your browser anymore

Please choose more modern alternatives, such as Google Chrome or Mozilla Firefox.