# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
Built-in distribution-creation functions.
"""
from warnings import warn
import numpy as np
from astropy import units as u
from .core import Distribution
__all__ = ["normal", "poisson", "uniform"]
[docs]
def normal(
    center, *, std=None, var=None, ivar=None, n_samples, cls=Distribution, **kwargs
):
    """
    Create a Gaussian/normal distribution.
    Parameters
    ----------
    center : `~astropy.units.Quantity`
        The center of this distribution
    std : `~astropy.units.Quantity` or None
        The standard deviation/σ of this distribution. Shape must match and unit
        must be compatible with ``center``, or be `None` (if ``var`` or ``ivar``
        are set).
    var : `~astropy.units.Quantity` or None
        The variance of this distribution. Shape must match and unit must be
        compatible with ``center``, or be `None` (if ``std`` or ``ivar`` are set).
    ivar : `~astropy.units.Quantity` or None
        The inverse variance of this distribution. Shape must match and unit
        must be compatible with ``center``, or be `None` (if ``std`` or ``var``
        are set).
    n_samples : int
        The number of Monte Carlo samples to use with this distribution
    cls : class
        The class to use to create this distribution.  Typically a
        `Distribution` subclass.
    Remaining keywords are passed into the constructor of the ``cls``
    Returns
    -------
    distr : `~astropy.uncertainty.Distribution` or object
        The sampled Gaussian distribution.
        The type will be the same as the parameter ``cls``.
    """
    center = np.asanyarray(center)
    if var is not None:
        if std is None:
            std = np.asanyarray(var) ** 0.5
        else:
            raise ValueError("normal cannot take both std and var")
    if ivar is not None:
        if std is None:
            std = np.asanyarray(ivar) ** -0.5
        else:
            raise ValueError("normal cannot take both ivar and and std or var")
    if std is None:
        raise ValueError("normal requires one of std, var, or ivar")
    else:
        std = np.asanyarray(std)
    randshape = np.broadcast(std, center).shape + (n_samples,)
    samples = (
        center[..., np.newaxis] + np.random.randn(*randshape) * std[..., np.newaxis]
    )
    return cls(samples, **kwargs) 
COUNT_UNITS = (
    u.count,
    u.electron,
    u.dimensionless_unscaled,
    u.chan,
    u.bin,
    u.vox,
    u.bit,
    u.byte,
)
[docs]
def poisson(center, n_samples, cls=Distribution, **kwargs):
    """
    Create a Poisson distribution.
    Parameters
    ----------
    center : `~astropy.units.Quantity`
        The center value of this distribution (i.e., λ).
    n_samples : int
        The number of Monte Carlo samples to use with this distribution
    cls : class
        The class to use to create this distribution.  Typically a
        `Distribution` subclass.
    Remaining keywords are passed into the constructor of the ``cls``
    Returns
    -------
    distr : `~astropy.uncertainty.Distribution` or object
        The sampled Poisson distribution.
        The type will be the same as the parameter ``cls``.
    """
    # we convert to arrays because np.random.poisson has trouble with quantities
    has_unit = False
    if hasattr(center, "unit"):
        has_unit = True
        poissonarr = np.asanyarray(center.value)
    else:
        poissonarr = np.asanyarray(center)
    randshape = poissonarr.shape + (n_samples,)
    samples = np.random.poisson(poissonarr[..., np.newaxis], randshape)
    if has_unit:
        if center.unit == u.adu:
            warn(
                "ADUs were provided to poisson.  ADUs are not strictly count"
                "units because they need the gain to be applied. It is "
                "recommended you apply the gain to convert to e.g. electrons."
            )
        elif center.unit not in COUNT_UNITS:
            warn(
                f"Unit {center.unit} was provided to poisson, which is not one of"
                f' {COUNT_UNITS}, and therefore suspect as a "counting" unit.  Ensure'
                " you mean to use Poisson statistics."
            )
        # re-attach the unit
        samples = samples * center.unit
    return cls(samples, **kwargs)