diff --git a/sonnet/src/conformance/descriptors.py b/sonnet/src/conformance/descriptors.py index 985ff711..57b9c281 100644 --- a/sonnet/src/conformance/descriptors.py +++ b/sonnet/src/conformance/descriptors.py @@ -226,8 +226,8 @@ def unroll_descriptors(descriptors, unroller=None): RECURRENT_MODULES = ( - unroll_descriptors(RNN_CORES, snt.dynamic_unroll) + - unroll_descriptors(RNN_CORES, snt.static_unroll) + + # unroll_descriptors(RNN_CORES, snt.dynamic_unroll) + + # unroll_descriptors(RNN_CORES, snt.static_unroll) + unroll_descriptors(UNROLLED_RNN_CORES)) diff --git a/sonnet/src/conformance/optimizer_test.py b/sonnet/src/conformance/optimizer_test.py index 9d581605..2d77f896 100644 --- a/sonnet/src/conformance/optimizer_test.py +++ b/sonnet/src/conformance/optimizer_test.py @@ -26,7 +26,8 @@ class OptimizerConformanceTest(test_utils.TestCase, parameterized.TestCase): @test_utils.combined_named_parameters( - BATCH_MODULES + RECURRENT_MODULES, + # BATCH_MODULES + RECURRENT_MODULES, + RECURRENT_MODULES, test_utils.named_bools("construct_module_in_function"), ) def test_variable_order_is_constant(self, module_fn, input_shape, dtype, @@ -57,6 +58,8 @@ def f(): self.skipTest("Module did not create variables in forward pass.") else: assert len(logged_variables) == 2 + # print('logged_variables[0] is', logged_variables[0], flush=True) + # print('logged_variables[1] is', logged_variables[1], flush=True) self.assertCountEqual(logged_variables[0], logged_variables[1]) if __name__ == "__main__":