New in 0.17.0 (2022-07-14)#
Major Changes#
Experimental MuData support for
TOTALVI
via the methodsetup_mudata()
. For several of the existingAnnDataField
classes, there is now a MuData counterpart with an additionalmod_key
argument used to indicate the modality where the data lives (e.g.LayerField
toMuDataLayerField
). These modified classes are simply wrapped versions of the originalAnnDataField
code via the newscvi.data.fields.MuDataWrapper
method #1474.Modification of the
generative()
method’s outputs to return prior and likelihood properties asDistribution
objects. Concerned modules areAmortizedLDAPyroModule
,AutoZIVAE
,MULTIVAE
,PEAKVAE
,TOTALVAE
,SCANVAE
,VAE
, andVAEC
. This allows facilitating the manipulation of these distributions for model training and inference #1356.Major changes to Jax support for scvi-tools models to generalize beyond
JaxSCVI
. Support for Jax remains experimental and is subject to breaking changes:Consistent module interface for Flax modules (Jax-backed) via
JaxModuleWrapper
, such that they are compatible with the existingBaseModelClass
#1506.JaxTrainingPlan
now leverages Pytorch Lightning to factor out Jax-specific training loop implementation #1506.Enable basic device management in Jax-backed modules #1585.
Minor changes#
Add
on_load()
callback which is called onload()
prior to loading the module state dict #1542.Refactor metrics code and use
MetricCollection
to update metrics in bulk #1529.Add
max_kl_weight
andmin_kl_weight
toTrainingPlan
#1595.Add a warning to
UnsupervisedTrainingMixin
that is raised ifmax_kl_weight
is not reached during training #1595.
Breaking changes#
Any methods relying on the output of
inference
andgenerative
from existing scvi-tools models (e.g.SCVI
,SCANVI
) will need to be modified to accepttorch.Distribution
objects rather than tensors for each parameter (e.g.px_m
,px_v
) #1356.The signature of
compute_and_log_metrics()
has changed to support the use ofMetricCollection
. The typical modification required will look like changingself.compute_and_log_metrics(scvi_loss, self.elbo_train)
toself.compute_and_log_metrics(scvi_loss, self.train_metrics, "train")
. The same is necessary for validation metrics except withself.val_metrics
and the mode"validation"
#1529.
Bug Fixes#
Fix issue with
get_normalized_expression()
with multiple samples and additional continuous covariates. This bug originated fromgenerative()
failing to match the dimensions of the continuous covariates with the input whenn_samples>1
ininference()
in multiple module classes #1548.Add support for padding layers in
prepare_query_anndata()
which is necessary to runload_query_data()
for a model setup with a layer instead of X #1575.