{ "cells": [ { "cell_type": "markdown", "source": [ "# MNIST Classification with Falkon" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 1, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[pyKeOps]: Warning, no cuda detected. Switching to cpu only.\n" ] } ], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "plt.style.use('ggplot')\n", "\n", "import torch\n", "import torchvision\n", "\n", "import falkon" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Download the MNIST dataset & load it in memory" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "mnist_train_dataset = torchvision.datasets.MNIST(\n", " root=\".\", train=True, download=True,\n", " transform=torchvision.transforms.ToTensor())\n", "mnist_test_dataset = torchvision.datasets.MNIST(\n", " root=\".\", train=False, download=True,\n", " transform=torchvision.transforms.ToTensor())" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Load the whole dataset in memory\n", "mnist_tr_img_list, mnist_tr_label_list = [], []\n", "for i in range(len(mnist_train_dataset)):\n", " data_point = mnist_train_dataset[i]\n", " mnist_tr_img_list.append(data_point[0])\n", " mnist_tr_label_list.append(data_point[1])\n", "mnist_ts_img_list, mnist_ts_label_list = [], []\n", "for i in range(len(mnist_test_dataset)):\n", " data_point = mnist_test_dataset[i]\n", " mnist_ts_img_list.append(data_point[0])\n", " mnist_ts_label_list.append(data_point[1])" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "mnist_tr_x = torch.vstack(mnist_tr_img_list)\n", "mnist_tr_x = mnist_tr_x.reshape(mnist_tr_x.shape[0], -1)\n", "mnist_ts_x = torch.vstack(mnist_ts_img_list)\n", "mnist_ts_x = mnist_ts_x.reshape(mnist_ts_x.shape[0], -1)\n", "mnist_tr_y = torch.tensor(mnist_tr_label_list)\n", "mnist_ts_y = torch.tensor(mnist_ts_label_list)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAUXUlEQVR4nO3dfVRUZ34H8O8MoICEcXgJRpZJAuoqCa4loFnXBA2j61GPIa6auolWkx5LwCRKYpbYHuk22p1snEWTxXi6TTSxm1ZMV5Km2+Q42gxbjYKlxrO+g1pfFkFgQIyQwMztH/ZcHPU+g8O8XH2+n7/und+9d37O+OW+PDNzDYqiKCCiu54x3A0QUWgw7ESSYNiJJMGwE0mCYSeSBMNOJInIgax88OBBbN68GR6PB/n5+SgoKPC5zlTjPHW6osaG4vGlA2khaPTam177AtibvwLZ207Pds2a33t2j8eD9957D6tWrUJ5eTn27NmD8+fP+7s5Igoyv8NeX1+PYcOGISUlBZGRkZg4cSJqa2sD2RsRBZDB30/Q7du3DwcPHkRhYSEAoLq6GidPnsTzzz/vtZzD4YDD4QAA2Gw2nDjQoNYsY1Jx9ugFf3sPKr32pte+APbmr0D2NionQ7Pm9zn7rf5GGAyGmx6zWq2wWq3q/PXnJrKcRwWSXvsC2Ju/dH/OnpiYiNbWVnW+tbUVZrPZ380RUZD5HfaMjAw0NjaiubkZvb292Lt3L3JycgLZGxEFkN+H8REREXjuueewdu1aeDweTJkyBWlpaYHsjYgCaEDj7NnZ2cjOzg5UL0QURPwEHZEkGHYiSTDsRJJg2IkkwbATSYJhJ5IEw04kCYadSBIMO5EkGHYiSTDsRJJg2IkkwbATSYJhJ5IEw04kCYadSBIMO5EkGHYiSTDsRJJg2IkkwbATSYJhJ5IEw04kCYadSBIMO5EkGHYiSTDsRJJg2IkkwbATSWJAd3El/TNEit/iiOSkwD1ZVBQi7xvm9dDxVx/QXNwd6xFu7v6MZmE9tsggrF/81SB1umdENC59+n11vi5nm3DdFvc3wvqE7a8I6yNK9gnr4TCgsBcXFyM6OhpGoxERERGw2WyB6ouIAmzAe/aysjLEx8cHohciCiKesxNJwqAoiuLvysXFxYiLiwMATJ06FVar9aZlHA4HHA4HAMBms+HEgQa1ZhmTirNHL/j79EGl195uuy+D+LwWPs7pb4dlZArOnmzyeuzbYYM0lgYUH7uaQYN7hHXj/4r/bT1pffX0uGScunJJnc8c0iZct1cRX0846koR1qPPic/5rxfI/2ujcjI0awMKe1tbGxISEtDR0YE1a9ZgyZIlyMzMFK4z1ThPna6osaF4fKm/Tx9Ueu3tdvsK5QW6Db9fiZdnvOX1mF4u0P120gt45r/eVef1dIEukP/Xdnq2a9YGdBifkJAAADCZTMjNzUV9ff1ANkdEQeR32Lu7u9HV1aVOHzp0CBaLJWCNEVFg+X3C1tHRgXXr1gEA3G43Jk2ahHHjxgWqr7tKxJiRwroyOEpY/1PeUHW6JyUOF1+e6FXvelT7kDPBJD4c/cMPxIeztyMqcSmqDvx7wLb3H1fvEdbf/PV0YX1/1kfqdFTs09j/SN/86Z4u4bq2pqnC+vA/+H32GzZ+hz0lJQVvvfWW7wWJSBc49EYkCYadSBIMO5EkGHYiSTDsRJLgV1wDwD05W1j/1ZYKYX1UlPZHSm8UlWjFgdfe6ffyetajuIX11e8sFtYjvxEPf/1w+zJ1+p/+OhnPru2bv+dCr3DdwS3iobnYA/uFdT3inp1IEgw7kSQYdiJJMOxEkmDYiSTBsBNJgmEnkgTH2QNg8PE/Cev/3Z0mrI+KahLWw+mVxkeF9VNX+n7pZkOcCS+fnOVV35Lxsea6HR7xOHnK23v70WH/RP7Vk0h876t+L3/nfYHVN+7ZiSTBsBNJgmEnkgTDTiQJhp1IEgw7kSQYdiJJcJw9AHobLwrr77w5T1hfO138c88Rh+LU6Y8L78XcTS961b8u8v/77Wtaxgrr9dZYYd3d3qhOKzU96Jnc6FX/6Q+LNNc985K4twfxtXgBui3csxNJgmEnkgTDTiQJhp1IEgw7kSQYdiJJMOxEkuA4ewgkbBZ/jzr53xKFdXdrmzo96MnZSPt77+099PBzmusefvx94bY//Yc8Yf3e9oF9p9zwlfZY+YP9/3o5BYDPsG/cuBF1dXUwmUyw2+0AgCtXrqC8vByXLl1CcnIyVqxYgbi4OB9bIqJw8nkYP3nyZKxatcrrsaqqKmRlZeHtt99GVlYWqqqqgtUfEQWIz7BnZmbetNeura1FXt61w7+8vDzU1tYGpzsiChi/ztk7OjpgNpsBAGazGZcvX9Zc1uFwwOFwAABsNhsqamxqzTIm1WteT0LaW6SPt6G3775kljGpqNj/C6/yt+kxmqtGxT0p3PRHK+8V1qP+Yra4t+vw/fRPqHoL+gU6q9UKq9WqzhePL1WnK2psXvN6EsreIpL6f4GuYv8vUDzhda/66Y+0v8zi6wLdT996UVi/t6L/F+j4fvonkL3t9GzXrPk19GYymeByuQAALpcL8fHx/nVGRCHjV9hzcnLgdDoBAE6nE7m5uQFtiogCz+dh/Pr163HkyBF0dnaisLAQ8+fPR0FBAcrLy7F7924kJSWhpKQkFL3etdwtrbe3guL9q+Y9l/t/f/cbPfTMEWH90rsR4g14xPdYJ/3wGfbly5ff8vHVq1cHuhciCiJ+XJZIEgw7kSQYdiJJMOxEkmDYiSTBr7jeBcb87IRmbUlWvnDdzffvEtbz5hUL6/ds2yesk35wz04kCYadSBIMO5EkGHYiSTDsRJJg2IkkwbATSYLj7HcBd3uHZq31hTHCdc9+2iWsl675UFh/ff5T6vR3GTE4/68PedWV/zFprpu21sdvSd/wVV4aGO7ZiSTBsBNJgmEnkgTDTiQJhp1IEgw7kSQYdiJJcJz9Luf5+qiw/uc/Xyms/7ZsnbB+8NG+cfioIT/xmgcAPKq97kNDlgm3PfI3jcJ676kzwjp5456dSBIMO5EkGHYiSTDsRJJg2IkkwbATSYJhJ5IEx9kll/C++Dvly46Lfzc+3nZenf7lPfF47dSPver/nP6F5rqHF/1auO3RaX8prH//5+J9lfvkKWFdNj7DvnHjRtTV1cFkMsFutwMAKisrsWvXLsTHxwMAFixYgOzs7OB2SkQD4jPskydPxvTp01FRUeH1+MyZMzF79uygNUZEgeXznD0zMxNxcXGh6IWIgsigKL5/6Ku5uRlvvvmm12G80+lETEwM0tPTsWjRIs0/CA6HAw6HAwBgs9lw4kCDWrOMScXZoxcC8e8IOL32Fuq+lLhYYT3ie9+p09+LGYbzXRe96umDLvv93H/sTBLWoxvd4g10f6tO6vX9BALb26icDM2aX2Fvb29Xz9e3bdsGl8uFoqKifjUz1ThPna6osaF4fGm/1gs1vfYW6r6UH40T1r0u0P2gFK99bfOqiy7Q+TL6P31doNP+oU3A+wKdXt9PILC97fRs16z5NfQ2dOhQGI1GGI1G5Ofno6GhwfdKRBRWfoXd5XKp0zU1NUhLSwtYQ0QUHD6vxq9fvx5HjhxBZ2cnCgsLMX/+fBw+fBhnzpyBwWBAcnIyli5dGopeKQwMew4K61fn3qtOez4Hrs713n/kPv2i5rr7f7ZBuO1jU/5RWH/mgWnCesckYVk6PsO+fPnymx574okngtELEQURPy5LJAmGnUgSDDuRJBh2Ikkw7ESS4FdcaUDcTc19Mz293vMAUt5uhpbu13qF2441DBLWf/PAZ8L6rKeWq9OeoUNw9akJfdvesV+47t2Ie3YiSTDsRJJg2IkkwbATSYJhJ5IEw04kCYadSBIcZychz6RxwnrDvGh1ujttCE5u8L5H88Pjzmiu62sc3Zd32v5MWI/95IA6bXz9Ka95GXHPTiQJhp1IEgw7kSQYdiJJMOxEkmDYiSTBsBNJguPsdzlDzsPC+omXfHxn/EcfCOuPR/fd/inKPA3H51YIlr493yo9wvq+tgfFG/A03jDv43ZRdznu2YkkwbATSYJhJ5IEw04kCYadSBIMO5EkGHYiSXCc/Q4Q+eD9fTODB3nPA2hYMlxz3b99+l+E2/5JXMuAehuIVU05wrrzhu/G38j8wVeBbOeu5zPsLS0tqKioQHt7OwwGA6xWK2bMmIErV66gvLwcly5dQnJyMlasWIG4uLhQ9ExEfvAZ9oiICCxcuBDp6eno6upCaWkpxo4diy+//BJZWVkoKChAVVUVqqqq8Oyzz4aiZyLyg89zdrPZjPT0dABATEwMUlNT0dbWhtraWuTl5QEA8vLyUFtbG9xOiWhADIqiKP1duLm5GWVlZbDb7SgqKsKWLVvU2pIlS7B58+ab1nE4HHA4HAAAm82GEwca1JplTCrOHr0wgPaDR1e9De77/LolIwVnG5q8yt8mRmmuOjzBJdy02Si+39rtMERmQOlt8L3g/7vQGyusdzYNEdYjW7/p93Pp6v28QSB7G5WToVnr9wW67u5u2O12LF68GLGx4jfpelarFVarVZ0vHl+qTlfU2Lzm9URPvV1/QW5D1ct4uWCDV10vF+iiEj9BT+uT/V5+fQgv0Onp/bxRIHvb6dmuWevX0Ftvby/sdjsee+wxTJhw7U6YJpMJLte1vYbL5UJ8fHwAWiWiYPG5Z1cUBZs2bUJqaipmzZqlPp6TkwOn04mCggI4nU7k5uYGtdE7WeQDFmG945H7hPWn/+5zddr0wFVM/+ygV71w6O/87m2gXmns2/u+Fj8Ev2z03ht/tVF7752wpUa4bbOHQ2uB5DPsx48fR3V1NSwWC1auXAkAWLBgAQoKClBeXo7du3cjKSkJJSUlQW+WiPznM+yjR49GZWXlLWurV68OeENEFBz8uCyRJBh2Ikkw7ESSYNiJJMGwE0mCX3Htp8j7hmnW2t4Xf6zzhQedwvqCe5qE9etFRXyHwqGn+r28L8suTBLW694dJ6wnffxHdbr7SyNOTon2qid0cqxcL7hnJ5IEw04kCYadSBIMO5EkGHYiSTDsRJJg2IkkIc04+3c/Fv8qyncr2rzm3SMH48rn6er8qhG/11x3Wkz/fx4pGJrcXZq1xz99Rbju6L85JqwntIvHyT3Xz7jd8HR2Cpen8OGenUgSDDuRJBh2Ikkw7ESSYNiJJMGwE0mCYSeShDTj7GcKxH/XTmR530kjKuZZfJmlfXeN21HRrn1LHgDY4JwmrBvcBnX603nJmP27Iq/66DWnNdcd2bRfuG23sEp3E+7ZiSTBsBNJgmEnkgTDTiQJhp1IEgw7kSQYdiJJ+Bxnb2lpQUVFBdrb22EwGGC1WjFjxgxUVlZi165diI+PB3DtNs7Z2dlBb9hfo14Q3wt81guPeM1X1MSiePwjGksH1iiIe7te9ISnMPJF77FzjpVTf/gMe0REBBYuXIj09HR0dXWhtLQUY8eOBQDMnDkTs2fPDnqTRDRwPsNuNpthNpsBADExMUhNTUVbW5uPtYhIbwyKoij9Xbi5uRllZWWw2+347LPP4HQ6ERMTg/T0dCxatAhxcXE3reNwOOBwOAAANpsNJw40qDXLmFScPXohAP+MwNNrb3rtC2Bv/gpkb6NytD+a3e+wd3d3o6ysDHPmzMGECRPQ3t6unq9v27YNLpcLRUVFPrYCTDXOU6cramwoHl/an6cPOb32pte+APbmr0D2ttOj/X2Ofl2N7+3thd1ux2OPPYYJEyYAAIYOHQqj0Qij0Yj8/Hw0NDT42AoRhZPPsCuKgk2bNiE1NRWzZs1SH3e5XOp0TU0N0tLSgtMhEQWEzwt0x48fR3V1NSwWC1auXAng2jDbnj17cObMGRgMBiQnJ2Pp0qVBb5aI/Ocz7KNHj0ZlZeVNj+t5TJ2IbsZP0BFJgmEnkgTDTiQJhp1IEgw7kSQYdiJJMOxEkmDYiSTBsBNJgmEnkgTDTiQJhp1IEgw7kSQYdiJJ3NZv0BHRnSuse/bSUn3+Jhig39702hfA3vwVqt54GE8kCYadSBJhDbvVag3n0wvptTe99gWwN3+FqjdeoCOSBA/jiSTBsBNJwudPSQfDwYMHsXnzZng8HuTn56OgoCAcbdxScXExoqOjYTQaERERAZvNFrZeNm7ciLq6OphMJtjtdgDAlStXUF5ejkuXLiE5ORkrVqy45T32wtGbXm7jrXWb8XC/dmG//bkSYm63W1m2bJly8eJFpaenR3n11VeVc+fOhboNTUVFRUpHR0e421AURVEOHz6sNDQ0KCUlJepjW7duVXbs2KEoiqLs2LFD2bp1q25627Ztm/LJJ5+EpZ/rtbW1KQ0NDYqiKMrVq1eVl156STl37lzYXzutvkL1uoX8ML6+vh7Dhg1DSkoKIiMjMXHiRNTW1oa6jTtCZmbmTXue2tpa5OXlAQDy8vLC9trdqje9MJvNSE9PB+B9m/Fwv3ZafYVKyA/j29rakJiYqM4nJibi5MmToW5DaO3atQCAqVOn6m7IpqOjA2azGcC1/zyXL18Oc0fevvjiC1RXVwtv4x1Kzc3NOH36NEaMGKGr1+76vo4dOxaS1y3kYVduMdJnMBhC3YamN954AwkJCejo6MCaNWswfPhwZGZmhrutO8K0adMwd+5cANdu4/3hhx/26zbewdLd3Q273Y7FixcjNjY2bH3c6Ma+QvW6hfwwPjExEa2trep8a2ur+tdWDxISEgAAJpMJubm5qK+vD3NH3kwmk3oHXZfLpV7U0QM93cb7VrcZ18NrF87bn4c87BkZGWhsbERzczN6e3uxd+9e5OTkhLqNW+ru7kZXV5c6fejQIVgsljB35S0nJwdOpxMA4HQ6kZubG+aO+ujlNt6Kxm3Gw/3aafUVqtctLJ+gq6urwwcffACPx4MpU6Zgzpw5oW7hlpqamrBu3ToAgNvtxqRJk8La2/r163HkyBF0dnbCZDJh/vz5yM3NRXl5OVpaWpCUlISSkpKwnBffqrfDhw/fdBvvcBy1HTt2DKtXr4bFYlFPERcsWICRI0eG9bXT6utWtz8PxuvGj8sSSYKfoCOSBMNOJAmGnUgSDDuRJBh2Ikkw7ESSYNiJJPF/h2UBXzw4klYAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, ax = plt.subplots()\n", "ax.imshow(mnist_tr_x[0].reshape(28,28))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data Preprocessing\n", "\n", "We convert the labels to their one-hot representation. \n", "This is the best way to run multi-class classification with Falkon which minimizes the squared error" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# one-hot labels\n", "A = torch.eye(10, dtype=torch.float32)\n", "mnist_tr_y = A[mnist_tr_y.to(torch.long), :]\n", "mnist_ts_y = A[mnist_ts_y.to(torch.long), :]" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "def classif_error(y_true, y_pred):\n", " y_true = torch.argmax(y_true, dim=1)\n", " y_pred = torch.argmax(y_pred, dim=1)\n", " err = y_true.flatten() != y_pred.flatten()\n", " return torch.mean(err.to(torch.float32))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run Falkon\n", "\n", "There are a few parameters which need to be provided to the algorithm\n", "\n", " - The `FalkonOptions` class is used to provide non-standard tuning knobs. It allows to, for example, tune the amount of GPU memory the algorithm can use, adjust the convergence tolerance, and decide whether certain parts of the algorithm are computed on CPU or GPU. \n", " \n", " It can be used with default parameters for most purposes!\n", " \n", " - The **kernel** is the most important choice which depends on the data at hand. We use the `GaussianKernel` which is the most common option and initialize it with a length-scale of 15.\n", " \n", " - The **penalty** determines the amount of regularization. A higher value corresponds to more regularization.\n", " \n", " - The **number of centers** `M` strongly influences the time needed for fitting. By default the centers\n", " are chosen uniformly at random." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "options = falkon.FalkonOptions(use_cpu=True)\n", "kernel = falkon.kernels.GaussianKernel(sigma=15)\n", "flk = falkon.Falkon(kernel=kernel, \n", " penalty=1e-8,\n", " M=1000, \n", " maxiter=10,\n", " options=options,\n", " error_every=1,\n", " error_fn=classif_error)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iteration 1 - Elapsed 0.55s - training error: 0.11998333\n", "Iteration 2 - Elapsed 1.02s - training error: 0.07140000\n", "Iteration 3 - Elapsed 1.50s - training error: 0.05766667\n", "Iteration 4 - Elapsed 1.98s - training error: 0.05121667\n", "Iteration 5 - Elapsed 2.46s - training error: 0.04776667\n", "Iteration 6 - Elapsed 2.95s - training error: 0.04556667\n", "Iteration 7 - Elapsed 3.45s - training error: 0.04376667\n", "Iteration 8 - Elapsed 3.93s - training error: 0.04340000\n", "Iteration 9 - Elapsed 4.42s - training error: 0.04286667\n", "Iteration 10 - Elapsed 5.39s - training error: 0.04223333\n" ] } ], "source": [ "_ = flk.fit(mnist_tr_x, mnist_tr_y)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training error: 4.22%\n", "Test error: 4.13%\n" ] } ], "source": [ "train_pred = flk.predict(mnist_tr_x)\n", "test_pred = flk.predict(mnist_ts_x)\n", "\n", "print(\"Training error: %.2f%%\" % (classif_error(train_pred, mnist_tr_y) * 100))\n", "print(\"Test error: %.2f%%\" % (classif_error(test_pred, mnist_ts_y) * 100))\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.9" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }