From 99b1809e0bd721024190d486ff4125c7901644dd Mon Sep 17 00:00:00 2001 From: Sonnet Contributor Date: Wed, 18 Sep 2024 08:17:22 -0700 Subject: [PATCH] Resolves potential Nones caught by tytype. PiperOrigin-RevId: 675998642 --- sonnet/src/metrics.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/sonnet/src/metrics.py b/sonnet/src/metrics.py index fb871a0..aae94ed 100644 --- a/sonnet/src/metrics.py +++ b/sonnet/src/metrics.py @@ -62,7 +62,13 @@ def initialize(self, value: tf.Tensor): def update(self, value: tf.Tensor): """See base class.""" self.initialize(value) - self.sum.assign_add(value) + self._checked_sum.assign_add(value) + + @property + def _checked_sum(self): + if self.sum is None: + raise ValueError("Metric is not initialized. Call `initialize` first.") + return self.sum @property def value(self) -> tf.Tensor: @@ -71,6 +77,8 @@ def value(self) -> tf.Tensor: def reset(self): """See base class.""" + if self.sum is None: + raise ValueError("Metric is not initialized. Call `initialize` first.") self.sum.assign(tf.zeros_like(self.sum)) @@ -90,15 +98,23 @@ def initialize(self, value: tf.Tensor): def update(self, value: tf.Tensor): """See base class.""" self.initialize(value) - self.sum.assign_add(value) + self._checkedsum.assign_add(value) self.count.assign_add(1) + @property + def _checked_sum(self) -> tf.Variable: + if self.sum is None: + raise ValueError("Metric is not initialized. Call `initialize` first.") + return self.sum + @property def value(self) -> tf.Tensor: """See base class.""" # TODO(cjfj): Assert summed type is floating-point? - return self.sum / tf.cast(self.count, dtype=self.sum.dtype) + return self._checked_sum / tf.cast( + self.count, dtype=self._checked_sum.dtype + ) def reset(self): - self.sum.assign(tf.zeros_like(self.sum)) + self._checked_sum.assign(tf.zeros_like(self._checked_sum)) self.count.assign(0)