Skip to content

Commit

Permalink
Implement longllmlingua compressor module (#459)
Browse files Browse the repository at this point in the history
* just commit

* implement longllmlingua but i cant test it because i dont have cuda

* add split

* delete llm parameter

* add \n\n join

* add support.py

* add base.py

* fix pure result type List

* make list of list at longllmlingua node result

* add pytest skip

* add docs

* edit requirements.txt

* add cuda cache delete and full.yaml
  • Loading branch information
bwook00 authored May 23, 2024
1 parent 252a151 commit 7930b51
Show file tree
Hide file tree
Showing 11 changed files with 164 additions and 0 deletions.
1 change: 1 addition & 0 deletions autorag/nodes/passagecompressor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .longllmlingua import longllmlingua
from .pass_compressor import pass_compressor
from .refine import refine
from .tree_summarize import tree_summarize
9 changes: 9 additions & 0 deletions autorag/nodes/passagecompressor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ def wrapper(
)
del llm
result = list(map(lambda x: [x], result))
elif func.__name__ == 'longllmlingua':
result = func(
queries=queries,
contents=retrieved_contents,
scores=retrieve_scores,
ids=retrieved_ids,
**kwargs
)
result = list(map(lambda x: [x], result))
elif func.__name__ == 'pass_compressor':
result = func(contents=retrieved_contents)
else:
Expand Down
87 changes: 87 additions & 0 deletions autorag/nodes/passagecompressor/longllmlingua.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import List, Optional

import torch
from llmlingua import PromptCompressor

from autorag.nodes.passagecompressor.base import passage_compressor_node


@passage_compressor_node
def longllmlingua(queries: List[str],
contents: List[List[str]],
scores,
ids,
model_name: str = "NousResearch/Llama-2-7b-hf",
instructions: Optional[str] = None,
target_token: int = 300,
**kwargs,
) -> List[str]:
"""
Compresses the retrieved texts using LongLLMLingua.
For more information, visit https://github.com/microsoft/LLMLingua.
:param queries: The queries for retrieved passages.
:param contents: The contents of retrieved passages.
:param scores: The scores of retrieved passages.
Do not use in this function, so you can pass an empty list.
:param ids: The ids of retrieved passages.
Do not use in this function, so you can pass an empty list.
:param model_name: The model name to use for compression.
Default is "NousResearch/Llama-2-7b-hf".
:param instructions: The instructions for compression.
Default is None. When it is None, it will use default instructions.
:param target_token: The target token for compression.
Default is 300.
:param kwargs: Additional keyword arguments.
:return: The list of compressed texts.
"""
if instructions is None:
instructions = "Given the context, please answer the final question"
llm_lingua = PromptCompressor(
model_name=model_name,
)
results = [llmlingua_pure(query, contents_, llm_lingua, instructions, target_token, **kwargs)
for query, contents_ in zip(queries, contents)]

del llm_lingua
if torch.cuda.is_available():
torch.cuda.empty_cache()

return results


def llmlingua_pure(query: str,
contents: List[str],
llm_lingua: PromptCompressor,
instructions: str,
target_token: int = 300,
**kwargs,
) -> str:
"""
Return the compressed text.
:param query: The query for retrieved passages.
:param contents: The contents of retrieved passages.
:param llm_lingua: The llm instance that will be used to compress.
:param instructions: The instructions for compression.
:param target_token: The target token for compression.
Default is 300.
:param kwargs: Additional keyword arguments.
:return: The compressed text.
"""
# split by "\n\n" (recommended by LongLLMLingua authors)
new_context_texts = [c for context in contents for c in context.split("\n\n")]
compressed_prompt = llm_lingua.compress_prompt(
new_context_texts,
question=query,
instruction=instructions,
rank_method="longllmlingua",
target_token=target_token,
**kwargs,
)
compressed_prompt_txt = compressed_prompt["compressed_prompt"]

# separate out the question and instruction
result = '\n\n'.join(compressed_prompt_txt.split("\n\n")[1:-1])

return result
1 change: 1 addition & 0 deletions autorag/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def get_support_modules(module_name: str) -> Callable:
'tree_summarize': ('autorag.nodes.passagecompressor', 'tree_summarize'),
'pass_compressor': ('autorag.nodes.passagecompressor', 'pass_compressor'),
'refine': ('autorag.nodes.passagecompressor', 'refine'),
'longllmlingua': ('autorag.nodes.passagecompressor', 'longllmlingua'),
# prompt_maker
'fstring': ('autorag.nodes.promptmaker', 'fstring'),
'long_context_reorder': ('autorag.nodes.promptmaker', 'long_context_reorder'),
Expand Down
8 changes: 8 additions & 0 deletions docs/source/api_spec/autorag.nodes.passagecompressor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ autorag.nodes.passagecompressor.base module
:undoc-members:
:show-inheritance:

autorag.nodes.passagecompressor.longllmlingua module
----------------------------------------------------

.. automodule:: autorag.nodes.passagecompressor.longllmlingua
:members:
:undoc-members:
:show-inheritance:

autorag.nodes.passagecompressor.pass\_compressor module
-------------------------------------------------------

Expand Down
25 changes: 25 additions & 0 deletions docs/source/nodes/passage_compressor/longllmlingua.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Long LLM Lingua

The `Long LLM Lingua` module is compressor based on [llmlingua](https://github.com/microsoft/LLMLingua).

Compresses the retrieved texts using LongLLMLingua.

## **Module Parameters**

**model_name**: The name of the LLM to be used for compression, defaulting to "NousResearch/Llama-2-7b-hf".

**instructions**: Optional instructions for the LLM, defaulting to "Given the context, please answer the final
question".

**target_token**: The target token count for the output, default to 300.

- **Additional Parameters**:
You can put any additional parameters at llm_lingua.
Find additional parameters [here](https://github.com/microsoft/LLMLingua)

## **Example config.yaml**

```yaml
modules:
- module_type: longllmlingua
```
1 change: 1 addition & 0 deletions docs/source/nodes/passage_compressor/passage_compressor.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,5 @@ maxdepth: 1
---
tree_summarize.md
refine.md
longllmlingua.md
```
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ FlagEmbedding # for flag embedding reranker
ragas # evaluation data generation & evaluation
ray # for parallel processing
kiwipiepy # for BM25 Korean tokenizer
llmlingua # for longllmlingua

### LlamaIndex ###
llama-index>=0.10.1
Expand Down
1 change: 1 addition & 0 deletions sample_config/full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ node_lines:
- module_type: refine
llm: openai
model: gpt-3.5-turbo-16k
- module_type: longllmlingua
- node_line_name: post_retrieve_node_line # Arbitrary node line name
nodes:
- node_type: prompt_maker
Expand Down
29 changes: 29 additions & 0 deletions tests/autorag/nodes/passagecompressor/test_longllmlingua.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pandas as pd
import pytest

from autorag.nodes.passagecompressor import longllmlingua
from tests.autorag.nodes.passagecompressor.test_base_passage_compressor import (queries, retrieved_contents,
check_result, df)


@pytest.mark.skip(reason="This test needs CUDA enabled machine.")
def test_longllmlingua():
result = longllmlingua.__wrapped__(queries, retrieved_contents, [], [])
check_result(result)


@pytest.mark.skip(reason="This test needs CUDA enabled machine.")
def test_longllmlingua_node():
result = longllmlingua(
"project_dir",
df,
target_token=75,
)
assert isinstance(result, pd.DataFrame)
contents = result['retrieved_contents'].tolist()
assert isinstance(contents, list)
assert len(contents) == len(queries)
assert isinstance(contents[0], list)
assert len(contents[0]) == 1
assert isinstance(contents[0][0], str)
assert bool(contents[0][0]) is True
1 change: 1 addition & 0 deletions tests/resources/full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ node_lines:
- module_type: tree_summarize
llm: openai
model: gpt-3.5-turbo-16k
- module_type: longllmlingua
- node_line_name: post_retrieve_node_line # Arbitrary node line name
nodes:
- node_type: prompt_maker
Expand Down

0 comments on commit 7930b51

Please sign in to comment.