Interpretation metrics from a TabNet model
Interpretation metrics from a TabNet model
tabnet_explain(object, new_data)
object |
a TabNet fit object |
new_data |
a data.frame to obtain interpretation metrics. |
Returns a list with
M_explain
: the aggregated feature importance masks as detailed in
TabNet's paper.
masks
a list containing the masks for each step.
if (torch::torch_is_installed()) { set.seed(2021) n <- 1000 x <- data.frame( x = runif(n), y = runif(n), z = runif(n) ) y <- x$x fit <- tabnet_fit(x, y, epochs = 20, num_steps = 1, batch_size = 512, attention_width = 1, num_shared = 1, num_independent = 1) ex <- tabnet_explain(fit, x) }
Please choose more modern alternatives, such as Google Chrome or Mozilla Firefox.