如何在pytorch中有效地实现向前填充

nfg76nw0  于 2023-10-20  发布在  其他
关注(0)|答案(3)|浏览(108)

如何有效地为矢量形状的NxLxC(批处理,序列维,通道)实现填充转发逻辑(灵感来自pandas ffill)。因为每个通道序列是独立的,所以这可以等效于使用Tensor形状的(N*C)xL。
计算应该保持 Torch 变量,以便实际输出是可微的。
我设法使先进的索引的东西,但它是L**2在内存和操作的数量,所以不是很大和GPU友好.
范例:
假设你有一个序列[0,1,2,0,0,3,0,4,0,0,0,5,6,0],在一个形状为1x14的Tensor中,向前填充将给予序列[0,1,2,2,2,3,3,4,4,4,4,5,6,6]
另一个形状为2x4的例子是[[0, 1, 0, 3], [1, 2, 0, 3]],应该向前填充到[[0, 1, 1, 3], [1, 2, 2, 3]]中。
今天使用的方法:
我们使用以下代码,它高度未优化,但仍然比非向量化循环快:

def last_zero_sequence_start_indices(t: torch.Tensor) -> torch.Tensor:
    """
    Given a 3D tensor `t`, this function returns a two-dimensional tensor where each entry represents
    the starting index of the last contiguous sequence of zeros up to and including the current index.
    If there's no zero at the current position, the value is the tensor's length.

    In essence, for each position in `t`, the function pinpoints the beginning of the last contiguous
    sequence of zeros up to that position.

    Args:
    - t (torch.Tensor): Input tensor with shape [Batch, Channel, Time].

    Returns:
    - torch.Tensor: Three-dimensional tensor with shape [Batch, Channel, Time] indicating the starting position of
        the last sequence of zeros up to each index in `t`.
    """

    # Create a mask indicating the start of each zero sequence
    start_of_zero_sequence = (t == 0) & torch.cat([
        torch.full(t.shape[:-1] + (1,), True, device=t.device),
        t[..., :-1] != 0,
    ], dim=2)

    # Duplicate this mask into a TxT matrix
    duplicated_mask = start_of_zero_sequence.unsqueeze(2).repeat(1, 1, t.size(-1), 1)

    # Extract the lower triangular part of this matrix (including the diagonal)
    lower_triangular = torch.tril(duplicated_mask)

    # For each row, identify the index of the rightmost '1' (start of the last zero sequence up to that row)
    indices = t.size(-1) - 1 - lower_triangular.int().flip(dims=[3]).argmax(dim=3)

    return indices
vvppvyoh

vvppvyoh1#

这一个避免使用任何循环,torch.tril()和torch.argmax()函数被优化。

import torch
    def forward_fill(t: torch.Tensor) -> torch.Tensor:
      """
      Efficiently implements forward fill on a PyTorch tensor.
    
      Args:
        t (torch.Tensor): Input tensor with shape [Batch, Channel, Time].
    
      Returns:
        torch.Tensor: Output tensor with the same shape as the input tensor, but with the
          missing values filled in with the last non-missing value in each sequence.
      """
    
      # Convert input tensor to 2D
      flattened_tensor = t.view(-1, t.size(-1))
    
      # Extract lower triangular part of the tensor
      lower_triangular = torch.tril(flattened_tensor)
    
      # Find index of the rightmost non-zero element in every row
      indices = flattened_tensor.size(-1) - 1 - lower_triangular.int().flip(dims=[1]).argmax(dim=1)
    
      # Set values of the tensor to the corresponding values in the last row,
      # starting at index found in the previous step
      flattened_tensor[torch.arange(flattened_tensor.size(0)), indices] = flattened_tensor[-1]
    
      # Convert tensor back to the original shape
      return flattened_tensor.view(t.shape)

下面是如何使用forward_fill()的例子:

# Create tensor with some missing values
t = torch.tensor([
    [0, 1, 2, 0, 0, 3, 0, 4, 0, 0, 0, 5, 6, 0],
    [1, 2, 0, 3, 0, 3, 0, 4, 0, 0, 0, 5, 6, 0],
])

# Fill in the missing values using the `forward_fill()` function
filled_t = forward_fill(t)

# Print the filled-in tensor
print(filled_t)

产出:

tensor([[0, 1, 2, 2, 2, 3, 3, 4, 4, 4, 4, 5, 6, 6],
       [1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 6, 6]])
fhity93d

fhity93d2#

以下是解决这个问题的方法,无需创建TxT矩阵:

import torch
def forward_fill(t: torch.Tensor) -> torch.Tensor:
    n_dim, t_dim = t.shape
    # Generate indices range
    rng = torch.arange(t_dim)
    
    rng_2d = rng.unsqueeze(0).repeat(n_dim, 1)
    # Replace indices to zero for elements that equal zero
    rng_2d[t == 0] = 0
    
    # Forward fill of indices range so all zero elements will be replaced with previous non-zero index.
    idx = rng_2d.cummax(1).values
    t = t[torch.arange(n_dim)[:, None], idx]
    return t

请注意,这是一个2D输入的解决方案,但可以很容易地修改为更多的维度。

cidc1ykv

cidc1ykv3#

我能想到的最短的方法是,使用掩码来标识最后的零序起始位置,然后进行一些索引和置换操作来创建最终输出。

def fill_seq_with_last(seq: torch.Tensor) -> torch.Tensor:
    idx = torch.ones_like(seq, dtype=torch.int64) * (len(seq))
    idx[seq != 0] = t.nonzero()[:, -1]
    idx = torch.tril(idx)
    idx[idx == 0] = (len(seq))
    idx = idx.permute(1, 0)[torch.arange(0, len(seq)), seq != 0].add(1)
    return seq[torch.arange(len(seq)).unsqueeze(0)] = seq[torch.arange(len(seq)).unsqueeze(0), idx].unsqueeze(-1)

使用高级索引来提取这些索引处的值,并将它们设置为向前填充值。该解决方案是有效的,并且可以容易地适用于批量序列。

相关问题