Pandas.DataFrame.where alignment

I was looking into pandas indexing and selection and came across a function pandas.DataFrame.where()

. This function has axis

and level

for alignment purpose. For example:

In [21]: df = pd.DataFrame(np.random.randn(8,4), index = pd.date_range('20000101',periods=8), columns = list('ABCD'))

In [22]: df
Out[22]: 
                   A         B         C         D
2000-01-01 -0.222193  0.764096 -2.000947 -1.162589
2000-01-02 -0.387643 -0.497687  0.868227 -0.939663
2000-01-03  1.001708 -0.761496  0.179564 -0.403473
2000-01-04  0.469317 -0.161929 -0.844448 -0.211096
2000-01-05  0.580083 -0.952382  0.105044 -0.648209
2000-01-06 -0.312277 -0.762257 -0.894456 -1.169686
2000-01-07 -1.446776 -1.276171 -1.466667  0.800513
2000-01-08 -0.659035 -0.006725 -1.475503  0.353150

In [23]: df.where(df>0, df.A, axis = 1)
Out[23]: 
                   A         B         C         D
2000-01-01       NaN  0.764096       NaN       NaN
2000-01-02       NaN       NaN  0.868227       NaN
2000-01-03  1.001708       NaN  0.179564       NaN
2000-01-04  0.469317       NaN       NaN       NaN
2000-01-05  0.580083       NaN  0.105044       NaN
2000-01-06       NaN       NaN       NaN       NaN
2000-01-07       NaN       NaN       NaN  0.800513
2000-01-08       NaN       NaN       NaN  0.353150

In [24]: df.where(df>0, df.A, axis = 0)
Out[24]: 
                   A         B         C         D
2000-01-01 -0.222193  0.764096 -0.222193 -0.222193
2000-01-02 -0.387643 -0.387643  0.868227 -0.387643
2000-01-03  1.001708  1.001708  0.179564  1.001708
2000-01-04  0.469317  0.469317  0.469317  0.469317
2000-01-05  0.580083  0.580083  0.105044  0.580083
2000-01-06 -0.312277 -0.312277 -0.312277 -0.312277
2000-01-07 -1.446776 -1.446776 -1.446776  0.800513
2000-01-08 -0.659035 -0.659035 -0.659035  0.353150

      

I didn't understand the usage axis

here (to be honest, I didn't get the concept of alignment.) I know that axis = 0 is called "column-wise" and axis = 1 is "-wise". Can anyone explain to me the result (alignment concept) and also use the parameter level

.

+3


source to share


1 answer


Example for explanation:

np.random.seed(12)
df = pd.DataFrame(np.random.randn(8,4), 
                  index = pd.date_range('20000101',periods=8), 
                  columns = list('ABCD'))
print (df)
                   A         B         C         D
2000-01-01  0.472986 -0.681426  0.242439 -1.700736
2000-01-02  0.753143 -1.534721  0.005127 -0.120228
2000-01-03 -0.806982  2.871819 -0.597823  0.472457
2000-01-04  1.095956 -1.215169  1.342356 -0.122150
2000-01-05  1.012515 -0.913869 -1.029530  1.209796
2000-01-06  0.501872  0.138846  0.640761  0.527333
2000-01-07 -1.154360 -2.213333 -1.681757 -1.788094
2000-01-08 -2.218535 -0.647431 -0.528404 -0.039209

      


#boolean mask by condition
print (df>0)
                A      B      C      D
2000-01-01   True  False   True  False
2000-01-02   True  False   True  False
2000-01-03  False   True  False   True
2000-01-04   True  False   True  False
2000-01-05   True  False  False   True
2000-01-06   True   True   True   True
2000-01-07  False  False  False  False
2000-01-08  False  False  False  False

#without define value for replace False of mask values create NaNs
print (df.where(df>0))
                   A         B         C         D
2000-01-01  0.472986       NaN  0.242439       NaN
2000-01-02  0.753143       NaN  0.005127       NaN
2000-01-03       NaN  2.871819       NaN  0.472457
2000-01-04  1.095956       NaN  1.342356       NaN
2000-01-05  1.012515       NaN       NaN  1.209796
2000-01-06  0.501872  0.138846  0.640761  0.527333
2000-01-07       NaN       NaN       NaN       NaN
2000-01-08       NaN       NaN       NaN       NaN

      

To replace the False

mask with some Series

(string) it axis=1

is required - the string is repeated ( broadcasting ):

print (df.loc['2000-01-01'])
A    0.472986
B   -0.681426
C    0.242439
D   -1.700736
Name: 2000-01-01 00:00:00, dtype: float64

print (df.where(df>0, df.loc['2000-01-01'], axis = 1))
                   A         B         C         D
2000-01-01  0.472986 -0.681426  0.242439 -1.700736
2000-01-02  0.753143 -0.681426  0.005127 -1.700736
2000-01-03  0.472986  2.871819  0.242439  0.472457
2000-01-04  1.095956 -0.681426  1.342356 -1.700736
2000-01-05  1.012515 -0.681426  0.242439  1.209796
2000-01-06  0.501872  0.138846  0.640761  0.527333
2000-01-07  0.472986 -0.681426  0.242439 -1.700736
2000-01-08  0.472986 -0.681426  0.242439 -1.700736

      

And to replace with Series

(column) you need axis=0

- the column is repeated ( broadcasting )

print (df.A)
2000-01-01    0.472986
2000-01-02    0.753143
2000-01-03   -0.806982
2000-01-04    1.095956
2000-01-05    1.012515
2000-01-06    0.501872
2000-01-07   -1.154360
2000-01-08   -2.218535
Freq: D, Name: A, dtype: float64 

print (df.where(df>0, df.A, axis = 0))
                   A         B         C         D
2000-01-01  0.472986  0.472986  0.242439  0.472986
2000-01-02  0.753143  0.753143  0.005127  0.753143
2000-01-03 -0.806982  2.871819 -0.806982  0.472457
2000-01-04  1.095956  1.095956  1.342356  1.095956
2000-01-05  1.012515  1.012515  1.012515  1.209796
2000-01-06  0.501872  0.138846  0.640761  0.527333
2000-01-07 -1.154360 -1.154360 -1.154360 -1.154360
2000-01-08 -2.218535 -2.218535 -2.218535 -2.218535

      

To replace others DataFrame

:



df1 = pd.DataFrame(np.random.randint(10, size=(8,4)), 
                  index = pd.date_range('20000101',periods=8), 
                  columns = list('ABCD'))
print (df1)
            A  B  C  D
2000-01-01  5  3  5  0
2000-01-02  2  9  6  4
2000-01-03  7  6  2  3
2000-01-04  2  6  4  5
2000-01-05  0  0  5  4
2000-01-06  0  3  7  9
2000-01-07  6  8  6  1
2000-01-08  4  9  6  5

print (df.where(df>0, df1))
                   A         B         C         D
2000-01-01  0.472986  3.000000  0.242439  0.000000
2000-01-02  0.753143  9.000000  0.005127  4.000000
2000-01-03  7.000000  2.871819  2.000000  0.472457
2000-01-04  1.095956  6.000000  1.342356  5.000000
2000-01-05  1.012515  0.000000  5.000000  1.209796
2000-01-06  0.501872  0.138846  0.640761  0.527333
2000-01-07  6.000000  8.000000  6.000000  1.000000
2000-01-08  4.000000  9.000000  6.000000  5.000000

      

EDIT:

Explanation of alignment:

Here the function where

works with 2 objects (Series, DataFrame) and if we encounter common columns then the common index names data is aligned and the function is applied. So, here only the index A,B

in Series

matches the A,B

DataFrame column , whereas the other columns get NaNs.

s = pd.Series(np.random.randint(10, size=4) , index = list('ABEF'))
print (s)
A    5
B    3
E    5
F    0
dtype: int32

print (df.where(df>0, s, axis=1))

                   A         B         C         D
2000-01-01  0.472986  3.000000  0.242439       NaN
2000-01-02  0.753143  3.000000  0.005127       NaN
2000-01-03  5.000000  2.871819       NaN  0.472457
2000-01-04  1.095956  3.000000  1.342356       NaN
2000-01-05  1.012515  3.000000       NaN  1.209796
2000-01-06  0.501872  0.138846  0.640761  0.527333
2000-01-07  5.000000  3.000000       NaN       NaN
2000-01-08  5.000000  3.000000       NaN       NaN

      

EDIT1:

Example with parameter level

:

If you DataFrame

have MultiIndex

, then you must add a parameter level

to indicate what level MultiIndex

is for the function where

.

np.random.seed(12)
mux = pd.MultiIndex.from_arrays([pd.date_range('20000101',periods=8), list('aaaabbbb')],
                                names=('date', 'par'))
df = pd.DataFrame(np.random.randn(8,4), 
                  index = mux, 
                  columns = list('ABCD'))
print (df)
                       A         B         C         D
date       par                                        
2000-01-01 a    0.472986 -0.681426  0.242439 -1.700736
2000-01-02 a    0.753143 -1.534721  0.005127 -0.120228
2000-01-03 a   -0.806982  2.871819 -0.597823  0.472457
2000-01-04 a    1.095956 -1.215169  1.342356 -0.122150
2000-01-05 b    1.012515 -0.913869 -1.029530  1.209796
2000-01-06 b    0.501872  0.138846  0.640761  0.527333
2000-01-07 b   -1.154360 -2.213333 -1.681757 -1.788094
2000-01-08 b   -2.218535 -0.647431 -0.528404 -0.039209

s = pd.Series(np.random.randint(10, size=2) , index = list('ac'))
print (s)
a    5
c    3
dtype: int32

print (df.where(df>0, s, axis=0, level=1))
                       A         B         C         D
date       par                                        
2000-01-01 a    0.472986  5.000000  0.242439  5.000000
2000-01-02 a    0.753143  5.000000  0.005127  5.000000
2000-01-03 a    5.000000  2.871819  5.000000  0.472457
2000-01-04 a    1.095956  5.000000  1.342356  5.000000
2000-01-05 b    1.012515       NaN       NaN  1.209796
2000-01-06 b    0.501872  0.138846  0.640761  0.527333
2000-01-07 b         NaN       NaN       NaN       NaN
2000-01-08 b         NaN       NaN       NaN       NaN

      

+1


source







All Articles