-
Notifications
You must be signed in to change notification settings - Fork 537
Expand file tree
/
Copy pathAutoGraphAttribute.cs
More file actions
125 lines (112 loc) · 4.42 KB
/
AutoGraphAttribute.cs
File metadata and controls
125 lines (112 loc) · 4.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
using MethodBoundaryAspect.Fody.Attributes;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Tensorflow.Eager;
using Tensorflow.Functions;
using static Tensorflow.Binding;
namespace Tensorflow.Graphs
{
/// <summary>
/// func_graph.py func_graph_from_py_func
/// </summary>
[AllowChangingInputArguments]
public sealed class AutoGraphAttribute : OnMethodBoundaryAspect
{
ConcreteFunction function;
Tensors originalInputs;
string func_name;
static Dictionary<string, ConcreteFunction> functions = new Dictionary<string, ConcreteFunction>();
public override void OnEntry(MethodExecutionArgs args)
{
// TODO: func_name can be cache in FullName + Args
func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}";
if (functions.ContainsKey(func_name))
{
function = functions[func_name];
if (args.Arguments[0] is Tensors tensor_inputs)
args.ReturnValue = ConvertReturnValue(function.FilteredCall(tensor_inputs));
else
args.ReturnValue = ConvertReturnValue(function.FilteredCall(args.Arguments.Select(x => x as Tensor).ToArray()));
args.FlowBehavior = FlowBehavior.Return;
return;
}
// make function as an Operation by autograph
// need to restore mode when exits
function = new ConcreteFunction(func_name);
function.Enter();
// convert to Tensors
if (args.Arguments[0] is Tensors inputs)
{
originalInputs = inputs;
var new_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.shape, name: "inputs")).ToArray();
args.Arguments[0] = new Tensors(new_inputs);
}
else
{
originalInputs = new Tensors();
// convert args to placeholder
for (var i = 0; i < args.Arguments.Length; i++)
{
if (args.Arguments[i] is EagerTensor tensor)
{
originalInputs.Add(tensor);
args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.shape, name: "inputs");
}
}
}
}
public override void OnExit(MethodExecutionArgs args)
{
if (args.ReturnValue is Tensors outputs)
{
Tensors inputs = null;
outputs = mark_as_return(outputs);
if (args.Arguments[0] is Tensors inputs1)
inputs = inputs1;
else
inputs = args.Arguments.Select(x => x as Tensor).ToArray();
inputs = inputs.Where(x => x.op.OpType == "Placeholder"
&& x.op.name.StartsWith("inputs")).ToArray();
function.ToGraph(inputs, outputs);
}
else if (args.ReturnValue is Tensor output)
{
var inputs = args.Arguments.Select(x => x as Tensor)
.Where(x => x.op.type == "Placeholder" && x.op.name.StartsWith("inputs"))
.ToArray();
var outputs2 = array_ops.identity(output);
function.ToGraph(inputs, outputs2);
}
function.Exit();
// cache function.
function.ReturnType = args.ReturnValue.GetType();
function._set_infer_function();
functions[func_name] = function;
// run function
args.ReturnValue = ConvertReturnValue(function.FilteredCall(originalInputs));
}
object ConvertReturnValue(Tensors tensors)
{
if (function.ReturnType == typeof(Tensor))
return (Tensor)tensors;
else
return tensors;
}
/// <summary>
/// Acts like identity but marks the `Tensor` as a return value.
/// </summary>
/// <param name="tensors"></param>
/// <returns></returns>
public Tensors mark_as_return(Tensors tensors)
{
if (tensors == null)
return null;
var result = new Tensors();
foreach (var tensor in tensors)
result.Add(array_ops.identity(tensor));
return result;
}
}
}