From 1807b436bc0e835e4e0e43a1ddd8e27c1b59cd6c Mon Sep 17 00:00:00 2001 From: Ulrich <ulrich.kerzel@rwth-aachen.de> Date: Wed, 16 Nov 2022 17:10:54 +0100 Subject: [PATCH] Numba example with 2D random walk --- pythonintro/solutions/Solution_Numba.ipynb | 149 +++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 pythonintro/solutions/Solution_Numba.ipynb diff --git a/pythonintro/solutions/Solution_Numba.ipynb b/pythonintro/solutions/Solution_Numba.ipynb new file mode 100644 index 0000000..7dd4058 --- /dev/null +++ b/pythonintro/solutions/Solution_Numba.ipynb @@ -0,0 +1,149 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Fast Python with Numba\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 153, + "metadata": {}, + "outputs": [], + "source": [ + "# all imports\n", + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib as mpl\n", + "import seaborn as sns\n", + "import numba\n", + "from numba import jit\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 154, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.56.4\n" + ] + } + ], + "source": [ + "print(numba.__version__)" + ] + }, + { + "cell_type": "code", + "execution_count": 168, + "metadata": {}, + "outputs": [], + "source": [ + "@jit(nopython=True)\n", + "def random_walk(n_steps = 5000, step_size = 1):\n", + "\n", + " # we always start at (0,0)\n", + " x_points = [0]\n", + " y_points = [0]\n", + "\n", + " # do the random walk:\n", + " for i in range(0, n_steps):\n", + " # choose direction:\n", + " # the following is the same as np.random.choice([-1,1]) but this cannot be optimized with Numba\n", + " x_dir = np.round(2*(np.random.randint(0,2)-0.5)) \n", + " y_dir = np.round(2*(np.random.randint(0,2)-0.5))\n", + "\n", + " # calculate new positions: last position + step_size * direction\n", + " new_x = x_points[-1] + step_size * x_dir\n", + " new_y = y_points[-1] + step_size * y_dir\n", + "\n", + " # append to arrays\n", + " x_points.append(new_x)\n", + " y_points.append(new_y)\n", + "\n", + " # calculate distance between start and end as Eucledian distance\n", + " # bit explicit as numba does not work with the one line we have used before\n", + " x_start = x_points[0]\n", + " y_start = y_points[0] \n", + " x_stop = x_points[-1] \n", + " y_stop = y_points[-1]\n", + " distance2 = (x_stop - x_start )**2 + ( y_stop - y_start )**2\n", + " distance = np.sqrt( distance2)\n", + " \n", + "\n", + " return x_points, y_points, distance\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can compare the timings with and without the ```@jit``` decorator. \\\n", + "Remember that decorators change the behaviour of the function - but we do not have to change the function itself.\n", + "\n", + "In this case, Numba is a specialised package that optimises a function \"behind the scenes\".\n", + "\n", + "Note that the first call includes the optimisation / compile time. If we want to measure the time the optimised function takes, we need to discard the timing from the first call.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 170, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 33.4 ms, sys: 0 ns, total: 33.4 ms\n", + "Wall time: 33 ms\n" + ] + } + ], + "source": [ + "%%time\n", + "distances = []\n", + "\n", + "for i in range(0,200):\n", + " _, _, distance = random_walk()\n", + " distances.append(distance)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.10.6 ('pythonintro-A08DqnGu-py3.10')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "eb3a944512c48027e49906d9e47d21cfb9b2c1ddd33f9b79c72df4b5e3a553dc" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} -- GitLab