如何有效地为矢量形状的NxLxC(批处理,序列维,通道)实现填充转发逻辑(灵感来自pandas ffill
)。因为每个通道序列是独立的,所以这可以等效于使用Tensor形状的(N*C)xL。
计算应该保持 Torch 变量,以便实际输出是可微的。
我设法使先进的索引的东西,但它是L**2在内存和操作的数量,所以不是很大和GPU友好.
范例:
假设你有一个序列[0,1,2,0,0,3,0,4,0,0,0,5,6,0]
,在一个形状为1x14
的Tensor中,向前填充将给予序列[0,1,2,2,2,3,3,4,4,4,4,5,6,6]
。
另一个形状为2x4
的例子是[[0, 1, 0, 3], [1, 2, 0, 3]]
,应该向前填充到[[0, 1, 1, 3], [1, 2, 2, 3]]
中。
今天使用的方法:
我们使用以下代码,它高度未优化,但仍然比非向量化循环快:
def last_zero_sequence_start_indices(t: torch.Tensor) -> torch.Tensor:
"""
Given a 3D tensor `t`, this function returns a two-dimensional tensor where each entry represents
the starting index of the last contiguous sequence of zeros up to and including the current index.
If there's no zero at the current position, the value is the tensor's length.
In essence, for each position in `t`, the function pinpoints the beginning of the last contiguous
sequence of zeros up to that position.
Args:
- t (torch.Tensor): Input tensor with shape [Batch, Channel, Time].
Returns:
- torch.Tensor: Three-dimensional tensor with shape [Batch, Channel, Time] indicating the starting position of
the last sequence of zeros up to each index in `t`.
"""
# Create a mask indicating the start of each zero sequence
start_of_zero_sequence = (t == 0) & torch.cat([
torch.full(t.shape[:-1] + (1,), True, device=t.device),
t[..., :-1] != 0,
], dim=2)
# Duplicate this mask into a TxT matrix
duplicated_mask = start_of_zero_sequence.unsqueeze(2).repeat(1, 1, t.size(-1), 1)
# Extract the lower triangular part of this matrix (including the diagonal)
lower_triangular = torch.tril(duplicated_mask)
# For each row, identify the index of the rightmost '1' (start of the last zero sequence up to that row)
indices = t.size(-1) - 1 - lower_triangular.int().flip(dims=[3]).argmax(dim=3)
return indices
3条答案
按热度按时间vvppvyoh1#
这一个避免使用任何循环,torch.tril()和torch.argmax()函数被优化。
下面是如何使用forward_fill()的例子:
产出:
fhity93d2#
以下是解决这个问题的方法,无需创建TxT矩阵:
请注意,这是一个2D输入的解决方案,但可以很容易地修改为更多的维度。
cidc1ykv3#
我能想到的最短的方法是,使用掩码来标识最后的零序起始位置,然后进行一些索引和置换操作来创建最终输出。
使用高级索引来提取这些索引处的值,并将它们设置为向前填充值。该解决方案是有效的,并且可以容易地适用于批量序列。