Comparison between grid search and successive halving
This example compares the parameter search performed by HalvingGridSearchCV and GridSearchCV.
from time import time import matplotlib.pyplot as plt import numpy as np import pandas as pd from sklearn.svm import SVC from sklearn import datasets from sklearn.model_selection import GridSearchCV from sklearn.experimental import enable_halving_search_cv # noqa from sklearn.model_selection import HalvingGridSearchCV print(__doc__)
We first define the parameter space for an SVC estimator, and compute the time required to train a HalvingGridSearchCV instance, as well as a GridSearchCV instance.
rng = np.random.RandomState(0)
X, y = datasets.make_classification(n_samples=1000, random_state=rng)
gammas = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7]
Cs = [1, 10, 100, 1e3, 1e4, 1e5]
param_grid = {'gamma': gammas, 'C': Cs}
clf = SVC(random_state=rng)
tic = time()
gsh = HalvingGridSearchCV(estimator=clf, param_grid=param_grid, factor=2,
random_state=rng)
gsh.fit(X, y)
gsh_time = time() - tic
tic = time()
gs = GridSearchCV(estimator=clf, param_grid=param_grid)
gs.fit(X, y)
gs_time = time() - tic
We now plot heatmaps for both search estimators.
def make_heatmap(ax, gs, is_sh=False, make_cbar=False):
"""Helper to make a heatmap."""
results = pd.DataFrame.from_dict(gs.cv_results_)
results['params_str'] = results.params.apply(str)
if is_sh:
# SH dataframe: get mean_test_score values for the highest iter
scores_matrix = results.sort_values('iter').pivot_table(
index='param_gamma', columns='param_C',
values='mean_test_score', aggfunc='last'
)
else:
scores_matrix = results.pivot(index='param_gamma', columns='param_C',
values='mean_test_score')
im = ax.imshow(scores_matrix)
ax.set_xticks(np.arange(len(Cs)))
ax.set_xticklabels(['{:.0E}'.format(x) for x in Cs])
ax.set_xlabel('C', fontsize=15)
ax.set_yticks(np.arange(len(gammas)))
ax.set_yticklabels(['{:.0E}'.format(x) for x in gammas])
ax.set_ylabel('gamma', fontsize=15)
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
if is_sh:
iterations = results.pivot_table(index='param_gamma',
columns='param_C', values='iter',
aggfunc='max').values
for i in range(len(gammas)):
for j in range(len(Cs)):
ax.text(j, i, iterations[i, j],
ha="center", va="center", color="w", fontsize=20)
if make_cbar:
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cbar_ax)
cbar_ax.set_ylabel('mean_test_score', rotation=-90, va="bottom",
fontsize=15)
fig, axes = plt.subplots(ncols=2, sharey=True)
ax1, ax2 = axes
make_heatmap(ax1, gsh, is_sh=True)
make_heatmap(ax2, gs, make_cbar=True)
ax1.set_title('Successive Halving\ntime = {:.3f}s'.format(gsh_time),
fontsize=15)
ax2.set_title('GridSearch\ntime = {:.3f}s'.format(gs_time), fontsize=15)
plt.show()
The heatmaps show the mean test score of the parameter combinations for an SVC instance. The HalvingGridSearchCV also shows the iteration at which the combinations where last used. The combinations marked as 0 were only evaluated at the first iteration, while the ones with 5 are the parameter combinations that are considered the best ones.
We can see that the HalvingGridSearchCV class is able to find parameter combinations that are just as accurate as GridSearchCV, in much less time.
Total running time of the script: ( 0 minutes 17.141 seconds)
© 2007–2020 The scikit-learn developers
Licensed under the 3-clause BSD License.
https://scikit-learn.org/0.24/auto_examples/model_selection/plot_successive_halving_heatmap.html