Materializer | Handled Data Types | Storage Format |
---|---|---|
BuiltInMaterializer | bool , float , int , str , None | .json |
BytesInMaterializer | bytes | .txt |
BuiltInContainerMaterializer | dict , list , set , tuple | Directory |
NumpyMaterializer | np.ndarray | .npy |
PandasMaterializer | pd.DataFrame , pd.Series | .csv (or .gzip if parquet is installed) |
PydanticMaterializer | pydantic.BaseModel | .json |
ServiceMaterializer | zenml.services.service.BaseService | .json |
StructuredStringMaterializer | zenml.types.CSVString , zenml.types.HTMLString , zenml.types.MarkdownString | .csv / .html / .md (depending on type) |
Integration | Materializer | Handled Data Types | Storage Format |
---|---|---|---|
bentoml | BentoMaterializer | bentoml.Bento | .bento |
deepchecks | DeepchecksResultMateriailzer | deepchecks.CheckResult , deepchecks.SuiteResult | .json |
evidently | EvidentlyProfileMaterializer | evidently.Profile | .json |
great_expectations | GreatExpectationsMaterializer | great_expectations.ExpectationSuite , great_expectations.CheckpointResult | .json |
huggingface | HFDatasetMaterializer | datasets.Dataset , datasets.DatasetDict | Directory |
huggingface | HFPTModelMaterializer | transformers.PreTrainedModel | Directory |
huggingface | HFTFModelMaterializer | transformers.TFPreTrainedModel | Directory |
huggingface | HFTokenizerMaterializer | transformers.PreTrainedTokenizerBase | Directory |
lightgbm | LightGBMBoosterMaterializer | lgbm.Booster | .txt |
lightgbm | LightGBMDatasetMaterializer | lgbm.Dataset | .binary |
neural_prophet | NeuralProphetMaterializer | NeuralProphet | .pt |
pillow | PillowImageMaterializer | Pillow.Image | .PNG |
polars | PolarsMaterializer | pl.DataFrame , pl.Series | .parquet |
pycaret | PyCaretMaterializer | Any sklearn , xgboost , lightgbm or catboost model | .pkl |
pytorch | PyTorchDataLoaderMaterializer | torch.Dataset , torch.DataLoader | .pt |
pytorch | PyTorchModuleMaterializer | torch.Module | .pt |
scipy | SparseMaterializer | scipy.spmatrix | .npz |
spark | SparkDataFrameMaterializer | pyspark.DataFrame | .parquet |
spark | SparkModelMaterializer | pyspark.Transformer | pyspark.Estimator |
tensorflow | KerasMaterializer | tf.keras.Model | Directory |
tensorflow | TensorflowDatasetMaterializer | tf.Dataset | Directory |
whylogs | WhylogsMaterializer | whylogs.DatasetProfileView | .pb |
xgboost | XgboostBoosterMaterializer | xgb.Booster | .json |
xgboost | XgboostDMatrixMaterializer | xgb.DMatrix | .binary |
Integration | Materializer | Handled Data Types | Storage Format |
---|---|---|---|
bentoml | BentoMaterializer | bentoml.Bento | .bento |
deepchecks | DeepchecksResultMateriailzer | deepchecks.CheckResult , deepchecks.SuiteResult | .json |
evidently | EvidentlyProfileMaterializer | evidently.Profile | .json |
great_expectations | GreatExpectationsMaterializer | great_expectations.ExpectationSuite , great_expectations.CheckpointResult | .json |
huggingface | HFDatasetMaterializer | datasets.Dataset , datasets.DatasetDict | Directory |
huggingface | HFPTModelMaterializer | transformers.PreTrainedModel | Directory |
huggingface | HFTFModelMaterializer | transformers.TFPreTrainedModel | Directory |
huggingface | HFTokenizerMaterializer | transformers.PreTrainedTokenizerBase | Directory |
lightgbm | LightGBMBoosterMaterializer | lgbm.Booster | .txt |
lightgbm | LightGBMDatasetMaterializer | lgbm.Dataset | .binary |
neural_prophet | NeuralProphetMaterializer | NeuralProphet | .pt |
pillow | PillowImageMaterializer | Pillow.Image | .PNG |
polars | PolarsMaterializer | pl.DataFrame , pl.Series | .parquet |
pycaret | PyCaretMaterializer | Any sklearn , xgboost , lightgbm or catboost model | .pkl |
pytorch | PyTorchDataLoaderMaterializer | torch.Dataset , torch.DataLoader | .pt |
pytorch | PyTorchModuleMaterializer | torch.Module | .pt |
scipy | SparseMaterializer | scipy.spmatrix | .npz |
spark | SparkDataFrameMaterializer | pyspark.DataFrame | .parquet |
spark | SparkModelMaterializer | pyspark.Transformer | pyspark.Estimator |
tensorflow | KerasMaterializer | tf.keras.Model | Directory |
tensorflow | TensorflowDatasetMaterializer | tf.Dataset | Directory |
whylogs | WhylogsMaterializer | whylogs.DatasetProfileView | .pb |
xgboost | XgboostBoosterMaterializer | xgb.Booster | .json |
xgboost | XgboostDMatrixMaterializer | xgb.DMatrix | .binary |
self._chunk_size: + part_len = self._chunk_size - self._idx + self._arr[self._idx : self._idx + part_len] = arr[:part_len] + self._write_chunk() + arr = arr[part_len:] + + arr_len = arr.shape[0] + self._arr[self._idx : self._idx + arr_len] = arr + self._idx += arr_len + + def write_reminder(self): + self._write_chunk() + + +class PackedDatasetIterator: + def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap): + self._seed = seed + self._shuffle = shuffle + self._rng = np.random.default_rng(seed) if shuffle else None + self._block_idxs = None + + self._wrap = wrap + + # TODO: instead of filenames, we could have a single text stream + # (or text file) with the sequence of all files to be + # fetched/loaded. + self._filenames = filenames + self._file_idx = 0 + + self._n_chunks = n_chunks + + self._dtype = None + self._block_size = block_size + self._n_blocks = None + + self._mmaps = [] + self._buffers = [] + + self._block_idxs = [] + self._curr_idx = 0 + + self._load_n_chunks() + + def _read_header(self, path): + with open(path, "rb") as f: + magic = f.read(len(HDR_MAGIC)) + assert magic == HDR_MAGIC, "File doesn't match expected format." + version = struct.unpack("len(self._filenames[self._file_idx :]): + if not self._wrap: + raise StopIteration + self._file_idx = 0 + + for i in range(self._n_chunks): + filename = self._filenames[self._file_idx + i] + if self._dtype is None: + self._dtype, self._chunk_size = self._read_header(filename) + self._n_blocks = self._chunk_size // self._block_size + # TODO: check header matches with previous files + mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE) + self._mmaps.append(mmap) + self._buffers.append(memoryview(mmap)) + + self._file_idx += self._n_chunks + n_all_blocks = self._n_chunks * self._n_blocks + + self._block_idxs = ( + self._rng.permutation(n_all_blocks) + if self._shuffle + else range(n_all_blocks) + ) + + self._curr_idx = 0 + + def __del__(self): + self._close_mmaps() + del self._mmaps + del self._buffers + + def __iter__(self): + return self + + def __next__(self): + if self._curr_idx >= len(self._block_idxs): + self._load_n_chunks() + # TODO: trigger fetching next next n_chunks if remote + block_idx = self._block_idxs[self._curr_idx] + chunk_id = block_idx // self._n_blocks + buffer = self._buffers[chunk_id] + elem_id = (block_idx % self._n_blocks) * self._block_size + offset = np.dtype(self._dtype).itemsize * elem_id + arr = np.frombuffer( + buffer, dtype=self._dtype, count=self._block_size, offset=offset + ) + self._curr_idx += 1 + return torch.from_numpy(arr.astype(np.int64)) + + +class CombinedDataset(IterableDataset): + def __init__(self, datasets, seed, weights=None): + self._seed = seed + self._datasets = datasets + self._weights = weights + n_datasets = len(datasets) + if weights is None: + self._weights = [1 / n_datasets] * n_datasets + else: + self._weights = [w / sum(weights) for w in weights] + + def __iter__(self): + return CombinedDatasetIterator( + self._datasets, self._seed, self._weights + ) + + +class CombinedDatasetIterator: + def __init__(self, datasets, seed, weights): + self._datasets = [iter(el) for el in datasets] + self._weights = weights + self._rng = random.Random(seed) + + def __next__(self): + (dataset,) = self._rng.choices( + self._datasets, weights=self._weights, k=1 + ) + return next(dataset) diff --git a/examples/llm_finetuning/lit_gpt/rmsnorm.py b/examples/llm_finetuning/lit_gpt/rmsnorm.py new file mode 100644 index 00000000000..108288128f7 --- /dev/null +++ b/examples/llm_finetuning/lit_gpt/rmsnorm.py @@ -0,0 +1,40 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import torch + + +class RMSNorm(torch.nn.Module): + """Root Mean Square Layer Normalization. + + Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: + https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. + """ + + def __init__( + self, + size: int, + dim: int = -1, + eps: float = 1e-6, + add_unit_offset: bool = False, + ) -> None: + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(size)) + self.eps = eps + self.dim = dim + self.add_unit_offset = add_unit_offset + + def forward(self, x: torch.Tensor) -> torch.Tensor: + dtype = x.dtype + x = x.float() + # NOTE: the original RMSNorm paper implementation is not equivalent + norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) + x_normed = x * torch.rsqrt(norm_x + self.eps) + x_normed = x_normed.to(dtype=dtype) + if self.add_unit_offset: + # Gemma model requires a unit offset + # https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L176 + return x_normed * (1 + self.weight) + return x_normed * self.weight + + def reset_parameters(self) -> None: + torch.nn.init.ones_(self.weight) diff --git a/examples/llm_finetuning/lit_gpt/tokenizer.py b/examples/llm_finetuning/lit_gpt/tokenizer.py new file mode 100644 index 00000000000..f2832ce61c2 --- /dev/null +++ b/examples/llm_finetuning/lit_gpt/tokenizer.py @@ -0,0 +1,136 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import json +from pathlib import Path +from typing import Optional, Union + +import torch + + +class Tokenizer: + def __init__(self, checkpoint_dir: Union[Path, str]) -> None: + checkpoint_dir = Path(checkpoint_dir) + if not checkpoint_dir.exists(): + raise NotADirectoryError( + f"The checkpoint directory does not exist: {str(checkpoint_dir)}" + ) + + self.use_bos = self.check_if_bos_token_used(checkpoint_dir) + self.bos_id = None + self.eos_id = None + + # some checkpoints have both files, `.model` takes precedence + if (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file(): + from sentencepiece import SentencePieceProcessor + + self.processor = SentencePieceProcessor( + model_file=str(vocabulary_path) + ) + self.backend = "sentencepiece" + self.bos_id = self.processor.bos_id() + self.eos_id = self.processor.eos_id() + + elif (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file(): + from tokenizers import Tokenizer as HFTokenizer + + self.processor = HFTokenizer.from_file(str(vocabulary_path)) + self.backend = "huggingface" + + if ( + special_tokens_path := checkpoint_dir / "tokenizer_config.json" + ).is_file(): + with open(special_tokens_path) as fp: + config = json.load(fp) + bos_token = config.get("bos_token") + self.bos_id = ( + self.token_to_id(bos_token) + if bos_token is not None + else None + ) + eos_token = config.get("eos_token") + self.eos_id = ( + self.token_to_id(eos_token) + if eos_token is not None + else None + ) + if ( + special_tokens_path := checkpoint_dir + / "generation_config.json" + ).is_file(): + with open(special_tokens_path) as fp: + config = json.load(fp) + if self.bos_id is None: + self.bos_id = config.get("bos_token_id") + if self.eos_id is None: + self.eos_id = config.get("eos_token_id") + else: + raise NotImplementedError + + @property + def vocab_size(self) -> int: + if self.backend == "huggingface": + return self.processor.get_vocab_size(with_added_tokens=False) + if self.backend == "sentencepiece": + return self.processor.vocab_size() + raise RuntimeError + + def token_to_id(self, token: str) -> int: + if self.backend == "huggingface": + id_ = self.processor.token_to_id(token) + elif self.backend == "sentencepiece": + id_ = self.processor.piece_to_id(token) + else: + raise RuntimeError + if id_ is None: + raise ValueError(f"token {token!r} not found in the collection.") + return id_ + + def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool: + if not ( + tokenizer_config_path := checkpoint_dir / "tokenizer_config.json" + ).is_file(): + return False + with open(tokenizer_config_path) as fp: + config = json.load(fp) + if any( + config.get(check, False) + for check in ("add_bos_token", "add_prefix_space") + ): + return True + # for examples that also use the Llama tokenizer, but do not have or set add_bos_token to True. + # ex: https://huggingface.co/stabilityai/StableBeluga2/blob/main/tokenizer_config.json#L2 + return ( + config.get("add_bos_token") is None + and config.get("tokenizer_class") == "LlamaTokenizer" + ) + + def encode( + self, + string: str, + device: Optional[torch.device] = None, + bos: Optional[bool] = None, + eos: bool = False, + max_length: int = -1, + ) -> torch.Tensor: + if self.backend == "huggingface": + tokens = self.processor.encode(string).ids + elif self.backend == "sentencepiece": + tokens = self.processor.encode(string) + else: + raise RuntimeError + if bos or (bos is None and self.use_bos): + bos_id = self.bos_id + if bos_id is None: + raise NotImplementedError( + "This tokenizer does not have a defined a bos token" + ) + tokens = [bos_id] + tokens + if eos: + tokens = tokens + [self.eos_id] + if max_length > 0: + tokens = tokens[:max_length] + return torch.tensor(tokens, dtype=torch.int, device=device) + + def decode(self, tensor: torch.Tensor) -> str: + tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist() + return self.processor.decode(tokens) diff --git a/examples/llm_finetuning/lit_gpt/utils.py b/examples/llm_finetuning/lit_gpt/utils.py new file mode 100644 index 00000000000..ba4706ff473 --- /dev/null +++ b/examples/llm_finetuning/lit_gpt/utils.py @@ -0,0 +1,477 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +"""Utility functions for training and inference.""" + +import math +import pickle +import sys +from io import BytesIO +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Mapping, + Optional, + TypeVar, + Union, +) + +import lightning as L +import torch +import torch.nn as nn +import torch.utils._device +from lightning.fabric.strategies import FSDPStrategy +from lightning.fabric.utilities.load import _lazy_load as lazy_load +from torch.serialization import normalize_storage_type +from typing_extensions import Self + +if TYPE_CHECKING: + from lit_gpt import GPT + + +def find_multiple(n: int, k: int) -> int: + assert k > 0 + if n % k == 0: + return n + return n + k - (n % k) + + +def num_parameters( + module: nn.Module, requires_grad: Optional[bool] = None +) -> int: + total = 0 + for p in module.parameters(): + if requires_grad is None or p.requires_grad == requires_grad: + if hasattr(p, "quant_state"): + # bitsandbytes 4bit layer support + total += math.prod(p.quant_state[1]) + else: + total += p.numel() + return total + + +def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None: + files = { + "lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(), + "lit_config.json": (checkpoint_dir / "lit_config.json").is_file(), + "tokenizer.json OR tokenizer.model": ( + checkpoint_dir / "tokenizer.json" + ).is_file() + or (checkpoint_dir / "tokenizer.model").is_file(), + "tokenizer_config.json": ( + checkpoint_dir / "tokenizer_config.json" + ).is_file(), + } + if checkpoint_dir.is_dir(): + if all(files.values()): + # we're good + return + problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}" + else: + problem = " is not a checkpoint directory" + + # list locally available checkpoints + available = list(Path("checkpoints").glob("*/*")) + if available: + options = "\n --checkpoint_dir ".join( + [""] + [repr(str(p.resolve())) for p in available] + ) + extra = f"\nYou have downloaded locally:{options}\n" + else: + extra = "" + + error_message = ( + f"--checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}." + "\nFind download instructions at https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials\n" + f"{extra}\nSee all download options by running:\n python scripts/download.py" + ) + print(error_message, file=sys.stderr) + raise SystemExit(1) + + +class SavingProxyForStorage: + def __init__(self, obj, saver, protocol_version=5): + self.protocol_version = protocol_version + self.saver = saver + if not ( + isinstance(obj, torch.storage.TypedStorage) + or torch.is_storage(obj) + ): + raise TypeError(f"expected storage, not {type(obj)}") + + # this logic is taken from PyTorch 2.0+ torch/serialization.py + if isinstance(obj, torch.storage.TypedStorage): + # PT upstream wants to deprecate this eventually... + storage = obj._untyped_storage + storage_type_str = obj._pickle_storage_type() + storage_type = getattr(torch, storage_type_str) + storage_numel = obj._size() + else: + storage = obj + storage_type = normalize_storage_type(type(obj)) + storage_numel = storage.nbytes() + + storage_key = saver._write_storage_and_return_key(storage) + location = torch.serialization.location_tag(storage) + + self.storage_info = ( + "storage", + storage_type, + storage_key, + location, + storage_numel, + ) + + def __reduce_ex__(self, protocol_version): + assert False, "this should be handled with out of band" + + +class SavingProxyForTensor: + def __init__(self, tensor, saver, protocol_version=5): + self.protocol_version = protocol_version + self.reduce_ret_fn, reduce_args = tensor.__reduce_ex__( + protocol_version + ) + if reduce_args[0] == torch._utils._rebuild_tensor_v2: + # for Tensors with Python attributes + (a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args + assert isinstance( + storage, torch.storage.TypedStorage + ), "Please check for updates" + storage_proxy = SavingProxyForStorage( + storage, saver, protocol_version=protocol_version + ) + self.reduce_args = ( + a0, + a1, + (storage_proxy, *a2_other), + *other_reduce_args, + ) + else: + (storage, *other_reduce_args) = reduce_args + assert isinstance( + storage, torch.storage.TypedStorage + ), "Please check for updates" + storage_proxy = SavingProxyForStorage( + storage, saver, protocol_version=protocol_version + ) + self.reduce_args = (storage_proxy, *other_reduce_args) + + def __reduce_ex__(self, protocol_version): + if protocol_version != self.protocol_version: + raise RuntimeError( + f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}" + ) + return self.reduce_ret_fn, self.reduce_args + + +class IncrementalPyTorchPickler(pickle.Pickler): + def __init__(self, saver, *args, **kwargs): + super().__init__(*args, **kwargs) + self.storage_dtypes = {} + self.saver = saver + self.id_map = {} + + # this logic is taken from PyTorch 2.0+ torch/serialization.py + def persistent_id(self, obj): + # FIXME: the docs say that persistent_id should only return a string + # but torch store returns tuples. This works only in the binary protocol + # see + # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects + # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 + if isinstance(obj, SavingProxyForStorage): + return obj.storage_info + + if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage( + obj + ): + if isinstance(obj, torch.storage.TypedStorage): + # TODO: Once we decide to break serialization FC, this case + # can be deleted + storage = obj._untyped_storage + storage_dtype = obj.dtype + storage_type_str = obj._pickle_storage_type() + storage_type = getattr(torch, storage_type_str) + storage_numel = obj._size() + + else: + storage = obj + storage_dtype = torch.uint8 + storage_type = normalize_storage_type(type(obj)) + storage_numel = storage.nbytes() + + # If storage is allocated, ensure that any other saved storages + # pointing to the same data all have the same dtype. If storage is + # not allocated, don't perform this check + if storage.data_ptr() != 0: + if storage.data_ptr() in self.storage_dtypes: + if ( + storage_dtype + != self.storage_dtypes[storage.data_ptr()] + ): + raise RuntimeError( + "Cannot save multiple tensors or storages that view the same data as different types" + ) + else: + self.storage_dtypes[storage.data_ptr()] = storage_dtype + + storage_key = self.id_map.get(storage._cdata) + if storage_key is None: + storage_key = self.saver._write_storage_and_return_key(storage) + self.id_map[storage._cdata] = storage_key + location = torch.serialization.location_tag(storage) + + return ( + "storage", + storage_type, + storage_key, + location, + storage_numel, + ) + + return None + + +class incremental_save: + def __init__(self, name): + self.name = name + self.zipfile = torch._C.PyTorchFileWriter(str(name)) + self.has_saved = False + self.next_key = 0 + + def __enter__(self): + return self + + def store_early(self, tensor): + if isinstance(tensor, torch.Tensor): + return SavingProxyForTensor(tensor, self) + raise TypeError(f"can only store tensors early, not {type(tensor)}") + + def save(self, obj): + if self.has_saved: + raise RuntimeError("have already saved") + # Write the pickle data for `obj` + data_buf = BytesIO() + pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5) + pickler.dump(obj) + data_value = data_buf.getvalue() + self.zipfile.write_record("data.pkl", data_value, len(data_value)) + self.has_saved = True + + def _write_storage_and_return_key(self, storage): + if self.has_saved: + raise RuntimeError("have already saved") + key = self.next_key + self.next_key += 1 + name = f"data/{key}" + if storage.device.type != "cpu": + storage = storage.cpu() + num_bytes = storage.nbytes() + self.zipfile.write_record(name, storage.data_ptr(), num_bytes) + return key + + def __exit__(self, type, value, traceback): + self.zipfile.write_end_of_file() + + +T = TypeVar("T") + + +def chunked_cross_entropy( + logits: Union[torch.Tensor, List[torch.Tensor]], + targets: torch.Tensor, + chunk_size: int = 128, + ignore_index: int = -1, +) -> torch.Tensor: + # with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate + # the memory usage in fine-tuning settings with low number of parameters. + # as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing + # the memory spike's magnitude + + # lm_head was chunked (we are fine-tuning) + if isinstance(logits, list): + # don't want to chunk cross entropy + if chunk_size == 0: + logits = torch.cat(logits, dim=1) + logits = logits.reshape(-1, logits.size(-1)) + targets = targets.reshape(-1) + return torch.nn.functional.cross_entropy( + logits, targets, ignore_index=ignore_index + ) + + # chunk cross entropy + logit_chunks = [ + logit_chunk.reshape(-1, logit_chunk.size(-1)) + for logit_chunk in logits + ] + target_chunks = [ + target_chunk.reshape(-1) + for target_chunk in targets.split(logits[0].size(1), dim=1) + ] + loss_chunks = [ + torch.nn.functional.cross_entropy( + logit_chunk, + target_chunk, + ignore_index=ignore_index, + reduction="none", + ) + for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) + ] + non_masked_elems = (targets != ignore_index).sum() + return torch.cat(loss_chunks).sum() / max(1, non_masked_elems) + + # no chunking at all + logits = logits.reshape(-1, logits.size(-1)) + targets = targets.reshape(-1) + if chunk_size == 0: + return torch.nn.functional.cross_entropy( + logits, targets, ignore_index=ignore_index + ) + + # lm_head wasn't chunked, chunk cross entropy + logit_chunks = logits.split(chunk_size) + target_chunks = targets.split(chunk_size) + loss_chunks = [ + torch.nn.functional.cross_entropy( + logit_chunk, + target_chunk, + ignore_index=ignore_index, + reduction="none", + ) + for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) + ] + non_masked_elems = (targets != ignore_index).sum() + return torch.cat(loss_chunks).sum() / max(1, non_masked_elems) + + +def map_old_state_dict_weights( + state_dict: Dict, mapping: Mapping, prefix: str +) -> Dict: + for checkpoint_name, attribute_name in mapping.items(): + full_checkpoint_name = prefix + checkpoint_name + if full_checkpoint_name in state_dict: + full_attribute_name = prefix + attribute_name + state_dict[full_attribute_name] = state_dict.pop( + full_checkpoint_name + ) + return state_dict + + +def get_default_supported_precision(training: bool) -> str: + """Return default precision that is supported by the hardware: either `bf16` or `16`. + + Args: + training: `-mixed` or `-true` version of the precision to use + + Returns: + default precision that is suitable for the task and is supported by the hardware + """ + from lightning.fabric.accelerators import MPSAccelerator + + if MPSAccelerator.is_available() or ( + torch.cuda.is_available() and not torch.cuda.is_bf16_supported() + ): + return "16-mixed" if training else "16-true" + return "bf16-mixed" if training else "bf16-true" + + +def load_checkpoint( + fabric: L.Fabric, + model: nn.Module, + checkpoint_path: Path, + strict: bool = True, +) -> None: + if isinstance(fabric.strategy, FSDPStrategy): + fabric.load_raw(checkpoint_path, model, strict=strict) + else: + state_dict = lazy_load(checkpoint_path) + state_dict = state_dict.get("model", state_dict) + model.load_state_dict(state_dict, strict=strict) + + +def flops_per_param( + max_seq_length: int, n_layer: int, n_embd: int, n_params: int +) -> int: + flops_per_token = ( + 2 * n_params + ) # each parameter is used for a MAC (2 FLOPS) per network operation + # this assumes that all samples have a fixed length equal to the block size + # which is most likely false during finetuning + flops_per_seq = flops_per_token * max_seq_length + attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2)) + return flops_per_seq + attn_flops_per_seq + + +def estimate_flops(model: "GPT", training: bool) -> int: + """Measures estimated FLOPs for MFU. + + Refs: + * https://ar5iv.labs.arxiv.org/html/2205.05198#A1 + * https://ar5iv.labs.arxiv.org/html/2204.02311#A2 + """ + # using all parameters for this is a naive over estimation because not all model parameters actually contribute to + # this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage + # (~10%) compared to the measured FLOPs, making those lower but more realistic. + # For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper. + n_trainable_params = num_parameters(model, requires_grad=True) + trainable_flops = flops_per_param( + model.max_seq_length, + model.config.n_layer, + model.config.n_embd, + n_trainable_params, + ) + # forward + backward + gradients (assumes no gradient accumulation) + ops_per_step = 3 if training else 1 + n_frozen_params = num_parameters(model, requires_grad=False) + frozen_flops = flops_per_param( + model.max_seq_length, + model.config.n_layer, + model.config.n_embd, + n_frozen_params, + ) + # forward + backward + frozen_ops_per_step = 2 if training else 1 + return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops + + +class CycleIterator: + """An iterator that cycles through an iterable indefinitely. + + Example: + >>> iterator = CycleIterator([1, 2, 3]) + >>> [next(iterator) for _ in range(5)] + [1, 2, 3, 1, 2] + + Note: + Unlike ``itertools.cycle``, this iterator does not cache the values of the iterable. + """ + + def __init__(self, iterable: Iterable) -> None: + self.iterable = iterable + self.epoch = 0 + self._iterator = None + + def __next__(self) -> Any: + if self._iterator is None: + self._iterator = iter(self.iterable) + try: + return next(self._iterator) + except StopIteration: + self._iterator = iter(self.iterable) + self.epoch += 1 + return next(self._iterator) + + def __iter__(self) -> Self: + return self + + +def CLI(*args: Any, **kwargs: Any) -> Any: + from jsonargparse import CLI, set_docstring_parse_options + + set_docstring_parse_options(attribute_docstrings=True) + + kwargs.setdefault("as_positional", False) + return CLI(*args, **kwargs) diff --git a/examples/llm_finetuning/materializers/__init__.py b/examples/llm_finetuning/materializers/__init__.py new file mode 100644 index 00000000000..757bd8418a5 --- /dev/null +++ b/examples/llm_finetuning/materializers/__init__.py @@ -0,0 +1,16 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/examples/llm_finetuning/materializers/directory_materializer.py b/examples/llm_finetuning/materializers/directory_materializer.py new file mode 100644 index 00000000000..4adc7b4a10a --- /dev/null +++ b/examples/llm_finetuning/materializers/directory_materializer.py @@ -0,0 +1,71 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from pathlib import Path +from tempfile import mkdtemp +from typing import Any, ClassVar, Tuple, Type + +from zenml.enums import ArtifactType +from zenml.io import fileio +from zenml.materializers.base_materializer import BaseMaterializer + + +class DirectoryMaterializer(BaseMaterializer): + """Materializer to store local directories in the artifact store.""" + + ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (Path,) + ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA + + def load(self, data_type: Type[Any]) -> Any: + """Copy the artifact files to a local temp directory. + + Args: + data_type: Unused. + + Returns: + Path to the local directory that contains the artifact files. + """ + directory = mkdtemp(prefix="zenml-artifact") + self._copy_directory(src=self.uri, dst=directory) + return Path(directory) + + def save(self, data: Any) -> None: + """Store the directory in the artifact store. + + Args: + data: Path to a local directory to store. + """ + assert isinstance(data, Path) + self._copy_directory(src=str(data), dst=self.uri) + + @staticmethod + def _copy_directory(src: str, dst: str) -> None: + """Recursively copy a directory. + + Args: + src: The directory to copy. + dst: Where to copy the directory to. + """ + for src_dir, _, files in fileio.walk(src): + dst_dir = os.path.join(dst, os.path.relpath(src_dir, src)) + fileio.makedirs(dst_dir) + + for file in files: + src_file = os.path.join(src_dir, file) + dst_file = os.path.join(dst_dir, file) + fileio.copy(src_file, dst_file) diff --git a/examples/llm_finetuning/pipelines/__init__.py b/examples/llm_finetuning/pipelines/__init__.py new file mode 100644 index 00000000000..2d7c5390a7d --- /dev/null +++ b/examples/llm_finetuning/pipelines/__init__.py @@ -0,0 +1,21 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pipelines.evaluate import llm_lora_evaluation +from pipelines.feature_engineering import llm_lora_feature_engineering +from pipelines.finetuning import llm_lora_finetuning +from pipelines.merge import llm_lora_merging diff --git a/examples/llm_finetuning/pipelines/evaluate.py b/examples/llm_finetuning/pipelines/evaluate.py new file mode 100644 index 00000000000..41feb5bfa72 --- /dev/null +++ b/examples/llm_finetuning/pipelines/evaluate.py @@ -0,0 +1,33 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from steps import evaluate + +from zenml import pipeline +from zenml.config import DockerSettings + + +@pipeline( + settings={ + "docker": DockerSettings( + apt_packages=["git"], requirements="requirements.txt" + ) + } +) +def llm_lora_evaluation() -> None: + """Pipeline to evaluate a LoRA fine-tuned LLM.""" + evaluate() diff --git a/examples/llm_finetuning/pipelines/feature_engineering.py b/examples/llm_finetuning/pipelines/feature_engineering.py new file mode 100644 index 00000000000..6630bd1fb86 --- /dev/null +++ b/examples/llm_finetuning/pipelines/feature_engineering.py @@ -0,0 +1,33 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from steps import feature_engineering + +from zenml import pipeline +from zenml.config import DockerSettings + + +@pipeline( + settings={ + "docker": DockerSettings( + apt_packages=["git"], requirements="requirements.txt" + ) + } +) +def llm_lora_feature_engineering() -> None: + """Feature engineering pipeline.""" + feature_engineering() diff --git a/examples/llm_finetuning/pipelines/finetuning.py b/examples/llm_finetuning/pipelines/finetuning.py new file mode 100644 index 00000000000..faa7d185fda --- /dev/null +++ b/examples/llm_finetuning/pipelines/finetuning.py @@ -0,0 +1,44 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional + +from steps import finetune + +from zenml import get_pipeline_context, pipeline +from zenml.config import DockerSettings + + +@pipeline( + settings={ + "docker": DockerSettings( + apt_packages=["git"], requirements="requirements.txt" + ) + } +) +def llm_lora_finetuning( + dataset_artifact_name: Optional[str] = None, + dataset_artifact_version: Optional[str] = None, +) -> None: + """Pipeline to finetune LLMs using LoRA.""" + dataset_directory = None + if dataset_artifact_name: + dataset_directory = get_pipeline_context().model.get_artifact( + name=dataset_artifact_name, version=dataset_artifact_version + ) + + finetune(dataset_directory=dataset_directory) diff --git a/examples/llm_finetuning/pipelines/merge.py b/examples/llm_finetuning/pipelines/merge.py new file mode 100644 index 00000000000..20c1c1f36f1 --- /dev/null +++ b/examples/llm_finetuning/pipelines/merge.py @@ -0,0 +1,33 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from steps import merge + +from zenml import pipeline +from zenml.config import DockerSettings + + +@pipeline( + settings={ + "docker": DockerSettings( + apt_packages=["git"], requirements="requirements.txt" + ) + } +) +def llm_lora_merging() -> None: + """Pipeline to merge LLMs with adapters.""" + merge() diff --git a/examples/llm_finetuning/requirements.txt b/examples/llm_finetuning/requirements.txt new file mode 100644 index 00000000000..ad19fe96de8 --- /dev/null +++ b/examples/llm_finetuning/requirements.txt @@ -0,0 +1,17 @@ +zenml +torch>=2.2.0 +lightning @ git+https://github.com/Lightning-AI/lightning@ed367ca675861cdf40dbad2e4d66f7eee2ec50af +jsonargparse[signatures] # CLI +bitsandbytes==0.41.0 # quantization +scipy # required by bitsandbytes +sentencepiece # llama-based models +tokenizers # pythia, falcon, redpajama +datasets # eval +requests # scripts/prepare_* +zstandard # scripts/prepare_redpajama.py, scripts/prepare_starcoder.py +pandas # scripts/prepare_csv.py, scripts/prepare_starcoder.py +pyarrow # scripts/prepare_starcoder.py +# eval +git+https://github.com/EleutherAI/lm-evaluation-harness.git@115206dc89dad67b8beaa90051fb52db77f0a529 +# scripts/prepare_slimpajama.py, scripts/prepare_starcoder.py, pretrain/tinyllama.py +lightning[data] @ git+https://github.com/Lightning-AI/lightning@ed367ca675861cdf40dbad2e4d66f7eee2ec50af diff --git a/examples/llm_finetuning/run.py b/examples/llm_finetuning/run.py new file mode 100644 index 00000000000..5bfd379ba1d --- /dev/null +++ b/examples/llm_finetuning/run.py @@ -0,0 +1,132 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from typing import Optional + +import click +from pipelines import ( + llm_lora_evaluation, + llm_lora_feature_engineering, + llm_lora_finetuning, + llm_lora_merging, +) + +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +@click.command( + help=""" +ZenML LLM Finetuning project CLI v0.1.0. + +Run the ZenML LLM Finetuning project LLM LoRA finetuning pipelines. + +Examples: + + \b + # Run the feature feature engineering pipeline + python run.py --feature-pipeline + + \b + # Run the finetuning pipeline + python run.py --finetuning-pipeline + + \b + # Run the merging pipeline + python run.py --merging-pipeline + + \b + # Run the evaluation pipeline + python run.py --eval-pipeline +""" +) +@click.option( + "--config", + type=str, + default=None, + help="Path to the YAML config file.", +) +@click.option( + "--feature-pipeline", + is_flag=True, + default=False, + help="Whether to run the pipeline that creates the dataset.", +) +@click.option( + "--finetuning-pipeline", + is_flag=True, + default=False, + help="Whether to run the pipeline that finetunes the model.", +) +@click.option( + "--merging-pipeline", + is_flag=True, + default=False, + help="Whether to run the pipeline that merges the model and adapter.", +) +@click.option( + "--eval-pipeline", + is_flag=True, + default=False, + help="Whether to run the pipeline that evaluates the model.", +) +@click.option( + "--no-cache", + is_flag=True, + default=False, + help="Disable caching for the pipeline run.", +) +def main( + config: Optional[str] = None, + feature_pipeline: bool = False, + finetuning_pipeline: bool = False, + merging_pipeline: bool = False, + eval_pipeline: bool = False, + no_cache: bool = False, +): + """Main entry point for the pipeline execution. + + Args: + no_cache: If `True` cache will be disabled. + """ + config_folder = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "configs", + ) + pipeline_args = {"enable_cache": not no_cache} + if not config: + raise RuntimeError("Config file is required to run a pipeline.") + + pipeline_args["config_path"] = os.path.join(config_folder, config) + + if feature_pipeline: + llm_lora_feature_engineering.with_options(**pipeline_args)() + + if finetuning_pipeline: + llm_lora_finetuning.with_options(**pipeline_args)() + + if merging_pipeline: + llm_lora_merging.with_options(**pipeline_args)() + + if eval_pipeline: + llm_lora_evaluation.with_options(**pipeline_args)() + + +if __name__ == "__main__": + main() diff --git a/examples/llm_finetuning/scripts/convert_hf_checkpoint.py b/examples/llm_finetuning/scripts/convert_hf_checkpoint.py new file mode 100644 index 00000000000..14d0ff6fb73 --- /dev/null +++ b/examples/llm_finetuning/scripts/convert_hf_checkpoint.py @@ -0,0 +1,377 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import gc +import json +import sys +from collections import defaultdict +from dataclasses import asdict +from functools import partial +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import torch +from lightning.fabric.utilities.load import ( + _NotYetLoadedTensor as NotYetLoadedTensor, +) + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt import Config +from lit_gpt.utils import incremental_save, lazy_load + + +def copy_weights_gpt_neox( + state_dict: Dict[str, torch.Tensor], + hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, + dtype: Optional[torch.dtype] = None, +) -> None: + weight_map = { + "gpt_neox.embed_in.weight": "transformer.wte.weight", + "gpt_neox.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", + "gpt_neox.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", + "gpt_neox.layers.{}.attention.query_key_value.bias": "transformer.h.{}.attn.attn.bias", + "gpt_neox.layers.{}.attention.query_key_value.weight": "transformer.h.{}.attn.attn.weight", + "gpt_neox.layers.{}.attention.dense.bias": "transformer.h.{}.attn.proj.bias", + "gpt_neox.layers.{}.attention.dense.weight": "transformer.h.{}.attn.proj.weight", + "gpt_neox.layers.{}.attention.rotary_emb.inv_freq": None, + "gpt_neox.layers.{}.attention.bias": None, + "gpt_neox.layers.{}.attention.masked_bias": None, + "gpt_neox.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.norm_2.bias", + "gpt_neox.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", + "gpt_neox.layers.{}.mlp.dense_h_to_4h.bias": "transformer.h.{}.mlp.fc.bias", + "gpt_neox.layers.{}.mlp.dense_h_to_4h.weight": "transformer.h.{}.mlp.fc.weight", + "gpt_neox.layers.{}.mlp.dense_4h_to_h.bias": "transformer.h.{}.mlp.proj.bias", + "gpt_neox.layers.{}.mlp.dense_4h_to_h.weight": "transformer.h.{}.mlp.proj.weight", + "gpt_neox.final_layer_norm.bias": "transformer.ln_f.bias", + "gpt_neox.final_layer_norm.weight": "transformer.ln_f.weight", + "embed_out.weight": "lm_head.weight", + } + + for name, param in hf_weights.items(): + if "gpt_neox.layers" in name: + from_name, number = layer_template(name, 2) + to_name = weight_map[from_name] + if to_name is None: + continue + to_name = to_name.format(number) + else: + to_name = weight_map[name] + param = load_param(param, name, dtype) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + +def copy_weights_falcon( + model_name: str, + state_dict: Dict[str, torch.Tensor], + hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, + dtype: Optional[torch.dtype] = None, +) -> None: + weight_map = { + "transformer.word_embeddings.weight": "transformer.wte.weight", + "transformer.h.{}.self_attention.query_key_value.weight": "transformer.h.{}.attn.attn.weight", + "transformer.h.{}.self_attention.dense.weight": "transformer.h.{}.attn.proj.weight", + "transformer.h.{}.mlp.dense_h_to_4h.weight": "transformer.h.{}.mlp.fc.weight", + "transformer.h.{}.mlp.dense_4h_to_h.weight": "transformer.h.{}.mlp.proj.weight", + "transformer.ln_f.bias": "transformer.ln_f.bias", + "transformer.ln_f.weight": "transformer.ln_f.weight", + "lm_head.weight": "lm_head.weight", + } + # the original model definition is different for each size + if "7b" in model_name: + weight_map.update( + { + "transformer.h.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", + "transformer.h.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", + } + ) + elif "40b" in model_name or "180B" in model_name: + weight_map.update( + { + "transformer.h.{}.ln_attn.bias": "transformer.h.{}.norm_1.bias", + "transformer.h.{}.ln_attn.weight": "transformer.h.{}.norm_1.weight", + "transformer.h.{}.ln_mlp.bias": "transformer.h.{}.norm_2.bias", + "transformer.h.{}.ln_mlp.weight": "transformer.h.{}.norm_2.weight", + } + ) + else: + raise NotImplementedError + + for name, param in hf_weights.items(): + if "transformer.h" in name: + from_name, number = layer_template(name, 2) + to_name = weight_map[from_name].format(number) + else: + to_name = weight_map[name] + param = load_param(param, name, dtype) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + +def copy_weights_hf_llama( + config: Config, + qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], + state_dict: Dict[str, torch.Tensor], + hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, + dtype: Optional[torch.dtype] = None, +) -> None: + weight_map = { + "model.embed_tokens.weight": "transformer.wte.weight", + "model.layers.{}.input_layernorm.weight": "transformer.h.{l}.norm_1.weight", + "model.layers.{}.input_layernorm.bias": "transformer.h.{l}.norm_1.bias", + "model.layers.{}.self_attn.q_proj.weight": None, + "model.layers.{}.self_attn.k_proj.weight": None, + "model.layers.{}.self_attn.v_proj.weight": None, + "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{l}.attn.proj.weight", + "model.layers.{}.self_attn.rotary_emb.inv_freq": None, + "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{l}.norm_2.weight", + "model.layers.{}.post_attention_layernorm.bias": "transformer.h.{l}.norm_2.bias", + "model.norm.weight": "transformer.ln_f.weight", + "model.norm.bias": "transformer.ln_f.bias", + "lm_head.weight": "lm_head.weight", + } + if config._mlp_class == "LLaMAMoE": + weight_map.update( + { + "model.layers.{}.block_sparse_moe.gate.weight": "transformer.h.{l}.mlp.gate.weight", + "model.layers.{}.block_sparse_moe.experts.{}.w1.weight": "transformer.h.{l}.mlp.experts.{e}.fc_1.weight", + "model.layers.{}.block_sparse_moe.experts.{}.w3.weight": "transformer.h.{l}.mlp.experts.{e}.fc_2.weight", + "model.layers.{}.block_sparse_moe.experts.{}.w2.weight": "transformer.h.{l}.mlp.experts.{e}.proj.weight", + } + ) + elif config._mlp_class in ("LLaMAMLP", "GemmaMLP"): + weight_map.update( + { + "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{l}.mlp.fc_1.weight", + "model.layers.{}.mlp.up_proj.weight": "transformer.h.{l}.mlp.fc_2.weight", + "model.layers.{}.mlp.down_proj.weight": "transformer.h.{l}.mlp.proj.weight", + } + ) + else: + raise NotImplementedError + + for name, param in hf_weights.items(): + if "model.layers" in name: + from_name, l = layer_template(name, 2) + e = None + if "block_sparse_moe.experts" in name: + from_name, e = layer_template(from_name, 5) + qkv = qkv_weights.setdefault(l, [None, None, None]) + if "q_proj" in name: + qkv[0] = param + elif "k_proj" in name: + qkv[1] = param + elif "v_proj" in name: + qkv[2] = param + to_name = weight_map[from_name] + if to_name is None: + continue + to_name = to_name.format(l=l, e=e) + else: + to_name = weight_map[name] + param = load_param(param, name, dtype) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + if "lm_head.weight" not in state_dict: + state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"] + + # convert separate q, k, v matrices into an interleaved qkv + for i, (q, k, v) in list(qkv_weights.items()): + if q is None or k is None or v is None: + # split across different .bin files + continue + q = load_param(q, f"layer {i} q", dtype) + k = load_param(k, f"layer {i} k", dtype) + v = load_param(v, f"layer {i} v", dtype) + q_per_kv = config.n_head // config.n_query_groups + qs = torch.split(q, config.head_size * q_per_kv) + ks = torch.split(k, config.head_size) + vs = torch.split(v, config.head_size) + cycled = [t for group in zip(qs, ks, vs) for t in group] + qkv = torch.cat(cycled) + state_dict[f"transformer.h.{i}.attn.attn.weight"] = qkv + del qkv_weights[i] + + +def copy_weights_phi( + config: Config, + qkv_weights: dict, + state_dict: Dict[str, torch.Tensor], + hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, + dtype: Optional[torch.dtype] = None, +) -> None: + if any( + layer_name.startswith(("layers.", "transformer.")) + for layer_name in hf_weights + ): + raise ValueError( + "You are using an outdated Phi checkpoint. Please reload it as described in 'tutorials/download_phi.md'" + ) + + weight_map = { + "model.embed_tokens.weight": "transformer.wte.weight", + "model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", + "model.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", + "model.layers.{}.self_attn.q_proj.weight": None, + "model.layers.{}.self_attn.q_proj.bias": None, + "model.layers.{}.self_attn.k_proj.weight": None, + "model.layers.{}.self_attn.k_proj.bias": None, + "model.layers.{}.self_attn.v_proj.weight": None, + "model.layers.{}.self_attn.v_proj.bias": None, + "model.layers.{}.self_attn.dense.weight": "transformer.h.{}.attn.proj.weight", + "model.layers.{}.self_attn.dense.bias": "transformer.h.{}.attn.proj.bias", + "model.layers.{}.mlp.fc1.weight": "transformer.h.{}.mlp.fc.weight", + "model.layers.{}.mlp.fc1.bias": "transformer.h.{}.mlp.fc.bias", + "model.layers.{}.mlp.fc2.weight": "transformer.h.{}.mlp.proj.weight", + "model.layers.{}.mlp.fc2.bias": "transformer.h.{}.mlp.proj.bias", + "model.final_layernorm.weight": "transformer.ln_f.weight", + "model.final_layernorm.bias": "transformer.ln_f.bias", + "lm_head.weight": "lm_head.weight", + "lm_head.bias": "lm_head.bias", + } + + for name, param in hf_weights.items(): + if name.startswith("model.layers."): + from_name, l = layer_template(name, 2) + qkv = qkv_weights.setdefault(l, defaultdict(dict)) + if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): + weight_name, weight_type = from_name.split(".")[-2:] + qkv[weight_type][weight_name] = param + to_name = weight_map[from_name] + if to_name is None: + continue + to_name = to_name.format(l) + else: + to_name = weight_map[name] + param = load_param(param, name, dtype) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + for i in list(qkv_weights): + for weight_type in list(qkv_weights[i]): + qkv = qkv_weights[i][weight_type] + if len(qkv) != 3: + # split across different .bin files + continue + q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype) + k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype) + v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype) + q_per_kv = config.n_head // config.n_query_groups + qs = torch.split(q, config.head_size * q_per_kv) + ks = torch.split(k, config.head_size) + vs = torch.split(v, config.head_size) + cycled = [t for group in zip(qs, ks, vs) for t in group] + qkv = torch.cat(cycled) + state_dict[f"transformer.h.{i}.attn.attn.{weight_type}"] = qkv + del qkv_weights[i][weight_type] + + +def layer_template(layer_name: str, idx: int) -> Tuple[str, int]: + split = layer_name.split(".") + number = int(split[idx]) + split[idx] = "{}" + from_name = ".".join(split) + return from_name, number + + +def load_param( + param: Union[torch.Tensor, NotYetLoadedTensor], + name: str, + dtype: Optional[torch.dtype], +) -> torch.Tensor: + if hasattr(param, "_load_tensor"): + # support tensors loaded via `lazy_load()` + print(f"Loading {name!r} into RAM") + param = param._load_tensor() + if ( + dtype is not None + and type(dtype) is not NotYetLoadedTensor + and dtype != param.dtype + ): + print(f"Converting {name!r} from {param.dtype} to {dtype}") + param = param.to(dtype) + return param + + +@torch.inference_mode() +def convert_hf_checkpoint( + *, + checkpoint_dir: Path = Path( + "checkpoints/stabilityai/stablelm-base-alpha-3b" + ), + model_name: Optional[str] = None, + dtype: Optional[str] = None, +) -> None: + if model_name is None: + model_name = checkpoint_dir.name + if dtype is not None: + dtype = getattr(torch, dtype) + + config = Config.from_name(model_name) + config_dict = asdict(config) + print(f"Model config {config_dict}") + with open(checkpoint_dir / "lit_config.json", "w") as json_config: + json.dump(config_dict, json_config) + + if "falcon" in model_name: + copy_fn = partial(copy_weights_falcon, model_name) + elif config._mlp_class in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"): + # holder to reconstitute the split q, k, v + qkv_weights = {} + copy_fn = partial(copy_weights_hf_llama, config, qkv_weights) + elif "phi" in model_name: + # holder to reconstitute the split q, k, v + qkv_weights = {} + copy_fn = partial(copy_weights_phi, config, qkv_weights) + else: + copy_fn = copy_weights_gpt_neox + + # initialize a new empty state dict to hold our new weights + sd = {} + + # Load the json file containing weight mapping + pytorch_bin_map_json_path = checkpoint_dir / "pytorch_model.bin.index.json" + if ( + pytorch_bin_map_json_path.is_file() + ): # not all checkpoints have this file + with open(pytorch_bin_map_json_path) as json_map: + bin_index = json.load(json_map) + bin_files = { + checkpoint_dir / bin for bin in bin_index["weight_map"].values() + } + else: + bin_files = set(checkpoint_dir.glob("*.bin")) + # some checkpoints serialize the training arguments + bin_files = {f for f in bin_files if f.name != "training_args.bin"} + if not bin_files: + raise ValueError( + f"Expected {str(checkpoint_dir)!r} to contain .bin files" + ) + + with incremental_save(checkpoint_dir / "lit_model.pth") as saver: + # for checkpoints that split the QKV across several files, we need to keep all the bin files + # open, so we use `ExitStack` to close them all together at the end + for bin_file in sorted(bin_files): + print("Processing", bin_file) + hf_weights = lazy_load(bin_file) + copy_fn(sd, hf_weights, saver=saver, dtype=dtype) + gc.collect() + print("Saving converted checkpoint") + saver.save(sd) + + +if __name__ == "__main__": + from jsonargparse import CLI + + CLI(convert_hf_checkpoint) diff --git a/examples/llm_finetuning/scripts/convert_lit_checkpoint.py b/examples/llm_finetuning/scripts/convert_lit_checkpoint.py new file mode 100644 index 00000000000..1239e7d255d --- /dev/null +++ b/examples/llm_finetuning/scripts/convert_lit_checkpoint.py @@ -0,0 +1,284 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import gc +import sys +from functools import partial +from pathlib import Path +from typing import Dict, Optional, Tuple, Union + +import torch +from lightning.fabric.utilities.load import ( + _NotYetLoadedTensor as NotYetLoadedTensor, +) + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt import Config +from lit_gpt.utils import CLI, incremental_save, lazy_load + +from scripts.convert_hf_checkpoint import layer_template, load_param + + +def copy_weights_falcon( + model_name: str, + state_dict: Dict[str, torch.Tensor], + lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, +) -> None: + weight_map = { + "transformer.wte.weight": "transformer.word_embeddings.weight", + "transformer.h.{}.attn.attn.weight": "transformer.h.{}.self_attention.query_key_value.weight", + "transformer.h.{}.attn.proj.weight": "transformer.h.{}.self_attention.dense.weight", + "transformer.h.{}.mlp.fc.weight": "transformer.h.{}.mlp.dense_h_to_4h.weight", + "transformer.h.{}.mlp.proj.weight": "transformer.h.{}.mlp.dense_4h_to_h.weight", + "transformer.ln_f.bias": "transformer.ln_f.bias", + "transformer.ln_f.weight": "transformer.ln_f.weight", + "lm_head.weight": "lm_head.weight", + } + # the original model definition is different for each size + if "7b" in model_name: + weight_map.update( + { + "transformer.h.{}.norm_1.bias": "transformer.h.{}.input_layernorm.bias", + "transformer.h.{}.norm_1.weight": "transformer.h.{}.input_layernorm.weight", + } + ) + elif "40b" in model_name or "180B" in model_name: + weight_map.update( + { + "transformer.h.{}.norm_1.bias": "transformer.h.{}.ln_attn.bias", + "transformer.h.{}.norm_1.weight": "transformer.h.{}.ln_attn.weight", + "transformer.h.{}.norm_2.bias": "transformer.h.{}.ln_mlp.bias", + "transformer.h.{}.norm_2.weight": "transformer.h.{}.ln_mlp.weight", + } + ) + else: + raise NotImplementedError + + for name, param in lit_weights.items(): + if "transformer.h" in name: + from_name, number = layer_template(name, 2) + to_name = weight_map[from_name].format(number) + else: + to_name = weight_map[name] + param = load_param(param, name, None) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + +def copy_weights_gpt_neox( + state_dict: Dict[str, torch.Tensor], + lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, +) -> None: + weight_map = { + "transformer.wte.weight": "gpt_neox.embed_in.weight", + "transformer.h.{}.norm_1.bias": "gpt_neox.layers.{}.input_layernorm.bias", + "transformer.h.{}.norm_1.weight": "gpt_neox.layers.{}.input_layernorm.weight", + "transformer.h.{}.attn.attn.bias": "gpt_neox.layers.{}.attention.query_key_value.bias", + "transformer.h.{}.attn.attn.weight": "gpt_neox.layers.{}.attention.query_key_value.weight", + "transformer.h.{}.attn.proj.bias": "gpt_neox.layers.{}.attention.dense.bias", + "transformer.h.{}.attn.proj.weight": "gpt_neox.layers.{}.attention.dense.weight", + "transformer.h.{}.norm_2.bias": "gpt_neox.layers.{}.post_attention_layernorm.bias", + "transformer.h.{}.norm_2.weight": "gpt_neox.layers.{}.post_attention_layernorm.weight", + "transformer.h.{}.mlp.fc.bias": "gpt_neox.layers.{}.mlp.dense_h_to_4h.bias", + "transformer.h.{}.mlp.fc.weight": "gpt_neox.layers.{}.mlp.dense_h_to_4h.weight", + "transformer.h.{}.mlp.proj.bias": "gpt_neox.layers.{}.mlp.dense_4h_to_h.bias", + "transformer.h.{}.mlp.proj.weight": "gpt_neox.layers.{}.mlp.dense_4h_to_h.weight", + "transformer.ln_f.bias": "gpt_neox.final_layer_norm.bias", + "transformer.ln_f.weight": "gpt_neox.final_layer_norm.weight", + "lm_head.weight": "embed_out.weight", + } + + for name, param in lit_weights.items(): + if "transformer.h" in name: + from_name, number = layer_template(name, 2) + to_name = weight_map[from_name].format(number) + else: + to_name = weight_map[name] + param = load_param(param, name, None) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + +def copy_weights_llama( + config: Config, + state_dict: Dict[str, torch.Tensor], + lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + untie_weights: bool = False, + saver: Optional[incremental_save] = None, +) -> None: + weight_map = { + "transformer.wte.weight": "model.embed_tokens.weight", + "transformer.h.{}.norm_1.weight": "model.layers.{l}.input_layernorm.weight", + "transformer.h.{}.norm_1.bias": "model.layers.{l}.input_layernorm.bias", + "transformer.h.{}.attn.proj.weight": "model.layers.{l}.self_attn.o_proj.weight", + "transformer.h.{}.norm_2.weight": "model.layers.{l}.post_attention_layernorm.weight", + "transformer.h.{}.norm_2.bias": "model.layers.{l}.post_attention_layernorm.bias", + "transformer.ln_f.weight": "model.norm.weight", + "transformer.ln_f.bias": "model.norm.bias", + "lm_head.weight": "lm_head.weight", + } + if config._mlp_class == "LLaMAMoE": + weight_map.update( + { + "transformer.h.{}.mlp.gate.weight": "model.layers.{l}.block_sparse_moe.gate.weight", + "transformer.h.{}.mlp.experts.{}.fc_1.weight": "model.layers.{l}.block_sparse_moe.experts.{e}.w1.weight", + "transformer.h.{}.mlp.experts.{}.fc_2.weight": "model.layers.{l}.block_sparse_moe.experts.{e}.w3.weight", + "transformer.h.{}.mlp.experts.{}.proj.weight": "model.layers.{l}.block_sparse_moe.experts.{e}.w2.weight", + } + ) + elif config._mlp_class in ("LLaMAMLP", "GemmaMLP"): + weight_map.update( + { + "transformer.h.{}.mlp.fc_1.weight": "model.layers.{l}.mlp.gate_proj.weight", + "transformer.h.{}.mlp.fc_2.weight": "model.layers.{l}.mlp.up_proj.weight", + "transformer.h.{}.mlp.proj.weight": "model.layers.{l}.mlp.down_proj.weight", + } + ) + else: + raise NotImplementedError + + for name, param in lit_weights.items(): + if name == "lm_head.weight" and untie_weights: + continue + if name.endswith(".attn.attn.weight"): + from_name, l = layer_template(name, 2) + q = "model.layers.{}.self_attn.q_proj.weight".format(l) + k = "model.layers.{}.self_attn.k_proj.weight".format(l) + v = "model.layers.{}.self_attn.v_proj.weight".format(l) + qkv = load_param(param, name, None) + qp, kp, vp = qkv_split(qkv, config) + for to_name, param in zip((q, k, v), (qp, kp, vp)): + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + else: + if "transformer.h" in name: + from_name, l = layer_template(name, 2) + e = None + if "mlp.experts" in name: + from_name, e = layer_template(from_name, 5) + to_name = weight_map[from_name] + to_name = to_name.format(l=l, e=e) + else: + to_name = weight_map[name] + param = load_param(param, name, None) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + +def copy_weights_phi( + config: Config, + state_dict: Dict[str, torch.Tensor], + lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, +) -> None: + weight_map = { + "transformer.wte.weight": "model.embed_tokens.weight", + "transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight", + "transformer.h.{}.norm_1.bias": "model.layers.{}.input_layernorm.bias", + "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.dense.weight", + "transformer.h.{}.attn.proj.bias": "model.layers.{}.self_attn.dense.bias", + "transformer.h.{}.mlp.fc.weight": "model.layers.{}.mlp.fc1.weight", + "transformer.h.{}.mlp.fc.bias": "model.layers.{}.mlp.fc1.bias", + "transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.fc2.weight", + "transformer.h.{}.mlp.proj.bias": "model.layers.{}.mlp.fc2.bias", + "transformer.ln_f.weight": "model.final_layernorm.weight", + "transformer.ln_f.bias": "model.final_layernorm.bias", + "lm_head.weight": "lm_head.weight", + "lm_head.bias": "lm_head.bias", + } + + for name, param in lit_weights.items(): + if name.endswith((".attn.attn.weight", ".attn.attn.bias")): + from_name, l = layer_template(name, 2) + weight_type = name.split(".")[-1] # weight or bias + q = f"model.layers.{l}.self_attn.q_proj.{weight_type}" + k = f"model.layers.{l}.self_attn.k_proj.{weight_type}" + v = f"model.layers.{l}.self_attn.v_proj.{weight_type}" + qkv = load_param(param, name, None) + qp, kp, vp = qkv_split(qkv, config) + for to_name, param in zip((q, k, v), (qp, kp, vp)): + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + else: + if "transformer.h" in name: + from_name, l = layer_template(name, 2) + to_name = weight_map[from_name] + to_name = to_name.format(l) + else: + to_name = weight_map[name] + param = load_param(param, name, None) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + +def qkv_split( + param: Union[torch.Tensor, NotYetLoadedTensor], config: Config +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q_per_kv = config.n_head // config.n_query_groups + qs = [] + ks = [] + vs = [] + for chunk in torch.chunk(param, config.n_query_groups): + split = torch.split( + chunk, + [config.head_size * q_per_kv, config.head_size, config.head_size], + ) + qs.append(split[0]) + ks.append(split[1]) + vs.append(split[2]) + q = torch.cat(qs) + k = torch.cat(ks) + v = torch.cat(vs) + return q, k, v + + +def check_conversion_supported(lit_weights: Dict[str, torch.Tensor]) -> None: + if any("lora" in wn for wn in lit_weights): + raise ValueError( + "Checkpoints with LoRA weights cannot be converted. Call `scripts/merge_lora.py` first." + ) + if any("adapter" in wn or "gating_factor" in wn for wn in lit_weights): + raise NotImplementedError("Converting adapter models is supported.") + + +@torch.inference_mode() +def convert_lit_checkpoint( + checkpoint_path: Path, output_path: Path, config_path: Path +) -> None: + config = Config.from_json(config_path) + + if "falcon" in config.name: + copy_fn = partial(copy_weights_falcon, config.name) + elif config._mlp_class in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"): + untie_weights = "Gemma" in config.name + copy_fn = partial( + copy_weights_llama, config, untie_weights=untie_weights + ) + elif "phi" in config.name: + copy_fn = partial(copy_weights_phi, config) + else: + copy_fn = copy_weights_gpt_neox + + # initialize a new empty state dict to hold our new weights + sd = {} + with incremental_save(output_path) as saver: + lit_weights = lazy_load(checkpoint_path) + lit_weights = lit_weights.get("model", lit_weights) + check_conversion_supported(lit_weights) + copy_fn(sd, lit_weights, saver=saver) + gc.collect() + saver.save(sd) + + +if __name__ == "__main__": + CLI(convert_lit_checkpoint) diff --git a/examples/llm_finetuning/scripts/convert_pretrained_checkpoint.py b/examples/llm_finetuning/scripts/convert_pretrained_checkpoint.py new file mode 100644 index 00000000000..a6c3093374a --- /dev/null +++ b/examples/llm_finetuning/scripts/convert_pretrained_checkpoint.py @@ -0,0 +1,88 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import json +import shutil +import sys +from dataclasses import asdict +from pathlib import Path + +import torch + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt import Config +from lit_gpt.utils import CLI, incremental_save + + +@torch.inference_mode() +def convert_checkpoint( + checkpoint_file: Path, + tokenizer_dir: Path, + config_name: str, + output_dir: Path, +) -> None: + """Convert a checkpoint after pretraining. + + The pretrained checkpoint contains optimizer states and several other metadata that are not needed after training + is finished. This script will export the state-dict of the model and place it in the chosen output folder together + with the tokenizer and model config, which then can be loaded by other scripts for inference, evaluation, etc. + + Args: + checkpoint_file: Path to a checkpoint file scripts produced by the scripts in ``lit_gpt/pretrain/``. + tokenizer_dir: A path to the folder that holds the tokenizer configuration files that were used to train + the model. All files with a name starting with 'tokenizer' will be copied to the output folder. + config_name: The name of the model loaded with the ``lit_gpt.Config``. The configuration will be saved as a + JSON file to the output folder. + output_dir: The output folder where model state-dict file, the tokenizer config file, and the model config + file will be saved. + """ + + if output_dir.is_dir() and output_dir.glob("*"): + raise FileExistsError( + f"The output folder exists and is not empty: {str(output_dir)}." + " Please delete it first or choose a different name." + ) + if not tokenizer_dir.is_dir(): + raise FileNotFoundError( + f"The tokenizer_dir must be a directory: {str(output_dir)}." + ) + + output_dir.mkdir(parents=True) + output_checkpoint_file = output_dir / "lit_model.pth" + output_config_file = output_dir / "lit_config.json" + + # Save the config to output folder + config = Config.from_name(config_name) + with open(output_config_file, "w") as json_config: + json.dump(asdict(config), json_config) + + # Export the tokenizer configuration to output folder + for tokenizer_file in tokenizer_dir.glob("tokenizer*"): + shutil.copyfile(tokenizer_file, output_dir / tokenizer_file.name) + + # Copy config for tokenization if found + if (tokenizer_dir / "generation_config.json").is_file(): + shutil.copyfile( + tokenizer_dir / "generation_config.json", + output_dir / "generation_config.json", + ) + + # Extract the model state dict and save to output folder + with incremental_save(output_checkpoint_file) as saver: + print("Processing", checkpoint_file) + full_checkpoint = torch.load(str(checkpoint_file), mmap=True) + loaded_state_dict = full_checkpoint["model"] + converted_state_dict = {} + for param_name, param in loaded_state_dict.items(): + saver.store_early(param) + # remove prefix for compiled model (if any) + param_name = param_name.replace("_orig_mod.", "") + converted_state_dict[param_name] = param + print(f"Saving converted checkpoint to {str(output_checkpoint_file)}.") + saver.save(converted_state_dict) + + +if __name__ == "__main__": + CLI(convert_checkpoint) diff --git a/examples/llm_finetuning/scripts/download.py b/examples/llm_finetuning/scripts/download.py new file mode 100644 index 00000000000..e5a7459d2be --- /dev/null +++ b/examples/llm_finetuning/scripts/download.py @@ -0,0 +1,106 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import os +import sys +from pathlib import Path +from typing import Optional + +import torch +from lightning_utilities.core.imports import RequirementCache + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt.utils import CLI + +_SAFETENSORS_AVAILABLE = RequirementCache("safetensors") +_HF_TRANSFER_AVAILABLE = RequirementCache("hf_transfer") + + +def download_from_hub( + repo_id: Optional[str] = None, + access_token: Optional[str] = os.getenv("HF_TOKEN"), + from_safetensors: bool = False, + tokenizer_only: bool = False, + checkpoint_dir: Path = Path("checkpoints"), +) -> None: + if repo_id is None: + from lit_gpt.config import configs + + options = [ + f"{config['hf_config']['org']}/{config['hf_config']['name']}" + for config in configs + ] + print("Please specify --repo_id. Available values:") + print("\n".join(options)) + return + + from huggingface_hub import snapshot_download + + if ( + "meta-llama" in repo_id or "falcon-180" in repo_id + ) and not access_token: + raise ValueError( + f"{repo_id} requires authentication, please set the `HF_TOKEN=your_token` environment" + " variable or pass --access_token=your_token. You can find your token by visiting" + " https://huggingface.co/settings/tokens" + ) + + download_files = ["tokenizer*", "generation_config.json"] + if not tokenizer_only: + if from_safetensors: + if not _SAFETENSORS_AVAILABLE: + raise ModuleNotFoundError(str(_SAFETENSORS_AVAILABLE)) + download_files.append("*.safetensors") + else: + # covers `.bin` files and `.bin.index.json` + download_files.append("*.bin*") + elif from_safetensors: + raise ValueError( + "`--from_safetensors=True` won't have an effect with `--tokenizer_only=True`" + ) + + import huggingface_hub._snapshot_download as download + import huggingface_hub.constants as constants + + previous = constants.HF_HUB_ENABLE_HF_TRANSFER + if _HF_TRANSFER_AVAILABLE and not previous: + print("Setting HF_HUB_ENABLE_HF_TRANSFER=1") + constants.HF_HUB_ENABLE_HF_TRANSFER = True + download.HF_HUB_ENABLE_HF_TRANSFER = True + + directory = checkpoint_dir / repo_id + snapshot_download( + repo_id, + local_dir=directory, + local_dir_use_symlinks=False, + resume_download=True, + allow_patterns=download_files, + token=access_token, + ) + + constants.HF_HUB_ENABLE_HF_TRANSFER = previous + download.HF_HUB_ENABLE_HF_TRANSFER = previous + + # convert safetensors to PyTorch binaries + if from_safetensors: + from safetensors import SafetensorError + from safetensors.torch import load_file as safetensors_load + + print("Converting .safetensor files to PyTorch binaries (.bin)") + for safetensor_path in directory.glob("*.safetensors"): + bin_path = safetensor_path.with_suffix(".bin") + try: + result = safetensors_load(safetensor_path) + except SafetensorError as e: + raise RuntimeError( + f"{safetensor_path} is likely corrupted. Please try to re-download it." + ) from e + print(f"{safetensor_path} --> {bin_path}") + torch.save(result, bin_path) + os.remove(safetensor_path) + + +if __name__ == "__main__": + CLI(download_from_hub) diff --git a/examples/llm_finetuning/scripts/merge_lora.py b/examples/llm_finetuning/scripts/merge_lora.py new file mode 100644 index 00000000000..89818a999fa --- /dev/null +++ b/examples/llm_finetuning/scripts/merge_lora.py @@ -0,0 +1,94 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +"""This script merges the LoRA weights with the base model""" + +import sys +from pathlib import Path +from typing import Optional + +import lightning as L +import torch + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt.lora import GPT, Config, lora_filter, merge_lora_weights +from lit_gpt.utils import ( + CLI, + check_valid_checkpoint_dir, + get_default_supported_precision, + lazy_load, +) + + +def merge_lora( + lora_path: Path = Path("out/lora/alpaca/lit_model_lora_finetuned.pth"), + checkpoint_dir: Path = Path( + "checkpoints/stabilityai/stablelm-base-alpha-3b" + ), + out_dir: Path = Path("out/lora/checkpoint"), + precision: Optional[str] = None, + lora_r: int = 8, + lora_alpha: int = 16, + lora_dropout: float = 0.05, + lora_query: bool = True, + lora_key: bool = False, + lora_value: bool = True, + lora_projection: bool = False, + lora_mlp: bool = False, + lora_head: bool = False, +) -> None: + """Generates a response based on a given instruction and an optional input. + This script will only work with checkpoints from the instruction-tuned GPT-LoRA model. + See `finetune/lora.py`. + + Args: + lora_path: Path to the checkpoint with trained adapter weights, which are the output of + `finetune/lora.py`. + checkpoint_dir: The path to the checkpoint folder with pretrained GPT weights. + out_dir: The path to the merged model that is created by this script. + precision: Indicates the Fabric precision setting to use. + """ + check_valid_checkpoint_dir(checkpoint_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + precision = precision or get_default_supported_precision(training=False) + fabric = L.Fabric(devices=1, precision=precision) + + config = Config.from_json( + checkpoint_dir / "lit_config.json", + r=lora_r, + alpha=lora_alpha, + dropout=lora_dropout, + to_query=lora_query, + to_key=lora_key, + to_value=lora_value, + to_projection=lora_projection, + to_mlp=lora_mlp, + to_head=lora_head, + ) + + with fabric.init_module(empty_init=True): + model = GPT(config) + checkpoint_path = checkpoint_dir / "lit_model.pth" + checkpoint = lazy_load(checkpoint_path) + lora_checkpoint = lazy_load(lora_path) + checkpoint.update(lora_checkpoint.get("model", lora_checkpoint)) + model.load_state_dict(checkpoint) + + merge_lora_weights(model) + + save_path = out_dir / "lit_model.pth" + fabric.print(f"Saving weights to {str(save_path)!r}") + # remove lora parameters and the lora linear substring + state_dict = { + k.replace("linear.", ""): v + for k, v in model.state_dict().items() + if not lora_filter(k, v) + } + torch.save(state_dict, save_path) + + +if __name__ == "__main__": + CLI(merge_lora) diff --git a/examples/llm_finetuning/scripts/prepare_alpaca.py b/examples/llm_finetuning/scripts/prepare_alpaca.py new file mode 100644 index 00000000000..cde6fca1b67 --- /dev/null +++ b/examples/llm_finetuning/scripts/prepare_alpaca.py @@ -0,0 +1,169 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +"""Implementation derived from https://github.com/tloen/alpaca-lora""" + +import json +import sys +from pathlib import Path +from typing import Optional + +import torch +from lightning_utilities.core.imports import RequirementCache +from torch.utils.data import random_split +from tqdm import tqdm + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt.tokenizer import Tokenizer +from lit_gpt.utils import CLI + + +def prepare( + destination_path: Path = Path("data/alpaca"), + checkpoint_dir: Path = Path( + "checkpoints/stabilityai/stablelm-base-alpha-3b" + ), + test_split_fraction: float = 0.03865, # to get exactly 2000 test samples, + seed: int = 42, + mask_inputs: bool = False, # as in alpaca-lora + data_file_name: str = "alpaca_data_cleaned_archive.json", + data_file_url: str = "https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json", + ignore_index: int = -1, + max_seq_length: Optional[int] = None, +) -> None: + """Prepare the Alpaca dataset for instruction tuning. + + The output is a training and test dataset saved as `train.pt` and `test.pt`, + which stores the preprocessed and tokenized prompts and labels. + """ + if max_seq_length is None: + with open( + checkpoint_dir / "lit_config.json", "r", encoding="utf-8" + ) as file: + config = json.load(file) + max_seq_length = config["block_size"] + + destination_path.mkdir(parents=True, exist_ok=True) + data_file_path = destination_path / data_file_name + print("Loading data file...") + download_if_missing(data_file_path, data_file_url) + with open(data_file_path, "r", encoding="utf-8") as file: + data = json.load(file) + + print("Loading tokenizer...") + tokenizer = Tokenizer(checkpoint_dir) + + # Partition the dataset into train and test + train_set, test_set = random_split( + data, + [1.0 - test_split_fraction, test_split_fraction], + generator=torch.Generator().manual_seed(seed), + ) + train_set, test_set = list(train_set), list(test_set) + + print(f"train has {len(train_set):,} samples") + print(f"test has {len(test_set):,} samples") + + print("Processing train split ...") + train_set = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + mask_inputs=mask_inputs, + ignore_index=ignore_index, + ) + for sample in tqdm(train_set) + ] + torch.save(train_set, destination_path / "train.pt") + + print("Processing test split ...") + test_set = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + mask_inputs=mask_inputs, + ignore_index=ignore_index, + ) + for sample in tqdm(test_set) + ] + torch.save(test_set, destination_path / "test.pt") + + +def download_if_missing(file_path: Path, file_url: str) -> None: + """Downloads the raw json data file and saves it in the given destination.""" + if file_path.exists() and file_path.stat().st_size > 0: + return + requests_available = RequirementCache("requests") + if not requests_available: + raise ModuleNotFoundError(str(requests_available)) + import requests + + with open(file_path, "w", encoding="utf-8") as f: + f.write(requests.get(file_url).text) + + +def prepare_sample( + example: dict, + tokenizer: Tokenizer, + max_length: int, + mask_inputs: bool, + ignore_index: int, +) -> dict: + """Processes a single sample. + + Each sample in the dataset consists of: + - instruction: A string describing the task + - input: A string holding a special input value for the instruction. + This only applies to some samples, and in others this is empty. + - output: The response string + + This function processes this data to produce a prompt text and a label for + supervised training. The prompt text is formed as a single message including both + the instruction and the input. The label/target is the same message but with the + response attached. + + Finally, both the prompt and the label get tokenized. If desired, all tokens + in the label that correspond to the original input prompt get masked out (default). + """ + full_prompt = generate_prompt(example) + full_prompt_and_response = full_prompt + example["output"] + encoded_full_prompt = tokenizer.encode(full_prompt, max_length=max_length) + encoded_full_prompt_and_response = tokenizer.encode( + full_prompt_and_response, eos=True, max_length=max_length + ) + + # The labels are the full prompt with response, but with the prompt masked out + labels = encoded_full_prompt_and_response.clone() + if mask_inputs: + labels[: len(encoded_full_prompt)] = ignore_index + + return { + **example, + "input_ids": encoded_full_prompt_and_response, + "labels": labels, + } + + +def generate_prompt(example: dict) -> str: + """Generates a standardized message to prompt the model with an instruction, optional input and a + 'response' field.""" + + if example["input"]: + return ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:" + ) + return ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{example['instruction']}\n\n### Response:" + ) + + +if __name__ == "__main__": + CLI(prepare) diff --git a/examples/llm_finetuning/scripts/prepare_csv.py b/examples/llm_finetuning/scripts/prepare_csv.py new file mode 100644 index 00000000000..bbd27074d52 --- /dev/null +++ b/examples/llm_finetuning/scripts/prepare_csv.py @@ -0,0 +1,157 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import json +import logging +import sys +from pathlib import Path +from typing import Optional, Tuple + +import torch +from torch.utils.data import random_split +from tqdm import tqdm + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +logger = logging.getLogger(__name__) +sys.path.append(str(wd)) + +from lit_gpt.tokenizer import Tokenizer +from lit_gpt.utils import CLI + + +def prepare( + csv_path: Path, + destination_path: Path = Path("data/csv"), + checkpoint_dir: Path = Path( + "checkpoints/stabilityai/stablelm-base-alpha-3b" + ), + test_split_fraction: float = 0.1, + seed: int = 42, + mask_inputs: bool = False, + ignore_index: int = -1, + max_seq_length: Optional[int] = None, + columns: Tuple[str, ...] = ("instruction", "input", "output"), +) -> None: + """Prepare a CSV dataset for instruction tuning. + + The output is a training and test dataset saved as `train.pt` and `test.pt`, + which stores the preprocessed and tokenized prompts and labels. + """ + if max_seq_length is None: + with open(checkpoint_dir / "lit_config.json", "r") as file: + config = json.load(file) + max_seq_length = config["block_size"] + + destination_path.mkdir(parents=True, exist_ok=True) + logger.info("Loading data file ...") + import pandas as pd + + df = pd.read_csv(csv_path, dtype=str).fillna("") + if not (df.columns.values == columns).all(): + raise ValueError( + f"CSV columns must be {columns}, found {df.columns.values}" + ) + data = json.loads(df.to_json(orient="records", indent=4)) + + print("Loading tokenizer...") + tokenizer = Tokenizer(checkpoint_dir) + + # Partition the dataset into train and test + train_set, test_set = random_split( + data, + [1.0 - test_split_fraction, test_split_fraction], + generator=torch.Generator().manual_seed(seed), + ) + train_set, test_set = list(train_set), list(test_set) + + print(f"train has {len(train_set):,} samples") + print(f"test has {len(test_set):,} samples") + + print("Processing train split ...") + train_set = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + mask_inputs=mask_inputs, + ignore_index=ignore_index, + ) + for sample in tqdm(train_set) + ] + torch.save(train_set, destination_path / "train.pt") + + print("Processing test split ...") + test_set = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + mask_inputs=mask_inputs, + ignore_index=ignore_index, + ) + for sample in tqdm(test_set) + ] + torch.save(test_set, destination_path / "test.pt") + + +def prepare_sample( + example: dict, + tokenizer: Tokenizer, + max_length: int, + mask_inputs: bool, + ignore_index: int, +) -> dict: + """Processes a single sample. + + Each sample in the dataset consists of: + - instruction: A string describing the task + - input: A string holding a special input value for the instruction. + This only applies to some samples, and in others this is empty. + - output: The response string + + This function processes this data to produce a prompt text and a label for + supervised training. The prompt text is formed as a single message including both + the instruction and the input. The label/target is the same message but with the + response attached. + + Finally, both the prompt and the label get tokenized. If desired, all tokens + in the label that correspond to the original input prompt get masked out (default). + """ + full_prompt = generate_prompt(example) + full_prompt_and_response = full_prompt + example["output"] + encoded_full_prompt = tokenizer.encode(full_prompt, max_length=max_length) + encoded_full_prompt_and_response = tokenizer.encode( + full_prompt_and_response, eos=True, max_length=max_length + ) + + # The labels are the full prompt with response, but with the prompt masked out + labels = encoded_full_prompt_and_response.clone() + if mask_inputs: + labels[: len(encoded_full_prompt)] = ignore_index + + return { + **example, + "input_ids": encoded_full_prompt_and_response, + "labels": labels, + } + + +def generate_prompt(example: dict) -> str: + """Generates a standardized message to prompt the model with an instruction, optional input and a + 'response' field.""" + + if example["input"]: + return ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:" + ) + return ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{example['instruction']}\n\n### Response:" + ) + + +if __name__ == "__main__": + CLI(prepare) diff --git a/examples/llm_finetuning/scripts/prepare_dolly.py b/examples/llm_finetuning/scripts/prepare_dolly.py new file mode 100644 index 00000000000..8bb434398fa --- /dev/null +++ b/examples/llm_finetuning/scripts/prepare_dolly.py @@ -0,0 +1,163 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +"""Implementation derived from https://github.com/tloen/alpaca-lora""" + +import json +import sys +from pathlib import Path +from typing import Optional + +import torch +from torch.utils.data import random_split +from tqdm import tqdm + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt.tokenizer import Tokenizer +from lit_gpt.utils import CLI + +from scripts.prepare_alpaca import download_if_missing + + +def prepare( + destination_path: Path = Path("data/dolly"), + checkpoint_dir: Path = Path( + "checkpoints/stabilityai/stablelm-base-alpha-3b" + ), + test_split_fraction: float = 0.1, + seed: int = 42, + mask_inputs: bool = False, + data_file_name: str = "dolly_data_cleaned.json", + data_file_url: str = "https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl", + ignore_index: int = -1, + max_seq_length: Optional[int] = None, +) -> None: + """Prepare the Dolly 15k dataset for instruction tuning. + + The output is a training and test dataset saved as `train.pt` and `test.pt`, + which stores the preprocessed and tokenized prompts and labels. + """ + + if max_seq_length is None: + with open( + checkpoint_dir / "lit_config.json", "r", encoding="utf-8" + ) as file: + config = json.load(file) + max_seq_length = config["block_size"] + + destination_path.mkdir(parents=True, exist_ok=True) + data_file_path = destination_path / data_file_name + print("Loading data file...") + download_if_missing(data_file_path, data_file_url) + + with open(data_file_path, "r", encoding="utf-8") as file: + data = file.readlines() + data = [json.loads(line) for line in data] + for item in data: + item["input"] = item.pop("context") + item["output"] = item.pop("response") + + print("Loading tokenizer...") + tokenizer = Tokenizer(checkpoint_dir) + + # Partition the dataset into train and test + train_set, test_set = random_split( + data, + [1.0 - test_split_fraction, test_split_fraction], + generator=torch.Generator().manual_seed(seed), + ) + train_set, test_set = list(train_set), list(test_set) + + print(f"train has {len(train_set):,} samples") + print(f"test has {len(test_set):,} samples") + + print("Processing train split ...") + train_set = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + mask_inputs=mask_inputs, + ignore_index=ignore_index, + ) + for sample in tqdm(train_set) + ] + torch.save(train_set, destination_path / "train.pt") + + print("Processing test split ...") + test_set = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + mask_inputs=mask_inputs, + ignore_index=ignore_index, + ) + for sample in tqdm(test_set) + ] + torch.save(test_set, destination_path / "test.pt") + + +def prepare_sample( + example: dict, + tokenizer: Tokenizer, + max_length: int, + mask_inputs: bool, + ignore_index: int, +) -> dict: + """Processes a single sample. + + Each sample in the dataset consists of: + - instruction: A string describing the task + - input: A string holding a special input value for the instruction. + This only applies to some samples, and in others this is empty. + - output: The response string + + This function processes this data to produce a prompt text and a label for + supervised training. The prompt text is formed as a single message including both + the instruction and the input. The label/target is the same message but with the + response attached. + + Finally, both the prompt and the label get tokenized. If desired, all tokens + in the label that correspond to the original input prompt get masked out (default). + """ + full_prompt = generate_prompt(example) + full_prompt_and_response = full_prompt + example["output"] + encoded_full_prompt = tokenizer.encode(full_prompt, max_length=max_length) + encoded_full_prompt_and_response = tokenizer.encode( + full_prompt_and_response, eos=True, max_length=max_length + ) + + # The labels are the full prompt with response, but with the prompt masked out + labels = encoded_full_prompt_and_response.clone() + if mask_inputs: + labels[: len(encoded_full_prompt)] = ignore_index + + return { + **example, + "input_ids": encoded_full_prompt_and_response, + "labels": labels, + } + + +def generate_prompt(example: dict) -> str: + """Generates a standardized message to prompt the model with an instruction, optional input and a + 'response' field.""" + + if example["input"]: + return ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:" + ) + return ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{example['instruction']}\n\n### Response:" + ) + + +if __name__ == "__main__": + CLI(prepare) diff --git a/examples/llm_finetuning/scripts/prepare_flan.py b/examples/llm_finetuning/scripts/prepare_flan.py new file mode 100644 index 00000000000..a34b547213b --- /dev/null +++ b/examples/llm_finetuning/scripts/prepare_flan.py @@ -0,0 +1,249 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +"""Implementation derived from https://github.com/tloen/alpaca-lora""" +import json +import sys +from pathlib import Path +from typing import Optional + +import torch +from tqdm import tqdm + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt.tokenizer import Tokenizer +from lit_gpt.utils import CLI + +from scripts.prepare_alpaca import download_if_missing + + +def load_jsonl(filename): + data = [] + with open(filename, "r", encoding="utf-8") as f: + for line in f: + data.append(json.loads(line)) + return data + + +def prepare( + destination_path: Path = Path("data/flan"), + checkpoint_dir: Path = Path( + "checkpoints/stabilityai/stablelm-base-alpha-3b" + ), + mask_inputs: bool = False, # as in alpaca-lora + subsets: Optional[str] = None, + ignore_index: int = -1, + max_seq_length: Optional[int] = None, +) -> None: + """Prepare the FLAN-collection datasets for instruction tuning. + + The output is a training and test dataset saved as `train.pt` and `test.pt`, + which stores the preprocessed and tokenized prompts and labels. + + Since the original test set does not have responses, the validation set + is used as the test set. + """ + + supported_subsets = { + "aeslc_10templates", + "ag_news_subset_10templates", + "anli_r1_10templates", + "anli_r2_10templates", + "anli_r3_10templates", + "arc_challenge_10templates", + "arc_easy_10templates", + "bool_q_10templates", + "cb_10templates", + "cnn_dailymail_10templates", + "cola_10templates", + "common_gen_10templates", + "copa_10templates", + "coqa_10templates", + "cosmos_qa_10templates", + "dart_10templates", + "definite_pronoun_resolution_10templates", + "drop_10templates", + "e2e_nlg_10templates", + "fix_punct_10templates", + "gigaword_10templates", + "glue_mrpc_10templates", + "glue_qqp_10templates", + "hellaswag_10templates", + "imdb_reviews_10templates", + "math_dataset_10templates", + "mnli_matched_10templates", + "mnli_mismatched_10templates", + "multi_news_10templates", + "multirc_10templates", + "natural_questions_10templates", + "openbookqa_10templates", + "opinion_abstracts_idebate_10templates", + "opinion_abstracts_rotten_tomatoes_10templates", + "para_crawl_enes_10templates", + "paws_wiki_10templates", + "piqa_10templates", + "qnli_10templates", + "quac_10templates", + "record_10templates", + "rte_10templates", + "samsum_10templates", + "sentiment140_10templates", + "snli_10templates", + "squad_v1_10templates", + "squad_v2_10templates", + "sst2_10templates", + "story_cloze_10templates", + "stsb_10templates", + "trec_10templates", + "trivia_qa_10templates", + "true_case_10templates", + "web_nlg_en_10templates", + "wic_10templates", + "wiki_lingua_english_en_10templates", + "wmt14_enfr_10templates", + "wmt16_translate_csen_10templates", + "wmt16_translate_deen_10templates", + "wmt16_translate_fien_10templates", + "wmt16_translate_roen_10templates", + "wmt16_translate_ruen_10templates", + "wmt16_translate_tren_10templates", + "wnli_10templates", + "word_segment_10templates", + "wsc_10templates", + "yelp_polarity_reviews_10templates", + } + + if subsets is not None: + subsets = subsets.split(",") + for sub in subsets: + if sub not in supported_subsets: + raise ValueError(f"{sub} not in {supported_subsets}") + else: + subsets = list(supported_subsets) + + if max_seq_length is None: + with open( + checkpoint_dir / "lit_config.json", "r", encoding="utf-8" + ) as file: + config = json.load(file) + max_seq_length = config["block_size"] + + destination_path.mkdir(parents=True, exist_ok=True) + print("Loading data file...") + + base_url = "https://huggingface.co/datasets/Muennighoff/flan/resolve/main/" + + train_set, test_set = [], [] + for sub in subsets: + train_sub = sub + "_train" + data_file_name = train_sub + ".jsonl" + data_file_path = destination_path / data_file_name + data_file_url = base_url + "train/" + data_file_name + + print(f"Loading training data file {sub}...") + download_if_missing(data_file_path, data_file_url) + sub_train_set = load_jsonl(data_file_path) + train_set.extend(sub_train_set) + + test_sub = sub + "_test" + data_file_name = test_sub + ".jsonl" + data_file_path = destination_path / data_file_name + data_file_url = base_url + "test/" + data_file_name + + print(f"Loading test data file {sub}...") + download_if_missing(data_file_path, data_file_url) + sub_test_set = load_jsonl(data_file_path) + test_set.extend(sub_test_set) + + print("Loading tokenizer...") + tokenizer = Tokenizer(checkpoint_dir) + + train_set, test_set = list(train_set), list(test_set) + + print(f"train has {len(train_set):,} samples") + print(f"test has {len(test_set):,} samples") + + print("Processing train split ...") + train_set = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + mask_inputs=mask_inputs, + ignore_index=ignore_index, + ) + for sample in tqdm(train_set) + ] + torch.save(train_set, destination_path / "train.pt") + + print("Processing test split ...") + test_set = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + mask_inputs=mask_inputs, + ignore_index=ignore_index, + ) + for sample in tqdm(test_set) + ] + torch.save(test_set, destination_path / "test.pt") + + +def prepare_sample( + example: dict, + tokenizer: Tokenizer, + max_length: int, + mask_inputs: bool, + ignore_index: int, +): + """Processes a single sample. + + Each sample in the dataset consists of: + - instruction: A string describing the task + - input: A string holding a special input value for the instruction. + This only applies to some samples, and in others this is empty. + - output: The response string + + This function processes this data to produce a prompt text and a label for + supervised training. The prompt text is formed as a single message including both + the instruction and the input. The label/target is the same message but with the + response attached. + + Finally, both the prompt and the label get tokenized. If desired, all tokens + in the label that correspond to the original input prompt get masked out (default). + """ + full_prompt = generate_prompt(example) + full_prompt_and_response = full_prompt + example["targets"] + encoded_full_prompt = tokenizer.encode(full_prompt, max_length=max_length) + encoded_full_prompt_and_response = tokenizer.encode( + full_prompt_and_response, eos=True, max_length=max_length + ) + + # The labels are the full prompt with response, but with the prompt masked out + labels = encoded_full_prompt_and_response.clone() + if mask_inputs: + labels[: len(encoded_full_prompt)] = ignore_index + + return { + **example, + "input_ids": encoded_full_prompt_and_response, + "labels": labels, + } + + +def generate_prompt(example): + """Generates a standardized message to prompt the model with an instruction, optional input and a + 'response' field.""" + + return ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{example['inputs']}\n\n### Response:" + ) + + +if __name__ == "__main__": + CLI(prepare) diff --git a/examples/llm_finetuning/scripts/prepare_lima.py b/examples/llm_finetuning/scripts/prepare_lima.py new file mode 100644 index 00000000000..e27928ce9e2 --- /dev/null +++ b/examples/llm_finetuning/scripts/prepare_lima.py @@ -0,0 +1,198 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +"""Implementation derived from https://github.com/tloen/alpaca-lora""" + +import json +import os +import sys +from pathlib import Path +from typing import List, Optional + +import torch +from torch.utils.data import random_split +from tqdm import tqdm + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt.tokenizer import Tokenizer +from lit_gpt.utils import CLI + + +def prepare( + destination_path: Path = Path("data/lima"), + test_split_fraction: float = 0.1, + checkpoint_dir: Path = Path( + "checkpoints/stabilityai/stablelm-base-alpha-3b" + ), + mask_inputs: bool = False, # as in alpaca-lora + seed: int = 42, + include_multiturn_conversations: bool = False, + data_repo_id: str = "GAIR/lima", + ignore_index: int = -1, + access_token: Optional[str] = os.getenv("HF_TOKEN"), + max_seq_length: Optional[int] = None, +) -> None: + """Prepare the LIMA dataset for instruction tuning. + + The output is a training and test dataset saved as `train.pt` and `test.pt`, + which stores the preprocessed and tokenized prompts and labels. + """ + + if access_token is None: + raise ValueError( + "LIMA requires authentication, please set the `HF_TOKEN=your_token` environment" + " variable or pass --access_token=your_token. You can find your token by visiting" + " https://huggingface.co/settings/tokens" + ) + + if max_seq_length is None: + with open( + checkpoint_dir / "lit_config.json", "r", encoding="utf-8" + ) as file: + config = json.load(file) + max_seq_length = config["block_size"] + + destination_path.mkdir(parents=True, exist_ok=True) + print("Loading data file...") + + from datasets import load_dataset + + dataset = load_dataset(data_repo_id, token=access_token) + train_data = format_dataset( + dataset["train"], include_multiturn_conversations + ) + + # test set is present but doesn't have any solutions, so we cannot use it here + # but have to create our own + # for consistency with prepare_alpaca.py and prepare_dolly.py + # test_set = format_dataset(dataset["test"], include_multiturn_conversations) + + print("Loading tokenizer...") + tokenizer = Tokenizer(checkpoint_dir) + + # Partition the dataset into train and test + train_set, test_set = random_split( + train_data, + [1.0 - test_split_fraction, test_split_fraction], + generator=torch.Generator().manual_seed(seed), + ) + train_set, test_set = list(train_set), list(test_set) + + print(f"train has {len(train_set):,} samples") + print(f"test has {len(test_set):,} samples") + + print("Processing train split ...") + train_set = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + mask_inputs=mask_inputs, + ignore_index=ignore_index, + ) + for sample in tqdm(train_set) + ] + torch.save(train_set, destination_path / "train.pt") + + print("Processing test split ...") + test_set = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + mask_inputs=mask_inputs, + ignore_index=ignore_index, + ) + for sample in tqdm(test_set) + ] + torch.save(test_set, destination_path / "test.pt") + + +def format_dataset( + dataset_partition: dict, include_multi_turn_conversations: bool +) -> List[dict]: + formatted_ds = [] + + for entry in dataset_partition: + convo = entry["conversations"] + if include_multi_turn_conversations: + for i in range(0, len(convo) - 1, 2): + formatted_ds.append( + { + "instruction": convo[i], + "input": "", + "output": convo[i + 1], + } + ) + + else: + formatted_ds.append( + {"instruction": convo[0], "input": "", "output": convo[1]} + ) + + return formatted_ds + + +def prepare_sample( + example: dict, + tokenizer: Tokenizer, + max_length: int, + mask_inputs: bool, + ignore_index: int, +) -> dict: + """Processes a single sample. + + Each sample in the dataset consists of: + - instruction: A string describing the task + - input: A string holding a special input value for the instruction. + This only applies to some samples, and in others this is empty. + - output: The response string + + This function processes this data to produce a prompt text and a label for + supervised training. The prompt text is formed as a single message including both + the instruction and the input. The label/target is the same message but with the + response attached. + + Finally, both the prompt and the label get tokenized. If desired, all tokens + in the label that correspond to the original input prompt get masked out (default). + """ + full_prompt = generate_prompt(example) + full_prompt_and_response = full_prompt + example["output"] + encoded_full_prompt = tokenizer.encode(full_prompt, max_length=max_length) + encoded_full_prompt_and_response = tokenizer.encode( + full_prompt_and_response, eos=True, max_length=max_length + ) + + # The labels are the full prompt with response, but with the prompt masked out + labels = encoded_full_prompt_and_response.clone() + if mask_inputs: + labels[: len(encoded_full_prompt)] = ignore_index + + return { + **example, + "input_ids": encoded_full_prompt_and_response, + "labels": labels, + } + + +def generate_prompt(example: dict) -> str: + """Generates a standardized message to prompt the model with an instruction, optional input and a + 'response' field.""" + + if example["input"]: + return ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:" + ) + return ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{example['instruction']}\n\n### Response:" + ) + + +if __name__ == "__main__": + CLI(prepare) diff --git a/examples/llm_finetuning/scripts/prepare_longform.py b/examples/llm_finetuning/scripts/prepare_longform.py new file mode 100644 index 00000000000..6327bad8654 --- /dev/null +++ b/examples/llm_finetuning/scripts/prepare_longform.py @@ -0,0 +1,153 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +"""Implementation derived from https://github.com/tloen/alpaca-lora""" + +import json +import sys +from pathlib import Path +from typing import Optional + +import torch +from tqdm import tqdm + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt.tokenizer import Tokenizer +from lit_gpt.utils import CLI + +from scripts.prepare_alpaca import download_if_missing + + +def prepare( + destination_path: Path = Path("data/longform"), + checkpoint_dir: Path = Path( + "checkpoints/stabilityai/stablelm-base-alpha-3b" + ), + mask_inputs: bool = False, # as in alpaca-lora + ignore_index: int = -1, + max_seq_length: Optional[int] = None, +) -> None: + """Prepare the Alpaca dataset for instruction tuning. + + The output is a training and test dataset saved as `train.pt` and `test.pt`, + which stores the preprocessed and tokenized prompts and labels. + """ + if max_seq_length is None: + with open( + checkpoint_dir / "lit_config.json", "r", encoding="utf-8" + ) as file: + config = json.load(file) + max_seq_length = config["block_size"] + + destination_path.mkdir(parents=True, exist_ok=True) + + train_file_name = "train.json" + # val_file_name = "val.json" + test_file_name = "test.json" + + train_file_url = "https://raw.githubusercontent.com/akoksal/LongForm/main/dataset/train.json" + # val_file_url = "https://raw.githubusercontent.com/akoksal/LongForm/main/dataset/val.json" + test_file_url = "https://raw.githubusercontent.com/akoksal/LongForm/main/dataset/test.json" + + train_file_path = destination_path / train_file_name + print("Loading train data file...") + download_if_missing(train_file_path, train_file_url) + with open(train_file_path, "r", encoding="utf-8") as file: + train_data = json.load(file) + + test_file_path = destination_path / test_file_name + print("Loading test data file...") + download_if_missing(test_file_path, test_file_url) + with open(test_file_path, "r", encoding="utf-8") as file: + test_data = json.load(file) + + print("Loading tokenizer...") + tokenizer = Tokenizer(checkpoint_dir) + + print(f"train has {len(train_data):,} samples") + print(f"test has {len(test_data):,} samples") + + print("Processing train set ...") + train_data = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + mask_inputs=mask_inputs, + ignore_index=ignore_index, + ) + for sample in tqdm(train_data) + ] + torch.save(train_data, destination_path / "train.pt") + + print("Processing test set ...") + test_data = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + mask_inputs=mask_inputs, + ignore_index=ignore_index, + ) + for sample in tqdm(test_data) + ] + torch.save(test_data, destination_path / "test.pt") + + +def prepare_sample( + example: dict, + tokenizer: Tokenizer, + max_length: int, + mask_inputs: bool, + ignore_index: int, +) -> dict: + """Processes a single sample. + + Each sample in the dataset consists of: + - instruction: A string describing the task + - input: A string holding a special input value for the instruction. + This only applies to some samples, and in others this is empty. + - output: The response string + + This function processes this data to produce a prompt text and a label for + supervised training. The prompt text is formed as a single message including both + the instruction and the input. The label/target is the same message but with the + response attached. + + Finally, both the prompt and the label get tokenized. If desired, all tokens + in the label that correspond to the original input prompt get masked out (default). + """ + full_prompt = generate_prompt(example) + full_prompt_and_response = full_prompt + example["output"] + encoded_full_prompt = tokenizer.encode(full_prompt, max_length=max_length) + encoded_full_prompt_and_response = tokenizer.encode( + full_prompt_and_response, eos=True, max_length=max_length + ) + + # The labels are the full prompt with response, but with the prompt masked out + labels = encoded_full_prompt_and_response.clone() + if mask_inputs: + labels[: len(encoded_full_prompt)] = ignore_index + + return { + **example, + "input_ids": encoded_full_prompt_and_response, + "labels": labels, + } + + +def generate_prompt(example: dict) -> str: + """Generates a standardized message to prompt the model with an instruction and a + 'response' field.""" + + return ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{example['input']}\n\n### Response:" + ) + + +if __name__ == "__main__": + CLI(prepare) diff --git a/examples/llm_finetuning/scripts/prepare_openwebtext.py b/examples/llm_finetuning/scripts/prepare_openwebtext.py new file mode 100644 index 00000000000..fbb4a8d9d96 --- /dev/null +++ b/examples/llm_finetuning/scripts/prepare_openwebtext.py @@ -0,0 +1,100 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +# saves the openwebtext dataset to a binary file for training. following was helpful: +# https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py +import os +import sys +from pathlib import Path +from typing import Union + +import numpy as np +from tqdm import tqdm + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt import Tokenizer +from lit_gpt.utils import CLI + + +def prepare( + destination_path: Path = Path("data/openwebtext"), + checkpoint_dir: Path = Path( + "checkpoints/stabilityai/stablelm-base-alpha-3b" + ), + seed: int = 42, + test_size: Union[float, int, None] = 0.0005, +) -> None: + from datasets import load_dataset # huggingface datasets + + destination_path.mkdir(parents=True, exist_ok=True) + + tokenizer = Tokenizer(checkpoint_dir) + + # number of workers in .map() call + # good number to use is ~order number of cpu cores // 2 + num_proc = os.cpu_count() // 2 + + # number of workers in load_dataset() call + # best number might be different from num_proc above as it also depends on HW speed. + # it is better than 1 usually though + num_proc_load_dataset = num_proc + + # takes 54GB in huggingface .cache dir, about 8M documents (8,013,769) + dataset = load_dataset("openwebtext", num_proc=num_proc_load_dataset) + + # owt by default only contains the 'train' split, so create a test split + split_dataset = dataset["train"].train_test_split( + test_size=test_size, seed=seed, shuffle=True + ) + split_dataset["val"] = split_dataset.pop( + "test" + ) # rename the test split to val + + def process(example): + ids = tokenizer.encode(example["text"]).tolist() + ids.append(tokenizer.eos_id) + + # ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens + # ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe + # note: I think eot should be prepended not appended... hmm. it's called "eot" though... + return {"ids": ids, "len": len(ids)} + + # tokenize the dataset + tokenized = split_dataset.map( + process, + remove_columns=["text"], + desc="tokenizing the splits", + num_proc=num_proc, + ) + + # concatenate all the ids in each dataset into one large file we can use for training + for split, dset in tokenized.items(): + arr_len = np.sum(dset["len"], dtype=np.uint64) + filename = destination_path / f"{split}.bin" + dtype = ( + np.uint16 + ) # (can do since enc.max_token_value == 50256 is < 2**16) + arr = np.memmap( + str(filename), dtype=dtype, mode="w+", shape=(arr_len,) + ) + total_batches = 1024 + + idx = 0 + for batch_idx in tqdm( + range(total_batches), desc=f"writing {filename}" + ): + # Batch together samples for faster write + batch = dset.shard( + num_shards=total_batches, index=batch_idx, contiguous=True + ).with_format("numpy") + arr_batch = np.concatenate(batch["ids"]) + # Write into mmap + arr[idx : idx + len(arr_batch)] = arr_batch + idx += len(arr_batch) + arr.flush() + + +if __name__ == "__main__": + CLI(prepare) diff --git a/examples/llm_finetuning/scripts/prepare_redpajama.py b/examples/llm_finetuning/scripts/prepare_redpajama.py new file mode 100644 index 00000000000..02044307797 --- /dev/null +++ b/examples/llm_finetuning/scripts/prepare_redpajama.py @@ -0,0 +1,185 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import glob +import json +import os +import sys +from pathlib import Path + +import numpy as np +from tqdm import tqdm + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +import lit_gpt.packed_dataset as packed_dataset +from lit_gpt import Config, Tokenizer +from lit_gpt.utils import CLI + +filenames_sample = [ + "arxiv_sample.jsonl", + "book_sample.jsonl", + "c4_sample.jsonl", + "cc_2019-30_sample.jsonl", + "cc_2020-05_sample.jsonl", + "cc_2021-04_sample.jsonl", + "cc_2022-05_sample.jsonl", + "cc_2023-06_sample.jsonl", + "github_sample.jsonl", + "stackexchange_sample.jsonl", + "wikipedia_sample.jsonl", +] + +filename_sets = { + "arxiv": "arxiv/arxiv*", + "book": "book/book*", + "c4": "c4/c4-train*", + "common_crawl": "common_crawl/*", + "github": "github/filtered*", + "stackexchange": "stackexchange/stackexchange*", + "wikipedia": "wikipedia/wiki*", +} + + +def prepare_sample( + source_path: Path, + checkpoint_dir: Path, + destination_path: Path, + chunk_size: int, + match: str = "", +) -> None: + """Prepare the "Red Pajama" dataset using the original tokenizer.""" + destination_path.mkdir(parents=True, exist_ok=True) + + tokenizer = Tokenizer(checkpoint_dir) + + for name in filenames_sample: + if match and match not in name: + continue + + filepath = source_path / name + + if not filepath.is_file(): + raise RuntimeError( + f"Input file not found at {filepath}. \nMake sure you download the data, e.g. wget -i" + " https://data.together.xyz/redpajama-data-1T/v1.0.0/urls.txt or through" + " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T" + " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample \n" + ) + + prefix, _ = os.path.splitext(name) + + builder = packed_dataset.PackedDatasetBuilder( + outdir=destination_path, + prefix=prefix, + chunk_size=chunk_size, + sep_token=tokenizer.eos_id, + dtype="auto", + vocab_size=tokenizer.vocab_size, + ) + + print(f"Processing {name}") + + with open(filepath, encoding="utf-8") as f: + for row in tqdm(f): + text = json.loads(row)["text"] + text_ids = tokenizer.encode(text) + builder.add_array(np.array(text_ids, dtype=builder.dtype)) + + builder.write_reminder() + + +def prepare_full( + source_path: Path, + checkpoint_dir: Path, + destination_path: Path, + chunk_size: int, + match: str = "", +) -> None: + """Prepare the "Red Pajama" dataset using the original tokenizer.""" + import zstandard as zstd + + destination_path.mkdir(parents=True, exist_ok=True) + + tokenizer = Tokenizer(checkpoint_dir) + + for set_name, pattern in filename_sets.items(): + if match and match not in set_name: + continue + + is_cc = set_name == "common_crawl" + + filenames = glob.glob( + os.path.join(source_path, pattern), recursive=True + ) + + if not filenames: + raise RuntimeError( + f"No files matching {pattern} found at {source_path}. \nMake sure you download the data, e.g. wget -i" + " https://data.together.xyz/redpajama-data-1T/v1.0.0/urls.txt or through" + " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T" + " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample \n" + ) + + builder = packed_dataset.PackedDatasetBuilder( + outdir=destination_path, + prefix=set_name, + chunk_size=chunk_size, + sep_token=tokenizer.eos_id, + dtype="auto", + vocab_size=tokenizer.vocab_size, + ) + + for name in filenames: + filepath = source_path / name + + print(f"Processing {name}") + + if is_cc: + with zstd.open( + open(filepath, "rb"), "rt", encoding="utf-8" + ) as f: + for row in tqdm(f): + text = json.loads(row)["text"] + text_ids = tokenizer.encode(text) + builder.add_array( + np.array(text_ids, dtype=builder.dtype) + ) + else: + with open(filepath, encoding="utf-8") as f: + for row in tqdm(f): + text = json.loads(row)["text"] + text_ids = tokenizer.encode(text) + builder.add_array( + np.array(text_ids, dtype=builder.dtype) + ) + + builder.write_reminder() + + +def prepare( + source_path: Path = Path("data/RedPajama-Data-1T-Sample"), + checkpoint_dir: Path = Path( + "checkpoints/stabilityai/stablelm-base-alpha-3b" + ), + destination_path: Path = Path("data/redpajama_sample"), + sample: bool = True, + match: str = "", +) -> None: + """Prepare the "Red Pajama" dataset. We assume tokenizer has been trained.""" + config = Config.from_checkpoint(checkpoint_dir) + + prepare_fn = prepare_sample if sample else prepare_full + prepare_fn( + source_path=source_path, + checkpoint_dir=checkpoint_dir, + destination_path=destination_path, + chunk_size=(config.block_size + 1) + * 1024, # block size + 1 for causal, 1024 blocks + match=match, + ) + + +if __name__ == "__main__": + CLI(prepare) diff --git a/examples/llm_finetuning/scripts/prepare_slimpajama.py b/examples/llm_finetuning/scripts/prepare_slimpajama.py new file mode 100644 index 00000000000..0a80191f299 --- /dev/null +++ b/examples/llm_finetuning/scripts/prepare_slimpajama.py @@ -0,0 +1,68 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import json +import os +import sys +import time +from pathlib import Path + +import zstandard as zstd +from lightning.data.streaming import DataChunkRecipe, DataProcessor + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt import Tokenizer +from lit_gpt.utils import CLI + + +class SlimPajamaDataRecipe(DataChunkRecipe): + def __init__(self, tokenizer: Tokenizer, chunk_size: int): + super().__init__(chunk_size) + self.tokenizer = tokenizer + + def prepare_structure(self, input_dir): + files = Path(input_dir).rglob("*.zst") + return [str(file) for file in files] + + def prepare_item(self, filepath): + with zstd.open(open(filepath, "rb"), "rt", encoding="utf-8") as f: + for row in f: + text = json.loads(row)["text"] + if ( + json.loads(row)["meta"]["redpajama_set_name"] + == "RedPajamaGithub" + ): + continue # exclude the GitHub data since it overlaps with starcoder + text_ids = self.tokenizer.encode(text, bos=False, eos=True) + yield text_ids + + +def prepare( + input_dir: Path = Path("data/SlimPajama-627B/train"), + output_dir: Path = Path("data/slimpajama/train"), + tokenizer_path: Path = Path("checkpoints/Llama-2-7b-hf/"), + chunk_size: int = (2049 * 16384), + fast_dev_run: bool = False, +) -> None: + tokenizer = Tokenizer(tokenizer_path) + data_recipe = SlimPajamaDataRecipe( + tokenizer=tokenizer, chunk_size=chunk_size + ) + data_processor = DataProcessor( + input_dir=str(input_dir), + output_dir=str(output_dir), + fast_dev_run=fast_dev_run, + num_workers=os.cpu_count(), + num_downloaders=1, + ) + + start_time = time.time() + data_processor.run(data_recipe) + elapsed_time = time.time() - start_time + print(f"Time taken: {elapsed_time:.2f} seconds") + + +if __name__ == "__main__": + CLI(prepare) diff --git a/examples/llm_finetuning/scripts/prepare_starcoder.py b/examples/llm_finetuning/scripts/prepare_starcoder.py new file mode 100644 index 00000000000..1f67c93e1fe --- /dev/null +++ b/examples/llm_finetuning/scripts/prepare_starcoder.py @@ -0,0 +1,78 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import os +import sys +import time +import traceback +from pathlib import Path + +import pyarrow.parquet as pq +from lightning.data.streaming import DataChunkRecipe, DataProcessor + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt import Tokenizer +from lit_gpt.utils import CLI + + +class StarcoderDataRecipe(DataChunkRecipe): + def __init__(self, tokenizer: Tokenizer, chunk_size: int): + super().__init__(chunk_size) + self.tokenizer = tokenizer + + def prepare_structure(self, input_dir): + files = Path(input_dir).rglob("*.parquet") + return [str(file) for file in files] + + def prepare_item(self, item_metadata): + filepath = item_metadata + start = time.time() + + try: + parquet_file = pq.ParquetFile(filepath) + # reduce RAM usage + for batch in parquet_file.iter_batches( + batch_size=8192, columns=["content"] + ): + for text in batch.to_pandas()["content"]: + yield self.tokenizer.encode(text, bos=False, eos=True) + + except Exception: + print(traceback.format_exc()) + print(f"Error reading {filepath}") + return + + parquet_file.close() + end = time.time() + print(f"Took {end - start:.2f} seconds total", filepath) + + +def prepare( + input_dir: Path = Path("data/starcoderdata"), + output_dir: Path = Path("data/starcoder"), + tokenizer_path: Path = Path("checkpoints/Llama-2-7b-hf/"), + chunk_size: int = (2049 * 8192), + fast_dev_run: bool = False, +) -> None: + tokenizer = Tokenizer(tokenizer_path) + data_recipe = StarcoderDataRecipe( + tokenizer=tokenizer, chunk_size=chunk_size + ) + data_processor = DataProcessor( + input_dir=str(input_dir), + output_dir=str(output_dir), + fast_dev_run=fast_dev_run, + num_workers=os.cpu_count(), + num_downloaders=1, + ) + + start_time = time.time() + data_processor.run(data_recipe) + elapsed_time = time.time() - start_time + print(f"Time taken: {elapsed_time:.2f} seconds") + + +if __name__ == "__main__": + CLI(prepare) diff --git a/examples/llm_finetuning/steps/__init__.py b/examples/llm_finetuning/steps/__init__.py new file mode 100644 index 00000000000..c9630597e75 --- /dev/null +++ b/examples/llm_finetuning/steps/__init__.py @@ -0,0 +1,21 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from steps.evaluate import evaluate +from steps.feature_engineering import feature_engineering +from steps.finetune import finetune +from steps.merge import merge diff --git a/examples/llm_finetuning/steps/evaluate.py b/examples/llm_finetuning/steps/evaluate.py new file mode 100644 index 00000000000..f9570dee734 --- /dev/null +++ b/examples/llm_finetuning/steps/evaluate.py @@ -0,0 +1,143 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import shutil +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional + +import torch +from evaluate.lm_eval_harness import run_eval_harness +from huggingface_hub import snapshot_download +from pydantic import BaseModel +from scripts.download import download_from_hub +from scripts.merge_lora import merge_lora +from typing_extensions import Annotated + +from steps.params import LoraParameters +from steps.utils import ( + convert_to_lit_checkpoint_if_necessary, + get_huggingface_access_token, +) +from zenml import step +from zenml.logger import get_logger + +logger = get_logger(__file__) + + +class EvaluationParameters(BaseModel): + """Parameters for the evaluation step. + + If `adapter_repo` is set, it will be merged with the model. Otherwise + the model itself will be evaluated. + """ + + model_repo: str + from_safetensors: bool = False + adapter_repo: Optional[str] = None + + precision: Optional[str] = None + quantize: Optional[ + Literal[ + "bnb.nf4", + "bnb.nf4-dq", + "bnb.fp4", + "bnb.fp4-dq", + "bnb.int8-training", + ] + ] = None + + lora: LoraParameters = LoraParameters() + + eval_tasks: List[str] = [ + "arc_challenge", + "piqa", + "hellaswag", + "hendrycksTest-*", + ] + num_fewshot: int = 0 + limit: Optional[int] = None + bootstrap_iters: int = 100000 + no_cache: bool = True + + +@step +def evaluate( + config: EvaluationParameters, +) -> Annotated[Dict[str, Any], "evaluation_results"]: + """Evaluate model. + + Args: + config: Configuration for this step. + """ + torch.set_float32_matmul_precision("high") + + access_token = get_huggingface_access_token() + + checkpoint_root_dir = Path("checkpoints") + checkpoint_dir = checkpoint_root_dir / config.model_repo + + if checkpoint_dir.exists(): + logger.info( + "Checkpoint directory already exists, skipping download..." + ) + else: + download_from_hub( + repo_id=config.model_repo, + from_safetensors=config.from_safetensors, + checkpoint_dir=checkpoint_root_dir, + access_token=access_token, + ) + + convert_to_lit_checkpoint_if_necessary(checkpoint_dir=checkpoint_dir) + + if config.adapter_repo: + adapter_dir = Path("adapters") / config.adapter_repo + merged_dir = Path("output/merged") + + snapshot_download( + config.adapter_repo, + local_dir=adapter_dir, + local_dir_use_symlinks=False, + resume_download=True, + token=access_token, + ) + + lora_path = adapter_dir / "lit_model_lora_finetuned.pth" + merge_lora( + lora_path=lora_path, + checkpoint_dir=checkpoint_dir, + out_dir=merged_dir, + precision=config.precision, + **config.lora.dict(), + ) + + for path in Path(checkpoint_dir).glob("*.json"): + destination = Path(merged_dir) / path.name + shutil.copy(src=path, dst=destination) + + checkpoint_dir = merged_dir + + output_path = Path("output.json") + run_eval_harness( + checkpoint_dir=checkpoint_dir, + save_filepath=output_path, + **config.dict(exclude={"model_repo", "adapter_repo", "lora"}), + ) + + with open(output_path, "r") as f: + return json.load(f) diff --git a/examples/llm_finetuning/steps/feature_engineering.py b/examples/llm_finetuning/steps/feature_engineering.py new file mode 100644 index 00000000000..c47eb8a28e3 --- /dev/null +++ b/examples/llm_finetuning/steps/feature_engineering.py @@ -0,0 +1,89 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import importlib +import json +from dataclasses import asdict +from pathlib import Path +from typing import Any, Dict + +from lit_gpt import Config +from materializers.directory_materializer import DirectoryMaterializer +from pydantic import BaseModel +from scripts.download import download_from_hub +from typing_extensions import Annotated + +from steps.utils import get_huggingface_access_token +from zenml import log_artifact_metadata, step + + +class FeatureEngineeringParameters(BaseModel): + """Parameters for the feature engineering step.""" + + model_repo: str + dataset_name: str + + prepare_kwargs: Dict[str, Any] = {} + + +@step(output_materializers=DirectoryMaterializer) +def feature_engineering( + config: FeatureEngineeringParameters, +) -> Annotated[Path, "dataset"]: + """Prepare the dataset. + + Args: + config: Configuration for this step. + """ + access_token = get_huggingface_access_token() + + checkpoint_root_dir = Path("checkpoints") + download_from_hub( + repo_id=config.model_repo, + tokenizer_only=True, + checkpoint_dir=checkpoint_root_dir, + access_token=access_token, + ) + + checkpoint_dir = checkpoint_root_dir / config.model_repo + + model_name = checkpoint_dir.name + lit_config = Config.from_name(model_name) + lit_config_dict = asdict(lit_config) + with open(checkpoint_dir / "lit_config.json", "w") as json_config: + json.dump(lit_config_dict, json_config) + + log_artifact_metadata( + metadata={ + "model_name": model_name, + "model_config": lit_config_dict, + "dataset_name": config.dataset_name, + } + ) + destination_dir = Path("data") / config.dataset_name + + helper_module = importlib.import_module( + f"scripts.prepare_{config.dataset_name}" + ) + prepare_function = getattr(helper_module, "prepare") + + prepare_function( + checkpoint_dir=checkpoint_dir, + destination_path=destination_dir, + **config.prepare_kwargs, + ) + return destination_dir diff --git a/examples/llm_finetuning/steps/finetune.py b/examples/llm_finetuning/steps/finetune.py new file mode 100644 index 00000000000..fa3a9305e8b --- /dev/null +++ b/examples/llm_finetuning/steps/finetune.py @@ -0,0 +1,249 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import shutil +from pathlib import Path +from typing import Literal, Optional + +import torch +from finetune.lora import setup +from huggingface_hub import upload_folder +from lit_gpt.args import EvalArgs, IOArgs, TrainArgs +from materializers.directory_materializer import DirectoryMaterializer +from pydantic import BaseModel +from scripts.convert_lit_checkpoint import convert_lit_checkpoint +from scripts.download import download_from_hub +from scripts.merge_lora import merge_lora +from scripts.prepare_alpaca import prepare +from typing_extensions import Annotated + +from steps.params import LoraParameters +from steps.utils import ( + convert_to_lit_checkpoint_if_necessary, + get_huggingface_access_token, +) +from zenml import get_step_context, log_model_metadata, step +from zenml.logger import get_logger +from zenml.materializers import BuiltInMaterializer + +logger = get_logger(__file__) + + +class DataParameters(BaseModel): + """Data preprocessing parameters.""" + + seed: int = 42 + test_split_fraction: float = 0.03865 + mask_inputs: bool = False + ignore_index: int = -1 + max_seq_length: Optional[int] = None + + +class TrainingParameters(BaseModel): + """Training parameters.""" + + save_interval: int = 1000 + log_interval: int = 1 + global_batch_size: int = 64 + micro_batch_size: int = 4 + lr_warmup_steps: int = 100 + epochs: Optional[int] = None + epoch_size: Optional[int] = None + max_tokens: Optional[int] = None + max_seq_length: Optional[int] = None + + learning_rate: float = 1e-3 + weight_decay: float = 0.02 + beta1: float = 0.9 + beta2: float = 0.95 + max_norm: Optional[float] = None + min_lr: float = 6e-5 + + +class EvalParameters(BaseModel): + """Mid-training evaluation parameters.""" + + interval: int = 100 + max_new_tokens: int = 100 + max_iters: int = 100 + + +class FinetuningParameters(BaseModel): + """Parameters for the finetuning step.""" + + base_model_repo: str + from_safetensors: bool = False + + adapter_output_repo: Optional[str] = None + merged_output_repo: Optional[str] = None + convert_to_hf_checkpoint: bool = False + + precision: Optional[str] = None + quantize: Optional[ + Literal[ + "bnb.nf4", + "bnb.nf4-dq", + "bnb.fp4", + "bnb.fp4-dq", + "bnb.int8-training", + ] + ] = None + + data: DataParameters = DataParameters() + training: TrainingParameters = TrainingParameters() + eval: EvalParameters = EvalParameters() + lora: LoraParameters = LoraParameters() + + +@step(output_materializers=[DirectoryMaterializer, BuiltInMaterializer]) +def finetune( + config: FinetuningParameters, dataset_directory: Optional[Path] = None +) -> Annotated[Optional[Path], "adapter"]: + """Finetune model using LoRA. + + Args: + config: Configuration for this step. + """ + torch.set_float32_matmul_precision("high") + + access_token = get_huggingface_access_token() + + checkpoint_root_dir = Path("checkpoints") + checkpoint_dir = checkpoint_root_dir / config.base_model_repo + + if checkpoint_dir.exists(): + logger.info( + "Checkpoint directory already exists, skipping download..." + ) + else: + download_from_hub( + repo_id=config.base_model_repo, + from_safetensors=config.from_safetensors, + checkpoint_dir=checkpoint_root_dir, + access_token=access_token, + ) + + convert_to_lit_checkpoint_if_necessary(checkpoint_dir=checkpoint_dir) + + if dataset_directory: + try: + dataset_name = ( + get_step_context() + .inputs["dataset_directory"] + .run_metadata["dataset_name"] + .value + ) + except KeyError: + dataset_name = "unknown_dataset" + else: + dataset_directory = Path("data/alpaca") + dataset_name = dataset_directory.name + prepare( + destination_path=dataset_directory, + checkpoint_dir=checkpoint_dir, + test_split_fraction=config.data.test_split_fraction, + seed=config.data.seed, + mask_inputs=config.data.mask_inputs, + ignore_index=config.data.ignore_index, + max_seq_length=config.data.max_seq_length, + ) + + model_name = checkpoint_dir.name + + log_model_metadata( + metadata={"model_name": model_name, "dataset_name": dataset_name} + ) + adapter_output_dir = Path("output/lora") / dataset_name / model_name + + io_args = IOArgs( + train_data_dir=dataset_directory, + val_data_dir=dataset_directory, + checkpoint_dir=checkpoint_dir, + out_dir=adapter_output_dir, + ) + train_args = TrainArgs(**config.training.dict()) + eval_args = EvalArgs(**config.eval.dict()) + setup( + devices=1, + io=io_args, + train=train_args, + eval=eval_args, + precision=config.precision, + quantize=config.quantize, + **config.lora.dict(), + ) + + if config.merged_output_repo: + lora_path = adapter_output_dir / "lit_model_lora_finetuned.pth" + + merge_output_dir = ( + Path("output/lora_merged") / dataset_name / model_name + ) + merge_lora( + lora_path=lora_path, + checkpoint_dir=checkpoint_dir, + out_dir=merge_output_dir, + precision=config.precision, + **config.lora.dict(), + ) + + for path in Path(checkpoint_dir).glob("*.json"): + destination = Path(merge_output_dir) / path.name + shutil.copy(src=path, dst=destination) + + if config.convert_to_hf_checkpoint: + upload_dir = ( + Path("output/lora_merged_hf") / dataset_name / model_name + ) + upload_dir.mkdir(parents=True, exist_ok=True) + convert_lit_checkpoint( + checkpoint_path=config.merged_output_repo / "lit_model.pth", + config_path=config.merged_output_repo / "lit_config.json", + output_path=upload_dir / "pytorch_model", + ) + else: + upload_dir = merge_output_dir + + commit = upload_folder( + repo_id=config.merged_output_repo, + folder_path=upload_dir, + token=access_token, + ) + log_model_metadata( + metadata={ + "merged_model_huggingface_commit_hash": commit.oid, + "merged_model_huggingface_commit_url": commit.commit_url, + } + ) + + if config.adapter_output_repo: + commit = upload_folder( + repo_id=config.adapter_output_repo, + folder_path=adapter_output_dir, + token=access_token, + ) + log_model_metadata( + metadata={ + "adapter_huggingface_commit_hash": commit.oid, + "adapter_huggingface_commit_url": commit.commit_url, + } + ) + return None + else: + # If the adapter should not be uploaded to the HF Hub, we store it + # in the artifact store + return adapter_output_dir diff --git a/examples/llm_finetuning/steps/merge.py b/examples/llm_finetuning/steps/merge.py new file mode 100644 index 00000000000..bc8fa90f716 --- /dev/null +++ b/examples/llm_finetuning/steps/merge.py @@ -0,0 +1,124 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import shutil +from pathlib import Path +from typing import Optional + +from huggingface_hub import snapshot_download, upload_folder +from pydantic import BaseModel +from scripts.convert_lit_checkpoint import convert_lit_checkpoint +from scripts.download import download_from_hub +from scripts.merge_lora import merge_lora + +from steps.params import LoraParameters +from steps.utils import ( + convert_to_lit_checkpoint_if_necessary, + get_huggingface_access_token, +) +from zenml import log_model_metadata, step +from zenml.logger import get_logger + +logger = get_logger(__file__) + + +class MergeParameters(BaseModel): + """Parameters for the merging step.""" + + base_model_repo: str + from_safetensors: bool = False + + adapter_repo: str + output_repo: str + convert_to_hf_checkpoint: bool = False + + precision: Optional[str] = None + lora: LoraParameters = LoraParameters() + + +@step +def merge(config: MergeParameters) -> None: + """Merge base model and LoRA adapter. + + Args: + config: Configuration for this step. + """ + access_token = get_huggingface_access_token() + + checkpoint_root_dir = Path("checkpoints") + base_model_dir = checkpoint_root_dir / config.base_model_repo + adapter_dir = Path("adapters") / config.adapter_repo + + if base_model_dir.exists(): + logger.info( + "Checkpoint directory already exists, skipping download..." + ) + else: + download_from_hub( + repo_id=config.base_model_repo, + from_safetensors=config.from_safetensors, + checkpoint_dir=checkpoint_root_dir, + access_token=access_token, + ) + + convert_to_lit_checkpoint_if_necessary(checkpoint_dir=base_model_dir) + + snapshot_download( + config.adapter_repo, + local_dir=adapter_dir, + local_dir_use_symlinks=False, + resume_download=True, + token=access_token, + ) + + lora_path = adapter_dir / "lit_model_lora_finetuned.pth" + merged_dir = Path("output/merged") + + merge_lora( + lora_path=lora_path, + checkpoint_dir=base_model_dir, + out_dir=merged_dir, + precision=config.precision, + **config.lora.dict(), + ) + + for path in Path(base_model_dir).glob("*.json"): + destination = Path(merged_dir) / path.name + shutil.copy(src=path, dst=destination) + + if config.convert_to_hf_checkpoint: + model_name = base_model_dir.name + + output_dir = Path("output/lora_merged_hf") / model_name + output_dir.mkdir(parents=True, exist_ok=True) + convert_lit_checkpoint( + checkpoint_path=merged_dir / "lit_model.pth", + config_path=merged_dir / "lit_config.json", + output_path=output_dir / "pytorch_model", + ) + else: + output_dir = merged_dir + + commit = upload_folder( + repo_id=config.output_repo, folder_path=output_dir, token=access_token + ) + log_model_metadata( + metadata={ + "merged_model_huggingface_commit_hash": commit.oid, + "merged_model_huggingface_commit_url": commit.commit_url, + } + ) diff --git a/examples/llm_finetuning/steps/params.py b/examples/llm_finetuning/steps/params.py new file mode 100644 index 00000000000..52e450de206 --- /dev/null +++ b/examples/llm_finetuning/steps/params.py @@ -0,0 +1,32 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pydantic import BaseModel + + +class LoraParameters(BaseModel): + """Lora specific parameters.""" + + lora_r: int = 8 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_query: bool = True + lora_key: bool = False + lora_value: bool = True + lora_projection: bool = False + lora_mlp: bool = False + lora_head: bool = False diff --git a/examples/llm_finetuning/steps/utils.py b/examples/llm_finetuning/steps/utils.py new file mode 100644 index 00000000000..c81238fef5c --- /dev/null +++ b/examples/llm_finetuning/steps/utils.py @@ -0,0 +1,54 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from pathlib import Path +from typing import Optional + +from scripts.convert_hf_checkpoint import convert_hf_checkpoint + +from zenml.client import Client + + +def get_huggingface_access_token() -> Optional[str]: + """Get access token for huggingface. + + Returns: + The access token if one was found. + """ + try: + return ( + Client() + .get_secret("huggingface_credentials") + .secret_values["token"] + ) + except KeyError: + return os.getenv("HF_TOKEN") + + +def convert_to_lit_checkpoint_if_necessary(checkpoint_dir: Path) -> None: + """Convert an HF checkpoint to a lit checkpoint if necessary. + + Args: + checkpoint_dir: The directory of the HF checkpoint. + """ + lit_model_path = checkpoint_dir / "lit_model.pth" + + if lit_model_path.is_file(): + return + + convert_hf_checkpoint(checkpoint_dir=checkpoint_dir) diff --git a/pyproject.toml b/pyproject.toml index f86386c9498..dadc1fe0620 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "zenml" -version = "0.55.5" +version = "0.56.2" packages = [{ include = "zenml", from = "src" }] description = "ZenML: Write production-ready ML code." authors = ["ZenML GmbH "] @@ -60,8 +60,8 @@ python = ">=3.8,<3.12" python-dateutil = "^2.8.1" pyyaml = ">=6.0.1" rich = { extras = ["jupyter"], version = ">=12.0.0" } -sqlalchemy_utils = "0.41.1" -sqlmodel = ">=0.0.9, <=0.0.16" +sqlalchemy_utils = "0.38.3" +sqlmodel = "0.0.8" importlib_metadata = { version = "<=7.0.0", python = "<3.10" } # Optional dependencies for the ZenServer @@ -69,6 +69,7 @@ fastapi = { version = ">=0.75,<0.100", optional = true } uvicorn = { extras = ["standard"], version = ">=0.17.5", optional = true } python-multipart = { version = "~0.0.5", optional = true } pyjwt = { extras = ["crypto"], version = "2.7.*", optional = true } +fastapi-utils = { version = "~0.2.1", optional = true } orjson = { version = "~3.8.3", optional = true } Jinja2 = { version = "*", optional = true } ipinfo = { version = ">=4.4.3", optional = true } @@ -303,6 +304,12 @@ exclude = [ "venv", '__init__.py', 'src/zenml/cli/version.py', + # LitGPT files from the LLM Finetuning example + 'examples/llm_finetuning/evaluate', + 'examples/llm_finetuning/finetune', + 'examples/llm_finetuning/generate', + 'examples/llm_finetuning/lit_gpt', + 'examples/llm_finetuning/scripts', ] src = ["src", "test"] @@ -439,6 +446,7 @@ module = [ "bentoml.*", "multipart.*", "jose.*", + "fastapi_utils.*", "sqlalchemy_utils.*", "sky.*", "copier.*", diff --git a/scripts/check-security.sh b/scripts/check-security.sh index de3fbef2060..8894b8d1eb2 100755 --- a/scripts/check-security.sh +++ b/scripts/check-security.sh @@ -8,4 +8,6 @@ SRC=${1:-"src/zenml tests examples"} export ZENML_DEBUG=1 export ZENML_ANALYTICS_OPT_IN=false -bandit -r $SRC -ll +bandit -r $SRC -ll \ + --exclude examples/llm_finetuning/scripts/prepare_alpaca.py + diff --git a/scripts/install-dashboard.sh b/scripts/install-dashboard.sh index 7c4d492a694..445097ff0d9 100755 --- a/scripts/install-dashboard.sh +++ b/scripts/install-dashboard.sh @@ -25,6 +25,17 @@ verifySupported() { fi } +# checkGitIgnore checks if the dashboard directories are ignored by Git +checkGitIgnore() { + if [ -f ".gitignore" ]; then + if grep -q -E "(^|\/)dashboard($|\/)" ".gitignore" || grep -q -E "(^|\/)src\/zenml\/zen_server\/dashboard($|\/)" ".gitignore"; then + echo "Error: The '/dashboard' or 'src/zenml/zen_server/dashboard' directory is ignored by Git." + echo "Please remove the corresponding entries from the .gitignore file to proceed with the installation." + exit 1 + fi + fi +} + # checkTagProvided checks whether TAG has provided as an environment variable # so we can skip checkLatestVersion checkTagProvided() { @@ -143,10 +154,11 @@ done set +u verifySupported +checkGitIgnore checkTagProvided || checkLatestVersion if [[ ! -z "$TAG" ]]; then downloadFile verifyFile installFile fi -cleanup \ No newline at end of file +cleanup diff --git a/scripts/test-migrations-mariadb.sh b/scripts/test-migrations-mariadb.sh index 12167e2894e..30494823381 100755 --- a/scripts/test-migrations-mariadb.sh +++ b/scripts/test-migrations-mariadb.sh @@ -7,22 +7,22 @@ function run_tests_for_version() { set -e # Exit immediately if a command exits with a non-zero status local VERSION=$1 + export ZENML_ANALYTICS_OPT_IN=false + export ZENML_DEBUG=true + echo "===== Testing version $VERSION =====" mkdir test_starter - zenml init --template starter --path test_starter --template-with-defaults --test + zenml init --template starter --path test_starter --template-with-defaults <<< $'my@mail.com\n' cd test_starter - export ZENML_ANALYTICS_OPT_IN=false - export ZENML_DEBUG=true - echo "===== Installing sklearn integration =====" zenml integration export-requirements sklearn --output-file sklearn-requirements.txt uv pip install -r sklearn-requirements.txt rm sklearn-requirements.txt echo "===== Running starter template pipeline =====" - python3 run.py + python3 run.py --feature-pipeline --training-pipeline --no-cache # Add additional CLI tests here zenml version diff --git a/scripts/test-migrations-mysql.sh b/scripts/test-migrations-mysql.sh index 4a52ecfa927..804e7cda48d 100755 --- a/scripts/test-migrations-mysql.sh +++ b/scripts/test-migrations-mysql.sh @@ -17,7 +17,11 @@ function run_tests_for_version() { local VERSION=$1 # versions pre-templates and pre-init test flag # (zenml init --test allows for a non-interactive init) - local PRE_TEMPLATE_VERSIONS=("0.40.0" "0.40.3" "0.41.0" "0.43.0" "0.44.1" "0.44.3" "0.45.2" "0.45.3" "0.45.4" "0.45.5" "0.45.6" "0.46.0" "0.47.0") + local PRE_TEMPLATE_VERSIONS=("0.40.0" "0.40.3" "0.41.0" "0.43.0") + local PRE_ARGS_VERSIONS=("0.40.0" "0.40.3" "0.41.0" "0.43.0" "0.44.1" "0.44.3" "0.45.2" "0.45.3" "0.45.4" "0.45.5" "0.45.6" "0.46.0" "0.47.0" "0.50.0" "0.51.0" "0.52.0") + + export ZENML_ANALYTICS_OPT_IN=false + export ZENML_DEBUG=true echo "===== Testing version $VERSION =====" @@ -26,7 +30,7 @@ function run_tests_for_version() { copier copy -l --trust -r release/0.43.0 https://github.com/zenml-io/template-starter.git test_starter else mkdir test_starter - zenml init --template starter --path test_starter --template-with-defaults --test + zenml init --template starter --path test_starter --template-with-defaults <<< $'my@mail.com\n' fi cd test_starter @@ -40,7 +44,11 @@ function run_tests_for_version() { rm sklearn-requirements.txt echo "===== Running starter template pipeline =====" - python3 run.py + if printf '%s\n' "${PRE_ARGS_VERSIONS[@]}" | grep -q "^$VERSION$"; then + python3 run.py --no-cache + else + python3 run.py --feature-pipeline --training-pipeline --no-cache + fi # Add additional CLI tests here zenml version @@ -88,10 +96,10 @@ do # Get the major and minor version of Python PYTHON_VERSION=$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")') - # Check if the Python version is 3.9 and VERSION is > 0.47.0 + # Check if the Python version is 3.9 and VERSION is > 0.44.0 if [[ "$PYTHON_VERSION" == "3.9" ]]; then case "$VERSION" in - "0.47.0"|"0.50.0"|"0.51.0"|"0.52.0") + "0.44.1"|"0.44.3"|"0.45.2"|"0.45.3"|"0.45.4"|"0.45.5"|"0.45.6"|"0.46.0"|"0.47.0"|"0.50.0"|"0.51.0"|"0.52.0") uv pip install importlib_metadata ;; esac diff --git a/src/zenml/VERSION b/src/zenml/VERSION index 9aaab801597..cc169d8ce70 100644 --- a/src/zenml/VERSION +++ b/src/zenml/VERSION @@ -1 +1 @@ -0.55.5 +0.56.2 \ No newline at end of file diff --git a/src/zenml/cli/base.py b/src/zenml/cli/base.py index 9ae9d77abfd..e8a8b1655bd 100644 --- a/src/zenml/cli/base.py +++ b/src/zenml/cli/base.py @@ -83,6 +83,10 @@ def copier_github_url(self) -> str: github_url="zenml-io/template-nlp", github_tag="2024.01.12", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml ), + llm_finetuning=ZenMLProjectTemplateLocation( + github_url="zenml-io/template-llm-finetuning", + github_tag="2024.03.18", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml + ), ) @@ -98,9 +102,9 @@ def copier_github_url(self) -> str: type=str, required=False, help="Name or URL of the ZenML project template to use to initialize the " - "repository, Can be a string like `e2e_batch`, `nlp`, `starter` etc. or a " - "copier URL like gh:owner/repo_name. If not specified, no template is " - "used.", + "repository, Can be a string like `e2e_batch`, `nlp`, `llm_finetuning`, " + "`starter` etc. or a copier URL like gh:owner/repo_name. If not specified, " + "no template is used.", ) @click.option( "--template-tag", diff --git a/src/zenml/cli/served_model.py b/src/zenml/cli/served_model.py index 77540039744..3d708651121 100644 --- a/src/zenml/cli/served_model.py +++ b/src/zenml/cli/served_model.py @@ -29,6 +29,7 @@ ) from zenml.console import console from zenml.enums import StackComponentType +from zenml.model_deployers import BaseModelDeployer if TYPE_CHECKING: from zenml.model_deployers import BaseModelDeployer @@ -71,14 +72,6 @@ def models(ctx: click.Context) -> None: help="Get a list of all served models within the model-deployer stack " "component.", ) - @click.option( - "--pipeline", - "-p", - type=click.STRING, - default=None, - help="Show only served models that were deployed by the indicated " - "pipeline.", - ) @click.option( "--step", "-s", @@ -88,13 +81,21 @@ def models(ctx: click.Context) -> None: "pipeline step.", ) @click.option( - "--run-name", + "--pipeline-run-id", "-r", type=click.STRING, default=None, help="Show only served models that were deployed by the indicated " "pipeline run.", ) + @click.option( + "--pipeline-name", + "-p", + type=click.STRING, + default=None, + help="Show only served models that were deployed by the indicated " + "pipeline.", + ) @click.option( "--model", "-m", @@ -102,6 +103,20 @@ def models(ctx: click.Context) -> None: default=None, help="Show only served model versions for the given model name.", ) + @click.option( + "--model-version", + "-v", + type=click.STRING, + default=None, + help="Show only served model versions for the given model version.", + ) + @click.option( + "--flavor", + "-f", + type=click.STRING, + default=None, + help="Show only served model versions for the given model flavor.", + ) @click.option( "--running", is_flag=True, @@ -110,31 +125,38 @@ def models(ctx: click.Context) -> None: @click.pass_obj def list_models( model_deployer: "BaseModelDeployer", - pipeline: Optional[str], step: Optional[str], - run_name: Optional[str], + pipeline_name: Optional[str], + pipeline_run_id: Optional[str], model: Optional[str], + model_version: Optional[str], + flavor: Optional[str], running: bool, ) -> None: """List of all served models within the model-deployer stack component. Args: model_deployer: The model-deployer stack component. - pipeline: Show only served models that were deployed by the - indicated pipeline. step: Show only served models that were deployed by the indicated pipeline step. - run_name: Show only served models that were deployed by the + pipeline_run_id: Show only served models that were deployed by the indicated pipeline run. + pipeline_name: Show only served models that were deployed by the + indicated pipeline. model: Show only served model versions for the given model name. running: Show only model servers that are currently running. + model_version: Show only served model versions for the given model + version. + flavor: Show only served model versions for the given model flavor. """ services = model_deployer.find_model_server( running=running, - pipeline_name=pipeline, - run_name=run_name, + pipeline_name=pipeline_name, + pipeline_run_id=pipeline_run_id if pipeline_run_id else None, pipeline_step_name=step, model_name=model, + model_version=model_version, + flavor=flavor, ) if services: pretty_print_model_deployer( @@ -386,14 +408,16 @@ def get_model_service_logs( ) return - for line in model_deployer.get_model_server_logs( + model_logs = model_deployer.get_model_server_logs( served_models[0].uuid, follow=follow, tail=tail - ): - # don't pretty-print log lines that are already pretty-printed - if raw or line.startswith("\x1b["): - console.print(line, markup=False) - else: - try: - console.print(line) - except MarkupError: + ) + if model_logs: + for line in model_logs: + # don't pretty-print log lines that are already pretty-printed + if raw or line.startswith("\x1b["): console.print(line, markup=False) + else: + try: + console.print(line) + except MarkupError: + console.print(line, markup=False) diff --git a/src/zenml/cli/utils.py b/src/zenml/cli/utils.py index c6879f42e04..bcefd8182d4 100644 --- a/src/zenml/cli/utils.py +++ b/src/zenml/cli/utils.py @@ -1128,6 +1128,10 @@ def get_service_state_emoji(state: "ServiceState") -> str: return ":pause_button:" if state == ServiceState.ERROR: return ":heavy_exclamation_mark:" + if state == ServiceState.PENDING_STARTUP: + return ":hourglass:" + if state == ServiceState.SCALED_TO_ZERO: + return ":chart_decreasing:" return ":hourglass_not_done:" @@ -1142,15 +1146,18 @@ def pretty_print_model_deployer( """ model_service_dicts = [] for model_service in model_services: - served_model_info = model_deployer.get_model_server_info(model_service) dict_uuid = str(model_service.uuid) dict_pl_name = model_service.config.pipeline_name dict_pl_stp_name = model_service.config.pipeline_step_name - dict_model_name = served_model_info.get("MODEL_NAME", "") + dict_model_name = model_service.config.model_name + type = model_service.SERVICE_TYPE.type + flavor = model_service.SERVICE_TYPE.flavor model_service_dicts.append( { "STATUS": get_service_state_emoji(model_service.status.state), "UUID": dict_uuid, + "TYPE": type, + "FLAVOR": flavor, "PIPELINE_NAME": dict_pl_name, "PIPELINE_STEP_NAME": dict_pl_stp_name, "MODEL_NAME": dict_model_name, @@ -1277,9 +1284,10 @@ def print_served_model_configuration( **served_model_info, "UUID": str(model_service.uuid), "STATUS": get_service_state_emoji(model_service.status.state), + "TYPE": model_service.SERVICE_TYPE.type, + "FLAVOR": model_service.SERVICE_TYPE.flavor, "STATUS_MESSAGE": model_service.status.last_error, "PIPELINE_NAME": model_service.config.pipeline_name, - "RUN_NAME": model_service.config.run_name, "PIPELINE_STEP_NAME": model_service.config.pipeline_step_name, } diff --git a/src/zenml/client.py b/src/zenml/client.py index 66f144b4ebe..a32fea00152 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -150,6 +150,10 @@ ServiceConnectorResponse, ServiceConnectorTypeModel, ServiceConnectorUpdate, + ServiceFilter, + ServiceRequest, + ServiceResponse, + ServiceUpdate, StackFilter, StackRequest, StackResponse, @@ -175,7 +179,11 @@ WorkspaceResponse, WorkspaceUpdate, ) +from zenml.services.service import ServiceConfig +from zenml.services.service_status import ServiceState +from zenml.services.service_type import ServiceType from zenml.utils import io_utils, source_utils +from zenml.utils.dict_utils import dict_to_bytes from zenml.utils.filesync_model import FileSyncModel from zenml.utils.pagination_utils import depaginate from zenml.utils.uuid_utils import is_valid_uuid @@ -1478,6 +1486,227 @@ def _validate_stack_configuration(self, stack: StackRequest) -> None: "an Orchestrator." ) + # ----------------------------- Services ----------------------------------- + + def create_service( + self, + config: ServiceConfig, + service_type: ServiceType, + model_version_id: Optional[UUID] = None, + ) -> ServiceResponse: + """Registers a service. + + Args: + config: The configuration of the service. + service_type: The type of the service. + model_version_id: The ID of the model version to associate with the + service. + + Returns: + The registered service. + """ + service_request = ServiceRequest( + name=config.service_name, + service_type=service_type, + config=config.dict(), + workspace=self.active_workspace.id, + user=self.active_user.id, + model_version_id=model_version_id, + ) + # Register the service + return self.zen_store.create_service(service_request) + + def get_service( + self, + name_id_or_prefix: Union[str, UUID], + allow_name_prefix_match: bool = True, + hydrate: bool = True, + type: Optional[str] = None, + ) -> ServiceResponse: + """Gets a service. + + Args: + name_id_or_prefix: The name or ID of the service. + allow_name_prefix_match: If True, allow matching by name prefix. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + type: The type of the service. + + Returns: + The Service + """ + + def type_scoped_list_method( + hydrate: bool = True, + **kwargs: Any, + ) -> Page[ServiceResponse]: + """Call `zen_store.list_services` with type scoping. + + Args: + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + **kwargs: Keyword arguments to pass to `ServiceFilterModel`. + + Returns: + The type-scoped list of services. + """ + service_filter_model = ServiceFilter(**kwargs) + if type: + service_filter_model.set_type(type=type) + service_filter_model.set_scope_workspace(self.active_workspace.id) + return self.zen_store.list_services( + filter_model=service_filter_model, + hydrate=hydrate, + ) + + return self._get_entity_by_id_or_name_or_prefix( + get_method=self.zen_store.get_service, + list_method=type_scoped_list_method, + name_id_or_prefix=name_id_or_prefix, + allow_name_prefix_match=allow_name_prefix_match, + hydrate=hydrate, + ) + + def list_services( + self, + sort_by: str = "created", + page: int = PAGINATION_STARTING_PAGE, + size: int = PAGE_SIZE_DEFAULT, + logical_operator: LogicalOperators = LogicalOperators.AND, + id: Optional[Union[UUID, str]] = None, + created: Optional[datetime] = None, + updated: Optional[datetime] = None, + type: Optional[str] = None, + flavor: Optional[str] = None, + workspace_id: Optional[Union[str, UUID]] = None, + user_id: Optional[Union[str, UUID]] = None, + hydrate: bool = False, + running: Optional[bool] = None, + service_name: Optional[str] = None, + pipeline_name: Optional[str] = None, + pipeline_run_id: Optional[str] = None, + pipeline_step_name: Optional[str] = None, + model_version_id: Optional[Union[str, UUID]] = None, + config: Optional[Dict[str, Any]] = None, + ) -> Page[ServiceResponse]: + """List all services. + + Args: + sort_by: The column to sort by + page: The page of items + size: The maximum size of all pages + logical_operator: Which logical operator to use [and, or] + id: Use the id of services to filter by. + created: Use to filter by time of creation + updated: Use the last updated date for filtering + type: Use the service type for filtering + flavor: Use the service flavor for filtering + workspace_id: The id of the workspace to filter by. + user_id: The id of the user to filter by. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + running: Use the running status for filtering + pipeline_name: Use the pipeline name for filtering + service_name: Use the service name or model name + for filtering + pipeline_step_name: Use the pipeline step name for filtering + model_version_id: Use the model version id for filtering + config: Use the config for filtering + pipeline_run_id: Use the pipeline run id for filtering + + Returns: + The Service response page. + """ + service_filter_model = ServiceFilter( + sort_by=sort_by, + page=page, + size=size, + logical_operator=logical_operator, + id=id, + created=created, + updated=updated, + type=type, + flavor=flavor, + workspace_id=workspace_id, + user_id=user_id, + running=running, + name=service_name, + pipeline_name=pipeline_name, + pipeline_step_name=pipeline_step_name, + model_version_id=model_version_id, + pipeline_run_id=pipeline_run_id, + config=dict_to_bytes(config) if config else None, + ) + service_filter_model.set_scope_workspace(self.active_workspace.id) + return self.zen_store.list_services( + filter_model=service_filter_model, hydrate=hydrate + ) + + def update_service( + self, + id: UUID, + name: Optional[str] = None, + service_source: Optional[str] = None, + admin_state: Optional[ServiceState] = None, + status: Optional[Dict[str, Any]] = None, + endpoint: Optional[Dict[str, Any]] = None, + labels: Optional[Dict[str, str]] = None, + prediction_url: Optional[str] = None, + health_check_url: Optional[str] = None, + model_version_id: Optional[UUID] = None, + ) -> ServiceResponse: + """Update a service. + + Args: + id: The ID of the service to update. + name: The new name of the service. + admin_state: The new admin state of the service. + status: The new status of the service. + endpoint: The new endpoint of the service. + service_source: The new service source of the service. + labels: The new labels of the service. + prediction_url: The new prediction url of the service. + health_check_url: The new health check url of the service. + model_version_id: The new model version id of the service. + + Returns: + The updated service. + """ + service_update = ServiceUpdate() + if name: + service_update.name = name + if service_source: + service_update.service_source = service_source + if admin_state: + service_update.admin_state = admin_state + if status: + service_update.status = status + if endpoint: + service_update.endpoint = endpoint + if labels: + service_update.labels = labels + if prediction_url: + service_update.prediction_url = prediction_url + if health_check_url: + service_update.health_check_url = health_check_url + if model_version_id: + service_update.model_version_id = model_version_id + return self.zen_store.update_service( + service_id=id, update=service_update + ) + + def delete_service(self, name_id_or_prefix: UUID) -> None: + """Delete a service. + + Args: + name_id_or_prefix: The name or ID of the service to delete. + """ + service = self.get_service( + name_id_or_prefix, + allow_name_prefix_match=False, + ) + self.zen_store.delete_service(service_id=service.id) + # -------------------------------- Components ------------------------------ def get_stack_component( diff --git a/src/zenml/config/server_config.py b/src/zenml/config/server_config.py index eee3844d3af..2fef2890e04 100644 --- a/src/zenml/config/server_config.py +++ b/src/zenml/config/server_config.py @@ -26,6 +26,8 @@ DEFAULT_ZENML_JWT_TOKEN_LEEWAY, DEFAULT_ZENML_SERVER_DEVICE_AUTH_POLLING, DEFAULT_ZENML_SERVER_DEVICE_AUTH_TIMEOUT, + DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_DAY, + DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_MINUTE, DEFAULT_ZENML_SERVER_MAX_DEVICE_AUTH_ATTEMPTS, DEFAULT_ZENML_SERVER_PIPELINE_RUN_AUTH_WINDOW, ENV_ZENML_SERVER_PREFIX, @@ -85,13 +87,13 @@ class ServerConfiguration(BaseModel): construct the OAuth 2.0 device authorization endpoint. If not set, a partial URL is returned to the client which is used to construct the full URL based on the server's root URL path. - device_expiration: The time in minutes that an OAuth 2.0 device is + device_expiration_minutes: The time in minutes that an OAuth 2.0 device is allowed to be used to authenticate with the ZenML server. If not set or if `jwt_token_expire_minutes` is not set, the devices are allowed to be used indefinitely. This controls the expiration time of the JWT tokens issued to clients after they have authenticated with the ZenML server using an OAuth 2.0 device. - trusted_device_expiration: The time in minutes that a trusted OAuth 2.0 + trusted_device_expiration_minutes: The time in minutes that a trusted OAuth 2.0 device is allowed to be used to authenticate with the ZenML server. If not set or if `jwt_token_expire_minutes` is not set, the devices are allowed to be used indefinitely. This controls the expiration @@ -114,11 +116,18 @@ class ServerConfiguration(BaseModel): the RBAC interface defined by `zenml.zen_server.rbac_interface.RBACInterface`. If not specified, RBAC will not be enabled for this server. + feature_gate_implementation_source: Source pointing to a class + implementing the feature gate interface defined by + `zenml.zen_server.feature_gate.feature_gate_interface.FeatureGateInterface`. + If not specified, feature usage will not be gated/tracked for this + server. workload_manager_implementation_source: Source pointing to a class implementing the workload management interface. pipeline_run_auth_window: The default time window in minutes for which a pipeline run action is allowed to authenticate with the ZenML server. + login_rate_limit_minute: The number of login attempts allowed per minute. + login_rate_limit_day: The number of login attempts allowed per day. """ deployment_type: ServerDeploymentType = ServerDeploymentType.OTHER @@ -152,11 +161,16 @@ class ServerConfiguration(BaseModel): external_server_id: Optional[UUID] = None rbac_implementation_source: Optional[str] = None + feature_gate_implementation_source: Optional[str] = None workload_manager_implementation_source: Optional[str] = None pipeline_run_auth_window: int = ( DEFAULT_ZENML_SERVER_PIPELINE_RUN_AUTH_WINDOW ) + rate_limit_enabled: bool = False + login_rate_limit_minute: int = DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_MINUTE + login_rate_limit_day: int = DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_DAY + _deployment_id: Optional[UUID] = None @root_validator(pre=True) @@ -236,6 +250,15 @@ def rbac_enabled(self) -> bool: """ return self.rbac_implementation_source is not None + @property + def feature_gate_enabled(self) -> bool: + """Whether feature gating is enabled on the server or not. + + Returns: + Whether feature gating is enabled on the server or not. + """ + return self.feature_gate_implementation_source is not None + @property def workload_manager_enabled(self) -> bool: """Whether workload management is enabled on the server or not. diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 4abf73cd3d7..6da982c828d 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -13,10 +13,61 @@ # permissions and limitations under the License. """ZenML constants.""" +import json +import logging import os +from typing import Any, List, Optional, Type, TypeVar from zenml.enums import AuthScheme +T = TypeVar("T") + + +def handle_json_env_var( + var: str, + expected_type: Type[T], + default: Optional[List[str]] = None, +) -> Any: + """Converts a json env var into a Python object. + + Args: + var: The environment variable to convert. + default: The default value to return if the env var is not set. + expected_type: The type of the expected Python object. + + Returns: + The converted list value. + + Raises: + TypeError: In case the value of the environment variable is not of a + valid type. + + """ + # this needs to be here to avoid mutable defaults + if default is None: + default = [] + + value = os.getenv(var) + if value: + try: + loaded_value = json.loads(value) + # check if loaded value is of correct type + if expected_type is None or isinstance( + loaded_value, expected_type + ): + return loaded_value + else: + raise TypeError # if not correct type, raise TypeError + except (TypeError, json.JSONDecodeError): + # Use raw logging to avoid cyclic dependency + logging.warning( + f"Environment Variable {var} could not be loaded, into type " + f"{expected_type}, defaulting to: {default}." + ) + return default + else: + return default + def handle_bool_env_var(var: str, default: bool = False) -> bool: """Converts normal env var to boolean. @@ -100,6 +151,9 @@ def handle_int_env_var(var: str, default: int = 0) -> int: ENV_ZENML_SERVER_PREFIX = "ZENML_SERVER_" ENV_ZENML_SERVER_DEPLOYMENT_TYPE = f"{ENV_ZENML_SERVER_PREFIX}DEPLOYMENT_TYPE" ENV_ZENML_SERVER_AUTH_SCHEME = f"{ENV_ZENML_SERVER_PREFIX}AUTH_SCHEME" +ENV_ZENML_SERVER_REPORTABLE_RESOURCES = ( + f"{ENV_ZENML_SERVER_PREFIX}REPORTABLE_RESOURCES" +) # Logging variables IS_DEBUG_ENV: bool = handle_bool_env_var(ENV_ZENML_DEBUG, default=False) @@ -178,6 +232,18 @@ def handle_int_env_var(var: str, default: int = 0) -> int: DEFAULT_HTTP_TIMEOUT = 30 ZENML_API_KEY_PREFIX = "ZENKEY_" DEFAULT_ZENML_SERVER_PIPELINE_RUN_AUTH_WINDOW = 60 * 48 # 48 hours +DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_MINUTE = 5 +DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_DAY = 1000 + +# Configurations to decide which resources report their usage and check for +# entitlement in the case of a cloud deployment. Expected Format is this: +# ENV_ZENML_REPORTABLE_RESOURCES='["Foo", "bar"]' +REPORTABLE_RESOURCES: List[str] = handle_json_env_var( + ENV_ZENML_SERVER_REPORTABLE_RESOURCES, + expected_type=list, + default=["pipeline_run", "model"], +) +REQUIRES_CUSTOM_RESOURCE_REPORTING = ["pipeline"] # API Endpoint paths: ACTIVATE = "/activate" @@ -208,10 +274,6 @@ def handle_int_env_var(var: str, default: int = 0) -> int: LOGIN = "/login" LOGOUT = "/logout" LOGS = "/logs" -MODEL_VERSION_ARTIFACTS = "/model_version_artifacts" -MODEL_VERSION_PIPELINE_RUNS = "/model_version_pipeline_runs" -MODEL_VERSIONS = "/model_versions" -MODELS = "/models" PIPELINE_BUILDS = "/pipeline_builds" PIPELINE_CONFIGURATION = "/pipeline-configuration" PIPELINE_DEPLOYMENTS = "/pipeline_deployments" @@ -230,6 +292,12 @@ def handle_int_env_var(var: str, default: int = 0) -> int: SERVICE_CONNECTOR_RESOURCES = "/resources" SERVICE_CONNECTOR_TYPES = "/service_connector_types" SERVICE_CONNECTOR_VERIFY = "/verify" +SERVICE_CONNECTOR_RESOURCES = "/resources" +MODELS = "/models" +MODEL_VERSIONS = "/model_versions" +MODEL_VERSION_ARTIFACTS = "/model_version_artifacts" +MODEL_VERSION_PIPELINE_RUNS = "/model_version_pipeline_runs" +SERVICES = "/services" SERVICE_CONNECTORS = "/service_connectors" STACKS = "/stacks" STACK_COMPONENTS = "/components" diff --git a/src/zenml/container_registries/base_container_registry.py b/src/zenml/container_registries/base_container_registry.py index 4617b7db588..d8f641cf4b4 100644 --- a/src/zenml/container_registries/base_container_registry.py +++ b/src/zenml/container_registries/base_container_registry.py @@ -142,7 +142,9 @@ def docker_client(self) -> "DockerClient": ) self._docker_client = client else: - self._docker_client = DockerClient.from_env() + self._docker_client = ( + docker_utils._try_get_docker_client_from_env() + ) credentials = self.credentials if credentials: diff --git a/src/zenml/enums.py b/src/zenml/enums.py index e92da7e8871..67f6ace00f6 100644 --- a/src/zenml/enums.py +++ b/src/zenml/enums.py @@ -54,6 +54,13 @@ class VisualizationType(StrEnum): MARKDOWN = "markdown" +class ZenMLServiceType(StrEnum): + """All possible types a service can have.""" + + ZEN_SERVER = "zen_server" + MODEL_SERVING = "model-serving" + + class ExecutionStatus(StrEnum): """Enum that represents the current status of a step or pipeline run.""" diff --git a/src/zenml/event_sources/webhooks/base_webhook_event_source.py b/src/zenml/event_sources/webhooks/base_webhook_event_source.py index 0fac4f73592..035c4aabad9 100644 --- a/src/zenml/event_sources/webhooks/base_webhook_event_source.py +++ b/src/zenml/event_sources/webhooks/base_webhook_event_source.py @@ -154,10 +154,12 @@ def _validate_webhook_event_signature( Raises: AuthorizationException: If the signature validation fails. """ - signature_header = headers.get("x-hub-signature-256") + signature_header = headers.get("x-hub-signature-256") or headers.get( + "x-hub-signature" + ) if not signature_header: raise AuthorizationException( - "x-hub-signature-256 header is missing!" + "x-hub-signature-256 or x-hub-signature header is missing!" ) if not self.is_valid_signature( diff --git a/src/zenml/exceptions.py b/src/zenml/exceptions.py index f7e67339033..5ef6b0315af 100644 --- a/src/zenml/exceptions.py +++ b/src/zenml/exceptions.py @@ -253,6 +253,10 @@ class InputResolutionError(ZenMLBaseException): """Raised when step input resolving failed.""" +class SubscriptionUpgradeRequiredError(ZenMLBaseException): + """Raised when user tries to perform an action outside their current subscription tier.""" + + class HydrationError(ZenMLBaseException): """Raised when the model hydration failed.""" diff --git a/src/zenml/image_builders/local_image_builder.py b/src/zenml/image_builders/local_image_builder.py index 16a1fd29c1f..5a918e934f9 100644 --- a/src/zenml/image_builders/local_image_builder.py +++ b/src/zenml/image_builders/local_image_builder.py @@ -17,8 +17,6 @@ import tempfile from typing import TYPE_CHECKING, Any, Dict, Optional, Type, cast -from docker.client import DockerClient - from zenml.image_builders import ( BaseImageBuilder, BaseImageBuilderConfig, @@ -106,7 +104,7 @@ def build( # authenticated to access additional registries docker_client = container_registry.docker_client else: - docker_client = DockerClient.from_env() + docker_client = docker_utils._try_get_docker_client_from_env() with tempfile.TemporaryFile(mode="w+b") as f: build_context.write_archive(f) diff --git a/src/zenml/integrations/__init__.py b/src/zenml/integrations/__init__.py index 3b2e37ca377..4d1b4033eb2 100644 --- a/src/zenml/integrations/__init__.py +++ b/src/zenml/integrations/__init__.py @@ -23,6 +23,7 @@ from zenml.integrations.aws import AWSIntegration # noqa from zenml.integrations.azure import AzureIntegration # noqa from zenml.integrations.bentoml import BentoMLIntegration # noqa +from zenml.integrations.bitbucket import BitbucketIntegration # noqa from zenml.integrations.deepchecks import DeepchecksIntegration # noqa from zenml.integrations.discord import DiscordIntegration # noqa from zenml.integrations.evidently import EvidentlyIntegration # noqa diff --git a/src/zenml/integrations/airflow/__init__.py b/src/zenml/integrations/airflow/__init__.py index ddf8a79a1fd..7446a195a61 100644 --- a/src/zenml/integrations/airflow/__init__.py +++ b/src/zenml/integrations/airflow/__init__.py @@ -17,7 +17,7 @@ orchestrator. You can enable it by registering the Airflow orchestrator with the CLI tool, then bootstrap using the ``zenml orchestrator up`` command. """ -from typing import List, Type +from typing import List, Optional, Type from zenml.integrations.constants import AIRFLOW from zenml.integrations.integration import Integration @@ -32,14 +32,7 @@ class AirflowIntegration(Integration): NAME = AIRFLOW # remove pendulum version requirement once Airflow supports # pendulum>-3.0.0 - REQUIREMENTS = [ - "apache-airflow~=2.4.0", - "pendulum<3.0.0", - # We need to add this as an extra dependency to manually downgrade - # SQLModel. Otherwise, the initial installation of ZenML installs - # a higher version SQLModel and a version mismatch is created. - "sqlmodel>=0.0.9,<=0.0.16", - ] + REQUIREMENTS = ["apache-airflow~=2.4.0", "pendulum<3.0.0"] @classmethod def flavors(cls) -> List[Type[Flavor]]: diff --git a/src/zenml/integrations/bentoml/constants.py b/src/zenml/integrations/bentoml/constants.py index 318913cd19e..19395866834 100644 --- a/src/zenml/integrations/bentoml/constants.py +++ b/src/zenml/integrations/bentoml/constants.py @@ -15,5 +15,5 @@ DEFAULT_BENTO_FILENAME = "zenml_exported.bento" BENTOML_DEFAULT_PORT = 3000 -BENTOML_HEALTHCHECK_URL_PATH = "healthz" +BENTOML_HEALTHCHECK_URL_PATH = "readyz" BENTOML_PREDICTION_URL_PATH = "" diff --git a/src/zenml/integrations/bentoml/model_deployers/bentoml_model_deployer.py b/src/zenml/integrations/bentoml/model_deployers/bentoml_model_deployer.py index 746d13a8f67..6f782f6ce4d 100644 --- a/src/zenml/integrations/bentoml/model_deployers/bentoml_model_deployer.py +++ b/src/zenml/integrations/bentoml/model_deployers/bentoml_model_deployer.py @@ -15,13 +15,11 @@ import os import shutil -from pathlib import Path -from typing import ClassVar, Dict, List, Optional, Type, cast +from typing import ClassVar, Dict, Optional, Type, cast from uuid import UUID from zenml.config.global_config import GlobalConfiguration from zenml.constants import DEFAULT_SERVICE_START_STOP_TIMEOUT -from zenml.integrations.bentoml.constants import BENTOML_DEFAULT_PORT from zenml.integrations.bentoml.flavors.bentoml_model_deployer_flavor import ( BentoMLModelDeployerConfig, BentoMLModelDeployerFlavor, @@ -32,8 +30,6 @@ ) from zenml.logger import get_logger from zenml.model_deployers import BaseModelDeployer, BaseModelDeployerFlavor -from zenml.services import ServiceRegistry -from zenml.services.local.local_service import SERVICE_DAEMON_CONFIG_FILE_NAME from zenml.services.service import BaseService, ServiceConfig from zenml.utils.io_utils import create_dir_recursive_if_not_exists @@ -126,7 +122,8 @@ def get_model_server_info( # type: ignore[override] ) return { - "PREDICTION_URL": service_instance.prediction_url, + "HEALTH_CHECK_URL": service_instance.get_healthcheck_url(), + "PREDICTION_URL": service_instance.get_prediction_url(), "BENTO_TAG": service_instance.config.bento, "MODEL_NAME": service_instance.config.model_name, "MODEL_URI": service_instance.config.model_uri, @@ -136,10 +133,10 @@ def get_model_server_info( # type: ignore[override] "PREDICTION_APIS_URLS": predictions_apis_urls, } - def deploy_model( + def perform_deploy_model( self, + id: UUID, config: ServiceConfig, - replace: bool = False, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT, ) -> BaseService: """Create a new BentoML deployment service or update an existing one. @@ -171,10 +168,8 @@ def deploy_model( and the others are deleted. Args: + id: the UUID of the BentoML model deployer. config: the configuration of the model to be deployed with BentoML. - replace: set this flag to True to find and update an equivalent - BentoML deployment server with the new model instead of - creating and starting a new deployment server. timeout: the timeout in seconds to wait for the BentoML server to be provisioned and successfully started or updated. If set to 0, the method will return immediately after the BentoML @@ -185,49 +180,11 @@ def deploy_model( interact with the BentoML model http server. """ config = cast(BentoMLDeploymentConfig, config) - service = None - - # if replace is True, remove all existing services - if replace is True: - existing_services = self.find_model_server( - pipeline_name=config.pipeline_name, - pipeline_step_name=config.pipeline_step_name, - model_name=config.model_name, - ) - - for existing_service in existing_services: - if service is None: - # keep the most recently created service - service = cast(BentoMLDeploymentService, existing_service) - try: - # delete the older services and don't wait for them to - # be deprovisioned - self._clean_up_existing_service( - existing_service=cast( - BentoMLDeploymentService, existing_service - ), - timeout=timeout, - force=True, - ) - except RuntimeError: - # ignore errors encountered while stopping old services - pass - if service: - logger.info( - f"Updating an existing BentoML deployment service: {service}" - ) - - # set the root runtime path with the stack component's UUID - config.root_runtime_path = self.local_path - service.stop(timeout=timeout, force=True) - service.update(config) - service.start(timeout=timeout) - else: - # create a new BentoMLDeploymentService instance - service = self._create_new_service(timeout, config) - logger.info(f"Created a new BentoML deployment service: {service}") - - return cast(BaseService, service) + service = self._create_new_service( + id=id, timeout=timeout, config=config + ) + logger.info(f"Created a new BentoML deployment service: {service}") + return service def _clean_up_existing_service( self, @@ -246,12 +203,13 @@ def _clean_up_existing_service( # of workers etc.the step implementation will create a new config using # all values from the user and add values like pipeline name, model_uri def _create_new_service( - self, timeout: int, config: BentoMLDeploymentConfig + self, id: UUID, timeout: int, config: BentoMLDeploymentConfig ) -> BentoMLDeploymentService: """Creates a new BentoMLDeploymentService. Args: - timeout: the timeout in seconds to wait for the BentoML http server + id: the ID of the BentoML deployment service to be created or updated. + timeout: the timeout in seconds to wait for the BentoML server to be provisioned and successfully started or updated. config: the configuration of the model to be deployed with BentoML. @@ -262,197 +220,61 @@ def _create_new_service( # set the root runtime path with the stack component's UUID config.root_runtime_path = self.local_path # create a new service for the new model - service = BentoMLDeploymentService(config) + service = BentoMLDeploymentService(uuid=id, config=config) service.start(timeout=timeout) return service - def find_model_server( + def perform_stop_model( self, - running: bool = False, - service_uuid: Optional[UUID] = None, - pipeline_name: Optional[str] = None, - run_name: Optional[str] = None, - pipeline_step_name: Optional[str] = None, - model_name: Optional[str] = None, - model_uri: Optional[str] = None, - model_type: Optional[str] = None, - ) -> List[BaseService]: - """Finds one or more model servers that match the given criteria. - - Args: - running: If true, only running services will be returned. - service_uuid: The UUID of the service that was originally used - to deploy the model. - pipeline_name: Name of the pipeline that the deployed model was part - of. - run_name: ID of the pipeline run which the deployed model - was part of. - pipeline_step_name: The name of the pipeline model deployment step - that deployed the model. - model_name: Name of the deployed model. - model_uri: URI of the deployed model. - model_type: Type/format of the deployed model. Not used in this - BentoML case. - - Returns: - One or more Service objects representing model servers that match - the input search criteria. - - Raises: - TypeError: if any of the input arguments are of an invalid type. - """ - services = [] - config = BentoMLDeploymentConfig( - model_name=model_name or "", - bento="", - port=BENTOML_DEFAULT_PORT, - model_uri=model_uri or "", - working_dir="", - pipeline_name=pipeline_name or "", - pipeline_run_id=run_name or "", - run_name=run_name or "", - pipeline_step_name=pipeline_step_name or "", - ) - - # find all services that match the input criteria - for root, _, files in os.walk(self.local_path): - if service_uuid and Path(root).name != str(service_uuid): - continue - for file in files: - if file == SERVICE_DAEMON_CONFIG_FILE_NAME: - service_config_path = os.path.join(root, file) - logger.debug( - "Loading service daemon configuration from %s", - service_config_path, - ) - existing_service_config = None - with open(service_config_path, "r") as f: - existing_service_config = f.read() - existing_service = ( - ServiceRegistry().load_service_from_json( - existing_service_config - ) - ) - if not isinstance( - existing_service, BentoMLDeploymentService - ): - raise TypeError( - f"Expected service type BentoMLDeploymentService but got " - f"{type(existing_service)} instead" - ) - existing_service.update_status() - if self._matches_search_criteria(existing_service, config): - if not running or existing_service.is_running: - services.append( - cast(BaseService, existing_service) - ) - - return services - - def _matches_search_criteria( - self, - existing_service: BentoMLDeploymentService, - config: BentoMLDeploymentConfig, - ) -> bool: - """Returns true if a service matches the input criteria. - - If any of the values in the input criteria are None, they are ignored. - This allows listing services just by common pipeline names or step - names, etc. - - Args: - existing_service: The materialized Service instance derived from - the config of the older (existing) service - config: The BentoMlDeploymentConfig object passed to the - deploy_model function holding parameters of the new service - to be created. - - Returns: - True if the service matches the input criteria. - """ - existing_service_config = existing_service.config - - # check if the existing service matches the input criteria - if ( - ( - not config.pipeline_name - or existing_service_config.pipeline_name - == config.pipeline_name - ) - and ( - not config.model_name - or existing_service_config.model_name == config.model_name - ) - and ( - not config.pipeline_step_name - or existing_service_config.pipeline_step_name - == config.pipeline_step_name - ) - and ( - not config.run_name - or existing_service_config.run_name == config.run_name - ) - ): - return True - - return False - - def stop_model_server( - self, - uuid: UUID, + service: BaseService, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT, force: bool = False, - ) -> None: + ) -> BaseService: """Method to stop a model server. Args: - uuid: UUID of the model server to stop. + service: The service to stop. timeout: Timeout in seconds to wait for the service to stop. force: If True, force the service to stop. - """ - # get list of all services - existing_services = self.find_model_server(service_uuid=uuid) - # if the service exists, stop it - if existing_services: - existing_services[0].stop(timeout=timeout, force=force) + Returns: + The stopped service. + """ + service.stop(timeout=timeout, force=force) + return service - def start_model_server( - self, uuid: UUID, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT - ) -> None: + def perform_start_model( + self, + service: BaseService, + timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT, + ) -> BaseService: """Method to start a model server. Args: - uuid: UUID of the model server to start. + service: The service to start. timeout: Timeout in seconds to wait for the service to start. - """ - # get list of all services - existing_services = self.find_model_server(service_uuid=uuid) - # if the service exists, start it - if existing_services: - existing_services[0].start(timeout=timeout) + Returns: + The started service. + """ + service.start(timeout=timeout) + return service - def delete_model_server( + def perform_delete_model( self, - uuid: UUID, + service: BaseService, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT, force: bool = False, ) -> None: """Method to delete all configuration of a model server. Args: - uuid: UUID of the model server to delete. + service: The service to delete. timeout: Timeout in seconds to wait for the service to stop. force: If True, force the service to stop. """ - # get list of all services - existing_services = self.find_model_server(service_uuid=uuid) - - # if the service exists, clean it up - if existing_services: - service = cast(BentoMLDeploymentService, existing_services[0]) - self._clean_up_existing_service( - existing_service=service, timeout=timeout, force=force - ) + service = cast(BentoMLDeploymentService, service) + self._clean_up_existing_service( + existing_service=service, timeout=timeout, force=force + ) diff --git a/src/zenml/integrations/bentoml/services/bentoml_deployment.py b/src/zenml/integrations/bentoml/services/bentoml_deployment.py index 138d3039c9b..2a826fb5077 100644 --- a/src/zenml/integrations/bentoml/services/bentoml_deployment.py +++ b/src/zenml/integrations/bentoml/services/bentoml_deployment.py @@ -94,8 +94,8 @@ class SSLBentoMLParametersConfig(BaseModel): ssl_certfile: Optional[str] = None ssl_keyfile: Optional[str] = None ssl_keyfile_password: Optional[str] = None - ssl_version: Optional[str] = None - ssl_cert_reqs: Optional[str] = None + ssl_version: Optional[int] = None + ssl_cert_reqs: Optional[int] = None ssl_ca_certs: Optional[str] = None ssl_ciphers: Optional[str] = None @@ -121,9 +121,9 @@ class BentoMLDeploymentConfig(LocalDaemonServiceConfig): bento: str bento_uri: Optional[str] = None apis: List[str] = [] - workers: Optional[int] = 1 - port: Optional[int] = None - backlog: Optional[int] = 2048 + workers: int = 1 + port: int + backlog: int = 2048 production: bool = False working_dir: str host: Optional[str] = None @@ -147,6 +147,7 @@ class BentoMLDeploymentService(LocalDaemonService, BaseDeploymentService): type="model-serving", flavor="bentoml", description="BentoML prediction service", + logo_url="https://public-flavor-logos.s3.eu-central-1.amazonaws.com/model_deployer/bentoml.png", ) config: BentoMLDeploymentConfig @@ -203,9 +204,9 @@ def run(self) -> None: serve_http_production( self.config.bento, working_dir=self.config.working_dir, - port=self.endpoint.status.port, + port=self.config.port, api_workers=self.config.workers, - host=self.endpoint.status.hostname, + host=self.config.host or DEFAULT_LOCAL_SERVICE_IP_ADDRESS, backlog=self.config.backlog, ssl_certfile=ssl_params.ssl_certfile, ssl_keyfile=ssl_params.ssl_keyfile, diff --git a/src/zenml/integrations/bentoml/steps/bentoml_deployer.py b/src/zenml/integrations/bentoml/steps/bentoml_deployer.py index 225126233ed..4bb11e1957f 100644 --- a/src/zenml/integrations/bentoml/steps/bentoml_deployer.py +++ b/src/zenml/integrations/bentoml/steps/bentoml_deployer.py @@ -87,16 +87,8 @@ def bentoml_model_deployer_step( # get pipeline name, step name and run id step_context = get_step_context() pipeline_name = step_context.pipeline.name - run_name = step_context.pipeline_run.name step_name = step_context.step_run.name - # fetch existing services with same pipeline name, step name and model name - existing_services = model_deployer.find_model_server( - pipeline_name=pipeline_name, - pipeline_step_name=step_name, - model_name=model_name, - ) - # Return the apis endpoint of the defined service to use in the predict. # This is a workaround to get the endpoints of the service defined as functions # from the user code in the BentoML service. @@ -123,7 +115,6 @@ def service_apis(bento_tag: str) -> List[str]: working_dir=working_dir or source_utils.get_source_root(), port=port, pipeline_name=pipeline_name, - run_name=run_name, pipeline_step_name=step_name, ssl_parameters=SSLBentoMLParametersConfig( ssl_certfile=ssl_certfile, @@ -136,8 +127,13 @@ def service_apis(bento_tag: str) -> List[str]: ), ) + # fetch existing services with same pipeline name, step name and model name + existing_services = model_deployer.find_model_server( + config=predictor_cfg.dict(), + service_type=BentoMLDeploymentService.SERVICE_TYPE, + ) + # Creating a new service with inactive state and status by default - service = BentoMLDeploymentService(predictor_cfg) if existing_services: service = cast(BentoMLDeploymentService, existing_services[0]) @@ -159,6 +155,7 @@ def service_apis(bento_tag: str) -> List[str]: replace=True, config=predictor_cfg, timeout=timeout, + service_type=BentoMLDeploymentService.SERVICE_TYPE, ), ) diff --git a/src/zenml/integrations/bitbucket/__init__.py b/src/zenml/integrations/bitbucket/__init__.py new file mode 100644 index 00000000000..770f355a2ee --- /dev/null +++ b/src/zenml/integrations/bitbucket/__init__.py @@ -0,0 +1,42 @@ +# Copyright (c) ZenML GmbH 2022. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Initialization of the bitbucket ZenML integration.""" +from typing import List, Type + +from zenml.integrations.constants import BITBUCKET +from zenml.integrations.integration import Integration +from zenml.plugins.base_plugin_flavor import BasePluginFlavor + +BITBUCKET_EVENT_FLAVOR = "bitbucket" + + +class BitbucketIntegration(Integration): + """Definition of bitbucket integration for ZenML.""" + + NAME = BITBUCKET + REQUIREMENTS: List[str] = [] + + @classmethod + def plugin_flavors(cls) -> List[Type[BasePluginFlavor]]: + """Declare the event flavors for the bitbucket integration. + + Returns: + List of stack component flavors for this integration. + """ + from zenml.integrations.bitbucket.plugins import BitbucketWebhookEventSourceFlavor + + return [BitbucketWebhookEventSourceFlavor] + + +BitbucketIntegration.check_installation() diff --git a/src/zenml/integrations/bitbucket/plugins/__init__.py b/src/zenml/integrations/bitbucket/plugins/__init__.py new file mode 100644 index 00000000000..c5eb3accaed --- /dev/null +++ b/src/zenml/integrations/bitbucket/plugins/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Bitbucket event flavors.""" + +from zenml.integrations.bitbucket.plugins.bitbucket_webhook_event_source_flavor import BitbucketWebhookEventSourceFlavor + +__all__ = [ + "BitbucketWebhookEventSourceFlavor" +] \ No newline at end of file diff --git a/src/zenml/integrations/bitbucket/plugins/bitbucket_webhook_event_source_flavor.py b/src/zenml/integrations/bitbucket/plugins/bitbucket_webhook_event_source_flavor.py new file mode 100644 index 00000000000..a389b6677ee --- /dev/null +++ b/src/zenml/integrations/bitbucket/plugins/bitbucket_webhook_event_source_flavor.py @@ -0,0 +1,43 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Bitbucket webhook event source flavor.""" + +from typing import ClassVar, Type + +from zenml.event_sources.webhooks.base_webhook_event_source import ( + BaseWebhookEventSourceFlavor, +) +from zenml.integrations.bitbucket import BITBUCKET_EVENT_FLAVOR +from zenml.integrations.bitbucket.plugins.event_sources.bitbucket_webhook_event_source import ( + BitbucketWebhookEventFilterConfiguration, + BitbucketWebhookEventSourceConfiguration, + BitbucketWebhookEventSourceHandler, +) + + +class BitbucketWebhookEventSourceFlavor(BaseWebhookEventSourceFlavor): + """Enables users to configure Bitbucket event sources.""" + + FLAVOR: ClassVar[str] = BITBUCKET_EVENT_FLAVOR + PLUGIN_CLASS: ClassVar[Type[BitbucketWebhookEventSourceHandler]] = ( + BitbucketWebhookEventSourceHandler + ) + + # EventPlugin specific + EVENT_SOURCE_CONFIG_CLASS: ClassVar[ + Type[BitbucketWebhookEventSourceConfiguration] + ] = BitbucketWebhookEventSourceConfiguration + EVENT_FILTER_CONFIG_CLASS: ClassVar[ + Type[BitbucketWebhookEventFilterConfiguration] + ] = BitbucketWebhookEventFilterConfiguration diff --git a/src/zenml/integrations/bitbucket/plugins/event_sources/__init__.py b/src/zenml/integrations/bitbucket/plugins/event_sources/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/zenml/integrations/bitbucket/plugins/event_sources/bitbucket_webhook_event_source.py b/src/zenml/integrations/bitbucket/plugins/event_sources/bitbucket_webhook_event_source.py new file mode 100644 index 00000000000..c9a6c247958 --- /dev/null +++ b/src/zenml/integrations/bitbucket/plugins/event_sources/bitbucket_webhook_event_source.py @@ -0,0 +1,490 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Implementation of the Bitbucket webhook event source.""" + +from typing import Any, Dict, List, Optional, Type, Union +from uuid import UUID + +from pydantic import BaseModel, Extra, Field + +from zenml.enums import SecretScope +from zenml.event_sources.base_event import ( + BaseEvent, +) +from zenml.event_sources.base_event_source import EventSourceConfig +from zenml.event_sources.webhooks.base_webhook_event_source import ( + BaseWebhookEventSourceFlavor, + BaseWebhookEventSourceHandler, + WebhookEventFilterConfig, + WebhookEventSourceConfig, +) +from zenml.exceptions import AuthorizationException +from zenml.logger import get_logger +from zenml.models import ( + EventSourceRequest, + EventSourceResponse, + EventSourceUpdate, + SecretRequest, + SecretUpdate, +) +from zenml.utils.enum_utils import StrEnum +from zenml.utils.string_utils import random_str + +logger = get_logger(__name__) + +# -------------------- Utils ----------------------------------- + + +class BitbucketEventType(StrEnum): + """Collection of all possible Bitbucket Events.""" + + PUSH_EVENT = "push_event" + TAG_EVENT = "tag_event" + + +# -------------------- Bitbucket Event Models ---------------------------------- + + +class User(BaseModel): + """Bitbucket User.""" + + name: Optional[str] + email: Optional[str] + username: Optional[str] + + +class Commit(BaseModel): + """Bitbucket Commit.""" + + hash: str + message: str + links: Dict[str, Any] + author: User + + +class Repository(BaseModel): + """Bitbucket Repository.""" + + uuid: str + name: str + full_name: str + links: Dict[str, Any] + + +class PushChange(BaseModel): + """Bitbucket Push Change.""" + + new: Optional[Dict[str, Any]] + old: Optional[Dict[str, Any]] + commits: List[Commit] + + +class Push(BaseModel): + """Bitbucket Push.""" + + changes: List[PushChange] + + +class BitbucketEvent(BaseEvent): + """Bitbucket Event.""" + + actor: User + repository: Repository + push: Push + + class Config: + """Pydantic configuration class.""" + + extra = Extra.allow + + @property + def branch(self) -> Optional[str]: + """The branch the event happened on. + + Returns: + The branch name. + """ + if self.push.changes[0].new: + branch = self.push.changes[0].new.get("name", None) + if self.push.changes[0].new.get("name", None): + return str(branch) + return None + + @property + def event_type(self) -> Union[BitbucketEventType, str]: + """The type of Bitbucket event. + + Args: + The type of the event based on Bitbucket specific fields. + + Returns: + The type of the event. + """ + is_push_event = all( + [change.new is not None for change in self.push.changes] + ) + is_tag_event = all( + [ + change.new.get("type") == "tag" + for change in self.push.changes + if change.new + ] + ) + + if is_push_event: + return BitbucketEventType.PUSH_EVENT + elif is_tag_event: + return BitbucketEventType.TAG_EVENT + else: + return "unknown" + + +# -------------------- Configuration Models ---------------------------------- + + +class BitbucketWebhookEventFilterConfiguration(WebhookEventFilterConfig): + """Configuration for Bitbucket event filters.""" + + repo: Optional[str] + branch: Optional[str] + event_type: Optional[BitbucketEventType] + + def event_matches_filter(self, event: BaseEvent) -> bool: + """Checks the filter against the inbound event. + + Args: + event: The incoming event + + Returns: + Whether the event matches the filter + """ + if not isinstance(event, BitbucketEvent): + return False + if self.event_type and event.event_type != self.event_type: + # Mismatch for the action + return False + if self.repo and event.repository.full_name != self.repo: + # Mismatch for the repository + return False + if self.branch and event.branch != self.branch: + # Mismatch for the branch + return False + return True + + +class BitbucketWebhookEventSourceConfiguration(WebhookEventSourceConfig): + """Configuration for Bitbucket source filters.""" + + webhook_secret: Optional[str] = Field( + default=None, + title="The webhook secret for the event source.", + ) + webhook_secret_id: Optional[UUID] = Field( + default=None, + description="The ID of the secret containing the webhook secret.", + ) + rotate_secret: Optional[bool] = Field( + default=None, description="Set to rotate the webhook secret." + ) + + +# -------------------- Bitbucket Webhook Plugin ----------------------------------- + + +class BitbucketWebhookEventSourceHandler(BaseWebhookEventSourceHandler): + """Handler for all Bitbucket events.""" + + @property + def config_class(self) -> Type[BitbucketWebhookEventSourceConfiguration]: + """Returns the webhook event source configuration class. + + Returns: + The configuration. + """ + return BitbucketWebhookEventSourceConfiguration + + @property + def filter_class(self) -> Type[BitbucketWebhookEventFilterConfiguration]: + """Returns the webhook event filter configuration class. + + Returns: + The event filter configuration class. + """ + return BitbucketWebhookEventFilterConfiguration + + @property + def flavor_class(self) -> Type[BaseWebhookEventSourceFlavor]: + """Returns the flavor class of the plugin. + + Returns: + The flavor class of the plugin. + """ + from zenml.integrations.bitbucket.plugins.bitbucket_webhook_event_source_flavor import ( + BitbucketWebhookEventSourceFlavor, + ) + + return BitbucketWebhookEventSourceFlavor + + def _interpret_event(self, event: Dict[str, Any]) -> BitbucketEvent: + """Converts the generic event body into a event-source specific pydantic model. + + Args: + event: The generic event body + + Returns: + An instance of the event source specific pydantic model. + + Raises: + ValueError: If the event body can not be parsed into the pydantic model. + """ + try: + Bitbucket_event = BitbucketEvent(**event) + except ValueError: + raise ValueError("Event did not match the pydantic model.") + else: + return Bitbucket_event + + def _get_webhook_secret( + self, event_source: EventSourceResponse + ) -> Optional[str]: + """Get the webhook secret for the event source. + + Args: + event_source: The event source to retrieve the secret for. + + Returns: + The webhook secret associated with the event source, or None if a + secret is not applicable. + + Raises: + AuthorizationException: If the secret value could not be retrieved. + """ + # Temporary solution to get the secret value for the Event Source + config = self.validate_event_source_configuration( + event_source.configuration + ) + assert isinstance(config, BitbucketWebhookEventSourceConfiguration) + webhook_secret_id = config.webhook_secret_id + if webhook_secret_id is None: + raise AuthorizationException( + f"Webhook secret ID is missing from the event source " + f"configuration for event source '{event_source.id}'." + ) + try: + return self.zen_store.get_secret( + secret_id=webhook_secret_id + ).secret_values["webhook_secret"] + except KeyError: + logger.exception( + f"Could not retrieve secret value for webhook secret id " + f"'{webhook_secret_id}'" + ) + raise AuthorizationException( + "Could not retrieve webhook signature." + ) + + def _validate_event_source_request( + self, event_source: EventSourceRequest, config: EventSourceConfig + ) -> None: + """Validate an event source request before it is created in the database. + + The `webhook_secret`, `webhook_secret_id`, and `rotate_secret` + fields are not allowed in the request. + + Args: + event_source: Event source request. + config: Event source configuration instantiated from the request. + + Raises: + ValueError: If any of the disallowed fields are present in the + request. + """ + assert isinstance(config, BitbucketWebhookEventSourceConfiguration) + for field in ["webhook_secret", "webhook_secret_id", "rotate_secret"]: + if getattr(config, field) is not None: + raise ValueError( + f"The `{field}` field is not allowed in the event source " + "request." + ) + + def _process_event_source_request( + self, event_source: EventSourceResponse, config: EventSourceConfig + ) -> None: + """Process an event source request after it is created in the database. + + Generates a webhook secret and stores it in a secret in the database, + then attaches the secret ID to the event source configuration. + + Args: + event_source: Newly created event source + config: Event source configuration instantiated from the response. + """ + assert isinstance(config, BitbucketWebhookEventSourceConfiguration) + assert ( + event_source.user is not None + ), "User is not set for event source" + + secret_key_value = random_str(12) + webhook_secret = SecretRequest( + name=f"event_source-{str(event_source.id)}-{random_str(4)}".lower(), + values={"webhook_secret": secret_key_value}, + workspace=event_source.workspace.id, + user=event_source.user.id, + scope=SecretScope.WORKSPACE, + ) + secret = self.zen_store.create_secret(webhook_secret) + + # Store the secret ID in the event source configuration in the database + event_source_update = EventSourceUpdate.from_response(event_source) + assert event_source_update.configuration is not None + event_source_update.configuration["webhook_secret_id"] = str(secret.id) + + self.zen_store.update_event_source( + event_source_id=event_source.id, + event_source_update=event_source_update, + ) + + # Set the webhook secret in the configuration returned to the user + config.webhook_secret = secret_key_value + # Remove hidden field from the response + config.rotate_secret = None + config.webhook_secret_id = None + + def _validate_event_source_update( + self, + event_source: EventSourceResponse, + config: EventSourceConfig, + event_source_update: EventSourceUpdate, + config_update: EventSourceConfig, + ) -> None: + """Validate an event source update before it is reflected in the database. + + Ensure the webhook secret ID is preserved in the updated event source + configuration. + + Args: + event_source: Original event source before the update. + config: Event source configuration instantiated from the original + event source. + event_source_update: Event source update request. + config_update: Event source configuration instantiated from the + updated event source. + """ + assert isinstance(config, BitbucketWebhookEventSourceConfiguration) + assert isinstance( + config_update, BitbucketWebhookEventSourceConfiguration + ) + + config_update.webhook_secret_id = config.webhook_secret_id + + def _process_event_source_update( + self, + event_source: EventSourceResponse, + config: EventSourceConfig, + previous_event_source: EventSourceResponse, + previous_config: EventSourceConfig, + ) -> None: + """Process an event source after it is updated in the database. + + If the `rotate_secret` field is set to `True`, the webhook secret is + rotated and the new secret ID is attached to the event source + configuration. + + Args: + event_source: Event source after the update. + config: Event source configuration instantiated from the updated + event source. + previous_event_source: Original event source before the update. + previous_config: Event source configuration instantiated from the + original event source. + """ + assert isinstance(config, BitbucketWebhookEventSourceConfiguration) + assert isinstance( + previous_config, BitbucketWebhookEventSourceConfiguration + ) + assert config.webhook_secret_id is not None + + if config.rotate_secret: + # In case the secret is being rotated + secret_key_value = random_str(12) + webhook_secret = SecretUpdate( # type: ignore[call-arg] + values={"webhook_secret": secret_key_value} + ) + self.zen_store.update_secret( + secret_id=config.webhook_secret_id, + secret_update=webhook_secret, + ) + + # Remove the `rotate_secret` field from the configuration stored + # in the database + event_source_update = EventSourceUpdate.from_response(event_source) + assert event_source_update.configuration is not None + event_source_update.configuration.pop("rotate_secret") + self.zen_store.update_event_source( + event_source_id=event_source.id, + event_source_update=event_source_update, + ) + + # Set the new secret in the configuration returned to the user + config.webhook_secret = secret_key_value + + # Remove hidden fields from the response + config.rotate_secret = None + config.webhook_secret_id = None + + def _process_event_source_delete( + self, + event_source: EventSourceResponse, + config: EventSourceConfig, + force: Optional[bool] = False, + ) -> None: + """Process an event source before it is deleted from the database. + + Deletes the associated secret from the database. + + Args: + event_source: Event source before the deletion. + config: Validated instantiated event source configuration before + the deletion. + force: Whether to force deprovision the event source. + """ + assert isinstance(config, BitbucketWebhookEventSourceConfiguration) + if config.webhook_secret_id is not None: + try: + self.zen_store.delete_secret( + secret_id=config.webhook_secret_id + ) + except KeyError: + pass + + # Remove hidden fields from the response + config.rotate_secret = None + config.webhook_secret_id = None + + def _process_event_source_response( + self, event_source: EventSourceResponse, config: EventSourceConfig + ) -> None: + """Process an event source response before it is returned to the user. + + Removes hidden fields from the configuration. + + Args: + event_source: Event source response. + config: Event source configuration instantiated from the response. + """ + assert isinstance(config, BitbucketWebhookEventSourceConfiguration) + # Remove hidden fields from the response + config.rotate_secret = None + config.webhook_secret_id = None + config.webhook_secret = None diff --git a/src/zenml/integrations/constants.py b/src/zenml/integrations/constants.py index cb800d9b251..0a486ae63be 100644 --- a/src/zenml/integrations/constants.py +++ b/src/zenml/integrations/constants.py @@ -18,6 +18,7 @@ AZURE = "azure" AZUREML = "azureml" BENTOML = "bentoml" +BITBUCKET = "bitbucket" DASH = "dash" DEEPCHECKS = "deepchecks" DISCORD = "discord" diff --git a/src/zenml/integrations/evidently/__init__.py b/src/zenml/integrations/evidently/__init__.py index 00e0e42b6d1..6912a9ef516 100644 --- a/src/zenml/integrations/evidently/__init__.py +++ b/src/zenml/integrations/evidently/__init__.py @@ -54,13 +54,7 @@ class EvidentlyIntegration(Integration): """[Evidently](https://github.com/evidentlyai/evidently) integration for ZenML.""" NAME = EVIDENTLY - REQUIREMENTS = [ - "evidently>0.2.6,<0.4.5", # supports pyyaml 6 - # We need to add this as an extra dependency to manually downgrade - # SQLModel. Otherwise, the initial installation of ZenML installs - # a higher version SQLModel and a version mismatch is created. - "sqlmodel>=0.0.9,<=0.0.16" - ] + REQUIREMENTS = ["evidently>0.2.6,<0.4.5"] # supports pyyaml 6 @classmethod def flavors(cls) -> List[Type[Flavor]]: diff --git a/src/zenml/integrations/github/plugins/github_webhook_event_source_flavor.py b/src/zenml/integrations/github/plugins/github_webhook_event_source_flavor.py index 568291bc626..5b321911ad5 100644 --- a/src/zenml/integrations/github/plugins/github_webhook_event_source_flavor.py +++ b/src/zenml/integrations/github/plugins/github_webhook_event_source_flavor.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. -"""Example file of what an event Plugin could look like.""" +"""Github webhook event source flavor.""" from typing import ClassVar, Type diff --git a/src/zenml/integrations/great_expectations/__init__.py b/src/zenml/integrations/great_expectations/__init__.py index 4a4e630d9fa..500b197a93f 100644 --- a/src/zenml/integrations/great_expectations/__init__.py +++ b/src/zenml/integrations/great_expectations/__init__.py @@ -35,10 +35,6 @@ class GreatExpectationsIntegration(Integration): "great-expectations>=0.15.0,<=0.15.47", # typing_extensions 4.6.0 and above doesn't work with GE "typing_extensions<4.6.0", - # We need to add this as an extra dependency to manually downgrade - # SQLModel. Otherwise, the initial installation of ZenML installs - # a higher version SQLModel and a version mismatch is created. - "sqlmodel>=0.0.9,<=0.0.16", ] @staticmethod diff --git a/src/zenml/integrations/huggingface/__init__.py b/src/zenml/integrations/huggingface/__init__.py index c1a92f48e41..5f11ebc1cb3 100644 --- a/src/zenml/integrations/huggingface/__init__.py +++ b/src/zenml/integrations/huggingface/__init__.py @@ -30,6 +30,11 @@ class HuggingfaceIntegration(Integration): "transformers<=4.31", "datasets", "huggingface_hub>0.19.0", + # temporary fix for CI issue similar to: + # - https://github.com/huggingface/datasets/issues/6737 + # - https://github.com/huggingface/datasets/issues/6697 + # TODO try relaxing it back going forward + "fsspec<=2023.12.0", ] @classmethod diff --git a/src/zenml/integrations/huggingface/flavors/huggingface_model_deployer_flavor.py b/src/zenml/integrations/huggingface/flavors/huggingface_model_deployer_flavor.py index d9150fe9986..f9f98b65686 100644 --- a/src/zenml/integrations/huggingface/flavors/huggingface_model_deployer_flavor.py +++ b/src/zenml/integrations/huggingface/flavors/huggingface_model_deployer_flavor.py @@ -33,7 +33,6 @@ class HuggingFaceBaseConfig(BaseModel): """Hugging Face Inference Endpoint configuration.""" - endpoint_name: str = "zenml-" repository: Optional[str] = None framework: Optional[str] = None accelerator: Optional[str] = None @@ -41,15 +40,15 @@ class HuggingFaceBaseConfig(BaseModel): instance_type: Optional[str] = None region: Optional[str] = None vendor: Optional[str] = None - token: Optional[str] = None account_id: Optional[str] = None min_replica: int = 0 max_replica: int = 1 revision: Optional[str] = None task: Optional[str] = None custom_image: Optional[Dict[str, Any]] = None - namespace: Optional[str] = None endpoint_type: str = "public" + secret_name: Optional[str] = None + namespace: Optional[str] = None class HuggingFaceModelDeployerConfig( @@ -62,7 +61,7 @@ class HuggingFaceModelDeployerConfig( namespace: Hugging Face namespace used to list endpoints """ - token: str = SecretField() + token: Optional[str] = SecretField() # The namespace to list endpoints for. Set to `"*"` to list all endpoints # from all namespaces (i.e. personal namespace and all orgs the user belongs to). diff --git a/src/zenml/integrations/huggingface/model_deployers/huggingface_model_deployer.py b/src/zenml/integrations/huggingface/model_deployers/huggingface_model_deployer.py index 2ab93405864..eb551d5051a 100644 --- a/src/zenml/integrations/huggingface/model_deployers/huggingface_model_deployer.py +++ b/src/zenml/integrations/huggingface/model_deployers/huggingface_model_deployer.py @@ -13,12 +13,11 @@ # permissions and limitations under the License. """Implementation of the Hugging Face Model Deployer.""" -from typing import Any, ClassVar, Dict, List, Optional, Type, cast +from typing import ClassVar, Dict, Optional, Tuple, Type, cast from uuid import UUID -from huggingface_hub import list_inference_endpoints - -from zenml.artifacts.utils import log_artifact_metadata, save_artifact +from zenml.analytics.enums import AnalyticsEvent +from zenml.analytics.utils import track_handler from zenml.client import Client from zenml.integrations.huggingface import HUGGINGFACE_SERVICE_ARTIFACT from zenml.integrations.huggingface.flavors.huggingface_model_deployer_flavor import ( @@ -35,13 +34,12 @@ DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, BaseModelDeployerFlavor, ) -from zenml.services import BaseService, ServiceConfig, ServiceRegistry +from zenml.services import BaseService, ServiceConfig +from zenml.stack.stack import Stack +from zenml.stack.stack_validator import StackValidator logger = get_logger(__name__) -ZENM_ENDPOINT_PREFIX: str = "zenml-" -UUID_SLICE_LENGTH: int = 8 - class HuggingFaceModelDeployer(BaseModelDeployer): """Hugging Face endpoint model deployer.""" @@ -61,45 +59,42 @@ def config(self) -> HuggingFaceModelDeployerConfig: return cast(HuggingFaceModelDeployerConfig, self._config) @property - def deployed_endpoints(self) -> Any: - """Get list of deployed endpoint from Hugging Face. + def validator(self) -> Optional[StackValidator]: + """Validates the stack. Returns: - List of deployed endpoints. + A validator that checks that the stack contains a remote artifact + store. """ - return list_inference_endpoints( - token=self.config.token, - namespace=self.config.namespace, - ) - - def modify_endpoint_name( - self, endpoint_name: str, artifact_version: str - ) -> str: - """Modify endpoint name by adding suffix and prefix. - - It adds a prefix "zenml-" if not present and a suffix - of first 8 characters of uuid. - Args: - endpoint_name : Name of the endpoint - artifact_version: Name of the artifact version - - Returns: - Modified endpoint name with added prefix and suffix - """ - # Add prefix if it does not start with ZENM_ENDPOINT_PREFIX - if not endpoint_name.startswith(ZENM_ENDPOINT_PREFIX): - endpoint_name = ZENM_ENDPOINT_PREFIX + endpoint_name + def _validate_if_secret_or_token_is_present( + stack: "Stack", + ) -> Tuple[bool, str]: + """Check if secret or token is present in the stack. + + Args: + stack: The stack to validate. + + Returns: + A tuple with a boolean indicating whether the stack is valid + and a message describing the validation result. + """ + return bool(self.config.token or self.config.secret_name), ( + "The Hugging Face model deployer requires either a secret name" + " or a token to be present in the stack." + ) - endpoint_name += artifact_version - return endpoint_name + return StackValidator( + custom_validation_function=_validate_if_secret_or_token_is_present, + ) def _create_new_service( - self, timeout: int, config: HuggingFaceServiceConfig + self, id: UUID, timeout: int, config: HuggingFaceServiceConfig ) -> HuggingFaceDeploymentService: """Creates a new Hugging FaceDeploymentService. Args: + id: the UUID of the model to be deployed with Hugging Face model deployer. timeout: the timeout in seconds to wait for the Hugging Face inference endpoint to be provisioned and successfully started or updated. config: the configuration of the model to be deployed with Hugging Face model deployer. @@ -109,36 +104,12 @@ def _create_new_service( with the Hugging Face inference endpoint. """ # create a new service for the new model - service = HuggingFaceDeploymentService(config) - - # Use first 8 characters of UUID as artifact version - artifact_version = str(service.dict()["uuid"])[:UUID_SLICE_LENGTH] - # Add same 8 characters as suffix to endpoint name - service.config.endpoint_name = self.modify_endpoint_name( - service.config.endpoint_name, artifact_version - ) + service = HuggingFaceDeploymentService(uuid=id, config=config) logger.info( f"Creating an artifact {HUGGINGFACE_SERVICE_ARTIFACT} with service instance attached as metadata." " If there's an active pipeline and/or model this artifact will be associated with it." ) - - save_artifact( - service, - HUGGINGFACE_SERVICE_ARTIFACT, - version=artifact_version, - is_deployment_artifact=True, - ) - - # Convert UUID object to be json serializable - service_metadata = service.dict() - service_metadata["uuid"] = str(service_metadata["uuid"]) - log_artifact_metadata( - artifact_name=HUGGINGFACE_SERVICE_ARTIFACT, - artifact_version=artifact_version, - metadata={HUGGINGFACE_SERVICE_ARTIFACT: service_metadata}, - ) - service.start(timeout=timeout) return service @@ -159,10 +130,10 @@ def _clean_up_existing_service( # stop the older service existing_service.stop(timeout=timeout, force=force) - def deploy_model( + def perform_deploy_model( self, + id: UUID, config: ServiceConfig, - replace: bool = True, timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, ) -> BaseService: """Create a new Hugging Face deployment service or update an existing one. @@ -170,11 +141,8 @@ def deploy_model( This should serve the supplied model and deployment configuration. Args: + id: the UUID of the model to be deployed with Hugging Face. config: the configuration of the model to be deployed with Hugging Face. - Core - replace: set this flag to True to find and update an equivalent - Hugging Face deployment server with the new model instead of - starting a new deployment server. timeout: the timeout in seconds to wait for the Hugging Face endpoint to be provisioned and successfully started or updated. If set to 0, the method will return immediately after the Hugging Face @@ -184,263 +152,82 @@ def deploy_model( The ZenML Hugging Face deployment service object that can be used to interact with the remote Hugging Face inference endpoint server. """ - config = cast(HuggingFaceServiceConfig, config) - service = None - - # if replace is True, remove all existing services - if replace: - existing_services = self.find_model_server( - pipeline_name=config.pipeline_name, - pipeline_step_name=config.pipeline_step_name, - ) - - for existing_service in existing_services: - if service is None: - # keep the most recently created service - service = cast( - HuggingFaceDeploymentService, existing_service - ) - try: - # delete the older services and don't wait for them to - # be deprovisioned - self._clean_up_existing_service( - existing_service=cast( - HuggingFaceDeploymentService, existing_service - ), - timeout=timeout, - force=True, - ) - except RuntimeError: - # ignore errors encountered while stopping old services - pass - - if service: - # update an equivalent service in place - logger.info( - f"Updating an existing Hugging Face deployment service: {service}" - ) - - service_metadata = service.dict() - artifact_version = str(service_metadata["uuid"])[ - :UUID_SLICE_LENGTH - ] - config.endpoint_name = self.modify_endpoint_name( - config.endpoint_name, artifact_version - ) - - service.stop(timeout=timeout, force=True) - service.update(config) - service.start(timeout=timeout) - else: + with track_handler(AnalyticsEvent.MODEL_DEPLOYED) as analytics_handler: + config = cast(HuggingFaceServiceConfig, config) # create a new HuggingFaceDeploymentService instance - service = self._create_new_service(timeout, config) + service = self._create_new_service( + id=id, timeout=timeout, config=config + ) logger.info( f"Creating a new Hugging Face inference endpoint service: {service}" ) + # Add telemetry with metadata that gets the stack metadata and + # differentiates between pure model and custom code deployments + stack = Client().active_stack + stack_metadata = { + component_type.value: component.flavor + for component_type, component in stack.components.items() + } + analytics_handler.metadata = { + "store_type": Client().zen_store.type.value, + **stack_metadata, + } - return cast(BaseService, service) - - def find_model_server( - self, - running: bool = False, - service_uuid: Optional[UUID] = None, - pipeline_name: Optional[str] = None, - run_name: Optional[str] = None, - pipeline_step_name: Optional[str] = None, - model_name: Optional[str] = None, - model_uri: Optional[str] = None, - model_type: Optional[str] = None, - ) -> List[BaseService]: - """Find one or more Hugging Face model services that match the given criteria. - - Args: - running: if true, only running services will be returned. - service_uuid: the UUID of the Hugging Face service that was - originally used to create the Hugging Face deployment resource. - pipeline_name: name of the pipeline that the deployed model was part - of. - run_name: Name of the pipeline run which the deployed model was - part of. - pipeline_step_name: the name of the pipeline model deployment step - that deployed the model. - model_name: the name of the deployed model. - model_uri: URI of the deployed model. - model_type: the Hugging Face server implementation used to serve - the model - - Raises: - TypeError: If service type does not match HuggingFaceDeploymentService - - Returns: - One or more Hugging Face service objects representing Hugging Face - model servers that match the input search criteria. - """ - # Use a Hugging Face deployment service configuration to compute the labels - config = HuggingFaceServiceConfig( - pipeline_name=pipeline_name or "", - run_name=run_name or "", - pipeline_run_id=run_name or "", - pipeline_step_name=pipeline_step_name or "", - model_name=model_name or "", - model_uri=model_uri or "", - implementation=model_type or "", - ) - - services: List[BaseService] = [] - - # Find all services that match input criteria - for endpoint in self.deployed_endpoints: - if endpoint.name.startswith("zenml-"): - artifact_version = endpoint.name[-8:] - # If service_uuid is supplied, fetch service for that uuid - if ( - service_uuid is not None - and str(service_uuid)[:8] != artifact_version - ): - continue - - # Fetch the saved metadata artifact from zenml server to recreate service - client = Client() - try: - service_artifact = client.get_artifact_version( - HUGGINGFACE_SERVICE_ARTIFACT, artifact_version - ) - hf_deployment_service_dict = service_artifact.run_metadata[ - HUGGINGFACE_SERVICE_ARTIFACT - ].value - - existing_service = ( - ServiceRegistry().load_service_from_dict( - hf_deployment_service_dict # type: ignore - ) - ) - - if not isinstance( - existing_service, HuggingFaceDeploymentService - ): - raise TypeError( - f"Expected service type HuggingFaceDeploymentService but got " - f"{type(existing_service)} instead" - ) - - existing_service.update_status() - if self._matches_search_criteria(existing_service, config): - if not running or existing_service.is_running: - services.append( - cast(BaseService, existing_service) - ) - - # if endpoint is provisioned externally - # we do not have saved artifact for it. - except KeyError: - logger.error( - f"No key found for endpoint {endpoint.name} provisioned externally" - ) - - return services - - def _matches_search_criteria( - self, - existing_service: HuggingFaceDeploymentService, - config: HuggingFaceServiceConfig, - ) -> bool: - """Returns true if a service matches the input criteria. - - If any of the values in the input criteria are None, they are ignored. - This allows listing services just by common pipeline names or step - names, etc. - - Args: - existing_service: The materialized Service instance derived from - the config of the older (existing) service - config: The HuggingFaceServiceConfig object passed to the - deploy_model function holding parameters of the new service - to be created. - - Returns: - True if the service matches the input criteria. - """ - existing_service_config = existing_service.config - - # check if the existing service matches the input criteria - if ( - ( - not config.pipeline_name - or existing_service_config.pipeline_name - == config.pipeline_name - ) - and ( - not config.pipeline_step_name - or existing_service_config.pipeline_step_name - == config.pipeline_step_name - ) - and ( - not config.run_name - or existing_service_config.run_name == config.run_name - ) - ): - return True - - return False + return service - def stop_model_server( + def perform_stop_model( self, - uuid: UUID, + service: BaseService, timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, force: bool = False, - ) -> None: + ) -> BaseService: """Method to stop a model server. Args: - uuid: UUID of the model server to stop. + service: The service to stop. timeout: Timeout in seconds to wait for the service to stop. force: If True, force the service to stop. - """ - # get list of all services - existing_services = self.find_model_server(service_uuid=uuid) - # if the service exists, stop it - if existing_services: - existing_services[0].stop(timeout=timeout, force=force) + Returns: + The stopped service. + """ + service.stop(timeout=timeout, force=force) + return service - def start_model_server( - self, uuid: UUID, timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT - ) -> None: + def perform_start_model( + self, + service: BaseService, + timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + ) -> BaseService: """Method to start a model server. Args: - uuid: UUID of the model server to start. + service: The service to start. timeout: Timeout in seconds to wait for the service to start. - """ - # get list of all services - existing_services = self.find_model_server(service_uuid=uuid) - # if the service exists, start it - if existing_services: - existing_services[0].start(timeout=timeout) + Returns: + The started service. + """ + service.start(timeout=timeout) + return service - def delete_model_server( + def perform_delete_model( self, - uuid: UUID, + service: BaseService, timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, force: bool = False, ) -> None: """Method to delete all configuration of a model server. Args: - uuid: UUID of the model server to delete. + service: The service to delete. timeout: Timeout in seconds to wait for the service to stop. force: If True, force the service to stop. """ - # get list of all services - existing_services = self.find_model_server(service_uuid=uuid) - - # if the service exists, clean it up - if existing_services: - service = cast(HuggingFaceDeploymentService, existing_services[0]) - self._clean_up_existing_service( - existing_service=service, timeout=timeout, force=force - ) + service = cast(HuggingFaceDeploymentService, service) + self._clean_up_existing_service( + existing_service=service, timeout=timeout, force=force + ) @staticmethod def get_model_server_info( # type: ignore[override] @@ -455,5 +242,6 @@ def get_model_server_info( # type: ignore[override] Model server information. """ return { - "PREDICTION_URL": service_instance.prediction_url, + "PREDICTION_URL": service_instance.get_prediction_url(), + "HEALTH_CHECK_URL": service_instance.get_healthcheck_url(), } diff --git a/src/zenml/integrations/huggingface/services/huggingface_deployment.py b/src/zenml/integrations/huggingface/services/huggingface_deployment.py index 26af08f7548..ed12e9954d1 100644 --- a/src/zenml/integrations/huggingface/services/huggingface_deployment.py +++ b/src/zenml/integrations/huggingface/services/huggingface_deployment.py @@ -26,6 +26,7 @@ from huggingface_hub.utils import HfHubHTTPError from pydantic import Field +from zenml.client import Client from zenml.integrations.huggingface.flavors.huggingface_model_deployer_flavor import ( HuggingFaceBaseConfig, ) @@ -36,16 +37,11 @@ logger = get_logger(__name__) POLLING_TIMEOUT = 1200 +UUID_SLICE_LENGTH: int = 8 class HuggingFaceServiceConfig(HuggingFaceBaseConfig, ServiceConfig): - """Hugging Face service configurations. - - Attributes: - model_name: the name of the model. - """ - - model_name: str = "default" + """Hugging Face service configurations.""" class HuggingFaceServiceStatus(ServiceStatus): @@ -81,6 +77,35 @@ def __init__(self, config: HuggingFaceServiceConfig, **attrs: Any): """ super().__init__(config=config, **attrs) + def get_token(self) -> str: + """Get the Hugging Face token. + + Raises: + ValueError: If token not found. + + Returns: + Hugging Face token. + """ + client = Client() + token = None + if self.config.secret_name: + secret = client.get_secret(self.config.secret_name) + token = secret.secret_values["token"] + else: + from zenml.integrations.huggingface.model_deployers.huggingface_model_deployer import ( + HuggingFaceModelDeployer, + ) + + model_deployer = client.active_stack.model_deployer + if not isinstance(model_deployer, HuggingFaceModelDeployer): + raise ValueError( + "HuggingFaceModelDeployer is not active in the stack." + ) + token = model_deployer.config.token or None + if not token: + raise ValueError("Token not found.") + return token + @property def hf_endpoint(self) -> InferenceEndpoint: """Get the deployed Hugging Face inference endpoint. @@ -89,22 +114,20 @@ def hf_endpoint(self) -> InferenceEndpoint: Huggingface inference endpoint. """ return get_inference_endpoint( - name=self.config.endpoint_name, - token=self.config.token, + name=self._generate_an_endpoint_name(), + token=self.get_token(), namespace=self.config.namespace, ) @property - def prediction_url(self) -> Any: + def prediction_url(self) -> Optional[str]: """The prediction URI exposed by the prediction service. Returns: The prediction URI exposed by the prediction service, or None if the service is not yet ready. """ - if not self.is_running: - return None - return self.hf_endpoint.url + return self.hf_endpoint.url if self.is_running else None @property def inference_client(self) -> InferenceClient: @@ -123,8 +146,8 @@ def provision(self) -> None: """ try: # Attempt to create and wait for the inference endpoint - _ = create_inference_endpoint( - name=self.config.endpoint_name, + hf_endpoint = create_inference_endpoint( + name=self._generate_an_endpoint_name(), repository=self.config.repository, framework=self.config.framework, accelerator=self.config.accelerator, @@ -139,20 +162,10 @@ def provision(self) -> None: task=self.config.task, custom_image=self.config.custom_image, type=self.config.endpoint_type, + token=self.get_token(), namespace=self.config.namespace, - token=self.config.token, ).wait(timeout=POLLING_TIMEOUT) - # Check if the endpoint URL is available after provisioning - if self.hf_endpoint.url is not None: - logger.info( - "Hugging Face inference endpoint successfully deployed." - ) - else: - logger.error( - "Failed to start Hugging Face inference endpoint service: No URL available." - ) - except Exception as e: self.status.update_state( new_state=ServiceState.ERROR, error=str(e) @@ -162,6 +175,16 @@ def provision(self) -> None: f"An unexpected error occurred while provisioning the Hugging Face inference endpoint: {e}" ) + # Check if the endpoint URL is available after provisioning + if hf_endpoint.url: + logger.info( + f"Hugging Face inference endpoint successfully deployed and available. Endpoint URL: {hf_endpoint.url}" + ) + else: + logger.error( + "Failed to start Hugging Face inference endpoint service: No URL available, please check the Hugging Face console for more details." + ) + def check_status(self) -> Tuple[ServiceState, str]: """Check the the current operational state of the Hugging Face deployment. @@ -170,39 +193,30 @@ def check_status(self) -> Tuple[ServiceState, str]: providing additional information about that state (e.g. a description of the error, if one is encountered). """ - # TODO: Support all different InferenceEndpointStatus try: - _ = self.hf_endpoint.status - except (InferenceEndpointError, HfHubHTTPError): - return (ServiceState.INACTIVE, "") - - if self.hf_endpoint.status == InferenceEndpointStatus.RUNNING: - return ( - ServiceState.ACTIVE, - "Hugging Face Inference Endpoint deployment is available", - ) - - elif self.hf_endpoint.status == InferenceEndpointStatus.SCALED_TO_ZERO: - return ( - ServiceState.ACTIVE, - "Hugging Face Inference Endpoint deployment is scaled to zero", - ) - - elif self.hf_endpoint.status == InferenceEndpointStatus.FAILED: - return ( - ServiceState.ERROR, - "Hugging Face Inference Endpoint deployment failed: ", - ) + status = self.hf_endpoint.status + if status == InferenceEndpointStatus.RUNNING: + return (ServiceState.ACTIVE, "") + + elif status == InferenceEndpointStatus.SCALED_TO_ZERO: + return ( + ServiceState.SCALED_TO_ZERO, + "Hugging Face Inference Endpoint is scaled to zero, but still running. It will be started on demand.", + ) - elif self.hf_endpoint.status == InferenceEndpointStatus.PENDING: + elif status == InferenceEndpointStatus.FAILED: + return ( + ServiceState.ERROR, + "Hugging Face Inference Endpoint deployment is inactive or not found", + ) + elif status == InferenceEndpointStatus.PENDING: + return (ServiceState.PENDING_STARTUP, "") + return (ServiceState.PENDING_STARTUP, "") + except (InferenceEndpointError, HfHubHTTPError): return ( - ServiceState.PENDING_STARTUP, - "Hugging Face Inference Endpoint deployment is being created: ", + ServiceState.INACTIVE, + "Hugging Face Inference Endpoint deployment is inactive or not found", ) - return ( - ServiceState.PENDING_STARTUP, - "Hugging Face Inference Endpoint deployment is being created: ", - ) def deprovision(self, force: bool = False) -> None: """Deprovision the remote Hugging Face deployment instance. @@ -217,7 +231,6 @@ def deprovision(self, force: bool = False) -> None: logger.error( "Hugging Face Inference Endpoint is deleted or cannot be found." ) - pass def predict(self, data: "Any", max_new_tokens: int) -> "Any": """Make a prediction using the service. @@ -238,7 +251,7 @@ def predict(self, data: "Any", max_new_tokens: int) -> "Any": "Hugging Face endpoint inference service is not running. " "Please start the service before making predictions." ) - if self.hf_endpoint.prediction_url is not None: + if self.prediction_url is not None: if self.hf_endpoint.task == "text-generation": result = self.inference_client.task_generation( data, max_new_tokens=max_new_tokens @@ -267,3 +280,13 @@ def get_logs( "your Endpoints through the UI in the “Logs” tab of your Endpoint" ) return # type: ignore + + def _generate_an_endpoint_name(self) -> str: + """Generate a unique name for the Hugging Face Inference Endpoint. + + Returns: + A unique name for the Hugging Face Inference Endpoint. + """ + return ( + f"{self.config.service_name}-{str(self.uuid)[:UUID_SLICE_LENGTH]}" + ) diff --git a/src/zenml/integrations/huggingface/steps/huggingface_deployer.py b/src/zenml/integrations/huggingface/steps/huggingface_deployer.py index fd123e88341..5303d89bda7 100644 --- a/src/zenml/integrations/huggingface/steps/huggingface_deployer.py +++ b/src/zenml/integrations/huggingface/steps/huggingface_deployer.py @@ -58,21 +58,17 @@ def huggingface_model_deployer_step( # get pipeline name, step name and run id context = get_step_context() pipeline_name = context.pipeline.name - run_name = context.pipeline_run.name step_name = context.step_run.name # update the step configuration with the real pipeline runtime information service_config = service_config.copy() service_config.pipeline_name = pipeline_name - service_config.run_name = run_name service_config.pipeline_step_name = step_name # fetch existing services with same pipeline name, step name and # model name existing_services = model_deployer.find_model_server( - pipeline_name=pipeline_name, - pipeline_step_name=step_name, - model_name=service_config.model_name, + config=service_config.dict() ) # even when the deploy decision is negative, if an existing model server @@ -99,7 +95,10 @@ def huggingface_model_deployer_step( service = cast( HuggingFaceDeploymentService, model_deployer.deploy_model( - service_config, replace=True, timeout=timeout + service_config, + replace=True, + timeout=timeout, + service_type=HuggingFaceDeploymentService.SERVICE_TYPE, ), ) diff --git a/src/zenml/integrations/hyperai/orchestrators/hyperai_orchestrator.py b/src/zenml/integrations/hyperai/orchestrators/hyperai_orchestrator.py index 0f90eec205b..18b62a0bacc 100644 --- a/src/zenml/integrations/hyperai/orchestrators/hyperai_orchestrator.py +++ b/src/zenml/integrations/hyperai/orchestrators/hyperai_orchestrator.py @@ -17,7 +17,7 @@ import re import tempfile from shlex import quote -from typing import TYPE_CHECKING, Any, Dict, Optional, Type, cast +from typing import IO, TYPE_CHECKING, Any, Dict, Optional, Type, cast import paramiko import yaml @@ -129,6 +129,36 @@ def _escape_shell_command(self, command: str) -> str: """ return quote(command) + def _scp_to_hyperai_instance( + self, + paramiko_client: paramiko.SSHClient, + f: IO[str], + directory_name: str, + file_name: str, + description: str, + ) -> None: + """Copies a file to a HyperAI instance using SCP. + + Args: + paramiko_client: The SSH client to use for the SCP transfer. + f: The file to transfer. + directory_name: The directory on the HyperAI instance to transfer + the file to. + file_name: The name of the file being transferred. + description: A description of the file being transferred. + + Raises: + RuntimeError: If the file cannot be written to the HyperAI instance. + """ + try: + scp_client = paramiko_client.open_sftp() + scp_client.put(f.name, f"{directory_name}/{file_name}") + scp_client.close() + except FileNotFoundError: + raise RuntimeError( + f"Failed to write {description} to HyperAI instance. Does the user have permissions to write?" + ) + def prepare_or_run_pipeline( self, deployment: "PipelineDeploymentResponse", @@ -230,17 +260,25 @@ def prepare_or_run_pipeline( # Add dependency on upstream steps if applicable upstream_steps = step.spec.upstream_steps - for upstream_step_name in upstream_steps: - upstream_container_name = ( - f"{deployment_id}-{upstream_step_name}" - ) + + if len(upstream_steps) > 0: compose_definition["services"][container_name][ "depends_on" - ] = { - upstream_container_name: { - "condition": "service_completed_successfully" - } - } + ] = {} + + for upstream_step_name in upstream_steps: + upstream_container_name = ( + f"{deployment_id}-{upstream_step_name}" + ) + compose_definition["services"][container_name][ + "depends_on" + ].update( + { + upstream_container_name: { + "condition": "service_completed_successfully" + } + } + ) # Convert into yaml logger.info("Finalizing Docker Compose definition.") @@ -373,14 +411,33 @@ def prepare_or_run_pipeline( f_.write(compose_definition_yaml) # Scp Docker Compose file to HyperAI instance - try: - scp_client = paramiko_client.open_sftp() - scp_client.put(f.name, f"{directory_name}/docker-compose.yaml") - scp_client.close() - except FileNotFoundError: - raise RuntimeError( - "Failed to write Docker Compose file to HyperAI instance. Does the user have permissions to write?" - ) + self._scp_to_hyperai_instance( + paramiko_client, + f, + directory_name, + file_name="docker-compose.yml", + description="Docker Compose file", + ) + + # Create temporary file and write script to it + with tempfile.NamedTemporaryFile(mode="w", delete=True) as f: + # Define bash line and command line + bash_line = "#!/bin/bash\n" + command_line = f'cd {directory_name} && echo {ENV_ZENML_HYPERAI_RUN_ID}="{deployment_id}_$(date +\%s)" > .env && docker compose up -d' + + # Write script to temporary file + with f.file as f_: + f_.write(bash_line) + f_.write(command_line) + + # Scp script to HyperAI instance + self._scp_to_hyperai_instance( + paramiko_client, + f, + directory_name, + file_name="run_pipeline.sh", + description="startup script", + ) # Run or schedule Docker Compose file depending on settings if not deployment.schedule: @@ -413,7 +470,7 @@ def prepare_or_run_pipeline( # Create cron job for scheduled pipeline on HyperAI instance stdin, stdout, stderr = paramiko_client.exec_command( # nosec - f"(crontab -l ; echo '{cron_expression} cd {directory_name} && echo {ENV_ZENML_HYPERAI_RUN_ID}=\"{deployment_id}_$(date +\%s)\" > .env && docker compose up -d') | crontab -" + f"(crontab -l ; echo '{cron_expression} bash {directory_name}/run_pipeline.sh') | crontab -" ) logger.info("Pipeline scheduled successfully.") diff --git a/src/zenml/integrations/kaniko/flavors/kaniko_image_builder_flavor.py b/src/zenml/integrations/kaniko/flavors/kaniko_image_builder_flavor.py index 7bc785dff46..09e13cad35e 100644 --- a/src/zenml/integrations/kaniko/flavors/kaniko_image_builder_flavor.py +++ b/src/zenml/integrations/kaniko/flavors/kaniko_image_builder_flavor.py @@ -16,7 +16,7 @@ import json from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union -from pydantic import validator +from pydantic import PositiveInt, validator from zenml.image_builders import BaseImageBuilderConfig, BaseImageBuilderFlavor from zenml.integrations.kaniko import KANIKO_IMAGE_BUILDER_FLAVOR @@ -29,6 +29,7 @@ DEFAULT_KANIKO_EXECUTOR_IMAGE = ( f"gcr.io/kaniko-project/executor:{KANIKO_EXECUTOR_IMAGE_TAG}" ) +DEFAULT_KANIKO_POD_RUNNING_TIMEOUT = 300 class KanikoImageBuilderConfig(BaseImageBuilderConfig): @@ -47,6 +48,8 @@ class KanikoImageBuilderConfig(BaseImageBuilderConfig): Kaniko pod. This namespace will not be created and must already exist. executor_image: The image of the Kaniko executor to use. + pod_running_timeout: The timeout to wait until the pod is running + in seconds. Defaults to `300`. env: `env` section of the Kubernetes container spec. env_from: `envFrom` section of the Kubernetes container spec. volume_mounts: `volumeMounts` section of the Kubernetes container spec. @@ -67,6 +70,7 @@ class KanikoImageBuilderConfig(BaseImageBuilderConfig): kubernetes_context: str kubernetes_namespace: str = "zenml-kaniko" executor_image: str = DEFAULT_KANIKO_EXECUTOR_IMAGE + pod_running_timeout: PositiveInt = DEFAULT_KANIKO_POD_RUNNING_TIMEOUT env: List[Dict[str, Any]] = [] env_from: List[Dict[str, Any]] = [] diff --git a/src/zenml/integrations/kaniko/image_builders/kaniko_image_builder.py b/src/zenml/integrations/kaniko/image_builders/kaniko_image_builder.py index 314e31f0657..ebb3f09fefa 100644 --- a/src/zenml/integrations/kaniko/image_builders/kaniko_image_builder.py +++ b/src/zenml/integrations/kaniko/image_builders/kaniko_image_builder.py @@ -257,6 +257,8 @@ def _run_kaniko_build( self.config.executor_image, "--overrides", json.dumps(spec_overrides), + "--pod-running-timeout", + f"{self.config.pod_running_timeout}s", ] logger.debug("Running Kaniko build with command: %s", command) with subprocess.Popen( diff --git a/src/zenml/integrations/mlflow/__init__.py b/src/zenml/integrations/mlflow/__init__.py index c7cc82e790d..3cd8a0d146f 100644 --- a/src/zenml/integrations/mlflow/__init__.py +++ b/src/zenml/integrations/mlflow/__init__.py @@ -35,7 +35,7 @@ class MlflowIntegration(Integration): # does not pin it. They fixed this in a later version, so we can probably # remove this once we update the mlflow version. REQUIREMENTS = [ - "mlflow>=2.1.1,<=2.10.2", + "mlflow>=2.1.1,<=2.11.3", "mlserver>=1.3.3", "mlserver-mlflow>=1.3.3", # TODO: remove this requirement once rapidjson is fixed diff --git a/src/zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py b/src/zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py index d163ae09a5e..1c0de9b43d2 100644 --- a/src/zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py +++ b/src/zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py @@ -15,8 +15,7 @@ import os import shutil -from pathlib import Path -from typing import ClassVar, Dict, List, Optional, Type, cast +from typing import ClassVar, Dict, Optional, Type, cast from uuid import UUID from zenml.config.global_config import GlobalConfiguration @@ -31,8 +30,6 @@ ) from zenml.logger import get_logger from zenml.model_deployers import BaseModelDeployer, BaseModelDeployerFlavor -from zenml.services import ServiceRegistry -from zenml.services.local.local_service import SERVICE_DAEMON_CONFIG_FILE_NAME from zenml.services.service import BaseService, ServiceConfig from zenml.utils.io_utils import create_dir_recursive_if_not_exists @@ -120,12 +117,15 @@ def get_model_server_info( # type: ignore[override] "REGISTRY_MODEL_VERSION": service_instance.config.registry_model_version, "SERVICE_PATH": service_instance.status.runtime_path, "DAEMON_PID": str(service_instance.status.pid), + "HEALTH_CHECK_URL": service_instance.endpoint.monitor.get_healthcheck_uri( + service_instance.endpoint + ), } - def deploy_model( + def perform_deploy_model( self, + id: UUID, config: ServiceConfig, - replace: bool = False, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT, ) -> BaseService: """Create a new MLflow deployment service or update an existing one. @@ -157,10 +157,8 @@ def deploy_model( and the others are deleted. Args: + id: the ID of the MLflow deployment service to be created or updated. config: the configuration of the model to be deployed with MLflow. - replace: set this flag to True to find and update an equivalent - MLflow deployment server with the new model instead of - creating and starting a new deployment server. timeout: the timeout in seconds to wait for the MLflow server to be provisioned and successfully started or updated. If set to 0, the method will return immediately after the MLflow @@ -171,49 +169,11 @@ def deploy_model( interact with the MLflow model server. """ config = cast(MLFlowDeploymentConfig, config) - service = None - - # if replace is True, remove all existing services - if replace is True: - existing_services = self.find_model_server( - pipeline_name=config.pipeline_name, - pipeline_step_name=config.pipeline_step_name, - model_name=config.model_name, - ) - - for existing_service in existing_services: - if service is None: - # keep the most recently created service - service = cast(MLFlowDeploymentService, existing_service) - try: - # delete the older services and don't wait for them to - # be deprovisioned - self._clean_up_existing_service( - existing_service=cast( - MLFlowDeploymentService, existing_service - ), - timeout=timeout, - force=True, - ) - except RuntimeError: - # ignore errors encountered while stopping old services - pass - if service: - logger.info( - f"Updating an existing MLflow deployment service: {service}" - ) - - # set the root runtime path with the stack component's UUID - config.root_runtime_path = self.local_path - service.stop(timeout=timeout, force=True) - service.update(config) - service.start(timeout=timeout) - else: - # create a new MLFlowDeploymentService instance - service = self._create_new_service(timeout, config) - logger.info(f"Created a new MLflow deployment service: {service}") - - return cast(BaseService, service) + service = self._create_new_service( + id=id, timeout=timeout, config=config + ) + logger.info(f"Created a new MLflow deployment service: {service}") + return service def _clean_up_existing_service( self, @@ -232,11 +192,12 @@ def _clean_up_existing_service( # of workers etc.the step implementation will create a new config using # all values from the user and add values like pipeline name, model_uri def _create_new_service( - self, timeout: int, config: MLFlowDeploymentConfig + self, id: UUID, timeout: int, config: MLFlowDeploymentConfig ) -> MLFlowDeploymentService: """Creates a new MLFlowDeploymentService. Args: + id: the ID of the MLflow deployment service to be created or updated. timeout: the timeout in seconds to wait for the MLflow server to be provisioned and successfully started or updated. config: the configuration of the model to be deployed with MLflow. @@ -248,213 +209,61 @@ def _create_new_service( # set the root runtime path with the stack component's UUID config.root_runtime_path = self.local_path # create a new service for the new model - service = MLFlowDeploymentService(config) + service = MLFlowDeploymentService(uuid=id, config=config) service.start(timeout=timeout) return service - def find_model_server( + def perform_stop_model( self, - running: bool = False, - service_uuid: Optional[UUID] = None, - pipeline_name: Optional[str] = None, - run_name: Optional[str] = None, - pipeline_step_name: Optional[str] = None, - model_name: Optional[str] = None, - model_uri: Optional[str] = None, - model_type: Optional[str] = None, - registry_model_name: Optional[str] = None, - registry_model_version: Optional[str] = None, - ) -> List[BaseService]: - """Finds one or more model servers that match the given criteria. - - Args: - running: If true, only running services will be returned. - service_uuid: The UUID of the service that was originally used - to deploy the model. - pipeline_name: Name of the pipeline that the deployed model was part - of. - run_name: Name of the pipeline run which the deployed model - was part of. - pipeline_step_name: The name of the pipeline model deployment step - that deployed the model. - model_name: Name of the deployed model. - model_uri: URI of the deployed model. - model_type: Type/format of the deployed model. Not used in this - MLflow case. - registry_model_name: Name of the registered model that the - deployed model belongs to. - registry_model_version: Version of the registered model that - the deployed model belongs to. - - Returns: - One or more Service objects representing model servers that match - the input search criteria. - - Raises: - TypeError: if any of the input arguments are of an invalid type. - """ - services = [] - config = MLFlowDeploymentConfig( - model_name=model_name or "", - model_uri=model_uri or "", - pipeline_name=pipeline_name or "", - pipeline_run_id=run_name or "", - run_name=run_name or "", - pipeline_step_name=pipeline_step_name or "", - registry_model_name=registry_model_name, - registry_model_version=registry_model_version, - ) - - # find all services that match the input criteria - for root, _, files in os.walk(self.local_path): - if service_uuid and Path(root).name != str(service_uuid): - continue - for file in files: - if file == SERVICE_DAEMON_CONFIG_FILE_NAME: - service_config_path = os.path.join(root, file) - logger.debug( - "Loading service daemon configuration from %s", - service_config_path, - ) - existing_service_config = None - with open(service_config_path, "r") as f: - existing_service_config = f.read() - existing_service = ( - ServiceRegistry().load_service_from_json( - existing_service_config - ) - ) - if not isinstance( - existing_service, MLFlowDeploymentService - ): - raise TypeError( - f"Expected service type MLFlowDeploymentService but got " - f"{type(existing_service)} instead" - ) - existing_service.update_status() - if self._matches_search_criteria(existing_service, config): - if not running or existing_service.is_running: - services.append( - cast(BaseService, existing_service) - ) - - return services - - def _matches_search_criteria( - self, - existing_service: MLFlowDeploymentService, - config: MLFlowDeploymentConfig, - ) -> bool: - """Returns true if a service matches the input criteria. - - If any of the values in the input criteria are None, they are ignored. - This allows listing services just by common pipeline names or step - names, etc. - - Args: - existing_service: The materialized Service instance derived from - the config of the older (existing) service - config: The MLFlowDeploymentConfig object passed to the - deploy_model function holding parameters of the new service - to be created. - - Returns: - True if the service matches the input criteria. - """ - existing_service_config = existing_service.config - # check if the existing service matches the input criteria - if ( - ( - not config.pipeline_name - or existing_service_config.pipeline_name - == config.pipeline_name - ) - and ( - not config.model_name - or existing_service_config.model_name == config.model_name - ) - and ( - not config.pipeline_step_name - or existing_service_config.pipeline_step_name - == config.pipeline_step_name - ) - and ( - not config.run_name - or existing_service_config.run_name == config.run_name - ) - and ( - ( - not config.registry_model_name - and not config.registry_model_version - ) - or ( - existing_service_config.registry_model_name - == config.registry_model_name - and existing_service_config.registry_model_version - == config.registry_model_version - ) - ) - ): - return True - - return False - - def stop_model_server( - self, - uuid: UUID, + service: BaseService, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT, force: bool = False, - ) -> None: + ) -> BaseService: """Method to stop a model server. Args: - uuid: UUID of the model server to stop. + service: The service to stop. timeout: Timeout in seconds to wait for the service to stop. force: If True, force the service to stop. - """ - # get list of all services - existing_services = self.find_model_server(service_uuid=uuid) - # if the service exists, stop it - if existing_services: - existing_services[0].stop(timeout=timeout, force=force) + Returns: + The service that was stopped. + """ + service.stop(timeout=timeout, force=force) + return service - def start_model_server( - self, uuid: UUID, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT - ) -> None: + def perform_start_model( + self, + service: BaseService, + timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT, + ) -> BaseService: """Method to start a model server. Args: - uuid: UUID of the model server to start. + service: The service to start. timeout: Timeout in seconds to wait for the service to start. - """ - # get list of all services - existing_services = self.find_model_server(service_uuid=uuid) - # if the service exists, start it - if existing_services: - existing_services[0].start(timeout=timeout) + Returns: + The service that was started. + """ + service.start(timeout=timeout) + return service - def delete_model_server( + def perform_delete_model( self, - uuid: UUID, + service: BaseService, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT, force: bool = False, ) -> None: """Method to delete all configuration of a model server. Args: - uuid: UUID of the model server to delete. + service: The service to delete. timeout: Timeout in seconds to wait for the service to stop. force: If True, force the service to stop. """ - # get list of all services - existing_services = self.find_model_server(service_uuid=uuid) - - # if the service exists, clean it up - if existing_services: - service = cast(MLFlowDeploymentService, existing_services[0]) - self._clean_up_existing_service( - existing_service=service, timeout=timeout, force=force - ) + service = cast(MLFlowDeploymentService, service) + self._clean_up_existing_service( + existing_service=service, timeout=timeout, force=force + ) diff --git a/src/zenml/integrations/mlflow/services/mlflow_deployment.py b/src/zenml/integrations/mlflow/services/mlflow_deployment.py index 114f7e66a16..2cdccdbbf09 100644 --- a/src/zenml/integrations/mlflow/services/mlflow_deployment.py +++ b/src/zenml/integrations/mlflow/services/mlflow_deployment.py @@ -101,8 +101,6 @@ class MLFlowDeploymentConfig(LocalDaemonServiceConfig): timeout: timeout in seconds for starting and stopping the service """ - # TODO: ServiceConfig should have additional fields such as "pipeline_run_uuid" - # and "pipeline_uuid" to allow for better tracking of the service. model_uri: str model_name: str registry_model_name: Optional[str] = None @@ -128,6 +126,7 @@ class MLFlowDeploymentService(LocalDaemonService, BaseDeploymentService): type="model-serving", flavor="mlflow", description="MLflow prediction service", + logo_url="https://public-flavor-logos.s3.eu-central-1.amazonaws.com/model_deployer/mlflow.png", ) config: MLFlowDeploymentConfig diff --git a/src/zenml/integrations/mlflow/steps/mlflow_deployer.py b/src/zenml/integrations/mlflow/steps/mlflow_deployer.py index a93a42d75e9..6d0fde1a9df 100644 --- a/src/zenml/integrations/mlflow/steps/mlflow_deployer.py +++ b/src/zenml/integrations/mlflow/steps/mlflow_deployer.py @@ -118,32 +118,30 @@ def mlflow_model_deployer_step( run_id=mlflow_run_id, artifact_path=model_name ) - # Fetch existing services with same pipeline name, step name and model name - existing_services = model_deployer.find_model_server( + predictor_cfg = MLFlowDeploymentConfig( + model_name=model_name or "", + model_uri=model_uri, + workers=workers, + mlserver=mlserver, pipeline_name=pipeline_name, pipeline_step_name=step_name, - model_name=model_name, + timeout=timeout, + ) + + # Fetch existing services with same pipeline name, step name and model name + existing_services = model_deployer.find_model_server( + config=predictor_cfg.dict(), ) # Check whether to deploy a new service if model_uri and deploy_decision: - predictor_cfg = MLFlowDeploymentConfig( - model_name=model_name or "", - model_uri=model_uri, - workers=workers, - mlserver=mlserver, - pipeline_name=pipeline_name, - run_name=run_name, - pipeline_run_id=run_name, - pipeline_step_name=step_name, - timeout=timeout, - ) new_service = cast( MLFlowDeploymentService, model_deployer.deploy_model( replace=True, config=predictor_cfg, timeout=timeout, + service_type=MLFlowDeploymentService.SERVICE_TYPE, ), ) logger.info( @@ -277,26 +275,25 @@ def mlflow_model_registry_deployer_step( f"using this step." ) # fetch existing services with same pipeline name, step name and model name + existing_services = ( model_deployer.find_model_server( - registry_model_name=model_version.registered_model.name, + model_name=registry_model_name, + model_version=model_version.version, ) if replace_existing else None ) - # create a config for the new model service metadata = model_version.metadata or ModelRegistryModelMetadata() predictor_cfg = MLFlowDeploymentConfig( - model_name=model_name or "", + name=model_name or None, + model_name=registry_model_name, + model_version=model_version.version, model_uri=model_version.model_source_uri, - registry_model_name=model_version.registered_model.name, - registry_model_version=model_version.version, - registry_model_stage=model_version.stage.value, workers=workers, mlserver=mlserver, pipeline_name=metadata.zenml_pipeline_name or "", - run_name=metadata.zenml_run_name or "", pipeline_step_name=metadata.zenml_step_name or "", timeout=timeout, ) @@ -308,6 +305,7 @@ def mlflow_model_registry_deployer_step( replace=True, config=predictor_cfg, timeout=timeout, + service_type=MLFlowDeploymentService.SERVICE_TYPE, ), ) diff --git a/src/zenml/integrations/seldon/model_deployers/seldon_model_deployer.py b/src/zenml/integrations/seldon/model_deployers/seldon_model_deployer.py index 9529ccbcc85..8ae9282bee5 100644 --- a/src/zenml/integrations/seldon/model_deployers/seldon_model_deployer.py +++ b/src/zenml/integrations/seldon/model_deployers/seldon_model_deployer.py @@ -15,7 +15,6 @@ import json import re -from datetime import datetime from typing import TYPE_CHECKING, ClassVar, Dict, List, Optional, Type, cast from uuid import UUID @@ -479,10 +478,10 @@ def _delete_kubernetes_secret(self, secret_name: str) -> None: return self.seldon_client.delete_secret(secret_name) - def deploy_model( + def perform_deploy_model( self, + id: UUID, config: ServiceConfig, - replace: bool = False, timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT, ) -> BaseService: """Create a new Seldon Core deployment or update an existing one. @@ -517,11 +516,9 @@ def deploy_model( to be updated and the others are deleted. Args: + id: the UUID of the model server to deploy. config: the configuration of the model to be deployed with Seldon. Core - replace: set this flag to True to find and update an equivalent - Seldon Core deployment server with the new model instead of - starting a new deployment server. timeout: the timeout in seconds to wait for the Seldon Core server to be provisioned and successfully started or updated. If set to 0, the method will return immediately after the Seldon Core @@ -541,31 +538,6 @@ def deploy_model( """ with track_handler(AnalyticsEvent.MODEL_DEPLOYED) as analytics_handler: config = cast(SeldonDeploymentConfig, config) - service = None - - # if replace is True, find equivalent Seldon Core deployments - if replace is True: - equivalent_services = self.find_model_server( - running=False, - pipeline_name=config.pipeline_name, - pipeline_step_name=config.pipeline_step_name, - model_name=config.model_name, - ) - - for equivalent_service in equivalent_services: - if service is None: - # keep the most recently created service - service = equivalent_service - else: - try: - # delete the older services and don't wait for - # them to be deprovisioned - service.stop() - except RuntimeError: - # ignore errors encountered while stopping old - # services - pass - # if a custom Kubernetes secret is not explicitly specified in the # SeldonDeploymentConfig, try to create one from the ZenML secret # configured for the model deployer @@ -573,19 +545,9 @@ def deploy_model( config.secret_name or self._create_or_update_kubernetes_secret() ) - - if service: - # update an equivalent service in place - service.update(config) - logger.info( - f"Updating an existing Seldon deployment service: {service}" - ) - else: - # create a new service - service = SeldonDeploymentService(config=config) - logger.info( - f"Creating a new Seldon deployment service: {service}" - ) + # create a new service + service = SeldonDeploymentService(uuid=id, config=config) + logger.info(f"Creating a new Seldon deployment service: {service}") # start the service which in turn provisions the Seldon Core # deployment server and waits for it to reach a ready state @@ -606,95 +568,16 @@ def deploy_model( return service - def find_model_server( - self, - running: bool = False, - service_uuid: Optional[UUID] = None, - pipeline_name: Optional[str] = None, - run_name: Optional[str] = None, - pipeline_step_name: Optional[str] = None, - model_name: Optional[str] = None, - model_uri: Optional[str] = None, - model_type: Optional[str] = None, - ) -> List[BaseService]: - """Find one or more Seldon Core model services that match the given criteria. - - The Seldon Core deployment services that meet the search criteria are - returned sorted in descending order of their creation time (i.e. more - recent deployments first). - - Args: - running: if true, only running services will be returned. - service_uuid: the UUID of the Seldon Core service that was - originally used to create the Seldon Core deployment resource. - pipeline_name: name of the pipeline that the deployed model was part - of. - run_name: Name of the pipeline run which the deployed model was - part of. - pipeline_step_name: the name of the pipeline model deployment step - that deployed the model. - model_name: the name of the deployed model. - model_uri: URI of the deployed model. - model_type: the Seldon Core server implementation used to serve - the model - - Returns: - One or more Seldon Core service objects representing Seldon Core - model servers that match the input search criteria. - """ - # Use a Seldon deployment service configuration to compute the labels - config = SeldonDeploymentConfig( - pipeline_name=pipeline_name or "", - run_name=run_name or "", - pipeline_run_id=run_name or "", - pipeline_step_name=pipeline_step_name or "", - model_name=model_name or "", - model_uri=model_uri or "", - implementation=model_type or "", - ) - labels = config.get_seldon_deployment_labels() - if service_uuid: - # the service UUID is not a label covered by the Seldon - # deployment service configuration, so we need to add it - # separately - labels["zenml.service_uuid"] = str(service_uuid) - - deployments = self.seldon_client.find_deployments(labels=labels) - # sort the deployments in descending order of their creation time - deployments.sort( - key=lambda deployment: datetime.strptime( - deployment.metadata.creationTimestamp, - "%Y-%m-%dT%H:%M:%SZ", - ) - if deployment.metadata.creationTimestamp - else datetime.min, - reverse=True, - ) - - services: List[BaseService] = [] - for deployment in deployments: - # recreate the Seldon deployment service object from the Seldon - # deployment resource - service = SeldonDeploymentService.create_from_deployment( - deployment=deployment - ) - if running and not service.is_running: - # skip non-running services - continue - services.append(service) - - return services - - def stop_model_server( + def perform_stop_model( self, - uuid: UUID, + service: BaseService, timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT, force: bool = False, - ) -> None: + ) -> BaseService: """Stop a Seldon Core model server. Args: - uuid: UUID of the model server to stop. + service: The service to stop. timeout: timeout in seconds to wait for the service to stop. force: if True, force the service to stop. @@ -707,15 +590,15 @@ def stop_model_server( "deleting the Seldon Core model server instead." ) - def start_model_server( + def perform_start_model( self, - uuid: UUID, + service: BaseService, timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT, - ) -> None: + ) -> BaseService: """Start a Seldon Core model deployment server. Args: - uuid: UUID of the model server to start. + service: The service to start. timeout: timeout in seconds to wait for the service to become active. . If set to 0, the method will return immediately after provisioning the service, without waiting for it to become @@ -729,28 +612,22 @@ def start_model_server( "Starting Seldon Core model servers is not implemented" ) - def delete_model_server( + def perform_delete_model( self, - uuid: UUID, + service: BaseService, timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT, force: bool = False, ) -> None: """Delete a Seldon Core model deployment server. Args: - uuid: UUID of the model server to delete. + service: The service to delete. timeout: timeout in seconds to wait for the service to stop. If set to 0, the method will return immediately after deprovisioning the service, without waiting for it to stop. force: if True, force the service to stop. """ - services = self.find_model_server(service_uuid=uuid) - if len(services) == 0: - return - - service = services[0] - - assert isinstance(service, SeldonDeploymentService) + service = cast(SeldonDeploymentService, service) service.stop(timeout=timeout, force=force) if service.config.secret_name: diff --git a/src/zenml/integrations/seldon/services/seldon_deployment.py b/src/zenml/integrations/seldon/services/seldon_deployment.py index 28c6a1c1822..5d3c56a1b04 100644 --- a/src/zenml/integrations/seldon/services/seldon_deployment.py +++ b/src/zenml/integrations/seldon/services/seldon_deployment.py @@ -86,8 +86,6 @@ def get_seldon_deployment_labels(self) -> Dict[str, str]: labels = {} if self.pipeline_name: labels["zenml.pipeline_name"] = self.pipeline_name - if self.run_name: - labels["zenml.run_name"] = self.run_name if self.pipeline_step_name: labels["zenml.pipeline_step_name"] = self.pipeline_step_name if self.model_name: @@ -174,6 +172,7 @@ class SeldonDeploymentService(BaseDeploymentService): type="model-serving", flavor="seldon", description="Seldon Core prediction service", + logo_url="https://public-flavor-logos.s3.eu-central-1.amazonaws.com/model_deployer/seldon.png", ) config: SeldonDeploymentConfig diff --git a/src/zenml/integrations/seldon/steps/seldon_deployer.py b/src/zenml/integrations/seldon/steps/seldon_deployer.py index e89944e9f11..0b527252e4b 100644 --- a/src/zenml/integrations/seldon/steps/seldon_deployer.py +++ b/src/zenml/integrations/seldon/steps/seldon_deployer.py @@ -73,13 +73,11 @@ def seldon_model_deployer_step( # get pipeline name, step name and run id context = get_step_context() pipeline_name = context.pipeline.name - run_name = context.pipeline_run.name step_name = context.step_run.name # update the step configuration with the real pipeline runtime information service_config = service_config.copy() service_config.pipeline_name = pipeline_name - service_config.run_name = run_name service_config.pipeline_step_name = step_name def prepare_service_config(model_uri: str) -> SeldonDeploymentConfig: @@ -143,9 +141,7 @@ def prepare_service_config(model_uri: str) -> SeldonDeploymentConfig: # fetch existing services with same pipeline name, step name and # model name existing_services = model_deployer.find_model_server( - pipeline_name=pipeline_name, - pipeline_step_name=step_name, - model_name=service_config.model_name, + config=service_config.dict() ) # even when the deploy decision is negative, if an existing model server @@ -173,7 +169,10 @@ def prepare_service_config(model_uri: str) -> SeldonDeploymentConfig: service = cast( SeldonDeploymentService, model_deployer.deploy_model( - service_config, replace=True, timeout=timeout + service_config, + replace=True, + timeout=timeout, + service_type=SeldonDeploymentService.SERVICE_TYPE, ), ) @@ -231,21 +230,17 @@ def seldon_custom_model_deployer_step( # get pipeline name, step name, run id context = get_step_context() pipeline_name = context.pipeline.name - run_name = context.pipeline_run.name step_name = context.step_run.name # update the step configuration with the real pipeline runtime information service_config.pipeline_name = pipeline_name - service_config.run_name = run_name service_config.pipeline_step_name = step_name service_config.is_custom_deployment = True # fetch existing services with the same pipeline name, step name and # model name existing_services = model_deployer.find_model_server( - pipeline_name=pipeline_name, - pipeline_step_name=step_name, - model_name=service_config.model_name, + config=service_config.dict() ) # even when the deploy decision is negative if an existing model server # is not running for this pipeline/step, we still have to serve the @@ -325,7 +320,10 @@ def seldon_custom_model_deployer_step( service = cast( SeldonDeploymentService, model_deployer.deploy_model( - service_config, replace=True, timeout=timeout + service_config, + replace=True, + timeout=timeout, + service_type=SeldonDeploymentService.SERVICE_TYPE, ), ) @@ -476,7 +474,10 @@ def seldon_mlflow_registry_deployer_step( service = cast( SeldonDeploymentService, model_deployer.deploy_model( - service_config, replace=True, timeout=timeout + service_config, + replace=True, + timeout=timeout, + service_type=SeldonDeploymentService.SERVICE_TYPE, ), ) diff --git a/src/zenml/integrations/tensorboard/services/tensorboard_service.py b/src/zenml/integrations/tensorboard/services/tensorboard_service.py index adda572fd42..b6c61a8d017 100644 --- a/src/zenml/integrations/tensorboard/services/tensorboard_service.py +++ b/src/zenml/integrations/tensorboard/services/tensorboard_service.py @@ -13,6 +13,7 @@ # permissions and limitations under the License. """Implementation of the TensorBoard service.""" +import uuid from typing import Any, Dict, Union from tensorboard import default, program # type: ignore [import-untyped] @@ -103,7 +104,7 @@ def __init__( ), ) attrs["endpoint"] = endpoint - super().__init__(config=config, **attrs) + super().__init__(config=config, uuid=uuid.uuid4(), **attrs) def run(self) -> None: """Initialize and run the TensorBoard server.""" diff --git a/src/zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py b/src/zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py index 5be0e01d878..bc9b0c20f00 100644 --- a/src/zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py +++ b/src/zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py @@ -113,6 +113,7 @@ def visualize( service = TensorboardService( TensorboardServiceConfig( logdir=logdir, + name=f"zenml-tensorboard-{logdir}", ) ) service.start(timeout=60) diff --git a/src/zenml/materializers/service_materializer.py b/src/zenml/materializers/service_materializer.py index 7659cabe0ca..a8294433ab0 100644 --- a/src/zenml/materializers/service_materializer.py +++ b/src/zenml/materializers/service_materializer.py @@ -14,13 +14,13 @@ """Implementation of a materializer to read and write ZenML service instances.""" import os +import uuid from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type from zenml.client import Client from zenml.enums import ArtifactType from zenml.materializers.base_materializer import BaseMaterializer -from zenml.services.service import BaseService -from zenml.services.service_registry import ServiceRegistry +from zenml.services.service import BaseDeploymentService, BaseService if TYPE_CHECKING: from zenml.metadata.metadata_types import MetadataType @@ -49,8 +49,11 @@ def load(self, data_type: Type[Any]) -> BaseService: artifact_store = Client().active_stack.artifact_store filepath = os.path.join(self.uri, SERVICE_CONFIG_FILENAME) with artifact_store.open(filepath, "r") as f: - service = ServiceRegistry().load_service_from_json(f.read()) - return service + service_id = f.read().strip() + + client = Client() + service = client.get_service(name_id_or_prefix=uuid.UUID(service_id)) + return BaseDeploymentService.from_model(service) def save(self, service: BaseService) -> None: """Writes a ZenML service. @@ -64,7 +67,7 @@ def save(self, service: BaseService) -> None: artifact_store = Client().active_stack.artifact_store filepath = os.path.join(self.uri, SERVICE_CONFIG_FILENAME) with artifact_store.open(filepath, "w") as f: - f.write(service.json(indent=4)) + f.write(str(service.uuid)) def extract_metadata( self, service: BaseService @@ -79,6 +82,6 @@ def extract_metadata( """ from zenml.metadata.metadata_types import Uri - if service.endpoint and service.endpoint.status.uri: - return {"uri": Uri(service.endpoint.status.uri)} + if prediction_url := service.get_prediction_url() or None: + return {"uri": Uri(prediction_url)} return {} diff --git a/src/zenml/model/model.py b/src/zenml/model/model.py index 260d0d67c08..5374ffe2ce0 100644 --- a/src/zenml/model/model.py +++ b/src/zenml/model/model.py @@ -503,9 +503,13 @@ def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["suppress_class_validation_warnings"] = True return values - def _validate_config_in_runtime(self) -> None: - """Validate that config doesn't conflict with runtime environment.""" - self._get_or_create_model_version() + def _validate_config_in_runtime(self) -> "ModelVersionResponse": + """Validate that config doesn't conflict with runtime environment. + + Returns: + The model version based on configuration. + """ + return self._get_or_create_model_version() def _get_or_create_model(self) -> "ModelResponse": """This method should get or create a model from Model Control Plane. @@ -545,12 +549,10 @@ def _get_or_create_model(self) -> "ModelResponse": ) logger.info(f"New model `{self.name}` was created implicitly.") except EntityExistsError: - # this is backup logic, if model was created somehow in between get and create calls - pass - finally: model = zenml_client.zen_store.get_model( model_name_or_id=self.name ) + self._model_id = model.id return model @@ -722,7 +724,9 @@ def _get_or_create_model_version( retries_made += 1 self.version = model_version.name self.was_created_in_this_run = True + logger.info(f"New model version `{self.version}` was created.") + self._id = model_version.id self._model_id = model_version.model.id self._number = model_version.number diff --git a/src/zenml/model/utils.py b/src/zenml/model/utils.py index 824a4ef71e7..f947d182397 100644 --- a/src/zenml/model/utils.py +++ b/src/zenml/model/utils.py @@ -23,7 +23,10 @@ from zenml.logger import get_logger from zenml.metadata.metadata_types import MetadataType from zenml.model.model import Model -from zenml.models import ModelVersionArtifactRequest +from zenml.models import ( + ModelVersionArtifactRequest, + ServiceUpdate, +) from zenml.new.steps.step_context import get_step_context logger = get_logger(__name__) @@ -219,3 +222,49 @@ def link_artifact_to_model( artifact_version_id=artifact_version_id, model=model, ) + + +def link_service_to_model( + service_id: UUID, + model: Optional["Model"] = None, + model_version_id: Optional[UUID] = None, +) -> None: + """Links a service to a model. + + Args: + service_id: The ID of the service to link to the model. + model: The model to link the service to. + model_version_id: The ID of the model version to link the service to. + + Raises: + RuntimeError: If no model is provided and the model context cannot be + identified. + """ + client = Client() + + # If no model is provided, try to get it from the context + if not model and not model_version_id: + is_issue = False + try: + step_context = get_step_context() + model = step_context.model + except StepContextError: + is_issue = True + + if model is None or is_issue: + raise RuntimeError( + "`link_service_to_model` called without `model` parameter " + "and configured model context cannot be identified. Consider " + "passing the `model` explicitly or configuring it in " + "@step or @pipeline decorator." + ) + + model_version_id = ( + model_version_id or model._get_or_create_model_version().id + if model + else None + ) + update_service = ServiceUpdate(model_version_id=model_version_id) + client.zen_store.update_service( + service_id=service_id, update=update_service + ) diff --git a/src/zenml/model_deployers/base_model_deployer.py b/src/zenml/model_deployers/base_model_deployer.py index ccc3831b850..747c61fd674 100644 --- a/src/zenml/model_deployers/base_model_deployer.py +++ b/src/zenml/model_deployers/base_model_deployer.py @@ -13,9 +13,10 @@ # permissions and limitations under the License. """Base class for all ZenML model deployers.""" +import contextlib from abc import ABC, abstractmethod from typing import ( - TYPE_CHECKING, + Any, ClassVar, Dict, Generator, @@ -27,19 +28,16 @@ from uuid import UUID from zenml.client import Client -from zenml.constants import METADATA_DEPLOYED_MODEL_URL from zenml.enums import StackComponentType -from zenml.metadata.metadata_types import Uri +from zenml.logger import get_logger from zenml.services import BaseService, ServiceConfig from zenml.services.service import BaseDeploymentService +from zenml.services.service_type import ServiceType from zenml.stack import StackComponent from zenml.stack.flavor import Flavor from zenml.stack.stack_component import StackComponentConfig -if TYPE_CHECKING: - from zenml.config.step_run_info import StepRunInfo - from zenml.metadata.metadata_types import MetadataType - +logger = get_logger(__name__) DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT = 300 @@ -125,11 +123,118 @@ def get_active_model_deployer(cls) -> "BaseModelDeployer": return model_deployer - @abstractmethod def deploy_model( self, config: ServiceConfig, + service_type: ServiceType, replace: bool = False, + continuous_deployment_mode: bool = False, + timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + ) -> BaseService: + """Deploy a model. + + the deploy_model method is the main entry point for deploying models + using the model deployer. It is used to deploy a model to a model server + instance that is running on a remote serving platform or service. The + method is responsible for detecting if there is an existing model server + instance running serving one or more previous versions of the same model + and deploying the model to the serving platform or updating the existing + model server instance to include the new model version. The method + returns a Service object that is a representation of the external model + server instance. The Service object must implement basic operational + state tracking and lifecycle management operations for the model server + (e.g. start, stop, etc.). + + Args: + config: Custom Service configuration parameters for the model + deployer. Can include the pipeline name, the run id, the step + name, the model name, the model uri, the model type etc. + replace: If True, it will replace any existing model server instances + that serve the same model. If False, it does not replace any + existing model server instance. + continuous_deployment_mode: If True, it will replace any existing + model server instances that serve the same model, regardless of + the configuration. If False, it will only replace existing model + server instances that serve the same model if the configuration + is exactly the same. + timeout: The maximum time in seconds to wait for the model server + to start serving the model. + service_type: The type of the service to deploy. If not provided, + the default service type of the model deployer will be used. + + Raises: + RuntimeError: if the model deployment fails. + + Returns: + The deployment Service object. + """ + # Instantiate the client + client = Client() + if not continuous_deployment_mode: + # Find existing model server + services = self.find_model_server( + config=config.dict(), + service_type=service_type, + ) + if len(services) > 0: + logger.info( + f"Existing model server found for {config.name or config.model_name} with the exact same configuration. Returning the existing service named {services[0].config.service_name}." + ) + return services[0] + else: + # Find existing model server + services = self.find_model_server( + pipeline_name=config.pipeline_name, + pipeline_step_name=config.pipeline_step_name, + model_name=config.model_name, + service_type=service_type, + ) + if len(services) > 0: + logger.info( + f"Existing model server found for {config.pipeline_name} and {config.pipeline_step_name}, since continuous deployment mode is enabled, replacing the existing service named {services[0].config.service_name}." + ) + service = services[0] + self.delete_model_server(service.uuid) + logger.info( + f"Deploying model server for {config.model_name} with the following configuration: {config.dict()}" + ) + service_response = client.create_service( + config=config, + service_type=service_type, + model_version_id=get_model_version_id_if_exists( + config.model_name, config.model_version + ), + ) + try: + service = self.perform_deploy_model( + id=service_response.id, + config=config, + timeout=timeout, + ) + except Exception as e: + client.delete_service(service_response.id) + raise RuntimeError( + f"Failed to deploy model server for {config.model_name}: {e}" + ) from e + # Update the service in store + client.update_service( + id=service.uuid, + name=service.config.service_name, + service_source=service.dict().get("type"), + admin_state=service.admin_state, + status=service.status.dict(), + endpoint=service.endpoint.dict() if service.endpoint else None, + # labels=service.config.get_service_labels() # TODO: fix labels in services and config + prediction_url=service.get_prediction_url(), + health_check_url=service.get_healthcheck_url(), + ) + return service + + @abstractmethod + def perform_deploy_model( + self, + id: UUID, + config: ServiceConfig, timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, ) -> BaseService: """Abstract method to deploy a model. @@ -146,12 +251,10 @@ def deploy_model( start, stop, etc.) Args: + id: UUID of the service that was originally used to deploy the model. config: Custom Service configuration parameters for the model deployer. Can include the pipeline name, the run id, the step name, the model name, the model uri, the model type etc. - replace: If True, it will replace any existing model server instances - that serve the same model. If False, it does not replace any - existing model server instance. timeout: The maximum time in seconds to wait for the model server to start serving the model. @@ -173,17 +276,20 @@ def get_model_server_info( A dictionary containing the relevant model server properties. """ - @abstractmethod def find_model_server( self, - running: bool = False, + config: Optional[Dict[str, Any]] = None, + running: Optional[bool] = None, service_uuid: Optional[UUID] = None, pipeline_name: Optional[str] = None, - run_name: Optional[str] = None, pipeline_step_name: Optional[str] = None, + service_name: Optional[str] = None, model_name: Optional[str] = None, - model_uri: Optional[str] = None, - model_type: Optional[str] = None, + model_version: Optional[str] = None, + service_type: Optional[ServiceType] = None, + type: Optional[str] = None, + flavor: Optional[str] = None, + pipeline_run_id: Optional[str] = None, ) -> List[BaseService]: """Abstract method to find one or more a model servers that match the given criteria. @@ -191,23 +297,91 @@ def find_model_server( running: If true, only running services will be returned. service_uuid: The UUID of the service that was originally used to deploy the model. - pipeline_name: name of the pipeline that the deployed model was part - of. - run_name: Name of the pipeline run which the deployed model was - part of. - pipeline_step_name: the name of the pipeline model deployment step - that deployed the model. - model_name: the name of the deployed model. - model_uri: URI of the deployed model. - model_type: the implementation specific type/format of the deployed - model. + pipeline_step_name: The name of the pipeline step that was originally used + to deploy the model. + pipeline_name: The name of the pipeline that was originally used to deploy + the model from the model registry. + model_name: The name of the model that was originally used to deploy + the model from the model registry. + model_version: The version of the model that was originally used to + deploy the model from the model registry. + service_type: The type of the service to find. + type: The type of the service to find. + flavor: The flavor of the service to find. + pipeline_run_id: The UUID of the pipeline run that was originally used + to deploy the model. + config: Custom Service configuration parameters for the model + deployer. Can include the pipeline name, the run id, the step + name, the model name, the model uri, the model type etc. + service_name: The name of the service to find. Returns: One or more Service objects representing model servers that match the input search criteria. """ + client = Client() + service_responses = client.list_services( + sort_by="desc:created", + id=service_uuid, + running=running, + service_name=service_name, + pipeline_name=pipeline_name, + pipeline_step_name=pipeline_step_name, + model_version_id=get_model_version_id_if_exists( + model_name, model_version + ), + pipeline_run_id=pipeline_run_id, + config=config, + type=type or service_type.type if service_type else None, + flavor=flavor or service_type.flavor if service_type else None, + hydrate=True, + ) + services = [] + for service_response in service_responses.items: + if not service_response.service_source: + client.delete_service(service_response.id) + continue + service = BaseDeploymentService.from_model(service_response) + service.update_status() + if service.status.dict() != service_response.status: + client.update_service( + id=service.uuid, + admin_state=service.admin_state, + status=service.status.dict(), + endpoint=service.endpoint.dict() + if service.endpoint + else None, + ) + if running and not service.is_running: + logger.warning( + f"Service {service.uuid} is in an unexpected state. " + f"Expected running={running}, but found running={service.is_running}." + ) + continue + services.append(service) + return services @abstractmethod + def perform_stop_model( + self, + service: BaseService, + timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + force: bool = False, + ) -> BaseService: + """Abstract method to stop a model server. + + This operation should be reversible. A stopped model server should still + show up in the list of model servers returned by `find_model_server` and + it should be possible to start it again by calling `start_model_server`. + + Args: + service: The service to stop. + timeout: timeout in seconds to wait for the service to stop. If + set to 0, the method will return immediately after + deprovisioning the service, without waiting for it to stop. + force: if True, force the service to stop. + """ + def stop_model_server( self, uuid: UUID, @@ -226,9 +400,43 @@ def stop_model_server( set to 0, the method will return immediately after deprovisioning the service, without waiting for it to stop. force: if True, force the service to stop. + + Raises: + RuntimeError: if the model server is not found. """ + client = Client() + try: + service = self.find_model_server(service_uuid=uuid)[0] + updated_service = self.perform_stop_model(service, timeout, force) + client.update_service( + id=updated_service.uuid, + admin_state=updated_service.admin_state, + status=updated_service.status.dict(), + endpoint=updated_service.endpoint.dict() + if updated_service.endpoint + else None, + ) + except Exception as e: + raise RuntimeError( + f"Failed to stop model server with UUID {uuid}: {e}" + ) from e @abstractmethod + def perform_start_model( + self, + service: BaseService, + timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + ) -> BaseService: + """Abstract method to start a model server. + + Args: + service: The service to start. + timeout: timeout in seconds to wait for the service to start. If + set to 0, the method will return immediately after + provisioning the service, without waiting for it to become + active. + """ + def start_model_server( self, uuid: UUID, @@ -242,9 +450,47 @@ def start_model_server( set to 0, the method will return immediately after provisioning the service, without waiting for it to become active. + + Raises: + RuntimeError: if the model server is not found. """ + client = Client() + try: + service = self.find_model_server(service_uuid=uuid)[0] + updated_service = self.perform_start_model(service, timeout) + client.update_service( + id=updated_service.uuid, + admin_state=updated_service.admin_state, + status=updated_service.status.dict(), + endpoint=updated_service.endpoint.dict() + if updated_service.endpoint + else None, + ) + except Exception as e: + raise RuntimeError( + f"Failed to start model server with UUID {uuid}: {e}" + ) from e @abstractmethod + def perform_delete_model( + self, + service: BaseService, + timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + force: bool = False, + ) -> None: + """Abstract method to delete a model server. + + This operation is irreversible. A deleted model server must no longer + show up in the list of model servers returned by `find_model_server`. + + Args: + service: The service to delete. + timeout: timeout in seconds to wait for the service to stop. If + set to 0, the method will return immediately after + deprovisioning the service, without waiting for it to stop. + force: if True, force the service to stop. + """ + def delete_model_server( self, uuid: UUID, @@ -262,7 +508,19 @@ def delete_model_server( set to 0, the method will return immediately after deprovisioning the service, without waiting for it to stop. force: if True, force the service to stop. + + Raises: + RuntimeError: if the model server is not found. """ + client = Client() + try: + service = self.find_model_server(service_uuid=uuid)[0] + self.perform_delete_model(service, timeout, force) + client.delete_service(uuid) + except Exception as e: + raise RuntimeError( + f"Failed to delete model server with UUID {uuid}: {e}" + ) from e def get_model_server_logs( self, @@ -288,32 +546,21 @@ def get_model_server_logs( raise RuntimeError(f"No model server found with UUID {uuid}") return services[0].get_logs(follow=follow, tail=tail) - def get_step_run_metadata( - self, info: "StepRunInfo" - ) -> Dict[str, "MetadataType"]: - """Get component- and step-specific metadata after a step ran. - - For model deployers, this extracts the prediction URL of the deployed - model. + def load_service( + self, + service_id: UUID, + ) -> BaseService: + """Load a service from a URI. Args: - info: Info about the step that was executed. + service_id: The ID of the service to load. Returns: - A dictionary of metadata. + The loaded service. """ - existing_services = self.find_model_server( - run_name=info.run_name, - ) - if existing_services: - existing_service = existing_services[0] - if ( - isinstance(existing_service, BaseDeploymentService) - and existing_service.is_running - ): - deployed_model_url = existing_service.prediction_url - return {METADATA_DEPLOYED_MODEL_URL: Uri(deployed_model_url)} - return {} + client = Client() + service = client.get_service(service_id) + return BaseDeploymentService.from_model(service) class BaseModelDeployerFlavor(Flavor): @@ -341,3 +588,26 @@ def config_class(self) -> Type[BaseModelDeployerConfig]: @abstractmethod def implementation_class(self) -> Type[BaseModelDeployer]: """The class that implements the model deployer.""" + + +def get_model_version_id_if_exists( + model_name: Optional[str], + model_version: Optional[str], +) -> Optional[UUID]: + """Get the model version id if it exists. + + Args: + model_name: The name of the model. + model_version: The version of the model. + + Returns: + The model version id if it exists. + """ + client = Client() + if model_name: + with contextlib.suppress(KeyError): + return client.get_model_version( + model_name_or_id=model_name, + model_version_name_or_number_or_id=model_version, + ).id + return None diff --git a/src/zenml/models/__init__.py b/src/zenml/models/__init__.py index db5ff386a69..9e9480f5797 100644 --- a/src/zenml/models/__init__.py +++ b/src/zenml/models/__init__.py @@ -89,6 +89,15 @@ ArtifactVisualizationResponseBody, ArtifactVisualizationResponseMetadata, ) +from zenml.models.v2.core.service import ( + ServiceResponse, + ServiceResponseBody, + ServiceResponseMetadata, + ServiceUpdate, + ServiceFilter, + ServiceRequest, + ServiceResponseResources, +) from zenml.models.v2.core.code_reference import ( CodeReferenceRequest, CodeReferenceResponse, @@ -157,6 +166,7 @@ ModelVersionResponseMetadata, ModelVersionFilter, ModelVersionUpdate, + ModelVersionResponseResources, ) from zenml.models.v2.core.model_version_artifact import ( ModelVersionArtifactFilter, @@ -402,6 +412,15 @@ FlavorResponseMetadata.update_forward_refs( WorkspaceResponse=WorkspaceResponse, ) +ServiceResponseBody.update_forward_refs( + UserResponse=UserResponse, +) +ServiceResponseMetadata.update_forward_refs( + WorkspaceResponse=WorkspaceResponse, +) +ServiceResponseResources.update_forward_refs( + ModelVersionResponse=ModelVersionResponse, +) ModelResponseBody.update_forward_refs( UserResponse=UserResponse, TagResponse=TagResponse, @@ -418,6 +437,9 @@ WorkspaceResponse=WorkspaceResponse, RunMetadataResponse=RunMetadataResponse, ) +ModelVersionResponseResources.update_forward_refs( + ServiceResponse=ServiceResponse, +) ModelVersionArtifactResponseBody.update_forward_refs( ArtifactVersionResponse=ArtifactVersionResponse, ) @@ -639,6 +661,7 @@ "ModelVersionResponse", "ModelVersionResponseBody", "ModelVersionResponseMetadata", + "ModelVersionResponseResources", "ModelVersionUpdate", "ModelVersionArtifactFilter", "ModelVersionArtifactRequest", @@ -765,6 +788,13 @@ "WorkspaceResponse", "WorkspaceResponseBody", "WorkspaceResponseMetadata", + "ServiceResponse", + "ServiceResponseBody", + "ServiceResponseMetadata", + "ServiceUpdate", + "ServiceFilter", + "ServiceRequest", + "ServiceResponseResources", # V2 Misc "AuthenticationMethodModel", "ServiceConnectorResourcesModel", diff --git a/src/zenml/models/v2/core/model_version.py b/src/zenml/models/v2/core/model_version.py index 4e6dc97c489..04d8da143de 100644 --- a/src/zenml/models/v2/core/model_version.py +++ b/src/zenml/models/v2/core/model_version.py @@ -21,6 +21,7 @@ from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH from zenml.enums import ModelStages from zenml.models.v2.base.filter import AnyQuery +from zenml.models.v2.base.page import Page from zenml.models.v2.base.scoped import ( WorkspaceScopedRequest, WorkspaceScopedResponse, @@ -29,6 +30,7 @@ WorkspaceScopedResponseResources, WorkspaceScopedTaggableFilter, ) +from zenml.models.v2.core.service import ServiceResponse from zenml.models.v2.core.tag import TagResponse if TYPE_CHECKING: @@ -176,6 +178,10 @@ class ModelVersionResponseMetadata(WorkspaceScopedResponseMetadata): class ModelVersionResponseResources(WorkspaceScopedResponseResources): """Class for all resource models associated with the model version entity.""" + services: Page[ServiceResponse] = Field( + description="Services linked to the model version", + ) + class ModelVersionResponse( WorkspaceScopedResponse[ diff --git a/src/zenml/models/v2/core/service.py b/src/zenml/models/v2/core/service.py new file mode 100644 index 00000000000..b1bbc2c8210 --- /dev/null +++ b/src/zenml/models/v2/core/service.py @@ -0,0 +1,479 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Models representing Services.""" + +from datetime import datetime +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + List, + Optional, + Type, + Union, +) +from uuid import UUID + +from pydantic import BaseModel, Field +from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList +from sqlmodel import SQLModel + +from zenml.constants import STR_FIELD_MAX_LENGTH +from zenml.models.v2.base.scoped import ( + WorkspaceScopedFilter, + WorkspaceScopedRequest, + WorkspaceScopedResponse, + WorkspaceScopedResponseBody, + WorkspaceScopedResponseMetadata, + WorkspaceScopedResponseResources, + WorkspaceScopedTaggableFilter, +) +from zenml.services.service_status import ServiceState +from zenml.services.service_type import ServiceType + +if TYPE_CHECKING: + pass + +# ------------------ Request Model ------------------ + + +class ServiceRequest(WorkspaceScopedRequest): + """Request model for services.""" + + name: str = Field( + title="The name of the service.", + max_length=STR_FIELD_MAX_LENGTH, + ) + + service_type: ServiceType = Field( + title="The type of the service.", + ) + + service_source: Optional[str] = Field( + title="The class of the service.", + description="The fully qualified class name of the service implementation.", + ) + + admin_state: Optional[ServiceState] = Field( + title="The admin state of the service.", + description="The administrative state of the service, e.g., ACTIVE, INACTIVE.", + ) + + config: Dict[str, Any] = Field( + title="The service config.", + description="A dictionary containing configuration parameters for the service.", + ) + + labels: Optional[Dict[str, str]] = Field( + default=None, + title="The service labels.", + ) + + status: Optional[Dict[str, Any]] = Field( + title="The status of the service.", + ) + + endpoint: Optional[Dict[str, Any]] = Field( + default=None, + title="The service endpoint.", + ) + + prediction_url: Optional[str] = Field( + default=None, + title="The service endpoint URL.", + ) + + health_check_url: Optional[str] = Field( + default=None, + title="The service health check URL.", + ) + + model_version_id: Optional[UUID] = Field( + default=None, + title="The model version id linked to the service.", + ) + pipeline_run_id: Optional[Union[UUID, str]] = Field( + default=None, + description="By the event source this trigger is attached to.", + ) + + +# ------------------ Update Model ------------------ + + +class ServiceUpdate(BaseModel): + """Update model for stack components.""" + + name: Optional[str] = Field( + title="The name of the service.", + max_length=STR_FIELD_MAX_LENGTH, + ) + + admin_state: Optional[ServiceState] = Field( + title="The admin state of the service.", + description="The administrative state of the service, e.g., ACTIVE, INACTIVE.", + ) + + service_source: Optional[str] = Field( + title="The class of the service.", + description="The fully qualified class name of the service implementation.", + ) + + status: Optional[Dict[str, Any]] = Field( + title="The status of the service.", + ) + + endpoint: Optional[Dict[str, Any]] = Field( + title="The service endpoint.", + ) + + prediction_url: Optional[str] = Field( + title="The service endpoint URL.", + ) + + health_check_url: Optional[str] = Field( + title="The service health check URL.", + ) + + labels: Optional[Dict[str, str]] = Field( + default=None, + title="The service labels.", + ) + + model_version_id: Optional[UUID] = Field( + default=None, + title="The model version id linked to the service.", + ) + + +# ------------------ Response Model ------------------ + + +class ServiceResponseBody(WorkspaceScopedResponseBody): + """Response body for services.""" + + service_type: ServiceType = Field( + title="The type of the service.", + ) + labels: Optional[Dict[str, str]] = Field( + default=None, + title="The service labels.", + ) + created: datetime = Field( + title="The timestamp when this component was created." + ) + updated: datetime = Field( + title="The timestamp when this component was last updated.", + ) + state: Optional[ServiceState] = Field( + default=None, + title="The current state of the service.", + ) + + +class ServiceResponseMetadata(WorkspaceScopedResponseMetadata): + """Response metadata for services.""" + + service_source: Optional[str] = Field( + title="The class of the service.", + ) + admin_state: Optional[ServiceState] = Field( + title="The admin state of the service.", + ) + config: Dict[str, Any] = Field( + title="The service config.", + ) + status: Optional[Dict[str, Any]] = Field( + title="The status of the service.", + ) + endpoint: Optional[Dict[str, Any]] = Field( + default=None, + title="The service endpoint.", + ) + prediction_url: Optional[str] = Field( + default=None, + title="The service endpoint URL.", + ) + health_check_url: Optional[str] = Field( + default=None, + title="The service health check URL.", + ) + + +class ServiceResponseResources(WorkspaceScopedResponseResources): + """Class for all resource models associated with the service entity.""" + + +class ServiceResponse( + WorkspaceScopedResponse[ + ServiceResponseBody, ServiceResponseMetadata, ServiceResponseResources + ] +): + """Response model for services.""" + + name: str = Field( + title="The name of the service.", + max_length=STR_FIELD_MAX_LENGTH, + ) + + def get_hydrated_version(self) -> "ServiceResponse": + """Get the hydrated version of this artifact. + + Returns: + an instance of the same entity with the metadata field attached. + """ + from zenml.client import Client + + return Client().zen_store.get_service(self.id) + + # Body and metadata properties + + @property + def service_type(self) -> ServiceType: + """The `service_type` property. + + Returns: + the value of the property. + """ + return self.get_body().service_type + + @property + def labels(self) -> Optional[Dict[str, str]]: + """The `labels` property. + + Returns: + the value of the property. + """ + return self.get_body().labels + + @property + def service_source(self) -> Optional[str]: + """The `service_source` property. + + Returns: + the value of the property. + """ + return self.get_metadata().service_source + + @property + def config(self) -> Dict[str, Any]: + """The `config` property. + + Returns: + the value of the property. + """ + return self.get_metadata().config + + @property + def status(self) -> Optional[Dict[str, Any]]: + """The `status` property. + + Returns: + the value of the property. + """ + return self.get_metadata().status + + @property + def endpoint(self) -> Optional[Dict[str, Any]]: + """The `endpoint` property. + + Returns: + the value of the property. + """ + return self.get_metadata().endpoint + + @property + def created(self) -> datetime: + """The `created` property. + + Returns: + the value of the property. + """ + return self.get_body().created + + @property + def updated(self) -> datetime: + """The `updated` property. + + Returns: + the value of the property. + """ + return self.get_body().updated + + @property + def admin_state(self) -> Optional[ServiceState]: + """The `admin_state` property. + + Returns: + the value of the property. + """ + return self.get_metadata().admin_state + + @property + def prediction_url(self) -> Optional[str]: + """The `prediction_url` property. + + Returns: + the value of the property. + """ + return self.get_metadata().prediction_url + + @property + def health_check_url(self) -> Optional[str]: + """The `health_check_url` property. + + Returns: + the value of the property. + """ + return self.get_metadata().health_check_url + + @property + def state(self) -> Optional[ServiceState]: + """The `state` property. + + Returns: + the value of the property. + """ + return self.get_body().state + + +# ------------------ Filter Model ------------------ + + +class ServiceFilter(WorkspaceScopedFilter): + """Model to enable advanced filtering of services. + + The Service needs additional scoping. As such the `_scope_user` field + can be set to the user that is doing the filtering. The + `generate_filter()` method of the baseclass is overwritten to include the + scoping. + """ + + name: Optional[str] = Field( + description="Name of the service. Use this to filter services by their name.", + ) + workspace_id: Optional[Union[UUID, str]] = Field( + default=None, description="Workspace of the service" + ) + user_id: Optional[Union[UUID, str]] = Field( + default=None, description="User of the service" + ) + type: Optional[str] = Field( + default=None, + description="Type of the service. Filter services by their type.", + ) + flavor: Optional[str] = Field( + default=None, + description="Flavor of the service. Use this to filter services by their flavor.", + ) + config: Optional[bytes] = Field( + default=None, + description="Config of the service. Use this to filter services by their config.", + ) + pipeline_name: Optional[str] = Field( + default=None, + description="Pipeline name responsible for deploying the service", + ) + pipeline_step_name: Optional[str] = Field( + default=None, + description="Pipeline step name responsible for deploying the service", + ) + running: Optional[bool] = Field( + default=None, description="Whether the service is running" + ) + model_version_id: Optional[Union[UUID, str]] = Field( + default=None, + description="By the model version this service is attached to.", + ) + pipeline_run_id: Optional[Union[UUID, str]] = Field( + default=None, + description="By the pipeline run this service is attached to.", + ) + + def set_type(self, type: str) -> None: + """Set the type of the service. + + Args: + type: The type of the service. + """ + self.type = type + + def set_flavor(self, flavor: str) -> None: + """Set the flavor of the service. + + Args: + flavor: The flavor of the service. + """ + self.flavor = flavor + + # Artifact name and type are not DB fields and need to be handled separately + FILTER_EXCLUDE_FIELDS = [ + *WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS, + "flavor", + "type", + "pipeline_step_name", + "running", + "pipeline_name", + "config", + ] + CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + *WorkspaceScopedTaggableFilter.CLI_EXCLUDE_FIELDS, + "workspace_id", + "user_id", + "flavor", + "type", + "pipeline_step_name", + "running", + "pipeline_name", + ] + + def generate_filter( + self, table: Type["SQLModel"] + ) -> Union["BinaryExpression[Any]", "BooleanClauseList[Any]"]: + """Generate the filter for the query. + + Services can be scoped by type to narrow the search. + + Args: + table: The Table that is being queried from. + + Returns: + The filter expression for the query. + """ + from sqlalchemy import and_ + + base_filter = super().generate_filter(table) + + if self.type: + type_filter = getattr(table, "type") == self.type + base_filter = and_(base_filter, type_filter) + + if self.flavor: + flavor_filter = getattr(table, "flavor") == self.flavor + base_filter = and_(base_filter, flavor_filter) + + if self.pipeline_name: + pipeline_name_filter = ( + getattr(table, "pipeline_name") == self.pipeline_name + ) + base_filter = and_(base_filter, pipeline_name_filter) + + if self.pipeline_step_name: + pipeline_step_name_filter = ( + getattr(table, "pipeline_step_name") == self.pipeline_step_name + ) + base_filter = and_(base_filter, pipeline_step_name_filter) + + return base_filter diff --git a/src/zenml/new/pipelines/pipeline.py b/src/zenml/new/pipelines/pipeline.py index c7803f2665c..7ff767327fb 100644 --- a/src/zenml/new/pipelines/pipeline.py +++ b/src/zenml/new/pipelines/pipeline.py @@ -748,10 +748,6 @@ def _run( run_id=run.id if run else None, ) - deploy_pipeline( - deployment=deployment_model, stack=stack, placeholder_run=run - ) - if run: run_url = dashboard_utils.get_run_url(run) if run_url: @@ -763,6 +759,10 @@ def _run( "`zenml up`." ) + deploy_pipeline( + deployment=deployment_model, stack=stack, placeholder_run=run + ) + return run @staticmethod diff --git a/src/zenml/new/pipelines/run_utils.py b/src/zenml/new/pipelines/run_utils.py index e98caef2792..2b3d750ba6a 100644 --- a/src/zenml/new/pipelines/run_utils.py +++ b/src/zenml/new/pipelines/run_utils.py @@ -28,6 +28,7 @@ from zenml.new.pipelines.model_utils import NewModelRequest from zenml.orchestrators.utils import get_run_name from zenml.stack import Stack +from zenml.utils import cloud_utils if TYPE_CHECKING: from zenml.config.source import Source @@ -232,6 +233,7 @@ def _validate_new_version_requests( new_versions_requested: A dict of new model version request objects. """ + is_cloud_model = True for key, data in new_versions_requested.items(): model_name, model_version = key if len(data.requesters) > 1: @@ -241,4 +243,12 @@ def _validate_new_version_requests( "that `Model` requesting new version is configured only in one " "place of the pipeline." ) - data.model._validate_config_in_runtime() + model_version_response = data.model._validate_config_in_runtime() + is_cloud_model &= cloud_utils.is_cloud_model_version( + model_version_response + ) + if not is_cloud_model: + logger.info( + "Models can be viewed in the dashboard using ZenML Cloud. Sign up " + "for a free trial at https://www.zenml.io/cloud/" + ) diff --git a/src/zenml/orchestrators/local_docker/local_docker_orchestrator.py b/src/zenml/orchestrators/local_docker/local_docker_orchestrator.py index c5cd80bbc43..cd5d254d442 100644 --- a/src/zenml/orchestrators/local_docker/local_docker_orchestrator.py +++ b/src/zenml/orchestrators/local_docker/local_docker_orchestrator.py @@ -38,7 +38,7 @@ ContainerizedOrchestrator, ) from zenml.stack import Stack, StackValidator -from zenml.utils import string_utils +from zenml.utils import docker_utils, string_utils if TYPE_CHECKING: from zenml.models import PipelineDeploymentResponse @@ -117,9 +117,8 @@ def prepare_or_run_pipeline( "and the pipeline will be run immediately." ) - from docker.client import DockerClient + docker_client = docker_utils._try_get_docker_client_from_env() - docker_client = DockerClient.from_env() entrypoint = StepEntrypointConfiguration.get_entrypoint_command() # Add the local stores path as a volume mount diff --git a/src/zenml/orchestrators/step_runner.py b/src/zenml/orchestrators/step_runner.py index 4ab21d96767..b47d0c8aa37 100644 --- a/src/zenml/orchestrators/step_runner.py +++ b/src/zenml/orchestrators/step_runner.py @@ -44,7 +44,9 @@ from zenml.logger import get_logger from zenml.logging.step_logging import StepLogsStorageContext, redirected from zenml.materializers.base_materializer import BaseMaterializer -from zenml.model.utils import link_step_artifacts_to_model +from zenml.model.utils import ( + link_step_artifacts_to_model, +) from zenml.new.steps.step_context import StepContext, get_step_context from zenml.orchestrators.publish_utils import ( publish_step_run_metadata, diff --git a/src/zenml/service_connectors/docker_service_connector.py b/src/zenml/service_connectors/docker_service_connector.py index 70f6c92f26d..13f9d632035 100644 --- a/src/zenml/service_connectors/docker_service_connector.py +++ b/src/zenml/service_connectors/docker_service_connector.py @@ -37,6 +37,7 @@ AuthenticationConfig, ServiceConnector, ) +from zenml.utils import docker_utils from zenml.utils.enum_utils import StrEnum logger = get_logger(__name__) @@ -258,7 +259,9 @@ def _connect_to_resource( An authenticated python-docker client object. """ assert self.resource_id is not None - docker_client = DockerClient.from_env() + + docker_client = docker_utils._try_get_docker_client_from_env() + self._authorize_client(docker_client, self.resource_id) return docker_client diff --git a/src/zenml/services/__init__.py b/src/zenml/services/__init__.py index 55ef932dc48..95646020d5e 100644 --- a/src/zenml/services/__init__.py +++ b/src/zenml/services/__init__.py @@ -51,7 +51,6 @@ TCPEndpointHealthMonitor, TCPEndpointHealthMonitorConfig, ) -from zenml.services.service_registry import ServiceRegistry from zenml.services.service_status import ServiceState, ServiceStatus from zenml.services.service_type import ServiceType @@ -84,5 +83,4 @@ "LocalDaemonServiceEndpointConfig", "LocalDaemonServiceEndpointStatus", "LocalDaemonServiceEndpoint", - "ServiceRegistry", ] diff --git a/src/zenml/services/container/container_service.py b/src/zenml/services/container/container_service.py index 5c8dcb3b8cb..28089b1bdf9 100644 --- a/src/zenml/services/container/container_service.py +++ b/src/zenml/services/container/container_service.py @@ -33,6 +33,7 @@ ) from zenml.services.service import BaseService, ServiceConfig from zenml.services.service_status import ServiceState, ServiceStatus +from zenml.utils import docker_utils from zenml.utils.io_utils import ( create_dir_recursive_if_not_exists, get_global_config_directory, @@ -177,7 +178,9 @@ def docker_client(self) -> DockerClient: The docker client. """ if self._docker_client is None: - self._docker_client = DockerClient.from_env() + self._docker_client = ( + docker_utils._try_get_docker_client_from_env() + ) return self._docker_client @property diff --git a/src/zenml/services/container/entrypoint.py b/src/zenml/services/container/entrypoint.py index 2f0956a192e..b7476bc19b9 100644 --- a/src/zenml/services/container/entrypoint.py +++ b/src/zenml/services/container/entrypoint.py @@ -19,6 +19,7 @@ import os import sys +from typing import cast import click @@ -50,7 +51,7 @@ def launch_service(service_config_file: str) -> None: # with messages before daemonization is complete from zenml.integrations.registry import integration_registry from zenml.logger import get_logger - from zenml.services import ContainerService, ServiceRegistry + from zenml.services import ContainerService logger = get_logger(__name__) @@ -63,7 +64,7 @@ def launch_service(service_config_file: str) -> None: logger.debug( "Running containerized service with configuration:\n %s", config ) - service = ServiceRegistry().load_service_from_json(config) + service = cast("ContainerService", ContainerService.from_json(config)) if not isinstance(service, ContainerService): raise TypeError( f"Expected service type ContainerService but got " diff --git a/src/zenml/services/local/local_daemon_entrypoint.py b/src/zenml/services/local/local_daemon_entrypoint.py index 3d2cf42f8a3..33d03685cd9 100644 --- a/src/zenml/services/local/local_daemon_entrypoint.py +++ b/src/zenml/services/local/local_daemon_entrypoint.py @@ -18,6 +18,7 @@ """ import os +from typing import cast import click @@ -68,7 +69,7 @@ def launch_service(service_config_file: str) -> None: # with messages before daemonization is complete from zenml.integrations.registry import integration_registry from zenml.logger import get_logger - from zenml.services import LocalDaemonService, ServiceRegistry + from zenml.services import LocalDaemonService logger = get_logger(__name__) @@ -81,7 +82,9 @@ def launch_service(service_config_file: str) -> None: integration_registry.activate_integrations() logger.debug("Running service daemon with configuration:\n %s", config) - service = ServiceRegistry().load_service_from_json(config) + service = cast( + "LocalDaemonService", LocalDaemonService.from_json(config) + ) if not isinstance(service, LocalDaemonService): raise TypeError( f"Expected service type LocalDaemonService but got " diff --git a/src/zenml/services/service.py b/src/zenml/services/service.py index ba2664be586..446cee28709 100644 --- a/src/zenml/services/service.py +++ b/src/zenml/services/service.py @@ -13,10 +13,12 @@ # permissions and limitations under the License. """Implementation of the ZenML Service class.""" +import json import time from abc import abstractmethod from functools import wraps from typing import ( + TYPE_CHECKING, Any, Callable, ClassVar, @@ -26,24 +28,27 @@ Tuple, Type, TypeVar, - cast, ) -from uuid import UUID, uuid4 - -from pydantic import Field +from uuid import UUID from zenml.console import console from zenml.logger import get_logger from zenml.services.service_endpoint import BaseServiceEndpoint -from zenml.services.service_registry import ServiceRegistry +from zenml.services.service_monitor import HTTPEndpointHealthMonitor from zenml.services.service_status import ServiceState, ServiceStatus from zenml.services.service_type import ServiceType -from zenml.utils.typed_model import BaseTypedModel, BaseTypedModelMeta +from zenml.utils import source_utils +from zenml.utils.typed_model import BaseTypedModel logger = get_logger(__name__) T = TypeVar("T", bound=Callable[..., Any]) +if TYPE_CHECKING: + from zenml.models.v2.core.service import ServiceResponse + +ZENM_ENDPOINT_PREFIX = "zenml-" + def update_service_status( pre_status: Optional[ServiceState] = None, @@ -108,107 +113,42 @@ class ServiceConfig(BaseTypedModel): description: str = "" pipeline_name: str = "" pipeline_step_name: str = "" - run_name: str = "" + model_name: str = "" + model_version: str = "" + service_name: str = "" - -class BaseServiceMeta(BaseTypedModelMeta): - """Metaclass responsible for registering different BaseService subclasses. - - This metaclass has two main responsibilities: - 1. register all BaseService types in the service registry. This is relevant - when services are deserialized and instantiated from their JSON or dict - representation, because their type needs to be known beforehand. - 2. ensuring BaseService instance uniqueness by enforcing that no two - service instances have the same UUID value. Implementing this at the - constructor level guarantees that deserializing a service instance from - a JSON representation multiple times always returns the same service object. - """ - - def __new__( - mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any] - ) -> "BaseServiceMeta": - """Creates a BaseService class and registers it in the `ServiceRegistry`. + def __init__(self, **data: Any): + """Initialize the service configuration. Args: - name: name of the class. - bases: tuple of base classes. - dct: dictionary of class attributes. - - Returns: - the created BaseServiceMeta class. + **data: keyword arguments. Raises: - TypeError: if the 'service_type' reserved attribute name is used. + ValueError: if neither 'name' nor 'model_name' is set. """ - service_type = dct.get("SERVICE_TYPE", None) - - # register only classes of concrete service implementations - if service_type: - # add the service type class attribute to the class as a regular - # immutable attribute to include it in the JSON representation - if "service_type" in dct: - raise TypeError( - "`service_type` is a reserved attribute name for BaseService " - "subclasses" - ) - dct.setdefault("__annotations__", dict())["service_type"] = ( - ServiceType + super().__init__(**data) + if self.name or self.model_name: + self.service_name = data.get( + "service_name", + f"{ZENM_ENDPOINT_PREFIX}{self.name or self.model_name}", ) - dct["service_type"] = Field(service_type, allow_mutation=False) - - cls = cast(Type["BaseService"], super().__new__(mcs, name, bases, dct)) - - # register only classes of concrete service implementations - if service_type: - # register the service type in the service registry - ServiceRegistry().register_service_type(cls) - return cls - - def __call__(cls, *args: Any, **kwargs: Any) -> "BaseServiceMeta": - """Validate the creation of a service. + else: + raise ValueError("Either 'name' or 'model_name' must be set.") - Args: - *args: positional arguments. - **kwargs: keyword arguments. + def get_service_labels(self) -> Dict[str, str]: + """Get the service labels. Returns: - the created BaseServiceMeta class. - - Raises: - AttributeError: if the service UUID is untyped. - ValueError: if the service UUID is not a UUID type. + a dictionary of service labels. """ - if not getattr(cls, "SERVICE_TYPE", None): - raise AttributeError( - f"Untyped service instances are not allowed. Please set the " - f"SERVICE_TYPE class attribute for {cls}." - ) - uuid = kwargs.get("uuid", None) - if uuid: - if isinstance(uuid, str): - uuid = UUID(uuid) - if not isinstance(uuid, UUID): - raise ValueError( - f"The `uuid` argument for {cls} must be a UUID instance or a " - f"string representation of a UUID." - ) - - # if a service instance with the same UUID is already registered, - # return the existing instance rather than the newly created one - existing_service = ServiceRegistry().get_service(uuid) - if existing_service: - logger.debug( - f"Reusing existing service '{existing_service}' " - f"instead of creating a new service with the same UUID." - ) - return cast("BaseServiceMeta", existing_service) - - svc = cast("BaseService", super().__call__(*args, **kwargs)) - ServiceRegistry().register_service(svc) - return cast("BaseServiceMeta", svc) + labels = {} + for k, v in self.dict().items(): + label = f"zenml_{k}".upper() + labels[label] = str(v) + return labels -class BaseService(BaseTypedModel, metaclass=BaseServiceMeta): +class BaseService(BaseTypedModel): """Base service class. This class implements generic functionality concerning the life-cycle @@ -227,7 +167,7 @@ class BaseService(BaseTypedModel, metaclass=BaseServiceMeta): SERVICE_TYPE: ClassVar[ServiceType] - uuid: UUID = Field(default_factory=uuid4, allow_mutation=False) + uuid: UUID admin_state: ServiceState = ServiceState.INACTIVE config: ServiceConfig status: ServiceStatus @@ -246,6 +186,49 @@ def __init__( super().__init__(**attrs) self.config.name = self.config.name or self.__class__.__name__ + @classmethod + def from_model(cls, model: "ServiceResponse") -> "BaseService": + """Loads a service from a model. + + Args: + model: The ServiceResponse to load from. + + Returns: + The loaded service object. + + Raises: + ValueError: if the service source is not found in the model. + """ + if not model.service_source: + raise ValueError("Service source not found in the model.") + class_: Type[BaseService] = source_utils.load_and_validate_class( + source=model.service_source, expected_class=BaseService + ) + return class_( + uuid=model.id, + admin_state=model.admin_state, + config=model.config, + status=model.status, + service_type=model.service_type.dict(), + endpoint=model.endpoint, + ) + + @classmethod + def from_json(cls, json_str: str) -> "BaseTypedModel": + """Loads a service from a JSON string. + + Args: + json_str: the JSON string to load from. + + Returns: + The loaded service object. + """ + service_dict = json.loads(json_str) + class_: Type[BaseService] = source_utils.load_and_validate_class( + source=service_dict["type"], expected_class=BaseService + ) + return class_.from_dict(service_dict) + @abstractmethod def check_status(self) -> Tuple[ServiceState, str]: """Check the the current operational state of the external service. @@ -449,19 +432,15 @@ def start(self, timeout: int = 0) -> None: timeout: amount of time to wait for the service to become active. If set to 0, the method will return immediately after checking the service status. - - Raises: - RuntimeError: if the service cannot be started """ with console.status(f"Starting service '{self}'.\n"): self.admin_state = ServiceState.ACTIVE self.provision() - if timeout > 0: - if not self.poll_service_status(timeout): - raise RuntimeError( - f"Failed to start service {self}\n" - + self.get_service_status_message() - ) + if timeout > 0 and not self.poll_service_status(timeout): + logger.error( + f"Failed to start service {self}\n" + + self.get_service_status_message() + ) @update_service_status( pre_status=ServiceState.PENDING_SHUTDOWN, @@ -476,9 +455,6 @@ def stop(self, timeout: int = 0, force: bool = False) -> None: the service status. force: if True, the service will be stopped even if it is not currently running. - - Raises: - RuntimeError: if the service cannot be stopped """ with console.status(f"Stopping service '{self}'.\n"): self.admin_state = ServiceState.INACTIVE @@ -486,12 +462,40 @@ def stop(self, timeout: int = 0, force: bool = False) -> None: if timeout > 0: self.poll_service_status(timeout) if not self.is_stopped: - raise RuntimeError( + logger.error( f"Failed to stop service {self}. Last state: " f"'{self.status.state.value}'. Last error: " f"'{self.status.last_error}'" ) + def get_prediction_url(self) -> Optional[str]: + """Gets the prediction URL for the endpoint. + + Returns: + the prediction URL for the endpoint + """ + prediction_url = None + if isinstance(self, BaseDeploymentService) and self.prediction_url: + prediction_url = self.prediction_url + elif self.endpoint: + prediction_url = ( + self.endpoint.status.uri if self.endpoint.status else None + ) + return prediction_url + + def get_healthcheck_url(self) -> Optional[str]: + """Gets the healthcheck URL for the endpoint. + + Returns: + the healthcheck URL for the endpoint + """ + return ( + self.endpoint.monitor.get_healthcheck_uri(self.endpoint) + if (self.endpoint and self.endpoint.monitor) + and isinstance(self.endpoint.monitor, HTTPEndpointHealthMonitor) + else None + ) + def __repr__(self) -> str: """String representation of the service. @@ -529,3 +533,12 @@ def prediction_url(self) -> Optional[str]: the prediction URL for the endpoint """ return None + + @property + def healthcheck_url(self) -> Optional[str]: + """Gets the healthcheck URL for the endpoint. + + Returns: + the healthcheck URL for the endpoint + """ + return None diff --git a/src/zenml/services/service_registry.py b/src/zenml/services/service_registry.py deleted file mode 100644 index c88cdab8b5b..00000000000 --- a/src/zenml/services/service_registry.py +++ /dev/null @@ -1,214 +0,0 @@ -# Copyright (c) ZenML GmbH 2022. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. -"""Implementation of the ZenML service registry.""" - -import json -from typing import TYPE_CHECKING, Any, Dict, Optional, Type, cast -from uuid import UUID - -from zenml.logger import get_logger -from zenml.services.service_type import ServiceType -from zenml.utils.singleton import SingletonMetaClass - -logger = get_logger(__name__) - -if TYPE_CHECKING: - from zenml.services.service import BaseService - - -class ServiceRegistry(metaclass=SingletonMetaClass): - """Registry of service types and service instances. - - The service registry provides a central place to register service types - as well as service instances. - """ - - def __init__(self) -> None: - """Initialize the service registry.""" - self.service_types: Dict[ServiceType, Type["BaseService"]] = {} - self.services: Dict[UUID, "BaseService"] = {} - - def register_service_type(self, cls: Type["BaseService"]) -> None: - """Registers a new service type. - - Args: - cls: a BaseService subclass. - - Raises: - TypeError: if the service type is already registered. - """ - service_type = cls.SERVICE_TYPE - if service_type not in self.service_types: - self.service_types[service_type] = cls - logger.debug( - f"Registered service class {cls} for " - f"service type `{service_type}`" - ) - else: - raise TypeError( - f"Found existing service type for {service_type}: " - f"{self.service_types[service_type]}. Skipping registration " - f"of {cls}." - ) - - def get_service_type( - self, service_type: ServiceType - ) -> Optional[Type["BaseService"]]: - """Get the service class registered for a service type. - - Args: - service_type: service type. - - Returns: - `BaseService` subclass that was registered for the service type or - None, if no service class was registered for the service type. - """ - return self.service_types.get(service_type) - - def get_service_types( - self, - ) -> Dict[ServiceType, Type["BaseService"]]: - """Get all registered service types. - - Returns: - Dictionary of service types indexed by their service type. - """ - return self.service_types.copy() - - def service_type_is_registered(self, service_type: ServiceType) -> bool: - """Check if a service type is registered. - - Args: - service_type: service type. - - Returns: - True, if a service type is registered for the service type, False - otherwise. - """ - return service_type in self.service_types - - def register_service(self, service: "BaseService") -> None: - """Registers a new service instance. - - Args: - service: a BaseService instance. - - Raises: - TypeError: if the service instance has a service type that is not - registered. - Exception: if a preexisting service is found for that UUID. - """ - service_type = service.SERVICE_TYPE - if service_type not in self.service_types: - raise TypeError( - f"Service type `{service_type}` is not registered." - ) - - if service.uuid not in self.services: - self.services[service.uuid] = service - logger.debug(f"Registered service {service}") - else: - existing_service = self.services[service.uuid] - raise Exception( - f"Found existing service {existing_service} for UUID: " - f"{service.uuid}. Skipping registration for service " - f"{service}." - ) - - def get_service(self, uuid: UUID) -> Optional["BaseService"]: - """Get the service instance registered for a UUID. - - Args: - uuid: service instance identifier. - - Returns: - `BaseService` instance that was registered for the UUID or - None, if no matching service instance was found. - """ - return self.services.get(uuid) - - def get_services(self) -> Dict[UUID, "BaseService"]: - """Get all service instances currently registered. - - Returns: - Dictionary of `BaseService` instances indexed by their UUID with - all services that are currently registered. - """ - return self.services.copy() - - def service_is_registered(self, uuid: UUID) -> bool: - """Check if a service instance is registered. - - Args: - uuid: service instance identifier. - - Returns: - True, if a service instance is registered for the UUID, False - otherwise. - """ - return uuid in self.services - - def load_service_from_dict( - self, service_dict: Dict[str, Any] - ) -> "BaseService": - """Load a service instance from its dict representation. - - Creates, registers and returns a service instantiated from the dict - representation of the service configuration and last known status - information. - - If an existing service instance with the same UUID is already - present in the service registry, it is returned instead. - - Args: - service_dict: dict representation of the service configuration and - last known status - - Returns: - A new or existing ZenML service instance. - - Raises: - TypeError: if the service type is not registered. - ValueError: if the service type is not valid. - """ - service_type = service_dict.get("service_type") - if not service_type: - raise ValueError( - "Service type not present in the service dictionary" - ) - service_type = ServiceType.parse_obj(service_type) - service_class = self.get_service_type(service_type) - if not service_class: - raise TypeError( - f"Cannot load service with unregistered service " - f"type: {service_type}" - ) - service = cast("BaseService", service_class.from_dict(service_dict)) - return service - - def load_service_from_json(self, json_str: str) -> "BaseService": - """Load a service instance from its JSON representation. - - Creates and returns a service instantiated from the JSON serialized - service configuration and last known status information. - - Args: - json_str: JSON string representation of the service configuration - and last known status - - Returns: - A ZenML service instance. - """ - service_dict = json.loads(json_str) - return self.load_service_from_dict(service_dict) diff --git a/src/zenml/services/service_status.py b/src/zenml/services/service_status.py index 368a96f90f0..fc21e3f328e 100644 --- a/src/zenml/services/service_status.py +++ b/src/zenml/services/service_status.py @@ -25,11 +25,12 @@ class ServiceState(StrEnum): """Possible states for the service and service endpoint.""" + INACTIVE = "inactive" ACTIVE = "active" PENDING_STARTUP = "pending_startup" - INACTIVE = "inactive" PENDING_SHUTDOWN = "pending_shutdown" ERROR = "error" + SCALED_TO_ZERO = "scaled_to_zero" class ServiceStatus(BaseTypedModel): diff --git a/src/zenml/services/service_type.py b/src/zenml/services/service_type.py index 8942c87bbda..a83539d336d 100644 --- a/src/zenml/services/service_type.py +++ b/src/zenml/services/service_type.py @@ -24,12 +24,14 @@ class ServiceType(BaseModel): flavor: service flavor name: name of the service type description: description of the service type + logo_url: logo of the service type """ type: str flavor: str name: str = "" description: str = "" + logo_url: str = "" class Config: """Pydantic configuration class.""" diff --git a/src/zenml/utils/cloud_utils.py b/src/zenml/utils/cloud_utils.py new file mode 100644 index 00000000000..cad6b9dcb98 --- /dev/null +++ b/src/zenml/utils/cloud_utils.py @@ -0,0 +1,40 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Utilities for ZenML Cloud.""" + +from zenml.logger import get_logger +from zenml.models.v2.core.model_version import ModelVersionResponse +from zenml.utils.dashboard_utils import get_model_version_url + +logger = get_logger(__name__) + + +def is_cloud_model_version(model_version: ModelVersionResponse) -> bool: + """Check if a model version is from a ZenML Cloud server. + + Args: + model_version: The model version to check. + + Returns: + True if the model version is from a ZenML Cloud server, else False. + """ + model_version_url = get_model_version_url(model_version.id) + if model_version_url: + logger.info( + f"Dashboard URL for Model Version with name {model_version.name} " + f": {model_version_url}" + ) + return True + else: + return False diff --git a/src/zenml/utils/dashboard_utils.py b/src/zenml/utils/dashboard_utils.py index 172dfc5805b..23b59bdc3cc 100644 --- a/src/zenml/utils/dashboard_utils.py +++ b/src/zenml/utils/dashboard_utils.py @@ -14,6 +14,7 @@ """Utility class to help with interacting with the dashboard.""" from typing import Optional +from uuid import UUID from zenml import constants from zenml.client import Client @@ -34,6 +35,17 @@ def get_base_url() -> Optional[str]: client = Client() if client.zen_store.type == StoreType.REST: + # if the server config has a base URL use that + server_model = client.zen_store.get_store_info() + if server_model.base_url: + url = server_model.base_url + # if the base url has cloud.zenml.io in it, then it is a cloud + # deployment and there isn't a workspace in the URL + if "cloud.zenml.io" in url: + return url + return ( + url + f"{constants.WORKSPACES}/{client.active_workspace.name}" + ) url = ( client.zen_store.url + f"{constants.WORKSPACES}/{client.active_workspace.name}" @@ -85,8 +97,13 @@ def get_run_url(run: PipelineRunResponse) -> Optional[str]: Returns: the URL to the pipeline run if the dashboard is available, else None. """ + client = Client() base_url = get_base_url() if base_url: + server_model = client.zen_store.get_store_info() + # if the server is a zenml cloud tenant, use a different URL + if server_model.metadata.get("organization_id"): + return f"{base_url}{constants.RUNS}/{run.id}" if run.pipeline: return f"{base_url}{constants.PIPELINES}/{run.pipeline.id}{constants.RUNS}/{run.id}/dag" else: @@ -94,6 +111,28 @@ def get_run_url(run: PipelineRunResponse) -> Optional[str]: return None +def get_model_version_url(model_version_id: UUID) -> Optional[str]: + """Function to get the dashboard URL of a given model version. + + Args: + model_version_id: the id of the model version. + + Returns: + the URL to the model version if the dashboard is available, else None. + """ + client = Client() + server_model = client.zen_store.get_store_info() + # if organization_id exists as key in server_config.metadata + # only then output a URL. + if server_model.metadata.get("organization_id"): + base_url = get_base_url() + if base_url: + # TODO MODEL_VERSIONS resolves to /model_versions but on the + # cloud, the URL is /model-versions. This should be fixed? + return f"{base_url}/model-versions/{str(model_version_id)}" + return None + + def show_dashboard(url: str) -> None: """Show the ZenML dashboard at the given URL. diff --git a/src/zenml/utils/dict_utils.py b/src/zenml/utils/dict_utils.py index 5c14b548968..fe5e9fb6dfe 100644 --- a/src/zenml/utils/dict_utils.py +++ b/src/zenml/utils/dict_utils.py @@ -13,8 +13,12 @@ # permissions and limitations under the License. """Util functions for dictionaries.""" +import base64 +import json from typing import Any, Dict +from pydantic.json import pydantic_encoder + def recursive_update( original: Dict[str, Any], update: Dict[str, Any] @@ -69,3 +73,21 @@ def _maybe_recurse(value: Any) -> Any: return value return {k: _maybe_recurse(v) for k, v in dict_.items() if v is not None} + + +def dict_to_bytes(dict_: Dict[str, Any]) -> bytes: + """Converts a dictionary to bytes. + + Args: + dict_: The dictionary to convert. + + Returns: + The dictionary as bytes. + """ + return base64.b64encode( + json.dumps( + dict_, + sort_keys=False, + default=pydantic_encoder, + ).encode("utf-8") + ) diff --git a/src/zenml/utils/docker_utils.py b/src/zenml/utils/docker_utils.py index 225a1ab5e09..4b8097542dc 100644 --- a/src/zenml/utils/docker_utils.py +++ b/src/zenml/utils/docker_utils.py @@ -29,6 +29,7 @@ ) from docker.client import DockerClient +from docker.errors import DockerException from docker.utils import build as docker_build_utils from zenml.io import fileio @@ -227,7 +228,8 @@ def build_image( logger.info("Building the image might take a while...") - docker_client = DockerClient.from_env() + docker_client = _try_get_docker_client_from_env() + # We use the client api directly here, so we can stream the logs output_stream = docker_client.images.client.api.build( fileobj=build_context, @@ -258,7 +260,7 @@ def push_image( RuntimeError: If fetching the repository digest of the image failed. """ logger.info("Pushing Docker image `%s`.", image_name) - docker_client = docker_client or DockerClient.from_env() + docker_client = _try_get_docker_client_from_env() output_stream = docker_client.images.push(image_name, stream=True) aux_info = _process_stream(output_stream) logger.info("Finished pushing Docker image.") @@ -283,7 +285,7 @@ def tag_image(image_name: str, target: str) -> None: image_name: The name of the image to tag. target: The full target name including a tag. """ - docker_client = DockerClient.from_env() + docker_client = _try_get_docker_client_from_env() image = docker_client.images.get(image_name) image.tag(target) @@ -298,7 +300,8 @@ def get_image_digest(image_name: str) -> Optional[str]: Returns the repo digest for the given image if there exists exactly one. If there are zero or multiple repo digests, returns `None`. """ - docker_client = DockerClient.from_env() + docker_client = _try_get_docker_client_from_env() + image = docker_client.images.get(image_name) repo_digests = image.attrs["RepoDigests"] if len(repo_digests) == 1: @@ -321,7 +324,7 @@ def is_local_image(image_name: str) -> bool: Returns: `True` if the image was pulled from a registry, `False` otherwise. """ - docker_client = DockerClient.from_env() + docker_client = _try_get_docker_client_from_env() images = docker_client.images.list(name=image_name) if images: # An image with this name is available locally -> now check whether it @@ -333,6 +336,23 @@ def is_local_image(image_name: str) -> bool: return False +def _try_get_docker_client_from_env() -> DockerClient: + """Tries to create a Docker client from the environment. + + Raises: + RuntimeError: If creating a Docker client from the environment failed. + + Returns: + A Docker client created from the environment. + """ + try: + return DockerClient.from_env() + except DockerException as e: + raise RuntimeError( + "Could not create a Docker client from the environment. Is your Docker daemon running?" + ) from e + + def _process_stream(stream: Iterable[bytes]) -> List[Dict[str, Any]]: """Processes the output stream of a docker command call. diff --git a/src/zenml/utils/pipeline_docker_image_builder.py b/src/zenml/utils/pipeline_docker_image_builder.py index 32de37d04a5..88a090b8599 100644 --- a/src/zenml/utils/pipeline_docker_image_builder.py +++ b/src/zenml/utils/pipeline_docker_image_builder.py @@ -626,25 +626,24 @@ def _generate_zenml_pipeline_dockerfile( f"--no-install-recommends {apt_packages}" ) + if ( + docker_settings.python_package_installer + == PythonPackageInstaller.PIP + ): + install_command = "pip install --default-timeout=60" + elif ( + docker_settings.python_package_installer + == PythonPackageInstaller.UV + ): + lines.append("RUN pip install uv") + install_command = "uv pip install --system" + else: + raise ValueError("Unsupported python package installer.") + for file, _, options in requirements_files: lines.append(f"COPY {file} .") - option_string = " ".join(options) - if ( - docker_settings.python_package_installer - == PythonPackageInstaller.PIP - ): - install_command = "pip install --default-timeout=60" - elif ( - docker_settings.python_package_installer - == PythonPackageInstaller.UV - ): - lines.append("RUN pip install uv") - install_command = "uv pip install --system" - else: - raise ValueError("Unsupported python package installer.") - lines.append( f"RUN {install_command} --no-cache-dir " f"{option_string} -r {file}" diff --git a/src/zenml/utils/source_utils.py b/src/zenml/utils/source_utils.py index ea1fcfe6d0d..ba9e9e3e91f 100644 --- a/src/zenml/utils/source_utils.py +++ b/src/zenml/utils/source_utils.py @@ -231,7 +231,9 @@ def get_source_root() -> str: raise RuntimeError( "Unable to determine source root because the main module does not " "have an associated file. This could be because you're running in " - "an interactive Python environment." + "an interactive Python environment. If you are trying to run from " + "within a Jupyter notebook, please run `zenml init` from the root " + "where your notebook is located and restart your notebook server. " ) path = Path(main_module.__file__).resolve().parent diff --git a/src/zenml/zen_server/cloud_utils.py b/src/zenml/zen_server/cloud_utils.py new file mode 100644 index 00000000000..eabac1396de --- /dev/null +++ b/src/zenml/zen_server/cloud_utils.py @@ -0,0 +1,201 @@ +"""Utils concerning anything concerning the cloud control plane backend.""" + +import os +from typing import Any, Dict, Optional + +import requests +from pydantic import BaseModel, validator +from requests.adapters import HTTPAdapter, Retry + +from zenml.exceptions import SubscriptionUpgradeRequiredError + +ZENML_CLOUD_RBAC_ENV_PREFIX = "ZENML_CLOUD_" + + +class ZenMLCloudConfiguration(BaseModel): + """ZenML Cloud RBAC configuration.""" + + api_url: str + + oauth2_client_id: str + oauth2_client_secret: str + oauth2_audience: str + auth0_domain: str + + @validator("api_url") + def _strip_trailing_slashes_url(cls, url: str) -> str: + """Strip any trailing slashes on the API URL. + + Args: + url: The API URL. + + Returns: + The API URL with potential trailing slashes removed. + """ + return url.rstrip("/") + + @classmethod + def from_environment(cls) -> "ZenMLCloudConfiguration": + """Get the RBAC configuration from environment variables. + + Returns: + The RBAC configuration. + """ + env_config: Dict[str, Any] = {} + for k, v in os.environ.items(): + if v == "": + continue + if k.startswith(ZENML_CLOUD_RBAC_ENV_PREFIX): + env_config[k[len(ZENML_CLOUD_RBAC_ENV_PREFIX) :].lower()] = v + + return ZenMLCloudConfiguration(**env_config) + + class Config: + """Pydantic configuration class.""" + + # Allow extra attributes from configs of previous ZenML versions to + # permit downgrading + extra = "allow" + + +class ZenMLCloudSession: + """Class to use for communication between server and control plane.""" + + def __init__(self) -> None: + """Initialize the RBAC component.""" + self._config = ZenMLCloudConfiguration.from_environment() + self._session: Optional[requests.Session] = None + + def _get( + self, endpoint: str, params: Optional[Dict[str, Any]] + ) -> requests.Response: + """Send a GET request using the active session. + + Args: + endpoint: The endpoint to send the request to. This will be appended + to the base URL. + params: Parameters to include in the request. + + Raises: + RuntimeError: If the request failed. + SubscriptionUpgradeRequiredError: In case the current subscription + tier is insufficient for the attempted operation. + + Returns: + The response. + """ + url = self._config.api_url + endpoint + + response = self.session.get(url=url, params=params, timeout=7) + if response.status_code == 401: + # Refresh the auth token and try again + self._clear_session() + response = self.session.get(url=url, params=params, timeout=7) + + try: + response.raise_for_status() + except requests.HTTPError: + if response.status_code == 402: + raise SubscriptionUpgradeRequiredError(response.json()) + else: + raise RuntimeError( + f"Failed with the following error {response.json()}" + ) + + return response + + def _post( + self, + endpoint: str, + params: Optional[Dict[str, Any]] = None, + data: Optional[Dict[str, Any]] = None, + ) -> requests.Response: + """Send a POST request using the active session. + + Args: + endpoint: The endpoint to send the request to. This will be appended + to the base URL. + params: Parameters to include in the request. + data: Data to include in the request. + + Raises: + RuntimeError: If the request failed. + + Returns: + The response. + """ + url = self._config.api_url + endpoint + + response = self.session.post( + url=url, params=params, json=data, timeout=7 + ) + if response.status_code == 401: + # Refresh the auth token and try again + self._clear_session() + response = self.session.post( + url=url, params=params, json=data, timeout=7 + ) + + try: + response.raise_for_status() + except requests.HTTPError as e: + raise RuntimeError( + f"Failed while trying to contact the central zenml cloud " + f"service: {e}" + ) + + return response + + @property + def session(self) -> requests.Session: + """Authenticate to the ZenML Cloud API. + + Returns: + A requests session with the authentication token. + """ + if self._session is None: + self._session = requests.Session() + token = self._fetch_auth_token() + self._session.headers.update({"Authorization": "Bearer " + token}) + + retries = Retry(total=5, backoff_factor=0.1) + self._session.mount("https://", HTTPAdapter(max_retries=retries)) + + return self._session + + def _clear_session(self) -> None: + """Clear the authentication session.""" + self._session = None + + def _fetch_auth_token(self) -> str: + """Fetch an auth token for the Cloud API from auth0. + + Raises: + RuntimeError: If the auth token can't be fetched. + + Returns: + Auth token. + """ + # Get an auth token from auth0 + auth0_url = f"https://{self._config.auth0_domain}/oauth/token" + headers = {"content-type": "application/x-www-form-urlencoded"} + payload = { + "client_id": self._config.oauth2_client_id, + "client_secret": self._config.oauth2_client_secret, + "audience": self._config.oauth2_audience, + "grant_type": "client_credentials", + } + try: + response = requests.post( + auth0_url, headers=headers, data=payload, timeout=7 + ) + response.raise_for_status() + except Exception as e: + raise RuntimeError(f"Error fetching auth token from auth0: {e}") + + access_token = response.json().get("access_token", "") + + if not access_token or not isinstance(access_token, str): + raise RuntimeError("Could not fetch auth token from auth0.") + + return str(access_token) diff --git a/src/zenml/zen_server/deploy/docker/docker_provider.py b/src/zenml/zen_server/deploy/docker/docker_provider.py index aae7060bc96..2353ecf30ba 100644 --- a/src/zenml/zen_server/deploy/docker/docker_provider.py +++ b/src/zenml/zen_server/deploy/docker/docker_provider.py @@ -15,6 +15,7 @@ import shutil from typing import ClassVar, List, Optional, Tuple, Type, cast +from uuid import uuid4 from zenml.enums import ServerProviderType from zenml.logger import get_logger @@ -131,7 +132,9 @@ def _create_service( config=monitor_cfg, ), ) - service = DockerZenServer(config=service_config, endpoint=endpoint) + service = DockerZenServer( + uuid=uuid4(), config=service_config, endpoint=endpoint + ) service.start(timeout=timeout) return service diff --git a/src/zenml/zen_server/deploy/docker/docker_zen_server.py b/src/zenml/zen_server/deploy/docker/docker_zen_server.py index 188aed6f15f..58c02165833 100644 --- a/src/zenml/zen_server/deploy/docker/docker_zen_server.py +++ b/src/zenml/zen_server/deploy/docker/docker_zen_server.py @@ -132,14 +132,11 @@ def get_service(cls) -> Optional["DockerZenServer"]: The docker ZenML server service or None, if the docker server deployment is not found. """ - from zenml.services import ServiceRegistry - config_filename = os.path.join(cls.config_path(), "service.json") try: with open(config_filename, "r") as f: return cast( - DockerZenServer, - ServiceRegistry().load_service_from_json(f.read()), + "DockerZenServer", DockerZenServer.from_json(f.read()) ) except FileNotFoundError: return None diff --git a/src/zenml/zen_server/deploy/helm/Chart.yaml b/src/zenml/zen_server/deploy/helm/Chart.yaml index 673505615ab..e6bc01d1a2b 100644 --- a/src/zenml/zen_server/deploy/helm/Chart.yaml +++ b/src/zenml/zen_server/deploy/helm/Chart.yaml @@ -1,6 +1,6 @@ apiVersion: v2 name: zenml -version: "0.55.5" +version: "0.56.2" description: Open source MLOps framework for portable production ready ML pipelines keywords: - mlops diff --git a/src/zenml/zen_server/deploy/helm/README.md b/src/zenml/zen_server/deploy/helm/README.md index 2b228e3f33e..0f678b869bf 100644 --- a/src/zenml/zen_server/deploy/helm/README.md +++ b/src/zenml/zen_server/deploy/helm/README.md @@ -20,8 +20,8 @@ ZenML is an open-source MLOps framework designed to help you create robust, main To install the ZenML chart directly from Amazon ECR, use the following command: ```bash -# example command for version 0.55.5 -helm install my-zenml oci://public.ecr.aws/zenml/zenml --version 0.55.5 +# example command for version 0.56.2 +helm install my-zenml oci://public.ecr.aws/zenml/zenml --version 0.56.2 ``` Note: Ensure you have OCI support enabled in your Helm client and that you are authenticated with Amazon ECR. diff --git a/src/zenml/zen_server/deploy/helm/templates/server-db-job.yaml b/src/zenml/zen_server/deploy/helm/templates/server-db-job.yaml index 9af1688e6ea..e71d4fd1d49 100644 --- a/src/zenml/zen_server/deploy/helm/templates/server-db-job.yaml +++ b/src/zenml/zen_server/deploy/helm/templates/server-db-job.yaml @@ -110,16 +110,16 @@ spec: envFrom: - secretRef: name: {{ include "zenml.fullname" . }}-db-migration - {{- with .Values.resources }} - resources: - {{- toYaml . | nindent 12 }} + {{- with .Values.resources }} + resources: + {{- toYaml . | nindent 12 }} + {{- end }} + {{- with .Values.tolerations }} + tolerations: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} {{- end }} - {{- with .Values.tolerations }} - tolerations: - {{- toYaml . | nindent 8 }} - {{- end }} - {{- with .Values.nodeSelector }} - nodeSelector: - {{- toYaml . | nindent 8 }} - {{- end }} {{- end }} \ No newline at end of file diff --git a/src/zenml/zen_server/deploy/local/local_provider.py b/src/zenml/zen_server/deploy/local/local_provider.py index 3d8a9b8fe45..b380017ae7e 100644 --- a/src/zenml/zen_server/deploy/local/local_provider.py +++ b/src/zenml/zen_server/deploy/local/local_provider.py @@ -15,6 +15,7 @@ import shutil from typing import ClassVar, List, Optional, Tuple, Type, cast +from uuid import uuid4 from zenml import __version__ from zenml.enums import ServerProviderType @@ -61,6 +62,7 @@ def check_local_server_dependencies() -> None: try: # Make sure the ZenML Server dependencies are installed import fastapi # noqa + import fastapi_utils # noqa import jwt # noqa import multipart # noqa import uvicorn # noqa @@ -92,7 +94,6 @@ def _get_service_configuration( The service, service endpoint and endpoint monitor configuration. """ assert isinstance(server_config, LocalServerDeploymentConfig) - return ( LocalZenServerConfig( root_runtime_path=LocalZenServer.config_path(), @@ -156,7 +157,9 @@ def _create_service( config=monitor_cfg, ), ) - service = LocalZenServer(config=service_config, endpoint=endpoint) + service = LocalZenServer( + uuid=uuid4(), config=service_config, endpoint=endpoint + ) service.start(timeout=timeout) return service diff --git a/src/zenml/zen_server/deploy/local/local_zen_server.py b/src/zenml/zen_server/deploy/local/local_zen_server.py index 6425b2829bc..8f5041d9de1 100644 --- a/src/zenml/zen_server/deploy/local/local_zen_server.py +++ b/src/zenml/zen_server/deploy/local/local_zen_server.py @@ -127,14 +127,11 @@ def get_service(cls) -> Optional["LocalZenServer"]: The local ZenML server service or None, if the local server deployment is not found. """ - from zenml.services import ServiceRegistry - config_filename = os.path.join(cls.config_path(), "service.json") try: with open(config_filename, "r") as f: return cast( - LocalZenServer, - ServiceRegistry().load_service_from_json(f.read()), + "LocalZenServer", LocalZenServer.from_json(f.read()) ) except FileNotFoundError: return None diff --git a/src/zenml/zen_server/deploy/terraform/providers/terraform_provider.py b/src/zenml/zen_server/deploy/terraform/providers/terraform_provider.py index 0215e7d929c..7f25e4fb87d 100644 --- a/src/zenml/zen_server/deploy/terraform/providers/terraform_provider.py +++ b/src/zenml/zen_server/deploy/terraform/providers/terraform_provider.py @@ -15,6 +15,7 @@ import os from typing import ClassVar, List, Optional, Tuple, Type, cast +from uuid import uuid4 from zenml.config.global_config import GlobalConfiguration from zenml.logger import get_logger @@ -153,7 +154,7 @@ def _create_service( monitor_cfg, ) = self._get_service_configuration(config) - service = TerraformZenServer(config=service_config) + service = TerraformZenServer(uuid=uuid4(), config=service_config) service.start(timeout=timeout) return service diff --git a/src/zenml/zen_server/deploy/terraform/terraform_zen_server.py b/src/zenml/zen_server/deploy/terraform/terraform_zen_server.py index 1b1441ddaf0..61b838afdd9 100644 --- a/src/zenml/zen_server/deploy/terraform/terraform_zen_server.py +++ b/src/zenml/zen_server/deploy/terraform/terraform_zen_server.py @@ -184,13 +184,10 @@ def get_service(cls) -> Optional["TerraformZenServer"]: The terraform ZenML server service or None, if the terraform server deployment is not found. """ - from zenml.services import ServiceRegistry - try: with open(TERRAFORM_ZENML_SERVER_CONFIG_FILENAME, "r") as f: return cast( - TerraformZenServer, - ServiceRegistry().load_service_from_json(f.read()), + TerraformZenServer, TerraformZenServer.from_json(f.read()) ) except FileNotFoundError: return None diff --git a/src/zenml/zen_server/exceptions.py b/src/zenml/zen_server/exceptions.py index 31d3464d82d..0a3d379fc93 100644 --- a/src/zenml/zen_server/exceptions.py +++ b/src/zenml/zen_server/exceptions.py @@ -27,6 +27,7 @@ SecretExistsError, StackComponentExistsError, StackExistsError, + SubscriptionUpgradeRequiredError, ValidationError, ZenKeyError, ) @@ -77,6 +78,8 @@ class ErrorModel(BaseModel): (IllegalOperationError, 403), # 401 Unauthorized (AuthorizationException, 401), + # 402 Payment required + (SubscriptionUpgradeRequiredError, 402), # 404 Not Found (DoesNotExistException, 404), (ZenKeyError, 404), diff --git a/src/zenml/zen_server/feature_gate/__init__.py b/src/zenml/zen_server/feature_gate/__init__.py new file mode 100644 index 00000000000..b6bdfa91873 --- /dev/null +++ b/src/zenml/zen_server/feature_gate/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. \ No newline at end of file diff --git a/src/zenml/zen_server/feature_gate/endpoint_utils.py b/src/zenml/zen_server/feature_gate/endpoint_utils.py new file mode 100644 index 00000000000..3b509e9a494 --- /dev/null +++ b/src/zenml/zen_server/feature_gate/endpoint_utils.py @@ -0,0 +1,59 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""All endpoint utils for the feature gate implementations.""" + +from uuid import UUID + +from zenml.zen_server.rbac.models import ResourceType +from zenml.zen_server.utils import feature_gate, server_config + + +def check_entitlement(resource_type: ResourceType) -> None: + """Queries the feature gate to see if the operation falls within the tenants entitlements. + + Raises an exception if the user is not entitled to create an instance of the + resource. Otherwise, simply returns. + + Args: + resource_type: The type of resource to check for. + """ + if not server_config().feature_gate_enabled: + return + return feature_gate().check_entitlement(resource=resource_type) + + +def report_usage(resource_type: ResourceType, resource_id: UUID) -> None: + """Reports the creation/usage of a feature/resource. + + Args: + resource_type: The type of resource to report a usage for + resource_id: ID of the resource that was created. + """ + if not server_config().feature_gate_enabled: + return + feature_gate().report_event( + resource=resource_type, resource_id=resource_id + ) + + +def report_decrement(resource_type: ResourceType, resource_id: UUID) -> None: + """Reports the deletion/deactivation of a feature/resource. + + Args: + resource_type: The type of resource to report a decrement in count for. + resource_id: ID of the resource that was deleted. + """ + feature_gate().report_event( + resource=resource_type, resource_id=resource_id, is_decrement=True + ) diff --git a/src/zenml/zen_server/feature_gate/feature_gate_interface.py b/src/zenml/zen_server/feature_gate/feature_gate_interface.py new file mode 100644 index 00000000000..df4a5d3fc70 --- /dev/null +++ b/src/zenml/zen_server/feature_gate/feature_gate_interface.py @@ -0,0 +1,49 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Definition of the feature gate interface.""" + +from abc import ABC, abstractmethod +from uuid import UUID + +from zenml.zen_server.rbac.models import ResourceType + + +class FeatureGateInterface(ABC): + """RBAC interface definition.""" + + @abstractmethod + def check_entitlement(self, resource: ResourceType) -> None: + """Checks if a user is entitled to create a resource. + + Args: + resource: The resource the user wants to create + + Raises: + UpgradeRequiredError in case a subscription limit is reached + """ + + @abstractmethod + def report_event( + self, + resource: ResourceType, + resource_id: UUID, + is_decrement: bool = False, + ) -> None: + """Reports the usage of a feature to the aggregator backend. + + Args: + resource: The resource the user created + resource_id: ID of the resource that was created/deleted. + is_decrement: In case this event reports an actual decrement of usage + """ diff --git a/src/zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py b/src/zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py new file mode 100644 index 00000000000..f928539ad4b --- /dev/null +++ b/src/zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py @@ -0,0 +1,119 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""ZenML Cloud implementation of the feature gate.""" + +from typing import Any, Dict +from uuid import UUID + +from pydantic import BaseModel, Field + +from zenml.config.server_config import ServerConfiguration +from zenml.exceptions import SubscriptionUpgradeRequiredError +from zenml.logger import get_logger +from zenml.zen_server.cloud_utils import ZenMLCloudSession +from zenml.zen_server.feature_gate.feature_gate_interface import ( + FeatureGateInterface, +) +from zenml.zen_server.rbac.models import ResourceType + +logger = get_logger(__name__) + +server_config = ServerConfiguration.get_server_config() + +ORGANIZATION_ID = server_config.metadata.get("organization_id", "unknown") + +USAGE_EVENT_ENDPOINT = "/usage-event" +ENTITLEMENT_ENDPOINT = f"/organizations/{ORGANIZATION_ID}/entitlement" + + +class RawUsageEvent(BaseModel): + """Model for reporting raw usage of a feature. + + In case of consumables the UsageReport allows the Pricing Backend to + increment the usage per time-frame by 1. + """ + + organization_id: str = Field( + description="The organization that this usage can be attributed to.", + ) + feature: ResourceType = Field( + description="The feature whose usage is being reported.", + ) + total: int = Field( + description="The total amount of entities of this type." + ) + metadata: Dict[str, Any] = Field( + default={}, + description="Allows attaching additional metadata to events.", + ) + + +class ZenMLCloudFeatureGateInterface(FeatureGateInterface, ZenMLCloudSession): + """Feature Gate interface definition.""" + + def check_entitlement(self, resource: ResourceType) -> None: + """Checks if a user is entitled to create a resource. + + Args: + resource: The resource the user wants to create + + Raises: + SubscriptionUpgradeRequiredError: in case a subscription limit is reached + """ + try: + response = self._get( + endpoint=ENTITLEMENT_ENDPOINT + "/" + resource, params=None + ) + except SubscriptionUpgradeRequiredError: + raise SubscriptionUpgradeRequiredError( + f"Your subscription reached its `{resource}` limit. Please " + f"upgrade your subscription or reach out to us." + ) + + if response.status_code != 200: + logger.warning( + "Unexpected response status code from entitlement " + f"endpoint: {response.status_code}. Message: " + f"{response.json()}" + ) + + def report_event( + self, + resource: ResourceType, + resource_id: UUID, + is_decrement: bool = False, + ) -> None: + """Reports the usage of a feature to the aggregator backend. + + Args: + resource: The resource the user created + resource_id: ID of the resource that was created/deleted. + is_decrement: In case this event reports an actual decrement of usage + """ + data = RawUsageEvent( + organization_id=ORGANIZATION_ID, + feature=resource, + total=1 if not is_decrement else -1, + metadata={ + "tenant_id": str(server_config.external_server_id), + "resource_id": str(resource_id), + }, + ).dict() + response = self._post(endpoint=USAGE_EVENT_ENDPOINT, data=data) + if response.status_code != 200: + logger.error( + "Usage report not accepted by upstream backend. " + f"Status Code: {response.status_code}, Message: " + f"{response.json()}." + ) diff --git a/src/zenml/zen_server/rate_limit.py b/src/zenml/zen_server/rate_limit.py new file mode 100644 index 00000000000..520025778d6 --- /dev/null +++ b/src/zenml/zen_server/rate_limit.py @@ -0,0 +1,184 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Rate limiting for the ZenML Server.""" + +import inspect +import time +from collections import defaultdict +from functools import wraps +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + TypeVar, + cast, +) + +from starlette.requests import Request + +from zenml.logger import get_logger +from zenml.zen_server.utils import server_config + +logger = get_logger(__name__) +F = TypeVar("F", bound=Callable[..., Any]) + + +class RequestLimiter: + """Simple in-memory rate limiter.""" + + def __init__( + self, + day_limit: Optional[int] = None, + minute_limit: Optional[int] = None, + ): + """Initializes the limiter. + + Args: + day_limit: The number of requests allowed per day. + minute_limit: The number of requests allowed per minute. + + Raises: + ValueError: If both day_limit and minute_limit are None. + """ + self.limiting_enabled = server_config().rate_limit_enabled + if not self.limiting_enabled: + return + if day_limit is None and minute_limit is None: + raise ValueError("Pass either day or minuter limits, or both.") + self.day_limit = day_limit + self.minute_limit = minute_limit + self.limiter: Dict[str, List[float]] = defaultdict(list) + + def hit_limiter(self, request: Request) -> None: + """Increase the number of hits in the limiter. + + Args: + request: Request object. + + Raises: + HTTPException: If the request limit is exceeded. + """ + if not self.limiting_enabled: + return + from fastapi import HTTPException + + requester = self._get_ipaddr(request) + now = time.time() + minute_ago = now - 60 + day_ago = now - 60 * 60 * 24 + self.limiter[requester].append(now) + + from bisect import bisect_left + + # remove failures older than a day + older_index = bisect_left(self.limiter[requester], day_ago) + self.limiter[requester] = self.limiter[requester][older_index:] + + if self.day_limit and len(self.limiter[requester]) > self.day_limit: + raise HTTPException( + status_code=429, detail="Daily request limit exceeded." + ) + minute_requests = len( + [ + limiter_hit + for limiter_hit in self.limiter[requester][::-1] + if limiter_hit >= minute_ago + ] + ) + if self.minute_limit and minute_requests > self.minute_limit: + raise HTTPException( + status_code=429, detail="Minute request limit exceeded." + ) + + def reset_limiter(self, request: Request) -> None: + """Resets the limiter on successful request. + + Args: + request: Request object. + """ + if self.limiting_enabled: + requester = self._get_ipaddr(request) + if requester in self.limiter: + del self.limiter[requester] + + def _get_ipaddr(self, request: Request) -> str: + """Returns the IP address for the current request. + + Based on the X-Forwarded-For headers or client information. + + Args: + request: The request object. + + Returns: + The ip address for the current request (or 127.0.0.1 if none found). + """ + if "X_FORWARDED_FOR" in request.headers: + return request.headers["X_FORWARDED_FOR"] + else: + if not request.client or not request.client.host: + return "127.0.0.1" + + return request.client.host + + +def rate_limit_requests( + day_limit: Optional[int] = None, + minute_limit: Optional[int] = None, +) -> Callable[..., Any]: + """Decorator to handle exceptions in the API. + + Args: + day_limit: Number of requests allowed per day. + minute_limit: Number of requests allowed per minute. + + Returns: + Decorated function. + """ + limiter = RequestLimiter(day_limit=day_limit, minute_limit=minute_limit) + + def decorator(func: F) -> F: + request_arg, request_kwarg = None, None + parameters = inspect.signature(func).parameters + for arg_num, arg_name in enumerate(parameters): + if parameters[arg_name].annotation == Request: + request_arg = arg_num + request_kwarg = arg_name + break + if request_arg is None or request_kwarg is None: + raise ValueError( + "Rate limiting APIs must have argument of `Request` type." + ) + + @wraps(func) + def decorated( + *args: Any, + **kwargs: Any, + ) -> Any: + if request_kwarg in kwargs: + request = kwargs[request_kwarg] + else: + request = args[request_arg] + limiter.hit_limiter(request) + + ret = func(*args, **kwargs) + + # if request was successful - reset limiter + limiter.reset_limiter(request) + return ret + + return cast(F, decorated) + + return decorator diff --git a/src/zenml/zen_server/rbac/endpoint_utils.py b/src/zenml/zen_server/rbac/endpoint_utils.py index 6cc78ddcc97..1f8abe8d6ea 100644 --- a/src/zenml/zen_server/rbac/endpoint_utils.py +++ b/src/zenml/zen_server/rbac/endpoint_utils.py @@ -1,3 +1,16 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. """High-level helper functions to write endpoints with RBAC.""" from typing import Any, Callable, TypeVar, Union @@ -5,6 +18,10 @@ from pydantic import BaseModel +from zenml.constants import ( + REPORTABLE_RESOURCES, + REQUIRES_CUSTOM_RESOURCE_REPORTING, +) from zenml.exceptions import IllegalOperationError from zenml.models import ( BaseFilter, @@ -14,6 +31,10 @@ UserScopedRequest, ) from zenml.zen_server.auth import get_auth_context +from zenml.zen_server.feature_gate.endpoint_utils import ( + check_entitlement, + report_usage, +) from zenml.zen_server.rbac.models import Action, ResourceType from zenml.zen_server.rbac.utils import ( dehydrate_page, @@ -58,12 +79,21 @@ def verify_permissions_and_create_entity( f"Not allowed to create resource '{resource_type}' for a " "different user." ) + verify_permission(resource_type=resource_type, action=Action.CREATE) - verify_permission( - resource_type=resource_type, - action=Action.CREATE, + needs_usage_increment = ( + resource_type in REPORTABLE_RESOURCES + and resource_type not in REQUIRES_CUSTOM_RESOURCE_REPORTING ) - return create_method(request_model) + if needs_usage_increment: + check_entitlement(resource_type) + + created = create_method(request_model) + + if needs_usage_increment: + report_usage(resource_type, resource_id=created.id) + + return created def verify_permissions_and_get_entity( @@ -141,18 +171,23 @@ def verify_permissions_and_delete_entity( id: UUIDOrStr, get_method: Callable[[UUIDOrStr], AnyResponse], delete_method: Callable[[UUIDOrStr], None], -) -> None: +) -> AnyResponse: """Verify permissions and delete an entity. Args: id: The ID of the entity to delete. get_method: The method to fetch the entity. delete_method: The method to delete the entity. + + Returns: + The deleted entity. """ model = get_method(id) verify_permission_for_model(model, action=Action.DELETE) delete_method(model.id) + return model + def verify_permissions_and_prune_entities( resource_type: ResourceType, diff --git a/src/zenml/zen_server/rbac/models.py b/src/zenml/zen_server/rbac/models.py index 4a7459db1a5..eb136685e77 100644 --- a/src/zenml/zen_server/rbac/models.py +++ b/src/zenml/zen_server/rbac/models.py @@ -58,6 +58,7 @@ class ResourceType(StrEnum): PIPELINE_DEPLOYMENT = "pipeline_deployment" PIPELINE_BUILD = "pipeline_build" USER = "user" + SERVICE = "service" RUN_METADATA = "run_metadata" SECRET = "secret" SERVICE_ACCOUNT = "service_account" diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index da64e417899..692b7f8d89c 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -400,6 +400,7 @@ def get_resource_type_for_model( SecretResponse, ServiceAccountResponse, ServiceConnectorResponse, + ServiceResponse, StackResponse, TagResponse, UserResponse, @@ -429,6 +430,7 @@ def get_resource_type_for_model( PipelineRunResponse: ResourceType.PIPELINE_RUN, TagResponse: ResourceType.TAG, ServiceAccountResponse: ResourceType.SERVICE_ACCOUNT, + ServiceResponse: ResourceType.SERVICE, } return mapping.get(type(model)) @@ -536,6 +538,7 @@ def get_schema_for_resource_type( RunMetadataSchema, SecretSchema, ServiceConnectorSchema, + ServiceSchema, StackComponentSchema, StackSchema, TagSchema, @@ -555,6 +558,7 @@ def get_schema_for_resource_type( ResourceType.ARTIFACT: ArtifactSchema, ResourceType.ARTIFACT_VERSION: ArtifactVersionSchema, ResourceType.SECRET: SecretSchema, + ResourceType.SERVICE: ServiceSchema, ResourceType.TAG: TagSchema, ResourceType.SERVICE_ACCOUNT: UserSchema, ResourceType.WORKSPACE: WorkspaceSchema, diff --git a/src/zenml/zen_server/rbac/zenml_cloud_rbac.py b/src/zenml/zen_server/rbac/zenml_cloud_rbac.py index deeed246c51..fd534b313a9 100644 --- a/src/zenml/zen_server/rbac/zenml_cloud_rbac.py +++ b/src/zenml/zen_server/rbac/zenml_cloud_rbac.py @@ -13,13 +13,9 @@ # permissions and limitations under the License. """Cloud RBAC implementation.""" -import os -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple - -import requests -from pydantic import BaseModel, validator -from requests.adapters import HTTPAdapter, Retry +from typing import TYPE_CHECKING, Dict, List, Set, Tuple +from zenml.zen_server.cloud_utils import ZenMLCloudSession from zenml.zen_server.rbac.models import Action, Resource from zenml.zen_server.rbac.rbac_interface import RBACInterface from zenml.zen_server.utils import server_config @@ -28,7 +24,6 @@ from zenml.models import UserResponse -ZENML_CLOUD_RBAC_ENV_PREFIX = "ZENML_CLOUD_" PERMISSIONS_ENDPOINT = "/rbac/check_permissions" ALLOWED_RESOURCE_IDS_ENDPOINT = "/rbac/allowed_resource_ids" RESOURCE_MEMBERSHIP_ENDPOINT = "/rbac/resource_members" @@ -79,60 +74,9 @@ def _convert_from_cloud_resource(cloud_resource: str) -> Resource: return Resource(type=resource_type_and_id) -class ZenMLCloudRBACConfiguration(BaseModel): - """ZenML Cloud RBAC configuration.""" - - api_url: str - - oauth2_client_id: str - oauth2_client_secret: str - oauth2_audience: str - auth0_domain: str - - @validator("api_url") - def _strip_trailing_slashes_url(cls, url: str) -> str: - """Strip any trailing slashes on the API URL. - - Args: - url: The API URL. - - Returns: - The API URL with potential trailing slashes removed. - """ - return url.rstrip("/") - - @classmethod - def from_environment(cls) -> "ZenMLCloudRBACConfiguration": - """Get the RBAC configuration from environment variables. - - Returns: - The RBAC configuration. - """ - env_config: Dict[str, Any] = {} - for k, v in os.environ.items(): - if v == "": - continue - if k.startswith(ZENML_CLOUD_RBAC_ENV_PREFIX): - env_config[k[len(ZENML_CLOUD_RBAC_ENV_PREFIX) :].lower()] = v - - return ZenMLCloudRBACConfiguration(**env_config) - - class Config: - """Pydantic configuration class.""" - - # Allow extra attributes from configs of previous ZenML versions to - # permit downgrading - extra = "allow" - - -class ZenMLCloudRBAC(RBACInterface): +class ZenMLCloudRBAC(RBACInterface, ZenMLCloudSession): """RBAC implementation that uses the ZenML Cloud API as a backend.""" - def __init__(self) -> None: - """Initialize the RBAC component.""" - self._config = ZenMLCloudRBACConfiguration.from_environment() - self._session: Optional[requests.Session] = None - def check_permissions( self, user: "UserResponse", resources: Set[Resource], action: Action ) -> Dict[Resource, bool]: @@ -234,129 +178,3 @@ def update_resource_membership( "actions": [str(action) for action in actions], } self._post(endpoint=RESOURCE_MEMBERSHIP_ENDPOINT, data=data) - - def _get(self, endpoint: str, params: Dict[str, Any]) -> requests.Response: - """Send a GET request using the active session. - - Args: - endpoint: The endpoint to send the request to. This will be appended - to the base URL. - params: Parameters to include in the request. - - Raises: - RuntimeError: If the request failed. - - Returns: - The response. - """ - url = self._config.api_url + endpoint - - response = self.session.get(url=url, params=params, timeout=7) - if response.status_code == 401: - # Refresh the auth token and try again - self._clear_session() - response = self.session.get(url=url, params=params, timeout=7) - - try: - response.raise_for_status() - except requests.HTTPError as e: - raise RuntimeError( - f"Failed while trying to contact RBAC service: {e}" - ) - - return response - - def _post( - self, - endpoint: str, - params: Optional[Dict[str, Any]] = None, - data: Optional[Dict[str, Any]] = None, - ) -> requests.Response: - """Send a POST request using the active session. - - Args: - endpoint: The endpoint to send the request to. This will be appended - to the base URL. - params: Parameters to include in the request. - data: Data to include in the request. - - Raises: - RuntimeError: If the request failed. - - Returns: - The response. - """ - url = self._config.api_url + endpoint - - response = self.session.post( - url=url, params=params, json=data, timeout=7 - ) - if response.status_code == 401: - # Refresh the auth token and try again - self._clear_session() - response = self.session.post( - url=url, params=params, json=data, timeout=7 - ) - - try: - response.raise_for_status() - except requests.HTTPError as e: - raise RuntimeError( - f"Failed while trying to contact RBAC service: {e}" - ) - - return response - - @property - def session(self) -> requests.Session: - """Authenticate to the ZenML Cloud API. - - Returns: - A requests session with the authentication token. - """ - if self._session is None: - self._session = requests.Session() - token = self._fetch_auth_token() - self._session.headers.update({"Authorization": "Bearer " + token}) - - retries = Retry(total=5, backoff_factor=0.1) - self._session.mount("https://", HTTPAdapter(max_retries=retries)) - - return self._session - - def _clear_session(self) -> None: - """Clear the authentication session.""" - self._session = None - - def _fetch_auth_token(self) -> str: - """Fetch an auth token for the Cloud API from auth0. - - Raises: - RuntimeError: If the auth token can't be fetched. - - Returns: - Auth token. - """ - # Get an auth token from auth0 - auth0_url = f"https://{self._config.auth0_domain}/oauth/token" - headers = {"content-type": "application/x-www-form-urlencoded"} - payload = { - "client_id": self._config.oauth2_client_id, - "client_secret": self._config.oauth2_client_secret, - "audience": self._config.oauth2_audience, - "grant_type": "client_credentials", - } - try: - response = requests.post( - auth0_url, headers=headers, data=payload, timeout=7 - ) - response.raise_for_status() - except Exception as e: - raise RuntimeError(f"Error fetching auth token from auth0: {e}") - - access_token = response.json().get("access_token", "") - - if not access_token or not isinstance(access_token, str): - raise RuntimeError("Could not fetch auth token from auth0.") - - return str(access_token) diff --git a/src/zenml/zen_server/routers/auth_endpoints.py b/src/zenml/zen_server/routers/auth_endpoints.py index f6b2d289ecb..41137a1e18a 100644 --- a/src/zenml/zen_server/routers/auth_endpoints.py +++ b/src/zenml/zen_server/routers/auth_endpoints.py @@ -65,6 +65,7 @@ ) from zenml.zen_server.exceptions import error_response from zenml.zen_server.jwt import JWTToken +from zenml.zen_server.rate_limit import rate_limit_requests from zenml.zen_server.rbac.models import Action, ResourceType from zenml.zen_server.rbac.utils import verify_permission from zenml.zen_server.utils import ( @@ -255,6 +256,10 @@ def generate_access_token( LOGIN, response_model=Union[OAuthTokenResponse, OAuthRedirectResponse], ) +@rate_limit_requests( + day_limit=server_config().login_rate_limit_day, + minute_limit=server_config().login_rate_limit_minute, +) @handle_exceptions def token( request: Request, diff --git a/src/zenml/zen_server/routers/models_endpoints.py b/src/zenml/zen_server/routers/models_endpoints.py index c43026d0bdf..124660e5cfc 100644 --- a/src/zenml/zen_server/routers/models_endpoints.py +++ b/src/zenml/zen_server/routers/models_endpoints.py @@ -22,6 +22,7 @@ API, MODEL_VERSIONS, MODELS, + REPORTABLE_RESOURCES, VERSION_1, ) from zenml.models import ( @@ -34,6 +35,7 @@ ) from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response +from zenml.zen_server.feature_gate.endpoint_utils import report_decrement from zenml.zen_server.rbac.endpoint_utils import ( verify_permissions_and_delete_entity, verify_permissions_and_get_entity, @@ -48,6 +50,7 @@ from zenml.zen_server.utils import ( handle_exceptions, make_dependable, + server_config, zen_store, ) @@ -160,12 +163,16 @@ def delete_model( Args: model_name_or_id: The name or ID of the model to delete. """ - verify_permissions_and_delete_entity( + model = verify_permissions_and_delete_entity( id=model_name_or_id, get_method=zen_store().get_model, delete_method=zen_store().delete_model, ) + if server_config().feature_gate_enabled: + if ResourceType.MODEL in REPORTABLE_RESOURCES: + report_decrement(ResourceType.MODEL, resource_id=model.id) + ################# # Model Versions diff --git a/src/zenml/zen_server/routers/pipelines_endpoints.py b/src/zenml/zen_server/routers/pipelines_endpoints.py index f4e5ad27808..fb1510ac772 100644 --- a/src/zenml/zen_server/routers/pipelines_endpoints.py +++ b/src/zenml/zen_server/routers/pipelines_endpoints.py @@ -18,7 +18,14 @@ from fastapi import APIRouter, Depends, Security from zenml.config.pipeline_spec import PipelineSpec -from zenml.constants import API, PIPELINE_SPEC, PIPELINES, RUNS, VERSION_1 +from zenml.constants import ( + API, + PIPELINE_SPEC, + PIPELINES, + REPORTABLE_RESOURCES, + RUNS, + VERSION_1, +) from zenml.models import ( Page, PipelineFilter, @@ -31,6 +38,7 @@ ) from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response +from zenml.zen_server.feature_gate.endpoint_utils import report_decrement from zenml.zen_server.rbac.endpoint_utils import ( verify_permissions_and_delete_entity, verify_permissions_and_get_entity, @@ -154,12 +162,20 @@ def delete_pipeline( Args: pipeline_id: ID of the pipeline to delete. """ - verify_permissions_and_delete_entity( + pipeline = verify_permissions_and_delete_entity( id=pipeline_id, get_method=zen_store().get_pipeline, delete_method=zen_store().delete_pipeline, ) + should_decrement = ( + ResourceType.PIPELINE in REPORTABLE_RESOURCES + and zen_store().count_pipelines(PipelineFilter(name=pipeline.name)) + == 0 + ) + if should_decrement: + report_decrement(ResourceType.PIPELINE, resource_id=pipeline_id) + @router.get( "/{pipeline_id}" + RUNS, diff --git a/src/zenml/zen_server/routers/service_endpoints.py b/src/zenml/zen_server/routers/service_endpoints.py new file mode 100644 index 00000000000..1d7925494df --- /dev/null +++ b/src/zenml/zen_server/routers/service_endpoints.py @@ -0,0 +1,180 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Endpoint definitions for services.""" + +from uuid import UUID + +from fastapi import APIRouter, Depends, Security + +from zenml.constants import API, SERVICES, VERSION_1 +from zenml.models import ( + Page, + ServiceFilter, + ServiceResponse, + ServiceUpdate, +) +from zenml.models.v2.core.service import ServiceRequest +from zenml.zen_server.auth import AuthContext, authorize +from zenml.zen_server.exceptions import error_response +from zenml.zen_server.rbac.endpoint_utils import ( + verify_permissions_and_create_entity, + verify_permissions_and_delete_entity, + verify_permissions_and_get_entity, + verify_permissions_and_list_entities, + verify_permissions_and_update_entity, +) +from zenml.zen_server.rbac.models import ResourceType +from zenml.zen_server.utils import ( + handle_exceptions, + make_dependable, + zen_store, +) + +router = APIRouter( + prefix=API + VERSION_1 + SERVICES, + tags=["services"], + responses={401: error_response, 403: error_response}, +) + + +@router.post( + "", + response_model=ServiceResponse, + responses={401: error_response, 422: error_response}, +) +@handle_exceptions +def create_service( + service: ServiceRequest, + _: AuthContext = Security(authorize), +) -> ServiceResponse: + """Creates a new service. + + Args: + service: The model containing the attributes of the new service. + + Returns: + The created service object. + """ + return verify_permissions_and_create_entity( + request_model=service, + create_method=zen_store().create_service, + resource_type=ResourceType.SERVICE, + ) + + +@router.get( + "", + response_model=Page[ServiceResponse], + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def list_services( + filter_model: ServiceFilter = Depends(make_dependable(ServiceFilter)), + hydrate: bool = False, + _: AuthContext = Security(authorize), +) -> Page[ServiceResponse]: + """Gets a page of service objects. + + Args: + filter_model: Filter model used for pagination, sorting, + filtering. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + Page of service objects. + """ + return verify_permissions_and_list_entities( + filter_model=filter_model, + resource_type=ResourceType.SERVICE, + list_method=zen_store().list_services, + hydrate=hydrate, + ) + + +@router.get( + "/{service_id}", + response_model=ServiceResponse, + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def get_service( + service_id: UUID, + hydrate: bool = True, + _: AuthContext = Security(authorize), +) -> ServiceResponse: + """Gets a specific service using its unique ID. + + Args: + service_id: The ID of the service to get. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + A specific service object. + """ + return verify_permissions_and_get_entity( + id=service_id, + get_method=zen_store().get_service, + hydrate=hydrate, + ) + + +@router.put( + "/{service_id}", + response_model=ServiceResponse, + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def update_service( + service_id: UUID, + update: ServiceUpdate, + _: AuthContext = Security(authorize), +) -> ServiceResponse: + """Updates a service. + + Args: + service_id: The ID of the service to update. + update: The model containing the attributes to update. + + Returns: + The updated service object. + """ + return verify_permissions_and_update_entity( + id=service_id, + update_model=update, + get_method=zen_store().get_service, + update_method=zen_store().update_service, + ) + + +@router.delete( + "/{service_id}", + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def delete_service( + service_id: UUID, + _: AuthContext = Security(authorize), +) -> None: + """Deletes a specific service. + + Args: + service_id: The ID of the service to delete. + """ + verify_permissions_and_delete_entity( + id=service_id, + get_method=zen_store().get_service, + delete_method=zen_store().delete_service, + ) diff --git a/src/zenml/zen_server/routers/webhook_endpoints.py b/src/zenml/zen_server/routers/webhook_endpoints.py index 2f115a535b1..9cb167a5d34 100644 --- a/src/zenml/zen_server/routers/webhook_endpoints.py +++ b/src/zenml/zen_server/routers/webhook_endpoints.py @@ -13,9 +13,10 @@ # permissions and limitations under the License. """Endpoint definitions for webhooks.""" +from typing import Dict from uuid import UUID -from fastapi import APIRouter, Depends, Request +from fastapi import APIRouter, BackgroundTasks, Depends, Request from zenml.constants import API, VERSION_1, WEBHOOKS from zenml.enums import PluginSubType, PluginType @@ -52,20 +53,26 @@ async def get_body(request: Request) -> bytes: @router.post( "/{event_source_id}", + response_model=Dict[str, str], ) @handle_exceptions def webhook( event_source_id: UUID, request: Request, + background_tasks: BackgroundTasks, raw_body: bytes = Depends(get_body), -) -> None: +) -> Dict[str, str]: """Webhook to receive events from external event sources. Args: event_source_id: The event_source_id request: The request object + background_tasks: Background task handler raw_body: The raw request body + Returns: + Static dict stating that event is received. + Raises: AuthorizationException: If the Event Source does not exist. KeyError: If no appropriate Plugin found in the plugin registry @@ -111,8 +118,11 @@ def webhook( ) # Pass the raw event and headers to the plugin - plugin.process_webhook_event( + background_tasks.add_task( + plugin.process_webhook_event, event_source=event_source, raw_body=raw_body, headers=dict(request.headers.items()), ) + + return {"status": "Event Received."} diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 042565b6275..4b747ae6f0c 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -28,12 +28,14 @@ PIPELINE_BUILDS, PIPELINE_DEPLOYMENTS, PIPELINES, + REPORTABLE_RESOURCES, RUN_METADATA, RUNS, SCHEDULES, SECRETS, SERVICE_CONNECTOR_RESOURCES, SERVICE_CONNECTORS, + SERVICES, STACK_COMPONENTS, STACKS, STATISTICS, @@ -80,6 +82,8 @@ ServiceConnectorRequest, ServiceConnectorResourcesModel, ServiceConnectorResponse, + ServiceRequest, + ServiceResponse, StackFilter, StackRequest, StackResponse, @@ -90,6 +94,10 @@ ) from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response +from zenml.zen_server.feature_gate.endpoint_utils import ( + check_entitlement, + report_usage, +) from zenml.zen_server.rbac.endpoint_utils import ( verify_permissions_and_create_entity, verify_permissions_and_delete_entity, @@ -509,12 +517,30 @@ def create_pipeline( f"not supported." ) - return verify_permissions_and_create_entity( + # We limit pipeline namespaces, not pipeline versions + needs_usage_increment = ( + ResourceType.PIPELINE in REPORTABLE_RESOURCES + and zen_store().count_pipelines(PipelineFilter(name=pipeline.name)) + == 0 + ) + + if needs_usage_increment: + check_entitlement(ResourceType.PIPELINE) + + pipeline_response = verify_permissions_and_create_entity( request_model=pipeline, resource_type=ResourceType.PIPELINE, create_method=zen_store().create_pipeline, ) + if needs_usage_increment: + report_usage( + resource_type=ResourceType.PIPELINE, + resource_id=pipeline_response.id, + ) + + return pipeline_response + @router.get( WORKSPACES + "/{workspace_name_or_id}" + PIPELINE_BUILDS, @@ -1431,3 +1457,44 @@ def create_model_version_pipeline_run_link( model_version_pipeline_run_link ) return mv + + +@router.post( + WORKSPACES + "/{workspace_name_or_id}" + SERVICES, + response_model=ServiceResponse, + responses={401: error_response, 409: error_response, 422: error_response}, +) +@handle_exceptions +def create_service( + workspace_name_or_id: Union[str, UUID], + service: ServiceRequest, + _: AuthContext = Security(authorize), +) -> ServiceResponse: + """Create a new service. + + Args: + workspace_name_or_id: Name or ID of the workspace. + service: The service to create. + + Returns: + The created service. + + Raises: + IllegalOperationError: If the workspace or user specified in the + model does not match the current workspace or authenticated + user. + """ + workspace = zen_store().get_workspace(workspace_name_or_id) + + if service.workspace != workspace.id: + raise IllegalOperationError( + "Creating models outside of the workspace scope " + f"of this endpoint `{workspace_name_or_id}` is " + f"not supported." + ) + + return verify_permissions_and_create_entity( + request_model=service, + resource_type=ResourceType.SERVICE, + create_method=zen_store().create_service, + ) diff --git a/src/zenml/zen_server/utils.py b/src/zenml/zen_server/utils.py index bc6be68e24b..2d6d9af132e 100644 --- a/src/zenml/zen_server/utils.py +++ b/src/zenml/zen_server/utils.py @@ -16,7 +16,15 @@ import inspect import os from functools import wraps -from typing import Any, Callable, Optional, Tuple, Type, TypeVar, cast +from typing import ( + Any, + Callable, + Optional, + Tuple, + Type, + TypeVar, + cast, +) from urllib.parse import urlparse from pydantic import BaseModel, ValidationError @@ -35,6 +43,9 @@ LocalServerDeploymentConfig, ) from zenml.zen_server.exceptions import http_exception_from_error +from zenml.zen_server.feature_gate.feature_gate_interface import ( + FeatureGateInterface, +) from zenml.zen_server.pipeline_deployment.workload_manager_interface import ( WorkloadManagerInterface, ) @@ -45,6 +56,7 @@ _zen_store: Optional["SqlZenStore"] = None _rbac: Optional[RBACInterface] = None +_feature_gate: Optional[FeatureGateInterface] = None _workload_manager: Optional[WorkloadManagerInterface] = None _plugin_flavor_registry: Optional[PluginFlavorRegistry] = None @@ -92,6 +104,50 @@ def rbac() -> RBACInterface: return _rbac +def initialize_rbac() -> None: + """Initialize the RBAC component.""" + global _rbac + + if rbac_source := server_config().rbac_implementation_source: + from zenml.utils import source_utils + + implementation_class = source_utils.load_and_validate_class( + rbac_source, expected_class=RBACInterface + ) + _rbac = implementation_class() + + +def feature_gate() -> FeatureGateInterface: + """Return the initialized Feature Gate component. + + Raises: + RuntimeError: If the RBAC component is not initialized. + + Returns: + The RBAC component. + """ + global _feature_gate + if _feature_gate is None: + raise RuntimeError("Feature gate component not initialized.") + return _feature_gate + + +def initialize_feature_gate() -> None: + """Initialize the Feature Gate component.""" + global _feature_gate + + if ( + feature_gate_source + := server_config().feature_gate_implementation_source + ): + from zenml.utils import source_utils + + implementation_class = source_utils.load_and_validate_class( + feature_gate_source, expected_class=FeatureGateInterface + ) + _feature_gate = implementation_class() + + def workload_manager() -> WorkloadManagerInterface: """Return the initialized workload manager component. @@ -107,19 +163,6 @@ def workload_manager() -> WorkloadManagerInterface: return _workload_manager -def initialize_rbac() -> None: - """Initialize the RBAC component.""" - global _rbac - - if rbac_source := server_config().rbac_implementation_source: - from zenml.utils import source_utils - - implementation_class = source_utils.load_and_validate_class( - rbac_source, expected_class=RBACInterface - ) - _rbac = implementation_class() - - def initialize_workload_manager() -> None: """Initialize the workload manager component. diff --git a/src/zenml/zen_server/zen_server_api.py b/src/zenml/zen_server/zen_server_api.py index 79c9c01f7c0..b5f01f940da 100644 --- a/src/zenml/zen_server/zen_server_api.py +++ b/src/zenml/zen_server/zen_server_api.py @@ -52,6 +52,7 @@ server_endpoints, service_accounts_endpoints, service_connectors_endpoints, + service_endpoints, stack_components_endpoints, stacks_endpoints, steps_endpoints, @@ -62,6 +63,7 @@ workspaces_endpoints, ) from zenml.zen_server.utils import ( + initialize_feature_gate, initialize_plugins, initialize_rbac, initialize_workload_manager, @@ -158,6 +160,7 @@ def initialize() -> None: # race conditions initialize_zen_store() initialize_rbac() + initialize_feature_gate() initialize_workload_manager() initialize_plugins() @@ -234,6 +237,7 @@ def dashboard(request: Request) -> Any: app.include_router(service_accounts_endpoints.router) app.include_router(service_connectors_endpoints.router) app.include_router(service_connectors_endpoints.types_router) +app.include_router(service_endpoints.router) app.include_router(stacks_endpoints.router) app.include_router(stack_components_endpoints.router) app.include_router(stack_components_endpoints.types_router) diff --git a/src/zenml/zen_stores/migrations/utils.py b/src/zenml/zen_stores/migrations/utils.py index f1300946ee5..6ee4af7b3b5 100644 --- a/src/zenml/zen_stores/migrations/utils.py +++ b/src/zenml/zen_stores/migrations/utils.py @@ -236,9 +236,17 @@ def backup_database_to_storage( # correct order, since some tables have inner foreign key # constraints. if "created" in table.columns: - order_by = table.columns["created"] + order_by = [table.columns["created"]] else: - order_by = None + order_by = [] + if "id" in table.columns: + # If the table has an `id` column, we also use it to sort + # the rows in the table, even if we already use "created" + # to sort the rows. We need a unique field to sort the rows, + # to break the tie between rows with the same "created" + # date, otherwise the same entry might end up multiple times + # in subsequent pages. + order_by.append(table.columns["id"]) # Fetch the number of rows in the table row_count = conn.scalar( @@ -250,7 +258,7 @@ def backup_database_to_storage( for i in range(0, row_count, batch_size): rows = conn.execute( table.select() - .order_by(order_by) + .order_by(*order_by) .limit(batch_size) .offset(i) ).fetchall() diff --git a/src/zenml/zen_stores/migrations/versions/0.56.0_release.py b/src/zenml/zen_stores/migrations/versions/0.56.0_release.py new file mode 100644 index 00000000000..85dc2ccdf5e --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/0.56.0_release.py @@ -0,0 +1,23 @@ +"""Release [0.56.0]. + +Revision ID: 0.56.0 +Revises: 1a9a9d2a836d +Create Date: 2024-03-20 13:30:40.013587 + +""" + +# revision identifiers, used by Alembic. +revision = "0.56.0" +down_revision = "1a9a9d2a836d" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + pass + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + pass diff --git a/src/zenml/zen_stores/migrations/versions/0.56.1_release.py b/src/zenml/zen_stores/migrations/versions/0.56.1_release.py new file mode 100644 index 00000000000..d1eb6c0c982 --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/0.56.1_release.py @@ -0,0 +1,23 @@ +"""Release [0.56.1]. + +Revision ID: 0.56.1 +Revises: 0.56.0 +Create Date: 2024-03-21 14:50:20.869911 + +""" + +# revision identifiers, used by Alembic. +revision = "0.56.1" +down_revision = "0.56.0" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + pass + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + pass diff --git a/src/zenml/zen_stores/migrations/versions/0.56.2_release.py b/src/zenml/zen_stores/migrations/versions/0.56.2_release.py new file mode 100644 index 00000000000..47431e949fe --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/0.56.2_release.py @@ -0,0 +1,23 @@ +"""Release [0.56.2]. + +Revision ID: 0.56.2 +Revises: 0701da9951a0 +Create Date: 2024-03-25 14:49:49.021147 + +""" + +# revision identifiers, used by Alembic. +revision = "0.56.2" +down_revision = "0701da9951a0" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + pass + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + pass diff --git a/src/zenml/zen_stores/migrations/versions/0701da9951a0_added_service_table.py b/src/zenml/zen_stores/migrations/versions/0701da9951a0_added_service_table.py new file mode 100644 index 00000000000..b32a6fe8b72 --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/0701da9951a0_added_service_table.py @@ -0,0 +1,94 @@ +"""Added service table [0701da9951a0]. + +Revision ID: 0701da9951a0 +Revises: 0.56.1 +Create Date: 2024-03-25 12:24:32.928543 + +""" + +import sqlalchemy as sa +import sqlmodel +from alembic import op +from sqlalchemy.engine.reflection import Inspector + +# revision identifiers, used by Alembic. +revision = "0701da9951a0" +down_revision = "0.56.1" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + # If the tables already exist, skip this migration. + conn = op.get_bind() + inspector = Inspector.from_engine(conn) + tables = inspector.get_table_names() + if "service" in tables: + return + + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "service", + sa.Column( + "workspace_id", sqlmodel.sql.sqltypes.GUID(), nullable=False + ), + sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), + sa.Column("service_source", sa.TEXT(), nullable=True), + sa.Column("service_type", sa.TEXT(), nullable=False), + sa.Column("type", sa.TEXT(), nullable=False), + sa.Column("flavor", sa.TEXT(), nullable=False), + sa.Column("admin_state", sa.TEXT(), nullable=True), + sa.Column("state", sa.TEXT(), nullable=True), + sa.Column("prediction_url", sa.TEXT(), nullable=True), + sa.Column("health_check_url", sa.TEXT(), nullable=True), + sa.Column("pipeline_name", sa.TEXT(), nullable=True), + sa.Column("pipeline_step_name", sa.TEXT(), nullable=True), + sa.Column( + "model_version_id", sqlmodel.sql.sqltypes.GUID(), nullable=True + ), + sa.Column( + "pipeline_run_id", sqlmodel.sql.sqltypes.GUID(), nullable=True + ), + sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("created", sa.DateTime(), nullable=False), + sa.Column("updated", sa.DateTime(), nullable=False), + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("labels", sa.LargeBinary(), nullable=True), + sa.Column("config", sa.LargeBinary(), nullable=False), + sa.Column("status", sa.LargeBinary(), nullable=True), + sa.Column("endpoint", sa.LargeBinary(), nullable=True), + sa.ForeignKeyConstraint( + ["model_version_id"], + ["model_version.id"], + name="fk_service_model_version_id_model_version", + ondelete="SET NULL", + ), + sa.ForeignKeyConstraint( + ["pipeline_run_id"], + ["pipeline_run.id"], + name="fk_service_pipeline_run_id_pipeline_run", + ondelete="SET NULL", + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + name="fk_service_user_id_user", + ondelete="SET NULL", + ), + sa.ForeignKeyConstraint( + ["workspace_id"], + ["workspace.id"], + name="fk_service_workspace_id_workspace", + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("service") + # ### end Alembic commands ### diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 8a52010daf3..d03cec9f79b 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -80,6 +80,7 @@ SERVICE_CONNECTOR_TYPES, SERVICE_CONNECTOR_VERIFY, SERVICE_CONNECTORS, + SERVICES, STACK_COMPONENTS, STACKS, STEPS, @@ -189,6 +190,10 @@ ServiceConnectorResponse, ServiceConnectorTypeModel, ServiceConnectorUpdate, + ServiceFilter, + ServiceRequest, + ServiceResponse, + ServiceUpdate, StackFilter, StackRequest, StackResponse, @@ -590,6 +595,93 @@ def delete_api_key( route=f"{SERVICE_ACCOUNTS}/{str(service_account_id)}{API_KEYS}", ) + # ----------------------------- Services ----------------------------- + + def create_service( + self, service_request: ServiceRequest + ) -> ServiceResponse: + """Create a new service. + + Args: + service_request: The service to create. + + Returns: + The created service. + """ + return self._create_resource( + resource=service_request, + response_model=ServiceResponse, + route=SERVICES, + ) + + def get_service( + self, service_id: UUID, hydrate: bool = True + ) -> ServiceResponse: + """Get a service. + + Args: + service_id: The ID of the service to get. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + The service. + """ + return self._get_resource( + resource_id=service_id, + route=SERVICES, + response_model=ServiceResponse, + params={"hydrate": hydrate}, + ) + + def list_services( + self, filter_model: ServiceFilter, hydrate: bool = False + ) -> Page[ServiceResponse]: + """List all services matching the given filter criteria. + + Args: + filter_model: All filter parameters including pagination + params. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + A list of all services matching the filter criteria. + """ + return self._list_paginated_resources( + route=SERVICES, + response_model=ServiceResponse, + filter_model=filter_model, + params={"hydrate": hydrate}, + ) + + def update_service( + self, service_id: UUID, update: ServiceUpdate + ) -> ServiceResponse: + """Update a service. + + Args: + service_id: The ID of the service to update. + update: The update to be applied to the service. + + Returns: + The updated service. + """ + return self._update_resource( + resource_id=service_id, + resource_update=update, + response_model=ServiceResponse, + route=SERVICES, + ) + + def delete_service(self, service_id: UUID) -> None: + """Delete a service. + + Args: + service_id: The ID of the service to delete. + """ + self._delete_resource(resource_id=service_id, route=SERVICES) + # ----------------------------- Artifacts ----------------------------- def create_artifact(self, artifact: ArtifactRequest) -> ArtifactResponse: @@ -3816,6 +3908,7 @@ def _create_resource( The created resource. """ response_body = self.post(f"{route}", body=resource, params=params) + return response_model.parse_obj(response_body) def _create_workspace_scoped_resource( diff --git a/src/zenml/zen_stores/schemas/__init__.py b/src/zenml/zen_stores/schemas/__init__.py index 0ec208fff81..5957605c0c7 100644 --- a/src/zenml/zen_stores/schemas/__init__.py +++ b/src/zenml/zen_stores/schemas/__init__.py @@ -41,6 +41,7 @@ from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema from zenml.zen_stores.schemas.schedule_schema import ScheduleSchema from zenml.zen_stores.schemas.secret_schemas import SecretSchema +from zenml.zen_stores.schemas.service_schemas import ServiceSchema from zenml.zen_stores.schemas.service_connector_schemas import ( ServiceConnectorSchema, ) @@ -90,6 +91,7 @@ "ScheduleSchema", "SecretSchema", "ServiceConnectorSchema", + "ServiceSchema", "StackComponentSchema", "StackCompositionSchema", "StackSchema", diff --git a/src/zenml/zen_stores/schemas/artifact_schemas.py b/src/zenml/zen_stores/schemas/artifact_schemas.py index 40bb32d2a57..6aeebd556ca 100644 --- a/src/zenml/zen_stores/schemas/artifact_schemas.py +++ b/src/zenml/zen_stores/schemas/artifact_schemas.py @@ -171,7 +171,7 @@ class ArtifactVersionSchema(BaseSchema, table=True): # Fields version: str version_number: Optional[int] - type: str + type: ArtifactType uri: str = Field(sa_column=Column(TEXT, nullable=False)) materializer: str = Field(sa_column=Column(TEXT, nullable=False)) data_type: str = Field(sa_column=Column(TEXT, nullable=False)) @@ -277,7 +277,7 @@ def from_request( artifact_store_id=artifact_version_request.artifact_store_id, workspace_id=artifact_version_request.workspace, user_id=artifact_version_request.user, - type=artifact_version_request.type.value, + type=artifact_version_request.type, uri=artifact_version_request.uri, materializer=artifact_version_request.materializer.json(), data_type=artifact_version_request.data_type.json(), @@ -328,7 +328,7 @@ def to_model( version=self.version_number or self.version, user=self.user.to_model() if self.user else None, uri=self.uri, - type=ArtifactType(self.type), + type=self.type, materializer=materializer, data_type=data_type, created=self.created, diff --git a/src/zenml/zen_stores/schemas/artifact_visualization_schemas.py b/src/zenml/zen_stores/schemas/artifact_visualization_schemas.py index 79862fc0376..6447bf93d25 100644 --- a/src/zenml/zen_stores/schemas/artifact_visualization_schemas.py +++ b/src/zenml/zen_stores/schemas/artifact_visualization_schemas.py @@ -37,7 +37,7 @@ class ArtifactVisualizationSchema(BaseSchema, table=True): __tablename__ = "artifact_visualization" # Fields - type: str + type: VisualizationType uri: str = Field(sa_column=Column(TEXT, nullable=False)) # Foreign Keys @@ -71,7 +71,7 @@ def from_model( The `ArtifactVisualizationSchema`. """ return cls( - type=artifact_visualization_request.type.value, + type=artifact_visualization_request.type, uri=artifact_visualization_request.uri, artifact_version_id=artifact_version_id, ) @@ -95,7 +95,7 @@ def to_model( The `Visualization`. """ body = ArtifactVisualizationResponseBody( - type=VisualizationType(self.type), + type=self.type, uri=self.uri, created=self.created, updated=self.updated, diff --git a/src/zenml/zen_stores/schemas/component_schemas.py b/src/zenml/zen_stores/schemas/component_schemas.py index ca50ac29e37..f3b5b44ea72 100644 --- a/src/zenml/zen_stores/schemas/component_schemas.py +++ b/src/zenml/zen_stores/schemas/component_schemas.py @@ -49,7 +49,7 @@ class StackComponentSchema(NamedSchema, table=True): __tablename__ = "stack_component" - type: str + type: StackComponentType flavor: str configuration: bytes labels: Optional[bytes] @@ -127,8 +127,6 @@ def update( self.labels = base64.b64encode( json.dumps(component_update.labels).encode("utf-8") ) - elif field == "type": - self.type = component_update.type.value else: setattr(self, field, value) @@ -153,7 +151,7 @@ def to_model( A `ComponentModel` """ body = ComponentResponseBody( - type=StackComponentType(self.type), + type=self.type, flavor=self.flavor, user=self.user.to_model() if self.user else None, created=self.created, diff --git a/src/zenml/zen_stores/schemas/device_schemas.py b/src/zenml/zen_stores/schemas/device_schemas.py index abb9e6551c7..93ebc69556f 100644 --- a/src/zenml/zen_stores/schemas/device_schemas.py +++ b/src/zenml/zen_stores/schemas/device_schemas.py @@ -44,7 +44,7 @@ class OAuthDeviceSchema(BaseSchema, table=True): client_id: UUID user_code: str device_code: str - status: str + status: OAuthDeviceStatus failed_auth_attempts: int = 0 expires: Optional[datetime] = None last_login: Optional[datetime] = None @@ -121,7 +121,7 @@ def from_request( client_id=request.client_id, user_code=hashed_user_code, device_code=hashed_device_code, - status=OAuthDeviceStatus.PENDING.value, + status=OAuthDeviceStatus.PENDING, failed_auth_attempts=0, expires=now + timedelta(seconds=request.expires_in), os=request.os, @@ -153,9 +153,9 @@ def update(self, device_update: OAuthDeviceUpdate) -> "OAuthDeviceSchema": setattr(self, field, value) if device_update.locked is True: - self.status = OAuthDeviceStatus.LOCKED.value + self.status = OAuthDeviceStatus.LOCKED elif device_update.locked is False: - self.status = OAuthDeviceStatus.ACTIVE.value + self.status = OAuthDeviceStatus.ACTIVE self.updated = datetime.utcnow() return self @@ -233,7 +233,7 @@ def to_model( client_id=self.client_id, expires=self.expires, trusted_device=self.trusted_device, - status=OAuthDeviceStatus(self.status), + status=self.status, os=self.os, ip_address=self.ip_address, hostname=self.hostname, diff --git a/src/zenml/zen_stores/schemas/flavor_schemas.py b/src/zenml/zen_stores/schemas/flavor_schemas.py index 7ace6f97716..edb9c3d8b37 100644 --- a/src/zenml/zen_stores/schemas/flavor_schemas.py +++ b/src/zenml/zen_stores/schemas/flavor_schemas.py @@ -46,7 +46,7 @@ class FlavorSchema(NamedSchema, table=True): __tablename__ = "flavor" - type: str + type: StackComponentType source: str config_schema: str = Field(sa_column=Column(TEXT, nullable=False)) integration: Optional[str] = Field(default="") @@ -98,8 +98,6 @@ def update(self, flavor_update: "FlavorUpdate") -> "FlavorSchema": ).items(): if field == "config_schema": setattr(self, field, json.dumps(value)) - elif field == "type": - setattr(self, field, value.value) else: setattr(self, field, value) @@ -125,7 +123,7 @@ def to_model( """ body = FlavorResponseBody( user=self.user.to_model() if self.user else None, - type=StackComponentType(self.type), + type=self.type, integration=self.integration, logo_url=self.logo_url, created=self.created, diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index 52e0355b9eb..4658d094281 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -14,7 +14,7 @@ """SQLModel implementation of model tables.""" from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast from uuid import UUID from sqlalchemy import BOOLEAN, INTEGER, TEXT, Column @@ -38,6 +38,8 @@ ModelVersionResponse, ModelVersionResponseBody, ModelVersionResponseMetadata, + ModelVersionResponseResources, + Page, ) from zenml.zen_stores.schemas.artifact_schemas import ArtifactVersionSchema from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema @@ -46,8 +48,12 @@ from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema from zenml.zen_stores.schemas.user_schemas import UserSchema +from zenml.zen_stores.schemas.utils import get_page_from_list from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema +if TYPE_CHECKING: + from zenml.zen_stores.schemas import ServiceSchema + class ModelSchema(NamedSchema, table=True): """SQL Model for model.""" @@ -263,6 +269,10 @@ class ModelVersionSchema(NamedSchema, table=True): ), ) + services: List["ServiceSchema"] = Relationship( + back_populates="model_version", + ) + number: int = Field(sa_column=Column(INTEGER, nullable=False)) description: str = Field(sa_column=Column(TEXT, nullable=True)) stage: str = Field(sa_column=Column(TEXT, nullable=True)) @@ -315,6 +325,8 @@ def to_model( Returns: The created `ModelVersionResponse`. """ + from zenml.models import ServiceResponse + # Construct {name: {version: id}} dicts for all linked artifacts model_artifact_ids: Dict[str, Dict[str, UUID]] = {} deployment_artifact_ids: Dict[str, Dict[str, UUID]] = {} @@ -347,7 +359,6 @@ def to_model( pipeline_run_ids[pipeline_run.name] = pipeline_run.id metadata = None - if include_metadata: metadata = ModelVersionResponseMetadata( workspace=self.workspace.to_model(), @@ -358,6 +369,21 @@ def to_model( }, ) + resources = None + if include_resources: + services = cast( + Page[ServiceResponse], + get_page_from_list( + items_list=self.services, + response_model=ServiceResponse, + include_resources=include_resources, + include_metadata=include_metadata, + ), + ) + resources = ModelVersionResponseResources( + services=services, + ) + body = ModelVersionResponseBody( user=self.user.to_model() if self.user else None, created=self.created, @@ -377,6 +403,7 @@ def to_model( name=self.name, body=body, metadata=metadata, + resources=resources, ) def update( diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index 966952d4416..c27cc34bb74 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -49,6 +49,7 @@ ModelVersionPipelineRunSchema, ) from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema + from zenml.zen_stores.schemas.service_schemas import ServiceSchema from zenml.zen_stores.schemas.step_run_schemas import StepRunSchema @@ -68,7 +69,7 @@ class PipelineRunSchema(NamedSchema, table=True): orchestrator_run_id: Optional[str] = Field(nullable=True) start_time: Optional[datetime] = Field(nullable=True) end_time: Optional[datetime] = Field(nullable=True, default=None) - status: str = Field(nullable=False) + status: ExecutionStatus = Field(nullable=False) orchestrator_environment: Optional[str] = Field( sa_column=Column(TEXT, nullable=True) ) @@ -182,6 +183,10 @@ class PipelineRunSchema(NamedSchema, table=True): pipeline: Optional["PipelineSchema"] = Relationship(back_populates="runs") trigger_execution: Optional["TriggerExecutionSchema"] = Relationship() + services: List["ServiceSchema"] = Relationship( + back_populates="pipeline_run", + ) + @classmethod def from_request( cls, request: "PipelineRunRequest" @@ -203,7 +208,7 @@ def from_request( orchestrator_run_id=request.orchestrator_run_id, orchestrator_environment=orchestrator_environment, start_time=request.start_time, - status=request.status.value, + status=request.status, pipeline_id=request.pipeline, deployment_id=request.deployment, trigger_execution_id=request.trigger_execution_id, @@ -277,7 +282,7 @@ def to_model( body = PipelineRunResponseBody( user=self.user.to_model() if self.user else None, - status=ExecutionStatus(self.status), + status=self.status, stack=stack, pipeline=pipeline, build=build, @@ -322,7 +327,7 @@ def update(self, run_update: "PipelineRunUpdate") -> "PipelineRunSchema": The updated `PipelineRunSchema`. """ if run_update.status: - self.status = run_update.status.value + self.status = run_update.status self.end_time = run_update.end_time self.updated = datetime.utcnow() @@ -367,7 +372,7 @@ def update_placeholder( self.orchestrator_run_id = request.orchestrator_run_id self.orchestrator_environment = orchestrator_environment - self.status = request.status.value + self.status = request.status self.updated = datetime.utcnow() diff --git a/src/zenml/zen_stores/schemas/run_metadata_schemas.py b/src/zenml/zen_stores/schemas/run_metadata_schemas.py index f84e210d97d..ade0bb1449a 100644 --- a/src/zenml/zen_stores/schemas/run_metadata_schemas.py +++ b/src/zenml/zen_stores/schemas/run_metadata_schemas.py @@ -109,7 +109,7 @@ class RunMetadataSchema(BaseSchema, table=True): key: str value: str = Field(sa_column=Column(TEXT, nullable=False)) - type: str + type: MetadataTypeEnum def to_model( self, @@ -134,7 +134,7 @@ def to_model( created=self.created, updated=self.updated, value=json.loads(self.value), - type=MetadataTypeEnum(self.type), + type=self.type, ) metadata = None if include_metadata: diff --git a/src/zenml/zen_stores/schemas/secret_schemas.py b/src/zenml/zen_stores/schemas/secret_schemas.py index 94059c6b102..468318c87c8 100644 --- a/src/zenml/zen_stores/schemas/secret_schemas.py +++ b/src/zenml/zen_stores/schemas/secret_schemas.py @@ -55,7 +55,7 @@ class SecretSchema(NamedSchema, table=True): __tablename__ = "secret" - scope: str + scope: SecretScope values: Optional[bytes] = Field(sa_column=Column(TEXT, nullable=True)) @@ -177,7 +177,7 @@ def from_request( assert secret.user is not None, "User must be set for secret creation." return cls( name=secret.name, - scope=secret.scope.value, + scope=secret.scope, workspace_id=secret.workspace, user_id=secret.user, # Don't store secret values implicitly in the secret. The @@ -204,10 +204,7 @@ def update( for field, value in secret_update.dict( exclude_unset=True, exclude={"workspace", "user", "values"} ).items(): - if field == "scope": - setattr(self, field, value.value) - else: - setattr(self, field, value) + setattr(self, field, value) self.updated = datetime.utcnow() return self @@ -242,7 +239,7 @@ def to_model( user=self.user.to_model() if self.user else None, created=self.created, updated=self.updated, - scope=SecretScope(self.scope), + scope=self.scope, ) return SecretResponse( id=self.id, diff --git a/src/zenml/zen_stores/schemas/service_schemas.py b/src/zenml/zen_stores/schemas/service_schemas.py new file mode 100644 index 00000000000..a38c0b68425 --- /dev/null +++ b/src/zenml/zen_stores/schemas/service_schemas.py @@ -0,0 +1,249 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""SQLModel implementation of service table.""" + +import base64 +import json +from datetime import datetime +from typing import Any, Optional +from uuid import UUID + +from sqlalchemy import TEXT, Column +from sqlmodel import Field, Relationship + +from zenml.models.v2.core.service import ( + ServiceRequest, + ServiceResponse, + ServiceResponseBody, + ServiceResponseMetadata, + ServiceResponseResources, + ServiceUpdate, +) +from zenml.utils.dict_utils import dict_to_bytes +from zenml.zen_stores.schemas.base_schemas import NamedSchema +from zenml.zen_stores.schemas.model_schemas import ModelVersionSchema +from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema +from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field +from zenml.zen_stores.schemas.user_schemas import UserSchema +from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema + + +class ServiceSchema(NamedSchema, table=True): + """SQL Model for service.""" + + __tablename__ = "service" + + workspace_id: UUID = build_foreign_key_field( + source=__tablename__, + target=WorkspaceSchema.__tablename__, + source_column="workspace_id", + target_column="id", + ondelete="CASCADE", + nullable=False, + ) + workspace: "WorkspaceSchema" = Relationship(back_populates="services") + + user_id: Optional[UUID] = build_foreign_key_field( + source=__tablename__, + target=UserSchema.__tablename__, + source_column="user_id", + target_column="id", + ondelete="SET NULL", + nullable=True, + ) + user: Optional["UserSchema"] = Relationship(back_populates="services") + service_source: Optional[str] = Field( + sa_column=Column(TEXT, nullable=True) + ) + service_type: str = Field(sa_column=Column(TEXT, nullable=False)) + type: str = Field(sa_column=Column(TEXT, nullable=False)) + flavor: str = Field(sa_column=Column(TEXT, nullable=False)) + admin_state: Optional[str] = Field(sa_column=Column(TEXT, nullable=True)) + state: Optional[str] = Field(sa_column=Column(TEXT, nullable=True)) + labels: Optional[bytes] + config: bytes + status: Optional[bytes] + endpoint: Optional[bytes] + prediction_url: Optional[str] = Field( + sa_column=Column(TEXT, nullable=True) + ) + health_check_url: Optional[str] = Field( + sa_column=Column(TEXT, nullable=True) + ) + pipeline_name: Optional[str] = Field(sa_column=Column(TEXT, nullable=True)) + pipeline_step_name: Optional[str] = Field( + sa_column=Column(TEXT, nullable=True) + ) + model_version_id: Optional[UUID] = build_foreign_key_field( + source=__tablename__, + target=ModelVersionSchema.__tablename__, + source_column="model_version_id", + target_column="id", + ondelete="SET NULL", + nullable=True, + ) + model_version: Optional["ModelVersionSchema"] = Relationship( + back_populates="services", + ) + pipeline_run_id: Optional[UUID] = build_foreign_key_field( + source=__tablename__, + target="pipeline_run", + source_column="pipeline_run_id", + target_column="id", + ondelete="SET NULL", + nullable=True, + ) + pipeline_run: Optional["PipelineRunSchema"] = Relationship( + back_populates="services", + ) + + def to_model( + self, + include_metadata: bool = False, + include_resources: bool = False, + **kwargs: Any, + ) -> ServiceResponse: + """Convert an `ServiceSchema` to an `ServiceResponse`. + + Args: + include_metadata: Whether to include metadata in the response. + include_resources: Whether to include resources in the response. + kwargs: Additional keyword arguments. + + Returns: + The created `ServiceResponse`. + """ + body = ServiceResponseBody( + user=self.user.to_model() if self.user else None, + workspace=self.workspace.to_model(), + created=self.created, + updated=self.updated, + service_type=json.loads(self.service_type), + labels=json.loads(base64.b64decode(self.labels).decode()) + if self.labels + else None, + state=self.state, + ) + metadata = None + if include_metadata: + metadata = ServiceResponseMetadata( + workspace=self.workspace.to_model(), + service_source=self.service_source, + config=json.loads(base64.b64decode(self.config).decode()), + status=json.loads(base64.b64decode(self.status).decode()) + if self.status + else None, + endpoint=json.loads(base64.b64decode(self.endpoint).decode()) + if self.endpoint + else None, + admin_state=self.admin_state or None, + prediction_url=self.prediction_url or None, + health_check_url=self.health_check_url, + ) + resources = None + if include_resources: + resources = ServiceResponseResources( + model_version=self.model_version.to_model() + if self.model_version + else None, + pipeline_run=self.pipeline_run.to_model() + if self.pipeline_run + else None, + ) + return ServiceResponse( + id=self.id, + name=self.name, + body=body, + metadata=metadata, + resources=resources, + ) + + def update( + self, + update: ServiceUpdate, + ) -> "ServiceSchema": + """Updates a `ServiceSchema` from a `ServiceUpdate`. + + Args: + update: The `ServiceUpdate` to update from. + + Returns: + The updated `ServiceSchema`. + """ + for field, value in update.dict( + exclude_unset=True, exclude_none=True + ).items(): + if field == "labels": + self.labels = ( + dict_to_bytes(update.labels) if update.labels else None + ) + elif field == "status": + self.status = ( + dict_to_bytes(update.status) if update.status else None + ) + self.state = ( + update.status.get("state") if update.status else None + ) + elif field == "endpoint": + self.endpoint = ( + dict_to_bytes(update.endpoint) if update.endpoint else None + ) + else: + setattr(self, field, value) + self.updated = datetime.utcnow() + return self + + @classmethod + def from_request( + cls, service_request: "ServiceRequest" + ) -> "ServiceSchema": + """Convert a `ServiceRequest` to a `ServiceSchema`. + + Args: + service_request: The request model to convert. + + Returns: + The converted schema. + """ + return cls( + name=service_request.name, + workspace_id=service_request.workspace, + user_id=service_request.user, + service_source=service_request.service_source, + service_type=service_request.service_type.json(), + type=service_request.service_type.type, + flavor=service_request.service_type.flavor, + admin_state=service_request.admin_state, + config=dict_to_bytes(service_request.config), + labels=dict_to_bytes(service_request.labels) + if service_request.labels + else None, + status=dict_to_bytes(service_request.status) + if service_request.status + else None, + endpoint=dict_to_bytes(service_request.endpoint) + if service_request.endpoint + else None, + state=service_request.status.get("state") + if service_request.status + else None, + model_version_id=service_request.model_version_id, + pipeline_run_id=service_request.pipeline_run_id, + prediction_url=service_request.prediction_url, + health_check_url=service_request.health_check_url, + pipeline_name=service_request.config.get("pipeline_name"), + pipeline_step_name=service_request.config.get( + "pipeline_step_name" + ), + ) diff --git a/src/zenml/zen_stores/schemas/step_run_schemas.py b/src/zenml/zen_stores/schemas/step_run_schemas.py index 4ae1d111f90..8ba628fc92a 100644 --- a/src/zenml/zen_stores/schemas/step_run_schemas.py +++ b/src/zenml/zen_stores/schemas/step_run_schemas.py @@ -27,6 +27,8 @@ from zenml.enums import ( ExecutionStatus, MetadataResourceTypes, + StepRunInputArtifactType, + StepRunOutputArtifactType, ) from zenml.models import ( StepRunRequest, @@ -58,7 +60,7 @@ class StepRunSchema(NamedSchema, table=True): # Fields start_time: Optional[datetime] = Field(nullable=True) end_time: Optional[datetime] = Field(nullable=True) - status: str = Field(nullable=False) + status: ExecutionStatus = Field(nullable=False) docstring: Optional[str] = Field(sa_column=Column(TEXT, nullable=True)) cache_key: Optional[str] = Field(nullable=True) @@ -163,7 +165,7 @@ def from_request(cls, request: StepRunRequest) -> "StepRunSchema": user_id=request.user, start_time=request.start_time, end_time=request.end_time, - status=request.status.value, + status=request.status, original_step_run_id=request.original_step_run_id, pipeline_run_id=request.pipeline_run_id, deployment_id=request.deployment, @@ -223,7 +225,7 @@ def to_model( body = StepRunResponseBody( user=self.user.to_model() if self.user else None, - status=ExecutionStatus(self.status), + status=self.status, inputs=input_artifacts, outputs=output_artifacts, created=self.created, @@ -268,7 +270,7 @@ def update(self, step_update: "StepRunUpdate") -> "StepRunSchema": exclude_unset=True, exclude_none=True ).items(): if key == "status": - self.status = value.value + self.status = value if key == "end_time": self.end_time = value @@ -310,7 +312,7 @@ class StepRunInputArtifactSchema(SQLModel, table=True): # Fields name: str = Field(nullable=False, primary_key=True) - type: str + type: StepRunInputArtifactType # Foreign keys step_id: UUID = build_foreign_key_field( @@ -346,7 +348,7 @@ class StepRunOutputArtifactSchema(SQLModel, table=True): # Fields name: str - type: str + type: StepRunOutputArtifactType # Foreign keys step_id: UUID = build_foreign_key_field( diff --git a/src/zenml/zen_stores/schemas/tag_schemas.py b/src/zenml/zen_stores/schemas/tag_schemas.py index 1cfbfc29c55..803a53805c5 100644 --- a/src/zenml/zen_stores/schemas/tag_schemas.py +++ b/src/zenml/zen_stores/schemas/tag_schemas.py @@ -108,11 +108,7 @@ def update(self, update: TagUpdate) -> "TagSchema": The updated `TagSchema`. """ for field, value in update.dict(exclude_unset=True).items(): - if field == "color": - setattr(self, field, value.value) - else: - setattr(self, field, value) - + setattr(self, field, value) self.updated = datetime.utcnow() return self diff --git a/src/zenml/zen_stores/schemas/user_schemas.py b/src/zenml/zen_stores/schemas/user_schemas.py index 610e45d4c18..72737ffa9b8 100644 --- a/src/zenml/zen_stores/schemas/user_schemas.py +++ b/src/zenml/zen_stores/schemas/user_schemas.py @@ -54,6 +54,7 @@ ScheduleSchema, SecretSchema, ServiceConnectorSchema, + ServiceSchema, StackComponentSchema, StackSchema, StepRunSchema, @@ -124,6 +125,7 @@ class UserSchema(NamedSchema, table=True): code_repositories: List["CodeRepositorySchema"] = Relationship( back_populates="user", ) + services: List["ServiceSchema"] = Relationship(back_populates="user") service_connectors: List["ServiceConnectorSchema"] = Relationship( back_populates="user", ) diff --git a/src/zenml/zen_stores/schemas/workspace_schemas.py b/src/zenml/zen_stores/schemas/workspace_schemas.py index aa9fd28f16c..3da451ac6c1 100644 --- a/src/zenml/zen_stores/schemas/workspace_schemas.py +++ b/src/zenml/zen_stores/schemas/workspace_schemas.py @@ -45,6 +45,7 @@ ScheduleSchema, SecretSchema, ServiceConnectorSchema, + ServiceSchema, StackComponentSchema, StackSchema, StepRunSchema, @@ -120,6 +121,10 @@ class WorkspaceSchema(NamedSchema, table=True): back_populates="workspace", sa_relationship_kwargs={"cascade": "delete"}, ) + services: List["ServiceSchema"] = Relationship( + back_populates="workspace", + sa_relationship_kwargs={"cascade": "delete"}, + ) service_connectors: List["ServiceConnectorSchema"] = Relationship( back_populates="workspace", sa_relationship_kwargs={"cascade": "delete"}, diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 5feb10d93a2..7cf37dcad84 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -36,11 +36,11 @@ TypeVar, Union, cast, - get_origin, ) from uuid import UUID from pydantic import Field, SecretStr, root_validator, validator +from pydantic.json import pydantic_encoder from sqlalchemy import asc, desc, func from sqlalchemy.engine import URL, Engine, make_url from sqlalchemy.exc import ( @@ -48,7 +48,7 @@ IntegrityError, NoResultFound, ) -from sqlalchemy.orm import Mapped, noload +from sqlalchemy.orm import noload from sqlmodel import ( Session, SQLModel, @@ -209,6 +209,10 @@ ServiceConnectorResponse, ServiceConnectorTypeModel, ServiceConnectorUpdate, + ServiceFilter, + ServiceRequest, + ServiceResponse, + ServiceUpdate, StackFilter, StackRequest, StackResponse, @@ -299,6 +303,7 @@ ArtifactVisualizationSchema, ) from zenml.zen_stores.schemas.logs_schemas import LogsSchema +from zenml.zen_stores.schemas.service_schemas import ServiceSchema from zenml.zen_stores.schemas.trigger_schemas import TriggerSchema from zenml.zen_stores.secrets_stores.base_secrets_store import BaseSecretsStore from zenml.zen_stores.secrets_stores.sql_secrets_store import ( @@ -861,18 +866,23 @@ def filter_and_paginate( custom_fetch_result = custom_fetch(session, query, filter_model) total = len(custom_fetch_result) else: - total = ( - session.query(func.count()) - .select_from(query.options(noload("*")).subquery()) - .scalar() + total = session.scalar( + select([func.count("*")]).select_from( + query.options(noload("*")).subquery() + ) ) # Sorting column, operand = filter_model.sorting_params if operand == SorterOps.DESCENDING: - query = query.order_by(desc(getattr(table, column))) + sort_clause = desc(getattr(table, column)) else: - query = query.order_by(asc(getattr(table, column))) + sort_clause = asc(getattr(table, column)) + + # We always add the `id` column as a tiebreaker to ensure a stable, + # repeatable order of items, otherwise subsequent pages might contain + # the same items. + query = query.order_by(sort_clause, asc(table.id)) # Get the total amount of pages in the database for a given query if total == 0: @@ -1363,7 +1373,9 @@ def migrate_database(self) -> None: # identity table with needed info. logger.info("Creating database tables") with self.engine.begin() as conn: - SQLModel.metadata.create_all(conn) + conn.run_callable( + SQLModel.metadata.create_all # type: ignore[arg-type] + ) with Session(self.engine) as session: session.add( IdentitySchema( @@ -1757,6 +1769,175 @@ def delete_api_key( session.delete(api_key) session.commit() + # -------------------- Services -------------------- + + @staticmethod + def _fail_if_service_with_config_exists( + service_request: ServiceRequest, session: Session + ) -> None: + """Raise an exception if a service with same name/config exists. + + Args: + service_request: The service to check for. + session: The database session to use for the query. + + Raises: + EntityExistsError: If a service with the given name and + type already exists. + """ + # Check if service with the same domain key (name, config, workspace) + # already exists + + existing_domain_service = session.exec( + select(ServiceSchema).where( + ServiceSchema.config + == base64.b64encode( + json.dumps( + service_request.config, + sort_keys=False, + default=pydantic_encoder, + ).encode("utf-8") + ) + ) + ).first() + + if existing_domain_service: + raise EntityExistsError( + f"Unable to create service '{service_request.name}' with the given configuration: " + "A service with the same configuration already exists." + ) + + def create_service(self, service: ServiceRequest) -> ServiceResponse: + """Create a new service. + + Args: + service: The service to create. + + Returns: + The newly created service. + """ + with Session(self.engine) as session: + # Check if a service with the given name already exists + self._fail_if_service_with_config_exists( + service_request=service, + session=session, + ) + + # Create the service. + service_schema = ServiceSchema.from_request(service) + logger.debug("Creating service: %s", service_schema) + session.add(service_schema) + session.commit() + + return service_schema.to_model( + include_metadata=True, include_resources=True + ) + + def get_service( + self, service_id: UUID, hydrate: bool = True + ) -> ServiceResponse: + """Get a service. + + Args: + service_id: The ID of the service to get. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + The service. + + Raises: + KeyError: if the service doesn't exist. + """ + with Session(self.engine) as session: + service = session.exec( + select(ServiceSchema).where(ServiceSchema.id == service_id) + ).first() + if service is None: + raise KeyError( + f"Unable to get service with ID {service_id}: No " + "service with this ID found." + ) + return service.to_model( + include_metadata=hydrate, include_resources=hydrate + ) + + def list_services( + self, filter_model: ServiceFilter, hydrate: bool = False + ) -> Page[ServiceResponse]: + """List all services matching the given filter criteria. + + Args: + filter_model: All filter parameters including pagination + params. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + A list of all services matching the filter criteria. + """ + with Session(self.engine) as session: + query = select(ServiceSchema) + return self.filter_and_paginate( + session=session, + query=query, + table=ServiceSchema, + filter_model=filter_model, + hydrate=hydrate, + ) + + def update_service( + self, service_id: UUID, update: ServiceUpdate + ) -> ServiceResponse: + """Update a service. + + Args: + service_id: The ID of the service to update. + update: The update to be applied to the service. + + Returns: + The updated service. + + Raises: + KeyError: if the service doesn't exist. + """ + with Session(self.engine) as session: + existing_service = session.exec( + select(ServiceSchema).where(ServiceSchema.id == service_id) + ).first() + if not existing_service: + raise KeyError(f"Service with ID {service_id} not found.") + + # Update the schema itself. + existing_service.update(update=update) + logger.debug("Updated service: %s", existing_service) + session.add(existing_service) + session.commit() + session.refresh(existing_service) + return existing_service.to_model( + include_metadata=True, include_resources=True + ) + + def delete_service(self, service_id: UUID) -> None: + """Delete a service. + + Args: + service_id: The ID of the service to delete. + + Raises: + KeyError: if the service doesn't exist. + """ + with Session(self.engine) as session: + existing_service = session.exec( + select(ServiceSchema).where(ServiceSchema.id == service_id) + ).first() + if not existing_service: + raise KeyError(f"Service with ID {service_id} not found.") + + # Delete the service + session.delete(existing_service) + session.commit() + # -------------------- Artifacts -------------------- def create_artifact(self, artifact: ArtifactRequest) -> ArtifactResponse: @@ -2566,9 +2747,7 @@ def update_stack_component( if existing_component.name != component_update.name: self._fail_if_component_with_name_type_exists( name=component_update.name, - component_type=StackComponentType( - existing_component.type - ), + component_type=existing_component.type, workspace_id=existing_component.workspace_id, session=session, ) @@ -3320,6 +3499,7 @@ def _custom_fetch( PipelineRunSchema.created == max_date_subquery.c.max_created, ) + .order_by(desc(PipelineRunSchema.updated)) ) return self.filter_and_paginate( @@ -6865,9 +7045,7 @@ def _update_pipeline_run_status( assert pipeline_run.deployment num_steps = len(pipeline_run.deployment.to_model().step_configurations) new_status = get_pipeline_run_status( - step_statuses=[ - ExecutionStatus(step_run.status) for step_run in step_runs - ], + step_statuses=[step_run.status for step_run in step_runs], num_steps=num_steps, ) @@ -7277,8 +7455,6 @@ def _get_resource_references( for resource_attr in resource_attrs: # Extract the target schema from the annotation annotation = UserSchema.__annotations__[resource_attr] - if get_origin(annotation) == Mapped: - annotation = annotation.__args__[0] # The annotation must be of the form # `typing.List[ForwardRef(' ')]` @@ -7336,13 +7512,11 @@ def _account_owns_resources( resource_attrs = self._get_resource_references() for schema, resource_attr in resource_attrs: # Check if the user owns any resources of this type - count = ( - session.query(func.count()) + count = session.scalar( + select([func.count("*")]) .select_from(schema) .where(getattr(schema, resource_attr) == account.id) - .scalar() ) - if count > 0: logger.debug( f"User {account.name} owns {count} resources of type " @@ -8421,7 +8595,9 @@ def get_model_version( f"`{model_version_id}`: No model version with this " f"ID found." ) - return model_version.to_model(include_metadata=hydrate) + return model_version.to_model( + include_metadata=hydrate, include_resources=hydrate + ) def list_model_versions( self, @@ -8570,7 +8746,9 @@ def update_model_version( session.commit() session.refresh(existing_model_version) - return existing_model_version.to_model(include_metadata=True) + return existing_model_version.to_model( + include_metadata=True, include_resources=True + ) # ------------------------ Model Versions Artifacts ------------------------ diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index 7914a5681bd..7163936d506 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -104,6 +104,10 @@ ServiceConnectorResponse, ServiceConnectorTypeModel, ServiceConnectorUpdate, + ServiceFilter, + ServiceRequest, + ServiceResponse, + ServiceUpdate, StackFilter, StackRequest, StackResponse, @@ -359,6 +363,87 @@ def delete_api_key( for the given service account. """ + # -------------------- Services -------------------- + + @abstractmethod + def create_service( + self, + service: ServiceRequest, + ) -> ServiceResponse: + """Create a new service. + + Args: + service: The service to create. + + Returns: + The newly created service. + + Raises: + EntityExistsError: If a service with the same name already exists. + """ + + @abstractmethod + def get_service( + self, service_id: UUID, hydrate: bool = True + ) -> ServiceResponse: + """Get a service by ID. + + Args: + service_id: The ID of the service to get. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + The service. + + Raises: + KeyError: if the service doesn't exist. + """ + + @abstractmethod + def list_services( + self, filter_model: ServiceFilter, hydrate: bool = False + ) -> Page[ServiceResponse]: + """List all services matching the given filter criteria. + + Args: + filter_model: All filter parameters including pagination + params. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + A list of all services matching the filter criteria. + """ + + @abstractmethod + def update_service( + self, service_id: UUID, update: ServiceUpdate + ) -> ServiceResponse: + """Update an existing service. + + Args: + service_id: The ID of the service to update. + update: The update to be applied to the service. + + Returns: + The updated service. + + Raises: + KeyError: if the service doesn't exist. + """ + + @abstractmethod + def delete_service(self, service_id: UUID) -> None: + """Delete a service. + + Args: + service_id: The ID of the service to delete. + + Raises: + KeyError: if the service doesn't exist. + """ + # -------------------- Artifacts -------------------- @abstractmethod diff --git a/tests/integration/examples/bentoml/steps/prediction_service_loader.py b/tests/integration/examples/bentoml/steps/prediction_service_loader.py index 3871fe1e8ab..1fa9e669a9c 100644 --- a/tests/integration/examples/bentoml/steps/prediction_service_loader.py +++ b/tests/integration/examples/bentoml/steps/prediction_service_loader.py @@ -29,7 +29,7 @@ def bentoml_prediction_service_loader( """Get the BentoML prediction service started by the deployment pipeline. Args: - pipeline_name: name of the pipeline that deployed the model. + pipeline_name: name of the pipeline_name that deployed the model. step_name: the name of the step that deployed the model. model_name: the name of the model that was deployed. """ diff --git a/tests/integration/examples/huggingface/steps/prediction_service_loader/prediction_service_loader.py b/tests/integration/examples/huggingface/steps/prediction_service_loader/prediction_service_loader.py index 49a763bdca1..91ea669c9cc 100644 --- a/tests/integration/examples/huggingface/steps/prediction_service_loader/prediction_service_loader.py +++ b/tests/integration/examples/huggingface/steps/prediction_service_loader/prediction_service_loader.py @@ -43,19 +43,16 @@ def prediction_service_loader( # get the Huggingface model deployer stack component model_deployer = HuggingFaceModelDeployer.get_active_model_deployer() - # fetch existing services with same pipeline name, step name and model name - services = model_deployer.find_model_server( + if services := model_deployer.find_model_server( pipeline_name=pipeline_name, pipeline_step_name=pipeline_step_name, model_name=model_name, running=running, - ) - - if not services: + ): + return cast(HuggingFaceDeploymentService, services[0]) + else: raise RuntimeError( f"No Huggingface inference endpoint deployed by step " f"'{pipeline_step_name}' in pipeline '{pipeline_name}' with name " f"'{model_name}' is currently running." ) - - return cast(HuggingFaceDeploymentService, services[0]) diff --git a/tests/integration/examples/mlflow/pipelines/deployment_pipelines/deployment_inference_pipeline.py b/tests/integration/examples/mlflow/pipelines/deployment_pipelines/deployment_inference_pipeline.py index fad0bd06331..29bb1b57887 100644 --- a/tests/integration/examples/mlflow/pipelines/deployment_pipelines/deployment_inference_pipeline.py +++ b/tests/integration/examples/mlflow/pipelines/deployment_pipelines/deployment_inference_pipeline.py @@ -36,6 +36,5 @@ def mlflow_deployment_inference_pipeline( model_deployment_service = prediction_service_loader( pipeline_name=pipeline_name, pipeline_step_name=pipeline_step_name, - running=False, ) predictor(model_deployment_service, inference_data) diff --git a/tests/integration/examples/mlflow/steps/prediction_service_loader_step.py b/tests/integration/examples/mlflow/steps/prediction_service_loader_step.py index 36067d0dfbb..4e4e8427b37 100644 --- a/tests/integration/examples/mlflow/steps/prediction_service_loader_step.py +++ b/tests/integration/examples/mlflow/steps/prediction_service_loader_step.py @@ -24,7 +24,6 @@ def prediction_service_loader( pipeline_name: str, pipeline_step_name: str, running: bool = True, - model_name: str = "model", ) -> MLFlowDeploymentService: """Get the prediction service started by the deployment pipeline. @@ -40,19 +39,13 @@ def prediction_service_loader( model_deployer = MLFlowModelDeployer.get_active_model_deployer() # fetch existing services with same pipeline name, step name and model name - existing_services = model_deployer.find_model_server( - pipeline_name=pipeline_name, - pipeline_step_name=pipeline_step_name, - model_name=model_name, - running=running, - ) + existing_services = model_deployer.find_model_server() if not existing_services: raise RuntimeError( f"No MLflow prediction service deployed by the " f"{pipeline_step_name} step in the {pipeline_name} " - f"pipeline for the '{model_name}' model is currently " - f"running." + f"pipeline" ) return existing_services[0] diff --git a/tests/integration/functional/zen_server/test_zen_server.py b/tests/integration/functional/zen_server/test_zen_server.py index 93290aa22d8..322635c8e8b 100644 --- a/tests/integration/functional/zen_server/test_zen_server.py +++ b/tests/integration/functional/zen_server/test_zen_server.py @@ -17,8 +17,13 @@ import pytest import requests +from zenml.client import Client +from zenml.constants import DEFAULT_USERNAME +from zenml.enums import StoreType from zenml.utils.networking_utils import scan_for_available_port from zenml.zen_server.deploy import ServerDeployer, ServerDeploymentConfig +from zenml.zen_server.utils import server_config +from zenml.zen_stores.rest_zen_store import RestZenStore SERVER_START_STOP_TIMEOUT = 60 @@ -73,3 +78,17 @@ def test_server_up_down(clean_client, mocker): print(line) raise assert deployer.list_servers() == [] + + +def test_rate_limit_is_not_impacted_by_successful_requests(): + zen_store = Client().zen_store + if zen_store.type == StoreType.SQL: + pytest.skip("SQL ZenStore does not support rate limiting.") + + assert Client().active_user.name == DEFAULT_USERNAME + zen_store: RestZenStore = zen_store + + repeat = server_config().login_rate_limit_minute * 2 + for _ in range(repeat): + zen_store.clear_session() + zen_store.session diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 930d8a48929..df139355033 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -69,10 +69,17 @@ WorkspaceResponseBody, WorkspaceResponseMetadata, ) +from zenml.models.v2.core.service import ( + ServiceResponse, + ServiceResponseBody, + ServiceResponseMetadata, +) from zenml.new.pipelines.pipeline import Pipeline from zenml.orchestrators.base_orchestrator import BaseOrchestratorConfig from zenml.orchestrators.local.local_orchestrator import LocalOrchestrator from zenml.pipelines import pipeline +from zenml.services.service_status import ServiceState +from zenml.services.service_type import ServiceType from zenml.stack.stack import Stack from zenml.stack.stack_component import ( StackComponentConfig, @@ -693,3 +700,64 @@ def sample_hub_plugin_response_model() -> HubPluginResponseModel: updated=datetime.now(), requirements=["ploogin==0.0.1", "zenml>=0.1.0"], ) + + +# Test data +service_id = "12345678-1234-5678-1234-567812345678" +service_name = "test_service" +service_type = ServiceType( + type="model-serving", flavor="test_flavor", name="test_name" +) +service_source = "tests.unit.services.test_service.TestService" +admin_state = ServiceState.ACTIVE +config = { + "type": "zenml.services.service.ServiceConfig", + "name": "test_service", + "description": "", + "pipeline_name": "", + "pipeline_step_name": "", + "model_name": "", + "model_version": "", + "service_name": "zenml-test_service", +} +labels = {"label1": "value1", "label2": "value2"} +status = { + "type": "zenml.services.service_status.ServiceStatus", + "state": ServiceState.ACTIVE, + "last_state": ServiceState.INACTIVE, + "last_error": "", +} +endpoint = None +prediction_url = "http://example.com/predict" +health_check_url = "http://example.com/health" +created_time = datetime(2024, 3, 14, 10, 30) +updated_time = datetime(2024, 3, 14, 11, 45) + + +@pytest.fixture +def service_response( + sample_workspace_model, +): + body = ServiceResponseBody( + service_type=service_type, + labels=labels, + created=created_time, + updated=updated_time, + state=admin_state, + ) + metadata = ServiceResponseMetadata( + service_source=service_source, + admin_state=admin_state, + config=config, + status=status, + endpoint=endpoint, + prediction_url=prediction_url, + health_check_url=health_check_url, + workspace=sample_workspace_model, + ) + return ServiceResponse( + id=service_id, + name=service_name, + body=body, + metadata=metadata, + ) diff --git a/tests/unit/models/test_service_models.py b/tests/unit/models/test_service_models.py new file mode 100644 index 00000000000..2148d576222 --- /dev/null +++ b/tests/unit/models/test_service_models.py @@ -0,0 +1,130 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from datetime import datetime + +import pytest + +from zenml.constants import STR_FIELD_MAX_LENGTH +from zenml.models import ( + ServiceRequest, + ServiceResponse, + ServiceResponseBody, + ServiceResponseMetadata, +) +from zenml.services.service_status import ServiceState +from zenml.services.service_type import ServiceType + +# Test data +service_id = "12345678-1234-5678-1234-567812345678" +service_name = "test_service" +service_type = ServiceType( + type="model-serving", flavor="test_flavor", name="test_name" +) +service_source = "tests.unit.services.test_service.TestService" +admin_state = ServiceState.ACTIVE +config = { + "type": "zenml.services.service.ServiceConfig", + "name": "test_service", + "description": "", + "pipeline_name": "", + "pipeline_step_name": "", + "model_name": "", + "model_version": "", + "service_name": "zenml-test_service", +} +labels = {"label1": "value1", "label2": "value2"} +status = { + "type": "zenml.services.service_status.ServiceStatus", + "state": ServiceState.ACTIVE, + "last_state": ServiceState.INACTIVE, + "last_error": "", +} +endpoint = None +prediction_url = "http://example.com/predict" +health_check_url = "http://example.com/health" +created_time = datetime(2023, 3, 14, 10, 30) +updated_time = datetime(2023, 3, 14, 11, 45) + + +@pytest.fixture +def service_response( + sample_workspace_model, +): + body = ServiceResponseBody( + service_type=service_type, + labels=labels, + created=created_time, + updated=updated_time, + state=admin_state, + ) + metadata = ServiceResponseMetadata( + service_source=service_source, + admin_state=admin_state, + config=config, + status=status, + endpoint=endpoint, + prediction_url=prediction_url, + health_check_url=health_check_url, + workspace=sample_workspace_model, + ) + return ServiceResponse( + id=service_id, + name=service_name, + body=body, + metadata=metadata, + ) + + +def test_service_response_properties(service_response): + assert service_response.service_type == service_type + assert service_response.labels == labels + assert service_response.service_source == service_source + assert service_response.config == config + assert service_response.status == status + assert service_response.endpoint == endpoint + assert service_response.created == created_time + assert service_response.updated == updated_time + assert service_response.admin_state == admin_state + assert service_response.prediction_url == prediction_url + assert service_response.health_check_url == health_check_url + assert service_response.state == admin_state + + +def test_service_request_name_too_long(): + # Test that the service name cannot be longer than the maximum allowed length + long_name = "a" * (STR_FIELD_MAX_LENGTH + 1) + with pytest.raises(ValueError): + ServiceRequest( + name=long_name, + service_type=ServiceType( + type="model-serving", flavor="test_flavor", name="test_name" + ), + service_source="path.to.ServiceClass", + admin_state=ServiceState.ACTIVE, + config={"param1": "value1"}, + ) + + +def test_service_request_invalid_service_type(): + # Test that an invalid service type raises an error + invalid_service_type = "invalid_type" + with pytest.raises(ValueError): + ServiceRequest( + name="test_service", + service_type=invalid_service_type, + service_source="path.to.ServiceClass", + admin_state=ServiceState.ACTIVE, + config={"param1": "value1"}, + ) diff --git a/tests/unit/services/__init__.py b/tests/unit/services/__init__.py new file mode 100644 index 00000000000..cd90a82cfc2 --- /dev/null +++ b/tests/unit/services/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. diff --git a/tests/unit/services/test_service.py b/tests/unit/services/test_service.py new file mode 100644 index 00000000000..b0875e9d62d --- /dev/null +++ b/tests/unit/services/test_service.py @@ -0,0 +1,112 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +from typing import Generator, Optional, Tuple +from uuid import UUID + +import pytest + +from zenml.services import ( + BaseService, + ServiceConfig, + ServiceState, + ServiceStatus, +) +from zenml.services.service import ZENM_ENDPOINT_PREFIX + + +# Create a concrete subclass of BaseService +class TestService(BaseService): + """Test service class for testing BaseService.""" + + SERVICE_TYPE = { + "type": "model-serving", + "flavor": "test_flavor", + "name": "test_name", + } + + @property + def is_running(self): + return True + + @property + def is_stopped(self): + return not self.is_running + + @property + def is_failed(self): + return False + + def check_status(self) -> Tuple[ServiceState, str]: + return ServiceState.ACTIVE, "Service is running" + + def get_logs( + self, follow: bool = False, tail: Optional[int] = None + ) -> Generator[str, bool, None]: + return (f"log line {i}" for i in range(5)) + + +# Modify the base_service fixture to use the TestService subclass +@pytest.fixture +def base_service(): + return TestService( + uuid=UUID("12345678-1234-5678-1234-567812345678"), + admin_state=ServiceState.ACTIVE, + config=ServiceConfig(name="test_service", param1="value1", param2=2), + status=ServiceStatus( + state=ServiceState.ACTIVE, + last_error="", + last_status=ServiceState.INACTIVE, + ), + endpoint=None, + ) + + +# Update the test_from_model to handle the case when service_source is missing +def test_from_model(service_response): + service = BaseService.from_model(service_response) + assert isinstance(service, TestService) + assert service.uuid == service_response.id + assert service.admin_state == service_response.admin_state + assert service.config == service_response.config + assert service.status == service_response.status + assert service.SERVICE_TYPE["type"] == service_response.service_type.type + assert ( + service.SERVICE_TYPE["flavor"] == service_response.service_type.flavor + ) + assert service.endpoint == service_response.endpoint + + +def test_update_status(base_service, monkeypatch): + def mock_check_status(self): + return ServiceState.ACTIVE, "Service is running" + + monkeypatch.setattr(BaseService, "check_status", mock_check_status) + base_service.update_status() + + assert base_service.status.state == ServiceState.ACTIVE + assert base_service.status.last_error == "Service is running" + + +def test_service_config_init_without_name_or_model_name(): + """Test initialization without name or model_name.""" + with pytest.raises(ValueError) as excinfo: + ServiceConfig() + assert "Either 'name' or 'model_name' must be set." in str(excinfo.value) + + +def test_service_config_init_with_name(): + """Test initialization with name.""" + config = ServiceConfig(name="test-service") + assert config.name == "test-service" + assert config.service_name == f"{ZENM_ENDPOINT_PREFIX}test-service" diff --git a/tests/unit/test_constants.py b/tests/unit/test_constants.py index 1a5e76faa3b..78ab52076fa 100644 --- a/tests/unit/test_constants.py +++ b/tests/unit/test_constants.py @@ -12,19 +12,50 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. -import os -from zenml.constants import handle_int_env_var +from zenml.constants import handle_int_env_var, handle_json_env_var -def test_handle_int_env_var(): +def test_handle_int_env_var(monkeypatch): """Check handle_int_env_var in all cases.""" env_var = "ZENML_TEST_HANDLE_INT_ENV_VAR" # check value error (when it can't be converted to int) - os.environ[env_var] = "test" + monkeypatch.setenv(env_var, "test") assert 0 == handle_int_env_var(env_var, 0) # check if it isn't there (in case it doesn't exist) - del os.environ[env_var] + monkeypatch.delenv(env_var, raising=False) assert 0 == handle_int_env_var(env_var, 0) + + +def test_handle_json_env_var(monkeypatch): + # Given an environment variable that is json + monkeypatch.setenv("TEST_VAR", '["hello", "world"]') + + # When we ask for that variable and expect it to be a List + result = handle_json_env_var("TEST_VAR", expected_type=list) + + # Then we should get the list ["hello", "world"] + assert result == ["hello", "world"] + + # Given an environment variable that is not json + monkeypatch.setenv("TEST_VAR", "hello world") + + # When we ask for that variable and expect it to be a List + result = handle_json_env_var("TEST_VAR", expected_type=list) + + # Then we should get an empty list (the default) + assert result == [] + + # Given an environment variable that is json but not the expected type + monkeypatch.setenv("TEST_VAR", '{"hello": "world"}') + + # When we ask for that variable and expect it to be a List + result = handle_json_env_var("TEST_VAR", expected_type=list) + + # Then we should get an empty list (the default) + assert result == [] + + # Unset environment variable + monkeypatch.delenv("TEST_VAR", raising=False)