Source code for olive.resource_path

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import shutil
import tempfile
from abc import abstractmethod
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Dict, Union

from pydantic import Field, validator

from olive.azureml.azureml_client import AzureMLClientConfig
from olive.common.auto_config import AutoConfigClass
from olive.common.config_utils import ConfigBase, ConfigParam, serialize_to_json, validate_config
from olive.common.utils import retry_func

logger = logging.getLogger(__name__)


[docs]class ResourceType(str, Enum): LocalFile = "file" LocalFolder = "folder" StringName = "string_name" AzureMLModel = "azureml_model" AzureMLDatastore = "azureml_datastore" AzureMLJobOutput = "azureml_job_output" def __str__(self) -> str: return self.value
LOCAL_RESOURCE_TYPES = [ResourceType.LocalFile, ResourceType.LocalFolder] AZUREML_RESOURCE_TYPES = [ResourceType.AzureMLModel, ResourceType.AzureMLDatastore, ResourceType.AzureMLJobOutput] class ResourcePath(AutoConfigClass): registry: Dict[str, "ResourcePath"] = {} name: ResourceType = None @property def type(self) -> ResourceType: return self.name @abstractmethod def get_path(self) -> str: """Return the resource path as a string.""" raise NotImplementedError @abstractmethod def save_to_dir(self, dir_path: Union[Path, str], name: str = None, overwrite: bool = False) -> str: """Save the resource to a directory.""" raise NotImplementedError def is_local_resource(self) -> bool: """Return True if the resource is a local resource.""" return self.type in LOCAL_RESOURCE_TYPES def is_azureml_resource(self) -> bool: """Return True if the resource is an AzureML resource.""" return self.type in AZUREML_RESOURCE_TYPES def is_string_name(self) -> bool: """Return True if the resource is a string name.""" return self.type == ResourceType.StringName def to_json(self): json_data = {"type": self.type, "config": self.config.to_json()} return serialize_to_json(json_data) class ResourcePathConfig(ConfigBase): type: ResourceType = Field(..., description="Type of the resource.") config: ConfigBase = Field(..., description="Config of the resource.") @validator("config", pre=True) def validate_config(cls, v, values): if "type" not in values: raise ValueError("Invalid type.") config_class = ResourcePath.registry[values["type"]].get_config_class() return validate_config(v, ConfigBase, config_class) def create_resource_path(self) -> ResourcePath: return ResourcePath.registry[self.type](self.config) def create_resource_path( resource_path: Union[str, Path, Dict[str, Any], ResourcePathConfig, ResourcePath] ) -> ResourcePath: """ Create a resource path from a string or a dict. If a string is provided, it is inferred to be a file, folder, or string name. If a Path is provided, it is inferred to be a file or folder. If a dict is provided, it must have "type" and "config" fields. The "type" field must be one of the values in the ResourceType enum. The "config" field must be a dict that can be used to create a resource config of the specified type. :param resource_path: A string, a Path, or a dict. :return: A resource path. """ if isinstance(resource_path, ResourcePath): return resource_path if isinstance(resource_path, ResourcePathConfig) or isinstance(resource_path, dict): resource_path_config = validate_config(resource_path, ResourcePathConfig) return resource_path_config.create_resource_path() if isinstance(resource_path, Path) and not resource_path.exists(): ValueError(f"Resource path {resource_path} of type Path is not a file or folder.") # check if the resource path is a file, folder, or a string name type: ResourceType = None config_key = "path" if Path(resource_path).is_file(): type = ResourceType.LocalFile elif Path(resource_path).is_dir(): type = ResourceType.LocalFolder else: type = ResourceType.StringName config_key = "name" logger.debug(f"Resource path {resource_path} is inferred to be of type {type}.") return ResourcePathConfig(type=type, config={config_key: resource_path}).create_resource_path() def _overwrite_helper(new_path: Union[Path, str], overwrite: bool): new_path = Path(new_path).resolve() # check if the resource already exists if new_path.exists(): if not overwrite: # raise an error if the file/folder with same name already exists and overwrite is set to False # Olive doesn't know if the existing file/folder is the same as the one being saved # or if the user wants to overwrite the existing file/folder raise FileExistsError( f"Trying to save resource to {new_path} but a file/folder with the same name already exists and" " overwrite is set to False. If you want to overwrite the existing file/folder, set overwrite to True." ) else: # delete the resource if it already exists if new_path.is_file(): new_path.unlink() else: shutil.rmtree(new_path) def _validate_path(v): if not Path(v).exists(): raise ValueError(f"Path {v} does not exist.") return Path(v).resolve() class LocalResourcePath(ResourcePath): """Base class for a local resource path.""" @staticmethod def _default_config() -> Dict[str, Any]: return { "path": ConfigParam(type_=Union[Path, str], required=True, description="Path to the resource."), } @staticmethod def _validators() -> Dict[str, Callable]: return {"validate_path": validator("path", allow_reuse=True)(_validate_path)} def get_path(self) -> str: return str(self.config.path) def save_to_dir(self, dir_path: Union[Path, str], name: str = None, overwrite: bool = False) -> str: # directory to save the resource to dir_path = Path(dir_path).resolve() dir_path.mkdir(parents=True, exist_ok=True) # path to save the resource to if name: new_path_name = Path(name).with_suffix(self.config.path.suffix).name else: new_path_name = self.config.path.name new_path = dir_path / new_path_name _overwrite_helper(new_path, overwrite) # is the resource a file or a folder is_file = Path(self.config.path).is_file() # copy the resource to the new path if is_file: shutil.copy(self.config.path, new_path) else: shutil.copytree(self.config.path, new_path) return str(new_path) def _validate_file_path(v): path = Path(v) if not path.is_file(): raise ValueError(f"Path {path} is not a file.") return path class LocalFile(LocalResourcePath): """Local file resource path""" name = ResourceType.LocalFile @staticmethod def _validators() -> Dict[str, Callable[..., Any]]: validators = LocalResourcePath._validators() validators.update({"validate_file_path": validator("path", allow_reuse=True)(_validate_file_path)}) return validators def _validate_folder_path(v): path = Path(v) if not path.is_dir(): raise ValueError(f"Path {path} is not a folder.") return path class LocalFolder(LocalResourcePath): """Local folder resource path""" name = ResourceType.LocalFolder @staticmethod def _validators() -> Dict[str, Callable[..., Any]]: validators = LocalResourcePath._validators() validators.update({"validate_folder_path": validator("path", allow_reuse=True)(_validate_folder_path)}) return validators class StringName(ResourcePath): """String name resource path""" name = ResourceType.StringName @staticmethod def _default_config() -> Dict[str, Any]: return { "name": ConfigParam(type_=str, required=True, description="Name of the resource."), } def get_path(self) -> str: return self.config.name def save_to_dir(self, dir_path: Union[Path, str], name: str = None, overwrite: bool = False) -> str: logger.debug("Resource is a string name. No need to save to directory.") return self.config.name def _get_azureml_resource_prefix(workspace_config: Dict[str, str]) -> str: return ( f"azureml://subscriptions/{workspace_config['subscription_id']}" f"/resourcegroups/{workspace_config['resource_group']}" f"/workspaces/{workspace_config['workspace_name']}" ) class AzureMLModel(ResourcePath): """AzureML Model resource path""" name = ResourceType.AzureMLModel @staticmethod def _default_config() -> Dict[str, Any]: return { "azureml_client": ConfigParam( type_=AzureMLClientConfig, required=True, description="AzureML client config." ), "name": ConfigParam(type_=str, required=True, description="Name of the model."), "version": ConfigParam(type_=Union[int, str], required=True, description="Version of the model."), } def get_path(self) -> str: return f"azureml:{self.config.name}:{self.config.version}" def save_to_dir(self, dir_path: Union[Path, str], name: str = None, overwrite: bool = False) -> str: # directory to save the resource to dir_path = Path(dir_path).resolve() dir_path.mkdir(parents=True, exist_ok=True) # azureml client ml_client = self.config.azureml_client.create_client() # azureml model model = ml_client.models.get(self.config.name, self.config.version) model_path = Path(model.path) # path to save the resource to if name: new_path_name = Path(name).with_suffix(model_path.suffix).name else: new_path_name = model_path.name new_path = dir_path / new_path_name _overwrite_helper(new_path, overwrite) # download the resource to the new path logger.debug(f"Downloading model {self.config.name} version {self.config.version} to {new_path}.") from azure.core.exceptions import ServiceResponseError with tempfile.TemporaryDirectory(dir=dir_path, prefix="olive_tmp") as temp_dir: temp_dir = Path(temp_dir) retry_func( ml_client.models.download, [self.config.name], {"version": self.config.version, "download_path": temp_dir}, max_tries=self.config.azureml_client.max_operation_retries, delay=self.config.azureml_client.operation_retry_interval, exceptions=ServiceResponseError, ) new_path.parent.mkdir(parents=True, exist_ok=True) shutil.move(temp_dir / self.config.name / model_path.name, new_path) return str(new_path) class AzureMLDatastore(ResourcePath): """AzureML datastore resource path""" name = ResourceType.AzureMLDatastore @staticmethod def _default_config() -> Dict[str, Any]: return { "azureml_client": ConfigParam( type_=AzureMLClientConfig, required=True, description="AzureML client config." ), "datastore_name": ConfigParam(type_=str, required=True, description="Name of the datastore."), "relative_path": ConfigParam(type_=str, required=True, description="Relative path to the resource."), } def get_path(self) -> str: workspace_config = self.config.azureml_client.get_workspace_config() return ( f"{_get_azureml_resource_prefix(workspace_config)}" f"/datastores/{self.config.datastore_name}/paths/{self.config.relative_path}" ) def save_to_dir(self, dir_path: Union[Path, str], name: str = None, overwrite: bool = False) -> str: # there is no direct way to download a file from a datastore # so we will use a workaround to download the file by creating a aml model # that references the file and downloading the model from azure.ai.ml.constants import AssetTypes from azure.ai.ml.entities import Model from azure.core.exceptions import ServiceResponseError # azureml client ml_client = self.config.azureml_client.create_client() # create aml model logger.debug(f"Creating aml model for datastore {self.config.datastore_name} path {self.config.relative_path}.") aml_model = retry_func( ml_client.models.create_or_update, [ Model( path=self.get_path(), name="olive-backend-model", description="Model created by Olive backend. Ignore this model.", type=AssetTypes.CUSTOM_MODEL, ) ], max_tries=self.config.azureml_client.max_operation_retries, delay=self.config.azureml_client.operation_retry_interval, exceptions=ServiceResponseError, ) # use the AzureMLModel to download the model logger.debug( f"Downloading aml model for datastore {self.config.datastore_name} path {self.config.relative_path}." ) azureml_model_resource = AzureMLModel( {"azureml_client": self.config.azureml_client, "name": aml_model.name, "version": aml_model.version} ) return azureml_model_resource.save_to_dir(dir_path, name, overwrite) class AzureMLJobOutput(ResourcePath): """AzureML job output resource path""" name = ResourceType.AzureMLJobOutput @staticmethod def _default_config() -> Dict[str, Any]: return { "azureml_client": ConfigParam( type_=AzureMLClientConfig, required=True, description="AzureML client config." ), "job_name": ConfigParam(type_=str, required=True, description="Name of the job."), "output_name": ConfigParam(type_=str, required=True, description="Name of the output."), "relative_path": ConfigParam(type_=str, required=True, description="Relative path to the resource."), } def get_path(self) -> str: return ( f"azureml://jobs/{self.config.job_name}/outputs/{self.config.output_name}/paths/{self.config.relative_path}" ) def save_to_dir(self, dir_path: Union[Path, str], name: str = None, overwrite: bool = False) -> str: # directory to save the resource to dir_path = Path(dir_path).resolve() dir_path.mkdir(parents=True, exist_ok=True) # path to save the resource to if name: new_path_name = Path(name).with_suffix(Path(self.config.relative_path).suffix).name else: new_path_name = Path(self.config.relative_path).name new_path = dir_path / new_path_name _overwrite_helper(new_path, overwrite) # download the resource to the new path ml_client = self.config.azureml_client.create_client() logger.debug(f"Downloading job output {self.config.job_name} output {self.config.output_name} to {new_path}.") from azure.core.exceptions import ServiceResponseError with tempfile.TemporaryDirectory(dir=dir_path, prefix="olive_tmp") as temp_dir: temp_dir = Path(temp_dir) retry_func( ml_client.jobs.download, [self.config.job_name], {"output_name": self.config.output_name, "download_path": temp_dir}, max_tries=self.config.azureml_client.max_operation_retries, delay=self.config.azureml_client.operation_retry_interval, exceptions=ServiceResponseError, ) new_path.parent.mkdir(parents=True, exist_ok=True) shutil.move(temp_dir / "named-outputs" / self.config.output_name / self.config.relative_path, new_path) return str(new_path)