scvi.train.JaxTrainingPlanConfig#

class scvi.train.JaxTrainingPlanConfig(optimizer='Adam', optimizer_creator=None, lr=0.001, weight_decay=1e-06, eps=0.01, max_norm=None, n_steps_kl_warmup=None, n_epochs_kl_warmup=400, loss_kwargs=<factory>)[source]#

Config for JaxTrainingPlan.

Attributes table#

Methods table#

Attributes#

JaxTrainingPlanConfig.eps: float = 0.01#
JaxTrainingPlanConfig.lr: float = 0.001#
JaxTrainingPlanConfig.max_norm: float | None = None#
JaxTrainingPlanConfig.n_epochs_kl_warmup: int | None = 400#
JaxTrainingPlanConfig.n_steps_kl_warmup: int | None = None#
JaxTrainingPlanConfig.optimizer: Literal['Adam', 'AdamW', 'Custom'] = 'Adam'#
JaxTrainingPlanConfig.optimizer_creator: Callable[[], Any] | None = None#
JaxTrainingPlanConfig.weight_decay: float = 1e-06#
JaxTrainingPlanConfig.loss_kwargs: dict[str, Any]#

Methods#

JaxTrainingPlanConfig.to_kwargs()[source]#
Return type:

dict[str, Any]