Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consistency of regression layers (student-t and standard) #12

Open
brunzema opened this issue Aug 20, 2024 · 0 comments
Open

Consistency of regression layers (student-t and standard) #12

brunzema opened this issue Aug 20, 2024 · 0 comments

Comments

@brunzema
Copy link

brunzema commented Aug 20, 2024

Hi,

just found a subtle bug / inconsistency in the regression layers that I recognized when creating posterior sample from the predictive.

For the standard MVN case, W is defined as a method:

    def W(self):
        cov_diag = torch.exp(self.W_logdiag)
        if self.W_dist == Normal:
            cov = self.W_dist(self.W_mean, cov_diag)
        elif self.W_dist == DenseNormal:
            tril = torch.tril(self.W_offdiag, diagonal=-1) + torch.diag_embed(cov_diag)
            cov = self.W_dist(self.W_mean, tril)
        elif self.W_dist == LowRankNormal:
            cov = self.W_dist(self.W_mean, self.W_offdiag, cov_diag)

        return cov

whereas for the t-VBLL regression layer, it is defined as a property:

    @property
    def W(self):
        cov_diag = torch.exp(self.W_logdiag)
        if self.W_dist == Normal:
            cov = self.W_dist(self.W_mean, cov_diag)
        elif self.W_dist == DenseNormal:
            tril = torch.tril(self.W_offdiag, diagonal=-1) + torch.diag_embed(cov_diag)
            cov = self.W_dist(self.W_mean, tril)
        elif self.W_dist == LowRankNormal:
            cov = self.W_dist(self.W_mean, self.W_offdiag, cov_diag)

        return cov

This than alters the way to sample from W:

  • for VBLL: layer.W().rsample()
  • for tVBLL: layer.W.rsample()

I personally prefer W as a property. Happy to create a PR for this but wanted to double check with you guys.


EDIT: Just checked, same holds for the classification case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant