Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAML++: BNRS #327

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

* New vision example: MAML++. (@[DubiousCactus](https:/DubiousCactus))
* New BatchNorm layer with per-step running statistics and weights & biases from MAML++. (@[Théo Morales](https:/DubiousCactus))
* New vision example: MAML++. (@[Théo Morales](https:/DubiousCactus))
* Add tutorial: "Demystifying Task Transforms", ([Varad Pimpalkhute](https:/nightlessbaron/))

### Changed
Expand Down
321 changes: 0 additions & 321 deletions examples/vision/mamlpp/cnn4_bnrs.py

This file was deleted.

14 changes: 7 additions & 7 deletions examples/vision/mamlpp/maml++_miniimagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Tuple
from tqdm import tqdm

from examples.vision.mamlpp.cnn4_bnrs import CNN4_BNRS
from learn2learn.vision.models.cnn4_metabatchnorm import CNN4_MetaBatchNorm
from examples.vision.mamlpp.MAMLpp import MAMLpp


Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(
)

# Model
self._model = CNN4_BNRS(ways, adaptation_steps=steps)
self._model = CNN4_MetaBatchNorm(ways, steps)
if self._use_cuda:
self._model.cuda()

Expand Down Expand Up @@ -147,19 +147,19 @@ def _training_step(
# Adapt the model on the support set
for step in range(self._steps):
# forward + backward + optimize
pred = learner(s_inputs, step)
pred = learner(s_inputs)
support_loss = self._inner_criterion(pred, s_labels)
learner.adapt(support_loss, first_order=not second_order, step=step)
# Multi-Step Loss
if msl:
q_pred = learner(q_inputs, step)
q_pred = learner(q_inputs)
query_loss += self._step_weights[step] * self._inner_criterion(
q_pred, q_labels
)

# Evaluate the adapted model on the query set
if not msl:
q_pred = learner(q_inputs, self._steps-1)
q_pred = learner(q_inputs, inference=True)
query_loss = self._inner_criterion(q_pred, q_labels)
acc = accuracy(q_pred, q_labels).detach()

Expand All @@ -180,12 +180,12 @@ def _testing_step(
# Adapt the model on the support set
for step in range(self._steps):
# forward + backward + optimize
pred = learner(s_inputs, step)
pred = learner(s_inputs)
support_loss = self._inner_criterion(pred, s_labels)
learner.adapt(support_loss, step=step)

# Evaluate the adapted model on the query set
q_pred = learner(q_inputs, self._steps-1)
q_pred = learner(q_inputs, inference=True)
query_loss = self._inner_criterion(q_pred, q_labels).detach()
acc = accuracy(q_pred, q_labels)

Expand Down
1 change: 1 addition & 0 deletions learn2learn/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .misc import *
from .protonet import PrototypicalClassifier
from .metaoptnet import SVClassifier
from .metabatchnorm import MetaBatchNorm
Loading