ML Papers: Why do tree-based models still outperform deep learning on typical tabular data?
In my ML Papers reading group, we discussed the paper Why do tree-based models still outperform deep learning on typical tabular data? Below are some notes I took during our discussion.
In my ML Papers reading group, we discussed the paper Why do tree-based models still outperform deep learning on typical tabular data? Below are some notes I took during our discussion.
For text and image datasets, deep learning has shown its clear benefits. This paper explores how simpler tree-based models may outperform deep learning models when it comes to typical tabular data. The authors aim to investigate this by benchmarking various models on diverse datasets, considering different factors such as hyper-parameter tuning and model training costs.
I found this interesting to read in a time when the risk of over-engineering is increasing as we advance in the field of machine learning and data science.
Experiment setup
Data
The experiments utilize a variety of publicly available tabular datasets. These datasets cover a range of domains, including healthcare, finance, and marketing, to ensure the findings are generalizable across different types of tabular data.
Models
Tree-Based Models: The primary tree-based models evaluated are Random Forests and Gradient Boosting Machines (e.g., XGBoost, LightGBM).
Deep Learning Models: The study includes several deep learning architectures, such as fully connected neural networks (MLPs) and specialized architectures designed for tabular data.
Bayesian deep learning models were not used in the comparison. We noted that these usually perform well for tabular data. Neural nets may also be beneficial when working with high-cardinality features. In the experiment, high-cardinality datasets were excluded. This could also have skewed the results in favour of tree-based models.
Evaluation Metrics
The performance of the models is assessed using standard evaluation metrics such as accuracy, precision, recall, F1-score for classification tasks, and mean squared error (MSE) or R-squared for regression tasks.
Key Findings
Performance and Computational Cost:
Tree-based models, such as Random Forests and Gradient Boosting Machines, generally provide better predictive performance with significantly lower computational costs compared to deep learning models.
The performance advantage of tree-based models is attributed to the specific nature of tabular data, which often includes irregular patterns in the target function and uninformative features.
Rotational Invariance:
Rotational invariance in the context of machine learning, particularly in neural networks, refers to the property of a model to produce the same output for inputs that are rotated versions of each other. This concept is more commonly discussed in domains like image processing, where it is beneficial for models to recognize objects regardless of their orientation.
In neural networks: when the input features are rotated in a multidimensional space, the output remains unchanged.
In the context of tabular data: each column represents a distinct feature with its own semantic meaning (income, weight, age). Rotating data in a higher dimensional space can mix these features, leading to a loss of the original meaningful structure of the data.
The study shows that tree-based models are sensitive to the natural basis of tabular data, which often includes meaningful individual features (e.g., age, weight). In contrast, neural networks, particularly those that are rotation invariant, may not capture this inherent structure effectively.
Random rotations of the datasets demonstrated that tree-based models' performance decreases less significantly compared to neural networks, indicating that rotation invariance is not a desirable property for tabular data.
So, neural networks are more likely to be rotationally invariant, which is actually not desirable when working with tabular data. We don’t want to be able to arbitrarily rotate or transform the original columns and lose their intrinsic meaning.
Tabular data may also contain irregular and complex patterns that are tied to the specific features. Rotational invariance might obscure these patterns, making it harder for the model to learn them.
An example of such an irregular pattern in the context of medical data: the dataset could contain features like age, cholesterol levels, blood pressure, and medication dosage. The risk of a cardiovascular event might not increase linearly with age or cholesterol levels. Instead, there could be thresholds or interaction effects, such as a significant increase in risk for patients above a certain age with cholesterol levels above a particular threshold, which might not be apparent if considering each feature independently.
Effect of Uninformative Features:
Adding or removing uninformative features affects tree-based models less severely than neural networks. Neural networks need to identify and orient the features correctly, which adds complexity and reduces performance.
Discussion and Conclusion
The paper concludes that the ease of achieving good predictions with tree-based models and their lower computational cost explain their superiority on tabular data. The specific features of tabular data, such as irregular patterns and the presence of uninformative features, could make tree-based models more suitable in these cases.
The study encourages further research into the inductive biases of tree-based models and their performance on different settings, such as small or very large datasets, and how both model types handle challenges like missing data or high-cardinality categorical features.