Skip to content

Commit

Permalink
Attempt to fix examples
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Jan 25, 2022
1 parent 469bde0 commit 1acac78
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
17 changes: 11 additions & 6 deletions examples/sst2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,19 +162,22 @@ def __call__(self, inputs: Array,
class SimpleLSTM(nn.Module):
"""A simple unidirectional LSTM."""

def setup(self):
self.lstm = nn.OptimizedLSTMCell()

@functools.partial(
nn.transforms.scan,
variable_broadcast='params',
in_axes=1, out_axes=1,
split_rngs={'params': False})
@nn.compact
def __call__(self, carry, x):
return nn.OptimizedLSTMCell()(carry, x)
return self.lstm(carry, x)

@staticmethod
def initialize_carry(batch_dims, hidden_size):
@nn.module.wrap_method_once
def initialize_carry(self, batch_dims, hidden_size):
# Use fixed random key since default state init fn is just zeros.
return nn.OptimizedLSTMCell.initialize_carry(
return self.lstm.initialize_carry(
jax.random.PRNGKey(0), batch_dims, hidden_size)


Expand All @@ -190,12 +193,14 @@ def __call__(self, embedded_inputs, lengths):
batch_size = embedded_inputs.shape[0]

# Forward LSTM.
initial_state = SimpleLSTM.initialize_carry((batch_size,), self.hidden_size)
initial_state = self.forward_lstm.initialize_carry((batch_size,),
self.hidden_size)
_, forward_outputs = self.forward_lstm(initial_state, embedded_inputs)

# Backward LSTM.
reversed_inputs = flip_sequences(embedded_inputs, lengths)
initial_state = SimpleLSTM.initialize_carry((batch_size,), self.hidden_size)
initial_state = self.backward_lstm.initialize_carry((batch_size,),
self.hidden_size)
_, backward_outputs = self.backward_lstm(initial_state, reversed_inputs)
backward_outputs = flip_sequences(backward_outputs, lengths)

Expand Down
2 changes: 1 addition & 1 deletion examples/sst2/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_lstm_returns_correct_output_shape(self):
rng = jax.random.PRNGKey(0)
inputs = np.random.RandomState(0).normal(
size=[batch_size, seq_len, embedding_size])
initial_state = models.SimpleLSTM.initialize_carry((batch_size,), hidden_size)
initial_state = model.initialize_carry((batch_size,), hidden_size)
(_, output), _ = model.init_with_output(rng, initial_state, inputs)
self.assertEqual((batch_size, seq_len, hidden_size), output.shape)

Expand Down

0 comments on commit 1acac78

Please sign in to comment.