12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576 |
- Traceback (most recent call last):
- File "/opt/anaconda3/lib/python3.8/site-packages/jupyter_cache/executors/utils.py", line 51, in single_nb_execution
- executenb(
- File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 1087, in execute
- return NotebookClient(nb=nb, resources=resources, km=km, **kwargs).execute()
- File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 74, in wrapped
- return just_run(coro(*args, **kwargs))
- File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 53, in just_run
- return loop.run_until_complete(coro)
- File "/opt/anaconda3/lib/python3.8/asyncio/base_events.py", line 616, in run_until_complete
- return future.result()
- File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 540, in async_execute
- await self.async_execute_cell(
- File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 832, in async_execute_cell
- self._check_raise_for_error(cell, exec_reply)
- File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 740, in _check_raise_for_error
- raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content'])
- nbclient.exceptions.CellExecutionError: An error occurred while executing the following cell:
- ------------------
- import collections
- def compute_counts(state_seq, nstates):
- wseq = np.array(state_seq)
- word_pairs = [pair for pair in zip(wseq[:-1], wseq[1:])]
- counter_pairs = collections.Counter(word_pairs)
- counts = np.zeros((nstates, nstates))
- for (k,v) in counter_pairs.items():
- counts[k[0], k[1]] = v
- return counts
- def normalize(u, axis=0, eps=1e-15):
- u = jnp.where(u == 0, 0, jnp.where(u < eps, eps, u))
- c = u.sum(axis=axis)
- c = jnp.where(c == 0, 1, c)
- return u / c, c
- def normalize_counts(counts):
- ncounts = vmap(lambda v: normalize(v)[0], in_axes=0)(counts)
- return ncounts
- init_dist = jnp.array([1.0, 0.0])
- trans_mat = jnp.array([[0.7, 0.3], [0.5, 0.5]])
- obs_mat = jnp.eye(2)
- hmm = HMM(trans_dist=distrax.Categorical(probs=trans_mat),
- init_dist=distrax.Categorical(probs=init_dist),
- obs_dist=distrax.Categorical(probs=obs_mat))
- rng_key = jax.random.PRNGKey(0)
- seq_len = 500
- state_seq, _ = hmm.sample(seed=PRNGKey(seed), seq_len=seq_len)
- counts = compute_counts(state_seq, nstates=2)
- print(counts)
- trans_mat_empirical = normalize_counts(counts)
- print(trans_mat_empirical)
- assert jnp.allclose(trans_mat, trans_mat_empirical, atol=1e-1)
- ------------------
- [0;31m---------------------------------------------------------------------------[0m
- [0;31mNameError[0m Traceback (most recent call last)
- [0;32m<ipython-input-6-f054683fcd82>[0m in [0;36m<module>[0;34m[0m
- [1;32m 30[0m [0mrng_key[0m [0;34m=[0m [0mjax[0m[0;34m.[0m[0mrandom[0m[0;34m.[0m[0mPRNGKey[0m[0;34m([0m[0;36m0[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
- [1;32m 31[0m [0mseq_len[0m [0;34m=[0m [0;36m500[0m[0;34m[0m[0;34m[0m[0m
- [0;32m---> 32[0;31m [0mstate_seq[0m[0;34m,[0m [0m_[0m [0;34m=[0m [0mhmm[0m[0;34m.[0m[0msample[0m[0;34m([0m[0mseed[0m[0;34m=[0m[0mPRNGKey[0m[0;34m([0m[0mseed[0m[0;34m)[0m[0;34m,[0m [0mseq_len[0m[0;34m=[0m[0mseq_len[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
- [0m[1;32m 33[0m [0;34m[0m[0m
- [1;32m 34[0m [0mcounts[0m [0;34m=[0m [0mcompute_counts[0m[0;34m([0m[0mstate_seq[0m[0;34m,[0m [0mnstates[0m[0;34m=[0m[0;36m2[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
- [0;31mNameError[0m: name 'PRNGKey' is not defined
- NameError: name 'PRNGKey' is not defined
|