Skip to content

Commit

Permalink
Add RET ruff rule (#1065)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Aug 15, 2024
1 parent 9111570 commit 4d71eaf
Show file tree
Hide file tree
Showing 38 changed files with 69 additions and 128 deletions.
8 changes: 2 additions & 6 deletions docs/examples/src/multi_agent_workflow_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

def build_researcher() -> Agent:
"""Builds a Researcher Structure."""
researcher = Agent(
return Agent(
id="researcher",
tools=[
WebSearchTool(
Expand Down Expand Up @@ -78,8 +78,6 @@ def build_researcher() -> Agent:
],
)

return researcher


def build_writer(role: str, goal: str, backstory: str) -> Agent:
"""Builds a Writer Structure.
Expand All @@ -89,7 +87,7 @@ def build_writer(role: str, goal: str, backstory: str) -> Agent:
goal: The goal of the writer.
backstory: The backstory of the writer.
"""
writer = Agent(
return Agent(
id=role.lower().replace(" ", "_"),
rulesets=[
Ruleset(
Expand Down Expand Up @@ -123,8 +121,6 @@ def build_writer(role: str, goal: str, backstory: str) -> Agent:
],
)

return writer


if __name__ == "__main__":
# Build the team
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,24 @@


def build_joke_teller() -> Agent:
joke_teller = Agent(
return Agent(
rules=[
Rule(
value="You are very funny.",
)
],
)

return joke_teller


def build_joke_rewriter() -> Agent:
joke_rewriter = Agent(
return Agent(
rules=[
Rule(
value="You are the editor of a joke book. But you only speak in riddles",
)
],
)

return joke_rewriter


joke_coordinator = Pipeline(
tasks=[
Expand Down
8 changes: 2 additions & 6 deletions docs/griptape-framework/structures/src/tasks_16.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def build_researcher() -> Agent:
researcher = Agent(
return Agent(
tools=[
WebSearchTool(
web_search_driver=GoogleWebSearchDriver(
Expand Down Expand Up @@ -63,11 +63,9 @@ def build_researcher() -> Agent:
],
)

return researcher


def build_writer() -> Agent:
writer = Agent(
return Agent(
input="Instructions: {{args[0]}}\nContext: {{args[1]}}",
rulesets=[
Ruleset(
Expand Down Expand Up @@ -106,8 +104,6 @@ def build_writer() -> Agent:
],
)

return writer


team = Pipeline(
tasks=[
Expand Down
7 changes: 3 additions & 4 deletions docs/plugins/swagger_ui_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ def generate_page_contents(page: Any) -> str:
env.filters["markdown"] = lambda text: Markup(md.convert(text))

template = env.get_template(tmpl_url)
tmpl_out = template.render(spec_url=spec_url)
return tmpl_out
return template.render(spec_url=spec_url)


def on_config(config: Any) -> None:
Expand All @@ -32,5 +31,5 @@ def on_page_read_source(page: Any, config: Any) -> Any:
index_path = os.path.join(config["docs_dir"], config_scheme["outfile"])
page_path = os.path.join(config["docs_dir"], page.file.src_path)
if index_path == page_path:
contents = generate_page_contents(page)
return contents
return generate_page_contents(page)
return None
5 changes: 1 addition & 4 deletions griptape/common/prompt_stack/prompt_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,7 @@ def __to_message_content(self, artifact: str | BaseArtifact) -> list[BaseMessage
return [ActionResultMessageContent(output, action=action)]
elif isinstance(artifact, ListArtifact):
processed_contents = [self.__to_message_content(artifact) for artifact in artifact.value]
flattened_content = [
sub_content for processed_content in processed_contents for sub_content in processed_content
]
return [sub_content for processed_content in processed_contents for sub_content in processed_content]

return flattened_content
else:
raise ValueError(f"Unsupported artifact type: {type(artifact)}")
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ def try_list_files(self, path: str) -> list[str]:
if len(files_and_dirs) == 0:
if len(self._list_files_and_dirs(full_key.rstrip("/"), max_items=1)) > 0:
raise NotADirectoryError
else:
raise FileNotFoundError
raise FileNotFoundError
return files_and_dirs

def try_load_file(self, path: str) -> bytes:
Expand All @@ -57,8 +56,7 @@ def try_load_file(self, path: str) -> bytes:
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] in {"NoSuchKey", "404"}:
raise FileNotFoundError from e
else:
raise e
raise e

def try_save_file(self, path: str, value: bytes) -> None:
full_key = self._to_full_key(path)
Expand Down Expand Up @@ -141,5 +139,4 @@ def _normpath(self, path: str) -> str:
else:
stack.append(part)

normalized_path = "/".join(stack)
return normalized_path
return "/".join(stack)
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,15 @@ def _request_parameters(
request["mask_source"] = mask_source
request["mask_image"] = mask.base64

request = {k: v for k, v in request.items() if v is not None}

return request
return {k: v for k, v in request.items() if v is not None}

def get_generated_image(self, response: dict) -> bytes:
image_response = response["artifacts"][0]

# finishReason may be SUCCESS, CONTENT_FILTERED, or ERROR.
if image_response.get("finishReason") == "ERROR":
raise Exception(f"Image generation failed: {image_response.get('finishReason')}")
elif image_response.get("finishReason") == "CONTENT_FILTERED":
if image_response.get("finishReason") == "CONTENT_FILTERED":
logging.warning("Image generation triggered content filter and may be blurred")

return base64.decodebytes(bytes(image_response.get("base64"), "utf-8"))
4 changes: 1 addition & 3 deletions griptape/drivers/image_query/anthropic_image_query_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ def _base_params(self, text_query: str, images: list[ImageArtifact]) -> dict:
content = [self._construct_image_message(image) for image in images]
content.append(self._construct_text_message(text_query))
messages = self._construct_messages(content)
params = {"model": self.model, "messages": messages, "max_tokens": self.max_tokens}

return params
return {"model": self.model, "messages": messages, "max_tokens": self.max_tokens}

def _construct_image_message(self, image_data: ImageArtifact) -> dict:
data = image_data.base64
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ def image_query_request_parameters(self, query: str, images: list[ImageArtifact]
content = [self._construct_image_message(image) for image in images]
content.append(self._construct_text_message(query))
messages = self._construct_messages(content)
input_params = {"messages": messages, "anthropic_version": self.ANTHROPIC_VERSION, "max_tokens": max_tokens}

return input_params
return {"messages": messages, "anthropic_version": self.ANTHROPIC_VERSION, "max_tokens": max_tokens}

def process_output(self, output: dict) -> TextArtifact:
content_blocks = output["content"]
Expand Down
12 changes: 3 additions & 9 deletions griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message: ...
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: ...

def __process_run(self, prompt_stack: PromptStack) -> Message:
result = self.try_run(prompt_stack)

return result
return self.try_run(prompt_stack)

def __process_stream(self, prompt_stack: PromptStack) -> Message:
delta_contents: dict[int, list[BaseDeltaMessageContent]] = {}
Expand All @@ -136,9 +134,7 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message:
event_bus.publish_event(CompletionChunkEvent(token=content.partial_input))

# Build a complete content from the content deltas
result = self.__build_message(list(delta_contents.values()), usage)

return result
return self.__build_message(list(delta_contents.values()), usage)

def __build_message(
self, delta_contents: list[list[BaseDeltaMessageContent]], usage: DeltaMessage.Usage
Expand All @@ -153,10 +149,8 @@ def __build_message(
if action_deltas:
content.append(ActionCallMessageContent.from_deltas(action_deltas))

result = Message(
return Message(
content=content,
role=Message.ASSISTANT_ROLE,
usage=Message.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens),
)

return result
4 changes: 1 addition & 3 deletions griptape/drivers/prompt/google_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def _default_model_client(self) -> GenerativeModel:
def __to_google_messages(self, prompt_stack: PromptStack) -> ContentsType:
types = import_optional_dependency("google.generativeai.types")

inputs = [
return [
types.ContentDict(
{
"role": self.__to_google_role(message),
Expand All @@ -166,8 +166,6 @@ def __to_google_messages(self, prompt_stack: PromptStack) -> ContentsType:
if not message.is_system()
]

return inputs

def __to_google_role(self, message: Message) -> str:
if message.is_assistant():
return "model"
Expand Down
3 changes: 2 additions & 1 deletion griptape/drivers/sql/amazon_redshift_sql_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class AmazonRedshiftSqlDriver(BaseSqlDriver):
def validate_params(self, _: Attribute, workgroup_name: Optional[str]) -> None:
if not self.cluster_identifier and not self.workgroup_name:
raise ValueError("Provide a value for one of `cluster_identifier` or `workgroup_name`")
elif self.cluster_identifier and self.workgroup_name:
if self.cluster_identifier and self.workgroup_name:
raise ValueError("Provide a value for either `cluster_identifier` or `workgroup_name`, but not both")

@classmethod
Expand Down Expand Up @@ -92,6 +92,7 @@ def execute_query_raw(self, query: str) -> Optional[list[dict[str, Optional[Any]

elif statement["Status"] in ["FAILED", "ABORTED"]:
return None
return None

def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]:
function_kwargs = {"Database": self.database, "Table": table_name}
Expand Down
1 change: 1 addition & 0 deletions griptape/drivers/sql/sql_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def execute_query_raw(self, query: str) -> Optional[list[dict[str, Optional[Any]
return [dict(result._mapping) for result in results]
else:
con.commit()
return None
else:
raise ValueError("No result found")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ def query(

response = requests.post(url, json=request, headers=self.headers).json()
entries = response.get("entries", [])
entry_list = [BaseVectorStoreDriver.Entry.from_dict(entry) for entry in entries]
return entry_list
return [BaseVectorStoreDriver.Entry.from_dict(entry) for entry in entries]

def delete_vector(self, vector_id: str) -> NoReturn:
raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")
4 changes: 1 addition & 3 deletions griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def query(
if namespace:
pipeline[0]["$vectorSearch"]["filter"] = {"namespace": namespace}

results = [
return [
BaseVectorStoreDriver.Entry(
id=str(doc["_id"]),
vector=doc[self.vector_path] if include_vectors else [],
Expand All @@ -171,8 +171,6 @@ def query(
for doc in collection.aggregate(pipeline)
]

return results

def delete_vector(self, vector_id: str) -> None:
"""Deletes the vector from the collection."""
collection = self.get_collection()
Expand Down
6 changes: 2 additions & 4 deletions griptape/drivers/vector/opensearch_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,12 @@ def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Opti

if response["hits"]["total"]["value"] > 0:
vector_data = response["hits"]["hits"][0]["_source"]
entry = BaseVectorStoreDriver.Entry(
return BaseVectorStoreDriver.Entry(
id=vector_id,
meta=vector_data.get("metadata"),
vector=vector_data.get("vector"),
namespace=vector_data.get("namespace"),
)
return entry
else:
return None
except Exception as e:
Expand All @@ -109,7 +108,7 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto

response = self.client.search(index=self.index_name, body=query_body)

entries = [
return [
BaseVectorStoreDriver.Entry(
id=hit["_id"],
vector=hit["_source"].get("vector"),
Expand All @@ -118,7 +117,6 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto
)
for hit in response["hits"]["hits"]
]
return entries

def query(
self,
Expand Down
3 changes: 1 addition & 2 deletions griptape/drivers/vector/qdrant_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def query(
results = self.client.search(**request)

# Convert results to QueryResult objects
query_results = [
return [
BaseVectorStoreDriver.Entry(
id=result.id,
vector=result.vector if include_vectors else [],
Expand All @@ -123,7 +123,6 @@ def query(
)
for result in results
]
return query_results

def upsert_vector(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def skip_loading_images(route: Any) -> Any:
if route.request.resource_type == "image":
return route.abort()
route.continue_()
return None

page.route("**/*", skip_loading_images)

Expand Down
13 changes: 6 additions & 7 deletions griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,12 @@ def scrape_url(self, url: str) -> TextArtifact:

if page is None:
raise Exception("can't access URL")
else:
extracted_page = trafilatura.extract(
page,
include_links=self.include_links,
output_format="json",
config=config,
)
extracted_page = trafilatura.extract(
page,
include_links=self.include_links,
output_format="json",
config=config,
)

if not extracted_page:
raise Exception("can't extract page")
Expand Down
3 changes: 1 addition & 2 deletions griptape/drivers/web_search/google_web_search_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@ def _search_google(self, query: str, **kwargs) -> list[dict]:
if response.status_code == 200:
data = response.json()

links = [{"url": r["link"], "title": r["title"], "description": r["snippet"]} for r in data["items"]]
return [{"url": r["link"], "title": r["title"], "description": r["snippet"]} for r in data["items"]]

return links
else:
raise Exception(
f"Google Search API returned an error with status code "
Expand Down
2 changes: 1 addition & 1 deletion griptape/engines/extraction/base_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class BaseExtractionEngine(ABC):
def validate_max_token_multiplier(self, _: Attribute, max_token_multiplier: int) -> None:
if max_token_multiplier > 1:
raise ValueError("has to be less than or equal to 1")
elif max_token_multiplier <= 0:
if max_token_multiplier <= 0:
raise ValueError("has to be greater than 0")

@property
Expand Down
Loading

0 comments on commit 4d71eaf

Please sign in to comment.