inference.log 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. Traceback (most recent call last):
  2. File "/opt/anaconda3/lib/python3.8/site-packages/jupyter_cache/executors/utils.py", line 51, in single_nb_execution
  3. executenb(
  4. File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 1087, in execute
  5. return NotebookClient(nb=nb, resources=resources, km=km, **kwargs).execute()
  6. File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 74, in wrapped
  7. return just_run(coro(*args, **kwargs))
  8. File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 53, in just_run
  9. return loop.run_until_complete(coro)
  10. File "/opt/anaconda3/lib/python3.8/asyncio/base_events.py", line 616, in run_until_complete
  11. return future.result()
  12. File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 540, in async_execute
  13. await self.async_execute_cell(
  14. File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 832, in async_execute_cell
  15. self._check_raise_for_error(cell, exec_reply)
  16. File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 740, in _check_raise_for_error
  17. raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content'])
  18. nbclient.exceptions.CellExecutionError: An error occurred while executing the following cell:
  19. ------------------
  20. # state transition matrix
  21. A = np.array([
  22. [0.95, 0.05],
  23. [0.10, 0.90]
  24. ])
  25. # observation matrix
  26. B = np.array([
  27. [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die
  28. [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die
  29. ])
  30. pi = np.array([0.5, 0.5])
  31. (nstates, nobs) = np.shape(B)
  32. import distrax
  33. from distrax import HMM
  34. hmm = HMM(trans_dist=distrax.Categorical(probs=A),
  35. init_dist=distrax.Categorical(probs=pi),
  36. obs_dist=distrax.Categorical(probs=B))
  37. seed = 314
  38. n_samples = 300
  39. z_hist, x_hist = hmm.sample(seed=PRNGKey(seed), seq_len=n_samples)
  40. ------------------
  41. ---------------------------------------------------------------------------
  42. AttributeError Traceback (most recent call last)
  43. <ipython-input-2-155367c5b347> in <module>
  44.  15 (nstates, nobs) = np.shape(B)
  45.  16 
  46. ---> 17 import distrax
  47.  18 from distrax import HMM
  48.  19 
  49. /opt/anaconda3/lib/python3.8/site-packages/distrax/__init__.py in <module>
  50.  18 from distrax._src.bijectors.bijector import Bijector
  51.  19 from distrax._src.bijectors.bijector import BijectorLike
  52. ---> 20 from distrax._src.bijectors.block import Block
  53.  21 from distrax._src.bijectors.chain import Chain
  54.  22 from distrax._src.bijectors.gumbel_cdf import GumbelCDF
  55. /opt/anaconda3/lib/python3.8/site-packages/distrax/_src/bijectors/block.py in <module>
  56.  18 
  57.  19 from distrax._src.bijectors import bijector as base
  58. ---> 20 from distrax._src.utils import conversion
  59.  21 from distrax._src.utils import math
  60.  22 
  61. /opt/anaconda3/lib/python3.8/site-packages/distrax/_src/utils/conversion.py in <module>
  62.  17 from typing import Optional, Union
  63.  18 
  64. ---> 19 import chex
  65.  20 from distrax._src.bijectors import bijector
  66.  21 from distrax._src.bijectors import bijector_from_tfp
  67. /opt/anaconda3/lib/python3.8/site-packages/chex/__init__.py in <module>
  68.  15 """Chex: Testing made fun, in JAX!"""
  69.  16 
  70. ---> 17 from chex._src.asserts import assert_axis_dimension
  71.  18 from chex._src.asserts import assert_axis_dimension_gt
  72.  19 from chex._src.asserts import assert_devices_available
  73. /opt/anaconda3/lib/python3.8/site-packages/chex/_src/asserts.py in <module>
  74.  26 
  75.  27 from chex._src import asserts_internal as ai
  76. ---> 28 from chex._src import pytypes
  77.  29 import jax
  78.  30 import jax.numpy as jnp
  79. /opt/anaconda3/lib/python3.8/site-packages/chex/_src/pytypes.py in <module>
  80.  37 Shape = Tuple[int, ...]
  81.  38 
  82. ---> 39 CpuDevice = jax.lib.xla_extension.CpuDevice
  83.  40 GpuDevice = jax.lib.xla_extension.GpuDevice
  84.  41 TpuDevice = jax.lib.xla_extension.TpuDevice
  85. AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'
  86. AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'