Skip to content

Commit

Permalink
percentile_of_score
Browse files Browse the repository at this point in the history
Summary: Implements `percentile_of_score` in BoTorch utils, akin to `scipy.stats.percentileofscore`

Reviewed By: SebastianAment

Differential Revision: D64137730
  • Loading branch information
Jihao Andreas Lin authored and facebook-github-bot committed Oct 9, 2024
1 parent df93789 commit f284972
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 0 deletions.
30 changes: 30 additions & 0 deletions botorch/utils/probability/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,3 +375,33 @@ def compute_log_prob_feas_from_bounds(
dist_u = (con_upper - means[..., i]) / sigmas[..., i]
log_prob = log_prob + log_prob_normal_in(a=dist_l, b=dist_u).sum(dim=-1)
return log_prob


def percentile_of_score(data: Tensor, score: Tensor, dim: int = -1) -> Tensor:
"""Compute the percentile rank of `score` relative to `data`.
For example, if this function returns 70 then 70% of the
values in `data` are below `score`.
This implementation is based on `scipy.stats.percentileofscore`,
with `kind='rank'` and `nan_policy='propagate'`, which is the default.
Args:
data: A `... x n x output_shape`-dim Tensor of data.
score: A `... x 1 x output_shape`-dim Tensor of scores.
Returns:
A `... x output_shape`-dim Tensor of percentile ranks.
"""
# based on scipy.stats.percentileofscore
left = torch.count_nonzero(data < score, dim=dim)
right = torch.count_nonzero(data <= score, dim=dim)
plus1 = left < right
perct = (left + right + plus1) * (50.0 / data.shape[dim])
# perct shape: `... x output_shape`
# fill in nans due to current trial progression being nan
nan_mask = torch.broadcast_to(torch.isnan(score.squeeze(dim)), perct.shape)
perct[nan_mask] = torch.nan
# fill in nans due to previous trial progressions being nan
nan_mask = torch.broadcast_to(torch.any(torch.isnan(data), dim=dim), perct.shape)
perct[nan_mask] = torch.nan
return perct
70 changes: 70 additions & 0 deletions test/utils/probability/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

from __future__ import annotations

import itertools

import numpy as np

import torch
from botorch.utils.probability import ndtr, utils
from botorch.utils.probability.utils import (
Expand All @@ -14,11 +18,13 @@
log_ndtr,
log_phi,
log_prob_normal_in,
percentile_of_score,
phi,
standard_normal_log_hazard,
)
from botorch.utils.testing import BotorchTestCase
from numpy.polynomial.legendre import leggauss as numpy_leggauss
from scipy.stats import percentileofscore as percentile_of_score_scipy


class TestProbabilityUtils(BotorchTestCase):
Expand Down Expand Up @@ -321,3 +327,67 @@ def test_gaussian_probabilities(self) -> None:

with self.assertRaisesRegex(TypeError, expected_regex=float16_msg):
log_ndtr(torch.tensor(1.0, dtype=torch.float16, device=self.device))

def test_percentile_of_score(self) -> None:
# compare to scipy.stats.percentileofscore with default settings
# `kind='rank'` and `nan_policy='propagate'`
torch.manual_seed(12345)
n = 10
for (
dtype,
data_batch_shape,
score_batch_shape,
output_shape,
) in itertools.product(
(torch.float, torch.double),
((), (1,), (2,), (2, 3)),
((), (1,), (2,), (2, 3)),
((), (1,), (2,), (2, 3)),
):
# calculate shapes
data_shape = data_batch_shape + (n,) + output_shape
score_shape = score_batch_shape + (1,) + output_shape
dim = -1 - len(output_shape)
# generate data
data = torch.rand(*data_shape, dtype=dtype, device=self.device)
score = torch.rand(*score_shape, dtype=dtype, device=self.device)
# insert random nans to test nan policy
data[data < 0.01] = torch.nan
score[score < 0.01] = torch.nan
# calculate percentile ranks using torch
try:
perct_torch = percentile_of_score(data, score, dim=dim).cpu().numpy()
except RuntimeError:
# confirm RuntimeError is raised because shapes cannot be broadcasted
with self.assertRaises(ValueError):
np.broadcast_shapes(data_batch_shape, score_batch_shape)
continue
# check shape
broadcast_shape = np.broadcast_shapes(data_batch_shape, score_batch_shape)
expected_perct_shape = broadcast_shape + output_shape
self.assertEqual(perct_torch.shape, expected_perct_shape)
# calculate percentile ranks using scipy.stats.percentileofscore
# scipy.stats.percentileofscore does not support broadcasting
# loop over batch and output shapes instead
perct_scipy = np.zeros_like(perct_torch)
data_scipy = np.broadcast_to(
data.cpu().numpy(), broadcast_shape + (n,) + output_shape
)
score_scipy = np.broadcast_to(
score.cpu().numpy(), broadcast_shape + (1,) + output_shape
)
broadcast_idx_prod = list(
itertools.product(*[list(range(d)) for d in broadcast_shape])
)
output_idx_prod = list(
itertools.product(*[list(range(d)) for d in output_shape])
)
for broadcast_idx in broadcast_idx_prod:
for output_idx in output_idx_prod:
data_idx = broadcast_idx + (slice(None),) + output_idx
score_idx = broadcast_idx + (0,) + output_idx
perct_idx = broadcast_idx + output_idx
perct_scipy[perct_idx] = percentile_of_score_scipy(
data_scipy[data_idx], score_scipy[score_idx]
)
self.assertTrue(np.array_equal(perct_torch, perct_scipy, equal_nan=True))

0 comments on commit f284972

Please sign in to comment.