Skip to content

Commit

Permalink
refactor: Make backend orchestrator, parser, and plugins PEP8 (#879)
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurarpit authored May 14, 2024
1 parent fab857b commit e60816d
Show file tree
Hide file tree
Showing 37 changed files with 38 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ...document_chunking.chunking_strategy import ChunkingStrategy, ChunkingSettings
from ...document_loading import LoadingSettings, LoadingStrategy
from .embedding_config import EmbeddingConfig
from ...orchestrator.OrchestrationStrategy import OrchestrationStrategy
from ...orchestrator.orchestration_strategy import OrchestrationStrategy
from ...orchestrator import OrchestrationSettings
from ..env_helper import EnvHelper

Expand Down
4 changes: 2 additions & 2 deletions code/backend/batch/utilities/helpers/orchestrator_helper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import List

from ..orchestrator.OrchestrationStrategy import OrchestrationStrategy
from ..orchestrator.orchestration_strategy import OrchestrationStrategy
from ..orchestrator import OrchestrationSettings
from ..orchestrator.Strategies import get_orchestrator
from ..orchestrator.strategies import get_orchestrator

__all__ = ["OrchestrationStrategy"]

Expand Down
2 changes: 1 addition & 1 deletion code/backend/batch/utilities/orchestrator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List
import os.path
import pkgutil
from .OrchestrationStrategy import OrchestrationStrategy
from .orchestration_strategy import OrchestrationStrategy


class OrchestrationSettings:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from langchain.chains import LLMChain
from langchain_community.callbacks import get_openai_callback

from .OrchestratorBase import OrchestratorBase
from .orchestrator_base import OrchestratorBase
from ..helpers.llm_helper import LLMHelper
from ..tools.PostPromptTool import PostPromptTool
from ..tools.QuestionAnswerTool import QuestionAnswerTool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List
import json

from .OrchestratorBase import OrchestratorBase
from .orchestrator_base import OrchestratorBase
from ..helpers.llm_helper import LLMHelper
from ..tools.PostPromptTool import PostPromptTool
from ..tools.QuestionAnswerTool import QuestionAnswerTool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from abc import ABC, abstractmethod
from ..loggers.ConversationLogger import ConversationLogger
from ..helpers.config.config_helper import ConfigHelper
from ..parser.OutputParserTool import OutputParserTool
from ..parser.output_parser_tool import OutputParserTool
from ..tools.ContentSafetyChecker import ContentSafetyChecker

logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@

from ..common.Answer import Answer
from ..helpers.llm_helper import LLMHelper
from ..plugins.ChatPlugin import ChatPlugin
from ..plugins.PostAnsweringPlugin import PostAnsweringPlugin
from .OrchestratorBase import OrchestratorBase
from ..plugins.chat_plugin import ChatPlugin
from ..plugins.post_answering_plugin import PostAnsweringPlugin
from .orchestrator_base import OrchestratorBase

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .OrchestrationStrategy import OrchestrationStrategy
from .OpenAIFunctions import OpenAIFunctionsOrchestrator
from .LangChainAgent import LangChainAgent
from .SemanticKernel import SemanticKernelOrchestrator
from .orchestration_strategy import OrchestrationStrategy
from .open_ai_functions import OpenAIFunctionsOrchestrator
from .lang_chain_agent import LangChainAgent
from .semantic_kernel import SemanticKernelOrchestrator


def get_orchestrator(orchestration_strategy: str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import re
import json
from .ParserBase import ParserBase
from .parser_base import ParserBase
from ..common.SourceDocument import SourceDocument

logger = logging.getLogger(__name__)
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion code/tests/test_OutputParserTool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import List

from backend.batch.utilities.parser.OutputParserTool import OutputParserTool
from backend.batch.utilities.parser.output_parser_tool import OutputParserTool
from backend.batch.utilities.common.SourceDocument import SourceDocument


Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest

from backend.batch.utilities.orchestrator.LangChainAgent import LangChainAgent
from backend.batch.utilities.orchestrator.lang_chain_agent import LangChainAgent
from backend.batch.utilities.common.Answer import Answer


Expand Down Expand Up @@ -69,8 +69,8 @@ def test_run_text_processing_tool_returns_answer_json():
)


@patch("backend.batch.utilities.orchestrator.LangChainAgent.ZeroShotAgent")
@patch("backend.batch.utilities.orchestrator.LangChainAgent.LLMChain")
@patch("backend.batch.utilities.orchestrator.lang_chain_agent.ZeroShotAgent")
@patch("backend.batch.utilities.orchestrator.lang_chain_agent.LLMChain")
@patch("langchain.agents.AgentExecutor.from_agent_and_tools")
@pytest.mark.asyncio
async def test_orchestrate_langchain_to_orchestrate_chat(
Expand Down Expand Up @@ -100,8 +100,8 @@ async def test_orchestrate_langchain_to_orchestrate_chat(
)


@patch("backend.batch.utilities.orchestrator.LangChainAgent.ZeroShotAgent")
@patch("backend.batch.utilities.orchestrator.LangChainAgent.LLMChain")
@patch("backend.batch.utilities.orchestrator.lang_chain_agent.ZeroShotAgent")
@patch("backend.batch.utilities.orchestrator.lang_chain_agent.LLMChain")
@patch("langchain.agents.AgentExecutor.from_agent_and_tools")
@pytest.mark.asyncio
async def test_orchestrate_returns_error_message_on_Exception(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from unittest.mock import MagicMock, patch

import pytest
from backend.batch.utilities.orchestrator.OpenAIFunctions import (
from backend.batch.utilities.orchestrator.open_ai_functions import (
OpenAIFunctionsOrchestrator,
)
from backend.batch.utilities.parser.OutputParserTool import OutputParserTool
from backend.batch.utilities.parser.output_parser_tool import OutputParserTool


@pytest.fixture(autouse=True)
def llm_helper_mock():
with patch(
"backend.batch.utilities.orchestrator.OpenAIFunctions.LLMHelper"
"backend.batch.utilities.orchestrator.open_ai_functions.LLMHelper"
) as mock:
llm_helper = mock.return_value

Expand All @@ -20,7 +20,7 @@ def llm_helper_mock():
@pytest.fixture()
def orchestrator():
with patch(
"backend.batch.utilities.orchestrator.OpenAIFunctions.OrchestratorBase.__init__"
"backend.batch.utilities.orchestrator.open_ai_functions.OrchestratorBase.__init__"
):
orchestrator = OpenAIFunctionsOrchestrator()

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from unittest.mock import MagicMock, patch

import pytest
from backend.batch.utilities.orchestrator.OrchestratorBase import OrchestratorBase
from backend.batch.utilities.orchestrator.orchestrator_base import OrchestratorBase


class MockOrchestrator(OrchestratorBase):
Expand All @@ -14,7 +14,7 @@ async def orchestrate(
@pytest.fixture(autouse=True)
def config_mock():
with patch(
"backend.batch.utilities.orchestrator.OrchestratorBase.ConfigHelper"
"backend.batch.utilities.orchestrator.orchestrator_base.ConfigHelper"
) as mock:
config = mock.get_active_config_or_default.return_value
yield config
Expand All @@ -23,7 +23,7 @@ def config_mock():
@pytest.fixture(autouse=True)
def conversation_logger_mock():
with patch(
"backend.batch.utilities.orchestrator.OrchestratorBase.ConversationLogger"
"backend.batch.utilities.orchestrator.orchestrator_base.ConversationLogger"
) as mock:
conversation_logger = mock.return_value
yield conversation_logger
Expand All @@ -32,7 +32,7 @@ def conversation_logger_mock():
@pytest.fixture(autouse=True)
def content_safety_checker_mock():
with patch(
"backend.batch.utilities.orchestrator.OrchestratorBase.ContentSafetyChecker"
"backend.batch.utilities.orchestrator.orchestrator_base.ContentSafetyChecker"
) as mock:
content_safety_checker = mock.return_value
yield content_safety_checker
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import pytest
from backend.batch.utilities.common.Answer import Answer
from backend.batch.utilities.orchestrator.SemanticKernel import (
from backend.batch.utilities.orchestrator.semantic_kernel import (
SemanticKernelOrchestrator,
)
from backend.batch.utilities.parser.OutputParserTool import OutputParserTool
from backend.batch.utilities.parser.output_parser_tool import OutputParserTool
from semantic_kernel import Kernel
from semantic_kernel.connectors.ai.open_ai import AzureChatCompletion
from semantic_kernel.connectors.ai.function_call_behavior import EnabledFunctions
Expand All @@ -31,7 +31,7 @@

@pytest.fixture(autouse=True)
def llm_helper_mock():
with patch("backend.batch.utilities.orchestrator.SemanticKernel.LLMHelper") as mock:
with patch("backend.batch.utilities.orchestrator.semantic_kernel.LLMHelper") as mock:
llm_helper = mock.return_value

llm_helper.get_sk_chat_completion_service.return_value = AzureChatCompletion(
Expand All @@ -53,7 +53,7 @@ def llm_helper_mock():
@pytest.fixture()
def orchestrator():
with patch(
"backend.batch.utilities.orchestrator.SemanticKernel.OrchestratorBase.__init__"
"backend.batch.utilities.orchestrator.semantic_kernel.OrchestratorBase.__init__"
):
orchestrator = SemanticKernelOrchestrator()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import pytest
from backend.batch.utilities.common.Answer import Answer
from backend.batch.utilities.plugins.ChatPlugin import ChatPlugin
from backend.batch.utilities.plugins.chat_plugin import ChatPlugin
from semantic_kernel import Kernel


@patch("backend.batch.utilities.plugins.ChatPlugin.QuestionAnswerTool")
@patch("backend.batch.utilities.plugins.chat_plugin.QuestionAnswerTool")
@pytest.mark.asyncio
async def test_search_documents(QuestionAnswerToolMock: MagicMock):
# given
Expand Down Expand Up @@ -40,7 +40,7 @@ async def test_search_documents(QuestionAnswerToolMock: MagicMock):
)


@patch("backend.batch.utilities.plugins.ChatPlugin.TextProcessingTool")
@patch("backend.batch.utilities.plugins.chat_plugin.TextProcessingTool")
@pytest.mark.asyncio
async def test_text_processing(TextProcessingToolMock: MagicMock):
# given
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import pytest
from backend.batch.utilities.common.Answer import Answer
from backend.batch.utilities.plugins.PostAnsweringPlugin import PostAnsweringPlugin
from backend.batch.utilities.plugins.post_answering_plugin import PostAnsweringPlugin
from semantic_kernel import Kernel


@patch("backend.batch.utilities.plugins.PostAnsweringPlugin.PostPromptTool")
@patch("backend.batch.utilities.plugins.post_answering_plugin.PostPromptTool")
@pytest.mark.asyncio
async def test_validate_answer(PostPromptToolMock: MagicMock):
# given
Expand Down

0 comments on commit e60816d

Please sign in to comment.