hmm.log 3.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. Traceback (most recent call last):
  2. File "/opt/anaconda3/lib/python3.8/site-packages/jupyter_cache/executors/utils.py", line 51, in single_nb_execution
  3. executenb(
  4. File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 1087, in execute
  5. return NotebookClient(nb=nb, resources=resources, km=km, **kwargs).execute()
  6. File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 74, in wrapped
  7. return just_run(coro(*args, **kwargs))
  8. File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 53, in just_run
  9. return loop.run_until_complete(coro)
  10. File "/opt/anaconda3/lib/python3.8/asyncio/base_events.py", line 616, in run_until_complete
  11. return future.result()
  12. File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 540, in async_execute
  13. await self.async_execute_cell(
  14. File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 832, in async_execute_cell
  15. self._check_raise_for_error(cell, exec_reply)
  16. File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 740, in _check_raise_for_error
  17. raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content'])
  18. nbclient.exceptions.CellExecutionError: An error occurred while executing the following cell:
  19. ------------------
  20. import collections
  21. def compute_counts(state_seq, nstates):
  22. wseq = np.array(state_seq)
  23. word_pairs = [pair for pair in zip(wseq[:-1], wseq[1:])]
  24. counter_pairs = collections.Counter(word_pairs)
  25. counts = np.zeros((nstates, nstates))
  26. for (k,v) in counter_pairs.items():
  27. counts[k[0], k[1]] = v
  28. return counts
  29. def normalize(u, axis=0, eps=1e-15):
  30. u = jnp.where(u == 0, 0, jnp.where(u < eps, eps, u))
  31. c = u.sum(axis=axis)
  32. c = jnp.where(c == 0, 1, c)
  33. return u / c, c
  34. def normalize_counts(counts):
  35. ncounts = vmap(lambda v: normalize(v)[0], in_axes=0)(counts)
  36. return ncounts
  37. init_dist = jnp.array([1.0, 0.0])
  38. trans_mat = jnp.array([[0.7, 0.3], [0.5, 0.5]])
  39. obs_mat = jnp.eye(2)
  40. hmm = HMM(trans_dist=distrax.Categorical(probs=trans_mat),
  41. init_dist=distrax.Categorical(probs=init_dist),
  42. obs_dist=distrax.Categorical(probs=obs_mat))
  43. rng_key = jax.random.PRNGKey(0)
  44. seq_len = 500
  45. state_seq, _ = hmm.sample(seed=PRNGKey(seed), seq_len=seq_len)
  46. counts = compute_counts(state_seq, nstates=2)
  47. print(counts)
  48. trans_mat_empirical = normalize_counts(counts)
  49. print(trans_mat_empirical)
  50. assert jnp.allclose(trans_mat, trans_mat_empirical, atol=1e-1)
  51. ------------------
  52. ---------------------------------------------------------------------------
  53. NameError Traceback (most recent call last)
  54. <ipython-input-6-f054683fcd82> in <module>
  55.  30 rng_key = jax.random.PRNGKey(0)
  56.  31 seq_len = 500
  57. ---> 32 state_seq, _ = hmm.sample(seed=PRNGKey(seed), seq_len=seq_len)
  58.  33 
  59.  34 counts = compute_counts(state_seq, nstates=2)
  60. NameError: name 'PRNGKey' is not defined
  61. NameError: name 'PRNGKey' is not defined