-
Notifications
You must be signed in to change notification settings - Fork 537
Expand file tree
/
Copy pathSubGraphUtility.cs
More file actions
177 lines (164 loc) · 6.78 KB
/
SubGraphUtility.cs
File metadata and controls
177 lines (164 loc) · 6.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
using System;
using System.Collections.Generic;
using System.Text;
using System.Linq;
using static Tensorflow.Binding;
namespace Tensorflow.Graphs
{
public class SubGraphUtility
{
/// <summary>
/// Copies the tensor and all its inputs recursively to the outer graph.
/// </summary>
/// <param name="tensors"></param>
/// <param name="graph"></param>
/// <param name="add_sources"></param>
/// <param name="handle_captures"></param>
/// <param name="base_graph"></param>
/// <returns></returns>
public static Dictionary<ITensorOrOperation, Operation> lift_to_graph(Tensors init_tensors,
FuncGraph graph,
List<Tensor> sources,
bool add_sources = false,
bool handle_captures = false,
Graph base_graph = null,
Dictionary<ITensorOrOperation, Operation> op_map = null)
{
base_graph = base_graph ?? init_tensors[0].graph;
op_map = op_map ?? new Dictionary<ITensorOrOperation, Operation>();
var visited_ops = sources.Select(x => x.op).ToList();
foreach (var init_tensor in init_tensors)
{
var src = map_subgraph(init_tensor, sources, visited_ops, add_sources);
sources.AddRange(src);
}
var ops_to_copy = new List<Operation>();
var marked_ops = new List<Operation>();
var ops_to_visit = new Stack<Operation>(init_tensors.Select(x => x.op));
var unvisited_ops = new List<Operation>(ops_to_visit.ToList());
while (unvisited_ops.Count > 0)
{
while(ops_to_visit.Count > 0)
{
var op = ops_to_visit.Pop();
if (marked_ops.Contains(op))
continue;
marked_ops.Add(op);
ops_to_copy.append(op);
foreach(var inp in op.inputs)
{
}
}
// difference_update
unvisited_ops.difference_update(marked_ops);
if (unvisited_ops.Count > 0)
ops_to_visit.Push(unvisited_ops.Last());
}
// When lifting from one FuncGraph to another, we will need to capture the
// relevant tensors as well.
var inverse_captures = new Dictionary<Tensor, Tensor>();
Tensor[] internal_captures = null;
if (base_graph is FuncGraph base_func_graph)
{
var captures = base_func_graph.captures;
foreach (var (external_capture, internal_capture) in captures)
inverse_captures[internal_capture] = external_capture;
internal_captures = base_func_graph.internal_captures;
}
graph.as_default();
var source_ops = new List<Operation>();
// Add the sources in the same order as the original graph.
foreach (var s in internal_captures)
{
if (sources.Contains(s))
{
sources.Remove(s);
source_ops.Add(s.op);
_copy_source(s: s,
graph: graph,
op_map: op_map,
handle_captures: handle_captures,
inverse_captures: inverse_captures,
base_graph: base_graph);
}
}
foreach(var op in reversed(ops_to_copy))
{
if (source_ops.Contains(op) || op_map.ContainsKey(op))
continue;
_copy_non_source(op, graph, op_map, base_graph);
}
graph.Exit();
return op_map;
}
static void _copy_source(Tensor s,
FuncGraph graph,
Dictionary<ITensorOrOperation, Operation> op_map,
bool handle_captures,
Dictionary<Tensor, Tensor> inverse_captures,
Graph base_graph)
{
Tensor copied_placeholder = null;
if (handle_captures && inverse_captures.ContainsKey(s))
copied_placeholder = graph.capture(inverse_captures[s], name: s.op.name);
else
throw new NotImplementedException("");
op_map[s] = copied_placeholder;
// Add an entry for the op of the source tensor so that if there are any nodes
// depending on that op via control dependencies it can work correctly.
op_map[s.op] = copied_placeholder.op;
}
static void _copy_non_source(Operation op, FuncGraph graph, Dictionary<ITensorOrOperation, Operation> op_map, Graph base_graph)
{
Operation copied_op = null;
var copied_inputs = new Tensors();
tf_with(ops.control_dependencies(new object[] { op }), delegate
{
// Create a new op in the destination graph if it doesn't exist before.
var attrs = new Dictionary<string, AttrValue>();
foreach (var attr_def in op.node_def.Attr)
attrs[attr_def.Key] = attr_def.Value;
copied_op = graph.create_op(op.type,
copied_inputs,
dtypes: op.outputs.Select(x => x.dtype).ToArray(),
attrs: attrs,
name: op.name);
});
op_map[op] = copied_op;
foreach (var (i, o) in enumerate(op.outputs))
op_map[o] = copied_op.outputs[i];
}
/// <summary>
/// Walk a Graph and capture the subgraph between init_tensor and sources.
/// </summary>
/// <param name="init_tensor"></param>
/// <param name="add_sources"></param>
public static List<Tensor> map_subgraph(Tensor init_tensor,
List<Tensor> sources,
List<Operation> visited_ops,
bool add_sources)
{
var ops_to_visit = new Stack<Operation>();
ops_to_visit.Push(init_tensor.op);
var extra_sources = new List<Tensor>();
while (ops_to_visit.Count > 0)
{
var op = ops_to_visit.Pop();
if (visited_ops.Contains(op))
continue;
visited_ops.Add(op);
bool should_raise = false;
if (should_raise)
throw new RuntimeError($"Unable to lift tensor {init_tensor.name}.");
if(op.type == "Placeholder")
{
extra_sources.AddRange(op.outputs);
}
foreach(var inp in op.inputs)
{
}
}
return extra_sources;
}
}
}