Skip to content

Commit

Permalink
Merge pull request bitsandbytes-foundation#14 from SirRob1997/main
Browse files Browse the repository at this point in the history
[FIX] passing of sparse in StableEmbedding
  • Loading branch information
TimDettmers authored Nov 29, 2021
2 parents 037022e + 67a1283 commit 262350c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
class StableEmbedding(torch.nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
sparse: bool = True, _weight: Optional[Tensor] = None) -> None:
super(StableEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, False, _weight)
sparse: bool = False, _weight: Optional[Tensor] = None) -> None:
super(StableEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight)
self.norm = torch.nn.LayerNorm(embedding_dim)
GlobalOptimManager.get_instance().register_parameters(self.weight)
GlobalOptimManager.get_instance().override_config(self.weight, 'optim_bits', 32)
Expand Down

0 comments on commit 262350c

Please sign in to comment.