Enzyme fails on GPU kernel
Created by: pxl-th
On CPU
Enzyme works fine and is able to differentiate through spherical_harmonics!
kernel, however on CUDADevice
it fails with the error below.
Looks like it is missing something similar to transform_gpu!
which inserts return nothing
.
Error:
ERROR: LoadError: GPU compilation of kernel #df#1(KernelAbstractions.CompilerMetadata{KernelAbstractions.NDIteration.DynamicSize, KernelAbstractions.NDIteration.DynamicCheck, Nothing, CartesianIndices{1, Tuple{Base.OneTo{Int64}}}, KernelAbstractions.NDIteration.NDRange{1, KernelAbstractions.NDIteration.DynamicSize, KernelAbstractions.NDIteration.StaticSize{(512,)}, CartesianIndices{1, Tuple{Base.OneTo{Int64}}}, Nothing}}, Duplicated{CuDeviceMatrix{Float32, 1}}, Duplicated{CuDeviceMatrix{Float32, 1}}) failed
KernelError: kernel returns a value of type `Tuple{}`
Make sure your kernel function ends in `return`, `return nothing` or `nothing`.
If the returned value is of type `Union{}`, your Julia code probably throws an exception.
Inspect the code with ``@device_code_warntype`` for more details.
MWE:
using InteractiveUtils
using CUDA
using CUDAKernels
using Enzyme
using KernelAbstractions
using KernelGradients
Base.rand(::CPU, T, shape) = rand(T, shape)
Base.rand(::CUDADevice, T, shape) = CUDA.rand(T, shape)
Base.zeros(::CPU, T, shape) = zeros(T, shape)
Base.zeros(::CUDADevice, T, shape) = CUDA.zeros(T, shape)
Base.ones(::CPU, T, shape) = ones(T, shape)
Base.ones(::CUDADevice, T, shape) = CUDA.ones(T, shape)
linear_threads(::CPU) = Threads.nthreads()
linear_threads(::CUDADevice) = 128
@kernel function spherical_harmonics!(encodings, @Const(directions))
i = @index(Global)
x = directions[1, i]
y = directions[2, i]
z = directions[3, i]
encodings[1, i] = 0.28209479177387814f0
encodings[2, i] = -0.48860251190291987f0 * y
encodings[3, i] = 0.48860251190291987f0 * z
encodings[4, i] = -0.48860251190291987f0 * x
end
function ∇spherical_harmonics!(∂encodings, ∂directions, encodings, directions, device)
∇k! = Enzyme.autodiff(spherical_harmonics!(device, linear_threads(device)))
# @device_code dir="./" ∇k!(Duplicated(encodings, ∂encodings), Duplicated(directions, ∂directions); ndrange=1)
n = size(encodings, 2)
wait(∇k!(Duplicated(encodings, ∂encodings), Duplicated(directions, ∂directions); ndrange=n))
nothing
end
function main()
device = CUDADevice()
n = 1
x = rand(device, Float32, (3, n))
y = zeros(device, Float32, (4, n))
∂L∂x = zeros(device, Float32, (3, n))
∂L∂y = ones(device, Float32, (4, n))
wait(spherical_harmonics!(device, linear_threads(device))(y, x; ndrange=n))
∇spherical_harmonics!(∂L∂y, ∂L∂x, y, x, device)
end
main()
I'm on Julia 1.8.0-rc1
.
]st
:
[052768ef] CUDA v3.10.1
[72cfdca4] CUDAKernels v0.4.2 `https://github.com/JuliaGPU/KernelAbstractions.jl.git:lib/CUDAKernels#master`
[7da242da] Enzyme v0.10.0
[63c18a36] KernelAbstractions v0.8.2 `https://github.com/JuliaGPU/KernelAbstractions.jl.git#master`
[e5faadeb] KernelGradients v0.1.1 `https://github.com/JuliaGPU/KernelAbstractions.jl.git:lib/KernelGradients#master`