Numpy intersect1d with array with matrix as elements

I have two arrays, one of the shapes (200000, 28, 28)

and the other of the shape (10000, 28, 28)

, so there are practically two arrays with matrices as elements. Now I want to count and get all the elements (in the form (N, 28, 28)

) that overlap in both arrays. With normal for loops, this is a way to slow it down, so I tried using the numpys intersect1d method, but I don't know how to apply it on these types of arrays.

+2


source to share


1 answer


Using the approach of this question about unique strings

def intersect_along_first_axis(a, b):
    # check that casting to void will create equal size elements
    assert a.shape[1:] == b.shape[1:]
    assert a.dtype == b.dtype

    # compute dtypes
    void_dt = np.dtype((np.void, a.dtype.itemsize * np.prod(a.shape[1:])))
    orig_dt = np.dtype((a.dtype, a.shape[1:]))

    # convert to 1d void arrays
    a = np.ascontiguousarray(a)
    b = np.ascontiguousarray(b)
    a_void = a.reshape(a.shape[0], -1).view(void_dt)
    b_void = b.reshape(b.shape[0], -1).view(void_dt)

    # intersect, then convert back
    return np.intersect1d(b_void, a_void).view(orig_dt)

      



Note that use is void

unsafe with floats, as this would result in an -0

unequal0

+4


source







All Articles