scvi.train.SemiSupervisedAdversarialTrainingPlanConfig#

class scvi.train.SemiSupervisedAdversarialTrainingPlanConfig(key_adversarial='batch', optimizer='Adam', optimizer_creator=None, classification_ratio=50, lr=0.001, weight_decay=1e-06, n_steps_kl_warmup=None, n_epochs_kl_warmup=400, reduce_lr_on_plateau=False, lr_factor=0.6, lr_patience=30, lr_threshold=0.0, lr_scheduler_metric='elbo_validation', lr_min=0.0, adversarial_classifier=False, scale_adversarial_loss='auto', loss_kwargs=<factory>)[source]#

Config for SemiSupervisedAdversarialTrainingPlan.

Attributes table#

Methods table#

Attributes#

SemiSupervisedAdversarialTrainingPlanConfig.adversarial_classifier: bool | Any = False#
SemiSupervisedAdversarialTrainingPlanConfig.classification_ratio: int = 50#
SemiSupervisedAdversarialTrainingPlanConfig.key_adversarial: str = 'batch'#
SemiSupervisedAdversarialTrainingPlanConfig.lr: float = 0.001#
SemiSupervisedAdversarialTrainingPlanConfig.lr_factor: float = 0.6#
SemiSupervisedAdversarialTrainingPlanConfig.lr_min: float = 0.0#
SemiSupervisedAdversarialTrainingPlanConfig.lr_patience: int = 30#
SemiSupervisedAdversarialTrainingPlanConfig.lr_scheduler_metric: Literal['elbo_validation', 'reconstruction_loss_validation', 'kl_local_validation'] = 'elbo_validation'#
SemiSupervisedAdversarialTrainingPlanConfig.lr_threshold: float = 0.0#
SemiSupervisedAdversarialTrainingPlanConfig.n_epochs_kl_warmup: int | None = 400#
SemiSupervisedAdversarialTrainingPlanConfig.n_steps_kl_warmup: int | None = None#
SemiSupervisedAdversarialTrainingPlanConfig.optimizer: Literal['Adam', 'AdamW', 'Custom'] = 'Adam'#
SemiSupervisedAdversarialTrainingPlanConfig.optimizer_creator: Callable[[Iterable[Tensor]], Optimizer] | None = None#
SemiSupervisedAdversarialTrainingPlanConfig.reduce_lr_on_plateau: bool = False#
SemiSupervisedAdversarialTrainingPlanConfig.scale_adversarial_loss: Union[float, Literal['auto']] = 'auto'#
SemiSupervisedAdversarialTrainingPlanConfig.weight_decay: float = 1e-06#
SemiSupervisedAdversarialTrainingPlanConfig.loss_kwargs: dict[str, Any]#

Methods#

SemiSupervisedAdversarialTrainingPlanConfig.to_kwargs()[source]#
Return type:

dict[str, Any]