Skip to content

Commit 8c431fb

Browse files
authored
Merge pull request #3 from CreditMutuelArkea/fixes_doc_update
🔧 Run configuration for PyCharm
2 parents f422cb2 + 90c1908 commit 8c431fb

6 files changed

+153
-12
lines changed
+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
<component name="ProjectRunConfigurationManager">
2+
<configuration default="false" name="llm_inference - EMBEDDING" type="PythonConfigurationType" factoryName="Python">
3+
<module name="llm-inference" />
4+
<option name="INTERPRETER_OPTIONS" value="" />
5+
<option name="PARENT_ENVS" value="true" />
6+
<envs>
7+
<env name="PYTHONUNBUFFERED" value="1" />
8+
</envs>
9+
<option name="SDK_HOME" value="$PROJECT_DIR$/venv/bin/python" />
10+
<option name="SDK_NAME" value="Python 3.9 (llm-inference)" />
11+
<option name="WORKING_DIRECTORY" value="" />
12+
<option name="IS_MODULE_SDK" value="false" />
13+
<option name="ADD_CONTENT_ROOTS" value="true" />
14+
<option name="ADD_SOURCE_ROOTS" value="true" />
15+
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
16+
<EXTENSION ID="net.ashald.envfile">
17+
<option name="IS_ENABLED" value="true" />
18+
<option name="IS_SUBST" value="false" />
19+
<option name="IS_PATH_MACRO_SUPPORTED" value="false" />
20+
<option name="IS_IGNORE_MISSING_FILES" value="false" />
21+
<option name="IS_ENABLE_EXPERIMENTAL_INTEGRATIONS" value="false" />
22+
<ENTRIES>
23+
<ENTRY IS_ENABLED="true" PARSER="runconfig" IS_EXECUTABLE="false" />
24+
<ENTRY IS_ENABLED="true" PARSER="env" IS_EXECUTABLE="false" PATH=".env" />
25+
</ENTRIES>
26+
</EXTENSION>
27+
<option name="SCRIPT_NAME" value="llm_inference" />
28+
<option name="PARAMETERS" value="--task EMBEDDING --port 8081 --model cmarkea/bloomz-560m-retriever-v2" />
29+
<option name="SHOW_COMMAND_LINE" value="false" />
30+
<option name="EMULATE_TERMINAL" value="false" />
31+
<option name="MODULE_MODE" value="true" />
32+
<option name="REDIRECT_INPUT" value="false" />
33+
<option name="INPUT_FILE" value="" />
34+
<method v="2" />
35+
</configuration>
36+
</component>
+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
<component name="ProjectRunConfigurationManager">
2+
<configuration default="false" name="llm_inference - GUARDRAIL" type="PythonConfigurationType" factoryName="Python">
3+
<module name="llm-inference" />
4+
<option name="INTERPRETER_OPTIONS" value="" />
5+
<option name="PARENT_ENVS" value="true" />
6+
<envs>
7+
<env name="PYTHONUNBUFFERED" value="1" />
8+
</envs>
9+
<option name="SDK_HOME" value="$PROJECT_DIR$/venv/bin/python" />
10+
<option name="SDK_NAME" value="Python 3.9 (llm-inference)" />
11+
<option name="WORKING_DIRECTORY" value="" />
12+
<option name="IS_MODULE_SDK" value="false" />
13+
<option name="ADD_CONTENT_ROOTS" value="true" />
14+
<option name="ADD_SOURCE_ROOTS" value="true" />
15+
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
16+
<EXTENSION ID="net.ashald.envfile">
17+
<option name="IS_ENABLED" value="true" />
18+
<option name="IS_SUBST" value="false" />
19+
<option name="IS_PATH_MACRO_SUPPORTED" value="false" />
20+
<option name="IS_IGNORE_MISSING_FILES" value="false" />
21+
<option name="IS_ENABLE_EXPERIMENTAL_INTEGRATIONS" value="false" />
22+
<ENTRIES>
23+
<ENTRY IS_ENABLED="true" PARSER="runconfig" IS_EXECUTABLE="false" />
24+
<ENTRY IS_ENABLED="true" PARSER="env" IS_EXECUTABLE="false" PATH=".env" />
25+
</ENTRIES>
26+
</EXTENSION>
27+
<option name="SCRIPT_NAME" value="llm_inference" />
28+
<option name="PARAMETERS" value="--task GUARDRAIL --port 8083 --model cmarkea/bloomz-560m-guardrail" />
29+
<option name="SHOW_COMMAND_LINE" value="false" />
30+
<option name="EMULATE_TERMINAL" value="false" />
31+
<option name="MODULE_MODE" value="true" />
32+
<option name="REDIRECT_INPUT" value="false" />
33+
<option name="INPUT_FILE" value="" />
34+
<method v="2" />
35+
</configuration>
36+
</component>

.run/llm_inference - SCORING.run.xml

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
<component name="ProjectRunConfigurationManager">
2+
<configuration default="false" name="llm_inference - SCORING" type="PythonConfigurationType" factoryName="Python">
3+
<module name="llm-inference" />
4+
<option name="INTERPRETER_OPTIONS" value="" />
5+
<option name="PARENT_ENVS" value="true" />
6+
<envs>
7+
<env name="PYTHONUNBUFFERED" value="1" />
8+
</envs>
9+
<option name="SDK_HOME" value="$PROJECT_DIR$/venv/bin/python" />
10+
<option name="SDK_NAME" value="Python 3.9 (llm-inference)" />
11+
<option name="WORKING_DIRECTORY" value="" />
12+
<option name="IS_MODULE_SDK" value="false" />
13+
<option name="ADD_CONTENT_ROOTS" value="true" />
14+
<option name="ADD_SOURCE_ROOTS" value="true" />
15+
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
16+
<EXTENSION ID="net.ashald.envfile">
17+
<option name="IS_ENABLED" value="true" />
18+
<option name="IS_SUBST" value="false" />
19+
<option name="IS_PATH_MACRO_SUPPORTED" value="false" />
20+
<option name="IS_IGNORE_MISSING_FILES" value="false" />
21+
<option name="IS_ENABLE_EXPERIMENTAL_INTEGRATIONS" value="false" />
22+
<ENTRIES>
23+
<ENTRY IS_ENABLED="true" PARSER="runconfig" IS_EXECUTABLE="false" />
24+
<ENTRY IS_ENABLED="true" PARSER="env" IS_EXECUTABLE="false" PATH=".env" />
25+
</ENTRIES>
26+
</EXTENSION>
27+
<option name="SCRIPT_NAME" value="llm_inference" />
28+
<option name="PARAMETERS" value="--task SCORING --port 8082 --model cmarkea/bloomz-560m-reranking" />
29+
<option name="SHOW_COMMAND_LINE" value="false" />
30+
<option name="EMULATE_TERMINAL" value="false" />
31+
<option name="MODULE_MODE" value="true" />
32+
<option name="REDIRECT_INPUT" value="false" />
33+
<option name="INPUT_FILE" value="" />
34+
<method v="2" />
35+
</configuration>
36+
</component>

README.md

+34-8
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,40 @@ HUGGING_FACE_HUB_TOKEN="<YOUR HF HUB TOKEN>"
3939

4040
### Running the Server
4141

42+
For each kind of task goes a specific model here are some models based on bloomz architecture, Open Sourced that you
43+
could use you will find the latest model on [Credit Mutuel Arkea's hugginface ODQA collection](https://huggingface.co/collections/cmarkea/odqa-65f56ecd2b3e8e993a9982d6).
44+
45+
If you are using PyCharm find run configuration in `.run/**.yaml` they should appear direcly in PyCharm, those configurations uses the samllest models.
46+
47+
#### Embedding server
48+
Use to vectorise documents search for `*-retriever` [models](https://huggingface.co/collections/cmarkea/odqa-65f56ecd2b3e8e993a9982d6), then start the inference server like this (smallest model) :
49+
```bash
50+
python -m llm_inference --task EMBEDDING --port 8081 --model cmarkea/bloomz-560m-retriever-v2
51+
```
52+
53+
Then go to http://localhost:8081/docs.
54+
55+
#### Reranking / Scoring server
56+
Use to rank severeal context according to a specific query, search for `*-reranking` [models](https://huggingface.co/collections/cmarkea/odqa-65f56ecd2b3e8e993a9982d6), then start the inference server like this (smallest model) :
4257
```bash
43-
python -m llm_inference --model "cmarkea/bloomz-3b-retriever-v2" --task EMBEDDING
58+
python -m llm_inference --task SCORING --port 8082 --model cmarkea/bloomz-560m-reranking
4459
```
4560

46-
The server is designed to run one task at a time. There are three different tasks:
47-
- EMBEDDING
48-
- SCORING
49-
- GUARDRAIL
61+
Then go to http://localhost:8082/docs.
62+
63+
Be aware to check the examples in the model card depending on the model you use to understand the meaning of the output labels.
64+
For instance for [**cmarkea/bloomz-560m-reranking**](https://huggingface.co/cmarkea/bloomz-560m-reranking), `LABEL1`
65+
near to 1 means that the context in really similar to the query, as [described in the model card](https://huggingface.co/cmarkea/bloomz-560m-reranking#:~:text=context%20in%20contexts%0A%20%20%20%20%5D%0A)-,contexts_reranked,-%3D%20sorted().
66+
67+
#### Guardrail
68+
69+
Use to detect responses that would be toxic for instance : insult, obscene, sexual_explicit, identity_attack...
70+
Our guardrail models are published under `*-guardrail` [models](https://huggingface.co/collections/cmarkea/odqa-65f56ecd2b3e8e993a9982d6)
71+
72+
```bash
73+
python -m llm_inference --task GUARDRAIL --port 8083 --model cmarkea/bloomz-560m-guardrail
74+
```
75+
Then go to http://localhost:8083/docs.
5076

5177
### API Endpoints
5278

@@ -78,9 +104,9 @@ This project is licensed under the MIT License. See the [LICENSE](LICENSE) file
78104

79105
## Acknowledgments
80106

81-
- [Bloomz](https://bloomz.ai) for providing the pre-trained models.
82-
- [Your Organization](https://yourorganization.com) for supporting this project.
107+
- [BigScience](https://bigscience.huggingface.co/) for providing the pre-trained models.
108+
- [Crédit Mutuel Arkéa](https://www.cm-arkea.com/) for supporting this project.
83109

84110
## Contact
85111

86-
For any inquiries or support, please contact [your email](mailto:[email protected]).
112+
For any inquiries or support, open an issue on this repository.

llm_inference/routes/embedding.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from llm_inference import metrics
88
from llm_inference.model import ServerPipeline
9-
from llm_inference.routes.models import EmbeddingResponse, EmbeddingRequest
9+
from llm_inference.routes.models import EmbeddingResponse, EmbeddingRequest, EmbeddingPooling
1010

1111
router = APIRouter(tags=["Embedding"])
1212
logger = logging.getLogger(__name__)
@@ -28,9 +28,9 @@ def inference(request: EmbeddingRequest):
2828
outputs = ServerPipeline().pipeline(request.text)
2929

3030
for i in range(len(outputs)):
31-
if request.pooling == "mean":
31+
if request.pooling == EmbeddingPooling.MEAN:
3232
outputs[i] = np.mean(outputs[i][0], axis=0).tolist()
33-
elif request.pooling == "last":
33+
elif request.pooling == EmbeddingPooling.LAST:
3434
outputs[i] = outputs[i][0][-1]
3535
else:
3636
return Response("Unsupported pooling method.", status_code=400)

llm_inference/routes/models.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from enum import Enum
23
from typing import List
34

45
from pydantic import BaseModel
@@ -19,9 +20,15 @@ class ClassificationItem(BaseModel):
1920
class ScoringRequest(BaseModel):
2021
contexts: List[ScoringItem]
2122

23+
24+
class EmbeddingPooling(str, Enum):
25+
MEAN = "mean"
26+
LAST = "last"
27+
28+
2229
class EmbeddingRequest(BaseModel):
2330
text: List[str]
24-
pooling: str
31+
pooling: EmbeddingPooling
2532

2633

2734
class GuardrailRequest(BaseModel):

0 commit comments

Comments
 (0)