esquilax.transforms.edge_map

esquilax.transforms.edge_map(f: Callable) Callable

Map a function over graph edges and related nodes

Maps a function over a set of edges (with data) and data corresponding to the edge start and end nodes,

Warning

Edge indices and any associated data should be sorted using esquilax.utils.sort_edges()

Examples

def f(k, params, start, end, edge):
    return params + start + end + edge

k = jax.random.PRNGKey(101)
edge_idxs = jnp.array([[0, 0, 1], [1, 2, 0]])
edges = jnp.array([0, 1, 2])
starts = jnp.array([0, 1, 2])
ends = jnp.array([0, 1, 2])

# Call transform with edge indexes
result = esquilax.transforms.edge_map(
    f
)(
    k, 2, starts, ends, edges, edge_idxs=edge_idxs
)
# result = [3, 5, 5]

The transform can also be used as a decorator. Arguments can also be PyTrees or None if unused

@esquilax.transforms.edge_map
def f(_k, _params, _start, end, edge):
    return end[0] + end[1] + edge

k = jax.random.PRNGKey(101)
edge_idxs = jnp.array([[0, 0, 1], [1, 2, 0]])
edges = jnp.array([0, 1, 2])
ends = (jnp.array([0, 1, 2]), jnp.array([0, 1, 2]))

# Call transform with edge indexes
f(k, None, None, ends, edges, edge_idxs=edge_idxs)
# [2, 5, 2]
Parameters:

f

Function with the signature

def f(k, params, start, end, edge, **static_kwargs):
    ...
    return x

where the arguments are:

  • k: JAX PRNGKey

  • params: Parameters (shared over the map)

  • start: Start node data

  • end: End node data

  • edge: Edge data

  • **static_kwargs: Any values required at compile-time by JAX can be passed as keyword arguments.