Better Predicting Wine Cultivar with Feature Selection
In supervised machine learning (ML) the goal is to have an accurate model, which based on previously tagged data provides predictions for new data.
The number one question when it comes to modeling is: "How can I improve my results?"
There are several basic ways to improve your prediction model:
- Hyperparameters optimization
- Feature extraction
- Selecting another model
- Adding more data
- Feature selection
In this blog post, I'll walk you through how I used Feature Selection to improve my model. For the demonstration I'll use the 'Wine' dataset from UCI ML repository
Most of the functions are from the sklearn (scikit-learn) module.
For the plotting functions make sure to read about matplotlib and seaborn. Both are great plotting modules with great documentation.
Before we jump into the ML model and prediction we need to understand our data. The process of understanding the data is called EDA - exploratory data analysis.
EDA - exploratory data analysis.¶
UCI kindly gave us some basic information about the data set. I'll quote some of the more important info given: "These data are the results of a chemical analysis of wines grown in the same region in Italy but derived from three different cultivars. The analysis determined the quantities of 13 constituents found in each of the three types of wines ... All attributes are continuous ... 1st attribute is class identifier (1-3)"
Based on this, it seems like a classification problem with 3 class labels and 13 numeric attributes. A classification problem with the goal of predicting the specific cultivar the wine was derived from.
# Loading a few important modules
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
sns.set() #sets a style for the seaborn plots.
# Loading the data from it's csv,
# and converting the 'label' column to be a string so pandas won't infer it as a numeric value
data = pd.read_csv('wine_data_UCI.csv', dtype={'label':str})
data.head() # print the data's top five instances
I named the first columns as 'label'. This is the target attribute - what we are trying to predict. This is a classification problem, so the class label ('label') is not a numeric but a nominal value. that's why I'm telling Pandas this columns dtype is 'str'.
data.info() # prints out a basic information about the data.
As we can see we have 178 entries (instances). as we know from UCI's description of the data, we have 13 numeric attributes and one 'object' type attribute (which is the target column). all the columns of all the rows have data, therefore we see "178 non-null" next to every column description.
print(data['label'].value_counts()) # prints out how many times each value in the 'label' column is appearing.
sns.countplot(data['label']) # plots the above print
It's important to check the amount of instances in each class. There is difference between the class labels but It isn't a huge difference. If the difference was bigger we would be in an imbalanced problem. That would require a lot of other things to do, but this is for another post.
# This method prints us some summary statistics for each column in our data.
data.describe()
This is probably only informative to people who have some experience in statistics. Let's try to plot this information and see if it helps us understand.
# box plots are best for plotting summary statistics.
sns.boxplot(data=data)
Unfortunately this is not a very informative plot becasue the data is not in the same value range. We can resolve the problem by plotting each column side by side.
data_to_plot = data.iloc[:, 1:]
fig, ax = plt.subplots(ncols=len(data_to_plot.columns))
plt.subplots_adjust(right=3, wspace=1)
for i, col in enumerate(data_to_plot.columns):
sns.boxplot(y=data_to_plot[col], ax = ax[i])
This is a better way to plot the data.
We can see that we have some outliers (based on the IQR calculation) in almost all the feaures. These outliers deserve a second look, but we won't deal with them right now.
Pair plot is a great way to see a scatter plot of all the data, of course only for two features at a time. Pair plot is good for small amout of features and for first glance at the columns (features), afterwords in my opinion a simple scatterplot with the relevant columns is better.
columns_to_plot = list(data.columns)
columns_to_plot.remove('label')
sns.pairplot(data, hue='label', vars=columns_to_plot) # the hue parameter colors data instances baces on their value in the 'label' column.