Skip to content

Commit

Permalink
Add more ops (including all Reduce ops) using the relay frontend (apa…
Browse files Browse the repository at this point in the history
…che#11)

* [WIP] add more ops. Some fail at the moment

* skip some tests

* Remove duplicate tests for squeeze
  • Loading branch information
Florin Blanaru authored and Josh Fromm committed Feb 3, 2023
1 parent 1f57626 commit e96d033
Show file tree
Hide file tree
Showing 2 changed files with 843 additions and 9 deletions.
26 changes: 24 additions & 2 deletions python/tvm/relax/frontend/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,26 @@ def _get_convert_map(opset):
"Constant": Constant,
"Sub": Sub,
"LayerNormalization": relay.frontend.onnx.LayerNormalization,
"SkipLayerNormalization": relay.frontend.onnx.SkipLayerNormalization,
"EmbedLayerNormalization": relay.frontend.onnx.EmbedLayerNormalization,
# defs/reduction
"ReduceMax": relay.frontend.onnx.ReduceMax,
"ReduceMin": relay.frontend.onnx.ReduceMin,
"ReduceSum": relay.frontend.onnx.ReduceSum,
"ReduceMean": relay.frontend.onnx.ReduceMean,
"ReduceProd": relay.frontend.onnx.ReduceProd,
"ReduceLogSumExp": relay.frontend.onnx.ReduceLogSumExp,
"ReduceLogSum": relay.frontend.onnx.ReduceLogSum,
"ReduceSumSquare": relay.frontend.onnx.ReduceSumSquare,
"ReduceL1": relay.frontend.onnx.ReduceL1,
"ReduceL2": relay.frontend.onnx.ReduceL2,
"Expand": relay.frontend.onnx.Expand,
"ConstantOfShape": relay.frontend.onnx.ConstantOfShape,
"Slice": relay.frontend.onnx.Slice,
"Attention": relay.frontend.onnx.Attention,
"Pad": relay.frontend.onnx.Pad,
"Split": relay.frontend.onnx.Split,
"Tile": relay.frontend.onnx.Tile,
}


Expand Down Expand Up @@ -780,7 +800,7 @@ def _parse_attr(self, attr_proto):

def _relay_input_adapter(self, inputs):
"""Creates equivalent input Relay vars from the input Relax vars"""
relay_vars = []
relay_vars = onnx_input()
for relax_var in inputs:
shape_values = []
for shape_value in relax_var.struct_info.shape.values:
Expand Down Expand Up @@ -896,7 +916,9 @@ def _convert_operator(self, op_name, inputs, attrs, opset):
if issubclass(convert_class, RelayOnnxOpConverter):
relay_inputs = self._relay_input_adapter(inputs)
# The op_function might change the inputs to the relay op. Use a copy of the inputs.
relay_inputs_copy = [relay_input for relay_input in relay_inputs]
relay_inputs_copy = onnx_input()
for relay_input in relay_inputs:
relay_inputs_copy.append(relay_input)
# TODO handle params passing
relay_output = op_function(relay_inputs_copy, attrs, params=[])
sym = self._relay_output_adapter(inputs, relay_inputs, relay_output)
Expand Down
Loading

0 comments on commit e96d033

Please sign in to comment.