{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Falkon Regression Tutorial\n", "\n", "## Introduction\n", "\n", "This notebook introduces the main interface of the Falkon library, \n", "using a toy regression problem.\n", "\n", "We will be using the Boston housing dataset which is included in `scikit-learn` to train a Falkon model.\n", "Since the dataset is very small, it is not necessary to use the Nystroem approximation here. It is however useful to demonstrate the simple API offered by Falkon." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "pycharm": { "is_executing": false, "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[pyKeOps]: Warning, no cuda detected. Switching to cpu only.\n" ] } ], "source": [ "%matplotlib inline\n", "from sklearn import datasets\n", "import numpy as np\n", "import torch\n", "import matplotlib.pyplot as plt\n", "plt.style.use('ggplot')\n", "\n", "import falkon" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load the data\n", "\n", "The Boston housing dataset poses a regression problem with 506 data points in 13 dimensions.\n", "The goal is to predict house prices given some attributes including criminality rates, air pollution, property value, etc.\n", "\n", "After loading the data, we split it into two parts: a training set (containing 80% of the points) and a test \n", "set with the remaining 20%. Data splitting could alternatively be done using some scikit-learn utilities (found in the [model_selection module](https://scikit-learn.org/stable/modules/cross_validation.html#cross-validation))" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "pycharm": { "is_executing": false, "name": "#%%\n" } }, "outputs": [], "source": [ "X, Y = datasets.fetch_california_housing(return_X_y=True)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "pycharm": { "is_executing": false, "name": "#%%\n" } }, "outputs": [], "source": [ "num_train = int(X.shape[0] * 0.8)\n", "num_test = X.shape[0] - num_train\n", "shuffle_idx = np.arange(X.shape[0])\n", "np.random.shuffle(shuffle_idx)\n", "train_idx = shuffle_idx[:num_train]\n", "test_idx = shuffle_idx[num_train:]\n", "\n", "Xtrain, Ytrain = X[train_idx], Y[train_idx]\n", "Xtest, Ytest = X[test_idx], Y[test_idx]" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Pre-process the data\n", "\n", "We must convert the numpy arrays to PyTorch tensors before using them in Falkon.\n", "This is very easy and fast with the `torch.from_numpy` function.\n", "\n", "Another preprocessing step which is often necessary with kernel methods is to normalize the z-score of the data:\n", "convert it to have zero-mean and unit standard deviation.\n", "We use the statistics of the training data to avoid leakage between the two sets." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "pycharm": { "is_executing": false, "name": "#%%\n" } }, "outputs": [], "source": [ "# convert numpy -> pytorch\n", "Xtrain = torch.from_numpy(Xtrain).to(dtype=torch.float32)\n", "Xtest = torch.from_numpy(Xtest).to(dtype=torch.float32)\n", "Ytrain = torch.from_numpy(Ytrain).to(dtype=torch.float32)\n", "Ytest = torch.from_numpy(Ytest).to(dtype=torch.float32)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "pycharm": { "is_executing": false, "name": "#%%\n" } }, "outputs": [], "source": [ "# z-score normalization\n", "train_mean = Xtrain.mean(0, keepdim=True)\n", "train_std = Xtrain.std(0, keepdim=True)\n", "Xtrain -= train_mean\n", "Xtrain /= train_std\n", "Xtest -= train_mean\n", "Xtest /= train_std" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Create the Falkon model\n", "\n", "The Falkon object is the main API of this library. \n", "It is similar in spirit to the fit-transform API of scikit-learn, while supporting some\n", "additional features such as monitoring of validation error.\n", "\n", "While Falkon models have many options, most are related to performance fine-tuning which becomes useful with much \n", "larger datasets.\n", "Here we only showcase some of the more basic options.\n", "\n", "Mandatory parameters are:\n", " - the kernel function (here we use a linear kernel)\n", " - the amount of regularization, which we set to some small positive value\n", " - the number of inducing points M. We set M to 5000, which is a sizable portion of the dataset." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "pycharm": { "is_executing": false, "name": "#%%\n" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/giacomo/Dropbox/unige/falkon/falkon/falkon/utils/switches.py:25: UserWarning: Failed to initialize CUDA library; falling back to CPU. Set 'use_cpu' to True to avoid this warning.\n", " warnings.warn(get_error_str(\"CUDA\", None))\n" ] } ], "source": [ "options = falkon.FalkonOptions(keops_active=\"no\")\n", "\n", "kernel = falkon.kernels.GaussianKernel(sigma=1, opt=options)\n", "flk = falkon.Falkon(kernel=kernel, penalty=1e-5, M=5000, options=options)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Training the model\n", "\n", "The Falkon model is trained using the preconditioned conjugate gradient algorithm (TODO: Add a reference). Thus there are\n", "two steps to the algorithm: first the preconditioner is computed, and then the conjugate gradient iterations are performed.\n", "To gain more insight in the various steps of the algorithm you can pass `debug=True` when creating the Falkon object. \n", "\n", "Model training will occur on the GPU, if it is available, and CUDA is properly installed, \n", "or on the CPU as a fallback. " ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "pycharm": { "is_executing": false, "name": "#%%\n" } }, "outputs": [ { "data": { "text/plain": [ "Falkon(M=5000, center_selection=, kernel=GaussianKernel(sigma=Parameter containing:\n", "tensor([1.], dtype=torch.float64)), options=FalkonOptions(keops_acc_dtype='auto', keops_sum_scheme='auto', keops_active='no', keops_memory_slack=0.7, chol_force_in_core=False, chol_force_ooc=False, chol_par_blk_multiplier=2, pc_epsilon_32=1e-05, pc_epsilon_64=1e-13, cpu_preconditioner=False, cg_epsilon_32=1e-07, cg_epsilon_64=1e-15, cg_tolerance=1e-07, cg_full_gradient_every=10, cg_differential_convergence=False, debug=False, use_cpu=False, max_gpu_mem=inf, max_cpu_mem=inf, compute_arch_speed=False, no_single_kernel=True, min_cuda_pc_size_32=10000, min_cuda_pc_size_64=30000, min_cuda_iter_size_32=300000000, min_cuda_iter_size_64=900000000, never_store_kernel=False, store_kernel_d_threshold=1200, num_fmm_streams=2), penalty=1e-05)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "flk.fit(Xtrain, Ytrain)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Optimization converges very quickly to a minimum, where convergence is detected by checking the change model parameters between iterations." ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Evaluating model performance" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since the problem is regression a natural error metric is the RMSE. Given a fitted model, we can run the `predict` method to obtain predictions on new data.\n", "\n", "Here we print the error on both train and test sets." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "pycharm": { "is_executing": false, "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training RMSE: 0.527\n", "Test RMSE: 0.578\n" ] } ], "source": [ "train_pred = flk.predict(Xtrain).reshape(-1, )\n", "test_pred = flk.predict(Xtest).reshape(-1, )\n", "\n", "def rmse(true, pred):\n", " return torch.sqrt(torch.mean((true.reshape(-1, 1) - pred.reshape(-1, 1))**2))\n", "\n", "print(\"Training RMSE: %.3f\" % (rmse(train_pred, Ytrain)))\n", "print(\"Test RMSE: %.3f\" % (rmse(test_pred, Ytest)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally we plot the model predictions to check that the distribution of our predictions is close to that of the labels." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD4CAYAAAAAczaOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAY9ElEQVR4nO3df2xb1d3H8bdvHJOWLMGOG6KEVHtCUqGKQNjStQmwlNbVWNWJqNKYWhWNtJMGTA9qgxhhdC1SC4s0kkCmVJWmUoYmTZu2NZug+4PgkWp4kwylWvih0rBoUDU0P2xSWpI4if380QdrrKFJ7Bvb8fm8/qqv7Xu+p44/Offk3nMdsVgshoiIGMFKdwEiIpI6Cn0REYMo9EVEDKLQFxExiEJfRMQgCn0REYM4013AfJw7dy7dJczJ6/UyMjKS7jJskU19AfUnk2VTXyCz+lNaWjrrdo30RUQMotAXETGIQl9ExCBLYk5fRGShYrEYExMTRKNRHA5HSto8f/48k5OTKWkLLvfRsizy8vLm3UeFvohkpYmJCXJzc3E6UxdzTqeTnJyclLUHMD09zcTEBMuWLZvX6zW9IyJZKRqNpjTw08XpdBKNRuf9eoW+iGSlVE3pZIKF9FWhLyJikOw/9hERAWZ+ccDW/eX870+v+nwoFOJ73/seAMPDw+Tk5ODxeAB4+eWXcblcttYzXwp9sYXdX6iFmOvLJ5IOHo+HV155BYC2tjauvfZaHnjggfjz09PTafmbg0JfRCRFdu/ezXXXXcfbb79NdXU1+fn5X/hlsGHDBn71q19RXl7OH/7wB55//nkikQi33XYbP/vZz2w5M0hz+iIiKfSvf/2L3/72t+zfv/9LX3PmzBn+/Oc/093dzSuvvEJOTg5//OMfbWlfI30RkRTasmXLnCP2v/3tb/T19bF582bg8jUHXq/XlvYV+iIiKbR8+fL4v3Nycr5wjv3nV/PGYjG++93v8vjjj9vevqZ3RETSpLy8nL6+PgD6+vr48MMPAbjjjjt46aWX4ss0h8Nhzp49a0ubGumLiBEy8SyvzZs38/vf/55NmzZRU1NDRUUFAKtWreLHP/4x27ZtIxaL4XQ6eeqpp7jhhhuSblOhLyKyyB555JFZty9btozf/OY3sz53zz33cM8999hei6Z3REQMMudI/9ChQ5w8eZLCwkLa2toAuHjxIh0dHQwPD7NixQr27NlDfn4+AMeOHcPv92NZFk1NTdTU1ACXT1Pq6uqKn3Pa1NRk1NoY2e7p/NvT1nbmHbSLZK45R/rr16/nJz/5yRe2dXd3U11dTWdnJ9XV1XR3dwNw9uxZAoEA7e3tPPHEExw5ciT+l+lf/vKX/PCHP6Szs5OPP/6YU6dO2d4ZERG5ujlDf/Xq1fFR/OeCwSANDQ0ANDQ0EAwG49vr6+vJzc2luLiYkpIS+vv7CYfDjI+Ps2rVKhwOB9/85jfj7xERkdRJaE5/bGwMt9sNgNvt5sKFC8DlBYaKiorir/N4PIRCoSu2FxUVEQqFkqlbREQSYOvZO7FYbEHbv0xPTw89PT0AtLa22nYl2mJyOp1Los75SKQvDit95wTMVWs2fTaQXf1ZzL6cP38+PQuapaHNa665Zt7/jwlVV1hYSDgcxu12Ew6HKSgoAC6P4EdHR+OvC4VCeDyeK7aPjo7Glxidjc/nw+fzxR9/foFCJvN6vUuizvlIpC+xBdy5x25z1ZpNnw1kV38Wsy+Tk5NfWO7g4Gsf2br/vevLr9jmdDqZnp6OPy4vL+emm25iZmaGyspKnnvuuXnf1vC/7d69G5/Px5YtW654bnJy8or/x9LS0ln3k9DwrLa2lt7eXgB6e3tZs2ZNfHsgEGBqaoqhoSEGBweprKzE7XazbNky3n//fWKxGCdOnKC2tjaRpkVEloy8vDxeeeUV/H4/LpeLF1988QvPz8zMpLymOUf6zz77LO+++y6ffvopDzzwAPfeey+NjY10dHTg9/vxer00NzcDl3+r1dXV0dzcjGVZ7Nq1C+v/D/t/8IMfcOjQISKRCDU1Ndx2222L2zMRkQzyjW98g/feey9+huP111/PO++8w6uvvsrTTz/N3//+dyKRCN///ve57777iMVi7N27l9dff53y8iuPKhI1Z+jv3r171u379u2bdfvWrVvZunXrFdtvvPHG+Hn+IiImmZ6e5q9//Svr168H4NSpU/j9flauXMmvf/1rvvKVr3D8+HEmJydpbGykoaGBt99+mw8++IBXX32V4eFh7rrrrviduJKhZRhERBbJxMQEmzZtAmDt2rVs27aNN954g5qaGlauXAlcniJ/7733ePnllwH49NNPGRgY4B//+AeNjY3k5ORQUlLC7bfbcwGkQl9EZJF8Pqf/3/5zeWWAgwcPxo8CPvfqq68uyqoFWntHRCSNGhoaePHFF5mamgLggw8+4LPPPmPdunX86U9/YmZmhvPnzxMIBGxpTyN9ETHCbKdYZoLt27fz0UcfcffddxOLxfB4PDz//PN8+9vf5vXXX2fjxo1UVFSwbt06W9pzxBZ65VQanDt3Lt0lzMn0c6cPHPUvUjVz+2nThqs+n02fDWRXfxazL5999tkV0yiL7b/P00+V2fpq63n6IiKyNCn0RUQMotAXkay0BGaubbOQvir0RSQrWZaVlvn1VJueno6vfDAfOntHRLJSXl4eExMTTE5Opuwufddccw2Tk5MpaQsuj/AtyyIvL2/e71Hoi0hWcjgcCa9omailcGaVpndERAyi0BcRMYhCX0TEIAp9ERGD6A+5WWbmFweS3kfY5WImElnYm/LtWfZVRBaXRvoiIgbRSF+WvLlueO1ynSey0COXecjUVRtFrkYjfRERgyj0RUQMotAXETGIQl9ExCAKfRERgyj0RUQMotAXETGIQl9ExCAKfRERgyj0RUQMotAXETGIQl9ExCBJLbj20ksv4ff7cTgclJeX89BDDxGJROjo6GB4eJgVK1awZ88e8vPzATh27Bh+vx/LsmhqaqKmpsaOPoiIyDwlPNIPhUL85S9/obW1lba2NqLRKIFAgO7ubqqrq+ns7KS6upru7m4Azp49SyAQoL29nSeeeIIjR44QjUbt6oeIiMxDUtM70WiUSCTCzMwMkUgEt9tNMBikoaEBgIaGBoLBIADBYJD6+npyc3MpLi6mpKSE/v7+5HsgIiLzlvD0jsfj4Tvf+Q4PPvggLpeLW2+9lVtvvZWxsTHcbjcAbrebCxcuAJePDKqqqr7w/lAoNOu+e3p66OnpAaC1tRWv15tomSnjdDozos6wy5X0PiyHA9cC9+Ow0vfnoblqdTisBfdnPtL1eWfKz5odsqkvsDT6k3DoX7x4kWAwSFdXF8uXL6e9vZ0TJ0586etjsdi89+3z+fD5fPHHIyMjiZaZMl6vNyPqXPBtDmfhcrkWfNORmCt9U3Vz1ZpIf+YjXZ93pvys2SGb+gKZ1Z/S0tJZtyc8POvr66O4uJiCggKcTidr167l/fffp7CwkHA4DEA4HKagoACAoqIiRkdH4+8PhUJ4PJ5EmxcRkQQkHPper5czZ84wOTlJLBajr6+PsrIyamtr6e3tBaC3t5c1a9YAUFtbSyAQYGpqiqGhIQYHB6msrLSnFyIiMi8JT+9UVVWxbt06HnvsMXJycvjqV7+Kz+djYmKCjo4O/H4/Xq+X5uZmAMrLy6mrq6O5uRnLsti1axdWGueBs9XT+bcnvQ+HZaV1ukZEFo8jtpDJ9jQ5d+5cukuYU6bM5R046k96Hw7LIraETqd1/E/VVZ9frDn9dN0YPVN+1uyQTX2BzOqP7XP6IiKy9Cj0RUQMotAXETGIQl9ExCAKfRERgyj0RUQMotAXETGIQl9ExCBJ3URFJBPEBs5c9fnIYl1slqaLs0SSoZG+iIhBFPoiIgZR6IuIGEShLyJiEIW+iIhBFPoiIgZR6IuIGEShLyJiEIW+iIhBFPoiIgZR6IuIGEShLyJiEIW+iIhBFPoiIgZR6IuIGEShLyJiEIW+iIhBFPoiIgZR6IuIGEShLyJiEIW+iIhBnMm8+dKlSxw+fJiPPvoIh8PBgw8+SGlpKR0dHQwPD7NixQr27NlDfn4+AMeOHcPv92NZFk1NTdTU1NjRBxERmaekQv/o0aPU1NTwyCOPMD09zeTkJMeOHaO6uprGxka6u7vp7u5mx44dnD17lkAgQHt7O+FwmAMHDvDcc89hWTrYEBFJlYQT97PPPuO9995jw4YNADidTq699lqCwSANDQ0ANDQ0EAwGAQgGg9TX15Obm0txcTElJSX09/fb0AUREZmvhEf6Q0NDFBQUcOjQIf79739TUVHB/fffz9jYGG63GwC3282FCxcACIVCVFVVxd/v8XgIhUKz7runp4eenh4AWltb8Xq9iZaZMk6nMyPqdNhw5OQAyKIjsMXqT7o+70z5WbNDNvUFlkZ/Eg79mZkZBgYG2LlzJ1VVVRw9epTu7u4vfX0sFpv3vn0+Hz6fL/54ZGQk0TJTxuv1ZkSdsWg0+Z1Ylj37yRSL1J/dv3/L9n3Oh8vl4sf116elbbtlyvfGLpnUn9LS0lm3Jzz8KSoqoqioKD56X7duHQMDAxQWFhIOhwEIh8MUFBTEXz86Ohp/fygUwuPxJNq8iIgkIOHQv+666ygqKuLcuXMA9PX1ccMNN1BbW0tvby8Avb29rFmzBoDa2loCgQBTU1MMDQ0xODhIZWWlDV0QEZH5SursnZ07d9LZ2cn09DTFxcU89NBDxGIxOjo68Pv9eL1empubASgvL6euro7m5mYsy2LXrl06c0dEJMUcsYVMtqfJ50cTmSxT5vIOHPUnvQ9Hls3pL1Z/HP9TNfeLFoHm9DNXJvXH9jl9ERFZehT6IiIGUeiLiBhEoS8iYhCFvoiIQRT6IiIGUeiLiBhEoS8iYhCFvoiIQRT6IiIGUeiLiBhEoS8iYhCFvoiIQRT6IiIGUeiLiBhEoS8iYhCFvoiIQRT6IiIGUeiLiBhEoS8iYhBnugsQWapiA2fS0m7EsiBLbowuqaeRvoiIQRT6IiIGUeiLiBhEoS8iYhCFvoiIQRT6IiIGUeiLiBhEoS8iYhCFvoiIQZK+IjcajdLS0oLH46GlpYWLFy/S0dHB8PAwK1asYM+ePeTn5wNw7Ngx/H4/lmXR1NRETU1Nss2LiMgCJD3SP378OGVlZfHH3d3dVFdX09nZSXV1Nd3d3QCcPXuWQCBAe3s7TzzxBEeOHCEajSbbvIiILEBSoT86OsrJkyfZuHFjfFswGKShoQGAhoYGgsFgfHt9fT25ubkUFxdTUlJCf39/Ms2LiMgCJTW988ILL7Bjxw7Gx8fj28bGxnC73QC43W4uXLgAQCgUoqqqKv46j8dDKBSadb89PT309PQA0NraitfrTabMlHA6nRlRp8NK/s80DgAb9pMpsrE/mfCzZodM+d7YZSn0J+HQf/PNNyksLKSiooJ33nlnztfHYrF579vn8+Hz+eKPR0ZGEqoxlbxeb0bUGbNjysyy7NlPpsjC/mTCz5odMuV7Y5dM6k9paems2xMO/dOnT/PGG2/w1ltvEYlEGB8fp7Ozk8LCQsLhMG63m3A4TEFBAQBFRUWMjo7G3x8KhfB4PIk2LyIiCUj4mHf79u0cPnyYrq4udu/ezc0338zDDz9MbW0tvb29APT29rJmzRoAamtrCQQCTE1NMTQ0xODgIJWVlfb0QkRE5sX2m6g0NjbS0dGB3+/H6/XS3NwMQHl5OXV1dTQ3N2NZFrt27cLKonlWEZGlwBFbyGR7mpw7dy7dJcwpU+byDhz1J70PR5bNgWdjf/Z+f326y7BFpnxv7JJJ/bF9Tl9E0ufgax+lpd2968vT0q7YR/MrIiIGUeiLiBhEoS8iYhCFvoiIQRT6IiIGUeiLiBhEoS8iYhCFvoiIQXRxlsgSFBs4k56GdXHWkqeRvoiIQRT6IiIGUeiLiBhEoS8iYhCFvoiIQRT6IiIGUeiLiBhEoS8iYhCFvoiIQRT6IiIGUeiLiBhEoS8iYhCFvoiIQRT6IiIGUeiLiBhEoS8iYhCFvoiIQRT6IiIGUeiLiBhEoS8iYpCEb4w+MjJCV1cXn3zyCQ6HA5/Px+bNm7l48SIdHR0MDw+zYsUK9uzZQ35+PgDHjh3D7/djWRZNTU3U1NTY1Q8RkUUx84sD835t2OViJhKxpd2c//2pLfv5bwmHfk5ODvfddx8VFRWMj4/T0tLCLbfcwmuvvUZ1dTWNjY10d3fT3d3Njh07OHv2LIFAgPb2dsLhMAcOHOC5557DsnSwISKSKgmHvtvtxu12A7Bs2TLKysoIhUIEg0GefPJJABoaGnjyySfZsWMHwWCQ+vp6cnNzKS4upqSkhP7+flatWmVLRzLJQkYGtsu/PX1ti0jGSzj0/9PQ0BADAwNUVlYyNjYW/2Xgdru5cOECAKFQiKqqqvh7PB4PoVBo1v319PTQ09MDQGtrK16v144yF5XT6YzXGXa50laHw4YjJwdAFh2BqT/2sfu7+J/fm0y1kO+z5XDgsun7716k/5ekQ39iYoK2tjbuv/9+li9f/qWvi8Vi896nz+fD5/PFH4+MjCRVYyp4vd54nXbN6SUi5oomvxPLIha1YT+ZQv2xjd3fxf/83mSqhXyfXS4XEZu+/8n+v5SWls66PanhwvT0NG1tbdx5552sXbsWgMLCQsLhMADhcJiCggIAioqKGB0djb83FArh8XiSaV5ERBYo4dCPxWIcPnyYsrIytmzZEt9eW1tLb28vAL29vaxZsya+PRAIMDU1xdDQEIODg1RWViZZvoiILETC0zunT5/mxIkTrFy5kkcffRSAbdu20djYSEdHB36/H6/XS3NzMwDl5eXU1dXR3NyMZVns2rVLZ+6IiKRYwqF/00038bvf/W7W5/bt2zfr9q1bt7J169ZEmxQRkSRpqC0iYhBbTtmUL3pa58qLSIbSSF9ExCAa6YvIvB187SNb9+dynZ/Xee1715fb2q7JFPoiMm+xgTO27i8y3wvNFPq20fSOiIhBFPoiIgZR6IuIGEShLyJiEIW+iIhBFPoiIgZR6IuIGEShLyJiEIW+iIhBFPoiIgbRMgwikvHsXvNnIR5PW8uLQyN9ERGDKPRFRAyi6R0RyXh2r+5pMo30RUQMotAXETGIpndERK5iIfe8dlgWMdc8bgozDz+1ZS9X0khfRMQgCn0REYMo9EVEDKLQFxExiEJfRMQgCn0REYMo9EVEDKLQFxExSMovzjp16hRHjx4lGo2yceNGGhsbU12CiIixUhr60WiUI0eOsHfvXoqKinj88cepra3lhhtuWJT2UrkGt8t1nkgkkrL2REQSkdLpnf7+fkpKSrj++utxOp3U19cTDAZTWYKIiNFSOtIPhUIUFRXFHxcVFXHmzJVLpvb09NDT0wNAa2srpaWlCbV3aHti70ve2jS1KyJydSkd6cdisSu2ORyOK7b5fD5aW1tpbW1NRVm2aGlpSXcJtsmmvoD6k8myqS+wNPqT0tAvKipidHQ0/nh0dBS3253KEkREjJbS0L/xxhsZHBxkaGiI6elpAoEAtbW1qSxBRMRoKZ3Tz8nJYefOnTz11FNEo1HuuusuysvLU1nCovH5fOkuwTbZ1BdQfzJZNvUFlkZ/HLHZJtpFRCQr6YpcERGDKPRFRAyie+QmKZuWlTh06BAnT56ksLCQtra2dJeTtJGREbq6uvjkk09wOBz4fD42b96c7rISEolE2L9/P9PT08zMzLBu3TruvffedJeVtGg0SktLCx6PZ0mc7ng1P/rRj8jLy8OyLHJycjL2lHOFfhJSvazEYlu/fj133303XV1d6S7FFjk5Odx3331UVFQwPj5OS0sLt9xyy5L8fHJzc9m/fz95eXlMT0+zb98+ampqWLVqVbpLS8rx48cpKytjfHw83aXYYv/+/RQUFKS7jKvS9E4Ssm1ZidWrV5Ofn5/uMmzjdrupqKgAYNmyZZSVlREKhdJcVWIcDgd5eXkAzMzMMDMzM+uFjUvJ6OgoJ0+eZOPGjekuxSga6SdhvstKSPoNDQ0xMDBAZWVluktJWDQa5bHHHuPjjz/mW9/6FlVVVekuKSkvvPACO3bsyJpRPsBTTz0FwKZNmzL29E2FfhLmu6yEpNfExARtbW3cf//9LF++PN3lJMyyLH7+859z6dIlnnnmGT788ENWrlyZ7rIS8uabb1JYWEhFRQXvvPNOusuxxYEDB/B4PIyNjXHw4EFKS0tZvXp1usu6gkI/CVpWIvNNT0/T1tbGnXfeydq12bEQ3rXXXsvq1as5derUkg3906dP88Ybb/DWW28RiUQYHx+ns7OThx9+ON2lJczj8QBQWFjImjVr6O/vz8jQ15x+ErSsRGaLxWIcPnyYsrIytmzZku5yknLhwgUuXboEXD6Tp6+vj7KysjRXlbjt27dz+PBhurq62L17NzfffPOSDvyJiYn4NNXExAT//Oc/M/YXskb6Sci2ZSWeffZZ3n33XT799FMeeOAB7r33XjZs2JDushJ2+vRpTpw4wcqVK3n00UcB2LZtG1/72tfSXNnChcNhurq6iEajxGIx6urq+PrXv57usuT/jY2N8cwzzwCX/9B+xx13UFNTk96ivoSWYRARMYimd0REDKLQFxExiEJfRMQgCn0REYMo9EVEDKLQFxExiEJfRMQg/wcfh6NZb+HHugAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, ax = plt.subplots()\n", "hist_range = (min(Ytest.min(), test_pred.min()).item(), max(Ytest.max(), test_pred.max()).item())\n", "ax.hist(Ytest.numpy(), bins=10, range=hist_range, alpha=0.7, label=\"True\")\n", "ax.hist(test_pred.numpy(), bins=10, range=hist_range, alpha=0.7, label=\"Pred\")\n", "ax.legend(loc=\"best\");" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.9" }, "pycharm": { "stem_cell": { "cell_type": "raw", "metadata": { "collapsed": false }, "source": [] } }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 1 }