Where communities thrive

• Join over 1.5M+ people
• Join over 100K+ communities
• Free without limits
Activity
Ricardo Vieira
@ricardov94:matrix.org
[m]
Trying to wrap my mind around FunctionGraph.replace
import aesara
import aesara.tensor as at

# Fine
x = at.scalar('x')
y = at.exp(x)
y.name = 'exp(x)'
fg = aesara.graph.FunctionGraph(outputs=[x], clone=False)
fg.replace(x, y)
print('1')
aesara.dprint(fg)

# Not fine
x = at.scalar('x')
y = at.exp(x)
y.name = 'exp(x)'
fg = aesara.graph.FunctionGraph(outputs=[y], clone=False)
fg.replace(x, y)
print('\n2')
aesara.dprint(fg)

# Evel less fine
x = at.scalar('x')
y = at.exp(x)
y.name = 'exp(x)'
fg = aesara.graph.FunctionGraph(outputs=[x, y], clone=False)
fg.replace(x, y)
print('\n3')
aesara.dprint(fg)
Why do the last two examples generate an invalid (recursive) graph, but not the first?
brandonwillard
@brandonwillard:matrix.org
[m]
what are the results and why are they invalid?
it sounds like you're running into something involving capture-avoiding substitutions or a similar topic
Ricardo Vieira
@ricardov94:matrix.org
[m]
The last two graphs are recursive. output.owner.inputs[0] is output
brandonwillard
@brandonwillard:matrix.org
[m]
it might help to start with this question: what do you expect from such a substitution?
Ricardo Vieira
@ricardov94:matrix.org
[m]
I am not sure. To replace once x by y, so x -> exp(x) and exp(x) -> (exp(exp(x)) perhaps?
brandonwillard
@brandonwillard:matrix.org
[m]
(remember that a FunctionGraph is literally a function/lambda)
brandonwillard
@brandonwillard:matrix.org
[m]
anyway, it might not be exactly that
because your substitution is simply recursive
Ricardo Vieira
@ricardov94:matrix.org
[m]
What would you expect from the 3 cases above?
brandonwillard
@brandonwillard:matrix.org
[m]
I would expect it to fail
Ricardo Vieira
@ricardov94:matrix.org
[m]
Yeah it is. And why is the first case different? Is the introduced x treated differently than the original one?
brandonwillard
@brandonwillard:matrix.org
[m]
because x is an input to the FunctionGraph, and the checks surrounding that are fairly strict
the first case is probably not treating the x as an input that's distinct from an output
Ricardo Vieira
@ricardov94:matrix.org
[m]
It's also an input in the other cases. Is the difference that it is also an output?
brandonwillard
@brandonwillard:matrix.org
[m]
i.e. it might be treating it exclusively as an output
so the checks are different
Ricardo Vieira
@ricardov94:matrix.org
[m]
I see. I should try the example with 2 intermediate variables then
The other thing is that I struggle with coming up with non recursive replacements. Unless all inputs are new won't the replacement variable always contain the variable it is trying to replace as one of it's inputs (maybe with some extra nodes in between)?
Or in other words what is a valid replacement for tbe original graph which is just exp(x)
brandonwillard
@brandonwillard:matrix.org
[m]
valid substitions (in the untyped lambda calculus) usually have a formal rename rule as part of their definition
e.g. (λx.P)[x := N] == λz.(P'[x := N]) where P' has y renamed to z
brandonwillard
@brandonwillard:matrix.org
[m]
and λy.P alpha-equivalent to λz.P'
otherwise a naive substitution like λy.yx[x := xy] would have the free y in x := xy bound/captured by the lambda
brandonwillard
@brandonwillard:matrix.org
[m]
that's the relevant high-level concept here
Ricardo Vieira
@ricardov94:matrix.org
[m]
I will probably have to read some examples somewhere to understand the concept and "algorithm"
brandonwillard
@brandonwillard:matrix.org
[m]
well, that's just a good reference for how it should work
how it does work is a whole other story
Ricardo Vieira
@ricardov94:matrix.org
[m]
def create_fgraph():
a = at.scalar('a')
b = at.exp(a); b.name = 'b'
c = at.log(b); c.name = 'c'
d = c + 5; d.name = 'd'
e = at.log(d); e.name = 'e'
f = e - 3; f.name = 'f'
nodes = dict(a=a, b=b, c=c, d=d, e=e, f=f)
fg = FunctionGraph(inputs=[a], outputs=[f], clone=False)
return nodes, fg

# Can always replace one variable by an earlier one
nodes, fg = create_fgraph()
fg.replace_validate(nodes['c'], nodes['b'])

nodes, fg = create_fgraph()
fg.replace_validate(nodes['e'], nodes['c'])

# Can never replace one variable by a later one
nodes, fg = create_fgraph()
# This would enter an inifinite loop!
# fg.replace_validate(nodes['b'], nodes['c'])

# Unless it is the output variable, but then it gives
# an invalid cyclical graph
nodes, fg = create_fgraph()
fg.replace_validate(nodes['e'], nodes['f'])

aesara.dprint(fg)
# Elemwise{sub,no_inplace} [id A] 'f'   0
#  |Elemwise{sub,no_inplace} [id A] 'f'   0
#  |TensorConstant{3} [id B]
Ricardo Vieira
@ricardov94:matrix.org
[m]

:point_up: Edit: python
def create_fgraph():
a = at.scalar('a')
b = at.exp(a); b.name = 'b'
c = at.log(b); c.name = 'c'
d = c + 5; d.name = 'd'
e = at.log(d); e.name = 'e'
f = e - 3; f.name = 'f'
nodes = dict(a=a, b=b, c=c, d=d, e=e, f=f)
fg = FunctionGraph(inputs=[a], outputs=[f], clone=False)
return nodes, fg

Can always replace one variable by an earlier one

nodes, fg = create_fgraph()
fg.replace(nodes['c'], nodes['b'])

nodes, fg = create_fgraph()
fg.replace(nodes['e'], nodes['c'])

Can never replace one variable by a later one

nodes, fg = create_fgraph()

an invalid cyclical graph

nodes, fg = create_fgraph()
fg.replace_validate(nodes['e'], nodes['f'])

aesara.dprint(fg)

|TensorConstant{3} [id B]



:point_up: Edit: python
def create_fgraph():
a = at.scalar('a')
b = at.exp(a); b.name = 'b'
c = at.log(b); c.name = 'c'
d = c + 5; d.name = 'd'
e = at.log(d); e.name = 'e'
f = e - 3; f.name = 'f'
nodes = dict(a=a, b=b, c=c, d=d, e=e, f=f)
fg = FunctionGraph(inputs=[a], outputs=[f], clone=False)
return nodes, fg

Can always replace one variable by an earlier one

nodes, fg = create_fgraph()
fg.replace(nodes['c'], nodes['b'])

nodes, fg = create_fgraph()
fg.replace(nodes['e'], nodes['c'])

Produces an invalid cyclical graph

nodes, fg = create_fgraph()
fg.replace(nodes['b'], nodes['c'])

an invalid cyclical graph

nodes, fg = create_fgraph()
fg.replace(nodes['e'], nodes['f'])

aesara.dprint(fg.outputs)

|TensorConstant{3} [id B]



Ricardo Vieira
@ricardov94:matrix.org
[m]

Last code dump I promise

#%%
a = at.scalar('a')
b = at.exp(a); b.name = 'b'
c = at.log(b); c.name = 'c'
d = c + 5; d.name = 'd'

fg = FunctionGraph(inputs=[a], outputs=[c], clone=False)
fg.replace(b, d)
aesara.dprint(fg.outputs)
# Elemwise{log,no_inplace} [id A] 'c'
#    |Elemwise{log,no_inplace} [id A] 'c'  <-- CYCLICAL
#    |TensorConstant{5} [id C]

#%%
a = at.scalar('a')
b = at.exp(a); b.name = 'b'
c = at.log(b); c.name = 'c'
d = c + 5; d.name = 'd'

fg = FunctionGraph(inputs=[a], outputs=[b], clone=False)
fg.replace(b, d)
aesara.dprint(fg.outputs)
#  |Elemwise{log,no_inplace} [id B] 'c'
#  | |Elemwise{exp,no_inplace} [id C] 'b'
#  |   |a [id D]
#  |TensorConstant{5} [id E]

Does it make sense that the replacement of b -> d works when b is the output of the FunctionGraph (second half) but not when it is not (first half)?

Is this just a corner case I am hitting, and they should both fail/ or succeed?
Ricardo Vieira
@ricardov94:matrix.org
[m]

:point_up: Edit: Last code dump I promise

#%%
a = at.scalar('a')
b = at.exp(a); b.name = 'b'
c = at.log(b); c.name = 'c'
d = c + 5; d.name = 'd'

fg = FunctionGraph(inputs=[a], outputs=[c], clone=False)
fg.replace(b, d)
aesara.dprint(fg.outputs)
# Elemwise{log,no_inplace} [id A] 'c'
#    |Elemwise{log,no_inplace} [id A] 'c'  <-- CYCLICAL
#    |TensorConstant{5} [id C]

#%%
a = at.scalar('a')
b = at.exp(a); b.name = 'b'
c = at.log(b); c.name = 'c'
d = c + 5; d.name = 'd'

fg = FunctionGraph(inputs=[a], outputs=[b], clone=False)
fg.replace(b, d)
aesara.dprint(fg.outputs)
#  |Elemwise{log,no_inplace} [id B] 'c'
#  | |Elemwise{exp,no_inplace} [id C] 'b'
#  |   |a [id D]
#  |TensorConstant{5} [id E]

Does it make sense that the replacement of b -> d "works" when b is the output of the FunctionGraph (second half) but not when it is not (first half)?
By works I mean that it produces an acyclical graph

:point_up: Edit: Is this just a corner case I am hitting, and should both fail/ or succeed?
brandonwillard
@brandonwillard:matrix.org
[m]
one minute
brandonwillard
@brandonwillard:matrix.org
[m]
when you replaced the output b in the latter case, it basically cleared the entire FunctionGraph and replaced it with another graph/output that also happened to reference b
so no issue there
remember, these steps need to happen in a fixed sequence of events
and the actual replacement steps are different depending on whether or not the replaced term is an output
in that case, the thing being changed is FunctionGraph.outputs
when a replacement is made on something that is not just a FunctionGraph output, like in the former case, an Apply node is updated in-place
brandonwillard
@brandonwillard:matrix.org
[m]
specifically, I believe it's Apply.inputs that's updated in-place
brandonwillard
@brandonwillard:matrix.org
[m]
anyway, there's no real cycle detection at this level, so it's completely up to the caller to not introduce them
Ricardo Vieira
@ricardov94:matrix.org
[m]
Ricardo Vieira
@ricardov94:matrix.org
[m]
Okay, now things are clicking a bit more for me. The (aesara) problem is not with the recursive expression, but somehow its identity?
a = at.scalar('a')
b = at.exp(a); b.name = 'b'
c = at.log(b); c.name = 'c'
d = c + 5; d.name = 'd'
fg = FunctionGraph(inputs=[a], outputs=[d], clone=False)

# Cannot do this
# fg.replace(b, c)

# But can do this
new_c = c.owner.op(*c.owner.inputs); new_c.name = 'new_c'
fg.replace(b, new_c)

aesara.dprint(fg.outputs)
Is there a reason why fg.replace(b, c) should behave differently than fg.replace(b, new_c)?
Ricardo Vieira
@ricardov94:matrix.org
[m]
:point_up: Edit: Things are still not clicking entirely. The (aesara) problem I was facing before is not with the recursive expression, but somehow its identity?
a = at.scalar('a')
b = at.exp(a); b.name = 'b'
c = at.log(b); c.name = 'c'
d = c + 5; d.name = 'd'
fg = FunctionGraph(inputs=[a], outputs=[d], clone=False)

# Cannot do this
# fg.replace(b, c)

# But can do this
new_c = c.owner.op(*c.owner.inputs); new_c.name = 'new_c'
fg.replace(b, new_c)

aesara.dprint(fg.outputs)