{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "-jlGXqM_t452" }, "source": [ "# P8 - Convolutional Neural Networks (CNNs)\n", "We have now learned about the Perceptron, Linear and logistic regression, Multi-layer perceptron and backpropagation, Auto-encoders. \n", "\n", "In this pratical session about Convolutional Neural Networks (CNNs) we will use the MNIST datasets.\n", "\n", "First, we will obtain baselines using a Logistic Regression and a Feed-forward Neural Network." ] }, { "cell_type": "markdown", "metadata": { "id": "ITJR4snhxdT0" }, "source": [ "## 0.0 - Imports\n", "We will need to import some libraries to be used in this session. Libraries include data visualizers ([matplotlib](https://matplotlib.org/)), neural network package ([torch](https://pytorch.org/)), and other helper packages for data handling ([sklearn](https://scikit-learn.org/), [numpy](https://numpy.org/))." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "MWGjU3tDw4bD" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from sklearn.base import BaseEstimator\n", "from sklearn.datasets import load_digits\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn.utils import check_random_state\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from torchvision import datasets, transforms\n", "from torch.autograd import Variable\n", "from torch.utils.data import Dataset, DataLoader\n", "from torch.utils.data.sampler import SubsetRandomSampler\n", "import time\n", "import copy" ] }, { "cell_type": "markdown", "metadata": { "id": "W-od7M6WMN0N" }, "source": [ "Then, other variable definitions are needed to be set. This includes the size of the dataset we will use, and the configuration of the GPU to be activated:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ECqewHJ0MM62", "outputId": "e5377940-a224-4e98-b427-bad0a9579863" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cpu\n" ] } ], "source": [ "# Configure Device\n", "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", "print(device)" ] }, { "cell_type": "markdown", "metadata": { "id": "odY0Ng9yycgr" }, "source": [ "### 0.1 - Create Dataloaders\n", "#### MNIST dataset \n", "Using torchvision we can easily download and use the MNIST dataset to create our train and validation dataloaders" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "snFv-Hu-zRnW" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5ef88637fe884886acab052dc627132e", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/9912422 [00:00 creates an iterator of the dataloader and gets the next batchß\n", "batch_idx, (example_imgs, example_targets) = next(enumerate(mnist_train_dataloader))\n", "# info about the dataset\n", "D_in = np.prod(example_imgs.shape[1:])\n", "D_out = len(mnist_train_dataloader.dataset.targets.unique())\n", "print(\"Datasets shapes:\", {x: dataloaders[x].dataset.data.shape for x in ['train', 'val']})\n", "print(\"N input features:\", D_in, \"Output classes:\", D_out)\n", "print(\"Train batch:\", example_imgs.shape, example_targets.shape)\n", "batch_idx, (example_imgs, example_targets) = next(enumerate(mnist_val_dataloader))\n", "print(\"Val batch:\", example_imgs.shape, example_targets.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "JnFAmoinjY1T" }, "source": [ "We can plot some examples with corresponding labels using the following function. This function can also receive the predicted labels." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 284 }, "id": "5ZWvjQOvC2ep", "outputId": "c77ced2a-931a-4fb1-db71-5354316f0e6d" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZoAAAELCAYAAADgPECFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAg3ElEQVR4nO3de7hVVb3/8c8XlYsieAFEkIsnNRNFURNNOXL8JSSKqGVo5IXwkueEovRDLCQR7RFT4fd4SbyiWd7SFC3spHnbkhqipZkiJSAXUzQFBCVl/P6Yk9kYU/baa629xlprw/v1PPthjDXm5bvnHqzvmmPONaY55wQAQCytah0AAGDjRqIBAERFogEAREWiAQBERaIBAERFogEAREWiKZGZzTCzS2odB1oe+g7K1dL7TsFEY2arvJ91ZrbGq4+oVpBm9oNcLGvSeDo1svwCL9Z/pH+k9tWK14vjVDNzZjYu9/piMxtY7XiqqY76zpFm1mBmH5jZ22Z2k5ltXWB5+k6N1VHf2dHMZprZ0vRv0buJ5ek7jSiYaJxz7df/SFokaaj32s/XL2dmm8cM0jn341wsUyQ94ZxbXmC1oemy+0raX9KE/AKx4069L2lcoTe3jVG99B1JHSVdIqmbpC9J6i7pJ02sQ9+poTrqO+skPSLp6yWsQ9/ZgLKGzsxsYJodzzeztyXdmmbRhtxyzsx2ScttzOwKM1uUZvvrzaxdGfs2SSdLuq2Y5Z1zSyTNkrSnF9P/mNkbkt5IXzvKzF5KP/XONrO+3v76mdlcM1tpZndLaltiyH+V9AdJ5zXy+7Qxs2npp6alablN2rb+OI81s3fMbJmZjcyt2+xjWk3V7jvOuV845x5xzq12zv1T0o2SDi5yXfpOHalB3/mHc+46SX8sNVb6Tqg512i6StpOUi9JZxSx/GWSdpO0j6RdlHyynLi+MT3YhxSxnQGSuki6r5ggzayHpCGSXvRePkZSf0l7mFk/SbdIOlPS9pKmS5qZHszWkh6Q9DMlv+u9yn26KTLuCyWNMbPtNtD2Q0kHKjkue0s6QOGnoK5KPpV3lzRK0rVmtm3aVvCY1rFa9R1J+k9JfylmQfpOXapl3ykafSfHOVfUj6QFkr6algdKWiuprdd+qqSG3DouDcQkfSTpC17bQZLeLHb/3no3S5pRRKyrJH0gaaGk6yS182I6zFv2p5Im59Z/XdKhSt6Ulkoyr222pEuKjDU7JpLukTQlLS+WNDAt/03SEG+dwZIWeMd5jaTNvfZ3lHSQih3T2D911HcOl/RPSbvRd+g7JcSwebrN3kXESt/ZwE9zxgrfdc59XOSynSVtKekFM1v/mknarJQdmtmWko6XNKyIxY9xzj3aSNtbXrmXpFPMbLT3WmslY/pO0hKXHs3UwhJC9k2U9LyZXZV7vVtumwvT19Z7zzn3qVdfLam9KnRMa6QWfedASb+Q9A3n3LwmFqfv1K+q950S0Xc2oDlDZ/lpnz9KA0j2bNbVa1uuJEP2cc5tk/50dMlFs1Icq+Qi1xNlxOvzY39L0qVeXNs457Z0zt0paZmk7uYdUUk9y9qhc69Jul/JKatvqZJO529/aRGbrNQxrYWq9p10mGKmpO845x5rRtwSfafWavG+UymbbN+p5Pdo/iSpj5ntY2ZtJV20vsE5t07JRdipZtZFksysu5kNLnEfp0i6PZfpm+tGSd81s/6W2MqSW2K3VnIx7VNJZ5vZFmZ2nJKxzHJNkjRS0jbea3dKmmBmnS25XXuipDua2lAFj2k9iNZ3zGxPJXcOjXbOPVThuOk7tRf1fSfdZpu02iatV8Im1XcqlmjS4YiLJT2q5K6Khtwi50uaL+lZM1uRLvfF9Y2W3Hs+oLHtm1l3SYdJur1SMadxz5F0uqRrlIzfz1cyxinn3FpJx6X19yUNV/LpwI+rYNy5fb2p5ALfVt7Ll0iaI+nPkl6WNDd9rRgFj2lLEbnvjFVyun+z/fu7GEXdDFBE3PSdGov9vqPk0/uqtPxaWq9E3JtU37HKnhwAABBiChoAQFQkGgBAVCQaAEBUJBoAQFQkGgBAVCXPDGBm3KZWh5xz1vRStUO/qVvLnXOdax1EIfSdulV03+GMBti0lTu1CVB03yHRAACiItEAAKIi0QAAoiLRAACiItEAAKIi0QAAoiLRAACiItEAAKIqeWYAYFOw9dZbZ+XTTjutYtt98skns/LcuXMrtl2gnnFGAwCIikQDAIiKRAMAiIprNNgkHXXUUUF93LhxQX3nnXfOyjvuuGPQZhZOlO1c8ZMLf/jhh1n5N7/5TdB2xhlnZOU1a9YUvU3Ur5deeimo77XXXkH94osvzsqTJk2qRkg1wRkNACAqEg0AICor5bRf4iFE9YoHnzVt/PjxWXnixIlBW+vWrYveTnOGzgptp2fPnll5yZIlZW2zDC845/av1s7KUQ99pxT+rfHPP/980LbrrrsG9XfeeScrd+vWLW5glVd03+GMBgAQFYkGABAViQYAEFWLuL15883DMPO3m/pGjRoV1P1xz3xbIU8//XRQf/jhh4P61VdfnZU/+eSToreL6hk2bFhQv/TSS7Ny/hrIZ5991uh2Zs6cGdTzfcM3evTooN6vX7+gvuWWWza67plnnrnBWCX6WEty0EEHZeX8NZm8hQsXxg6nLnBGAwCIikQDAIiqLm9vzt/mN3369KB+xBFHFL0t/xbSUn7Xpm5h/e53v5uVH3300aBtwYIFRe+nUri9uWljxozJyjfddFPQtmrVqij7HDFiRFC/7bbbsnKhPnbwwQcHbc8991yE6CRxe3PFDRo0KCvnZ3/Ie/XVV7Ny3759o8UUCbc3AwDqA4kGABAViQYAEFVd3t48ZcqUoD5kyJCgXu6UH5V0/fXXZ+Vly5YFbT169Kh2OCjCtGnTqr7Pp556Kqj712VatQo/561bty4rDxw4MGiLeI0GNdSpU6dah1AVnNEAAKIi0QAAoiLRAACiqstrNPnvOHzrW98qet25c+cGdf8+9vx28/bdd9+sfOGFFwZt+alEfIWmxAF8/vVF/5pMvi3/BND8dUu0DPnvSpXavrHgjAYAEBWJBgAQVV0OnT355JNBPT9b7oABA7Ky/4Q66fMz9uZvPS5k8eLFWfnFF18M2kqZVubcc8/NylOnTi16PWA9f3ZwtCwrVqzIyqtXrw7a2rVrF9Tr4asa1cAZDQAgKhINACAqEg0AIKq6vEaTN2HChKDuX6O54YYbgrb33nuvIvvMb2f27NlB3X+KHtCYjh07lrXepvLkxY3Rs88+m5Xnz58ftO21117VDqcucEYDAIiKRAMAiIpEAwCIqkVco2loaChYjyF///vSpUuj7xMtX+/evYP6vffeW5tAgDrCGQ0AICoSDQAgqhYxdFYLW2+9dVDfddddi17XnwUam5auXbsG9d12263RZfNP2Lz77ruzMk/U3DRstdVWWXmPPfYI2l599dVqhxMNZzQAgKhINACAqEg0AICouEbTiO7duwf1vffeu9FlV61aFdRPOumkKDGhNPvss09QL+U6W7nGjh0b1AtNA59/wubFF18cJSbUTv4Jmvl6hw4dsnL+2i7XaAAAKBKJBgAQFYkGABAV12gaMXHixKBeaKw9/6gCxLP99tsH9cMPPzwrjx49Omjr2bNnUO/WrVtWbs4jdPPj7JV6HO+HH35Yke2gfuT7Rr6ev063seKMBgAQFYkGABAVQ2eeE088MSsPHz48aCs0PJJ/ih4qy58O6Pbbbw/aBg8eXO1wornuuuuy8uTJk4O2OXPmVDscVNnXv/71oH7HHXfUKJLK44wGABAViQYAEBWJBgAQFddoPBMmTCh62Xnz5mXlu+66K0Y4SJ122mlZuZRrMg8++GBQf/rpp4te159GqND0Q5V05JFHZuX8dCQ33nhjUGe6mpYhf03xJz/5SaPL9urVK3Y4NcMZDQAgKhINACAqEg0AIKpN6hpNmzZtgvozzzwT1HffffesnH/M7scffxzUv/3tb2dlpg6Ja9iwYVk5P/2Lb/HixUE9/70E3+mnnx7U+/TpE9Tzjxjw5ftGudOIFNqOP12OJP3oRz8K6itWrMjK06ZNK2v/iG/58uVFL9u5c+eg3rt376C+YMGCCkRUG5zRAACiItEAAKLa6IfO/OGyqVOnBm354RF/mpn8UNmYMWOC+ty5cysTIJrk36Z8yCGHNLpc165dg/rChQuDuj/stsMOOwRtm222WVAv5cmYixYtysp333130Hb11VcH9aOPPjorH3rooUGbf3tzu3btGt2/JA0dOjQrM3TWcuSHfv3h0/xwaf6JsAydAQDQCBINACAqEg0AIKqN/hrNV77ylax8xhlnFL1eQ0NDUJ8+fXrFYkJpli1blpXXrl0btLVu3Tor56+zdO/ePaj74+OlPBUzPzbuxyNJ5557blZuajr/n/70pxssS+G0M2PHjg3a8k8L5amuLVOhJ2xW6kmt9YgzGgBAVCQaAEBUG93Q2X777RfUH3jggaLX9W9ZPu644yoVEprJnx07f5t53759s7I/jNaU1atXB/X87erXXHNNVn7ppZeCtlhPVPVjGDFiRJR9ALXAGQ0AICoSDQAgKhINACCqje4ajT+NhyR16NAhK+enDsmP00+aNCkrr1y5MkJ0aK4DDzwwqJ9wwglZOT8FTSGvv/56UJ81a1bzAgPQKM5oAABRkWgAAFGRaAAAUVmp0x6YWV3Nk7DnnnsG9d/97ndBvUuXLlk5/7uOHDkyqP/sZz+rcHTV45xr/NGTdaDe+g0yLzjn9q91EIW05L6Tf9zDPffcE9SHDBmSlWfMmBG0nXXWWUE9P/1SHSi673BGAwCIikQDAIiqxQ+dPfPMM0G9f//+Qd2fsffll18O2gYMGBDUW/ItzQydoUwMnaFcDJ0BAOoDiQYAEBWJBgAQ1UY3BU3e0qVLs/KgQYOCtpZ8TQYAWgrOaAAAUZFoAABRkWgAAFG1+Gs0Bx98cK1DAAAUwBkNACAqEg0AICoSDQAgKhINACAqEg0AICoSDQAgqnJub14uaWGlA0Gz9Kp1AEWg39Qn+g7KVXTfKfl5NAAAlIKhMwBAVCQaAEBUJBoAQFQkGgBAVCQaAEBUJBoAQFQkGgBAVCQaAEBUJBoAQFQkGgBAVCQaAEBUJBoAQFQkGgBAVCSaEpnZDDO7pNZxoOWh76BcZvaEmZ1W6zjKVTDRmNkq72edma3x6iOqFaQlfmhmi8xshZndZWYdCiy/wIv1H+l/8PbViteL41Qzc2Y2Lvf6YjMbWO14qqle+k4aS2cz+4WZfWhm/zSznxdYlr5TY/XSd8xsYLp/P55TCizvzOyjdLklZnaVmW1WrXi9OC5KY/mm99rm6Wu9qx2P1ESicc61X/8jaZGkod5r2X9WMyvnAWqlOFnSSZIOltRNUjtJVzexztA07n0l7S9pQn6BKsQtSe9LGmdmW1dhX3WjjvqOJN0v6W1JPSV1kXRFE8vTd2qozvrOUj8e59xtTSy/dxr3/5H0LUmn5xeoYt+ZVItEtyFlDZ2lmX6xmZ1vZm9LujX9BNaQW86Z2S5puY2ZXZGelfzDzK43s3ZF7nKopJudc28551ZJmiJpuJlt2dSKzrklkmZJ2tOL6X/M7A1Jb6SvHWVmL5nZB2Y228z6er9DPzOba2YrzexuSW2LjHm9v0r6g6TzNtSYHpdpZrY0/ZlmZm3StvXHeayZvWNmy8xsZG7dco9pTVS775jZIEk9JP1f59yHzrl/OedeLGZd+k59qcH7Ttmcc69JelrSnmbWO41plJktkvT7NLbvmNlfLTnL/q2ZZU+sNLPDzew1S87Cr5FkJYbwiKS1kr69oUYz62hmt5vZu2a20MwmmFmrtO1UM2tIj9s/zexNMzsit+7NaZ9aYmaXWBMJrTnXaLpK2k7J4zzPKGL5yyTtJmkfSbtI6i5p4vrG9D/qIQXWt1y5jaRdm9qpmfWQNESS/+ZyjKT+kvYws36SbpF0pqTtJU2XNDPtoK0lPSDpZ0p+13slfT23/abilqQLJY0xs+020PZDSQcqOS57SzpA4SforpI6KjleoyRda2bbpm0Fj2kdq2bfOVDS65JuM7P3zOyPZnZoMUHSd+pStd93uqQJ6k0zm2pmWxUTpJntIWmAwr5zqKQvSRpsZsMk/UDScZI6K0lKd6brdlJyFj5BUidJf1MymrN+2z3TuHsWCMEp6Ts/MrMtNtB+tZK+8R9pXCdLGum191fy/6aTpMsl3Wxm69+DZ0j6VMnx7CdpkKTC14+cc0X9SFog6atpeaCSbNnWaz9VUkNuHZcGY5I+kvQFr+0gSW8Wue/TJM2T1Ds9ODPTbR9UINZVkj5Q8qzx6yS182I6zFv2p5Im59Z/PT34/ylpqdJHXqdtsyVdUmTc2TGRdI+kKWl5saSBaflvkoZ46wyWtMA7zmskbe61v6PkzaVZx7SaPzXuOzek2xolaQtJJ6T9ohN9h77TxL67StpDyQfynSU9JWl6geWdpBWS/pn+bS5J1+2dtv2Ht+wsSaO8eitJq5Uk0JMlPeu1Wfp3P63IuC+SdEdafk7SWZI2T2PoLWmz9Dju4a1zpqQnvGM632vbMl23q6QdJH2y/v9E2n6ipMcLxdScscJ3nXMfF7ls5zTYF/6dFGXpL1yMW5QMfzyh5IBdqWQ4bXGBdY5xzj3aSNtbXrmXpFPMbLT3Wmsl14KcpCUuPZqphUXGnDdR0vNmdlXu9W65bS5MX1vvPefcp159taT2av4xraVq9p01St58b07rd5nZD5V8QnywkXXoO/Wran3HOfe2kmt7kvSmJTdmPKzkTbkx+zrn5vsvePvO953/Z2ZX+osqOePq5i/rnHNm5q9bigmSblVyZr1eJyUfuvJ9p7tXX/97yzm3Ov0d2is5m9xC0jLv92ql8Hf7nOYMnblc/SMlf1RJkpl19dqWK/kP38c5t03609ElF82a3pFz65xzP3LO9XbO7STpL5KWpD/Njf0tSZd6cW3jnNvSOXenpGWSununjFJyQbn0HSZjtvcrGe7wLVXS6fztLy1ik806pjVWtb4j6c8b2F++Xgr6Tm1Vs+9saN+Ves98S9KZub7Tzjk3W0nf6bF+wbQP9VAZnHO/kzRf0n97Ly+X9C99vu8U8376lpIzmk5e3B2cc30KrVTJ79H8SVIfM9vHzNoqOX2TlCQKSTdKmmpmXSTJzLqb2eBiNmxm25nZFyyxh6SrJF2cbre5bpT0XTPrn25/KzM70pI7ff6gZCzybDPbwsyOUzIOXq5JSsZBt/Feu1PSBEtuwe2k5NPrHU1tqLnHtM5E6zuSfiVpWzM7xcw2M7NvSNpJ0jMViJu+U3sx33f+y8x6pX/bHkqu9zR2Flyq6yVdYGZ90n11NLPj07Zfp7/TcZbcoXa2kmGrcv1QUnabvHPuMyXDsZea2daW3IRwnorrO8sk/a+kK82sg5m1St+bC173rFiicc7Nk3SxpEeV3JHTkFvkfCWZ9VkzW5Eu98X1jZbcez6gkc13kvQbJZ9eZkm6xTl3Q4XinqPkFsRrlIytzlcyRinn3FolF+tOVXK74HAlnywzTcSd39ebSk5h/QuKl0iao+ST98uS5qavFaPgMW0pYvYd59z7ko6W9H1JH0oaL2mYc255BeKm79RY5Pedfkquq32U/vuykjf9SsT9KyV3z96VxvWKpCPStuWSjleS2N5TctNT9sEovRlgVRM3A/j7ekbS87mXRyv5vf6u5Jj9QsklimKcrGSI+FUl/f6XknYstIKFQ8gAAFQWU9AAAKIi0QAAoiLRAACiItEAAKIi0QAAoip5ZgAz4za1OuScK3XSvaqi39St5c65zrUOohD6Tt0quu9wRgNs2sqdFgcouu+QaAAAUZFoAABRkWgAAFGRaAAAUVXj2dUAUvff/+95NQcMCOdy7Ny5rm/+AsrGGQ0AICoSDQAgKhINACAqrtEAEfXu3TuoDxs2LCvPmTOnytEAtcEZDQAgKhINACAqEg0AICqu0QARjR8/PqivW7cuK59zzjnVDgdV0Ldv36x89tlnB22vvPJKUP/jH/9Y9HYbGhqysnONT2h97LHHBvUHH3yw6H3EwhkNACAqEg0AICqGzoAK2meffYL6KaecEtSfeuqprPzss89WIyREtv322wf1Bx54ICvnb28vNOTVFH/dQtvp2bNn2fuIhTMaAEBUJBoAQFQkGgBAVFyjASpo6tSpQb1NmzZBPX/rKVq+cePGBfVevXpF3+fatWuD+vTp0zdYrhec0QAAoiLRAACi2uiGzoYOHRrU99tvv6x84YUXBm2tWoV51v/W9mWXXRa0zZ07N6jfd999zYoTG4+RI0dm5fxTM88777ygvnLlyqrEhOrx//6V9MEHHwT1l19+OSufeuqpQduCBQuixFApnNEAAKIi0QAAoiLRAACislKnRDCz8udQqJAddtghK//qV78K2vr16xfUt9hii0a3Y2ZBvdCx+PTTT4P62LFjs/K1117beLBV4pyzppeqnXroN5XSoUOHoO7PyJvvb7vssktQ/+ijj+IFVp4XnHP71zqIQuq977zzzjtB3Z+SJt/28MMPB/W33norKz/66KMFtzt//vxmxRlB0X2HMxoAQFQkGgBAVCQaAEBULeJ7NDvuuGNQ/+Uvf5mVDzjggKrEsPnm4aG6/PLLs/JLL70UtD3zzDPVCAk1csUVVwT1nXbaKSvn+2MdXpNBM+UfBdG+ffug7l/7nTFjRtB2wQUXxAqrrnFGAwCIikQDAIiqRQydHX/88UG9f//+jS77+OOPB/Xvfe97Ze1z2LBhQf3HP/5xUPenpMkPnWHjsvvuuwf14cOHB3V/OiJ/mhBsnHbbbbegnp+h2/+aBFNVJTijAQBERaIBAERFogEARNUirtH4Uzrk5W8fnTJlSlB//fXXy9rn22+/XbC9T58+WflLX/pS0DZnzpyy9on60bZt26z861//OmjLj8mff/75WfmTTz6JGxhq7sQTTyx62e7duxes+1atWhXUH3vssdICq2Oc0QAAoiLRAACiItEAAKJqEddoBg8eHNT96zJjxowJ2vJTbZer0FiqFE4Vf8455wRtJ510UkViQO2MHz8+K++8885B2+TJk4P63//+96rEhNrxrxN/7WtfK3q9/GNMCj2K5LPPPgvq+Uc5+1NbnXHGGUHbu+++W3RMtcAZDQAgKhINACCqFjF0lucPVfgzOZeqU6dOQX3o0KFZ+Qc/+EHR23nuuefKjgH1oVWr8DPXoEGDGl32oYceih0O6syKFSuy8h133BG0fec736nIPjbbbLOgnv9ahz8tVv7rF2eddVZFYoiFMxoAQFQkGgBAVCQaAEBULeIajT8+KkmHHXZYVs7fPjht2rSg3qNHj6ycf7zAnnvuGdTzT84r1l/+8pey1kP9uOiii4L6gQce2Oiyv//974P6n/70p6z8xhtvFNzPhRdemJWXLFlSQoSopX/9619Z+fTTTw/annjiiaDuf8XiqKOOCtr8viKF01X16tWr4LIdO3bMymeeeWbQNmvWrKw8c+bMz8Vfa5zRAACiItEAAKIi0QAAorJCUyJscAWz0laogJ122imo+99j2GuvvcrerpkF9VKOxZ///OesfPTRRwdtixcvLjumcjnnrOmlaqcW/aaQ/HcW5s2bF9T9Ppef5ujjjz8O6l/96lez8hFHHBG0bbXVVkHd73Pf/OY3g7YHHnigcNBxvOCc278WOy5Wc/qO//2o/PH2p3y59957y91FNFdddVVQ9/th/r1q4sSJWfnSSy+NGpen6L7DGQ0AICoSDQAgqhYxdJbXrVu3rHzeeecFbflhDv9Wv/xUMYWGzvynJkrhbM1SePtrPTxRk6Gz0uSn7Lj22msbrY8ePbrs/XTp0iWoL126NCvPnTs3aDvggAPK3k8zbNRDZyeccEJW/vnPfx60+X+b9957r9xdVM26deuycv5925+SZvfddw/aVq5cGSskhs4AAPWBRAMAiIpEAwCIqkVMQZPnj3N///vfD9ry9VLst99+WfmCCy4I2vJTi+Sn6UbL0q9fv4Lt5d5q7E95JEk33XRTUPfH2fPXF1F5PXv2bLRtyJAhWTn/uJE1a9ZEi6lc/hM3/eloJKlr165ZuXXr1tUKqWic0QAAoiLRAACiItEAAKJqkddoYvG/G9O+ffugLf8ogFpMM4Pm8aeDOfzww4O2Tz/9NKg/9thjjW6nbdu2Qf2YY47JytOnTy8Yw7HHHpuVGxoaCi6LuGbMmJGV/ce4S5+frqYe3H///Vl55MiRNYykdJzRAACiItEAAKJi6MyTn3YGGxf/lubevXsHbf4TFKVwGo/hw4cHbflhC/8W2t/+9rdBW/5JiIsWLSo+YFRUfsop3ze+8Y2gnv/6wmWXXZaVH3nkkaDttddeq0B0Tdthhx2ycv53efDBB7NyPU6nwxkNACAqEg0AICoSDQAgqk36Gs0Xv/jFoL7tttvWKBJUw/PPP5+V33///aAt/7d/8cUXs3KbNm2CtgULFgR1/9bY/Pi9/xRHVJ8/BdB2220XtI0dOzYr+0/ilD7/eIcrr7wyK0+aNClo86+PSNIrr7ySlW+55ZaC8flPa/3kk0+Ctvw0M0ceeWRWzj8mwH/ibz3ijAYAEBWJBgAQ1SY9dHbccccF9Xbt2tUoElTD2rVrs/JFF10UtE2dOjWo+0Nnt99+e9CWHw7xt4v64g+Rjh8/Pmjz/8bjxo0L2grN7p2fNWTEiBGNLuvfFi19fsjLv909f+v7IYcc0uh281544YWil60FzmgAAFGRaAAAUZFoAABRWX7MsMkVzEpboY7ln6I5efLkRpfNz+46a9asKDGVyznX+PwadWBj6jcbmRecc/vXOohCqtF38jNyn3zyyUHdfxrqrrvuWvR281PFlPp+61u9enVWnjJlStB2+eWXZ+UqXjMsuu9wRgMAiIpEAwCIikQDAIhqk/4eTSlWrlxZ6xAAROJPBSNJN9xwQ1C/9dZbs/KXv/zloO3EE08M6v71nlGjRpUd07x584L6hAkTsvJ9991X9nZrgTMaAEBUJBoAQFQMnRVpzJgxQb2hoaE2gQCoOv8JrLNnzw7a8nXf6aefHi2mloQzGgBAVCQaAEBUJBoAQFRco2mEPyYrff4pegCA4nBGAwCIikQDAIiKRAMAiGqTvkbz+OOPB/WHHnooK99zzz1B25133lmVmABgY8MZDQAgKhINACCqTfoJmxsTnrCJMvGETZSLJ2wCAOoDiQYAEBWJBgAQVTm3Ny+XtLDSgaBZetU6gCLQb+oTfQflKrrvlHwzAAAApWDoDAAQFYkGABAViQYAEBWJBgAQFYkGABAViQYAEBWJBgAQFYkGABAViQYAENX/B3ffCNdAnzqjAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def plot_img_label_prediction(imgs, y_true, y_pred=None, shape=(2, 3)):\n", " y_pred = [None] * len(y_true) if y_pred is None else y_pred\n", " fig = plt.figure()\n", " for i in range(np.prod(shape)):\n", " plt.subplot(*shape, i+1)\n", " plt.tight_layout()\n", " plt.imshow(imgs[i][0], cmap='gray', interpolation='none')\n", " plt.title(\"True: {} Pred: {}\".format(y_true[i], y_pred[i]))\n", " plt.xticks([])\n", " plt.yticks([])\n", "\n", "plot_img_label_prediction(imgs=example_imgs, y_true=example_targets, y_pred=None, shape=(2, 3))\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Mj3utDDuzDCj" }, "source": [ "### 1.1 Logistic Regression\n", "\n", "We can use a very simple Logistic Regression that receives our input images as a vector and predicts the digit. This will be our first baseline to compare with the CNNs." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "TniyY4bQzBMS", "outputId": "54a7e07e-3078-4a71-95f6-5670a051b4b2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test score with penalty: 0.9030\n" ] } ], "source": [ "scaler = StandardScaler()\n", "X_train = scaler.fit_transform(np.reshape(X_train, (X_train.shape[0], -1)))\n", "X_val = scaler.transform(np.reshape(X_val, (X_val.shape[0], -1)))\n", "\n", "clf = LogisticRegression(C=50., multi_class='multinomial', solver='sag', tol=0.1)\n", "clf.fit(X_train, y_train)\n", "score = clf.score(X_val, y_val)\n", "\n", "print(\"Test score with penalty: %.4f\" % score)" ] }, { "cell_type": "markdown", "metadata": { "id": "A8rylkCnrwIy" }, "source": [ "We can select the coefficients for each class and reshape them into the image shape to plot them. This allows us to visualize what are the pixels that are contributing more to the classification for each of the digits. \n", "\n", "But what happens if the digits are not centered? Will we still get such a good performance? Lets test that out later!" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 339 }, "id": "2pucfjpaDF9_", "outputId": "3d370f1b-27ce-4a5a-a05a-62e25f560876" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAj8AAAFCCAYAAAAe+Ly1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAABvWklEQVR4nO29ebxk11Xf+1s1V9156HnUbA3WZOHZxsaGxAw2BJIYgh0nQOL3yfBiiJMH4fMeGRiSvBDyCfDIgzAEgp0AD8cEE0xiG9nGxrIly5ZsWbKkHtXq7tt3nurWsN8fVbf2Wqu6ju69qrq3W+f3/Xz606dqnzpnn7323mffvSYJIYAQQgghJC1k9roChBBCCCG7CRc/hBBCCEkVXPwQQgghJFVw8UMIIYSQVMHFDyGEEEJSBRc/hBBCCEkVXPyQ6x4R+QkR+a0BXv9xEXlT+1hE5NdEZE5EPicibxCRrw3gnsdFZFlEsv2+9iAQkdtE5IsisiQif19EyiLyByKyICK/IyJ/TUQ+uoXr/JiI/Mpu1PmlgIj8byJysd1Xpva6PoRcLwjj/JDrARH5PgA/DOBlAJYAfBHAT4YQPiUiPwHg5hDC9+9CPd4A4AMAbgshrPTxuqcA/GAI4X/265q7iYj8RwCLIYT3tT+/C8DfA/DaEEJ9D+pzEsCzAPK7ef9d7ot5AIsAXh1CeHTQ9yPkpQR3fsg1j4j8MICfA/BTAA4AOA7gFwG8Yw+qcwLAqX4ufF4inADwuPv85F4sfK5nRCS3jdMPACjBtvtW7yMiwvmfpJcQAv/x3zX7D8AYgGUAfznhnJ8A8Fvq8+8AeB7AAoAHAdypyr4VwFfQ2j06D+Aftr+fBvDfAcwDmAXwSQCZdtkpAG8F8AMA1gE02nX6pwDeBOCcuv4xAP8fgMsArgD4+fb3NwH4WPu7GQD/GcB4u+w3ATQBrLWv+48AnAQQAOTa5xwG8OF23b4O4Ifc8/9XAP+p/VyPA3ggob3uBPAn7WtdBPBj7e+LaC0yn2v/+zkARfW7b0drx20ewJ8BuLv9/cfabbLerv8HAGwAqLU//wCA9wD41Bbq4GX56va95gE8CuBNquwTAP45gE+3n/ujAKbbZWfa7bfc/vca1waH2+09qb67ry2bfPvz3wTwVQBzAP4YwImk+gP4i+65H92i7H4XwG+htYvzgwBeCeDz7c8XAfzsVWR4K4AV9Ywfa3//WgAPodX3H0Jr502310+222sNrR2qPR/j/Md/e/FvzyvAf/yX9K/9QqmjvQjocY5/Yf5NACOIL/MvqrILAN7QPp4AcH/7+KcB/BKAfPvfGxDVwqcAvLV9/B7Yl/ib0F78AMi2X9D/FsAQWn+Vv75ddjOAb27XaR9ai7KfU9fp3KP9+STs4udBtHa7SgDuRWtx9U3q+dfRWthl28/y2R5tNdJugx9pX2sEwKvaZf8MwGcB7G/X8c8A/PN22X0ALgF4Vfsef71d52K7/BNoqe16yaTTbi9Qh87vABxBa7H4rWjtUn9z+/M+dc+n0VoIlNuff+Zq7dejLT4GuxD51wB+qX38DrQWKrcDyAH4cQB/tp36q+u+kOxqAL6z/YxlAJ8B8K52+TBaaq2r1d88I4BJtBZq72rX+Xvbn6dUe51Ba+GWQ3uRx3/8l8Z/3PYk1zpTAGbCNtQnIYRfDSEshRCqaL1c7hGRsXZxDcAdIjIaQpgLITysvj+E1l/3tRDCJ0MI2zWIeyVaf+W/P4SwEkJYDyF8ql2nr4cQ/iSEUA0hXAbwswC+cSsXFZFjAF4H4B+3r/lFAL8C4N3qtE+FED4SQmigtZN0T4/LfTuA50MI/6Z9raUQwp+3y/4agH8WQrjUruM/RetFCgB/C8B/CCH8eQihEUL4DQBVtHZmtktSHTTfD+Aj7edqhhD+BK0dkW9V5/xaCOHJEMIaWrtf926jHr+N1gIBIiIA3tn+DgDeC+CnQwhfbfe9nwJwr4ic2Eb9tyq7z4QQPtR+xjW0+uLNIjIdQlgOIXx2i8/zbQCeCiH8ZgihHkL4AIAnAHyHOufXQwiPt8trW7wuIS85uPgh1zpXAExv1RZCRLIi8jMi8rSILKK1OwG01FoA8N1ovTxPi8ifishr2t//a7T+0v+oiDwjIv/HDup6DMDpqy3UROSAiHxQRM636/Vbqk4vxGEAsyGEJfXdabR2RjZ5Xh2vAij1aLNjaO2W9LrPaXePw+3jEwB+RETmN/+1r3UY2yepDpoTAP6yu+fr0VqkbuKfe3gb9fg9AK8RkUMA3oiW6vGT6t7/Tt13FoCg1eZbrT+wNdmddb/5AbR2s54QkYdE5Nu3ca/T7rsXuhchqYSLH3Kt8xm0dhi+c4vnfx9aKou3omUvdLL9vQBACOGhEMI70FLtfAit3QK0/3r/kRDCjQDeDuCHReQt26zrWQDHeyw6fgotFcXLQwijaO1qiCpP2mV6DsCkiIyo746jZbO0Xc4CuDHhPifcPZ5Tv/vJEMK4+ldp7y70sw7+vN909xwKIfzMFn77grt2IYQ5tOyE/ipa/eaDarfvLIC/7e5dDiH82QvU3993K7IzvwkhPBVC+F60+ui/BPC7IjL0Qs+Dbvm94L0ISStc/JBrmhDCAoD/E8AviMh3ikhFRPIi8jYR+VdX+ckIWoulKwAqaC06AAAiUmjHmxlrb/kvovXXPkTk20Xk5rb6YwEtA97mNqv7ObRsQX5GRIZEpCQir1P1WgawICJHALzf/fYierxQQwhn0bK/+en2Ne9Ga3dgJ7GN/juAQyLyD0SkKCIjIvKqdtkHAPy4iOwTkWm02n3zHr8M4L0i8qq2p9CQiHybe6n3ow6a3wLwHSLyF9o7eiUReZOIHN3CPS6jJb8XWmT9NloqqO9BVHkBLfuvHxWROwFARMZE5C9vof4XAZzc9KTaiexE5PtFZF8IoYmWoTewtb74EQC3isj3iUhORP4qgDva9SWEKLj4Idc8IYR/g1aMnx9H66V2FsDfRWvnxvOf0NrqP4+WV5e3l3gXgFNt1dN70bJzAYBbAPxPtBYonwHwiyGEj2+zng207CtuRsuw9BxauwpAy37mfrQWVn+IlkeY5qfRWnjMi8g/vMrlvxetXaznAPw+gP8r7CAmUFv98s3tej4P4CkAb24X/wu0bGq+BODLAB5uf4cQwucB/BCAn0fLiPbraBkxb5sXqIM+7yxau3g/hij392ML81YIYRVtz6Z2m/ayTfowWrJ/PqhYOSGE30dr1+WD7b7yGIC3baH+v9P+/4qIbNqTbVd2fxHA4yKyDODfAXhn2xbohZ75Clr2SD+C1uL/HwH49hDCzNXOF5E/EpEfU5+X23GsIK3gncsvdE9CrlcY5JAQQgghqYI7P4QQQghJFVz8EEIIISRVcPFDCCGEkFTBxQ8hhBBCUgUXP4QQQghJFVz8EEIIISRVcPFDCCGEkFTBxQ8hhBBCUgUXP4QQQghJFVz8EEIIISRVcPFDCCGEkFTBxQ8hhBBCUgUXP4QQQghJFVz8EEIIISRVcPFDCCGEkFTBxQ8hhBBCUgUXP4QQQghJFVz8EEIIISRVcPFDCCGEkFTBxQ8hhBBCUgUXP4QQQghJFVz8EEIIISRVcPFDCCGEkFTBxQ8hhBBCUgUXP4QQQghJFVz8EEIIISRVcPFDCCGEkFTBxQ8hhBBCUgUXP4QQQghJFVz8EEIIISRVcPFDCCGEkFTBxQ8hhBBCUgUXP4QQQghJFVz8EEIIISRVcPFDCCGEkFTBxQ8hhBBCUgUXP4QQQghJFVz8EEIIISRV5LZz8vT0dDhx/Pig6tI3wg5+I32vRf85feYMZmZm+lJVynJv6acsgZY8j18H8tQPvBPZXquc6fPYvN5kuVWuB5lTllvjepAlADzyyCMzIYR9/vttLX5OHD+OT3/60/2r1YB4qb4wX/e61/XtWpTl3tJPWQLA8etEnr0WP80EQWeuA4H2U57Xoyy3yvXwwkylLENTfdiaQuh6kCUAVCqV01f7nmovQgghhKSKbe387JSdrhDNXxbNeu8TM/YxamoRm5WrH/t61V0lc9fBX5t7wSBW+1ttar07sFpvmrKNRiycKvSuZcj07vK+HknP+lLpHonPEZpJpT2pq7+pam5Lp6QGYZIKTP/u+WU79qcr2c7xSNIM5uuv/6JN+Ov2evmL1tOXPunbbKt9IGm3QF2jq476+n5sbnEHIomXvCy3IS/R79BGwvs0G+UQciV7jY3Vq18PMPILuYK7eYIst9h3Qh/6g4Y7P4QQQghJFVz8EEIIISRVcPFDCCGEkFTRN5sfrVtdcwY05Z0a0CidYpd+sal1mxumqKj1nlrf3Gz0vFXW6RODqDon6RqztgmTbEqSbA50i+21PclO9eTavGNbHjoJngb6mnX1ISv2BoVoBoINZG2ZbvYk2zF37y3bwziZX0uyBPpkP+DL1DMve9W/xBbwNj+ryiBPFwXX6ZqqFUeLVi5G9sGOaamtY0s4Wdfzlc7xmrMny6j77Xgu6xN9uXvSGHByliSbHz0H67naC1Nf3o1bI4eQIDs/r2bi77rmXHVNuU5su7raeas2VA377jMy8bLcWOl9P20DlNmibZwvU3KQRs2em1Fzsn9n6ut4WSbZi6nf7USW3PkhhBBCSKrg4ocQQgghqWIgru4F51MuSdvnelsrYSvOb6OZS2SsmsOcu8Wt+64tQrWNm9lY63k/79InCVt4oVCOx9ltuALuMltVX3l5GdVhkpz9tnuCnDPqmrl8MV7Ot5faSpWNZXuN6lIsq1XRC9+Pklw+QzZ/1fNaN+zdr+DlvtcoWXj1wXozCj8E2zbLa/G5Lq9ama0rtdHyhpX1jDq31ojnDRftvUeUHjOfsbLeNxTbcF/F/m68ENVXXgVm+lbT9oOcktNw3sq6LqpvOXn22/32RaFkeblqB65V1zmVn3qkRrC/azSjHBY3rIqxqkRbdOPD3jveL5/x6ur4uVKw9conqKu1y3UmWFnqsdnlcq1VM66e15IabMuqSeeybvq4U4np8SB122ZafabNQ5LUlF3vMO3qrudHACgOq7KEkAYJ7vNexafv3zUuk8xPNi/9gmcQQgghhLyE4OKHEEIIIamCix9CCCGEpIqB2Pz4NBJGmbodV0r9M6dDNG7r3j26NNI5rmeiXnClZu+lUyIM5e01tC46F5xeVdtJePdoHY7b6R3177pccrVuusvtvvV5UDppf13jqp1g41B3a2dtFdIlV/257twz1TWbxSFTtNSMbTi3FmXuXd3LOS1Le41cIeqbvRvz7Hq8pu4PrcrEw6Fgn3VE2aGMNJ1tST3aIwRlg9K6ZoIuv08kus162wkli2rW2vUsVeN1Gs59WbfjuUX7/IvKGOTyipX1M5eiPVYjIZtpWdn8jFWsbcHt+6M8sweGTVlBzRO2FwCyNh+Pnc2DtsVqKts8AMhVJuIHP6bd//3Gd8mklCAbyj6nWrdy3mjEX3pZLlXjGFhydj1za9GGpOrGzrCWUSm2y1TFztVZ1f2GnV1PJRvrIg03jlYWOseZtQVTlqlFO8xm3spLy6ipZQc3v3hZtsfJoGS5HRujDfVqLvgtCv3+cfZpxs3f2/yod6afI3RZc2UxXr/e29Y2MzxuPofyaDx2tlb6nbzm7AfrqpPXGs7mTPXVfMZec0R1rFzdv097VrsDd34IIYQQkiq4+CGEEEJIquib2is5EqhS93iVh3ZR9m7IetsuyU3duShrF7gFpdZYqNot3bVa7239Sj5ec7Rgr1/Ox+uXs07tpY5zfhtXPZ93IRS1S93LfX63tmOXN2K7VPK9VVtebSHKjTXr5JWpquiitVX0wruUVypTneONfO+NYr0FX1qfs9fUoQmyVk2io0YvrPdWSRW866ai6VRbWbVFH/bA1b1LFauOV5u2Lzcltnfw8lQ/XNmwz5GkDlmtxXE2v9p72zxJ7aXL/Hn6+nWv4UsaJHpc+Wjv2tW35samUlF3qRkGLE9vQqA/Ljk1/lotttO605cFNb95VeS6Cjmg2xYAFhPGRK8wGD40QSPh1VCp6BAVNqSIfleIU2s01+J8EhaumLLsWJwzuswl9Fywy7L0vT3pnanfR8s1/8uE0AE61INXuWdiO2WXLtr7qTZszF2Kda66d5iWrZOzDE12jr26Uc873vxEq9DnXX/TEeFPjBVNmYqggHLezcFbeFly54cQQgghqYKLH0IIIYSkCi5+CCGEEJIqdmzz06W/NIUJ7uw+tLUOz+1d81QIc5+RXaeK6OV+6o+9e3RN6cVnVu29tc51smz1xkdHo+4xWbVodcjazVPbwACwNghOl7qVUN39ZFxV22v8tYq5mPGu4dG+wz9fZiXqlGVt0ZSFjahXzhSsHl6q0TV6evRALHBh3bMzz8XrzZyzZePTneORQ3eZsjVl21Vw+vOiUiqPFa2tTLEW6yXL1lbB1D9YPXWXDdAA8H3Su0trtM1M09m9aVsebwui0W7OnsyErc2R8ShfbVdXdLZzuu1HCvb6E2o8jjg/YCPD2jayl2ubEm8zpd3ifcqaAY9NH34hp55vbt3KRNe65n6n5zcvy0vKBmip2rvNSjk7BvLKsEKbZa07W0dtlzJWsnOpdnMeK42ZshHllu5DRkgpjr/MsrX5CTVlK+TS5eylLBPfFS4ERQixnXwfyOo+3uxdlmT30px93nyufvXznePnP/fVzvGlR8/b8xZj2x6894ApO/Htb+wcF+/7RlNWUfPuhnNnX1WPsOzscnVfrTXt+1Q/+aoz/hvJvvA8y50fQgghhKQKLn4IIYQQkip2vM/XtaOWlN242Tuzut6W9FGPdRRP7wINE0nZbsf22uavJbjMzq3b7VHtDnrFbdXqLfmic48sqTK/XbkqKvpszqpDZEOpiZq2jTYvM7CswwlRnLP+XI2PFmyyBjtZqudrLNptau1mKU6W2QNRDhnlqpyUUVhKdotc9x0f9XS4EOWQdX24KLF/ZJz7vFHjuWcNKqJ0Q0U2bZ08mL83BGpMOnmuKjF51Z7uo358aHdpr/LQKuRhp5bS6qwJpzKeVtGadRgFnfkbAIoqC7nTxhkX14JTl2WbV89ODdh+4MMBZFX4BS9PaUQVe2jacTsIeQqiqYDPsl5UH327nFmMKp1ZN581lWwzrg9omRwaLvYsG3aqX63O0mYCl1bsvS8tx3otuHqVlPyyrl41FXF8qDRpykrKhd2rtvRc02VKoaMjwzFgWSbi5rOmcvH27aLVWd7VPafC0kvVhhTRc1btuWdN2ewTpzvHWtX16S9Y9dizSrZvfW7ZlI3fcqxzvO/W++y9VfiPEafe1Or2gyOuHZSQJkq2/w2r+cO7z29Fltz5IYQQQkiq4OKHEEIIIamCix9CCCGEpIpt2fwEAPW2Di4p7HqXvk3pPL0ePinLtcnW61zktV2Fd4NcVvo/c+zc6HRYd+8aqlnasHW8tBz1yOPOzTef4F+o9ZKVotVtZtZVSgTn6r55yUGlt/DySsg6YPD2QKKNELwdkQ4zn+udKiL49CfK/qlZjPYzc7AyX69EPfLBiaM9r+8ZakS7Mqm58PrKzT5TXbLV0iHgXZ/OTB9RJw7etd3TcH/TlFUX7Rq36rNP8VJQJxfcD0tqPPrfjSrbEP87rbfP6zQnPryDss+pu/6o7ZS6h1u8jrcF1Ij33Nc2F2tW1lD9blA2W72o5OwD6qbwz65tGH1Yj4qSyYEhO/fsH4rjcbrsbLuWL8fjxQumTI/3RiXa5Ozfd8Sc95g6Pr9o7alKuWhDUnR2X9omzE/PRZVBvOnd2XW/cjY/hswe/+2f0Je0LY+fLXW4gyycDa2as/y7Vts7ZlQKEACYfNmJznFjLZ53xzML5rzD6n1361tPmrKJVz7QOQ77bzBlTRWaxqdF0fNF2aVTMu3gQihApdwZ9qk8kuS+WY8XPIMQQggh5CUEFz+EEEIISRXbdnXfSrbUbaFd1n02c6US0yoPAGiW4rbnoss4fVn59l5YcpniFdqt07t46oiofqt2eSNuJ86uWpWY3ub3rqg6e20pZ5t+XKt7XPRibCFa5YshKQLwdjCZ6gs2e3pTqX8yo+756nHburlioz9rauqaDz41a8qemonurd91p408euNobOuM2sYHXLZ5v1WqVbJVF8VZbSnrDNMAkNXhG9yWPDa3pQcY6dlvhWsCrq5SvRqTSkU10nRus6vK9T8hInlmyW6b42zMJh2U3DMqCjcANCbiNnytMGrK1tU48q7gI3k1rnzEXi1P51psXN+9aj4xjEdCFOkXQ497rikdoJfycDE+b8WpD8ZU2aFhq0gZuhgVU9WP/6kpe/6Rr3SOV87PmLLRkwc7x/v+wts6x43brdrr7EIcO2fm7DhqhqgOGSvaepVyCe2ux49XsSeEtkiM4jwoWW7hfuLubUKo+PAw2oXd9xP1fD6cQ2M8uqLL8D5TVrrh7s7xwYkHO8erl+fNefmh2LYn3/VOU9a8N/aBK24qLSqjjbJ71rx69+VdmQ5boN3lASvnhsooDwBwa4mrwZ0fQgghhKQKLn4IIYQQkiq4+CGEEEJIquhbegttN5INvXWnwblgitZtepdrda628QGsC/tzS9auYm4tftYZ2fdVrE55f1GlL9iwuui18Xi/M84e6LmlaAN0caW3TVHNuTlrV9RcpmzKxtVxV1j+nnfYOWEQ19XpJ7puqEIT1HpnQdepLgAAyl1Zhwr4j586ZU67cjG6eL7lJms/kslE+5Tcog3XHhqqrZ0dVsgr90mXXiVTjPITZ6dk+rjP1lxuu+Tvssv0Jj6MgXYt9TYWw+vKxuOpz5my+pXYjs3leVNWW4nyrS1aW6G1WWcD1Gbq5beYz6V7Y4bo8vH77TXUcPchKjaUp3Yhb2Vm7D9qbtyqeUiGxk1RU1/Hy23QcnT2HjqdR90JU7usT7oQHOO52A9zpz5rymb/+MOd42f/+FFTdvHLMaRDwbnI36VsfnDrazqHH3923pz3oUee6xxvuOzb5Tv2d45Pjts5cUj1zS4XfOV2L27uNnYv3rZLpRXyNjEDk+Xmdb2dn/7sQ2JIlJd/N2k37uDS59Szsa8u+Azppu2ta/jwSLQBmj4a01uM33TYnDdx752xysrGBwAeej7Wc8ylhDo8rNIPNeycIPU4FrtSXJm0T25dodpBXwMAmt71/Spw54cQQgghqYKLH0IIIYSkih2rvbx7tHGv9a55ekvKuRo2i0PqxN5uez5D+uxa3NJbqNrtsEMjcWvz5FjcfiucesicV/1K3MqvO3fl8o1xe+/WO77RlGmV1bPzdkvylPqs1VyAddH12a5DRamFxIfh3bs1qo/GaQudmkh/8HVW7ow+y3FdRUtefOLrpqw0Fcs27vqOzvHXHj5rzsuXoswnyrZeWtXVdC7rUo7t3qxMmDITXsG7TqpwBJkhFxFYqVT0NjsQx80gVJk9UdvFWacGKKtQx9kVGwag/rmPdI7nHn3clC2fj+euXLTPv3Ylbmuvztj2ri7GflCaiG1TnbMZoo+rKMGZyWOmbHIoRqf10ax1OImCc/fWao5MzW696zmrWR4zRU11v97hEHZVogCAghuc2r19Avb5chef6RyvPfKgKZv58tOd49UZ+7vKdJzrbniLVU0e+Bt/r3P8ofOx/f7fB58x5y2rOfHYEdu240o9cmDYRZ4uxTbNXTllyjAXVWmZ8pApahbi51CwqrStuEAPjITwFiG79VdxUJkPqhmr2j2/GHXCj1+2Y+rZ2SjbUacWffMN0VV8Uqnxy/vtnFi681Wd489dsqqmh85Hlfarjo6bsuxolLNU7e9kI9ary+QjkxC6Qqs3807OW4A7P4QQQghJFVz8EEIIISRVcPFDCCGEkFSxLZsfQXRx7wqLr1Te3qZD2xyEotXPVlXO2ob/mfpiyaWwuLQSde+Tzn7mhtHoFikPfahz/PzHP27Om3syuvRVF60ecv990abk0JB1sz940+s7x7Nrtgm1m/2pWas/126eR0atLcjdk9G+RHxm8wHj7Xr0xy7bLl3o7Hp09m3vuVnU+lrnsli7ENv6yuPPmrLhI9FtfUm5bi5deNqcd8MDMaPw8VHbH2pfi9fMumzGQaXMaJatfruuXN19O+iE4TnnUq1TZng9/2b79TtLTCJKTlq/Djh7vDOPmbKZh74Yjx+zNlbarmfpOWsvNzcf3VU3nDu2tlMpNlRoBNdhGksxPEHhkpV1ZjTaGMm4dcXNZ9S48qHyVXh8k54DNg1J3dn8nFmO/a7etH3+xpHNz7sj0ZwKI3LQpanQCeAzc/b5GpejjUx93c4vxfE4Bg6/0trETN0Zs3OPvf3dpuzj6zGNzC8/+FTneGnW2nmNTcdxdPfxcVN2x7547yPuefIXY2qNxsUzpkynxOlC2X94d/bE9BaDxt87IfO4fgeErG2XFWVzenrOvrc+d36+c/zRxy6asosqBdDLb7TpIL7lJmXXpsJ4lI+fNOfVx492jp84ZW2KLql3aM2Ne53qKRTsGkDbRXqbH7N2cPaTuj19GqGwhfRB3PkhhBBCSKrg4ocQQgghqWJbe4BJUYElyR1bbU951zydIX154+oqAgC4sGy3CJfU726ZtG5u2a9FV85z/+0PO8cXHz1nzpt7Zr5z3KzZe5cmYj33X7FRgSs3xu0971Z9frH39ne5kO1Z5iOR7gabqhyvwkxy2tVlXhWktzp9bzCunM7dfFlli144bSMADx2M27OPPh/VHZUpq+6487aoHstfetKUVdWWecapMJGLW8re3VQ/XgjuYXU4goTIreJ+tisO0Ulj0dVVq390yIHWqfHcjOskOeWiPHTARlPV7tHDh4ZN2fChqFIqTUVZTNxq3dkzhTj+wrpT1Y2o9l23bvalvMqUvWHVcdnl6J4fXFRu6FATbnv98lzcUp9ZtfPQ/krrGfxY2AuSXIZRjO1ZmLDq3ckY1QP5ITuXlu5/c+f4TOWkKfvoo6fjNVV08PtutxnDv0G5UR8ZtfP/PhWVuiszgFbX5t386D/rnyWoPHT29K45d8AhRdZcJylrtY1TR4di7I9L7t307FxUK//JUzOm7MEnYh9/7rRVfdbVO3P1qFXtNtTcLdNRtVWo2AjSy6Uoy/ML9n16aTHWy4+TmbXY1ouu3YeU2UC5aGWg1bzeHETP112hArag3uTODyGEEEJSBRc/hBBCCEkVO/b2SjyvuuK+iGssra4CbPK1K6vWYltHRPZl+1WSvYlgrc6XPxfVXpcfi14Os1+324AzS3EbbdgldSyrLXlxW39NleSy0bRbcRnlzXJsonfUyZGCi3StE8P2iJzc7531Lo+9LaB3bn2STKMZ8h5kais6OE+NZi1+loyVw75X3N45/uK5+c7x1NED5rzX3xLVXrUv/YEpWz4d+8DEoRtMWUbVRRp2210yKtmlT8hrTzRl2jsjuIiyg5KlwScATNjOD8ptLTNqPeGm77ujczx8xKoymhsqUeaQVWUU9sVklYWb77Y3HFZeJmr7W9cDsN4bTZe8UT+feI+u5agGEBfFOSgPsuDUQqLHnNtC18FwffLXcvtzYiT0ndAeL3X396n26Gq4AZjLJbSnlrPzeCyWovdNdsLJWSW89IHnv+F4VJ+9/sZ4TR1hH7AJpbN+HKmPs875afLAy+LvXB8QPW7rNhmmVmd5FVjQY9yrRgas9vLPriP5NwpWPXx5NdZzsWqf4ZELUWWr1VwAcOHMfOd4fcXOs0VlorHm3sNPqMje+47HSN7FaXveVy/Htn74lH2fLihP569dsu/kpuqr+12i8KMjUV5N59BVysZ+W3RetWYecAlRg0uEfjW480MIIYSQVMHFDyGEEEJSBRc/hBBCCEkVOw536W0WtG7V62AbI9E+Y33V/lK7us84ux7tLletW73nbdNRT527ZDNOzz8Zo9HqLMXra9Y+QNu8HHyZjXg5qSKbZo7caspmlQ627vTuh5Q+c6xom1e7gk8PWVsQYzPgot1uhhHYrajA+pG8LUNSmW7PrC+s9o5mWhyP+vzjb7rNlBXe8N2d43N/Et3g865tX3NsvHO8/CkbEbg6H92hw4btm7qv+iih2iMzKQq2txUyLro9op0PVJYJbp51Z1uQVXYimfK4Kcsfi/ZWxarLXK/u4bOgN0YPdY7Prti2mVVjcFWN/QkXpb1cim2Yd40/ORTtAArrNjSCsTf09kAqzEFw9hfNSrRfmW/asVlrxNAWc24O2Yw8PyhX95zrKGuN+IWfe5rKrqfiZJJT9lwZ5+bfqMYx0Fy4Yn93RUW6d+7R33ZLnDNzyr5Kqi6MgBoCup0BYL4e+9Hsum3beWUPOl48aMomR1RYDW9jmhA52YR62OVoz3m31aD7zLp7vwVTZu1u5tR7subTIijG99lIyiMqJMx4xY43nTHhmfl4rDMWAMDvfDFmRXjW2fwMj0WbnCUnS/3uq/iGUHg7Nv1S8eEANF02PozwTAghhBBi4eKHEEIIIalix/t+Xdv2Ru1V61lWb9pfrqrolas1u713RW3FDTs1h94686oMTVElEB07YK9/UEVxPvaNVt0ydN9rY52nbzRlC4vx+XJuS35M+cXqY8CqTsrOZda77GoG4R6dFLZA19OrbTIJ7qA5ie3rVUEZF21XM3xrVCuOTuw3ZbNDR+KxitA75BLD3jymts9Xe/eH5pqrh36+pEjN/rm1vLyaUqnPQsb5bg7YnRZA17a/dv3VqgQAqKlEnY2m3SZvqMSuQxVb74LajvauuKfPRTfXL1+y6jLtHltUY+BupbYErBrMq49vmYr1OjpiVTF5IxfrHh10wsuCjUpdDVFlNLts++6jF50aZzdo95MFF/W+qrIHe1kWlUwODVs1QEapwTIjLoHvxajaWr9ko3w3nozR0suHP2+vqVzkGzp0gJv/c0duir+58T5TtpaJ6rizCwnj1kWGzkjsE8Mlp+JT4Ue8C7SJfJ3gIt9X2vOIn0N0PUdcdOvhIds/NbepZLDV2+18ef5ALDs4bkOt3HUwjgevZtYhHLSq6xPPWDXo5x+P/SM4FdVxde97j1iZ3KXKhgu2HYbUu9y/F3NaZ+rnTj0/V61rvY6Q3Qvu/BBCCCEkVXDxQwghhJBUwcUPIYQQQlJF/3z9tuBaBnSHSNdurD78tybjXVOVulFcpu5xlSG6puw/xtetLcT4LfG8qTd+oykLN7yic3xp1erW9b292553y9VoO4nmNrKEz6y17l/fWhNvm7qrinGvdXUxLt7O9VBqMVu7t2GSDVXm5JW//ZWdY+0mDQBPXom/026d999kQxNkVqJuujju3LnXVbj7vNV1B5WqxGfzNjpm3w76+XwIfSXL0NNWqL++0UFdUZwL77ryqV12NiTadfvCctWULasUFhMl227ans276Z6ajzJ76JlZU/bMueianssrl3Wn6z86GW0eitPWFqmqnmfB2b2MVaINyYazSVhT9oVrLuRGVdlpeduTmeU4b2i7iUGhZVlyOWj0tPHkFWu/VtHt6X53SIUbya+58AC5KNuFp8+bsrmvX4z3drZt2p4yk1WhCUasvcqRvxDll3mZleVlZT+56rKX6/F+yDX7skvPoBkpxPGer1l37ET39gHY4wXEOaDrzaBsjDILF+zvlE3aIRceoHgkNsZNk71tg8aLNsXJeCl+9qmNltS88NRsHL8bbmxPKvf5E9P23u+8P2aDf/kBWzbaiDY54lLLoK5Cg6y5lDRJdlhKljqlEABkVuf82d0/f8EzCCGEEEJeQnDxQwghhJBU0b8Iz6bQu/7G7fOhvFUtbKjt85lV506rtsJrbst1SW3J6wjSAFC+KbpOH1YZp6Xosk/feFes8ol7TNlCLm4t1pwrbzkfn9a75untZh+tUqvEml2boDp0shXLvvLm79Ff2ioYn0XekODinXFRfzNuO12j1UvNIZtVulmIW6nzGauymlmNbsZ3HInqstfeaK8BiTKqnDhh7622WXMHjtt7F+O9fQZ2e6ILReCjOuv7qS3r0KNtQ59jPJvQBU5m5YxqG9eJTs1HtcOlFav2WlRRWmddpNdJ5Sqbz9hrPqfURj57tHaPzajxMOzUakdUtFjv6q5ZdqoSreqqu7DLWkXmVXVa1XXRtUO5ENUFdx+w/XOzPfuZ1V0QQ0zkXNvOKZk8O2tVBONGJs5MYCK6PR+dsGMgdzS6L5f3P2XKVp6Pasvqom2X9bnYZjk1j5cmrVpbZ4qfhVWHXFqOc4iPJqxDIfgoxxlR2b7dPJtNcHXXLtC9xmY/EQRIO/SEV+HoMCLBZSzPrMd28RGs9ym1+j43ZxnzgkvWTb1+Oao0s2N2/swfje/CQ8OxjR44Pm7Ou1PpH+/ab3WRt4/HeuVnnjRlWFPvisLWzQtCUUVlzzmTBdWewUU096EsrgZ3fgghhBCSKrj4IYQQQkiq4OKHEEIIIaliW0pP7YLpVO3Qmsfg7Hq0vUQ5Z0Nur+eivvbQiP2ddk1dqPqM7/HzwrjVX47f9YbOcf6maIeibUsAoDkabYUWg9WdrtVivcoutbIOBe7LdLtkvT+hIu+LEtz8B5HeIgBotEPEZ5PCFDj7EUlIB5GkQ28Wotx9Zucqoi53dd3q9g8Nxz7x+pumO8c3TVidbijG1incakPom/Qn4zY7tE55kOjq6sqMjt6VNfTfFAPK9p1IQpqOkrNR0yEkvDu7LnvOuX9rW55VZ9ezvB7H5pjLHj18Msr+VuU2/rqTNnTBoeE4Hou53u7e6y5Ow5VqnGuWXL10+pymsxMZVnY95bydo7QLuWfThKq/Flzo9Ck/Nhvq4edX7ZyoM2nPO/sZ/ezBZfs+dsMDneOJ11s5F5Tb+vL5y+hFeSraXAzdfLMtvDle//KqtZXT2cQ9WiYeHRZl1KVLyKzPxw9+TOs5aovhWV4MAdKxTUlMYeTeTdC2OxsudcNctNFqXLEu8vVL0a5H22sBQFPZyY7eYEOKFJU9zdEbX905zmVdihg1bA4O2fk+O3s6fliz9qDN9Wi31DXLFsv+m6ui7UaBbvf27cKdH0IIIYSkCi5+CCGEEJIqtqX20u60Gw27ZVhQET67ojKqrUYf2biktrSnXKbZ2lg8tzFvf6fdIp+etS6YR0ajamNo6HDPe+totyH451Hbqi5SZk5n4HU7pzrbcC5pL7xLNaEOvarJ/d8PBCrK5zZUM1q1FQrW7Vd/Dk6NF0rRZXHVqSq023HOuegeH4t96eBI7B8+GnjIRRk1Jo6YMr3d7LdODX5bWsvBb5+rz775fPBuzUbb/brfG+4mwrMv1OPPuXjrdvSZnmtKNTTu1FdXVNRjreYCgLFKlNlRF4H2zv2xj9yuIjdPld0Yq0U3bh+CYEFFhNUhLwDgjFLPVd2zatfpJJWK/91x5XbvQwVsqoH7OTa1LH2/0/LyUbF1NF6vEnvaucVralNRDje8/K2mbFS5RFcunrH1VNnbdQiJ5oGbzHmzhejqfnnB1kPLxLvnazOIkYJ9VU2q/lKqWRWLjzxv6pxRct8FV3eDv5+ab3zUY11Pqa6ZsuZSjF7cmLOqyOXzM53j9Ssukrd6RzfW7TszrMc2y86d7RwfLtmwBUG5qcuqbXdzr7wdsxnd7q7MrBfc2kGbS3SpuV5kRG7u/BBCCCEkVXDxQwghhJBUwcUPIYQQQlLFjpWeQz7Xgko/0ZUdW4VozztbjQ0Vgj7rLjlR1qkvrI7+eZWB+iuXrSvgQjXef1rZH+SzV7elAYCRoi0bUTYBxsYHgNSjvUPI2ibMhd6um1t1pV53Yfk3Mzv302u6CWCtfZ+Sb3h9nrtpUDZNjYLVB2tLCZ/aY22tdxZmbV81Betqm1mOOm3tXl4btilNtBlRzrnS6/Qq0rC2EIkkhF3XZT7sQ1BGP0UX7iDbNgTr918dSektdHeqOnur6aE4PrzLuraRKzr7Ep3eQrtRA8D+oTj+bnY2PzeMKxf2uegaK89bl11RNgKNKZuOAcVoQ+LrrO11fJ2NfYnr89r1Xc8ZgO2fw37e67uPu5WlmwpMipvb91ubOy2Hmht/V5RL+XmXpsLaOFnX5iPHv6FzXJo+aeupxkSjGH+3krGuy5eXeo85Hcpi2NlWahvQSWcTVqlH1+nsinXphkkb4d5FmgFkce+6BaIsk1JCeRtJaDu3knWDl4JK2VSy46uyf7xznM3bd1NG2U3lxm1oiVCL/SMzF93lM2VrN9RUcoZ79+k6m/Pg7HVcyhZjh+nsovTv/LtWz+tdYVe2IFvu/BBCCCEkVXDxQwghhJBU0T9fvwS3YKnFbdZC2boarzfidl/GaxbUxuCYiz57Rblyzq/3jmaqt3gPugjSR0djXbybpamHd4HWmc1r1g3RuBUXtha50lPoZ4roLeBVW0m319vpswmqrOD8vfX2vY/YO6bCXefOPmHv98xjnePsdIxKmrvlNea8lWzcGl4PtlsX1HZsIWvVkpKQnR0J0aw1PpC31ix5tcUuixaArV/BVVZrhrSaGQCGlOq37juJwstzSl1nPFi349zzMdtz/czXOsfeZTczFLfN8y4L9JhSa/o615rxXJ+5fUxnHs9ZNYoe/16zpZ8nv2ozZTeHbHT5fpMQJB5HRu1c6l30NTpKtVdT6s+PPm/dly8sxbGzb8hmzh7WkZXVFFxz86XWMB4ZtbLU/bHs1JQ6rECxZk0bMqvR3btbZRTv0RXaYhdUXb3oEqWui3eD1yFFnPt3RoUfyDWtzDNKJZYbt2NPq8v0+AIAaFd7FcIgrNl2FxXOpJl3Ufb1+86ZHtSz8d5J82XOBwHRqlWxbWRG8A7kyp0fQgghhKQKLn4IIYQQkiq4+CGEEEJIqtixzU+i/tK7wCn34kzdujIPKffljFh9X065xGXE6j21jt5nBtaZj7XL+uERq//V7pPjPoWFCq8vVasHz6iMuyFnbZG0jtnbkxhXvQQdZbYr+UHr3H6ai2QAlHsYFGjrjqxz888oTWvD2fXoj/7SWi3vbZoy67F9TQZ2AJmRqDuWitJTOxscHTJh3aVe0Q7BBe/Wn+ltq2bu4csasc9lu9K5xOdbc3YYm67S/Tb9MSkRHNoluezsc44Mx/7r7YEaCfLMrC/G440VUybzarwszpiy+uXoRlt3Gak1oabsDpatu212ONrdHBw6aMr0M/hwC/p5/LNOlmK/9uEJllUsg2LRuh0H938/SJKltoMZKdp5tpSL9fT2Tjp1iU8No4dj1Y2d80txPM4620rtpq7tqXwYFP25K1SAws96xXqcg7WND2BDVpiUFbA2MiHv7C7VeO9lj7dbsjSu7l1u3Mrmx4WOkZHpznG2bMMdZKtxPOj0IwDsXOdshYzNj3qnydC4Oa1RinNws2xtwBbU7XxIjYZ6j3jzQZ3iyvcPUXY+wf1QS30nMuPODyGEEEJSBRc/hBBCCEkVA0lr613zgoqerKMjA0ChFre0C96lXLn7TbuErienlTuji3RqVBQ6c27duznHfTpZdJmA/bagwqi6klzs3POYbc587+zivbK6XwuUQgwdcHi49zPkN5xr6nqUc2beRQ2dUaqQuUu2bCWqUDIqm3HebXVPjantXte2oakyA2dc30zK7KzLEiI8d7vBx7JS7hr4+0LVtbxh2167CcuqDduQr0Z1ljTc2KnHftBYsioJLbPgxpzeUs8fuiFer2hlZt1yXSRxpXYu5+ZN2RG1tR/yLku9cs31EZC1Guyci0isI66Xxmw9e+eGHwxTuVi3Zeeu32jGZ9Cu7YA1DViv+0je8bjm5r0VFUE769TVi+txftPRs70L/gEVRdyrPLS7fMVnZ1d9rGtMq2zjXa7uCXOynlu9Kref6q6toOsiCZGNUXAu5XltWmHVetoVvfuGvUN3iDLz0O3X9GpD/Tv3fhs22dq9WULvamk1szhZahn5/gfE+/mSrcjyGpiZCSGEEEJ2Dy5+CCGEEJIquPghhBBCSKrom82P1rGJ17kmucFrFztvI6Pd4p3beGY1QavnbYeuVg84F0lfZ1XPLruQTG+XdX3NkHMu0AkpQHZb37xTRNlbFIJ1cTZpP6rO/VnJMqwuoheZUZsuIDt1SBUqHXnehXxX4Qiaro9tuW29TLROPsmOwH2+lv+i8ClXjB2c6+dN5dad2ej9VJlRmyFaVHiC0EhQ9ufU/bydoPrccPZdZg5xtkimHu5zHqvq2J27Hvv1sfIQetE7P/nu0FShNE6UbViI+WJsMx9ioZSLz+TtnbwNkEanzPCu4RodUmS4kHVlmZ5lBcQ5I8mup8uuLikMhS5LsunbY0z/9HOPfnd0vRer6tj1SDUevJ0eqrGPN7xt5UIMHyHFOEdo2zsAxmYv4+6t55LpA8dNWahEt/jgwkVgPcrLp4up6xQa/nkS3sNb4VqepwkhhBBC+g4XP4QQQghJFX3bE0x0x9YZapNUC3kbyTIkuJsnZdnuCybjbu86b/ka7rPfQb6W3NlNXVw768y9XS756ty6y+prol1bLYnFt7X6nSgVaS3rI2urbeIktWsCSeqxncrrWpCrUUl79ZIem05Nq2XWcC6vosN5+7Got+m9S636bPrPdtQTZj5JcHP26s+E+wXnTtyL3XZtT8JnLNfxdkcLbmzq37nriFGVOJOBBDnrPhAy+ryN3udt+AzsKhqzH6e6H+2wfySx16YGZlz6Mv0MJRvqARtRfdv0c/CGC9miy0rx3Nz+o6asqcJJ6Cj7zTVvvtDz8tZEoVjufaIjaeyZcATbeJ9uBe78EEIIISRVcPFDCCGEkFTBxQ8hhBBCUsWu+AEafZzXUeoMvN52IEnBnpheoM9sx40uwd5Bf/a2F9csrp5JLZ2kw27u9Hm32EN92HxNP1b4SbY714Jdz1bpsvcwNjK93c27bL/Mh122zUsIGbGjazj22hZkqyTZoSX2Vy+TBJtMe8OEPpCAmRf8NZJkmVSXPtjxXUtsx6bQhATwNpLKjbzLJlPbc2W8fdXVU19oO0vAhYdxNJLGVDbBLV3fz/dNZY/m02a9WLjzQwghhJBUwcUPIYQQQlKFhLD1jUERuQzg9OCqQ16AEyGEff24EGW55/RNlgDleQ3AsfnSgbJ8aXFVeW5r8UMIIYQQcr1DtRchhBBCUgUXP4QQQghJFdfN4kdEDorIB0XkaRH5goh8RERuFZGTIvLYgO5ZFJH/IiJfF5E/F5GTg7hP2tgjWb5RRB4WkbqIfM8g7pFW9kiePywiXxGRL4nI/xKRE4O4T9rYI1m+V0S+LCJfFJFPicgdg7hPGtkLeap7f7eIBBF5YJD32SnXxeJHRATA7wP4RAjhphDCKwD8KIADA771DwCYCyHcDODfAviXA77fS549lOUZAO8B8NsDvk+q2EN5PgLggRDC3QB+F8C/GvD9XvLsoSx/O4Tw8hDCvWjJ8WcHfL9UsIfyhIiMAPjfAfz5oO+1U66LxQ+ANwOohRB+afOLEMKjIYRP6pPaq9lPtv/Cf1hEXtv+/pCIPNj+y+IxEXmDiGRF5Nfbn78sIu+7yn3fAeA32se/C+At7Q5Fds6eyDKEcCqE8CUkx2kk22ev5PnxEMJmFsfPAjjqzyHbZq9kuag+DuH6iU14rbNX700A+OdobRas9yjfc3YlwnMfuAvAF7Zw3iUA3xxCWBeRWwB8AMADAL4PwB+HEH5SRLIAKgDuBXAkhHAXAIjI+FWudwTAWQAIIdRFZAHAFICZF/c4qWavZEkGw7Ugzx8A8Ec7qz5R7JksReTvAPhhAAUA3/Qin4O02BN5isj9AI6FEP5QRN7flycZANfL4mer5AH8vIjcC6AB4Nb29w8B+FURyQP4UAjhiyLyDIAbReTfA/hDAB/diwqTnlCWLy0GIk8R+X60JupvHGTliaHvsgwh/AKAXxCR7wPw4wD++oCfgUT6Jk9p5cv5WbRMDK5prhe11+MAXrGF894H4CKAe9CaEAsAEEJ4EMAbAZwH8Osi8u4Qwlz7vE8AeC+AX7nK9c4DOAYAIpIDMAbgyot5ELJnsiSDYc/kKSJvBfBPALw9hFB9cY9BcG2MzQ8C+M4d1J10sxfyHEFrx+kTInIKwKsBfFiuQaPn62Xx8zEARRH5W5tfiMjdIvIGd94YgAuhlSH1XWinRpWWJ8jFEMIvoyWs+0VkGkAmhPB7aP2lcf9V7vthxL9AvgfAxwKjQr5Y9kqWZDDsiTxF5D4A/wGthc+lATxXGtkrWd6iPn4bgKf6+ExpZtflGUJYCCFMhxBOhhBOomWP9/YQwucH84g757pQe4UQgoh8F4CfE5F/jJYR1SkA/8Cd+osAfk9E3g3gfwBYaX//JgDvF5EagGUA70bLnufXJKa1/tGr3Po/AvhNEfk6gFkA7+zXM6WVvZKliHwDWp4PEwC+Q0T+aQjhzj4+WirZw7H5rwEMA/gdafkgnAkhvL1Pj5VK9lCWf7e9i1cDMAeqvPrCHsrzuoDpLQghhBCSKq4XtRchhBBCSF/g4ocQQgghqYKLH0IIIYSkCi5+CCGEEJIquPghhBBCSKrg4ocQQgghqYKLH0IIIYSkCi5+CCGEEJIquPghhBBCSKrg4ocQQgghqYKLH0IIIYSkCi5+CCGEEJIquPghhBBCSKrg4ocQQgghqYKLH0IIIYSkCi5+CCGEEJIquPghhBBCSKrg4ocQQgghqYKLH0IIIYSkCi5+CCGEEJIquPghhBBCSKrg4ocQQgghqYKLH0IIIYSkCi5+CCGEEJIquPghhBBCSKrg4ocQQgghqYKLH0IIIYSkCi5+CCGEEJIquPghhBBCSKrg4ocQQgghqYKLH0IIIYSkCi5+CCGEEJIquPghhBBCSKrg4ocQQgghqYKLH0IIIYSkCi5+CCGEEJIquPghhBBCSKrg4ocQQgghqYKLH0IIIYSkitx2Tp6eng4njh8fVF3IC3D6zBnMzMxIP641PT0djl/HskxqhLBrtdg5Z/ooS+D6l+f1Tj/leb3IcicPm7axyXfm3vPwI4/MhBD2+e+3tfg5cfw4Pv3pT/evVmRbvO51r+vbtY5fJ7JsqNlSz0bZhKnpephg+ylL4PqR50uVNI7Nl+rip5+y5Dtz7ylXKqev9v22Fj+7Tmj2LmvWO4fizguitHn6OOF6/hrmehnXTLJFbeFWz3uJoBcqjaad5mrqc7Vhy9Zqvdu+kI1tqA4xXsya8/RiaMNd/7nl2FeGC1YmpVz8Yc39rpxX9xY71W80Yp2H8umSM/ACLz49lvy46jU2t3MNjR+bfeB6eEFfjR1vVSS1tfqcOM9ukYy/fjN+DlknSy3bHfaj61WWfSFBlontmfS97g/qHdx162yh93W2eu8Bk75ZmxBCCCGphosfQgghhKQKLn4IIYQQkir23uZnq/rmhtMvapufZsMUiSlTv/PX0Dh9s7bzEa+TVGXB2YKYMq/DTtJTX6P2Qd6OQJvFeH26trWZXbNtvaFsfpw5EB59fqnn/VdrUbbza7XO8dJ6b1l6eyP9eWF1w5RlM/EJ7zk2bsqOjJY6xweGrA57shxlO+xsfmRjtXPcLFR61nMv8PLcqk1Ekk1cFwm2AL3sdxJt7pLGRsPKM3EcbbFsp220G/TFridBPn6elUYcc/53veZZqdfMeYk2W6rdQy7f87SQK9nPWXWus/vS864k2GteS3LdMduw6zGyTXrX6vdp0vU95p1ZtT/T78kEeSW+F/v8jrw237iEEEIIIQOCix9CCCGEpIpdUXstK1fmkcayLdRbZY2E7VK3vS31DXVst9jMFrb6nXdZN1unHnWNZnnMFFVDdLNe927bdeW66R5nSqlK8m7/Wnt7F2RvXQF11fzWcF2pkNyjY109+4ZTPemP1bp9vuFC9qrnAcBEOcpIu5t7tdflxfXO8eqGVYM+fW6hc7w4u2bKiur6/pq3Hx7tHN84adVX+X3DneNG0wr6yEg8d2bN1mW6bF309xo9Prrcxrfo1qrdlVufE87V9w4Jigd978SL2LGRtL2OjFJ5JLnIJ4y3a0oltsVQIIBTeXSVxf4r9XVbVu9dZq5Ti/2oubpoz9NqsIzr/0ommcqoKZKsmhcKQ/bWak4OTu3cyEQVdVc8sCTX/QGETegLiSoqbw6iTUXcC6iZ0Ad6qS39ezdpbGuVonN1D0ruIW9VmKLO9e/kkFPX2WnImR5w54cQQgghqYKLH0IIIYSkCi5+CCGEEJIq+qfkTHCxGwnKPqdLD6lsIpx+MVNd6XlNjXh31/X4u+ZqdKPODI/b84Ym462LI/YSuahjXtmw985memv668oQRrtRA8DzKzV/eofJktKFZ+2atNDcfL7BWBh02VTotnZ6VVE2FS5TBHJKr5txIQCaQdsKWb3/vkrU85adMVRRtYW2N6o1e9tanV+0/eGpK7E/LFStzlrbEWnbIwAYLsbhMVmyumjtPl/M2d+tKAOufc7GZzO1x17Ziiy5VCJlpVPPhd52AF1hIkJv2wIzHutuTNeUzZU+b8PZ7WkKRftZh5Pwdnw5dW7O2R1od+ku25PeISqMu7QfD+7/frPlNCLus9Sc7Y6fd01hgq2G/lyZsGX52NZN1bbeDlJTckY4mXVlH1S1IS+07Ym+V+sL9awbK6Yor5614eq81oj332hYWZZzezg2d2i/1SVXNaa6+oAae95+KyxHu8jm0rw6nrNVWV9FL0SFKsiMjJuyzEiUQ9a9h0Mp2k+GfNmWKfugrrIXaQ/EnR9CCCGEpAoufgghhBCSKgai9vLucGZbchuRYpulkd6FejvWXTOTj1t4UlEuke4Sxv3ObZFr1cUV5668Vo+fZ1bttqOOSHzFRRPWKpyTE9Z1upSN23sN9zyFwmbdBrO57ttFR7TOuO3RrKpnrma3QHUbFoq2a/lM7uaaSj2Y37ChEESpQ/RWrY/4Oq62RPcdsG1765Takq/3rkfeqSm19rHhum1DqfG8VPx19hpdm0rO/r2jq1p300E2pyPlJqjEvMu6inCdXZs3RfVL5zrHjSvPd461err1hRpzTkUlpSjfzJB1j9aq7a6t96H4uStisFKrhHyld5nfTvfZq/tM19jULskJqq0ut/Qtzrvd7RLH1VywZWfm4ti8oFQl5xetCvP8fFS3bNTtXHpUhZC4/9C4KTs0HNvWq6Qryr29KzSIaiM/7SyqMBglZ15QlFbZru0IJEVjTgjlomUrNVem5uSwPG9vt6HmT1eV5sKVznFj7lLnuDpzxZy3fiWqKdcuW5WYpjg+bD4PHdnXOS5MTZuy7NhU59iP2exoLGsOT5myUI7nerWo78dXgzs/hBBCCEkVXPwQQgghJFVw8UMIIYSQVNE3m5+GqNQN1QVTpvXUXmeudZve7ibJbbWW1XYcTuebj3Y+ZWXj4EOd6+zbnkJXXPTImYWoO720bO16Vjbis15xZWuq7EvnbRu94aaoB33FIWvrtJZtaWi3kVd7YGRD78zAuj19CPNiQuh47e4q3uZH67SNe6tNUxGy8XdFZ4NS0vYPVSdzZdfSZe9QjOEOQsHqsJsF5Xa5DbuPTRuEQbrT+rD9a83Yl33Gex2SYMMZSNhwAl2WTZ2jkhublfE4/vKuH2Qb0eZC2xk0162NysZSlFMmb/tOYTReI3iX9ZJKg5CUXTwpbH9SJuucsy0YcEqEJOuxxPQF/hl0ChLnyq9TR4SStaFaU31idtneT9s7rikbyXzC3FlwYSFKan7OZ+y7QbvM19atrdCyst3JuftpM8xcxtZZ2+6VdzubRVJoAhcSQrulZ1adbc3STOewOXfZFGnX9OD6v3ZFz04fMmW5k7fH373yHZ3j2TUrk8sqXMtw0cry2Ei8fmX5OVvnC0/1rvPK4lWPASCzFu2Fvbiaqh8HF1plK3Myd34IIYQQkiq4+CGEEEJIqtj5xp/bwsvpyJJuSzm7FF3nurK26m3jru3Y6Aa55qICz67ELd5VF7W2mItbYAcq8ThXs+60SfUaGT3QOd4o27IbxqPKo+5UBVrt5d06v3R6Pp63aLf5L8zHzy9/222mrNDe2u+nA7Xs8HrG1TYpaqyPCKxdN5u2XUSrvbxbZ48swl5FpevSfO7rpqz27OOd47VLdgtZu2Tmj5w0ZbkDxzvHmZFJUyYJqgKtimkWrbos1/Yt77czfDPE8Aw+urZmzbn6l9SQ06oEAJhZ1WPMlq0r/UHTqdJ0ZOwplX0bAPYdiO6qE/tv7hxXFuw2eUWryr3KUYXACF4NpeYQW2OnQnWu4DqavFep2uvbueDFZpa+GmZs+v6fEH3dbPW73wWt3nUyWZfYhhtuLtUhPwouhMPNk3Ee1G7jY8XebeKj3qupGqv13mrXharvf7Gs6X53RrnWN5wK/PhYrHOj6eo5KBXmpiySshS4eU+rusKVc6as8fyZznFdhYsAAMmq92TO9lU9n+HYHabsVDa+737rM3Es/ucPf9Wcd/Grn+8cjx+73ZR9/1+5v3P8Q688asqOKC1b1qnJdT21KhwAglaJrVqVmJTjvCtujtiKWQF3fgghhBCSKrj4IYQQQkiq4OKHEEIIIamif0pO7Urp7D2aVZVNtmR1vjqVQtPpXDdU9bzNQVXZHKw4W4WRonLDDVGXmuQyKM4tMKcyDO8bsmG1R6eiArOYkC5gw7ngL83Gdjj72BOmLDSj/tTbZWx6cg4qaUJXCH1dlhCawLsLm/OSbCG8i7y2HXK/a6rw+kHZgHn9fHb+bOe4dvZJU3bhM491jpfO2T4wfkMMu76/YHXkmUq0LckWrI2RrmWXJl/VUwo2VUJBun/fD0SiPYV3Z1+tKZfhZm9tuO93S8p+7byzUdP2Qd6uQpN19kcTyn5uuhKP1+sHzXlza7FPLFTtfJKVWHZszNoC3jgR7Qkmy7ZsrBRlmK8520PVJzNrNgyFebokN/hdQGf07nLz19nZMzZcRl3ZRKy6eWlZ2dOsuTLdXXz4Dy3bKdXWpbMPm/P0eBRnhyJqXA3f8kpTtlCM867vYtrN/vmlqiuL85Lvm4dGon2Tj5AyX219kZAB58Xh51KdnsSHXlCfQ9PNl8U4J3ald1HpITIue7rccE/n+ErxgCl75lK0edPzx8vvtS7xL7/3OzrH95+cMGVvUeFaRl06ktCM49LbT+ZUn2i4/mHS3jR7j7UuV/ctwJ0fQgghhKQKLn4IIYQQkioG49vnInVm1Dadj4jaqMSts2bFbqNp90bvzj63Frd/Z1z2dB15cikft9vGvIudUnXpTLYAEC6eQS9KymXw1sMvM2WTx+OW3pFRqyo5PRO3Fr/+2T8zZee+Gt2zr6zdb8oODrfaLCExen/RkUeTopJ2udPG5+3Ksqvknlm3IQeaKpJyl9pLqRw3gsrkDJdlfHG2c7zwtVOmbOHZqN5sbCSoKfyWvOq3yNu+o9VxXeh28a7KA3CN3mRTDdFwisym+uyjONfVFne96d3gY121+7rHq9K0OiQpwf3TSg38xTNWHfnE2ah6WnVZwssjURavu8Nu3+uq3LFvyJTlVGXGXOTpoNTcXVGbtczcHHJprXXD2m5pv3R/cn1JR1XXYwUAltRcuuIqe24htq9XMers6ZW8VWUcUFnXc42owvFz6epTUcVfnBw3ZfkjN6kHsPXSffXSip3jvzYTI7qfn3cq2fU4r99+yKqFdHZ43283VXcJAap3xmZ/SlKR+nlBq2HHbRb0zNThznHOXbOpwhjUx46YslMLsV3OKTWX52+8Iv5u3xtOmLJ8VYUlqbko+/l4/ZC187+ex0PJqmQzK9G9PVu086osKNOUQu9M7eLUm6FHiBRz3xc8gxBCCCHkJQQXP4QQQghJFVz8EEIIISRV9C29hcHpLxvlcXVHZzuhdIGzG1bZqm1+dCZ1ADg1F/WNc87m5/nlqMO+Y1908bzngHXbq0zGa3gbn0sPRpucxWcvmrL8UNRnTr7M6kQn7767c3zgprtN2Q3fEtNWVOtvN2Wf/tTpzvGpOZt5/L4DLXfpHXjz9SQguvB2XVbL1rtZ6tQU3uZA2e7UslY/a1xmnc5XX6fu1uOza9qlOupxD2dtG9VORbuC6py1KRpWev/8kK3X+K3HYtmxW2219p+M93b2aFq37u16dEiFLpupxguHut8Jgmhf423D1pSre9b9uZNXdjC1pu0J46VoA+VTZjSHemdNHlO/q+Tc79Txc0tx3I6V7FQ0ORzH2OeetiHvS872RFNV/stV1xBV5cNcdQ1R1vZqhQRbNtfnN93pc7v1Z6S6v08Ts6SqvdGwZYvV2C6zazash7bzWVi3ZctqTi66NltXbV3OxZAOh47YcVRZnu8c6/ARAJA9HG1+1oet/dbsfOwfF7w7+3IsW9uw8hqrKDtPZ6umn2Eob59nuW0LNTDbyqQwCd7WsRjbqTlq20XPrd6G76JK+3TxOWvXc3Yhvu9GXbu84Xicz8aee6RzXP/S07bOyp43e8Kmt9g4GFNmaPsiwIaqKbt2nyrH9/LY8D5TltVhZqr2eXxqm+3CnR9CCCGEpAoufgghhBCSKgai9upyFS3GbcimU3ksN+L6y0dxrqktvTm3Vfvlc/Od48dVtnTPW+6JW2raNRMAblAZmhsum+yVr5zvHF942GbODapeQ1+w2ainvxCjmU7d+Xlb9qY3d47/w3e9zZT91HRUGd0yZV10dzeGLOwWbIKc9dYsYLNDrzmXcq01qfvItArvajurQhpMVWK/ys2eMuetXorutWM3WxfPfCVuE+cm7bZq7qhytT10iylbGo4upVr9BgBrC7FePprpQaUmMZnEAawVWiq45gBc3jdddKsuRK2Ocpt1Sk6tjhwr2ufQKiQRO6ZLyh942D3/RDaqJLQba+uG8Zonx6KqJEzZLew3n4zhJB5ykWSfVCEjyk4FdlBF8M37DOJJQcd1n/Tzl4q2KzV7v1xbfTCo6OtJJEW13Wj2DmmQd77c2p291rDPpyPYa3MCAPifT8Qxd0LNX99z1zFz3olXxwzffjZZQ5yDZ1bs2J93KjjNkbHSVY8BG0V8LSFLvQ/tcHSk9bu+u7pv4udSHZE76yIbT0ZzijNLth1WNqIcfJiXUyqjfdO5f2uV3z0HbRb0sUtf7hzXnokR8TMT+22db3pF5/BizkZq/vOn5zvHnz87b8o26nH+PDppo96/6sh45/j4mH1HT5WiOk4SQlB0RTvfAtz5IYQQQkiq4OKHEEIIIamCix9CCCGEpIqd2/wk2CyEgrVZCcq9ve5sBzaUnU82IRa+twfSdj5nvmptcnKFeL/TR6Kbs88wnVEh7aszs6YsX471nL7NZnXXlCasvlm7wVfnrcv1ype/0DmubFjX/X/y6mgDdLFu7R/KmZbutp8rVcEO7RSU3jq4jOWL67316brpq+7O2ibl9IK1K9D691tUVu7mrA0/ULklul1m91mbH1HZjbXbMmBDLSzmrA3TM7NRRuddigVtD7PfuX3vPxRDtGcXbT0b+VZ/7Lc3bUBMseDd2XVdmy7tgTKj6hof2hbEmfVguBDLRhZtmIjGl/60czz7+FdM2dqlGAZgYymGKxg6aO0H9r/9uzrHD5x8A3qRdw+rM8WPFW2ZtmnqsuvQtiG+j6iwBl1hBOoDdo/2aDsHl6Ijo0IVeHunzRQ5AJDL+Iz28To+9YXO8u5TTDz8eLT5+Zyan/3Yv/dwnIN9ioyjo7Ge5byts7bRnKo4mxh1i5ITpqh+HIK3f4vHk2Vbl2zbImm37LeM/Jw9yzll53Pape/QdpGL69ZOakmVjTh39nsORjkcFuc2vh7HYu6OV3eO1/fZsAXPqPADf/6stef78CPR/vXKrA1FMj4e58SpYft+qyVla8/GZ5AkO0lf5u2DrgJ3fgghhBCSKrj4IYQQQkiq6JvaS29P+S2nRsIaa7TYO8uujl6pt+4BoKa29woVu/07NhXVMW++Lbo2Hx629ao//tV4r0W7TTd5e3S1PX7EZtXNTcbPGZdptqm2D0PdbhNnSlEdmBmx7rtQLsH7J6yraL+jAXeR4ILZtZ2oPvvM0RuNKJOk7eaGu9+ccmf3rpu3Kbf/4obKKHzgBluvW6Nb+mrOquN0QFGfiX61GfvfZedqq6OKP+8izK6pbX6vrd2Q6EZazto+V8m2GmL7jpnJZABsaqK8y/p6fWc6mYJxZ7eyLq1e7hxrNRcAnPrdP+ocP/VHXzdlp1fidv6CasMHlHoaAF5zPEa1Hb/tTabskHJn96q6snLNLm0j7LJW4fpQHTpqg2sGbE4pA3OPTppnHUOIfXQ42P4qy3HsZJYum7KwHlUg40NWDrXpm9CLi89GNUdNXeN/Tdnx96Uz853j77zvsCl74HAcK0PVeVjUXODm2UTX5gT1yHo2ql+KGTcuBhB+Iun6ei5aDfZ5ZpU6y7us6z7fcGUFpQY+MmrbbP+Q6jtO1dQ8EMN8LBejCvrrMzZz+yMXYj/671+8YMrOKTkPjVrV1p1HY7969XH77juhQhVMFp1pyspC59jP3d6EwbAFWXLnhxBCCCGpgosfQgghhKSKbam9AoDNHfSuHyZEBdZbwpl16wGFelQtFGvWqr2kolzec9B64lx6IEYNXXCqkruORO+Md9wWVVS5Jz5hzpt/OHpfNWt2S23kWIxsWbr71aZMpqNKzEfmzK/HbbqwtoyejFhVWlAJ43REWQBAtncSyb6QtEWYcWW+bgrtYbLmVC1apbngPBQuqMixdec2oz2OdFLclUm7HX95NV5z3fWHg2q7dyjvVLKN2FfXnWpVq1rXnLfhUCFexyfl3FDXLCWoDQeFV8Fo9ZUv083tvS31FnqxbtXC2cW45b30tFVtXXg4ln3JeaqcXbOy3+T+hm37/P6oHln2Km8d+X3DRr/V3l/lnPcqjJ+9liqv1Ch+tBnVglMzbCY63S1nL5PY1KvV1dyTWVswZY1zT3WO188+acqWz0c1mE/8O/a2d3aOs2K9XrOFqEJavRJVYBeenTPnTSq1xm3T1hN4ePFcrOOjHzdlyEaZiFN7ZcZiXbLOhMAkBS3a+xVHlEou7JrUWvixr1RPVVeXci4+e37IJWBViVy9yl3PWZOl3h5ya8VxU6bnAe3h573zhtW85+fEkvLIe5t6PwPA99x1sHO8GUl7E/3OlPlFUyYNNb69ClgnHM5t/x3JnR9CCCGEpAoufgghhBCSKrj4IYQQQkiqGEiEZ2/zo/XP2lYAAJrzMUpoc866YOYmnu0cv/rwbabsgTdGeyDnIY/CYszIXvvoL3WOz33qC+a82mq0R5i4xbqXF05E179w9E5T1hhWmcG9+10puvTJkLWTMOd5Nz2ty3ftFwbs6u4130aN3LDPp+0MvNuvti1Zt+pgo16fW7N2GpeW4zVHnJu21mmvILpPerd0bVM0UXJ66qzKTl73NktRV+wjjWobphMqQilg7XwOjVhZahOVzIZ1FW02r27z8mIJiHp7b9czqgS17myqCupcHxVYX0dq7jkWY0R0cVGW990R7dle6+q5fDG6RBeVO+w9P/h6e+9Xv6NzrO25ABsOwUd41o/gTIWM7Y5/Vm1a5ueTsm4IZ5sh7v++4+dZPRf4eUH17bBg59L6lTjv+sjzTWVDUjho3dR1dGtvX1IaimNifSieNzxux8Nb74xhC24cty7Q+NoTncPVU8/0rFdx3GYhzx+Kc7e3B4Ky+elqP2Wz6OcCH7G+73hXd2XPMlyy916pxXN9xPZxNb9pGxzAzq1+PptR42itbi86qkK4T+Xj5L0v2HGfPRTfb586YGVSUOEq/tKdB03ZDdloy5O9ZPum1G1YBo22qU10bd/BO5I7P4QQQghJFVz8EEIIISRVbEvtJQBy2Nxe6r0d23DJS5NWWGEtboOvPvVVU7Z8/sF4Db+9nbfucpr6etxGWz4b1Wo152Y7ejxuz1eO2W263ImYKHNDq7kALKqQr3kXabRYUtvEassYcOost00nG0pF5lRNg3Z1T9qyl2aj52fvTjuinn3DqVeWqnE71idIfHYmhgQ4Omm3f/V1dNLFnNuDPzQc+0N52apWZUW50k+dNGXLSn3mowVPV2K7e5fPSZURdDjvtrPVZZquDwyKJoBqu602o0hvol3WvTu7Vhf6six6byVnKipZ5ckbTdnxUmy36bvn7TVVqIGhG2LIiPxr32HOe6Yet9S1ay8ATJajrH1SWR3dupLvrRLLBafOVerI0PTqCVXmXWoHHbrAb+cnzSFKfRBc4mQdid7La+SVyi35iDUveCpEl/I/ePQJU7a6EMf/4Zfd3Dn+/m+5xZz3LTfHawxn3Nym5k9xYTWyqh9lh22oE1FyyBStSrquxlzIWTWb6HAqW0h+2Rd6qGSaKlSAi8pgkvJ6Fa3Gz1la1eVDisyuRRXWURf9eZ+az/IzUf0YXP/OlGJk/ZephLUAcIuKxn/DqEsae/Z057i5at3Ztct/pmLlrPtH0lgQ984MW3hncueHEEIIIamCix9CCCGEpAoufgghhBCSKrav9Oyl49b6OKei1HYPjYoNRZ4dO9I5Lq9aF8zFU9F249Ij50yZdpn19kBZ5dqbUyG+x2+cNOdN3Rn1l8U7XmnK6hPRHmGhanWN2g5lQ6x9hUpa3eVyrLNMZ8XqRMvKhTb4lBIDICC6uHdplLUsnWu21KJtUmbJ2u4UVNlwxdpQPbcUn+/CorVHeOT0fOd4YdW6wb98f9QB3zChsv+WXfstnI312lgxZfXxaNNwftk+z/PKzb7mfJwnlG3JiPPrH1auoT6Ngr5Ol83Bpnttn8MXZEw9etsIeNsCUfYE2aZLl6BsurrS0ujzJvabzwWVlqB4g7Uhye6L471x/N7O8ddXrDy1vcJE2U5TIyNRFj4LdHbpYvzgPWgT2lxnchfnUqttM3bDTsSMzYQ6exdhY7fkzs0ejPMZDtjUMPPDUSafOmPtMf7gy093jr/08HOmTKe3eNtrY+iRt99mbSSPVmJtMlU7NrWbeunocVtWjGWZ4XFTlhmJnxsVO683h6KNUVc7qP4RCkPYFTbfmUlj3qUNGhF1bq13SqGic5Gv5ON8s7Rh7WL1e8vPn0PrMXSFnvOb4zYEzMxCnJ/13AwAt03H/pCdO2vKQleIEcVW33d+7CWNxS3Mr9z5IYQQQkiq4OKHEEIIIali+3u4m9tJPlql2mby6p4FteM1v27VGtVG3Cq77VXfacoOleK2ZGHkT03Z/FNxW62x4bLLTsXtuMr+qGbTmdoBoHTHKzrHYfqEKavmo6tt011fR4pteu87XWaLTOZc775YUi59IevEMgB3WkFUjnRFeDYqzIQIs/MXTVGmFF3WK2rrGQCGlQty3amXVpfi9v1XnBv85/ZHOewfjvKbcqqQ5kjcal8Wm1H43GLsc1+fXTZlWn2lXagBq+oac1GjC0p+RdfhRbmfBp8jfHOrts8yDYjRm0uuPrq5vft6Vm1xe9VWZjVm5864CM8N7a5at2Na1LjNTln1Z/NQDCFxajW26UUn90PDsd2OuCzQhdUr8foXbR8Mql7eBTpot+CkDNFF50Kuosw2nQvtICI7C9QYdFv7OiSGcduGjRjs1URhOI7HpVE7Ph46F+X+R48/b8q++kxUh5SH7bO/7oFoNvDeV0X1yMHaJXMeFlUfcyrpoEJn5PYfMWW6H8mIVW1peTVGD5iyqppoi+JCzWu578I82wSw1q5P2Ydq1rf2slQqTRMGBYDoKONLtq3H1fw85lTu+vkyV+xYD4txTGE09pXFnFVtTZZje/pI+pMNpSZfvmLKTEiDklU36j6Agquz6v/BhZXRn3cSgoI7P4QQQghJFVz8EEIIISRVcPFDCCGEkFSxc1d35wJdzUQdbMG57WVUuovVmtWnP3Ul6oBrDeu2d+Sub+scT998vymbvvBU51inyACA4GwQOlV32X9lv8oMn7f3XlMpoX2dtU2HT+MQlAWNzw6ts7MPO9fpUIz3F9d+3ianHyS60yZljtanuRD6zZVob5F3bX3DgTs6x/e6sOh/pkKtX37e6qIffCJmAD45Eduo1rB6Y93Wz87Pm7JF5TatbXwAYKwUdcxjLqP8pLIrqngfcS0jZ1ZgtM+DTn/QJoOYffzSqq3Q/op6LhcGXme2lrqzO1Cyb1atzY+38zF1KSvZjEybsuV8lP2qys7uU1FMKVfcwrKzLbv8bKzGFWujomn68Z6PkhHnIowRZfPj5gJtTzCw7O2eHvOs/uxTz4SG+ly2thpNFWJkyYXuaKrBc8tB+7uD49Fu6r4jNlXLa49FWU6d+WznuHHFppfJqlAIfi7T9h6ZSWu701TpcpolO2cE5Wa/GqxNmLbJLDVtn9bj0aduGAR6XPrxpTF2PLA2QNqWC7B2U405a/OjPzcWnN2NHrPOfiY7Gm2qchOHOsc6PAsAnBxT4VqWbXb27MypznFz3b6TdTocydp7G1d+Z1PXVOEIfKqgkFfjewey5M4PIYQQQlIFFz+EEEIISRU7D1fqtpn0RpY4/+8RdZdx5x63fziqHc65yL9PXYkufkdGx03Z0WOv7xx7dcXwxnznWLvrenWSjja9UrDbqpeX4hZhzam28sqVOHjVljpeqtot6+mKyjzuthMbakM9u0uqki3hXYL1tmTObjc39TbrhWdNWUlt17/55N2mrPr6k53jP3zMqjHOqj7w3x6NEWb3j1qXyGEdybti63WDyhR/w7h1fz6ossFPub5pXExdhmTTLl1u07GNGmKHmNvw7Rtajbmv4u6iVSU+E7hWney037kIrdrN2kd0X1EqZD2OJnyYgWZsexO1GUD98vnOcXNp3pTpqMBwaiG93S7ONbap3Nm9S+1uqS4794dSSSfd26ukc7GvNfO2nzeLSn1QtZPWrdNxfNw0aVV+BSWjoyU7BnIzMct79ZnHYoEPfaDUj+LCD2TGo1q0MWZd3RsqXIazPMBGI35RbdhCrUKVutNJ60jerv1Cn6Oue7oyjetx6U/WbvgukrdWQfv+X78Ux8by+RlTVluJv8uV7Pw5drtVEW9SWJsznzNa1bU8a8oaqi5etaU/d40v1SeCi64eijHUSXCqz7rqxrmuwDIvzDX0liWEEEIIGTxc/BBCCCEkVXDxQwghhJBU0T+bn0ZC1lZ17v6K/Z3Wzz63ZKtzXtkAXViyek9th7PfhV0/MBRd4kamo82Bd7OcW4/64Nklq6de8DYeus7KcqPm7JuWq/Gaw84WKaczaDslr075sBs2Pya9hU9VokOKZ60tRlC2BNkRa8+h9c9h3YZkb1w43Tke2rCyfPtN93SOX7bvZlP28WejHdFTyg2+XLB95WUHom74noNWNzyt3KYnnV2PzlyOpd6uqN4FM+gmS5CXT+WRaQu+3+ELktKVmPq5/mr07y6Vgta/S3nYlInX2+vfDcV+4d1T87XY8fdV4v1GXeiH7BVlr7Dh07NvDW/XY0JdODf4JBuuJIL7vx+Y9Ba+Lkn1VH1U20oAQD0Ty4bytg+MFno78OuMIDWfRXv6xs5h7rUxvYxP1aDr0ixYm59VibYnc1Vnn7OqstS7BvYpZTQ69YzPlm7SJSS17SDYxv2Ctk3yKU4Sxp4O51AYtfZb2Xy8TrZsx0Z2LNpXyVoMWZJVx0B3GAODTmFRsSETTCoKH0pC29vl7bjU88eas701cvb2WkxvQQghhBBi4eKHEEIIIali52ovR1eUYIXercpt2KzaE8qNb2jSZgLXmbvn/ZaowiVIx4JWZ63F44bbO60qX7nlDavmyqj93rzbYjXqK7cFWcnHz96dvaSiBOso0YCNKF1zW3bDe7lEdZmP9RZ2w8k8fzhmefbRRXX059r5p02ZKNfl20fGTdkdx2IG6nCz2kr1/a0RI4pKzbpgynpvlazJ5u225M327HZcjtVnkb4NsS2TGIU4459DRbzNuqjAeru9aCNqS0W1qXf1T2i3EaXe0hnmZd1urwcVRkFcpufMkFJrOjWedat2EZ5VNukul1qtIvNtpOW712Eo1P29Cinjo0H3oJDpPZ85zQIurURzgM+cWzBlnz8d3aAbSr27b9S27cHR2FemXGDtUjbO/xlXrwNDUSYjCSYEJRd9Xfer7lAd+Z5le0nI2z6unyhkXOR1NR4ybr7UkfWzUzbKso6k7EMOZPfFMAM6cn9XFP/VaHqgI3cDVnXmQy1oN/9myanEir3H5ZKJceBMCPS9u+a1F+bakT4hhBBCyC7AxQ8hhBBCUgUXP4QQQghJFX0zSDC2E65MRx/POhdMnbHWh9KeLEeX2WKut42Mi27e5V68if9aX7OSt65/2gW/7O6tL+Ozuutw8FnpbX3h62hc3RPcOPuFDqHfhc587N1blZ2BD1Ou3RQzBWsjkhuONh3NRWsP1JiJ7pMbF06ZsuZ6zBYtWRW23skro1wrpTzUsyy3/6gpC0muw8rWoysku26XBH1zQZxr+YD+3kiSp7EfcM+4hmgD0fBu+c7WRqOz3Hs7kQ11nYz3QNWVUXXx80JT19Nlos9MKzdg77KucelXoPpMYgoLb0cE5XLt7+HHRx8ISMg4nu3tAq3tWXRaHwAo6L6cs3KtFaMN1YabTFeVXeSZuTVT9pkvx1Q0F5+NaQ+GJqxNx9Sh+PmWozb0we2H472PjVk7kZFClJFPYZRXNj/FrG0r3Vvy7lm1nL1LfrPdB/sdhmIrdMlb98+CNZTSEpJDNqyH/l3OpcXosk3U11H2YmF1sed5OnWN7DtuyuoqHUnT2QjqUAv+nbm8oexdXYgZ/V70NrTlsvocnE0wXd0JIYQQQixc/BBCCCEkVQzED9dvayVF40xyPcyFuAWWcWV6C0y7rAN2q0y7t/tts7Fi/Dyc0BJ+S1JrB8rud/pRfSZic03ndj+diVuU3t1vt9E169pa32Ik3KaPSqq2n6Vst2rzkwc6xzkXGVpHig7aVdNnDdYuzl4Volycmz66aE5n83bCNKotX5agGtxhtOBBsdaMnTIDqwqaN2EhertKr9bstnI+QdWno55nElS/JTUei35LW42B0ZFjpmxEqaQzw5dNmVVf2efRW/teLqasYaO9h0xRHe9+6AKDzwyuEJXFPrNh3ZxFRTr22cVFzUVZF5F7v4rC/R23HzBldx2M6qzzi3H+WveZ1BX7huy9b5uOY7OUs2Nau7B7V/cxFTJB6i6idEIbhWzCmN5t9P1dJOqgVF3VvFUJh9Jk53jZvWRq6uWUL9uxpzWa3oxEqxErEvt/ZsOqOrU667wtwkWVJSGz4vtAPHnV1bmp6uyzImjTkUPDdu7KNXuH29gKez8zE0IIIYTsIlz8EEIIISRVcPFDCCGEkFQxEKXnQtXq9KZUVu1c1brRidOvmzLlqjfiXOe07rvhPFqhsq5rGxyp2tQaxmew5l0NlU7ZXT2jXWGd+2CtEPWzyxu9dd91Zw9UrFSufuK1jNez5rQrcW83XG2bAABoRjsDGXZ2JzpVRLO3EZXW5XtX1ZBku2Pc+p3rvk55sI3s2nvhKptEWWJ7b7ghr+3zdGoWAJhdi2PziRlrQ/L05fh5ed2O4VXV75ed66p2p88qO4Nywd77xv1xHJ2YsC7Qt0zFueDg8IQpK6kBP1L27tHxuCsdj7b5SbAZuZZoOvscMxe559M2P7Ji3eDl4rOd42zV2s9MDUW7nmk3Pu5UtnXhSJRDc8imKdJ2IlJ3qWZEjRYfF0ERxI4qWVf19GNa2Y/5LOF7buejMGmf/HtE7UucWrBtpm2qZlbt2FtU422hasvW1Utnw72AltTvRkqxjaYqdiw0Qnwn6zkAAGaXY9lYxb6Uj07G91vJ2fcdGYkyumXKjnVt9+XDhqyFeI8yer9re8GdH0IIIYSkCi5+CCGEEJIqBrIHuL/ioqdqNz6vDlHHXgWmXRizNesCbaLDJkSuTMK4lPtrJLlLKve74CLT6m39obxdWxYz6mmvARfoHbFVdY9/vNBbLaVVEF0qowRVl0G7XvvfbDHj705d1q81NZdHP1fRuQWfLMW2qrswAOV87Od5l/F9pBCvubRhVVuXluL296VFez+9vb7cY6sdAE4rNdvCqt321+68M6t2nOp63TZtt9CT4qaPFuN1/KSoI2nsQvD1RMx86cefypbdcJGNtQkBnFpKxg52jjN+DtZzt88ar1XSeq5evGBOyySonXU28y4Xdf18PlKzVrG73yWF6uh13l6g+9JM3aqJlpU6y4eKEWXWcXLcvmu1qcXcWu9nr7lwNHPr+n6x3b3phlY5nxy342uyrCKMuxAX2rXeZz6YVOrp0WDni1ouzkkbTmBWq739pcx1+gYmhBBCCNkZXPwQQgghJFVw8UMIIYSQVLE7fn/Z3vpZkznb65STULpcn35C25AYnXWXm2XC2k+XOZuRpBD3ReVWfN3a9Si2oxdP0q9v+R5dtlfmBr3P6/Ubz0tAJi+GpkudolO1eFkPK5u10qi1udg/pLPB975fw6Vx0R/nVDqNvLNr0DYJ/hraZmC0ZIWdlAlcZ492icAxp9J8TJd9iH1ck3RlAteydXI2c2tSdu+uEAC93ef7gg41kd2ZzV1XqhJ1vNd2PVtlyvVj89m1u5b7esOPjfh+zWesTY4e6xnXp3XIiyKi/U/VpcMx9kc+nMKGSkWUc+959W7Q6XYAG56i2rR2hwWV4qrfYQrS/SYghBBCSOrg4ocQQgghqUJ8dvHEk0UuAzg9uOqQF+BECGFfPy5EWe45fZMlQHleA3BsvnSgLF9aXFWe21r8EEIIIYRc71DtRQghhJBUwcUPIYQQQlLFdbP4EZGDIvJBEXlaRL4gIh8RkVtF5KSIPDage75HRC6LyBfb/35wEPdJG3shy/Z9/4qIfEVEHheR3x7UfdLGHo3Nf6vG5ZMiMj+I+6SNPZLlcRH5uIg8IiJfEpFvHcR90sgeyfOEiPyvtiw/ISJHB3GfF8vuxPl5kYiIAPh9AL8RQnhn+7t7ABwAcHbAt/8vIYS/O+B7pIa9kqWI3ALgRwG8LoQwJyL7B3WvNLFX8gwhvE/V4e8BuG9Q90oLezjP/jiA/xpC+H9E5A4AHwFwcoD3SwV7KM//G8B/CiH8hoh8E4CfBvCuAd5vR1wvOz9vBlALIfzS5hchhEdDCJ/UJ7VXs58UkYfb/17b/v6QiDzY/ivxMRF5g4hkReTX25+/LCLvA9kN9kqWPwTgF0IIc+17XhrgM6aJa2Fsfi+AD/T9ydLHXskyABhtH48BeG5Az5c29kqedwD4WPv44wDeMaDne1FcFzs/AO4C8IUtnHcJwDeHENbbf+l/AMADAL4PwB+HEH5SRLIAKgDuBXAkhHAXAIjIeI9rfreIvBHAkwDeF0IY9E7TS529kuWt7bJPoxUD+idCCP/jRT4L2duxCRE5AeAGxMmW7Jy9kuVPAPhoewdvCMBbX+RzkBZ7Jc9HAfwlAP8OwHcBGBGRqRDClRf5PH3leln8bJU8gJ8XkXsBNNB+4QF4CMCvikgewIdCCF8UkWcA3Cgi/x7AHwL46FWu9wcAPhBCqIrI3wbwGwC+adAPQQD0X5Y5ALcAeBOAowAeFJGXhxDmB/oUZJN+y3OTdwL43RBCI+Ec0l/6LcvvBfDrIYR/IyKvAfCbInJXCIPIp0GuQr/l+Q/b13sPgAcBnG9f95rielF7PQ7gFVs4730ALgK4B62VawEAQggPAngjWkL4dRF5d1v9cQ+ATwB4L4Bf8RcLIVwJIVTbH39li3UgyeyJLAGcA/DhEEIthPAsWjt5t7y4RyHYO3lu8k5Q5dUv9kqWPwDgv7av8RkAJQDTL+ZBCIC9e28+F0L4SyGE+wD8k/Z38y/2YfrN9bL4+RiAooj8rc0vRORuEXmDO28MwIX2XwzvQjvFZXtr/GII4ZfREtb9IjINIBNC+D20DO7u9zcVkUPq49sBfLWPz5RW9kSWAD6E1q4P2uffCuCZPj5XWtkreUJEXgZgAsBn+vxMaWWvZHkGwFva17gdrcXP5b4+WTrZq/fmtEgn++qPAvjVPj9XX7gu1F4hhCAi3wXg50TkHwNYB3AKwD9wp/4igN8TkXcD+B8AVtrfvwnA+0WkBmAZwLsBHAHwa05Inr8vIm8HUAcwC+A9fXqk1LKHsvxjAN8iIl9Bawv2/deaDvp6ZA/lCbR2fT4YGKa+L+yhLH8EwC+3jWcDgPdQpi+ePZTnmwD8tIgEtNRef6dPj9RXmN6CEEIIIanielF7EUIIIYT0BS5+CCGEEJIquPghhBBCSKrg4ocQQgghqYKLH0IIIYSkCi5+CCGEEJIquPghhBBCSKrg4ocQQgghqeL/B1q1uolrM7hOAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "coef = clf.coef_.copy()\n", "plt.figure(figsize=(10, 5))\n", "scale = np.abs(coef).max()\n", "for i in range(10):\n", " l1_plot = plt.subplot(2, 5, i + 1)\n", " l1_plot.imshow(coef[i].reshape(28, 28), interpolation='nearest',\n", " cmap=plt.cm.RdBu, vmin=-scale, vmax=scale)\n", " l1_plot.set_xticks(())\n", " l1_plot.set_yticks(())\n", " l1_plot.set_xlabel('Class %i' % i)\n", "plt.suptitle('Classification coefficient vectors for...')\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "kK8v34AK6xXJ" }, "source": [ "### 1.2 Feed-Forward Neural Network\n", "\n", "The first step is to create the functions that will allow us to implement a feed-forward neural network and manage the training and validation process.\n", "\n", "The MLP class will define the architecture of a feed-forward neural network, with a set of hidden layers (fully connected layers [nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html)), with a activation function in between them ([relu](https://pytorch.org/docs/stable/generated/torch.nn.functional.relu.html#torch.nn.functional.relu)), and a [softmax](https://pytorch.org/docs/stable/generated/torch.nn.functional.log_softmax.html#torch.nn.functional.log_softmax) in the last layer. Since the dataset poses a multiclass classification problem, the last layer should have a number of neurons equal to the number of classes." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "In9r_o8vvNaz" }, "outputs": [], "source": [ "class MLP(nn.Module):\n", " def __init__(self, dim_layers):\n", " super(MLP, self).__init__()\n", " self.dim_layers = dim_layers\n", " layer_list = [nn.Linear(dim_layers[l], dim_layers[l+1]) for l in range(len(dim_layers) - 1)]\n", " self.lin_layers = nn.ModuleList(layer_list)\n", "\n", " def forward(self, X):\n", " X = X.view(-1, self.dim_layers[0])\n", " # apply relu\n", " for layer in self.lin_layers[:-1]:\n", " X = F.relu(layer(X))\n", " # use softmax for output layer\n", " return F.log_softmax(self.lin_layers[-1](X), dim=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "h6OVD_1xUwWH" }, "source": [ "##### training validation function for the MLP and CNN" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "B1eUu01N8wIR" }, "outputs": [], "source": [ "def train_val_model(model, criterion, optimizer, dataloaders, num_epochs=25,\n", " scheduler=None, log_interval=None):\n", " since = time.time()\n", "\n", " best_model_wts = copy.deepcopy(model.state_dict())\n", " best_acc = 0.0\n", "\n", " # init dictionaries to save losses and accuracies of training and validation\n", " losses, accuracies = dict(train=[], val=[]), dict(train=[], val=[])\n", "\n", " for epoch in range(num_epochs):\n", " if log_interval is not None and epoch % log_interval == 0:\n", " print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n", " print('-' * 10)\n", "\n", " # execute a training and validation phase for each epoch\n", " for phase in ['train', 'val']:\n", " if phase == 'train':\n", " model.train() # set model to train mode\n", " else:\n", " model.eval() # Set model to eval mode\n", "\n", " running_loss = 0.0\n", " running_corrects = 0\n", "\n", " # iterate over the data\n", " nsamples = 0\n", " for inputs, labels in dataloaders[phase]:\n", " inputs = inputs.to(device)\n", " labels = labels.to(device)\n", " nsamples += inputs.shape[0]\n", "\n", " # set the parameter gradients to zero\n", " optimizer.zero_grad()\n", "\n", " with torch.set_grad_enabled(phase == 'train'):\n", " outputs = model(inputs)\n", " _, preds = torch.max(outputs, 1)\n", " loss = criterion(outputs, labels)\n", "\n", " # if in training phase, perform backward prop and optimize\n", " if phase == 'train':\n", " loss.backward()\n", " optimizer.step()\n", "\n", " # increment loss and correct counts\n", " running_loss += loss.item() * inputs.size(0)\n", " running_corrects += torch.sum(preds == labels.data)\n", "\n", " if scheduler is not None and phase == 'train':\n", " scheduler.step()\n", "\n", " epoch_loss = running_loss / nsamples\n", " epoch_acc = running_corrects.double() / nsamples\n", "\n", " losses[phase].append(epoch_loss)\n", " accuracies[phase].append(epoch_acc)\n", " if log_interval is not None and epoch % log_interval == 0:\n", " print('{} Loss: {:.4f} Acc: {:.2f}%'.format(\n", " phase, epoch_loss, 100 * epoch_acc))\n", "\n", " # deep copy the best model\n", " if phase == 'val' and epoch_acc > best_acc:\n", " best_acc = epoch_acc\n", " best_model_wts = copy.deepcopy(model.state_dict())\n", " if log_interval is not None and epoch % log_interval == 0:\n", " print()\n", "\n", " time_elapsed = time.time() - since\n", " print('Training complete in {:.0f}m {:.0f}s'.format(\n", " time_elapsed // 60, time_elapsed % 60))\n", " print('Best val Acc: {:.2f}%'.format(100 * best_acc))\n", "\n", " # load best model weights to return\n", " model.load_state_dict(best_model_wts)\n", "\n", " return model, losses, accuracies" ] }, { "cell_type": "markdown", "metadata": { "id": "0CBE5tRMZEfr" }, "source": [ "We will start by creating a simple network with some hidden layers. Thus, in addition to the input, it will have 3 fully connected layer which, in this implemetation, is assigned to the input of the MLP Class. We will use the Stochastic Gradient Descend optimizer ([optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html)) with 0.01 learning rate and 0.5 momentum. The loss function to be optimized will be negative log likelihood ([nn.NLLLoss](https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html)). Training and validation will be managed by the function \"train_val_model\" previously define." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 981 }, "id": "200WI3xND6_M", "outputId": "79913a00-abf0-4e48-8177-64b49bbd6fac" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0/14\n", "----------\n", "train Loss: 0.7707 Acc: 77.78%\n", "val Loss: 0.2943 Acc: 91.41%\n", "\n", "Epoch 2/14\n", "----------\n", "train Loss: 0.1756 Acc: 94.85%\n", "val Loss: 0.1514 Acc: 95.56%\n", "\n", "Epoch 4/14\n", "----------\n", "train Loss: 0.1097 Acc: 96.83%\n", "val Loss: 0.1115 Acc: 96.57%\n", "\n", "Epoch 6/14\n", "----------\n", "train Loss: 0.0750 Acc: 97.82%\n", "val Loss: 0.0881 Acc: 97.24%\n", "\n", "Epoch 8/14\n", "----------\n", "train Loss: 0.0546 Acc: 98.44%\n", "val Loss: 0.0816 Acc: 97.46%\n", "\n", "Epoch 10/14\n", "----------\n", "train Loss: 0.0406 Acc: 98.84%\n", "val Loss: 0.0704 Acc: 97.84%\n", "\n", "Epoch 12/14\n", "----------\n", "train Loss: 0.0296 Acc: 99.20%\n", "val Loss: 0.0738 Acc: 97.92%\n", "\n", "Epoch 14/14\n", "----------\n", "train Loss: 0.0218 Acc: 99.43%\n", "val Loss: 0.0755 Acc: 97.80%\n", "\n", "Training complete in 2m 28s\n", "Best val Acc: 97.92%\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD6CAYAAACxrrxPAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAiqUlEQVR4nO3deZhU5Zn38e/djS07LjSo7CDatIiCLYpGVFwGl4CaRTB51URjFtHEaBJM3ngpzrjFMZlJyCgxBseohJeYSAyGOKgDGjU0KsgiS5ClcWtQQEBlu98/nqpUddFLNV3d1XXO73NdddV2uupull+des5zntvcHRERKXxF+S5ARERyQ4EuIhIRCnQRkYhQoIuIRIQCXUQkIhToIiIRkVWgm9loM1tuZqvMbGItz/c2s+fM7DUzW2Rm5+e+VBERqY81NA/dzIqBFcA5QBUwHxjv7kvTtpkCvObu/2Vm5cAsd+9b3+t27drV+/atdxMREcmwYMGCje5eWttzbbL4+eHAKndfDWBm04CxwNK0bRzonLjdBXi7oRft27cvlZWVWby9iIgkmdnaup7LJtB7AOvT7lcBJ2VscyvwVzO7DugAnN3IGkVEpIlydVB0PDDV3XsC5wOPmNk+r21m15hZpZlVVldX5+itRUQEsgv0DUCvtPs9E4+luwqYDuDuLwFtga6ZL+TuU9y9wt0rSktrHQISEZH9lE2gzwcGmlk/MysBxgEzM7ZZB5wFYGaDCIGuXXARkRbUYKC7+25gAjAbWAZMd/clZjbJzMYkNrsR+JqZLQQeB650LeMoItKisjkoirvPAmZlPHZL2u2lwKm5LU1ERBpDZ4qKiEREwQX6iy/CzTeDBnRERGoquEBfsADuugveey/flYiItC4FF+hlZeH6zTfzW4eISGtTcIE+aFC4XrYsv3WIiLQ2BRfoPXtChw7aQxcRyVRwgW4Whl0U6CIiNRVcoEMIdA25iIjUVLCBvn49bNuW70pERFqPggz05IHR5cvzW4eISGtSkIGuqYsiIvsqyEA/8kgoLtY4uohIuoIM9AMPhP79tYcuIpKuIAMdwji6Al1EJKVgA72sDFasgN27812JiEjrUNCBvmsXvPVWvisREWkdCjbQtaaLiEhNBRvomrooIlJTVoFuZqPNbLmZrTKzibU8/1Mzez1xWWFmm3NeaYaDDoLDDtMeuohIUoM9Rc2sGJgMnANUAfPNbGaijygA7n5D2vbXAUObodZ9aJEuEZGUbPbQhwOr3H21u+8EpgFj69l+PPB4LoprSHLqotrRiYhkF+g9gPVp96sSj+3DzPoA/YBn63j+GjOrNLPK6urqxta6j7Iy2LxZ7ehERCD3B0XHATPcfU9tT7r7FHevcPeK0tLSJr+ZDoyKiKRkE+gbgF5p93smHqvNOFpouAU0dVFEJF02gT4fGGhm/cyshBDaMzM3MrMy4GDgpdyWWDe1oxMRSWkw0N19NzABmA0sA6a7+xIzm2RmY9I2HQdMc2+5Q5TJdnTaQxcRyWLaIoC7zwJmZTx2S8b9W3NXVvbKymDu3Hy8s4hI61KwZ4omDRqkdnQiIhCBQE/OdFE7OhGJu8gEug6MikjcFXygqx2diEhQ8IGudnQiIkHBBzqEA6PaQxeRuItEoJeVwcqVakcnIvEWiUAfNEjt6EREIhHoyZkuGnYRkTiLVKDrwKiIxFkkAl3t6EREIhLooHZ0IiKRCXS1oxORuItMoKsdnYjEXWQCPdm9SMMuIhJXkQl0TV0UkbiLTKCrHZ2IxF1WgW5mo81suZmtMrOJdWzzRTNbamZLzOyx3JaZTY1qRyci8dZgCzozKwYmA+cAVcB8M5vp7kvTthkI3Ayc6u4fmlm35iq4PmpHJyJxls0e+nBglbuvdvedwDRgbMY2XwMmu/uHAO7+fm7LzI7a0YlInGUT6D2A9Wn3qxKPpTsKOMrMXjSzl81sdG0vZGbXmFmlmVVWV1fvX8X1UDs6EYmzXB0UbQMMBM4AxgO/MrODMjdy9ynuXuHuFaWlpTl66xRNXRSROMsm0DcAvdLu90w8lq4KmOnuu9z9LWAFIeBb1IABakcnIvGVTaDPBwaaWT8zKwHGATMztvkjYe8cM+tKGIJZnbsys6N2dCISZw0GurvvBiYAs4FlwHR3X2Jmk8xsTGKz2cAmM1sKPAd8z903NVfR9VE7OhGJqwanLQK4+yxgVsZjt6TdduC7iUtelZXB00+HdnRtsvrtRESiITJniiapHZ2IxFXkAl1ruohIXEU20HVgVETiJnKBrnZ0IhJXkQt0UDs6EYmnSAZ6cuqi2tGJSJxEMtDLymDLFrWjE5F4iWSga00XEYmjSAa6pi6KSBxFMtDVjk5E4iiSga52dCISR5EMdNDURRGJn8gGutrRiUjcRDbQ1Y5OROImsoGuqYsiEjeRDfQjj1Q7OhGJl8gGeklJ6DGqPXQRiYusAt3MRpvZcjNbZWYTa3n+SjOrNrPXE5erc19q42nqoojESYOBbmbFwGTgPKAcGG9m5bVs+jt3Pz5xeTDHde6XsjJYuTK0oxMRibps9tCHA6vcfbW77wSmAWObt6zcSLajW70635WIiDS/bAK9B7A+7X5V4rFMnzOzRWY2w8x65aS6JlL3IhGJk1wdFP0T0NfdhwDPAA/XtpGZXWNmlWZWWV1dnaO3rpsCXUTiJJtA3wCk73H3TDz2T+6+yd0/Tdx9EDihthdy9ynuXuHuFaWlpftTb6OoHZ2IxEk2gT4fGGhm/cysBBgHzEzfwMwOT7s7Bmg1ETpokPbQRSQeGgx0d98NTABmE4J6ursvMbNJZjYmsdn1ZrbEzBYC1wNXNlfBjZWcuqh2dCISdW2y2cjdZwGzMh67Je32zcDNuS0tN9Lb0R12WL6rERFpPpE9UzQpuaaLxtFFJOoiH+ia6SIicRH5QFc7OhGJi8gHutrRiUhcRD7QQVMXRSQeYhHoZWVqRyci0RebQAe1oxORaItFoGvqoojEQSwCPdmOTuPoIhJlsQh0taMTkTiIRaCDpi6KSPTFJtAHDVI7OhGJttgEelmZ2tGJSLTFKtBB4+giEl2xC3SNo4tIVMUm0JPt6LSHLiJRFZtAB63pIiLRFqtAVzs6EYmyrALdzEab2XIzW2VmE+vZ7nNm5mZWkbsSc2fQoFQ7OhGRqGkw0M2sGJgMnAeUA+PNrLyW7ToB3wZeyXWRuaIDoyISZdnsoQ8HVrn7anffCUwDxtay3e3A3cAnOawvpzR1UUSiLJtA7wGsT7tflXjsn8xsGNDL3f9c3wuZ2TVmVmlmldXV1Y0utqmS7ei0hy4iUdTkg6JmVgTcB9zY0LbuPsXdK9y9orS0tKlv3WjJdnTaQxeRKMom0DcAvdLu90w8ltQJGAw8b2ZrgJOBma35wKgCXUSiKJtAnw8MNLN+ZlYCjANmJp909y3u3tXd+7p7X+BlYIy7VzZLxU2kdnQiElUNBrq77wYmALOBZcB0d19iZpPMbExzF5hrye5FakcnIlHTJpuN3H0WMCvjsVvq2PaMppfVfNKnLp5wQn5rERHJpVidKQpqRyci0RW7QFc7OhGJqtgFOqgdnYhEUywDXe3oRCSKYhnoakcnIlEUy0BPTl3UOLqIREksA/3oo8O1xtFFJEpiGehqRyciURTLQAet6SIi0RPbQFc7OhGJmtgGutrRiUjUxDbQ1Y5ORKIm9oGucXQRiYrYBrra0YlI1MQ20NWOTkSiJraBDpq6KCLREutAVzs6EYmSrALdzEab2XIzW2VmE2t5/htm9oaZvW5mL5hZee5LzT21oxORKGkw0M2sGJgMnAeUA+NrCezH3P1Ydz8euAe4L9eFNgdNXRSRKMlmD304sMrdV7v7TmAaMDZ9A3ffmna3A1AQ51+qHZ2IREk2TaJ7AOvT7lcBJ2VuZGbXAt8FSoBRtb2QmV0DXAPQu3fvxtaac8l2dNpDF5EoyNlBUXef7O4DgB8A/7eObaa4e4W7V5SWlu7/m73yCmzd2vB2WdDURRGJimwCfQPQK+1+z8RjdZkGXNSEmupXXQ2jRsG4cTnpIad2dCISFdkE+nxgoJn1M7MSYBwwM30DMxuYdvcCYGXuSsxQWgr//u/w9NNw001Nfjm1oxORqGhwDN3dd5vZBGA2UAw85O5LzGwSUOnuM4EJZnY2sAv4ELiiOYvmG98Icw1/9rPQfuib39zvl0pvR3fUUbkpT0QkH7I5KIq7zwJmZTx2S9rtb+e4robde28YK7nuOjjpJBg2bL9eJr0d3ZgxOaxPRKSFZRXorVJxMTz+ODz4IBx//H6/jNrRiUhUFPap/506wQ03QFERrFsHGzfu18sMGqSpiyJS+Ao70JM+/RROPx0uuSTcbqTk1EW1oxORQhaNQD/wQLjrLpg3D77+9UYns9rRiUgURCPQAS69FG67DR5+GO65p1E/qjVdRCQKohPoAD/+MYwfDxMnwl/+kvWPpU9dFBEpVIU7y6U2ZvDQQ9CnD5x6atY/1qOH2tGJSOGL1h46QNu2cOedYQbM9u3w7rsN/kiyHd0LL8Ann7RAjSIizSB6gZ7kDhdeCOefH4K9Ad/6Frz2GoweDZs3N395IiK5Ft1ANwtrvSxcCF/+MuzdW+/mX/0qPPYY/O1v8JnPQFVVC9UpIpIj0Q10gAsugPvugz/+EX74wwY3Hz8+HEtdtw5GjIDFi5u/RBGRXIl2oANcf31YzOvuu+GRRxrcfNSoMJ19zx447TSYO7cFahQRyYHoB7oZ/Od/hhUZs5z5ctxx8NJLYY2Xc86BGTOauUYRkRyIfqADHHAA/PKX0L9/OFiaxZovffrAiy9CRQV88Yvw85+3QJ0iIk0Qj0BPN2FCOOr54YcNbnrIIfA//wNjx4aRmx/8oMFjqyIieRO/QB83LrQn+uIXQ6uiBrRrF4ZcvvnNsKLA5ZfDzp0tUKeISCPFL9BPOw1+9auw633ddVkt5FVcDJMnw7/9Gzz6aJg8k6Me1SIiORO/QAe44gq4+WZ44IFwwDQLZmHm429+A889F1brfeedZq5TRKQRsgp0MxttZsvNbJWZTazl+e+a2VIzW2Rmc8ysT+5LzbF//Ve46ioYMqRRP3bllfDUU6H73SmnhNamIiKtQYOBbmbFwGTgPKAcGG9m5RmbvQZUuPsQYAbQuPVr86GoKLSvO/PMcH/btqx/dPRoeP552LEjhPpLLzVPiSIijZHNHvpwYJW7r3b3ncA0YGz6Bu7+nLvvSNx9GeiZ2zKb2eTJcOyxjepwUVERlgk45JBwMtLMmc1Yn4hIFrIJ9B7A+rT7VYnH6nIV8HRtT5jZNWZWaWaV1dXV2VfZ3EaMgPffD9MZ//jHrDseDRgQ5qofeyxcfHEYkhcRyZecHhQ1sy8DFcBPanve3ae4e4W7V5SWlubyrZtm2DCYNQvatAnJPHJkWNQrC926hYOko0eHFQZuuUW9SUUkP7IJ9A1Ar7T7PROP1WBmZwM/Asa4e+M7Nefb6afDG2/A/ffDqlXw8cfh8SzSuUMHePLJsGLj7bfD1VdnNcVdRCSnsgn0+cBAM+tnZiXAOKDGiLGZDQUeIIT5+7kvs4W0aROaTK9ZAyefHB67/nr4zndg06YGf/TBB8Me+kMPwUUXZbUMu4hIzjQY6O6+G5gAzAaWAdPdfYmZTTKzMYnNfgJ0BP6fmb1uZoV9iPDAA8O1O+zeHRZyGTAgnCpaT0sjs9Cn+oEHwjK8Z54ZhuZFRFqCeZ4GfCsqKryysjIv791oixeHhVxmzYLevcPpop/5TL0/8qc/waWXwhFHhKUDjj++ZUoVkWgzswXuXlHbc/E8U7SxBg+GP/8Z5syBnj2hV+KQwo4ddf7IZz8Lzz4LW7bA0KEwZgy8/HIL1SsisaRAb4xRo8I8xT6JE2Evuij0LK2jtdHJJ8OKFTBpUvixESPgrLNC0GsmjIjkmgJ9f+3dG7pf/O1voSPG1VfD22/vs9nBB8OPfwxr18K998LSpSHUTzklLCGgYBeRXFGg76+iIvje9+Af/wgzYf77v2HgwDAsU4uOHeHGG+Gtt0KvjXfeCcMyQ4fC9Omh5Z2ISFMo0Jvq0EPhpz+FN9+Eyy6D4cPD42vWhBkyGdq2DWurr1wJU6eGSTOXXgrl5WElR81fF5H9pUDPlf79wzrrnTqFID///LAmwMyZtY6rHHBAWMV3yZKwh96+fTgx6cgjw9IyyfOaRESypUBvDsXFcOedIcjHjg1zFy+5JKwRUMumX/gCvPpqmEjTs2foktevH/zkJ/DRR3moX0QKkgK9OZiFIH/jDXj44XDwdNGiVHPqBQvCUdEbb4Tf/x7efhuzsFP/wgsh9489Fr7//TCh5rbb4IMP8vsriUjrpxOLWpJ7CPsXXggdk+bPh08Ty9706RNOXCovD2uzt23LKwvacMcdYdSmY0f41rfgu9+F7t3z+2uISP7oxKLWwixcf+YzMG9eaEz68stw331w4onhLFSAu++GLl046eZRPDn4R6z++Z8Zd84m7r0X+vYNQzLr1uXttxCRVkp76K3R88/DH/4Q5ri/9lqY09i5Myv//iF33VPEqodf5AM/mMPOHMQlnzMuuggOPzzfRYtIS6hvD12B3tpt3w6VlbBhQ5gWCew85nhKli5kY3E3ntkziucYxQfDzmbEZf24+OIw4UZEokmBHjUrV8K8efhzz7F79hwOqH6H2QddyujN0wC4se/v6fH5EZx75RGUl6dGekSk8CnQo8wdli+HPXtY3e4Y5vx6DV+7ox8AyyhjQZez8DNHUT5hFMNGHaRwFylwCvQ42bsXFi5k6x/msPmJZyl9cy7t9mznMh5lXs/L+Mo5VXz+6Dco//pptDmoY76rFZFGqi/Q27R0MdLMiopg6FA6Dx1K50k3wa5dbJ0znzFvDWLHbNj+yBMM2f1tdk1sw8ruJ8GZo+jz1bMoOeOUcPqqiBQs7aHHzPbqHbz6i7/xwYxn6fHmHIbuDX8HX7vkA84b14ULi5+m3db3wprvyUu7dnmuWkSSmjzkYmajgf8AioEH3f2ujOdHAj8DhgDj3H1GQ6+pQM+/Tz+FuX/awuu/eY17/n4GGzfCU3YhF/ifa244ZAgsXBhu//KXsHlzmDOfDPwePVJt+0SkWTUp0M2sGFgBnANUEZpGj3f3pWnb9AU6AzcBMxXohWfPnnAC6zN/+oQVz23gg4XrOXzPenqznq6HteEfn/s+I0fC2PtGcuAr82r+8IgRYc48wA03hMXJ0vfw+/cP69mISJM1dQx9OLDK3VcnXmwaMBb4Z6C7+5rEc3ubXK3kRXExnH46nH56W2AAH388gFdeCSe0zpoLL00Nq0DCXI7tv50Lj69iZL/1DOtWRWn/Tvxz8syCBWFvfuvW1It/4QthSUmAa68NIT94cLj07h3G/UWkybIJ9B7A+rT7VcBJ+/NmZnYNcA1A7+Rp7tIqtWsHZ5wRLhDWaX/9dZg7F+bN68ADzx/NnU8cDYSzVEfOgNNOg5GT53LMMVC0bStUVcH69XDQQeFFduwIbZrS1y3o2BFuvTUsVLZrVzhLdvBgOOwwTaAXaaQWneXi7lOAKRCGXFryvaVpDjggLDdz4okhe/fuhWXLwh783Lnh8rvfhW0PPhhOPbUzI0eWc9pp5ZwwDA6AsOj72rWhc/bSpaEX6+LFMGhQ+MHly+Hcc8PtQw5J7cVffXVo7ZRc3ExEapVNoG8AeqXd75l4TGKsqAiOOSZcvvGNkLVr1qQCft68sDMOYW//hBNC69XjjoMhQ7oweMgIOowYUfNF+/cPHbQXLw6dPxYvhkcfhQsvDM//9a/wla+kgv6YY6BbNxg5Erp0gQ8/DJcOHcKHR/v2YSxJJCayCfT5wEAz60cI8nHAZc1alRQcs9CUo18/uPzy8Nh774VgnzcvLEfz8MNhZeDk9kcemQz45HV7+pxxJnbmmakXdg9fByDstZ97bgj6++9PtXVatCgsIP/b34b+runatg3tAfv0gQcfDD/XoUPNy89/HjpNPftsOAbQvn34FEpef/az4ROsqiocG0g+ntymjU7nkNahwX+J7r7bzCYAswnTFh9y9yVmNgmodPeZZnYi8AfgYOCzZnabux/TrJVLq9e9O3z+8+ECIZfXrg3HTBctCtevvw4z0uZEde6cHvBw3HHG4MHFdOhAGO+ZOjVsuGdPeLEPPgifDBAaiUydGhY0S78kx/A7dQpFbd8O1dXh57dvTx2Ufeqp0B82U7KD9+23w5QpNZ9r3z68BoQDvrNm1Qz7I45IjUX94hewYkX4EOnYMVwffnhoKgvhD+TTT8Nzyec7dtSUUMmaTiySvNu2Lex0pwf9okWp9nvJvfmaQR92unM6pL53bzhwu3172PtPXoYNC8+/+mpYGG3HjtRz7nDTTeH5++8P0zc//ji1TZcuYSlkgPHj4S9/Cb9wsoF4eXkYXgI49dTU9M+kE0+Ev/893P6Xf4G33qoZ9iedFA4qAzz2WDjY0b17OKjcvXv4hNRxB9i5MwzHbd0ajuEkry+4AEpK4JlnwpBe+vNbt4bH27WDhx6C2bNTH9bJD+zbbw9/vv/7v+HvJv2bXceOcPLJ4f03bgw7Bjn4Vqe1XKTgJMfkFy2qGfKrVqV6bnfuHII+OdTTt2/qdp8+4f9Nq7VzZwj2XbtSLagqK8M41fbt4bnt28Mw05e+FJ7/0Y9g9erUc9u2QUVFONkLwuu8/37N9xk/PgQ9hOWXO3SoGfiDB6cOSjf2oLN7uBQVhd+nqqrmB+HHH4fjHD16wNtvhw+2Tz4JH5x794afveQSOOqo8Bf7+OOpIbbk9ZVXwoAB4S//scdS75nc5vrrw1/288+HIbX0MN6yBebMCf8g7rkHfvCDfX+Ht98O35Juuy00luncOXwId+4cvtE980w4DnPHHfDII6nfa8eO8P7JMcTLLw/Ppzv00FTbyUsuSX2wp0/j3Q8KdImMbdvCDm0y4P/xjxD8a9aErEjXvXvNkE/e7ts3TH+P3EjG++/Du++GD4XkZcAAuOiisHdYUQHvvBOGm5LHJb7znTDM9PHHYXpS9+7h0q1bCKzLLw9DQuvXw9ln1wzrHTvC8Ydrrw1/Iccfv29NU6fCFVfAiy+GTl2ZnngCLr4Ynn46NNVNZxb2ms8+O/Teveyy8FhRUbg2C8c9hg8PATlxYiqMk9d33hnOe3jtNXjppX2fLy8Pe+hNnUG1eXO4pH9727s3HLCH8MGwcmV4/MgjQ8/h/aRAl8jbuzfk2VtvhcuaNTVvr12bGuWA8H/3iCNqD/t+/cJOZWTXKtuzBzZtCoHfqVP4pT/6KAwfJD8QqqtDcF57bZhZtGlTuJ0+5NCuXZiBdPLJIcyefLLmcES7djBwIJSWhj34zZvDQeqiolQol5SEPeDkXnt6WEutFOgSe3v2hG/XdQV+VVVqpxVCrvToEb7N9+4drtMvvXuH0QuRlqblcyX2iotTS8skvwWnSw4BJ0N+3bqwV792bThOOX16zT18gK5d6w77Pn3CEKp2NKUlKdBFCN/8+/evux9rcg9/7dqaYb92bZjmPnt2GD5N16FDzcDv1Ssci0xekkPVJSXN//tJPCjQRbKQvodfG/cwJT496JOXdevCBJbkhIdMhx6aCvjMwE+/f+ihOvFV6qdAF8kBsxC4hx6amrae6ZNPwvHG5HHHd99NXZL3X345XGfu7UMY1+/Wbd/A79YtDP907RreP3lbU9DjR4Eu0kLatk0NvzRk27Z9wz7z/pIl4fbOnbW/Rps2tQd9+iXz8Y4d9SFQyBToIq1Qx45hunJyVYO6uIdzaDZuDDMLN26s/bJpU1gdM3k7uZpBppKSVNA3dDnkkNS1hoJaBwW6SAEzC+fIdOkSziHKxt694STKzMCv7UNg6dJwXd+HAITlcrL9AEjW26VLhOf654kCXSRmiorCSaEHHxzO+8lG8ptAMtzru7z3Xvg2sGlTaj2eurRrFz4M0kO+sRd9KKQo0EWkQenfBOqa2lmbnTvD7J/0wN+ype7L5s2pHihbtqRWSK5Pu3ahrk6dUkuwZF7X91z6NoU+dKRAF5FmU1KSmpWzP3buTK2z1dDlo4/CZevW8KGwdWvqsU8/ze792revGfLJhS0zL3U9Xtelpb5FKNBFpNVKHqTt2rVpr7NzZyrs04M/83bmY9u2hWVt1qypucR+5kJwDTnggJofArfeCuPGNe13qo0CXUQir6QkdXA2F/bs2bePSuYlucpxbZdc1ZFJgS4i0kjFxWFYpnPnfFdSU1E2G5nZaDNbbmarzGxiLc8faGa/Szz/ipn1zXmlIiJSrwYD3cyKgcnAeUA5MN7MyjM2uwr40N2PBH4K3J3rQkVEpH7Z7KEPB1a5+2p33wlMAzLbbYwFHk7cngGcZaYTiEVEWlI2gd4DWJ92vyrxWK3buPtuYAvQTMP+IiJSm6zG0HPFzK4xs0ozq6yurm7JtxYRibxsAn0DkL4KdM/EY7VuY2ZtgC7ApswXcvcp7l7h7hWlpaX7V7GIiNQqm0CfDww0s35mVgKMA2ZmbDMTuCJx+/PAs56vZqUiIjHV4Dx0d99tZhOA2UAx8JC7LzGzSUClu88Efg08YmargA8IoS8iIi3I8rUjbWbVwNr9/PGuQB0NvVqlQqq3kGqFwqq3kGqFwqq3kGqFptXbx91rHbPOW6A3hZlVuntFvuvIViHVW0i1QmHVW0i1QmHVW0i1QvPV26KzXEREpPko0EVEIqJQA31KvgtopEKqt5BqhcKqt5BqhcKqt5BqhWaqtyDH0EVEZF+FuocuIiIZCi7QG1rKt7Uws15m9pyZLTWzJWb27XzXlA0zKzaz18zsqXzXUh8zO8jMZpjZm2a2zMxG5Lum+pjZDYl/B4vN7HEza5vvmtKZ2UNm9r6ZLU577BAze8bMViauD85njUl11PqTxL+FRWb2BzM7KI8l/lNttaY9d6OZuZk1sR9TSkEFepZL+bYWu4Eb3b0cOBm4thXXmu7bwLJ8F5GF/wD+4u5lwHG04prNrAdwPVDh7oMJJ+i1tpPvpgKjMx6bCMxx94HAnMT91mAq+9b6DDDY3YcAK4CbW7qoOkxl31oxs17AucC6XL5ZQQU62S3l2yq4+zvu/mri9keEwMlcpbJVMbOewAXAg/mupT5m1gUYSThDGXff6e6b81pUw9oA7RJrHbUH3s5zPTW4+1zCWd7p0pfFfhi4qCVrqktttbr7XxMrvQK8TFhzKu/q+HOF0Dfi+0BOD2IWWqBns5Rvq5Po4DQUeCXPpTTkZ4R/ZHvzXEdD+gHVwG8Sw0MPmlmHfBdVF3ffANxL2Bt7B9ji7n/Nb1VZ6e7u7yRuvwt0z2cxjfBV4Ol8F1EXMxsLbHD3hbl+7UIL9IJjZh2B3wPfcfet+a6nLmZ2IfC+uy/Idy1ZaAMMA/7L3YcC22k9wwH7SIw9jyV8EB0BdDCzL+e3qsZJLLbX6qfEmdmPCMOdj+a7ltqYWXvgh8AtzfH6hRbo2Szl22qY2QGEMH/U3Z/Idz0NOBUYY2ZrCENZo8zst/ktqU5VQJW7J7/xzCAEfGt1NvCWu1e7+y7gCeCUPNeUjffM7HCAxPX7ea6nXmZ2JXAh8KVWvNrrAMIH+8LE/7WewKtmdlguXrzQAj2bpXxbhUQLvl8Dy9z9vnzX0xB3v9nde7p7X8Kf67Pu3ir3It39XWC9mR2deOgsYGkeS2rIOuBkM2uf+HdxFq34IG6a9GWxrwCezGMt9TKz0YThwjHuviPf9dTF3d9w927u3jfxf60KGJb4N91kBRXoiYMeyaV8lwHT3X1Jfquq06nA/yHs6b6euJyf76Ii5DrgUTNbBBwP3JHfcuqW+CYxA3gVeIPw/65VndloZo8DLwFHm1mVmV0F3AWcY2YrCd8y7spnjUl11PoLoBPwTOL/2v15LTKhjlqb7/1a7zcTERFpjILaQxcRkbop0EVEIkKBLiISEQp0EZGIUKCLiESEAl1EJCIU6CIiEaFAFxGJiP8PgmOmhlM7IpIAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "model_mlp = MLP([D_in, 256, 128, 64, D_out]).to(device) # [D_in, 512, 256, 128, 64, D_out]\n", "\n", "optimizer = optim.SGD(model_mlp.parameters(), lr=0.01, momentum=0.5)\n", "criterion = nn.NLLLoss()\n", "\n", "model_mlp, losses, accuracies = train_val_model(model_mlp, criterion, optimizer, dataloaders,\n", " num_epochs=15, log_interval=2)\n", "\n", "_ = plt.plot(losses['train'], '-b', losses['val'], '--r')" ] }, { "cell_type": "markdown", "metadata": { "id": "HXhHtX1TkUba" }, "source": [ "### 1.3 Convolutional Neural Network\n", "\n", "Convolutional layers capture patterns corresponding to relevant features independently of where they occur in the input. To do so, they slide a window over the input and apply the convolution operation with a set of kernels or filters that represent the features. Although it is not their only field of application, convolutional neural networks are mainly praised for their performance on image processing tasks.\n", "\n", "The training and validation management for the CNN implementation will be performed as the feed-forward network, however we will have to define the network's architecture.\n", "\n", "For that we will implement a CNN class to define how many layers it comprises and how the layers will be connected.\n", "\n", "The initialization (`__init__`) function will define the architecture and the `forward` function will implement how the different layers are connected. This architecture will be a sequece of 2 convolutional layers ([nn.Conv2d](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html)) (1st: output channels 10, kernel size 5; 2nd: output channels 20, kernel size 5), then 2 fully connected layers ([nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html)) (1st: output features 50; 2nd: output features 10 (the number of classes)). Once again, the final layer will be a [softmax](https://pytorch.org/docs/stable/generated/torch.nn.functional.log_softmax.html#torch.nn.functional.log_softmax) function that will choose the most probable class of the 10 in the input.\n", "\n", "Between the second convolution layer and the first fully connected, we will set a dropout layer ([nn.Dropout2d](https://pytorch.org/docs/stable/generated/torch.nn.Dropout2d.html)). The idea behind dropout is to disable a percentage of randomly selected neurons during each step of the training phase, in order to avoid overfitting." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "PZ0mCl24EoaM" }, "outputs": [], "source": [ "class CNN(nn.Module):\n", " \"\"\"Basic Pytorch CNN for MNIST-like data.\"\"\"\n", "\n", " def __init__(self):\n", " super(CNN, self).__init__()\n", " self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n", " self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n", " self.conv2_drop = nn.Dropout2d()\n", " self.fc1 = nn.Linear(320, 50)\n", " self.fc2 = nn.Linear(50, 10)\n", "\n", " def forward(self, x, T=1.0):\n", " # Batch size = 64, images 28x28 =>\n", " # x.shape = [64, 1, 28, 28]\n", " x = F.relu(F.max_pool2d(self.conv1(x), 2))\n", " # Convolution with 5x5 filter without padding and 10 channels =>\n", " # x.shape = [64, 10, 24, 24] since 24 = 28 - 5 + 1\n", " # Max pooling with stride of 2 =>\n", " # x.shape = [64, 10, 12, 12]\n", " x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n", " # Convolution with 5x5 filter without padding and 20 channels =>\n", " # x.shape = [64, 20, 8, 8] since 8 = 12 - 5 + 1\n", " # Max pooling with stride of 2 =>\n", " # x.shape = [64, 20, 4, 4]\n", " x = x.view(-1, 320)\n", " # Reshape =>\n", " # x.shape = [64, 320]\n", " x = F.relu(self.fc1(x))\n", " x = F.dropout(x, training=self.training)\n", " x = self.fc2(x)\n", " x = F.log_softmax(x, dim=1)\n", " return x" ] }, { "cell_type": "markdown", "metadata": { "id": "mv9vdZZ7OlSh" }, "source": [ "As previously, lets describe the model to be trained. We will use the ADAM optimizes ([optim.Adam](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam)), with learning rate 0.001, and the same negative log likelihood ([nn.NLLLoss](https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html))." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "ImTlr5JeEsb6", "outputId": "a8d8e0a6-e3cc-4b37-d022-adc0241e8b88" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0/24\n", "----------\n", "train Loss: 0.5321 Acc: 83.19%\n", "val Loss: 0.0969 Acc: 97.07%\n", "\n", "Epoch 2/24\n", "----------\n", "train Loss: 0.1932 Acc: 94.42%\n", "val Loss: 0.0587 Acc: 98.02%\n", "\n", "Epoch 4/24\n", "----------\n", "train Loss: 0.1598 Acc: 95.36%\n", "val Loss: 0.0422 Acc: 98.65%\n", "\n", "Epoch 6/24\n", "----------\n", "train Loss: 0.1401 Acc: 95.99%\n", "val Loss: 0.0364 Acc: 98.78%\n", "\n", "Epoch 8/24\n", "----------\n", "train Loss: 0.1305 Acc: 96.24%\n", "val Loss: 0.0331 Acc: 98.90%\n", "\n", "Epoch 10/24\n", "----------\n", "train Loss: 0.1236 Acc: 96.32%\n", "val Loss: 0.0316 Acc: 99.02%\n", "\n", "Epoch 12/24\n", "----------\n", "train Loss: 0.1202 Acc: 96.44%\n", "val Loss: 0.0351 Acc: 98.82%\n", "\n", "Epoch 14/24\n", "----------\n", "train Loss: 0.1156 Acc: 96.50%\n", "val Loss: 0.0295 Acc: 99.01%\n", "\n", "Epoch 16/24\n", "----------\n", "train Loss: 0.1146 Acc: 96.56%\n", "val Loss: 0.0295 Acc: 99.01%\n", "\n", "Epoch 18/24\n", "----------\n", "train Loss: 0.1086 Acc: 96.82%\n", "val Loss: 0.0298 Acc: 98.98%\n", "\n", "Epoch 20/24\n", "----------\n", "train Loss: 0.1062 Acc: 96.81%\n", "val Loss: 0.0300 Acc: 98.97%\n", "\n", "Epoch 22/24\n", "----------\n", "train Loss: 0.1053 Acc: 96.80%\n", "val Loss: 0.0300 Acc: 98.93%\n", "\n", "Epoch 24/24\n", "----------\n", "train Loss: 0.1000 Acc: 97.01%\n", "val Loss: 0.0288 Acc: 99.08%\n", "\n", "Training complete in 6m 42s\n", "Best val Acc: 99.08%\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAdUElEQVR4nO3deZgU9Z3H8fd3ZhgGuUYc5JgBBrzxAp3gJppE3aiIiceyGsm6q4kumqxrNJoEE9cYEp+wxhj3MFkvoj67Bo8YgyuIR/BaNMuARERWGQjKfd8Cw8Bv//hOp3uGOXqgZ2q66vN6nnq6u6qPb03Dp6p/9atfWQgBERGJl4KoCxARkdxTuIuIxJDCXUQkhhTuIiIxpHAXEYmhoqg+uKysLFRWVkb18SIieWnOnDnrQwh9W3teZOFeWVlJdXV1VB8vIpKXzOyjbJ6nZhkRkRhSuIuIxJDCXUQkhhTuIiIxpHAXEYkhhbuISAwp3EVEYijvwv3NN+HWW0EjFYuINC/vwr26GiZNgo0bo65ERKTzyrtwr6jw2+XLo61DRKQzU7iLiMSQwl1EJIbyLtz794eCAoW7iEhL8i7ci4pgwACFu4hIS/Iu3MGbZhTuIiLNU7iLiMRQ3ob7ihVRVyEi0nnlbbhv2wZbt0ZdiYhI55RVuJvZaDP7wMxqzGxCE8uvMrN1Zjavfrom96WmqTukiEjLWg13MysE7gPOB4YD48xseBNPfSKEMKJ+eijHdTagcBcRaVk2e+6jgJoQwpIQQi0wBbiofctqmcJdRKRl2YR7ObAs4/Hy+nmNjTWzd83saTMb1NQbmdl4M6s2s+p169YdQLlu4MD6QhTuIiJNytUB1eeAyhDCScBLwKNNPSmE8EAIoSqEUNW3b98D/rDiYujXT+EuItKcbMJ9BZC5J15RP+/PQggbQgi76x8+BJyam/Kap77uIiLNyybcZwNHmdlQMysGLgemZj7BzAZkPLwQWJi7EpumcBcRaV5Ra08IIdSZ2fXADKAQmBxCWGBmE4HqEMJU4AYzuxCoAzYCV7VjzYCH++uvt/eniIjkp1bDHSCEMA2Y1mje7Rn3bwVuzW1pLauogE2bYMcO6N69Iz9ZRKTzy8szVCHdHVLDEIiI7C/vw13t7iIi+8vbcC+v72mvcBcR2Z/CXUQkhvI23A85BPr0UbiLiDQlb8Md1NddRKQ5CncRkRhSuIuIxFDeh/u6dbBrV9SViIh0Lnkf7gArV0Zbh4hIZxOLcFfTjIhIQwp3EZEYUriLiMRQXod7z57Qq5fCXUSksbwOd1B3SBGRpsQi3DXsr4hIQ7EId+25i4g0FItwX7UK9uyJuhIRkc4jFuEeAqxeHXUlIiKdRyzCHdQ0IyKSSeEuIhJDCncRkRjK+3AvLfWrMincRUTS8j7czdQdUkSksbwPd1C4i4g0pnAXEYmh2IT7ypWwd2/UlYiIdA6xCfe6Oli7NupKREQ6h9iEO6hpRkQkReEuIhJDWYW7mY02sw/MrMbMJrTwvLFmFsysKncltq683G8V7iIirtVwN7NC4D7gfGA4MM7MhjfxvJ7AN4E/5LrI1pSVQXGxwl1EJCWbPfdRQE0IYUkIoRaYAlzUxPN+BPwzsCuH9WWloMD33hXuIiIum3AvB5ZlPF5eP+/PzOwUYFAI4fmW3sjMxptZtZlVr1u3rs3FtkR93UVE0g76gKqZFQD3ADe39twQwgMhhKoQQlXfvn0P9qMbULiLiKRlE+4rgEEZjyvq56X0BE4AXjWzpcBfAFM7+qBqKtxD6MhPFRHpnLIJ99nAUWY21MyKgcuBqamFIYQtIYSyEEJlCKESeBu4MIRQ3S4VN6OiAmprYf36jvxUEZHOqdVwDyHUAdcDM4CFwJMhhAVmNtHMLmzvArOlvu4iImlF2TwphDANmNZo3u3NPPfMgy+r7TLDfeTIKCoQEek8YnGGKmjPXUQkU2zCvV8/KCxUuIuIQIzCvbAQBg6EFStaf66ISNzFJtxBfd1FRFIU7iIiMRTLcNeJTCKSdLEL9x07YMuWqCsREYlW7MId1DQjIqJwFxGJIYW7iEgMxSrcBwwAM4W7iEiswr1LF+jfX+EuIhKrcAf1dRcRAYW7iEgsKdxFRGIoluG+ZQts2xZ1JSIi0YlluINGhxSRZIttuKtpRkSSLHbhXl7utwp3EUkyhbuISAzFLtxLSqCsTOEuIskWu3AHdYcUEVG4i4jEkMJdRCSGYhvuGzbAzp1RVyIiEo3YhjvoRCYRSa5Yh7uaZkQkqRTuIiIxFMtwT53IpGYZEUmqWIZ7jx5QWqo9dxFJrqzC3cxGm9kHZlZjZhOaWH6dmc03s3lm9qaZDc99qW2j7pAikmSthruZFQL3AecDw4FxTYT34yGEE0MII4C7gHtyXWhbKdxFJMmy2XMfBdSEEJaEEGqBKcBFmU8IIWzNeNgdCLkr8cAo3EUkyYqyeE45sCzj8XLgtMZPMrN/AL4FFANnN/VGZjYeGA8wePDgttbaJhUVsGYN1NZCcXG7fpSISKeTswOqIYT7QghHAN8FbmvmOQ+EEKpCCFV9+/bN1Uc3qaICQoBVq9r1Y0REOqVswn0FMCjjcUX9vOZMAS4+iJpyQn3dRSTJsgn32cBRZjbUzIqBy4GpmU8ws6MyHl4ALMpdiQdG4S4iSdZqm3sIoc7MrgdmAIXA5BDCAjObCFSHEKYC15vZF4A9wCbgyvYsOhsKdxFJsmwOqBJCmAZMazTv9oz738xxXQetVy8/mUnhLiJJFMszVAHM1B1SRJIrtuEOCncRSS6Fu4hIDMU+3Fetgrq6qCsREelYsQ/3vXv9TFURkSSJfbiDmmZEJHliHe6pi3Yo3EUkaWId7tpzF5GkinW4H3YYdO2qcBeR5Il1uOtEJhFJqliHOyjcRSSZFO4iIjGUiHBfsQL27Yu6EhGRjpOIcN+zB9ati7oSEZGOk4hwBzXNiEiyJCbcV7R0YUARkZhJTLhrz11EkiT24X744VBUpHAXkWSJfbgXFPgYMwp3EUmS2Ic7wHHHwfTpsHJl1JWIiHSMRIT7PffAJ5/AV76iC3eISDIkItyPOw5++Ut47TWYODHqakRE2l8iwh3g7/4OrroKfvxjePnlqKsREWlfiQl3gH//d9+Lv+IKWL066mpERNpPosK9e3d48knYuhX+5m/8+qoiInGUqHAHOP5434P//e/hzjujrkZEpH0kLtwBvvpVb5q54w6YOTPqakREci+R4W7mvWeOPtq7R65ZE3VFIiK5lchwB+jRw9vfN2+Gv/1bjfcuIvGS2HAHOOkk+Nd/hZdegp/8JOpqRERyJ9HhDnDNNTBuHNx+O7z+etTViIjkRlbhbmajzewDM6sxswlNLP+Wmb1vZu+a2StmNiT3pbYPM7j/fjjiCA95XbFJROKg1XA3s0LgPuB8YDgwzsyGN3raO0BVCOEk4GngrlwX2p569vT29w0b/ExWtb+LSL7LZs99FFATQlgSQqgFpgAXZT4hhDAzhPBJ/cO3gYrcltn+RoyAn/8cXngB7sqrTZOIyP6yCfdyYFnG4+X185pzNTC9qQVmNt7Mqs2sel0nbP+47jq49FK47TZ4882oqxEROXA5PaBqZlcAVcBPm1oeQngghFAVQqjq27dvLj86J8zgwQdhyBBvf9f4MyKSr7IJ9xXAoIzHFfXzGjCzLwDfBy4MIezOTXkdr3dvb39fvx5OPBGmTIEQoq5KRKRtsgn32cBRZjbUzIqBy4GpmU8ws5HA/Xiwr819mR3r1FOhuhqGDfM9+Isv1lWcRCS/tBruIYQ64HpgBrAQeDKEsMDMJprZhfVP+ynQA3jKzOaZ2dRm3i5vHH88zJoFd98NL74Iw4fD5MnaixeR/GAhorSqqqoK1dXVkXx2W9XU+MlOr70G55yTbpcXEeloZjYnhFDV2vMSf4ZqNo480ocI/sUv4K23fK/+vvvUH15EOi+Fe5YKCuDrX4f33oMzzoDrr4czz4QPP4y6MhGR/Snc22jIEJg+HX71K5g/H04+GX76U6iri7oyEZE0hfsBMPOLbb//Ppx3HnznO/CZz/hevYhIZ6BwPwgDBsBvf+t94Zcu9b340aO9n/yuXVFXJyJJpnA/SGbw5S/7Xvz3vue3X/4yDBzo7fJz56r7pIh0PIV7jpSVwY9+BH/6k/eLP+88eOghPyFqxAi4914NJywiHUfhnmOFhd4X/te/hlWrvMtkcTHcdBOUl8PYsfD88zoAKyLtS+Hejg49FL7xDZg9G95915tp3ngDvvhFGDwYvvtd+OCDqKsUkThSuHeQE0+Ee+6B5cv9IGxVFfzsZ3DssXD66fDww7BtW9RVikhcKNw7WHGxD0Q2daoH/V13wcaNPrzBgAHwta/5WPI6CCsiB0PhHqH+/eHb3/YeNrNm+QiUTz0Fn/0sHHMMTJqk0ShF5MAo3DsBM/j0p31AstWr4ZFHfC/+1lth0CBvo3/mGaitjbpSEckXCvdOpnt3uPJKH4Hyww9hwgR45x3vZVNeDt/6FrzyCmzfHnWlItKZacjfPLB3r/ednzwZfvc72LPHBzI7+WQf9uD00/128GD/FSAi8ZXtkL8K9zyzZQu8/Tb8z/94O/3bb8OOHb6svNxDPhX4I0ZAly6RlisiOZZtuBd1RDGSO717+9mv553nj+vqfHTKWbPSgf/UU76sWzcYNQpOOQUOOcRPsCoq8tvM+03Nq6jwoY27do1uXUXkwGnPPYZWrPCQTwX+/Pl+MLatFxc55BA4+2w4/3yfhg5tn3pFJHtqlpH9hOABX1fn7fh79zZ9v64OFi6EF17wsesXL/bXH310Oug//3koKYl2fUSSSOEuObNokYf89Onw6qs+nHG3bnDWWemwP+KIqKsUSQaFu7SLnTs94FN79YsW+fwjj4Rhw6BHD+/O2fi2qXmlpb5R0C8Akewp3KVDLF7sIf/SS7Bmjffc2b49fbtzZ8uvLyjwgD/+eJ+GD/fbY45R6Is0ReEuncLevfDJJ/uH/o4dPqbOwoWwYIFPixb588FD/8gj02GfCn6FviSdukJKp1BYCD17+tSa2lo/KzcV9gsW+Lg7zz2XDn3w/vzDhjU99eunE7lEQOEunUhxMZxwgk+Zdu9Oh/6HH/rVrpYs8WEYHn204XO7dWsY9kOHel/93bv3n2prm56/b5+P7TN4sE+DBqVvu3XruL+HyMFQuEun17Wrj4d/4on7L9u1Cz76yMO+8TRzZtNj8Jj5ezY3Acyb51fSaqxv33TYZ4Z/v34tv2fXrr7x0q8K6SgKd8lrJSXeDn/MMfsvCwE2bPCxeDIDtqgou5DdvdtPCFu2DD7+uOG0aBG8/HLbB3ArLvY6Skr8Sl0DBvjQz5lT5ryyMm/aaqnGTZtg82afUvdTt4WFMHKkX8v3sMPaVqvkN4W7xJaZh+OB6to13bzTlBB8rJ+PP/aLnzfVxNPctGuXH1BevRrmzvXbpq7EVVAAhx+eDvodOxoG+a5d2a9PZaWHfFWVT6ecAn36HMAfRvKCwl3kAJl5X/3S0ty8344dHvJNTatW+a+QHj38gHJpqe/5pz4/dT9zXmmph//cuVBd7dOcOfCb36Q/c9iwdOCfeqpPqfVJnc3c2hSC19SjR27+DpIb6gopkjAbN6YDf84cv126NL28S5d0aLdFv37effWII/a/7dOn7ccb9u6FrVvTTUx79qQHuGtpajwQXuYAeQUF+X/cI6f93M1sNPAvQCHwUAhhUqPlnwPuBU4CLg8hPN3aeyrcRTqP9es98OfM8eahVCi2NKUCMwRvmlq82KeaGj9Wkal374aBf/jh3qSVeXyg8f0tW9pnXTNrb7w+3bv7r5lUnamahw3rPOdX5CzczawQ+BA4B1gOzAbGhRDez3hOJdALuAWYqnAXSbadO73HUmbgp26XLk2ft9CjR/NNS42bmIqL04PctTY1HhAvc1C8zKnxvK1bve6amoYbFzMfBrtx6Kc2VJm9orI9YH+gcnkS0yigJoSwpP6NpwAXAX8O9xDC0vplbRxUVkTiqFu39JnFjdXVeXD26tV5LyYTgjdf1dQ03DDV1MDUqbB2bfOvba2rbdeu8O1vwyWXtO86ZBPu5cCyjMfLgdMO5MPMbDwwHmDw4MEH8hZu5UrfXBbpeLBIvikq6vzdMs28xsMOg9OaSLutW9OBv2FDdifHZU4dsVHr0HQMITwAPADeLHNAb7Jxox/aHzMGHnww/4+OiEje6dXLzx8YOTLqSppXkMVzVgCDMh5X1M+LRp8+cPXV8PDD8E//FFkZIiKdWTZ77rOBo8xsKB7qlwNfadeqWjNxojd63Xmn97/6x3+MtBwRkc6m1T33EEIdcD0wA1gIPBlCWGBmE83sQgAz+5SZLQcuBe43swXtWTRm8Itf+BGJG26A559v148TEck3WbW5hxCmAdMazbs94/5svLmm4xQWwuOPwx13wOc+16EfLSLS2WXT5t55lZTApEk+WPiOHT4mrIiI5Hm4Z7rmGjjzTPjgg6grERGJXHzCfeJEb4s/7zzvBy8ikmDxCfejjvIrNW/Y4AG/aVPUFYmIRCY+4Q4+Xumzz/q12K65JupqREQiE7/z9//yL+Hpp+G446KuREQkMvHac0/50pd8uLYQPOgjGrNeRCQq8Qz3lGefhUsvhe9/P+pKREQ6VPyaZTJdfDFcey385Cc+SPORR8JnP+sHXAHee88vTFlWphEmRSRW4r3nbgb33Qfjx8Nrr/kJTzNm+LJPPoETT/RLzXfp4gF/3HHwH//hy0OAV15JX1VARCSPxH93tbAQ7r/fp337fLDl1Pwnn/TL1q9d69O6dX75F4C33oIvfAEGDfJRKL/2Nb8vIpIHdIHs5tTWwnPPwQMPwIsv+pV1x4zxPfvy8qirE5GEyvYye/FuljkYxcUwdqw34yxeDBMm+MUfU5eQef31hpeMFxHpRBTu2Rg2zMeOf/ddH6wsBD9JatgwGD0annkG9uyJukoRkT9TuLdF6pJ+ZvDyy3D77d7jZuxYb49/7jlfvmuXbwg2bVIfexGJhML9QA0e7GPJL13qoT5qlF/fFXzo4ZNP9ksC9uzpvXDOPRfeeMOXr1vn7fgLF8L27VGtgYjEWPx7y7S3oiL44hd9Shk6FJ54ApYtazjt2+fLZ83yPvgphx7qG4vJk+GUU2DRIpg71+cNHgz9+3vvHhGRLCnc20OfPnDZZc0v//zn/YDssmXw8cfpqXdvXz5jRsPrwhYVQUUFzJwJlZXw+9/763v3htLS9HTGGd5nf/duf402CCKJpXCPQmmpnynbnK9+Fc46Kx36H33kt2VlvvyNN+CHP9z/ddu3e7hPmAD33utNQqWlftu1q/8aAD9WMHWqHztITaWlftIW+Ov/8AcYONCnAQN8o/JXf+XL6+o67xm9e/Z476Y//cmnpUt93uWXw2mn+Vj/jz3mvaFSU5cu/veurIT16+F//9dPXqur82nvXt8gDxjgZzrPmJFelpouuQSOPdaPsaSOzYhEqJP+D0247t3h+ON9asoPfgC33QZbt8LmzbBli98ecogvHzPGw3rzZp+2bfN++illZTBkiAdRaurRI728Z08PxLff9jDctQuOPjod7uee6xuKVPgPHAif/jR8/eu+fNIkf92ePT7V1cEJJ8Att/jyq66CNWv8c3v18l8gn/mMb9TAex+VlPg6pH6d9OkD3bp509bHH3twL1mSDvGxY72+RYsa/t2Ki/29Tj3Vw33pUrj11v3/pk8/7eE+dy5ccMH+y194wcN93jz4xjf2Xz5mjN8+9phveEeM8GnkSL+tqFDoZ2rvjeC+fX4WeklJ590RaWc6iUlaFkJ6A1JZ6fMeesh7A61cmZ6GDfPmIoBPfQpqanyPuEsX/8919tnwq1/58jFjfA/ZLL2BuuACf1/wjdTOnQ3r+Pu/9xPKtm3zDUJKQYH3VPrOdzx0d+3yoB461Kf+/Rtu2Pbt82ar2lrf8NTW+lRW5hu4LVvg//7Pa86cKip8o7tzp9fceHlxsa/PSy/5esyb5xua1P+v1auhXz//dbRypR9wLynxHlW1telfco895q/dtMmnjRt9o/LEE778r//a6zv0UJ/69PFhNG6+2ZdPn+7rmFrWvbv/ajv8cF9eU+Oft2+f/yLZu9efO3SoL581q+GQG2b+N0yNsvrWWw2XgW/chwzxMH3ssfRORWq67DLf8C5d6jsBmzf7d9Czp3+XP/4xXHmlb7RvucXnpTb6vXrB+ef7r6IVK+C//9u/o9S0ebN/9yed5Bvga6/1+Vu3pv/2r7/uf9/nn/cuzakdht69fbrxRv8bL17svd8a/2obO9b/Tc6e7b9oM5fV1cFNN/nyd9/1HY1evXzdUtOAATndkGV7ElMyN2mSPbN0kKS0diGU2bNbXj5tWsvL33kn/R83NR19tC/r2RMeecTDduhQD/YuXdKvLSmBK65o/r0LCvwXQLduTS/v3dv38JvT0msBzjnHJ/Bmsvnz4f33PdjBg3/KlIav6d8fVq3y+08/7RvJ1N+88d/+hBM8mDdt8jD84x99A5AK9xtu8ADP9KUveTMc+HGZNWsaLh83Dh5/PF3/J580XD5+vA/fAXD66fuv8803w913e9Clfr0VF6ePBZ15ps8rLYULL/Tb4mLfUG/d6t8l+OP5833eli1+0fvU3+fYY329rrvO5xUWpsN5/Xqf16+fN6+l5vfs6Rvj1IarsNBDeP16f6/Uv7Grr/YAfvbZ9K/LTGed5a+bPt1/NTd2441++8gj8POf77+8rs4/+6abfCM9diz827/t/7wc0567SEeqq/Mrhf3xj+m95rKy9AZl376GvzTaavFiD6/Unv+OHb5XndrgPPus77kXFHjgFBR4uJ56qi+fOTPdqyuVDeXl3p03BP9lkrkM/P1TxxvWrPHwLik58HVI2bvXA79rV9+g7trll9Hs3dt/keRibzi1HmbeRXnZsvSvscJCvx0yxG+3b/caMpcVFvrOhZn/Olu5Mr3R2rbNN5SpnaH//E949VXvEddU016Wst1zV7iLiOQRjS0jIpJgCncRkRhSuIuIxJDCXUQkhhTuIiIxpHAXEYkhhbuISAwp3EVEYiiyk5jMbB3w0QG+vAxYn8Ny8k2S1z/J6w7JXn+tuxsSQujb2gsiC/eDYWbV2ZyhFVdJXv8krzske/217m1bdzXLiIjEkMJdRCSG8jXcH4i6gIglef2TvO6Q7PXXurdBXra5i4hIy/J1z11ERFqgcBcRiaG8C3czG21mH5hZjZlNiLqejmRmS81svpnNM7PYX+nEzCab2Vozey9jXh8ze8nMFtXfHtrSe+SrZtb9DjNbUf/9zzOzMVHW2F7MbJCZzTSz981sgZl9s35+Ur775ta/Td9/XrW5m1kh8CFwDrAcmA2MCyG8H2lhHcTMlgJVIYREnMhhZp8DtgOPhRBOqJ93F7AxhDCpfuN+aAjhu1HW2R6aWfc7gO0hhLujrK29mdkAYEAIYa6Z9QTmABcDV5GM77659b+MNnz/+bbnPgqoCSEsCSHUAlOAiyKuSdpJCOF1YGOj2RcBj9bffxT/Rx87zax7IoQQVoUQ5tbf3wYsBMpJznff3Pq3Sb6FezmwLOPxcg5gpfNYAF40szlmNj7qYiLSL4Swqv7+aqBflMVE4Hoze7e+2SaWzRKZzKwSGAn8gQR+943WH9rw/edbuCfdGSGEU4DzgX+o/+meWMHbFPOnXfHg/RI4AhgBrAJ+Fmk17czMegC/AW4MIWzNXJaE776J9W/T959v4b4CGJTxuKJ+XiKEEFbU364Ffos3UyXNmvo2yVTb5NqI6+kwIYQ1IYS9IYR9wIPE+Ps3sy54sP1XCOGZ+tmJ+e6bWv+2fv/5Fu6zgaPMbKiZFQOXA1MjrqlDmFn3+oMrmFl34FzgvZZfFUtTgSvr718J/C7CWjpUKtjqXUJMv38zM+BhYGEI4Z6MRYn47ptb/7Z+/3nVWwagvvvPvUAhMDmEcGe0FXUMMxuG760DFAGPx33dzezXwJn4cKdrgB8AzwJPAoPxIaMvCyHE7sBjM+t+Jv6TPABLgWsz2qBjw8zOAN4A5gP76md/D293TsJ339z6j6MN33/ehbuIiLQu35plREQkCwp3EZEYUriLiMSQwl1EJIYU7iIiMaRwFxGJIYW7iEgM/T/TstjqYZfRqgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "model = CNN().to(device)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", "criterion = nn.NLLLoss()\n", "\n", "model, losses, accuracies = train_val_model(model, criterion, optimizer, dataloaders,\n", " num_epochs=25, log_interval=2)\n", "\n", "_ = plt.plot(losses['train'], '-b', losses['val'], '--r')" ] }, { "cell_type": "markdown", "metadata": { "id": "ULZ91b0cPhy5" }, "source": [ "We have now completed training and validation with 3 different models: Logistic Regression, Feed-Forward Network, and Convolutional Neural Network. \n", "\n", "We have seen that with the CNN, the performance of the model in the validation set, outperforms the other models (~99% accuracy against ~90% and ~98%). " ] }, { "cell_type": "markdown", "metadata": { "id": "PHyGUuZbTvhr" }, "source": [ "The difference in performance between CNNs and MLP is small but how many learnable parameters are we using in the MLP and in CNN models?\n", "\n", "We can find it out using the following lines of code:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "acy0l3-YQjT2", "outputId": "ceb6251b-4ea1-4168-f23a-0c42feea37ce" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of parameters in the MLP model: 242762\n", "Number of parameters in the CNN model: 21840\n" ] } ], "source": [ "#model_mlp = MLP([D_in, 256, 128, 64, D_out]).to(device)\n", "model_parameters_mlp = filter(lambda p: p.requires_grad, model_mlp.parameters())\n", "params_mlp = sum([np.prod(p.size()) for p in model_parameters_mlp])\n", "print('Number of parameters in the MLP model: {}'.format(params_mlp))\n", "\n", "model_parameters_cnn = filter(lambda p: p.requires_grad, model.parameters())\n", "params_cnn = sum([np.prod(p.size()) for p in model_parameters_cnn])\n", "print('Number of parameters in the CNN model: {}'.format(params_cnn))" ] }, { "cell_type": "markdown", "metadata": { "id": "Sj28CWvrMbOw" }, "source": [ "You can see that we have ~11x more learnable parameters to achieve almost the same performance.\n", "\n", "We can experiment and try to find out the number of layers and corresponding sizes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "2RmgJhIPMECw", "outputId": "3c677343-5c57-4a1d-a047-9bca2d8fe52f" }, "outputs": [], "source": [ "model_mlp_test = MLP([D_in, 32, D_out]).to(device)\n", "model_parameters_mlp_test = filter(lambda p: p.requires_grad, model_mlp_test.parameters())\n", "params_mlp_test = sum([np.prod(p.size()) for p in model_parameters_mlp_test])\n", "print('Number of parameters in the MLP model: {}'.format(params_mlp_test))" ] }, { "cell_type": "markdown", "metadata": { "id": "B_oq9682QWCF" }, "source": [ "And how does that model perform? We are about to find out" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 979 }, "id": "w6oa0TeBQU9E", "outputId": "e8c478bf-27c6-4ade-f0bd-59302eaf0e49" }, "outputs": [], "source": [ "optimizer = optim.SGD(model_mlp_test.parameters(), lr=0.01, momentum=0.5)\n", "criterion = nn.NLLLoss()\n", "\n", "model_mlp_test, losses, accuracies = train_val_model(model_mlp_test, criterion, \n", " optimizer, dataloaders,\n", " num_epochs=15, \n", " log_interval=5)\n", "\n", "_ = plt.plot(losses['train'], '-b', losses['val'], '--r')" ] }, { "cell_type": "markdown", "metadata": { "id": "rpmgachOUCnX" }, "source": [ "We can see a drop in performance compared with the previous MLP model. So we can understand that although we have less learnable parameters, due to properties of CNNs (e.g., invariance and parameter sharing), which allow them to have fewer weights as some parameters are shared.\n", "\n", "CNNs are expected to be invariant to the location where important features occur in the input. In fact, it's not unusual that there is a dataset shift where the data acquisition process suffers some modification. We will do this by applying a transformation with horizontal translations to our validation dataset and see how robust each model is to these shifts.\n", "\n", "We can do this by going back to **0.1 - Create Dataloaders -\n", "MNIST dataset** cell to define the test transform using the following code \n", "\n", "```\n", "mnist_transform_test = transforms.Compose(\n", " [transforms.ToTensor(),\n", " transforms.RandomAffine(0, translate=[0.1, 0]),\n", " transforms.Normalize((0.1307,), (0.3081,))])\n", "```\n", "\n", "and replace\n", "\n", "`mnist_val_dataset = datasets.MNIST('../data', download=True, train=False, transform=mnist_transform)`\n", "\n", "with\n", "\n", "`mnist_val_dataset = datasets.MNIST('../data', download=True, train=False, transform=mnist_transform_test)`" ] }, { "cell_type": "markdown", "metadata": { "id": "-5gcf_gMlcqI" }, "source": [ "After rerunning the different models we can see that the accuracy of the Logistic Regression drops from ~90% to ~72%, the MLP drops from ~98% to ~87%, and the CNN drops from ~99% to ~97%. This shows that the learned features are more robust to variances in location, as expected." ] }, { "cell_type": "markdown", "metadata": { "id": "nU3NwQ7Nuvhv" }, "source": [ "# Bonus Case - Attention with small images and CNNs. (And how to create a dataset that takes numpy arrays)\n", "\n", "In this case we will use the Scikit-Learn's digits dataset\n", "\n", "## Scikit-Learn Digits\n", "\n", "This dataset is provided by scikit-learn and the digit images are returned as numpy ndarray. We will use PIL (Python Image Library) to convert the numpy ndarray to a image, tranform it to a tensor and normalize it.\n", "\n", "In this case we don't have a predefined Digits Dataset provided by torchvision so we will need to write a custom Dataset class and implement three functions: \n", "\n", "`__init__`, `__len__`, and `__getitem__`.\n", "\n", "Scikit-Learn return the digits images and labels as ndarrays. Each digit image is an 8x8 array.\n", "\n", "To use the previous CNN, we will use a transform to resize the images to the MNIST image size." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "A4v-XFzcv9If" }, "outputs": [], "source": [ "SKLEARN_DIGITS_TRAIN_SIZE = 1247\n", "SKLEARN_DIGITS_VAL_SIZE = 550\n", "\n", "class NumpyDataset(Dataset):\n", "\n", " def __init__(self, data, targets, transform=None):\n", " self.data = torch.from_numpy(data).float()\n", " self.targets = torch.from_numpy(targets).long()\n", " self.transform = transform\n", "\n", " def __getitem__(self, index):\n", " x = np.expand_dims(self.data[index], axis=2)\n", " y = self.targets[index]\n", " if self.transform:\n", " x = self.transform(x)\n", " return x, y\n", "\n", " def __len__(self):\n", " return len(self.data) \n", "\n", "digits_transform = transforms.Compose([\n", " transforms.ToPILImage(),\n", " transforms.Resize(28),\n", " transforms.ToTensor(),\n", " ])\n", "\n", "# Get sklearn digits dataset\n", "X, y = load_digits(return_X_y=True)\n", "X = X.reshape((len(X), 8, 8))\n", "y_train = y[:-SKLEARN_DIGITS_VAL_SIZE]\n", "y_val = y[-SKLEARN_DIGITS_VAL_SIZE:]\n", "X_train = X[:-SKLEARN_DIGITS_VAL_SIZE]\n", "X_val = X[-SKLEARN_DIGITS_VAL_SIZE:]\n", "\n", "digits_train_dataset = NumpyDataset(X_train, y_train, transform=digits_transform)\n", "digits_val_dataset = NumpyDataset(X_val, y_val, transform=digits_transform)\n", "digits_train_dataloader = torch.utils.data.DataLoader(digits_train_dataset, batch_size=64, shuffle=True)\n", "digits_val_dataloader = torch.utils.data.DataLoader(digits_val_dataset, batch_size=64, shuffle=True)\n", "\n", "dataloaders = dict(train=digits_train_dataloader, val=digits_val_dataloader)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "dhQU3v7Zv9Ih", "outputId": "030f1fa0-62a0-4dc0-d61f-013e09e2457d" }, "outputs": [], "source": [ "# Get some examples of images and targets\n", "_, (example_train_imgs, example_train_targets) = next(enumerate(digits_train_dataloader))\n", "_, (example_val_imgs, example_val_targets) = next(enumerate(digits_val_dataloader))\n", "\n", "# Info about the dataset\n", "D_in = np.prod(example_imgs.shape[1:])\n", "D_out = len(digits_train_dataloader.dataset.targets.unique())\n", "\n", "# Output information\n", "print(\"Datasets shapes (before transformations):\", {x: dataloaders[x].dataset.data.shape for x in ['train', 'val']})\n", "print(\"N input features:\", D_in, \"Output classes:\", D_out)\n", "print(\"Train batch:\", example_train_imgs.shape, example_train_targets.shape)\n", "print(\"Val batch:\", example_val_imgs.shape, example_val_targets.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 284 }, "id": "Vx78pb7Ov9Ih", "outputId": "207e26b2-c2f7-41e7-ae13-9421b398027e" }, "outputs": [], "source": [ "plot_img_label_prediction(imgs=example_train_imgs, y_true=example_train_targets, y_pred=None, shape=(2, 3))\n" ] }, { "cell_type": "markdown", "metadata": { "id": "xbBAH9OTv9Ii" }, "source": [ "### Logistic Regression" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "W46ofUE6v9Ii", "outputId": "5e698e3a-c465-45c3-9f72-07f566ef8055" }, "outputs": [], "source": [ "scaler = StandardScaler()\n", "print(X_train.squeeze().shape)\n", "X_train = scaler.fit_transform(np.reshape(X_train, (X_train.shape[0], -1)))\n", "X_val = scaler.transform(np.reshape(X_val, (X_val.shape[0], -1)))\n", "\n", "# Turn up tolerance for faster convergence\n", "clf = LogisticRegression(C=50., multi_class='multinomial', solver='sag', tol=0.1)\n", "clf.fit(X_train, y_train)\n", "#sparsity = np.mean(clf.coef_ == 0) * 100\n", "score = clf.score(X_val, y_val)\n", "\n", "print(\"Test score with penalty: %.4f\" % score)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 339 }, "id": "E_G2Rt0Gv9Ij", "outputId": "59791f8a-e8ac-41f9-81e9-8cdc92cadd42" }, "outputs": [], "source": [ "coef = clf.coef_.copy()\n", "plt.figure(figsize=(10, 5))\n", "scale = np.abs(coef).max()\n", "for i in range(10):\n", " l1_plot = plt.subplot(2, 5, i + 1)\n", " l1_plot.imshow(coef[i].reshape(8, 8), interpolation='nearest',\n", " cmap=plt.cm.RdBu, vmin=-scale, vmax=scale)\n", " l1_plot.set_xticks(())\n", " l1_plot.set_yticks(())\n", " l1_plot.set_xlabel('Class %i' % i)\n", "plt.suptitle('Classification coefficient vectors for...')\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "PfYdXpde4bg0" }, "source": [ "### Feed-forward using digits dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "UDKn4WS636Bg", "outputId": "754a52a5-d33e-4a7f-9431-df84fb38b3d8" }, "outputs": [], "source": [ "model = MLP([D_in, 512, 256, 128, 64, D_out]).to(device)\n", "\n", "optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)\n", "criterion = nn.NLLLoss()\n", "\n", "model, losses, accuracies = train_val_model(model, criterion, optimizer, dataloaders,\n", " num_epochs=20, log_interval=2)\n", "\n", "_ = plt.plot(losses['train'], '-b', losses['val'], '--r')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 724 }, "id": "R9ImEBeo4MW6", "outputId": "c402c8da-dd25-4212-86f6-d3768e1619ed" }, "outputs": [], "source": [ "model = CNN().to(device)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", "criterion = nn.NLLLoss()\n", "\n", "model, losses, accuracies = train_val_model(model, criterion, optimizer, dataloaders,\n", " num_epochs=50, log_interval=10)\n", "\n", "_ = plt.plot(losses['train'], '-b', losses['val'], '--r')" ] }, { "cell_type": "markdown", "metadata": { "id": "uIhei09Ruvf-" }, "source": [ "# Bonus Information - Visualizing CNN filters\n", "\n", "Some work have been done to demonstrate the type of features learned by different filters in different layers. \n", "\n", "For instance, considering a known CNN called **VGG16** which has the following architecture\n", "\n", "![image](https://media.geeksforgeeks.org/wp-content/uploads/20200219152327/conv-layers-vgg16.jpg)\\[taken from: https://www.geeksforgeeks.org/vgg-16-cnn-model/ \\]\n", "\n", "these would be some of the filters from some of the layers: \n", "\n", "\t \n", "\t\t\n", "\t\t\t\n", "\t\t\t\n", "\t\t\t\n", "\t\t\t\n", "\t\t\n", "\t\t\n", "\t\t\t\n", "\t\t\t\n", "\t\t\t\n", "\t\t\t\n", "\t\t\n", "\t\t\n", "\t\t\t\n", "\t\t\t\n", "\t\t\t\n", "\t\t\t\n", "\t\t\n", "\t\t\n", "\t\t\t\n", "\t\t\t\n", "\t\t\t\n", "\t\t\t\n", "\t\t\n", "\t\n", "
Layer 2
(Conv 1-2)
Layer 10
(Conv 2-1)
Layer 17
(Conv 3-1)
Layer 24
(Conv 4-1)
\n", "\n", "or obtain the class activations:\n", "\n", "\t \n", " \t\t\n", "\t\t\t\n", "\t\t\t\n", "\t\t\n", "\n", "\t\t\t\n", "\t\t\t\n", "\t\t\t\n", "\t\t\n", "\t\n", "
Input Image Layer Vis. (Filter=0) Filter Vis. (Layer=29)
\n", "\n", "\\[examples taken from: http://www.github.com/utkuozbulak/pytorch-cnn-visualizations \\]\n" ] }, { "cell_type": "markdown", "metadata": { "id": "VU21L86BvAXK" }, "source": [ "# Bonus Information - Predefined architectures, pre-trained models and transfer learning\n", "\n", "Packages like [torchvision](https://pytorch.org/vision/stable/index.html) and [timm](https://rwightman.github.io/pytorch-image-models/) offer you the possibility of using predefined architectures or even use pre-trained models that can be used to fine tune the models for that same task or used for transfer learning.\n", "\n", "Besides datasets, transforms and others, **Torchvision** has a large number of predefined architecture with the possibility of loading the pre-trained weights." ] }, { "cell_type": "markdown", "metadata": { "id": "w_1YpmkV-PbU" }, "source": [ "#### Torchvision classification models examples\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "l5yOWA9l4lrF" }, "outputs": [], "source": [ "import torchvision.models as models\n", "\n", "# construct a model with random weights to be trained\n", "resnet18 = models.resnet18()\n", "\n", "# load a pre-trained model\n", "resnet18 = models.resnet18(pretrained=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "nAsfqUIyDBD7" }, "source": [ "For examples of different models and how to use pre-trained weights please visit https://pytorch.org/vision/stable/models.html#\n", "\n", "\n", "\n", "Another possibility is **timm** which contains models for classification only.\n", "In **timm** you are not restricted to have inputs only with 1/3-channels, allowing you to use architectures or pre-trained models using images that have 2 or > 3-channels." ] }, { "cell_type": "markdown", "metadata": { "id": "4zS8Ykbo-ZUg" }, "source": [ "#### timm classification models examples" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "w0FQYnKQ-lr_", "outputId": "0d3ef180-ef41-4653-99dd-6721c41c2475" }, "outputs": [], "source": [ "if 'google.colab' in str(get_ipython()):\n", " !pip install -q timm\n", "import timm\n", "\n", "# list all models\n", "print(timm.list_models())\n", "\n", "# list pre-trained models\n", "print(timm.list_models(pretrained=True))\n", "\n", "# list models architectures by wildcards\n", "print(timm.list_models('*resne*t*'))\n", "\n", "# construct a model with random weights to be trained\n", "model = timm.create_model('resnet18')\n", "\n", "# load a pre-trained model\n", "model = timm.create_model('resnet18', pretrained=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "OSEwqB8ADIao" }, "source": [ "For more details on how to use this package visit https://rwightman.github.io/pytorch-image-models/" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kaank-9kDI72" }, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "P8-CNNs.ipynb", "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.2" } }, "nbformat": 4, "nbformat_minor": 1 }