Unified way to use array indexing for 0D and 1D numpy arrays
(note: the original question was slightly different, to which a different answer applies, see changelog for the original question.)
Is there a uniform way to index numpy arrays when those arrays can also be scalar?
I am trying to write a function that deals with float, float list or 0 / 1D numpy array. To deal with this evenly, I use numpy.asarray()
which works great overall (I don't mind returning numpy.float64 when the input is a standard Python float).
Problems arise when I need to deal with conditional operations and an intermediate array function, for example:
value = np.asarray(5.5)
mask = value > 5
tmpvalue = np.asarray(np.cos(value))
tmpvalue[mask] = value
This will throw an exception:
Traceback (most recent call last):
File "testscalars.py", line 27, in <module>
tmpvalue[mask] = value
IndexError: 0-d arrays can't be indexed
Is there an elegant solution for this?
source to share
It turns out this issue is specific to numpy 1.8 and before; an upgrade to numpy 1.9 (.2) fixes this.
the numpy 1.9 release notes allow:
Boolean indexing into scalar arrays will always return the new 1st array. This means that array (1) [array (True)] gives an array ([1]), not the original array.
which will conveniently temporarily turn tmpvalue[mask]
into a 1D array, allowing it to be assigned value
:
tmpvalue[mask] = value
source to share
While the actual answer to the question asked is essentially what was hitting me and causing (Type) errors:
value = numpy.asarray(5.5)
mask = value > 5
tmpvalue = numpy.cos(value)
tmpvalue[mask] = value[mask]
The problem is that the value has a type numpy.ndarray
, but since it is a 0-d array, it numpy.cos
returns numpy.scalar
that cannot be indexed.
I think this numpy issue is directly related to this issue.
Currently it seems like the simplest solution is to wrap numpy ufuncs with numpy.asarray
:
value = numpy.asarray(5.5)
mask = value > 5
tmpvalue = numpy.asarray(numpy.cos(value))
tmpvalue[mask] = value[mask]
I successfully tested it with inputs 5.5
, 4.5
, [5.5]
, [4.5]
and [4.5, 5.5]
.
Note that this behavior also applies to even more common operations, such as adding:
>>> x = numpy.asarray(5)
>>> y = numpy.asarray(6)
>>> z = x + y
>>> type(x), type(y), type(z)
(<class 'numpy.ndarray'>, <class 'numpy.ndarray'>, <class 'numpy.int64'>)
source to share