numpy NP.分位数()的TensorFlow概率(tfp)等效值

vd8tlhqk  于 2023-02-08  发布在  其他
关注(0)|答案(2)|浏览(237)

我正在尝试寻找np.quantile()的TensorFlow等价物。我找到了tfp.stats.quantiles()tfp代表TensorFlow概率)。但是,它的构造与np.quantile()的构造略有不同。
请看下面的例子:

import tensorflow_probability as tfp
import tensorflow as tf 
import numpy as np 

inputs = tf.random.normal((1, 4096, 4))

print("NumPy")
print(np.quantile(inputs.numpy(), q=0.9, axis=1, keepdims=False))

从TFP文档中我不确定如何使用tfp.stats.quantile()编写上述代码。我尝试检查这两个方法的源代码,但没有帮助。

mwkjh3gx

mwkjh3gx1#

让我试着比在GitHub上更有帮助。
np.quantiletfp.stats.quantiles在行为上有所不同。
沿着指定轴计算数据的第q个分位数。
其中q
要计算的分位数或分位数序列,必须介于0和1之间(含0和1)。
tfp.stats.quantiles
给定样本向量x,此函数通过返回num_quantiles + 1分界点来估计分界点
所以你需要告诉tfp.stats.quantiles你想要多少个分位数,然后选择出q的分位数。如果你不清楚如何从API中完成这件事,如果你看一下tfp.stats.quantiles的源代码(对于v0.19.0),我们可以看到它向我们展示了如何获得与NumPy类似的返回结构。
为完整起见,请使用设置虚拟环境

$ cat requirements.txt
numpy==1.24.2
tensorflow==2.11.0
tensorflow-probability==0.19.0

让我们能够

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

inputs = tf.random.normal((1, 4096, 4), dtype=tf.float64)
q = 0.9

numpy_quantiles = np.quantile(inputs.numpy(), q=q, axis=1, keepdims=False)

tfp_quantiles = tfp.stats.quantiles(
    inputs, num_quantiles=100, axis=1, interpolation="linear"
)[int(q * 100)]

assert np.allclose(numpy_quantiles, tfp_quantiles.numpy())

print(f"{numpy_quantiles=}")
# numpy_quantiles=array([[1.31727661, 1.2699167 , 1.28735237, 1.27137588]])
print(f"{tfp_quantiles=}")
# tfp_quantiles=<tf.Tensor: shape=(1, 4), dtype=float64, numpy=array([[1.31727661, 1.2699167 , 1.28735237, 1.27137588]])>
axzmvihb

axzmvihb2#

您还可以使用tfp.stats.percentile(inputs, 90., axis=1, keepdims=False)--与分位数的唯一区别是90.替换了.90

相关问题