Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Even more generic broadcasted #107

Merged
merged 2 commits into from
Mar 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Yota"
uuid = "cd998857-8626-517d-b929-70ad188a48f0"
authors = ["Andrei Zhabinski <[email protected]>"]
version = "0.7.0"
version = "0.7.1"

[deps]
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
Expand All @@ -20,5 +20,5 @@ ChainRules = "1"
ChainRulesCore = "1"
FiniteDifferences = "0.12"
NNlib = "0.8"
Umlaut = "0.2.2"
Umlaut = "0.2.3"
julia = "1.6"
3 changes: 3 additions & 0 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ using ChainRules
using NNlib
using Umlaut
import Umlaut: record_primitive!, isprimitive, BaseCtx
import Umlaut: record_or_recurse!, Tracer, trace!, getcode


const V = Umlaut.Variable
const broadcasted = Broadcast.broadcasted


include("helpers.jl")
Expand Down
58 changes: 54 additions & 4 deletions src/cr_api.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import ChainRulesCore: rrule, no_rrule
import ChainRulesCore: rrule_via_ad, RuleConfig, NoForwardsMode, HasReverseMode
import Umlaut: make_name, Input, to_expr
import Umlaut: make_name, Input, to_expr, BcastCtx


###############################################################################
Expand Down Expand Up @@ -34,9 +34,36 @@ Extends RuleConfig{Union{NoForwardsMode,HasReverseMode}}.
struct YotaRuleConfig <: RuleConfig{Union{NoForwardsMode,HasReverseMode}} end


###############################################################################
# rrule_via_ad #
###############################################################################

"""
bcast_rrule(::YotaRuleConfig, ::typeof(broadcasted), f, args...; kw...)

Similar to rrule(config, broadcasted, f, args...), but works on for ChainRule-primitive
functions. For a more flexible handling of broadcasting use rrule(...) directly.
"""
function bcast_rrule(::YotaRuleConfig, ::typeof(broadcasted), f::F, args...; kw...) where F
ys, pbs = unzip(rrule.(YOTA_RULE_CONFIG, f, args...; kw...))
function pullback(Δ)
if Δ isa NoTangent || Δ isa ZeroTangent
return (NoTangent(), [Δ for _=1:length(pbs) + 1]...,)
end
Δ = unthunk(Δ)
dxs = map((pb, Δ) -> pb(Δ), pbs, Δ) |> unzip
dxs = [all(dx .== NoTangent()) ? NoTangent() : dx for dx in dxs]
return (NoTangent(), dxs...,)
end
return ys, pullback
end


function to_rrule_expr(tape::Tape)
# TODO (maybe): add YotaRuleConfig() as the first argument for consistency
fn_name = gensym("rrule_$(tape[V(1)].val)")
header = Expr(:call, fn_name)
push!(header.args, Expr(:(::), :config, YotaRuleConfig))
for v in inputs(tape)
op = tape[v]
push!(header.args, Expr(:(::), make_name(op), op.typ))
Expand Down Expand Up @@ -95,7 +122,30 @@ Examples:

"""
make_rrule(tape::Tape) = Base.eval(@__MODULE__, to_rrule_expr(tape))
make_rrule(f, args...) = make_rrule(gradtape(f, args...; seed=:auto))

function make_rrule(f, args...)
return make_rrule(gradtape(f, args...; seed=:auto, ctx=GradCtx()))
end

function make_rrule(::typeof(broadcasted), f, args...)
if isprimitive(GradCtx(), f, map(first, args)...)
return bcast_rrule # (YOTA_RULE_CONFIG, broadcasted, f, args...)
end
ctx = BcastGradCtx(GradCtx())
_, tape = trace(f, args...; ctx=ctx, fargtypes=(f, map(eltype, args)))
tape = Tape(tape; ctx=ctx.inner)
gradtape!(tape, seed=:auto)
# insert imaginary broadcasted to the list of inputs
insert!(tape, 1, Umlaut.Input(broadcasted))
# insert ZeroTangent to the result to account for the additional argument
grad_tuple_op = tape[V(tape.result.id - 2)]
@assert grad_tuple_op isa Call && grad_tuple_op.fn == tuple
grad_tuple_op.args = [ZeroTangent(), grad_tuple_op.args...]
for id=grad_tuple_op.id:grad_tuple_op.id + 2
Umlaut.exec!(tape, tape[V(id)])
end
return make_rrule(tape)
end


const GENERATED_RRULE_CACHE = Dict()
Expand All @@ -113,13 +163,13 @@ function ChainRulesCore.rrule_via_ad(::YotaRuleConfig, f, args...)
if haskey(GENERATED_RRULE_CACHE, sig)
rr = GENERATED_RRULE_CACHE[sig]
# return Base.invokelatest(rr, f, args...)
val, pb = Base.invokelatest(rr, f, args...)
val, pb = Base.invokelatest(rr, YOTA_RULE_CONFIG, f, args...)
return val, dy -> Base.invokelatest(pb, dy)
else
rr = make_rrule(f, args...)
GENERATED_RRULE_CACHE[sig] = rr
# return Base.invokelatest(rr, f, args...)
val, pb = Base.invokelatest(rr, f, args...)
val, pb = Base.invokelatest(rr, YOTA_RULE_CONFIG, f, args...)
return val, dy -> Base.invokelatest(pb, dy)
end
end
43 changes: 37 additions & 6 deletions src/grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,40 @@ function record_primitive!(tape::Tape{GradCtx}, v_fargs...)
end


#################################################################################
# GRAD #
#################################################################################
###############################################################################
# BCAST GRAD CONTEXT #
###############################################################################

struct BcastGradCtx
inner
end


function record_or_recurse!(t::Tracer{BcastGradCtx}, v_fargs...)
fargs = [v isa V ? t.tape[v].val : v for v in v_fargs]
# global STATE = (t, v_fargs)
v_f, v_args... = v_fargs
f, args... = [v isa V ? t.tape[v].val : v for v in v_fargs]
return if isprimitive(t.tape.c.inner, fargs...)
# push!(t.tape, mkcall(broadcasted, vs...))
rr_op = (is_kwfunc(f) ?
mkcall(Core.kwfunc(bcast_rrule), v_args[1], bcast_rrule, YOTA_RULE_CONFIG, broadcasted, v_args[2:end]...) :
mkcall(bcast_rrule, YOTA_RULE_CONFIG, broadcasted, v_f, v_args...))
v_rr = push!(t.tape, rr_op)
v_val = push!(t.tape, mkcall(_getfield, v_rr, 1))
v_pb = push!(t.tape, mkcall(_getfield, v_rr, 2))
t.tape.c.inner.pullbacks[v_val] = v_pb
return v_val
else
types = map(eltype, fargs[2:end])
trace!(t, getcode(fargs[1], types), v_fargs...)
end
end


###############################################################################
# GRAD #
###############################################################################

getderiv(tape::Tape, v::Variable) = get(tape.c.derivs, bound(tape, v), nothing)
setderiv!(tape::Tape, x::Variable, dx::Variable) = (
Expand Down Expand Up @@ -197,14 +228,14 @@ end


"""
gradtape(f::Union{Function, DataType}, args...; seed=1)
gradtape(f::Union{Function, DataType}, args...; ctx=GradCtx(), seed=1)
gradtape!(tape::Tape; seed=1)

Calculate and record to the tape gradients of `tape[tape.resultid]` w.r.t. `Input` nodes.
See grad() for more high-level API.
"""
function gradtape(f::Union{Function,DataType}, args...; seed=1)
_, tape = trace(f, args...; ctx=GradCtx())
function gradtape(f::Union{Function,DataType}, args...; ctx=GradCtx(), seed=1)
_, tape = trace(f, args...; ctx=ctx)
tape = gradtape!(tape; seed=seed)
return tape
end
Expand Down
9 changes: 1 addition & 8 deletions src/rulesets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,7 @@ end


function rrule(::YotaRuleConfig, ::typeof(Broadcast.broadcasted), f::F, args...) where F
ys, pbs = unzip(rrule_via_ad.(YOTA_RULE_CONFIG, f, args...))
function pullback(Δ)
Δ = unthunk(Δ)
dxs = map((pb, Δ) -> pb(Δ), pbs, Δ) |> unzip
dxs = [all(dx .== NoTangent()) ? NoTangent() : dx for dx in dxs]
return NoTangent(), dxs...
end
return ys, pullback
return rrule_via_ad(YOTA_RULE_CONFIG, broadcasted, f, args...)
end


Expand Down
20 changes: 18 additions & 2 deletions test/test_cr_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ import ChainRulesTestUtils.test_rrule
import ChainRulesCore: rrule, unthunk
import Yota: isprimitive, CR_CTX

const broadcasted = Broadcast.broadcasted


double_inc(x::Number) = 2x + 1
double_inc(x::AbstractArray) = 2x .+ 1

double_dec(x::Number) = 2x - 1

primitive_test(x; y=1) = x + y
primitive_test(x, y) = x + y
Expand All @@ -19,13 +22,18 @@ rrule(::YotaRuleConfig, ::typeof(primitive_test2), x; y=1) = primitive_test2(x;


@testset "chainrules api" begin
config = YotaRuleConfig()

rr = make_rrule(double_inc, 2.0)
val, pb = rr(double_inc, 3.0)
val, pb = rr(config, double_inc, 3.0)
@test val == 7
@test pb(1.0) == (ZeroTangent(), 2.0)

config = YotaRuleConfig()
rr = make_rrule(broadcasted, double_dec, [1.0, 2.0])
val, pb = rr(config, broadcasted, double_dec, [3.0, 4.0])
@test val == [5.0, 7.0]
@test pb([1, 1]) == (ZeroTangent(), ZeroTangent(), [2.0, 2.0])

val, pb = rrule_via_ad(config, double_inc, 3.0)
@test val == 7
@test pb(1.0) == (ZeroTangent(), 2.0)
Expand All @@ -38,6 +46,14 @@ rrule(::YotaRuleConfig, ::typeof(primitive_test2), x; y=1) = primitive_test2(x;
dxs = map(unthunk, pb([1, 2, 3]))
@test dxs == (ZeroTangent(), [2.0, 4.0, 6.0])

x = rand(3)
val, pb = rrule_via_ad(config, broadcasted, double_dec, x)
@test val == broadcast(double_dec, x)
dxs = map(unthunk, pb(ones(3)))
@test dxs == (ZeroTangent(), ZeroTangent(), [2.0, 2.0, 2.0])
dxs = map(unthunk, pb([1, 2, 3]))
@test dxs == (ZeroTangent(), ZeroTangent(), [2.0, 4.0, 6.0])

x, y = rand(2)
@test isprimitive(CR_CTX, primitive_test, x) == true
@test isprimitive(CR_CTX, Core.kwfunc(primitive_test), (y=1,), primitive_test, x) == true
Expand Down
25 changes: 16 additions & 9 deletions test/test_rulesets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,26 @@ import Yota.YotaRuleConfig
end


# broacastable non-primitive
sin_inc(x::Number) = sin(x) + 1


@testset "generic broadcasted" begin
# see the discussion here:
# https:/JuliaDiff/ChainRules.jl/issues/531
f = sin
xs = rand(2)
for f in [sin, sin_inc]
xs = rand(2)

# manually get pullbacks for each element and apply them to seed 1.0
pbs = [rrule(f, x)[2] for x in xs]
dxs = [pbs[1](1.0)[2], pbs[2](1.0)[2]]
# manually get pullbacks for each element and apply them to seed 1.0
pbs = !isnothing(rrule(f, xs[1])) ?
[rrule(f, x)[2] for x in xs] : # just in case
[rrule_via_ad(YotaRuleConfig(), f, x)[2] for x in xs]
dxs = [pbs[1](1.0)[2], pbs[2](1.0)[2]]

# use rrule for broadcasted
_, bcast_pb = rrule(YotaRuleConfig(), Broadcast.broadcasted, f, xs)
dxs_bcast = bcast_pb(ones(2))[end]
# use rrule for broadcasted
_, bcast_pb = rrule(YotaRuleConfig(), Broadcast.broadcasted, f, xs)
dxs_bcast = bcast_pb(ones(2))[end]

@test all(dxs .== dxs_bcast)
@test all(dxs .== dxs_bcast)
end
end