🤖 AI Summary
Decision tree learning for tabular data faces challenges including an exponentially large search space, limitations of greedy optimization, and the lack of structural priors in deep models. Method: This paper formalizes decision tree learning for the first time as an amortized structural inference and sequential planning problem, proposing a Bayesian posterior–driven deep reinforcement learning framework based on Generative Flow Networks (GFlowNets). Contribution/Results: The method enables interpretable decision tree generation and ensemble learning, achieving high predictive accuracy while significantly improving out-of-distribution generalization and anomaly detection capabilities. It surpasses state-of-the-art decision trees and deep models on multiple real-world tabular classification benchmarks. Crucially, performance improves monotonically with ensemble size, marking the first approach to jointly optimize structural interpretability, robustness, and predictive performance in a unified framework.
📝 Abstract
Building predictive models for tabular data presents fundamental challenges, notably in scaling consistently, i.e., more resources translating to better performance, and generalizing systematically beyond the training data distribution. Designing decision tree models remains especially challenging given the intractably large search space, and most existing methods rely on greedy heuristics, while deep learning inductive biases expect a temporal or spatial structure not naturally present in tabular data. We propose a hybrid amortized structure inference approach to learn predictive decision tree ensembles given data, formulating decision tree construction as a sequential planning problem. We train a deep reinforcement learning (GFlowNet) policy to solve this problem, yielding a generative model that samples decision trees from the Bayesian posterior. We show that our approach, DT-GFN, outperforms state-of-the-art decision tree and deep learning methods on standard classification benchmarks derived from real-world data, robustness to distribution shifts, and anomaly detection, all while yielding interpretable models with shorter description lengths. Samples from the trained DT-GFN model can be ensembled to construct a random forest, and we further show that the performance of scales consistently in ensemble size, yielding ensembles of predictors that continue to generalize systematically.