# scvi.distributions.JaxNegativeBinomialMeanDisp#

class scvi.distributions.JaxNegativeBinomialMeanDisp(mean, inverse_dispersion, validate_args=None, eps=1e-08)[source]#

Negative binomial parameterized by mean and inverse dispersion.

## Attributes table#

 arg_constraints batch_shape Returns the shape over which the distribution parameters are batched. event_dim Number of dimensions of individual events. event_shape Returns the shape of a single sample from the distribution without batching. has_enumerate_support has_rsample inverse_dispersion is_discrete mean Mean of the distribution. reparametrized_params support variance Variance of the distribution.

## Methods table#

 cdf(value) The cummulative distribution function of this distribution. enumerate_support([expand]) Returns an array with shape len(support) x batch_shape containing all values in the support. expand(batch_shape) Returns a new ExpandedDistribution instance with batch dimensions expanded to batch_shape. expand_by(sample_shape) Expands a distribution by adding sample_shape to the left side of its batch_shape. The inverse cumulative distribution function of this distribution. infer_shapes(*args, **kwargs) Infers batch_shape and event_shape given shapes of args to __init__(). log_prob(*args, **kwargs) Evaluates the log probability density for a batch of samples given by value. mask(mask) Masks a distribution by a boolean or boolean-valued array that is broadcastable to the distributions Distribution.batch_shape . rsample(key[, sample_shape]) sample(key[, sample_shape]) Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. sample_with_intermediates(key[, sample_shape]) Same as sample except that any intermediate computations are returned (useful for TransformedDistribution). shape([sample_shape]) The tensor shape of samples from this distribution. to_event([reinterpreted_batch_ndims]) Interpret the rightmost reinterpreted_batch_ndims batch dimensions as dependent event dimensions. tree_unflatten(aux_data, params)

## Attributes#

### arg_constraints#

JaxNegativeBinomialMeanDisp.arg_constraints = {'inverse_dispersion': GreaterThan(lower_bound=0.0), 'mean': GreaterThan(lower_bound=0.0)}#

### batch_shape#

JaxNegativeBinomialMeanDisp.batch_shape[source]#

Returns the shape over which the distribution parameters are batched.

Returns:

batch shape of the distribution.

Return type:

tuple

### event_dim#

JaxNegativeBinomialMeanDisp.event_dim[source]#

Number of dimensions of individual events. :rtype: int

Type:

return

### event_shape#

JaxNegativeBinomialMeanDisp.event_shape[source]#

Returns the shape of a single sample from the distribution without batching.

Returns:

event shape of the distribution.

Return type:

tuple

### has_enumerate_support#

JaxNegativeBinomialMeanDisp.has_enumerate_support = False#

### has_rsample#

JaxNegativeBinomialMeanDisp.has_rsample[source]#

### inverse_dispersion#

JaxNegativeBinomialMeanDisp.inverse_dispersion[source]#

### is_discrete#

JaxNegativeBinomialMeanDisp.is_discrete[source]#

### mean#

JaxNegativeBinomialMeanDisp.mean[source]#

### reparametrized_params#

JaxNegativeBinomialMeanDisp.reparametrized_params = []#

### support#

JaxNegativeBinomialMeanDisp.support = IntegerGreaterThan(lower_bound=0)#

### variance#

JaxNegativeBinomialMeanDisp.variance[source]#

## Methods#

### cdf#

JaxNegativeBinomialMeanDisp.cdf(value)[source]#

The cummulative distribution function of this distribution.

Parameters:
value

samples from this distribution.

Returns:

output of the cummulative distribution function evaluated at value.

### enumerate_support#

JaxNegativeBinomialMeanDisp.enumerate_support(expand=True)[source]#

Returns an array with shape len(support) x batch_shape containing all values in the support.

### expand#

JaxNegativeBinomialMeanDisp.expand(batch_shape)[source]#

Returns a new ExpandedDistribution instance with batch dimensions expanded to batch_shape.

Parameters:
batch_shape : tuple

batch shape to expand to.

Returns:

an instance of ExpandedDistribution.

Return type:

ExpandedDistribution

### expand_by#

JaxNegativeBinomialMeanDisp.expand_by(sample_shape)[source]#

Expands a distribution by adding sample_shape to the left side of its batch_shape. To expand internal dims of self.batch_shape from 1 to something larger, use expand() instead.

Parameters:
sample_shape : tuple

The size of the iid batch to be drawn from the distribution.

Returns:

An expanded version of this distribution.

Return type:

ExpandedDistribution

### icdf#

JaxNegativeBinomialMeanDisp.icdf(q)[source]#

The inverse cumulative distribution function of this distribution.

Parameters:
q

quantile values, should belong to [0, 1].

Returns:

the samples whose cdf values equals to q.

### infer_shapes#

classmethod JaxNegativeBinomialMeanDisp.infer_shapes(*args, **kwargs)[source]#

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
*args

Positional args replacing each input arg with a tuple representing the sizes of each tensor input.

**kwargs

Keywords mapping name of input arg to tuple representing the sizes of each tensor input.

Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

### log_prob#

JaxNegativeBinomialMeanDisp.log_prob(*args, **kwargs)[source]#

Evaluates the log probability density for a batch of samples given by value.

Parameters:
value

A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

numpy.ndarray

Masks a distribution by a boolean or boolean-valued array that is broadcastable to the distributions Distribution.batch_shape .

Parameters:
mask : bool or jnp.ndarray

A boolean or boolean valued array (True includes a site, False excludes a site).

Returns:

A masked copy of this distribution.

Return type:

MaskedDistribution

Example:

### rsample#

JaxNegativeBinomialMeanDisp.rsample(key, sample_shape=())[source]#

### sample#

JaxNegativeBinomialMeanDisp.sample(key, sample_shape=())[source]#

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
key : jax.random.PRNGKey

the rng_key key to be used for the distribution.

sample_shape : tuple

the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

### sample_with_intermediates#

JaxNegativeBinomialMeanDisp.sample_with_intermediates(key, sample_shape=())[source]#

Same as sample except that any intermediate computations are returned (useful for TransformedDistribution).

Parameters:
key : jax.random.PRNGKey

the rng_key key to be used for the distribution.

sample_shape : tuple

the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

### set_default_validate_args#

static JaxNegativeBinomialMeanDisp.set_default_validate_args(value)[source]#

### shape#

JaxNegativeBinomialMeanDisp.shape(sample_shape=())[source]#

The tensor shape of samples from this distribution.

Samples are of shape:

d.shape(sample_shape) == sample_shape + d.batch_shape + d.event_shape

Parameters:
sample_shape : tuple

the size of the iid batch to be drawn from the distribution.

Returns:

shape of samples.

Return type:

tuple

### to_event#

JaxNegativeBinomialMeanDisp.to_event(reinterpreted_batch_ndims=None)[source]#

Interpret the rightmost reinterpreted_batch_ndims batch dimensions as dependent event dimensions.

Parameters:
reinterpreted_batch_ndims

Number of rightmost batch dims to interpret as event dims.

Returns:

An instance of Independent distribution.

Return type:

numpyro.distributions.distribution.Independent

### tree_flatten#

JaxNegativeBinomialMeanDisp.tree_flatten()[source]#

### tree_unflatten#

classmethod JaxNegativeBinomialMeanDisp.tree_unflatten(aux_data, params)[source]#