diff --git a/morpheus/stages/inference/triton_inference_stage.py b/morpheus/stages/inference/triton_inference_stage.py index e5901363f9..e6c5c0fbb7 100644 --- a/morpheus/stages/inference/triton_inference_stage.py +++ b/morpheus/stages/inference/triton_inference_stage.py @@ -781,3 +781,13 @@ def _get_cpp_inference_node(self, builder: mrc.Builder) -> mrc.SegmentObject: self._needs_logits, self._input_mapping, self._output_mapping) + + def _build_single(self, builder: mrc.Builder, input_node: mrc.SegmentObject) -> mrc.SegmentObject: + node = super()._build_single(builder, input_node) + + # ensure that the C++ impl only uses a single progress engine + if (self._build_cpp_node()): + node.launch_options.pe_count = 1 + node.launch_options.engines_per_pe = 1 + + return node