-
Notifications
You must be signed in to change notification settings - Fork 471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fori_loop|while_loop] add test case for Linear model with changeable weight #7302
base: master
Are you sure you want to change the base?
Conversation
I see this PR is in draft state; is it ready for review @ManfeiBai? |
Hi, @miladm, still debuging locally, its not ready for review now |
5b9947b
to
b2923f0
Compare
test/test_while_loop.py
Outdated
return next_iteri, weights, bias, next_x | ||
|
||
inputs = torch.stack((weights[2], bias[2], | ||
torch.tensor([[1.0, 1.0], [1.0, 1.0]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generate test case more than 1.0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
test/test_while_loop.py
Outdated
[[5.1, 6.2], [7.3, 8.4]]], | ||
device=device) | ||
|
||
bias = torch.tensor([[[1.0, 2.0], [3.0, 4.0]], [[1.0, 2.0], [3.0, 4.0]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bias size is not the same like weights
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
extend size to match weight size due to concatenate need size match
device = xm.xla_device() | ||
torch.set_grad_enabled(False) | ||
|
||
# TODO(@manfei): enable weights[0] != weights[1] and bias[0] != bias[1], now test pass with weights[0] == weights[1] and bias[0]==bias[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this sounds like a bug, do we know why?
torch.tensor([[1.0, 2.0], [3.0, 4.0]], | ||
dtype=torch.float32, | ||
device=device))) | ||
print("inputs: ", inputs) # needed to enable func catch stacked inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
@@ -91,6 +91,142 @@ def forward_without_while_loop_op(self, iteri, x): | |||
|
|||
self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop))) | |||
|
|||
def test_while_loop_simple_linear_outside_loop_change_weight_bias(self): | |||
device = xm.xla_device() | |||
torch.set_grad_enabled(False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need this?
|
||
expected = inputs | ||
while (iteri >= 0): | ||
weight_value = expected[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is weight_value
always expected[0]
, shouldn't you iterating over the expected
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank,
expected
would be list contain (current_weight, current_bias, next_x), which was generated in the last iteration;
we need stacked value here to let the next iteration know the right weight/bias and input_x;
we iterate over list: weights
(line 99) and bias
(line 103), and expected
is a stack of weights[i]
and bias[i]
, so we prefer to gain current weight_value from expected[0]
here
def cond_fn(iteri, weights, bias, x): | ||
return iteri >= 0 | ||
|
||
def body_fn(iteri, weights, bias, x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of x
just called it carry
or something meaningful.
local_linear.bias = torch.nn.parameter.Parameter( | ||
data=local_bias_value, requires_grad=False) | ||
next_iteri = iteri - 1 | ||
next_x = torch.stack((weights[-next_iteri - 1], bias[-next_iteri - 1], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you need to construct the weights[-next_iteri - 1]
and bias[-next_iteri - 1]
? Couldn't you get those values in the beginning of this function directly instead of doing it at step end and pass it to the next iter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, would try locally
cc @tengyifei |
Add test for linear model with changeable weight/bias layer
Dump HLO for
WhileLoopTest.test_while_loop_simple_linear_outside_loop_change_weight_bias
:https://gist.github.com/ManfeiBai/73ee0264113f1968135976bee80ef4a0