Become an expert in R — Interactive courses, Cheat Sheets, certificates and more!
Get Started for Free

tabnet_fit

Tabnet model


Description

Usage

tabnet_fit(x, ...)

## Default S3 method:
tabnet_fit(x, ...)

## S3 method for class 'data.frame'
tabnet_fit(x, y, ...)

## S3 method for class 'formula'
tabnet_fit(formula, data, ...)

## S3 method for class 'recipe'
tabnet_fit(x, data, ...)

Arguments

x

Depending on the context:

  • A data frame of predictors.

  • A matrix of predictors.

  • A recipe specifying a set of preprocessing steps created from recipes::recipe().

The predictor data should be standardized (e.g. centered or scaled). The model treats categorical predictors internally thus, you don't need to make any treatment.

...

Model hyperparameters. See tabnet_config() for a list of all possible hyperparameters.

y

When x is a data frame or matrix, y is the outcome specified as:

  • A data frame with 1 numeric column.

  • A matrix with 1 numeric column.

  • A numeric vector.

formula

A formula specifying the outcome terms on the left-hand side, and the predictor terms on the right-hand side.

data

When a recipe or formula is used, data is specified as:

  • A data frame containing both the predictors and the outcome.

Value

A TabNet model object. It can be used for serialization and predictions.

Threading

TabNet uses torch as it's backend for computation and torch uses all available threads by default.

You can control the number of threads used by torch with:

torch::torch_set_num_threads(1)
torch::torch_set_num_interop_threads(1)

Examples

if (torch::torch_is_installed()) {
data("ames", package = "modeldata")
fit <- tabnet_fit(Sale_Price ~ ., data = ames, epochs = 1)
}

tabnet

Fit 'TabNet' Models for Classification and Regression

v0.1.0
MIT + file LICENSE
Authors
Daniel Falbel [aut, cre], RStudio [cph]
Initial release

We don't support your browser anymore

Please choose more modern alternatives, such as Google Chrome or Mozilla Firefox.