New in 0.17.0 (2022-07-14)#

Major Changes#

  • Experimental MuData support for TOTALVI via the method setup_mudata(). For several of the existing AnnDataField classes, there is now a MuData counterpart with an additional mod_key argument used to indicate the modality where the data lives (e.g. LayerField to MuDataLayerField). These modified classes are simply wrapped versions of the original AnnDataField code via the new method #1474.

  • Modification of the generative() method’s outputs to return prior and likelihood properties as Distribution objects. Concerned modules are AmortizedLDAPyroModule, AutoZIVAE, MULTIVAE, PEAKVAE, TOTALVAE, SCANVAE, VAE, and VAEC. This allows facilitating the manipulation of these distributions for model training and inference #1356.

  • Major changes to Jax support for scvi-tools models to generalize beyond JaxSCVI. Support for Jax remains experimental and is subject to breaking changes:

    • Consistent module interface for Flax modules (Jax-backed) via JaxModuleWrapper, such that they are compatible with the existing BaseModelClass #1506.

    • JaxTrainingPlan now leverages Pytorch Lightning to factor out Jax-specific training loop implementation #1506.

    • Enable basic device management in Jax-backed modules #1585.

Minor changes#

  • Add on_load() callback which is called on load() prior to loading the module state dict #1542.

  • Refactor metrics code and use MetricCollection to update metrics in bulk #1529.

  • Add max_kl_weight and min_kl_weight to TrainingPlan #1595.

  • Add a warning to UnsupervisedTrainingMixin that is raised if max_kl_weight is not reached during training #1595.

Breaking changes#

  • Any methods relying on the output of inference and generative from existing scvi-tools models (e.g. SCVI, SCANVI) will need to be modified to accept torch.Distribution objects rather than tensors for each parameter (e.g. px_m, px_v) #1356.

  • The signature of compute_and_log_metrics() has changed to support the use of MetricCollection. The typical modification required will look like changing self.compute_and_log_metrics(scvi_loss, self.elbo_train) to self.compute_and_log_metrics(scvi_loss, self.train_metrics, "train"). The same is necessary for validation metrics except with self.val_metrics and the mode "validation" #1529.

Bug Fixes#