-
Notifications
You must be signed in to change notification settings - Fork 346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Tangram #1743
Implement Tangram #1743
Conversation
Codecov ReportBase: 90.84% // Head: 90.85% // Increases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## main #1743 +/- ##
========================================
Coverage 90.84% 90.85%
========================================
Files 115 118 +3
Lines 9758 9969 +211
========================================
+ Hits 8865 9057 +192
- Misses 893 912 +19
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
tensor_dict = self._get_tensor_dict(device=device) | ||
training_plan = JaxTrainingPlan(self.module, **plan_kwargs) | ||
module_init = self.module.init(self.module.rngs, tensor_dict) | ||
state, params = module_init.pop("params") | ||
training_plan.set_train_state(params, state) | ||
train_step_fn = JaxTrainingPlan.jit_training_step | ||
pbar = track(range(max_epochs), style="tqdm", description="Training") | ||
history = pd.DataFrame(index=np.arange(max_epochs), columns=["loss"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @justjhong it's actually pretty flexible
you can use the training plan just as a shell
@dataclass | ||
class AnnDataManagerValidationCheck: | ||
""" | ||
Validation checks for AnnorMudata scvi-tools compat. | ||
|
||
Parameters | ||
---------- | ||
check_if_view | ||
If True, checks if AnnData is a view. | ||
check_fully_paired_mudata | ||
If True, checks if MuData is fully paired across mods. | ||
""" | ||
|
||
check_if_view: bool = True | ||
check_fully_paired_mudata: bool = True | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@justjhong what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems fine, the user will never interact with this right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct
Fixes #1738
Example on this tutorials PR.
Also fixes an issue with the
to()
to device method after removingJaxModuleWrapper