Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
dc92437
Reorganize graph test files for clarity
Andy-Jost Mar 31, 2026
281ed82
Enhance Graph.update() and add whole-graph update tests
Andy-Jost Mar 31, 2026
7854b76
Add AdjacencySet proxy for pred/succ and GraphNode.remove()
Andy-Jost Mar 31, 2026
5fbd288
Add edge mutation support and MutableSet interface for GraphNode adja…
Andy-Jost Apr 2, 2026
aa84e26
Use requires_module mark for numpy version checks in mutation tests
Andy-Jost Apr 2, 2026
b27dd93
Fix empty-graph return type: return set() instead of () for nodes/edges
Andy-Jost Apr 2, 2026
8554d30
Rename AdjacencySet to AdjacencySetProxy, add bulk ops and safety guards
Andy-Jost Apr 2, 2026
9813c20
Add destroy() method with handle invalidation, remove GRAPH_NODE_SENT…
Andy-Jost Apr 2, 2026
6411881
Add GraphNode identity cache for stable Python object round-trips
Andy-Jost Apr 2, 2026
7a3dbb4
Purge node cache on destroy to prevent stale identity lookups
Andy-Jost Apr 2, 2026
91b3b4e
Skip NULL nodes in graph_node_registry to fix sentinel identity colli…
Andy-Jost Apr 2, 2026
1b7743d
Unregister destroyed nodes from C++ graph_node_registry
Andy-Jost Apr 3, 2026
84f0b30
Add dedicated test for node identity preservation through round-trips
Andy-Jost Apr 3, 2026
64d6c2d
Merge branch 'main' into graph-node-identity
Andy-Jost Apr 3, 2026
a40be9a
Merge branch 'main' into graph-node-identity
Andy-Jost Apr 3, 2026
729af49
Rename _node_cache/_cached to _node_registry/_registered
Andy-Jost Apr 3, 2026
42131b6
Fix unregister_handle and rename invalidate_graph_node_handle
Andy-Jost Apr 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions cuda_core/cuda/core/_cpp/resource_handles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,8 @@ class HandleRegistry {
}

void unregister_handle(const Key& key) noexcept {
try {
std::lock_guard<std::mutex> lock(mutex_);
auto it = map_.find(key);
if (it != map_.end() && it->second.expired()) {
map_.erase(it);
}
} catch (...) {}
std::lock_guard<std::mutex> lock(mutex_);
map_.erase(key);
Comment on lines -177 to +178
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The expired() check was always unnecessary. For GraphNodes it is harmful. By not purging invalidated GraphNode handles, the driver reusing an old pointer could cause a false hit.

}

Handle lookup(const Key& key) {
Expand Down Expand Up @@ -969,17 +964,32 @@ static const GraphNodeBox* get_box(const GraphNodeHandle& h) {
);
}

static HandleRegistry<CUgraphNode, GraphNodeHandle> graph_node_registry;

GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_graph) {
if (node) {
if (auto h = graph_node_registry.lookup(node)) {
return h;
}
}
auto box = std::make_shared<const GraphNodeBox>(GraphNodeBox{node, h_graph});
return GraphNodeHandle(box, &box->resource);
GraphNodeHandle h(box, &box->resource);
if (node) {
graph_node_registry.register_handle(node, h);
}
return h;
}

GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept {
return h ? get_box(h)->h_graph : GraphHandle{};
}

void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept {
void invalidate_graph_node(const GraphNodeHandle& h) noexcept {
if (h) {
CUgraphNode node = get_box(h)->resource;
if (node) {
graph_node_registry.unregister_handle(node);
}
get_box(h)->resource = nullptr;
}
}
Expand Down
2 changes: 1 addition & 1 deletion cuda_core/cuda/core/_cpp/resource_handles.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_
GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept;

// Zero the CUgraphNode resource inside the handle, marking it invalid.
void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept;
void invalidate_graph_node(const GraphNodeHandle& h) noexcept;
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following on this discussion, this is a better name.


// ============================================================================
// Graphics resource handle functions
Expand Down
69 changes: 45 additions & 24 deletions cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ from cuda.core._resource_handles cimport (
create_graph_handle_ref,
create_graph_node_handle,
graph_node_get_graph,
invalidate_graph_node_handle,
invalidate_graph_node,
)
from cuda.core._utils.cuda_utils cimport HANDLE_RETURN, _parse_fill_value

Expand All @@ -57,10 +57,19 @@ from cuda.core._graph._utils cimport (
_attach_user_object,
)

import weakref

from cuda.core import Device
from cuda.core._graph._graph_def._adjacency_set_proxy import AdjacencySetProxy
from cuda.core._utils.cuda_utils import driver, handle_return

_node_registry = weakref.WeakValueDictionary()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you explained this live in one of the PR review meetings, but since it's subtle enough, it might be worth documenting why there is a registry on both the C++ and Python sides and how they interact (if at all).



cdef inline GraphNode _registered(GraphNode n):
_node_registry[<uintptr_t>n._h_node.get()] = n
return n


cdef class GraphNode:
"""Base class for all graph nodes.
Expand Down Expand Up @@ -144,7 +153,8 @@ cdef class GraphNode:
return
with nogil:
HANDLE_RETURN(cydriver.cuGraphDestroyNode(node))
invalidate_graph_node_handle(self._h_node)
_node_registry.pop(<uintptr_t>self._h_node.get(), None)
invalidate_graph_node(self._h_node)

@property
def pred(self):
Expand Down Expand Up @@ -522,18 +532,29 @@ cdef inline ConditionalNode _make_conditional_node(
n._cond_type = cond_type
n._branches = branches

return n
return _registered(n)

cdef inline GraphNode GN_create(GraphHandle h_graph, cydriver.CUgraphNode node):
cdef GraphNodeHandle h_node = create_graph_node_handle(node, h_graph)

# Sentinel: virtual node to represent the graph entry point.
if node == NULL:
n = GraphNode.__new__(GraphNode)
(<GraphNode>n)._h_node = create_graph_node_handle(node, h_graph)
(<GraphNode>n)._h_node = h_node
return n

cdef GraphNodeHandle h_node = create_graph_node_handle(node, h_graph)
# Return a registered object or create and register a new one.
registered = _node_registry.get(<uintptr_t>h_node.get())
if registered is not None:
return <GraphNode>registered
else:
return _registered(GN_create_impl(h_node))


cdef inline GraphNode GN_create_impl(GraphNodeHandle h_node):
cdef cydriver.CUgraphNodeType node_type
with nogil:
HANDLE_RETURN(cydriver.cuGraphNodeGetType(node, &node_type))
HANDLE_RETURN(cydriver.cuGraphNodeGetType(as_cu(h_node), &node_type))

if node_type == cydriver.CU_GRAPH_NODE_TYPE_EMPTY:
return EmptyNode._create_impl(h_node)
Expand Down Expand Up @@ -595,10 +616,10 @@ cdef inline KernelNode GN_launch(GraphNode self, LaunchConfig conf, Kernel ker,
_attach_user_object(as_cu(h_graph), <void*>new KernelHandle(ker._h_kernel),
<cydriver.CUhostFn>_destroy_kernel_handle_copy)

return KernelNode._create_with_params(
return _registered(KernelNode._create_with_params(
create_graph_node_handle(new_node, h_graph),
conf.grid, conf.block, conf.shmem_size,
ker._h_kernel)
ker._h_kernel))


cdef inline EmptyNode GN_join(GraphNode self, tuple nodes):
Expand All @@ -624,7 +645,7 @@ cdef inline EmptyNode GN_join(GraphNode self, tuple nodes):
HANDLE_RETURN(cydriver.cuGraphAddEmptyNode(
&new_node, as_cu(h_graph), deps_ptr, num_deps))

return EmptyNode._create_impl(create_graph_node_handle(new_node, h_graph))
return _registered(EmptyNode._create_impl(create_graph_node_handle(new_node, h_graph)))


cdef inline AllocNode GN_alloc(GraphNode self, size_t size, object options):
Expand Down Expand Up @@ -700,9 +721,9 @@ cdef inline AllocNode GN_alloc(GraphNode self, size_t size, object options):
HANDLE_RETURN(cydriver.cuGraphAddMemAllocNode(
&new_node, as_cu(h_graph), deps, num_deps, &alloc_params))

return AllocNode._create_with_params(
return _registered(AllocNode._create_with_params(
create_graph_node_handle(new_node, h_graph), alloc_params.dptr, size,
device_id, memory_type, tuple(peer_ids))
device_id, memory_type, tuple(peer_ids)))


cdef inline FreeNode GN_free(GraphNode self, cydriver.CUdeviceptr c_dptr):
Expand All @@ -720,7 +741,7 @@ cdef inline FreeNode GN_free(GraphNode self, cydriver.CUdeviceptr c_dptr):
HANDLE_RETURN(cydriver.cuGraphAddMemFreeNode(
&new_node, as_cu(h_graph), deps, num_deps, c_dptr))

return FreeNode._create_with_params(create_graph_node_handle(new_node, h_graph), c_dptr)
return _registered(FreeNode._create_with_params(create_graph_node_handle(new_node, h_graph), c_dptr))


cdef inline MemsetNode GN_memset(
Expand Down Expand Up @@ -755,9 +776,9 @@ cdef inline MemsetNode GN_memset(
&new_node, as_cu(h_graph), deps, num_deps,
&memset_params, ctx))

return MemsetNode._create_with_params(
return _registered(MemsetNode._create_with_params(
create_graph_node_handle(new_node, h_graph), c_dst,
val, elem_size, width, height, pitch)
val, elem_size, width, height, pitch))


cdef inline MemcpyNode GN_memcpy(
Expand Down Expand Up @@ -816,9 +837,9 @@ cdef inline MemcpyNode GN_memcpy(
HANDLE_RETURN(cydriver.cuGraphAddMemcpyNode(
&new_node, as_cu(h_graph), deps, num_deps, &params, ctx))

return MemcpyNode._create_with_params(
return _registered(MemcpyNode._create_with_params(
create_graph_node_handle(new_node, h_graph), c_dst, c_src, size,
c_dst_type, c_src_type)
c_dst_type, c_src_type))


cdef inline ChildGraphNode GN_embed(GraphNode self, GraphDef child_def):
Expand All @@ -843,8 +864,8 @@ cdef inline ChildGraphNode GN_embed(GraphNode self, GraphDef child_def):

cdef GraphHandle h_embedded = create_graph_handle_ref(embedded_graph, h_graph)

return ChildGraphNode._create_with_params(
create_graph_node_handle(new_node, h_graph), h_embedded)
return _registered(ChildGraphNode._create_with_params(
create_graph_node_handle(new_node, h_graph), h_embedded))


cdef inline EventRecordNode GN_record_event(GraphNode self, Event ev):
Expand All @@ -865,8 +886,8 @@ cdef inline EventRecordNode GN_record_event(GraphNode self, Event ev):
_attach_user_object(as_cu(h_graph), <void*>new EventHandle(ev._h_event),
<cydriver.CUhostFn>_destroy_event_handle_copy)

return EventRecordNode._create_with_params(
create_graph_node_handle(new_node, h_graph), ev._h_event)
return _registered(EventRecordNode._create_with_params(
create_graph_node_handle(new_node, h_graph), ev._h_event))


cdef inline EventWaitNode GN_wait_event(GraphNode self, Event ev):
Expand All @@ -887,8 +908,8 @@ cdef inline EventWaitNode GN_wait_event(GraphNode self, Event ev):
_attach_user_object(as_cu(h_graph), <void*>new EventHandle(ev._h_event),
<cydriver.CUhostFn>_destroy_event_handle_copy)

return EventWaitNode._create_with_params(
create_graph_node_handle(new_node, h_graph), ev._h_event)
return _registered(EventWaitNode._create_with_params(
create_graph_node_handle(new_node, h_graph), ev._h_event))


cdef inline HostCallbackNode GN_callback(GraphNode self, object fn, object user_data):
Expand All @@ -914,6 +935,6 @@ cdef inline HostCallbackNode GN_callback(GraphNode self, object fn, object user_
&new_node, as_cu(h_graph), deps, num_deps, &node_params))

cdef object callable_obj = fn if not isinstance(fn, ct._CFuncPtr) else None
return HostCallbackNode._create_with_params(
return _registered(HostCallbackNode._create_with_params(
create_graph_node_handle(new_node, h_graph), callable_obj,
node_params.fn, node_params.userData)
node_params.fn, node_params.userData))
2 changes: 1 addition & 1 deletion cuda_core/cuda/core/_resource_handles.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ cdef GraphHandle create_graph_handle_ref(cydriver.CUgraph graph, const GraphHand
# Graph node handles
cdef GraphNodeHandle create_graph_node_handle(cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil
cdef GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept nogil
cdef void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept nogil
cdef void invalidate_graph_node(const GraphNodeHandle& h) noexcept nogil

# Graphics resource handles
cdef GraphicsResourceHandle create_graphics_resource_handle(
Expand Down
2 changes: 1 addition & 1 deletion cuda_core/cuda/core/_resource_handles.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core":
cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil
GraphHandle graph_node_get_graph "cuda_core::graph_node_get_graph" (
const GraphNodeHandle& h) noexcept nogil
void invalidate_graph_node_handle "cuda_core::invalidate_graph_node_handle" (
void invalidate_graph_node "cuda_core::invalidate_graph_node" (
const GraphNodeHandle& h) noexcept nogil

# Graphics resource handles
Expand Down
27 changes: 27 additions & 0 deletions cuda_core/tests/graph/test_graphdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,7 @@ def test_node_type_preserved_by_nodes(node_spec):
matched = [n for n in all_nodes if n == node]
assert len(matched) == 1
assert isinstance(matched[0], spec.roundtrip_class)
assert matched[0] is node


def test_node_type_preserved_by_pred_succ(node_spec):
Expand All @@ -670,6 +671,7 @@ def test_node_type_preserved_by_pred_succ(node_spec):
matched = [s for s in predecessor.succ if s == node]
assert len(matched) == 1
assert isinstance(matched[0], spec.roundtrip_class)
assert matched[0] is node


def test_node_attrs(node_spec):
Expand Down Expand Up @@ -697,6 +699,31 @@ def test_node_attrs_preserved_by_nodes(node_spec):
assert getattr(retrieved, attr) == getattr(node, attr), f"{spec.name}.{attr} not preserved by nodes()"


def test_identity_preservation(init_cuda):
"""Round-trips through nodes(), edges(), and pred/succ return extant
objects rather than duplicates."""
g = GraphDef()
a = g.join()
b = a.join()

# nodes()
assert any(x is a for x in g.nodes())
assert any(x is b for x in g.nodes())

# succ/pred
a.succ = {b}
(b2,) = a.succ
assert b2 is b

(a2,) = b.pred
assert a2 is a

# edges()
((a2, b2),) = g.edges()
assert a2 is a
assert b2 is b

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since a potential failure mode here might be unintentional leaking of references, maybe it's also worth testing something like:

del a, a2, b, b2

and then asserting that the registry is empty.


# =============================================================================
# GraphDef basics
# =============================================================================
Expand Down
2 changes: 1 addition & 1 deletion cuda_core/tests/graph/test_graphdef_mutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def test_convert_linear_to_fan_in(init_cuda):
for node in g.nodes():
if isinstance(node, MemsetNode):
node.pred.clear()
elif isinstance(node, KernelNode) and node != reduce_node:
elif isinstance(node, KernelNode) and node is not reduce_node:
node.succ.add(reduce_node)

assert len(g.edges()) == 8
Expand Down
Loading