Skip to content

Commit

Permalink
Update ruff linter rules according to documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jul 5, 2024
1 parent c785aba commit db8a2e4
Show file tree
Hide file tree
Showing 30 changed files with 98 additions and 120 deletions.
5 changes: 1 addition & 4 deletions griptape/artifacts/base_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@ def value_to_bytes(cls, value: Any) -> bytes:

@classmethod
def value_to_dict(cls, value: Any) -> dict:
if isinstance(value, dict):
dict_value = value
else:
dict_value = json.loads(value)
dict_value = value if isinstance(value, dict) else json.loads(value)

return {k: v for k, v in dict_value.items()}

Expand Down
5 changes: 1 addition & 4 deletions griptape/chunkers/base_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,7 @@ def _chunk_recursively(self, chunk: str, current_separator: Optional[ChunkSepara
# Iterate through the subchunks and calculate token counts.
for index, subchunk in enumerate(subchunks):
if index < len(subchunks):
if separator.is_prefix:
subchunk = separator.value + subchunk
else:
subchunk = subchunk + separator.value
subchunk = separator.value + subchunk if separator.is_prefix else subchunk + separator.value

tokens_count += self.tokenizer.count_tokens(subchunk)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def try_load_file(self, path: str) -> bytes:
return response["Body"].read()
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] in {"NoSuchKey", "404"}:
raise FileNotFoundError
raise FileNotFoundError from e
else:
raise e

Expand Down
7 changes: 2 additions & 5 deletions griptape/drivers/file_manager/base_file_manager_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,8 @@ def save_file(self, path: str, value: bytes | str) -> InfoArtifact | ErrorArtifa
encoding = None if loader is None else loader.encoding

if isinstance(value, str):
if encoding is None:
value = value.encode()
else:
value = value.encode(encoding=encoding)
elif isinstance(value, bytearray) or isinstance(value, memoryview):
value = value.encode() if encoding is None else value.encode(encoding=encoding)
elif isinstance(value, (bytearray, memoryview)):
raise ValueError(f"Unsupported type: {type(value)}")

self.try_save_file(path, value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,6 @@ def _make_request(self, request: dict) -> bytes:
try:
image_bytes = self.image_generation_model_driver.get_generated_image(response_body)
except Exception as e:
raise ValueError(f"Inpainting generation failed: {e}")
raise ValueError(f"Inpainting generation failed: {e}") from e

Check warning on line 125 in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py#L125

Added line #L125 was not covered by tests

return image_bytes
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
try:
return self.image_query_model_driver.process_output(response_body)
except Exception as e:
raise ValueError(f"Output is unable to be processed as returned {e}")
raise ValueError(f"Output is unable to be processed as returned {e}") from e

Check warning on line 35 in griptape/drivers/image_query/amazon_bedrock_image_query_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/image_query/amazon_bedrock_image_query_driver.py#L35

Added line #L35 was not covered by tests
5 changes: 1 addition & 4 deletions griptape/drivers/prompt/anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
)

system_messages = prompt_stack.system_messages
if system_messages:
system_message = system_messages[0].to_text()
else:
system_message = None
system_message = system_messages[0].to_text() if system_messages else None

return {
"model": self.model,
Expand Down
5 changes: 1 addition & 4 deletions griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,7 @@ def run(self, prompt_stack: PromptStack) -> Message:
with attempt:
self.before_run(prompt_stack)

if self.stream:
result = self.__process_stream(prompt_stack)
else:
result = self.__process_run(prompt_stack)
result = self.__process_stream(prompt_stack) if self.stream else self.__process_run(prompt_stack)

self.after_run(result)

Expand Down
5 changes: 1 addition & 4 deletions griptape/drivers/prompt/cohere_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
)

system_messages = prompt_stack.system_messages
if system_messages:
preamble = system_messages[0].to_text()
else:
preamble = None
preamble = system_messages[0].to_text() if system_messages else None

return {
"message": user_message,
Expand Down
5 changes: 1 addition & 4 deletions griptape/drivers/vector/base_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,7 @@ def upsert_text_artifact(
else:
meta["artifact"] = artifact.to_json()

if artifact.embedding:
vector = artifact.embedding
else:
vector = artifact.generate_embedding(self.embedding_driver)
vector = artifact.embedding if artifact.embedding else artifact.generate_embedding(self.embedding_driver)

if isinstance(vector, list):
return self.upsert_vector(vector, vector_id=vector_id, namespace=namespace, meta=meta, **kwargs)
Expand Down
5 changes: 1 addition & 4 deletions griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,7 @@ def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreD
Entries can optionally be filtered by namespace.
"""
collection = self.get_collection()
if namespace is None:
cursor = collection.find()
else:
cursor = collection.find({"namespace": namespace})
cursor = collection.find() if namespace is None else collection.find({"namespace": namespace})

return [
BaseVectorStoreDriver.Entry(
Expand Down
67 changes: 33 additions & 34 deletions griptape/drivers/web_scraper/markdownify_web_scraper_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,51 +48,50 @@ def convert_a(self, el, text, convert_as_inline):
return super().convert_a(el, text, convert_as_inline)
return text

with sync_playwright() as p:
with p.chromium.launch(headless=True) as browser:
page = browser.new_page()
with sync_playwright() as p, p.chromium.launch(headless=True) as browser:
page = browser.new_page()

def skip_loading_images(route):
if route.request.resource_type == "image":
return route.abort()
route.continue_()
def skip_loading_images(route):
if route.request.resource_type == "image":
return route.abort()
route.continue_()

Check warning on line 57 in griptape/drivers/web_scraper/markdownify_web_scraper_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/web_scraper/markdownify_web_scraper_driver.py#L56-L57

Added lines #L56 - L57 were not covered by tests

page.route("**/*", skip_loading_images)
page.route("**/*", skip_loading_images)

page.goto(url)
page.goto(url)

# Some websites require a delay before the content is fully loaded
# even after the browser has emitted "load" event.
if self.timeout:
page.wait_for_timeout(self.timeout)
# Some websites require a delay before the content is fully loaded
# even after the browser has emitted "load" event.
if self.timeout:
page.wait_for_timeout(self.timeout)

Check warning on line 66 in griptape/drivers/web_scraper/markdownify_web_scraper_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/web_scraper/markdownify_web_scraper_driver.py#L66

Added line #L66 was not covered by tests

content = page.content()
content = page.content()

if not content:
raise Exception("can't access URL")
if not content:
raise Exception("can't access URL")

soup = BeautifulSoup(content, "html.parser")
soup = BeautifulSoup(content, "html.parser")

# Remove unwanted elements
exclude_selector = ",".join(
self.exclude_tags + [f".{c}" for c in self.exclude_classes] + [f"#{i}" for i in self.exclude_ids]
)
if exclude_selector:
for s in soup.select(exclude_selector):
s.extract()
# Remove unwanted elements
exclude_selector = ",".join(
self.exclude_tags + [f".{c}" for c in self.exclude_classes] + [f"#{i}" for i in self.exclude_ids]
)
if exclude_selector:
for s in soup.select(exclude_selector):
s.extract()

text = OptionalLinksMarkdownConverter().convert_soup(soup)
text = OptionalLinksMarkdownConverter().convert_soup(soup)

# Remove leading and trailing whitespace from the entire text
text = text.strip()
# Remove leading and trailing whitespace from the entire text
text = text.strip()

# Remove trailing whitespace from each line
text = re.sub(r"[ \t]+$", "", text, flags=re.MULTILINE)
# Remove trailing whitespace from each line
text = re.sub(r"[ \t]+$", "", text, flags=re.MULTILINE)

# Indent using 2 spaces instead of tabs
text = re.sub(r"(\n?\s*?)\t", r"\1 ", text)
# Indent using 2 spaces instead of tabs
text = re.sub(r"(\n?\s*?)\t", r"\1 ", text)

# Remove triple+ newlines (keep double newlines for paragraphs)
text = re.sub(r"\n\n+", "\n\n", text)
# Remove triple+ newlines (keep double newlines for paragraphs)
text = re.sub(r"\n\n+", "\n\n", text)

return TextArtifact(text)
return TextArtifact(text)
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ def search(self, query: str, **kwargs) -> ListArtifact:
]
)
except Exception as e:
raise Exception(f"Error searching '{query}' with DuckDuckGo: {e}")
raise Exception(f"Error searching '{query}' with DuckDuckGo: {e}") from e
5 changes: 1 addition & 4 deletions griptape/loaders/base_text_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,7 @@ def load_collection(self, sources: list[Any], *args, **kwargs) -> dict[str, Erro
def _text_to_artifacts(self, text: str) -> list[TextArtifact]:
artifacts = []

if self.chunker:
chunks = self.chunker.chunk(text)
else:
chunks = [TextArtifact(text)]
chunks = self.chunker.chunk(text) if self.chunker else [TextArtifact(text)]

if self.embedding_driver:
for chunk in chunks:
Expand Down
5 changes: 1 addition & 4 deletions griptape/loaders/sql_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@ def load(self, source: str, *args, **kwargs) -> list[CsvRowArtifact]:
rows = self.sql_driver.execute_query(source)
artifacts = []

if rows:
chunks = [CsvRowArtifact(row.cells) for row in rows]
else:
chunks = []
chunks = [CsvRowArtifact(row.cells) for row in rows] if rows else []

if self.embedding_driver:
for chunk in chunks:
Expand Down
5 changes: 1 addition & 4 deletions griptape/mixins/rule_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,7 @@ def all_rulesets(self) -> list[Ruleset]:
if self.rulesets:
task_rulesets = self.rulesets
elif self.rules:
if structure_rulesets:
task_ruleset_name = self.ADDITIONAL_RULESET_NAME
else:
task_ruleset_name = self.DEFAULT_RULESET_NAME
task_ruleset_name = self.ADDITIONAL_RULESET_NAME if structure_rulesets else self.DEFAULT_RULESET_NAME

task_rulesets = [Ruleset(name=task_ruleset_name, rules=self.rules)]

Expand Down
2 changes: 1 addition & 1 deletion griptape/mixins/serializable_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_schema(cls: type[T], subclass_name: Optional[str] = None) -> Schema:

@classmethod
def from_dict(cls: type[T], data: dict) -> T:
return cast(T, cls.get_schema(subclass_name=data["type"] if "type" in data else None).load(data))
return cast(T, cls.get_schema(subclass_name=data.get("type")).load(data))

@classmethod
def from_json(cls: type[T], data: str) -> T:
Expand Down
8 changes: 4 additions & 4 deletions griptape/schemas/polymorphic_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ def _dump(self, obj, *, update_fields=True, **kwargs):
obj_type = self.get_obj_type(obj)

if not obj_type:
return (None, {"_schema": "Unknown object class: %s" % obj.__class__.__name__})
return (None, {"_schema": f"Unknown object class: {obj.__class__.__name__}"})

Check warning on line 61 in griptape/schemas/polymorphic_schema.py

View check run for this annotation

Codecov / codecov/patch

griptape/schemas/polymorphic_schema.py#L61

Added line #L61 was not covered by tests

type_schema = BaseSchema.from_attrs_cls(obj.__class__)

if not type_schema:
return None, {"_schema": "Unsupported object type: %s" % obj_type}
return None, {"_schema": f"Unsupported object type: {obj_type}"}

Check warning on line 66 in griptape/schemas/polymorphic_schema.py

View check run for this annotation

Codecov / codecov/patch

griptape/schemas/polymorphic_schema.py#L66

Added line #L66 was not covered by tests

schema = type_schema if isinstance(type_schema, Schema) else type_schema()

Expand Down Expand Up @@ -110,7 +110,7 @@ def load(self, data, *, many=None, partial=None, unknown=None, **kwargs):

def _load(self, data, *, partial=None, unknown=None, **kwargs):
if not isinstance(data, dict):
raise ValidationError({"_schema": "Invalid data type: %s" % data})
raise ValidationError({"_schema": f"Invalid data type: {data}"})

Check warning on line 113 in griptape/schemas/polymorphic_schema.py

View check run for this annotation

Codecov / codecov/patch

griptape/schemas/polymorphic_schema.py#L113

Added line #L113 was not covered by tests

data = dict(data)
unknown = unknown or self.unknown
Expand All @@ -121,7 +121,7 @@ def _load(self, data, *, partial=None, unknown=None, **kwargs):

type_schema = self.inner_class.get_schema(data_type)
if not type_schema:
raise ValidationError({self.type_field: ["Unsupported value: %s" % data_type]})
raise ValidationError({self.type_field: [f"Unsupported value: {data_type}"]})

Check warning on line 124 in griptape/schemas/polymorphic_schema.py

View check run for this annotation

Codecov / codecov/patch

griptape/schemas/polymorphic_schema.py#L124

Added line #L124 was not covered by tests

schema = type_schema if isinstance(type_schema, Schema) else type_schema()

Expand Down
12 changes: 3 additions & 9 deletions griptape/structures/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,9 @@ def default_config(self) -> BaseStructureConfig:
if self.prompt_driver is not None or self.embedding_driver is not None or self.stream is not None:
config = StructureConfig()

if self.prompt_driver is None:
prompt_driver = OpenAiChatPromptDriver(model="gpt-4o")
else:
prompt_driver = self.prompt_driver

if self.embedding_driver is None:
embedding_driver = OpenAiEmbeddingDriver()
else:
embedding_driver = self.embedding_driver
prompt_driver = OpenAiChatPromptDriver(model="gpt-4o") if self.prompt_driver is None else self.prompt_driver

embedding_driver = OpenAiEmbeddingDriver() if self.embedding_driver is None else self.embedding_driver

if self.stream is not None:
prompt_driver.stream = self.stream
Expand Down
4 changes: 2 additions & 2 deletions griptape/structures/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def insert_task(

try:
parent_index = self.tasks.index(parent_task)
except ValueError:
raise ValueError(f"Parent task {parent_task.id} not found in workflow.")
except ValueError as exc:
raise ValueError(f"Parent task {parent_task.id} not found in workflow.") from exc
else:
if parent_index > last_parent_index:
last_parent_index = parent_index
Expand Down
14 changes: 6 additions & 8 deletions griptape/tasks/actions_subtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,10 @@ def run(self) -> BaseArtifact:
self.structure.logger.error(f"Subtask {self.id}\n{e}", exc_info=True)

self.output = ErrorArtifact(str(e), exception=e)
finally:
if self.output is not None:
return self.output
else:
return ErrorArtifact("no tool output")
if self.output is not None:
return self.output
else:
return ErrorArtifact("no tool output")

Check warning on line 118 in griptape/tasks/actions_subtask.py

View check run for this annotation

Codecov / codecov/patch

griptape/tasks/actions_subtask.py#L118

Added line #L118 was not covered by tests

def execute_actions(self, actions: list[Action]) -> list[tuple[str, BaseArtifact]]:
with self.futures_executor_fn() as executor:
Expand Down Expand Up @@ -234,9 +233,8 @@ def __parse_actions(self, actions_matches: list[str]) -> None:
tag=action_tag, name=action_name, path=action_path, input=action_input, tool=tool
)

if new_action.tool:
if new_action.input:
self.__validate_action(new_action)
if new_action.tool and new_action.input:
self.__validate_action(new_action)

# Don't forget to add it to the subtask actions list!
self.actions.append(new_action)
Expand Down
2 changes: 1 addition & 1 deletion griptape/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def execute(self) -> Optional[BaseArtifact]:
finally:
self.state = BaseTask.State.FINISHED

return self.output
return self.output

def can_execute(self) -> bool:
return self.state == BaseTask.State.PENDING and all(parent.is_finished() for parent in self.parents)
Expand Down
2 changes: 1 addition & 1 deletion griptape/tasks/prompt_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _process_task_input(
return self._process_task_input(task_input(self))
elif isinstance(task_input, BaseArtifact):
return task_input
elif isinstance(task_input, list) or isinstance(task_input, tuple):
elif isinstance(task_input, (list, tuple)):
return ListArtifact([self._process_task_input(elem) for elem in task_input])
else:
return self._process_task_input(TextArtifact(task_input))
3 changes: 2 additions & 1 deletion griptape/tokenizers/base_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
import logging
from abc import ABC
from abc import ABC, abstractmethod
from attrs import define, field, Factory


Expand Down Expand Up @@ -40,6 +40,7 @@ def count_output_tokens_left(self, text: str) -> int:
else:
return 0

@abstractmethod
def count_tokens(self, text: str) -> int: ...

def _default_max_input_tokens(self) -> int:
Expand Down
Loading

0 comments on commit db8a2e4

Please sign in to comment.