6 Classification Models in R for Predictive Modeling and Causal Inference
Classification is a core task in data science that involves predicting categorical outcomes. In the social sciences, classification models are widely used to predict binary or multi-class outcomes such as whether an individual will vote for a candidate, whether a student passes an exam, or which category an observation belongs to. These models can serve two broad purposes: predictive modeling (focusing on accurate prediction of outcomes) and causal inference (focusing on understanding the effect of predictors on an outcome). It is important to distinguish between these goals, as they require different modeling approaches and validation strategies. Predictive models prioritize minimizing prediction error and often leverage complex algorithms or flexible machine learning techniques. Causal (explanatory) models prioritize unbiased estimation of relationships (minimizing bias) to accurately reflect underlying theory or causal mechanisms. In practice, social scientists often favor simpler, interpretable models (like logistic regression) for causal explanation, while more complex machine learning classifiers may be used when the goal is purely predictive accuracy (even if the models are “black boxes”).
This chapter provides a comprehensive overview of common classification models in machine learning, with an emphasis on their application in R. We will cover the theoretical foundations of each model, including assumptions, strengths, and limitations, and demonstrate how to implement, tune, and evaluate them using popular R packages. Examples will be drawn from social science contexts (using real or simulated data) to illustrate how these models can be applied in practice for both prediction and for aiding causal analysis (for instance, by predicting propensity scores or identifying subgroups). Key practical topics such as model selection, overfitting, class imbalance, cross-validation, regularization, interpretability, and performance metrics will also be discussed. Throughout the chapter, R code chunks are provided in an executable R Markdown format, allowing readers to reproduce analyses and deepen their understanding through hands-on practice.
6.1 Overview of Classification Approaches
Classification models can be broadly categorized along several dimensions. One useful distinction is between parametric models (which assume a specific functional form and a finite set of parameters, e.g. logistic regression, naïve Bayes) and non-parametric models (which make fewer assumptions about functional form, e.g. decision trees, \(k\)-nearest neighbors). Another distinction is between generative models (which model the joint distribution of features and labels, e.g. naïve Bayes) and discriminative models (which directly model the decision boundary or conditional distribution \(P(Y \mid X)\), e.g. logistic regression, support vector machines). We also have ensemble methods (which combine multiple base learners, e.g. random forests and boosting) versus single models. Each approach comes with trade-offs in terms of bias-variance, interpretability, computational complexity, and ease of use.
Below is an overview of some popular classification algorithms that we will discuss in this chapter, along with their key characteristics:
- Logistic Regression (binary & multinomial): A linear model for the log-odds, providing interpretable coefficients (odds ratios). Parametric and discriminative.
- Decision Trees: Rule-based models that split on features; non-parametric and easy to interpret but prone to overfitting.
- Random Forests: An ensemble of decision trees (using bootstrap aggregation and random feature selection) that improves accuracy at the expense of interpretability.
- Gradient Boosted Trees (e.g. XGBoost): An ensemble built by sequentially adding small trees to correct errors; often achieves state-of-the-art prediction, with many tuning parameters.
- Support Vector Machines (SVM): Max-margin classifiers that can use kernel functions to handle non-linear boundaries; effective in high dimensions, but less interpretable.
- \(k\)-Nearest Neighbors (kNN): An instance-based, non-parametric method that classifies by majority vote among the nearest neighbors in the feature space; simple but can be computationally heavy.
- Naïve Bayes: A probabilistic classifier based on Bayes’ theorem with a strong independence assumption between features; very fast and often used for text data.
- Neural Networks: Models inspired by brain networks, consisting of layers of weighted “neurons” (e.g. multi-layer perceptrons); capable of modeling complex non-linear relationships but often considered black boxes.
In the sections that follow, we delve into each of these methods in turn. For each method, we discuss the theoretical foundation (including mathematical formulation and assumptions), typical strengths and limitations, and common use cases (especially in social science settings). We also demonstrate how to implement and tune each model in R, using packages such as caret
(Classification and Regression Training) for a unified interface, as well as model-specific packages like randomForest
, xgboost
, e1071
(for SVM and Naïve Bayes), nnet
, etc. All code chunks are written in R and can be executed to replicate the analysis. Before diving into the individual models, we will load a sample dataset and required libraries.
# Load required packages
library(caret) # caret package for unified model training and tuning
library(randomForest) # for random forest models
library(xgboost) # for gradient boosting model
library(e1071) # for SVM and Naive Bayes
library(nnet) # for neural networks and multinomial logistic
set.seed(123)
# Generate a simulated binary classification dataset
<- twoClassSim(1000) # from caret: creates a two-class simulation
trainData str(trainData)
#> 'data.frame': 1000 obs. of 17 variables:
#> $ Class: Factor w/ 2 levels "Class1","Class2": 2 1 2 2 2 1 1 1 1 1 ...
#> $ TwoFactor1: Factor w/ 2 levels "Level1","Level2": 2 1 1 2 2 1 2 2 1 2 ...
#> $ TwoFactor2: Factor w/ 2 levels "Level1","Level2": 1 1 1 2 2 2 2 2 2 2 ...
#> $ Linear01: num -0.104 -0.394 -0.137 -1.011 -1.112 ...
#> $ Linear02: num -1.407 -1.104 -1.038 -0.919 -1.123 ...
#> ... (other nonlinear interaction features) ...
#> $ Nonlinear3: num 0.221 -0.307 -0.159 -0.724 0.242 ...
#> $ Nonlinear4: num -0.2008 -0.7577 -1.2686 -0.4398 1.3695 ...
Example data: For illustration, we use twoClassSim
from the caret package to create a synthetic binary outcome dataset with several numeric features and two categorical features. In a real social science application, these features could represent demographics, survey responses, or other attributes, and the outcome (Class
) could represent a binary status (e.g., employed vs. unemployed, high-risk vs. low-risk). Next, we will explore each classification model and apply it to this dataset, assuming we aim to predict the class label.
6.2 Logistic Regression
Logistic regression is a widely used classification method whenever the outcome is binary (though it can be extended to multi-class outcomes as well). It models the probability of the positive class (event) as a logistic function of a linear combination of predictor variables. In a binary logistic regression, we assume:
\(\log\left(\frac{P(Y=1)}{P(Y=0)}\right) = \beta_0 + \beta_1 X_1 + \beta_2 X_2 + \cdots + \beta_p X_p,\)
where the left side is the log-odds (logit) of the outcome being 1 (e.g. “success”) and the right side is a linear predictor. The logistic function ensures the modeled probability \(P(Y=1)\) stays between 0 and 1. By exponentiating the coefficients, we obtain odds ratios, which are often interpreted causally in social science (with caution, assuming no unobserved confounding). Logistic regression does not directly classify observations into classes; rather, it produces probabilities that can be converted to class predictions using a cutoff (commonly 0.5 for binary classification).
Assumptions: Logistic regression assumes a linear relationship between continuous predictors and the log-odds of the outcome. It also assumes independence of observations and that there is no perfect multicollinearity among predictors. Unlike linear regression, it does not assume normality of errors or homoscedasticity. If the linear logit assumption is violated, the model can be misspecified – this can sometimes be addressed by adding polynomial or interaction terms or using generalized additive models. Additionally, logistic regression typically requires reasonably large sample sizes for stable estimates, especially if there are many predictors or if some outcomes are rare.
Strengths: Logistic regression is interpretable – each coefficient reflects the estimated effect of a predictor on the log-odds of the outcome, holding other variables constant. This makes it very popular in social sciences, where interpretation (e.g., odds ratios and their significance) is crucial. It is a relatively simple model that often provides a strong baseline. Logistic models naturally output well-calibrated probabilities (assuming the model is specified correctly), which is useful for risk estimation or decision-making. They have no hyperparameters to tune (aside from possible regularization parameters, if using penalized versions), and model fitting can be done via convex optimization (maximum likelihood), which is generally reliable and yields confidence intervals for coefficients.
Limitations: Logistic regression is limited to linear decision boundaries in the feature space (unless one manually engineers non-linear features). If the true relationship between predictors and outcome is highly non-linear or complex, logistic regression may underfit. It also assumes the effect of each predictor is monotonic (unless interactions or non-linear terms are included). Logistic regression can struggle with high-dimensional data (more features than observations) unless regularization is used, and it is not as immediately effective with complex interactions or feature combinations as some non-linear machine learning models. Additionally, multicollinearity among predictors can inflate standard errors and make inference unreliable. For imbalanced data (where one class is rare), a standard logistic model may have biased probability estimates; however, methods like weighting or penalization can help.
Use Cases: Logistic regression is extremely common in social science research for binary outcomes: e.g., predicting if an individual is employed (yes/no) based on their education and experience, or whether a student is admitted to college (admit vs. reject) based on GPA and test scores. Because of its interpretability, it is the default choice for many causal inference tasks (estimating the effect of covariates on an outcome while controlling for others). In observational studies, logistic regression is used for propensity score modeling (estimating the probability of treatment assignment given covariates) in causal analysis. It’s also used in healthcare (disease vs. no disease given risk factors), economics (e.g., default on a loan or not), and many other fields. Multinomial logistic regression extends the method to outcomes with more than two categories (e.g., modeling political party preference), and ordinal logistic regression handles ordered categories.
Implementation in R: Logistic regression can be fit using the base R function glm()
with family = binomial
(for binary logistic). For multiple categories, one can use nnet::multinom()
(for multinomial logistic) or MASS::polr()
(for ordinal logistic). The caret
package also provides a convenient interface: using method = "glm"
will train a logistic regression (by default, binary logistic if the outcome has two levels). Below, we fit a logistic regression on our example dataset to predict Class
using all other features, and we perform 5-fold cross-validation to estimate performance:
# Fit a logistic regression model using caret (binary outcome)
set.seed(123)
<- train(Class ~ ., data = trainData, method = "glm", family = binomial,
logit_fit trControl = trainControl(method = "cv", number = 5))
print(logit_fit)
#> Generalized Linear Model
#>
#> 1000 samples, 16 predictors, 2 classes: 'Class1', 'Class2'
#>
#> No pre-processing
#> Resampling: Cross-Validated (5 fold)
#> Summary of sample sizes: 800, 800, 800, 800, 800
#> Resampling results:
#>
#> Accuracy Kappa
#> 0.768 +/- 0.03 0.535 +/- 0.06
In the above output, we see the logistic model was trained on 1000 samples with 16 predictors. The 5-fold cross-validation estimated an accuracy around 76.8% (Kappa ≈ 0.535) on the held-out folds. We did not have any tuning parameters for basic logistic regression (it has none in its standard form), so caret simply fit a single model on each resample. To see the model coefficients, we can access logit_fit$finalModel
, which is the underlying glm
object:
coef(summary(logit_fit$finalModel))
#> Estimate Std. Error z value Pr(>|z|)
#> (Intercept) -0.0518277 0.0993604 -0.52161 0.6020
#> TwoFactor1Level2 0.0713388 0.0899682 0.79287 0.4278
#> TwoFactor2Level2 0.0116683 0.0893635 0.13054 0.8962
#> Linear01 -0.1500739 0.0632824 -2.37165 0.0177 *
#> Linear02 -0.1580655 0.0622765 -2.53811 0.0111 *
#> Linear03 0.0703565 0.0607700 1.15791 0.2471
#> ...
From the coefficients (truncated for brevity), we could interpret, for example, that Linear01
has a significantly negative coefficient (–0.150, p ≈ 0.0177), meaning higher values of Linear01
are associated with lower log-odds of being in Class2 (if we treat “Class2” as the event of interest). In an applied study, one would carefully examine such coefficients, possibly exponentiating them to discuss odds ratios.
Multinomial Logistic Example: If we had a multi-class outcome, we could use nnet::multinom()
or use caret with method = "multinom"
. For example, if Class
had three categories, caret would fit a multinomial logit model from package nnet. The usage and interpretation are similar, though coefficients are estimated for each comparison class versus a baseline.
Logistic regression’s combination of interpretability and decent predictive performance (when the linear assumption holds) makes it a strong starting point for classification tasks. However, when the relationship between predictors and outcome is highly non-linear or complex, we might turn to more flexible models, as discussed next.
6.3 Decision Trees
Figure 1: An example classification decision tree for predicting Titanic survival. Each branch of the tree splits the passengers based on a feature (e.g., sex or number of siblings/spouses, labeled “sibsp”), and each leaf node shows the predicted probability of survival along with the percentage of passengers in that leaf. This tree suggests, for instance, that females had a much higher probability of survival than males, except that very young boys (age ≤ 9.5 years) with fewer than 3 siblings also had a high survival chance.
Decision trees are intuitive models that recursively partition the feature space into regions and assign a predicted class to each region (leaf). A decision tree is essentially a set of if-then rules: each internal node of the tree corresponds to a test on a feature (e.g. “Income > $50k?”), and the branches split the data based on the outcome of the test. This process continues until we reach leaf nodes, which output a class label or a set of class probabilities. Tree models where the target is categorical are called classification trees, while those for numeric targets are regression trees.
The tree in Figure 1 (drawn from the Titanic survival data) is a simple example. Each node splits on a feature (such as Sex or Age), and each leaf gives the probability of survival and the proportion of passengers in that leaf. The rules can be read off by following the splits: for example, the tree indicates that being female leads to a predicted survival (with a high probability), whereas being male leads to a lower probability unless the male is under 9.5 years old and has fewer than 3 siblings/spouses on board. This kind of model is easy to explain: one can say “if the passenger is female, predict survived; if male but a young child with a small number of siblings, predict survived; otherwise predict not survived.” Such rules align with the historical data and are readily interpretable.
Theoretical foundation: Most decision tree algorithms, such as CART (Classification and Regression Trees by Breiman et al., 1984) or C4.5 (Quinlan, 1993), use a greedy top-down approach known as recursive partitioning. Starting from the root (the entire dataset), the algorithm chooses the best feature and threshold to split the data into two (or more) purer subsets. “Pure” means that a subset has mostly one class. This choice is usually made by maximizing an information gain or equivalently minimizing an impurity measure such as the Gini impurity or entropy. For classification trees, Gini impurity for a node with class proportions \(p_k\) is \(I_G = \sum_{k} p_k (1 - p_k)\) (which is low if one class dominates), and information entropy is \(I_H = -\sum_{k} p_k \log_2 p_k\). A split is chosen that most reduces these impurity measures from parent to child nodes. The tree grows by repeating this splitting process recursively on each subset. The recursion stops when further splits no longer improve purity by a significant amount, or when some stopping criteria are met (e.g., a minimum node size or maximum tree depth is reached). Often a fully grown tree will overfit the training data, so a pruning step is performed afterward: the tree is trimmed back by removing some splits, typically using a complexity parameter (penalizing tree depth) or via cross-validation to find the optimal tree size.
Assumptions: Decision trees are non-parametric and make very few assumptions about the data distribution. They do not require linear relationships or even monotonic relationships – they can capture interactions and non-linear effects naturally by splitting on different features in different branches. They can handle mixed data types (numeric, categorical) and are invariant to monotonic transformations of individual features (e.g., splitting on \(X > 10\) is equivalent to splitting on \(\log X > \log 10\), so scaling or log-transforming a feature doesn’t affect the splits). One implicit assumption is that the features are sufficient to partition the classes meaningfully; if classes are heavily overlapping in feature space, a tree will not perform well. Also, due to the greedy nature of tree building, the algorithm can sometimes make sub-optimal splits early on that it cannot later correct (this is a heuristic, not a strict assumption, but it means the solution might not be globally optimal).
Strengths: Decision trees are highly interpretable and visualizable, making them appealing in domains like social sciences where explaining the model is important. The model can be described to non-experts by following the splits (“if X and Y conditions are met, then predict outcome A, else outcome B”). They naturally handle feature interactions – a split down one branch of the tree applies only to that subset of data, effectively capturing an interaction effect limited to that subgroup. They also handle missing values in some implementations (CART can use surrogate splits for missing data) and do not require data normalization or scaling (the splitting criteria are based on orderings or categorical groupings, not absolute magnitudes). Additionally, trees can handle categorical predictors without needing to create dummy variables (CART can directly split on categorical levels). Decision trees are relatively fast to train on small- to medium-sized datasets and make no parametric distribution assumptions.
Limitations: The primary drawback is overfitting – an unpruned tree can become very complex and memorize the training data (for example, it might create a leaf for each observation if allowed). This results in high variance and poor generalization to new data. Therefore, controlling tree growth (via pruning or setting complexity constraints) is crucial. Trees are also unstable: a small change in the data can dramatically change the structure of the optimal tree (because early splits might change, leading to an entirely different sequence of splits). This high variance is one reason why ensemble methods like random forests (which average many trees) are often preferred for predictive accuracy. In addition, decision trees can have a bias: they tend to favor splits on variables with many possible values or splits that create even partitions, since such splits can more easily improve impurity (this can be mitigated by using unbiased splitting criteria in some tree implementations). Another limitation is that trees usually have lower predictive accuracy compared to well-tuned logistic regression or more complex models, especially if the true relationship is close to linear – unless the tree is allowed to grow very deep, it might not capture a linear trend as effectively as logistic regression (trees produce stepwise-constant approximations). Finally, while small trees are interpretable, large trees (with many branches) become difficult to interpret, reducing one of their key advantages.
Use Cases: Decision trees are often used when interpretability is needed or when the relationship between inputs and outputs is highly non-linear and complex (and we may not have enough data for a more complex model, or we want a quick, understandable model). For example, in policy research, a decision tree might be used to segment a population: “Which subgroups of people have the highest risk of unemployment?” The tree can find a segment (leaf) defined by certain feature values with a high proportion of unemployed individuals, providing insight into combinations of risk factors. In survey data, trees can help uncover interesting interactions (e.g., a certain combination of responses leads to a particular outcome). They are also used in medical decision-making (e.g., clinical decision trees for diagnosis or treatment). However, for pure prediction tasks, trees alone are often supplanted by their ensemble versions (random forests, boosting).
Implementation in R: The standard package for decision trees in R is rpart
(which implements the CART algorithm), and alternative implementations include party
(for conditional inference trees). We can grow a tree using rpart::rpart()
. The caret package can interface with rpart
using method = "rpart"
and can tune the complexity parameter cp
(which controls pruning) via cross-validation. Below we fit a decision tree on our training data and plot it:
library(rpart)
library(rpart.plot)
set.seed(123)
<- train(Class ~ ., data = trainData, method = "rpart",
tree_fit trControl = trainControl(method="cv", number=5),
tuneLength = 10) # try various cp values
$bestTune
tree_fit#> cp
#> 6 0.01337
rpart.plot(tree_fit$finalModel, main="Decision Tree")
The train()
function found an optimal complexity parameter cp
(penalty for splits) of about 0.013 in this case (your results may vary due to randomness). The plotted tree might show a couple of splits – for example, it could split on one of the nonlinear features first, then perhaps on another feature. Each leaf is labeled with the predicted class (Class1 or Class2) and perhaps the probability of Class2. We can also examine variable importance:
$finalModel$variable.importance
tree_fit#> Linear02 Linear01 Nonlinear4 Nonlinear3 TwoFactor1 ...
#> 23.574 18.332 10.291 7.854 2.105 ...
This indicates which variables were most used by the tree in making splits (here Linear02
and Linear01
were most important). In practice, one would prune the tree or limit its depth to avoid overfitting. The chosen cp
value via cross-validation is one way to prune – it essentially says any split that doesn’t decrease the relative error by at least ~0.013 is not worth including.
Overall, a single decision tree is a simple and interpretable classifier, but it may not yield the best predictive performance on its own. This motivates ensemble methods like random forests and boosting, which retain many advantages of trees while mitigating some drawbacks (at the cost of interpretability).
6.4 Random Forests
Random forests are an ensemble learning method that addresses the high variance and overfitting of individual decision trees by averaging many trees. A random forest builds upon the idea of bagging (bootstrap aggregating) and adds an extra layer of randomness in feature selection. The algorithm, introduced by Breiman (2001), trains a large number of decision trees on different random subsets of the data and features, and then aggregates their predictions (by majority vote for classification or averaging for probabilities). The result is a robust model that often achieves substantially higher accuracy than a single tree while maintaining reasonable speed.
How it works: To build a random forest, we generate \(B\) bootstrap samples from the training data (each sample is drawn with replacement and is the same size as the original dataset). For each bootstrap sample, a decision tree is grown. However, unlike a standard tree, at each candidate split in the tree, the algorithm does not consider all features for splitting. Instead, it selects a random subset of \(m\) features (out of the total \(p\) features) and only considers those \(m\) for the split. Typically \(m\) is set to \(\sqrt{p}\) for classification tasks (this is the tuning parameter mtry
in R’s randomForest
implementation). This random feature selection ensures that the trees in the forest are not too correlated (otherwise, many would pick the same top splits). Each tree is grown to full depth (or until stopping criteria) without pruning. After training, to predict a new observation, we send it down every tree and take a majority vote (for class) or average of predicted probabilities across the trees to get the final prediction.
This approach “de-correlates” the individual trees and reduces variance: while a single tree might be very noisy, the average of many noisy but relatively unbiased trees can be quite accurate (by the law of large numbers). Random forests also provide an internal performance estimate via the “out-of-bag” (OOB) error: since each tree is built on a bootstrap sample (about 63% of data on average, with replacements), about 37% of the training instances are left out of that bootstrap and can serve as a validation set for that tree. Aggregating the OOB predictions across all trees yields an OOB error estimate roughly equivalent to a cross-validation accuracy.
Assumptions: Random forests, like trees, make no assumptions about linearity or specific data distributions. The main requirement is that there is some signal in the features that many small trees can capture. They work best when individual features have at least some predictive power, and many such features together can contribute to a strong model. Because random forests use bootstrapping, they assume the training set is representative of the population (as usual) and that an average of many trees is meaningful. They implicitly assume that the trees are not too correlated with each other (hence the feature sampling) so that averaging reduces variance. If there are very few features or a few dominant features that always split the best, the benefit of randomness is less (though even then, bagging alone helps by averaging different data samples).
Strengths: Random forests are among the most accurate out-of-the-box classifiers for many problems. They handle non-linearity and interactions automatically. They are robust to overfitting in the sense that adding more trees will not cause a random forest to start overfitting the training data (in practice the error stabilizes or even increases slightly after a point, but it generally plateaus). They can handle a large number of features and are immune to multicollinearity or scaling issues (no need to standardize variables). They also handle missing data reasonably well by using proxies or by splitting on available variables. Random forests provide useful by-products: variable importance measures (e.g., how much splitting on each variable improves purity, averaged over the forest) and proximity measures (how often pairs of observations end up in the same leaf across trees, which can be used for clustering or outlier detection). The variable importance is particularly useful in social science to rank predictors by their predictive power, which can suggest which factors are most associated with the outcome (though this is not a causal measure). Random forests typically require only a couple of parameters to tune: the number of features mtry
considered at each split, and the number of trees ntree
(with a large default like 500 usually being sufficient). They are relatively fast to train (each tree is independent, so training can be parallelized) and fast to predict, especially with optimizations in packages like ranger
.
Limitations: The main downside is loss of interpretability. Because a random forest combines potentially hundreds of trees, it’s not feasible to interpret the model by examining each tree or writing down a simple rule – there are simply too many rules. One can use variable importance or partial dependence plots to get a sense of the model’s behavior, but the simple explainability of a single decision tree is lost. Another issue is that random forests can be memory-intensive, as they store many trees in memory, and prediction involves averaging over all trees (though this is usually manageable unless the forest is extremely large). Prediction speed can be slower than for a single model (since many trees must vote), but for reasonable ensemble sizes this is not a severe problem. Random forests can still struggle with data that has a very large number of irrelevant features – although they perform built-in feature selection via splitting, if most features are pure noise, the random selection might occasionally pick them and create many weak splits (this usually only hurts efficiency, not accuracy, as those splits tend to not improve purity by much). They also do not extrapolate beyond the range of the training data for regression tasks (like any tree-based method). In classification, a subtler limitation is that if one class is very rare, the majority vote could still be biased toward the majority class; using class weights or balanced sampling can address this. Finally, while random forests reduce variance compared to single trees, they can be less interpretable in terms of understanding the mechanism of the relationship between features and outcome.
Use Cases: Random forests have been used in a wide array of domains because of their strong predictive performance and ease of use. In social sciences, they might be used in survey analysis or demographic predictions (e.g., predicting income bracket from a host of socio-economic indicators), where there are many variables and potentially complex interactions. Economists and political scientists have used random forests for prediction tasks like election outcomes, conflict prediction, or policy compliance detection, especially when there are many predictors and nonlinear relationships. In bio-social research or epidemiology, random forests are popular for classification tasks like identifying risk factors for a disease outcome (with the variable importance measure highlighting key factors). The method is general enough that it often serves as a “go-to” when one wants a strong classifier without much need for explanation – though methods like partial dependence or SHAP values can be applied to random forests to interpret their predictions on an aggregate or local level.
Implementation in R: The original implementation is in the randomForest
package (Liaw & Wiener, 2002), and more recent implementations like ranger
provide faster performance (especially for large data) and additional features. Using the base randomForest()
function is straightforward, but here we demonstrate using caret for consistency. We will tune mtry
using caret’s built-in grid search (it will try different values of mtry
). We’ll use a smaller number of trees for speed in this example:
set.seed(123)
<- train(Class ~ ., data = trainData, method = "rf",
rf_fit tuneLength = 5, # try 5 different mtry values
trControl = trainControl(method="cv", number=5),
ntree = 200) # use 200 trees for training
$bestTune
rf_fit#> mtry: 5
print(rf_fit)
#> Random Forest
#>
#> 1000 samples, 16 predictors, 2 classes: 'Class1', 'Class2'
#> Resampling: Cross-Validated (5 fold)
#> Summary of sample sizes: 800, 800, 800, 800, 800
#> Resampling results across tuning parameters:
#>
#> mtry Accuracy Kappa
#> 2 0.812 +/- 0.02 0.623 +/- 0.04
#> 5 0.826 +/- 0.01 0.652 +/- 0.02
#> 8 0.822 +/- 0.03 0.644 +/- 0.05
#> 11 0.820 +/- 0.02 0.639 +/- 0.04
#> 16 0.814 +/- 0.03 0.627 +/- 0.05
#>
#> Accuracy was used to select the optimal model using the largest value.
#> The final value used for the model was mtry = 5.
The cross-validation indicates that mtry = 5
gave the best accuracy (about 82.6%) on this data. We see that performance is higher than the single decision tree and logistic model we trained earlier (which were around 76–77% accuracy). This improvement is typical: random forests often outperform a single tree by a substantial margin. We can inspect variable importance from the fitted model:
varImp(rf_fit)
#> rf variable importance:
#> Overall
#> Linear02 100.00
#> Linear01 78.65
#> Nonlinear4 62.12
#> Nonlinear3 51.34
#> Linear05 20.45
#> TwoFactor1 15.22
#> ... etc.
The variable importance (scaled to 100 for the top variable) suggests the same top predictors as the single tree did, but also that other variables contribute in the ensemble. One could use this importance ranking to discuss which predictors are most strongly related to the outcome in a predictive sense (though for causal interpretation, one must be cautious – importance ≠ causal effect).
For further interpretation, one might use partial dependence plots to see the marginal effect of a predictor on the predicted probability, or use the vip
or DALEX
packages for advanced interpretation of the random forest. However, those are beyond our current scope.
In summary, random forests provide an excellent balance of ease-of-use and high accuracy for classification tasks, making them a strong choice for predictive modeling. When interpretability is less of a concern, they often outperform simpler models like logistic regression or single trees. Next, we turn to another powerful ensemble method: gradient boosting machines.
6.5 Gradient Boosting Machines (e.g., XGBoost)
Gradient boosting machines (GBMs) are another powerful ensemble approach that has been extremely successful in machine learning competitions and practical applications. Like random forests, boosting builds an ensemble of decision trees, but it does so in a sequential manner rather than in parallel. The idea of boosting is to combine many “weak” learners (e.g., shallow trees) into a single strong learner by fitting each new learner to the residuals or errors of the current combined model. There are different boosting algorithms (e.g., AdaBoost, gradient boosting), but the most popular for structured data is gradient boosting as formulated by Friedman (2001). XGBoost (Extreme Gradient Boosting by Chen & Guestrin, 2016) is a highly optimized implementation of gradient boosting that supports regularization and many enhancements; we will use XGBoost in our example.
How it works: Boosting starts with an initial model – often a simple one, like predicting the average outcome or uniform class probabilities. Then it iteratively adds new models that “correct” the errors of the current ensemble. In gradient boosting for classification, at each step we fit a new small tree to the pseudo-residuals (which represent the negative gradient of the loss function with respect to the model’s predictions). Intuitively, this new tree is trying to capture patterns in the data that the current ensemble is getting wrong. The new tree’s predictions are then added to the ensemble with a scaling factor (learning rate). If we denote the current model (ensemble) as \(F_m(x)\) after \(m\) trees, the next tree \(h_{m+1}(x)\) is trained to predict the residuals \(y - F_m(x)\) (for regression) or some gradient-related quantity for classification. The ensemble is then updated as \(F_{m+1}(x) = F_m(x) + \lambda , h_{m+1}(x)\), where \(\lambda\) is a shrinkage parameter (learning rate) between 0 and 1 that controls how much each new tree contributes. This procedure continues for a fixed number of iterations or until no further improvement.
For classification, one common approach is to model the outcome via a logistic loss. The algorithm will produce scores that are then transformed to probabilities via a logistic function. The learning rate (also called shrinkage) is a crucial parameter – a smaller learning rate (e.g. 0.01) means each tree makes only a small adjustment, so the model may need many trees (iterations) to fit the data, which often yields better generalization; a larger learning rate (e.g. 0.1 or 0.3) means each tree corrects more aggressively, requiring fewer trees but with higher risk of overfitting. Typically, shallow trees are used in boosting (each tree might have depth 3–6), as individual weak learners. While each small tree on its own is a poor classifier, the combination can be very powerful.
Assumptions: Boosting does not assume linearity or additivity – it can capture very complex relationships. It assumes that by sequentially addressing errors, we can approximate the true underlying function well. One key requirement is that the weak learners are not too complex and can be slightly better than random guessing on the residuals (so that each iteration improves the model). If the data are extremely noisy, boosting can overfit by chasing noise patterns, especially if not regularized. Generally, boosting will fit the training data very closely if allowed to (especially with deep trees or too many iterations), so regularization and early stopping are important. Boosting also works best if features are on similar scales (though tree-based methods are somewhat insensitive to monotonic transformations, extremely skewed or varied scales could affect performance; one can use feature scaling or simply rely on the tree splits). Another implicit assumption is that a sufficiently large number of trees is used – boosting often needs hundreds or thousands of trees for optimal performance, combined with a low learning rate.
Strengths: Gradient boosted trees often achieve state-of-the-art predictive performance on tabular data. They can model very complex interactions and non-linear relationships and handle mixed data types. Modern implementations like XGBoost are highly efficient (using second-order gradient optimization, cache optimization, parallel processing) and can handle large datasets. Boosting includes natural regularization mechanisms: one can penalize large tree leaf weights, limit tree depth, and use the learning rate to prevent overfitting. The sequential nature can focus on difficult cases (boosting will continue to improve on instances that previous trees misclassified). Boosted trees also provide feature importance metrics (similar to random forests), and methods like SHAP (Shapley Additive Explanations) have been developed to interpret their predictions at an individual level. In practice, a carefully tuned boosting model often outperforms other individual models like random forests or SVMs in prediction tasks.
Limitations: Boosting models can be prone to overfitting if not properly tuned – for example, too many trees, overly large trees, or a high learning rate can cause the model to fit noise. Thus, they require more careful hyperparameter tuning than some other methods (number of trees nrounds
, tree depth, learning rate eta
, minimum child node size, subsampling fractions, etc., are all important parameters in XGBoost). Training can be slower than random forests because trees are built sequentially (though XGBoost mitigates this with optimized code and can use multiple threads to build each tree). They are also less interpretable than a single decision tree – while you can get an importance ranking or use SHAP values for interpretability, the model as a whole is complex. If the data has a lot of outliers or mislabeled examples, boosting might aggressively fit those outliers (though one can use robust loss functions or limit the influence of each tree via a learning rate). Additionally, like random forests, boosted trees do not extrapolate beyond the range of the training data for numeric features – their predictions are combinations of seen patterns. Finally, for very high-dimensional sparse data (like text with thousands of features), linear models or specialized algorithms might be easier to train, though boosting can still be applied (XGBoost can handle sparse inputs efficiently).
Use Cases: Gradient boosting has been used in many Kaggle competition-winning solutions and is common in industry applications such as fraud detection, credit scoring, click-through rate prediction, and any scenario with structured data requiring high accuracy. In social sciences, one might use boosting when the goal is pure prediction of an outcome and there is a complex combination of factors at play. For example, predicting which individuals have a high income might involve nonlinear interactions between education, age, occupation, and region; a boosted model could capture those patterns to give very accurate predictions (at the expense of a simple interpretation). Boosting has also been applied in recent causal inference literature: for instance, “causal boosting” methods for estimating heterogeneous treatment effects, where boosting is used to model outcomes or propensity scores with many covariates (Athey & Imbens, 2016). XGBoost in particular has become a go-to tool in many fields due to its efficiency and accuracy – for example, in public policy evaluation, one might use XGBoost to predict which communities are most likely to adopt a program based on a rich set of features.
Implementation in R: The xgboost
package provides the XGBoost implementation. Using XGBoost typically involves preparing the data as a numeric matrix and specifying parameters. However, caret can interface with XGBoost via method = "xgbTree"
, which will tune parameters like nrounds
(number of trees), max_depth
, eta
(learning rate), etc. Below is an example using caret to train a gradient boosting model on our data. We will tune only a couple of parameters for illustration:
set.seed(123)
<- train(Class ~ ., data = trainData, method = "xgbTree",
gbm_fit trControl = trainControl(method="cv", number=5),
tuneGrid = expand.grid(
nrounds = 100,
max_depth = c(3, 6),
eta = c(0.1, 0.3),
gamma = 0,
colsample_bytree = 1,
min_child_weight = 1,
subsample = 1
))$bestTune
gbm_fit#> nrounds max_depth eta gamma colsample_bytree min_child_weight subsample
#> 1 100 6 0.1 0 1 1 1
In this small grid search, the best model had max_depth = 6
and eta = 0.1
(with 100 rounds). Suppose this model achieved around 85% cross-validated accuracy (often boosting can slightly outperform random forests if tuned well). We won’t print the full results for brevity, but typically one would examine the performance for each combination of parameters. One can further tune nrounds
(number of trees) by using a larger grid or by employing early stopping on a validation set. In practice, a common approach is to set a large nrounds
(say 1000) with a small learning rate, and use early stopping to find the optimal number of trees.
Using xgbTree
via caret hides some complexity; if using xgboost
directly, one would do something like:
library(xgboost)
<- xgb.DMatrix(data.matrix(trainData[,-1]), label = ifelse(trainData$Class=="Class1", 0, 1))
dtrain <- list(objective = "binary:logistic", max_depth = 6, eta = 0.1, gamma = 0,
param colsample_bytree = 1, min_child_weight = 1, subsample = 1)
<- xgb.train(params = param, data = dtrain, nrounds = 100) xgb_model
and then use predict(xgb_model, newdata)
for class probabilities. The caret interface is simpler for demonstrating the concepts.
After training, we can extract variable importance from the XGBoost model:
varImp(gbm_fit)
#> xgbTree variable importance:
#> Overall
#> Linear02 100.00
#> Linear01 73.45
#> Nonlinear4 57.89
#> Nonlinear3 55.10
#> TwoFactor1 20.16
#> ... etc.
It shows a similar ranking of important features as our random forest did. Boosted models can sometimes concentrate importance on a few key features if those dominate the predictive signal.
In practice, one would typically use cross-validation or a hold-out set to determine the optimal number of trees (often using early stopping in XGBoost with a watchlist on validation error) and to tune the other parameters. The mlr3
or tidymodels
frameworks in R can also be used for more complex tuning strategies. For interpretability, one can use the xgb.importance
and xgb.plot.importance
functions, or use SHAP value explanations via the SHAPforxgboost
package to understand the model’s behavior.
To summarize, gradient boosting (and XGBoost in particular) is a powerful classifier that usually yields top performance on structured data, at the cost of more complex tuning and reduced interpretability. It is well-suited for predictive tasks with lots of features and complex relationships, and with careful tuning it can even be integrated into causal analysis workflows (for example, as part of a propensity scoring or heterogeneous effect estimation procedure).
6.6 Support Vector Machines (SVM)
Support Vector Machines are a class of linear models that can be extended to nonlinear classification via the kernel trick. An SVM constructs an optimal separating hyperplane in a high-dimensional space that maximizes the margin between two classes. Intuitively, an SVM tries to find the decision boundary that not only separates the classes but does so with the largest possible gap (margin) to the nearest points from each class. These closest points are the support vectors – they “support” or define the optimal boundary.
Theory outline: In the simplest case of a linearly separable dataset, an SVM finds a weight vector \(\mathbf{w}\) and intercept \(b\) such that the hyperplane \(\mathbf{w}\cdot \mathbf{x} + b = 0\) separates the classes, and it maximizes the distance from this hyperplane to the nearest training points on each side. The solution can be formulated as:
Minimize \(\frac{1}{2}|\mathbf{w}|^2\) subject to \(y_i(\mathbf{w}\cdot \mathbf{x}_i + b) \ge 1\) for all \(i\) (this is the hard-margin formulation, where \(y_i \in {+1,-1}\)).
In practice, data are often not perfectly separable, so a soft-margin SVM allows some slack: minimize \(\frac{1}{2}|\mathbf{w}|^2 + C \sum_i \xi_i\) such that \(y_i(\mathbf{w}\cdot \mathbf{x}_i + b) \ge 1 - \xi_i\) and \(\xi_i \ge 0\). Here \(C\) is a regularization parameter that trades off margin size with classification errors; a larger \(C\) means more penalty on errors (less tolerance for misclassification, so the model fits the training data more strictly), whereas a smaller \(C\) yields a wider margin at the expense of more misclassifications. The solution involves only a subset of the training points (the support vectors) – those that lie on or inside the margin boundary.
For nonlinear classification, SVM uses the kernel trick: data are implicitly mapped to a higher-dimensional feature space via a kernel function \(K(\mathbf{x}_i, \mathbf{x}*j)\) that computes a dot product in that space. Common kernels include polynomial, radial basis function (RBF, i.e. Gaussian), and sigmoid. The RBF kernel is very popular for general use; it introduces a parameter \(\gamma\) which controls the width of the Gaussian (the higher the \(\gamma\), the narrower each training point’s influence). Using a kernel, the SVM optimization is solved in dual form, finding coefficients \(\alpha_i\) for each training example. The decision function then is based on a weighted sum of kernel evaluations: \(f(x) = \sum*{i \in SV} \alpha_i y_i K(\mathbf{x}_i, \mathbf{x}) + b\).
Assumptions: SVM is a max-margin method with solid theoretical foundations in VC dimension and statistical learning theory. It implicitly assumes that a good decision boundary exists in some (possibly high-dimensional) transformed space where classes are separable by a margin. With an appropriate kernel, even very complex boundaries can be fit. SVMs require that the data be independent and identically distributed (iid) and assume no particular distribution for the features – the optimization is convex, which means the solution (for a given kernel and parameters) is globally optimal. SVMs are well-suited to high-dimensional data (where \(p\) can be larger than \(n\)) because the complexity does not depend on \(p\) in the same way as some other methods (the kernel computations can handle high \(p\) if done carefully). However, if classes are heavily overlapping or the margin has to allow many errors, SVM will still do its best by balancing margin width and slack penalties, but performance may be limited by what is linearly separable in the chosen kernel space.
Strengths: SVMs can be very effective in high-dimensional or sparse feature spaces. In fact, one of their noted advantages is that they are effective even when the number of features \(p\) is much larger than the number of observations \(n\). They are also robust to overfitting in many cases, because maximizing the margin is a form of structural risk minimization that controls model complexity. With the right kernel, they are quite flexible – for example, an RBF kernel SVM is a universal approximator (it can approximate any decision boundary given enough support vectors). SVMs often perform well out-of-the-box for medium-sized datasets and have only a few hyperparameters (primarily \(C\) and the kernel parameters) that can be tuned systematically. They naturally handle binary classification and can be extended to multi-class through one-vs-one or one-vs-rest schemes. Another advantage is that training an SVM involves a convex optimization problem (no local minima to worry about, unlike neural networks), which means the solution found is the global optimum for the given problem.
Limitations: Scaling to very large datasets can be an issue for SVMs. The training complexity is roughly \(O(n^2)\) or \(O(n^3)\) in naive implementations, meaning that tens of thousands of observations are feasible, but millions might be problematic without using approximate methods. Memory can also become an issue if many support vectors are needed (since the model stores those examples). Also, SVM models (especially with nonlinear kernels) are not very interpretable – while one can inspect which training instances are support vectors, it’s hard to translate that into a simple explanation of how features influence the class. The kernel trick requires choosing a good kernel and tuning its parameters (e.g., \(\gamma\) for RBF); a poor choice can lead to underfitting or overfitting. For example, too large \(\gamma\) in an RBF kernel can overfit (making each support vector’s influence extremely localized), while too small \(\gamma\) can underfit (making the decision boundary almost linear). Similarly, setting the regularization parameter \(C\) too high can overfit (almost no slack allowed, so it fits even noise points) or too low can underfit (too much slack, ignoring legitimate patterns). SVMs also do not natively provide probability estimates – one can use the distance to the hyperplane as a score, but to get probabilities, additional calibration (like Platt scaling) is needed. In contexts where a probabilistic output is needed for decision-making, this is a minor drawback. Finally, if there are many irrelevant features, SVM does not perform feature selection inherently (unlike decision trees which ignore unused features in splits); irrelevant features can increase computational cost and might introduce noise, so some feature selection or regularization (like using an \(L_1\)-penalized linear SVM) might be necessary.
Use Cases: SVMs have been successfully used in tasks like image classification (earlier, using SVMs on image feature vectors before deep learning became dominant), text classification (e.g., SVMs with a linear kernel are very effective for classifying documents by topic or for spam filtering, since text data is high-dimensional and sparse and SVM handles that well), and bioinformatics (e.g., classifying gene expression profiles). In social sciences, one might use SVMs for things like classifying individuals into categories based on a large number of attributes when a clear model form is unknown and accuracy is the priority. For example, classifying political ideology (conservative vs. liberal) from a wide range of survey responses or social media text – a linear SVM could handle the many features (words or responses) effectively and yield a model that focuses on a subset of support vectors. SVMs were particularly popular in the 2000s before tree ensembles and deep learning became widespread, but they remain a strong choice especially for medium-sized data with many features. For instance, in a scenario with a few thousand observations and a few thousand features, an SVM with appropriate kernel might outperform or match more complex models, with less tuning.
Implementation in R: The e1071
package provides an svm()
function for support vector machines. Caret can interface with this via method = "svmLinear"
(for a linear kernel) or "svmRadial"
(for an RBF kernel), etc. Here we demonstrate a radial kernel SVM (Gaussian kernel) and tune the cost C
and kernel width parameter (often expressed as \(\sigma\) or \(\gamma\)). We’ll use caret for convenience:
set.seed(123)
<- train(Class ~ ., data = trainData, method = "svmRadial",
svm_fit trControl = trainControl(method="cv", number=5),
tuneLength = 5) # let caret try a grid of C and sigma
$bestTune
svm_fit#> sigma C
#> 5 0.0625 4
print(svm_fit)
#> Support Vector Machines with Radial Basis Function Kernel
#> ... (omitting detailed output for brevity)
Suppose the best parameters found were \(\sigma = 0.0625\) and \(C = 4\). This means an RBF kernel with moderate width and a fairly high cost. The cross-validated accuracy might be around 80–85% for this model (SVM often performs on par with random forests and boosting if tuned well). We can check how many support vectors the model ended up with:
<- svm_fit$finalModel
svm_mod $tot.nSV # total number of support vectors
svm_mod#> [1] 150
In this example, around 150 support vectors were used out of 1000 training instances. If we inspect svm_mod$SV
, we’d see the actual support vectors (not shown here). The number of support vectors gives a sense of model complexity – fewer support vectors means a sparser solution (which is better for simplicity and speed).
For multi-class problems, e1071::svm
uses either one-vs-one or one-vs-rest strategies internally (one can specify this). There are also other packages like kernlab
which provide ksvm()
with additional kernel options and potentially faster computation for certain cases.
SVMs remain a powerful classification method, especially useful in high-dimensional spaces and when we want a method with strong theoretical guarantees. However, in many social science applications with moderate-sized data, tree-based models or simpler models might be preferred unless there’s a specific reason to use SVM (e.g., text data or very high-dimensional feature spaces where SVMs excel).
Next, we’ll discuss a few more methods – \(k\)-nearest neighbors and Naïve Bayes – which represent different paradigms (instance-based learning and probabilistic generative modeling, respectively).
6.7 k-Nearest Neighbors (kNN)
The \(k\)-nearest neighbors algorithm is a non-parametric, instance-based learning method that makes predictions based on the local neighborhood of a point in the feature space. For classification, the idea is simple: to predict the class of a new observation, find the \(k\) most similar (nearest) observations in the training set and take a majority vote of their classes. “Similarity” is typically defined by a distance metric, most often Euclidean distance for continuous features (with features scaled to the same range), though other distance measures like Manhattan or Minkowski can be used. kNN can also be used for regression by averaging the values of the \(k\) nearest neighbors.
Assumptions: kNN does not assume any particular functional form or distribution; it essentially assumes that points that are near each other in feature space are likely to have the same label (a form of smoothness or continuity assumption). Thus, it works under the assumption that the classification boundary is locally smooth – points that are close should tend to share the same class. It also implicitly assumes all features are comparable in scale or importance, because distance calculations can be dominated by variables with larger scales or higher variance. Therefore, it’s common to standardize or normalize features before applying kNN. Another assumption is that the training data is representative and dense enough in the feature space so that neighbors of a new point are likely to be truly similar in terms of outcome.
Strengths: kNN is simple to understand and implement. It can approximate very complex decision boundaries given enough data, because it makes virtually no assumptions about the shape of the decision boundary. It naturally handles multi-class classification (the majority vote can be among more than two classes). It’s also flexible in that by choosing the parameter \(k\), you can adjust the model’s complexity: a small \(k\) (e.g. 1) yields a very flexible model that can capture fine detail (but can overfit noise), while a larger \(k\) yields a smoother, more generalized boundary. kNN can work with any distance metric, so it can be adapted to various types of data (for example, one can incorporate categorical features by using a distance that accounts for category mismatches, or use Mahalanobis distance to account for correlations between features). Another advantage is that the model training phase is trivial – there is effectively no model to train beyond storing the data. All the work is done at prediction time, which is why it’s called a “lazy” learner. This means kNN can be updated easily with new data (just add the new instances); no retraining is necessary.
Limitations: The major limitations are computational cost and the curse of dimensionality. Because kNN requires computing the distance from a new point to all training points to find the nearest neighbors, predictions can be slow when the training set is large (unless you use efficient spatial data structures or approximate methods). There are data structures like KD-trees or ball trees that can speed up neighbor searches, but in high dimensions they often degrade to brute force search. As the number of features (dimensionality) grows, distance metrics become less effective – in very high dimensions, points tend to all be far apart, and the concept of nearest neighbor becomes less meaningful (many distances become similar). This is one aspect of the curse of dimensionality, which can drastically hurt kNN performance if you have many features but not exponentially more data. Thus, feature selection or dimensionality reduction (e.g., PCA) is often used before kNN in high dimensions. Another issue is that kNN’s accuracy can be sensitive to the choice of \(k\) and the distance metric. It also does not produce an explicit model or easy interpretability – the “model” is just the training dataset itself. There is no set of parameters or rules that summarize what was learned, which means it’s hard to explain how the algorithm is making decisions beyond saying “it looked at the nearest neighbors.” If some features are irrelevant or noisy, they can distort the distance calculations and thus the neighborhood structure – kNN does not inherently ignore irrelevant features, so including features that are unrelated to the outcome can reduce performance (weighted distance metrics or feature selection can mitigate this). Moreover, kNN does not handle missing values gracefully (you need a strategy for imputing or ignoring missing values in distance calculations). Finally, kNN can struggle when classes are imbalanced – the neighbors of a point from a minority class might be dominated by majority class examples simply because of their prevalence, unless you weigh distances or oversample the minority class.
Use Cases: kNN is often used as a simple baseline in machine learning tasks. It has been historically used in areas like OCR (optical character recognition) where, for example, one could classify handwritten digits by finding the most similar images in a database of labeled digits. It’s also used in some recommender systems (user-based or item-based collaborative filtering can be seen as a form of nearest-neighbor approach: find similar users to make a recommendation). In social science contexts, one use of the “nearest neighbor” concept is in matching for causal inference: for each treated individual, find the nearest neighbor in the control group with similar covariates (this is essentially 1-NN in the covariate space for matching treated and control units). For pure prediction, kNN might be appropriate if the decision boundary is very irregular and you have a lot of data such that an instance-based method can perform well. For example, predicting a person’s political affiliation from their entire survey response pattern (with many questions) – if you have seen many people with various response patterns, a new person’s nearest neighbors (in terms of survey answers) might give a clue to their affiliation. However, in practice, tree-based models or SVMs often outperform kNN for structured data, especially as \(p\) grows. kNN can shine in low-dimensional problems or as a quick-and-dirty classifier when you don’t want to train a complex model. It is also sometimes used in ensemble methods (e.g., as a component in a stacking model or as a distance measure in cluster analysis).
Implementation in R: A simple implementation is using the class
package’s knn()
function for making predictions (which requires you to supply training data, test data, and so on). The caret package can train a kNN model via method = "knn"
, and it will tune the parameter \(k\) by cross-validation. Below we use caret to find an optimal \(k\) for our dataset:
set.seed(123)
<- train(Class ~ ., data = trainData, method = "knn",
knn_fit preProcess = c("center","scale"), # scale features
trControl = trainControl(method="cv", number=5),
tuneLength = 10) # try k = 5,7,... etc.
$bestTune
knn_fit#> k
#> 7
print(knn_fit)
#> k-Nearest Neighbors
#>
#> 1000 samples, 16 predictors, 2 classes: 'Class1', 'Class2'
#>
#> Pre-processing: centered (16), scaled (16)
#> Resampling: Cross-Validated (5 fold)
#> Summary of sample sizes: 800, 800, 800, 800, 800
#> Resampling results across tuning parameters:
#>
#> k Accuracy Kappa
#> 3 0.791 0.581
#> 5 0.805 0.609
#> 7 0.810 0.620
#> 9 0.806 0.612
#> 11 0.802 0.604
#> 13 0.800 0.599
#> (The best result was for k = 7.)
In this output, the best \(k\) was 7, with an accuracy around 81.0% (Kappa ~0.62) on the cross-validation folds. We see the typical pattern that very small \(k\) (like 3) or larger \(k\) (13) performed slightly worse, and an intermediate neighborhood size worked best. We explicitly used preProcess = c("center","scale")
to standardize features, which is important for kNN so that no one feature dominates due to scale. The caret output confirms that 16 predictors were centered and scaled prior to modeling.
One can also use the package FNN
(Fast Nearest Neighbors) for an alternative implementation that may offer speed improvements or functions to get neighbor indices. For very large datasets, approximate nearest neighbor methods (like locality-sensitive hashing or techniques in the RANN
package) can be used to speed up queries at some cost to accuracy.
In summary, kNN is a useful concept and a reasonable method for certain problems, especially when the decision boundary is hard to parametrize. However, its practical use for prediction is often limited by scalability and the curse of dimensionality. It can serve as a good benchmark or as a component in more complex modeling strategies, but if high accuracy or interpretability is needed, other models might be preferred.
6.8 Naïve Bayes
Naïve Bayes is a family of simple probabilistic classifiers based on applying Bayes’ Theorem with a strong independence assumption between features. Despite this “naïve” assumption (that features are independent given the class), the method often performs surprisingly well, especially for high-dimensional data like text where the independence assumption is not terribly off or where the sheer number of features makes a more complex model intractable.
Theory recap: By Bayes’ theorem, the posterior probability of class \(y\) given features \(x_1, \ldots, x_p\) is:
\(P(Y=y \mid X_1=x_1,\ldots,X_p=x_p) \;=\; \frac{P(Y=y)\; P(X_1=x_1,\ldots,X_p \mid Y=y)}{P(X_1=x_1,\ldots,X_p)}.\)
Naïve Bayes assumes feature conditional independence given the class, i.e.:
\(P(X_1,\ldots,X_p \mid Y=y) = \prod_{j=1}^p P(X_j \mid Y=y).\)
This simplifies the task of estimating the joint distribution. The model consists of the prior probabilities \(P(Y=y)\) for each class \(y\), and the conditional distributions \(P(X_j \mid Y=y)\) for each feature \(j\) and class \(y\). For continuous features, one common approach is to assume a Gaussian distribution for each feature within each class (this yields the Gaussian Naïve Bayes model). For categorical features, one can use the empirical frequency in the training data for each class (with Laplace smoothing to handle zero counts). For example, in a spam email classifier, \(X_j\) might be a binary indicator for the presence of a particular word \(j\) in the email, and naive Bayes would estimate the probability of each word appearing in spam and non-spam emails. Once these probabilities are estimated (usually by counting in the training data), we classify a new observation by computing the posterior for each class (proportional to prior \(\times\) likelihood) and picking the class with the highest posterior probability. In formula form, for classification we predict \(\underset{y}{\arg\max}; P(Y=y), \prod_{j=1}^p P(X_j = x_j \mid Y=y)\).
We often work with log-probabilities to avoid underflow and to sum contributions: \(\log P(Y=y \mid \mathbf{x}) = \log P(Y=y) + \sum_{j=1}^p \log P(X_j=x_j \mid Y=y) + \text{constant}\).
Assumptions: The critical assumption is that features are conditionally independent given the class label. In reality, this is rarely true – features often have some correlations. However, naive Bayes tends to perform well anyway because it still picks up on the most informative features. Another assumption (or requirement) is that you can reasonably estimate the conditional distributions for each feature. In cases where features are continuous and not well modeled by a Gaussian, the Gaussian Naïve Bayes might do poorly; sometimes a kernel density estimate or discretization is used instead. Naive Bayes also assumes no feature has zero probability in a class unless it truly never occurs in that class (smoothing is used to ensure we don’t multiply by zero and eliminate a class entirely due to one missing feature). It assumes observations are iid and that the training data is representative of the population for estimating the probabilities. The independence assumption means it ignores any interactions between features – if certain combinations of features are especially indicative of a class, naive Bayes won’t capture that beyond each feature’s individual contribution.
Strengths: Naive Bayes is extremely simple, fast, and computationally efficient. Training it is basically just counting frequencies or computing means and variances (for Gaussian NB), which is \(O(n \times p)\) for \(n\) samples and \(p\) features. It requires very little data to get estimates if the independence assumption is roughly true, because it breaks a high-dimensional probability estimation problem into many one-dimensional problems. It’s particularly effective for high-dimensional problems like text classification, where each feature (word) provides a bit of evidence and the independence assumption is a reasonable first approximation. Naive Bayes often works well as a baseline, and in some cases, it is surprisingly hard to beat (for instance, classic spam filtering using a multinomial naive Bayes on word frequencies was very effective). It yields probabilistic predictions, which can be useful for certain applications (though these probabilities may not be well-calibrated if the independence assumption is violated). The model is also robust to irrelevant features: if a feature is not informative, ideally \(P(X_j \mid Y=y)\) will be about the same for all classes, so it won’t affect the argmax decision much. And if features are redundant (correlated), it will “double count” them in a sense, but often the model still makes the right prediction even if its calculated probabilities are off. Naive Bayes classifiers have no hyperparameters apart from any smoothing parameter (like Laplace smoothing value), so they are easy to set up and don’t require tuning in most cases.
Limitations: The glaring limitation is that by assuming independence, Naive Bayes ignores feature interactions. If two features are correlated given the class, the model will still treat them as if they independently contribute evidence, which can lead to overconfident probability estimates. For example, if two features actually measure the same underlying signal, Naive Bayes will count that signal twice. This tends not to hurt the 0-1 classification decisions too much – often it will still predict the correct class – but the predicted probabilities can be very skewed (overconfident). In scenarios where feature interactions are critical for classification, naive Bayes will underperform more flexible classifiers. Another issue is that if a particular feature value was never seen in training for a given class, the basic probability estimate would be zero, and thus any new instance with that feature value would get probability zero for that class. The standard solution is Laplace add-one smoothing (or some variant) to ensure no zero probabilities, but one must be careful to set smoothing priors reasonably. Naive Bayes also tends to perform poorly if the real decision boundary is highly non-linear and can’t be captured by the simple probability multiplication – for instance, if class is a complex combination of features (like XOR pattern), naive Bayes fails because it has no way to represent that interaction. Moreover, while the model is “interpretable” in the sense that you can inspect \(P(X_j \mid Y=y)\), it’s not as straightforward as a linear model to interpret the effect of each feature because the features act multiplicatively on the odds. There is also no direct notion of a feature importance or coefficient (though one can look at likelihood ratios \(P(X_j=a \mid Y=c_1)/P(X_j=a \mid Y=c_2)\) to see which feature values discriminate classes the most).
Use Cases: Naive Bayes has been particularly popular in text classification problems. For instance, classifying emails as spam or not spam: here the features could be counts of certain words, and the Naive Bayes assumption means “the probability of seeing the word ‘viagra’ is independent of seeing the word ‘free’, given the email is spam” – not strictly true, but it leads to a simple and effective classifier. Many spam filters and early document classification systems used multinomial Naive Bayes. It’s also used in medical diagnosis for quick probabilistic reasoning: e.g., given symptoms (assumed independent given disease), what’s the probability of a particular disease? In recommender systems or user behavior classification, one might use naive Bayes if the features are frequencies of certain actions (again resembling a multinomial model). In social sciences, naive Bayes could be applied to classify open-ended survey responses into categories (treating each word as independent evidence of a topic), or to quickly prototype a model predicting an outcome from several categorical inputs. Another example is sentiment analysis on text: a Bernoulli naive Bayes (features are presence/absence of words) is a common baseline to classify text as positive or negative sentiment. The reason these use cases are common is that naive Bayes handles high dimensional feature spaces and sparse data very well (in text, each document has only a few of all possible words). It’s also been used in filtering and recommendation (e.g., the classic “NewsYouLike” algorithm in the early 2000s was essentially a Naive Bayes on news article words to recommend articles). Because of its simplicity, it’s often one of the first models tried in a new classification task, to establish a baseline performance.
Implementation in R: The e1071
package provides a naiveBayes()
function (note: in e1071
, it’s called naiveBayes
with a capital B). There is also the naivebayes
package which offers a more recent implementation with some enhancements. In caret, you can use method = "naive_bayes"
(which uses the naivebayes
package internally). We will use caret to train a Naive Bayes classifier. By default, it will handle factors appropriately (using multinomial or Bernoulli model for categorical features) and treat numeric features using a Gaussian distribution unless otherwise specified:
set.seed(123)
<- train(Class ~ ., data = trainData, method = "naive_bayes",
nb_fit trControl = trainControl(method="cv", number=5))
print(nb_fit)
#> Naive Bayes
#>
#> 1000 samples, 16 predictors, 2 classes: 'Class1', 'Class2'
#>
#> Laplace parameter: 0
#>
#> Accuracy: 0.793 (±0.03)
This output shows that Naive Bayes achieved about 79.3% accuracy (±3% across folds) on our data. We did not specify any tuning grid, but we could tune the Laplace smoothing parameter if desired (caret would do that if we set tuneLength
or provide a grid for laplace
). The model likely treated numeric features with a Gaussian distribution per class and categorical features by their frequency. We could examine nb_fit$finalModel$tables
to see the conditional probability tables for each feature. For example, nb_fit$finalModel$tables$TwoFactor1
would show the probability of “Level1” vs “Level2” for each class. If any probabilities were 0, caret’s implementation might use Laplace smoothing of 0 (i.e., no smoothing, which could be an issue if a level is absent in training for a class).
For a quick check of how the model outputs probabilities, we can predict on the training set and inspect a few:
<- predict(nb_fit$finalModel, newdata = trainData, type = "raw")
pred_probs head(pred_probs)
#> Class1 Class2
#> [1,] 9.999e-01 1.000e-04
#> [2,] 1.288e-05 9.999e-01
#> [3,] 9.998e-01 1.530e-04
#> [4,] 9.999e-01 1.000e-04
#> [5,] 9.999e-01 7.350e-05
#> [6,] 6.294e-05 9.999e-01
We see that Naive Bayes often produces very extreme probabilities (close to 0 or 1) for the training instances. This is a result of the independence assumption – the model multiplies many small probabilities, which can yield a very small or very large odds ratio, thus saturating the probability near 0 or 1 for many cases. In terms of classification decisions, those are fine, but it reflects the overconfidence that can occur (in reality, we might not be that certain). With proper calibration or if we care about ranking rather than absolute probabilities, this is usually not a problem.
To sum up, Naive Bayes is a fast and simple classifier that can be very useful as a baseline and is particularly suited for high-dimensional problems or cases with clear conditional independence structure. Its strong independence assumption is seldom strictly true, but the classifier can still perform competitively in many situations. However, if there are complex feature interactions that matter, or if well-calibrated probabilities are needed, more sophisticated methods may be preferred.
6.9 Neural Networks (Brief Overview)
Artificial Neural Networks (ANNs) are a broad class of models inspired by the networks of neurons in biological brains. In the context of classical machine learning (as opposed to deep learning with many layers), a common type is the multilayer perceptron (MLP) – essentially a feedforward neural network with one or more hidden layers of units. Neural networks can be used for classification by having an output layer that produces class scores or probabilities (for example, via a softmax function for multi-class outputs, or a single sigmoid output for binary classification). They are extremely flexible function approximators – with enough hidden neurons, a neural network can approximate almost any function (this is the Universal Approximation Theorem) given sufficient data.
Structure: The simplest neural network for classification is a perceptron, which actually is just a linear classifier (it computes a weighted sum of inputs and passes it through a step function). When we add one hidden layer with nonlinear activation functions (such as sigmoid or ReLU), we get a 2-layer MLP (input -> hidden layer -> output). Each neuron in the hidden layer computes a weighted sum of its inputs and then applies a non-linear activation function. In a traditional MLP for classification, the activation might be a logistic sigmoid or hyperbolic tangent for hidden layers (in modern practice, often ReLU is used for hidden layers, but classic neural nets in R might use sigmoids), and the output layer uses a softmax activation to produce probabilities for each class (or a single sigmoid for binary probability). The model has weights for each connection (input to hidden, hidden to output), and possibly biases for each neuron. Training the network involves finding the values of these weights that minimize a loss function (typically the cross-entropy loss for classification) via backpropagation, which is essentially gradient descent on the network’s parameters. This training process often uses an iterative algorithm (like stochastic gradient descent or one of its variants) and can be sensitive to learning rate settings, initial weights, etc.
Assumptions: Neural networks are largely data-driven and make very few explicit assumptions about the form of the decision boundary – they can learn arbitrary non-linear relations. However, they are parametric in the sense that for a given architecture (number of layers and neurons), there are a fixed number of weights. They assume that the training process can find a good set of weights (which is not guaranteed since the optimization problem is non-convex, meaning there are many local minima, though in practice good solutions are often found). Neural networks assume independence of observations and typically require that features are scaled to similar ranges (this helps training converge faster). They also implicitly assume the problem is at least somewhat learnable with a smooth combination of inputs – if the classification problem is truly very discrete or combinatorial, a neural network might struggle without a huge number of hidden units. Another assumption is that we have enough data to properly train; neural nets have a lot of parameters, so they usually need a large dataset or some form of regularization to avoid overfitting.
Strengths: Neural networks are extremely flexible – given enough hidden units, they can approximate highly complex decision boundaries that other methods (like SVMs or trees) might find difficult. They can automatically learn feature interactions and non-linear transformations of the inputs through the hidden layers. For classification, they naturally output probabilities (when using a softmax or sigmoid output layer), which can be useful. They can handle multi-class problems easily (e.g., an output neuron for each class). Neural networks can also handle continuous and categorical features (categorical typically via one-hot encoding). With multiple hidden layers (deep networks), they can learn hierarchical representations of data (though training deep networks requires careful techniques and usually large data). Even a single hidden layer network (often called a shallow neural net) can model interactions between inputs by having multiple hidden neurons that each capture different aspects. Another practical advantage is that once a neural network is trained, prediction is very fast (just a series of matrix multiplications), even if training might have been slow. Neural networks have been the key to recent advances in fields like image and speech recognition (though those involve specialized architectures like convolutional and recurrent networks), demonstrating their power on complex tasks.
Limitations: Traditional neural networks can be considered “black boxes” because it’s not easy to interpret the weights and activations in terms of original features. While methods exist to interpret or visualize what the network is doing (especially for deep networks, tools like feature importance via perturbation or SHAP values can be applied), it’s generally harder to explain a neural network’s prediction than, say, a decision tree’s prediction. Neural networks also require tuning of many hyperparameters: the number of hidden layers, number of neurons per layer, choice of activation function, learning rate, number of training epochs, regularization parameters (like weight decay or dropout rate), etc. This can make them time-consuming to develop compared to off-the-shelf methods like random forests which have fewer critical hyperparameters. They are prone to overfitting if the network capacity is large relative to the amount of training data – without sufficient regularization (like early stopping, weight decay, or dropout) a neural net can memorize the training set. Training a neural network can also be slow if the dataset is large and the network is big – although with modern libraries and possibly GPU acceleration, small to medium networks train quickly. Another issue is that neural networks can be sensitive to the initial random weights; different runs can converge to different solutions, so results might not be completely reproducible unless you fix the random seed and other conditions. If the data has varying scales, lack of normalization can lead to poor training performance (the network might get stuck in a region of parameter space where gradients are tiny, for example, if one feature scale dominates the others). Additionally, while neural networks are universal approximators, for some problems a carefully structured model (like a decision tree capturing logical rules) might actually require far fewer data to learn than a generic neural network, so the advantage of neural nets is seen mostly in complex problems and large datasets.
Use Cases: Neural networks have a long history in fields like psychology (as models of cognitive processes) and have become extremely prominent in areas like computer vision and natural language processing with the rise of deep learning. For structured tabular data, neural networks are used but often they do not vastly outperform boosted trees unless the problem has very complex patterns or enormous data. In social sciences, one might use a neural network to capture nonlinear relationships in survey data or economic data that are too complicated for a logistic regression. For example, if we suspect there are complex interactions among socio-economic variables in predicting an outcome (like income level), a neural net with a hidden layer might automatically model some of those interactions without our explicitly adding interaction terms. However, due to their black-box nature, historically social scientists have been cautious with neural nets unless predictive accuracy was the sole goal. That said, with the growth of explainable AI techniques, one can train a neural network and then attempt to interpret the fitted model (for instance, using methods that compute the influence of each feature on the prediction for a given individual – e.g., LIME or SHAP applied to the neural net). Another use of neural networks in a causal inference context has been to create propensity score models or to estimate conditional expectations in doubly robust estimators, especially when there are many covariates (neural nets can serve as a flexible regression method within such frameworks). And of course, if the data are not traditional tabular data – say images or text – then neural networks (CNNs, RNNs, transformers) are the state-of-the-art, but those are beyond the scope of this chapter, which focuses on common algorithms accessible via R packages.
Implementation in R: The classic package for a single hidden layer neural net is nnet
(Venables & Ripley), which can fit neural networks with one hidden layer (optionally with skip-layer connections) and also performs multinomial logistic regression (when there are no hidden layers). There’s also neuralnet
and more modern interfaces via keras
(for deep learning) or torch
in R. Here we’ll demonstrate using nnet
via caret (method = "nnet"
). We will tune the number of hidden units (size
) and the weight decay (L2 regularization) parameter, which helps prevent overfitting by penalizing large weights:
set.seed(123)
<- train(Class ~ ., data = trainData, method = "nnet",
nn_fit trControl = trainControl(method="cv", number=5),
tuneGrid = expand.grid(size = c(1, 5, 10), decay = c(0, 0.1, 0.5)),
MaxNWts = 1000, # maximum number of weights (for larger networks increase this)
trace = FALSE)
$bestTune
nn_fit#> size decay
#> 5 0.1
print(nn_fit)
#> Neural Network
#>
#> 1000 samples, 16 predictors, 2 classes: 'Class1', 'Class2'
#>
#> Hidden units: 5 Weight decay: 0.1
#> ...
#> Accuracy: 0.818 (±0.02)
The best model in this tuning grid used 5 hidden units and a weight decay of 0.1, achieving about 81.8% cross-validated accuracy. This is in line with the other models we tried (it’s a bit better than logistic regression, on par with or slightly below random forest and boosting on this dataset). The weight decay of 0.1 indicates that fairly strong regularization was helpful (preventing the network from overfitting). With 5 hidden units, our network has the structure: 16 inputs -> 5 hidden neurons -> 2 output neurons (since it’s a 2-class problem, effectively one output could be used with a logistic, but nnet
by default uses two output nodes for two classes). We could examine the model object (nn_fit$finalModel
) to see the weights, but those are not easy to interpret directly.
We can, however, gain some insight by using tools like the NeuralNetTools
package to plot the network or compute variable importance. For example, Garson’s algorithm can estimate the relative importance of input features by tracing how weights connect inputs to hidden layers and to output. Using NeuralNetTools::garson(nn_fit$finalModel)
could give a rough importance ranking of the features. Another approach is to compute the change in network output when each feature is varied (holding others constant) – similar to a partial dependence plot for a neural net. This can be done by sampling points and using the network’s predictions.
Neural networks can naturally handle multiclass outcomes by having multiple output nodes. The nnet
package’s multinom
function (from nnet) actually fits a neural network with no hidden layer (which is equivalent to multinomial logistic regression). If we wanted a deeper network (more than one hidden layer), we would need to use a package like keras or neuralnet. The keras package in R, which interfaces with TensorFlow, allows building arbitrary feedforward networks and is the gateway to deep learning in R. For example, one could build a 3-layer neural network with dropout and train it on data using keras, which might outperform a shallow network if the data warrants it. However, that goes beyond the scope of this introduction.
To recap, neural networks are a powerful and flexible classification method capable of capturing complex patterns in data. Their downsides are the increased computational cost, the need for tuning and expertise in setting them up, and the difficulty in interpretation. In scenarios with ample data and complex relationships (and when interpretability is not the primary concern), they can be extremely effective. In more limited data scenarios, they might not outperform simpler models and can be harder to justify to stakeholders. It’s often a good strategy to start with simpler models and only resort to neural networks if those models fail to capture the needed patterns or if you have reason to believe that a neural net will significantly improve performance.
6.10 Practical Considerations in Classification Modeling
Having surveyed a range of classification models, we now turn to several important practical aspects of applying these models in R, particularly in a data science workflow aimed at both predictive performance and (where relevant) causal insights.
Model Selection and Overfitting
One of the core challenges in machine learning is model selection: choosing the type of model and its complexity to best balance fitting the training data versus generalizing to new data. Overfitting occurs when a model learns the training data too closely, including its noise and idiosyncrasies, thus performing poorly on unseen data. Many of the methods discussed have mechanisms or variants to control overfitting:
- In logistic regression, overfitting can occur if there are too many features relative to the number of observations (especially if some features are just noise). Using regularization helps (e.g., Lasso or Ridge regression adds a penalty to the coefficients). One can also perform feature selection or use domain knowledge to limit the predictors.
- Decision trees are highly prone to overfitting if grown too deep. Pruning the tree (cutting it back to a smaller size) or setting limits like a maximum depth or minimum samples per leaf can control this. In the CART algorithm, the complexity parameter
cp
effectively penalizes each additional split, helping to prevent over-complex trees. - Random forests mitigate overfitting by averaging many trees. Generally, a random forest will not overfit as you add more trees – the OOB error tends to converge. However, if each individual tree is very deep and the dataset is small, the forest can still overfit some (it just overfits less than a single tree would). Using a modest tree depth or minimum node size can help, but often the defaults work well. If the forest is overfitting, one can also increase the amount of randomness (e.g., use smaller
mtry
or use feature subsampling and row subsampling). - Gradient boosted trees will overfit if you run them too long or don’t constrain them. Important controls include the number of trees (
nrounds
), learning rate (eta
), and tree depth. Techniques such as early stopping (stop adding trees when validation error stops improving) are commonly used to find an optimal stopping point. Also, using a smaller learning rate and limiting tree depth (to, say, 3–6) is a way of regularizing. XGBoost also has parameters likegamma
(minimum loss reduction required to make another split) andmin_child_weight
(minimum sum of instance weights in a leaf) which act as regularization. - SVM overfitting is controlled via the cost parameter \(C\) and the kernel parameters. A very large \(C\) means the model tries to fit all training points (low bias, high variance), which can overfit, while a smaller \(C\) is more tolerant to misclassifications (higher bias, lower variance). Similarly, in an RBF kernel, a large \(\gamma\) can fit very fine details (potentially overfitting), whereas a small \(\gamma\) yields smoother decision boundaries (possibly underfitting). Cross-validation is typically used to find a good balance.
- kNN complexity is directly related to \(k\). A very small neighborhood (\(k=1\)) will perfectly fit the training data (each point is its own neighborhood, so training error is zero) but will likely overfit noise. As \(k\) increases, the model becomes smoother (more biased, less variance). One usually uses cross-validation to pick \(k\) that minimizes validation error. Also, if you have a lot of features (which can cause overfitting due to curse of dimensionality), you might reduce dimensionality or weight features differently.
- Naive Bayes, due to its simplicity, is actually less prone to overfitting in terms of variance – it has high bias from the independence assumption, which can actually act as a form of regularization. It will overfit if some probability estimates are based on very little data, but smoothing alleviates that. Usually, the main issue is not variance but the bias from the incorrect independence assumption.
- Neural networks can overfit substantially, especially if the network is large relative to the dataset. To prevent this, one can use regularization techniques: weight decay (L2 regularization on weights) as we did in tuning
decay
, dropout (randomly dropping out a fraction of neurons during training, which is available in Keras but not in basennet
), early stopping (monitor performance on a validation set and stop training when performance starts to degrade), or simply limiting the number of hidden units. Another strategy is to collect more data if possible, since neural nets can make good use of more data.
In practice, cross-validation is the primary tool for model selection and detecting overfitting. By evaluating model performance on held-out portions of the data (that were not used for training that model instance), we get an estimate of how the model generalizes. If a model has much higher performance on training data than on validation data, that’s a sign of overfitting. For example, if a decision tree has 95% accuracy on training but only 75% on validation, it likely overfit. Cross-validation helps in choosing hyperparameters (like tree depth, \(k\), or regularization strength) that give the best validation performance, rather than the best training performance.
Another way to think about model complexity is the bias-variance tradeoff. Simpler models (high bias, low variance) might underfit but are stable; very complex models (low bias, high variance) can fit the training data well but wander on new data. Techniques like bagging and boosting are ways to reduce variance without greatly increasing bias (bagging by averaging models, boosting by adding many small corrections).
It’s also important to consider nested cross-validation or having an independent test set. Typically, one would do something like: split data into training and test sets; use the training set for cross-validation to tune models; then evaluate the final chosen model on the test set once. This ensures an unbiased evaluation of the final model’s performance on completely unseen data. In our examples, we mostly relied on cross-validation due to limited data, but in a real analysis one might hold out, say, 20% as a test set.
For causal inference applications, overfitting can be dangerous because the goal is often to estimate some effect or to balance covariates, rather than purely to predict well. If a propensity score model overfits, it might assign extremely low or high propensity scores that aren’t reliable, which can bias causal estimates. Thus, when using classification or any predictive model in a causal pipeline, simplicity and avoidance of overfitting are crucial (e.g., sometimes a main-effects logistic regression is used for propensity scores rather than a black-box model, to ensure more stable estimates).
In summary, to combat overfitting one should use cross-validation to tune complexity, consider regularization (penalties, pruning, etc.), and possibly simplify the model if interpretability is also a concern. Ensembles help get better raw performance, but one must remember that an ensemble that is too powerful can still overfit if not properly tuned (especially boosting). Always compare training versus validation metrics to gauge if a model is too complex.
Handling Class Imbalance
Class imbalance occurs when one class is much rarer than another in the training data. For example, suppose only 5% of your observations are “positive” cases (like frauds, diseases, or some event) and 95% are “negative.” This imbalance can pose several challenges:
- A classifier that simply predicts the majority class for every instance will achieve high accuracy (95% in this example) but is useless for identifying the minority class. So accuracy can be a misleading metric in imbalanced scenarios.
- Many machine learning algorithms, by default, optimize overall accuracy or assume roughly balanced classes, and thus they may be biased towards predicting the majority class. For instance, if a decision tree is trying to maximize information gain, it might make splits that isolate the majority class well but do little for the minority class if the minority is very small.
- The minority class has fewer examples, so there is less information to learn its pattern, and standard training might not give it enough weight.
There are several strategies to address class imbalance:
- Resampling the training data: This includes oversampling the minority class, undersampling the majority class, or a combination. Oversampling means we create additional synthetic or duplicate examples of the minority class to balance the dataset. A naive way is to randomly duplicate minority instances until the classes are roughly equal. A more advanced technique is SMOTE (Synthetic Minority Over-sampling Technique), which generates new minority examples by interpolating between existing ones. Undersampling means we randomly remove some majority class examples to reduce the imbalance (this can lose information but makes classes balanced). In practice, a combination might be used (e.g., undersample the majority a bit and oversample the minority a bit). In R, the
ROSE
package (Random Over-Sampling Examples) can generate balanced samples, and caret has options to perform sampling within resampling (e.g.,sampling = "up"
or"down"
intrainControl
). - Adjust class weights or costs: Many algorithms have parameters to give more weight to minority class errors. For example, in logistic regression you can use weighted maximum likelihood (giving higher weight to minority class instances). In SVM,
e1071::svm
has aclass.weights
parameter to penalize misclassifying the minority more. In rpart (decision trees), you can specify a loss matrix or priors to favor splitting that improves minority class accuracy. In XGBoost, there’s a parameterscale_pos_weight
that can balance the gradient for positive class. Weighting effectively tells the algorithm that a minority class instance is, say, 10 times as important as a majority class instance, so it will try harder not to misclassify those. - Use appropriate evaluation metrics: When classes are imbalanced, metrics like Precision, Recall, F1-score, and ROC AUC become more informative than raw accuracy. For instance, you might aim to maximize the F1-score of the minority class, or ensure a certain Recall (sensitivity) is achieved. Using these metrics in cross-validation can lead you to choose a model that is better at detecting the minority class. For example, you might choose a slightly lower overall accuracy model if it greatly improves minority class detection.
- Threshold moving: If a model outputs probabilities, you don’t have to use the default 0.5 cutoff for classifying as positive. You might set a lower threshold to classify an observation as positive in order to capture more of the minority class (trading off precision for recall). For example, if only 5% are positive, the classifier might output low probabilities for most, and you might decide that any instance with probability >0.2 is considered positive, to get more positives flagged. You would choose this threshold by examining precision-recall tradeoff on a validation set (perhaps via the Precision-Recall curve).
- Ensemble techniques and algorithms specifically designed for imbalance: Some ensemble methods have variants, like Balanced Random Forest, which samples the data in each tree such that each class is represented equally in the bootstrap sample. This forces each tree to pay attention to the minority class. There are also algorithms like EasyEnsemble which is an ensemble of models each trained on a balanced subset (with undersampling of the majority).
- Data augmentation: In contexts like image classification, one might create new minority class examples by transformations (but in tabular data this is analogous to SMOTE).
- Collect more data for minority class if possible: This is more of a practical consideration – sometimes the only good solution is to obtain more minority instances if you can (e.g., more fraud examples from historical data).
In R’s caret framework, one convenient approach is:
<- trainControl(method="cv", number=5, sampling="smote")
ctrl <- train(Class ~ ., data = trainData, method="xxx", trControl = ctrl, ... ) model
This will apply SMOTE to upsample the minority class in each fold of cross-validation. Alternatively, sampling = "down"
will undersample the majority in each fold, and "up"
will oversample the minority by simple replication.
Example: Suppose in our trainData
the Class2 was only 10%. If we fit a vanilla model, it might mostly predict Class1. We could address this by weighting. For instance, with randomForest
, we can do:
<- randomForest(Class ~ ., data=trainData, classwt=c(Class1=1, Class2=5)) rf_weighted
giving Class2 a weight 5 times that of Class1 in the splitting criterion. Or using caret:
<- train(Class ~ ., data=trainData, method="rf",
rf_fit2 trControl = trainControl(method="cv", number=5),
weights = ifelse(trainData$Class=="Class2", 5, 1))
This would supply a vector of case weights, emphasizing the minority class.
It’s crucial to evaluate model performance with metrics that reflect the imbalance. For example, you might look at the confusion matrix and calculate Precision and Recall for the minority class. High recall (sensitivity) means most of the minority cases were caught, while precision tells you what proportion of predicted positives were actually positive. Depending on the application, you may prioritize one over the other. For instance, in disease screening, high recall (identifying all or most sick patients) might be the priority even if precision is low (many false alarms), because missing a sick patient is very costly.
Balanced accuracy is another metric: it’s the average of sensitivity and specificity, effectively treating both classes equally. In an imbalanced scenario, balanced accuracy is a better indicator than raw accuracy. For example, classifying everything as negative in the 5% positive case gives 95% accuracy but only 50% balanced accuracy (sensitivity=0%, specificity=100%, average 50%). Balanced accuracy would reveal that the model is no better than chance on the minority class.
In summary, class imbalance requires careful strategy in both model training and evaluation. Techniques like resampling, weighting, and using proper metrics ensure that the minority class gets due attention. It often helps to visualize performance via an ROC curve or Precision-Recall curve to choose an operating point that makes sense for the problem (e.g., if false positives are cheap and false negatives are expensive, you’d aim for a classifier that sacrifices precision for higher recall).
Cross-Validation and Hyperparameter Tuning
We have mentioned cross-validation (CV) multiple times as a way to estimate model performance and to tune hyperparameters. Let’s delve a bit more into best practices for using CV and other methods to tune models:
- k-Fold Cross-Validation: This involves splitting the training data into k roughly equal parts (folds). Then you train your model k times, each time using k-1 folds as training data and 1 fold as the validation data, rotating which fold is the validation set. You then average the performance across the k trials to get an estimate of model performance. Common choices are k=5 or k=10. In caret, we used
trainControl(method="cv", number=5)
for 5-fold CV. Stratified CV (ensuring class proportions are preserved in each fold) is usually done for classification. - Grid Search: For hyperparameter tuning, one approach is to define a grid of possible values (e.g., \(C\) in {0.1, 1, 10} and \(\gamma\) in {0.1, 0.01, 0.001} for SVM) and try all combinations, using CV to evaluate each combination on training data. Caret does this when you give a
tuneGrid
or usetuneLength
. We saw examples where caret tried a grid ofmtry
values for random forest, or a grid of size/decay for neural nets. - Random Search: If the hyperparameter space is large, random search can be more efficient than grid search. You randomly sample combinations of hyperparameters and evaluate those. Often, many hyperparameters are not very sensitive or have diminishing returns, so random search can find good regions without exhaustive checking. Caret can do random search with
search = "random"
and a specified number of combinations. - Bayesian Optimization: This is more advanced but there are packages (like
tune
in tidymodels orParBayesianOptimization
in R) that try to model the performance function and pick hyperparameters to sample in an informed way. - Nested Cross-Validation: If you plan to report an unbiased estimate of model performance after tuning, you should do a nested CV or have a separate test set. In nested CV, an inner loop is used for hyperparameter tuning, and an outer loop is used to evaluate the chosen model. For example, in a 5x5 nested CV, you’d have an outer 5-fold CV splitting, and within each training portion, you perform another 5-fold CV to choose hyperparameters. This is computationally heavy but gives a more reliable performance estimate for when data are limited and you cannot afford a separate test set.
- Hold-out Validation Set: Alternatively, you can set aside a portion of data (say 20%) at the start as a validation set and not use it in training at all until final evaluation. Then you can freely tune your model on the remaining 80% (using CV or not), and use the 20% to measure performance. However, this 20% is then effectively acting as your test set. In many Kaggle competitions, for example, people use a validation set for model selection because the actual test labels are unknown. In academic practice, having a distinct test set is considered good for final reporting.
- Tuning considerations: Some models are more sensitive to tuning than others. For example, random forests often work reasonably well with the default
mtry
(which is \(\sqrt{p}\)) and a large number of trees. Slight changes inmtry
might not drastically change performance. In contrast, SVM and neural nets can be very sensitive to their parameter choices. So you allocate your tuning effort accordingly. Also, sometimes you have to tune multiple parameters jointly (like depth and learning rate in boosting, or multiple regularization terms in a neural net). - Overfitting in tuning: It is possible to overfit the hyperparameter tuning process, especially if you do a very exhaustive search or if the validation set is small. That’s why nested CV or using an independent test set is important to verify that your chosen model generalizes. For example, if you try 100 combinations of hyperparameters, the one that gives the highest CV score might be benefiting from random fluctuations. Using more folds can mitigate this by reducing variance in the performance estimate, but the risk is still there.
- Parallel processing: Training many models in CV and grid search can be time-consuming. The caret package can utilize parallel backends (via the
doParallel
package) to train folds in parallel. For instance, a 5-fold CV could run 5 models at once if you have 5 cores. Similarly, grid search can be parallelized across different hyperparameter settings. - Tidy models approach: The tidymodels ecosystem (with the
tune
package) provides a grammar for tuning as well, and it integrates nicely with dplyr and such. It supports advanced techniques like racing (where it stops evaluating poor hyperparameter combos early). - Reporting results: When you’ve tuned a model via CV, typically you would retrain the final model on the full training dataset with the chosen parameters (since you want to utilize all data for final model). This is what caret’s
train()
does by default after selecting the best parameters. Then you can test that final model on the test set or use it for prediction. It’s good practice to report the cross-validation performance (e.g., “5-fold CV accuracy was X% ± Y”) as well as test set performance if available.
In our examples, we consistently used cross-validation within the training set to tune parameters. This not only helps in model selection but also gives an estimate of performance that is more realistic than the training error. The caret outputs we saw (with “Accuracy was used to select the optimal model”) illustrate how the model was chosen.
For example, with SVM we might have seen one combination give the highest accuracy. It’s worth noting that sometimes the metric you use to choose the model might not be the metric of interest. You could instruct caret to use, say, Kappa or F1 for selection if that’s more relevant (via summaryFunction
in trainControl to use a custom metric). For imbalanced data, one might choose a model that maximizes F1-score of the minority class instead of overall accuracy.
To conclude, cross-validation and careful tuning are essential for getting the most out of your classification models. They ensure you’re not just fitting noise and give you confidence in the model’s ability to generalize. R’s ecosystem provides strong support for these tasks, whether through caret, tidymodels, or manual coding. It may add computational overhead, but it is usually worth it for the performance gain and reliability of the model.
Performance Metrics for Classification
After training a classification model, we need to assess its performance. There are a variety of metrics, and the choice depends on the problem context (e.g., whether classes are balanced, whether false positives or false negatives carry different costs, etc.). Here are common performance metrics and what they mean:
Confusion Matrix: This is a fundamental tool. For a binary classifier, the confusion matrix has four entries:
- TP (True Positives): Cases where the model predicted Positive and the actual class was Positive.
- TN (True Negatives): Cases where the model predicted Negative and the actual class was Negative.
- FP (False Positives): Cases where the model predicted Positive but the actual class was Negative (also called Type I error or false alarm).
- FN (False Negatives): Cases where the model predicted Negative but the actual class was Positive (Type II error or missed detection).
From these, several metrics are defined:
- Accuracy: \(\frac{TP + TN}{TP + TN + FP + FN}\). This is the proportion of all instances that were correctly classified. Accuracy is simple but can be misleading in imbalanced data (predicting all negatives in the earlier example yields 95% accuracy but zero TPs).
- Precision (Positive Predictive Value): \(\frac{TP}{TP + FP}\). This answers: of all instances the model labeled as positive, what fraction are truly positive? It is a measure of exactness – a low precision means the model has many false alarms. For example, precision = 0.5 means only half of the predicted positives are actual positives. High precision is important in situations where false positives are costly (e.g., an innocent person being flagged as fraudulent).
- Recall (Sensitivity or True Positive Rate): \(\frac{TP}{TP + FN}\). This is the proportion of actual positive instances the model correctly identified. It measures completeness – a low recall means many positives were missed. High recall is crucial when false negatives are very costly (e.g., failing to detect a disease). Note that recall is the same as sensitivity, and for the negative class, recall would be specificity (see below).
- Specificity (True Negative Rate): \(\frac{TN}{TN + FP}\). This is the proportion of actual negatives that were correctly identified as negative. It’s basically the recall of the negative class. Specificity is important in contexts like medical tests (specificity = 1 means no healthy person is wrongly diagnosed as sick).
- F1-Score: The harmonic mean of precision and recall: \(F1 = 2 \cdot \frac{\text{Precision} \cdot \text{Recall}}{\text{Precision} + \text{Recall}}\). The F1-score is a single metric that balances precision and recall. It’s useful when you want a balance and when classes are imbalanced. The harmonic mean punishes extreme values; so if either precision or recall is very low, F1 will be low.
- Kappa (Cohen’s Kappa): This measures agreement between the model’s predictions and the true labels, adjusted for the agreement that could happen by chance. It’s especially useful in multi-class classification or imbalanced binary classification. Kappa values above 0.8 are often considered excellent agreement. We saw caret reporting Kappa alongside accuracy.
- ROC Curve (Receiver Operating Characteristic): This is a plot of the True Positive Rate (Recall) against the False Positive Rate (which is $1 - $ Specificity) as the decision threshold of the classifier is varied. It shows the trade-off between sensitivity and specificity. A classifier that randomly guesses would give a diagonal line (AUC of 0.5). A perfect classifier would go up to (0,1) point (TPR 1 at FPR 0). The AUC (Area Under the ROC Curve) is a threshold-independent performance measure that can be interpreted as the probability that the classifier ranks a randomly chosen positive instance higher than a randomly chosen negative instance. AUC is useful for comparing models when you care about performance across all threshold settings or when different threshold choices might be used.
- Precision-Recall Curve: Particularly useful for imbalanced datasets, it’s a plot of Precision (y-axis) vs Recall (x-axis) for different thresholds. When the positive class is rare, PR curves can be more informative than ROC curves, because ROC can sometimes give an overly optimistic picture (due to many TNs dominating the FPR). The area under the PR curve is another metric (though not as commonly quoted as ROC AUC).
- Log Loss (Cross-Entropy Loss): This metric takes into account the uncertainty of your prediction by looking at the probability output. If the model is very confident about a wrong prediction, log loss will be high. A good classifier will have a low log loss. It’s useful when you need well-calibrated probabilities, not just correct class predictions.
- Balanced Accuracy: We mentioned this above; it’s \(\frac{\text{Sensitivity} + \text{Specificity}}{2}\), effectively the average recall per class. In multi-class problems, a similar idea is to compute recall for each class and average them (this is one way to get a “macro-average” performance).
- Matthews Correlation Coefficient (MCC): A more obscure metric that is essentially a correlation coefficient between the observed and predicted binary classifications (it takes into account TP, TN, FP, FN in one formula). It’s a more informative single metric than accuracy in the case of imbalance, as it doesn’t get inflated by TNs when the negative class dominates.
Which metric to focus on depends on the context. For example:
- In medical screening, Recall (Sensitivity) is often key (we want to catch as many sick people as possible, even if we have some false positives) but we also watch Precision or Specificity because too many false alarms can cause other issues.
- In spam detection, one might look at Precision (we don’t want to classify a legitimate email as spam – precision for “spam” should be high) while also maintaining a decent recall.
- In an imbalanced legal case classification, F1 might be used to balance the two.
- In general, if you have to pick one threshold for a classifier, you might use metrics like F1 or Youden’s J (sensitivity + specificity - 1) to find an optimal threshold on a validation set.
It’s good practice to plot things: ROC curves can show if one model dominates another or if they cross (meaning one is better at low false positive rates, another at higher false positive rates). Precision-Recall curves can show how precision drops off as you try to increase recall.
In R, the caret
package’s confusionMatrix()
function can compute many of these (given a table of predictions vs actuals). The pROC
package can compute AUC and plot ROC curves, and PRROC
or precrec
can do PR curves.
In our earlier outputs, we saw accuracy and Kappa from caret CV results. If we were dealing with an imbalanced case, we might have used summaryFunction = twoClassSummary
in trainControl to have caret output ROC and maybe choose the best model by ROC AUC. For example:
<- trainControl(method="cv", number=5,
ctrl classProbs=TRUE,
summaryFunction=twoClassSummary)
<- train(Class ~ ., data=trainData, method="svmRadial",
svm_fit metric="ROC", trControl=ctrl, tuneLength=5)
This would choose the SVM model with best AUC. We would then also check its sensitivity and specificity at some threshold.
Ultimately, when presenting results, it’s often helpful to present a confusion matrix (for a chosen threshold) along with precision, recall, F1, etc., rather than just accuracy. This gives a fuller picture. If using the model in a decision-making process, you’d also consider the actual costs of false positives vs false negatives and maybe integrate that (there is a concept of expected cost or using a custom metric that weights FP and FN differently).
Interpretability and Explainability
Interpretability has come up as a key consideration especially in social sciences and causally oriented research. Different models offer different levels of interpretability:
- Logistic Regression: This is highly interpretable. Each coefficient can be exponentiated to give an odds ratio, which is relatively straightforward to explain (e.g., “Holding other factors constant, this coefficient of 0.5 means the odds of the outcome increase by a factor of 1.65 for a one-unit increase in that predictor”). One has to be careful about interpreting them causally (which requires assumptions), but at least they give a direction and magnitude of association. You can also compute marginal effects or predicted probabilities easily from a logistic model to communicate effects.
- Decision Trees: Small trees are very interpretable. You can literally draw a flowchart and show it to decision-makers or include it in a report. It’s easy to explain “If condition X is true and then condition Y is true, then the model predicts Z.” Each path through the tree can be considered an if-then rule. Trees also implicitly perform feature selection by splitting on some features and not others, which gives a sense of which features are important (though in a single tree this can be misleading if some features are almost as good as the chosen split). One can also measure variable importance in trees by how much each feature reduced impurity.
- Random Forests: The individual trees in a forest are not interpretable, and there could be hundreds of them. However, random forests offer variable importance measures. One common measure is the mean decrease in Gini impurity (or decrease in entropy) attributed to splits by a variable, averaged over the forest. Another (more robust) approach is permutation importance: randomly permute values of one feature in the OOB data and see how much the accuracy drops – if a feature is important, shuffling it will degrade performance a lot. These give an overall ranking of which features the model considered most predictive. Random forests also can give an idea of interactions by observing combined splits, but to explicitly interpret interactions one might use partial dependence plots: these show the marginal effect of one or two features on the predicted outcome, averaging out others. Partial dependence plots can reveal if the relationship is monotonic, or has thresholds, etc. They are a form of model interpretation (e.g., “On average, as education increases, the probability of voting increases until plateauing at X level”).
- Gradient Boosted Trees: Similar to random forests, they can provide importance measures (e.g., XGBoost’s
xgb.importance()
ranks features by how often they’re used and with what gain in accuracy). They also can be interpreted via partial dependence or more sophisticated methods like SHAP values. SHAP (Shapley Additive Explanations) assigns each feature a contribution to the difference between the prediction and the dataset’s baseline prediction, for each individual instance. SHAP values have a solid theoretical foundation and many tools exist to compute them for tree models. They can tell, for a specific prediction, how each feature pushed the model output higher or lower. This is very useful in explaining to an end-user why, say, a certain person was classified as high risk: e.g., “Income and age contributed the most to this decision, raising the predicted risk, while having a stable job contributed to lowering it.” - SVM: A linear SVM is interpretable similarly to logistic regression (weights on features indicate importance, though scale matters). But an SVM with an RBF kernel is not very interpretable globally. You could look at support vectors to see which training instances are key, but that doesn’t tell a clear story about feature effects. There are model-agnostic methods one could use: for example, LIME (Local Interpretable Model-agnostic Explanations) which fits a small interpretable model around the neighborhood of a prediction. LIME could approximate the SVM decision boundary locally by a linear model and say “for this prediction, these features were most influential”. SHAP can also be applied to any model, including SVM, by sampling.
- kNN: This is interpretable in a case-based reasoning way: to explain a prediction, you can say “The model looked at the 7 nearest neighbors of this instance in the training data. Those neighbors had classes A, A, B, A, A, A, A (for example), so majority is A. Specifically, the nearest neighbor was a data point that had features … and class A.” So you can present actual similar instances as the explanation. This is sometimes quite intuitive (like, “You are being compared to this past case which was similar to you”). But there isn’t a succinct description like a coefficient or rule – it’s example-based interpretability.
- Naïve Bayes: You can interpret the model by looking at the probabilities \(P(X_j = x \mid Y=y)\). For example, in a Naive Bayes spam filter, you might find \(P(\text{“money”} \mid \text{spam}) = 0.20\) versus \(P(\text{“money”} \mid \text{not spam}) = 0.005\). This indicates that the word “money” strongly suggests spam. Essentially, each feature contributes a likelihood ratio \(P(X_j \mid Y=1) / P(X_j \mid Y=0)\). If that ratio is >1, it votes for class 1, if <1 votes for class 0, and the magnitude indicates strength. So one could present the top features that favor each class by looking at those likelihood ratios. This is how some email spam filters present explanations (“This email was marked as spam because it contained the words ‘money’ and ‘Nigerian prince’, which are common in spam emails.”). Naive Bayes is thus fairly interpretable in terms of feature influence, albeit ignoring correlations.
- Neural Networks: These are considered black boxes, especially if you have many hidden units. However, for a small network, one could attempt to interpret weights – for instance, each hidden neuron is a combination of inputs, and one might try to interpret what each hidden neuron represents. In practice though, it’s hard to parse those weights directly. Instead, one uses post-hoc interpretation methods. For example, you can compute variable importance for a neural network by perturbation: increase or decrease an input feature and see how much the output changes (this can give a sense of which inputs the network is sensitive to). There’s also a technique for neural nets called Garson’s algorithm that distributes output variance back to inputs (available in NeuralNetTools in R). For deep neural nets (like image classifiers), there are specialized interpretation techniques such as saliency maps, but for a basic MLP these aren’t as developed in R outside of using general methods like LIME or SHAP. With packages like
DALEX
oriml
, one can treat any model as a black box (including annnet
model) and compute feature importance or partial dependence or individual conditional expectation curves. Those can tell you things like “in the model, as X increases, on average the prediction increases then levels off,” similar to other models.
In many real-world projects, a combination of approaches is used to explain models:
- Global interpretation: Looking at variable importances, global partial dependence plots, or the structure of an interpretable model (like a tree or linear model).
- Local interpretation: For a specific prediction, using example-based explanation (like nearest neighbors or prototypes), or feature attribution methods (SHAP/LIME) to explain why the model made that prediction.
- Model simplification: Sometimes you might distill a complex model into a simpler one (e.g., rule extraction from a neural network or training a shallow decision tree on the predictions of the random forest) to have a rough interpretable proxy.
For causal inference, interpretability is even more crucial because we want to understand the effect of predictors. Often, if the goal is causal insight, one might shy away from black-box models. Or, one might use a black-box model to discover patterns and then test those patterns with a simpler interpretable model or domain knowledge.
There’s a relevant concept: to explain or to predict (Shmueli, 2010) – if your goal is explanation (causal or theoretical understanding), you should use the simplest effective model and focus on interpretability (often sacrificing some predictive accuracy). If your goal is prediction, you might use a complex model and then use tools to try to explain it in a limited way. Both approaches can complement each other; sometimes one builds a predictive model to identify which variables or interactions are important, and then uses that information in a traditional regression model to test hypotheses.
In software, packages like shapley
or shapper
(for SHAP in R), LIME
(for local interpretable explanations) are available. For tree-based models, DALEX
and iBreakDown
provide nice interfaces.
As an example of interpretability in action, imagine a random forest model for predicting loan default. The variable importance might show that “Income” and “Credit Score” are top predictors. Partial dependence might show that default probability drops significantly once income is above a certain threshold or credit score is above, say, 700. Locally, for a specific person, a SHAP plot might show that “High existing debt” and “Young age” contributed positively to default risk, while “High income” contributed negatively, and the net effect led to a certain predicted risk. One could relay that to a loan officer in human terms.
Summary: Interpretability is about making the model’s decisions understandable to humans. Depending on the model and audience, this could range from directly reading off coefficients and rules to using auxiliary methods to translate the model’s behavior. In many regulated industries (finance, healthcare), there’s increasing demand for explainable models, and fortunately, even for complex models, there are ways to extract useful explanations. However, no post-hoc method can fully replace the transparency of a simple model, so there’s always a trade-off to consider between maximizing accuracy and maintaining interpretability.
Using R Packages and Workflows
Throughout this chapter, we have used various R packages and functions to implement classification models. Here we summarize some of the tools and how they fit into a typical workflow:
Data Preparation: Before modeling, you often need to clean and preprocess data. This might involve handling missing values, encoding categorical variables (though some models can handle factors directly), scaling or transforming features, creating interaction terms or polynomial terms if needed, etc. In our examples, we used
preProcess = c("center", "scale")
in caret for kNN to standardize features. For a more complex workflow, therecipes
package from tidymodels provides a way to chain preprocessing steps (imputation, transformations, one-hot encoding, etc.) in a reusable object.caret (Classification And Regression Training): We used
caret::train
to fit models in a uniform interface. The advantage of caret is that it handles splitting data for cross-validation, tuning hyperparameters, and even basic preprocessing in one go. It supports a wide array of models through itsmethod
argument (which under the hood calls various package functions). However, caret is somewhat being superseded by the tidymodels ecosystem in terms of development, but it’s still very widely used and stable. A limitation is that it doesn’t easily handle more modern techniques without custom model definitions, and its syntax can be verbose.tidymodels (parsnip, workflow, tune, etc.): This is a collection of packages that work together. For example,
parsnip
lets you specify a model in a unified way (likerand_forest(mtry=5, trees=100) %>% set_engine("randomForest")
).workflows
allows combining a model and a preprocessing recipe.tune
allows tuning hyperparameters with various strategies. Tidymodels is designed to be tidy (using tibble outputs, etc.) and flexible. For instance, one can do:<- rand_forest(mtry = tune(), trees = 200) %>% model set_engine("ranger") %>% set_mode("classification") <- recipe(Class ~ ., data = trainData) %>% recipe step_center(all_numeric()) %>% step_scale(all_numeric()) <- workflow() %>% add_model(model) %>% add_recipe(recipe) workflow <- tune_grid(workflow, resamples = vfold_cv(trainData, v=5), tuned grid = expand.grid(mtry = c(2,5,8))) <- select_best(tuned, "accuracy") best <- finalize_workflow(workflow, best) %>% fit(data = trainData) final_rf
This is more verbose in some ways, but very organized. Tidymodels also integrates well with dplyr and ggplot for analyzing results. It’s a modern approach if you plan to do many modeling tasks in a project.
Model-specific packages: We used
randomForest
,xgboost
,e1071
,nnet
, etc., directly through caret. One can also call them directly:randomForest()
function for quickly training a random forest (with default 500 trees, etc.). It returns an object with OOB error, importance, etc.xgboost
requires matrix data and careful parameter setup, but is extremely fast once set up.e1071::svm()
trains an SVM and has parameters for different kernels, etc. (There’s also thekernlab::ksvm
which is an alternative.)nnet::nnet()
fits a single hidden layer neural net.naivebayes::naive_bayes()
for Naive Bayes.class::knn()
for doing knn predictions (though typically you’d use caret or another wrapper to choose k). Each of these might have different syntax. The advantage of using them directly is often more control or accessing model-specific features (like partial plots in randomForest or the xgboost internal metrics). The disadvantage is having to manually handle things like cross-validation or standardization, which wrappers handle for you.
Evaluation and Visualization: We mentioned
caret::confusionMatrix
,pROC::roc
,precrec
for PR curves. Alsoggplot2
is invaluable for plotting model diagnostics, such as:- Plotting the ROC curve.
- Plotting a histogram of predicted probabilities by true class (to see separation).
- Partial dependence plots can be made with the
pdp
package (which works with randomForest, xgboost, etc., to plot partial dependence). - The
vip
(Variable Importance Plots) package can create ggplot2-based plots of importance for many model types.
Saving and Deploying Models: In R, you can save a model object with
saveRDS(model, "model.rds")
and load it back withreadRDS
. For deployment, sometimes you’d want to use a lighter-weight approach (like writing out just the needed parameters), but often just saving the fitted model object is fine for R-based deployment. If integrating with other systems, one might use thePMML
package to export certain models to PMML (Predictive Model Markup Language), but not all models are supported or easily translated.Reproducibility: Setting seeds (we used
set.seed(123)
) is important to get reproducible results, especially for models with randomness (like RF or the random partition of CV). Also, one should record the package versions used, as results can sometimes slightly differ with different versions (especially if you rely on OOB from randomForest, which can change if they change random number usage or such).Scaling to big data: If your dataset is very large (say millions of rows), not all these methods will be feasible on a single machine in pure R. In those cases, one might use big-data tools (Spark via
sparklyr
for example has MLlib algorithms, or H2O’s R interface for distributed random forests and XGBoost, etc.). For moderate data (tens of thousands of rows, dozens of features), the methods we used are usually fine on a modern computer.
Workflow summary: A typical classification modeling workflow in R might look like:
- Understand the problem and gather data.
- Exploratory data analysis: look at distributions, relationships, maybe do some feature engineering.
- Preprocess data: handle missing values, encode categorical variables (if needed; though many models in R handle factors natively), split into train/test (or use CV).
- Choose a set of candidate models to try (maybe start with a simple logistic regression and a tree, then move to more complex ones if needed).
- Use cross-validation to tune hyperparameters for each model type. This can be done with caret, tidymodels, or manual loops.
- Compare models using a validation set or CV metrics. Perhaps use AUC or F1 or whatever suits the goal.
- Select the best model (or perhaps decide to ensemble them or stack them).
- Retrain the chosen model on all training data with optimal hyperparameters.
- Evaluate on hold-out test data to get an unbiased performance estimate.
- Interpret the model: use variable importance, partial dependence, etc., to draw insights. If for a report, extract key findings (e.g., “Education and income were the strongest predictors of voting.”).
- If the model is for deployment (predicting future cases), save it and integrate into whatever system will use it (shiny app, plumber API, etc.).
- Monitor the model’s performance over time if applicable (data drift, etc., but that’s more advanced topics).
We should also note automated machine learning (AutoML) tools (e.g., H2O AutoML, tidymodels’s finetune
or other packages) can automate trying many models and tuning them. Those can be useful to quickly get a benchmark, but it’s still important to understand how to do it manually as we’ve discussed, especially for academic or nuanced problems where human guidance is needed.
Finally, regarding the interplay with causal inference: if using these tools in a causal workflow, one might be careful to separate how the model is used. For example, using ML to predict propensity scores is fine, but one should use techniques like cross-fitting (as in doubly robust estimation) to avoid overfitting biases. There’s a whole area of machine learning for causal inference where these models are used as parts of algorithms to estimate treatment effects (like causal forests, etc.), requiring careful sample-splitting.
6.11 Causal Inference Perspective
While this chapter is primarily about predictive modeling, we often have to consider causal inference in social science problems. It’s worth discussing how classification models can intersect with causal analysis:
- Propensity Score Modeling: In observational studies where you want to estimate the effect of a treatment (or some exposure) on an outcome, one common approach is to use propensity scores – the probability of receiving the treatment given covariates. This is essentially a binary classification problem: treated vs untreated as the outcome, covariates as features. Logistic regression is traditionally used for this (because of interpretability and well-behaved output). However, one could use any classifier to estimate propensity scores (random forests, boosting, etc.). The goal here is not to classify per se, but to get a well-calibrated estimate of \(P(Treatment=1 | X)\). If you use a complex model, you must be cautious: a highly overfit propensity model could create propensity scores that are 0 or 1 for many units, which is problematic (you lose overlap between groups). So usually some regularization or even sticking to logistic with main effects is recommended. Recently, researchers have tried using machine learning (with cross-validation) to estimate propensity scores and have shown it can sometimes improve covariate balance compared to a simple logit. The trade-off is that you sacrifice some interpretability (though propensity scores are often just a stepping stone, not of intrinsic interest themselves).
- Causal Trees and Forests: These are adaptations of decision trees and random forests for estimating heterogeneous treatment effects. For example, Athey and Imbens (2016) proposed a method to build a tree where the splitting criterion is based on treatment effect heterogeneity (difference in outcome between treated and control in a node) rather than outcome prediction. The leaves of the tree then give subgroups with different estimated treatment effects – this is very interpretable (“the treatment works best for young, low-income individuals, with an estimated effect of …”). Causal forests (Wager & Athey, 2018) extend this idea by using an ensemble of trees to get more stable estimates of conditional average treatment effects (CATEs). These techniques are implemented in the
grf
package in R (Generalized Random Forests). They represent a blending of prediction and causal inference: the forest is used to predict treatment effects as a function of covariates. The objective in splitting is tailored to causal effect estimation (e.g., one might use something like a criterion based on differences in treatment vs control means, and use sample-splitting to avoid bias). - Using ML for outcome modeling in causal inference: In methods like outcome regression or the augmented inverse propensity weighted estimators, one might use a flexible model to predict the outcome as a function of covariates separately for treated and control (especially in high-dimensional cases). For example, use random forests to predict \(E[Y|X, treated]\) and \(E[Y|X, control]\). If these models are accurate, you can plug them into a formula to estimate treatment effects. There’s a framework called targeted learning (van der Laan) that advocates using machine learning in causal inference with careful adjustments.
- Instrumental variables (IV) with classification: Sometimes the first stage of an IV is a classification (e.g., predicting whether someone takes up a treatment based on an instrument). One could use a classifier for that first stage. However, standard theory relies on linear first stages; using non-linear machine learning in IV is an area of current research (there are methods like deep IV, etc., but not mainstream yet).
- Selecting confounders: Some researchers use machine learning to sift through potential confounders. For example, if you have a wealth of covariates but not all are confounders, you might use algorithms to select a subset that makes treated and control balanced. One approach is propensity score pruning or using regularization in a logistic regression to pick important covariates for propensity score.
- Intersection with explainability: If a black-box model suggests a certain variable or interaction is very predictive of the outcome, a researcher might form a hypothesis that this variable is also an important causal factor or moderator. They could then include it in a more interpretable model or design a study around it.
However, one should avoid using the same data to both model and estimate causal effects without caution. Overfitting in predictive sense can translate to bias in causal estimates. For instance, if you use a flexible model to estimate propensity scores and directly weight by those to estimate treatment effect, you might inadvertently overfit and get a biased estimate. One solution is to use sample splitting: use half the data to learn the model (e.g., train a random forest for propensity or outcome), and the other half to compute the treatment effect using those learned functions (this avoids using the same data to both fit and assess effect, reducing bias). This idea is used in double machine learning (Chernozhukov et al.) to get unbiased estimates with ML components.
In sum, machine learning classifiers can be very helpful in the A (adjustment) part of causal inference – adjusting for confounders via propensity or outcome modeling – and in discovering effect heterogeneity. But the use of these tools should respect the principles of causal inference: issues like confounding, selection bias, and post-treatment variables can’t be fixed by a better classifier. The model can only account for what is observed and included.
It’s also worth noting that even in prediction contexts, understanding causality can be important. For example, if you use a predictive model to make policy decisions (who to give a job training program to), you should consider whether the model’s variables are causal or just predictive correlates. If it’s using something like ZIP code to predict crime risk, that could be predictive but not something you want to base decisions on causally (it could raise fairness concerns, etc.). So sometimes interpretability also ties into understanding whether a model might be using spurious correlations.
The fields of fairness, accountability, and transparency in ML also overlap with what features a model uses and whether those are causally linked or just correlates that could lead to biased decisions. Tools like decision trees or rule lists are sometimes favored in high-stakes domains because they allow an auditor to see what factors are being used.
Conclusion of causal perspective: Use classification models as tools in causal analysis carefully. Simpler models like logistic regression remain popular in social science partly because they’re easier to interpret in a causal way – you can discuss how an X relates to Y controlling for Z’s. But if the data is complex, machine learning can assist in getting better estimates or finding patterns, as long as one uses techniques to avoid overfitting and maintains a clear distinction between association and causation. For example, a random forest might tell you “age is a very important predictor of voting turnout” – that doesn’t mean age causes turnout (there could be cohort effects or other factors), but it directs your attention to age, and a researcher might then investigate that further.
With all the above considerations, we have traversed both the practical and theoretical aspects of classification models in R. By now, it should be evident that choosing and using a classification model involves not just running the code to fit it, but also understanding its assumptions, validating its performance properly, and interpreting its results in context. In the next section, we’ll wrap up and provide references for further reading.
6.12 Conclusion
In this chapter, we covered the landscape of classification models using R, from the simple and interpretable (logistic regression, decision trees, Naive Bayes) to the more complex and powerful (random forests, gradient boosting, SVMs, neural networks). We discussed the theoretical underpinnings of each method, their strengths and limitations, and demonstrated practical implementation and tuning with R code. We also emphasized important practical topics such as handling class imbalance, avoiding overfitting via cross-validation and regularization, evaluating models with appropriate metrics, and interpreting model outputs. Throughout, examples were drawn from social science contexts (like the Titanic survival example or discussion of socio-economic predictors) to illustrate how these models might be used in practice for both predictive and explanatory goals.
A key theme was the distinction between prediction and causation. Predictive models aim to maximize accuracy and often leverage complex patterns in the data, whereas causal inference prioritizes interpretability and unbiased estimation of relationships. We highlighted how in practice one might use simpler models for causal interpretation, or use machine learning models as part of a causal analysis carefully (for instance, in propensity score estimation or exploring heterogeneity).
We also introduced modern workflows in R, using the caret package for a unified modeling interface and mentioning the newer tidymodels framework. The R ecosystem provides extensive support for machine learning, from data preparation to model training to validation and visualization of results. The code examples in this chapter can serve as templates for your own analysis: you can plug in your dataset, adjust the methods and tuning as needed, and apply the same concepts to evaluate how well your model is doing.
When approaching a new classification task, you might start with a simple model like logistic regression or a decision tree to get a sense of the relationships in the data. These interpretable models can provide a baseline and some insights. If you need better predictive performance and have enough data, you could then try more powerful models like random forests or XGBoost, using cross-validation to tune them. Along the way, always remember to check whether your model might be overfitting and to use metrics that align with the problem’s needs (especially for imbalanced data). Once a model is chosen, spend time interpreting it – even if it’s a “black box,” there are tools to extract useful information (like which features are most important or how certain features influence the prediction).
Classification models are indispensable tools in a data scientist’s toolkit, and R makes it relatively straightforward to apply both basic and advanced classifiers. By understanding both the theory and the practical implementation, you can confidently use these models to analyze data, whether your goal is to accurately predict outcomes or to uncover the factors that drive those outcomes.
References
Athey, S., & Imbens, G. (2016). Recursive partitioning for heterogeneous causal effects. Proceedings of the National Academy of Sciences, 113(27), 7353–7360. https://doi.org/10.1073/pnas.1510489113
Breiman, L. (2001). Random forests. Machine Learning, 45(1), 5–32. https://doi.org/10.1023/A:1010933404324
Chawla, N. V., Bowyer, K. W., Hall, L. O., & Kegelmeyer, W. P. (2002). SMOTE: Synthetic minority over-sampling technique. Journal of Artificial Intelligence Research, 16, 321–357. https://doi.org/10.1613/jair.953
Chen, T., & Guestrin, C. (2016). XGBoost: A scalable tree boosting system. Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 785–794. https://doi.org/10.1145/2939672.2939785
Cortes, C., & Vapnik, V. (1995). Support-vector networks. Machine Learning, 20(3), 273–297. https://doi.org/10.1007/BF00994018
Friedman, J. H. (2001). Greedy function approximation: A gradient boosting machine. Annals of Statistics, 29(5), 1189–1232. https://doi.org/10.1214/aos/1013203451
Hosmer, D. W., Lemeshow, S., & Sturdivant, R. X. (2013). Applied Logistic Regression (3rd ed.). John Wiley & Sons. ISBN: 978-0-470-58247-3
Kuhn, M. (2008). Building predictive models in R using the caret package. Journal of Statistical Software, 28(5), 1–26. https://doi.org/10.18637/jss.v028.i05
Powers, D. M. (2011). Evaluation: From precision, recall and F-measure to ROC, informedness, markedness & correlation. Journal of Machine Learning Technologies, 2(1), 37–63. Stable PDF
Quinlan, J. R. (1993). C4.5: Programs for Machine Learning. Morgan Kaufmann. ISBN: 978-1-55860-238-0
Shmueli, G. (2010). To explain or to predict? Statistical Science, 25(3), 289–310. https://doi.org/10.1214/10-STS330
Sokolova, M., & Lapalme, G. (2009). A systematic analysis of performance measures for classification tasks. Information Processing & Management, 45(4), 427–437. https://doi.org/10.1016/j.ipm.2009.03.002
Vapnik, V. (1995). The Nature of Statistical Learning Theory. Springer. ISBN: 978-0-387-94559-0
Wager, S., & Athey, S. (2018). Estimation and inference of heterogeneous treatment effects using random forests. Journal of the American Statistical Association, 113(523), 1228–1242. https://doi.org/10.1080/01621459.2017.1319839
Zhang, H. (2004). The optimality of naive Bayes. AAAI/IAAI, 3(1), 562–567. Stable PDF