Skip to content

Commit

Permalink
Merge pull request #898 from JuliaAI/loss-functions-0.9
Browse files Browse the repository at this point in the history
Bump compat LossFunctions = "0.9" and address breakage
  • Loading branch information
ablaom authored Apr 20, 2023
2 parents 4cb9ba7 + 9603c64 commit 8f3a29d
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ CategoricalDistributions = "0.1"
ComputationalResources = "0.3"
Distributions = "0.25.3"
InvertedIndices = "1"
LossFunctions = "0.5, 0.6, 0.7, 0.8"
LossFunctions = "0.9"
MLJModelInterface = "1.7"
Missings = "0.4, 1"
OrderedCollections = "1.1"
Expand Down
6 changes: 3 additions & 3 deletions src/measures/loss_functions_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ err_wrap(n) = ArgumentError("Bad @wrap syntax: $n. ")

# We define amacro to wrap a concrete `LossFunctions.SupervisedLoss`
# type and define its constructor, and to define property access in
# case of paramters; the macro also defined calling behaviour:
# case of parameters; the macro also defines calling behaviour:
macro wrap_loss(ex)
ex.head == :call || throw(err_wrap(1))
Loss_ex = ex.args[1]
Expand Down Expand Up @@ -130,7 +130,7 @@ MMI.prediction_type(::Type{<:DistanceLoss}) = :deterministic
MMI.target_scitype(::Type{<:DistanceLoss}) = Union{Vec{Continuous},Vec{Count}}

call(measure::DistanceLoss, yhat, y) =
LossFunctions.value(getfield(measure, :loss), y, yhat)
LossFunctions.value(getfield(measure, :loss), yhat, y)

function call(measure::DistanceLoss, yhat, y, w::AbstractArray)
return w .* call(measure, yhat, y)
Expand All @@ -148,7 +148,7 @@ _scale(p) = 2p - 1
function call(measure::MarginLoss, yhat, y)
probs_of_observed = broadcast(pdf, yhat, y)
return (LossFunctions.value).(getfield(measure, :loss),
1, _scale.(probs_of_observed))
_scale.(probs_of_observed), 1)
end

call(measure::MarginLoss, yhat, y, w::AbstractArray) =
Expand Down
17 changes: 6 additions & 11 deletions test/measures/loss_functions_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,9 @@ end

for M_ex in MARGIN_LOSSES
m = eval(:(MLJBase.$M_ex()))
@test m(yhat, y) LossFunctions.value(getfield(m, :loss), ym, yhatm)
@test MLJBase.Mean()(m(yhat, y, w))
LossFunctions.value(getfield(m, :loss),
ym,
yhatm,
WeightedSum(w))/N
@test m(yhat, y) LossFunctions.value(getfield(m, :loss), yhatm, ym)
@test m(yhat, y, w)
w .* LossFunctions.value(getfield(m, :loss), yhatm, ym)
end
end

Expand All @@ -64,10 +61,8 @@ end
m_ex = MLJBase.snakecase(M_ex)
@test m == eval(:(MLJBase.$m_ex))
@test m(yhat, y)
LossFunctions.value(getfield(m, :loss), y, yhat)
@test mean(m(yhat ,y, w))
LossFunctions.value(getfield(m, :loss), y, yhat,
WeightedSum(w))/N

LossFunctions.value(getfield(m, :loss), yhat, y)
@test m(yhat ,y, w)
w .* LossFunctions.value(getfield(m, :loss), yhat, y)
end
end

0 comments on commit 8f3a29d

Please sign in to comment.