Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reverse-mode AD for SDEs #1084

Open
elgazzarr opened this issue Jul 22, 2024 · 1 comment
Open

Reverse-mode AD for SDEs #1084

elgazzarr opened this issue Jul 22, 2024 · 1 comment
Assignees
Labels

Comments

@elgazzarr
Copy link

Reverse adjoints for SDEs only works with 'TrackerAdjoint()' and only on CPU. 🐞

Training Large (e.g, Neural) SDEs on GPUs fails. The only working solution is 'TrackerAdjoint()' and this only currently works on CPU.
None of the continuous adjoints methods, e.g. 'InterpolatingAdjoint()' or 'BackwardsolveAdjoint()' work either on cpu or gpu.

  • I suspect the problem with the continuous methods is the shape of the noise during the backwards solve.
  • W.r.t. 'TrackerAdjoint()' on gpus, something is transferred to the CPU during the backwards pass. This also happens for ODEs btw.

MWE

using DifferentialEquations, Lux, ComponentArrays, Random, SciMLSensitivity, Zygote, BenchmarkTools, LuxCUDA, CUDA,
OptimizationOptimisers



dev = gpu_device()
sensealg = TrackerAdjoint()  #This works only on cpu

data = rand32(32,100,512) |> dev
x₀ = rand32(32,512) |> dev
ts = range(0.0f0, 1.0f0, length=100)
drift = Dense(32, 32, tanh)
diffusion = Scale(32, sigmoid)

basic_tgrad(u, p, t) = zero(u)

struct NeuralSDE{D, F} <: Lux.AbstractExplicitContainerLayer{(:drift, :diffusion)}
    drift::D
    diffusion::F
    solver
    tspan
    sensealg
end

function (model::NeuralSDE)(x₀, ts, p, st)
    μ(u, p, t) = model.drift(u, p.drift, st.drift)[1]
    σ(u, p, t) = model.diffusion(u, p.diffusion, st.diffusion)[1]
    func = SDEFunction{false}(μ, σ; tgrad=basic_tgrad)
    prob = SDEProblem{false}(func, x₀, model.tspan, p)
    sol = solve(prob, model.solver; saveat=ts, dt=0.01f0, sensealg = model.sensealg)
    return permutedims(cat(sol.u..., dims=3), (1,3,2))
end

function loss!(p, data)
    pred = model(x₀, ts, p, st)
    l = sum(abs2, data .- pred)
    return l, st, pred
end

rng = Random.default_rng()
model = NeuralSDE(drift, diffusion, EM(), (0.0f0, 1.0f0), sensealg)
p, st = Lux.setup(rng, model)
p = p |> ComponentArray{Float32} |> dev


adtype = AutoZygote()
optf = OptimizationFunction((p, _ ) -> loss!(p, data), adtype)
optproblem = OptimizationProblem(optf, p)
result = Optimization.solve(optproblem, ADAMW(5e-4), maxiters=10)

Error & Stacktrace

ERROR: LoadError: GPU compilation of MethodInstance for (::GPUArrays.var"#35#37")(::CUDA.CuKernelContext, ::CuDeviceMatrix{…}, ::Base.Broadcast.Broadcasted{…}, ::Int64) failed
KernelError: passing and using non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2, CUDA.DeviceMemory}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(+), Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, Base.Broadcast.Extruded{Matrix{Float32}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, which is not isbits:
  .args is of type Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, Base.Broadcast.Extruded{Matrix{Float32}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}} which is not isbits.
    .2 is of type Base.Broadcast.Extruded{Matrix{Float32}, Tuple{Bool, Bool}, Tuple{Int64, Int64}} which is not isbits.
      .x is of type Matrix{Float32} which is not isbits.


Stacktrace:
    [1] check_invocation(job::GPUCompiler.CompilerJob)
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/validation.jl:92
    [2] macro expansion
      @ ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:128 [inlined]
    [3] macro expansion
      @ ~/.julia/packages/TimerOutputs/Lw5SP/src/TimerOutput.jl:253 [inlined]
    [4] 
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:126
    [5] 
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:111
    [6] compile
      @ ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:103 [inlined]
    [7] #1145
      @ ~/.julia/packages/CUDA/Tl08O/src/compiler/compilation.jl:254 [inlined]
    [8] JuliaContext(f::CUDA.var"#1145#1148"{GPUCompiler.CompilerJob{}}; kwargs::@Kwargs{})
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:52
    [9] JuliaContext(f::Function)
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:42
   [10] compile(job::GPUCompiler.CompilerJob)
      @ CUDA ~/.julia/packages/CUDA/Tl08O/src/compiler/compilation.jl:253
   [11] actual_compilation(cache::Dict{…}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{…}, compiler::typeof(CUDA.compile), linker::typeof(CUDA.link))
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/execution.jl:237
   [12] cached_compilation(cache::Dict{…}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{…}, compiler::Function, linker::Function)
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/execution.jl:151
   [13] macro expansion
      @ ~/.julia/packages/CUDA/Tl08O/src/compiler/execution.jl:369 [inlined]
   [14] macro expansion
      @ ./lock.jl:267 [inlined]
   [15] cufunction(f::GPUArrays.var"#35#37", tt::Type{Tuple{…}}; kwargs::@Kwargs{})
      @ CUDA ~/.julia/packages/CUDA/Tl08O/src/compiler/execution.jl:364
   [16] cufunction
      @ ~/.julia/packages/CUDA/Tl08O/src/compiler/execution.jl:361 [inlined]
   [17] macro expansion
      @ ~/.julia/packages/CUDA/Tl08O/src/compiler/execution.jl:112 [inlined]
   [18] #launch_heuristic#1204
      @ ~/.julia/packages/CUDA/Tl08O/src/gpuarrays.jl:17 [inlined]
   [19] launch_heuristic
      @ ~/.julia/packages/CUDA/Tl08O/src/gpuarrays.jl:15 [inlined]
   [20] _copyto!
      @ ~/.julia/packages/GPUArrays/8Y80U/src/host/broadcast.jl:78 [inlined]
   [21] copyto!
      @ ~/.julia/packages/GPUArrays/8Y80U/src/host/broadcast.jl:44 [inlined]
   [22] copy
      @ ~/.julia/packages/GPUArrays/8Y80U/src/host/broadcast.jl:29 [inlined]
   [23] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{…}, Nothing, typeof(+), Tuple{…}})
      @ Base.Broadcast ./broadcast.jl:903
   [24] accum!(g::Tracker.Grads, x::Tracker.Tracked{CuArray{Float32, 2, CUDA.DeviceMemory}}, Δ::Matrix{Float32})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/params.jl:46
   [25] back(g::Tracker.Grads, x::Tracker.Tracked{CuArray{Float32, 2, CUDA.DeviceMemory}}, Δ::Matrix{Float32})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:134
   [26] #710
      @ ~/.julia/packages/Tracker/NYUWw/src/back.jl:128 [inlined]
   [27] #64
      @ ./tuple.jl:628 [inlined]
   [28] BottomRF
      @ ./reduce.jl:86 [inlined]
   [29] _foldl_impl(op::Base.BottomRF{Base.var"#64#65"{…}}, init::Nothing, itr::Base.Iterators.Zip{Tuple{…}})
      @ Base ./reduce.jl:58
   [30] foldl_impl
      @ ./reduce.jl:48 [inlined]
   [31] mapfoldl_impl
      @ ./reduce.jl:44 [inlined]
   [32] mapfoldl
      @ ./reduce.jl:175 [inlined]
   [33] foldl
      @ ./reduce.jl:198 [inlined]
   [34] foreach
      @ ./tuple.jl:628 [inlined]
   [35] back_(g::Tracker.Grads, c::Tracker.Call{Tracker.var"#583#584"{…}, Tuple{…}}, Δ::Vector{Float32})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:128
   [36] back(g::Tracker.Grads, x::Tracker.Tracked{CuArray{Float32, 1, CUDA.DeviceMemory}}, Δ::Vector{Float32})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:140
   [37] #710
      @ ~/.julia/packages/Tracker/NYUWw/src/back.jl:128 [inlined]
   [38] #64
      @ ./tuple.jl:628 [inlined]
   [39] BottomRF
      @ ./reduce.jl:86 [inlined]
   [40] _foldl_impl
      @ ./reduce.jl:58 [inlined]
   [41] foldl_impl
      @ ./reduce.jl:48 [inlined]
   [42] mapfoldl_impl(f::typeof(identity), op::Base.var"#64#65"{}, nt::Nothing, itr::Base.Iterators.Zip{…})
      @ Base ./reduce.jl:44
   [43] mapfoldl(f::Function, op::Function, itr::Base.Iterators.Zip{Tuple{Tuple{…}, Tuple{…}}}; init::Nothing)
      @ Base ./reduce.jl:175
   [44] mapfoldl
      @ ./reduce.jl:175 [inlined]
   [45] foldl
      @ ./reduce.jl:198 [inlined]
   [46] foreach(::Function, ::Tuple{Tracker.Tracked{…}, Tracker.Tracked{…}}, ::Tuple{Vector{…}, Vector{…}})
      @ Base ./tuple.jl:628
   [47] back_(g::Tracker.Grads, c::Tracker.Call{Tracker.var"#552#555"{…}, Tuple{…}}, Δ::Matrix{Float32})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:128
   [48] back(g::Tracker.Grads, x::Tracker.Tracked{CuArray{Float32, 2, CUDA.DeviceMemory}}, Δ::Matrix{Float32})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:140
   [49] #710
      @ ~/.julia/packages/Tracker/NYUWw/src/back.jl:128 [inlined]
   [50] #64
      @ ./tuple.jl:628 [inlined]
   [51] BottomRF
      @ ./reduce.jl:86 [inlined]
   [52] _foldl_impl(op::Base.BottomRF{Base.var"#64#65"{…}}, init::Nothing, itr::Base.Iterators.Zip{Tuple{…}})
      @ Base ./reduce.jl:58
--- the last 12 lines are repeated 98 more times ---
 [1229] foldl_impl
      @ ./reduce.jl:48 [inlined]
 [1230] mapfoldl_impl
      @ ./reduce.jl:44 [inlined]
 [1231] mapfoldl
      @ ./reduce.jl:175 [inlined]
 [1232] foldl
      @ ./reduce.jl:198 [inlined]
 [1233] foreach
      @ ./tuple.jl:628 [inlined]
 [1234] back_(g::Tracker.Grads, c::Tracker.Call{…}, Δ::RODESolution{…})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:128
 [1235] back(g::Tracker.Grads, x::Tracker.Tracked{…}, Δ::RODESolution{…})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:140
 [1236] #712
      @ ~/.julia/packages/Tracker/NYUWw/src/back.jl:155 [inlined]
 [1237] #715
      @ ~/.julia/packages/Tracker/NYUWw/src/back.jl:164 [inlined]
 [1238] (::SciMLSensitivity.var"#tracker_adjoint_backpass#368"{})(ybar::RODESolution{…})
      @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4hOeN/src/concrete_solve.jl:1319
 [1239] ZBack
      @ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:211 [inlined]
 [1240] (::Zygote.var"#kw_zpullback#53"{})(dy::RODESolution{…})
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:237
 [1241] #291
      @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [1242] (::Zygote.var"#2169#back#293"{})(Δ::RODESolution{…})
      @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [1243] #solve#51
      @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1003 [inlined]
 [1244] (::Zygote.Pullback{…})(Δ::RODESolution{…})
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [1245] #291
      @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [1246] (::Zygote.var"#2169#back#293"{})(Δ::RODESolution{…})
      @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [1247] solve
      @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:993 [inlined]
 [1248] (::Zygote.Pullback{…})(Δ::RODESolution{…})
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [1249] NeuralSDE
      @ ~/code/NeuroDynamics.jl/examples/mwe.jl:31 [inlined]
 [1250] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::CuArray{Float32, 3, CUDA.DeviceMemory})
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [1251] loss!
      @ ~/code/NeuroDynamics.jl/examples/mwe.jl:36 [inlined]
 [1252] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Float32, Nothing, Nothing})
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [1253] #39
      @ ~/code/NeuroDynamics.jl/examples/mwe.jl:48 [inlined]
 [1254] #291
      @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [1255] #2169#back
      @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [1256] OptimizationFunction
      @ ~/.julia/packages/SciMLBase/rR75x/src/scimlfunctions.jl:3763 [inlined]
 [1257] #291
      @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [1258] #2169#back
      @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [1259] #37
      @ ~/.julia/packages/OptimizationBase/mGHPN/ext/OptimizationZygoteExt.jl:94 [inlined]
 [1260] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [1261] #291
      @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [1262] #2169#back
      @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [1263] #39
      @ ~/.julia/packages/OptimizationBase/mGHPN/ext/OptimizationZygoteExt.jl:97 [inlined]
 [1264] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [1265] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Float32)
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91
 [1266] gradient(f::Function, args::ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{…}}})
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:148
 [1267] (::OptimizationZygoteExt.var"#38#56"{})(::ComponentVector{…}, ::ComponentVector{…})
      @ OptimizationZygoteExt ~/.julia/packages/OptimizationBase/mGHPN/ext/OptimizationZygoteExt.jl:97
 [1268] macro expansion
      @ ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:68 [inlined]
 [1269] macro expansion
      @ ~/.julia/packages/Optimization/fPKIF/src/utils.jl:32 [inlined]
 [1270] __solve(cache::OptimizationCache{…})
      @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:66
 [1271] solve!(cache::OptimizationCache{…})
      @ SciMLBase ~/.julia/packages/SciMLBase/rR75x/src/solve.jl:188
 [1272] solve(::OptimizationProblem{…}, ::OptimiserChain{…}; kwargs::@Kwargs{})
      @ SciMLBase ~/.julia/packages/SciMLBase/rR75x/src/solve.jl:96
in expression starting at /home/artiintel/ahmelg/code/NeuroDynamics.jl/examples/mwe.jl:50
Some type information was truncated. Use `show(err)` to see complete types.

I am using the latest releases for the packages and Julia 1.10.4.

@ChrisRackauckas
Copy link
Member

I think this is related to the ComponentArrays thing we just found. @avik-pal is looking into it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants