scvi.external.tangram.TangramMapper#

class scvi.external.tangram.TangramMapper(n_obs_sc, n_obs_sp, lambda_g1=1.0, lambda_d=0.0, lambda_g2=0.0, lambda_r=0.0, lambda_count=1.0, lambda_f_reg=1.0, constrained=False, target_count=None, training=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: JaxBaseModuleClass

Tangram Mapper Model.

Attributes table#

constrained

lambda_count

lambda_d

lambda_f_reg

lambda_g1

lambda_g2

lambda_r

name

parent

required_rngs

Returns a tuple of rng sequence names required for this Flax module.

scope

target_count

training

n_obs_sc

n_obs_sp

Methods table#

generative()

No generative model here.

inference()

Run inference model.

loss(tensors, inference_outputs, ...)

Compute loss.

setup()

Setup model.

Attributes#

TangramMapper.constrained: bool = False#
TangramMapper.lambda_count: float = 1.0#
TangramMapper.lambda_d: float = 0.0#
TangramMapper.lambda_f_reg: float = 1.0#
TangramMapper.lambda_g1: float = 1.0#
TangramMapper.lambda_g2: float = 0.0#
TangramMapper.lambda_r: float = 0.0#
TangramMapper.name: Optional[str] = None#
TangramMapper.parent: Union[Type[Module], Scope, Type[_Sentinel], None] = None#
TangramMapper.required_rngs[source]#
TangramMapper.scope: Optional[Scope] = None#
TangramMapper.target_count: Optional[int] = None#
TangramMapper.training: bool = True#
TangramMapper.n_obs_sc: int#
TangramMapper.n_obs_sp: int#

Methods#

TangramMapper.generative()[source]#

No generative model here.

Return type:

dict

TangramMapper.inference()[source]#

Run inference model.

Return type:

dict

TangramMapper.loss(tensors, inference_outputs, generative_outputs)[source]#

Compute loss.

TangramMapper.setup()[source]#

Setup model.