Traceback (most recent call last): File "/opt/anaconda3/lib/python3.8/site-packages/jupyter_cache/executors/utils.py", line 51, in single_nb_execution executenb( File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 1087, in execute return NotebookClient(nb=nb, resources=resources, km=km, **kwargs).execute() File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 74, in wrapped return just_run(coro(*args, **kwargs)) File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 53, in just_run return loop.run_until_complete(coro) File "/opt/anaconda3/lib/python3.8/asyncio/base_events.py", line 616, in run_until_complete return future.result() File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 540, in async_execute await self.async_execute_cell( File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 832, in async_execute_cell self._check_raise_for_error(cell, exec_reply) File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 740, in _check_raise_for_error raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content']) nbclient.exceptions.CellExecutionError: An error occurred while executing the following cell: ------------------ # meta-data does not work yet in VScode # https://github.com/microsoft/vscode-jupyter/issues/1121 { "tags": [ "hide-cell" ] } ### Install necessary libraries try: import jax except: # For cuda version, see https://github.com/google/jax#installation %pip install --upgrade "jax[cpu]" import jax try: import distrax except: %pip install --upgrade distrax import distrax try: import jsl except: %pip install git+https://github.com/probml/jsl import jsl try: import rich except: %pip install rich import rich ------------------ --------------------------------------------------------------------------- AttributeError Traceback (most recent call last)  in   20 try: ---> 21 import distrax  22 except: /opt/anaconda3/lib/python3.8/site-packages/distrax/__init__.py in   19 from distrax._src.bijectors.bijector import BijectorLike ---> 20 from distrax._src.bijectors.block import Block  21 from distrax._src.bijectors.chain import Chain /opt/anaconda3/lib/python3.8/site-packages/distrax/_src/bijectors/block.py in   19 from distrax._src.bijectors import bijector as base ---> 20 from distrax._src.utils import conversion  21 from distrax._src.utils import math /opt/anaconda3/lib/python3.8/site-packages/distrax/_src/utils/conversion.py in   18  ---> 19 import chex  20 from distrax._src.bijectors import bijector /opt/anaconda3/lib/python3.8/site-packages/chex/__init__.py in   16  ---> 17 from chex._src.asserts import assert_axis_dimension  18 from chex._src.asserts import assert_axis_dimension_gt /opt/anaconda3/lib/python3.8/site-packages/chex/_src/asserts.py in   27 from chex._src import asserts_internal as ai ---> 28 from chex._src import pytypes  29 import jax /opt/anaconda3/lib/python3.8/site-packages/chex/_src/pytypes.py in   38  ---> 39 CpuDevice = jax.lib.xla_extension.CpuDevice  40 GpuDevice = jax.lib.xla_extension.GpuDevice AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice' During handling of the above exception, another exception occurred: AttributeError Traceback (most recent call last)  in   22 except:  23 get_ipython().run_line_magic('pip', 'install --upgrade distrax') ---> 24 import distrax  25   26 try: /opt/anaconda3/lib/python3.8/site-packages/distrax/__init__.py in   18 from distrax._src.bijectors.bijector import Bijector  19 from distrax._src.bijectors.bijector import BijectorLike ---> 20 from distrax._src.bijectors.block import Block  21 from distrax._src.bijectors.chain import Chain  22 from distrax._src.bijectors.gumbel_cdf import GumbelCDF /opt/anaconda3/lib/python3.8/site-packages/distrax/_src/bijectors/block.py in   18   19 from distrax._src.bijectors import bijector as base ---> 20 from distrax._src.utils import conversion  21 from distrax._src.utils import math  22  /opt/anaconda3/lib/python3.8/site-packages/distrax/_src/utils/conversion.py in   17 from typing import Optional, Union  18  ---> 19 import chex  20 from distrax._src.bijectors import bijector  21 from distrax._src.bijectors import bijector_from_tfp /opt/anaconda3/lib/python3.8/site-packages/chex/__init__.py in   15 """Chex: Testing made fun, in JAX!"""  16  ---> 17 from chex._src.asserts import assert_axis_dimension  18 from chex._src.asserts import assert_axis_dimension_gt  19 from chex._src.asserts import assert_devices_available /opt/anaconda3/lib/python3.8/site-packages/chex/_src/asserts.py in   26   27 from chex._src import asserts_internal as ai ---> 28 from chex._src import pytypes  29 import jax  30 import jax.numpy as jnp /opt/anaconda3/lib/python3.8/site-packages/chex/_src/pytypes.py in   37 Shape = Tuple[int, ...]  38  ---> 39 CpuDevice = jax.lib.xla_extension.CpuDevice  40 GpuDevice = jax.lib.xla_extension.GpuDevice  41 TpuDevice = jax.lib.xla_extension.TpuDevice AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice' AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'