Skip to content

Commit

Permalink
Fix DA cost correction when cost limit is set to Inf (#593)
Browse files Browse the repository at this point in the history
* Introduce the test that actually fails with cost = nan

* Update cost correction algorithm

* Better explanation in the code for how missing_labels and label_match interact

* All close to check semi-supervised mode

* Update RELEASE file

* Suppress runtime warning from numpy about using Inf in the multiplication
  • Loading branch information
kachayev authored Jan 12, 2024
1 parent 98a58d2 commit 7950b11
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 11 deletions.
6 changes: 6 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Releases

## 0.9.3

#### Closed issues
- Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593)


## 0.9.2
*December 2023*

Expand Down
28 changes: 22 additions & 6 deletions ot/da.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# License: MIT License

import numpy as np
import warnings

from .backend import get_backend
from .bregman import sinkhorn, jcpot_barycenter
Expand Down Expand Up @@ -499,12 +500,27 @@ class label
if self.limit_max != np.infty:
self.limit_max = self.limit_max * nx.max(self.cost_)

# zeros where source label is missing (masked with -1)
missing_labels = ys + nx.ones(ys.shape, type_as=ys)
missing_labels = nx.repeat(missing_labels[:, None], ys.shape[0], 1)
# zeros where labels match
label_match = ys[:, None] - yt[None, :]
self.cost_ = nx.maximum(self.cost_, nx.abs(label_match) * nx.abs(missing_labels) * self.limit_max)
# missing_labels is a (ns, nt) matrix of {0, 1} such that
# the cells (i, j) has 0 iff either ys[i] or yt[j] is masked
missing_ys = (ys == -1) + nx.zeros(ys.shape, type_as=ys)
missing_yt = (yt == -1) + nx.zeros(yt.shape, type_as=yt)
missing_labels = missing_ys[:, None] @ missing_yt[None, :]
# labels_match is a (ns, nt) matrix of {True, False} such that
# the cells (i, j) has False if ys[i] != yt[i]
label_match = (ys[:, None] - yt[None, :]) != 0
# cost correction is a (ns, nt) matrix of {-Inf, float, Inf} such
# that he cells (i, j) has -Inf where there's no correction necessary
# by 'correction' we mean setting cost to a large value when
# labels do not match
# we suppress potential RuntimeWarning caused by Inf multiplication
# (as we explicitly cover potential NANs later)
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=RuntimeWarning)
cost_correction = label_match * missing_labels * self.limit_max
# this operation is necessary because 0 * Inf = NAN
# thus is irrelevant when limit_max is finite
cost_correction = nx.nan_to_num(cost_correction, -np.infty)
self.cost_ = nx.maximum(self.cost_, cost_correction)

# distribution estimation
self.mu_s = self.distribution_estimation(Xs)
Expand Down
28 changes: 23 additions & 5 deletions test/test_da.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def test_sinkhorn_lpl1_transport_class(nx):
# test its computed
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
assert hasattr(otda, "cost_")
assert not np.any(np.isnan(nx.to_numpy(otda.cost_))), "cost is finite"
assert hasattr(otda, "coupling_")
assert np.all(np.isfinite(nx.to_numpy(otda.coupling_))), "coupling is finite"

# test dimensions of coupling
assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
Expand Down Expand Up @@ -148,7 +150,7 @@ def test_sinkhorn_lpl1_transport_class(nx):
n_semisup = nx.sum(otda_semi.cost_)

# check that the cost matrix norms are indeed different
assert n_unsup != n_semisup, "semisupervised mode not working"
assert np.allclose(n_unsup, n_semisup, atol=1e-7), "semisupervised mode is not working"

# check that the coupling forbids mass transport between labeled source
# and labeled target samples
Expand Down Expand Up @@ -238,7 +240,7 @@ def test_sinkhorn_l1l2_transport_class(nx):
n_semisup = nx.sum(otda_semi.cost_)

# check that the cost matrix norms are indeed different
assert n_unsup != n_semisup, "semisupervised mode not working"
assert np.allclose(n_unsup, n_semisup, atol=1e-7), "semisupervised mode is not working"

# check that the coupling forbids mass transport between labeled source
# and labeled target samples
Expand Down Expand Up @@ -331,7 +333,7 @@ def test_sinkhorn_transport_class(nx):
n_semisup = nx.sum(otda_semi.cost_)

# check that the cost matrix norms are indeed different
assert n_unsup != n_semisup, "semisupervised mode not working"
assert np.allclose(n_unsup, n_semisup, atol=1e-7), "semisupervised mode is not working"

# check that the coupling forbids mass transport between labeled source
# and labeled target samples
Expand Down Expand Up @@ -371,6 +373,10 @@ def test_unbalanced_sinkhorn_transport_class(nx):
# test dimensions of coupling
assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
assert not np.any(np.isnan(nx.to_numpy(otda.cost_))), "cost is finite"

# test coupling
assert np.all(np.isfinite(nx.to_numpy(otda.coupling_))), "coupling is finite"

# test transform
transp_Xs = otda.transform(Xs=Xs)
Expand Down Expand Up @@ -409,19 +415,22 @@ def test_unbalanced_sinkhorn_transport_class(nx):
# test unsupervised vs semi-supervised mode
otda_unsup = ot.da.SinkhornTransport()
otda_unsup.fit(Xs=Xs, Xt=Xt)
assert not np.any(np.isnan(nx.to_numpy(otda_unsup.cost_))), "cost is finite"
n_unsup = nx.sum(otda_unsup.cost_)

otda_semi = ot.da.SinkhornTransport()
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
assert not np.any(np.isnan(nx.to_numpy(otda_semi.cost_))), "cost is finite"
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
n_semisup = nx.sum(otda_semi.cost_)

# check that the cost matrix norms are indeed different
assert n_unsup != n_semisup, "semisupervised mode not working"
assert np.allclose(n_unsup, n_semisup, atol=1e-7), "semisupervised mode is not working"

# check everything runs well with log=True
otda = ot.da.SinkhornTransport(log=True)
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
assert not np.any(np.isnan(nx.to_numpy(otda.cost_))), "cost is finite"
assert len(otda.log_.keys()) != 0


Expand All @@ -448,7 +457,9 @@ def test_emd_transport_class(nx):

# test dimensions of coupling
assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
assert not np.any(np.isnan(nx.to_numpy(otda.cost_))), "cost is finite"
assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
assert np.all(np.isfinite(nx.to_numpy(otda.coupling_))), "coupling is finite"

# test margin constraints
mu_s = unif(ns)
Expand Down Expand Up @@ -495,15 +506,22 @@ def test_emd_transport_class(nx):
# test unsupervised vs semi-supervised mode
otda_unsup = ot.da.EMDTransport()
otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt)
assert_equal(otda_unsup.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
assert not np.any(np.isnan(nx.to_numpy(otda_unsup.cost_))), "cost is finite"
assert_equal(otda_unsup.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
assert np.all(np.isfinite(nx.to_numpy(otda_unsup.coupling_))), "coupling is finite"
n_unsup = nx.sum(otda_unsup.cost_)

otda_semi = ot.da.EMDTransport()
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
assert not np.any(np.isnan(nx.to_numpy(otda_semi.cost_))), "cost is finite"
assert_equal(otda_semi.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
assert np.all(np.isfinite(nx.to_numpy(otda_semi.coupling_))), "coupling is finite"
n_semisup = nx.sum(otda_semi.cost_)

# check that the cost matrix norms are indeed different
assert n_unsup != n_semisup, "semisupervised mode not working"
assert np.allclose(n_unsup, n_semisup, atol=1e-7), "semisupervised mode is not working"

# check that the coupling forbids mass transport between labeled source
# and labeled target samples
Expand Down

0 comments on commit 7950b11

Please sign in to comment.