Skip to content

Commit

Permalink
fix conflicting bug
Browse files Browse the repository at this point in the history
  • Loading branch information
alter-xp committed May 24, 2021
1 parent d048e98 commit cfc0ccb
Showing 1 changed file with 115 additions and 0 deletions.
115 changes: 115 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,6 +1540,29 @@ def uniform_strategy(attrs, inputs, out_type, target):
return strategy


# segment_max
def wrap_compute_segment_max(topi_compute):
"""wrap segment_max topi compute"""

def _compute_segment_max(attrs, inputs, out_type):
num_segments = attrs.num_segments
return [topi_compute(inputs[0], inputs[1], num_segments, "max")]

return _compute_segment_max


@override_native_generic_func("segment_max_strategy")
def segment_max_strategy(attrs, inputs, out_type, target):
"""segment_max generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_segment_max(topi.segment_op),
wrap_topi_schedule(topi.generic.schedule_segment_op),
name="segment_max.generic",
)
return strategy


def wrap_compute_scanop(topi_compute):
"""Wrap scanop style topi compute"""

Expand All @@ -1561,6 +1584,29 @@ def cumsum_strategy(attrs, inputs, out_type, target):
return strategy


# segment_min
def wrap_compute_segment_min(topi_compute):
"""wrap segment_min topi compute"""

def _compute_segment_min(attrs, inputs, out_type):
num_segments = attrs.num_segments
return [topi_compute(inputs[0], inputs[1], num_segments, "min")]

return _compute_segment_min


@override_native_generic_func("segment_min_strategy")
def segment_min_strategy(attrs, inputs, out_type, target):
"""segment_min generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_segment_min(topi.segment_op),
wrap_topi_schedule(topi.generic.schedule_segment_op),
name="segment_min.generic",
)
return strategy


@override_native_generic_func("cumprod_strategy")
def cumprod_strategy(attrs, inputs, out_type, target):
"""cumprod generic strategy"""
Expand All @@ -1573,6 +1619,29 @@ def cumprod_strategy(attrs, inputs, out_type, target):
return strategy


# segment_mean
def wrap_compute_segment_mean(topi_compute):
"""wrap segment_mean topi compute"""

def _compute_segment_mean(attrs, inputs, out_type):
num_segments = attrs.num_segments
return [topi_compute(inputs[0], inputs[1], num_segments, "mean")]

return _compute_segment_mean


@override_native_generic_func("segment_mean_strategy")
def segment_mean_strategy(attrs, inputs, out_type, target):
"""segment_mean generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_segment_mean(topi.segment_op),
wrap_topi_schedule(topi.generic.schedule_segment_op),
name="segment_mean.generic",
)
return strategy


def wrap_compute_unique(topi_compute):
"""Wrap unique topi compute"""

Expand All @@ -1594,8 +1663,54 @@ def unique_strategy(attrs, inputs, out_type, target):
return strategy


# segment_sum
def wrap_compute_segment_sum(topi_compute):
"""wrap segment_sum topi compute"""

def _compute_segment_sum(attrs, inputs, out_type):
num_segments = attrs.num_segments
return [topi_compute(inputs[0], inputs[1], num_segments, "sum")]

return _compute_segment_sum


@override_native_generic_func("segment_sum_strategy")
def segment_sum_strategy(attrs, inputs, out_type, target):
"""segment_sum generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_segment_sum(topi.segment_op),
wrap_topi_schedule(topi.generic.schedule_segment_op),
name="segment_sum.generic",
)
return strategy


@generic_func
def schedule_transpose(attrs, outs, target):
"""schedule transpose"""
with target:
return schedule_injective(attrs, outs, target)


# segment_prod
def wrap_compute_segment_prod(topi_compute):
"""wrap segment_prod topi compute"""

def _compute_segment_prod(attrs, inputs, out_type):
num_segments = attrs.num_segments
return [topi_compute(inputs[0], inputs[1], num_segments, "prod")]

return _compute_segment_prod


@override_native_generic_func("segment_prod_strategy")
def segment_prod_strategy(attrs, inputs, out_type, target):
"""segment_prod generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_segment_prod(topi.segment_op),
wrap_topi_schedule(topi.generic.schedule_segment_op),
name="segment_prod.generic",
)
return strategy

0 comments on commit cfc0ccb

Please sign in to comment.