Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ParametricAttention.v2 #913

Merged
merged 7 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion thinc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
MultiSoftmax,
MXNetWrapper,
ParametricAttention,
ParametricAttention_v2,
PyTorchLSTM,
PyTorchRNNWrapper,
PyTorchWrapper,
Expand Down Expand Up @@ -207,7 +208,7 @@
"PyTorchWrapper", "PyTorchRNNWrapper", "PyTorchLSTM",
"TensorFlowWrapper", "keras_subclass", "MXNetWrapper",
"PyTorchWrapper_v2", "Softmax_v2", "PyTorchWrapper_v3",
"SparseLinear_v2", "TorchScriptWrapper_v1",
"SparseLinear_v2", "TorchScriptWrapper_v1", "ParametricAttention_v2",

"add", "bidirectional", "chain", "clone", "concatenate", "noop",
"residual", "uniqued", "siamese", "list2ragged", "ragged2list",
Expand Down
2 changes: 2 additions & 0 deletions thinc/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .noop import noop
from .padded2list import padded2list
from .parametricattention import ParametricAttention
from .parametricattention_v2 import ParametricAttention_v2
from .premap_ids import premap_ids
from .pytorchwrapper import (
PyTorchRNNWrapper,
Expand Down Expand Up @@ -94,6 +95,7 @@
"Mish",
"MultiSoftmax",
"ParametricAttention",
"ParametricAttention_v2",
"PyTorchLSTM",
"PyTorchWrapper",
"PyTorchWrapper_v2",
Expand Down
100 changes: 100 additions & 0 deletions thinc/layers/parametricattention_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import Callable, Optional, Tuple, cast

from ..config import registry
from ..model import Model
from ..types import Floats2d, Ragged
from ..util import get_width
from .noop import noop

InT = Ragged
OutT = Ragged

KEY_TRANSFORM_REF: str = "key_transform"


@registry.layers("ParametricAttention.v2")
def ParametricAttention_v2(
*,
key_transform: Optional[Model[Floats2d, Floats2d]] = None,
nO: Optional[int] = None
) -> Model[InT, OutT]:
if key_transform is None:
key_transform = noop()

"""Weight inputs by similarity to a learned vector"""
return Model(
"para-attn",
forward,
init=init,
params={"Q": None},
dims={"nO": nO},
refs={KEY_TRANSFORM_REF: key_transform},
layers=[key_transform],
)


def forward(model: Model[InT, OutT], Xr: InT, is_train: bool) -> Tuple[OutT, Callable]:
Q = model.get_param("Q")
key_transform = model.get_ref(KEY_TRANSFORM_REF)

attention, bp_attention = _get_attention(
model.ops, Q, key_transform, Xr.dataXd, Xr.lengths, is_train
)
output, bp_output = _apply_attention(model.ops, attention, Xr.dataXd, Xr.lengths)

def backprop(dYr: OutT) -> InT:
dX, d_attention = bp_output(dYr.dataXd)
dQ, dX2 = bp_attention(d_attention)
model.inc_grad("Q", dQ.ravel())
dX += dX2
return Ragged(dX, dYr.lengths)

return Ragged(output, Xr.lengths), backprop


def init(
model: Model[InT, OutT], X: Optional[InT] = None, Y: Optional[OutT] = None
) -> None:
key_transform = model.get_ref(KEY_TRANSFORM_REF)
width = get_width(X) if X is not None else None
if width:
model.set_dim("nO", width)
if key_transform.has_dim("nO"):
key_transform.set_dim("nO", width)

# Randomly initialize the parameter, as though it were an embedding.
Q = model.ops.alloc1f(model.get_dim("nO"))
Q += model.ops.xp.random.uniform(-0.1, 0.1, Q.shape)
model.set_param("Q", Q)

X_array = X.dataXd if X is not None else None
Y_array = Y.dataXd if Y is not None else None

key_transform.initialize(X_array, Y_array)


def _get_attention(ops, Q, key_transform, X, lengths, is_train):
K, K_bp = key_transform(X, is_train=is_train)

attention = ops.gemm(K, ops.reshape2f(Q, -1, 1))
attention = ops.softmax_sequences(attention, lengths)

def get_attention_bwd(d_attention):
d_attention = ops.backprop_softmax_sequences(d_attention, attention, lengths)
dQ = ops.gemm(K, d_attention, trans1=True)
dY = ops.xp.outer(d_attention, Q)
dX = K_bp(dY)
return dQ, dX

return attention, get_attention_bwd


def _apply_attention(ops, attention, X, lengths):
output = X * attention

def apply_attention_bwd(d_output):
d_attention = (X * d_output).sum(axis=1, keepdims=True)
dX = d_output * attention
return dX, d_attention

return output, apply_attention_bwd
2 changes: 2 additions & 0 deletions thinc/tests/layers/test_layers_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def assert_data_match(Y, out_data):
("MultiSoftmax.v1", {"nOs": (1, 3)}, array2d, array2d),
# ("CauchySimilarity.v1", {}, (array2d, array2d), array1d),
("ParametricAttention.v1", {}, ragged, ragged),
("ParametricAttention.v2", {}, ragged, ragged),
("ParametricAttention.v2", {"key_transform": {"@layers": "Gelu.v1"}}, ragged, ragged),
danieldk marked this conversation as resolved.
Show resolved Hide resolved
("SparseLinear.v1", {}, (numpy.asarray([1, 2, 3], dtype="uint64"), array1d, numpy.asarray([1, 1], dtype="i")), array2d),
("SparseLinear.v2", {}, (numpy.asarray([1, 2, 3], dtype="uint64"), array1d, numpy.asarray([1, 1], dtype="i")), array2d),
("remap_ids.v1", {"dtype": "f"}, ["a", 1, 5.0], array2dint),
Expand Down
10 changes: 10 additions & 0 deletions thinc/tests/layers/test_parametric_attention_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from thinc.layers.gelu import Gelu
from thinc.layers.parametricattention_v2 import (
KEY_TRANSFORM_REF,
ParametricAttention_v2,
)


def test_key_transform_used():
attn = ParametricAttention_v2(key_transform=Gelu())
assert attn.get_ref(KEY_TRANSFORM_REF).name == "gelu"
38 changes: 38 additions & 0 deletions website/docs/api-layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,44 @@ attention mechanism.
https:/explosion/thinc/blob/master/thinc/layers/parametricattention.py
```

### ParametricAttention_v2 {#parametricattention_v2 tag="function"}

<inline-list>

- **Input:** <ndarray>Ragged</ndarray>
- **Output:** <ndarray>Ragged</ndarray>
- **Parameters:** <ndarray shape="nO,">Q</ndarray>

</inline-list>

A layer that uses the parametric attention scheme described by
[Yang et al. (2016)](https://aclanthology.org/N16-1174).
The layer learns a parameter vector that is used as the keys in a single-headed
attention mechanism.

<infobox variant="warning">

The original `ParametricAttention` layer uses the hidden representation as-is
for the keys in the attention. This differs from the paper that introduces
parametric attention (Equation 5). `ParametricAttention_v2` adds the option to
transform the key representation in line with the paper by passing such a
transformation through the `key_transform` parameter.

</infobox>


| Argument | Type | Description |
|-----------------|----------------------------------------------|------------------------------------------------------------------------|
| `key_transform` | <tt>Optional[Model[Floats2d, Floats2d]]</tt> | Transformation to apply to the key representations. Defaults to `None` |
| `nO` | <tt>Optional[int]</tt> | The size of the output vectors. |
| **RETURNS** | <tt>Model[Ragged, Ragged]</tt> | The created attention layer. |

```python
https:/explosion/thinc/blob/master/thinc/layers/parametricattention_v2.py
```



### Relu {#relu tag="function"}

<inline-list>
Expand Down