-
Notifications
You must be signed in to change notification settings - Fork 266
Add GraphNode identity cache for stable object round-trips #1853
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
dc92437
281ed82
7854b76
5fbd288
aa84e26
b27dd93
8554d30
9813c20
6411881
7a3dbb4
91b3b4e
1b7743d
84f0b30
64d6c2d
a40be9a
729af49
42131b6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -416,7 +416,7 @@ GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_ | |
| GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept; | ||
|
|
||
| // Zero the CUgraphNode resource inside the handle, marking it invalid. | ||
| void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept; | ||
| void invalidate_graph_node(const GraphNodeHandle& h) noexcept; | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Following on this discussion, this is a better name. |
||
|
|
||
| // ============================================================================ | ||
| // Graphics resource handle functions | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -48,7 +48,7 @@ from cuda.core._resource_handles cimport ( | |
| create_graph_handle_ref, | ||
| create_graph_node_handle, | ||
| graph_node_get_graph, | ||
| invalidate_graph_node_handle, | ||
| invalidate_graph_node, | ||
| ) | ||
| from cuda.core._utils.cuda_utils cimport HANDLE_RETURN, _parse_fill_value | ||
|
|
||
|
|
@@ -57,10 +57,19 @@ from cuda.core._graph._utils cimport ( | |
| _attach_user_object, | ||
| ) | ||
|
|
||
| import weakref | ||
|
|
||
| from cuda.core import Device | ||
| from cuda.core._graph._graph_def._adjacency_set_proxy import AdjacencySetProxy | ||
| from cuda.core._utils.cuda_utils import driver, handle_return | ||
|
|
||
| _node_registry = weakref.WeakValueDictionary() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you explained this live in one of the PR review meetings, but since it's subtle enough, it might be worth documenting why there is a registry on both the C++ and Python sides and how they interact (if at all). |
||
|
|
||
|
|
||
| cdef inline GraphNode _registered(GraphNode n): | ||
| _node_registry[<uintptr_t>n._h_node.get()] = n | ||
| return n | ||
|
|
||
|
|
||
| cdef class GraphNode: | ||
| """Base class for all graph nodes. | ||
|
|
@@ -144,7 +153,8 @@ cdef class GraphNode: | |
| return | ||
| with nogil: | ||
| HANDLE_RETURN(cydriver.cuGraphDestroyNode(node)) | ||
| invalidate_graph_node_handle(self._h_node) | ||
| _node_registry.pop(<uintptr_t>self._h_node.get(), None) | ||
| invalidate_graph_node(self._h_node) | ||
|
|
||
| @property | ||
| def pred(self): | ||
|
|
@@ -522,18 +532,29 @@ cdef inline ConditionalNode _make_conditional_node( | |
| n._cond_type = cond_type | ||
| n._branches = branches | ||
|
|
||
| return n | ||
| return _registered(n) | ||
|
|
||
| cdef inline GraphNode GN_create(GraphHandle h_graph, cydriver.CUgraphNode node): | ||
| cdef GraphNodeHandle h_node = create_graph_node_handle(node, h_graph) | ||
|
|
||
| # Sentinel: virtual node to represent the graph entry point. | ||
| if node == NULL: | ||
| n = GraphNode.__new__(GraphNode) | ||
| (<GraphNode>n)._h_node = create_graph_node_handle(node, h_graph) | ||
| (<GraphNode>n)._h_node = h_node | ||
| return n | ||
|
|
||
| cdef GraphNodeHandle h_node = create_graph_node_handle(node, h_graph) | ||
| # Return a registered object or create and register a new one. | ||
| registered = _node_registry.get(<uintptr_t>h_node.get()) | ||
| if registered is not None: | ||
| return <GraphNode>registered | ||
| else: | ||
| return _registered(GN_create_impl(h_node)) | ||
|
|
||
|
|
||
| cdef inline GraphNode GN_create_impl(GraphNodeHandle h_node): | ||
| cdef cydriver.CUgraphNodeType node_type | ||
| with nogil: | ||
| HANDLE_RETURN(cydriver.cuGraphNodeGetType(node, &node_type)) | ||
| HANDLE_RETURN(cydriver.cuGraphNodeGetType(as_cu(h_node), &node_type)) | ||
|
|
||
| if node_type == cydriver.CU_GRAPH_NODE_TYPE_EMPTY: | ||
| return EmptyNode._create_impl(h_node) | ||
|
|
@@ -595,10 +616,10 @@ cdef inline KernelNode GN_launch(GraphNode self, LaunchConfig conf, Kernel ker, | |
| _attach_user_object(as_cu(h_graph), <void*>new KernelHandle(ker._h_kernel), | ||
| <cydriver.CUhostFn>_destroy_kernel_handle_copy) | ||
|
|
||
| return KernelNode._create_with_params( | ||
| return _registered(KernelNode._create_with_params( | ||
| create_graph_node_handle(new_node, h_graph), | ||
| conf.grid, conf.block, conf.shmem_size, | ||
| ker._h_kernel) | ||
| ker._h_kernel)) | ||
|
|
||
|
|
||
| cdef inline EmptyNode GN_join(GraphNode self, tuple nodes): | ||
|
|
@@ -624,7 +645,7 @@ cdef inline EmptyNode GN_join(GraphNode self, tuple nodes): | |
| HANDLE_RETURN(cydriver.cuGraphAddEmptyNode( | ||
| &new_node, as_cu(h_graph), deps_ptr, num_deps)) | ||
|
|
||
| return EmptyNode._create_impl(create_graph_node_handle(new_node, h_graph)) | ||
| return _registered(EmptyNode._create_impl(create_graph_node_handle(new_node, h_graph))) | ||
|
|
||
|
|
||
| cdef inline AllocNode GN_alloc(GraphNode self, size_t size, object options): | ||
|
|
@@ -700,9 +721,9 @@ cdef inline AllocNode GN_alloc(GraphNode self, size_t size, object options): | |
| HANDLE_RETURN(cydriver.cuGraphAddMemAllocNode( | ||
| &new_node, as_cu(h_graph), deps, num_deps, &alloc_params)) | ||
|
|
||
| return AllocNode._create_with_params( | ||
| return _registered(AllocNode._create_with_params( | ||
| create_graph_node_handle(new_node, h_graph), alloc_params.dptr, size, | ||
| device_id, memory_type, tuple(peer_ids)) | ||
| device_id, memory_type, tuple(peer_ids))) | ||
|
|
||
|
|
||
| cdef inline FreeNode GN_free(GraphNode self, cydriver.CUdeviceptr c_dptr): | ||
|
|
@@ -720,7 +741,7 @@ cdef inline FreeNode GN_free(GraphNode self, cydriver.CUdeviceptr c_dptr): | |
| HANDLE_RETURN(cydriver.cuGraphAddMemFreeNode( | ||
| &new_node, as_cu(h_graph), deps, num_deps, c_dptr)) | ||
|
|
||
| return FreeNode._create_with_params(create_graph_node_handle(new_node, h_graph), c_dptr) | ||
| return _registered(FreeNode._create_with_params(create_graph_node_handle(new_node, h_graph), c_dptr)) | ||
|
|
||
|
|
||
| cdef inline MemsetNode GN_memset( | ||
|
|
@@ -755,9 +776,9 @@ cdef inline MemsetNode GN_memset( | |
| &new_node, as_cu(h_graph), deps, num_deps, | ||
| &memset_params, ctx)) | ||
|
|
||
| return MemsetNode._create_with_params( | ||
| return _registered(MemsetNode._create_with_params( | ||
| create_graph_node_handle(new_node, h_graph), c_dst, | ||
| val, elem_size, width, height, pitch) | ||
| val, elem_size, width, height, pitch)) | ||
|
|
||
|
|
||
| cdef inline MemcpyNode GN_memcpy( | ||
|
|
@@ -816,9 +837,9 @@ cdef inline MemcpyNode GN_memcpy( | |
| HANDLE_RETURN(cydriver.cuGraphAddMemcpyNode( | ||
| &new_node, as_cu(h_graph), deps, num_deps, ¶ms, ctx)) | ||
|
|
||
| return MemcpyNode._create_with_params( | ||
| return _registered(MemcpyNode._create_with_params( | ||
| create_graph_node_handle(new_node, h_graph), c_dst, c_src, size, | ||
| c_dst_type, c_src_type) | ||
| c_dst_type, c_src_type)) | ||
|
|
||
|
|
||
| cdef inline ChildGraphNode GN_embed(GraphNode self, GraphDef child_def): | ||
|
|
@@ -843,8 +864,8 @@ cdef inline ChildGraphNode GN_embed(GraphNode self, GraphDef child_def): | |
|
|
||
| cdef GraphHandle h_embedded = create_graph_handle_ref(embedded_graph, h_graph) | ||
|
|
||
| return ChildGraphNode._create_with_params( | ||
| create_graph_node_handle(new_node, h_graph), h_embedded) | ||
| return _registered(ChildGraphNode._create_with_params( | ||
| create_graph_node_handle(new_node, h_graph), h_embedded)) | ||
|
|
||
|
|
||
| cdef inline EventRecordNode GN_record_event(GraphNode self, Event ev): | ||
|
|
@@ -865,8 +886,8 @@ cdef inline EventRecordNode GN_record_event(GraphNode self, Event ev): | |
| _attach_user_object(as_cu(h_graph), <void*>new EventHandle(ev._h_event), | ||
| <cydriver.CUhostFn>_destroy_event_handle_copy) | ||
|
|
||
| return EventRecordNode._create_with_params( | ||
| create_graph_node_handle(new_node, h_graph), ev._h_event) | ||
| return _registered(EventRecordNode._create_with_params( | ||
| create_graph_node_handle(new_node, h_graph), ev._h_event)) | ||
|
|
||
|
|
||
| cdef inline EventWaitNode GN_wait_event(GraphNode self, Event ev): | ||
|
|
@@ -887,8 +908,8 @@ cdef inline EventWaitNode GN_wait_event(GraphNode self, Event ev): | |
| _attach_user_object(as_cu(h_graph), <void*>new EventHandle(ev._h_event), | ||
| <cydriver.CUhostFn>_destroy_event_handle_copy) | ||
|
|
||
| return EventWaitNode._create_with_params( | ||
| create_graph_node_handle(new_node, h_graph), ev._h_event) | ||
| return _registered(EventWaitNode._create_with_params( | ||
| create_graph_node_handle(new_node, h_graph), ev._h_event)) | ||
|
|
||
|
|
||
| cdef inline HostCallbackNode GN_callback(GraphNode self, object fn, object user_data): | ||
|
|
@@ -914,6 +935,6 @@ cdef inline HostCallbackNode GN_callback(GraphNode self, object fn, object user_ | |
| &new_node, as_cu(h_graph), deps, num_deps, &node_params)) | ||
|
|
||
| cdef object callable_obj = fn if not isinstance(fn, ct._CFuncPtr) else None | ||
| return HostCallbackNode._create_with_params( | ||
| return _registered(HostCallbackNode._create_with_params( | ||
| create_graph_node_handle(new_node, h_graph), callable_obj, | ||
| node_params.fn, node_params.userData) | ||
| node_params.fn, node_params.userData)) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -661,6 +661,7 @@ def test_node_type_preserved_by_nodes(node_spec): | |
| matched = [n for n in all_nodes if n == node] | ||
| assert len(matched) == 1 | ||
| assert isinstance(matched[0], spec.roundtrip_class) | ||
| assert matched[0] is node | ||
|
|
||
|
|
||
| def test_node_type_preserved_by_pred_succ(node_spec): | ||
|
|
@@ -670,6 +671,7 @@ def test_node_type_preserved_by_pred_succ(node_spec): | |
| matched = [s for s in predecessor.succ if s == node] | ||
| assert len(matched) == 1 | ||
| assert isinstance(matched[0], spec.roundtrip_class) | ||
| assert matched[0] is node | ||
|
|
||
|
|
||
| def test_node_attrs(node_spec): | ||
|
|
@@ -697,6 +699,31 @@ def test_node_attrs_preserved_by_nodes(node_spec): | |
| assert getattr(retrieved, attr) == getattr(node, attr), f"{spec.name}.{attr} not preserved by nodes()" | ||
|
|
||
|
|
||
| def test_identity_preservation(init_cuda): | ||
| """Round-trips through nodes(), edges(), and pred/succ return extant | ||
| objects rather than duplicates.""" | ||
| g = GraphDef() | ||
| a = g.join() | ||
| b = a.join() | ||
|
|
||
| # nodes() | ||
| assert any(x is a for x in g.nodes()) | ||
| assert any(x is b for x in g.nodes()) | ||
|
|
||
| # succ/pred | ||
| a.succ = {b} | ||
| (b2,) = a.succ | ||
| assert b2 is b | ||
|
|
||
| (a2,) = b.pred | ||
| assert a2 is a | ||
|
|
||
| # edges() | ||
| ((a2, b2),) = g.edges() | ||
| assert a2 is a | ||
| assert b2 is b | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since a potential failure mode here might be unintentional leaking of references, maybe it's also worth testing something like: and then asserting that the registry is empty. |
||
|
|
||
| # ============================================================================= | ||
| # GraphDef basics | ||
| # ============================================================================= | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
expired()check was always unnecessary. ForGraphNodesit is harmful. By not purging invalidatedGraphNodehandles, the driver reusing an old pointer could cause a false hit.