diff --git a/python/tvm/meta_schedule/schedule_rule/auto_bind.py b/python/tvm/meta_schedule/schedule_rule/auto_bind.py index c211093e9275..99a91f606e32 100644 --- a/python/tvm/meta_schedule/schedule_rule/auto_bind.py +++ b/python/tvm/meta_schedule/schedule_rule/auto_bind.py @@ -33,12 +33,15 @@ class AutoBind(ScheduleRule): The maximum number of threadblock on GPU. thread_extents: Optional[List[int]] Candidates of thread axis extent. + max_threads_per_block: int + The maximum number of threads per block, if it is known when this schedule rule is created. """ def __init__( self, max_threadblocks: int = 256, thread_extents: Optional[List[int]] = None, + max_threads_per_block: int = -1, ) -> None: if thread_extents is None: thread_extents = [32, 64, 128, 256, 512, 1024] @@ -46,4 +49,5 @@ def __init__( _ffi_api.ScheduleRuleAutoBind, # type: ignore # pylint: disable=no-member max_threadblocks, thread_extents, + max_threads_per_block, )