scratchpad.log 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. Traceback (most recent call last):
  2. File "/opt/anaconda3/lib/python3.8/site-packages/jupyter_cache/executors/utils.py", line 51, in single_nb_execution
  3. executenb(
  4. File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 1087, in execute
  5. return NotebookClient(nb=nb, resources=resources, km=km, **kwargs).execute()
  6. File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 74, in wrapped
  7. return just_run(coro(*args, **kwargs))
  8. File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 53, in just_run
  9. return loop.run_until_complete(coro)
  10. File "/opt/anaconda3/lib/python3.8/asyncio/base_events.py", line 616, in run_until_complete
  11. return future.result()
  12. File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 540, in async_execute
  13. await self.async_execute_cell(
  14. File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 832, in async_execute_cell
  15. self._check_raise_for_error(cell, exec_reply)
  16. File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 740, in _check_raise_for_error
  17. raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content'])
  18. nbclient.exceptions.CellExecutionError: An error occurred while executing the following cell:
  19. ------------------
  20. ### Import standard libraries
  21. import abc
  22. from dataclasses import dataclass
  23. import functools
  24. from functools import partial
  25. import itertools
  26. import matplotlib.pyplot as plt
  27. import numpy as np
  28. from typing import Any, Callable, NamedTuple, Optional, Union, Tuple
  29. import jax
  30. import jax.numpy as jnp
  31. from jax import lax, vmap, jit, grad
  32. #from jax.scipy.special import logit
  33. #from jax.nn import softmax
  34. import jax.random as jr
  35. import distrax
  36. import optax
  37. import jsl
  38. import ssm_jax
  39. ------------------
  40. ---------------------------------------------------------------------------
  41. AttributeError Traceback (most recent call last)
  42. <ipython-input-1-00ce8083638c> in <module>
  43.  19 
  44.  20 
  45. ---> 21 import distrax
  46.  22 import optax
  47.  23 
  48. /opt/anaconda3/lib/python3.8/site-packages/distrax/__init__.py in <module>
  49.  18 from distrax._src.bijectors.bijector import Bijector
  50.  19 from distrax._src.bijectors.bijector import BijectorLike
  51. ---> 20 from distrax._src.bijectors.block import Block
  52.  21 from distrax._src.bijectors.chain import Chain
  53.  22 from distrax._src.bijectors.gumbel_cdf import GumbelCDF
  54. /opt/anaconda3/lib/python3.8/site-packages/distrax/_src/bijectors/block.py in <module>
  55.  18 
  56.  19 from distrax._src.bijectors import bijector as base
  57. ---> 20 from distrax._src.utils import conversion
  58.  21 from distrax._src.utils import math
  59.  22 
  60. /opt/anaconda3/lib/python3.8/site-packages/distrax/_src/utils/conversion.py in <module>
  61.  17 from typing import Optional, Union
  62.  18 
  63. ---> 19 import chex
  64.  20 from distrax._src.bijectors import bijector
  65.  21 from distrax._src.bijectors import bijector_from_tfp
  66. /opt/anaconda3/lib/python3.8/site-packages/chex/__init__.py in <module>
  67.  15 """Chex: Testing made fun, in JAX!"""
  68.  16 
  69. ---> 17 from chex._src.asserts import assert_axis_dimension
  70.  18 from chex._src.asserts import assert_axis_dimension_gt
  71.  19 from chex._src.asserts import assert_devices_available
  72. /opt/anaconda3/lib/python3.8/site-packages/chex/_src/asserts.py in <module>
  73.  26 
  74.  27 from chex._src import asserts_internal as ai
  75. ---> 28 from chex._src import pytypes
  76.  29 import jax
  77.  30 import jax.numpy as jnp
  78. /opt/anaconda3/lib/python3.8/site-packages/chex/_src/pytypes.py in <module>
  79.  37 Shape = Tuple[int, ...]
  80.  38 
  81. ---> 39 CpuDevice = jax.lib.xla_extension.CpuDevice
  82.  40 GpuDevice = jax.lib.xla_extension.GpuDevice
  83.  41 TpuDevice = jax.lib.xla_extension.TpuDevice
  84. AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'
  85. AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'