esquilax.transforms.edge_map

esquilax.transforms.edge_map(f: Callable) Callable

Map a function over graph edges and related nodes

Maps a function over a set of edges (with data) and data corresponding to the edge start and end nodes,

Warning

Edge indices and any associated data should be sorted using esquilax.utils.sort_edges()

Examples

def f(params, start, end, edge):
    return params + start + end + edge

edge_idxs = jnp.array([[0, 0, 1], [1, 2, 0]])
edges = jnp.array([0, 1, 2])
starts = jnp.array([0, 1, 2])
ends = jnp.array([0, 1, 2])

# Call transform with edge indexes
result = esquilax.transforms.edge_map(
    f
)(
    2, starts, ends, edges, edge_idxs=edge_idxs
)
# result = [3, 5, 5]

The transform can also be used as a decorator. Arguments can also be PyTrees or None if unused

@esquilax.transforms.edge_map
def f(_params, _start, end, edge):
    return end[0] + end[1] + edge

edge_idxs = jnp.array([[0, 0, 1], [1, 2, 0]])
edges = jnp.array([0, 1, 2])
ends = (jnp.array([0, 1, 2]), jnp.array([0, 1, 2]))

# Call transform with edge indexes
f(None, None, ends, edges, edge_idxs=edge_idxs)
# [2, 5, 2]

JAX random keys can be passed to the wrapped function by including the key keyword argumnt

@esquilax.transforms.edge_map
def f(_params, _start, _end, _edge, *, key):
    # Sample a random integer for each edge
    return jax.random.choice(key, 100, ())

k = jax.random.PRNGKey(101)
result = f(None, None, None, None, edge_idxs=edge_idxs, key=k)
Parameters:

f

Function with the signature

def f(params, start, end, edge, **static_kwargs):
    ...
    return x

where the arguments are:

  • params: Parameters (shared over the map)

  • start: Start node data

  • end: End node data

  • edge: Edge data

  • **static_kwargs: Any values required at compile-time by JAX can be passed as keyword arguments.

The keyword argument key can be included to pass a random key to the mapped function.