esquilax.transforms.edge_map¶
- esquilax.transforms.edge_map(f: Callable) Callable ¶
Map a function over graph edges and related nodes
Maps a function over a set of edges (with data) and data corresponding to the edge start and end nodes,
Warning
Edge indices and any associated data should be sorted using
esquilax.utils.sort_edges()
Examples
def f(params, start, end, edge): return params + start + end + edge edge_idxs = jnp.array([[0, 0, 1], [1, 2, 0]]) edges = jnp.array([0, 1, 2]) starts = jnp.array([0, 1, 2]) ends = jnp.array([0, 1, 2]) # Call transform with edge indexes result = esquilax.transforms.edge_map( f )( 2, starts, ends, edges, edge_idxs=edge_idxs ) # result = [3, 5, 5]
The transform can also be used as a decorator. Arguments can also be PyTrees or
None
if unused@esquilax.transforms.edge_map def f(_params, _start, end, edge): return end[0] + end[1] + edge edge_idxs = jnp.array([[0, 0, 1], [1, 2, 0]]) edges = jnp.array([0, 1, 2]) ends = (jnp.array([0, 1, 2]), jnp.array([0, 1, 2])) # Call transform with edge indexes f(None, None, ends, edges, edge_idxs=edge_idxs) # [2, 5, 2]
JAX random keys can be passed to the wrapped function by including the
key
keyword argumnt@esquilax.transforms.edge_map def f(_params, _start, _end, _edge, *, key): # Sample a random integer for each edge return jax.random.choice(key, 100, ()) k = jax.random.PRNGKey(101) result = f(None, None, None, None, edge_idxs=edge_idxs, key=k)
- Parameters:
f –
Function with the signature
def f(params, start, end, edge, **static_kwargs): ... return x
where the arguments are:
params
: Parameters (shared over the map)start
: Start node dataend
: End node dataedge
: Edge data**static_kwargs
: Any values required at compile-time by JAX can be passed as keyword arguments.
The keyword argument
key
can be included to pass a random key to the mapped function.