train function
A function to train a neural network defined using the network function.
train( dat, truth, net, loss = Qloss(), tol = 0.95, eps = 0.001, batchsize = NULL, dropout = dropoutProbs(), parinit = function(n) { return(runif(n, -0.01, 0.01)) }, monitor = TRUE, stopping = "default", update = "classification" )
dat |
the input data, a list of vectors |
truth |
the truth, a list of vectors to compare with output from the feed-forward network |
net |
an object of class network, see ?network |
loss |
the loss function, see ?Qloss and ?multinomial |
tol |
stopping criteria for training. Current method monitors the quality of randomly chosen predictions from the data, terminates when the mean predictive probabilities of the last 20 randomly chosen points exceeds tol, default is 0.95 |
eps |
stepsize scaling constant in gradient descent, or stochastic gradient descent |
batchsize |
size of minibatches to be used with stochastic gradient descent |
dropout |
optional list of dropout probabilities ?dropoutProbs |
parinit |
a function of a single parameter returning the initial distribution of the weights, default is uniform on (-0.01,0.01) |
monitor |
logical, whether to produce learning/convergence diagnostic plots |
stopping |
method for stopping computation default, 'default', calls the function stopping.default |
update |
and default for meth is 'classification', which calls updateStopping.classification |
optimal cost and parameters from the trained network; at present, diagnostic plots are produced illustrating the parameters of the model, the gradient and stopping criteria trace.
Ian Goodfellow, Yoshua Bengio, Aaron Courville, Francis Bach. Deep Learning. (2016)
Terrence J. Sejnowski. The Deep Learning Revolution (The MIT Press). (2018)
Neural Networks YouTube playlist by 3brown1blue: https://www.youtube.com/playlist?list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi
http://neuralnetworksanddeeplearning.com/
# Example in context: download_mnist("mnist.RData") # only need to download once load("mnist.RData") # loads objects train_set, truth, test_set and test_truth net <- network( dims = c(784,16,16,10), activ=list(ReLU(),ReLU(),softmax())) netwts <- train(dat=train_set, truth=truth, net=net, eps=0.001, tol=0.8, # normally would use a higher tol here e.g. 0.95 loss=multinomial(), batchsize=100) pred <- NNpredict( net=net, param=netwts$opt, newdata=test_set, newtruth=test_truth, record=TRUE, plot=TRUE) # Example 2 N <- 1000 d <- matrix(rnorm(5*N),ncol=5) fun <- function(x){ lp <- 2*x[2] pr <- exp(lp) / (1 + exp(lp)) ret <- c(0,0) ret[1+rbinom(1,1,pr)] <- 1 return(ret) } d <- lapply(1:N,function(i){return(d[i,])}) truth <- lapply(d,fun) net <- network( dims = c(5,10,2), activ=list(ReLU(),softmax())) netwts <- train( dat=d, truth=truth, net=net, eps=0.01, tol=100, # run for 100 iterations batchsize=10, # note this is not enough loss=multinomial(), # for convergence stopping="maxit") pred <- NNpredict( net=net, param=netwts$opt, newdata=d, newtruth=truth, record=TRUE, plot=TRUE)
Please choose more modern alternatives, such as Google Chrome or Mozilla Firefox.