| 2 min read

Why we shouldn’t call fit_transform method on our test data?

Photo by David Rangel on Unslash

As we set out our foot in the field of Data Science, we are tempted to commit a mistake that might go unnoticed at first but it might be the sole reason of our model performing poorly in production.

The mistake:

X_test = imputer.fit_transform(X_test)

Do you see something wrong here?

Before we jump right into the explanation of what is happening here, let’s see what fit_transform() method does and how is it different from transform() method.

fit_transform():

This method is used on our train data to learn the training parameters (the mean and the variance of the features of training set) which are further used to scale our test data.

undefined fit_transform() is equivalent to using fit() and transform() methods in order.

transform():

This method uses the training parameters learned on the train data to scale the test data.

So what is it that we were doing wrong in the above code snippet?

We are unknowingly passing information from our train set to our validation set. This is calledTrain-Test contamination (Data Leakage).

Surprisingly, this mistake shall give you excellent results during model training. This is because by using fit method on our validation set, we compute a new mean and variance for the validation set and let our model learn about it (making it a biased model), which should never have been done in the first place as the validation set (or test data) should come as a “surprise” to our model.

So how do we scale our test data then? This is where the transform() method comes into picture. It uses the learnings from the training data to transform the testing data or the validation set.

It is common practice to scale the data so that the model is not biased towards any particular feature of the dataset, simultaneously ensuring that our model does not learn about our validation set (or test data).

Thus, the correct way:

X_train = imputer.fit_transform(X_train)
X_test = imputer.transform(X_test)

READ POST ON MEDIUM