Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix prompt helper init #11379

Merged
merged 5 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from llama_index.core.callbacks.schema import CBEventType, EventPayload
from llama_index.core.data_structs.table import StructDatapoint
from llama_index.core.indices.prompt_helper import PromptHelper
from llama_index.core.node_parser.interface import TextSplitter
from llama_index.core.prompts import BasePromptTemplate
from llama_index.core.prompts.default_prompt_selectors import (
Expand All @@ -26,7 +27,6 @@
Settings,
callback_manager_from_settings_or_context,
llm_from_settings_or_context,
prompt_helper_from_settings_or_context,
)
from llama_index.core.utilities.sql_wrapper import SQLDatabase
from llama_index.core.utils import truncate_text
Expand Down Expand Up @@ -67,8 +67,10 @@ def __init__(
self._sql_database = sql_database
self._text_splitter = text_splitter
self._llm = llm or llm_from_settings_or_context(Settings, service_context)
self._prompt_helper = prompt_helper_from_settings_or_context(
Settings, service_context
self._prompt_helper = PromptHelper.from_llm_metadata(
self._llm,
num_output=Settings.num_output,
context_window=Settings.context_window,
)
self._callback_manager = callback_manager_from_settings_or_context(
Settings, service_context
Expand Down
9 changes: 5 additions & 4 deletions llama-index-core/llama_index/core/indices/common_tree/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Common classes/functions for tree index operations."""


import asyncio
import logging
from typing import Dict, List, Optional, Sequence, Tuple

from llama_index.core.async_utils import run_async_tasks
from llama_index.core.callbacks.schema import CBEventType, EventPayload
from llama_index.core.data_structs.data_structs import IndexGraph
from llama_index.core.indices.prompt_helper import PromptHelper
from llama_index.core.indices.utils import get_sorted_node_list, truncate_text
from llama_index.core.llms.llm import LLM
from llama_index.core.prompts import BasePromptTemplate
Expand All @@ -17,7 +17,6 @@
Settings,
callback_manager_from_settings_or_context,
llm_from_settings_or_context,
prompt_helper_from_settings_or_context,
)
from llama_index.core.storage.docstore import BaseDocumentStore
from llama_index.core.storage.docstore.registry import get_default_docstore
Expand Down Expand Up @@ -50,8 +49,10 @@ def __init__(
self.num_children = num_children
self.summary_prompt = summary_prompt
self._llm = llm or llm_from_settings_or_context(Settings, service_context)
self._prompt_helper = prompt_helper_from_settings_or_context(
Settings, service_context
self._prompt_helper = PromptHelper.from_llm_metadata(
self._llm,
num_output=Settings.num_output,
context_window=Settings.context_window,
)
self._callback_manager = callback_manager_from_settings_or_context(
Settings, service_context
Expand Down
14 changes: 9 additions & 5 deletions llama-index-core/llama_index/core/indices/prompt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,21 @@ def from_llm_metadata(
chunk_size_limit: Optional[int] = None,
tokenizer: Optional[Callable[[str], List]] = None,
separator: str = " ",
context_window: Optional[int] = None,
num_output: Optional[int] = None,
) -> "PromptHelper":
"""Create from llm predictor.

This will autofill values like context_window and num_output.

"""
context_window = llm_metadata.context_window
if llm_metadata.num_output == -1:
num_output = DEFAULT_NUM_OUTPUTS
else:
num_output = llm_metadata.num_output
context_window = context_window or llm_metadata.context_window

if num_output is None:
if llm_metadata.num_output == -1:
num_output = DEFAULT_NUM_OUTPUTS
else:
num_output = llm_metadata.num_output

return cls(
context_window=context_window,
Expand Down
8 changes: 5 additions & 3 deletions llama-index-core/llama_index/core/indices/tree/inserter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional, Sequence

from llama_index.core.data_structs.data_structs import IndexGraph
from llama_index.core.indices.prompt_helper import PromptHelper
from llama_index.core.indices.tree.utils import get_numbered_text_from_nodes
from llama_index.core.indices.utils import (
extract_numbers_given_response,
Expand All @@ -19,7 +20,6 @@
from llama_index.core.settings import (
Settings,
llm_from_settings_or_context,
prompt_helper_from_settings_or_context,
)
from llama_index.core.storage.docstore import BaseDocumentStore
from llama_index.core.storage.docstore.registry import get_default_docstore
Expand All @@ -46,8 +46,10 @@ def __init__(
self.insert_prompt = insert_prompt
self.index_graph = index_graph
self._llm = llm or llm_from_settings_or_context(Settings, service_context)
self._prompt_helper = prompt_helper_from_settings_or_context(
Settings, service_context
self._prompt_helper = PromptHelper.from_llm_metadata(
self._llm,
num_output=Settings.num_output,
context_window=Settings.context_window,
)
self._docstore = docstore or get_default_docstore()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.base.response.schema import Response
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.indices.prompt_helper import PromptHelper
from llama_index.core.indices.query.schema import QueryBundle
from llama_index.core.indices.tree.base import TreeIndex
from llama_index.core.indices.tree.utils import get_numbered_text_from_nodes
Expand All @@ -32,7 +33,6 @@
from llama_index.core.settings import (
Settings,
callback_manager_from_settings_or_context,
prompt_helper_from_settings_or_context,
)
from llama_index.core.utils import print_text, truncate_text

Expand Down Expand Up @@ -93,8 +93,10 @@ def __init__(
self._index_struct = index.index_struct
self._docstore = index.docstore
self._service_context = index.service_context
self._prompt_helper = prompt_helper_from_settings_or_context(
Settings, index.service_context
self._prompt_helper = PromptHelper.from_llm_metadata(
self._llm,
num_output=Settings.num_output,
context_window=Settings.context_window,
)

self._text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
2) create and refine separately over each chunk, 3) tree summarization.

"""

import logging
from abc import abstractmethod
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
Expand Down Expand Up @@ -41,7 +42,6 @@
Settings,
callback_manager_from_settings_or_context,
llm_from_settings_or_context,
prompt_helper_from_settings_or_context,
)
from llama_index.core.types import RESPONSE_TEXT_TYPE

Expand Down Expand Up @@ -69,8 +69,10 @@ def __init__(
callback_manager
or callback_manager_from_settings_or_context(Settings, service_context)
)
self._prompt_helper = prompt_helper or prompt_helper_from_settings_or_context(
Settings, service_context
self._prompt_helper = PromptHelper.from_llm_metadata(
self._llm,
num_output=Settings.num_output,
context_window=Settings.context_window,
)

self._streaming = streaming
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
Settings,
callback_manager_from_settings_or_context,
llm_from_settings_or_context,
prompt_helper_from_settings_or_context,
)
from llama_index.core.types import BasePydanticProgram

Expand Down Expand Up @@ -63,8 +62,10 @@ def get_response_synthesizer(
Settings, service_context
)
llm = llm or llm_from_settings_or_context(Settings, service_context)
prompt_helper = prompt_helper or prompt_helper_from_settings_or_context(
Settings, service_context
self._prompt_helper = PromptHelper.from_llm_metadata(
self._llm,
num_output=Settings.num_output,
context_window=Settings.context_window,
)

if response_mode == ResponseMode.REFINE:
Expand Down
50 changes: 3 additions & 47 deletions llama-index-core/llama_index/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ class _Settings:
_prompt_helper: Optional[PromptHelper] = None
_transformations: Optional[List[TransformComponent]] = None

num_output: Optional[int] = None
context_window: Optional[int] = None

# ---- LLM ----

@property
Expand Down Expand Up @@ -197,43 +200,6 @@ def text_splitter(self, text_splitter: NodeParser) -> None:
"""Set the text splitter."""
self.node_parser = text_splitter

# ---- Prompt helper ----

@property
def prompt_helper(self) -> PromptHelper:
"""Get the prompt helper."""
if self._llm is not None and self._prompt_helper is None:
self._prompt_helper = PromptHelper.from_llm_metadata(self._llm.metadata)
elif self._prompt_helper is None:
self._prompt_helper = PromptHelper()

return self._prompt_helper

@prompt_helper.setter
def prompt_helper(self, prompt_helper: PromptHelper) -> None:
"""Set the prompt helper."""
self._prompt_helper = prompt_helper

@property
def num_output(self) -> int:
"""Get the number of outputs."""
return self.prompt_helper.num_output

@num_output.setter
def num_output(self, num_output: int) -> None:
"""Set the number of outputs."""
self.prompt_helper.num_output = num_output

@property
def context_window(self) -> int:
logan-markewich marked this conversation as resolved.
Show resolved Hide resolved
"""Get the context window."""
return self.prompt_helper.context_window

@context_window.setter
def context_window(self, context_window: int) -> None:
"""Set the context window."""
self.prompt_helper.context_window = context_window

# ---- Transformations ----

@property
Expand Down Expand Up @@ -296,16 +262,6 @@ def node_parser_from_settings_or_context(
return settings.node_parser


def prompt_helper_from_settings_or_context(
settings: _Settings, context: Optional["ServiceContext"]
) -> PromptHelper:
"""Get settings from either settings or context."""
if context is not None:
return context.prompt_helper

return settings.prompt_helper


def transformations_from_settings_or_context(
settings: _Settings, context: Optional["ServiceContext"]
) -> List[TransformComponent]:
Expand Down
Loading