register_model_provider¶
inference_models.developer_tools.register_model_provider
¶
Register a custom weights provider for model retrieval.
Allows registration of custom model providers that can be used with
AutoModel.from_pretrained() and get_model_from_provider(). This enables
loading models from custom sources beyond the default Roboflow provider.
Parameters:
-
(provider_name¶str) –Unique name for the provider. Will be used as the
weights_providerparameter inAutoModel.from_pretrained(). -
(provider_handler¶WeightsProvider) –Callable that implements the provider interface. Must accept (model_id: str, api_key: Optional[str], **kwargs) and return a ModelMetadata object.
Examples:
Register a custom provider:
>>> from inference_models.developer_tools import (
... register_model_provider,
... ModelMetadata,
... ModelPackageMetadata,
... ONNXPackageDetails
... )
>>> from inference_models import BackendType, Quantization
>>>
>>> def my_custom_provider(model_id: str, api_key: str = None, **kwargs):
... # Fetch model metadata from your custom source
... return ModelMetadata(
... model_id=model_id,
... model_packages=[
... ModelPackageMetadata(
... backend_type=BackendType.ONNX,
... quantization=Quantization.FP32,
... package_details=ONNXPackageDetails(
... download_url="https://my-server.com/model.onnx",
... md5_hash="abc123...",
... # ... other details
... ),
... # ... other metadata
... )
... ],
... dependencies=[]
... )
>>>
>>> # Register the provider
>>> register_model_provider("my_provider", my_custom_provider)
>>>
>>> # Now use it with AutoModel
>>> from inference_models import AutoModel
>>> model = AutoModel.from_pretrained(
... "my-model-id",
... weights_provider="my_provider"
... )
Register a provider for local file system:
>>> def local_file_provider(model_id: str, api_key: str = None, **kwargs):
... base_path = kwargs.get("base_path", "/models")
... model_path = f"{base_path}/{model_id}"
...
... # Return metadata pointing to local files
... return ModelMetadata(
... model_id=model_id,
... model_packages=[...], # Configure packages
... dependencies=[]
... )
>>>
>>> register_model_provider("local", local_file_provider)
>>>
>>> model = AutoModel.from_pretrained(
... "yolov8n",
... weights_provider="local",
... base_path="/my/models"
... )
Note
- Provider handlers must return a
ModelMetadataobject - The provider name must be unique (will override existing providers)
- Provider handlers should handle authentication and error cases
See Also
get_model_from_provider(): Retrieve models using registered providersModelMetadata: Structure for model metadataModelPackageMetadata: Structure for package metadata