Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI][CUDA] Improve the performance of scatter_nd #8479

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 85 additions & 37 deletions python/tvm/topi/cuda/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,9 +772,10 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
updates = ib.buffer_ptr(updates_ptr)
out = ib.buffer_ptr(out_ptr)

# We combine all the indices dimensions but the first one into a single
# dimension so we can iterate it in single loop instead of an arbitrary
# number of loops. We do the same thing for all the update dimensions.
atomic_add_return = ib.allocate(
updates.dtype, (1,), name="atomic_add_return", scope="local"
)

fused_indices_dimension = 1
for i in indices_ptr.shape[1:]:
fused_indices_dimension *= i
Expand All @@ -787,44 +788,91 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
for i in data_ptr.shape:
fused_shape *= i

# For now we avoid parallizing over dimensions indexed by `indices` as
# there may be repeated indices and hadling parallel accumulation can
# be hard. So we parallelize over X_M .. X_{N-1} instead. This will
# work well when these dimensions are large enough to saturate memory
# bandwidth, but performance will be bad when these dimensions are
# small.
Comment on lines -790 to -795
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment about how we are doing parallelism (we are thread-parallel over all the update dimension and each block handles one set of indices?)

Copy link
Contributor

@CaptainDuke CaptainDuke Jul 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We follow the original parallelism scheme, but replace ib.for_range() with blockIdx.y.
Atomic_add guarantees correctness when mode=="add"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the comment in the code to reflect this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

bx = te.thread_axis("blockIdx.x")
tx = te.thread_axis("threadIdx.x")
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
tdim = min(max_threads, fused_updates_dimension)
ib.scope_attr(tx, "thread_extent", tdim)
bdim = ceil_div(fused_updates_dimension, tdim)
ib.scope_attr(bx, "thread_extent", bdim)

# Copy data into the output. This loop writes to the same portions of
# memory as the following loop, so we do not need a memory sync.
with ib.for_range(0, ceil_div(fused_shape, fused_updates_dimension), name="i") as i:
index = i * fused_updates_dimension + bx * tdim + tx
with ib.if_scope(bx * tdim + tx < fused_updates_dimension):

with ib.new_scope():
bdim = ceil_div(fused_shape, tdim)
bx = te.thread_axis("blockIdx.x")
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(bx, "thread_extent", bdim)
ib.scope_attr(tx, "thread_extent", tdim)

index = bx * tdim + tx
with ib.if_scope(index < fused_shape):
out[index] = data[index]

with ib.for_range(0, fused_indices_dimension) as i:
j = bx * tdim + tx
with ib.if_scope(j < fused_updates_dimension):
offset = fused_updates_dimension
index = j # This is x_M, .. x_{N-1} part of the index into out.
# Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part
# of the index into out.
Comment on lines -815 to -817
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you keep this comment. I believe it still holds

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

for l in reversed(range(indices_ptr.shape[0].value)):
# indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}]
index += offset * indices[i + l * fused_indices_dimension]
offset *= data_ptr.shape[l]
if mode == "update":
out[index] = updates[i * fused_updates_dimension + j]
elif mode == "add":
out[index] += updates[i * fused_updates_dimension + j]
else:
raise NotImplementedError("scatter_nd mode not in [update, add]:", mode)
# For better performance, we introduce blockIdx.y to implement for-loops
# within one thread.
# The code is parallel over the scattered indices, so we use atomic_add
# to guarantee correctness when mode=="add"

# For now, atomic is not supported by target "vulkan", "metal", or "cuda" with "int64"
# So we fallback to normal algorithm, using "+=" rather than atomic_add

# TODO (CaptainDuke):
# Since multiple threads compete for the same write index, which leads to
# non-determinstic output for update mode. We could add a new attribute,
# "allow_non_deterministic", which can be conditionally set to True by
# each frontend when non-determinsm is allowed.
cur_target_kind = str(tvm.target.Target.current(allow_none=False).kind)
with ib.new_scope():
if (
mode == "add"
and cur_target_kind not in ["vulkan", "metal"]
and updates.dtype in ["int32", "float32"]
):
bdim_x = fused_indices_dimension
bdim_y = ceil_div(fused_updates_dimension, tdim)
# In case of large input sizes, fused_indices_dimension might be too large.
# So we use blockIdx.x because holds larger scales.
bx = te.thread_axis("blockIdx.x")
by = te.thread_axis("blockIdx.y")
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(bx, "thread_extent", bdim_x)
ib.scope_attr(by, "thread_extent", bdim_y)
ib.scope_attr(tx, "thread_extent", tdim)

j = by * tdim + tx
with ib.if_scope(j < fused_updates_dimension):
offset = fused_updates_dimension
index = j # This is x_M, .. x_{N-1} part of the index into out.
# Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}]
# part of the index into out.
up_index = bx * fused_updates_dimension + j
for l in reversed(range(indices_ptr.shape[0].value)):
# indices[bx * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}]
index += offset * indices[bx + l * fused_indices_dimension]
offset *= data_ptr.shape[l]
atomic_add_return[0] = atomic_add(
tvm.tir.call_intrin("handle", "tir.address_of", out[index]),
updates[up_index],
)
else:
bdim_x = ceil_div(fused_updates_dimension, tdim)
bx = te.thread_axis("blockIdx.x")
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(bx, "thread_extent", bdim_x)
ib.scope_attr(tx, "thread_extent", tdim)
with ib.for_range(0, fused_indices_dimension) as i:
j = bx * tdim + tx
with ib.if_scope(j < fused_updates_dimension):
offset = fused_updates_dimension
index = j # This is x_M, .. x_{N-1} part of the index into out.
# Build up the
# indices[0, y_0, .. y_{K-1}], ... indices[M-1, y_0, .. y_{K-1}]
# part of the index into out.
for l in reversed(range(indices_ptr.shape[0].value)):
# indices[i * l * fused_indices_dimension] = indices[l, y_0,
# ... y_{k-1}]
index += offset * indices[i + l * fused_indices_dimension]
offset *= data_ptr.shape[l]
if mode == "update":
out[index] = updates[i * fused_updates_dimension + j]
elif mode == "add":
out[index] += updates[i * fused_updates_dimension + j]
else:
raise NotImplementedError("scatter_nd mode not in [update, add]:", mode)

return ib.get()

Expand Down
3 changes: 2 additions & 1 deletion tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1884,7 +1884,8 @@ def verify_scatter_nd_with_stack(
):
data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype))
indices_vars = [
relay.var("ind{i}", shape=v.shape, dtype=str(v.dtype)) for i, v in enumerate(indices_np)
relay.var("ind%d" % i, shape=v.shape, dtype=str(v.dtype))
for i, v in enumerate(indices_np)
]
updates = relay.var("updates", shape=updates_np.shape, dtype=str(updates_np.dtype))

Expand Down