tensorflow tfp正态分布下的PRNG密钥错误

wwtsj6pe  于 2023-10-23  发布在  其他
关注(0)|答案(1)|浏览(111)

我尝试从Windows的Ubuntu WSL运行以下Pyhon代码

import tensorflow_probability as tfp; tfp = tfp.substrates.jax
tfd = tfp.distributions
dist = tfd.Normal(loc=0., scale=3.)
dist.cdf(1.)
dist = tfd.Normal(loc=[1, 2.], scale=[11, 22.])
dist.prob([0, 1.5])
dist.sample([3])

我收到以下错误

> AttributeError                            Traceback (most recent call
> last) Cell In[45], line 4
>       2 tfd = tfp.distributions
>       3 dist = tfd.Normal(loc=0., scale=3.)
> ----> 4 dist.cdf(1.)
>       5 dist = tfd.Normal(loc=[1, 2.], scale=[11, 22.])
>       6 dist.prob([0, 1.5])
> 
> File
> ~/.local/lib/python3.8/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:1429,
> in Distribution.cdf(self, value, name, **kwargs)    1411 def cdf(self,
> value, name='cdf', **kwargs):    1412   """Cumulative distribution
> function.    1413     1414   Given random variable `X`, the cumulative
> distribution function `cdf` is:    (...)    1427       values of type
> `self.dtype`.    1428   """
> -> 1429   return self._call_cdf(value, name, **kwargs)
> 
> File
> ~/.local/lib/python3.8/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:1405,
> in Distribution._call_cdf(self, value, name, **kwargs)    1403 with
> self._name_and_control_scope(name, value, kwargs):    1404   if
> hasattr(self, '_cdf'):
> -> 1405     return self._cdf(value, **kwargs)    1406   if hasattr(self, '_log_cdf'):    1407     return
> tf.exp(self._log_cdf(value, **kwargs))
> 
> File
> ~/.local/lib/python3.8/site-packages/tensorflow_probability/substrates/jax/distributions/normal.py:195,
> in Normal._cdf(self, x)
>     194 def _cdf(self, x):
> --> 195   return special_math.ndtr(self._z(x))
> 
> File
> ~/.local/lib/python3.8/site-packages/tensorflow_probability/substrates/jax/internal/special_math.py:136,
> in ndtr(x, name)
>     132 if dtype_util.as_numpy_dtype(x.dtype) not in [np.float32, np.float64]:
>     133   raise TypeError(
>     134       "x.dtype=%s is not handled, see docstring for supported types."
>     135       % x.dtype)
> --> 136 return _ndtr(x)
> 
> File
> ~/.local/lib/python3.8/site-packages/tensorflow_probability/substrates/jax/internal/special_math.py:141,
> in _ndtr(x)
>     139 def _ndtr(x):
>     140   """Implements ndtr core logic."""
> --> 141   half_sqrt_2 = tf.constant(
>     142       0.5 * np.sqrt(2.), dtype=x.dtype, name="half_sqrt_2")
>     143   half = tf.constant(0.5, x.dtype)
>     144   one = tf.constant(1., x.dtype)
> 
> File
> ~/.local/lib/python3.8/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:117,
> in _constant(value, dtype, shape, name)
>     116 def _constant(value, dtype=None, shape=None, name='Const'):  # pylint: disable=unused-argument
> --> 117   x = convert_to_tensor(value, dtype=dtype)
>     118   if shape is None:
>     119     return x
> 
> File
> ~/.local/lib/python3.8/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:167,
> in _convert_to_tensor(value, dtype, dtype_hint, name)
>     164     pass
>     166 if ret is None:
> --> 167   ret = conversion_func(value, dtype=dtype)
>     168 return ret
> 
> File
> ~/.local/lib/python3.8/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:222,
> in _default_convert_to_tensor(value, dtype)
>     218 """Default tensor conversion function for array, bool, int, float, and complex."""
>     219 if JAX_MODE:
>     220   # TODO(b/223267515): We shouldn't need to specialize here.
>     221   if hasattr(value, 'dtype') and jax.dtypes.issubdtype(
> --> 222       value.dtype, jax.dtypes.prng_key
>     223   ):
>     224     return value
>     225   if isinstance(value, (list, tuple)) and value:
> 
> AttributeError: module 'jax.dtypes' has no attribute 'prng_key'

安装的软件包包括

> Package                  Version             
------------------------ --------------------
absl-py                  2.0.0               
Ambit-Stochastics        1.0.6               
anyio                    3.6.2               
argon2-cffi              21.3.0              
argon2-cffi-bindings     21.2.0              
asttokens                2.2.1               
attrs                    19.3.0              
Automat                  0.8.0               
Babel                    2.11.0              
backcall                 0.2.0               
beautifulsoup4           4.11.2              
bleach                   6.0.0               
blinker                  1.4                 
certifi                  2019.11.28          
cffi                     1.15.1              
chardet                  3.0.4               
charset-normalizer       3.0.1               
Click                    7.0                 
cloud-init               23.2.1              
cloudpickle              2.2.1               
colorama                 0.4.3               
comm                     0.1.2               
command-not-found        0.3                 
configobj                5.0.6               
constantly               15.1.0              
cryptography             2.8                 
cycler                   0.11.0              
dbus-python              1.2.16              
debugpy                  1.6.6               
decorator                5.1.1               
defusedxml               0.7.1               
distlib                  0.3.6               
distro                   1.4.0               
distro-info              0.23ubuntu1         
dm-tree                  0.1.8               
entrypoints              0.3                 
env                      0.1.0               
executing                1.2.0               
fastjsonschema           2.16.2              
filelock                 3.9.0               
gast                     0.5.4               
httplib2                 0.14.0              
hyperlink                19.0.0              
idna                     2.8                 
importlib-metadata       6.0.0               
importlib-resources      5.10.2              
incremental              16.10.1             
ipykernel                6.20.2              
ipython                  8.9.0               
ipython-genutils         0.2.0               
ipywidgets               8.1.1               
jax                      0.4.13              
jaxlib                   0.4.13              
jedi                     0.18.2              
Jinja2                   3.1.2               
json5                    0.9.11              
jsonpatch                1.22                
jsonpointer              2.0                 
jsonschema               4.17.3              
jupyter                  1.0.0               
jupyter-client           8.0.2               
jupyter-console          6.6.3               
jupyter-core             5.2.0               
jupyter-events           0.6.3               
jupyter-server           2.2.0               
jupyter-server-terminals 0.4.4               
jupyterlab               3.5.3               
jupyterlab-pygments      0.2.2               
jupyterlab-server        2.19.0              
jupyterlab-widgets       3.0.9               
keyring                  18.0.1              
kiwisolver               1.3.2               
language-selector        0.1                 
launchpadlib             1.10.13             
lazr.restfulclient       0.14.2              
lazr.uri                 1.0.3               
MarkupSafe               2.1.2               
matplotlib               3.4.3               
matplotlib-inline        0.1.6               
mistune                  2.0.4               
ml-dtypes                0.2.0               
more-itertools           4.2.0               
nbclassic                0.5.1               
nbclient                 0.7.2               
nbconvert                7.2.9               
nbformat                 5.7.3               
nest-asyncio             1.5.6               
netifaces                0.10.4              
notebook                 6.5.2               
notebook-shim            0.2.2               
numpy                    1.21.3              
oauthlib                 3.1.0               
opt-einsum               3.3.0               
packaging                23.0                
pandas                   1.3.4               
pandocfilters            1.5.0               
parso                    0.8.3               
pexpect                  4.6.0               
pickleshare              0.7.5               
Pillow                   8.4.0               
pip                      20.0.2              
pkgutil-resolve-name     1.3.10              
platformdirs             2.6.2               
prometheus-client        0.16.0              
prompt-toolkit           3.0.36              
psutil                   5.9.4               
ptyprocess               0.7.0               
pure-eval                0.2.2               
pyasn1                   0.4.2               
pyasn1-modules           0.2.1               
pycparser                2.21                
Pygments                 2.14.0              
PyGObject                3.36.0              
PyHamcrest               1.9.0               
PyJWT                    1.7.1               
pymacaroons              0.13.0              
PyNaCl                   1.3.0               
pyOpenSSL                19.0.0              
pyparsing                3.0.4               
pyrsistent               0.15.5              
pyserial                 3.4                 
python-apt               2.0.1+ubuntu0.20.4.1
python-dateutil          2.8.2               
python-debian            0.1.36+ubuntu1.1    
python-json-logger       2.0.4               
pytz                     2021.3              
PyYAML                   5.3.1               
pyzmq                    25.0.0              
qtconsole                5.4.4               
QtPy                     2.4.0               
requests                 2.28.2              
requests-unixsocket      0.2.0               
rfc3339-validator        0.1.4               
rfc3986-validator        0.1.1               
scipy                    1.10.1              
SecretStorage            2.3.1               
Send2Trash               1.8.0               
service-identity         18.1.0              
setuptools               45.2.0              
simplejson               3.16.0              
six                      1.14.0              
sniffio                  1.3.0               
sos                      4.4                 
soupsieve                2.3.2.post1         
ssh-import-id            5.10                
stack-data               0.6.2               
systemd-python           234                 
terminado                0.17.1              
tfp-nightly              0.22.0.dev20231002  
tinycss2                 1.2.1               
tomli                    2.0.1               
tornado                  6.2                 
traitlets                5.9.0               
Twisted                  18.9.0              
typing-extensions        4.5.0               
ubuntu-advantage-tools   8001                
ufw                      0.36                
unattended-upgrades      0.1                 
urllib3                  1.25.8              
virtualenv               20.17.1             
wadllib                  1.3.3               
wcwidth                  0.2.6               
webencodings             0.5.1               
websocket-client         1.5.0               
wheel                    0.34.2              
widgetsnbextension       4.0.9               
zipp                     1.0.0               
zope.interface           4.7.1

如果我将正态分布更改为Gamma分布,我将不再得到错误。你知道这是为什么吗代码片段直接取自tensorflow网站。谢谢!2谢谢!

kyks70gy

kyks70gy1#

jax.dtypes.prng_key是在JAX 0.4.14版本中添加的。您应该将JAX更新到较新的版本(0.4.14或更高版本),或者如果这是不可能的,则将tensorflow_probability降级到较旧的版本(0.21.0或更旧版本应该足以解决此问题)。

相关问题