{ "cells": [ { "cell_type": "markdown", "id": "be9efd3e", "metadata": {}, "source": [ "# Implementing A Custom Kernel\n", "\n", "In this notebook we will show how to implement a custom kernel in Falkon.\n", "\n", "There are several complementary parts to a kernel, which can be added to support different operations.\n", "We will go through them one-by-one in this notebook:\n", "\n", " - Basic support: supports learning with Falkon!\n", " - Autodiff support: supports automatic hyperparameter tuning (in the `hopt` module)\n", " - KeOps support: faster kernel-vector products in low dimension\n", " - Sparse support: support learning on sparse data." ] }, { "cell_type": "code", "execution_count": 1, "id": "cc7a4428", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[pyKeOps]: Warning, no cuda detected. Switching to cpu only.\n" ] } ], "source": [ "import matplotlib.pyplot as plt\n", "plt.style.use('ggplot')\n", "\n", "from sklearn import datasets\n", "import torch\n", "import numpy as np\n", "\n", "import falkon\n", "from falkon import FalkonOptions\n", "from falkon.kernels import Kernel, DiffKernel, KeopsKernelMixin" ] }, { "cell_type": "markdown", "id": "6b37d1d0", "metadata": { "heading_collapsed": true }, "source": [ "## Setup a simple problem for testing\n", "\n", "Load and preprocess the *California housing* dataset. The `learn_with_kernel` function sets up Falkon for learning on the California housing datase with a given kernel." ] }, { "cell_type": "code", "execution_count": 2, "id": "be73c315", "metadata": { "hidden": true }, "outputs": [], "source": [ "X, Y = datasets.fetch_california_housing(return_X_y=True)\n", "num_train = int(X.shape[0] * 0.8)\n", "num_test = X.shape[0] - num_train\n", "shuffle_idx = np.arange(X.shape[0])\n", "np.random.shuffle(shuffle_idx)\n", "train_idx = shuffle_idx[:num_train]\n", "test_idx = shuffle_idx[num_train:]\n", "\n", "Xtrain, Ytrain = X[train_idx], Y[train_idx]\n", "Xtest, Ytest = X[test_idx], Y[test_idx]\n", "# convert numpy -> pytorch\n", "Xtrain = torch.from_numpy(Xtrain).to(dtype=torch.float32)\n", "Xtest = torch.from_numpy(Xtest).to(dtype=torch.float32)\n", "Ytrain = torch.from_numpy(Ytrain).to(dtype=torch.float32)\n", "Ytest = torch.from_numpy(Ytest).to(dtype=torch.float32)\n", "# z-score normalization\n", "train_mean = Xtrain.mean(0, keepdim=True)\n", "train_std = Xtrain.std(0, keepdim=True)\n", "Xtrain -= train_mean\n", "Xtrain /= train_std\n", "Xtest -= train_mean\n", "Xtest /= train_std" ] }, { "cell_type": "code", "execution_count": 3, "id": "a58a7c95", "metadata": { "hidden": true }, "outputs": [], "source": [ "def rmse(true, pred):\n", " return torch.sqrt(torch.mean((true.reshape(-1, 1) - pred.reshape(-1, 1))**2))\n", "\n", "def learn_with_kernel(kernel):\n", " flk_opt = FalkonOptions(use_cpu=True)\n", " model = falkon.Falkon(\n", " kernel=kernel, penalty=1e-5, M=1000, options=flk_opt,\n", " error_every=1, error_fn=rmse)\n", " model.fit(Xtrain, Ytrain)\n", " ts_err = rmse(Ytest, model.predict(Xtest))\n", " print(\"Test RMSE: %.2f\" % (ts_err))" ] }, { "cell_type": "markdown", "id": "1d046959", "metadata": {}, "source": [ "## Basic Kernel Implementation\n", "\n", "We must inherit from the `falkon.kernels.Kernel` class, and implement:\n", " - `compute` method: the core of the kernel implementation. \n", " Given two input matrices (of size $n\\times d$ and $m\\times d$), and an output matrix (of size $n\\times m$), compute the kernel function between the two inputs and store it in the output.\n", " \n", " The additional `diag` parameter is a boolean flag. It indicates that a) $n$ is equal to $m$, b) only the diagonal of the kernel matrix should be computed.\n", " - `compute_sparse` method: this should be used if you want your kernel to support sparse data. \n", " We will implement it in a later section.\n", " \n", "We will implement a **linear** kernel:\n", "$$k(x, x') = \\sigma (x^\\top x')$$\n", "the parameter $\\sigma$ is the *variance* of the kernel. It is the only hyperparameter." ] }, { "cell_type": "code", "execution_count": 4, "id": "bfd90b30", "metadata": {}, "outputs": [], "source": [ "class BasicLinearKernel(Kernel):\n", " def __init__(self, lengthscale, options):\n", " # The base class takes as inputs a name for the kernel, and\n", " # an instance of `FalkonOptions`.\n", " super().__init__(\"basic_linear\", options)\n", " \n", " self.lengthscale = lengthscale\n", " \n", " def compute(self, X1: torch.Tensor, X2: torch.Tensor, out: torch.Tensor, diag: bool) -> torch.Tensor:\n", " # To support different devices/data types, you must make sure\n", " # the lengthscale is compatible with the data.\n", " lengthscale = self.lengthscale.to(device=X1.device, dtype=X1.dtype)\n", "\n", " scaled_X1 = X1 * lengthscale\n", " \n", " if diag:\n", " out.copy_(torch.sum(scaled_X1 * X2, dim=-1))\n", " else:\n", " # The dot-product row-by-row on `X1` and `X2` can be computed\n", " # on many rows at a time with matrix multiplication.\n", " out = torch.matmul(scaled_X1, X2.T, out=out)\n", "\n", " return out\n", " \n", " def compute_sparse(self, X1, X2, out, diag, **kwargs) -> torch.Tensor:\n", " raise NotImplementedError(\"Sparse not implemented\")" ] }, { "cell_type": "markdown", "id": "f19f3e95", "metadata": {}, "source": [ "### Test the basic kernel" ] }, { "cell_type": "code", "execution_count": 5, "id": "7bfcd1be", "metadata": {}, "outputs": [], "source": [ "# Initialize the kernel\n", "lengthscale_init = torch.tensor([1.0])\n", "k = BasicLinearKernel(lengthscale_init, options=falkon.FalkonOptions())" ] }, { "cell_type": "code", "execution_count": 6, "id": "32b86480", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-1.3538, 4.0383, -0.5058, -3.1306, -0.3159],\n", " [-0.9498, -2.0581, 0.4684, 0.8994, 0.7577],\n", " [ 0.3122, -0.1038, -0.5039, 2.5076, -0.4032],\n", " [ 0.8383, 3.8545, -1.4094, 1.0497, -1.4979],\n", " [ 0.8344, -4.5258, 2.9362, -7.7300, 2.0740]])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# The kernel matrix\n", "k(torch.randn(5, 3), torch.randn(5, 3))" ] }, { "cell_type": "code", "execution_count": 7, "id": "a3134217", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[6.1084],\n", " [3.6743],\n", " [1.2653],\n", " [1.2448]])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Kernel-vector product\n", "k.mmv(torch.randn(4, 3), torch.randn(4, 3), v=torch.randn(4, 1))" ] }, { "cell_type": "code", "execution_count": 8, "id": "53521733", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ -3.6467],\n", " [ -9.8628],\n", " [ 1.4857],\n", " [-12.8557]])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Double kernel-vector product\n", "k.dmmv(torch.randn(3, 3), torch.randn(4, 3), v=torch.randn(4, 1), w=torch.randn(3, 1))" ] }, { "cell_type": "code", "execution_count": 9, "id": "31837bb2", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iteration 1 - Elapsed 0.07s - training error: 2.36367178\n", "Iteration 2 - Elapsed 0.11s - training error: 2.19508219\n", "Iteration 3 - Elapsed 0.14s - training error: 2.19265079\n", "Iteration 4 - Elapsed 0.17s - training error: 2.19265032\n", "Iteration 5 - Elapsed 0.20s - training error: 2.19262338\n", "Iteration 6 - Elapsed 0.24s - training error: 2.19262123\n", "Iteration 7 - Elapsed 0.27s - training error: 2.19261861\n", "Iteration 8 - Elapsed 0.30s - training error: 2.19261885\n", "Iteration 9 - Elapsed 0.33s - training error: 2.19261789\n", "Iteration 10 - Elapsed 0.39s - training error: 2.19261765\n", "Iteration 11 - Elapsed 0.42s - training error: 2.19261956\n", "Iteration 12 - Elapsed 0.45s - training error: 2.19261932\n", "Iteration 13 - Elapsed 0.48s - training error: 2.19261909\n", "Iteration 14 - Elapsed 0.51s - training error: 2.19261813\n", "Iteration 15 - Elapsed 0.55s - training error: 2.19261885\n", "Iteration 16 - Elapsed 0.57s - training error: 2.19261742\n", "Iteration 17 - Elapsed 0.61s - training error: 2.19261813\n", "Iteration 18 - Elapsed 0.63s - training error: 2.19261980\n", "Iteration 19 - Elapsed 0.66s - training error: 2.19261956\n", "Iteration 20 - Elapsed 0.73s - training error: 2.19262052\n", "Test RMSE: 2.19\n" ] } ], "source": [ "# Learning on the california housing dataset\n", "learn_with_kernel(k)" ] }, { "cell_type": "markdown", "id": "d1a4b946", "metadata": { "heading_collapsed": true }, "source": [ "## Differentiable Kernel\n", "\n", "A differentiable kernel is needed for automatic hyperparameter optimization (see the [notebook](hyperopt.ipynb)).\n", "\n", "It requires inheriting from `falkon.kernels.DiffKernel`. In addition to the methods already discussed, we must implement:\n", " - `compute_diff`, which works similarly to the `compute` method but it does not have an `out` parameter. The implementation should be fully differentiable with respect to its inputs, and to the kernel hyperparameters.\n", " - `detach`, which essentially clones the kernel with the parameters *detached* from the computational graph.\n", " \n", "Another important difference from the basic kernel is the call to the *constructor*, which must include\n", " - All kernel hyperparameters as keyword arguments. These will be available as attributes on the class. Hyperparameters do not need to be tensors.\n", "\n", "**`core_fn` parameter (optional)**\n", "\n", "The constructor can also *optionally* contain a `core_fn` parameter which can simplify implementation by uniting the `compute` and `compute_diff` implementations. Have a look at the implementation of kernels in `falkon.kernels.dot_prod_kernel.py` and `falkon.kernels.distance_kernel.py` for how to use the `core_fn` parameter." ] }, { "cell_type": "code", "execution_count": 10, "id": "13c002e2", "metadata": { "hidden": true }, "outputs": [], "source": [ "class DiffLinearKernel(DiffKernel):\n", " def __init__(self, lengthscale, options):\n", " # Super-class constructor call. We do not specify core_fn\n", " # but we must specify the hyperparameter of this kernel (lengthscale)\n", " super().__init__(\"diff_linear\", \n", " options, \n", " core_fn=None, \n", " lengthscale=lengthscale)\n", " \n", " def compute(self, X1: torch.Tensor, X2: torch.Tensor, out: torch.Tensor, diag: bool):\n", " lengthscale = self.lengthscale.to(device=X1.device, dtype=X1.dtype)\n", " scaled_X1 = X1 * lengthscale\n", " if diag:\n", " out.copy_(torch.sum(scaled_X1 * X2, dim=-1))\n", " else:\n", " out = torch.matmul(scaled_X1, X2.T, out=out)\n", "\n", " return out\n", " \n", " def compute_diff(self, X1: torch.Tensor, X2: torch.Tensor, diag: bool):\n", " # The implementation here is similar to `compute` without in-place operations.\n", " lengthscale = self.lengthscale.to(device=X1.device, dtype=X1.dtype)\n", " scaled_X1 = X1 * lengthscale\n", " \n", " if diag:\n", " return torch.sum(scaled_X1 * X2, dim=-1)\n", " \n", " return torch.matmul(scaled_X1, X2.T)\n", "\n", " def detach(self):\n", " # Clones the class with detached hyperparameters\n", " return DiffLinearKernel(\n", " lengthscale=self.lengthscale.detach(), \n", " options=self.params\n", " )\n", " \n", " def compute_sparse(self, X1, X2, out, diag, **kwargs) -> torch.Tensor:\n", " raise NotImplementedError(\"Sparse not implemented\")" ] }, { "cell_type": "markdown", "id": "a7dd73bf", "metadata": { "hidden": true }, "source": [ "### Test the differentiable kernel" ] }, { "cell_type": "code", "execution_count": 11, "id": "0f223c67", "metadata": { "hidden": true }, "outputs": [], "source": [ "# Initialize the kernel, with a lengthscale which requires grad.\n", "lengthscale_init = torch.tensor([1.0]).requires_grad_()\n", "k = DiffLinearKernel(lengthscale_init, options=falkon.FalkonOptions())" ] }, { "cell_type": "code", "execution_count": 12, "id": "cc862c47", "metadata": { "hidden": true, "scrolled": false }, "outputs": [ { "data": { "text/plain": [ "tensor([[ 2.7480, 1.6149, -1.2979, -2.3070, -1.1852],\n", " [ 4.2437, 2.8397, -2.6248, -3.1610, -1.1940],\n", " [ 2.6474, 0.9644, -0.4447, -1.1742, -1.0197],\n", " [-3.4735, 0.4214, -1.9773, 0.3380, 2.2361],\n", " [-1.8094, -0.2183, -0.5620, 1.8260, 1.8644]],\n", " grad_fn=)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Kernel matrix. Notice how the outputs has a `grad_fn`\n", "k_mat = k(torch.randn(5, 3), torch.randn(5, 3))\n", "k_mat" ] }, { "cell_type": "code", "execution_count": 13, "id": "a22671f4", "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "(tensor([-0.7049]),)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Gradient of the kernel with respect to the lengthscale.\n", "torch.autograd.grad(k_mat.sum(), k.lengthscale)" ] }, { "cell_type": "code", "execution_count": 14, "id": "d508d0d7", "metadata": { "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Kernel-vector product\n", "tensor([[ 0.0198],\n", " [-1.6055],\n", " [ 2.3654],\n", " [-0.6039]], grad_fn=)\n", "Gradients:\n", "(tensor([0.1758]), tensor([[ 0.6192, 1.2183, -0.2544],\n", " [ 0.6192, 1.2183, -0.2544],\n", " [ 0.6192, 1.2183, -0.2544],\n", " [ 0.6192, 1.2183, -0.2544]]))\n" ] } ], "source": [ "# kernel-vector product + gradient\n", "m1 = torch.randn(4, 3).requires_grad_()\n", "m2 = torch.randn(2, 3)\n", "v = torch.randn(2, 1)\n", "k_mmv = k.mmv(m1, m2, v)\n", "print(\"Kernel-vector product\")\n", "print(k_mmv)\n", "print(\"Gradients:\")\n", "print(torch.autograd.grad(k_mmv.sum(), [k.lengthscale, m1]))" ] }, { "cell_type": "code", "execution_count": 15, "id": "c56f9a2a", "metadata": { "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iteration 1 - Elapsed 0.06s - training error: 2.20815659\n", "Iteration 2 - Elapsed 0.10s - training error: 2.19324374\n", "Iteration 3 - Elapsed 0.12s - training error: 2.19264197\n", "Iteration 4 - Elapsed 0.15s - training error: 2.19263649\n", "Iteration 5 - Elapsed 0.18s - training error: 2.19262934\n", "Iteration 6 - Elapsed 0.21s - training error: 2.19261909\n", "Iteration 7 - Elapsed 0.24s - training error: 2.19261813\n", "Iteration 8 - Elapsed 0.26s - training error: 2.19262004\n", "Iteration 9 - Elapsed 0.29s - training error: 2.19261765\n", "Iteration 10 - Elapsed 0.34s - training error: 2.19261789\n", "Iteration 11 - Elapsed 0.38s - training error: 2.19261909\n", "Iteration 12 - Elapsed 0.40s - training error: 2.19261885\n", "Iteration 13 - Elapsed 0.43s - training error: 2.19261956\n", "Iteration 14 - Elapsed 0.46s - training error: 2.19261932\n", "Iteration 15 - Elapsed 0.49s - training error: 2.19261932\n", "Iteration 16 - Elapsed 0.52s - training error: 2.19262099\n", "Iteration 17 - Elapsed 0.54s - training error: 2.19262123\n", "Iteration 18 - Elapsed 0.57s - training error: 2.19262147\n", "Iteration 19 - Elapsed 0.60s - training error: 2.19262195\n", "Iteration 20 - Elapsed 0.65s - training error: 2.19262338\n", "Test RMSE: 2.19\n" ] } ], "source": [ "# Learning on the california housing dataset\n", "learn_with_kernel(k)" ] }, { "cell_type": "markdown", "id": "fd29c886", "metadata": { "heading_collapsed": true }, "source": [ "## Adding KeOps Support\n", "\n", "We must inherit from `falkon.kernels.KeopsKernelMixin` and implement the method `keops_mmv_impl`.\n", "\n", "KeOps-enabled kernels will still use the implementation in the `compute` function for computing the kernel matrix itself, but will use KeOps to compute kernel-vector products (if the data dimension is small enough).\n", "\n", "This method is responsible for kernel-vector products, and it should contain:\n", " 1. A formula definition (see https://www.kernel-operations.io/keops/api/math-operations.html for the appropriate syntax)\n", " 2. A definition of all variables (again have a look at the KeOps documentation, or the implementation\n", " of other kernels within Falkon)\n", " 3. A call to the `keops_mmv` method of the `KeopsKernelMixin` class, responsible for calling into\n", " the KeOps formula.\n", " \n", "For our kernel we will use the `(X | Y)` syntax for the dot-product between samples, and then multiplication with the vector `v`. The aliases list maps the symbols used in the formula with the KeOps variable types.\n", "\n", "For more examples check the [KeOps documentatiaon](https://www.kernel-operations.io) or the implementation of existing kernels." ] }, { "cell_type": "code", "execution_count": 16, "id": "8164d28e", "metadata": { "hidden": true }, "outputs": [], "source": [ "class KeopsLinearKernel(DiffKernel, KeopsKernelMixin):\n", " def __init__(self, lengthscale, options):\n", " super().__init__(\"my-keops-linear\", \n", " options, \n", " core_fn=None, \n", " lengthscale=lengthscale)\n", " \n", " def compute(self, X1: torch.Tensor, X2: torch.Tensor, out: torch.Tensor, diag: bool):\n", " lengthscale = self.lengthscale.to(device=X1.device, dtype=X1.dtype)\n", " scaled_X1 = X1 * lengthscale\n", " \n", " if diag:\n", " out.copy_(torch.sum(scaled_X1 * X2, dim=-1))\n", " else:\n", " out = torch.matmul(scaled_X1, X2.T, out=out)\n", "\n", " return out\n", " \n", " def compute_diff(self, X1: torch.Tensor, X2: torch.Tensor, diag: bool):\n", " scaled_X1 = X1 * self.lengthscale\n", " \n", " if diag:\n", " return torch.sum(scaled_X1 * X2, dim=-1)\n", " \n", " return torch.matmul(scaled_X1, X2.T)\n", "\n", " def detach(self):\n", " return KeopsLinearKernel(\n", " lengthscale=self.lengthscale.detach(), \n", " options=self.params\n", " )\n", " \n", " def keops_mmv_impl(self, X1, X2, v, kernel, out, opt):\n", " # Keops formula for kernel-vector.\n", " formula = '(scale * (X | Y)) * v'\n", " aliases = [\n", " 'X = Vi(%d)' % (X1.shape[1]),\n", " 'Y = Vj(%d)' % (X2.shape[1]),\n", " 'v = Vj(%d)' % (v.shape[1]),\n", " 'scale = Pm(%d)' % (self.lengthscale.shape[0]),\n", " ]\n", " other_vars = [\n", " self.lengthscale.to(dtype=X1.dtype, device=X1.device),\n", " ]\n", " # Call to the executor of the formula.\n", " return self.keops_mmv(X1, X2, v, out, formula, aliases, other_vars, opt)\n", "\n", " \n", " def compute_sparse(self, X1, X2, out: torch.Tensor, diag: bool, **kwargs) -> torch.Tensor:\n", " raise NotImplementedError(\"Sparse not implemented\")" ] }, { "cell_type": "markdown", "id": "11528fb9", "metadata": { "hidden": true }, "source": [ "### Test the KeOps kernel\n", "\n", "Note that KeOps will need to compile the kernels the first time they are run!" ] }, { "cell_type": "code", "execution_count": 17, "id": "12107005", "metadata": { "hidden": true }, "outputs": [], "source": [ "lengthscale_init = torch.tensor([1.0]).requires_grad_()\n", "k = KeopsLinearKernel(lengthscale_init, options=falkon.FalkonOptions(use_cpu=True))" ] }, { "cell_type": "code", "execution_count": 18, "id": "c231eabc", "metadata": { "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Kernel-vector product\n", "tensor([[-1.2121],\n", " [-0.1148],\n", " [ 2.2435],\n", " [ 0.9918]], grad_fn=)\n", "Gradients:\n", "(tensor([1.9084]), tensor([[ 1.0124, -0.8363, 0.7706],\n", " [ 1.0124, -0.8363, 0.7706],\n", " [ 1.0124, -0.8363, 0.7706],\n", " [ 1.0124, -0.8363, 0.7706]], requires_grad=True))\n" ] } ], "source": [ "# kernel-vector product + gradient\n", "m1 = torch.randn(4, 3).requires_grad_()\n", "m2 = torch.randn(2, 3)\n", "v = torch.randn(2, 1)\n", "k_mmv = k.mmv(m1, m2, v)\n", "print(\"Kernel-vector product\")\n", "print(k_mmv)\n", "print(\"Gradients:\")\n", "print(torch.autograd.grad(k_mmv.sum(), [k.lengthscale, m1]))" ] }, { "cell_type": "code", "execution_count": 19, "id": "7d2b80e5", "metadata": { "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iteration 1 - Elapsed 0.17s - training error: 2.27769995\n", "Iteration 2 - Elapsed 0.34s - training error: 2.19313025\n", "Iteration 3 - Elapsed 0.51s - training error: 2.19323778\n", "Iteration 4 - Elapsed 0.66s - training error: 2.19308257\n", "Iteration 5 - Elapsed 0.82s - training error: 2.19269753\n", "Iteration 6 - Elapsed 0.98s - training error: 2.19266987\n", "Iteration 7 - Elapsed 1.13s - training error: 2.19262886\n", "Iteration 8 - Elapsed 1.29s - training error: 2.19262505\n", "Iteration 9 - Elapsed 1.45s - training error: 2.19262052\n", "Iteration 10 - Elapsed 1.76s - training error: 2.19260979\n", "Iteration 11 - Elapsed 1.92s - training error: 2.19261813\n", "Iteration 12 - Elapsed 2.08s - training error: 2.19261646\n", "Iteration 13 - Elapsed 2.25s - training error: 2.19263911\n", "Iteration 14 - Elapsed 2.42s - training error: 2.19263911\n", "Iteration 15 - Elapsed 2.58s - training error: 2.19264960\n", "Iteration 16 - Elapsed 2.74s - training error: 2.19265103\n", "Iteration 17 - Elapsed 2.91s - training error: 2.19268680\n", "Iteration 18 - Elapsed 3.07s - training error: 2.19269395\n", "Iteration 19 - Elapsed 3.23s - training error: 2.19270301\n", "Iteration 20 - Elapsed 3.55s - training error: 2.19275403\n", "Test RMSE: 2.19\n" ] } ], "source": [ "# Learning on the california housing dataset.\n", "# Due to differences in floating point code, results may be slightly \n", "# different from the other implementations.\n", "learn_with_kernel(k)" ] }, { "cell_type": "markdown", "id": "4fecea8f", "metadata": {}, "source": [ "## Supporting Sparse Data\n", "\n", "Sparse support can be necessary for kernel learning in extremely high dimensions, when the inputs are sparse.\n", "\n", "Sparse support requires using special functions for common operations such as matrix multiplication. Falkon implements sparse tensors in a CSR format (PyTorch is slowly picking this format up, in place of COO), through the `falkon.sparse.SparseTensor` class.\n", "\n", "We will implement the `compute_sparse` method below, supporting both diagonal and full kernels.\n", "However, only CPU support is added here (CUDA support is possible but requires a few more details), and differentiable sparse kernels are not supported." ] }, { "cell_type": "code", "execution_count": 20, "id": "98a92bf4", "metadata": {}, "outputs": [], "source": [ "from falkon.sparse import SparseTensor\n", "from falkon.sparse import sparse_matmul, bdot" ] }, { "cell_type": "code", "execution_count": 21, "id": "d9304478", "metadata": {}, "outputs": [], "source": [ "class SparseLinearKernel(Kernel):\n", " def __init__(self, lengthscale, options):\n", " # The base class takes as inputs a name for the kernel, and\n", " # an instance of `FalkonOptions`.\n", " super().__init__(\"sparse_linear\", options)\n", " \n", " self.lengthscale = lengthscale\n", " \n", " def compute(self, X1: torch.Tensor, X2: torch.Tensor, out: torch.Tensor, diag: bool) -> torch.Tensor:\n", " lengthscale = self.lengthscale.to(device=X1.device, dtype=X1.dtype)\n", "\n", " scaled_X1 = X1 * lengthscale\n", " \n", " if diag:\n", " out.copy_(torch.sum(scaled_X1 * X2, dim=-1))\n", " else:\n", " # The dot-product row-by-row on `X1` and `X2` can be computed\n", " # on many rows at a time with matrix multiplication.\n", " out = torch.matmul(scaled_X1, X2.T, out=out)\n", "\n", " return out\n", " \n", " def compute_sparse(self, \n", " X1: SparseTensor, \n", " X2: SparseTensor, \n", " out: torch.Tensor, \n", " diag: bool,\n", " **kwargs) -> torch.Tensor:\n", " # The inputs will be matrix X1(n*d) in CSR format, and X2(d*n) in CSC format.\n", " \n", " # To support different devices/data types, you must make sure\n", " # the lengthscale is compatible with the data.\n", " lengthscale = self.lengthscale.to(device=X1.device, dtype=X1.dtype)\n", " \n", " if diag:\n", " # The diagonal is a dot-product between rows of X1 and X2.\n", " # The batched-dot is only implemented on CPU.\n", " out = bdot(X1, X2.transpose_csr(), out)\n", " else:\n", " # Otherwise we need to matrix-multiply. Note that X2 is already\n", " # transposed correctly.\n", " out = sparse_matmul(X1, X2, out)\n", "\n", " out.mul_(lengthscale)\n", " return out" ] }, { "cell_type": "markdown", "id": "75a18424", "metadata": {}, "source": [ "### Testing sparse support\n", "\n", "We generate two sparse matrices, and check that the sparse kernel is equivalent to the dense version." ] }, { "cell_type": "code", "execution_count": 22, "id": "e6e4c0bd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0., 5.],\n", " [1., 8.],\n", " [2., 0.]])" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "indexptr = torch.tensor([0, 1, 3, 4], dtype=torch.long)\n", "index = torch.tensor([1, 0, 1, 0], dtype=torch.long)\n", "value = torch.tensor([5, 1, 8, 2], dtype=torch.float32)\n", "sp1 = SparseTensor(indexptr=indexptr, index=index, data=value, size=(3, 2), sparse_type=\"csr\")\n", "# Converted to dense:\n", "dense1 = torch.from_numpy(sp1.to_scipy().todense())\n", "dense1" ] }, { "cell_type": "code", "execution_count": 23, "id": "7470a0b6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0., 2.],\n", " [1., 0.],\n", " [3., 4.]])" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "indexptr = torch.tensor([0, 1, 2, 4], dtype=torch.long)\n", "index = torch.tensor([1, 0, 0, 1], dtype=torch.long)\n", "value = torch.tensor([2, 1, 3, 4], dtype=torch.float32)\n", "sp2 = SparseTensor(indexptr=indexptr, index=index, data=value, size=(3, 2), sparse_type=\"csr\")\n", "dense2 = torch.from_numpy(sp2.to_scipy().todense())\n", "dense2" ] }, { "cell_type": "code", "execution_count": 24, "id": "f217cde0", "metadata": {}, "outputs": [], "source": [ "# Initialize the kernel\n", "lengthscale_init = torch.tensor([1.0])\n", "k = SparseLinearKernel(lengthscale_init, options=falkon.FalkonOptions())" ] }, { "cell_type": "code", "execution_count": 25, "id": "80f6a317", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[True, True, True],\n", " [True, True, True],\n", " [True, True, True]])" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "k(sp1, sp2) == k(dense1, dense2)\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.9" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }