scvi.module.base.BaseModuleClass#
- class scvi.module.base.BaseModuleClass[source]#
Bases:
torch.nn.modules.module.Module
Abstract class for scvi-tools modules.
Attributes table#
Methods table#
|
Forward pass through the network. |
|
Run the generative model. |
|
Run the inference (recognition) model. |
|
Compute the loss for a minibatch of data. |
|
Generate samples from the learned model. |
Attributes#
T_destination#
- BaseModuleClass.T_destination#
alias of TypeVar(‘T_destination’, bound=
Mapping
[str
,torch.Tensor
])
alias of TypeVar(‘T_destination’, bound=Mapping
[str
, torch.Tensor
])
.. autoattribute:: BaseModuleClass.T_destination
device
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- BaseModuleClass.device#
dump_patches#
- BaseModuleClass.dump_patches: bool = False#
This allows better BC support for
load_state_dict()
. Instate_dict()
, the version number will be saved as in the attribute _metadata of the returned state dict, and thus pickled. _metadata is a dictionary with keys that follow the naming convention of state dict. See_load_from_state_dict
on how to use this information in loading.If new parameters/buffers are added/removed from a module, this number shall be bumped, and the module’s _load_from_state_dict method can compare the version number and do appropriate changes if the state dict is from before the change.
training#
Methods#
forward#
- BaseModuleClass.forward(tensors, get_inference_input_kwargs=None, get_generative_input_kwargs=None, inference_kwargs=None, generative_kwargs=None, loss_kwargs=None, compute_loss=True)[source]#
Forward pass through the network.
- Parameters
- tensors
tensors to pass through
- get_inference_input_kwargs :
dict
|None
Optional
[dict
] (default:None
) Keyword args for _get_inference_input()
- get_generative_input_kwargs :
dict
|None
Optional
[dict
] (default:None
) Keyword args for _get_generative_input()
- inference_kwargs :
dict
|None
Optional
[dict
] (default:None
) Keyword args for inference()
- generative_kwargs :
dict
|None
Optional
[dict
] (default:None
) Keyword args for generative()
- loss_kwargs :
dict
|None
Optional
[dict
] (default:None
) Keyword args for loss()
- compute_loss
Whether to compute loss on forward pass. This adds another return value.
- Return type
Tuple
[Tensor
,Tensor
] |Tuple
[Tensor
,Tensor
,LossRecorder
]Union
[Tuple
[Tensor
,Tensor
],Tuple
[Tensor
,Tensor
,LossRecorder
]]
generative#
- abstract BaseModuleClass.generative(*args, **kwargs)[source]#
Run the generative model.
This function should return the parameters associated with the likelihood of the data. This is typically written as \(p(x|z)\).
This function should return a dictionary with str keys and
Tensor
values.- Return type
{
str
:Tensor
|Distribution
}Dict
[str
,Union
[Tensor
,Distribution
]]
inference#
- abstract BaseModuleClass.inference(*args, **kwargs)[source]#
Run the inference (recognition) model.
In the case of variational inference, this function will perform steps related to computing variational distribution parameters. In a VAE, this will involve running data through encoder networks.
This function should return a dictionary with str keys and
Tensor
values.- Return type
{
str
:Tensor
|Distribution
}Dict
[str
,Union
[Tensor
,Distribution
]]
loss#
- abstract BaseModuleClass.loss(*args, **kwargs)[source]#
Compute the loss for a minibatch of data.
This function uses the outputs of the inference and generative functions to compute a loss. This many optionally include other penalty terms, which should be computed here.
This function should return an object of type
LossRecorder
.- Return type