diff --git a/docs/examples/src/multi_agent_workflow_1.py b/docs/examples/src/multi_agent_workflow_1.py index e1f00a9bd..ad9436a55 100644 --- a/docs/examples/src/multi_agent_workflow_1.py +++ b/docs/examples/src/multi_agent_workflow_1.py @@ -26,7 +26,7 @@ def build_researcher() -> Agent: """Builds a Researcher Structure.""" - researcher = Agent( + return Agent( id="researcher", tools=[ WebSearchTool( @@ -78,8 +78,6 @@ def build_researcher() -> Agent: ], ) - return researcher - def build_writer(role: str, goal: str, backstory: str) -> Agent: """Builds a Writer Structure. @@ -89,7 +87,7 @@ def build_writer(role: str, goal: str, backstory: str) -> Agent: goal: The goal of the writer. backstory: The backstory of the writer. """ - writer = Agent( + return Agent( id=role.lower().replace(" ", "_"), rulesets=[ Ruleset( @@ -123,8 +121,6 @@ def build_writer(role: str, goal: str, backstory: str) -> Agent: ], ) - return writer - if __name__ == "__main__": # Build the team diff --git a/docs/griptape-framework/drivers/src/structure_run_drivers_1.py b/docs/griptape-framework/drivers/src/structure_run_drivers_1.py index 00ab6d60f..a29bfbedf 100644 --- a/docs/griptape-framework/drivers/src/structure_run_drivers_1.py +++ b/docs/griptape-framework/drivers/src/structure_run_drivers_1.py @@ -5,7 +5,7 @@ def build_joke_teller() -> Agent: - joke_teller = Agent( + return Agent( rules=[ Rule( value="You are very funny.", @@ -13,11 +13,9 @@ def build_joke_teller() -> Agent: ], ) - return joke_teller - def build_joke_rewriter() -> Agent: - joke_rewriter = Agent( + return Agent( rules=[ Rule( value="You are the editor of a joke book. But you only speak in riddles", @@ -25,8 +23,6 @@ def build_joke_rewriter() -> Agent: ], ) - return joke_rewriter - joke_coordinator = Pipeline( tasks=[ diff --git a/docs/griptape-framework/structures/src/tasks_16.py b/docs/griptape-framework/structures/src/tasks_16.py index a6da835a6..7496d2d9c 100644 --- a/docs/griptape-framework/structures/src/tasks_16.py +++ b/docs/griptape-framework/structures/src/tasks_16.py @@ -12,7 +12,7 @@ def build_researcher() -> Agent: - researcher = Agent( + return Agent( tools=[ WebSearchTool( web_search_driver=GoogleWebSearchDriver( @@ -63,11 +63,9 @@ def build_researcher() -> Agent: ], ) - return researcher - def build_writer() -> Agent: - writer = Agent( + return Agent( input="Instructions: {{args[0]}}\nContext: {{args[1]}}", rulesets=[ Ruleset( @@ -106,8 +104,6 @@ def build_writer() -> Agent: ], ) - return writer - team = Pipeline( tasks=[ diff --git a/docs/plugins/swagger_ui_plugin.py b/docs/plugins/swagger_ui_plugin.py index 499d74cf5..2f2ca2c4e 100644 --- a/docs/plugins/swagger_ui_plugin.py +++ b/docs/plugins/swagger_ui_plugin.py @@ -20,8 +20,7 @@ def generate_page_contents(page: Any) -> str: env.filters["markdown"] = lambda text: Markup(md.convert(text)) template = env.get_template(tmpl_url) - tmpl_out = template.render(spec_url=spec_url) - return tmpl_out + return template.render(spec_url=spec_url) def on_config(config: Any) -> None: @@ -32,5 +31,5 @@ def on_page_read_source(page: Any, config: Any) -> Any: index_path = os.path.join(config["docs_dir"], config_scheme["outfile"]) page_path = os.path.join(config["docs_dir"], page.file.src_path) if index_path == page_path: - contents = generate_page_contents(page) - return contents + return generate_page_contents(page) + return None diff --git a/griptape/common/prompt_stack/prompt_stack.py b/griptape/common/prompt_stack/prompt_stack.py index 3186dac89..cf6b67040 100644 --- a/griptape/common/prompt_stack/prompt_stack.py +++ b/griptape/common/prompt_stack/prompt_stack.py @@ -71,10 +71,7 @@ def __to_message_content(self, artifact: str | BaseArtifact) -> list[BaseMessage return [ActionResultMessageContent(output, action=action)] elif isinstance(artifact, ListArtifact): processed_contents = [self.__to_message_content(artifact) for artifact in artifact.value] - flattened_content = [ - sub_content for processed_content in processed_contents for sub_content in processed_content - ] + return [sub_content for processed_content in processed_contents for sub_content in processed_content] - return flattened_content else: raise ValueError(f"Unsupported artifact type: {type(artifact)}") diff --git a/griptape/drivers/file_manager/amazon_s3_file_manager_driver.py b/griptape/drivers/file_manager/amazon_s3_file_manager_driver.py index e58e46d37..20e432c0b 100644 --- a/griptape/drivers/file_manager/amazon_s3_file_manager_driver.py +++ b/griptape/drivers/file_manager/amazon_s3_file_manager_driver.py @@ -40,8 +40,7 @@ def try_list_files(self, path: str) -> list[str]: if len(files_and_dirs) == 0: if len(self._list_files_and_dirs(full_key.rstrip("/"), max_items=1)) > 0: raise NotADirectoryError - else: - raise FileNotFoundError + raise FileNotFoundError return files_and_dirs def try_load_file(self, path: str) -> bytes: @@ -57,8 +56,7 @@ def try_load_file(self, path: str) -> bytes: except botocore.exceptions.ClientError as e: if e.response["Error"]["Code"] in {"NoSuchKey", "404"}: raise FileNotFoundError from e - else: - raise e + raise e def try_save_file(self, path: str, value: bytes) -> None: full_key = self._to_full_key(path) @@ -141,5 +139,4 @@ def _normpath(self, path: str) -> str: else: stack.append(part) - normalized_path = "/".join(stack) - return normalized_path + return "/".join(stack) diff --git a/griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py b/griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py index 92428e157..0ec7d03d2 100644 --- a/griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py +++ b/griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py @@ -139,9 +139,7 @@ def _request_parameters( request["mask_source"] = mask_source request["mask_image"] = mask.base64 - request = {k: v for k, v in request.items() if v is not None} - - return request + return {k: v for k, v in request.items() if v is not None} def get_generated_image(self, response: dict) -> bytes: image_response = response["artifacts"][0] @@ -149,7 +147,7 @@ def get_generated_image(self, response: dict) -> bytes: # finishReason may be SUCCESS, CONTENT_FILTERED, or ERROR. if image_response.get("finishReason") == "ERROR": raise Exception(f"Image generation failed: {image_response.get('finishReason')}") - elif image_response.get("finishReason") == "CONTENT_FILTERED": + if image_response.get("finishReason") == "CONTENT_FILTERED": logging.warning("Image generation triggered content filter and may be blurred") return base64.decodebytes(bytes(image_response.get("base64"), "utf-8")) diff --git a/griptape/drivers/image_query/anthropic_image_query_driver.py b/griptape/drivers/image_query/anthropic_image_query_driver.py index bd19862ec..a50685724 100644 --- a/griptape/drivers/image_query/anthropic_image_query_driver.py +++ b/griptape/drivers/image_query/anthropic_image_query_driver.py @@ -47,9 +47,7 @@ def _base_params(self, text_query: str, images: list[ImageArtifact]) -> dict: content = [self._construct_image_message(image) for image in images] content.append(self._construct_text_message(text_query)) messages = self._construct_messages(content) - params = {"model": self.model, "messages": messages, "max_tokens": self.max_tokens} - - return params + return {"model": self.model, "messages": messages, "max_tokens": self.max_tokens} def _construct_image_message(self, image_data: ImageArtifact) -> dict: data = image_data.base64 diff --git a/griptape/drivers/image_query_model/bedrock_claude_image_query_model_driver.py b/griptape/drivers/image_query_model/bedrock_claude_image_query_model_driver.py index 8260ce3d5..1785550a0 100644 --- a/griptape/drivers/image_query_model/bedrock_claude_image_query_model_driver.py +++ b/griptape/drivers/image_query_model/bedrock_claude_image_query_model_driver.py @@ -14,9 +14,7 @@ def image_query_request_parameters(self, query: str, images: list[ImageArtifact] content = [self._construct_image_message(image) for image in images] content.append(self._construct_text_message(query)) messages = self._construct_messages(content) - input_params = {"messages": messages, "anthropic_version": self.ANTHROPIC_VERSION, "max_tokens": max_tokens} - - return input_params + return {"messages": messages, "anthropic_version": self.ANTHROPIC_VERSION, "max_tokens": max_tokens} def process_output(self, output: dict) -> TextArtifact: content_blocks = output["content"] diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 8044469b5..43b31306c 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -108,9 +108,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message: ... def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: ... def __process_run(self, prompt_stack: PromptStack) -> Message: - result = self.try_run(prompt_stack) - - return result + return self.try_run(prompt_stack) def __process_stream(self, prompt_stack: PromptStack) -> Message: delta_contents: dict[int, list[BaseDeltaMessageContent]] = {} @@ -136,9 +134,7 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message: event_bus.publish_event(CompletionChunkEvent(token=content.partial_input)) # Build a complete content from the content deltas - result = self.__build_message(list(delta_contents.values()), usage) - - return result + return self.__build_message(list(delta_contents.values()), usage) def __build_message( self, delta_contents: list[list[BaseDeltaMessageContent]], usage: DeltaMessage.Usage @@ -153,10 +149,8 @@ def __build_message( if action_deltas: content.append(ActionCallMessageContent.from_deltas(action_deltas)) - result = Message( + return Message( content=content, role=Message.ASSISTANT_ROLE, usage=Message.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens), ) - - return result diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index 06f9dfbe6..57ebbd338 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -155,7 +155,7 @@ def _default_model_client(self) -> GenerativeModel: def __to_google_messages(self, prompt_stack: PromptStack) -> ContentsType: types = import_optional_dependency("google.generativeai.types") - inputs = [ + return [ types.ContentDict( { "role": self.__to_google_role(message), @@ -166,8 +166,6 @@ def __to_google_messages(self, prompt_stack: PromptStack) -> ContentsType: if not message.is_system() ] - return inputs - def __to_google_role(self, message: Message) -> str: if message.is_assistant(): return "model" diff --git a/griptape/drivers/sql/amazon_redshift_sql_driver.py b/griptape/drivers/sql/amazon_redshift_sql_driver.py index 5ae85c495..837405e83 100644 --- a/griptape/drivers/sql/amazon_redshift_sql_driver.py +++ b/griptape/drivers/sql/amazon_redshift_sql_driver.py @@ -29,7 +29,7 @@ class AmazonRedshiftSqlDriver(BaseSqlDriver): def validate_params(self, _: Attribute, workgroup_name: Optional[str]) -> None: if not self.cluster_identifier and not self.workgroup_name: raise ValueError("Provide a value for one of `cluster_identifier` or `workgroup_name`") - elif self.cluster_identifier and self.workgroup_name: + if self.cluster_identifier and self.workgroup_name: raise ValueError("Provide a value for either `cluster_identifier` or `workgroup_name`, but not both") @classmethod @@ -92,6 +92,7 @@ def execute_query_raw(self, query: str) -> Optional[list[dict[str, Optional[Any] elif statement["Status"] in ["FAILED", "ABORTED"]: return None + return None def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]: function_kwargs = {"Database": self.database, "Table": table_name} diff --git a/griptape/drivers/sql/sql_driver.py b/griptape/drivers/sql/sql_driver.py index 0e3d1d4b7..d2293f94d 100644 --- a/griptape/drivers/sql/sql_driver.py +++ b/griptape/drivers/sql/sql_driver.py @@ -41,6 +41,7 @@ def execute_query_raw(self, query: str) -> Optional[list[dict[str, Optional[Any] return [dict(result._mapping) for result in results] else: con.commit() + return None else: raise ValueError("No result found") diff --git a/griptape/drivers/vector/griptape_cloud_knowledge_base_vector_store_driver.py b/griptape/drivers/vector/griptape_cloud_knowledge_base_vector_store_driver.py index 34b646846..a3bd6a011 100644 --- a/griptape/drivers/vector/griptape_cloud_knowledge_base_vector_store_driver.py +++ b/griptape/drivers/vector/griptape_cloud_knowledge_base_vector_store_driver.py @@ -105,8 +105,7 @@ def query( response = requests.post(url, json=request, headers=self.headers).json() entries = response.get("entries", []) - entry_list = [BaseVectorStoreDriver.Entry.from_dict(entry) for entry in entries] - return entry_list + return [BaseVectorStoreDriver.Entry.from_dict(entry) for entry in entries] def delete_vector(self, vector_id: str) -> NoReturn: raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.") diff --git a/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py b/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py index 34b1d3a5e..bc3f1e22f 100644 --- a/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py +++ b/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py @@ -160,7 +160,7 @@ def query( if namespace: pipeline[0]["$vectorSearch"]["filter"] = {"namespace": namespace} - results = [ + return [ BaseVectorStoreDriver.Entry( id=str(doc["_id"]), vector=doc[self.vector_path] if include_vectors else [], @@ -171,8 +171,6 @@ def query( for doc in collection.aggregate(pipeline) ] - return results - def delete_vector(self, vector_id: str) -> None: """Deletes the vector from the collection.""" collection = self.get_collection() diff --git a/griptape/drivers/vector/opensearch_vector_store_driver.py b/griptape/drivers/vector/opensearch_vector_store_driver.py index 267b549b7..cf944116a 100644 --- a/griptape/drivers/vector/opensearch_vector_store_driver.py +++ b/griptape/drivers/vector/opensearch_vector_store_driver.py @@ -83,13 +83,12 @@ def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Opti if response["hits"]["total"]["value"] > 0: vector_data = response["hits"]["hits"][0]["_source"] - entry = BaseVectorStoreDriver.Entry( + return BaseVectorStoreDriver.Entry( id=vector_id, meta=vector_data.get("metadata"), vector=vector_data.get("vector"), namespace=vector_data.get("namespace"), ) - return entry else: return None except Exception as e: @@ -109,7 +108,7 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto response = self.client.search(index=self.index_name, body=query_body) - entries = [ + return [ BaseVectorStoreDriver.Entry( id=hit["_id"], vector=hit["_source"].get("vector"), @@ -118,7 +117,6 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto ) for hit in response["hits"]["hits"] ] - return entries def query( self, diff --git a/griptape/drivers/vector/qdrant_vector_store_driver.py b/griptape/drivers/vector/qdrant_vector_store_driver.py index c33b7eb2e..154e54af7 100644 --- a/griptape/drivers/vector/qdrant_vector_store_driver.py +++ b/griptape/drivers/vector/qdrant_vector_store_driver.py @@ -114,7 +114,7 @@ def query( results = self.client.search(**request) # Convert results to QueryResult objects - query_results = [ + return [ BaseVectorStoreDriver.Entry( id=result.id, vector=result.vector if include_vectors else [], @@ -123,7 +123,6 @@ def query( ) for result in results ] - return query_results def upsert_vector( self, diff --git a/griptape/drivers/web_scraper/markdownify_web_scraper_driver.py b/griptape/drivers/web_scraper/markdownify_web_scraper_driver.py index 556d5e06e..b54ff072f 100644 --- a/griptape/drivers/web_scraper/markdownify_web_scraper_driver.py +++ b/griptape/drivers/web_scraper/markdownify_web_scraper_driver.py @@ -60,6 +60,7 @@ def skip_loading_images(route: Any) -> Any: if route.request.resource_type == "image": return route.abort() route.continue_() + return None page.route("**/*", skip_loading_images) diff --git a/griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py b/griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py index 0763155d5..06f5573a4 100644 --- a/griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py +++ b/griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py @@ -29,13 +29,12 @@ def scrape_url(self, url: str) -> TextArtifact: if page is None: raise Exception("can't access URL") - else: - extracted_page = trafilatura.extract( - page, - include_links=self.include_links, - output_format="json", - config=config, - ) + extracted_page = trafilatura.extract( + page, + include_links=self.include_links, + output_format="json", + config=config, + ) if not extracted_page: raise Exception("can't extract page") diff --git a/griptape/drivers/web_search/google_web_search_driver.py b/griptape/drivers/web_search/google_web_search_driver.py index b5ba01cb6..012c52307 100644 --- a/griptape/drivers/web_search/google_web_search_driver.py +++ b/griptape/drivers/web_search/google_web_search_driver.py @@ -33,9 +33,8 @@ def _search_google(self, query: str, **kwargs) -> list[dict]: if response.status_code == 200: data = response.json() - links = [{"url": r["link"], "title": r["title"], "description": r["snippet"]} for r in data["items"]] + return [{"url": r["link"], "title": r["title"], "description": r["snippet"]} for r in data["items"]] - return links else: raise Exception( f"Google Search API returned an error with status code " diff --git a/griptape/engines/extraction/base_extraction_engine.py b/griptape/engines/extraction/base_extraction_engine.py index 4b1184e5e..8f61bb764 100644 --- a/griptape/engines/extraction/base_extraction_engine.py +++ b/griptape/engines/extraction/base_extraction_engine.py @@ -31,7 +31,7 @@ class BaseExtractionEngine(ABC): def validate_max_token_multiplier(self, _: Attribute, max_token_multiplier: int) -> None: if max_token_multiplier > 1: raise ValueError("has to be less than or equal to 1") - elif max_token_multiplier <= 0: + if max_token_multiplier <= 0: raise ValueError("has to be greater than 0") @property diff --git a/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py b/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py index b79668583..4f53cc5f9 100644 --- a/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py +++ b/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py @@ -40,7 +40,6 @@ def run(self, context: RagContext) -> Sequence[TextArtifact]: if isinstance(loader_output, ErrorArtifact): raise Exception(loader_output.to_text() if loader_output.exception is None else loader_output.exception) - else: - self.vector_store_driver.upsert_text_artifacts({namespace: loader_output}) + self.vector_store_driver.upsert_text_artifacts({namespace: loader_output}) - return self.process_query_output_fn(self.vector_store_driver.query(context.query, **query_params)) + return self.process_query_output_fn(self.vector_store_driver.query(context.query, **query_params)) diff --git a/griptape/engines/summary/prompt_summary_engine.py b/griptape/engines/summary/prompt_summary_engine.py index 82c33a0ad..065677e1b 100644 --- a/griptape/engines/summary/prompt_summary_engine.py +++ b/griptape/engines/summary/prompt_summary_engine.py @@ -35,7 +35,7 @@ class PromptSummaryEngine(BaseSummaryEngine): def validate_allowlist(self, _: Attribute, max_token_multiplier: int) -> None: if max_token_multiplier > 1: raise ValueError("has to be less than or equal to 1") - elif max_token_multiplier <= 0: + if max_token_multiplier <= 0: raise ValueError("has to be greater than 0") @property diff --git a/griptape/loaders/audio_loader.py b/griptape/loaders/audio_loader.py index 532662e79..84d6b767a 100644 --- a/griptape/loaders/audio_loader.py +++ b/griptape/loaders/audio_loader.py @@ -14,9 +14,7 @@ class AudioLoader(BaseLoader): """Loads audio content into audio artifacts.""" def load(self, source: bytes, *args, **kwargs) -> AudioArtifact: - audio_artifact = AudioArtifact(source, format=import_optional_dependency("filetype").guess(source).extension) - - return audio_artifact + return AudioArtifact(source, format=import_optional_dependency("filetype").guess(source).extension) def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, AudioArtifact]: return cast(dict[str, AudioArtifact], super().load_collection(sources, *args, **kwargs)) diff --git a/griptape/loaders/image_loader.py b/griptape/loaders/image_loader.py index b7a277edb..83060dfa8 100644 --- a/griptape/loaders/image_loader.py +++ b/griptape/loaders/image_loader.py @@ -42,9 +42,7 @@ def load(self, source: bytes, *args, **kwargs) -> ImageArtifact: image = pil_image.open(byte_stream) source = byte_stream.getvalue() - image_artifact = ImageArtifact(source, format=image.format.lower(), width=image.width, height=image.height) - - return image_artifact + return ImageArtifact(source, format=image.format.lower(), width=image.width, height=image.height) def _get_mime_type(self, image_format: str | None) -> str: if image_format is None: diff --git a/griptape/mixins/activity_mixin.py b/griptape/mixins/activity_mixin.py index b137fe44e..61e8076b1 100644 --- a/griptape/mixins/activity_mixin.py +++ b/griptape/mixins/activity_mixin.py @@ -77,19 +77,17 @@ def find_activity(self, name: str) -> Optional[Callable]: def activity_name(self, activity: Callable) -> str: if activity is None or not getattr(activity, "is_activity", False): raise Exception("This method is not an activity.") - else: - return getattr(activity, "name") + return getattr(activity, "name") def activity_description(self, activity: Callable) -> str: if activity is None or not getattr(activity, "is_activity", False): raise Exception("This method is not an activity.") - else: - return Template(getattr(activity, "config")["description"]).render({"_self": self}) + return Template(getattr(activity, "config")["description"]).render({"_self": self}) def activity_schema(self, activity: Callable) -> Optional[Schema]: if activity is None or not getattr(activity, "is_activity", False): raise Exception("This method is not an activity.") - elif getattr(activity, "config")["schema"] is not None: + if getattr(activity, "config")["schema"] is not None: # Need to deepcopy to avoid modifying the original schema config_schema = deepcopy(getattr(activity, "config")["schema"]) activity_name = self.activity_name(activity) diff --git a/griptape/tokenizers/amazon_bedrock_tokenizer.py b/griptape/tokenizers/amazon_bedrock_tokenizer.py index 292dcde17..bd758c554 100644 --- a/griptape/tokenizers/amazon_bedrock_tokenizer.py +++ b/griptape/tokenizers/amazon_bedrock_tokenizer.py @@ -37,6 +37,4 @@ class AmazonBedrockTokenizer(BaseTokenizer): characters_per_token: int = field(default=4, kw_only=True) def count_tokens(self, text: str) -> int: - num_tokens = (len(text) + self.characters_per_token - 1) // self.characters_per_token - - return num_tokens + return (len(text) + self.characters_per_token - 1) // self.characters_per_token diff --git a/griptape/tokenizers/simple_tokenizer.py b/griptape/tokenizers/simple_tokenizer.py index 214e5be2d..97053acb2 100644 --- a/griptape/tokenizers/simple_tokenizer.py +++ b/griptape/tokenizers/simple_tokenizer.py @@ -11,6 +11,4 @@ class SimpleTokenizer(BaseTokenizer): characters_per_token: int = field(kw_only=True) def count_tokens(self, text: str) -> int: - num_tokens = (len(text) + self.characters_per_token - 1) // self.characters_per_token - - return num_tokens + return (len(text) + self.characters_per_token - 1) // self.characters_per_token diff --git a/griptape/utils/j2.py b/griptape/utils/j2.py index 70cf936db..3aecd8e3c 100644 --- a/griptape/utils/j2.py +++ b/griptape/utils/j2.py @@ -23,8 +23,7 @@ class J2: def render(self, **kwargs) -> str: if self.template_name is None: raise ValueError("template_name is required.") - else: - return self.environment.get_template(self.template_name).render(kwargs).rstrip() + return self.environment.get_template(self.template_name).render(kwargs).rstrip() def render_from_string(self, value: str, **kwargs) -> str: return self.environment.from_string(value).render(kwargs) diff --git a/griptape/utils/structure_visualizer.py b/griptape/utils/structure_visualizer.py index f24443cd6..260f6efb8 100644 --- a/griptape/utils/structure_visualizer.py +++ b/griptape/utils/structure_visualizer.py @@ -34,8 +34,7 @@ def to_url(self) -> str: graph_bytes = graph.encode("utf-8") base64_string = base64.b64encode(graph_bytes).decode("utf-8") - url = f"https://mermaid.ink/svg/{base64_string}" - return url + return f"https://mermaid.ink/svg/{base64_string}" def __render_task(self, task: BaseTask) -> str: if task.children: diff --git a/pyproject.toml b/pyproject.toml index 54b984261..147d59a9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -271,6 +271,7 @@ select = [ "G", # flake8-logging-format "T20", # flake8-print "PT", # flake8-pytest-style + "RET", # flake8-return "SIM", # flake8-simplify "TID", # flake8-tidy-imports "TCH", # flake8-type-checking @@ -299,6 +300,7 @@ ignore = [ "ANN102", # missing-type-cls "ANN401", # any-type "PT011", # pytest-raises-too-broad + "RET505" # superfluous-else-return ] [tool.ruff.lint.pydocstyle] convention = "google" diff --git a/tests/integration/drivers/vector/test_astra_db_vector_store_driver.py b/tests/integration/drivers/vector/test_astra_db_vector_store_driver.py index 94dbb8570..caa89144c 100644 --- a/tests/integration/drivers/vector/test_astra_db_vector_store_driver.py +++ b/tests/integration/drivers/vector/test_astra_db_vector_store_driver.py @@ -45,14 +45,13 @@ def vector_store_collection(self): @pytest.fixture() def vector_store_driver(self, embedding_driver, vector_store_collection): - driver = AstraDbVectorStoreDriver( + return AstraDbVectorStoreDriver( api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], collection_name=vector_store_collection.name, astra_db_namespace=os.environ.get("ASTRA_DB_KEYSPACE"), embedding_driver=embedding_driver, ) - return driver def test_vector_crud(self, vector_store_driver, vector_store_collection, embedding_driver): """Test basic vector CRUD, various call patterns.""" diff --git a/tests/mocks/mock_failing_prompt_driver.py b/tests/mocks/mock_failing_prompt_driver.py index 18895fdc9..9c760aab6 100644 --- a/tests/mocks/mock_failing_prompt_driver.py +++ b/tests/mocks/mock_failing_prompt_driver.py @@ -25,12 +25,11 @@ def try_run(self, prompt_stack: PromptStack) -> Message: self.current_attempt += 1 raise Exception("failed attempt") - else: - return Message( - content=[TextMessageContent(TextArtifact("success"))], - role=Message.ASSISTANT_ROLE, - usage=Message.Usage(input_tokens=100, output_tokens=100), - ) + return Message( + content=[TextMessageContent(TextArtifact("success"))], + role=Message.ASSISTANT_ROLE, + usage=Message.Usage(input_tokens=100, output_tokens=100), + ) def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: if self.current_attempt < self.max_failures: diff --git a/tests/unit/common/test_observable.py b/tests/unit/common/test_observable.py index f48c3086c..c06be5bb1 100644 --- a/tests/unit/common/test_observable.py +++ b/tests/unit/common/test_observable.py @@ -19,6 +19,7 @@ def bar(*args, **kwargs): """Bar's docstring.""" if args: return args[0] + return None assert bar() is None assert bar("a") == "a" @@ -48,6 +49,7 @@ def test_observable_function_empty_parenthesis(self, observe_spy): def bar(*args, **kwargs): if args: return args[0] + return None assert bar() is None assert bar("a") == "a" @@ -73,6 +75,7 @@ def test_observable_function_args(self, observe_spy): def bar(*args, **kwargs): if args: return args[0] + return None assert bar() is None assert bar("a") == "a" diff --git a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py index 5323f5d2d..0ece6c976 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py @@ -7,8 +7,7 @@ class TestHuggingFacePipelinePromptDriver: @pytest.fixture(autouse=True) def mock_pipeline(self, mocker): - mock_pipeline = mocker.patch("transformers.pipeline") - return mock_pipeline + return mocker.patch("transformers.pipeline") @pytest.fixture(autouse=True) def mock_generator(self, mock_pipeline): diff --git a/tests/unit/drivers/sql/test_snowflake_sql_driver.py b/tests/unit/drivers/sql/test_snowflake_sql_driver.py index a2887cb12..a758bb3a2 100644 --- a/tests/unit/drivers/sql/test_snowflake_sql_driver.py +++ b/tests/unit/drivers/sql/test_snowflake_sql_driver.py @@ -21,8 +21,7 @@ class Column: name: str type: str = "VARCHAR" - mock_table = mocker.MagicMock(name="table", columns=[Column("first_name"), Column("last_name")]) - return mock_table + return mocker.MagicMock(name="table", columns=[Column("first_name"), Column("last_name")]) @pytest.fixture() def mock_metadata(self, mocker): @@ -49,27 +48,22 @@ def mock_snowflake_engine(self, mocker): @pytest.fixture() def mock_snowflake_connection(self, mocker): - mock_connection = mocker.MagicMock(spec=SnowflakeConnection, name="connection") - return mock_connection + return mocker.MagicMock(spec=SnowflakeConnection, name="connection") @pytest.fixture() def mock_snowflake_connection_no_schema(self, mocker): - mock_connection = mocker.MagicMock(spec=SnowflakeConnection, name="connection_no_schema", schema=None) - return mock_connection + return mocker.MagicMock(spec=SnowflakeConnection, name="connection_no_schema", schema=None) @pytest.fixture() def mock_snowflake_connection_no_database(self, mocker): - mock_connection = mocker.MagicMock(spec=SnowflakeConnection, name="connection_no_database", database=None) - return mock_connection + return mocker.MagicMock(spec=SnowflakeConnection, name="connection_no_database", database=None) @pytest.fixture() def driver(self, mock_snowflake_engine, mock_snowflake_connection): def get_connection(): return mock_snowflake_connection - new_driver = SnowflakeSqlDriver(connection_func=get_connection, engine=mock_snowflake_engine) - - return new_driver + return SnowflakeSqlDriver(connection_func=get_connection, engine=mock_snowflake_engine) def test_connection_function_wrong_return_type(self): def get_connection() -> Any: diff --git a/tests/unit/drivers/vector/test_astra_db_vector_store_driver.py b/tests/unit/drivers/vector/test_astra_db_vector_store_driver.py index 16e6530b3..b544a3494 100644 --- a/tests/unit/drivers/vector/test_astra_db_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_astra_db_vector_store_driver.py @@ -9,10 +9,7 @@ class TestAstraDbVectorStoreDriver: @pytest.fixture(autouse=True) def base_mock_collection(self, mocker): - mock_get_collection = mocker.patch( - "astrapy.DataAPIClient" - ).return_value.get_database.return_value.get_collection - return mock_get_collection + return mocker.patch("astrapy.DataAPIClient").return_value.get_database.return_value.get_collection @pytest.fixture() def mock_collection(self, base_mock_collection, one_document): diff --git a/tests/unit/drivers/vector/test_qdrant_vector_store_driver.py b/tests/unit/drivers/vector/test_qdrant_vector_store_driver.py index 0b22784eb..ffb359953 100644 --- a/tests/unit/drivers/vector/test_qdrant_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_qdrant_vector_store_driver.py @@ -20,7 +20,7 @@ def mock_engine(self): @pytest.fixture(autouse=True) def driver(self, embedding_driver, mocker): mocker.patch("qdrant_client.QdrantClient") - driver = QdrantVectorStoreDriver( + return QdrantVectorStoreDriver( url="http://some_url", port=8080, grpc_port=50051, @@ -36,7 +36,6 @@ def driver(self, embedding_driver, mocker): content_payload_key="data", embedding_driver=embedding_driver, ) - return driver def test_attrs_post_init(self, driver): with patch("griptape.drivers.vector.qdrant_vector_store_driver.import_optional_dependency") as mock_import: