{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "improving-payment", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import random\n", "\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "%matplotlib inline\n", "\n", "random.seed(5)\n", "np.random.seed(5)\n", "torch.random.manual_seed(5)" ] }, { "cell_type": "markdown", "id": "f1d49658", "metadata": {}, "source": [ "### Question 1.1" ] }, { "cell_type": "code", "execution_count": 3, "id": "violent-pulse", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(4, 3)\n" ] } ], "source": [ "x1 = np.array([-2.0, 1.0, 0.5])\n", "x2 = np.array([1.0, 1.5, -0.5])\n", "x3 = np.array([-1.5, 1.0, -0.5])\n", "x4 = np.array([-2.0, -2.5, 1.5])\n", "\n", "X = np.array([x1, x2, x3, x4])\n", "\n", "print(X.shape)" ] }, { "cell_type": "code", "execution_count": 4, "id": "dressed-oracle", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3\n", "[4.5 0. 4.5 0. ]\n", "[2.59807621 0. 2.59807621 0. ]\n", "[0.46536883 0.03463117 0.46536883 0.03463117]\n", "[-1.66342208 0.8961065 0.03463117]\n" ] } ], "source": [ "q = np.array([-2.0, 1.0, -1.0])\n", "\n", "print(np.size(q))\n", "\n", "scores = X.dot(q) / np.sqrt(np.size(q))\n", "probabilities = np.exp(scores) / np.sum(np.exp(scores))\n", "output = X.T.dot(probabilities)\n", "\n", "print(X.dot(q))\n", "print(scores)\n", "print(probabilities)\n", "print(output)" ] }, { "cell_type": "markdown", "id": "relevant-centre", "metadata": {}, "source": [ "### Question 1.2" ] }, { "cell_type": "code", "execution_count": 5, "id": "opposed-trailer", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[-2.25 4.5 ]\n", " [ 1.25 2. ]\n", " [-1.25 4.75]\n", " [-2.75 -3.5 ]]\n", "[[ 5.75 1.5 ]\n", " [ 2. -0.5 ]\n", " [ 4.5 2. ]\n", " [-2.5 0.5 ]]\n", "[[-2.5 -7.5 ]\n", " [ 0.25 0. ]\n", " [-2. -5.25]\n", " [-0.75 -1.5 ]]\n", "1.4142135623730951\n", "[[ -4.37522321 -4.77297077 -0.79549513 5.5684659 ]\n", " [ 7.20365033 1.06066017 6.80590277 -1.50260191]\n", " [ -0.04419417 -3.44714556 2.74003878 3.8890873 ]\n", " [-14.89343658 -2.65165043 -13.70019389 3.62392225]]\n", "[[4.79433566e-05 3.22098620e-05 1.71943035e-03 9.98200416e-01]\n", " [5.97319598e-01 1.28333499e-03 4.01298182e-01 9.88847812e-05]\n", " [1.46423661e-02 4.87223528e-04 2.37021787e-01 7.47848624e-01]\n", " [9.06143069e-09 1.87817885e-03 2.98824011e-08 9.98121782e-01]]\n", "[[-0.75220098 -1.50668721]\n", " [-2.29564869 -6.58686077]\n", " [-1.07141415 -2.47595506]\n", " [-0.74812187 -1.4971829 ]]\n", "[1. 1. 1. 1.]\n" ] } ], "source": [ "W_Q = np.array([[1, -1.5], [0, 2], [-0.5, -1]])\n", "W_K = np.array([[-1.5, -1], [2.5, 0], [0.5, -1]])\n", "W_V = np.array([[1, 2.5], [-0.5, -2], [0, -1]])\n", "\n", "#print(W_Q)\n", "#print(W_K)\n", "#print(W_V)\n", "\n", "Q = X.dot(W_Q)\n", "K = X.dot(W_K)\n", "V = X.dot(W_V)\n", "\n", "print(Q)\n", "print(K)\n", "print(V)\n", "\n", "print(np.sqrt(np.size(Q, 1)))\n", "\n", "scores = Q.dot(K.T) / np.sqrt(np.size(Q, 1))\n", "probabilities = np.exp(scores) / np.sum(np.exp(scores), axis=1)[:, None]\n", "\n", "Z = probabilities.dot(V)\n", "\n", "print(scores)\n", "print(probabilities)\n", "print(Z)\n", "\n", "print(probabilities.sum(axis=1))" ] }, { "cell_type": "code", "execution_count": 6, "id": "animated-barrel", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQcAAAD8CAYAAAB6iWHJAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAANI0lEQVR4nO3df6zddX3H8eeLUosIyC8zutIJDmZmnANpKo5kISAZEEOXiBksETCQLkYmLjNRt4Rl/rHg/tDEYFwaIANDFAPKOsJCasComSClKR0tQ+/4hxYytGChQ6uXvffH+ZZdLp/bQs/3fM8tfT6Sk/s95/vpfb9vSl+c+/1+z/edqkKS5jts2g1IWpwMB0lNhoOkJsNBUpPhIKnJcJDUNFY4JDk+yYYkP+2+HrfAupeTbO4e68epKWkYGec6hyT/CDxXVTck+RxwXFV9trFud1UdNUafkgY2bjg8AZxbVc8kWQ58r6re3VhnOEgHmXHD4RdVdWy3HeD5vc/nrZsFNgOzwA1VdfcC328tsBZgCUvOOpJjDrg3qS+/976Xpt3CxDyyZc/Pq+odrX37DYck3wVOauz6W+DWuWGQ5Pmqes1xhyQrqmpHkncB9wPnV9V/7avuMTm+PpDz99mbNIT7nt487RYmZsnymUeqalVr3+H7+8NV9aGF9iX57yTL5/xa8ewC32NH9/XJJN8DzgT2GQ6SpmvcU5nrgSu77SuBf5m/IMlxSZZ12ycC5wDbxqwracLGDYcbgAuS/BT4UPecJKuS3NSt+X1gY5JHgQcYHXMwHKRFbr+/VuxLVe0EXnNgoKo2Atd02/8O/ME4dSQNzyskJTUZDpKaDAdJTYaDpCbDQVKT4SCpyXCQ1GQ4SGoyHCQ1GQ6SmgwHSU2Gg6Qmw0FSk+EgqclwkNRkOEhqMhwkNRkOkpp6CYckFyZ5IslMN/lq/v5lSe7o9j+U5JQ+6kqanLHDIckS4KvARcB7gMuTvGfesqsZDbw5Dfgy8MVx60qarD7eOawGZqrqyar6NfBNYM28NWuAW7vtO4HzuwlZkhapPsJhBfDUnOfbu9eaa6pqFtgFnNBDbUkTMtat6fs2d1bmERw55W6kQ1sf7xx2ACvnPD+5e625JsnhwNuBnfO/UVWtq6pVVbVqKct6aE3SgeojHB4GTk9yapK3AJcxGpM319yxeZcC99c4470lTdzYv1ZU1WySa4H7gCXALVW1NckXgI1VtR64Gfh6khngOUYBImkR6+WYQ1XdC9w777Xr52z/CvhoH7UkDcMrJCU1GQ6SmgwHSU2Gg6Qmw0FSk+EgqclwkNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKaDAdJTUPNyrwqyc+SbO4e1/RRV9LkjH2D2TmzMi9gNO3q4STrq2rbvKV3VNW149aTNIw+7j79yqxMgCR7Z2XOD4c3ZM+pb+XJfzhj/O4WmXf9+eZptzAxOz77R9NuYSL+5Len3cEkzSy4Z6hZmQAfSbIlyZ1JVjb2k2Rtko1JNv7vi//TQ2uSDtRQByT/FTilqt4HbOD/J26/ytxxeIcd/baBWpPUMsiszKraWVV7uqc3AWf1UFfSBA0yKzPJ8jlPLwEe76GupAkaalbmp5JcAswympV51bh1JU3WULMyPw98vo9akobhFZKSmgwHSU2Gg6Qmw0FSk+EgqclwkNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKaDAdJTX2Nw7slybNJHltgf5J8pRuXtyXJ+/uoK2ly+nrn8M/AhfvYfxFwevdYC3ytp7qSJqSXcKiq7zO6q/RC1gC31ciDwLHzblcvaZEZ6pjD6xqZ5zg8afFYVAckHYcnLR5DhcN+R+ZJWlyGCof1wBXdWYuzgV1V9cxAtSUdgF4mXiX5BnAucGKS7cDfAUsBquqfGE3DuhiYAV4CPt5HXUmT09c4vMv3s7+AT/ZRS9IwFtUBSUmLh+EgqclwkNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKaDAdJTYaDpCbDQVKT4SCpaahxeOcm2ZVkc/e4vo+6kianl3tIMhqHdyNw2z7W/KCqPtxTPUkTNtQ4PEkHmb7eObweH0zyKPA08Jmq2jp/QZK1jAbtcgRH8rtXNH9LOajVtBuYoBO2zk67hYn4rR8dM+0WJufshXcNFQ6bgHdW1e4kFwN3M5q4/SpVtQ5YB3DMYce/mf8dSYveIGcrquqFqtrdbd8LLE1y4hC1JR2YQcIhyUlJ0m2v7uruHKK2pAMz1Di8S4FPJJkFfglc1k3BkrRIDTUO70ZGpzolHSS8QlJSk+EgqclwkNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKaDAdJTYaDpCbDQVKT4SCpaexwSLIyyQNJtiXZmuS6xpok+UqSmSRbkrx/3LqSJquPe0jOAn9dVZuSHA08kmRDVW2bs+YiRnMqTgc+AHyt+yppkRr7nUNVPVNVm7rtF4HHgRXzlq0BbquRB4Fjkywft7akyen1mEOSU4AzgYfm7VoBPDXn+XZeGyAkWZtkY5KNv6k9fbYm6Q3qLRySHAXcBXy6ql44kO9RVeuqalVVrVqaZX21JukA9BIOSZYyCobbq+rbjSU7gJVznp/cvSZpkerjbEWAm4HHq+pLCyxbD1zRnbU4G9hVVc+MW1vS5PRxtuIc4GPAfyTZ3L32N8DvwCvj8O4FLgZmgJeAj/dQV9IEjR0OVfVDIPtZU8Anx60laTheISmpyXCQ1GQ4SGoyHCQ1GQ6SmgwHSU2Gg6Qmw0FSk+EgqclwkNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUNNQ4vHOT7EqyuXtcP25dSZM11Dg8gB9U1Yd7qCdpAEONw5N0kOnjncMr9jEOD+CDSR4FngY+U1VbG39+LbAW4AiOpGZn+2xPE3bEPT+edgsTcdu6zdNuYWJu38e+3sJhP+PwNgHvrKrdSS4G7mY0cftVqmodsA7gmBxfffUm6Y0bZBxeVb1QVbu77XuBpUlO7KO2pMkYZBxekpO6dSRZ3dXdOW5tSZMz1Di8S4FPJJkFfglc1k3BkrRIDTUO70bgxnFrSRqOV0hKajIcJDUZDpKaDAdJTYaDpCbDQVKT4SCpyXCQ1GQ4SGoyHCQ1GQ6SmgwHSU2Gg6Qmw0FSk+EgqclwkNRkOEhqMhwkNfVxg9kjkvw4yaPdOLy/b6xZluSOJDNJHurmW0haxPp457AHOK+q/hA4A7gwydnz1lwNPF9VpwFfBr7YQ11JE9THOLzaO5MCWNo95t9Zeg1wa7d9J3D+3lvVS1qc+hpqs6S7Lf2zwIaqmj8ObwXwFEBVzQK7gBP6qC1pMnoJh6p6uarOAE4GVid574F8nyRrk2xMsvE37OmjNUkHqNezFVX1C+AB4MJ5u3YAKwGSHA68ncbEq6paV1WrqmrVUpb12ZqkN6iPsxXvSHJst/1W4ALgP+ctWw9c2W1fCtzvxCtpcetjHN5y4NYkSxiFzbeq6p4kXwA2VtV6RrM0v55kBngOuKyHupImqI9xeFuAMxuvXz9n+1fAR8etJWk4XiEpqclwkNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKaDAdJTYaDpCbDQVKT4SCpyXCQ1GQ4SGoaalbmVUl+lmRz97hm3LqSJquPu0/vnZW5O8lS4IdJ/q2qHpy37o6quraHepIG0MfdpwvY36xMSQeZPt450M2seAQ4DfhqY1YmwEeS/DHwE+CvquqpxvdZC6ztnu7+bt35RB/9vU4nAj8fsN5Q/LnGtGT5EFVeZci/s3cutCN9Dp7qJl99B/jLqnpszusnALurak+SvwD+rKrO661wD5JsrKpV0+6jb/5cB5/F8rMNMiuzqnZW1d7JuDcBZ/VZV1L/BpmVmWTuG7NLgMfHrStpsoaalfmpJJcAs4xmZV7VQ92+rZt2AxPiz3XwWRQ/W6/HHCS9eXiFpKQmw0FS0yEfDkkuTPJEkpkkn5t2P31JckuSZ5M8tv/VB48kK5M8kGRbd7n+ddPuqQ+v52MIg/d0KB9z6A6i/oTRGZbtwMPA5VW1baqN9aC74Gw3cFtVvXfa/fSlO/O1vKo2JTma0cV3f3qw/50lCfC2uR9DAK5rfAxhMIf6O4fVwExVPVlVvwa+CayZck+9qKrvMzoz9KZSVc9U1aZu+0VGp8VXTLer8dXIovoYwqEeDiuAuZdxb+dN8B/aoSLJKcCZQOty/YNOkiVJNgPPAhsW+BjCYA71cNBBKslRwF3Ap6vqhWn304eqermqzgBOBlYnmeqvg4d6OOwAVs55fnL3mhax7nfyu4Dbq+rb0+6nbwt9DGFoh3o4PAycnuTUJG8BLgPWT7kn7UN34O5m4PGq+tK0++nL6/kYwtAO6XCoqlngWuA+Rge2vlVVW6fbVT+SfAP4EfDuJNuTXD3tnnpyDvAx4Lw5dxa7eNpN9WA58ECSLYz+p7Whqu6ZZkOH9KlMSQs7pN85SFqY4SCpyXCQ1GQ4SGoyHCQ1GQ6SmgwHSU3/B29FDXeWBQweAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "plt.imshow(probabilities)" ] }, { "cell_type": "markdown", "id": "bibliographic-funds", "metadata": {}, "source": [ "### Question 1.3" ] }, { "cell_type": "code", "execution_count": 7, "id": "patient-wyoming", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[4.79433566e-05 3.22098620e-05 1.71943035e-03 9.98200416e-01]\n", " [5.97319598e-01 1.28333499e-03 4.01298182e-01 9.88847812e-05]\n", " [1.46423661e-02 4.87223528e-04 2.37021787e-01 7.47848624e-01]\n", " [9.06143069e-09 1.87817885e-03 2.98824011e-08 9.98121782e-01]]\n", "[[-0.75220098 -1.50668721]\n", " [-2.29564869 -6.58686077]\n", " [-1.07141415 -2.47595506]\n", " [-0.74812187 -1.4971829 ]]\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQcAAAD8CAYAAAB6iWHJAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAANI0lEQVR4nO3df6zddX3H8eeLUosIyC8zutIJDmZmnANpKo5kISAZEEOXiBksETCQLkYmLjNRt4Rl/rHg/tDEYFwaIANDFAPKOsJCasComSClKR0tQ+/4hxYytGChQ6uXvffH+ZZdLp/bQs/3fM8tfT6Sk/s95/vpfb9vSl+c+/1+z/edqkKS5jts2g1IWpwMB0lNhoOkJsNBUpPhIKnJcJDUNFY4JDk+yYYkP+2+HrfAupeTbO4e68epKWkYGec6hyT/CDxXVTck+RxwXFV9trFud1UdNUafkgY2bjg8AZxbVc8kWQ58r6re3VhnOEgHmXHD4RdVdWy3HeD5vc/nrZsFNgOzwA1VdfcC328tsBZgCUvOOpJjDrg3qS+/976Xpt3CxDyyZc/Pq+odrX37DYck3wVOauz6W+DWuWGQ5Pmqes1xhyQrqmpHkncB9wPnV9V/7avuMTm+PpDz99mbNIT7nt487RYmZsnymUeqalVr3+H7+8NV9aGF9iX57yTL5/xa8ewC32NH9/XJJN8DzgT2GQ6SpmvcU5nrgSu77SuBf5m/IMlxSZZ12ycC5wDbxqwracLGDYcbgAuS/BT4UPecJKuS3NSt+X1gY5JHgQcYHXMwHKRFbr+/VuxLVe0EXnNgoKo2Atd02/8O/ME4dSQNzyskJTUZDpKaDAdJTYaDpCbDQVKT4SCpyXCQ1GQ4SGoyHCQ1GQ6SmgwHSU2Gg6Qmw0FSk+EgqclwkNRkOEhqMhwkNRkOkpp6CYckFyZ5IslMN/lq/v5lSe7o9j+U5JQ+6kqanLHDIckS4KvARcB7gMuTvGfesqsZDbw5Dfgy8MVx60qarD7eOawGZqrqyar6NfBNYM28NWuAW7vtO4HzuwlZkhapPsJhBfDUnOfbu9eaa6pqFtgFnNBDbUkTMtat6fs2d1bmERw55W6kQ1sf7xx2ACvnPD+5e625JsnhwNuBnfO/UVWtq6pVVbVqKct6aE3SgeojHB4GTk9yapK3AJcxGpM319yxeZcC99c4470lTdzYv1ZU1WySa4H7gCXALVW1NckXgI1VtR64Gfh6khngOUYBImkR6+WYQ1XdC9w777Xr52z/CvhoH7UkDcMrJCU1GQ6SmgwHSU2Gg6Qmw0FSk+EgqclwkNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKaDAdJTUPNyrwqyc+SbO4e1/RRV9LkjH2D2TmzMi9gNO3q4STrq2rbvKV3VNW149aTNIw+7j79yqxMgCR7Z2XOD4c3ZM+pb+XJfzhj/O4WmXf9+eZptzAxOz77R9NuYSL+5Len3cEkzSy4Z6hZmQAfSbIlyZ1JVjb2k2Rtko1JNv7vi//TQ2uSDtRQByT/FTilqt4HbOD/J26/ytxxeIcd/baBWpPUMsiszKraWVV7uqc3AWf1UFfSBA0yKzPJ8jlPLwEe76GupAkaalbmp5JcAswympV51bh1JU3WULMyPw98vo9akobhFZKSmgwHSU2Gg6Qmw0FSk+EgqclwkNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKaDAdJTX2Nw7slybNJHltgf5J8pRuXtyXJ+/uoK2ly+nrn8M/AhfvYfxFwevdYC3ytp7qSJqSXcKiq7zO6q/RC1gC31ciDwLHzblcvaZEZ6pjD6xqZ5zg8afFYVAckHYcnLR5DhcN+R+ZJWlyGCof1wBXdWYuzgV1V9cxAtSUdgF4mXiX5BnAucGKS7cDfAUsBquqfGE3DuhiYAV4CPt5HXUmT09c4vMv3s7+AT/ZRS9IwFtUBSUmLh+EgqclwkNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKaDAdJTYaDpCbDQVKT4SCpaahxeOcm2ZVkc/e4vo+6kianl3tIMhqHdyNw2z7W/KCqPtxTPUkTNtQ4PEkHmb7eObweH0zyKPA08Jmq2jp/QZK1jAbtcgRH8rtXNH9LOajVtBuYoBO2zk67hYn4rR8dM+0WJufshXcNFQ6bgHdW1e4kFwN3M5q4/SpVtQ5YB3DMYce/mf8dSYveIGcrquqFqtrdbd8LLE1y4hC1JR2YQcIhyUlJ0m2v7uruHKK2pAMz1Di8S4FPJJkFfglc1k3BkrRIDTUO70ZGpzolHSS8QlJSk+EgqclwkNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKaDAdJTYaDpCbDQVKT4SCpaexwSLIyyQNJtiXZmuS6xpok+UqSmSRbkrx/3LqSJquPe0jOAn9dVZuSHA08kmRDVW2bs+YiRnMqTgc+AHyt+yppkRr7nUNVPVNVm7rtF4HHgRXzlq0BbquRB4Fjkywft7akyen1mEOSU4AzgYfm7VoBPDXn+XZeGyAkWZtkY5KNv6k9fbYm6Q3qLRySHAXcBXy6ql44kO9RVeuqalVVrVqaZX21JukA9BIOSZYyCobbq+rbjSU7gJVznp/cvSZpkerjbEWAm4HHq+pLCyxbD1zRnbU4G9hVVc+MW1vS5PRxtuIc4GPAfyTZ3L32N8DvwCvj8O4FLgZmgJeAj/dQV9IEjR0OVfVDIPtZU8Anx60laTheISmpyXCQ1GQ4SGoyHCQ1GQ6SmgwHSU2Gg6Qmw0FSk+EgqclwkNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUNNQ4vHOT7EqyuXtcP25dSZM11Dg8gB9U1Yd7qCdpAEONw5N0kOnjncMr9jEOD+CDSR4FngY+U1VbG39+LbAW4AiOpGZn+2xPE3bEPT+edgsTcdu6zdNuYWJu38e+3sJhP+PwNgHvrKrdSS4G7mY0cftVqmodsA7gmBxfffUm6Y0bZBxeVb1QVbu77XuBpUlO7KO2pMkYZBxekpO6dSRZ3dXdOW5tSZMz1Di8S4FPJJkFfglc1k3BkrRIDTUO70bgxnFrSRqOV0hKajIcJDUZDpKaDAdJTYaDpCbDQVKT4SCpyXCQ1GQ4SGoyHCQ1GQ6SmgwHSU2Gg6Qmw0FSk+EgqclwkNRkOEhqMhwkNfVxg9kjkvw4yaPdOLy/b6xZluSOJDNJHurmW0haxPp457AHOK+q/hA4A7gwydnz1lwNPF9VpwFfBr7YQ11JE9THOLzaO5MCWNo95t9Zeg1wa7d9J3D+3lvVS1qc+hpqs6S7Lf2zwIaqmj8ObwXwFEBVzQK7gBP6qC1pMnoJh6p6uarOAE4GVid574F8nyRrk2xMsvE37OmjNUkHqNezFVX1C+AB4MJ5u3YAKwGSHA68ncbEq6paV1WrqmrVUpb12ZqkN6iPsxXvSHJst/1W4ALgP+ctWw9c2W1fCtzvxCtpcetjHN5y4NYkSxiFzbeq6p4kXwA2VtV6RrM0v55kBngOuKyHupImqI9xeFuAMxuvXz9n+1fAR8etJWk4XiEpqclwkNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKaDAdJTYaDpCbDQVKT4SCpyXCQ1GQ4SGoaalbmVUl+lmRz97hm3LqSJquPu0/vnZW5O8lS4IdJ/q2qHpy37o6quraHepIG0MfdpwvY36xMSQeZPt450M2seAQ4DfhqY1YmwEeS/DHwE+CvquqpxvdZC6ztnu7+bt35RB/9vU4nAj8fsN5Q/LnGtGT5EFVeZci/s3cutCN9Dp7qJl99B/jLqnpszusnALurak+SvwD+rKrO661wD5JsrKpV0+6jb/5cB5/F8rMNMiuzqnZW1d7JuDcBZ/VZV1L/BpmVmWTuG7NLgMfHrStpsoaalfmpJJcAs4xmZV7VQ92+rZt2AxPiz3XwWRQ/W6/HHCS9eXiFpKQmw0FS0yEfDkkuTPJEkpkkn5t2P31JckuSZ5M8tv/VB48kK5M8kGRbd7n+ddPuqQ+v52MIg/d0KB9z6A6i/oTRGZbtwMPA5VW1baqN9aC74Gw3cFtVvXfa/fSlO/O1vKo2JTma0cV3f3qw/50lCfC2uR9DAK5rfAxhMIf6O4fVwExVPVlVvwa+CayZck+9qKrvMzoz9KZSVc9U1aZu+0VGp8VXTLer8dXIovoYwqEeDiuAuZdxb+dN8B/aoSLJKcCZQOty/YNOkiVJNgPPAhsW+BjCYA71cNBBKslRwF3Ap6vqhWn304eqermqzgBOBlYnmeqvg4d6OOwAVs55fnL3mhax7nfyu4Dbq+rb0+6nbwt9DGFoh3o4PAycnuTUJG8BLgPWT7kn7UN34O5m4PGq+tK0++nL6/kYwtAO6XCoqlngWuA+Rge2vlVVW6fbVT+SfAP4EfDuJNuTXD3tnnpyDvAx4Lw5dxa7eNpN9WA58ECSLYz+p7Whqu6ZZkOH9KlMSQs7pN85SFqY4SCpyXCQ1GQ4SGoyHCQ1GQ6SmgwHSU3/B29FDXeWBQweAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[[1.18306921e-01 2.01966211e-02 1.68483136e-01 6.93013322e-01]\n", " [8.48429312e-04 9.98944583e-01 2.06267364e-04 7.20592823e-07]\n", " [2.67590116e-02 7.79843042e-04 5.42703523e-02 9.18190793e-01]\n", " [2.47463407e-05 6.12522986e-10 2.06437556e-04 9.99768815e-01]]\n", "[[-2.26628332 -2.26628332]\n", " [ 1.99725652 1.99725652]\n", " [-2.82066255 -2.82066255]\n", " [-2.99952526 -2.99952526]]\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQcAAAD8CAYAAAB6iWHJAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAANO0lEQVR4nO3df+hd9X3H8efLGGPVzp8F05hpi9KtdKvWkCnCEH9QlaKDWqZ/tCpKRqmrHSus3cCx9h+7PywUS0dQmZbSWmznsuImEZW2bDrTEK0mUzNh0yi1jVYbrLGJ7/1xT9zXbz9fo7nnnnu/5vmAy/fcez75vt9fEl6533POPe9UFZI03wHTbkDSbDIcJDUZDpKaDAdJTYaDpCbDQVLTWOGQ5Kgk65M80X09coF1u5Ns6h7rxqkpaRgZ5zqHJH8PPF9V1yX5AnBkVf1VY92OqjpsjD4lDWzccHgMOLOqnk2yHLivqj7QWGc4SIvMuOHwy6o6otsO8MKe5/PW7QI2AbuA66rqjgW+3xpgDcCSLD310GVH73Nvs6p2vjrtFibnkIOn3cFE1Ird025hYnY8/rNfVNV7WvsO3NsfTnI3cGxj19/MfVJVlWShpDm+qrYleT9wT5KfVtV/z19UVWuBtQCHv2t5nf7+K/bW3qLz2pP/O+0WJua1D//etFuYiN1ffmHaLUzMfWdf/z8L7dtrOFTVOQvtS/KzJMvn/Frx3ALfY1v39ckk9wGnAL8VDpJmx7inMtcBl3XblwH/PH9BkiOTLOu2jwHOADaPWVfShI0bDtcB5yZ5Ajine06SVUlu7Nb8PrAhyUPAvYyOORgO0ozb668Vb6aqtgNnN17fAFzVbf878Afj1JE0PK+QlNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKaDAdJTYaDpCbDQVKT4SCpyXCQ1GQ4SGrqJRySnJfksSRbu8lX8/cvS3Jbt/+BJCf0UVfS5IwdDkmWAF8Hzgc+CFya5IPzll3JaODNicBXga+MW1fSZPXxzmE1sLWqnqyqV4HvABfNW3MRcEu3fTtwdjchS9KM6iMcVgBPzXn+dPdac01V7QJeBN55s+6kd5CZOiCZZE2SDUk2vLr75Wm3I+3X+giHbcDKOc+P615rrklyIHA4sH3+N6qqtVW1qqpWHbTkkB5ak7Sv+giHB4GTkrwvyUHAJYzG5M01d2zexcA9Nc54b0kTN9bEKxgdQ0hyNXAXsAS4uaoeTfIlYENVrQNuAr6ZZCvwPKMAkTTDxg4HgKq6E7hz3mvXztl+BfhEH7UkDWOmDkhKmh2Gg6Qmw0FSk+EgqclwkNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKaDAdJTYaDpCbDQVLTULMyL0/y8ySbusdVfdSVNDlj32B2zqzMcxlNu3owybqq2jxv6W1VdfW49SQNo4+7T78+KxMgyZ5ZmfPD4W2pV3aye8sTPbQ3W+56ZtO0W5iYj75357RbmIgDz5l2B9Mx1KxMgI8neTjJ7UlWNva/YRzeb3hn/kOTFouhDkj+C3BCVf0hsJ7/n7j9BnPH4S1l2UCtSWoZZFZmVW2vqj1vBW4ETu2hrqQJGmRWZpLlc55eCGzpoa6kCRpqVuZnk1wI7GI0K/PycetKmqyhZmV+EfhiH7UkDcMrJCU1GQ6SmgwHSU2Gg6Qmw0FSk+EgqclwkNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKa+hqHd3OS55I8ssD+JPlaNy7v4SQf6aOupMnp653DPwLnvcn+84GTusca4Bs91ZU0Ib2EQ1X9kNFdpRdyEXBrjdwPHDHvdvWSZsxQxxze0sg8x+FJs2OmDkg6Dk+aHUOFw15H5kmaLUOFwzrgU91Zi9OAF6vq2YFqS9oHvUy8SvJt4EzgmCRPA38LLAWoqn9gNA3rAmAr8DJwRR91JU1OX+PwLt3L/gI+00ctScOYqQOSkmaH4SCpyXCQ1GQ4SGoyHCQ1GQ6SmgwHSU2Gg6Qmw0FSk+EgqclwkNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKlpqHF4ZyZ5Mcmm7nFtH3UlTU4v95BkNA7vBuDWN1nzo6r6WE/1JE3YUOPwJC0yfb1zeCtOT/IQ8Azw+ap6dP6CJGsYDdrl4BzKAQcfPGB7w/joe0+edgsTs+Too6bdwkR8ecO/TbuFibn7hIX3DRUOG4Hjq2pHkguAOxhN3H6DqloLrAU4/ICja6DeJDUMcraiql6qqh3d9p3A0iTHDFFb0r4ZJBySHJsk3fbqru72IWpL2jdDjcO7GPh0kl3Ar4FLuilYkmbUUOPwbmB0qlPSIuEVkpKaDAdJTYaDpCbDQVKT4SCpyXCQ1GQ4SGoyHCQ1GQ6SmgwHSU2Gg6Qmw0FSk+EgqclwkNRkOEhqMhwkNRkOkpoMB0lNY4dDkpVJ7k2yOcmjSa5prEmSryXZmuThJB8Zt66kyerjHpK7gL+sqo1J3g38JMn6qto8Z835jOZUnAT8EfCN7qukGTX2O4eqeraqNnbbvwK2ACvmLbsIuLVG7geOSLJ83NqSJqfXYw5JTgBOAR6Yt2sF8NSc50/z2wFCkjVJNiTZ8Co7+2xN0tvUWzgkOQz4HvC5qnppX75HVa2tqlVVteoglvXVmqR90Es4JFnKKBi+VVXfbyzZBqyc8/y47jVJM6qPsxUBbgK2VNX1CyxbB3yqO2txGvBiVT07bm1Jk9PH2YozgE8CP02yqXvtr4HfhdfH4d0JXABsBV4GruihrqQJGjscqurHQPaypoDPjFtL0nC8QlJSk+EgqclwkNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKaDAdJTYaDpCbDQVKT4SCpaahxeGcmeTHJpu5x7bh1JU3WUOPwAH5UVR/roZ6kAQw1Dk/SItPHO4fXvck4PIDTkzwEPAN8vqoebfz5NcAagIM5hNdeeaXP9jRhu7c/P+0WJuLUZQdNu4Wp6C0c9jIObyNwfFXtSHIBcAejidtvUFVrgbUAv5Ojqq/eJL19g4zDq6qXqmpHt30nsDTJMX3UljQZg4zDS3Jst44kq7u628etLWlyhhqHdzHw6SS7gF8Dl3RTsCTNqKHG4d0A3DBuLUnD8QpJSU2Gg6Qmw0FSk+EgqclwkNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKaDAdJTYaDpKY+bjB7cJL/TPJQNw7v7xprliW5LcnWJA908y0kzbA+3jnsBM6qqg8DJwPnJTlt3porgReq6kTgq8BXeqgraYL6GIdXe2ZSAEu7x/w7S18E3NJt3w6cvedW9ZJmU19DbZZ0t6V/DlhfVfPH4a0AngKoql3Ai8DRfdSWNBm9hENV7a6qk4HjgNVJPrQv3yfJmiQbkmz4DTv7aE3SPur1bEVV/RK4Fzhv3q5twEqAJAcCh9OYeFVVa6tqVVWtWsqyPluT9Db1cbbiPUmO6LbfBZwL/Ne8ZeuAy7rti4F7nHglzbY+xuEtB25JsoRR2Hy3qn6Q5EvAhqpax2iW5jeTbAWeBy7poa6kCepjHN7DwCmN16+ds/0K8Ilxa0kajldISmoyHCQ1GQ6SmgwHSU2Gg6Qmw0FSk+EgqclwkNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKahpqVeXmSnyfZ1D2uGreupMnq4+7Te2Zl7kiyFPhxkn+tqvvnrbutqq7uoZ6kAfRx9+kC9jYrU9Ii08c7B7qZFT8BTgS+3piVCfDxJH8MPA78RVU91fg+a4A13dMdd9ftj/XR31t0DPCLAesNxZ9rTEuWD1HlDYb8Ozt+oR3pc/BUN/nqn4A/r6pH5rx+NLCjqnYm+TPgT6vqrN4K9yDJhqpaNe0++ubPtfjMys82yKzMqtpeVXsm494InNpnXUn9G2RWZpK5b8wuBLaMW1fSZA01K/OzSS4EdjGalXl5D3X7tnbaDUyIP9fiMxM/W6/HHCS9c3iFpKQmw0FS034fDknOS/JYkq1JvjDtfvqS5OYkzyV5ZO+rF48kK5Pcm2Rzd7n+NdPuqQ9v5WMIg/e0Px9z6A6iPs7oDMvTwIPApVW1eaqN9aC74GwHcGtVfWja/fSlO/O1vKo2Jnk3o4vv/mSx/50lCXDo3I8hANc0PoYwmP39ncNqYGtVPVlVrwLfAS6ack+9qKofMjoz9I5SVc9W1cZu+1eMTouvmG5X46uRmfoYwv4eDiuAuZdxP8074B/a/iLJCcApQOty/UUnyZIkm4DngPULfAxhMPt7OGiRSnIY8D3gc1X10rT76UNV7a6qk4HjgNVJpvrr4P4eDtuAlXOeH9e9phnW/U7+PeBbVfX9affTt4U+hjC0/T0cHgROSvK+JAcBlwDrptyT3kR34O4mYEtVXT/tfvryVj6GMLT9OhyqahdwNXAXowNb362qR6fbVT+SfBv4D+ADSZ5OcuW0e+rJGcAngbPm3Fnsgmk31YPlwL1JHmb0n9b6qvrBNBvar09lSlrYfv3OQdLCDAdJTYaDpCbDQVKT4SCpyXCQ1GQ4SGr6P68zE5LVU+UjAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[[-0.75220098 -1.50668721 -2.26628332 -2.26628332]\n", " [-2.29564869 -6.58686077 1.99725652 1.99725652]\n", " [-1.07141415 -2.47595506 -2.82066255 -2.82066255]\n", " [-0.74812187 -1.4971829 -2.99952526 -2.99952526]]\n", "[[-6.04664898 3.77781072 -0.75731086]\n", " [ 8.28741825 0.14750295 10.57968068]\n", " [-7.3905735 5.09982766 -0.01158073]\n", " [-8.25045389 4.87428797 -1.50140321]]\n" ] } ], "source": [ "W_O = np.array([[-1, 1.5, 2], [0, -1, -2], [1, -1.5, 0], [2, 0, 1]])\n", "\n", "W_Q_heads = [W_Q, np.ones_like(W_Q)]\n", "W_K_heads = [W_K, np.ones_like(W_K)]\n", "W_V_heads = [W_V, np.ones_like(W_V)]\n", "\n", "head_representations = []\n", "\n", "for W_Q_h, W_K_h, W_V_h in zip(W_Q_heads, W_K_heads, W_V_heads):\n", " Q_h = X.dot(W_Q_h)\n", " K_h = X.dot(W_K_h)\n", " V_h = X.dot(W_V_h)\n", " scores = Q_h.dot(K_h.T) / np.sqrt(np.size(Q_h, 1))\n", " probabilities = np.exp(scores) / np.sum(np.exp(scores), axis=1)[:, None]\n", " Z_h = probabilities.dot(V_h)\n", " head_representations.append(Z_h)\n", " print(probabilities)\n", " print(Z_h)\n", " plt.imshow(probabilities)\n", " plt.show()\n", " \n", "print(np.concatenate(head_representations, axis=1))\n", "\n", "Z = np.concatenate(head_representations, axis=1).dot(W_O)\n", "\n", "print(Z)" ] }, { "cell_type": "markdown", "id": "f538217d", "metadata": {}, "source": [ "### Question 2" ] }, { "cell_type": "markdown", "id": "cd9e39cd", "metadata": {}, "source": [ "In this exercise, you will implement a sequence-to-sequence network that reverses strings with the help of attention. We will randomly generate strings consisting of \"a\", \"b\", \"c\", and \"d\"." ] }, { "cell_type": "code", "execution_count": 8, "id": "dbb48a59", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['ccadbabacdbd', 'ababdcbdbab', 'dbbaabbbbccb', 'bbbdcacdbbc', 'ccac', 'cccd', 'cbddbacacdadcd', 'adabbabdcccd', 'ccad', 'bccb']\n" ] } ], "source": [ "BOS = \"\"\n", "EOS = \"\"\n", "\n", "raw_vocab = list(\"abcd\")\n", "itos = [BOS, EOS] + raw_vocab\n", "stoi = {n: i for i, n in enumerate(itos)}\n", "vocab_size = len(itos) # Plus BOS/EOS\n", "\n", "N = 200\n", "valid_size = 100\n", "\n", "def sample_string(min_length, max_length):\n", " length = random.randrange(min_length, max_length)\n", " return \"\".join([random.choice(raw_vocab) for _ in range(length)])\n", "\n", "def sample_strings(min_length, max_length, size):\n", " return [sample_string(min_length, max_length) for _ in range(size)]\n", "\n", "def to_tensor(name):\n", " indices = [stoi[BOS]] + [stoi[n] for n in name] + [stoi[EOS]]\n", " return torch.tensor(indices, dtype=torch.long).unsqueeze(0)\n", "\n", "def make_dataset(lines):\n", " dataset = [(to_tensor(line), to_tensor(reversed(line))) for line in lines]\n", " return dataset\n", "\n", "train_lines = sample_strings(3, 15, N)\n", "valid_lines = sample_strings(3, 15, valid_size)\n", "\n", "train_dataset = make_dataset(train_lines)\n", "valid_dataset = make_dataset(valid_lines)\n", "\n", "print(train_lines[:10])" ] }, { "cell_type": "markdown", "id": "4c4dc806", "metadata": {}, "source": [ "The first part of the model is an RNN-based encoder:" ] }, { "cell_type": "code", "execution_count": 15, "id": "ad9671f8", "metadata": {}, "outputs": [], "source": [ "class Encoder(nn.Module):\n", " def __init__(self, vocab_size, embedding_size, hidden_size, bidirectional=False):\n", " super(Encoder, self).__init__()\n", "\n", " self.embeddings = nn.Embedding(vocab_size, embedding_size)\n", "\n", " if bidirectional:\n", " hidden_size //= 2\n", " self.rnn = nn.LSTM(\n", " embedding_size, \n", " hidden_size, \n", " bidirectional=bidirectional, \n", " batch_first=True\n", " )\n", "\n", " def forward(self, input, hidden=None):\n", " \"\"\"\n", " input (LongTensor): batch x src length\n", " src length (batch-length list0: If given, the input will be packed\n", " hidden: hidden or hidden/cell state input dimensions for the RNN type\n", " returns:\n", " output (FloatTensor): batch x src length x hidden size\n", " hidden_n (FloatTensor): hidden or hidden/cell state input\n", " dimensions for the RNN type\n", " \"\"\"\n", " emb = self.embeddings(input)\n", " output, hidden_n = self.rnn(emb, hidden)\n", " if self.rnn.bidirectional:\n", " hidden_n = self._reshape_hidden(hidden_n)\n", " return output, hidden_n\n", "\n", " def _merge_tensor(self, state_tensor):\n", " forward_states = state_tensor[::2]\n", " backward_states = state_tensor[1::2]\n", " return torch.cat([forward_states, backward_states], 2)\n", "\n", " def _reshape_hidden(self, hidden):\n", " \"\"\"\n", " hidden:\n", " num_layers * num_directions x batch x self.hidden_size // 2\n", " or a tuple of these\n", " returns:\n", " num_layers\n", " \"\"\"\n", " assert self.rnn.bidirectional\n", " if isinstance(hidden, tuple):\n", " return tuple(self._merge_tensor(h) for h in hidden)\n", " else:\n", " return self._merge_tensor(hidden)" ] }, { "cell_type": "markdown", "id": "d964f08c", "metadata": {}, "source": [ "We also need to define a decoder. This implementation works both with and without an attention mechanism." ] }, { "cell_type": "code", "execution_count": 16, "id": "057b604e", "metadata": {}, "outputs": [], "source": [ "class Decoder(nn.Module):\n", " def __init__(self, vocab_size, embedding_size, hidden_size, attn=None):\n", " super(Decoder, self).__init__()\n", " self.embeddings = nn.Embedding(vocab_size, embedding_size)\n", " self.rnn = nn.LSTM(embedding_size, hidden_size, batch_first=True)\n", " self.output_layer = nn.Linear(hidden_size, vocab_size)\n", " self.attn = attn\n", "\n", " def forward(self, input, context, hidden):\n", " \"\"\"\n", " input (LongTensor): batch x tgt length\n", " context (FloatTensor): batch x src length x hidden size\n", " hidden: hidden or hidden/cell state input dimensions for the RNN type\n", " returns (FloatTensor): (batch*tgt length) x output size\n", " \"\"\"\n", " emb = self.embeddings(input)\n", " output, hidden_n = self.rnn(emb, hidden)\n", "\n", " alignment = None\n", " # apply attention between source context and query from\n", " # decoder RNN\n", " if self.attn is not None:\n", " output, alignment = self.attn(output, context)\n", "\n", " flat_output = output.contiguous().view(-1, self.rnn.hidden_size)\n", " return self.output_layer(flat_output), alignment" ] }, { "cell_type": "markdown", "id": "43860e02", "metadata": {}, "source": [ "We can put them together into an encoder-decoder model class, like this:" ] }, { "cell_type": "code", "execution_count": 17, "id": "e5557b0e", "metadata": {}, "outputs": [], "source": [ "class Seq2Seq(nn.Module):\n", " def __init__(self, encoder, decoder):\n", " super(Seq2Seq, self).__init__()\n", " self.encoder = encoder\n", " self.decoder = decoder\n", "\n", " def forward(self, src, tgt):\n", " \"\"\"\n", " src, tgt (LongTensor): (batch size x sequence length)\n", " returns (FloatTensor): (batch*tgt length) x output size\n", " \"\"\"\n", " context, enc_hidden = self.encoder(src)\n", " return self.decoder(tgt, context=context, hidden=enc_hidden)" ] }, { "cell_type": "markdown", "id": "2e5dbdc9", "metadata": {}, "source": [ "With our base model defined, we can write training and validation code:" ] }, { "cell_type": "code", "execution_count": 18, "id": "8959b3a8", "metadata": {}, "outputs": [], "source": [ "def train_epoch(model, train_iter, loss, optimizer):\n", " epoch_loss = 0.0\n", " model.train()\n", " random.shuffle(train_iter) # present examples in random order\n", " for src, tgt in train_iter:\n", " model.zero_grad()\n", " tgt_in = tgt[:, :-1]\n", " pred, _ = model(src, tgt_in)\n", " gold = tgt[:, 1:].contiguous().view(-1)\n", "\n", " batch_loss = loss(pred, gold)\n", " batch_loss.backward()\n", " optimizer.step()\n", " epoch_loss += batch_loss.item()\n", " return epoch_loss\n", "\n", "\n", "def validate(model, data_iter):\n", " model.eval()\n", " n_correct = 0\n", " n_total = 0\n", " with torch.no_grad():\n", " for src, tgt in data_iter:\n", " tgt_in = tgt[:, :-1]\n", " pred = model(src, tgt_in)[0].argmax(dim=1)\n", " gold = tgt[:, 1:].contiguous().view(-1)\n", " n_correct += (pred == gold).sum().item()\n", " n_total += gold.size(0)\n", " return n_correct / n_total\n", "\n", "\n", "def train(model, train, valid, epochs=30, learning_rate=0.5):\n", " loss = nn.CrossEntropyLoss()\n", " optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)\n", "\n", " train_losses = []\n", " valid_accs = []\n", " epochs = list(range(1, epochs + 1))\n", " for epoch in epochs:\n", " print('Training epoch {}'.format(epoch))\n", " train_loss = train_epoch(model, train, loss, optimizer)\n", " train_losses.append(train_loss)\n", " valid_acc = validate(model, valid)\n", " valid_accs.append(valid_acc)\n", " print('Train loss: {} ; Validation acc: {}'.format(train_loss, valid_acc))" ] }, { "cell_type": "markdown", "id": "32421fb3", "metadata": {}, "source": [ "Train a unidirectional model without attention:" ] }, { "cell_type": "code", "execution_count": 19, "id": "5cf2d5f5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Seq2Seq(\n", " (encoder): Encoder(\n", " (embeddings): Embedding(6, 6)\n", " (rnn): LSTM(6, 64, batch_first=True)\n", " )\n", " (decoder): Decoder(\n", " (embeddings): Embedding(6, 6)\n", " (rnn): LSTM(6, 64, batch_first=True)\n", " (output_layer): Linear(in_features=64, out_features=6, bias=True)\n", " )\n", ")\n", "Training epoch 1\n", "Train loss: 321.6173617839813 ; Validation acc: 0.21871599564744287\n", "Training epoch 2\n", "Train loss: 314.5878132581711 ; Validation acc: 0.24374319912948858\n", "Training epoch 3\n", "Train loss: 311.40864992141724 ; Validation acc: 0.24483133841131666\n", "Training epoch 4\n", "Train loss: 302.0876899957657 ; Validation acc: 0.30467899891186073\n", "Training epoch 5\n", "Train loss: 280.0403906106949 ; Validation acc: 0.32861806311207836\n", "Training epoch 6\n", "Train loss: 252.06888407468796 ; Validation acc: 0.367791077257889\n", "Training epoch 7\n", "Train loss: 240.5655066370964 ; Validation acc: 0.44178454842219805\n", "Training epoch 8\n", "Train loss: 222.34247043728828 ; Validation acc: 0.49836779107725787\n", "Training epoch 9\n", "Train loss: 202.32590752840042 ; Validation acc: 0.5690968443960827\n", "Training epoch 10\n", "Train loss: 186.37792070955038 ; Validation acc: 0.5636561479869423\n", "Training epoch 11\n", "Train loss: 166.40333053469658 ; Validation acc: 0.6539717083786725\n", "Training epoch 12\n", "Train loss: 154.96789521723986 ; Validation acc: 0.6256800870511425\n", "Training epoch 13\n", "Train loss: 142.93639708310366 ; Validation acc: 0.6974972796517954\n", "Training epoch 14\n", "Train loss: 129.50957854278386 ; Validation acc: 0.750816104461371\n", "Training epoch 15\n", "Train loss: 121.79228046350181 ; Validation acc: 0.6855277475516867\n", "Training epoch 16\n", "Train loss: 114.97754304856062 ; Validation acc: 0.7094668117519043\n", "Training epoch 17\n", "Train loss: 98.03939318470657 ; Validation acc: 0.7747551686615887\n", "Training epoch 18\n", "Train loss: 93.48300758749247 ; Validation acc: 0.7845484221980413\n", "Training epoch 19\n", "Train loss: 90.27981984382495 ; Validation acc: 0.7899891186071817\n", "Training epoch 20\n", "Train loss: 77.29049127455801 ; Validation acc: 0.7878128400435256\n", "Training epoch 21\n", "Train loss: 78.23951551923528 ; Validation acc: 0.8324265505984766\n", "Training epoch 22\n", "Train loss: 72.50530843716115 ; Validation acc: 0.8052230685527747\n", "Training epoch 23\n", "Train loss: 65.12976488936692 ; Validation acc: 0.8128400435255713\n", "Training epoch 24\n", "Train loss: 62.22007679147646 ; Validation acc: 0.8280739934711643\n", "Training epoch 25\n", "Train loss: 60.39888472785242 ; Validation acc: 0.8465723612622416\n", "Training epoch 26\n", "Train loss: 56.14480369747616 ; Validation acc: 0.8302502720348205\n", "Training epoch 27\n", "Train loss: 54.18179729941767 ; Validation acc: 0.8302502720348205\n", "Training epoch 28\n", "Train loss: 48.06646375160199 ; Validation acc: 0.8280739934711643\n", "Training epoch 29\n", "Train loss: 39.795042789541185 ; Validation acc: 0.8389553862894451\n", "Training epoch 30\n", "Train loss: 52.13874210277572 ; Validation acc: 0.8433079434167573\n", "Training epoch 31\n", "Train loss: 47.36617806646973 ; Validation acc: 0.8291621327529923\n", "Training epoch 32\n", "Train loss: 39.336181659135036 ; Validation acc: 0.8509249183895539\n", "Training epoch 33\n", "Train loss: 36.418385865632445 ; Validation acc: 0.8258977149075082\n", "Training epoch 34\n", "Train loss: 36.226696672150865 ; Validation acc: 0.8661588683351469\n", "Training epoch 35\n", "Train loss: 25.227022941282485 ; Validation acc: 0.8705114254624592\n", "Training epoch 36\n", "Train loss: 37.27447834808845 ; Validation acc: 0.8628944504896626\n", "Training epoch 37\n", "Train loss: 32.050380388769554 ; Validation acc: 0.8607181719260065\n", "Training epoch 38\n", "Train loss: 30.29406469326932 ; Validation acc: 0.8683351468988031\n", "Training epoch 39\n", "Train loss: 27.635538705944782 ; Validation acc: 0.8574537540805223\n", "Training epoch 40\n", "Train loss: 23.632056825765176 ; Validation acc: 0.8596300326441785\n", "Training epoch 41\n", "Train loss: 30.07283775685937 ; Validation acc: 0.8628944504896626\n", "Training epoch 42\n", "Train loss: 21.734947333912714 ; Validation acc: 0.8607181719260065\n", "Training epoch 43\n", "Train loss: 16.206716568412958 ; Validation acc: 0.8737758433079434\n", "Training epoch 44\n", "Train loss: 23.5461353562132 ; Validation acc: 0.8302502720348205\n", "Training epoch 45\n", "Train loss: 23.969940600160044 ; Validation acc: 0.8639825897714908\n", "Training epoch 46\n", "Train loss: 16.859995011778665 ; Validation acc: 0.8726877040261154\n", "Training epoch 47\n", "Train loss: 20.785299404582474 ; Validation acc: 0.8824809575625681\n", "Training epoch 48\n", "Train loss: 14.41157994288369 ; Validation acc: 0.8705114254624592\n", "Training epoch 49\n", "Train loss: 16.775881042180117 ; Validation acc: 0.8487486398258978\n", "Training epoch 50\n", "Train loss: 14.489260943642876 ; Validation acc: 0.85310119695321\n" ] } ], "source": [ "embedding_size = vocab_size\n", "hidden_size = 64\n", "\n", "enc = Encoder(vocab_size, embedding_size, hidden_size)\n", "dec = Decoder(vocab_size, embedding_size, hidden_size)\n", "enc.embeddings.weight.data = torch.eye(vocab_size)\n", "dec.embeddings.weight.data = enc.embeddings.weight.data\n", "enc.embeddings.weight.requires_grad = False\n", "dec.embeddings.weight.requires_grad = False\n", "\n", "model = Seq2Seq(enc, dec)\n", "print(model)\n", "\n", "train(model, train_dataset, valid_dataset, epochs=50)" ] }, { "cell_type": "markdown", "id": "cf88e301", "metadata": {}, "source": [ "This model often manages to predict the right sequence, but it also often fails. Note that the decoder makes use of the *last* hidden state from the encoder, which has recently seen the final time step of the source sequence (in other words, the first element it needs to predict), but the other elements less recently. If only there were a way to make it easier for the model to focus on less recent positions...\n", "\n", "The attention mechanism is an extra layer for the decoder that can do precisely this. In this exercise, we consider a simple but effective attention mechanism called *dot product attention*.\n" ] }, { "cell_type": "code", "execution_count": 21, "id": "5dd55f2a", "metadata": {}, "outputs": [], "source": [ "class DotProdAttention(nn.Module):\n", "\n", " def __init__(self, hidden_size):\n", " super().__init__()\n", " self.mlp = nn.Sequential(\n", " nn.Linear(hidden_size * 2, hidden_size),\n", " nn.Tanh()\n", " )\n", "\n", " def forward(self, query, context):\n", " \"\"\"\n", " query: batch x tgt_length x hidden_size\n", " context: batch x src_length x hidden_size\n", " \"\"\"\n", " tgt_batch, tgt_len, tgt_hidden = query.size()\n", " src_batch, src_len, src_hidden = context.size()\n", " attn_scores = torch.bmm(query, context.transpose(1, 2))\n", " alignment = torch.softmax(attn_scores, 2)\n", " c = torch.bmm(alignment, context)\n", " attn_h_t = self.mlp(torch.cat([c, query], dim=2))\n", " return attn_h_t, alignment" ] }, { "cell_type": "markdown", "id": "06d56b9e", "metadata": {}, "source": [ "Now that attention has been implemented, we can train the model:" ] }, { "cell_type": "code", "execution_count": 22, "id": "fb3e6acd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Seq2Seq(\n", " (encoder): Encoder(\n", " (embeddings): Embedding(6, 6)\n", " (rnn): LSTM(6, 32, batch_first=True, bidirectional=True)\n", " )\n", " (decoder): Decoder(\n", " (embeddings): Embedding(6, 6)\n", " (rnn): LSTM(6, 64, batch_first=True)\n", " (output_layer): Linear(in_features=64, out_features=6, bias=True)\n", " (attn): DotProdAttention(\n", " (mlp): Sequential(\n", " (0): Linear(in_features=128, out_features=64, bias=True)\n", " (1): Tanh()\n", " )\n", " )\n", " )\n", ")\n", "Training epoch 1\n", "Train loss: 323.3171902894974 ; Validation acc: 0.2970620239390642\n", "Training epoch 2\n", "Train loss: 307.8389209508896 ; Validation acc: 0.3558215451577802\n", "Training epoch 3\n", "Train loss: 278.0971790552139 ; Validation acc: 0.17845484221980412\n", "Training epoch 4\n", "Train loss: 244.63097456097603 ; Validation acc: 0.4047878128400435\n", "Training epoch 5\n", "Train loss: 194.5210478901863 ; Validation acc: 0.719260065288357\n", "Training epoch 6\n", "Train loss: 120.59430788457394 ; Validation acc: 0.8813928182807399\n", "Training epoch 7\n", "Train loss: 42.478179505094886 ; Validation acc: 0.9532100108813928\n", "Training epoch 8\n", "Train loss: 23.710292062081862 ; Validation acc: 0.984766050054407\n", "Training epoch 9\n", "Train loss: 13.080233365355525 ; Validation acc: 0.984766050054407\n", "Training epoch 10\n", "Train loss: 7.509345229467726 ; Validation acc: 0.9836779107725789\n", "Training epoch 11\n", "Train loss: 2.1899635198205942 ; Validation acc: 0.998911860718172\n", "Training epoch 12\n", "Train loss: 0.5873596607416403 ; Validation acc: 0.9967355821545157\n", "Training epoch 13\n", "Train loss: 0.31321665963332634 ; Validation acc: 0.9956474428726877\n", "Training epoch 14\n", "Train loss: 0.26936489673971664 ; Validation acc: 0.998911860718172\n", "Training epoch 15\n", "Train loss: 0.36754418490090757 ; Validation acc: 1.0\n", "Training epoch 16\n", "Train loss: 0.13098835148048238 ; Validation acc: 0.9978237214363439\n", "Training epoch 17\n", "Train loss: 0.07573139200758305 ; Validation acc: 1.0\n", "Training epoch 18\n", "Train loss: 0.0442355618433794 ; Validation acc: 1.0\n", "Training epoch 19\n", "Train loss: 0.03942857300717151 ; Validation acc: 1.0\n", "Training epoch 20\n", "Train loss: 0.03561293949314859 ; Validation acc: 1.0\n", "Training epoch 21\n", "Train loss: 0.032497035406777286 ; Validation acc: 1.0\n", "Training epoch 22\n", "Train loss: 0.029893240634919493 ; Validation acc: 1.0\n", "Training epoch 23\n", "Train loss: 0.027674177041262737 ; Validation acc: 1.0\n", "Training epoch 24\n", "Train loss: 0.02576795702589152 ; Validation acc: 1.0\n", "Training epoch 25\n", "Train loss: 0.024103457562887343 ; Validation acc: 1.0\n", "Training epoch 26\n", "Train loss: 0.02264278964139521 ; Validation acc: 1.0\n", "Training epoch 27\n", "Train loss: 0.0213495231437264 ; Validation acc: 1.0\n", "Training epoch 28\n", "Train loss: 0.02019345318149135 ; Validation acc: 1.0\n", "Training epoch 29\n", "Train loss: 0.019156646067131078 ; Validation acc: 1.0\n", "Training epoch 30\n", "Train loss: 0.01822057438766933 ; Validation acc: 1.0\n" ] } ], "source": [ "attn = DotProdAttention(hidden_size)\n", "enc = Encoder(vocab_size, embedding_size, hidden_size, bidirectional=True)\n", "dec = Decoder(vocab_size, embedding_size, hidden_size, attn=attn)\n", "enc.embeddings.weight.data = torch.eye(vocab_size)\n", "dec.embeddings.weight.data = enc.embeddings.weight.data\n", "enc.embeddings.weight.requires_grad = False\n", "dec.embeddings.weight.requires_grad = False\n", "\n", "attn_model = Seq2Seq(enc, dec)\n", "print(attn_model)\n", "\n", "train(attn_model, train_dataset, valid_dataset, epochs=30)" ] }, { "cell_type": "markdown", "id": "7d2b4132", "metadata": {}, "source": [ "We can also visualize the model's attention matrix:" ] }, { "cell_type": "code", "execution_count": 24, "id": "2fba04b2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQEAAAECCAYAAAD+eGJTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAOK0lEQVR4nO3dbYxc5XnG8evyvthe82JMgxVsVPwBEVGU1umIQKgowrRywMVR1USgEEGDZEVKE4IiURBSab9VSoqSKmkiCwiosZwPDkkQBYIDSVGlBLGAQ/3WmCYUGwx2gARix95d790PM0hm/bLeec6ec8z9/0nWzsyO73Oxu74458yZZx0RApDXnKYDAGgWJQAkRwkAyVECQHKUAJAcJQAk15oSsL3S9v/YfsH2bQ1lOMf2j21vtb3F9s1N5JiSacD2c7YfajDDQtsbbG+3vc32JQ3luKX3fdlse73teTVu+17be2xvPuyxRbY32t7R+3hGQzm+1PvePG/7e7YXzmRmK0rA9oCkr0v6qKQLJF1n+4IGokxI+mJEXCDpYkmfbSjH4W6WtK3hDF+V9GhEfEDSHzeRx/YSSZ+X1ImICyUNSLq2xgj3SVo55bHbJD0eEedJerx3v4kcGyVdGBEflPQLSbfPZGArSkDSRZJeiIhfRsSYpO9IWl13iIjYHRHP9m6/re4P+5K6c7zD9lJJV0u6u8EMp0u6TNI9khQRYxHxm4biDEqab3tQ0oikV+racEQ8KemNKQ+vlnR/7/b9kj7WRI6IeCwiJnp3fyZp6UxmtqUElkjaedj9XWrwH58k2T5X0nJJTzUY4yuSbpU02WCGZZL2SvpW77DkbtsL6g4RES9L+rKklyTtlvTbiHis7hxTLI6I3b3br0pa3GSYnk9LemQmf6EtJdAqtk+R9F1JX4iItxrKsErSnoh4pontH2ZQ0ockfSMilkvap3p2e9+ld7y9Wt1SOlvSAtvX153jWKJ7/X2j1+DbvkPdQ9p1M/l7bSmBlyWdc9j9pb3Hamd7SN0CWBcRDzSRoedSSdfYflHdw6MrbH+7gRy7JO2KiHf2iDaoWwp1u1LSryJib0SMS3pA0kcayHG412y/X5J6H/c0FcT2jZJWSfpkzPANQW0pgaclnWd7me1hdU/4PFh3CNtW99h3W0TcVff2DxcRt0fE0og4V92vxxMRUfv/+SLiVUk7bZ/fe2iFpK1151D3MOBi2yO979MKNX/C9EFJN/Ru3yDpB02EsL1S3cPGayJi/4wHREQr/ki6St0zm/8r6Y6GMvyZurt0z0va1PtzVQu+NpdLeqjB7f+JpNHe1+X7ks5oKMc/SdouabOkf5c0t8Ztr1f3XMS4untHN0k6U91XBXZI+pGkRQ3leEHdc2rv/Mx+cyYz3RsMIKm2HA4AaAglACRHCQDJUQJAcpQAkFzrSsD2mqYzSOQ4mrZkIceRSrK0rgQkteULS44jtSULOY70nioBADWq9WKhYc+NeTr+G9DGdVBDmltTInLMRFuy1JUjTh05fo7xfRoamv4NlQNnjxVnmXxl+LifHxvbp+HhY2c5cOBNjY3t89E+N1gWbWbmaYE+7BV1bhJZ+ag/7zMycVE175M6/R9fKp7x1p3nTP+k4xh9+uvH/ByHA0BylACQHCUAJFdUAm1YIRhAmb5LoEUrBAMoULIn0IoVggGUKSmB1q0QDGDmZv06gd41zWskaZ6Of/EFgPqV7Amc0ArBEbE2IjoR0WnD1WYA3q2kBFqxQjCAMn0fDkTEhO2/k/RDdX8v3L0RsaWyZABqUXROICIelvRwRVkANIArBoHkKAEguVrfSowWq+Ctt5KkCtan8GD5j+Vrn7moeMbQR/cWz5CkuR8v/9oO/qbwdNvB3x/zU+wJAMlRAkBylACQHCUAJEcJAMlRAkBylACQHCUAJEcJAMlRAkBylACQHCUAJEcJAMlRAkBylACQHCUAJMeiIu8FVS0IUoUKsuz7qz8tntH51M+LZ7zyiUXFMyRpYk8Fi5OULtZynL/PngCQHCUAJEcJAMlRAkBylACQXN8lYPsc2z+2vdX2Fts3VxkMQD1KXiKckPTFiHjW9qmSnrG9MSK2VpQNQA363hOIiN0R8Wzv9tuStklaUlUwAPWo5JyA7XMlLZf0VBXzANSn+IpB26dI+q6kL0TEW0f5/BpJayRpnkZKNwegYkV7AraH1C2AdRHxwNGeExFrI6ITEZ0hzS3ZHIBZUPLqgCXdI2lbRNxVXSQAdSrZE7hU0qckXWF7U+/PVRXlAlCTvs8JRMR/SWrR29cA9IMrBoHkKAEgOUoASI6Vhd4DPDxcPCPGJypIInlgoHjGv9z1teIZ//DXNxbPiJ3bi2d0BxWuCjTL2BMAkqMEgOQoASA5SgBIjhIAkqMEgOQoASA5SgBIjhIAkqMEgOQoASA5SgBIjhIAkqMEgOQoASA5SgBIjkVFmjanfBGOKhYE8Zxq1oztPP374hm3/P3nimecsqmCX4ZV1WIgbsF6vMf5T2FPAEiOEgCSowSA5CgBILniErA9YPs52w9VEQhAvarYE7hZ0rYK5gBoQOmvJl8q6WpJd1cTB0DdSvcEviLpVkmT5VEANKHvErC9StKeiHhmmuetsT1qe3RcB/vdHIBZUrIncKmka2y/KOk7kq6w/e2pT4qItRHRiYjOkOYWbA7AbOi7BCLi9ohYGhHnSrpW0hMRcX1lyQDUgusEgOQqeQNRRPxE0k+qmAWgXuwJAMlRAkBylACQXP2LipQuojF5qDxDRYs8eKB8QZA5IyPlORYtLJ5x+UNbimdI0hOXnF0847TxTeVB5s8vn1HB91eSJn/3u/IhVS1wchTsCQDJUQJAcpQAkBwlACRHCQDJUQJAcpQAkBwlACRHCQDJUQJAcpQAkBwlACRHCQDJUQJAcpQAkBwlACRX/6IiVSwK0hIxMVE+5Kwzi0es+P7Pi2f86PJlxTMkKcbKF9AYOOt9xTMm9/66fMb+/cUzTgbsCQDJUQJAcpQAkBwlACRHCQDJFZWA7YW2N9jebnub7UuqCgagHqUvEX5V0qMR8Te2hyWVL6IPoFZ9l4Dt0yVdJulGSYqIMUlj1cQCUJeSw4FlkvZK+pbt52zfbXtBRbkA1KSkBAYlfUjSNyJiuaR9km6b+iTba2yP2h4d18GCzQGYDSUlsEvSroh4qnd/g7ql8C4RsTYiOhHRGdLcgs0BmA19l0BEvCppp+3zew+tkLS1klQAalP66sDnJK3rvTLwS0l/Wx4JQJ2KSiAiNknqVBMFQBO4YhBIjhIAkqMEgOTqX1moBTwwUM2cueUveX720UeKZ/zbX64snqGJN8tnSNKh8pWjJl9/o3zGgQPFM7JgTwBIjhIAkqMEgOQoASA5SgBIjhIAkqMEgOQoASA5SgBIjhIAkqMEgOQoASA5SgBIjhIAkqMEgOQoASC5k29REbt4xMDisyoIIi3/j53FM/71+k+UB3lxS/GIgdNPK89Rkcn9+5uOkAp7AkBylACQHCUAJEcJAMkVlYDtW2xvsb3Z9nrb86oKBqAefZeA7SWSPi+pExEXShqQdG1VwQDUo/RwYFDSfNuDkkYkvVIeCUCdSn41+cuSvizpJUm7Jf02Ih6rKhiAepQcDpwhabWkZZLOlrTA9vVHed4a26O2R8d1sP+kAGZFyeHAlZJ+FRF7I2Jc0gOSPjL1SRGxNiI6EdEZUvmv7QJQrZISeEnSxbZHbFvSCknbqokFoC4l5wSekrRB0rOS/rs3a21FuQDUpOgNRBFxp6Q7K8oCoAFcMQgkRwkAyVECQHL1Lipiy0PDZSOGh4pjXPbojuIZkvTk6j8qnuGd24tnzKlgQRCPzC+eIUnx5puVzClWweIziiifcRJgTwBIjhIAkqMEgOQoASA5SgBIjhIAkqMEgOQoASA5SgBIjhIAkqMEgOQoASA5SgBIjhIAkqMEgOQoASC5WhcV8cAczTntlKIZV/7ni8U5HvnMZcUzJGnwtWoWJyl1qIqFPKpaDGTOQPmMyUPlM5IsCFIF9gSA5CgBIDlKAEiOEgCSm7YEbN9re4/tzYc9tsj2Rts7eh/PmN2YAGbLiewJ3Cdp5ZTHbpP0eEScJ+nx3n0AJ6FpSyAinpT0xpSHV0u6v3f7fkkfqzYWgLr0e05gcUTs7t1+VdLiivIAqFnxicGICEnHvDLD9hrbo7ZHxyYPlG4OQMX6LYHXbL9fknof9xzriRGxNiI6EdEZnjOvz80BmC39lsCDkm7o3b5B0g+qiQOgbifyEuF6ST+VdL7tXbZvkvTPkv7C9g5JV/buAzgJTfsGooi47hifWlFxFgAN4IpBIDlKAEiOEgCSowSA5GpdWei8C97Wwz98omjG1R9eVZxjcO/24hmSNHnwYPmQtqyAY1czp4pVgVAr9gSA5CgBIDlKAEiOEgCSowSA5CgBIDlKAEiOEgCSowSA5CgBIDlKAEiOEgCSowSA5CgBIDlKAEiOEgCSq3VRkXFNas+hfUUz4tSR8iCvt6j7qljMo4qFSdqyuAlq16J/DQCaQAkAyVECQHKUAJDcifwuwntt77G9+bDHvmR7u+3nbX/P9sJZTQlg1pzInsB9klZOeWyjpAsj4oOSfiHp9opzAajJtCUQEU9KemPKY49FxETv7s8kLZ2FbABqUMU5gU9LeqSCOQAaUFQCtu+QNCFp3XGes8b2qO3R11+fLNkcgFnQdwnYvlHSKkmfjDj25WYRsTYiOhHROfNMXowA2qavy4Ztr5R0q6Q/j4j91UYCUKcTeYlwvaSfSjrf9i7bN0n6mqRTJW20vcn2N2c5J4BZMu2eQERcd5SH75mFLAAawEE6kBwlACRHCQDJ+Tiv7lW/MXuvpP+b5ml/IOnXNcSZDjmO1JYs5DjSdFn+MCLed7RP1FoCJ8L2aER0yNGuHFJ7spDjSCVZOBwAkqMEgOTaWAJrmw7QQ44jtSULOY7Ud5bWnRMAUK827gkAqBElACRHCQDJUQJAcpQAkNz/A2Ha74OS909WAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "string = \"abacadabacc\" # try something\n", "reversed_string = reversed(string)\n", "\n", "src = to_tensor(string)\n", "tgt = to_tensor(reversed_string)\n", "\n", "with torch.no_grad():\n", " _, alignment = attn_model(src, tgt)\n", "\n", "attn_matrix = alignment.squeeze(0).numpy()\n", "plt.matshow(attn_matrix)" ] }, { "cell_type": "code", "execution_count": null, "id": "31f9df9f", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "87029bc8", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.2" } }, "nbformat": 4, "nbformat_minor": 5 }