esquilax.transforms.edge_map¶
- esquilax.transforms.edge_map(f: Callable) Callable ¶
Map a function over graph edges and related nodes
Maps a function over a set of edges (with data) and data corresponding to the edge start and end nodes,
Warning
Edge indices and any associated data should be sorted using
esquilax.utils.sort_edges()
Examples
def f(k, params, start, end, edge): return params + start + end + edge k = jax.random.PRNGKey(101) edge_idxs = jnp.array([[0, 0, 1], [1, 2, 0]]) edges = jnp.array([0, 1, 2]) starts = jnp.array([0, 1, 2]) ends = jnp.array([0, 1, 2]) # Call transform with edge indexes result = esquilax.transforms.edge_map( f )( k, 2, starts, ends, edges, edge_idxs=edge_idxs ) # result = [3, 5, 5]
The transform can also be used as a decorator. Arguments can also be PyTrees or
None
if unused@esquilax.transforms.edge_map def f(_k, _params, _start, end, edge): return end[0] + end[1] + edge k = jax.random.PRNGKey(101) edge_idxs = jnp.array([[0, 0, 1], [1, 2, 0]]) edges = jnp.array([0, 1, 2]) ends = (jnp.array([0, 1, 2]), jnp.array([0, 1, 2])) # Call transform with edge indexes f(k, None, None, ends, edges, edge_idxs=edge_idxs) # [2, 5, 2]
- Parameters:
f –
Function with the signature
def f(k, params, start, end, edge, **static_kwargs): ... return x
where the arguments are:
k
: JAX PRNGKeyparams
: Parameters (shared over the map)start
: Start node dataend
: End node dataedge
: Edge data**static_kwargs
: Any values required at compile-time by JAX can be passed as keyword arguments.