Pandas.DataFrame.where alignment
I was looking into pandas indexing and selection and came across a function pandas.DataFrame.where()
. This function has axis
and level
for alignment purpose. For example:
In [21]: df = pd.DataFrame(np.random.randn(8,4), index = pd.date_range('20000101',periods=8), columns = list('ABCD'))
In [22]: df
Out[22]:
A B C D
2000-01-01 -0.222193 0.764096 -2.000947 -1.162589
2000-01-02 -0.387643 -0.497687 0.868227 -0.939663
2000-01-03 1.001708 -0.761496 0.179564 -0.403473
2000-01-04 0.469317 -0.161929 -0.844448 -0.211096
2000-01-05 0.580083 -0.952382 0.105044 -0.648209
2000-01-06 -0.312277 -0.762257 -0.894456 -1.169686
2000-01-07 -1.446776 -1.276171 -1.466667 0.800513
2000-01-08 -0.659035 -0.006725 -1.475503 0.353150
In [23]: df.where(df>0, df.A, axis = 1)
Out[23]:
A B C D
2000-01-01 NaN 0.764096 NaN NaN
2000-01-02 NaN NaN 0.868227 NaN
2000-01-03 1.001708 NaN 0.179564 NaN
2000-01-04 0.469317 NaN NaN NaN
2000-01-05 0.580083 NaN 0.105044 NaN
2000-01-06 NaN NaN NaN NaN
2000-01-07 NaN NaN NaN 0.800513
2000-01-08 NaN NaN NaN 0.353150
In [24]: df.where(df>0, df.A, axis = 0)
Out[24]:
A B C D
2000-01-01 -0.222193 0.764096 -0.222193 -0.222193
2000-01-02 -0.387643 -0.387643 0.868227 -0.387643
2000-01-03 1.001708 1.001708 0.179564 1.001708
2000-01-04 0.469317 0.469317 0.469317 0.469317
2000-01-05 0.580083 0.580083 0.105044 0.580083
2000-01-06 -0.312277 -0.312277 -0.312277 -0.312277
2000-01-07 -1.446776 -1.446776 -1.446776 0.800513
2000-01-08 -0.659035 -0.659035 -0.659035 0.353150
I didn't understand the usage axis
here (to be honest, I didn't get the concept of alignment.) I know that axis = 0 is called "column-wise" and axis = 1 is "-wise". Can anyone explain to me the result (alignment concept) and also use the parameter level
.
source to share
Example for explanation:
np.random.seed(12)
df = pd.DataFrame(np.random.randn(8,4),
index = pd.date_range('20000101',periods=8),
columns = list('ABCD'))
print (df)
A B C D
2000-01-01 0.472986 -0.681426 0.242439 -1.700736
2000-01-02 0.753143 -1.534721 0.005127 -0.120228
2000-01-03 -0.806982 2.871819 -0.597823 0.472457
2000-01-04 1.095956 -1.215169 1.342356 -0.122150
2000-01-05 1.012515 -0.913869 -1.029530 1.209796
2000-01-06 0.501872 0.138846 0.640761 0.527333
2000-01-07 -1.154360 -2.213333 -1.681757 -1.788094
2000-01-08 -2.218535 -0.647431 -0.528404 -0.039209
#boolean mask by condition
print (df>0)
A B C D
2000-01-01 True False True False
2000-01-02 True False True False
2000-01-03 False True False True
2000-01-04 True False True False
2000-01-05 True False False True
2000-01-06 True True True True
2000-01-07 False False False False
2000-01-08 False False False False
#without define value for replace False of mask values create NaNs
print (df.where(df>0))
A B C D
2000-01-01 0.472986 NaN 0.242439 NaN
2000-01-02 0.753143 NaN 0.005127 NaN
2000-01-03 NaN 2.871819 NaN 0.472457
2000-01-04 1.095956 NaN 1.342356 NaN
2000-01-05 1.012515 NaN NaN 1.209796
2000-01-06 0.501872 0.138846 0.640761 0.527333
2000-01-07 NaN NaN NaN NaN
2000-01-08 NaN NaN NaN NaN
To replace the False
mask with some Series
(string) it axis=1
is required - the string is repeated ( broadcasting ):
print (df.loc['2000-01-01'])
A 0.472986
B -0.681426
C 0.242439
D -1.700736
Name: 2000-01-01 00:00:00, dtype: float64
print (df.where(df>0, df.loc['2000-01-01'], axis = 1))
A B C D
2000-01-01 0.472986 -0.681426 0.242439 -1.700736
2000-01-02 0.753143 -0.681426 0.005127 -1.700736
2000-01-03 0.472986 2.871819 0.242439 0.472457
2000-01-04 1.095956 -0.681426 1.342356 -1.700736
2000-01-05 1.012515 -0.681426 0.242439 1.209796
2000-01-06 0.501872 0.138846 0.640761 0.527333
2000-01-07 0.472986 -0.681426 0.242439 -1.700736
2000-01-08 0.472986 -0.681426 0.242439 -1.700736
And to replace with Series
(column) you need axis=0
- the column is repeated ( broadcasting )
print (df.A)
2000-01-01 0.472986
2000-01-02 0.753143
2000-01-03 -0.806982
2000-01-04 1.095956
2000-01-05 1.012515
2000-01-06 0.501872
2000-01-07 -1.154360
2000-01-08 -2.218535
Freq: D, Name: A, dtype: float64
print (df.where(df>0, df.A, axis = 0))
A B C D
2000-01-01 0.472986 0.472986 0.242439 0.472986
2000-01-02 0.753143 0.753143 0.005127 0.753143
2000-01-03 -0.806982 2.871819 -0.806982 0.472457
2000-01-04 1.095956 1.095956 1.342356 1.095956
2000-01-05 1.012515 1.012515 1.012515 1.209796
2000-01-06 0.501872 0.138846 0.640761 0.527333
2000-01-07 -1.154360 -1.154360 -1.154360 -1.154360
2000-01-08 -2.218535 -2.218535 -2.218535 -2.218535
To replace others DataFrame
:
df1 = pd.DataFrame(np.random.randint(10, size=(8,4)),
index = pd.date_range('20000101',periods=8),
columns = list('ABCD'))
print (df1)
A B C D
2000-01-01 5 3 5 0
2000-01-02 2 9 6 4
2000-01-03 7 6 2 3
2000-01-04 2 6 4 5
2000-01-05 0 0 5 4
2000-01-06 0 3 7 9
2000-01-07 6 8 6 1
2000-01-08 4 9 6 5
print (df.where(df>0, df1))
A B C D
2000-01-01 0.472986 3.000000 0.242439 0.000000
2000-01-02 0.753143 9.000000 0.005127 4.000000
2000-01-03 7.000000 2.871819 2.000000 0.472457
2000-01-04 1.095956 6.000000 1.342356 5.000000
2000-01-05 1.012515 0.000000 5.000000 1.209796
2000-01-06 0.501872 0.138846 0.640761 0.527333
2000-01-07 6.000000 8.000000 6.000000 1.000000
2000-01-08 4.000000 9.000000 6.000000 5.000000
EDIT:
Explanation of alignment:
Here the function where
works with 2 objects (Series, DataFrame) and if we encounter common columns then the common index names data is aligned and the function is applied. So, here only the index A,B
in Series
matches the A,B
DataFrame column , whereas the other columns get NaNs.
s = pd.Series(np.random.randint(10, size=4) , index = list('ABEF'))
print (s)
A 5
B 3
E 5
F 0
dtype: int32
print (df.where(df>0, s, axis=1))
A B C D
2000-01-01 0.472986 3.000000 0.242439 NaN
2000-01-02 0.753143 3.000000 0.005127 NaN
2000-01-03 5.000000 2.871819 NaN 0.472457
2000-01-04 1.095956 3.000000 1.342356 NaN
2000-01-05 1.012515 3.000000 NaN 1.209796
2000-01-06 0.501872 0.138846 0.640761 0.527333
2000-01-07 5.000000 3.000000 NaN NaN
2000-01-08 5.000000 3.000000 NaN NaN
EDIT1:
Example with parameter level
:
If you DataFrame
have MultiIndex
, then you must add a parameter level
to indicate what level MultiIndex
is for the function where
.
np.random.seed(12)
mux = pd.MultiIndex.from_arrays([pd.date_range('20000101',periods=8), list('aaaabbbb')],
names=('date', 'par'))
df = pd.DataFrame(np.random.randn(8,4),
index = mux,
columns = list('ABCD'))
print (df)
A B C D
date par
2000-01-01 a 0.472986 -0.681426 0.242439 -1.700736
2000-01-02 a 0.753143 -1.534721 0.005127 -0.120228
2000-01-03 a -0.806982 2.871819 -0.597823 0.472457
2000-01-04 a 1.095956 -1.215169 1.342356 -0.122150
2000-01-05 b 1.012515 -0.913869 -1.029530 1.209796
2000-01-06 b 0.501872 0.138846 0.640761 0.527333
2000-01-07 b -1.154360 -2.213333 -1.681757 -1.788094
2000-01-08 b -2.218535 -0.647431 -0.528404 -0.039209
s = pd.Series(np.random.randint(10, size=2) , index = list('ac'))
print (s)
a 5
c 3
dtype: int32
print (df.where(df>0, s, axis=0, level=1))
A B C D
date par
2000-01-01 a 0.472986 5.000000 0.242439 5.000000
2000-01-02 a 0.753143 5.000000 0.005127 5.000000
2000-01-03 a 5.000000 2.871819 5.000000 0.472457
2000-01-04 a 1.095956 5.000000 1.342356 5.000000
2000-01-05 b 1.012515 NaN NaN 1.209796
2000-01-06 b 0.501872 0.138846 0.640761 0.527333
2000-01-07 b NaN NaN NaN NaN
2000-01-08 b NaN NaN NaN NaN
source to share