How to apply function to 2D numpy array with multiprocessing
Suppose I have the following function:
def f(x,y):
return x*y
How can I apply funtion to each element of a NxM 2D numpy array using the multiprocessing module? Using sequential iteration, the code might look like this:
import numpy as np
N = 10
M = 12
results = np.zeros(shape=(N,M))
for x in range(N):
for y in range(M):
results[x,y] = f(x,y)
source to share
Here's how you can parallelize your example function using multiprocesssing
. I've also included a nearly identical pure Python function that uses non-parallel loops for
and a one-liner one-liner that achieves the same result:
import numpy as np
from multiprocessing import Pool
def f(x,y):
return x * y
# this helper function is needed because map() can only be used for functions
# that take a single argument (see http://stackoverflow.com/q/5442910/1461210)
def splat_f(args):
return f(*args)
# a pool of 8 worker processes
pool = Pool(8)
def parallel(M, N):
results = pool.map(splat_f, ((i, j) for i in range(M) for j in range(N)))
return np.array(results).reshape(M, N)
def nonparallel(M, N):
out = np.zeros((M, N), np.int)
for i in range(M):
for j in range(N):
out[i, j] = f(i, j)
return out
def broadcast(M, N):
return np.prod(np.ogrid[:M, :N])
Now let's look at the performance:
%timeit parallel(1000, 1000)
# 1 loops, best of 3: 1.67 s per loop
%timeit nonparallel(1000, 1000)
# 1 loops, best of 3: 395 ms per loop
%timeit broadcast(1000, 1000)
# 100 loops, best of 3: 2 ms per loop
The non-parallel pure Python version outperforms the parallel version by about 4x, and the version using the numpy broadcast array absolutely suppresses the other two.
The problem is that starting and stopping Python subprocesses has quite a lot of overhead, and your test function is so trivial that each worker thread only spends a fraction of its life doing useful work. Multiprocessing only makes sense if each thread does significant work before it gets killed. For example, you could give each worker a larger chunk of the output array to compute (try messing up the parameter chunksize=
before pool.map()
), but with such a trivial example, I doubt you will see much improvement.
I don't know what your real code looks like - maybe your function is large and expensive enough to warrant multiprocessing. However, I would argue that there are much better ways to improve its performance.
source to share