diff --git a/torch_nn_layers.py b/torch_nn_layers.py index b17d0e6..89e4e91 100644 --- a/torch_nn_layers.py +++ b/torch_nn_layers.py @@ -20,7 +20,7 @@ class TorchModel(Component): model_in: InArg[list] loss_in: InArg[str] - learning_rate = InArg[float] + learning_rate: InArg[float] optimizer_in: InArg[str] should_flatten: InArg[bool] model_config: OutArg[nn.Module] @@ -357,4 +357,4 @@ def execute(self, ctx) -> None: if self.model_in.value is None: self.model_out.value = [nn.Dropout(prob)] else: - self.model_out.value = self.model_in.value + [nn.Dropout(prob)] \ No newline at end of file + self.model_out.value = self.model_in.value + [nn.Dropout(prob)]