Traceback (most recent call last): File "/opt/anaconda3/lib/python3.8/site-packages/jupyter_cache/executors/utils.py", line 51, in single_nb_execution executenb( File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 1087, in execute return NotebookClient(nb=nb, resources=resources, km=km, **kwargs).execute() File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 74, in wrapped return just_run(coro(*args, **kwargs)) File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 53, in just_run return loop.run_until_complete(coro) File "/opt/anaconda3/lib/python3.8/asyncio/base_events.py", line 616, in run_until_complete return future.result() File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 540, in async_execute await self.async_execute_cell( File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 832, in async_execute_cell self._check_raise_for_error(cell, exec_reply) File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 740, in _check_raise_for_error raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content']) nbclient.exceptions.CellExecutionError: An error occurred while executing the following cell: ------------------ import collections def compute_counts(state_seq, nstates): wseq = np.array(state_seq) word_pairs = [pair for pair in zip(wseq[:-1], wseq[1:])] counter_pairs = collections.Counter(word_pairs) counts = np.zeros((nstates, nstates)) for (k,v) in counter_pairs.items(): counts[k[0], k[1]] = v return counts def normalize(u, axis=0, eps=1e-15): u = jnp.where(u == 0, 0, jnp.where(u < eps, eps, u)) c = u.sum(axis=axis) c = jnp.where(c == 0, 1, c) return u / c, c def normalize_counts(counts): ncounts = vmap(lambda v: normalize(v)[0], in_axes=0)(counts) return ncounts init_dist = jnp.array([1.0, 0.0]) trans_mat = jnp.array([[0.7, 0.3], [0.5, 0.5]]) obs_mat = jnp.eye(2) hmm = HMM(trans_dist=distrax.Categorical(probs=trans_mat), init_dist=distrax.Categorical(probs=init_dist), obs_dist=distrax.Categorical(probs=obs_mat)) rng_key = jax.random.PRNGKey(0) seq_len = 500 state_seq, _ = hmm.sample(seed=PRNGKey(seed), seq_len=seq_len) counts = compute_counts(state_seq, nstates=2) print(counts) trans_mat_empirical = normalize_counts(counts) print(trans_mat_empirical) assert jnp.allclose(trans_mat, trans_mat_empirical, atol=1e-1) ------------------ --------------------------------------------------------------------------- NameError Traceback (most recent call last)  in   30 rng_key = jax.random.PRNGKey(0)  31 seq_len = 500 ---> 32 state_seq, _ = hmm.sample(seed=PRNGKey(seed), seq_len=seq_len)  33   34 counts = compute_counts(state_seq, nstates=2) NameError: name 'PRNGKey' is not defined NameError: name 'PRNGKey' is not defined