Author: Jason Brownlee
Linear Discriminant Analysis is a linear classification machine learning algorithm.
The algorithm involves developing a probabilistic model per class based on the specific distribution of observations for each input variable. A new example is then classified by calculating the conditional probability of it belonging to each class and selecting the class with the highest probability.
As such, it is a relatively simple probabilistic classification model that makes strong assumptions about the distribution of each input variable, although it can make effective predictions even when these expectations are violated (e.g. it fails gracefully).
In this tutorial, you will discover the Linear Discriminant Analysis classification machine learning algorithm in Python.
After completing this tutorial, you will know:
- The Linear Discriminant Analysis is a simple linear machine learning algorithm for classification.
- How to fit, evaluate, and make predictions with the Linear Discriminant Analysis model with Scikit-Learn.
- How to tune the hyperparameters of the Linear Discriminant Analysis algorithm on a given dataset.
Let’s get started.
Tutorial Overview
This tutorial is divided into three parts; they are:
- Linear Discriminant Analysis
- Linear Discriminant Analysis With scikit-learn
- Tune LDA Hyperparameters
Linear Discriminant Analysis
Linear Discriminant Analysis, or LDA for short, is a classification machine learning algorithm.
It works by calculating summary statistics for the input features by class label, such as the mean and standard deviation. These statistics represent the model learned from the training data. In practice, linear algebra operations are used to calculate the required quantities efficiently via matrix decomposition.
Predictions are made by estimating the probability that a new example belongs to each class label based on the values of each input feature. The class that results in the largest probability is then assigned to the example. As such, LDA may be considered a simple application of Bayes Theorem for classification.
LDA assumes that the input variables are numeric and normally distributed and that they have the same variance (spread). If this is not the case, it may be desirable to transform the data to have a Gaussian distribution and standardize or normalize the data prior to modeling.
… the LDA classifier results from assuming that the observations within each class come from a normal distribution with a class-specific mean vector and a common variance
— Page 142, An Introduction to Statistical Learning with Applications in R, 2014.
It also assumes that the input variables are not correlated; if they are, a PCA transform may be helpful to remove the linear dependence.
… practitioners should be particularly rigorous in pre-processing data before using LDA. We recommend that predictors be centered and scaled and that near-zero variance predictors be removed.
— Page 293, Applied Predictive Modeling, 2013.
Nevertheless, the model can perform well, even when violating these expectations.
The LDA model is naturally multi-class. This means that it supports two-class classification problems and extends to more than two classes (multi-class classification) without modification or augmentation.
It is a linear classification algorithm, like logistic regression. This means that classes are separated in the feature space by lines or hyperplanes. Extensions of the method can be used that allow other shapes, like Quadratic Discriminant Analysis (QDA), which allows curved shapes in the decision boundary.
… unlike LDA, QDA assumes that each class has its own covariance matrix.
— Page 149, An Introduction to Statistical Learning with Applications in R, 2014.
Now that we are familiar with LDA, let’s look at how to fit and evaluate models using the scikit-learn library.
Linear Discriminant Analysis With scikit-learn
The Linear Discriminant Analysis is available in the scikit-learn Python machine learning library via the LinearDiscriminantAnalysis class.
The method can be used directly without configuration, although the implementation does offer arguments for customization, such as the choice of solver and the use of a penalty.
... # create the lda model model = LinearDiscriminantAnalysis()
We can demonstrate the Linear Discriminant Analysis method with a worked example.
First, let’s define a synthetic classification dataset.
We will use the make_classification() function to create a dataset with 1,000 examples, each with 10 input variables.
The example creates and summarizes the dataset.
# test classification dataset from sklearn.datasets import make_classification # define dataset X, y = make_classification(n_samples=1000, n_features=10, n_informative=10, n_redundant=0, random_state=1) # summarize the dataset print(X.shape, y.shape)
Running the example creates the dataset and confirms the number of rows and columns of the dataset.
(1000, 10) (1000,)
We can fit and evaluate a Linear Discriminant Analysis model using repeated stratified k-fold cross-validation via the RepeatedStratifiedKFold class. We will use 10 folds and three repeats in the test harness.
The complete example of evaluating the Linear Discriminant Analysis model for the synthetic binary classification task is listed below.
# evaluate a lda model on the dataset from numpy import mean from numpy import std from sklearn.datasets import make_classification from sklearn.model_selection import cross_val_score from sklearn.model_selection import RepeatedStratifiedKFold from sklearn.discriminant_analysis import LinearDiscriminantAnalysis # define dataset X, y = make_classification(n_samples=1000, n_features=10, n_informative=10, n_redundant=0, random_state=1) # define model model = LinearDiscriminantAnalysis() # define model evaluation method cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=3, random_state=1) # evaluate model scores = cross_val_score(model, X, y, scoring='accuracy', cv=cv, n_jobs=-1) # summarize result print('Mean Accuracy: %.3f (%.3f)' % (mean(scores), std(scores)))
Running the example evaluates the Linear Discriminant Analysis algorithm on the synthetic dataset and reports the average accuracy across the three repeats of 10-fold cross-validation.
Your specific results may vary given the stochastic nature of the learning algorithm. Consider running the example a few times.
In this case, we can see that the model achieved a mean accuracy of about 89.3 percent.
Mean Accuracy: 0.893 (0.033)
We may decide to use the Linear Discriminant Analysis as our final model and make predictions on new data.
This can be achieved by fitting the model on all available data and calling the predict() function passing in a new row of data.
We can demonstrate this with a complete example listed below.
# make a prediction with a lda model on the dataset from sklearn.datasets import make_classification from sklearn.discriminant_analysis import LinearDiscriminantAnalysis # define dataset X, y = make_classification(n_samples=1000, n_features=10, n_informative=10, n_redundant=0, random_state=1) # define model model = LinearDiscriminantAnalysis() # fit model model.fit(X, y) # define new data row = [0.12777556,-3.64400522,-2.23268854,-1.82114386,1.75466361,0.1243966,1.03397657,2.35822076,1.01001752,0.56768485] # make a prediction yhat = model.predict([row]) # summarize prediction print('Predicted Class: %d' % yhat)
Running the example fits the model and makes a class label prediction for a new row of data.
Predicted Class: 1
Next, we can look at configuring the model hyperparameters.
Tune LDA Hyperparameters
The hyperparameters for the Linear Discriminant Analysis method must be configured for your specific dataset.
An important hyperparameter is the solver, which defaults to ‘svd‘ but can also be set to other values for solvers that support the shrinkage capability.
The example below demonstrates this using the GridSearchCV class with a grid of different solver values.
# grid search solver for lda from sklearn.datasets import make_classification from sklearn.model_selection import GridSearchCV from sklearn.model_selection import RepeatedStratifiedKFold from sklearn.discriminant_analysis import LinearDiscriminantAnalysis # define dataset X, y = make_classification(n_samples=1000, n_features=10, n_informative=10, n_redundant=0, random_state=1) # define model model = LinearDiscriminantAnalysis() # define model evaluation method cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=3, random_state=1) # define grid grid = dict() grid['solver'] = ['svd', 'lsqr', 'eigen'] # define search search = GridSearchCV(model, grid, scoring='accuracy', cv=cv, n_jobs=-1) # perform the search results = search.fit(X, y) # summarize print('Mean Accuracy: %.3f' % results.best_score_) print('Config: %s' % results.best_params_)
Running the example will evaluate each combination of configurations using repeated cross-validation.
Your specific results may vary given the stochastic nature of the learning algorithm. Try running the example a few times.
In this case, we can see that the default SVD solver performs the best compared to the other built-in solvers.
Mean Accuracy: 0.893 Config: {'solver': 'svd'}
Next, we can explore whether using shrinkage with the model improves performance.
Shrinkage adds a penalty to the model that acts as a type of regularizer, reducing the complexity of the model.
Regularization reduces the variance associated with the sample based estimate at the expense of potentially increased bias. This bias variance trade-off is generally regulated by one or more (degree-of-belief) parameters that control the strength of the biasing towards the “plausible” set of (population) parameter values.
— Regularized Discriminant Analysis, 1989.
This can be set via the “shrinkage” argument and can be set to a value between 0 and 1. We will test values on a grid with a spacing of 0.01.
In order to use the penalty, a solver must be chosen that supports this capability, such as ‘eigen’ or ‘lsqr‘. We will use the latter in this case.
The complete example of tuning the shrinkage hyperparameter is listed below.
# grid search shrinkage for lda from numpy import arange from sklearn.datasets import make_classification from sklearn.model_selection import GridSearchCV from sklearn.model_selection import RepeatedStratifiedKFold from sklearn.discriminant_analysis import LinearDiscriminantAnalysis # define dataset X, y = make_classification(n_samples=1000, n_features=10, n_informative=10, n_redundant=0, random_state=1) # define model model = LinearDiscriminantAnalysis(solver='lsqr') # define model evaluation method cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=3, random_state=1) # define grid grid = dict() grid['shrinkage'] = arange(0, 1, 0.01) # define search search = GridSearchCV(model, grid, scoring='accuracy', cv=cv, n_jobs=-1) # perform the search results = search.fit(X, y) # summarize print('Mean Accuracy: %.3f' % results.best_score_) print('Config: %s' % results.best_params_)
Running the example will evaluate each combination of configurations using repeated cross-validation.
Your specific results may vary given the stochastic nature of the learning algorithm. Try running the example a few times.
In this case, we can see that using shrinkage offers a slight lift in performance from about 89.3 percent to about 89.4 percent, with a value of 0.02.
Mean Accuracy: 0.894 Config: {'shrinkage': 0.02}
Further Reading
This section provides more resources on the topic if you are looking to go deeper.
Tutorials
Papers
Books
- Applied Predictive Modeling, 2013.
- An Introduction to Statistical Learning with Applications in R, 2014.
APIs
- sklearn.discriminant_analysis.LinearDiscriminantAnalysis API.
- Linear and Quadratic Discriminant Analysis, scikit-learn.
Articles
Summary
In this tutorial, you discovered the Linear Discriminant Analysis classification machine learning algorithm in Python.
Specifically, you learned:
- The Linear Discriminant Analysis is a simple linear machine learning algorithm for classification.
- How to fit, evaluate, and make predictions with the Linear Discriminant Analysis model with Scikit-Learn.
- How to tune the hyperparameters of the Linear Discriminant Analysis algorithm on a given dataset.
Do you have any questions?
Ask your questions in the comments below and I will do my best to answer.
The post Linear Discriminant Analysis With Python appeared first on Machine Learning Mastery.