条件を満たす最初/最後の行を取得 - pandas
DataFrameがあって、ある列への条件判定で最初にTrueとなる行や、そのindexを得たい。
aaa bbb 0 3 5 aaa==1である 1 1 9 ← 最初の行 2 4 2 3 1 6
もちろん、以下のように書けば得られる。
1 2 3 4 5 6 7 |
df # DataFrameが定義済みとする # aaa列が1になる最初の行を取得 idx = df[df[ 'aaa' ] = = 1 ].iloc[ 0 ] # aaa列が1になる最初の行のindexを取得 idx = df[df[ 'aaa' ] = = 1 ].index[ 0 ] |
しかし、あくまで最初の行が得たいだけなのに、これではまず全要素に対して1と等しいか比較し、それを元に新たなDataFrame参照を作り、やっとそこからindexを取得していて、無駄が多そうに見える。
良くありそうなケースなのに、意外と直感的な方法が見つかりづらい。 indexを取得するケースに絞ると、Stack Overflowなどで議論されていたものとしては、
-
- 「ソートされた列」から特定の値を持つ最初の要素を得る方法であり、若干違うが、実行速度を含めて議論されている
- 2年前の記事だが、紹介されていた計測コードを手元環境(Python3.7.3, pandas0.24.2)で実行しても、およそ同様の結果となった
- index[0]
- for loop
- pandas.idxmax()
- numpy.argmax()
- first_valid_index()
index[0]は、冒頭の例と同じもの。
for loop はその通り1行ずつ見て、見つかり次第breakする。
pandas.idxmax()は配列中の最大値のindexを得る関数だが、「boolが0,1で評価されること」「最大値が複数ある場合は最初のindexが返される仕様」を利用すると、今回の目的通りの結果が返る。
numpy.argmax()はそれをnumpy配列上で行う。
first_valid_index()は、ズバリそのもののことをする関数だが、遅いらしい。
total_time_sec ratio wrt fastest algo argmax numpy: 0.0165 1.00 idxmax pandas: 0.0741 4.49 index[0]: 0.0762 4.62 first_valid_index pandas: 0.1434 8.69 for loop: 9.0507 548.53
- argmax()を使って、
(df['aaa'].values == 1).argmax()
とするのが最も速い - idxmax()を使って、
(df['aaa'] == 1).idxmax()
とするのが次いで速いが、argmax()と比較すると4倍くらい遅い(小さいDFでは比が更に大きくなる) - index[0]はそれより僅かに遅い
- first_valid_index()は更に2倍ほど遅い
- for loopは、幸運にも最初の方にあると速いが、後ろの方にあると致命的に遅い
argmax()を使うのがよいと思われるが、注意点がある。
- 全てがFalseの場合、最大値は“0”なので、それを取る最初のindexである“0”が返されてしまう。確実にTrueが存在する保証か、または事後チェックが必要
- もしDataFrameのindexが0からの連番で無い場合、「配列の添字としてのindex」と「DataFrameのindex」で違いが生じる。df.index[]を通して後から変換する必要がある
- 関数名の意味とズレた目的に使っているので可読性維持のためにはコメントで注意するのが望ましい
1つめ、3つめの問題は、idxmax()も同様である。
そこまで速度にこだわりがなくて、コードの読みやすさを重視したいなら、index[0]がよい。
現状、for loop以外は、最初にbool配列を得るための条件判定は列全体に対して行わざるをえず、これが無駄といえば無駄である。 しかし、複数の値に対して単純な演算処理を適用するのはnumpyで高速に動くよう組まれているので、下手にPythonでfor loopするより速くなる。
最後の行
最後の行を取得する場合は、numpy.argmax()を使うなら配列を反転してから用いて、結果を配列長から引けば求まるのだが、ますます可読性が下がってしまう。
aaa aaa==1 0 1 True 1 3 False 2 1 True 3 4 False ↓ .values[::-1] (numpy配列化して反転) [0, 1, 0, 1] ↓ .argmax() 1 ↓ 配列長-1 から引く 2
Stack Overflowの検証コードを少し変更し、各種方法で最後のindexを取得する速度を比較した。
- index[-1]
- for loop
- 反転してnumpy.argmax()
- 反転してpandas.idxmax()
- last_valid_index()
last_valid_index()という関数もあるが、first同様に何故か遅い。
反転コストはあっても、やはりnumpy.argmax()が明確に最速。pandasは反転コストが高いのか、idxmax() よりは、index[-1]の方が僅かに早い結果となった。
10000行のDFで、最初の5要素のみ'b'で他は'a'の列から、'a'でない最後の行のindex(4)を探すという処理を100回繰り返した合計時間。
total_time_sec ratio wrt fastest algo argmax numpy: 0.0159 1.00 index[-1]: 0.0712 4.48 idxmax pandas: 0.0786 4.94 last_valid_index pandas: 0.1415 8.90 for loop: 9.0257 567.65
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import numpy as np import pandas as pd import timeit # code snippet to be executed only once # mysetup = '''import pandas as pd # import numpy as np # df = pd.DataFrame({"A":['a','a','a','b','b'],"B":[1]*5}) # ''' mysetup = '''import pandas as pd import numpy as np n = 10000 lt = ['a' for _ in range(n)] b = ['b' for _ in range(5)] lt[:5] = b df = pd.DataFrame({"A":lt,"B":[1]*n}) ''' # code snippets whose execution time is to be measured mycode_set = [ ''' df[df.A!='a'].last_valid_index() ''' ] message = [ "last_valid_index pandas:" ] mycode_set.append( '''df.loc[df.A!='a','A'].index[-1]''' ) message.append( "index[-1]: " ) mycode_set.append( '''df.A.ne('a')[::-1].idxmax()''' ) message.append( "idxmax pandas: " ) mycode_set.append( '''len(df) - (df.A.values != 'a')[::-1].argmax() - 1''' ) message.append( "argmax numpy: " ) mycode_set.append( '''for index in df.index[::-1]: if df['A'][index] != 'a': ans = index break ''' ) message.append( "for loop: " ) total_time_in_sec = [] for i in range ( len (mycode_set)): mycode = mycode_set[i] total_time_in_sec.append(np. round (timeit.timeit(setup = mysetup, stmt = mycode, number = 100 ), 4 )) output = pd.DataFrame(total_time_in_sec, index = message, columns = [ 'total_time_sec' ]) output[ "ratio wrt fastest algo" ] = \ np. round (output.total_time_sec / output[ "total_time_sec" ]. min (), 2 ) output = output.sort_values(by = "total_time_sec" ) print (output) |