Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fori_loop|while_loop] add test case for Linear model with changeable weight #7302

Open
wants to merge 10 commits into
base: master
Choose a base branch
from

Conversation

ManfeiBai
Copy link
Collaborator

@ManfeiBai ManfeiBai commented Jun 17, 2024

Add test for linear model with changeable weight/bias layer

Dump HLO for WhileLoopTest.test_while_loop_simple_linear_outside_loop_change_weight_bias:https://gist.github.com/ManfeiBai/73ee0264113f1968135976bee80ef4a0

@miladm miladm assigned miladm and ManfeiBai and unassigned miladm Jun 19, 2024
@miladm
Copy link
Collaborator

miladm commented Jun 20, 2024

I see this PR is in draft state; is it ready for review @ManfeiBai?

@ManfeiBai
Copy link
Collaborator Author

ManfeiBai commented Jun 20, 2024

I see this PR is in draft state; is it ready for review @ManfeiBai?

Hi, @miladm, still debuging locally, its not ready for review now

@ManfeiBai ManfeiBai force-pushed the linear_model_with_changeable_weight branch from 5b9947b to b2923f0 Compare June 21, 2024 01:17
@ManfeiBai ManfeiBai marked this pull request as ready for review June 21, 2024 01:17
@ManfeiBai ManfeiBai requested a review from JackCaoG June 21, 2024 01:17
@ManfeiBai ManfeiBai changed the title [test] Linear model with changeable weight [fori_loop|while_loop] add test case for Linear model with changeable weight Jun 21, 2024
@ManfeiBai ManfeiBai requested review from miladm and qihqi June 21, 2024 04:19
return next_iteri, weights, bias, next_x

inputs = torch.stack((weights[2], bias[2],
torch.tensor([[1.0, 1.0], [1.0, 1.0]],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generate test case more than 1.0

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

[[5.1, 6.2], [7.3, 8.4]]],
device=device)

bias = torch.tensor([[[1.0, 2.0], [3.0, 4.0]], [[1.0, 2.0], [3.0, 4.0]],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bias size is not the same like weights

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extend size to match weight size due to concatenate need size match

device = xm.xla_device()
torch.set_grad_enabled(False)

# TODO(@manfei): enable weights[0] != weights[1] and bias[0] != bias[1], now test pass with weights[0] == weights[1] and bias[0]==bias[1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this sounds like a bug, do we know why?

torch.tensor([[1.0, 2.0], [3.0, 4.0]],
dtype=torch.float32,
device=device)))
print("inputs: ", inputs) # needed to enable func catch stacked inputs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

@@ -91,6 +91,142 @@ def forward_without_while_loop_op(self, iteri, x):

self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop)))

def test_while_loop_simple_linear_outside_loop_change_weight_bias(self):
device = xm.xla_device()
torch.set_grad_enabled(False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this?


expected = inputs
while (iteri >= 0):
weight_value = expected[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is weight_value always expected[0], shouldn't you iterating over the expected?

Copy link
Collaborator Author

@ManfeiBai ManfeiBai Aug 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank,

expected would be list contain (current_weight, current_bias, next_x), which was generated in the last iteration;

we need stacked value here to let the next iteration know the right weight/bias and input_x;

we iterate over list: weights(line 99) and bias(line 103), and expected is a stack of weights[i] and bias[i], so we prefer to gain current weight_value from expected[0] here

def cond_fn(iteri, weights, bias, x):
return iteri >= 0

def body_fn(iteri, weights, bias, x):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of x just called it carry or something meaningful.

local_linear.bias = torch.nn.parameter.Parameter(
data=local_bias_value, requires_grad=False)
next_iteri = iteri - 1
next_x = torch.stack((weights[-next_iteri - 1], bias[-next_iteri - 1],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need to construct the weights[-next_iteri - 1] and bias[-next_iteri - 1]? Couldn't you get those values in the beginning of this function directly instead of doing it at step end and pass it to the next iter?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, would try locally

@ManfeiBai
Copy link
Collaborator Author

cc @tengyifei

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants