Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow additional parameters in compute_mat #697

Open
celsofranssa opened this issue Jun 3, 2024 · 10 comments
Open

Allow additional parameters in compute_mat #697

celsofranssa opened this issue Jun 3, 2024 · 10 comments
Labels
enhancement New feature or request
Milestone

Comments

@celsofranssa
Copy link

Hello,

The following code snippet shows a custom distance function that scales the simple dot distance with the rewards associated with each embedding.

import torch
from pytorch_metric_learning.distances import BaseDistance


class CustomDistance(BaseDistance):

    def __init__(self, params):
        super().__init__(params, is_inverted=True)
        assert self.is_inverted
        # dict of ref_emb rewards
        self.rewards = params.rewards

    def compute_mat(self, embeddings_ids, embeddings, ref_emb_ids, ref_emb):
        mat =  20 * torch.einsum("ab,cb->ac", embeddings, ref_emb)
        for col, ref_emb_idx in ref_emb_ids.items():
            mat[:, col] *= self.rewards[ref_emb_idx]

However, the loss function call of mat = self.distance(embeddings, ref_emb) not allows calling the overridden method compute_mat(self, embeddings_ids, embeddings, ref_emb_ids, ref_emb) (containing the embeddings and ref_emb ids) required to get the embedding reward and scale the corresponding distance value.

Is there a workaround?

Thank you.

@celsofranssa
Copy link
Author

Hello @KevinMusgrave
Do you have any suggestions here?

@KevinMusgrave
Copy link
Owner

Would overriding the forward method work?

def forward(self, query_emb, ref_emb=None):

I guess it depends where embeddings_ids and ref_emb_ids is coming from.

@KevinMusgrave
Copy link
Owner

Or do you mean you need to change the mat = self.distance(embeddings, ref_emb) in the loss function?

@celsofranssa
Copy link
Author

Or do you mean you need to change the mat = self.distance(embeddings, ref_emb) in the loss function?

Would overriding the forward method work?

Exactly @KevinMusgrave,

I am already using the very nice NTXentLoss. Recently, I discovered that scaling the simple dot distance with the rewards associated with each embedding is beneficial for my research task. At this point, I need to access the embedding identifier because each embedding has its associated reward.

@KevinMusgrave
Copy link
Owner

I see, so you also need to modify NTXentLoss to take in those ids?

@celsofranssa
Copy link
Author

I see, so you also need to modify NTXentLoss to take in those ids?

If there is no other overriding approach, yes.

@KevinMusgrave
Copy link
Owner

KevinMusgrave commented Jun 5, 2024

Yeah I don't think there's another approach. The only other way would be to set some attribute of your custom distance object before computing the loss.

dist_fn = CustomDistance()
loss_fn = NTXentLoss(distance=dist_fn)

...
dist_fn.curr_ref_ids = ref_ids
loss = loss_fn(...)

# inside dist_fn refer to self.curr_ref_ids

@celsofranssa
Copy link
Author

Yeah I don't think there's another approach. The only other way would be to set some attribute of your custom distance object before computing the loss.

dist_fn = CustomDistance()
loss_fn = NTXentLoss(distance=dist_fn)

...
dist_fn.curr_ref_ids = ref_ids
loss = loss_fn(...)

# inside dist_fn refer to self.curr_ref_ids

I see. For a feature release, it would be great if the mat = self.distance(embeddings, ref_emb) function accepts additional parameters to leverage some custom distance implementations.

@KevinMusgrave KevinMusgrave added the enhancement New feature or request label Jun 5, 2024
@KevinMusgrave KevinMusgrave added this to the v3.0 milestone Jun 5, 2024
@celsofranssa
Copy link
Author

Hello,

Is there any progress here?

@KevinMusgrave
Copy link
Owner

Sorry, no progress yet

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants