-
-
Notifications
You must be signed in to change notification settings - Fork 276
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement longllmlingua compressor module (#459)
* just commit * implement longllmlingua but i cant test it because i dont have cuda * add split * delete llm parameter * add \n\n join * add support.py * add base.py * fix pure result type List * make list of list at longllmlingua node result * add pytest skip * add docs * edit requirements.txt * add cuda cache delete and full.yaml
- Loading branch information
Showing
11 changed files
with
164 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .longllmlingua import longllmlingua | ||
from .pass_compressor import pass_compressor | ||
from .refine import refine | ||
from .tree_summarize import tree_summarize |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from typing import List, Optional | ||
|
||
import torch | ||
from llmlingua import PromptCompressor | ||
|
||
from autorag.nodes.passagecompressor.base import passage_compressor_node | ||
|
||
|
||
@passage_compressor_node | ||
def longllmlingua(queries: List[str], | ||
contents: List[List[str]], | ||
scores, | ||
ids, | ||
model_name: str = "NousResearch/Llama-2-7b-hf", | ||
instructions: Optional[str] = None, | ||
target_token: int = 300, | ||
**kwargs, | ||
) -> List[str]: | ||
""" | ||
Compresses the retrieved texts using LongLLMLingua. | ||
For more information, visit https://github.com/microsoft/LLMLingua. | ||
:param queries: The queries for retrieved passages. | ||
:param contents: The contents of retrieved passages. | ||
:param scores: The scores of retrieved passages. | ||
Do not use in this function, so you can pass an empty list. | ||
:param ids: The ids of retrieved passages. | ||
Do not use in this function, so you can pass an empty list. | ||
:param model_name: The model name to use for compression. | ||
Default is "NousResearch/Llama-2-7b-hf". | ||
:param instructions: The instructions for compression. | ||
Default is None. When it is None, it will use default instructions. | ||
:param target_token: The target token for compression. | ||
Default is 300. | ||
:param kwargs: Additional keyword arguments. | ||
:return: The list of compressed texts. | ||
""" | ||
if instructions is None: | ||
instructions = "Given the context, please answer the final question" | ||
llm_lingua = PromptCompressor( | ||
model_name=model_name, | ||
) | ||
results = [llmlingua_pure(query, contents_, llm_lingua, instructions, target_token, **kwargs) | ||
for query, contents_ in zip(queries, contents)] | ||
|
||
del llm_lingua | ||
if torch.cuda.is_available(): | ||
torch.cuda.empty_cache() | ||
|
||
return results | ||
|
||
|
||
def llmlingua_pure(query: str, | ||
contents: List[str], | ||
llm_lingua: PromptCompressor, | ||
instructions: str, | ||
target_token: int = 300, | ||
**kwargs, | ||
) -> str: | ||
""" | ||
Return the compressed text. | ||
:param query: The query for retrieved passages. | ||
:param contents: The contents of retrieved passages. | ||
:param llm_lingua: The llm instance that will be used to compress. | ||
:param instructions: The instructions for compression. | ||
:param target_token: The target token for compression. | ||
Default is 300. | ||
:param kwargs: Additional keyword arguments. | ||
:return: The compressed text. | ||
""" | ||
# split by "\n\n" (recommended by LongLLMLingua authors) | ||
new_context_texts = [c for context in contents for c in context.split("\n\n")] | ||
compressed_prompt = llm_lingua.compress_prompt( | ||
new_context_texts, | ||
question=query, | ||
instruction=instructions, | ||
rank_method="longllmlingua", | ||
target_token=target_token, | ||
**kwargs, | ||
) | ||
compressed_prompt_txt = compressed_prompt["compressed_prompt"] | ||
|
||
# separate out the question and instruction | ||
result = '\n\n'.join(compressed_prompt_txt.split("\n\n")[1:-1]) | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Long LLM Lingua | ||
|
||
The `Long LLM Lingua` module is compressor based on [llmlingua](https://github.com/microsoft/LLMLingua). | ||
|
||
Compresses the retrieved texts using LongLLMLingua. | ||
|
||
## **Module Parameters** | ||
|
||
**model_name**: The name of the LLM to be used for compression, defaulting to "NousResearch/Llama-2-7b-hf". | ||
|
||
**instructions**: Optional instructions for the LLM, defaulting to "Given the context, please answer the final | ||
question". | ||
|
||
**target_token**: The target token count for the output, default to 300. | ||
|
||
- **Additional Parameters**: | ||
You can put any additional parameters at llm_lingua. | ||
Find additional parameters [here](https://github.com/microsoft/LLMLingua) | ||
|
||
## **Example config.yaml** | ||
|
||
```yaml | ||
modules: | ||
- module_type: longllmlingua | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,4 +53,5 @@ maxdepth: 1 | |
--- | ||
tree_summarize.md | ||
refine.md | ||
longllmlingua.md | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
29 changes: 29 additions & 0 deletions
29
tests/autorag/nodes/passagecompressor/test_longllmlingua.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import pandas as pd | ||
import pytest | ||
|
||
from autorag.nodes.passagecompressor import longllmlingua | ||
from tests.autorag.nodes.passagecompressor.test_base_passage_compressor import (queries, retrieved_contents, | ||
check_result, df) | ||
|
||
|
||
@pytest.mark.skip(reason="This test needs CUDA enabled machine.") | ||
def test_longllmlingua(): | ||
result = longllmlingua.__wrapped__(queries, retrieved_contents, [], []) | ||
check_result(result) | ||
|
||
|
||
@pytest.mark.skip(reason="This test needs CUDA enabled machine.") | ||
def test_longllmlingua_node(): | ||
result = longllmlingua( | ||
"project_dir", | ||
df, | ||
target_token=75, | ||
) | ||
assert isinstance(result, pd.DataFrame) | ||
contents = result['retrieved_contents'].tolist() | ||
assert isinstance(contents, list) | ||
assert len(contents) == len(queries) | ||
assert isinstance(contents[0], list) | ||
assert len(contents[0]) == 1 | ||
assert isinstance(contents[0][0], str) | ||
assert bool(contents[0][0]) is True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters