Tensorflow中的分段中值

nwlls2ji  于 2022-11-16  发布在  其他
关注(0)|答案(1)|浏览(151)

我有一个Tensor,我想计算基于分段的中值。使用tf.math.segment_max/sum/mean可以很容易地计算分段最大值、求和、平均值等,但是如果我想计算分段中值,我该怎么做呢?

x = tf.constant([[0.1, 0.2, 0.4, 0.5],
                 [0.1, 0.8, 0.2, 0.6],
                 [0.1, 0.2, 0.4, 0.5],
                 [0.3, 0.1, 0.2, 0.9],
                 [0.1, 0.1, 0.6, 0.5]])

result = tf.math.segment_max(x, tf.constant([0, 1, 1, 1, 2]))

result

tf.constant([[0.1, 0.2, 0.4, 0.5],
             [0.1, 0.2, 0.2, 0.6],
             [0.1, 0.1, 0.6, 0.5]])
iyr7buue

iyr7buue1#

可通过将Tensor转换为粗糙Tensor,其中tf.RaggedTensor.from_value_rowids使用segment_ids。然后沿着每个轴应用median
例如,要获得上例的粗糙Tensor:

a = tf.RaggedTensor.from_value_rowids(
    values=x,
    value_rowids=[0, 1, 1, 1, 2])
# a.shape gives (3, None, 4) - Where None has 1,3,1 dimensions

然后,我们将中值应用于上述Tensor的每一行
1.每行Tensor的中值为:

def median(x):
    return tfp.stats.percentile(x[None,...], 50.0, interpolation='midpoint', axis=1)

1.段中值

def segment_median(data, segment_ids):
    ragged = tf.RaggedTensor.from_value_rowids(values=data, value_rowids=segment_ids)
    m = []
    for row in ragged:
        m.append(median(row))
    return tf.squeeze(tf.stack(m))

segment_median(x,tf.常量([0,1,1,1,2]))返回

<tf.Tensor: shape=(3, 4), dtype=float32, numpy=
array([[0.1, 0.2, 0.4, 0.5],
       [0.1, 0.2, 0.2, 0.6],
       [0.1, 0.1, 0.6, 0.5]], dtype=float32)>

相关问题