Source code for pymc_marketing.prior
# Copyright 2022 - 2025 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class that represents a prior distribution.
The `Prior` class is a wrapper around PyMC distributions that allows the user
to create outside of the PyMC model.
.. note::
This module has been deprecated and is moved to `pymc_extras.prior`.
This is the alternative to using the dictionaries in PyMC-Marketing models.
Examples
--------
Create a normal prior.
.. code-block:: python
from pymc_extras.prior import Prior
normal = Prior("Normal")
Create a hierarchical normal prior by using distributions for the parameters
and specifying the dims.
.. code-block:: python
hierarchical_normal = Prior(
"Normal",
mu=Prior("Normal"),
sigma=Prior("HalfNormal"),
dims="channel",
)
Create a non-centered hierarchical normal prior with the `centered` parameter.
.. code-block:: python
non_centered_hierarchical_normal = Prior(
"Normal",
mu=Prior("Normal"),
sigma=Prior("HalfNormal"),
dims="channel",
# Only change needed to make it non-centered
centered=False,
)
Create a hierarchical beta prior by using Beta distribution, distributions for
the parameters, and specifying the dims.
.. code-block:: python
hierarchical_beta = Prior(
"Beta",
alpha=Prior("HalfNormal"),
beta=Prior("HalfNormal"),
dims="channel",
)
Create a transformed hierarchical normal prior by using the `transform`
parameter. Here the "sigmoid" transformation comes from `pm.math`.
.. code-block:: python
transformed_hierarchical_normal = Prior(
"Normal",
mu=Prior("Normal"),
sigma=Prior("HalfNormal"),
transform="sigmoid",
dims="channel",
)
Create a prior with a custom transform function by registering it with
`register_tensor_transform`.
.. code-block:: python
from pymc_extras.prior import register_tensor_transform
def custom_transform(x):
return x**2
register_tensor_transform("square", custom_transform)
custom_distribution = Prior("Normal", transform="square")
"""
from __future__ import annotations
import copy
from typing import Any
from pymc_extras import prior
from pymc_extras.deserialize import deserialize, register_deserialization
[docs]
def is_alternative_prior(data: Any) -> bool:
"""Check if the data is a dictionary representing a Prior (alternative check)."""
return isinstance(data, dict) and "distribution" in data
[docs]
def deserialize_alternative_prior(data: dict[str, Any]) -> prior.Prior:
"""Alternative deserializer that recursively handles all nested parameters.
This implementation is more general and handles cases where any parameter
might be a nested prior, and also extracts centered and transform parameters.
Examples
--------
This handles cases like:
.. code-block:: yaml
distribution: Gamma
alpha: 1
beta:
distribution: HalfNormal
sigma: 1
dims: channel
dims: [brand, channel]
"""
data = copy.deepcopy(data)
distribution = data.pop("distribution")
dims = data.pop("dims", None)
centered = data.pop("centered", True)
transform = data.pop("transform", None)
parameters = data
# Recursively deserialize any nested parameters
parameters = {
key: value if not isinstance(value, dict) else deserialize(value)
for key, value in parameters.items()
}
return prior.Prior(
distribution,
transform=transform,
centered=centered,
dims=dims,
**parameters,
)
# Register the alternative prior deserializer for more complex nested cases
register_deserialization(is_alternative_prior, deserialize_alternative_prior)