diff --git a/src/differentiation/zygote.jl b/src/differentiation/zygote.jl index f8a2e70d16..ffe180113f 100644 --- a/src/differentiation/zygote.jl +++ b/src/differentiation/zygote.jl @@ -1,11 +1,11 @@ struct ZygoteDiffBackend <: AbstractDiffBackend end function Manifolds._gradient(f, p, ::ZygoteDiffBackend) - return Zygote.gradient(f, p) + return Zygote.gradient(f, p)[1] end function Manifolds._gradient!(f, X, p, ::ZygoteDiffBackend) - return Zygote.gradient!(X, f, p) + return copyto!(X, Zygote.gradient(f, p)[1]) end push!(Manifolds._diff_backends, ZygoteDiffBackend()) diff --git a/test/differentiation.jl b/test/differentiation.jl index b060d67962..cc0d2361ed 100644 --- a/test/differentiation.jl +++ b/test/differentiation.jl @@ -128,6 +128,7 @@ using LinearAlgebra: Diagonal, dot @test isapprox(X, [1.0, 0.0]) end @testset for backend in [fd51, fwd_diff, finite_diff, reverse_diff, zygote_diff] + diff_backend!(backend) X = [-1.0, -1.0] @test _gradient(f1, [1.0, -1.0]) ≈ [1.0, -2.0] @test _gradient!(f1, X, [1.0, -1.0]) === X