Become an expert in R — Interactive courses, Cheat Sheets, certificates and more!
Get Started for Free

tabnet_explain

Interpretation metrics from a TabNet model


Description

Interpretation metrics from a TabNet model

Usage

tabnet_explain(object, new_data)

Arguments

object

a TabNet fit object

new_data

a data.frame to obtain interpretation metrics.

Value

Returns a list with

  • M_explain: the aggregated feature importance masks as detailed in TabNet's paper.

  • masks a list containing the masks for each step.

Examples

if (torch::torch_is_installed()) {

set.seed(2021)

n <- 1000
x <- data.frame(
  x = runif(n),
  y = runif(n),
  z = runif(n)
)

y <- x$x

fit <- tabnet_fit(x, y, epochs = 20,
                  num_steps = 1,
                  batch_size = 512,
                  attention_width = 1,
                  num_shared = 1,
                  num_independent = 1)


 ex <- tabnet_explain(fit, x)

}

tabnet

Fit 'TabNet' Models for Classification and Regression

v0.1.0
MIT + file LICENSE
Authors
Daniel Falbel [aut, cre], RStudio [cph]
Initial release

We don't support your browser anymore

Please choose more modern alternatives, such as Google Chrome or Mozilla Firefox.