A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://mlflow.org/docs/latest/python_api/../_modules/mlflow/tracking/_model_registry/utils.html below:

mlflow.tracking._model_registry.utils

from functools import partial
from typing import Optional

from mlflow.environment_variables import MLFLOW_REGISTRY_URI
from mlflow.store.db.db_types import DATABASE_ENGINES
from mlflow.store.model_registry.databricks_workspace_model_registry_rest_store import (
    DatabricksWorkspaceModelRegistryRestStore,
)
from mlflow.store.model_registry.file_store import FileStore
from mlflow.store.model_registry.rest_store import RestStore
from mlflow.tracking._model_registry.registry import ModelRegistryStoreRegistry
from mlflow.tracking._tracking_service.utils import (
    _resolve_tracking_uri,
)
from mlflow.utils._spark_utils import _get_active_spark_session
from mlflow.utils.credentials import get_default_host_creds
from mlflow.utils.databricks_utils import (
    get_databricks_host_creds,
    is_in_databricks_serverless_runtime,
    warn_on_deprecated_cross_workspace_registry_uri,
)
from mlflow.utils.uri import (
    _DATABRICKS_UNITY_CATALOG_SCHEME,
    _OSS_UNITY_CATALOG_SCHEME,
    construct_db_uc_uri_from_profile,
    get_db_info_from_uri,
    is_databricks_uri,
)

# NOTE: in contrast to tracking, we do not support the following ways to specify
# the model registry URI:
#  - via environment variables like MLFLOW_TRACKING_URI, MLFLOW_TRACKING_USERNAME, ...
# We do support specifying it
#  - via the ``model_registry_uri`` parameter when creating an ``MlflowClient`` or
#    ``ModelRegistryClient``.
#  - via a utility method ``mlflow.set_registry_uri``
#  - by not specifying anything: in this case we assume the model registry store URI is
#    the same as the tracking store URI. This means Tracking and Model Registry are
#    backed by the same backend DB/Rest server. However, note that we access them via
#    different ``Store`` classes (e.g. ``mlflow.store.tracking.SQLAlchemyStore`` &
#    ``mlflow.store.model_registry.SQLAlchemyStore``).
# This means the following combinations are not supported:
#  - Tracking RestStore & Model Registry RestStore that use different credentials.

_registry_uri = None


[docs]def set_registry_uri(uri: str) -> None:
    """Set the registry server URI. This method is especially useful if you have a registry server
    that's different from the tracking server.

    Args:
        uri: An empty string, or a local file path, prefixed with ``file:/``. Data is stored
            locally at the provided file (or ``./mlruns`` if empty). An HTTP URI like
            ``https://my-tracking-server:5000`` or ``http://my-oss-uc-server:8080``. A Databricks
            workspace, provided as the string "databricks" or, to use a Databricks CLI
            `profile <https://github.com/databricks/databricks-cli#installation>`_,
            "databricks://<profileName>".

    .. code-block:: python
        :caption: Example

        import mflow

        # Set model registry uri, fetch the set uri, and compare
        # it with the tracking uri. They should be different
        mlflow.set_registry_uri("sqlite:////tmp/registry.db")
        mr_uri = mlflow.get_registry_uri()
        print(f"Current registry uri: {mr_uri}")
        tracking_uri = mlflow.get_tracking_uri()
        print(f"Current tracking uri: {tracking_uri}")

        # They should be different
        assert tracking_uri != mr_uri

    .. code-block:: text
        :caption: Output

        Current registry uri: sqlite:////tmp/registry.db
        Current tracking uri: file:///.../mlruns

    """
    global _registry_uri
    _registry_uri = uri
    if uri:
        # Set 'MLFLOW_REGISTRY_URI' environment variable
        # so that subprocess can inherit it.
        MLFLOW_REGISTRY_URI.set(_registry_uri)


def _get_registry_uri_from_spark_session():
    session = _get_active_spark_session()
    if session is None:
        return None

    if is_in_databricks_serverless_runtime():
        # Connected to Serverless
        return "databricks-uc"

    from pyspark.sql.utils import AnalysisException

    try:
        return session.conf.get("spark.mlflow.modelRegistryUri", None)
    except AnalysisException:
        # In serverless clusters, session.conf.get() is unsupported
        # and raises an AnalysisException. We may encounter this case
        # when DBConnect is used to connect to a serverless cluster,
        # in which case the prior `is_in_databricks_serverless_runtime()`
        # check will have returned false (as of 2025-06-07, it checks
        # an environment variable that isn't set by DBConnect)
        return None


def _get_registry_uri_from_context():
    if _registry_uri is not None:
        return _registry_uri
    elif (uri := MLFLOW_REGISTRY_URI.get()) or (uri := _get_registry_uri_from_spark_session()):
        return uri
    return _registry_uri


def _get_default_registry_uri_for_tracking_uri(tracking_uri: Optional[str]) -> Optional[str]:
    """
    Get the default registry URI for a given tracking URI.

    If the tracking URI starts with "databricks", returns "databricks-uc" with profile if present.
    Otherwise, returns the tracking URI itself.

    Args:
        tracking_uri: The tracking URI to get the default registry URI for

    Returns:
        The default registry URI
    """
    if tracking_uri is not None and is_databricks_uri(tracking_uri):
        # If the tracking URI is "databricks", we impute the registry URI as "databricks-uc"
        # corresponding to Databricks Unity Catalog Model Registry, which is the recommended
        # model registry offering on Databricks
        if tracking_uri == "databricks":
            return _DATABRICKS_UNITY_CATALOG_SCHEME
        else:
            # Extract profile from tracking URI and construct databricks-uc URI
            profile, key_prefix = get_db_info_from_uri(tracking_uri)
            if profile:
                # Reconstruct the profile string including key_prefix if present
                profile_string = f"{profile}:{key_prefix}" if key_prefix else profile
                return construct_db_uc_uri_from_profile(profile_string)
            else:
                return _DATABRICKS_UNITY_CATALOG_SCHEME

    # For non-databricks tracking URIs, use the tracking URI as the registry URI
    return tracking_uri


[docs]def get_registry_uri() -> str:
    """Get the current registry URI. If none has been specified, defaults to the tracking URI.

    Returns:
        The registry URI.

    .. code-block:: python

        # Get the current model registry uri
        mr_uri = mlflow.get_registry_uri()
        print(f"Current model registry uri: {mr_uri}")

        # Get the current tracking uri
        tracking_uri = mlflow.get_tracking_uri()
        print(f"Current tracking uri: {tracking_uri}")

        # They should be the same
        assert mr_uri == tracking_uri

    .. code-block:: text

        Current model registry uri: file:///.../mlruns
        Current tracking uri: file:///.../mlruns

    """
    return _resolve_registry_uri()


def _resolve_registry_uri(
    registry_uri: Optional[str] = None, tracking_uri: Optional[str] = None
) -> Optional[str]:
    """
    Resolve the registry URI following the same logic as get_registry_uri().
    """
    return (
        registry_uri
        or _get_registry_uri_from_context()
        or _get_default_registry_uri_for_tracking_uri(_resolve_tracking_uri(tracking_uri))
    )


def _get_sqlalchemy_store(store_uri):
    from mlflow.store.model_registry.sqlalchemy_store import SqlAlchemyStore

    return SqlAlchemyStore(store_uri)


def _get_rest_store(store_uri, **_):
    return RestStore(partial(get_default_host_creds, store_uri))


def _get_databricks_rest_store(store_uri, **_):
    warn_on_deprecated_cross_workspace_registry_uri(registry_uri=store_uri)
    return DatabricksWorkspaceModelRegistryRestStore(partial(get_databricks_host_creds, store_uri))


# We define the global variable as `None` so that instantiating the store does not lead to circular
# dependency issues.
_model_registry_store_registry = None


def _get_file_store(store_uri, **_):
    return FileStore(store_uri)


def _get_store_registry():
    global _model_registry_store_registry
    from mlflow.store._unity_catalog.registry.rest_store import UcModelRegistryStore
    from mlflow.store._unity_catalog.registry.uc_oss_rest_store import UnityCatalogOssStore

    if _model_registry_store_registry is not None:
        return _model_registry_store_registry

    _model_registry_store_registry = ModelRegistryStoreRegistry()
    _model_registry_store_registry.register("databricks", _get_databricks_rest_store)
    # Register a placeholder function that raises if users pass a registry URI with scheme
    # "databricks-uc"
    _model_registry_store_registry.register(_DATABRICKS_UNITY_CATALOG_SCHEME, UcModelRegistryStore)
    _model_registry_store_registry.register(_OSS_UNITY_CATALOG_SCHEME, UnityCatalogOssStore)

    for scheme in ["http", "https"]:
        _model_registry_store_registry.register(scheme, _get_rest_store)

    for scheme in DATABASE_ENGINES:
        _model_registry_store_registry.register(scheme, _get_sqlalchemy_store)

    for scheme in ["", "file"]:
        _model_registry_store_registry.register(scheme, _get_file_store)

    _model_registry_store_registry.register_entrypoints()
    return _model_registry_store_registry


def _get_store(store_uri=None, tracking_uri=None):
    return _get_store_registry().get_store(store_uri, tracking_uri)

RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4