Skip to content

Commit

Permalink
Add ParametricAttention.v2
Browse files Browse the repository at this point in the history
This layer is an extension of the existing `ParametricAttention` layer,
adding support for transformations (such as a non-linear layer) of the
key representation. This brings the model closer to the paper that
suggested it (Yang et al, 2016) and gave slightly better results in
experiments.
  • Loading branch information
danieldk committed Dec 12, 2023
1 parent c16f552 commit a2e178f
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 1 deletion.
3 changes: 2 additions & 1 deletion thinc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
MultiSoftmax,
MXNetWrapper,
ParametricAttention,
ParametricAttention_v2,
PyTorchLSTM,
PyTorchRNNWrapper,
PyTorchWrapper,
Expand Down Expand Up @@ -207,7 +208,7 @@
"PyTorchWrapper", "PyTorchRNNWrapper", "PyTorchLSTM",
"TensorFlowWrapper", "keras_subclass", "MXNetWrapper",
"PyTorchWrapper_v2", "Softmax_v2", "PyTorchWrapper_v3",
"SparseLinear_v2", "TorchScriptWrapper_v1",
"SparseLinear_v2", "TorchScriptWrapper_v1", "ParametricAttention_v2",

"add", "bidirectional", "chain", "clone", "concatenate", "noop",
"residual", "uniqued", "siamese", "list2ragged", "ragged2list",
Expand Down
2 changes: 2 additions & 0 deletions thinc/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .noop import noop
from .padded2list import padded2list
from .parametricattention import ParametricAttention
from .parametricattention_v2 import ParametricAttention_v2
from .premap_ids import premap_ids
from .pytorchwrapper import (
PyTorchRNNWrapper,
Expand Down Expand Up @@ -94,6 +95,7 @@
"Mish",
"MultiSoftmax",
"ParametricAttention",
"ParametricAttention_v2",
"PyTorchLSTM",
"PyTorchWrapper",
"PyTorchWrapper_v2",
Expand Down
107 changes: 107 additions & 0 deletions thinc/layers/parametricattention_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from typing import Callable, Optional, Tuple, cast

from ..config import registry
from ..model import Model
from ..types import Floats2d, Ragged
from ..util import get_width

InT = Ragged
OutT = Ragged


@registry.layers("ParametricAttention.v2")
def ParametricAttention_v2(
*,
key_transform: Optional[Model[Floats2d, Floats2d]] = None,
nO: Optional[int] = None
) -> Model[InT, OutT]:
if key_transform is None:
layers = []
refs = {}
else:
layers = [key_transform]
refs = {"key_transform": cast(Optional[Model], key_transform)}

layers = [key_transform] if key_transform is not None else []

"""Weight inputs by similarity to a learned vector"""
return Model(
"para-attn",
forward,
init=init,
params={"Q": None},
dims={"nO": nO},
layers=layers,
refs=refs,
)


def forward(model: Model[InT, OutT], Xr: InT, is_train: bool) -> Tuple[OutT, Callable]:
Q = model.get_param("Q")
key_transform = model.maybe_get_ref("key_transform")

attention, bp_attention = _get_attention(
model.ops, Q, key_transform, Xr.dataXd, Xr.lengths, is_train
)
output, bp_output = _apply_attention(model.ops, attention, Xr.dataXd, Xr.lengths)

def backprop(dYr: OutT) -> InT:
dX, d_attention = bp_output(dYr.dataXd)
dQ, dX2 = bp_attention(d_attention)
model.inc_grad("Q", dQ.ravel())
dX += dX2
return Ragged(dX, dYr.lengths)

return Ragged(output, Xr.lengths), backprop


def init(
model: Model[InT, OutT], X: Optional[InT] = None, Y: Optional[OutT] = None
) -> None:
key_transform = model.maybe_get_ref("key_transform")
width = get_width(X) if X is not None else None
if width:
model.set_dim("nO", width)
if key_transform is not None:
key_transform.set_dim("nO", width)

# Randomly initialize the parameter, as though it were an embedding.
Q = model.ops.alloc1f(model.get_dim("nO"))
Q += model.ops.xp.random.uniform(-0.1, 0.1, Q.shape)
model.set_param("Q", Q)

X_array = X.dataXd if X is not None else None
Y_array = Y.dataXd if Y is not None else None

if key_transform is not None:
key_transform.initialize(X_array, Y_array)


def _get_attention(ops, Q, key_transform, X, lengths, is_train):
if key_transform is None:
K, K_bp = X, lambda dY: dY
else:
K, K_bp = key_transform(X, is_train=is_train)

attention = ops.gemm(K, ops.reshape2f(Q, -1, 1))
attention = ops.softmax_sequences(attention, lengths)

def get_attention_bwd(d_attention):
d_attention = ops.backprop_softmax_sequences(d_attention, attention, lengths)
dQ = ops.gemm(K, d_attention, trans1=True)
dY = ops.xp.outer(d_attention, Q)
dX = K_bp(dY)
return dQ, dX

return attention, get_attention_bwd


def _apply_attention(ops, attention, X, lengths):
output = X * attention

def apply_attention_bwd(d_output):
d_attention = (X * d_output).sum(axis=1, keepdims=True)
dX = d_output * attention
return dX, d_attention

return output, apply_attention_bwd
3 changes: 3 additions & 0 deletions thinc/tests/layers/test_layers_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from thinc.api import Dropout, Model, NumpyOps, registry, with_padded
from thinc.backends import NumpyOps
from thinc.compat import has_torch
from thinc.layers.relu import Relu
from thinc.types import Array2d, Floats2d, FloatsXd, Padded, Ragged, Shape
from thinc.util import data_validation, get_width

Expand Down Expand Up @@ -129,6 +130,8 @@ def assert_data_match(Y, out_data):
("MultiSoftmax.v1", {"nOs": (1, 3)}, array2d, array2d),
# ("CauchySimilarity.v1", {}, (array2d, array2d), array1d),
("ParametricAttention.v1", {}, ragged, ragged),
("ParametricAttention.v2", {}, ragged, ragged),
("ParametricAttention.v2", {"key_transform": {"@layers": "Gelu.v1"}}, ragged, ragged),
("SparseLinear.v1", {}, (numpy.asarray([1, 2, 3], dtype="uint64"), array1d, numpy.asarray([1, 1], dtype="i")), array2d),
("SparseLinear.v2", {}, (numpy.asarray([1, 2, 3], dtype="uint64"), array1d, numpy.asarray([1, 1], dtype="i")), array2d),
("remap_ids.v1", {"dtype": "f"}, ["a", 1, 5.0], array2dint),
Expand Down
38 changes: 38 additions & 0 deletions website/docs/api-layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,44 @@ attention mechanism.
https://github.com/explosion/thinc/blob/master/thinc/layers/parametricattention.py
```

### ParametricAttention_v2 {#parametricattention_v2 tag="function"}

<inline-list>

- **Input:** <ndarray>Ragged</ndarray>
- **Output:** <ndarray>Ragged</ndarray>
- **Parameters:** <ndarray shape="nO,">Q</ndarray>

</inline-list>

A layer that uses the parametric attention scheme described by
[Yang et al. (2016)](https://www.cs.cmu.edu/~./hovy/papers/16HLT-hierarchical-attention-networks.pdf).
The layer learns a parameter vector that is used as the keys in a single-headed
attention mechanism.

<infobox variant="warning">

The original `ParametricAttention` layer uses the hidden representation as-is
for the keys in the attention. This differs from the paper that introduces
parametric attention (Equation 5). `ParametricAttention_v2` adds the option to
transform the key representation in line with the paper by passing such a
transformation through the `key_transform` parameter.

</infobox>


| Argument | Type | Description |
|-----------------|----------------------------------------------|------------------------------------------------------------------------|
| `key_transform` | <tt>Optional[Model[Floats2d, Floats2d]]</tt> | Transformation to apply to the key representations. Defaults to `None` |
| `nO` | <tt>Optional[int]</tt> | The size of the output vectors. |
| **RETURNS** | <tt>Model[Ragged, Ragged]</tt> | The created attention layer. |

```python
https://github.com/explosion/thinc/blob/master/thinc/layers/parametricattention_v2.py
```



### Relu {#relu tag="function"}

<inline-list>
Expand Down

0 comments on commit a2e178f

Please sign in to comment.