From 0ed890d24f6797359792e0ba9b61052dc3c9a38d Mon Sep 17 00:00:00 2001 From: Justin Hayes <52832301+justinh-rahb@users.noreply.github.com> Date: Sun, 30 Jun 2024 01:23:57 -0400 Subject: [PATCH 1/9] Add MLX Manifold Pipeline --- .../providers/mlx_manifold_pipeline.py | 185 ++++++++++++++++++ 1 file changed, 185 insertions(+) create mode 100644 examples/pipelines/providers/mlx_manifold_pipeline.py diff --git a/examples/pipelines/providers/mlx_manifold_pipeline.py b/examples/pipelines/providers/mlx_manifold_pipeline.py new file mode 100644 index 00000000..6c8aa16b --- /dev/null +++ b/examples/pipelines/providers/mlx_manifold_pipeline.py @@ -0,0 +1,185 @@ +""" +title: MLX Manifold Pipeline +author: justinh-rahb +date: 2024-05-28 +version: 2.0 +license: MIT +description: A pipeline for generating text using Apple MLX Framework with dynamic model loading. +requirements: requests, mlx-lm, huggingface-hub, psutil +environment_variables: MLX_HOST, MLX_PORT, MLX_SUBPROCESS +""" + +from typing import List, Union, Generator, Iterator +from schemas import OpenAIChatMessage +from pydantic import BaseModel +import requests +import os +import subprocess +import logging +from huggingface_hub import login +import time +import psutil + +class Pipeline: + class Valves(BaseModel): + MLX_STOP: str = "[INST]" + HUGGINGFACE_TOKEN: str = "" + MLX_MODEL_PATTERN: str = "mistralai" + MLX_DEFAULT_MODEL: str = "mistralai/Mistral-7B-Instruct-v0.3" + + def __init__(self): + self.type = "manifold" + self.id = "mlx" + self.name = "MLX/" + + self.valves = self.Valves() + self.update_valves() + + self.host = os.getenv("MLX_HOST", "localhost") + self.port = os.getenv("MLX_PORT", "8080") + self.subprocess = os.getenv("MLX_SUBPROCESS", "true").lower() == "true" + + self.models = self.get_mlx_models() + self.current_model = None + self.server_process = None + + if self.subprocess: + self.start_mlx_server(self.valves.MLX_DEFAULT_MODEL) + + def update_valves(self): + if self.valves.HUGGINGFACE_TOKEN: + login(self.valves.HUGGINGFACE_TOKEN) + self.stop_sequence = self.valves.MLX_STOP.split(",") + + def get_mlx_models(self): + try: + cmd = [ + 'mlx_lm.manage', + '--scan', + '--pattern', self.valves.MLX_MODEL_PATTERN, + ] + result = subprocess.run(cmd, capture_output=True, text=True) + lines = result.stdout.strip().split('\n') + + # Skip header lines and the line with dashes + content_lines = [line for line in lines if line and not line.startswith('-')] + + models = [] + for line in content_lines[2:]: # Skip the first two lines (header) + parts = line.split() + if len(parts) >= 2: + repo_id = parts[0] + models.append({ + "id": f"{repo_id.split('/')[-1].lower()}", + "name": repo_id + }) + if not models: + # Add default model if no models are found + models.append({ + "id": f"mlx.{self.valves.MLX_DEFAULT_MODEL.split('/')[-1].lower()}", + "name": self.valves.MLX_DEFAULT_MODEL + }) + return models + except Exception as e: + logging.error(f"Error fetching MLX models: {e}") + # Return default model on error + return [{ + "id": f"mlx.{self.valves.MLX_DEFAULT_MODEL.split('/')[-1].lower()}", + "name": self.valves.MLX_DEFAULT_MODEL + }] + + def pipelines(self) -> List[dict]: + return self.models + + def start_mlx_server(self, model_name): + model_id = f"mlx.{model_name.split('/')[-1].lower()}" + if self.current_model == model_id and self.server_process and self.server_process.poll() is None: + logging.info(f"MLX server already running with model {model_name}") + return + + self.stop_mlx_server() + + if not os.getenv("MLX_PORT"): + self.port = self.find_free_port() + command = f"mlx_lm.server --model {model_name} --port {self.port}" + logging.info(f"Starting MLX server with command: {command}") + self.server_process = subprocess.Popen(command, shell=True) + self.current_model = model_id + logging.info(f"Started MLX server for model {model_name} on port {self.port}") + time.sleep(5) # Give the server some time to start up + + def stop_mlx_server(self): + if self.server_process: + try: + process = psutil.Process(self.server_process.pid) + for proc in process.children(recursive=True): + proc.terminate() + process.terminate() + process.wait(timeout=10) # Wait for the process to terminate + except psutil.NoSuchProcess: + pass # Process already terminated + except psutil.TimeoutExpired: + logging.warning("Timeout while terminating MLX server process") + finally: + self.server_process = None + self.current_model = None + logging.info(f"Stopped MLX server on port {self.port}") + + def find_free_port(self): + import socket + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 0)) + port = s.getsockname()[1] + s.close() + return port + + async def on_startup(self): + logging.info(f"on_startup:{__name__}") + + async def on_shutdown(self): + if self.subprocess: + self.stop_mlx_server() + + async def on_valves_updated(self): + self.update_valves() + self.models = self.get_mlx_models() + if self.subprocess: + self.start_mlx_server(self.valves.MLX_DEFAULT_MODEL) + + def pipe( + self, user_message: str, model_id: str, messages: List[dict], body: dict + ) -> Union[str, Generator, Iterator]: + logging.info(f"pipe:{__name__}") + + if model_id != self.current_model: + model_name = next((model['name'] for model in self.models if model['id'] == model_id), self.valves.MLX_DEFAULT_MODEL) + self.start_mlx_server(model_name) + + url = f"http://{self.host}:{self.port}/v1/chat/completions" + headers = {"Content-Type": "application/json"} + + max_tokens = body.get("max_tokens", 4096) + temperature = body.get("temperature", 0.8) + repeat_penalty = body.get("repeat_penalty", 1.0) + + payload = { + "messages": messages, + "max_tokens": max_tokens, + "temperature": temperature, + "repetition_penalty": repeat_penalty, + "stop": self.stop_sequence, + "stream": body.get("stream", False), + } + + try: + r = requests.post( + url, headers=headers, json=payload, stream=body.get("stream", False) + ) + r.raise_for_status() + + if body.get("stream", False): + return r.iter_lines() + else: + return r.json() + except Exception as e: + return f"Error: {e}" \ No newline at end of file From 72e933cd6b40f65e4cea349fd6e89779fca3131a Mon Sep 17 00:00:00 2001 From: Justin Hayes <52832301+justinh-rahb@users.noreply.github.com> Date: Mon, 1 Jul 2024 10:59:31 -0400 Subject: [PATCH 2/9] Fix: more robust MLX model switching --- .../providers/mlx_manifold_pipeline.py | 27 ++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/examples/pipelines/providers/mlx_manifold_pipeline.py b/examples/pipelines/providers/mlx_manifold_pipeline.py index 6c8aa16b..51181a51 100644 --- a/examples/pipelines/providers/mlx_manifold_pipeline.py +++ b/examples/pipelines/providers/mlx_manifold_pipeline.py @@ -9,6 +9,7 @@ environment_variables: MLX_HOST, MLX_PORT, MLX_SUBPROCESS """ +import argparse from typing import List, Union, Generator, Iterator from schemas import OpenAIChatMessage from pydantic import BaseModel @@ -19,13 +20,16 @@ from huggingface_hub import login import time import psutil +import json class Pipeline: class Valves(BaseModel): - MLX_STOP: str = "[INST]" + MLX_STOP: str = "<|start_header_id|>,<|end_header_id|>,<|eot_id|>" HUGGINGFACE_TOKEN: str = "" - MLX_MODEL_PATTERN: str = "mistralai" - MLX_DEFAULT_MODEL: str = "mistralai/Mistral-7B-Instruct-v0.3" + MLX_MODEL_PATTERN: str = "meta-llama" + MLX_DEFAULT_MODEL: str = "meta-llama/Meta-Llama-3-8B-Instruct" + MLX_CHAT_TEMPLATE: str = "" + MLX_USE_DEFAULT_CHAT_TEMPLATE: bool = False def __init__(self): self.type = "manifold" @@ -101,9 +105,20 @@ def start_mlx_server(self, model_name): if not os.getenv("MLX_PORT"): self.port = self.find_free_port() - command = f"mlx_lm.server --model {model_name} --port {self.port}" - logging.info(f"Starting MLX server with command: {command}") - self.server_process = subprocess.Popen(command, shell=True) + + command = [ + "mlx_lm.server", + "--model", model_name, + "--port", str(self.port), + ] + + if self.valves.MLX_CHAT_TEMPLATE: + command.extend(["--chat-template", self.valves.MLX_CHAT_TEMPLATE]) + elif self.valves.MLX_USE_DEFAULT_CHAT_TEMPLATE: + command.append("--use-default-chat-template") + + logging.info(f"Starting MLX server with command: {' '.join(command)}") + self.server_process = subprocess.Popen(command) self.current_model = model_id logging.info(f"Started MLX server for model {model_name} on port {self.port}") time.sleep(5) # Give the server some time to start up From 83db4c2035269eb4747a32bd81762db8aeed21d6 Mon Sep 17 00:00:00 2001 From: Justin Hayes <52832301+justinh-rahb@users.noreply.github.com> Date: Mon, 1 Jul 2024 11:13:40 -0400 Subject: [PATCH 3/9] Remove `MLX_SUBPROCESS=True` option --- .../providers/mlx_manifold_pipeline.py | 45 ++++++++++++------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/examples/pipelines/providers/mlx_manifold_pipeline.py b/examples/pipelines/providers/mlx_manifold_pipeline.py index 51181a51..e8c27d11 100644 --- a/examples/pipelines/providers/mlx_manifold_pipeline.py +++ b/examples/pipelines/providers/mlx_manifold_pipeline.py @@ -6,10 +6,8 @@ license: MIT description: A pipeline for generating text using Apple MLX Framework with dynamic model loading. requirements: requests, mlx-lm, huggingface-hub, psutil -environment_variables: MLX_HOST, MLX_PORT, MLX_SUBPROCESS """ -import argparse from typing import List, Union, Generator, Iterator from schemas import OpenAIChatMessage from pydantic import BaseModel @@ -32,30 +30,35 @@ class Valves(BaseModel): MLX_USE_DEFAULT_CHAT_TEMPLATE: bool = False def __init__(self): + # Pipeline identification self.type = "manifold" self.id = "mlx" self.name = "MLX/" + # Initialize valves and update them self.valves = self.Valves() self.update_valves() - self.host = os.getenv("MLX_HOST", "localhost") - self.port = os.getenv("MLX_PORT", "8080") - self.subprocess = os.getenv("MLX_SUBPROCESS", "true").lower() == "true" + # Server configuration + self.host = "localhost" # Always use localhost for security + self.port = None # Port will be dynamically assigned + # Model management self.models = self.get_mlx_models() self.current_model = None self.server_process = None - if self.subprocess: - self.start_mlx_server(self.valves.MLX_DEFAULT_MODEL) + # Start the MLX server with the default model + self.start_mlx_server(self.valves.MLX_DEFAULT_MODEL) def update_valves(self): + """Update pipeline configuration based on valve settings.""" if self.valves.HUGGINGFACE_TOKEN: login(self.valves.HUGGINGFACE_TOKEN) self.stop_sequence = self.valves.MLX_STOP.split(",") def get_mlx_models(self): + """Fetch available MLX models based on the specified pattern.""" try: cmd = [ 'mlx_lm.manage', @@ -65,11 +68,10 @@ def get_mlx_models(self): result = subprocess.run(cmd, capture_output=True, text=True) lines = result.stdout.strip().split('\n') - # Skip header lines and the line with dashes content_lines = [line for line in lines if line and not line.startswith('-')] models = [] - for line in content_lines[2:]: # Skip the first two lines (header) + for line in content_lines[2:]: # Skip header lines parts = line.split() if len(parts) >= 2: repo_id = parts[0] @@ -93,9 +95,11 @@ def get_mlx_models(self): }] def pipelines(self) -> List[dict]: + """Return the list of available models as pipelines.""" return self.models def start_mlx_server(self, model_name): + """Start the MLX server with the specified model.""" model_id = f"mlx.{model_name.split('/')[-1].lower()}" if self.current_model == model_id and self.server_process and self.server_process.poll() is None: logging.info(f"MLX server already running with model {model_name}") @@ -103,8 +107,7 @@ def start_mlx_server(self, model_name): self.stop_mlx_server() - if not os.getenv("MLX_PORT"): - self.port = self.find_free_port() + self.port = self.find_free_port() command = [ "mlx_lm.server", @@ -112,6 +115,7 @@ def start_mlx_server(self, model_name): "--port", str(self.port), ] + # Add chat template options if specified if self.valves.MLX_CHAT_TEMPLATE: command.extend(["--chat-template", self.valves.MLX_CHAT_TEMPLATE]) elif self.valves.MLX_USE_DEFAULT_CHAT_TEMPLATE: @@ -124,6 +128,7 @@ def start_mlx_server(self, model_name): time.sleep(5) # Give the server some time to start up def stop_mlx_server(self): + """Stop the currently running MLX server.""" if self.server_process: try: process = psutil.Process(self.server_process.pid) @@ -138,9 +143,11 @@ def stop_mlx_server(self): finally: self.server_process = None self.current_model = None - logging.info(f"Stopped MLX server on port {self.port}") + self.port = None + logging.info("Stopped MLX server") def find_free_port(self): + """Find and return a free port to use for the MLX server.""" import socket s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(("", 0)) @@ -149,23 +156,26 @@ def find_free_port(self): return port async def on_startup(self): + """Perform any necessary startup operations.""" logging.info(f"on_startup:{__name__}") async def on_shutdown(self): - if self.subprocess: - self.stop_mlx_server() + """Perform cleanup operations on shutdown.""" + self.stop_mlx_server() async def on_valves_updated(self): + """Handle updates to the pipeline configuration.""" self.update_valves() self.models = self.get_mlx_models() - if self.subprocess: - self.start_mlx_server(self.valves.MLX_DEFAULT_MODEL) + self.start_mlx_server(self.valves.MLX_DEFAULT_MODEL) def pipe( self, user_message: str, model_id: str, messages: List[dict], body: dict ) -> Union[str, Generator, Iterator]: + """Process a request through the MLX pipeline.""" logging.info(f"pipe:{__name__}") + # Switch model if necessary if model_id != self.current_model: model_name = next((model['name'] for model in self.models if model['id'] == model_id), self.valves.MLX_DEFAULT_MODEL) self.start_mlx_server(model_name) @@ -173,6 +183,7 @@ def pipe( url = f"http://{self.host}:{self.port}/v1/chat/completions" headers = {"Content-Type": "application/json"} + # Prepare the payload for the MLX server max_tokens = body.get("max_tokens", 4096) temperature = body.get("temperature", 0.8) repeat_penalty = body.get("repeat_penalty", 1.0) @@ -187,11 +198,13 @@ def pipe( } try: + # Send request to MLX server r = requests.post( url, headers=headers, json=payload, stream=body.get("stream", False) ) r.raise_for_status() + # Return streamed response or full JSON response if body.get("stream", False): return r.iter_lines() else: From f1dab7a6aaac19a8646cc1608ea07738c1afc769 Mon Sep 17 00:00:00 2001 From: Justin Hayes <52832301+justinh-rahb@users.noreply.github.com> Date: Mon, 1 Jul 2024 11:32:12 -0400 Subject: [PATCH 4/9] Cleanup unused imports --- examples/pipelines/providers/mlx_manifold_pipeline.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/pipelines/providers/mlx_manifold_pipeline.py b/examples/pipelines/providers/mlx_manifold_pipeline.py index e8c27d11..0e6822bf 100644 --- a/examples/pipelines/providers/mlx_manifold_pipeline.py +++ b/examples/pipelines/providers/mlx_manifold_pipeline.py @@ -12,13 +12,11 @@ from schemas import OpenAIChatMessage from pydantic import BaseModel import requests -import os import subprocess import logging from huggingface_hub import login import time import psutil -import json class Pipeline: class Valves(BaseModel): From c44fb71930dec976e415448ac707485abb4839a1 Mon Sep 17 00:00:00 2001 From: Justin Hayes <52832301+justinh-rahb@users.noreply.github.com> Date: Mon, 1 Jul 2024 12:29:42 -0400 Subject: [PATCH 5/9] Update default MLX model --- examples/pipelines/providers/mlx_manifold_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/pipelines/providers/mlx_manifold_pipeline.py b/examples/pipelines/providers/mlx_manifold_pipeline.py index 0e6822bf..5490f9ef 100644 --- a/examples/pipelines/providers/mlx_manifold_pipeline.py +++ b/examples/pipelines/providers/mlx_manifold_pipeline.py @@ -22,8 +22,8 @@ class Pipeline: class Valves(BaseModel): MLX_STOP: str = "<|start_header_id|>,<|end_header_id|>,<|eot_id|>" HUGGINGFACE_TOKEN: str = "" - MLX_MODEL_PATTERN: str = "meta-llama" - MLX_DEFAULT_MODEL: str = "meta-llama/Meta-Llama-3-8B-Instruct" + MLX_MODEL_PATTERN: str = "mlx-community" + MLX_DEFAULT_MODEL: str = "mlx-community/Meta-Llama-3-8B-Instruct-8bit" MLX_CHAT_TEMPLATE: str = "" MLX_USE_DEFAULT_CHAT_TEMPLATE: bool = False From f7ac26be13e1cbece44d34a644be6e413ac1a026 Mon Sep 17 00:00:00 2001 From: Justin Hayes <52832301+justinh-rahb@users.noreply.github.com> Date: Mon, 1 Jul 2024 23:24:32 -0400 Subject: [PATCH 6/9] cleanup: MLX valves --- examples/pipelines/providers/mlx_manifold_pipeline.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/pipelines/providers/mlx_manifold_pipeline.py b/examples/pipelines/providers/mlx_manifold_pipeline.py index 5490f9ef..0ddb83cd 100644 --- a/examples/pipelines/providers/mlx_manifold_pipeline.py +++ b/examples/pipelines/providers/mlx_manifold_pipeline.py @@ -20,12 +20,12 @@ class Pipeline: class Valves(BaseModel): - MLX_STOP: str = "<|start_header_id|>,<|end_header_id|>,<|eot_id|>" - HUGGINGFACE_TOKEN: str = "" - MLX_MODEL_PATTERN: str = "mlx-community" + MLX_MODEL_FILTER: str = "mlx-community" MLX_DEFAULT_MODEL: str = "mlx-community/Meta-Llama-3-8B-Instruct-8bit" + MLX_STOP: str = "<|start_header_id|>,<|end_header_id|>,<|eot_id|>" MLX_CHAT_TEMPLATE: str = "" MLX_USE_DEFAULT_CHAT_TEMPLATE: bool = False + HUGGINGFACE_TOKEN: str = "" def __init__(self): # Pipeline identification @@ -61,7 +61,7 @@ def get_mlx_models(self): cmd = [ 'mlx_lm.manage', '--scan', - '--pattern', self.valves.MLX_MODEL_PATTERN, + '--pattern', self.valves.MLX_MODEL_FILTER, ] result = subprocess.run(cmd, capture_output=True, text=True) lines = result.stdout.strip().split('\n') From 47de6f1bb13ac99025dfd99a341c7f1b846799e7 Mon Sep 17 00:00:00 2001 From: Justin Hayes <52832301+justinh-rahb@users.noreply.github.com> Date: Wed, 3 Jul 2024 10:54:53 -0400 Subject: [PATCH 7/9] Update mlx_manifold_pipeline.py --- examples/pipelines/providers/mlx_manifold_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/pipelines/providers/mlx_manifold_pipeline.py b/examples/pipelines/providers/mlx_manifold_pipeline.py index 0ddb83cd..86dfb432 100644 --- a/examples/pipelines/providers/mlx_manifold_pipeline.py +++ b/examples/pipelines/providers/mlx_manifold_pipeline.py @@ -20,8 +20,8 @@ class Pipeline: class Valves(BaseModel): - MLX_MODEL_FILTER: str = "mlx-community" MLX_DEFAULT_MODEL: str = "mlx-community/Meta-Llama-3-8B-Instruct-8bit" + MLX_MODEL_FILTER: str = "mlx-community" MLX_STOP: str = "<|start_header_id|>,<|end_header_id|>,<|eot_id|>" MLX_CHAT_TEMPLATE: str = "" MLX_USE_DEFAULT_CHAT_TEMPLATE: bool = False @@ -208,4 +208,4 @@ def pipe( else: return r.json() except Exception as e: - return f"Error: {e}" \ No newline at end of file + return f"Error: {e}" From 54e83986e3d5ccfacbc42028ff4a01aaae180fd3 Mon Sep 17 00:00:00 2001 From: Justin Hayes <52832301+justinh-rahb@users.noreply.github.com> Date: Wed, 3 Jul 2024 11:06:22 -0400 Subject: [PATCH 8/9] Fix: make some MLX valves optional --- examples/pipelines/providers/mlx_manifold_pipeline.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/pipelines/providers/mlx_manifold_pipeline.py b/examples/pipelines/providers/mlx_manifold_pipeline.py index 86dfb432..02b30ca1 100644 --- a/examples/pipelines/providers/mlx_manifold_pipeline.py +++ b/examples/pipelines/providers/mlx_manifold_pipeline.py @@ -23,9 +23,9 @@ class Valves(BaseModel): MLX_DEFAULT_MODEL: str = "mlx-community/Meta-Llama-3-8B-Instruct-8bit" MLX_MODEL_FILTER: str = "mlx-community" MLX_STOP: str = "<|start_header_id|>,<|end_header_id|>,<|eot_id|>" - MLX_CHAT_TEMPLATE: str = "" - MLX_USE_DEFAULT_CHAT_TEMPLATE: bool = False - HUGGINGFACE_TOKEN: str = "" + MLX_CHAT_TEMPLATE: str | None = None + MLX_USE_DEFAULT_CHAT_TEMPLATE: bool | None = None + HUGGINGFACE_TOKEN: str | None = None def __init__(self): # Pipeline identification From f34771c647d3ccc81109674f3a42cbeedc1c4e00 Mon Sep 17 00:00:00 2001 From: Justin Hayes <52832301+justinh-rahb@users.noreply.github.com> Date: Wed, 3 Jul 2024 11:13:07 -0400 Subject: [PATCH 9/9] Update mlx_manifold_pipeline.py --- examples/pipelines/providers/mlx_manifold_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pipelines/providers/mlx_manifold_pipeline.py b/examples/pipelines/providers/mlx_manifold_pipeline.py index 02b30ca1..26260902 100644 --- a/examples/pipelines/providers/mlx_manifold_pipeline.py +++ b/examples/pipelines/providers/mlx_manifold_pipeline.py @@ -24,7 +24,7 @@ class Valves(BaseModel): MLX_MODEL_FILTER: str = "mlx-community" MLX_STOP: str = "<|start_header_id|>,<|end_header_id|>,<|eot_id|>" MLX_CHAT_TEMPLATE: str | None = None - MLX_USE_DEFAULT_CHAT_TEMPLATE: bool | None = None + MLX_USE_DEFAULT_CHAT_TEMPLATE: bool | None = False HUGGINGFACE_TOKEN: str | None = None def __init__(self):