Traceback (most recent call last): File "/opt/anaconda3/lib/python3.8/site-packages/jupyter_cache/executors/utils.py", line 51, in single_nb_execution executenb( File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 1087, in execute return NotebookClient(nb=nb, resources=resources, km=km, **kwargs).execute() File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 74, in wrapped return just_run(coro(*args, **kwargs)) File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 53, in just_run return loop.run_until_complete(coro) File "/opt/anaconda3/lib/python3.8/asyncio/base_events.py", line 616, in run_until_complete return future.result() File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 540, in async_execute await self.async_execute_cell( File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 832, in async_execute_cell self._check_raise_for_error(cell, exec_reply) File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 740, in _check_raise_for_error raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content']) nbclient.exceptions.CellExecutionError: An error occurred while executing the following cell: ------------------ ### Import standard libraries import abc from dataclasses import dataclass import functools from functools import partial import itertools import matplotlib.pyplot as plt import numpy as np from typing import Any, Callable, NamedTuple, Optional, Union, Tuple import jax import jax.numpy as jnp from jax import lax, vmap, jit, grad #from jax.scipy.special import logit #from jax.nn import softmax import jax.random as jr import distrax import optax import jsl import ssm_jax ------------------ --------------------------------------------------------------------------- AttributeError Traceback (most recent call last)  in   19   20  ---> 21 import distrax  22 import optax  23  /opt/anaconda3/lib/python3.8/site-packages/distrax/__init__.py in   18 from distrax._src.bijectors.bijector import Bijector  19 from distrax._src.bijectors.bijector import BijectorLike ---> 20 from distrax._src.bijectors.block import Block  21 from distrax._src.bijectors.chain import Chain  22 from distrax._src.bijectors.gumbel_cdf import GumbelCDF /opt/anaconda3/lib/python3.8/site-packages/distrax/_src/bijectors/block.py in   18   19 from distrax._src.bijectors import bijector as base ---> 20 from distrax._src.utils import conversion  21 from distrax._src.utils import math  22  /opt/anaconda3/lib/python3.8/site-packages/distrax/_src/utils/conversion.py in   17 from typing import Optional, Union  18  ---> 19 import chex  20 from distrax._src.bijectors import bijector  21 from distrax._src.bijectors import bijector_from_tfp /opt/anaconda3/lib/python3.8/site-packages/chex/__init__.py in   15 """Chex: Testing made fun, in JAX!"""  16  ---> 17 from chex._src.asserts import assert_axis_dimension  18 from chex._src.asserts import assert_axis_dimension_gt  19 from chex._src.asserts import assert_devices_available /opt/anaconda3/lib/python3.8/site-packages/chex/_src/asserts.py in   26   27 from chex._src import asserts_internal as ai ---> 28 from chex._src import pytypes  29 import jax  30 import jax.numpy as jnp /opt/anaconda3/lib/python3.8/site-packages/chex/_src/pytypes.py in   37 Shape = Tuple[int, ...]  38  ---> 39 CpuDevice = jax.lib.xla_extension.CpuDevice  40 GpuDevice = jax.lib.xla_extension.GpuDevice  41 TpuDevice = jax.lib.xla_extension.TpuDevice AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice' AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'