我尝试从Windows的Ubuntu WSL运行以下Pyhon代码
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
tfd = tfp.distributions
dist = tfd.Normal(loc=0., scale=3.)
dist.cdf(1.)
dist = tfd.Normal(loc=[1, 2.], scale=[11, 22.])
dist.prob([0, 1.5])
dist.sample([3])
我收到以下错误
> AttributeError Traceback (most recent call
> last) Cell In[45], line 4
> 2 tfd = tfp.distributions
> 3 dist = tfd.Normal(loc=0., scale=3.)
> ----> 4 dist.cdf(1.)
> 5 dist = tfd.Normal(loc=[1, 2.], scale=[11, 22.])
> 6 dist.prob([0, 1.5])
>
> File
> ~/.local/lib/python3.8/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:1429,
> in Distribution.cdf(self, value, name, **kwargs) 1411 def cdf(self,
> value, name='cdf', **kwargs): 1412 """Cumulative distribution
> function. 1413 1414 Given random variable `X`, the cumulative
> distribution function `cdf` is: (...) 1427 values of type
> `self.dtype`. 1428 """
> -> 1429 return self._call_cdf(value, name, **kwargs)
>
> File
> ~/.local/lib/python3.8/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:1405,
> in Distribution._call_cdf(self, value, name, **kwargs) 1403 with
> self._name_and_control_scope(name, value, kwargs): 1404 if
> hasattr(self, '_cdf'):
> -> 1405 return self._cdf(value, **kwargs) 1406 if hasattr(self, '_log_cdf'): 1407 return
> tf.exp(self._log_cdf(value, **kwargs))
>
> File
> ~/.local/lib/python3.8/site-packages/tensorflow_probability/substrates/jax/distributions/normal.py:195,
> in Normal._cdf(self, x)
> 194 def _cdf(self, x):
> --> 195 return special_math.ndtr(self._z(x))
>
> File
> ~/.local/lib/python3.8/site-packages/tensorflow_probability/substrates/jax/internal/special_math.py:136,
> in ndtr(x, name)
> 132 if dtype_util.as_numpy_dtype(x.dtype) not in [np.float32, np.float64]:
> 133 raise TypeError(
> 134 "x.dtype=%s is not handled, see docstring for supported types."
> 135 % x.dtype)
> --> 136 return _ndtr(x)
>
> File
> ~/.local/lib/python3.8/site-packages/tensorflow_probability/substrates/jax/internal/special_math.py:141,
> in _ndtr(x)
> 139 def _ndtr(x):
> 140 """Implements ndtr core logic."""
> --> 141 half_sqrt_2 = tf.constant(
> 142 0.5 * np.sqrt(2.), dtype=x.dtype, name="half_sqrt_2")
> 143 half = tf.constant(0.5, x.dtype)
> 144 one = tf.constant(1., x.dtype)
>
> File
> ~/.local/lib/python3.8/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:117,
> in _constant(value, dtype, shape, name)
> 116 def _constant(value, dtype=None, shape=None, name='Const'): # pylint: disable=unused-argument
> --> 117 x = convert_to_tensor(value, dtype=dtype)
> 118 if shape is None:
> 119 return x
>
> File
> ~/.local/lib/python3.8/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:167,
> in _convert_to_tensor(value, dtype, dtype_hint, name)
> 164 pass
> 166 if ret is None:
> --> 167 ret = conversion_func(value, dtype=dtype)
> 168 return ret
>
> File
> ~/.local/lib/python3.8/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:222,
> in _default_convert_to_tensor(value, dtype)
> 218 """Default tensor conversion function for array, bool, int, float, and complex."""
> 219 if JAX_MODE:
> 220 # TODO(b/223267515): We shouldn't need to specialize here.
> 221 if hasattr(value, 'dtype') and jax.dtypes.issubdtype(
> --> 222 value.dtype, jax.dtypes.prng_key
> 223 ):
> 224 return value
> 225 if isinstance(value, (list, tuple)) and value:
>
> AttributeError: module 'jax.dtypes' has no attribute 'prng_key'
安装的软件包包括
> Package Version
------------------------ --------------------
absl-py 2.0.0
Ambit-Stochastics 1.0.6
anyio 3.6.2
argon2-cffi 21.3.0
argon2-cffi-bindings 21.2.0
asttokens 2.2.1
attrs 19.3.0
Automat 0.8.0
Babel 2.11.0
backcall 0.2.0
beautifulsoup4 4.11.2
bleach 6.0.0
blinker 1.4
certifi 2019.11.28
cffi 1.15.1
chardet 3.0.4
charset-normalizer 3.0.1
Click 7.0
cloud-init 23.2.1
cloudpickle 2.2.1
colorama 0.4.3
comm 0.1.2
command-not-found 0.3
configobj 5.0.6
constantly 15.1.0
cryptography 2.8
cycler 0.11.0
dbus-python 1.2.16
debugpy 1.6.6
decorator 5.1.1
defusedxml 0.7.1
distlib 0.3.6
distro 1.4.0
distro-info 0.23ubuntu1
dm-tree 0.1.8
entrypoints 0.3
env 0.1.0
executing 1.2.0
fastjsonschema 2.16.2
filelock 3.9.0
gast 0.5.4
httplib2 0.14.0
hyperlink 19.0.0
idna 2.8
importlib-metadata 6.0.0
importlib-resources 5.10.2
incremental 16.10.1
ipykernel 6.20.2
ipython 8.9.0
ipython-genutils 0.2.0
ipywidgets 8.1.1
jax 0.4.13
jaxlib 0.4.13
jedi 0.18.2
Jinja2 3.1.2
json5 0.9.11
jsonpatch 1.22
jsonpointer 2.0
jsonschema 4.17.3
jupyter 1.0.0
jupyter-client 8.0.2
jupyter-console 6.6.3
jupyter-core 5.2.0
jupyter-events 0.6.3
jupyter-server 2.2.0
jupyter-server-terminals 0.4.4
jupyterlab 3.5.3
jupyterlab-pygments 0.2.2
jupyterlab-server 2.19.0
jupyterlab-widgets 3.0.9
keyring 18.0.1
kiwisolver 1.3.2
language-selector 0.1
launchpadlib 1.10.13
lazr.restfulclient 0.14.2
lazr.uri 1.0.3
MarkupSafe 2.1.2
matplotlib 3.4.3
matplotlib-inline 0.1.6
mistune 2.0.4
ml-dtypes 0.2.0
more-itertools 4.2.0
nbclassic 0.5.1
nbclient 0.7.2
nbconvert 7.2.9
nbformat 5.7.3
nest-asyncio 1.5.6
netifaces 0.10.4
notebook 6.5.2
notebook-shim 0.2.2
numpy 1.21.3
oauthlib 3.1.0
opt-einsum 3.3.0
packaging 23.0
pandas 1.3.4
pandocfilters 1.5.0
parso 0.8.3
pexpect 4.6.0
pickleshare 0.7.5
Pillow 8.4.0
pip 20.0.2
pkgutil-resolve-name 1.3.10
platformdirs 2.6.2
prometheus-client 0.16.0
prompt-toolkit 3.0.36
psutil 5.9.4
ptyprocess 0.7.0
pure-eval 0.2.2
pyasn1 0.4.2
pyasn1-modules 0.2.1
pycparser 2.21
Pygments 2.14.0
PyGObject 3.36.0
PyHamcrest 1.9.0
PyJWT 1.7.1
pymacaroons 0.13.0
PyNaCl 1.3.0
pyOpenSSL 19.0.0
pyparsing 3.0.4
pyrsistent 0.15.5
pyserial 3.4
python-apt 2.0.1+ubuntu0.20.4.1
python-dateutil 2.8.2
python-debian 0.1.36+ubuntu1.1
python-json-logger 2.0.4
pytz 2021.3
PyYAML 5.3.1
pyzmq 25.0.0
qtconsole 5.4.4
QtPy 2.4.0
requests 2.28.2
requests-unixsocket 0.2.0
rfc3339-validator 0.1.4
rfc3986-validator 0.1.1
scipy 1.10.1
SecretStorage 2.3.1
Send2Trash 1.8.0
service-identity 18.1.0
setuptools 45.2.0
simplejson 3.16.0
six 1.14.0
sniffio 1.3.0
sos 4.4
soupsieve 2.3.2.post1
ssh-import-id 5.10
stack-data 0.6.2
systemd-python 234
terminado 0.17.1
tfp-nightly 0.22.0.dev20231002
tinycss2 1.2.1
tomli 2.0.1
tornado 6.2
traitlets 5.9.0
Twisted 18.9.0
typing-extensions 4.5.0
ubuntu-advantage-tools 8001
ufw 0.36
unattended-upgrades 0.1
urllib3 1.25.8
virtualenv 20.17.1
wadllib 1.3.3
wcwidth 0.2.6
webencodings 0.5.1
websocket-client 1.5.0
wheel 0.34.2
widgetsnbextension 4.0.9
zipp 1.0.0
zope.interface 4.7.1
如果我将正态分布更改为Gamma分布,我将不再得到错误。你知道这是为什么吗代码片段直接取自tensorflow网站。谢谢!2谢谢!
1条答案
按热度按时间kyks70gy1#
jax.dtypes.prng_key
是在JAX 0.4.14版本中添加的。您应该将JAX更新到较新的版本(0.4.14或更高版本),或者如果这是不可能的,则将tensorflow_probability降级到较旧的版本(0.21.0或更旧版本应该足以解决此问题)。