Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TF frontend] add some "Segment" and "UnsortedSegment" ops #6928

Closed
wants to merge 5 commits into from

Conversation

alter-xp
Copy link
Contributor

  • segment_max, segment_min, segment_mean, segment_sum, segment_prod
  • unsorted_segment_max, unsorted_segment_min, unsorted_segment_mean
  • unsorted_segment_prod, unsorted_segment_sum

@giuseros @siju-samuel

Copy link
Contributor

@giuseros giuseros left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice algorithm @alter-xp . I will review a bit more tomorrow. General question:

  • Could you add a bit more comment to the generated TIR part?
  • Would this also work for GPU?

@@ -73,3 +75,341 @@ def full_like(x, fill_value):
The result.
"""
return cpp.full_like(x, fill_value)


def segment_max(data, segment_ids, num_out):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a bit more comments through this file? This would make it easier to read and also more future proof

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👌

with ib.for_range(0, num_segment) as n:
with ib.for_range(0, inner_size) as j:
out_index = n * inner_size + j
out[out_index] = -3.4028235e38
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a better way to do this? In theory you could pre-compute also the segment sizes. Or you could calculate the sizes on the fly. At least, I would put something like: sys.float_info.min

Copy link
Contributor Author

@alter-xp alter-xp Nov 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sys.float_info.min is a number like 2.2250738585072014e-308, , which is not very suitable here, because it is always greater than 0. I replaced it with float("inf")

Comment on lines 162 to 187
def _segment_min(data, segment_ids, out_buf):

ib = tir.ir_builder.create()
input_data = ib.buffer_ptr(data)
seg_ids = ib.buffer_ptr(segment_ids)
out = ib.buffer_ptr(out_buf)

shape = get_const_tuple(data.shape)
num_segment = get_const_tuple(out_buf.shape)[0]
inner_size = 1
for s in range(1, len(shape)):
inner_size = inner_size * shape[s]

with ib.for_range(0, num_segment) as n:
with ib.for_range(0, inner_size) as j:
out_index = n * inner_size + j
out[out_index] = 3.4028235e38

with ib.for_range(0, shape[0]) as k:
with ib.if_scope(seg_ids[k] == n):
with ib.for_range(0, inner_size) as l:
out_index = n * inner_size + l
in_index = k * inner_size + l
out[out_index] = te.min(input_data[in_index], out[out_index])

return ib.get()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like these function share a common implementation. Could you write a single _segment_op with a string op parameters and add some logic inside this function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👌

temp_index[num[0]] = k
num[0] += 1

with ib.if_scope(num[0] > 0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, if the segment is not present we omit it? Why we don't do this for max and min? Some explanation of this could would be useful, I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At that time, it was to ensure that the division by 0 would not occur during the calculation process. Now modified to only be used in the mean

}


def verify_segmet(name, data_shape, segmnet_size):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo verify_segmet -> verify_segment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👌

segment_ids.append(segment_ids[-1])
return np.array(segment_ids).astype("int32")

def get_ref_data():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using tflite to get the reference data, could you add some reference data manually? In this way tests can be independent and also show what the goal of each function is (also, you already added reference tests in test_forward.py). What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, thanks for all your comments. All suggestions have been revised.

Copy link
Contributor

@giuseros giuseros left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @alter-xp for addressing the comments! LGTM

@alter-xp
Copy link
Contributor Author

alter-xp commented Nov 23, 2020

when testing for unsorted_segment_mean, a single constant data will as input for strideslice op in tf graph. But tf frontend in tvm not support this situation. fix this in pr #6949.

@alter-xp
Copy link
Contributor Author

@giuseros hi, can you help me see what's wrong with this branch? It hasn't been merged

@alter-xp alter-xp requested a review from yzhliu as a code owner May 24, 2021 07:31
* segment_max, segment_min, segment_mean, segment_sum, segment_prod
* unsorted_segment_max, unsorted_segment_min, unsorted_segment_mean
* unsorted_segment_prod, unsorted_segment_sum
@jroesch
Copy link
Member

jroesch commented Jan 19, 2022

This PR appears to be out of date, please feel free to reopen it if this is not the case.

As part of the new year we are attempting to triage the project's open pull requests to ensure that code which
is ready for review and/or merging receives adequate attention.

Thanks again for your contribution, and feel free to reach out to discuss these changes.

@jroesch jroesch closed this Jan 19, 2022
@alter-xp alter-xp deleted the dev branch October 8, 2022 07:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants