scvi.external.tangram.TangramMapper#

class scvi.external.tangram.TangramMapper(n_obs_sc, n_obs_sp, lambda_g1=1.0, lambda_d=0.0, lambda_g2=0.0, lambda_r=0.0, lambda_count=1.0, lambda_f_reg=1.0, constrained=False, target_count=None, training=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: JaxBaseModuleClass

Tangram Mapper Model.

Attributes table#

Methods table#

generative()

No generative model here.

inference()

Run inference model.

loss(tensors, inference_outputs, ...)

Compute loss.

setup()

Setup model.

Attributes#

TangramMapper.constrained: bool = False#
TangramMapper.lambda_count: float = 1.0#
TangramMapper.lambda_d: float = 0.0#
TangramMapper.lambda_f_reg: float = 1.0#
TangramMapper.lambda_g1: float = 1.0#
TangramMapper.lambda_g2: float = 0.0#
TangramMapper.lambda_r: float = 0.0#
TangramMapper.name: Optional[str] = None#
TangramMapper.parent: Union[Module, Scope, _Sentinel, None] = None#
TangramMapper.required_rngs[source]#
TangramMapper.scope: Scope | None = None#
TangramMapper.target_count: int | None = None#
TangramMapper.training: bool = True#
TangramMapper.n_obs_sc: int#
TangramMapper.n_obs_sp: int#

Methods#

TangramMapper.generative()[source]#

No generative model here.

Return type:

dict

TangramMapper.inference()[source]#

Run inference model.

Return type:

dict

TangramMapper.loss(tensors, inference_outputs, generative_outputs)[source]#

Compute loss.

TangramMapper.setup()[source]#

Setup model.