Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate from Ghost to Umlaut #104

Merged
merged 6 commits into from
Feb 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,15 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Ghost = "4f8f7498-1303-42e1-920c-5033445536df"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
Umlaut = "92992a2b-8ce5-4a9c-bb9d-58be9a7dc841"

[compat]
ChainRules = "1"
ChainRulesCore = "1"
FiniteDifferences = "0.12"
Ghost = "^0.2.1"
NNlib = "^0.7.27"
OrderedCollections = "1.4"
NNlib = "0.7.27"
julia = "1.6"
130 changes: 19 additions & 111 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,65 +2,30 @@ import ChainRulesCore: rrule, no_rrule
import ChainRulesCore: rrule_via_ad, RuleConfig, NoForwardsMode, HasReverseMode
import Ghost: make_name, Input, to_expr


###############################################################################
# Primitives #
###############################################################################

"""
Collect list of function signatures for which rrule() or no_rrule() is defined
"""
function rrule_covered_signatures(fn=rrule)
rrule_methods = methods(fn).ms
rrule_sigs = [rr.sig for rr in rrule_methods]
primal_sigs = []
for rr_sig in rrule_sigs
# remove `rrule` parameter
sig = remove_first_parameter(rr_sig)
Ts = collect(Ghost.get_type_parameters(sig))
# skip rules with config with features that we don't support
if Ts[1] <: RuleConfig && !(Ts[1] <: RuleConfig{>:HasReverseMode})
continue
end
# remove RuleConfig parameter
if Ts[1] <: RuleConfig{>:HasReverseMode}
sig = remove_first_parameter(sig)
end
# now sig looks like the signature of the primal function
push!(primal_sigs, sig)
end
# add keyword version of these functions as well
kw_sigs = [kwsig for kwsig in map(kwfunc_signature, primal_sigs) if kwsig !== Tuple{}]
return [primal_sigs; kw_sigs]
end


struct ChainRulesCtx end

const CHAINRULES_PRIMITIVES = Ref(FunctionResolver{Bool}())
const NUM_CHAINRULES_METHODS = Ref{Int}(0)


function update_chainrules_primitives!(;force=false)
num_methods = length(methods(rrule)) + length(methods(no_rrule))
if force || num_methods != NUM_CHAINRULES_METHODS[]
sigs_flags = [
[sig => true for sig in rrule_covered_signatures(rrule)];
[sig => false for sig in rrule_covered_signatures(no_rrule)] # override rrule(sig...)
]
P = FunctionResolver{Bool}(sigs_flags)
CHAINRULES_PRIMITIVES[] = P
NUM_CHAINRULES_METHODS[] = num_methods
function isprimitive(::ChainRulesCtx, f, args...)
Ts = [a isa DataType ? Type{a} : typeof(a) for a in (f, args...)]
Core.Compiler.return_type(rrule, (YotaRuleConfig, Ts...,)) !== Nothing && return true
if is_kwfunc(Ts[1])
Ts_kwrrule = (Any, typeof(Core.kwfunc(f)), YotaRuleConfig, Ts[2:end]...,)
Core.Compiler.return_type(Core.kwfunc(rrule), Ts_kwrrule) !== Nothing && return true
end
return false
end


is_chainrules_primitive(sig) = CHAINRULES_PRIMITIVES[][sig] == true


###############################################################################
# RuleConfig #
###############################################################################

struct YotaRuleConfig <: RuleConfig{Union{NoForwardsMode,HasReverseMode}} end
struct YotaRuleConfig <: RuleConfig{Union{NoForwardsMode,HasReverseMode}} end


function to_rrule_expr(tape::Tape)
Expand Down Expand Up @@ -114,14 +79,14 @@ Generate a function equivalent to (but not extending) ChainRulesCore.rrule(),
i.e. returning the primal value and the pullback.


### Examples:
Examples:
=========

foo(x) = 2x + 1
rr = make_rrule(foo, 2.0)
val, pb = rr(foo, 3.0)
pb(1.0)

```
foo(x) = 2x + 1
rr = make_rrule(foo, 2.0)
val, pb = rr(foo, 3.0)
pb(1.0)
```
"""
make_rrule(tape::Tape) = Base.eval(@__MODULE__, to_rrule_expr(tape))
make_rrule(f, args...) = make_rrule(gradtape(f, args...))
Expand All @@ -132,7 +97,7 @@ const GENERATED_RRULE_CACHE = Dict()
function ChainRulesCore.rrule_via_ad(::YotaRuleConfig, f, args...)
res = rrule(f, args...)
!isnothing(res) && return res
sig = call_signature(f, args...)
sig = map(typeof, (f, args...))
if haskey(GENERATED_RRULE_CACHE, sig)
rr = GENERATED_RRULE_CACHE[sig]
# return Base.invokelatest(rr, f, args...)
Expand All @@ -145,61 +110,4 @@ function ChainRulesCore.rrule_via_ad(::YotaRuleConfig, f, args...)
val, pb = Base.invokelatest(rr, f, args...)
return val, dy -> Base.invokelatest(pb, dy)
end
end

###############################################################################
# Rules #
###############################################################################

function ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(identity), x)
identity_pullback(dy) = (NoTangent(), NoTangent(), dy)
return x, identity_pullback
end

# test_rrule(broadcasted, identity, [1.0, 2.0]) -- fails at the moment


function ChainRulesCore.rrule(
::YotaRuleConfig, ::typeof(Core._apply_iterate), ::typeof(iterate),
f::F, args...
) where F
# flatten nested arguments
flat = []
for a in args
push!(flat, a...)
end
# apply rrule of the function on the flat arguments
y, pb = rrule_via_ad(YotaRuleConfig(), f, flat...)
sizes = map(length, args)
function _apply_iterate_pullback(dy)
if dy isa NoTangent
return ntuple(_-> NoTangent(), length(args) + 3)
end
flat_dargs = pb(dy)
df = flat_dargs[1]
# group derivatives to tuples of the same sizes as arguments
dargs = []
j = 2
for i = 1:length(args)
darg_val = flat_dargs[j:j + sizes[i] - 1]
if length(darg_val) == 1 && darg_val[1] isa NoTangent
push!(dargs, darg_val[1])
else
darg = Tangent{typeof(darg_val)}(darg_val...)
push!(dargs, darg)
end
j = j + sizes[i]
end
return NoTangent(), NoTangent(), df, dargs...
end
return y, _apply_iterate_pullback
end


function ChainRulesCore.rrule(::typeof(tuple), args...)
y = tuple(args...)
return y, dy -> (NoTangent(), collect(dy)...)
end

# test_rrule(tuple, 1, 2, 3; output_tangent=Tangent{Tuple}((1, 2, 3)), check_inferred=false)

end
25 changes: 14 additions & 11 deletions src/core.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
import Statistics
using LinearAlgebra
using OrderedCollections
using ChainRulesCore
using ChainRules
using NNlib
using Ghost
using Ghost: Tape, Variable, V, Call, mkcall, Constant, inputs
using Ghost: bound, compile, play!, isstruct
using Ghost: remove_first_parameter, kwfunc_signature, call_signature
using Umlaut
# TODO: make these objects exported by default?
import Umlaut: Tape, Variable, V, Call, mkcall, Constant
import Umlaut: record_primitive!, isprimitive, BaseCtx
import Umlaut: play!, compile
import Ghost
import Ghost: bound, inputs


include("helpers.jl")
include("drules.jl")
include("utils.jl")
include("deprecated.jl")
# include("drules.jl")
include("chainrules.jl")
include("rulesets.jl")
include("grad.jl")
include("update.jl")
include("gradcheck.jl")


function __init__()
update_chainrules_primitives!()
end


# TODO: move to Ghost
function show_compact(tape::Tape)
println(typeof(tape))
Expand Down Expand Up @@ -59,3 +59,6 @@ function show_compact(tape::Tape)
end
end
end


Base.show(io::IO, tape::Tape{GradCtx}) = show_compact(tape)
3 changes: 3 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
function update_chainrules_primitives!(;force=false)
@info "update_chainrules_primitives!() is deprecated, you can safely remove this call"
end
Loading