More efficient way to multiply every column of the 2nd matrix by every slice of the 3rd matrix
I have an 8x8x25000 W array and an 8 x 25000 r array. I want each of the 8x8 W slices of each column (8x1) of r and store the result in Wres, which will eventually become an 8x25000 matrix.
I accomplish this using a for loop as such:
for i in range(0,25000):
Wres[:,i] = np.matmul(W[:,:,i],res[:,i])
But this is slow and I hope there is a faster way to accomplish this.
Any ideas?
source to share
Matmul can spread as long as 2 arrays have the same axis length. From the docs:
If any of the arguments ND, N> 2, it is treated as a stack of matrices in the last two indices and passed accordingly.
Thus, before matmul
you must perform 2 operations:
import numpy as np a = np.random.rand(8,8,100) b = np.random.rand(8, 100)
- transpose
a
andb
so that the first axis is 100 slices - add an extra dimension to
b
tob.shape = (100, 8, 1)
Then:
at = a.transpose(2, 0, 1) # swap to shape 100, 8, 8
bt = b.T[..., None] # swap to shape 100, 8, 1
c = np.matmul(at, bt)
c
now 100, 8, 1
, go back to 8, 100
:
c = np.squeeze(c).swapaxes(0, 1)
or
c = np.squeeze(c).T
And one last, one-liner for convenience:
c = np.squeeze(np.matmul(a.transpose(2, 0, 1), b.T[..., None])).T
source to share
An alternative to use np.matmul
is np.einsum
which can be done in 1 short and arguably more acceptable line of code without method chaining.
Examples of arrays:
np.random.seed(123)
w = np.random.rand(8,8,25000)
r = np.random.rand(8,25000)
wres = np.einsum('ijk,jk->ik',w,r)
# a quick check on result equivalency to your loop
print(np.allclose(np.matmul(w[:, :, 1], r[:, 1]), wres[:, 1]))
True
The timing is equivalent to @ Imanol's solution, so make a choice. Both are 30 times faster than a cycle. It einsum
will be competitive here because of the size of the arrays. With arrays larger than these, it will likely win and lose for smaller arrays. For details see.
def solution1():
return np.einsum('ijk,jk->ik',w,r)
def solution2():
return np.squeeze(np.matmul(w.transpose(2, 0, 1), r.T[..., None])).T
def solution3():
Wres = np.empty((8, 25000))
for i in range(0,25000):
Wres[:,i] = np.matmul(w[:,:,i],r[:,i])
return Wres
%timeit solution1()
100 loops, best of 3: 2.51 ms per loop
%timeit solution2()
100 loops, best of 3: 2.52 ms per loop
%timeit solution3()
10 loops, best of 3: 64.2 ms per loop
Credit : @ Divakar
source to share