Where communities thrive


  • Join over 1.5M+ people
  • Join over 100K+ communities
  • Free without limits
  • Create your own community
People
Activity
    Kaustubh
    @kc611:matrix.org
    [m]
    The jax file structure is pretty much same except for dispatch being in same file. We should probably split it like Numba.
    remilouf
    @remilouf:matrix.org
    [m]
    Thanks! Oh and 3- why is the code generation approach used with numba and not (that I saw) with jax?
    Kaustubh
    @kc611:matrix.org
    [m]

    Probably because Jax offers much more flexibility than Numba.

    For instance take the case of Scan. Jax has an inbuilt scan like functionality, but in case of Numba we have to create the loops manually.

    We can use code generation approach in Jax too, but i think we're yet to run into logic that cannot be implemented in Jax.

    1 reply
    remilouf
    @remilouf:matrix.org
    [m]
    That makes sense
    Kaustubh
    @kc611:matrix.org
    [m]
    Regarding debugging I don't know what's the usual way to go about it. For me I just put a break point in the compile_function_src. Which gives me source string for required function.
    That function is located in link/utils.py I think
    The problem is some of the functions aren't generated with strings while some of them are. So it's pretty hard to get a uniform debug functionality on them.
    1 reply
    brandonwillard
    @brandonwillard:matrix.org
    [m]
    and both are technically taking the exact same approach
    the difference is the interface
    at least in the case of Scan
    remilouf
    @remilouf:matrix.org
    [m]
    I think he meant most operations can be expressed as a function
    Kaustubh
    @kc611:matrix.org
    [m]
    Yeah exactly.
    remilouf
    @remilouf:matrix.org
    [m]
    including looping / branching logic
    brandonwillard
    @brandonwillard:matrix.org
    [m]
    that has more to do with JAX's limitations (or scope/purpose), though
    remilouf
    @remilouf:matrix.org
    [m]
    true
    brandonwillard
    @brandonwillard:matrix.org
    [m]
    JAXs serves the same set of operations as Aesara
    it's a tensor library
    Numba is much more general
    attempting to cover nearly all of Python
    but, yeah, I see what you're saying
    it's a limitation in terms of mapping simplicity and the like
    or something of that nature
    but Numba's generality is what will ultimately make it considerably easier to write high performance Ops
    e.g. like writing standard Python Ops
    Kaustubh
    @kc611:matrix.org
    [m]
    You mean lowering of those Ops?
    brandonwillard
    @brandonwillard:matrix.org
    [m]
    I'm saying that Op.perform implementations will ultimately be able to serve as the only necessary Op implementation code
    we might need/want to refactor the interface a bit, and the Python needs to be Numba compatible
    but that's the idea
    remilouf
    @remilouf:matrix.org
    [m]
    what interface are you talking about?
    brandonwillard
    @brandonwillard:matrix.org
    [m]
    the Op interface
    like Op.perform
    what it takes as arguments and what it returns/does
    the current design is built around Theano's old ideas
    e.g. that lists serve as "pointers" to memory that Theano manages manually
    it's neat, but not particularly good
    Ricardo Vieira
    @ricardov94:matrix.org
    [m]
    Anyone has access to a Windows machine to see if this issue replicates? aesara-devs/aesara#707
    brandonwillard
    @brandonwillard:matrix.org
    [m]
    I can try to spin up a VM
    Ricardo Vieira
    @ricardov94:matrix.org
    [m]

    brandonwillard: I tried to follow your blogpost on DLMs, but I didn't succeed in getting auto_updates to work with this minimal graph:

    rng = aesara.shared(np.random.default_rng(), borrow=True)
    rng.tag.is_rng = True
    rng.default_update = rng
    
    def step(v_tm1, rng):
        v_t = at.random.normal(v_tm1, 1.0, name='v', rng=rng)
        return v_t
    
    v, updates = aesara.scan(
        fn=step,
        outputs_info=[np.array(0.0)],
        non_sequences=[rng],
        n_steps=5,
        strict=True,
    )
    
    v_draws = aesara.function([], v, updates=updates)()
    assert len(np.unique(np.diff(v_draws))) > 1  # All values have the same offset

    Any thing obvious I am missing?

    brandonwillard
    @brandonwillard:matrix.org
    [m]
    I don't think the old Theano was necessarily using the auto updates in that case
    I believe it was in-placing the shared RNG objects
    if you do something like the following, it should work, though:
    for key, value in updates.items():
        key.default_update = value
    Ricardo Vieira
    @ricardov94:matrix.org
    [m]

    :point_up: Edit: brandonwillard: I tried to follow your blogpost on DLMs, but I didn't succeed in getting auto_updates to work with this minimal graph:

    rng = aesara.shared(np.random.default_rng(), borrow=True)
    rng.tag.is_rng = True
    rng.default_update = rng
    
    def step(v_tm1, rng):
        v_t = at.random.normal(v_tm1, 1.0, name='v', rng=rng)
        return v_t
    
    v, updates = aesara.scan(
        fn=step,
        outputs_info=[np.array(0.0)],
        non_sequences=[rng],
        n_steps=5,
        strict=True,
    )
    
    for key, value in updates.items():
        key.default_update = value
    
    v_draws = aesara.function([], v, updates=updates)()
    print(np.diff(v_draws))

    Any thing obvious I am missing?

    :point_up: Edit: brandonwillard: I tried to follow your blogpost on DLMs, but I didn't succeed in getting auto_updates to work with this minimal graph:

    rng = aesara.shared(np.random.default_rng(), borrow=True)
    rng.tag.is_rng = True
    rng.default_update = rng
    
    def step(v_tm1, rng):
        v_t = at.random.normal(v_tm1, 1.0, name='v', rng=rng)
        return v_t
    
    v, updates = aesara.scan(
        fn=step,
        outputs_info=[np.array(0.0)],
        non_sequences=[rng],
        n_steps=5,
        strict=True,
    )
    
    for key, value in updates.items():
        key.default_update = value
    
    v_draws = aesara.function([], v, updates=updates)()
    print(np.diff(v_draws))  # [0.57665853 0.57665853 0.57665853 0.57665853]

    Any thing obvious I am missing?

    Still not working. I updated the example
    brandonwillard
    @brandonwillard:matrix.org
    [m]
    what exactly isn't working?
    Ricardo Vieira
    @ricardov94:matrix.org
    [m]
    All the samples are the same (constant offset from v_tm1)
    brandonwillard
    @brandonwillard:matrix.org
    [m]
    ohh
    using non_sequences with the RNG, right?
    Ricardo Vieira
    @ricardov94:matrix.org
    [m]

    :point_up: Edit: brandonwillard: I tried to follow your blogpost on DLMs, but I didn't succeed in getting auto_updates to work with this minimal graph:

    import numpy as np
    import aesara
    import aesara.tensor as at
    
    rng = aesara.shared(np.random.default_rng(), borrow=True)
    rng.tag.is_rng = True
    rng.default_update = rng
    
    def step(v_tm1, rng):
        v_t = at.random.normal(v_tm1, 1.0, name='v', rng=rng)
        return v_t
    
    v, updates = aesara.scan(
        fn=step,
        outputs_info=[np.array(0.0)],
        non_sequences=[rng],
        n_steps=5,
        strict=True,
    )
    
    for key, value in updates.items():
        key.default_update = value
    
    v_draws = aesara.function
    print(np.diff(v_draws))  # [0.57665853 0.57665853 0.57665853 0.57665853]

    Any thing obvious I am missing?

    Not supposed to do that?