Where communities thrive


  • Join over 1.5M+ people
  • Join over 100K+ communities
  • Free without limits
  • Create your own community
People
Activity
    Ricardo Vieira
    @ricardov94:matrix.org
    [m]
    The second one. So that we can just write at.sqrt(x + at.pi) or at.switch(..., ..., at.inf)
    brandonwillard
    @brandonwillard:matrix.org
    [m]
    ok
    Ricardo Vieira
    @ricardov94:matrix.org
    [m]
    They only have a handful of constants, ignoring their own aliases: https://numpy.org/doc/stable/reference/constants.html
    Ricardo Vieira
    @ricardov94:matrix.org
    [m]
    Next aesara-family library idea aevmap?
    brandonwillard
    @brandonwillard:matrix.org
    [m]
    what does the vmap represent?
    Ricardo Vieira
    @ricardov94:matrix.org
    [m]
    JAX's vmap behavior
    brandonwillard
    @brandonwillard:matrix.org
    [m]
    what part of it?
    Ricardo Vieira
    @ricardov94:matrix.org
    [m]
    Vectorization of "arbitrary" subgraphs. I have tried this in the past to vectorize a model logp, and the best I could come up with was OpFromGraph + Scan
    Which for a small model was no better than a Python list comprehension
    brandonwillard
    @brandonwillard:matrix.org
    [m]
    this all depends on exactly what vectorization means in this context
    for instance, if it's more directly aligned with NumPy-like ufuncs, then that's already covered by Aesara's Composite Op
    Ricardo Vieira
    @ricardov94:matrix.org
    [m]
    That would be more in line with numpy.vectorize (although that one does nothing clever), I think?
    That's what Composite is more like right? Builds on top of scalar operators
    brandonwillard
    @brandonwillard:matrix.org
    [m]
    yeah, this is all very overlapping functionality
    same with Elemwise
    but, yes, there is no numpy.vectorize-like helper function in Aesara
    and I don't recall jax.vmap doing anything particularly special
    it seemed quite literally like a numpy.vectorize clone
    the effective difference being that the end result is a JITed function
    Ricardo Vieira
    @ricardov94:matrix.org
    [m]
    I think one big difference is that numpy.vectorize requires a base function that works only with scalars, whereas vmap can be built on top of tensor functions?
    brandonwillard
    @brandonwillard:matrix.org
    [m]
    I think numpy.vectorize can handle non-scalar functions
    that's why you can specify a signature with explicit dimensions for each input
    yeah, jax.vmap looks like an clone of numpy.vectorize
    with a somewhat different interface
    e.g. specified in terms of numbers of axes
    or labeled axes?
    yeah, it looks like their explicit version of numpy.vectorize just constructs vmap calls
    brandonwillard
    @brandonwillard:matrix.org
    [m]
    anyway, we could provide a numpy.vectorize helper function
    Ricardo Vieira
    @ricardov94:matrix.org
    [m]
    TIL about the signature keyword:
    import numpy as np
    
    def dirichlet_logp(a, x):
        return st.dirichlet(a).logpdf(x)
    
    vfunc = np.vectorize(dirichlet_logp, signature='(n),(n)->()')
    vfunc(np.arange(1, 10).reshape(3, 3), np.ones(3)/3)
    brandonwillard
    @brandonwillard:matrix.org
    [m]
    yeah, it's pretty useful
    this kind of logic is essentially what we're doing in the RandomVariable class interface
    e.g. with ndim_supp and ndims_params
    and Elemwise is constrained to scalar functions, of course
    (same with Composite, since it's just an optimization for Elemwise)
    Ricardo Vieira
    @ricardov94:matrix.org
    [m]
    And you think we could do something similar in Aesara, without having to resort to Scan loops internally?
    brandonwillard
    @brandonwillard:matrix.org
    [m]
    yeah, we could have a generalized Elemwise
    but the implementation/computation details are important
    remilouf
    @remilouf:matrix.org
    [m]
    This would be very useful for instance to advance several chains in aehmc
    brandonwillard
    @brandonwillard:matrix.org
    [m]
    I think there are directly analogous Numba vectorizations
    so that backend might be covered easily
    and, obviously, we can map to jax.vmap for the JAX backend
    and numpy.vectorize/for-loops in the Python case
    I just don't want to deal with an implementation in the current C backend
    although that would only involve some straightforward for-loops as well
    per usual, I've been avoiding these kinds of things until we replace/remove the C backend
    brandonwillard
    @brandonwillard:matrix.org
    [m]
    anyway, as a basic graph abstraction it's a good thing
    it's a generalization that has been motivated a few times by our RandomVariable work
    I think we're actually talking about gufuncs: https://numpy.org/doc/stable/reference/c-api/generalized-ufuncs.html#