Source code for ai4drpm.services.shared.mlflow_service

"""GitLab MLflow integration service.

Provides unified interface for experiment tracking and model registry
using GitLab's MLflow-compatible API.

GitLab acts as both experiment tracking server and model registry,
storing runs, metrics, parameters, and model artifacts.

Usage:
    from ai4drpm.services.shared.mlflow_service import get_mlflow_service
    
    service = get_mlflow_service()
    
    if service.enabled:
        with service.start_run("ai4drpm-classifiers", "data-1.0.1"):
            mlflow.log_param("C", 10)
            mlflow.log_metric("f1_score", 0.847)
            mlflow.sklearn.log_model(model, artifact_path="")
"""

# Suppress GitPython warnings in environments without git installed
import os

os.environ.setdefault("GIT_PYTHON_REFRESH", "quiet")

import logging
from contextlib import contextmanager
from typing import Any, Dict, Generator, Optional

from ai4drpm import config

logger = logging.getLogger(__name__)


[docs] class GitLabMLflowService: """Service for GitLab MLflow experiment tracking and model registry. Uses GitLab as the MLflow backend via their API v4 compatibility layer. Supports both experiment tracking and model registry operations. Features: - Experiment tracking with runs, parameters, metrics - Model registry with semantic versioning - Automatic CI/CD job linking when running in GitLab CI - Graceful degradation when MLflow is not configured Attributes: EXPERIMENT_CLASSIFIERS: Experiment name for classifier training EXPERIMENT_LLM_PIPELINES: Experiment name for LLM pipeline tracking EXPERIMENT_ML_INFERENCE: Experiment name for ML inference tracking """ # Experiment name constants (from config) EXPERIMENT_CLASSIFIERS = config.MLFLOW_EXPERIMENT_CLASSIFIERS EXPERIMENT_LLM_PIPELINES = config.MLFLOW_EXPERIMENT_LLM_PIPELINES EXPERIMENT_ML_INFERENCE = config.MLFLOW_EXPERIMENT_ML_INFERENCE
[docs] def __init__(self): """Initialize with environment configuration. Reads MLFLOW_TRACKING_URI and MLFLOW_TRACKING_TOKEN from config. If both are set, configures MLflow to use GitLab backend. """ self.tracking_uri = config.MLFLOW_TRACKING_URI self.tracking_token = config.MLFLOW_TRACKING_TOKEN self._configured = False self._mlflow = None # Lazy import if self.tracking_uri and self.tracking_token: self._configure() else: logger.info( "MLflow tracking disabled. Set MLFLOW_TRACKING_URI and " "MLFLOW_TRACKING_TOKEN to enable GitLab Model Registry integration." )
def _configure(self) -> None: """Configure MLflow to use GitLab backend. Sets the tracking URI and token environment variable. The token is passed via env var as required by MLflow client. """ try: import mlflow self._mlflow = mlflow # GitLab requires token in env var os.environ["MLFLOW_TRACKING_TOKEN"] = self.tracking_token mlflow.set_tracking_uri(self.tracking_uri) self._configured = True # Log truncated URI for security uri_display = self.tracking_uri[:60] + "..." if len(self.tracking_uri) > 60 else self.tracking_uri logger.info(f"MLflow configured with GitLab tracking: {uri_display}") except ImportError: logger.error("mlflow package not installed. Run: poetry add mlflow") self._configured = False except Exception as e: logger.error(f"Failed to configure MLflow: {e}") self._configured = False @property def enabled(self) -> bool: """Check if MLflow tracking is enabled and configured. Returns: True if MLflow is configured and ready to use """ return self._configured @property def mlflow(self): """Get the mlflow module. Returns: The mlflow module for direct API access Raises: RuntimeError: If MLflow is not configured """ if not self._configured or self._mlflow is None: raise RuntimeError( "MLflow not configured. Set MLFLOW_TRACKING_URI and " "MLFLOW_TRACKING_TOKEN environment variables." ) return self._mlflow @property def client(self): """Get MLflow client for registry operations. Returns: MlflowClient instance for model registry operations Raises: RuntimeError: If MLflow is not configured """ from mlflow import MlflowClient return MlflowClient()
[docs] @contextmanager def start_run( self, experiment_name: str, run_name: Optional[str] = None, nested: bool = False, tags: Optional[Dict[str, str]] = None ) -> Generator[Optional[Any], None, None]: """Context manager for MLflow run with GitLab CI integration. Automatically creates the experiment if it doesn't exist. Tags the run with CI job information when running in GitLab CI. Args: experiment_name: Name of the experiment (e.g., "ai4drpm-classifiers") run_name: Optional name for this run (e.g., "data-1.0.1") nested: If True, allows nested runs within an active run tags: Optional dict of tags to add to the run Yields: MLflow Run object if enabled, None otherwise Example: with service.start_run("ai4drpm-classifiers", "data-1.0.1") as run: if run: mlflow.log_param("C", 10) mlflow.log_metric("f1_score", 0.847) """ if not self.enabled: # No-op context manager when disabled logger.debug("MLflow disabled, skipping run tracking") yield None return mlflow = self.mlflow # Set or create experiment mlflow.set_experiment(experiment_name) with mlflow.start_run(run_name=run_name, nested=nested) as run: # Add GitLab CI integration tags self._add_ci_tags() # Add custom tags if tags: mlflow.set_tags(tags) logger.debug(f"Started MLflow run: {run.info.run_id} ({run_name})") yield run logger.debug(f"Finished MLflow run: {run.info.run_id}")
def _add_ci_tags(self) -> None: """Add GitLab CI environment tags to current run. Automatically detects GitLab CI environment and adds relevant tags to link the run to the pipeline, job, and commit. """ if not os.getenv("GITLAB_CI"): return mlflow = self.mlflow ci_tags = { "gitlab.CI_JOB_ID": os.getenv("CI_JOB_ID"), "gitlab.CI_PIPELINE_ID": os.getenv("CI_PIPELINE_ID"), "gitlab.CI_COMMIT_SHA": os.getenv("CI_COMMIT_SHA"), "gitlab.CI_COMMIT_REF_NAME": os.getenv("CI_COMMIT_REF_NAME"), "gitlab.CI_PROJECT_PATH": os.getenv("CI_PROJECT_PATH"), } # Filter out None values ci_tags = {k: v for k, v in ci_tags.items() if v is not None} if ci_tags: mlflow.set_tags(ci_tags) logger.debug(f"Added GitLab CI tags: {list(ci_tags.keys())}") # ========================================================================= # Model Registry Operations # =========================================================================
[docs] def list_registered_models(self) -> list[dict]: """List all registered models from GitLab Model Registry. Supports two discovery methods controlled by MLFLOW_REGISTRY_DISCOVERY_METHOD: - "http": Direct REST API call (default, works with GitLab) - "mlflow_client": Uses MLflow Python client (for when GitLab fixes compatibility) Returns: List of dicts with 'name' and 'latest_version' keys. Empty list if MLflow is disabled or the request fails. """ if not self.enabled: return [] method = config.MLFLOW_REGISTRY_DISCOVERY_METHOD if method == "mlflow_client": return self._list_models_via_client() else: return self._list_models_via_http()
def _list_models_via_http(self) -> list[dict]: """List models using direct HTTP call to MLflow REST API. This bypasses the MLflow Python client which sends parameters that GitLab's compatibility layer rejects. """ import requests url = f"{self.tracking_uri.rstrip('/')}{config.MLFLOW_REGISTRY_SEARCH_PATH}" headers = {"Authorization": f"Bearer {self.tracking_token}"} try: response = requests.get(url, headers=headers, timeout=30) response.raise_for_status() data = response.json() results = [] for model_data in data.get("registered_models", []): name = model_data.get("name") if not name: continue latest_version = None latest_versions = model_data.get("latest_versions", []) if latest_versions: latest_version = latest_versions[-1].get("version") results.append({ "name": name, "latest_version": latest_version, }) return results except Exception as e: logger.warning(f"Failed to list registered models via HTTP: {e}") return [] def _list_models_via_client(self) -> list[dict]: """List models using the MLflow Python client. Use when GitLab's MLflow compatibility improves to support search_registered_models() properly. """ try: client = self.client registered_models = client.search_registered_models() results = [] for model in registered_models: latest_version = None if model.latest_versions: mv = model.latest_versions[-1] latest_version = ( mv.tags.get("gitlab.version") if mv.tags else None ) or mv.version results.append({ "name": model.name, "latest_version": latest_version, }) return results except Exception as e: logger.warning(f"Failed to list registered models via MLflow client: {e}") return []
[docs] def register_model( self, model_name: str, version: str, description: Optional[str] = None, run_id: Optional[str] = None ) -> Optional[Any]: """Register a new model version in GitLab Model Registry. Creates the model if it doesn't exist, then creates a new version. IMPORTANT: On GitLab, each model version IS its own run. The run_id parameter is ignored by GitLab's create_model_version. To log artifacts to a model version, use the run_id returned by get_model_version() after calling this method. Args: model_name: Name of the model (e.g., "data_classifier") version: Semantic version string (e.g., "1.0.0") description: Optional description for this version run_id: Ignored by GitLab (kept for API compatibility) Returns: ModelVersion object if successful, None if MLflow disabled. Use model_version.run_id to log artifacts to this version. """ if not self.enabled: return None client = self.client # Ensure model exists try: client.get_registered_model(model_name) logger.debug(f"Model '{model_name}' exists in registry") except Exception: client.create_registered_model(model_name, description=description) logger.info(f"Created model '{model_name}' in GitLab Model Registry") # Create version with SemVer tag # Note: GitLab ignores source and run_id params tags = {"gitlab.version": version} model_version = client.create_model_version( model_name, source="", description=description, tags=tags ) logger.info( f"Registered model version: {model_name}/{version} " f"(version_run_id={model_version.run_id})" ) return model_version
[docs] def get_latest_model_version(self, model_name: str) -> Optional[str]: """Get latest version of a model from registry. Uses get_registered_model() which is reliably supported by GitLab's MLflow API (search_model_versions returns 404 on some GitLab instances). Prefers versions with a gitlab.version tag and a run_id (indicating artifacts are available). Falls back to any version with a gitlab.version tag, then to the MLflow version. Args: model_name: Name of the model Returns: Latest version string (e.g., "1.0.1") or None if not found """ if not self.enabled: return None try: model = self.client.get_registered_model(model_name) if not model.latest_versions: return None # Collect all versions with their metadata versions_with_run = [] versions_without_run = [] for v in model.latest_versions: tag_version = ( v.tags.get("gitlab.version") if v.tags else None ) version_str = tag_version or v.version if v.run_id: versions_with_run.append(version_str) else: versions_without_run.append(version_str) # Prefer versions that have a run_id (artifacts available) # Sort semantically so we pick the highest version number if versions_with_run: return self._max_semver(versions_with_run) if versions_without_run: return self._max_semver(versions_without_run) return model.latest_versions[0].version except Exception as e: logger.warning(f"Could not fetch latest version for {model_name}: {e}") return None
[docs] def load_model(self, model_name: str, version: str = "latest") -> Any: """Load model from GitLab Model Registry. Downloads model artifacts from GitLab Package Registry and loads using the appropriate MLflow flavor. Args: model_name: Name of the model version: Version to load ("latest" or specific like "1.0.1") Returns: Loaded model object Raises: RuntimeError: If MLflow is not configured Exception: If model cannot be loaded """ if not self.enabled: raise RuntimeError("MLflow not configured") model_uri = f"models:/{model_name}/{version}" logger.info(f"Loading model from GitLab: {model_uri}") return self.mlflow.pyfunc.load_model(model_uri)
[docs] def log_sklearn_model( self, model: Any, model_name: str, version: Optional[str] = None, register: bool = True ) -> Optional[str]: """Log an sklearn model to the current run and optionally register it. Convenience method that combines logging and registration. Args: model: Trained sklearn model model_name: Name for the model in registry version: Optional version (auto-increments if not provided) register: Whether to register in model registry Returns: Version string if registered, None otherwise """ if not self.enabled: return None mlflow = self.mlflow # Log model artifact mlflow.sklearn.log_model( model, artifact_path="", # Required empty for GitLab registered_model_name=model_name if register else None ) if register and version: # Register with explicit version self.register_model(model_name, version) return version elif register: # Return auto-assigned version return self.get_latest_model_version(model_name) return None
# ========================================================================= # Utility Methods # ========================================================================= @staticmethod def _max_semver(versions: list[str]) -> str: """Return the highest semantic version from a list of version strings. Handles versions like "1.1.9", "1.1.11" correctly (numeric comparison). Falls back to string comparison for non-semver versions. Args: versions: List of version strings (must be non-empty) Returns: The highest version string """ def _parse_version(v: str) -> tuple: try: return tuple(int(p) for p in v.split(".")) except (ValueError, AttributeError): return (0,) return max(versions, key=_parse_version)
[docs] def increment_version(self, version: str, bump: str = "patch") -> str: """Increment a semantic version string. Args: version: Current version (e.g., "1.0.0") bump: Which component to bump ("major", "minor", "patch") Returns: Incremented version string, or "1.0.0" if version is invalid Example: increment_version("1.0.0", "patch") → "1.0.1" increment_version("1.0.1", "minor") → "1.1.0" increment_version("1.1.0", "major") → "2.0.0" """ try: parts = [int(p) for p in version.split(".")] if len(parts) != 3: return "1.0.0" except (ValueError, AttributeError): return "1.0.0" if bump == "major": parts[0] += 1 parts[1] = 0 parts[2] = 0 elif bump == "minor": parts[1] += 1 parts[2] = 0 else: # patch parts[2] += 1 return ".".join(str(p) for p in parts)
[docs] def get_next_version(self, model_name: str, bump: str = "patch") -> str: """Determine next version for a model based on latest in registry. Args: model_name: Name of the model bump: Version component to bump ("major", "minor", "patch") Returns: Next version string (e.g., "1.0.1"), or "1.0.0" if no versions exist """ latest = self.get_latest_model_version(model_name) if latest: return self.increment_version(latest, bump) return "1.0.0"
[docs] def get_model_version_run_id( self, model_name: str, version: str | None = None ) -> tuple[str, str] | None: """Get (run_id, resolved_version) for a model version. On GitLab, each model version IS its own run. The run_id is available via get_model_version() — not via get_registered_model() which returns run_id=None in latest_versions. Args: model_name: Name of the model in the registry version: Specific version to look up, or None for latest Returns: Tuple of (run_id, resolved_version) if found, None otherwise """ if not self.enabled: return None # Resolve version if not specified resolved_version = version if resolved_version is None: resolved_version = self.get_latest_model_version(model_name) if resolved_version is None: logger.warning(f"No version found for model '{model_name}'") return None # On GitLab, get_model_version returns the version's own run_id try: mv = self.client.get_model_version(model_name, resolved_version) if mv and mv.run_id: logger.debug( f"Found run_id={mv.run_id} for {model_name} " f"v{resolved_version}" ) return (mv.run_id, resolved_version) logger.warning( f"No run_id for model '{model_name}' v{resolved_version}" ) except Exception as e: logger.warning( f"Could not get model version for {model_name} " f"v{resolved_version}: {e}" ) return None
[docs] def load_model_from_run(self, run_id: str) -> Any: """Load sklearn model from a specific MLflow run. Downloads artifacts and loads the model.joblib file. Falls back to mlflow.sklearn.load_model for backward compatibility with models logged via mlflow.sklearn.log_model. Args: run_id: The MLflow run ID containing the model artifact Returns: Loaded sklearn model object Raises: RuntimeError: If MLflow is not configured """ import joblib if not self._configured or self._mlflow is None: raise RuntimeError( "MLflow not configured. Set MLFLOW_TRACKING_URI and " "MLFLOW_TRACKING_TOKEN environment variables." ) # Download artifacts and look for model.joblib local_dir = self.client.download_artifacts(run_id, "") model_path = os.path.join(local_dir, "model.joblib") if os.path.exists(model_path): model = joblib.load(model_path) logger.info(f"Loaded model from MLflow run {run_id} (model.joblib)") return model # Fallback: try mlflow sklearn format (for older models) try: model = self.mlflow.sklearn.load_model(f"runs:/{run_id}/") logger.info(f"Loaded sklearn model from MLflow run {run_id} (sklearn format)") return model except Exception as e: raise FileNotFoundError( f"No model artifact found in run {run_id}: {e}" )
[docs] def load_vectorizer_from_run(self, run_id: str) -> Any: """Download and load the vectorizer joblib artifact from a run. Looks for vectorizer.joblib first (new format), then falls back to any .joblib file that isn't model.joblib (backward compat). Args: run_id: The MLflow run ID containing the vectorizer artifact Returns: Loaded vectorizer object Raises: RuntimeError: If MLflow is not configured """ import joblib if not self._configured or self._mlflow is None: raise RuntimeError( "MLflow not configured. Set MLFLOW_TRACKING_URI and " "MLFLOW_TRACKING_TOKEN environment variables." ) local_dir = self.client.download_artifacts(run_id, "") # Try exact name first vectorizer_path = os.path.join(local_dir, "vectorizer.joblib") if os.path.exists(vectorizer_path): vectorizer = joblib.load(vectorizer_path) logger.info( f"Loaded vectorizer from MLflow run {run_id} " f"(vectorizer.joblib)" ) return vectorizer # Fallback: find any .joblib that isn't model.joblib for filename in os.listdir(local_dir): if filename.endswith(".joblib") and filename != "model.joblib": path = os.path.join(local_dir, filename) vectorizer = joblib.load(path) logger.info( f"Loaded vectorizer from MLflow run {run_id} " f"(artifact: {filename})" ) return vectorizer raise FileNotFoundError( f"No vectorizer artifact found in run {run_id}" )
# ============================================================================= # Singleton Instance # ============================================================================= _mlflow_service: Optional[GitLabMLflowService] = None
[docs] def get_mlflow_service() -> GitLabMLflowService: """Get or create the MLflow service singleton. Returns: GitLabMLflowService instance (same instance on repeated calls) Example: service = get_mlflow_service() if service.enabled: with service.start_run("experiment", "run-name"): ... """ global _mlflow_service if _mlflow_service is None: _mlflow_service = GitLabMLflowService() return _mlflow_service
[docs] def reset_mlflow_service() -> None: """Reset the MLflow service singleton (for testing). Forces re-initialization on next get_mlflow_service() call. """ global _mlflow_service _mlflow_service = None