scvi.distributions.JaxNegativeBinomialMeanDisp

Contents

scvi.distributions.JaxNegativeBinomialMeanDisp#

class scvi.distributions.JaxNegativeBinomialMeanDisp(mean, inverse_dispersion, validate_args=None, eps=1e-08)[source]#

Negative binomial parameterized by mean and inverse dispersion.

Attributes table#

arg_constraints

batch_shape

Returns the shape over which the distribution parameters are batched.

event_dim

Number of dimensions of individual events.

event_shape

Returns the shape of a single sample from the distribution without batching.

has_enumerate_support

has_rsample

inverse_dispersion

is_discrete

mean

If \(X \sim \mathrm{GammaPoisson}(\alpha, \lambda)\), then the mean is:

pytree_aux_fields

pytree_data_fields

reparametrized_params

support(x)

variance

If \(X \sim \mathrm{GammaPoisson}(\alpha, \lambda)\), then the variance is:

Methods table#

cdf(value)

If \(X \sim \mathrm{GammaPoisson}(\alpha, \lambda)\), then the cumulative distribution function is:

entropy()

Returns the entropy of the distribution.

enumerate_support([expand])

Returns an array with shape len(support) x batch_shape containing all values in the support.

expand(batch_shape)

Returns a new ExpandedDistribution instance with batch dimensions expanded to batch_shape.

expand_by(sample_shape)

Expands a distribution by adding sample_shape to the left side of its batch_shape.

gather_pytree_aux_fields()

gather_pytree_data_fields()

get_args()

Get arguments of the distribution.

icdf(q)

The inverse cumulative distribution function of this distribution.

infer_shapes(*args, **kwargs)

Infers batch_shape and event_shape given shapes of args to __init__().

log_prob(value)

Log probability.

mask(mask)

Masks a distribution by a boolean or boolean-valued array that is broadcastable to the distributions Distribution.batch_shape .

rsample(key[, sample_shape])

sample(key[, sample_shape])

If \(X \sim \mathrm{GammaPoisson}(\alpha, \lambda)\), then the sampling procedure is:

sample_with_intermediates(key[, sample_shape])

Same as sample except that any intermediate computations are returned (useful for TransformedDistribution).

set_default_validate_args(value)

shape([sample_shape])

The tensor shape of samples from this distribution.

to_event([reinterpreted_batch_ndims])

Interpret the rightmost reinterpreted_batch_ndims batch dimensions as dependent event dimensions.

tree_flatten()

tree_unflatten(aux_data, params)

validate_args([strict])

Validate the arguments of the distribution.

validate_sample()

Attributes#

JaxNegativeBinomialMeanDisp.arg_constraints: dict[str, Any] = {'inverse_dispersion': Positive(lower_bound=0.0), 'mean': Positive(lower_bound=0.0)}#
JaxNegativeBinomialMeanDisp.batch_shape[source]#

Returns the shape over which the distribution parameters are batched.

Returns:

batch shape of the distribution.

Return type:

tuple[int, …]

JaxNegativeBinomialMeanDisp.event_dim[source]#

Number of dimensions of individual events. :rtype: int

Type:

return

JaxNegativeBinomialMeanDisp.event_shape[source]#

Returns the shape of a single sample from the distribution without batching.

Returns:

event shape of the distribution.

Return type:

tuple[int, …]

JaxNegativeBinomialMeanDisp.has_enumerate_support: bool = False#
JaxNegativeBinomialMeanDisp.has_rsample[source]#
JaxNegativeBinomialMeanDisp.inverse_dispersion[source]#
JaxNegativeBinomialMeanDisp.is_discrete[source]#
JaxNegativeBinomialMeanDisp.mean[source]#
JaxNegativeBinomialMeanDisp.pytree_aux_fields: tuple[str, ...] = ('_batch_shape', '_event_shape')#
JaxNegativeBinomialMeanDisp.pytree_data_fields: tuple[str, ...] = ('concentration',)#
JaxNegativeBinomialMeanDisp.reparametrized_params: list[str] = []#
JaxNegativeBinomialMeanDisp.support(x) = IntegerNonnegative(lower_bound=0)#
JaxNegativeBinomialMeanDisp.variance[source]#

If \(X \sim \mathrm{GammaPoisson}(\alpha, \lambda)\), then the variance is:

\[\mathrm{Var}[X] = \frac{\alpha}{\lambda^2}(1 + \lambda)\]

Methods#

JaxNegativeBinomialMeanDisp.cdf(value)[source]#

If \(X \sim \mathrm{GammaPoisson}(\alpha, \lambda)\), then the cumulative distribution function is:

\[F_{X}(x) = \frac{1}{\mathrm{B}(\alpha, x + 1)} \int_{0}^{\frac{\lambda}{1 + \lambda}} t^{\alpha - 1} (1 - t)^{x} dt\]

which is the regularized incomplete beta function. This implementation uses betainc().

Return type:

Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray]

JaxNegativeBinomialMeanDisp.entropy()[source]#

Returns the entropy of the distribution.

Return type:

Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray]

JaxNegativeBinomialMeanDisp.enumerate_support(expand=True)[source]#

Returns an array with shape len(support) x batch_shape containing all values in the support.

Return type:

Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray]

JaxNegativeBinomialMeanDisp.expand(batch_shape)[source]#

Returns a new ExpandedDistribution instance with batch dimensions expanded to batch_shape.

Parameters:

batch_shape (tuple) – batch shape to expand to.

Returns:

an instance of ExpandedDistribution.

Return type:

ExpandedDistribution

JaxNegativeBinomialMeanDisp.expand_by(sample_shape)[source]#

Expands a distribution by adding sample_shape to the left side of its batch_shape. To expand internal dims of self.batch_shape from 1 to something larger, use expand() instead.

Parameters:

sample_shape (tuple) – The size of the iid batch to be drawn from the distribution.

Returns:

An expanded version of this distribution.

Return type:

ExpandedDistribution

classmethod JaxNegativeBinomialMeanDisp.gather_pytree_aux_fields()[source]#
Return type:

tuple[str, ...]

classmethod JaxNegativeBinomialMeanDisp.gather_pytree_data_fields()[source]#
Return type:

tuple[str, ...]

JaxNegativeBinomialMeanDisp.get_args()[source]#

Get arguments of the distribution.

Return type:

dict[str, Any]

JaxNegativeBinomialMeanDisp.icdf(q)[source]#

The inverse cumulative distribution function of this distribution.

Parameters:

q (Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray]) – quantile values, should belong to [0, 1].

Return type:

Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray]

Returns:

the samples whose cdf values equals to q.

classmethod JaxNegativeBinomialMeanDisp.infer_shapes(*args, **kwargs)[source]#

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args (Any) – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.

  • **kwargs (Any) – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.

Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

JaxNegativeBinomialMeanDisp.log_prob(value)[source]#

Log probability.

Return type:

Array

JaxNegativeBinomialMeanDisp.mask(mask)[source]#

Masks a distribution by a boolean or boolean-valued array that is broadcastable to the distributions Distribution.batch_shape .

Parameters:

mask (bool or jnp.ndarray) – A boolean or boolean valued array (True includes a site, False excludes a site).

Returns:

A masked copy of this distribution.

Return type:

MaskedDistribution

Example:

JaxNegativeBinomialMeanDisp.rsample(key, sample_shape=())[source]#
Return type:

Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray]

JaxNegativeBinomialMeanDisp.sample(key, sample_shape=())[source]#

If \(X \sim \mathrm{GammaPoisson}(\alpha, \lambda)\), then the sampling procedure is:

\[\begin{split}\begin{align*} \theta &\sim \mathrm{Gamma}(\alpha, \lambda) \\ X \mid \theta &\sim \mathrm{Poisson}(\theta) \end{align*}\end{split}\]

It uses Gamma to generate samples from the Gamma distribution and Poisson to generate samples from the Poisson distribution.

Return type:

Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray]

JaxNegativeBinomialMeanDisp.sample_with_intermediates(key, sample_shape=())[source]#

Same as sample except that any intermediate computations are returned (useful for TransformedDistribution).

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

static JaxNegativeBinomialMeanDisp.set_default_validate_args(value)[source]#
Return type:

None

JaxNegativeBinomialMeanDisp.shape(sample_shape=())[source]#

The tensor shape of samples from this distribution.

Samples are of shape:

d.shape(sample_shape) == sample_shape + d.batch_shape + d.event_shape
Parameters:

sample_shape (tuple) – the size of the iid batch to be drawn from the distribution.

Returns:

shape of samples.

Return type:

tuple

JaxNegativeBinomialMeanDisp.to_event(reinterpreted_batch_ndims=None)[source]#

Interpret the rightmost reinterpreted_batch_ndims batch dimensions as dependent event dimensions.

Parameters:

reinterpreted_batch_ndims (Optional[int] (default: None)) – Number of rightmost batch dims to interpret as event dims.

Returns:

An instance of Independent distribution.

Return type:

numpyro.distributions.distribution.Independent

JaxNegativeBinomialMeanDisp.tree_flatten()[source]#
Return type:

tuple[tuple[Any, ...], tuple[Any, ...]]

classmethod JaxNegativeBinomialMeanDisp.tree_unflatten(aux_data, params)[source]#
Return type:

Distribution

JaxNegativeBinomialMeanDisp.validate_args(strict=True)[source]#

Validate the arguments of the distribution.

Parameters:

strict (bool (default: True)) – Require strict validation, raising an error if the function is called inside jitted code.

Return type:

None

JaxNegativeBinomialMeanDisp.validate_sample()[source]#