Probably because Jax offers much more flexibility than Numba.
For instance take the case of Scan. Jax has an inbuilt scan like functionality, but in case of Numba we have to create the loops manually.
We can use code generation approach in Jax too, but i think we're yet to run into logic that cannot be implemented in Jax.
Scan
Op
s
Op
s
Op.perform
implementations will ultimately be able to serve as the only necessary Op
implementation code
Op.perform
list
s serve as "pointers" to memory that Theano manages manually
brandonwillard: I tried to follow your blogpost on DLMs, but I didn't succeed in getting auto_updates to work with this minimal graph:
rng = aesara.shared(np.random.default_rng(), borrow=True)
rng.tag.is_rng = True
rng.default_update = rng
def step(v_tm1, rng):
v_t = at.random.normal(v_tm1, 1.0, name='v', rng=rng)
return v_t
v, updates = aesara.scan(
fn=step,
outputs_info=[np.array(0.0)],
non_sequences=[rng],
n_steps=5,
strict=True,
)
v_draws = aesara.function([], v, updates=updates)()
assert len(np.unique(np.diff(v_draws))) > 1 # All values have the same offset
Any thing obvious I am missing?
for key, value in updates.items():
key.default_update = value
:point_up: Edit: brandonwillard: I tried to follow your blogpost on DLMs, but I didn't succeed in getting auto_updates to work with this minimal graph:
rng = aesara.shared(np.random.default_rng(), borrow=True)
rng.tag.is_rng = True
rng.default_update = rng
def step(v_tm1, rng):
v_t = at.random.normal(v_tm1, 1.0, name='v', rng=rng)
return v_t
v, updates = aesara.scan(
fn=step,
outputs_info=[np.array(0.0)],
non_sequences=[rng],
n_steps=5,
strict=True,
)
for key, value in updates.items():
key.default_update = value
v_draws = aesara.function([], v, updates=updates)()
print(np.diff(v_draws))
Any thing obvious I am missing?
:point_up: Edit: brandonwillard: I tried to follow your blogpost on DLMs, but I didn't succeed in getting auto_updates to work with this minimal graph:
rng = aesara.shared(np.random.default_rng(), borrow=True)
rng.tag.is_rng = True
rng.default_update = rng
def step(v_tm1, rng):
v_t = at.random.normal(v_tm1, 1.0, name='v', rng=rng)
return v_t
v, updates = aesara.scan(
fn=step,
outputs_info=[np.array(0.0)],
non_sequences=[rng],
n_steps=5,
strict=True,
)
for key, value in updates.items():
key.default_update = value
v_draws = aesara.function([], v, updates=updates)()
print(np.diff(v_draws)) # [0.57665853 0.57665853 0.57665853 0.57665853]
Any thing obvious I am missing?
non_sequences
with the RNG, right?
:point_up: Edit: brandonwillard: I tried to follow your blogpost on DLMs, but I didn't succeed in getting auto_updates to work with this minimal graph:
import numpy as np
import aesara
import aesara.tensor as at
rng = aesara.shared(np.random.default_rng(), borrow=True)
rng.tag.is_rng = True
rng.default_update = rng
def step(v_tm1, rng):
v_t = at.random.normal(v_tm1, 1.0, name='v', rng=rng)
return v_t
v, updates = aesara.scan(
fn=step,
outputs_info=[np.array(0.0)],
non_sequences=[rng],
n_steps=5,
strict=True,
)
for key, value in updates.items():
key.default_update = value
v_draws = aesara.function
print(np.diff(v_draws)) # [0.57665853 0.57665853 0.57665853 0.57665853]
Any thing obvious I am missing?