0

我尝试使用 dfply 包在给定条件下创建一个累加器列,但自定义函数失败。

以钻石数据为例:我想创建一个累加器列,如果价格大于 500,则 +1,否则 +0。

我的代码如下:

import panda as pd
from dfply import *

@make_symbolic
def accu(s, threshold):
    cur = 0
    res = []
    for x in s:
        if x > threshold:
            cur += 1
        res += [cur]
    return pd.Series(res)


(diamonds >> 
 mask(X.color == 'D', X.cut == 'Premium', X.carat > 0.32) >>
 mutate(row_id = row_number(X.price),        # Get the row number
        accu_id = accu(X.price, 500)) >>     # Get the accumulator, this step failed
 arrange(X.row_id) >>
 head(10)
)

预期输出将如下所示:

price row_id accu_id
498   1      0
499   2      0
501   3      1
502   4      2
400   5      2
503   6      3
4

0 回答 0