Python 获取连续最小比较的索引

Python 获取连续最小比较的索引,python,numpy,pytorch,Python,Numpy,Pytorch,如果我有张量 values = torch.tensor([5., 4., 8., 3.]) 我想取每2个连续值的最小值,意思是 min(5., 4.) = 4. min(8., 3.) = 3. 是否有一种矢量化的方法可以做到这一点,并且仍然可以获得最小值的相对索引?这意味着我想要的输出是: min_index = [1, 1] #min_index[0] == 1 as 4. is the minimum of (5., 4.) and is in index 1 of (5., 4.)

如果我有张量

values = torch.tensor([5., 4., 8., 3.])
我想取每2个连续值的最小值,意思是

min(5., 4.) = 4.
min(8., 3.) = 3.
是否有一种矢量化的方法可以做到这一点,并且仍然可以获得最小值的相对索引?这意味着我想要的输出是:

min_index = [1, 1]
#min_index[0] == 1 as 4. is the minimum of (5., 4.) and is in index 1 of (5., 4.)
#min_index[1] == 1 as 3. is the minimum of (8., 3.) and is in index 1 of (8., 3.) 

这里是一个numpy实现:

a = np.random.randint(0,100, 10)
a
array([24, 60, 33, 65,  7, 84, 44, 67, 96, 18])

# Compare pairwise and get all pairs min relative index
min_index = np.argmin([a[:-1], a[1:]], axis=0)
min_index
array([0, 1, 0, 1, 0, 1, 0, 0, 1], dtype=int64)

# Pairs (24,60), (60,33), (33,65), and so on..

# Adding index and array location we get the global index of pairs min in the original array
global_min_index = [i+e for i,e in enumerate(min_index_tmp)]
global_min_index 
[0, 2, 2, 4, 4, 6, 6, 7, 9]

这里是一个numpy实现:

a = np.random.randint(0,100, 10)
a
array([24, 60, 33, 65,  7, 84, 44, 67, 96, 18])

# Compare pairwise and get all pairs min relative index
min_index = np.argmin([a[:-1], a[1:]], axis=0)
min_index
array([0, 1, 0, 1, 0, 1, 0, 0, 1], dtype=int64)

# Pairs (24,60), (60,33), (33,65), and so on..

# Adding index and array location we get the global index of pairs min in the original array
global_min_index = [i+e for i,e in enumerate(min_index_tmp)]
global_min_index 
[0, 2, 2, 4, 4, 6, 6, 7, 9]

我认为重塑你的张量会使它变得容易得多。 之后,
torch.min
自动返回最小值和索引

import torch

values = torch.tensor([5., 4., 8., 3.])
values_reshaped = values.reshape(-1,2) # works for any length
minimums, index = torch.min(values_reshaped, axis = -1)
print(minimums) # tensor of the minimum values
print(index) # tensor of indexes

我认为重塑你的张量会使它变得容易得多。 之后,
torch.min
自动返回最小值和索引

import torch

values = torch.tensor([5., 4., 8., 3.])
values_reshaped = values.reshape(-1,2) # works for any length
minimums, index = torch.min(values_reshaped, axis = -1)
print(minimums) # tensor of the minimum values
print(index) # tensor of indexes