diff --git a/Dockerfile.base b/Dockerfile.base index 9e0db298..7f5c97c3 100644 --- a/Dockerfile.base +++ b/Dockerfile.base @@ -1,4 +1,4 @@ -FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04 +FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 ENV INSTALL_OPTIONAL=TRUE ENV MAX_JOBS=8 @@ -13,24 +13,28 @@ RUN DEBIAN_FRONTEND="noninteractive" TZ=Etc/UTC apt-get install -y \ ruby-full \ ruby-bundler \ build-essential \ - cmake \ pkg-config \ libicu-dev \ zlib1g-dev \ libcurl4-openssl-dev \ libssl-dev \ && rm -rf /var/lib/{apt,dpkg,cache,log} +RUN DEBIAN_FRONTEND="noninteractive" TZ=Etc/UTC apt remove cmake -y +RUN pip install cmake --upgrade + RUN git clone https://github.com/smallcloudai/linguist.git /tmp/linguist \ && cd /tmp/linguist \ && bundle install \ && rake build_gem ENV PATH="${PATH}:/tmp/linguist/bin" -RUN pip install --no-cache-dir torch==2.3.0 --index-url https://download.pytorch.org/whl/cu118 -RUN pip install --no-cache-dir xformers==0.0.26.post1 --index-url https://download.pytorch.org/whl/cu118 +RUN pip install --no-cache-dir torch==2.5.0 +RUN pip install --no-cache-dir xformers==v0.0.28.post2 RUN pip install ninja -RUN VLLM_INSTALL_PUNICA_KERNELS=1 pip install -v --no-build-isolation git+https://github.com/smallcloudai/vllm@refact_v0.4.2_06052024 +RUN pip install setuptools_scm +ENV CMAKE_ARGS="-DLLAMA_CUBLAS=on -DCMAKE_CUDA_ARCHITECTURES=60;61;70;75;80;86;89;90+PTX" +RUN pip install -v --no-build-isolation git+https://github.com/smallcloudai/vllm@refact_v0.6.3_2adb440 -# there is no prebuild auto-gptq with torch 2.3.0 support +# there is no prebuild auto-gptq with torch 2.5.0 support ENV TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0;7.5;8.0;8.6;8.9;9.0+PTX" RUN BUILD_CUDA_EXT=1 pip install -v --no-build-isolation git+https://github.com/PanQiWei/AutoGPTQ.git@v0.7.1 diff --git a/README.md b/README.md index ce9cedbd..d6a965fc 100644 --- a/README.md +++ b/README.md @@ -103,21 +103,35 @@ Extensions > Refact.ai Assistant > Settings > Infurl ## Supported models -| Model | Completion | Chat | Fine-tuning | [Deprecated](## "Will be removed in next versions") | -|---------------------------------------------------------------------------------------------------|------------|------|-------------|-----------------------------------------------------| -| [Refact/1.6B](https://huggingface.co/smallcloudai/Refact-1_6B-fim) | + | | + | | -| [starcoder2/3b/base](https://huggingface.co/bigcode/starcoder2-3b) | + | | + | | -| [starcoder2/7b/base](https://huggingface.co/bigcode/starcoder2-7b) | + | | + | | -| [starcoder2/15b/base](https://huggingface.co/bigcode/starcoder2-15b) | + | | + | | -| [deepseek-coder/1.3b/base](https://huggingface.co/deepseek-ai/deepseek-coder-1.3b-base) | + | | + | | -| [deepseek-coder/5.7b/mqa-base](https://huggingface.co/deepseek-ai/deepseek-coder-5.7bmqa-base) | + | | + | | -| [magicoder/6.7b](https://huggingface.co/TheBloke/Magicoder-S-DS-6.7B-GPTQ) | | + | | | -| [mistral/7b/instruct-v0.1](https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GPTQ) | | + | | | -| [mixtral/8x7b/instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) | | + | | | -| [deepseek-coder/6.7b/instruct](https://huggingface.co/TheBloke/deepseek-coder-6.7B-instruct-GPTQ) | | + | | | -| [deepseek-coder/33b/instruct](https://huggingface.co/deepseek-ai/deepseek-coder-33b-instruct) | | + | | | -| [stable/3b/code](https://huggingface.co/stabilityai/stable-code-3b) | + | | | | -| [llama3/8b/instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | | + | | | +| Model | Completion | Chat | Fine-tuning | [Deprecated](## "Will be removed in next versions") | +|---------------------------------------------------------------------------------------------------------|------------|------|-------------|-----------------------------------------------------| +| [Refact/1.6B](https://huggingface.co/smallcloudai/Refact-1_6B-fim) | + | | + | | +| [starcoder2/3b/base](https://huggingface.co/bigcode/starcoder2-3b) | + | | + | | +| [starcoder2/7b/base](https://huggingface.co/bigcode/starcoder2-7b) | + | | + | | +| [starcoder2/15b/base](https://huggingface.co/bigcode/starcoder2-15b) | + | | + | | +| [deepseek-coder/1.3b/base](https://huggingface.co/deepseek-ai/deepseek-coder-1.3b-base) | + | | + | | +| [deepseek-coder/5.7b/mqa-base](https://huggingface.co/deepseek-ai/deepseek-coder-5.7bmqa-base) | + | | + | | +| [magicoder/6.7b](https://huggingface.co/TheBloke/Magicoder-S-DS-6.7B-GPTQ) | | + | | + | +| [mistral/7b/instruct-v0.1](https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GPTQ) | | + | | + | +| [mixtral/8x7b/instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) | | + | | | +| [deepseek-coder/6.7b/instruct](https://huggingface.co/TheBloke/deepseek-coder-6.7B-instruct-GPTQ) | | + | | + | +| [deepseek-coder/33b/instruct](https://huggingface.co/deepseek-ai/deepseek-coder-33b-instruct) | | + | | | +| [stable/3b/code](https://huggingface.co/stabilityai/stable-code-3b) | + | | | | +| [llama3/8b/instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | + | + | | | +| [llama3.1/8b/instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) | + | + | | | +| [llama3.2/1b/instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct) | + | + | | | +| [llama3.2/3b/instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) | + | + | | | +| [qwen2.5/coder/0.5b/base](https://huggingface.co/Qwen/Qwen2.5-Coder-0.5B) | + | | + | | +| [qwen2.5/coder/1.5b/base](https://huggingface.co/Qwen/Qwen2.5-Coder-1.5B) | + | | + | | +| [qwen2.5/coder/3b/base](https://huggingface.co/Qwen/Qwen2.5-Coder-3B) | + | | + | | +| [qwen2.5/coder/7b/base](https://huggingface.co/Qwen/Qwen2.5-Coder-7B) | + | | + | | +| [qwen2.5/coder/14b/base](https://huggingface.co/Qwen/Qwen2.5-Coder-14B) | + | | + | | +| [qwen2.5/coder/32b/base](https://huggingface.co/Qwen/Qwen2.5-Coder-32B) | + | | + | | +| [qwen2.5/coder/1.5b/instruct](https://huggingface.co/Qwen/Qwen2.5-Coder-1.5B-Instruct) | + | + | | | +| [qwen2.5/coder/3b/instruct](https://huggingface.co/Qwen/Qwen2.5-Coder-3B-Instruct) | + | + | | | +| [qwen2.5/coder/7b/instruct](https://huggingface.co/Qwen/Qwen2.5-Coder-7B-Instruct) | + | + | | | +| [qwen2.5/coder/14b/instruct](https://huggingface.co/Qwen/Qwen2.5-Coder-14B-Instruct) | + | + | | | +| [qwen2.5/coder/32b/instruct](https://huggingface.co/Qwen/Qwen2.5-Coder-32B-Instruct) | + | + | | | ## Usage diff --git a/refact_known_models/huggingface.py b/refact_known_models/huggingface.py index 767dfbf9..b8a95f11 100644 --- a/refact_known_models/huggingface.py +++ b/refact_known_models/huggingface.py @@ -22,6 +22,7 @@ "required_memory_mb": 8000, "T": 4096, # in fact this model allows 16k context, but we have 4k context at max in hf inference "filter_caps": ["chat"], + "deprecated": True }, "mistral/7b/instruct-v0.1": { "backend": "autogptq", @@ -30,6 +31,7 @@ "required_memory_mb": 8000, "T": 4096, # in fact this model allows 8k context, but we have 4k context at max in hf inference "filter_caps": ["chat"], + "deprecated": True }, "mixtral/8x7b/instruct-v0.1": { "backend": "transformers", @@ -50,6 +52,7 @@ "required_memory_mb": 8000, "T": 4096, # in fact this model allows 16k context, but we have 4k context at max in hf inference "filter_caps": ["chat"], + "deprecated": True }, "deepseek-coder/33b/instruct": { "backend": "transformers", @@ -113,16 +116,126 @@ }, "required_memory_mb": 20000, "T": 8192, - "filter_caps": ["chat"], + "filter_caps": ["completion", "chat"], + }, + "llama3.1/8b/instruct": { + "backend": "transformers", + "model_path": "meta-llama/Llama-3.1-8B-Instruct", + "model_class_kwargs": { + "torch_dtype": "bf16", + }, + "required_memory_mb": 20000, + "T": 16384, # in fact this model can handle 128K context + "filter_caps": ["completion", "chat"], + }, + "llama3.2/3b/instruct": { + "backend": "transformers", + "model_path": "meta-llama/Llama-3.2-3B-Instruct", + "model_class_kwargs": { + "torch_dtype": "bf16", + }, + "required_memory_mb": 12000, + "T": 16384, # in fact this model can handle 128K context + "filter_caps": ["completion", "chat"], }, - "deepseek-coder-v2/16b/instruct": { + "llama3.2/1b/instruct": { "backend": "transformers", - "model_path": "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", + "model_path": "meta-llama/Llama-3.2-1B-Instruct", "model_class_kwargs": { "torch_dtype": "bf16", }, - "required_memory_mb": 80000, + "required_memory_mb": 8000, "T": 16384, # in fact this model can handle 128K context "filter_caps": ["completion", "chat"], }, + # qwen 2.5-coder instruct models + "qwen2.5/coder/32b/instruct": { + "backend": "transformers", + "model_path": "Qwen/Qwen2.5-Coder-32B-Instruct", + "model_class_kwargs": {}, + "required_memory_mb": 45000, + "T": 32768, + "filter_caps": ["completion", "chat"], + }, + "qwen2.5/coder/14b/instruct": { + "backend": "transformers", + "model_path": "Qwen/Qwen2.5-Coder-14B-Instruct", + "model_class_kwargs": {}, + "required_memory_mb": 45000, + "T": 32768, + "filter_caps": ["completion", "chat"], + }, + "qwen2.5/coder/7b/instruct": { + "backend": "transformers", + "model_path": "Qwen/Qwen2.5-Coder-7B-Instruct", + "model_class_kwargs": {}, + "required_memory_mb": 45000, + "T": 32768, + "filter_caps": ["completion", "chat"], + }, + "qwen2.5/coder/3b/instruct": { + "backend": "transformers", + "model_path": "Qwen/Qwen2.5-Coder-3B-Instruct", + "model_class_kwargs": {}, + "required_memory_mb": 45000, + "T": 32768, + "filter_caps": ["completion", "chat"], + }, + "qwen2.5/coder/1.5b/instruct": { + "backend": "transformers", + "model_path": "Qwen/Qwen2.5-Coder-1.5B-Instruct", + "model_class_kwargs": {}, + "required_memory_mb": 45000, + "T": 32768, + "filter_caps": ["completion", "chat"], + }, + # qwen 2.5-coder completion models + "qwen2.5/coder/32b/base": { + "backend": "transformers", + "model_path": "Qwen/Qwen2.5-Coder-32B", + "model_class_kwargs": {}, + "required_memory_mb": 45000, + "T": 32768, + "filter_caps": ["completion", "finetune"], + }, + "qwen2.5/coder/14b/base": { + "backend": "transformers", + "model_path": "Qwen/Qwen2.5-Coder-14B", + "model_class_kwargs": {}, + "required_memory_mb": 35000, + "T": 32768, + "filter_caps": ["completion", "finetune"], + }, + "qwen2.5/coder/7b/base": { + "backend": "transformers", + "model_path": "Qwen/Qwen2.5-Coder-7B", + "model_class_kwargs": {}, + "required_memory_mb": 20000, + "T": 32768, + "filter_caps": ["completion", "finetune"], + }, + "qwen2.5/coder/3b/base": { + "backend": "transformers", + "model_path": "Qwen/Qwen2.5-Coder-3B", + "model_class_kwargs": {}, + "required_memory_mb": 15000, + "T": 32768, + "filter_caps": ["completion", "finetune"], + }, + "qwen2.5/coder/1.5b/base": { + "backend": "transformers", + "model_path": "Qwen/Qwen2.5-Coder-1.5B", + "model_class_kwargs": {}, + "required_memory_mb": 10000, + "T": 32768, + "filter_caps": ["completion", "finetune"], + }, + "qwen2.5/coder/0.5b/base": { + "backend": "transformers", + "model_path": "Qwen/Qwen2.5-Coder-0.5B", + "model_class_kwargs": {}, + "required_memory_mb": 7000, + "T": 32768, + "filter_caps": ["completion", "finetune"], + }, } diff --git a/refact_known_models/passthrough.py b/refact_known_models/passthrough.py index 183cc440..a2f6f0a7 100644 --- a/refact_known_models/passthrough.py +++ b/refact_known_models/passthrough.py @@ -10,7 +10,7 @@ "T_out": 4096, "pp1000t_prompt": 5_000, "pp1000t_generated": 15_000, # $15.00 / 1M tokens (2024 may) - "filter_caps": ["chat", "tools"], + "filter_caps": ["chat", "tools", "completion"], }, "gpt-4-turbo": { "backend": "litellm", @@ -21,18 +21,7 @@ "T_out": 4096, "pp1000t_prompt": 10_000, "pp1000t_generated": 30_000, # $30.00 / 1M tokens (2024 may) - "filter_caps": ["chat", "tools"], - }, - "gpt-4": { - "backend": "litellm", - "provider": "openai", - "tokenizer_path": "Xenova/gpt-4", - "resolve_as": "gpt-4-0125-preview", - "T": 128_000, - "T_out": 4096, - "pp1000t_prompt": 10_000, - "pp1000t_generated": 30_000, - "filter_caps": ["chat", "tools"], + "filter_caps": ["chat", "tools", "completion"], }, "gpt-3.5-turbo": { "backend": "litellm", @@ -43,7 +32,7 @@ "T_out": 4096, "pp1000t_prompt": 1000, "pp1000t_generated": 2000, - "filter_caps": ["chat", "tools"], + "filter_caps": ["chat", "tools", "completion"], }, "claude-3-5-sonnet": { "backend": "litellm", @@ -54,7 +43,7 @@ "T_out": 4096, "pp1000t_prompt": 3_000, # $3.00 / 1M tokens (2024 jun) "pp1000t_generated": 15_000, # $15.00 / 1M tokens (2024 jun) - "filter_caps": ["chat", "tools"], + "filter_caps": ["chat", "tools", "completion"], }, "claude-3-haiku": { "backend": "litellm", @@ -65,7 +54,7 @@ "T_out": 4096, "pp1000t_prompt": 250, "pp1000t_generated": 1_250, - "filter_caps": ["chat", "tools"], + "filter_caps": ["chat", "tools", "completion"], }, "claude-3-opus": { "backend": "litellm", @@ -76,7 +65,7 @@ "T_out": 4096, "pp1000t_prompt": 15_000, "pp1000t_generated": 75_000, - "filter_caps": ["chat", "tools"], + "filter_caps": ["chat", "tools", "completion"], }, "claude-3-sonnet": { "backend": "litellm", @@ -87,7 +76,7 @@ "T_out": 4096, "pp1000t_prompt": 3_000, "pp1000t_generated": 15_000, - "filter_caps": ["chat", "tools"], + "filter_caps": ["chat", "tools", "completion"], }, "gpt-4o-2024-05-13": { "backend": "litellm", @@ -98,7 +87,7 @@ "T_out": 4096, "pp1000t_prompt": 5_000, "pp1000t_generated": 15_000, # $15.00 / 1M tokens - "filter_caps": ["chat", "tools"], + "filter_caps": ["chat", "tools", "completion"], }, "gpt-4o-2024-08-06": { "backend": "litellm", @@ -109,7 +98,7 @@ "T_out": 4096, "pp1000t_prompt": 2_500, "pp1000t_generated": 10_000, # $15.00 / 1M tokens - "filter_caps": ["chat", "tools"] + "filter_caps": ["chat", "tools", "completion"] }, "gpt-4o-mini": { "backend": "litellm", @@ -120,7 +109,7 @@ "T_out": 4096, "pp1000t_prompt": 150, "pp1000t_generated": 600, # $0.60 / 1M tokens - "filter_caps": ["chat", "tools"], + "filter_caps": ["chat", "tools", "completion"], }, "claude-3-5-sonnet-20241022": { "backend": "litellm", @@ -131,6 +120,94 @@ "T_out": 4096, "pp1000t_prompt": 3_000, # $3.00 / 1M tokens (2024 oct) "pp1000t_generated": 15_000, # $15.00 / 1M tokens (2024 oct) - "filter_caps": ["chat", "tools"], + "filter_caps": ["chat", "tools", "completion"], + }, + "groq-llama-3.1-8b": { + "backend": "litellm", + "provider": "groq", + "tokenizer_path": "Xenova/Meta-Llama-3.1-Tokenizer", + "resolve_as": "groq/llama-3.1-8b-instant", + "T": 128_000, + "T_out": 8000, + "pp1000t_prompt": 150, + "pp1000t_generated": 600, # TODO: don't know the price + "filter_caps": ["chat", "completion"], + }, + "groq-llama-3.1-70b": { + "backend": "litellm", + "provider": "groq", + "tokenizer_path": "Xenova/Meta-Llama-3.1-Tokenizer", + "resolve_as": "groq/llama-3.1-70b-versatile", + "T": 128_000, + "T_out": 8000, + "pp1000t_prompt": 150, + "pp1000t_generated": 600, # TODO: don't know the price + "filter_caps": ["chat", "completion"], + }, + "groq-llama-3.2-1b": { + "backend": "litellm", + "provider": "groq", + "tokenizer_path": "Xenova/Meta-Llama-3.1-Tokenizer", + "resolve_as": "groq/llama-3.2-1b-preview", + "T": 128_000, + "T_out": 8000, + "pp1000t_prompt": 150, + "pp1000t_generated": 600, # TODO: don't know the price + "filter_caps": ["chat", "completion"], }, + "groq-llama-3.2-3b": { + "backend": "litellm", + "provider": "groq", + "tokenizer_path": "Xenova/Meta-Llama-3.1-Tokenizer", + "resolve_as": "groq/llama-3.2-3b-preview", + "T": 128_000, + "T_out": 8000, + "pp1000t_prompt": 150, + "pp1000t_generated": 600, # TODO: don't know the price + "filter_caps": ["chat", "completion"], + }, + "groq-llama-3.2-11b-vision": { + "backend": "litellm", + "provider": "groq", + "tokenizer_path": "Xenova/Meta-Llama-3.1-Tokenizer", + "resolve_as": "groq/llama-3.2-11b-vision-preview", + "T": 128_000, + "T_out": 8000, + "pp1000t_prompt": 150, + "pp1000t_generated": 600, # TODO: don't know the price + "filter_caps": ["chat", "completion"], + }, + "groq-llama-3.2-90b-vision": { + "backend": "litellm", + "provider": "groq", + "tokenizer_path": "Xenova/Meta-Llama-3.1-Tokenizer", + "resolve_as": "groq/llama-3.2-90b-vision-preview", + "T": 128_000, + "T_out": 8000, + "pp1000t_prompt": 150, + "pp1000t_generated": 600, # TODO: don't know the price + "filter_caps": ["chat", "completion"], + }, + "cerebras-llama3.1-8b": { + "backend": "litellm", + "provider": "cerebras", + "tokenizer_path": "Xenova/Meta-Llama-3.1-Tokenizer", + "resolve_as": "cerebras/llama3.1-8b", + "T": 8192, + "T_out": 4096, + "pp1000t_prompt": 150, + "pp1000t_generated": 600, # TODO: don't know the price + "filter_caps": ["chat", "completion"], + }, + "cerebras-llama3.1-70b": { + "backend": "litellm", + "provider": "cerebras", + "tokenizer_path": "Xenova/Meta-Llama-3.1-Tokenizer", + "resolve_as": "cerebras/llama3.1-70b", + "T": 8192, + "T_out": 4096, + "pp1000t_prompt": 150, + "pp1000t_generated": 600, # TODO: don't know the price + "filter_caps": ["chat", "completion"], + } } diff --git a/refact_utils/finetune/utils.py b/refact_utils/finetune/utils.py index edec5f26..a196f050 100644 --- a/refact_utils/finetune/utils.py +++ b/refact_utils/finetune/utils.py @@ -100,6 +100,8 @@ def _add_results_for_passthrough_provider(provider: str) -> None: for k, v in model_assigner.passthrough_mini_db.items(): if v.get('provider') == provider: result['chat'].append(k) + if 'completion' in v.get('filter_caps', []): + result['completion'].append(k) if data.get("openai_api_enable"): _add_results_for_passthrough_provider('openai') @@ -107,6 +109,12 @@ def _add_results_for_passthrough_provider(provider: str) -> None: if data.get('anthropic_api_enable'): _add_results_for_passthrough_provider('anthropic') + if data.get('cerebras_api_enable'): + _add_results_for_passthrough_provider('cerebras') + + if data.get('groq_api_enable'): + _add_results_for_passthrough_provider('groq') + for k, v in data.get("model_assign", {}).items(): if model_dict := [d for d in data['models'] if d['name'] == k]: model_dict = model_dict[0] diff --git a/refact_webgui/webgui/selfhost_fastapi_completions.py b/refact_webgui/webgui/selfhost_fastapi_completions.py index f9d62e34..667c9ca5 100644 --- a/refact_webgui/webgui/selfhost_fastapi_completions.py +++ b/refact_webgui/webgui/selfhost_fastapi_completions.py @@ -231,6 +231,8 @@ def _integrations_env_setup(env_var_name: str, api_key_name: str, api_enable_nam litellm.modify_params = True # NOTE: for Anthropic API _integrations_env_setup("OPENAI_API_KEY", "openai_api_key", "openai_api_enable") _integrations_env_setup("ANTHROPIC_API_KEY", "anthropic_api_key", "anthropic_api_enable") + _integrations_env_setup("GROQ_API_KEY", "groq_api_key", "groq_api_enable") + _integrations_env_setup("CEREBRAS_API_KEY", "cerebras_api_key", "cerebras_api_enable") def _models_available_dict_rewrite(self, models_available: List[str]) -> Dict[str, Any]: rewrite_dict = {} @@ -248,6 +250,7 @@ def _caps_base_data(self) -> Dict[str, Any]: running = running_models_and_loras(self._model_assigner) models_available = self._inference_queue.models_available(force_read=True) code_completion_default_model, _ = self._inference_queue.completion_model() + multiline_code_completion_default_model, _ = self._inference_queue.multiline_completion_default_model() code_chat_default_model = "" embeddings_default_model = "" for model_name in models_available: @@ -267,6 +270,7 @@ def _caps_base_data(self) -> Dict[str, Any]: "telemetry_basic_retrieve_my_own": "/stats/rh-stats", "running_models": [r for r in [*running['completion'], *running['chat']]], "code_completion_default_model": code_completion_default_model, + "multiline_code_completion_default_model": multiline_code_completion_default_model, "code_chat_default_model": code_chat_default_model, "models_dict_patch": self._models_available_dict_rewrite(models_available), @@ -303,6 +307,10 @@ def _select_default_lora_if_exists(model_name: str, running_models: List[str]): data["code_completion_default_model"], running['completion'], ) + data["multiline_code_completion_default_model"] = _select_default_lora_if_exists( + data["multiline_code_completion_default_model"], + running['completion'], + ) data["code_chat_default_model"] = _select_default_lora_if_exists( data["code_chat_default_model"], running['chat'], @@ -609,7 +617,10 @@ async def chat_completion_streamer(): log(err_msg) yield prefix + json.dumps({"error": err_msg}) + postfix - if model_dict.get('backend') == 'litellm' and (model_name := model_dict.get('resolve_as', post.model)) in litellm.model_list: + if model_dict.get('backend') == 'litellm': + model_name = model_dict.get('resolve_as', post.model) + if model_name not in litellm.model_list: + log(f"warning: requested model {model_name} is not in the litellm.model_list (this might not be the issue for some providers)") log(f"chat/completions: model resolve {post.model} -> {model_name}") prompt_tokens_n = litellm.token_counter(model_name, messages=messages) if post.tools: diff --git a/refact_webgui/webgui/selfhost_model_assigner.py b/refact_webgui/webgui/selfhost_model_assigner.py index 856cecea..7709854e 100644 --- a/refact_webgui/webgui/selfhost_model_assigner.py +++ b/refact_webgui/webgui/selfhost_model_assigner.py @@ -184,6 +184,8 @@ def first_run(self): }, "openai_api_enable": False, "anthropic_api_enable": False, + "groq_api_enable": False, + "cerebras_api_enable": False, } self.models_to_watchdog_configs(default_config) @@ -255,6 +257,8 @@ def models_info(self): def model_assignment(self): if os.path.exists(env.CONFIG_INFERENCE): j = json.load(open(env.CONFIG_INFERENCE, "r")) + j["groq_api_enable"] = j.get("groq_api_enable", False) + j["cerebras_api_enable"] = j.get("cerebras_api_enable", False) else: j = {"model_assign": {}} diff --git a/refact_webgui/webgui/selfhost_queue.py b/refact_webgui/webgui/selfhost_queue.py index fb365c15..8dfafb27 100644 --- a/refact_webgui/webgui/selfhost_queue.py +++ b/refact_webgui/webgui/selfhost_queue.py @@ -64,6 +64,10 @@ def _add_models_for_passthrough_provider(provider): _add_models_for_passthrough_provider('openai') if j.get("anthropic_api_enable"): _add_models_for_passthrough_provider('anthropic') + if j.get("groq_api_enable"): + _add_models_for_passthrough_provider('groq') + if j.get("cerebras_api_enable"): + _add_models_for_passthrough_provider('cerebras') return self._models_available @@ -76,3 +80,14 @@ def completion_model(self) -> Tuple[str, str]: return model, "" return "", f"completion model is not set" + + + def multiline_completion_default_model(self) -> Tuple[str, str]: + + if os.path.exists(env.CONFIG_INFERENCE): + j = json.load(open(env.CONFIG_INFERENCE, 'r')) + for model in j["model_assign"]: + if "completion" in self._model_assigner.models_db.get(model, {}).get("filter_caps", {}): + return model, "" + + return "", f"completion model is not set" diff --git a/refact_webgui/webgui/static/tab-finetune.html b/refact_webgui/webgui/static/tab-finetune.html index ea0a6b1a..b273a6dd 100644 --- a/refact_webgui/webgui/static/tab-finetune.html +++ b/refact_webgui/webgui/static/tab-finetune.html @@ -191,7 +191,7 @@
- +
diff --git a/refact_webgui/webgui/static/tab-model-hosting.html b/refact_webgui/webgui/static/tab-model-hosting.html index 48714516..6ecc86d7 100644 --- a/refact_webgui/webgui/static/tab-model-hosting.html +++ b/refact_webgui/webgui/static/tab-model-hosting.html @@ -38,6 +38,14 @@

3rd Party APIs

+
+ + +
+
+ + +
To enable Chat GPT add your API key in the API Keys tab.
diff --git a/refact_webgui/webgui/static/tab-model-hosting.js b/refact_webgui/webgui/static/tab-model-hosting.js index eb6affca..dc2f36d7 100644 --- a/refact_webgui/webgui/static/tab-model-hosting.js +++ b/refact_webgui/webgui/static/tab-model-hosting.js @@ -117,6 +117,8 @@ function get_models() integration_switch_init('enable_chat_gpt', models_data['openai_api_enable']); integration_switch_init('enable_anthropic', models_data['anthropic_api_enable']); + integration_switch_init('enable_groq', models_data['groq_api_enable']); + integration_switch_init('enable_cerebras', models_data['cerebras_api_enable']); const more_gpus_notification = document.querySelector('.model-hosting-error'); if(data.hasOwnProperty('more_models_than_gpus') && data.more_models_than_gpus) { @@ -140,12 +142,16 @@ function get_models() function save_model_assigned() { const openai_enable = document.querySelector('#enable_chat_gpt'); const anthropic_enable = document.querySelector('#enable_anthropic'); + const groq_enable = document.querySelector('#enable_groq'); + const cerebras_enable = document.querySelector('#enable_cerebras'); const data = { model_assign: { ...models_data.model_assign, }, openai_api_enable: openai_enable.checked, anthropic_api_enable: anthropic_enable.checked, + groq_api_enable: groq_enable.checked, + cerebras_api_enable: cerebras_enable.checked, }; console.log(data); fetch("/tab-host-models-assign", { diff --git a/refact_webgui/webgui/static/tab-settings.html b/refact_webgui/webgui/static/tab-settings.html index 1d699342..18a730b7 100644 --- a/refact_webgui/webgui/static/tab-settings.html +++ b/refact_webgui/webgui/static/tab-settings.html @@ -6,6 +6,10 @@

API Integrations

+ + + + diff --git a/refact_webgui/webgui/static/tab-settings.js b/refact_webgui/webgui/static/tab-settings.js index 0ade4005..c9597f2e 100644 --- a/refact_webgui/webgui/static/tab-settings.js +++ b/refact_webgui/webgui/static/tab-settings.js @@ -172,6 +172,8 @@ function throw_int_saved_success_toast(msg) { function save_integration_api_keys() { const openai_api_key = document.getElementById('openai_api_key'); const anthropic_api_key = document.getElementById('anthropic_api_key'); + const groq_api_key = document.getElementById('groq_api_key'); + const cerebras_api_key = document.getElementById('cerebras_api_key'); const huggingface_api_key = document.getElementById('huggingface_api_key'); fetch("/tab-settings-integrations-save", { method: "POST", @@ -181,6 +183,8 @@ function save_integration_api_keys() { body: JSON.stringify({ openai_api_key: openai_api_key.getAttribute('data-value'), anthropic_api_key: anthropic_api_key.getAttribute('data-value'), + groq_api_key: groq_api_key.getAttribute('data-value'), + cerebras_api_key: cerebras_api_key.getAttribute('data-value'), huggingface_api_key: huggingface_api_key.getAttribute('data-value'), }) }) @@ -189,6 +193,8 @@ function save_integration_api_keys() { throw_int_saved_success_toast('API Key saved') openai_api_key.setAttribute('data-saved-value', openai_api_key.getAttribute('data-value')) anthropic_api_key.setAttribute('data-saved-value', anthropic_api_key.getAttribute('data-value')) + groq_api_key.setAttribute('data-saved-value', groq_api_key.getAttribute('data-value')) + cerebras_api_key.setAttribute('data-saved-value', cerebras_api_key.getAttribute('data-value')) huggingface_api_key.setAttribute('data-saved-value', huggingface_api_key.getAttribute('data-value')) }); } @@ -222,6 +228,8 @@ export function tab_settings_integrations_get() { .then(function(data) { integrations_input_init(document.getElementById('openai_api_key'), data['openai_api_key']); integrations_input_init(document.getElementById('anthropic_api_key'), data['anthropic_api_key']); + integrations_input_init(document.getElementById('groq_api_key'), data['groq_api_key']); + integrations_input_init(document.getElementById('cerebras_api_key'), data['cerebras_api_key']); integrations_input_init(document.getElementById('huggingface_api_key'), data['huggingface_api_key']); }); } diff --git a/refact_webgui/webgui/tab_models_host.py b/refact_webgui/webgui/tab_models_host.py index af8068a8..2f1e241b 100644 --- a/refact_webgui/webgui/tab_models_host.py +++ b/refact_webgui/webgui/tab_models_host.py @@ -42,6 +42,8 @@ class TabHostModelsAssign(BaseModel): # integrations openai_api_enable: bool = False anthropic_api_enable: bool = False + groq_api_enable: bool = False + cerebras_api_enable: bool = False model_config = ConfigDict(protected_namespaces=()) # avoiding model_ namespace protection diff --git a/refact_webgui/webgui/tab_settings.py b/refact_webgui/webgui/tab_settings.py index ec365d7c..3be0a3f5 100644 --- a/refact_webgui/webgui/tab_settings.py +++ b/refact_webgui/webgui/tab_settings.py @@ -22,6 +22,8 @@ class SSHKey(BaseModel): class Integrations(BaseModel): openai_api_key: Optional[str] = None anthropic_api_key: Optional[str] = None + groq_api_key: Optional[str] = None + cerebras_api_key: Optional[str] = None huggingface_api_key: Optional[str] = None def __init__(self, models_assigner: ModelAssigner, *args, **kwargs): diff --git a/self_hosting_machinery/finetune/configuration/supported_models.py b/self_hosting_machinery/finetune/configuration/supported_models.py index 97badd59..8bb5a529 100644 --- a/self_hosting_machinery/finetune/configuration/supported_models.py +++ b/self_hosting_machinery/finetune/configuration/supported_models.py @@ -90,6 +90,36 @@ ], "force_enable_checkpointing": False } +_qwen_base = { + "lora_target_modules_mapping": { + "qkv": ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"], + "out": ["self_attn.o_proj"], + "backproj": ["self_attn.o_proj"], + "mlp": ["mlp.gate_proj", "mlp.up_proj", "mlp.down_proj"], + }, + "freeze_exceptions_mapping": { + "wte": ["embed_tokens"], + "lm_head": ["lm_head"], + "lora": ["lora"] + }, + "tokenizer": { + "eot_idx": 151643, # `<|endoftext|>` + "padding_idx": 151662, # `<|fim_pad|>` + "fim_prefix": 151659, # `<|fim_prefix|>` + "fim_middle": 151660, # `<|fim_middle|>` + "fim_suffix": 151661, # `<|fim_suffix|>` + "escape": 32013, # using `<|begin▁of▁sentence|>` token for now + }, + "train_ds_pipeline": { + "ds_opts": f"{_fim_train_ds_pipeline['ds_opts']},spm_prob=0.0", + "ds_name": _fim_train_ds_pipeline["ds_name"] + }, + "test_ds_pipeline": _fim_test_ds_pipeline, + "train_model_modifiers": [ + "flash_sa.apply_flash_mha_to_codellama_model" + ], + "force_enable_checkpointing": False +} config = { "Refact/1.6B": { @@ -179,5 +209,31 @@ "deepseek-coder/6.7b/base": { **_deepseek_base, "force_enable_checkpointing": True + }, + + # qwen models + "qwen2.5/coder/32b/base": { + **_qwen_base, + "force_enable_checkpointing": True + }, + "qwen2.5/coder/14b/base": { + **_qwen_base, + "force_enable_checkpointing": True + }, + "qwen2.5/coder/7b/base": { + **_qwen_base, + "force_enable_checkpointing": True + }, + "qwen2.5/coder/3b/base": { + **_qwen_base, + "force_enable_checkpointing": False + }, + "qwen2.5/coder/1.5b/base": { + **_qwen_base, + "force_enable_checkpointing": False + }, + "qwen2.5/coder/0.5b/base": { + **_qwen_base, + "force_enable_checkpointing": False } } diff --git a/self_hosting_machinery/finetune/modelling/flash_sa.py b/self_hosting_machinery/finetune/modelling/flash_sa.py index d7b8b483..b00e9db7 100644 --- a/self_hosting_machinery/finetune/modelling/flash_sa.py +++ b/self_hosting_machinery/finetune/modelling/flash_sa.py @@ -212,8 +212,8 @@ def _forward( k = einops.rearrange(k, "b t (h d) -> b h t d", h=self.num_key_value_heads) v = einops.rearrange(v, "b t (h d) -> b t h d", h=self.num_key_value_heads) - cos, sin = self.rotary_emb(v, seq_len=k.shape[-2]) - q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids) + cos, sin = self.rotary_emb(v, position_ids) + q, k = apply_rotary_pos_emb(q, k, cos, sin) q = einops.rearrange(q, "b h t d -> b t h d") k = einops.rearrange(k, "b h t d -> b t h d") diff --git a/self_hosting_machinery/inference/inference_hf.py b/self_hosting_machinery/inference/inference_hf.py index 5646886d..d12f996b 100644 --- a/self_hosting_machinery/inference/inference_hf.py +++ b/self_hosting_machinery/inference/inference_hf.py @@ -149,38 +149,32 @@ def __init__(self, self._device = "cuda:0" token = huggingface_hub_token() - for local_files_only in [True, False]: - try: - logging.getLogger("MODEL").info("loading model local_files_only=%i" % local_files_only) - self._tokenizer = AutoTokenizer.from_pretrained( - self._model_dict["model_path"], cache_dir=self.cache_dir, trust_remote_code=True, - local_files_only=local_files_only, token=token, - ) - if model_dict["backend"] == "transformers": - torch_dtype_mapping = { - "auto": "auto", - "fp16": torch.float16, - "bf16": torch.bfloat16, - } - torch_dtype = self._model_dict["model_class_kwargs"].pop("torch_dtype", "auto") - torch_dtype = torch_dtype_mapping[torch_dtype] - self._model = AutoModelForCausalLM.from_pretrained( - self._model_dict["model_path"], cache_dir=self.cache_dir, - device_map="auto", torch_dtype=torch_dtype, trust_remote_code=True, - local_files_only=local_files_only, token=token, - **self._model_dict["model_class_kwargs"]) - elif model_dict["backend"] == "autogptq": - self._model = CustomAutoGPTQForCausalLM.from_quantized( - self._model_dict["model_path"], cache_dir=self.cache_dir, device=self._device, - trust_remote_code=True, - local_files_only=local_files_only, token=token, - **self._model_dict["model_class_kwargs"]) - else: - raise RuntimeError(f"unknown model backend {model_dict['backend']}") - break - except IOError as e: - if local_files_only == False: - raise e + logging.getLogger("MODEL").info("loading model") + self._tokenizer = AutoTokenizer.from_pretrained( + self._model_dict["model_path"], cache_dir=self.cache_dir, + trust_remote_code=True, token=token, + ) + if model_dict["backend"] == "transformers": + torch_dtype_mapping = { + "auto": "auto", + "fp16": torch.float16, + "bf16": torch.bfloat16, + } + torch_dtype = self._model_dict["model_class_kwargs"].pop("torch_dtype", "auto") + torch_dtype = torch_dtype_mapping[torch_dtype] + self._model = AutoModelForCausalLM.from_pretrained( + self._model_dict["model_path"], cache_dir=self.cache_dir, + device_map="auto", torch_dtype=torch_dtype, trust_remote_code=True, + token=token, **self._model_dict["model_class_kwargs"] + ) + elif model_dict["backend"] == "autogptq": + self._model = CustomAutoGPTQForCausalLM.from_quantized( + self._model_dict["model_path"], cache_dir=self.cache_dir, device=self._device, + trust_remote_code=True, token=token, + **self._model_dict["model_class_kwargs"] + ) + else: + raise RuntimeError(f"unknown model backend {model_dict['backend']}") self._dump_embeddings() @property diff --git a/setup.py b/setup.py index 775eb623..fef2a05b 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ class PyPackage: "refact_known_models": PyPackage(), "refact_utils": PyPackage(), "refact_data_pipeline": PyPackage( - requires=["numpy", "tokenizers>=0.15.0", "torch", "requests>=2.31.0", "cloudpickle", "blobfile", + requires=["numpy", "tokenizers>=0.20.1", "torch", "requests>=2.31.0", "cloudpickle", "blobfile", "tqdm", "dataclasses_json", "termcolor", 'more_itertools', "cdifflib", "ujson", "zstandard", "scipy", "einops", "matplotlib", "giturlparse", "jsonlines", "binpacking", "filelock", "tables==3.8.0", "pygments", "kshingle"], @@ -41,10 +41,10 @@ class PyPackage: "webgui/static/dashboards/*", "webgui/static/assets/*", "webgui/static/utils/*",]), "self_hosting_machinery": PyPackage( requires=["python-multipart", "auto-gptq==0.7.1", "accelerate", - "termcolor", "torch", "transformers>=4.39.3", + "termcolor", "torch", "transformers>=4.46.0", "bitsandbytes", "safetensors", "peft", "triton", - "torchinfo", "mpi4py", "deepspeed==0.14.2", - "sentence-transformers", "huggingface-hub>=0.19.3", + "torchinfo", "mpi4py", "deepspeed>=0.15.3", + "sentence-transformers", "huggingface-hub>=0.26.2", "aiohttp", "setproctitle"], optional=["ninja", "flash-attn"], requires_packages=["refact_known_models", "refact_data_pipeline", @@ -91,7 +91,7 @@ def get_install_requires(packages): setup( name="refact-self-hosting", - version="1.7.0", + version="1.8.0", py_modules=list(setup_packages.keys()), package_data={ name: py_package.data