{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Stein Variational Gradient Descent\n", "\n", "This notebook showcases how to use Stein Variational Gradient Descent to sample \n", "from the Banana function introduced in the paper Relativistic Monte Carlo. \n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD8CAYAAACfF6SlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xd4FOX6xvHvkxCKUkIJhBKlRylSpdgIgoKigPVw9Aio\nR5SD5SioIAiKogj2LioCyk/NERWwoQiRo4gggkiLoNSINGmBJKS8vz924UQpIewum+zcn+vK5ew7\nszPvg8m9M++UNeccIiLiLVHh7oCIiJx4Cn8REQ9S+IuIeJDCX0TEgxT+IiIepPAXEfGggMPfzBLM\nbLaZLTezZWZ2h7+9kpl9YWar/P+tGHh3RUQkGCzQ6/zNrDpQ3Tn3g5mVAxYCPYG+wB/OudFmNhio\n6Jy7N9AOi4hI4ALe83fObXLO/eCf3gOsAGoCPYCJ/sUm4vtAEBGRIiDgPf8/rcysNjAHaAKsd87F\n+tsN2HHg9V/e0w/oB1C6dOlWp5xyStD6U9Tk5eURFRW5p1lUX/EVybVB5Nf3888/b3POxRXmPUEL\nfzMrC3wFjHLOvW9mO/OHvZntcM4dddw/MTHRpaamBqU/RVFKSgpJSUnh7kbIqL7iK5Jrg8ivz8wW\nOudaF+Y9QfkoNLMYYAow2Tn3vr95s/98wIHzAluCsS0REQlcMK72MeB1YIVz7sl8s6YBffzTfYCp\ngW5LRESCo0QQ1nE2cB3wk5kt9rfdB4wGks3sRmAdcHUQtiUiIkEQcPg7574G7AizOwW6fhHxhuzs\nbDZu3EhmZmbQ112hQgVWrFgR9PWeaKVLl6ZWrVrExMQEvK5g7PmLiARs48aNlCtXjtq1a+MbTQ6e\nPXv2UK5cuaCu80RzzrF9+3Y2btxInTp1Al5f5F77JCLFSmZmJpUrVw568EcKM6Ny5cpBOzJS+ItI\nkaHgP7pg/vso/EVEPEjhLyLiFx0dTfPmzWnWrBktW7Zk7ty54e5SyOiEr4iIX5kyZVi82HfF+owZ\nMxgyZAhfffVVmHsVGtrzFxE5jN27d1Oxou+JNOnp6XTq1ImWLVvStGlTpk713bO6du1aTj/9dG66\n6SYaN27MhRdeSEZGBgCvvvoqZ555Js2aNeOKK65g3759APTt25fbb7+ds846i7p16/Lee+8ddRuh\noj1/ESlyHpy+jOW/7Q7a+nJzc2maUJERlzY+6nIZGRk0b96czMxMNm3axKxZswDf9fUffPAB5cuX\nZ9u2bbRr147u3bsDsGrVKt5++21effVVrr76aqZMmcI//vEPLr/8cm666SYAhg0bxuuvv85tt90G\nwKZNm/j6669ZuXIl3bt358orrzziNkJ1ElzhLyLil3/Y59tvv6V3794sXboU5xz33Xcfc+bMISoq\nirS0NDZv3gxAnTp1aN68OQCtWrVi7dq1ACxdupRhw4axc+dO0tPT6dKly8Ht9OzZk6ioKBo1anRw\nPUfaRnx8fEhqVfiLSJFT0B56YR3PTV7t27dn27ZtbN26lU8++YStW7eycOFCYmJiqF279sHr7UuV\nKnXwPdHR0QeHffr27cuHH35Is2bNmDBhAikpKQeXy/+eA09Wnjx58hG3EQoa8xcROYyVK1eSm5tL\n5cqV2bVrF1WrViUmJobZs2ezbt26At+/Z88eqlevTnZ2NpMnTy5w+ePZRiC05y8i4ndgzB98e+QT\nJ04kOjqaa6+9lksvvZSmTZvSunVrTjvttALX9dBDD9G2bVvi4uJo27Yte/bsOeryx7ONQCj8RUT8\ncnNzD9tepUoVvv3228POW7p06cHpQYMGHZzu378//fv3P2T5CRMm/Ol1enp6gdsIBQ37iIh4kMJf\nRMSDFP4iUmQE6zvFI1Uw/30U/iJSJJQuXZrt27frA+AIDjzPv3Tp0kFZn074ivg559i9fQ/bf9vB\nzq27Sd+RTubeLLKzssnLzWPV6lVkrnGUObkUJ1U4idi48lSMj6VitQpER0eHu/vFXq1atdi4cSNb\nt24N+rozMzODFprhdOCbvIIhKOFvZuOBS4Atzrkm/rZKwLtAbWAtcLVzbkcwticSqO2bdpC6YDW/\nLFrLmmXr2bAyjd9/3ULmvqyjvu9TDn3IV4mYaKrVrkpCYg1ObVSLes3r0LB1XWrUi9fz6QshJiYm\nKN9QdTgpKSm0aNEiJOsuroK15z8BeB6YlK9tMPClc260mQ32v743SNsTKZQdm3fy/ec/sujLn/hp\nznJ+X+vbu3RAdmw59leOJfv0euRUOJmcsieTe1JpckuXJK9kDC46GgxuOS2HV5YZUftziMraz2tX\nNOaPTTvZsn4rv/26mQ0r0/h+xmJysn2XC8bGlafJuafTvGMTWndpRs361cP4LyDyZ0EJf+fcHDOr\n/ZfmHkCSf3oikILCX06gTb9u5qvkuXwzdQErv1sFQG6ZUmQkxJPR8Uwyq1chq2olXMlj+zLs0hVz\nyKnwvz+ZbbVr0fOytn9aJnt/NuuWbyR1/mqWzU1lyVfL+fr97wBISKzBWT3O5Lyr2tOgZV0dFUhY\nhXLMv5pzbpN/+negWgi3JQLA3l17mfV/X/P5pK8OBn5mfGXSz2nBvrq1yKpWCYIUumNnpNKzRc0/\ntcWUjKF+8zrUb16Hbv0uwDnHb7/8zoJPF/Pt9AW89+RHvDtmKjUbVOeC6zpwQZ8OVE2oEpT+iBSG\nBevMun/P/6N8Y/47nXOx+ebvcM5VPMz7+gH9AOLi4lolJycHpT9FUXp6OmXLlg13N0ImnPX9vnor\n33/wE8tmrSInK4dyCbHUOqsONdudyklVgtOnamVgc8af25rWrFCodWTszmTlf39l6cxU1v/4GxZl\n1G97Kq17NqVO64SwHQ3od7N469ix40LnXOvCvCeU4Z8KJDnnNplZdSDFOZd4tHUkJia61NTUoPSn\nKEpJSSEpKSnc3QiZE12fc47vZywmeexUFs9ehospwe7T67KrWUOy4isHbQ8f/9/IXU1zeHLp/4aI\nasaW4ZvB5x/3ajf9uplPX/+ST1+fxc4tu6jdOIErB17K+decQ8wxDkUFi343izczK3T4h3LYZxrQ\nBxjt/29ov5ZGPMM5x7fTv+etkf9h1Q9roMLJbO3Qmt3NGpBXulTBKzggN4+SO3ZRcvsuYnbsJmZX\nOiX27CN6bwbRmVlEZe7HcnKIys0DYDpQP8rIK1GCvFIxxFePZXDKXKrUqES1U+OoUT+eWv4rfkqf\nVHA/qtetxg2jruEfw6/iq3fn8t6T03n8hheZ9EAyvQZfRtcbOp7wDwHxjmBd6vk2vpO7VcxsIzAC\nX+gnm9mNwDrg6mBsS7xt8eylvD5kMivnr6ZUfEU2dz2b3Y3rQkHX2TtHiV3plNm4mTJpWyi1aRul\ntu3E8vIOLpJTphQ55U4m9+Qy7K9cgbzSJcmLKeG/2sdoXzWXeZvAsnNoX6McseSy4/edrF26nj82\n7Tx4c5KZUathdRqeWY9G7RJpcs5p1G6SQFTU4e+pLFkqhgt6d6Dzdecx/9NFTH74PZ7916skj/mQ\nPiN7cf415xzxvSLHK1hX+/z9CLM6BWP9IhtS0xh395vM+2ghcbUqs7nrWexuXB+ijxyKtj+bk9Zt\n4uRfN3LS2t+I2eV7emJuqRiy4quwo3Uj9sdVZH/lWPZXLI8rFcPa0d2OuL6UlBReOMLQwf7M/Wz6\ndTPrV6Sx5qf1rF68hkUzf+LLt/4LQIUq5WjR+QzadG1Bm4tbUKFK+UP7a0bbi1vS5qIWfD9jMeOH\nvs1jvZ9jylMfccuTfWjWIbhfcCLepjt8pUjLSM/grZHvMeXpjylVpiTbOrRidavTcSUO/6tr2Tmc\nvHoD5Vau4aQ1aUTl5JJbMoaMU+LZcWZjMhLi2V8l9uD5gKOFfWGULF2SUxslcGqjBM69oh3gG57a\nvG4rS75azqJZP7Hw8x9JeecboqKMpuc1osNV7TnvqvaHfBCYGWd2bUGrC5sx++1veP2+yQzq+ABJ\nfzuLmx/vTZWalYPSZ/E2hb8UWXOnLeD5215n64btdOnbkZdOrkruyWUOXdA5Sv+2lfJLVlE2dS3R\n+7PJObkMu5s2IL3hKWTUqvanYaFgBX5BzIz42lWJr12VC/skkZeXx+pFa5j74QLmTJnHswNe44U7\n3qBtt5Z0veF82lzUgugS/+tnVFQUna49l3Mub0PymGm889gHfPfxD1z/8N/pPqCLHikhAVH4S5Gz\nc+suXrh9PCnvzqV2kwSGvn0n3aatPmQ5y86h3PJfif1hBaW27iAvpgTpibXZ3bgeGQnVIN84eflS\n0Sx5sOuJLOMQUVFRNGxVj4at6tFn5N/4dck6Zr45hy8nz2Hu1AXE1apMt5svoFu/zsTG/e8S0lJl\nSnHdiKvofN15PHfra7z47zeY/c7XDHz9X5x6enCe8yLeo/CXIuWbD+fz9M2vkL5zL31H9uLqe7pz\nztiUPy0TtS+T2B9WELtoJdEZWWRVrcjmC9uz5/S6uFJ/vjrmRO3lF5aZUa9Zbeo1q82Nj17DvI8W\nMv2lGUy4/x3+b9QULuidxFWDLv3TIyGq163GqI/vY9b/fc0Ld4ynf8t7uP7hv3PFnd10QlgKTeEv\nRUJGegYv/nsCn42fRf0WdRjz5QjqNDkFgM179gMQvTeDivOXUmFxKlHZOaTXS2DnmY19e/l/uaa/\nqIb+4ZSIKcE5l7XlnMvasm7FRt5/6iM+n5jCp6/NJKnX2Vw77EpOOc13J7GZ0enac2nZuSlP3zKO\ncXdPYv4nC7ln4m3E1dK5ADl22l2QsPvlx7X8q/W9zHhjNr0GX8az3446GPwAUZlZVJ6zkNrjphD7\n/XLSG5zCuht6sumKTmScEn/Ym7nOGPHZiSwhaE49vRZ3jruFt9a8wBV3XsLcqQu4qcmdjOn7PL+v\n3XJwuYrVYnng/bsZ+Fp/Vs5fzc3NBzF32oIw9lyKG4W/hI1zjk9enclt7e4jIz2TMTOHc+Mj1xy8\nsSl7fzbvP/0xtce9T6V5P7G3fgLrbuzJ5kvO812xcxS7sw7/RdzFRaX4ivQb25s3f/V9CHyVPJcb\nTruDVwZNYs8O3yWrZkbXG87npR/GUu3UOEb0HMMrgyaRk50T5t5LcaDwl7DYn7mfJ258iadufoWm\n553OSz+MpXnHJgfnz/90Ef3OGMhLd03A1azCuj6X8vulHciudOzP0qk9+GPOHj2LDxelhaKEEyI2\nrgL9xvZmws/Pcf415zLlqY/o2/B2PnrlC3JzfR9wtRpU55m5o7i0fxfee3I693QeyY7NO8Pccynq\nFP5ywm1L285dHYYzY8Jsrh12BY98ch8Vq/pCfcv6rTxw+RiGdnsEgIc/GsKsn5+i4nE+Cz9tZwZD\n3v+pWH8AAMTVqsyg8f/ixYWPcWrjWjzTfxy3tbuP1O9/AXx3Cd/+wj8Z/Obt/Pz9L/yr9b2kLjj0\nCimRAxT+ckKtnL+KAWcOZv2KNB54/276juxFdHQ0ubm5vP/0x9zY+E4Wfr6EGx+5hnFLnqDtxS0x\nM74begFrR3c7+HN2vUrHvM2M7Fz+/e7iYn8UAFC/eR2emP0gQybfwfa0P7i93RBe/PcbZKT7Hjfa\n6dpzeWbuKKJLRHNXh+HMfuebMPdYiiqFv5wwc977loFJIyhZpiTPzB3F2T3bALB+ZRp3nns/L901\ngTM6NOLVpU/Sa/BlR32o2eSb2hf6ip5IOQowM87/+zmMX/E03W6+kA+f+5R+Zwzkhy9/AqBes9o8\nP/9REs+szyPXPM2bI/+jL0WXQyj8JeScc/zniek8dPWT1G9Zl+fmPUKdJqeQl5fHe09O55YWd7Px\n500MfvN2Hp4+hPjaVY953YX9AIiko4CTK5zM7S/8kyfnjKREyRLce8FInh3wGhl7M4mNq8Doz+/n\ngt4dmPRAMo/f+KJOBMuf6Dp/Cam8vDzGDZrElKc/5twr2zF40m2ULF2SrRu3M6bPcyyevYz23Vvz\n75f7USn+kO/6OSZrR3fjw0VpjJ2RStrODAzfd/MezYGjAOCQb+MqbpqcfRovLxrLG0PfZsrTH7Po\nyyUMfusOElvX4+43BhBfuypvjvwPO37fyf3Jd1Gm7GEekSGeoz1/CZmc7BzG9H2eKU9/TM/bLmLY\nO3dSsnRJvvlwPjc3G8jK+au569VbePCDe447+A/o2aIm3ww+n7Wju/HU35pTM7bggMvIzmVg8o/U\niYCrgkqVKcUtT/Zl7JcjyNq3nzvOGkry2Kk45+j9wNXcOe4WFn7+I/d0Hsnu7XvC3V0pAhT+EhL7\nM/fz4JWP8+Vb/6XvQ73419PXk5Odywu3j+eBy8cSX7caL/0wlotu7BT0ry488EHw9N+aUybm6A8/\ny3UOR+ScD2jesQmv/Pg4Z/Vozav3vsWwSx5l17bdXPzPToyYcje//LiOgUkj2L5pR7i7KmGm8Jeg\ny9yXxf3dRzNv+kJue/6fXDv0CrZu2MZd593Ph89/yuV3dOOZbx6mVoPju3zzWPVsUZNHL296TEcB\n4DsSGDuj+H+NaLmKZbk/eSC3v/BPFs9aSv+W97B83s+c1eNMRn08hN/XbuGuDsPZsmFbuLsqYaTw\nl6DK2JvJsEseZfGspdz9xgC6/6sLP8xcQv9W97Ih9TeGvzeI/k/1PWFfT1iYowCA33ZmFLhMcWBm\nXNq/i++yz5hoBnYYzvSXZtC8YxMe+/x+dm3dzcCkEWxetzXcXZUwUfhL0OzPyGbYJY/y05zl3Dvp\nNi7o3YH/PDGdIV0fplJ8LC8seIxzL28blr7lPwowIPoIQ001/EcJHy5K4+zRs4r9+YAGLevy4veP\n0fKCM3h2wGs8edPL1G9Zl8e+GE76jr0M6jiCLev1AeBFCn8JiqyMLP4z7GOW/ncF9755O+dc3pax\n17/AuLsncfZlbXj221EhH+YpyIGjgDWju/HE1c0OORIoExPN3V0S+XBRGkPe/4m0nRl/Oh+wMyM7\nPB0PULmKZXlo2mCuHXoFn42fxd2dHqRqQmUe++J+9uzYy6DzH2T31vRwd1NOsJCHv5l1NbNUM1tt\nZoNDvT058bL3ZzPyqidYuziNQW8MoEWnptzdeSRfTPqK3g9czf3JA4vc5YV/PRKoGVuGRy9vSs8W\nNRk7I5WM7D8/GC4jO5fNuzLD09kgiIqKou9DvRj27l38smgNt7YdQkzJEoyeMYxdW3fzf4OmsmPL\nrnB3U06gkIa/mUUDLwAXAY2Av5tZo1BuU06s3NxcHuv9HPM/WcTFdybRsHU9bm83hF8WreH+5Lu4\nbvhVQb+aJ1jyHwl8M/j8g9f7H2ncf39u3onsXkh0uKo9T84ZSW5OLnecPYw9f6Qz6uMh7NqSzuAu\nD5G+c2+4uygnSKj3/NsAq51zvzrn9gPvAD1CvE05QZxzPH/r63yV/C39xlxHpYRY/n32MLIy9vNE\nyoOcd2X7cHfxuNQ4wtVBJaMjY5S0Yat6PP/do9SoH8+wS0ezfkUaV468iPXLNzK8x2NkZWSFu4ty\nAlgon/lhZlcCXZ1z//S/vg5o65y7Nd8y/YB+AHFxca2Sk5ND1p9wS09Pp2zZsuHuRtDMmTif/05c\nQPteLYlvUIVpj86kYs0K/O3RS4iNLx/u7h23nRnZpO3IIC/f30aUGTXLGrHly4WxZ8GVtW8/H4yc\nwS/z19Pm6jOomRjPBw9/TsOz6nDFA12JipAPO4i8v72/6tix40LnXOvCvCfsj3dwzo0DxgEkJia6\npKSk8HYohFJSUoiU+j4bP4v/TlzABX06UL9ZHV66awIJTavzTMojlKtY/P/IDjwu4redGdSILcPd\nXRKJ3bUqYv7/HdD5wk48fcs4Phs/i67XV+aWx/vw8sCJLH3/V257/sYiO2RXWJH0txcsoQ7/NCAh\n3+ta/jYpxhZ+8SNP3fwKLS84g4pxFXjprgmcfVkbzrm5RUQEP/jOB/z1mT8pKavC1JvQiS4RzV2v\n3kJ69m4+e2M27bu3pseArkx94TOq163GVQMvDXcXJURCfVy3AGhgZnXMrCTQC5gW4m1KCK1bsZGH\nrn6ShMQalK9cjuTHp9Htps7cn3wXJUqG/UBSjoOZ0eH6ttz63I3Mm76QX5eso9WFzXj1njeZO1Xf\nCxypQhr+zrkc4FZgBrACSHbOLQvlNiV0dv+xh+HdRxMVZZxUvgwp73zD34dcxh0v9yM6uuC7Z6Vo\n6zGgK0Mm38Hyb39m64ZtVDu1Co/+4xnW/LQu3F2TEAj5GR3n3CfOuYbOuXrOuVGh3p6ERm5uLqP+\n/jS//bIZ52DFvFX0G3MdN4y6JmLGhQU69jqbBz+8h9/XbGH3H+lk7s1ieM8x7P5DTwKNNJFzOl9C\nauLwd/nhiyUA7N21jztfuZmrBnUPc68kFNpe3JJHPxt28EsRfl+zhdHXPUdeXvG/z0H+R+EvBVrw\n2SLefvQDwHeCcMjkO7j4ps5h7pWE0hnnNWLMlyMoX9l3aeuCTxfx7mNTw9wrCSaFvxzVrm27ue/i\nRw6+fuD9u+nY6+ww9khOlMTW9Xgi5UEq1/B90c74of9H6ve/hLlXEiwKfzmqEZeNOTj9RMqDtLuk\nVRh7Iyda7cYJPP31w8TG+W7aG9RxRJh7JMGia/PkqJKuPptK8bH0efBvnNoooeA3SMSJr12V15Y9\nxaP/eJa9evZPxFD4y1H1vO0iet52Ubi7IWFWoUp5Rn82jFA+DkZOLA37iMgx02W9kUPhLyLiQQp/\nEREPUviLiHiQwl9ExIMU/iIiHqTwFxHxIIW/iIgHKfxFRDxI4S8i4kEKfxERD1L4i4h4UEDhb2ZX\nmdkyM8szs9Z/mTfEzFabWaqZdQmsmyIiEkyBPtVzKXA58Er+RjNrBPQCGgM1gJlm1tA5lxvg9kRE\nJAgC2vN3zq1wzqUeZlYP4B3nXJZzbg2wGmgTyLZERCR4QvU8/5rAvHyvN/rbDmFm/YB+AHFxcaSk\npISoS+GXnp6u+oqxSK4vkmuDyK/veBQY/mY2E4g/zKyhzrmAv9HZOTcOGAeQmJjokpKSAl1lkZWS\nkoLqK74iub5Irg0iv77jUWD4O+c6H8d604D83/lXy98mIiJFQKgu9ZwG9DKzUmZWB2gAzA/RtkRE\npJACvdTzMjPbCLQHPjazGQDOuWVAMrAc+AwYoCt9RESKjoBO+DrnPgA+OMK8UcCoQNYvIiKhoTt8\nRUQ8SOEvIuJBCn8REQ9S+IuIeJDCX0TEgxT+IiIepPAXEfEghb+IiAcp/EVEPEjhLyLiQQp/EREP\nUviLiHiQwl9ExIMU/iIiHqTwFxHxIIW/iIgHKfxFRDxI4S8i4kGBfofvWDNbaWZLzOwDM4vNN2+I\nma02s1Qz6xJ4V0VEJFgC3fP/AmjinDsD+BkYAmBmjYBeQGOgK/CimUUHuC0REQmSgMLfOfe5cy7H\n/3IeUMs/3QN4xzmX5ZxbA6wG2gSyLRERCZ4SQVzXDcC7/uma+D4MDtjobzuEmfUD+gHExcWRkpIS\nxC4VLenp6aqvGIvk+iK5Noj8+o5HgeFvZjOB+MPMGuqcm+pfZiiQA0wubAecc+OAcQCJiYkuKSmp\nsKsoNlJSUlB9xVck1xfJtUHk13c8Cgx/51zno803s77AJUAn55zzN6cBCfkWq+VvExGRIiDQq326\nAvcA3Z1z+/LNmgb0MrNSZlYHaADMD2RbIiISPIGO+T8PlAK+MDOAec65W5xzy8wsGViObzhogHMu\nN8BtiYhIkAQU/s65+keZNwoYFcj6RUQkNHSHr4iIByn8RUQ8SOEvIuJBCn8REQ9S+IuIeJDCX0TE\ngxT+IiIepPAXEfEghb+IiAcp/EVEPEjhLyLiQQp/EREPUviLiHiQwl9ExIMU/iIiHqTwFxHxIIW/\niIgHKfxFRDwo0C9wf8jMlpjZYjP73Mxq+NvNzJ41s9X++S2D010REQmGQPf8xzrnznDONQc+Aob7\n2y8CGvh/+gEvBbgdEREJooDC3zm3O9/LkwHnn+4BTHI+84BYM6seyLZERCR4SgS6AjMbBfQGdgEd\n/c01gQ35Ftvob9t0mPf3w3d0QFxcHCkpKYF2qchKT09XfcVYJNcXybVB5Nd3PMw5d/QFzGYC8YeZ\nNdQ5NzXfckOA0s65EWb2ETDaOfe1f96XwL3Oue+Ptq3ExESXmppa2BqKjZSUFJKSksLdjZBRfcVX\nJNcGkV+fmS10zrUuzHsK3PN3znU+xnVNBj4BRgBpQEK+ebX8bSIiUgQEerVPg3wvewAr/dPTgN7+\nq37aAbucc4cM+YiISHgEOuY/2swSgTxgHXCLv/0T4GJgNbAPuD7A7YiISBAFFP7OuSuO0O6AAYGs\nW0REQkd3+IqIeJDCX0TEgxT+IiIepPAXEfEghb+IiAcp/EVEPEjhLyLiQQp/EREPUviLiHiQwl9E\nxIMU/iIiHqTwFxHxIIW/iIgHKfxFRDxI4S8i4kEKfxERD1L4i4h4kMJfRMSDghL+ZjbQzJyZVfG/\nNjN71sxWm9kSM2sZjO2IiEhwBBz+ZpYAXAisz9d8EdDA/9MPeCnQ7YiISPAEY8//KeAewOVr6wFM\ncj7zgFgzqx6EbYmISBCUCOTNZtYDSHPO/Whm+WfVBDbke73R37bpMOvoh+/ogLi4OFJSUgLpUpGW\nnp6u+oqxSK4vkmuDyK/veBQY/mY2E4g/zKyhwH34hnyOm3NuHDAOIDEx0SUlJQWyuiItJSUF1Vd8\nRXJ9kVwbRH59x6PA8HfOdT5cu5k1BeoAB/b6awE/mFkbIA1IyLd4LX+biIgUAcc95u+c+8k5V9U5\nV9s5Vxvf0E5L59zvwDSgt/+qn3bALufcIUM+IiISHgGN+R/FJ8DFwGpgH3B9iLYjIiLHIWjh79/7\nPzDtgAHBWreIiASX7vAVEfEghb+IiAcp/EVEPEjhLyLiQQp/EREPUviLiHiQwl9ExIMU/iIiHqTw\nFxHxIIW/iIgHKfxFRDxI4S8i4kEKfxERD1L4i4h4kMJfRMSDFP4iIh6k8BcR8SCFv4iIBwUU/mb2\ngJmlmdli/8/F+eYNMbPVZpZqZl0C76qIiARLML7D9ynn3OP5G8ysEdALaAzUAGaaWUPnXG4Qtici\nIgEK1bDYgAFpAAAGsElEQVRPD+Ad51yWc24NsBpoE6JtiYhIIQUj/G81syVmNt7MKvrbagIb8i2z\n0d8mIiJFgDnnjr6A2Uwg/jCzhgLzgG2AAx4CqjvnbjCz54F5zrm3/Ot4HfjUOffeYdbfD+gHEBcX\n1yo5OTmAcoq29PR0ypYtG+5uhIzqK74iuTaI/Po6duy40DnXujDvKXDM3znX+VhWZGavAh/5X6YB\nCflm1/K3HW7944BxAImJiS4pKelYNlcspaSkoPqKr0iuL5Jrg8iv73gEerVP9XwvLwOW+qenAb3M\nrJSZ1QEaAPMD2ZaIiARPoFf7jDGz5viGfdYCNwM455aZWTKwHMgBBuhKHxGRoiOg8HfOXXeUeaOA\nUYGsX0REQkN3+IqIeJDCX0TEgxT+IiIepPAXEfEghb+IiAcp/EVEPEjhLyLiQQp/EREPUviLiHiQ\nwl9ExIMU/iIiHqTwFxHxIIW/iIgHKfxFRDxI4S8i4kEKfxERD1L4i4h4kMJfRMSDFP4iIh4UcPib\n2W1mttLMlpnZmHztQ8xstZmlmlmXQLcjIiLBE9AXuJtZR6AH0Mw5l2VmVf3tjYBeQGOgBjDTzBo6\n53ID7bCIiAQu0D3//sBo51wWgHNui7+9B/COcy7LObcGWA20CXBbIiISJAHt+QMNgXPNbBSQCQxy\nzi0AagLz8i230d92CDPrB/Tzv8wys6UB9qkoqwJsC3cnQkj1FV+RXBtEfn2JhX1DgeFvZjOB+MPM\nGup/fyWgHXAmkGxmdQvTAefcOGCcf1vfO+daF+b9xYnqK94iub5Irg28UV9h31Ng+DvnOh9lg/2B\n951zDphvZnn4PmHTgIR8i9byt4mISBEQ6Jj/h0BHADNrCJTEd2g1DehlZqXMrA7QAJgf4LZERCRI\nAh3zHw+M94/T7wf6+I8ClplZMrAcyAEGHOOVPuMC7E9Rp/qKt0iuL5JrA9V3CPNltYiIeInu8BUR\n8SCFv4iIBxWJ8Dezh8xsiZktNrPPzayGv93M7Fn/YyKWmFnLcPf1eJjZWP8jMJaY2QdmFptvXrF+\nDIaZXeV/tEeembX+y7xiXdsBZtbVX8NqMxsc7v4EyszGm9mW/PfUmFklM/vCzFb5/1sxnH0MhJkl\nmNlsM1vu/928w99e7Gs0s9JmNt/MfvTX9qC/vY6Zfef/HX3XzEoWuDLnXNh/gPL5pm8HXvZPXwx8\nChi+ewm+C3dfj7O+C4ES/unHgMf8042AH4FSQB3gFyA63P0tZG2n47vBJAVona+92NfmryPa3/e6\n+K5m+xFoFO5+BVjTeUBLYGm+tjHAYP/04AO/o8XxB6gOtPRPlwN+9v8+Fvsa/VlY1j8dA3znz8Zk\noJe//WWgf0HrKhJ7/s653flengwcOAvdA5jkfOYBsWZW/YR3MEDOuc+dczn+l/Pw3fcAEfAYDOfc\nCudc6mFmFfva/NoAq51zvzrn9gPv4Kut2HLOzQH++EtzD2Cif3oi0POEdiqInHObnHM/+Kf3ACvw\nPWGg2Nfoz8J0/8sY/48Dzgfe87cfU21FIvwBzGyUmW0ArgWG+5trAhvyLXbEx0QUIzfgO5qByKzv\ngEipLVLqKEg159wm//TvQLVwdiZYzKw20ALfHnJE1Ghm0Wa2GNgCfIHvyHRnvh3MY/odPWHhb2Yz\nzWzpYX56ADjnhjrnEoDJwK0nql/BUlB9/mWG4rvvYXL4elp4x1KbRA7nGzso9teAm1lZYArw77+M\nLhTrGp1zuc655vhGENoApx3PegK9yeuYuaM8JuIvJgOfACMoRo+JKKg+M+sLXAJ08v/iQTGprxD/\n7/IrFrUdg0ipoyCbzay6c26Tf2h1S4HvKMLMLAZf8E92zr3vb46oGp1zO81sNtAe35B4Cf/e/zH9\njhaJYR8za5DvZQ9gpX96GtDbf9VPO2BXvsO2YsPMugL3AN2dc/vyzYrkx2BESm0LgAb+qylK4vue\nimlh7lMoTAP6+Kf7AFPD2JeAmJkBrwMrnHNP5ptV7Gs0s7gDVwuaWRngAnznNGYDV/oXO7bawn32\n2r8TPAVYCiwBpgM1853ZfgHfmNZP5LuapDj94DvZuQFY7P95Od+8of76UoGLwt3X46jtMnxjjFnA\nZmBGpNSWr46L8V0x8gswNNz9CUI9bwObgGz//7sbgcrAl8AqYCZQKdz9DKC+c/AN6SzJ9zd3cSTU\nCJwBLPLXthQY7m+vi2/najXwH6BUQevS4x1ERDyoSAz7iIjIiaXwFxHxIIW/iIgHKfxFRDxI4S8i\n4kEKfxERD1L4i4h40P8D2a8A3wqYVjsAAAAASUVORK5CYII=\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "import sys\n", "import os\n", "sys.path.insert(0, os.path.join(os.path.abspath(\".\"), \"..\", \"..\", \"..\"))\n", "\n", "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", "import numpy as np\n", "from pysgmcmc.samplers.svgd import SVGDSampler\n", "\n", "from pysgmcmc.diagnostics.objective_functions import (\n", " banana_log_likelihood,\n", ")\n", "\n", "from collections import namedtuple\n", "\n", "ObjectiveFunction = namedtuple(\n", " \"ObjectiveFunction\", [\"function\", \"dimensionality\"]\n", ")\n", "\n", "objective_functions = (\n", " ObjectiveFunction(\n", " function=banana_log_likelihood, dimensionality=2\n", " ),\n", ")\n", "\n", "\n", "def cost_function(log_likelihood_function):\n", " def wrapped(*args, **kwargs):\n", " return -log_likelihood_function(*args, **kwargs)\n", " wrapped.__name__ = log_likelihood_function.__name__\n", " return wrapped\n", "\n", "# Banana Contour {{{ #\n", "def banana_plot():\n", " x = np.arange(-25, 25, 0.05)\n", " y = np.arange(-50, 20, 0.05)\n", " xx, yy = np.meshgrid(x, y, sparse=True)\n", " densities = np.asarray([np.exp(banana_log_likelihood((x, y))) for x in xx for y in yy])\n", " f, ax = plt.subplots(1)\n", " xdata = [1, 4, 8]\n", " ydata = [10, 20, 30]\n", " ax.contour(x, y, densities, 1, label=\"Banana\")\n", " ax.plot([], [], label=\"Banana\")\n", " ax.legend()\n", " ax.grid()\n", " ax.set_ylim(ymin=-60, ymax=20)\n", " ax.set_xlim(xmin=-30, xmax=30)\n", " \n", "# }}} Banana Contour #\n", "\n", "\n", "plot_functions = {\n", " \"banana_log_likelihood\": banana_plot,\n", "}\n", "\n", "\n", "def extract_samples(sampler, n_samples=1000, keep_every=10):\n", " from itertools import islice\n", " n_iterations = n_samples * keep_every\n", " return np.asarray(\n", " [np.mean(sample, axis=0) for sample, _ in\n", " islice(sampler, 0, n_iterations, keep_every)]\n", " )\n", "\n", "def plot_samples(sampler, n_samples=5000, keep_every=10):\n", " samples = extract_samples(\n", " sampler, n_samples=n_samples, keep_every=keep_every\n", " )\n", " plot_functions[sampler.cost_fun.__name__]()\n", " \n", " first_sample = samples[0]\n", " try:\n", " sample_dimensionality, = first_sample.shape\n", " except ValueError:\n", " plt.scatter(samples, np.exp([-sampler.cost_fun(sample) for sample in samples]))\n", " else:\n", " plt.scatter(*[samples[:, i] for i in range(sample_dimensionality)])\n", "\n", " \n", "n_particles = 10\n", "\n", "graph = tf.Graph()\n", "\n", "for function, dimensionality in objective_functions:\n", " tf.reset_default_graph()\n", " graph = tf.Graph()\n", " \n", " with tf.Session(graph=graph) as session:\n", " particles = [\n", " tf.get_variable(\"particle_{}\".format(i), (dimensionality,), initializer=tf.random_normal_initializer())\n", " for i in range(n_particles)\n", " ]\n", " sampler = SVGDSampler(\n", " particles=particles,\n", " cost_fun=cost_function(function),\n", " session=session,\n", " dtype=tf.float32\n", " )\n", " plot_samples(sampler, n_samples=5000)\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }