diff --git a/cuda_core/cuda/core/_cpp/resource_handles.cpp b/cuda_core/cuda/core/_cpp/resource_handles.cpp index 2355d64717..904b84c657 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.cpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.cpp @@ -174,13 +174,8 @@ class HandleRegistry { } void unregister_handle(const Key& key) noexcept { - try { - std::lock_guard lock(mutex_); - auto it = map_.find(key); - if (it != map_.end() && it->second.expired()) { - map_.erase(it); - } - } catch (...) {} + std::lock_guard lock(mutex_); + map_.erase(key); } Handle lookup(const Key& key) { @@ -969,17 +964,32 @@ static const GraphNodeBox* get_box(const GraphNodeHandle& h) { ); } +static HandleRegistry graph_node_registry; + GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_graph) { + if (node) { + if (auto h = graph_node_registry.lookup(node)) { + return h; + } + } auto box = std::make_shared(GraphNodeBox{node, h_graph}); - return GraphNodeHandle(box, &box->resource); + GraphNodeHandle h(box, &box->resource); + if (node) { + graph_node_registry.register_handle(node, h); + } + return h; } GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept { return h ? get_box(h)->h_graph : GraphHandle{}; } -void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept { +void invalidate_graph_node(const GraphNodeHandle& h) noexcept { if (h) { + CUgraphNode node = get_box(h)->resource; + if (node) { + graph_node_registry.unregister_handle(node); + } get_box(h)->resource = nullptr; } } diff --git a/cuda_core/cuda/core/_cpp/resource_handles.hpp b/cuda_core/cuda/core/_cpp/resource_handles.hpp index 064f1406f6..d63fb86997 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.hpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.hpp @@ -416,7 +416,7 @@ GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_ GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept; // Zero the CUgraphNode resource inside the handle, marking it invalid. -void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept; +void invalidate_graph_node(const GraphNodeHandle& h) noexcept; // ============================================================================ // Graphics resource handle functions diff --git a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx index 4048c9ee06..1474d10430 100644 --- a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx +++ b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx @@ -48,7 +48,7 @@ from cuda.core._resource_handles cimport ( create_graph_handle_ref, create_graph_node_handle, graph_node_get_graph, - invalidate_graph_node_handle, + invalidate_graph_node, ) from cuda.core._utils.cuda_utils cimport HANDLE_RETURN, _parse_fill_value @@ -57,10 +57,19 @@ from cuda.core._graph._utils cimport ( _attach_user_object, ) +import weakref + from cuda.core import Device from cuda.core._graph._graph_def._adjacency_set_proxy import AdjacencySetProxy from cuda.core._utils.cuda_utils import driver, handle_return +_node_registry = weakref.WeakValueDictionary() + + +cdef inline GraphNode _registered(GraphNode n): + _node_registry[n._h_node.get()] = n + return n + cdef class GraphNode: """Base class for all graph nodes. @@ -144,7 +153,8 @@ cdef class GraphNode: return with nogil: HANDLE_RETURN(cydriver.cuGraphDestroyNode(node)) - invalidate_graph_node_handle(self._h_node) + _node_registry.pop(self._h_node.get(), None) + invalidate_graph_node(self._h_node) @property def pred(self): @@ -522,18 +532,29 @@ cdef inline ConditionalNode _make_conditional_node( n._cond_type = cond_type n._branches = branches - return n + return _registered(n) cdef inline GraphNode GN_create(GraphHandle h_graph, cydriver.CUgraphNode node): + cdef GraphNodeHandle h_node = create_graph_node_handle(node, h_graph) + + # Sentinel: virtual node to represent the graph entry point. if node == NULL: n = GraphNode.__new__(GraphNode) - (n)._h_node = create_graph_node_handle(node, h_graph) + (n)._h_node = h_node return n - cdef GraphNodeHandle h_node = create_graph_node_handle(node, h_graph) + # Return a registered object or create and register a new one. + registered = _node_registry.get(h_node.get()) + if registered is not None: + return registered + else: + return _registered(GN_create_impl(h_node)) + + +cdef inline GraphNode GN_create_impl(GraphNodeHandle h_node): cdef cydriver.CUgraphNodeType node_type with nogil: - HANDLE_RETURN(cydriver.cuGraphNodeGetType(node, &node_type)) + HANDLE_RETURN(cydriver.cuGraphNodeGetType(as_cu(h_node), &node_type)) if node_type == cydriver.CU_GRAPH_NODE_TYPE_EMPTY: return EmptyNode._create_impl(h_node) @@ -595,10 +616,10 @@ cdef inline KernelNode GN_launch(GraphNode self, LaunchConfig conf, Kernel ker, _attach_user_object(as_cu(h_graph), new KernelHandle(ker._h_kernel), _destroy_kernel_handle_copy) - return KernelNode._create_with_params( + return _registered(KernelNode._create_with_params( create_graph_node_handle(new_node, h_graph), conf.grid, conf.block, conf.shmem_size, - ker._h_kernel) + ker._h_kernel)) cdef inline EmptyNode GN_join(GraphNode self, tuple nodes): @@ -624,7 +645,7 @@ cdef inline EmptyNode GN_join(GraphNode self, tuple nodes): HANDLE_RETURN(cydriver.cuGraphAddEmptyNode( &new_node, as_cu(h_graph), deps_ptr, num_deps)) - return EmptyNode._create_impl(create_graph_node_handle(new_node, h_graph)) + return _registered(EmptyNode._create_impl(create_graph_node_handle(new_node, h_graph))) cdef inline AllocNode GN_alloc(GraphNode self, size_t size, object options): @@ -700,9 +721,9 @@ cdef inline AllocNode GN_alloc(GraphNode self, size_t size, object options): HANDLE_RETURN(cydriver.cuGraphAddMemAllocNode( &new_node, as_cu(h_graph), deps, num_deps, &alloc_params)) - return AllocNode._create_with_params( + return _registered(AllocNode._create_with_params( create_graph_node_handle(new_node, h_graph), alloc_params.dptr, size, - device_id, memory_type, tuple(peer_ids)) + device_id, memory_type, tuple(peer_ids))) cdef inline FreeNode GN_free(GraphNode self, cydriver.CUdeviceptr c_dptr): @@ -720,7 +741,7 @@ cdef inline FreeNode GN_free(GraphNode self, cydriver.CUdeviceptr c_dptr): HANDLE_RETURN(cydriver.cuGraphAddMemFreeNode( &new_node, as_cu(h_graph), deps, num_deps, c_dptr)) - return FreeNode._create_with_params(create_graph_node_handle(new_node, h_graph), c_dptr) + return _registered(FreeNode._create_with_params(create_graph_node_handle(new_node, h_graph), c_dptr)) cdef inline MemsetNode GN_memset( @@ -755,9 +776,9 @@ cdef inline MemsetNode GN_memset( &new_node, as_cu(h_graph), deps, num_deps, &memset_params, ctx)) - return MemsetNode._create_with_params( + return _registered(MemsetNode._create_with_params( create_graph_node_handle(new_node, h_graph), c_dst, - val, elem_size, width, height, pitch) + val, elem_size, width, height, pitch)) cdef inline MemcpyNode GN_memcpy( @@ -816,9 +837,9 @@ cdef inline MemcpyNode GN_memcpy( HANDLE_RETURN(cydriver.cuGraphAddMemcpyNode( &new_node, as_cu(h_graph), deps, num_deps, ¶ms, ctx)) - return MemcpyNode._create_with_params( + return _registered(MemcpyNode._create_with_params( create_graph_node_handle(new_node, h_graph), c_dst, c_src, size, - c_dst_type, c_src_type) + c_dst_type, c_src_type)) cdef inline ChildGraphNode GN_embed(GraphNode self, GraphDef child_def): @@ -843,8 +864,8 @@ cdef inline ChildGraphNode GN_embed(GraphNode self, GraphDef child_def): cdef GraphHandle h_embedded = create_graph_handle_ref(embedded_graph, h_graph) - return ChildGraphNode._create_with_params( - create_graph_node_handle(new_node, h_graph), h_embedded) + return _registered(ChildGraphNode._create_with_params( + create_graph_node_handle(new_node, h_graph), h_embedded)) cdef inline EventRecordNode GN_record_event(GraphNode self, Event ev): @@ -865,8 +886,8 @@ cdef inline EventRecordNode GN_record_event(GraphNode self, Event ev): _attach_user_object(as_cu(h_graph), new EventHandle(ev._h_event), _destroy_event_handle_copy) - return EventRecordNode._create_with_params( - create_graph_node_handle(new_node, h_graph), ev._h_event) + return _registered(EventRecordNode._create_with_params( + create_graph_node_handle(new_node, h_graph), ev._h_event)) cdef inline EventWaitNode GN_wait_event(GraphNode self, Event ev): @@ -887,8 +908,8 @@ cdef inline EventWaitNode GN_wait_event(GraphNode self, Event ev): _attach_user_object(as_cu(h_graph), new EventHandle(ev._h_event), _destroy_event_handle_copy) - return EventWaitNode._create_with_params( - create_graph_node_handle(new_node, h_graph), ev._h_event) + return _registered(EventWaitNode._create_with_params( + create_graph_node_handle(new_node, h_graph), ev._h_event)) cdef inline HostCallbackNode GN_callback(GraphNode self, object fn, object user_data): @@ -914,6 +935,6 @@ cdef inline HostCallbackNode GN_callback(GraphNode self, object fn, object user_ &new_node, as_cu(h_graph), deps, num_deps, &node_params)) cdef object callable_obj = fn if not isinstance(fn, ct._CFuncPtr) else None - return HostCallbackNode._create_with_params( + return _registered(HostCallbackNode._create_with_params( create_graph_node_handle(new_node, h_graph), callable_obj, - node_params.fn, node_params.userData) + node_params.fn, node_params.userData)) diff --git a/cuda_core/cuda/core/_resource_handles.pxd b/cuda_core/cuda/core/_resource_handles.pxd index f847e60223..9e7307e821 100644 --- a/cuda_core/cuda/core/_resource_handles.pxd +++ b/cuda_core/cuda/core/_resource_handles.pxd @@ -186,7 +186,7 @@ cdef GraphHandle create_graph_handle_ref(cydriver.CUgraph graph, const GraphHand # Graph node handles cdef GraphNodeHandle create_graph_node_handle(cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil cdef GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept nogil -cdef void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept nogil +cdef void invalidate_graph_node(const GraphNodeHandle& h) noexcept nogil # Graphics resource handles cdef GraphicsResourceHandle create_graphics_resource_handle( diff --git a/cuda_core/cuda/core/_resource_handles.pyx b/cuda_core/cuda/core/_resource_handles.pyx index 001f9b4a0c..2090f5026d 100644 --- a/cuda_core/cuda/core/_resource_handles.pyx +++ b/cuda_core/cuda/core/_resource_handles.pyx @@ -159,7 +159,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil GraphHandle graph_node_get_graph "cuda_core::graph_node_get_graph" ( const GraphNodeHandle& h) noexcept nogil - void invalidate_graph_node_handle "cuda_core::invalidate_graph_node_handle" ( + void invalidate_graph_node "cuda_core::invalidate_graph_node" ( const GraphNodeHandle& h) noexcept nogil # Graphics resource handles diff --git a/cuda_core/tests/graph/test_graphdef.py b/cuda_core/tests/graph/test_graphdef.py index be6da9515a..562f720ca8 100644 --- a/cuda_core/tests/graph/test_graphdef.py +++ b/cuda_core/tests/graph/test_graphdef.py @@ -661,6 +661,7 @@ def test_node_type_preserved_by_nodes(node_spec): matched = [n for n in all_nodes if n == node] assert len(matched) == 1 assert isinstance(matched[0], spec.roundtrip_class) + assert matched[0] is node def test_node_type_preserved_by_pred_succ(node_spec): @@ -670,6 +671,7 @@ def test_node_type_preserved_by_pred_succ(node_spec): matched = [s for s in predecessor.succ if s == node] assert len(matched) == 1 assert isinstance(matched[0], spec.roundtrip_class) + assert matched[0] is node def test_node_attrs(node_spec): @@ -697,6 +699,31 @@ def test_node_attrs_preserved_by_nodes(node_spec): assert getattr(retrieved, attr) == getattr(node, attr), f"{spec.name}.{attr} not preserved by nodes()" +def test_identity_preservation(init_cuda): + """Round-trips through nodes(), edges(), and pred/succ return extant + objects rather than duplicates.""" + g = GraphDef() + a = g.join() + b = a.join() + + # nodes() + assert any(x is a for x in g.nodes()) + assert any(x is b for x in g.nodes()) + + # succ/pred + a.succ = {b} + (b2,) = a.succ + assert b2 is b + + (a2,) = b.pred + assert a2 is a + + # edges() + ((a2, b2),) = g.edges() + assert a2 is a + assert b2 is b + + # ============================================================================= # GraphDef basics # ============================================================================= diff --git a/cuda_core/tests/graph/test_graphdef_mutation.py b/cuda_core/tests/graph/test_graphdef_mutation.py index dcfd4aab89..ac0d8f5e61 100644 --- a/cuda_core/tests/graph/test_graphdef_mutation.py +++ b/cuda_core/tests/graph/test_graphdef_mutation.py @@ -380,7 +380,7 @@ def test_convert_linear_to_fan_in(init_cuda): for node in g.nodes(): if isinstance(node, MemsetNode): node.pred.clear() - elif isinstance(node, KernelNode) and node != reduce_node: + elif isinstance(node, KernelNode) and node is not reduce_node: node.succ.add(reduce_node) assert len(g.edges()) == 8