Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 669413396
  • Loading branch information
stanojevic committed Aug 30, 2024
1 parent 2afad64 commit 0801921
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
8 changes: 7 additions & 1 deletion synjax/_src/utils/semirings.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,13 @@ def _sum(a):
def _sum_jvp(primals, tangents):
a, = primals
a_dot, = tangents
selection = self._one_hot_selection(a, axis)
if self.smoothing == "st-softmax":
# This is a special case because - gradient of
# special.straight_through_replace is correctly defined, but not used
# here because we are overriding the gradient with defjvp.
selection = jax.nn.softmax(a / self.temperature, axis=axis)
else:
selection = self._one_hot_selection(a, axis)
return jnp.sum(selection * a, axis), jnp.sum(selection * a_dot, axis)
return _sum(a)

Expand Down
2 changes: 1 addition & 1 deletion synjax/_src/utils/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def straight_through_replace(differentiable_input, non_differentiable_output):
raise ValueError("Shapes for straight-through replacement don't match.")
return tadd(jax.lax.stop_gradient(tsub(non_differentiable_output,
differentiable_input)),
non_differentiable_output)
differentiable_input)


def sparsemax(x: Array, axis: Union[int, Shape] = -1) -> Array:
Expand Down

0 comments on commit 0801921

Please sign in to comment.