scvi.distributions.JaxNegativeBinomialMeanDisp#
- class scvi.distributions.JaxNegativeBinomialMeanDisp(mean, inverse_dispersion, validate_args=None, eps=1e-08)[source]#
Negative binomial parameterized by mean and inverse dispersion.
Attributes table#
Returns the shape over which the distribution parameters are batched. |
|
Number of dimensions of individual events. |
|
Returns the shape of a single sample from the distribution without batching. |
|
Mean of the distribution. |
|
Variance of the distribution. |
Methods table#
|
The cummulative distribution function of this distribution. |
|
Returns an array with shape len(support) x batch_shape containing all values in the support. |
|
Returns a new |
|
Expands a distribution by adding |
|
The inverse cumulative distribution function of this distribution. |
|
Infers |
|
Evaluates the log probability density for a batch of samples given by value. |
|
Masks a distribution by a boolean or boolean-valued array that is broadcastable to the distributions |
|
|
|
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. |
|
Same as |
|
|
|
The tensor shape of samples from this distribution. |
|
Interpret the rightmost reinterpreted_batch_ndims batch dimensions as dependent event dimensions. |
|
Attributes#
arg_constraints#
- JaxNegativeBinomialMeanDisp.arg_constraints = {'inverse_dispersion': <numpyro.distributions.constraints._GreaterThan object>, 'mean': <numpyro.distributions.constraints._GreaterThan object>}#
batch_shape#
event_dim#
- JaxNegativeBinomialMeanDisp.event_dim#
Number of dimensions of individual events. :rtype: int
- Type
return
event_shape#
has_enumerate_support#
- JaxNegativeBinomialMeanDisp.has_enumerate_support = False#
has_rsample#
- JaxNegativeBinomialMeanDisp.has_rsample#
inverse_dispersion#
- JaxNegativeBinomialMeanDisp.inverse_dispersion#
is_discrete#
- JaxNegativeBinomialMeanDisp.is_discrete#
mean#
- JaxNegativeBinomialMeanDisp.mean#
reparametrized_params#
- JaxNegativeBinomialMeanDisp.reparametrized_params = []#
support#
- JaxNegativeBinomialMeanDisp.support = <numpyro.distributions.constraints._IntegerGreaterThan object>#
variance#
- JaxNegativeBinomialMeanDisp.variance#
Methods#
cdf#
- JaxNegativeBinomialMeanDisp.cdf(value)#
The cummulative distribution function of this distribution.
- Parameters
- value
samples from this distribution.
- Returns
output of the cummulative distribution function evaluated at value.
enumerate_support#
- JaxNegativeBinomialMeanDisp.enumerate_support(expand=True)#
Returns an array with shape len(support) x batch_shape containing all values in the support.
expand#
expand_by#
- JaxNegativeBinomialMeanDisp.expand_by(sample_shape)#
Expands a distribution by adding
sample_shape
to the left side of itsbatch_shape
. To expand internal dims ofself.batch_shape
from 1 to something larger, useexpand()
instead.- Parameters
- sample_shape : tuple
The size of the iid batch to be drawn from the distribution.
- Returns
An expanded version of this distribution.
- Return type
ExpandedDistribution
icdf#
- JaxNegativeBinomialMeanDisp.icdf(q)#
The inverse cumulative distribution function of this distribution.
- Parameters
- q
quantile values, should belong to [0, 1].
- Returns
the samples whose cdf values equals to q.
infer_shapes#
- classmethod JaxNegativeBinomialMeanDisp.infer_shapes(*args, **kwargs)#
Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
- Parameters
- *args
Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
- **kwargs
Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
- Returns
A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.- Return type
log_prob#
- JaxNegativeBinomialMeanDisp.log_prob(*args, **kwargs)#
Evaluates the log probability density for a batch of samples given by value.
- Parameters
- value
A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
mask#
- JaxNegativeBinomialMeanDisp.mask(mask)#
Masks a distribution by a boolean or boolean-valued array that is broadcastable to the distributions
Distribution.batch_shape
.- Parameters
- mask : bool or jnp.ndarray
A boolean or boolean valued array (True includes a site, False excludes a site).
- Returns
A masked copy of this distribution.
- Return type
MaskedDistribution
Example:
rsample#
- JaxNegativeBinomialMeanDisp.rsample(key, sample_shape=())#
sample#
- JaxNegativeBinomialMeanDisp.sample(key, sample_shape=())#
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
- key : jax.random.PRNGKey
the rng_key key to be used for the distribution.
- sample_shape : tuple
the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
sample_with_intermediates#
- JaxNegativeBinomialMeanDisp.sample_with_intermediates(key, sample_shape=())#
Same as
sample
except that any intermediate computations are returned (useful for TransformedDistribution).- Parameters
- key : jax.random.PRNGKey
the rng_key key to be used for the distribution.
- sample_shape : tuple
the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
set_default_validate_args#
- static JaxNegativeBinomialMeanDisp.set_default_validate_args(value)#
shape#
- JaxNegativeBinomialMeanDisp.shape(sample_shape=())#
The tensor shape of samples from this distribution.
Samples are of shape:
d.shape(sample_shape) == sample_shape + d.batch_shape + d.event_shape
to_event#
- JaxNegativeBinomialMeanDisp.to_event(reinterpreted_batch_ndims=None)#
Interpret the rightmost reinterpreted_batch_ndims batch dimensions as dependent event dimensions.
- Parameters
- reinterpreted_batch_ndims
Number of rightmost batch dims to interpret as event dims.
- Returns
An instance of Independent distribution.
- Return type
numpyro.distributions.distribution.Independent
tree_flatten#
- JaxNegativeBinomialMeanDisp.tree_flatten()#
tree_unflatten#
- classmethod JaxNegativeBinomialMeanDisp.tree_unflatten(aux_data, params)#