Traceback (most recent call last): File "/opt/anaconda3/lib/python3.8/site-packages/jupyter_cache/executors/utils.py", line 51, in single_nb_execution executenb( File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 1087, in execute return NotebookClient(nb=nb, resources=resources, km=km, **kwargs).execute() File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 74, in wrapped return just_run(coro(*args, **kwargs)) File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 53, in just_run return loop.run_until_complete(coro) File "/opt/anaconda3/lib/python3.8/asyncio/base_events.py", line 616, in run_until_complete return future.result() File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 540, in async_execute await self.async_execute_cell( File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 832, in async_execute_cell self._check_raise_for_error(cell, exec_reply) File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 740, in _check_raise_for_error raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content']) nbclient.exceptions.CellExecutionError: An error occurred while executing the following cell: ------------------ # state transition matrix A = np.array([ [0.95, 0.05], [0.10, 0.90] ]) # observation matrix B = np.array([ [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die ]) pi = np.array([0.5, 0.5]) (nstates, nobs) = np.shape(B) import distrax from distrax import HMM hmm = HMM(trans_dist=distrax.Categorical(probs=A), init_dist=distrax.Categorical(probs=pi), obs_dist=distrax.Categorical(probs=B)) seed = 314 n_samples = 300 z_hist, x_hist = hmm.sample(seed=PRNGKey(seed), seq_len=n_samples) ------------------ --------------------------------------------------------------------------- AttributeError Traceback (most recent call last)  in   15 (nstates, nobs) = np.shape(B)  16  ---> 17 import distrax  18 from distrax import HMM  19  /opt/anaconda3/lib/python3.8/site-packages/distrax/__init__.py in   18 from distrax._src.bijectors.bijector import Bijector  19 from distrax._src.bijectors.bijector import BijectorLike ---> 20 from distrax._src.bijectors.block import Block  21 from distrax._src.bijectors.chain import Chain  22 from distrax._src.bijectors.gumbel_cdf import GumbelCDF /opt/anaconda3/lib/python3.8/site-packages/distrax/_src/bijectors/block.py in   18   19 from distrax._src.bijectors import bijector as base ---> 20 from distrax._src.utils import conversion  21 from distrax._src.utils import math  22  /opt/anaconda3/lib/python3.8/site-packages/distrax/_src/utils/conversion.py in   17 from typing import Optional, Union  18  ---> 19 import chex  20 from distrax._src.bijectors import bijector  21 from distrax._src.bijectors import bijector_from_tfp /opt/anaconda3/lib/python3.8/site-packages/chex/__init__.py in   15 """Chex: Testing made fun, in JAX!"""  16  ---> 17 from chex._src.asserts import assert_axis_dimension  18 from chex._src.asserts import assert_axis_dimension_gt  19 from chex._src.asserts import assert_devices_available /opt/anaconda3/lib/python3.8/site-packages/chex/_src/asserts.py in   26   27 from chex._src import asserts_internal as ai ---> 28 from chex._src import pytypes  29 import jax  30 import jax.numpy as jnp /opt/anaconda3/lib/python3.8/site-packages/chex/_src/pytypes.py in   37 Shape = Tuple[int, ...]  38  ---> 39 CpuDevice = jax.lib.xla_extension.CpuDevice  40 GpuDevice = jax.lib.xla_extension.GpuDevice  41 TpuDevice = jax.lib.xla_extension.TpuDevice AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice' AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'