"""GitLab MLflow integration service.
Provides unified interface for experiment tracking and model registry
using GitLab's MLflow-compatible API.
GitLab acts as both experiment tracking server and model registry,
storing runs, metrics, parameters, and model artifacts.
Usage:
from ai4drpm.services.shared.mlflow_service import get_mlflow_service
service = get_mlflow_service()
if service.enabled:
with service.start_run("ai4drpm-classifiers", "data-1.0.1"):
mlflow.log_param("C", 10)
mlflow.log_metric("f1_score", 0.847)
mlflow.sklearn.log_model(model, artifact_path="")
"""
# Suppress GitPython warnings in environments without git installed
import os
os.environ.setdefault("GIT_PYTHON_REFRESH", "quiet")
import logging
from contextlib import contextmanager
from typing import Any, Dict, Generator, Optional
from ai4drpm import config
logger = logging.getLogger(__name__)
[docs]
class GitLabMLflowService:
"""Service for GitLab MLflow experiment tracking and model registry.
Uses GitLab as the MLflow backend via their API v4 compatibility layer.
Supports both experiment tracking and model registry operations.
Features:
- Experiment tracking with runs, parameters, metrics
- Model registry with semantic versioning
- Automatic CI/CD job linking when running in GitLab CI
- Graceful degradation when MLflow is not configured
Attributes:
EXPERIMENT_CLASSIFIERS: Experiment name for classifier training
EXPERIMENT_LLM_PIPELINES: Experiment name for LLM pipeline tracking
EXPERIMENT_ML_INFERENCE: Experiment name for ML inference tracking
"""
# Experiment name constants (from config)
EXPERIMENT_CLASSIFIERS = config.MLFLOW_EXPERIMENT_CLASSIFIERS
EXPERIMENT_LLM_PIPELINES = config.MLFLOW_EXPERIMENT_LLM_PIPELINES
EXPERIMENT_ML_INFERENCE = config.MLFLOW_EXPERIMENT_ML_INFERENCE
[docs]
def __init__(self):
"""Initialize with environment configuration.
Reads MLFLOW_TRACKING_URI and MLFLOW_TRACKING_TOKEN from config.
If both are set, configures MLflow to use GitLab backend.
"""
self.tracking_uri = config.MLFLOW_TRACKING_URI
self.tracking_token = config.MLFLOW_TRACKING_TOKEN
self._configured = False
self._mlflow = None # Lazy import
if self.tracking_uri and self.tracking_token:
self._configure()
else:
logger.info(
"MLflow tracking disabled. Set MLFLOW_TRACKING_URI and "
"MLFLOW_TRACKING_TOKEN to enable GitLab Model Registry integration."
)
def _configure(self) -> None:
"""Configure MLflow to use GitLab backend.
Sets the tracking URI and token environment variable.
The token is passed via env var as required by MLflow client.
"""
try:
import mlflow
self._mlflow = mlflow
# GitLab requires token in env var
os.environ["MLFLOW_TRACKING_TOKEN"] = self.tracking_token
mlflow.set_tracking_uri(self.tracking_uri)
self._configured = True
# Log truncated URI for security
uri_display = self.tracking_uri[:60] + "..." if len(self.tracking_uri) > 60 else self.tracking_uri
logger.info(f"MLflow configured with GitLab tracking: {uri_display}")
except ImportError:
logger.error("mlflow package not installed. Run: poetry add mlflow")
self._configured = False
except Exception as e:
logger.error(f"Failed to configure MLflow: {e}")
self._configured = False
@property
def enabled(self) -> bool:
"""Check if MLflow tracking is enabled and configured.
Returns:
True if MLflow is configured and ready to use
"""
return self._configured
@property
def mlflow(self):
"""Get the mlflow module.
Returns:
The mlflow module for direct API access
Raises:
RuntimeError: If MLflow is not configured
"""
if not self._configured or self._mlflow is None:
raise RuntimeError(
"MLflow not configured. Set MLFLOW_TRACKING_URI and "
"MLFLOW_TRACKING_TOKEN environment variables."
)
return self._mlflow
@property
def client(self):
"""Get MLflow client for registry operations.
Returns:
MlflowClient instance for model registry operations
Raises:
RuntimeError: If MLflow is not configured
"""
from mlflow import MlflowClient
return MlflowClient()
[docs]
@contextmanager
def start_run(
self,
experiment_name: str,
run_name: Optional[str] = None,
nested: bool = False,
tags: Optional[Dict[str, str]] = None
) -> Generator[Optional[Any], None, None]:
"""Context manager for MLflow run with GitLab CI integration.
Automatically creates the experiment if it doesn't exist.
Tags the run with CI job information when running in GitLab CI.
Args:
experiment_name: Name of the experiment (e.g., "ai4drpm-classifiers")
run_name: Optional name for this run (e.g., "data-1.0.1")
nested: If True, allows nested runs within an active run
tags: Optional dict of tags to add to the run
Yields:
MLflow Run object if enabled, None otherwise
Example:
with service.start_run("ai4drpm-classifiers", "data-1.0.1") as run:
if run:
mlflow.log_param("C", 10)
mlflow.log_metric("f1_score", 0.847)
"""
if not self.enabled:
# No-op context manager when disabled
logger.debug("MLflow disabled, skipping run tracking")
yield None
return
mlflow = self.mlflow
# Set or create experiment
mlflow.set_experiment(experiment_name)
with mlflow.start_run(run_name=run_name, nested=nested) as run:
# Add GitLab CI integration tags
self._add_ci_tags()
# Add custom tags
if tags:
mlflow.set_tags(tags)
logger.debug(f"Started MLflow run: {run.info.run_id} ({run_name})")
yield run
logger.debug(f"Finished MLflow run: {run.info.run_id}")
def _add_ci_tags(self) -> None:
"""Add GitLab CI environment tags to current run.
Automatically detects GitLab CI environment and adds relevant tags
to link the run to the pipeline, job, and commit.
"""
if not os.getenv("GITLAB_CI"):
return
mlflow = self.mlflow
ci_tags = {
"gitlab.CI_JOB_ID": os.getenv("CI_JOB_ID"),
"gitlab.CI_PIPELINE_ID": os.getenv("CI_PIPELINE_ID"),
"gitlab.CI_COMMIT_SHA": os.getenv("CI_COMMIT_SHA"),
"gitlab.CI_COMMIT_REF_NAME": os.getenv("CI_COMMIT_REF_NAME"),
"gitlab.CI_PROJECT_PATH": os.getenv("CI_PROJECT_PATH"),
}
# Filter out None values
ci_tags = {k: v for k, v in ci_tags.items() if v is not None}
if ci_tags:
mlflow.set_tags(ci_tags)
logger.debug(f"Added GitLab CI tags: {list(ci_tags.keys())}")
# =========================================================================
# Model Registry Operations
# =========================================================================
[docs]
def list_registered_models(self) -> list[dict]:
"""List all registered models from GitLab Model Registry.
Supports two discovery methods controlled by MLFLOW_REGISTRY_DISCOVERY_METHOD:
- "http": Direct REST API call (default, works with GitLab)
- "mlflow_client": Uses MLflow Python client (for when GitLab fixes compatibility)
Returns:
List of dicts with 'name' and 'latest_version' keys.
Empty list if MLflow is disabled or the request fails.
"""
if not self.enabled:
return []
method = config.MLFLOW_REGISTRY_DISCOVERY_METHOD
if method == "mlflow_client":
return self._list_models_via_client()
else:
return self._list_models_via_http()
def _list_models_via_http(self) -> list[dict]:
"""List models using direct HTTP call to MLflow REST API.
This bypasses the MLflow Python client which sends parameters
that GitLab's compatibility layer rejects.
"""
import requests
url = f"{self.tracking_uri.rstrip('/')}{config.MLFLOW_REGISTRY_SEARCH_PATH}"
headers = {"Authorization": f"Bearer {self.tracking_token}"}
try:
response = requests.get(url, headers=headers, timeout=30)
response.raise_for_status()
data = response.json()
results = []
for model_data in data.get("registered_models", []):
name = model_data.get("name")
if not name:
continue
latest_version = None
latest_versions = model_data.get("latest_versions", [])
if latest_versions:
latest_version = latest_versions[-1].get("version")
results.append({
"name": name,
"latest_version": latest_version,
})
return results
except Exception as e:
logger.warning(f"Failed to list registered models via HTTP: {e}")
return []
def _list_models_via_client(self) -> list[dict]:
"""List models using the MLflow Python client.
Use when GitLab's MLflow compatibility improves to support
search_registered_models() properly.
"""
try:
client = self.client
registered_models = client.search_registered_models()
results = []
for model in registered_models:
latest_version = None
if model.latest_versions:
mv = model.latest_versions[-1]
latest_version = (
mv.tags.get("gitlab.version") if mv.tags else None
) or mv.version
results.append({
"name": model.name,
"latest_version": latest_version,
})
return results
except Exception as e:
logger.warning(f"Failed to list registered models via MLflow client: {e}")
return []
[docs]
def register_model(
self,
model_name: str,
version: str,
description: Optional[str] = None,
run_id: Optional[str] = None
) -> Optional[Any]:
"""Register a new model version in GitLab Model Registry.
Creates the model if it doesn't exist, then creates a new version.
IMPORTANT: On GitLab, each model version IS its own run. The run_id
parameter is ignored by GitLab's create_model_version. To log
artifacts to a model version, use the run_id returned by
get_model_version() after calling this method.
Args:
model_name: Name of the model (e.g., "data_classifier")
version: Semantic version string (e.g., "1.0.0")
description: Optional description for this version
run_id: Ignored by GitLab (kept for API compatibility)
Returns:
ModelVersion object if successful, None if MLflow disabled.
Use model_version.run_id to log artifacts to this version.
"""
if not self.enabled:
return None
client = self.client
# Ensure model exists
try:
client.get_registered_model(model_name)
logger.debug(f"Model '{model_name}' exists in registry")
except Exception:
client.create_registered_model(model_name, description=description)
logger.info(f"Created model '{model_name}' in GitLab Model Registry")
# Create version with SemVer tag
# Note: GitLab ignores source and run_id params
tags = {"gitlab.version": version}
model_version = client.create_model_version(
model_name,
source="",
description=description,
tags=tags
)
logger.info(
f"Registered model version: {model_name}/{version} "
f"(version_run_id={model_version.run_id})"
)
return model_version
[docs]
def get_latest_model_version(self, model_name: str) -> Optional[str]:
"""Get latest version of a model from registry.
Uses get_registered_model() which is reliably supported by
GitLab's MLflow API (search_model_versions returns 404 on
some GitLab instances).
Prefers versions with a gitlab.version tag and a run_id
(indicating artifacts are available). Falls back to any
version with a gitlab.version tag, then to the MLflow version.
Args:
model_name: Name of the model
Returns:
Latest version string (e.g., "1.0.1") or None if not found
"""
if not self.enabled:
return None
try:
model = self.client.get_registered_model(model_name)
if not model.latest_versions:
return None
# Collect all versions with their metadata
versions_with_run = []
versions_without_run = []
for v in model.latest_versions:
tag_version = (
v.tags.get("gitlab.version") if v.tags else None
)
version_str = tag_version or v.version
if v.run_id:
versions_with_run.append(version_str)
else:
versions_without_run.append(version_str)
# Prefer versions that have a run_id (artifacts available)
# Sort semantically so we pick the highest version number
if versions_with_run:
return self._max_semver(versions_with_run)
if versions_without_run:
return self._max_semver(versions_without_run)
return model.latest_versions[0].version
except Exception as e:
logger.warning(f"Could not fetch latest version for {model_name}: {e}")
return None
[docs]
def load_model(self, model_name: str, version: str = "latest") -> Any:
"""Load model from GitLab Model Registry.
Downloads model artifacts from GitLab Package Registry and loads
using the appropriate MLflow flavor.
Args:
model_name: Name of the model
version: Version to load ("latest" or specific like "1.0.1")
Returns:
Loaded model object
Raises:
RuntimeError: If MLflow is not configured
Exception: If model cannot be loaded
"""
if not self.enabled:
raise RuntimeError("MLflow not configured")
model_uri = f"models:/{model_name}/{version}"
logger.info(f"Loading model from GitLab: {model_uri}")
return self.mlflow.pyfunc.load_model(model_uri)
[docs]
def log_sklearn_model(
self,
model: Any,
model_name: str,
version: Optional[str] = None,
register: bool = True
) -> Optional[str]:
"""Log an sklearn model to the current run and optionally register it.
Convenience method that combines logging and registration.
Args:
model: Trained sklearn model
model_name: Name for the model in registry
version: Optional version (auto-increments if not provided)
register: Whether to register in model registry
Returns:
Version string if registered, None otherwise
"""
if not self.enabled:
return None
mlflow = self.mlflow
# Log model artifact
mlflow.sklearn.log_model(
model,
artifact_path="", # Required empty for GitLab
registered_model_name=model_name if register else None
)
if register and version:
# Register with explicit version
self.register_model(model_name, version)
return version
elif register:
# Return auto-assigned version
return self.get_latest_model_version(model_name)
return None
# =========================================================================
# Utility Methods
# =========================================================================
@staticmethod
def _max_semver(versions: list[str]) -> str:
"""Return the highest semantic version from a list of version strings.
Handles versions like "1.1.9", "1.1.11" correctly (numeric comparison).
Falls back to string comparison for non-semver versions.
Args:
versions: List of version strings (must be non-empty)
Returns:
The highest version string
"""
def _parse_version(v: str) -> tuple:
try:
return tuple(int(p) for p in v.split("."))
except (ValueError, AttributeError):
return (0,)
return max(versions, key=_parse_version)
[docs]
def increment_version(self, version: str, bump: str = "patch") -> str:
"""Increment a semantic version string.
Args:
version: Current version (e.g., "1.0.0")
bump: Which component to bump ("major", "minor", "patch")
Returns:
Incremented version string, or "1.0.0" if version is invalid
Example:
increment_version("1.0.0", "patch") → "1.0.1"
increment_version("1.0.1", "minor") → "1.1.0"
increment_version("1.1.0", "major") → "2.0.0"
"""
try:
parts = [int(p) for p in version.split(".")]
if len(parts) != 3:
return "1.0.0"
except (ValueError, AttributeError):
return "1.0.0"
if bump == "major":
parts[0] += 1
parts[1] = 0
parts[2] = 0
elif bump == "minor":
parts[1] += 1
parts[2] = 0
else: # patch
parts[2] += 1
return ".".join(str(p) for p in parts)
[docs]
def get_next_version(self, model_name: str, bump: str = "patch") -> str:
"""Determine next version for a model based on latest in registry.
Args:
model_name: Name of the model
bump: Version component to bump ("major", "minor", "patch")
Returns:
Next version string (e.g., "1.0.1"), or "1.0.0" if no versions exist
"""
latest = self.get_latest_model_version(model_name)
if latest:
return self.increment_version(latest, bump)
return "1.0.0"
[docs]
def get_model_version_run_id(
self, model_name: str, version: str | None = None
) -> tuple[str, str] | None:
"""Get (run_id, resolved_version) for a model version.
On GitLab, each model version IS its own run. The run_id is
available via get_model_version() — not via get_registered_model()
which returns run_id=None in latest_versions.
Args:
model_name: Name of the model in the registry
version: Specific version to look up, or None for latest
Returns:
Tuple of (run_id, resolved_version) if found, None otherwise
"""
if not self.enabled:
return None
# Resolve version if not specified
resolved_version = version
if resolved_version is None:
resolved_version = self.get_latest_model_version(model_name)
if resolved_version is None:
logger.warning(f"No version found for model '{model_name}'")
return None
# On GitLab, get_model_version returns the version's own run_id
try:
mv = self.client.get_model_version(model_name, resolved_version)
if mv and mv.run_id:
logger.debug(
f"Found run_id={mv.run_id} for {model_name} "
f"v{resolved_version}"
)
return (mv.run_id, resolved_version)
logger.warning(
f"No run_id for model '{model_name}' v{resolved_version}"
)
except Exception as e:
logger.warning(
f"Could not get model version for {model_name} "
f"v{resolved_version}: {e}"
)
return None
[docs]
def load_model_from_run(self, run_id: str) -> Any:
"""Load sklearn model from a specific MLflow run.
Downloads artifacts and loads the model.joblib file.
Falls back to mlflow.sklearn.load_model for backward compatibility
with models logged via mlflow.sklearn.log_model.
Args:
run_id: The MLflow run ID containing the model artifact
Returns:
Loaded sklearn model object
Raises:
RuntimeError: If MLflow is not configured
"""
import joblib
if not self._configured or self._mlflow is None:
raise RuntimeError(
"MLflow not configured. Set MLFLOW_TRACKING_URI and "
"MLFLOW_TRACKING_TOKEN environment variables."
)
# Download artifacts and look for model.joblib
local_dir = self.client.download_artifacts(run_id, "")
model_path = os.path.join(local_dir, "model.joblib")
if os.path.exists(model_path):
model = joblib.load(model_path)
logger.info(f"Loaded model from MLflow run {run_id} (model.joblib)")
return model
# Fallback: try mlflow sklearn format (for older models)
try:
model = self.mlflow.sklearn.load_model(f"runs:/{run_id}/")
logger.info(f"Loaded sklearn model from MLflow run {run_id} (sklearn format)")
return model
except Exception as e:
raise FileNotFoundError(
f"No model artifact found in run {run_id}: {e}"
)
[docs]
def load_vectorizer_from_run(self, run_id: str) -> Any:
"""Download and load the vectorizer joblib artifact from a run.
Looks for vectorizer.joblib first (new format), then falls back
to any .joblib file that isn't model.joblib (backward compat).
Args:
run_id: The MLflow run ID containing the vectorizer artifact
Returns:
Loaded vectorizer object
Raises:
RuntimeError: If MLflow is not configured
"""
import joblib
if not self._configured or self._mlflow is None:
raise RuntimeError(
"MLflow not configured. Set MLFLOW_TRACKING_URI and "
"MLFLOW_TRACKING_TOKEN environment variables."
)
local_dir = self.client.download_artifacts(run_id, "")
# Try exact name first
vectorizer_path = os.path.join(local_dir, "vectorizer.joblib")
if os.path.exists(vectorizer_path):
vectorizer = joblib.load(vectorizer_path)
logger.info(
f"Loaded vectorizer from MLflow run {run_id} "
f"(vectorizer.joblib)"
)
return vectorizer
# Fallback: find any .joblib that isn't model.joblib
for filename in os.listdir(local_dir):
if filename.endswith(".joblib") and filename != "model.joblib":
path = os.path.join(local_dir, filename)
vectorizer = joblib.load(path)
logger.info(
f"Loaded vectorizer from MLflow run {run_id} "
f"(artifact: {filename})"
)
return vectorizer
raise FileNotFoundError(
f"No vectorizer artifact found in run {run_id}"
)
# =============================================================================
# Singleton Instance
# =============================================================================
_mlflow_service: Optional[GitLabMLflowService] = None
[docs]
def get_mlflow_service() -> GitLabMLflowService:
"""Get or create the MLflow service singleton.
Returns:
GitLabMLflowService instance (same instance on repeated calls)
Example:
service = get_mlflow_service()
if service.enabled:
with service.start_run("experiment", "run-name"):
...
"""
global _mlflow_service
if _mlflow_service is None:
_mlflow_service = GitLabMLflowService()
return _mlflow_service
[docs]
def reset_mlflow_service() -> None:
"""Reset the MLflow service singleton (for testing).
Forces re-initialization on next get_mlflow_service() call.
"""
global _mlflow_service
_mlflow_service = None