{ "cells": [ { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "# Hyperparameter Tuning with Falkon\n", "\n", "## Introduction\n", "\n", "We use Falkon for a multi-class problem (on the [digits](https://archive.ics.uci.edu/ml/datasets/Optical+Recognition+of+Handwritten+Digits) dataset), showing how to integrate it into scikit-learn\n", "for hyperparameter optimization.\n", "\n", "Since both `Falkon` and `LogisticFalkon` are estimators, and follow scikit-learn's API, integration is seamless" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[pyKeOps]: Warning, no cuda detected. Switching to cpu only.\n" ] } ], "source": [ "%matplotlib inline\n", "from sklearn import datasets, model_selection, metrics\n", "import numpy as np\n", "import pandas as pd\n", "import torch\n", "import matplotlib.pyplot as plt\n", "\n", "import falkon" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load the data\n", "\n", "We use the **digits** dataset, which is distributed alongside scikit-learn." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "X, Y = datasets.load_digits(return_X_y=True)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Label: 1\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD8CAYAAACvvuKtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAALn0lEQVR4nO3d34tc9RnH8c+nmwR/xaxUK2LEVKgBEboJIVQCmh8qsUpy04sEFCot6UUrhhZEexP9B8ReFGGJGsEY0WhIkdYa0EWEVpvEtUY3Fg0Rk6hRzCZqoUF9ejEnZSsb98zs+Z6dPL5fMOTMzsx5nk3yme85Z86cryNCAPL43kw3AKBZhBpIhlADyRBqIBlCDSRDqIFk+iLUtlfbftv2O7bvLljnYdtHbe8rVWNCrctsv2h7zPabtu8sWOss26/afr2qdV+pWhNqDth+zfazLdQ6aPsN26O2dxeuNWh7u+391b/dNYXqLKx+n1O3E7Y3NrLyiJjRm6QBSe9KukLSHEmvS7qqUK1rJS2WtK+F3+sSSYur5bmS/lXw97Kk86rl2ZJekfSTwr/fbyU9LunZFv4uD0q6sHSdqtajkn5ZLc+RNNhCzQFJH0q6vIn19cNIvVTSOxFxICJOSnpC0toShSLiJUmfllj3JLU+iIi91fJnksYkXVqoVkTE59Xd2dWt2FlFtudLulnS5lI1ZoLt89V5439IkiLiZESMt1B6laR3I+K9JlbWD6G+VNL7E+4fUqH//DPF9gJJi9QZQUvVGLA9KumopF0RUayWpAck3SXp64I1JgpJz9veY3tDwTpXSPpY0iPVrsVm2+cWrHfKOknbmlpZP4Tak/wszbmrts+T9LSkjRFxolSdiPgqIoYkzZe01PbVJerYvkXS0YjYU2L9p7EsIhZLuknSr21fW6jOLHV2zx6MiEWSvpBU7BiPJNmeI2mNpKeaWmc/hPqQpMsm3J8v6cgM9dIo27PVCfTWiHimjZrV5uKIpNWFSiyTtMb2QXV2lVbafqxQLUlSRByp/jwqaYc6u2wlHJJ0aMJWznZ1Ql7STZL2RsRHTa2wH0L9D0k/sv3D6l1rnaQ/zXBP02bb6uybjUXE/YVrXWR7sFo+W9L1kvaXqBUR90TE/IhYoM6/1QsRcWuJWpJk+1zbc08tS7pRUpFPLyLiQ0nv215Y/WiVpLdK1JpgvRrc9JY6mxszKiK+tP0bSX9V5yjgwxHxZolatrdJWi7pQtuHJG2KiIdK1FJnRLtN0hvVvq4k/T4i/lyg1iWSHrU9oM4b9ZMRUfyjppZcLGlH5z1SsyQ9HhHPFax3h6St1QBzQNLtpQrZPkfSDZJ+1eh6q0PqAJLoh81vAA0i1EAyhBpIhlADyRBqIJm+CXXh0/9mrFbb9ahFrb4JtaQ2g9ZqqFuuR63veK1+CjWABhQ5+cQ2Z7Q04Morr+z6NcePH9e8efO6ft2sWd2fXHjs2DFdcMEFXb/u8OHDXb/m5MmTmjNnTtevO378eNevOVNExGRfhiLU/WxkZKS1WoODg63V2rRpU2u1du7c2Vqttp0u1Gx+A8kQaiAZQg0kQ6iBZAg1kAyhBpIh1EAyhBpIplao25oWB8D0TRnq6mJ2f1TnUqZXSVpv+6rSjQHoTZ2RurVpcQBMX51Q15oWx/YG27tLz0oI4NvV+WpOrWlxImJY0rDEFzqAmVRnpE47LQ6QUZ1Qp5wWB8hqys3vNqfFATB9tS53Uc3/VGIOKAAN44wyIBlCDSRDqIFkCDWQDKEGkiHUQDKEGkim+2kZ0Jrx8fHWal133XWt1VqxYkVrtTJfzP90GKmBZAg1kAyhBpIh1EAyhBpIhlADyRBqIBlCDSRDqIFkCDWQTJ0ZOh62fdT2vjYaAjA9dUbqLZJWF+4DQEOmDHVEvCTp0xZ6AdAA9qmBZBr76qXtDZI2NLU+AL1pLNTMpQX0Bza/gWTqfKS1TdLfJC20fcj2L8q3BaBXdebSWt9GIwCaweY3kAyhBpIh1EAyhBpIhlADyRBqIBlCDSTDtDtdGBoaarXe8uXLW63XltHR0ZluITVGaiAZQg0kQ6iBZAg1kAyhBpIh1EAyhBpIhlADyRBqIBlCDSRT5xpll9l+0faY7Tdt39lGYwB6U+fc7y8l/S4i9tqeK2mP7V0R8Vbh3gD0oM60Ox9ExN5q+TNJY5IuLd0YgN50tU9te4GkRZJeKdINgGmr/dVL2+dJelrSxog4McnjTLsD9IFaobY9W51Ab42IZyZ7DtPuAP2hztFvS3pI0lhE3F++JQDTUWefepmk2ySttD1a3X5auC8APaoz7c7LktxCLwAawBllQDKEGkiGUAPJEGogGUINJEOogWQINZAMoQaSOePn0tq4cWNrte69997WaknSvHnzWq3XlpGRkZluITVGaiAZQg0kQ6iBZAg1kAyhBpIh1EAyhBpIhlADyRBqIJk6Fx48y/artl+vpt25r43GAPSmzmmi/5G0MiI+ry4V/LLtv0TE3wv3BqAHdS48GJI+r+7Orm5c1xvoU7X2qW0P2B6VdFTSrohg2h2gT9UKdUR8FRFDkuZLWmr76m8+x/YG27tt7264RwBd6Orod0SMSxqRtHqSx4YjYklELGmmNQC9qHP0+yLbg9Xy2ZKul7S/cF8AelTn6Pclkh61PaDOm8CTEfFs2bYA9KrO0e9/qjMnNYAzAGeUAckQaiAZQg0kQ6iBZAg1kAyhBpIh1EAyhBpIxp1vVja8UjvlVzMHBwdbrXfs2LFW67Vl0aL2zmUaHR1trVbbIsKT/ZyRGkiGUAPJEGogGUINJEOogWQINZAMoQaSIdRAMoQaSIZQA8nUDnV1Qf/XbHPRQaCPdTNS3ylprFQjAJpRd9qd+ZJulrS5bDsApqvuSP2ApLskfX26JzDtDtAf6szQcYukoxGx59uex7Q7QH+oM1Ivk7TG9kFJT0haafuxol0B6NmUoY6IeyJifkQskLRO0gsRcWvxzgD0hM+pgWTqTJD3PxExos5UtgD6FCM1kAyhBpIh1EAyhBpIhlADyRBqIBlCDSTT1efUQBOGhoZaq5V52p3TYaQGkiHUQDKEGkiGUAPJEGogGUINJEOogWQINZAMoQaSIdRAMrVOE62uJPqZpK8kfcllgIH+1c253ysi4pNinQBoBJvfQDJ1Qx2Snre9x/aGyZ7AtDtAf6i7+b0sIo7Y/oGkXbb3R8RLE58QEcOShiXJdjTcJ4Caao3UEXGk+vOopB2SlpZsCkDv6kyQd67tuaeWJd0oaV/pxgD0ps7m98WSdtg+9fzHI+K5ol0B6NmUoY6IA5J+3EIvABrAR1pAMoQaSIZQA8kQaiAZQg0kQ6iBZAg1kAyhBpIh1EAyhBpIhlADyRBqIBlCDSRDqIFkCDWQDKEGkiHUQDKEGkimVqhtD9rebnu/7THb15RuDEBv6l73+w+SnouIn9meI+mcgj0BmIYpQ237fEnXSvq5JEXESUkny7YFoFd1Nr+vkPSxpEdsv2Z7c3X9bwB9qE6oZ0laLOnBiFgk6QtJd3/zScylBfSHOqE+JOlQRLxS3d+uTsj/T0QMR8QS5q4GZtaUoY6IDyW9b3th9aNVkt4q2hWAntU9+n2HpK3Vke8Dkm4v1xKA6agV6ogYlcRmNXAG4IwyIBlCDSRDqIFkCDWQDKEGkiHUQDKEGkiGUAPJ1D2jDJLGx8dbrbdz587Waq1du7a1WsuXL2+t1pYtW1qr1S8YqYFkCDWQDKEGkiHUQDKEGkiGUAPJEGogGUINJEOogWSmDLXthbZHJ9xO2N7YQm8AejDlaaIR8bakIUmyPSDpsKQdZdsC0KtuN79XSXo3It4r0QyA6es21OskbSvRCIBm1A51dc3vNZKeOs3jTLsD9IFuvnp5k6S9EfHRZA9GxLCkYUmyHQ30BqAH3Wx+rxeb3kDfqxVq2+dIukHSM2XbATBddafd+bek7xfuBUADOKMMSIZQA8kQaiAZQg0kQ6iBZAg1kAyhBpIh1EAyjmj+NG3bH0vq9uuZF0r6pPFmZr5W2/Wo9d2odXlEXDTZA0VC3QvbuyNiSbZabdejFrXY/AaSIdRAMv0U6uGktdquR63veK2+2acG0Ix+GqkBNIBQA8kQaiAZQg0kQ6iBZP4Lhj5lm5RUNJYAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig, ax = plt.subplots()\n", "_ = ax.matshow(X[1].reshape((8, 8)), cmap='gray')\n", "print(\"Label: %d\" % Y[1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Split into training and test sets\n", "\n", "We split the data into a training set with 80% of the samples and a test set with the remaining 20%." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "X_train, X_test, Y_train, Y_test = model_selection.train_test_split(\n", " X, Y, test_size=0.2, random_state=10, shuffle=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data Preprocessing\n", "\n", "As always with Falkon we must:\n", " 1. Convert from numpy arrays to torch tensors\n", " 2. Convert data and labels to the same data-type (in this case float32)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "X_train = torch.from_numpy(X_train).to(dtype=torch.float32)\n", "X_test = torch.from_numpy(X_test).to(dtype=torch.float32)\n", "Y_train = torch.from_numpy(Y_train)\n", "Y_test = torch.from_numpy(Y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Normalizing the data is always a good idea. Here we use the global mean and standard deviation of the training set for z-score normalization." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# z-score normalization\n", "train_mean = X_train.mean()\n", "train_std = X_train.std()\n", "X_train -= train_mean\n", "X_train /= train_std\n", "X_test -= train_mean\n", "X_test /= train_std" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since Falkon optimizes with respect to the square loss, using ordinal labels (e.g. 1, 4, 5) is not ideal since closeness in the natural numbers is meaningless for classification. We therefore convert the labels to a 1-hot representation." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "First label vector: tensor([0., 0., 1., 0., 0., 0., 0., 0., 0., 0.])\n" ] } ], "source": [ "# Convert labels to 1-hot\n", "eye = torch.eye(10, dtype=torch.float32)\n", "Y_train = eye[Y_train]\n", "Y_test = eye[Y_test]\n", "print(\"First label vector: \", Y_train[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Search for the optimal parameters\n", "\n", "Since Falkon (with the Gaussian kernel) has only 3 important hyperparameters, it is entirely feasible to run a grid search over them to find the best parameter settings.\n", "\n", "Scikit-learn has great support for this, with the `GridSearchCV` class. For each paramater setting it will run 5-fold cross-validation on the training set, to determine which has the best results.\n", "\n", "Given the dataset is quite small, and Falkon is fast, we can run 160 model evaluations in around 40 seconds." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def mclass_loss(true, pred):\n", " true = torch.argmax(true, dim=1)\n", " pred = torch.argmax(pred, dim=1)\n", " return torch.mean((true != pred).to(torch.float32))\n", "mclass_scorer = metrics.make_scorer(mclass_loss, greater_is_better=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The parameter settings which will be explored by the grid-search are:\n", " - four different kernel length-scales (varying around small positive numbers, which are usually good for normalized data)\n", " - four different regularization values\n", " - two different values for M: the number of inducing points. As we will see, a larger `M` is almost always better than a smaller one (but it leads to longer training times). \n", " Of course this is not the case if the dataset is easy to overfit, since reducing `M` may also provide additional regularization.\n", "\n", "When we create the estimator we pass it additional parameters via the `FalkonOptions` class.\n", "In our case we want to ensure that the model runs on the CPU by setting `use_cpu=True`." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "parameter_grid = {\n", " 'kernel': [falkon.kernels.GaussianKernel(sigma=1), \n", " falkon.kernels.GaussianKernel(sigma=5), \n", " falkon.kernels.GaussianKernel(sigma=10),\n", " falkon.kernels.GaussianKernel(sigma=15),],\n", " 'penalty': [1e-3, 1e-5, 1e-7, 1e-9],\n", " 'M': [500, 1000],\n", "}\n", "estimator = falkon.Falkon(\n", " kernel=falkon.kernels.GaussianKernel(1), penalty=1e-3, M=1000, # Mandatory parameters, will be overridden\n", " maxiter=10, options=falkon.FalkonOptions(use_cpu=True))\n", "\n", "grid_search = model_selection.GridSearchCV(estimator, parameter_grid, scoring=mclass_scorer, cv=5)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The best parameters are: {'M': 500, 'kernel': GaussianKernel(sigma=Parameter containing:\n", "tensor([10.], dtype=torch.float64)), 'penalty': 1e-07}\n", "CPU times: user 52.3 s, sys: 1.78 s, total: 54.1 s\n", "Wall time: 13.6 s\n" ] } ], "source": [ "%%time\n", "grid_search.fit(X_train, Y_train)\n", "print(\"The best parameters are: \", grid_search.best_params_)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Evaluating the model\n", "We evaluate the model on the held-out set and see that we obtain a respectable 1% error on 10 classes." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 534 ms, sys: 23.8 ms, total: 558 ms\n", "Wall time: 139 ms\n" ] } ], "source": [ "%%time\n", "flk = grid_search.best_estimator_\n", "flk.fit(X_train, Y_train)\n", "test_pred = flk.predict(X_test)\n", "train_pred = flk.predict(X_train)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training error: 0.00%\n", "Test error: 1.11%\n" ] } ], "source": [ "print(\"Training error: %.2f%%\" % (mclass_loss(Y_train, train_pred) * 100))\n", "print(\"Test error: %.2f%%\" % (mclass_loss(Y_test, test_pred) * 100))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot grid-search results\n", "\n", "Plotting results from a grid-search is always useful, since it shows the range of parameters which were successful.\n", "If the initial grid was too coarse, one could then run a second grid search to obtain even better accuracy.\n", "\n", "In the plot red indicates a high error, while darker blue indicates a low error." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "res_df = pd.DataFrame.from_dict(grid_search.cv_results_)\n", "res_df[\"M\"] = res_df.params.apply(lambda x: x.get(\"M\"))\n", "res_df[\"penalty\"] = res_df.params.apply(lambda x: x.get(\"penalty\"))\n", "res_df[\"sigma\"] = res_df.params.apply(lambda x: x.get(\"kernel\").sigma.item())\n", "res_df = res_df[[\"mean_test_score\", \"M\", \"penalty\", \"sigma\"]]" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "def plot_heatmap(ax, df, xlabel, ylabel, value, scale):\n", " piv = pd.pivot_table(df, index=ylabel, columns=xlabel, values=value)\n", " cmap = plt.cm.get_cmap('coolwarm_r', 20)\n", " ax.grid(False)\n", " c = ax.pcolormesh(piv, cmap=cmap, vmin=scale[0], vmax=scale[1])\n", " ax.set_yticks(np.arange(piv.shape[0]) + 0.5, minor=False)\n", " ax.set_xticks(np.arange(piv.shape[1]) + 0.5, minor=False)\n", " ax.set_xticklabels(piv.columns, minor=False)\n", " ax.set_yticklabels(piv.index, minor=False)\n", " ax.set_xlabel(xlabel)\n", " ax.set_ylabel(ylabel)\n", " return c" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAt0AAAFNCAYAAADcudMsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAZaUlEQVR4nO3dfZCsZ1kn4N9tjghFkjUkbAqTiEBFsuEjEaJmtVaRddeQskwwspJCA1asCIruolQtlAq4H7qolFWAhI0xewjWRgkSCSUfKovEwogkkpAckXgA0UNSfEogREHCvX/Me6SdzJwzPWeemZ7T11XV1f1+dd/zVJ87v7z99NvV3QEAAMb5qp0uAAAAjnZCNwAADCZ0AwDAYEI3AAAMJnQDAMBgQjcAAAwmdAMAwGBCN0e9qvqbqvpiVZ20av0tVdVV9Q1b/Hp/XFX/WFX3TLcPrNr+76vqr6rq3qp6R1U9fGZbVdVLq+pT0+2Xq6q2sj6ARbIDPfq5VXVTVX2hqvausX3TPbqqvmE65t7pOb57K2tndxO6WRYfTnLxwYWqelySBw18ved297HT7dEzr3tSkjck+fkkD0lyU5LfmTnusiQXJjkryeOTfG+SHxtYJ8Ai2M4efWeS/5HkqtUbtqBHX5PkvUlOTPKzSV5fVQ/d8r+AXUnoZlm8NsklM8vPTHL1DtTx/Un2dfe13f2PSV6S5KyqOmOmrpd194Hu/miSlyV51g7UCbCdtq1Hd/cbuvv3knxqjc2b7tFV9Y1JnpDkxd39D939u0luS3LRiL+D3UfoZln8WZLjq+rfVNUxSX4wyW8d6oCqelVVfWad2/sO83q/VFWfrKp3VdWTZtY/JsmtBxe6+/NJPjitv9/26fFjAnB02+4evZ4j6dGPSfKh7v7cOttZcnt2ugDYRgfPpLwzyV8l+eihdu7uH0/y45t4nf+a5C+TfDHJ05O8qarO7u4PJjk2ySdW7X93kuOmx8dOy7Pbjq2q6u7eRC0Au8V29ehD2XSPXmPbwe2nbHGN7FJCN8vktUluSPKIDJxa0t3vnll8TVVdnOT8JK9Ick+S41cdcnySg2dGVm8/Psk9AjewBLalRx/Gpnt0VR3uWJac6SUsje7+SFa+rHN+Vr4oc0hV9eqZK5Csvu2b56WTHPx2+76sfAHn4Gs8OMmjpvX32z49nue1AHalHezRs46kR+9L8siqOm6d7Sw5oZtlc2mSJ0/z9A6pu589cwWS1bc15+hV1ddW1fdU1QOrak9VPSPJdyR527TLdUkeW1UXVdUDk7woyfu6+6+m7Vcn+emqOqWqvi7JzyTZe2R/MsCuMbRHJ8nUmx+Y5Jgkxxzs19PmTffo7r4jyS1JXjw951OzcoWT393EOHAUMr2EpTLNqx7pq7NyKaozktyXlXmJF3b3B6bX/0RVXZTklVn5ktC7szLv+6D/neSRWfnGe5JcOa0DOOptQ49Okp9L8uKZ5R9K8gtJXrIFPfrpWQnhf5/kb5P8QHevniPOkipTRQEAYCzTSwAAYDChGwAABhO6AQBgMKEbAAAGE7oBAGCwpbhk4IOOPbGPP+Hrd7qMXeOYPf5fbB4nP+ienS5hV7nn/R/Z6RJ2lf35wie7+6E7Xcd20rPno2fPR8+en749n/X69lKE7uNP+Pr84M+8c6fL2DUecuKDdrqEXeU/P+5dO13CrvKn3/xjO13CrvK9X7pj6f5rp2fPR8+ej549P317Puv1bf97DAAAgwndAAAwmNANAACDCd0AADCY0A0AAIMJ3QAAMJjQDQAAgwndAAAwmNANAACDCd0AADCY0A0AAIMJ3QAAMJjQDQAAgwndAAAwmNANAACDCd0AADCY0A0AAIMJ3QAAMJjQDQAAgwndAAAwmNANAACDCd0AADCY0A0AAIMJ3QAAMJjQDQAAgwndAAAwmNANAACDCd0AADDY0NBdVedV1Qeqan9VvWCN7VVVL5+2v6+qnnC4Y6vqaVW1r6q+XFXnjKwfYJno2QDjDAvdVXVMkl9P8pQkZya5uKrOXLXbU5KcPt0uS3L5Bo69Pcn3J7lhVO0Ay0bPBhhr5Jnub0myv7s/1N1fTPLbSS5Ytc8FSa7uFX+W5Gur6mGHOra739/dHxhYN8Ay0rMBBhoZuk9J8nczywemdRvZZyPHArB19GyAgUaG7lpjXW9wn40ce+gXr7qsqm6qqpv+4fOfmudQgGWkZwMMNDJ0H0hy2szyqUnu3OA+Gzn2kLr7iu4+p7vPedCDT5znUIBlpGcDDDQydL8nyelV9YiqekCSpye5ftU+1ye5ZPpG/LlJ7u7uuzZ4LABbR88GGGjPqCfu7i9V1XOTvC3JMUmu6u59VfXsafurk7w5yflJ9ie5N8mPHOrYJKmqpyZ5RZKHJvn9qrqlu79n1N8BsAz0bICxhoXuJOnuN2elSc+ue/XM407yExs9dlp/XZLrtrZSAPRsgHH8IiUAAAwmdAMAwGBCNwAADCZ0AwDAYEI3AAAMJnQDAMBgQjcAAAwmdAMAwGBCNwAADCZ0AwDAYEI3AAAMJnQDAMBgQjcAAAwmdAMAwGBCNwAADCZ0AwDAYEI3AAAMJnQDAMBgQjcAAAwmdAMAwGBCNwAADCZ0AwDAYEI3AAAMJnQDAMBgQjcAAAwmdAMAwGBCNwAADCZ0AwDAYEI3AAAMtmenC2Dx/NHr/nSnS9hV3vPOh+50CbvKN191+06XsLtc8oCdroAFp2fPR8+en749p3X6tjPdAAAwmNANAACDCd0AADCY0A0AAIMJ3QAAMJjQDQAAgwndAAAwmNANAACDCd0AADCY0A0AAIMJ3QAAMJjQDQAAgwndAAAwmNANAACDCd0AADCY0A0AAIMJ3QAAMJjQDQAAgwndAAAwmNANAACDCd0AADCY0A0AAIMJ3QAAMJjQDQAAgwndAAAwmNANAACDCd0AADCY0A0AAIMJ3QAAMNiw0F1VV1XVx6vq9k0c+8Squq2q9lfVy6uqpvXPqqpPVNUt0+1Ht75ygOWkbwOMM/JM994k523y2MuTXJbk9Ok2+zy/091nT7crj6xEAGbsjb4NMMSw0N3dNyT59Oy6qnpUVb21qm6uqj+pqjNWH1dVD0tyfHff2N2d5OokF46qE4AV+jbAONs9p/uKJD/Z3U9M8vwkr1pjn1OSHJhZPjCtO+iiqnpfVb2+qk4bVyoA0bcBtsSe7Xqhqjo2ybcluXaa6pckX7PWrmus6+n+TUmu6e4vVNWzk7wmyZPXeb3LsvJRZ447QY8HmNd29m09GzjabVvozspZ9c9099mzK6vqmCQ3T4vXZ2Ve4Kkzu5ya5M4k6e5Pzaz/jSQvXe/FuvuKrJyhycmnfVOvtx8A69q2vq1nA0e7bZte0t2fTfLhqnpaktSKs7r7vpkv2Lyou+9K8rmqOnf69vslSd44HfOwmaf8viTv3676AZaNvg2wdUZeMvCaJDcmeXRVHaiqS5M8I8mlVXVrkn1JLljn8OckuTLJ/iQfTPKWaf1PVdW+6fifSvKsUfUDLBt9G2CcYdNLuvvidTYd9nJU3X1Tkseusf6FSV54hKUBsAZ9G2Acv0gJAACDCd0AADCY0A0AAIMJ3QAAMJjQDQAAgwndAAAwmNANAACDCd0AADCY0A0AAIMJ3QAAMJjQDQAAgwndAAAw2IZCd1U9ZHQhAGwNPRtg8Wz0TPe7q+raqjq/qmpoRQAcKT0bYMFsNHR/Y5Irkvxwkv1V9YtV9Y3jygLgCOjZAAtmQ6G7V/xhd1+c5EeTPDPJn1fVO6vq3w6tEIC56NkAi2fPRnaqqhOT/FBWzpp8LMlPJrk+ydlJrk3yiEH1ATAnPRtg8WwodCe5Mclrk1zY3Qdm1t9UVa/e+rIAOAJ6NsCC2eic7p/r7v8+27yr6mlJ0t0vHVIZAJulZwMsmI2G7hesse6FW1kIAFtGzwZYMIecXlJVT0lyfpJTqurlM5uOT/KlkYUBMB89G2BxHW5O951Jbk7yfdP9QZ9L8rxRRQGwKXo2wII6ZOju7luT3FpVv9XdzpIALDA9G2BxHW56yW1Jenp8v+3d/fgxZQEwLz0bYHEdbnrJ925LFQBsBT0bYEEdbnrJR7arEACOjJ4NsLg2dMnAqjq3qt5TVfdU1Rer6r6q+uzo4gCYn54NsHg2ep3uVya5OMlfJ3lQkh9N8opRRQFwRPRsgAWz0Z+BT3fvr6pjuvu+JP+nqv50YF0AHAE9G2CxbDR031tVD0hyS1X9cpK7kjx4XFmwe9z9sU/udAm7yh+9znhtAz0b1qFnz0/f3hobnV7yw9O+z03y+SSnJbloVFEAHBE9G2DBbOhM98w34v8xyS+MKweAI6VnAyyeDYXuqvr2JC9J8vDZY7r7kWPKAmCz9GyAxbPROd2/meR5SW5Oct+4cgDYAno2wILZaOi+u7vfMrQSALaKng2wYDYaut9RVb+S5A1JvnBwZXf/xZCqADgSejbAgtlo6P7W6f6cmXWd5MlbWw4AW0DPBlgwG716yXeNLgSAraFnAyyeDV2nu6pOrqrfrKq3TMtnVtWlY0sDYDP0bIDFs9Efx9mb5G1Jvm5aviPJfxlQDwBHbm/0bICFstHQfVJ3vy7Jl5Oku78Ul6ECWFR6NsCC2Wjo/nxVnZiVL+Kkqs5NcvewqgA4Eno2wILZ6NVLfjrJ9UkeWVXvSvLQJD8wrCoAjoSeDbBgNhq6/zLJdUnuTfK5JL+XlTmCACwePRtgwWx0esnVSc5I8otJXpHk9CSvHVUUAEdEzwZYMBs90/3o7j5rZvkdVXXriIIAOGJ6NsCC2eiZ7vdOX8RJklTVtyZ515iSADhCejbAgpnnZ+Avqaq/nZa/Psn7q+q2JN3djx9SHQCboWcDLJiNhu7zhlYBwFbSswEWzIZCd3d/ZHQhAGwNPRtg8Wx0TjcAALBJQjcAAAwmdAMAwGBCNwAADCZ0AwDAYEI3AAAMJnQDAMBgQjcAAAwmdAMAwGDDQndVXVVVH6+q2zdx7BOr6raq2l9VL6+qmtb/WlXdMt3uqKrPbHnhAEtK3wYYZ+SZ7r1JztvksZcnuSzJ6dPtvCTp7ud199ndfXaSVyR5w5GXCcBkb/RtgCGGhe7uviHJp2fXVdWjquqtVXVzVf1JVZ2x+riqeliS47v7xu7uJFcnuXCNl7g4yTUDSgdYSvo2wDh7tvn1rkjy7O7+66r61iSvSvLkVfuckuTAzPKBad0/q6qHJ3lEkv83sFYA9G2ALbFtobuqjk3ybUmunab6JcnXrLXrGut61fLTk7y+u+87xOtdlpWPOnPcCafNXS/AstvOvq1nA0e77TzT/VVJPjPN6/tnVXVMkpunxeuzMi/w1JldTk1y56rnenqSnzjUi3X3FVk5Q5OTT/um1c0fgMPbtr6tZwNHu227ZGB3fzbJh6vqaUlSK87q7vsOfsmmu1/U3Xcl+VxVnTt9+/2SJG88+DxV9egkJyS5cbtqB1hG+jbA1hl5ycBrstJgH11VB6rq0iTPSHJpVd2aZF+SC9Y5/DlJrkyyP8kHk7xlZtvFSX57+rIOAFtE3wYYZ9j0ku6+eJ1Nh70cVXfflOSx62x7yRGUBcA69G2AcfwiJQAADCZ0AwDAYEI3AAAMJnQDAMBgQjcAAAwmdAMAwGBCNwAADCZ0AwDAYEI3AAAMJnQDAMBgQjcAAAwmdAMAwGBCNwAADCZ0AwDAYEI3AAAMJnQDAMBgQjcAAAwmdAMAwGBCNwAADCZ0AwDAYEI3AAAMJnQDAMBgQjcAAAwmdAMAwGBCNwAADCZ0AwDAYEI3AAAMJnQDAMBgQjcAAAy2Z6cLAJbLvzr5pJ0uAYA56Ntbw5luAAAYTOgGAIDBhG4AABhM6AYAgMGEbgAAGEzoBgCAwYRuAAAYTOgGAIDBhG4AABhM6AYAgMGEbgAAGEzoBgCAwYRuAAAYTOgGAIDBhG4AABhM6AYAgMGEbgAAGEzoBgCAwYRuAAAYTOgGAIDBhG4AABhM6AYAgMGEbgAAGEzoBgCAwYRuAAAYTOgGAIDBhG4AABhM6AYAgMGEbgAAGGxY6K6qq6rq41V1+yaOfWJV3VZV+6vq5VVV0/qHV9Xbq+p9VfXHVXXq1lcOsJz0bYBxRp7p3pvkvE0ee3mSy5KcPt0OPs+vJrm6ux+f5L8l+aUjrBGAr9gbfRtgiGGhu7tvSPLp2XVV9aiqemtV3VxVf1JVZ6w+rqoeluT47r6xuzvJ1UkunDafmeTt0+N3JLlgVP0Ay0bfBhhnu+d0X5HkJ7v7iUmen+RVa+xzSpIDM8sHpnVJcmuSi6bHT01yXFWdOKhWAPRtgC2xZ7teqKqOTfJtSa6dpvolydestesa63q6f36SV1bVs5LckOSjSb60zutdlpWPOnPcCadtum6AZbWdfVvPBo522xa6s3JW/TPdffbsyqo6JsnN0+L1WZkXOPtFm1OT3Jkk3X1nku+fjjs2yUXdffdaL9bdV2TlDE1OPu2beq19ADikbevbejZwtNu26SXd/dkkH66qpyVJrTiru+/r7rOn24u6+64kn6uqc6dvv1+S5I3TMSdV1cGaX5jkqu2qH2DZ6NsAW2fkJQOvSXJjkkdX1YGqujTJM5JcWlW3JtmX9b9Q85wkVybZn+SDSd4yrX9Skg9U1R1JTk7yP0fVD7Bs9G2AcYZNL+nui9fZdNjLUXX3TUkeu8b61yd5/RGWBsAa9G2AcfwiJQAADCZ0AwDAYEI3AAAMJnQDAMBgQjcAAAwmdAMAwGBCNwAADCZ0AwDAYEI3AAAMJnQDAMBgQjcAAAwmdAMAwGBCNwAADCZ0AwDAYEI3AAAMJnQDAMBgQjcAAAwmdAMAwGBCNwAADCZ0AwDAYEI3AAAMJnQDAMBgQjcAAAwmdAMAwGBCNwAADCZ0AwDAYEI3AAAMJnQDAMBgQjcAAAxW3b3TNQxXVZ9I8pGdrmMNJyX55E4XsYsYr/kYr/ks6ng9vLsfutNFbCc9+6hhvOZnzOazqOO1Zt9eitC9qKrqpu4+Z6fr2C2M13yM13yMF4fjPTIf4zU/Yzaf3TZeppcAAMBgQjcAAAwmdO+sK3a6gF3GeM3HeM3HeHE43iPzMV7zM2bz2VXjZU43AAAM5kw3AAAMJnQPVlVXVdXHq+r2dbZXVb28qvZX1fuq6gnbXeOiqaq/qarbquqWqrppje1LPWZrvaeq6iFV9YdV9dfT/QnrHHteVX1gGrsXbF/VO2ed8XpJVX10eo/dUlXnr3Ps0o3XstOz56dnH5qePb+jtW8L3ePtTXLeIbY/Jcnp0+2yJJdvQ027wXd199nrXApo2cdsb+7/nnpBkrd39+lJ3j4t/wtVdUySX8/K+J2Z5OKqOnNsqQthb9b+N/hr03vs7O5+8+qNSzxey25v9OzN0LPXtzd69rz25ijs20L3YN19Q5JPH2KXC5Jc3Sv+LMnXVtXDtqe6XWupx2yd99QFSV4zPX5NkgvXOPRbkuzv7g919xeT/PZ03FFtA/8G17OU47Xs9OwhlnrM9Oz5Ha19W+jeeack+buZ5QPTumXWSf6gqm6uqsvW2G7M7u/k7r4rSab7f73GPsbtX3ru9FH3Vet8tGu8WIv3xf3p2fPTszdnV/dtoXvn1Rrrlv2SMt/e3U/IysdDP1FV37FquzHbHOP2FZcneVSSs5PcleRla+xjvFiL98X96dljGLd/adf3baF75x1IctrM8qlJ7tyhWhZCd9853X88yXVZ+bholjG7v48d/Lh2uv/4GvsYt0l3f6y77+vuLyf5jdz/PZYYL9bmfbGKnr0pevacjoa+LXTvvOuTXDJ9u/vcJHcf/MhpGVXVg6vquIOPk/zHJKuvImDM7u/6JM+cHj8zyRvX2Oc9SU6vqkdU1QOSPH06bumsmk/61Nz/PZYYL9am/8zQszdNz57T0dC39+x0AUe7qromyZOSnFRVB5K8OMlXJ0l3vzrJm5Ocn2R/knuT/MjOVLowTk5yXVUlK+/P/9vdb62qZyfGLFn3PfW/kryuqi5N8rdJnjbt+3VJruzu87v7S1X13CRvS3JMkqu6e99O/A3baZ3xelJVnZ2Vjx3/JsmPTfsu/XgtOz17bnr2YejZ8zta+7ZfpAQAgMFMLwEAgMGEbgAAGEzoBgCAwYRuAAAYTOgGAIDBhG44hKq6sqrO3Ok6ADg8PZtF5pKBAAAwmDPdMJl+We33q+rWqrq9qn6wqv64qs6Ztl9aVXdM636jql45rd9bVZdX1Tuq6kNV9Z1VdVVVvb+q9s48/+VVdVNV7auqX9ihPxPgqKBns9sI3fAV5yW5s7vP6u7HJnnrwQ3TL179fJJzk/yHJGesOvaEJE9O8rwkb0rya0kek+Rx0y9oJcnPdvc5SR6f5Dur6vED/xaAo52eza4idMNX3Jbku6vqpVX177r77plt35Lknd396e7+pyTXrjr2Tb0yV+u2JB/r7tu6+8tJ9iX5hmmf/1RVf5HkvVlp7uYdAmyens2usmenC4BF0d13VNUTk5yf5Jeq6g9mNtdhDv/CdP/lmccHl/dU1SOSPD/JN3f3308fYT5wayoHWD56NruNM90wmT6OvLe7fyvJryZ5wszmP8/Kx4snVNWeJBfN+fTHJ/l8krur6uQkT9mKmgGWlZ7NbuNMN3zF45L8SlV9Ock/JXlOVhp5uvujVfWLSd6d5M4kf5nk7vWeaLXuvrWq3puVjy4/lORdW1w7wLLRs9lVXDIQNqiqju3ue6azJtcluaq7r9vpugC4Pz2bRWN6CWzcS6rqliS3J/lwkt/b0WoAOBQ9m4XiTDcAAAzmTDcAAAwmdAMAwGBCNwAADCZ0AwDAYEI3AAAMJnQDAMBg/x9D103bFBoy5gAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig, axs = plt.subplots(ncols=2, figsize=(12, 5))\n", "scale = (res_df[\"mean_test_score\"].min(), res_df[\"mean_test_score\"].max())\n", "\n", "c = plot_heatmap(axs[0], res_df[res_df.M == 500], \"sigma\", \"penalty\", \"mean_test_score\", scale)\n", "axs[0].set_title(\"M = 500\")\n", "c = plot_heatmap(axs[1], res_df[res_df.M == 1000], \"sigma\", \"penalty\", \"mean_test_score\", scale)\n", "_ = axs[1].set_title(\"M = 1000\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.9" }, "pycharm": { "stem_cell": { "cell_type": "raw", "metadata": { "collapsed": false }, "source": [] } }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 1 }