Python Pandas DataFrames:高效地在一列中查找下一个值,其中另一列的值更大

Python Pandas DataFrames:高效地在一列中查找下一个值,其中另一列的值更大,python,pandas,Python,Pandas,标题描述了我的情况。我已经有了一个工作版本,但是当扩展到大型数据帧(>1M行)时,效率非常低。我想知道是否有人有更好的想法来做这件事 带有解决方案和代码的示例 创建一个新列next_time,该列具有price列大于当前行的下一个时间值 import pandas as pd df = pd.DataFrame({'time': [15, 30, 45, 60, 75, 90], 'price': [10.00, 10.01, 10.00, 10.01, 10.02, 9.99]}) print

标题描述了我的情况。我已经有了一个工作版本,但是当扩展到大型数据帧(>1M行)时,效率非常低。我想知道是否有人有更好的想法来做这件事

带有解决方案和代码的示例

创建一个新列
next_time
,该列具有
price
列大于当前行的下一个时间值

import pandas as pd
df = pd.DataFrame({'time': [15, 30, 45, 60, 75, 90], 'price': [10.00, 10.01, 10.00, 10.01, 10.02, 9.99]})
print(df)
   time  price
0    15  10.00
1    30  10.01
2    45  10.00
3    60  10.01
4    75  10.02
5    90   9.99

series_to_concat = []
for price in df['price'].unique():
    index_equal_to_price = df[df['price'] == price].index
    series_time_greater_than_price = df[df['price'] > price]['time']
    time_greater_than_price_backfilled = series_time_greater_than_price.reindex(index_equal_to_price.union(series_time_greater_than_price.index)).fillna(method='backfill')

    series_to_concat.append(time_greater_than_price_backfilled.reindex(index_equal_to_price))

df['next_time'] = pd.concat(series_to_concat, sort=False)

print(df)
   time  price  next_time
0    15  10.00       30.0
1    30  10.01       75.0
2    45  10.00       60.0
3    60  10.01       75.0
4    75  10.02        NaN
5    90   9.99        NaN
这让我得到了想要的结果。当放大到一些大的数据帧时,计算这可能需要几分钟。有人对如何处理这个问题有更好的想法吗

编辑:约束澄清

我们可以假设数据帧是按时间排序的。
另一种说法是,给定任何一行(时间、价格),0价格n并且没有y,因此nyx价格当我在这个样本上用
%timeit
测试时,这些解决方案速度更快,但我在一个更大的数据帧上进行了测试,它们比您的解决方案慢得多。在更大的数据帧中,看看这3种解决方案中是否有任何一种更快,这将是一件有趣的事情。我会查看
dask
或查看:

我希望其他人能够发布一个更有效的解决方案。以下是一些不同的答案:


  • 您可以使用
    next
    一个行程序来实现这一点,该行程序与
    zip
    同时循环通过
    time
    price
    列。
    next
    函数的工作原理与列表理解完全相同,但您使用的是need to括号而不是括号,它只返回第一个
    True
    值。您还需要传递
    None
    ,以便在
    next
    函数中将错误作为参数处理
  • 您需要传递轴=1,因为您正在按列进行比较
  • 这将提高性能,因为当迭代在返回第一个值并移动到下一行后停止时,不会循环遍历整个列


    正如你所见,列表理解将返回相同的结果,但在理论上会慢得多。。。因为迭代的总次数会显著增加,特别是对于大数据帧

    df['next_time'] = (df.apply(lambda x: [z for (y, z) in zip(df['price'], df['time'])
                                           if y > x['price'] if z > x['time']], axis=1)).str[0]
    df
    Out[2]: 
       time  price  next_time
    0    15  10.00       30.0
    1    30  10.01       75.0
    2    45  10.00       60.0
    3    60  10.01       75.0
    4    75  10.02        NaN
    5    90   9.99        NaN
    

    使用一些
    numpy
    和np.where()创建函数的另一个选项:


    这一次在不到7秒内为我返回了1000000行和162000个唯一价格的数据帧变体。因此,我认为既然你在660000行和12000个独特的价格上运行它,速度的提高将是100x-1000x

    你的问题的另一个复杂性是,最接近的更高价格必须在以后的时间。这个答案帮助我发现了
    对分
    函数,但它没有增加依赖时间列的复杂性。因此,我必须从两个不同的角度来解决这个问题(正如您在关于我的
    np.where()
    的评论中提到的,将其分解为两种不同的方法)

    将熊猫作为pd导入
    df=pd.DataFrame({'time':[15,30,45,60,75,90],'price':[10.00,10.01,10.00,10.01,10.02,9.99])
    def对分_右(a、x、lo=0、hi=None):
    如果lo<0:
    raise VALUERROR('lo必须为非负')
    如果hi为无:
    hi=len(a)
    当lodf['next_time']=np。其中(df['next_time']David确实提出了一个很好的解决方案,可以在以后找到最接近的更高价格。不过,我确实想在以后找到更高价格的下一次出现。与我的同事合作,我们找到了这个解决方案

    包含元组的堆栈(索引、价格)

  • 遍历所有行(索引i)
  • 当堆栈为非空且堆栈顶部的价格较低时,则弹出并在弹出的索引中填入时间[index]
  • 将(i,价格[i])推到堆栈上
  • 这个解决方案实际上执行得非常快。我不完全确定,但我相信复杂性将接近O(n)因为这是整个数据帧的一次完整传递。之所以表现如此出色,是因为堆栈基本上是经过排序的,其中最大的价格在底部,最小的价格在堆栈顶部

    下面是我对实际数据帧的测试

    print(f'{len(df):,.0f} rows with {len(df["price"].unique()):,.0f} unique prices ranging from ${df["price"].min():,.2f} to ${df["price"].max():,.2f}')
    667,037 rows with 11,786 unique prices ranging from $1,857.52 to $2,022.00
    
    def find_next_time_with_greater_price(df):
        times = df['time'].to_numpy()
        prices = df['price'].to_numpy()
        stack = []
        next_times = np.full(len(df), np.nan)
        for i in range(len(df)):
            while stack and prices[i] > stack[-1][1]:
                stack_time_index, stack_price = stack.pop()
                next_times[stack_time_index] = times[i]
            stack.append((i, prices[i]))
        return next_times
    
    %timeit -n10 -r10 df['next_time'] = find_next_time_with_greater_price(df)
    434 ms ± 11.8 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
    

    嘿,大卫,谢谢你的快速回复!我明白你的意思了,使用较小的数据帧,我完全可以理解为什么应用程序更快。我最初做这项工作时试图避免应用程序,但很快就意识到为其提出矢量化解决方案是一个困难的问题。我将使用其中一个较大的数据帧和d我会告诉你性能的最新情况。另外,如果我可以使用增强性能链接中的大部分内容,那就太好了,但是我有数千个这样的大型数据帧需要计算,所以我已经在并行跟踪中这样做了:我在约660000行和约12000个唯一价格的数据帧上运行了这个。我的函数使用了about还有5分钟就要跑了,而你的3个中的第一个已经跑了20分钟了,仍然没有重新跑过
    def closest(x):
        try:
            lst = df.groupby(df['price'].cummax())['time'].transform('first')
            lst = np.asarray(lst)
            lst = lst[lst>x] 
            idx = (np.abs(lst - x)).argmin() 
            return lst[idx]
        except ValueError:
            pass
    
    
    df['next_time'] = np.where((df['price'].shift(-1) > df['price']),
                                df['time'].shift(-1),
                                df['time'].apply(lambda x: closest(x)))
    
    import pandas as pd
    df = pd.DataFrame({'time': [15, 30, 45, 60, 75, 90], 'price': [10.00, 10.01, 10.00, 10.01, 10.02, 9.99]})
    
    def bisect_right(a, x, lo=0, hi=None):
        if lo < 0:
            raise ValueError('lo must be non-negative')
        if hi is None:
            hi = len(a)
        while lo < hi:
            mid = (lo+hi)//2
            if x < a[mid]: hi = mid
            else: lo = mid+1
        return lo
    
    
    def get_closest_higher(df, col, val):
        higher_idx = bisect_right(df[col].values, val)
        return higher_idx
    
    
    df = df.sort_values(['price', 'time']).reset_index(drop=True)
    df['next_time'] = df['price'].apply(lambda x: get_closest_higher(df, 'price', x))
    
    df['next_time'] = df['next_time'].map(df['time'])
    df['next_time'] = np.where(df['next_time'] <= df['time'], np.nan, df['next_time'] )
    df = df.sort_values('time').reset_index(drop=True)
    df['next_time'] = np.where((df['price'].shift(-1) > df['price'])
                               ,df['time'].shift(-1),
                               df['next_time'])
    df['next_time'] = df['next_time'].ffill()
    df['next_time'] = np.where(df['next_time'] <= df['time'], np.nan, df['next_time'])
    df
    
    Out[1]: 
       time  price  next_time
    0    15  10.00       30.0
    1    30  10.01       75.0
    2    45  10.00       60.0
    3    60  10.01       75.0
    4    75  10.02        NaN
    5    90   9.99        NaN
    
    import numpy as np
    import pandas as pd
    df = pd.DataFrame({'time': [15, 30, 45, 60, 75, 90], 'price': [10.00, 10.01, 10.00, 10.01, 10.02, 9.99]})
    print(df)
       time  price
    0    15  10.00
    1    30  10.01
    2    45  10.00
    3    60  10.01
    4    75  10.02
    5    90   9.99
    
    times = df['time'].to_numpy()
    prices = df['price'].to_numpy()
    stack = []
    next_times = np.full(len(df), np.nan)
    for i in range(len(df)):
        while stack and prices[i] > stack[-1][1]:
            stack_time_index, stack_price = stack.pop()
            next_times[stack_time_index] = times[i]
        stack.append((i, prices[i]))
    df['next_time'] = next_times
    
    print(df)
       time  price  next_time
    0    15  10.00       30.0
    1    30  10.01       75.0
    2    45  10.00       60.0
    3    60  10.01       75.0
    4    75  10.02        NaN
    5    90   9.99        NaN
    
    print(f'{len(df):,.0f} rows with {len(df["price"].unique()):,.0f} unique prices ranging from ${df["price"].min():,.2f} to ${df["price"].max():,.2f}')
    667,037 rows with 11,786 unique prices ranging from $1,857.52 to $2,022.00
    
    def find_next_time_with_greater_price(df):
        times = df['time'].to_numpy()
        prices = df['price'].to_numpy()
        stack = []
        next_times = np.full(len(df), np.nan)
        for i in range(len(df)):
            while stack and prices[i] > stack[-1][1]:
                stack_time_index, stack_price = stack.pop()
                next_times[stack_time_index] = times[i]
            stack.append((i, prices[i]))
        return next_times
    
    %timeit -n10 -r10 df['next_time'] = find_next_time_with_greater_price(df)
    434 ms ± 11.8 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)