Skip to content

Commit

Permalink
Fix: path param to dir_path (#1601)
Browse files Browse the repository at this point in the history
## Description
Change the `path` param to `dir_path`

## Type of change

Please check the options that are relevant:

- [x] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] Model update
- [ ] Infrastructure change

## Checklist

- [ ] My code follows Phidata's style guidelines and best practices
- [ ] I have performed a self-review of my code
- [ ] I have added docstrings and comments for complex logic
- [ ] My changes generate no new warnings or errors
- [ ] I have added cookbook examples for my new addition (if needed)
- [ ] I have updated requirements.txt/pyproject.toml (if needed)
- [ ] I have verified my changes in a clean environment

## Additional Notes

Include any deployment notes, performance implications, or other
relevant information:
  • Loading branch information
manthanguptaa authored Dec 18, 2024
1 parent 9ef259f commit 65ebb7f
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 24 deletions.
2 changes: 1 addition & 1 deletion cookbook/storage/json_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from phi.storage.agent.json import JsonFileAgentStorage

agent = Agent(
storage=JsonFileAgentStorage(path="tmp/agent_sessions_json"),
storage=JsonFileAgentStorage(dir_path="tmp/agent_sessions_json"),
tools=[DuckDuckGo()],
add_history_to_messages=True,
)
Expand Down
2 changes: 1 addition & 1 deletion cookbook/storage/yaml_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from phi.storage.agent.yaml import YamlFileAgentStorage

agent = Agent(
storage=YamlFileAgentStorage(path="tmp/agent_sessions_yaml"),
storage=YamlFileAgentStorage(dir_path="tmp/agent_sessions_yaml"),
tools=[DuckDuckGo()],
add_history_to_messages=True,
)
Expand Down
22 changes: 11 additions & 11 deletions phi/storage/agent/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@


class JsonFileAgentStorage(AgentStorage):
def __init__(self, path: Union[str, Path]):
self.path = Path(path)
self.path.mkdir(parents=True, exist_ok=True)
def __init__(self, dir_path: Union[str, Path]):
self.dir_path = Path(dir_path)
self.dir_path.mkdir(parents=True, exist_ok=True)

def serialize(self, data: dict) -> str:
return json.dumps(data, ensure_ascii=False, indent=4)
Expand All @@ -21,13 +21,13 @@ def deserialize(self, data: str) -> dict:

def create(self) -> None:
"""Create the storage if it doesn't exist."""
if not self.path.exists():
self.path.mkdir(parents=True, exist_ok=True)
if not self.dir_path.exists():
self.dir_path.mkdir(parents=True, exist_ok=True)

def read(self, session_id: str, user_id: Optional[str] = None) -> Optional[AgentSession]:
"""Read an AgentSession from storage."""
try:
with open(self.path / f"{session_id}.json", "r", encoding="utf-8") as f:
with open(self.dir_path / f"{session_id}.json", "r", encoding="utf-8") as f:
data = self.deserialize(f.read())
if user_id and data["user_id"] != user_id:
return None
Expand All @@ -38,7 +38,7 @@ def read(self, session_id: str, user_id: Optional[str] = None) -> Optional[Agent
def get_all_session_ids(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> List[str]:
"""Get all session IDs, optionally filtered by user_id and/or agent_id."""
session_ids = []
for file in self.path.glob("*.json"):
for file in self.dir_path.glob("*.json"):
with open(file, "r", encoding="utf-8") as f:
data = self.deserialize(f.read())
if (not user_id or data["user_id"] == user_id) and (not agent_id or data["agent_id"] == agent_id):
Expand All @@ -48,7 +48,7 @@ def get_all_session_ids(self, user_id: Optional[str] = None, agent_id: Optional[
def get_all_sessions(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> List[AgentSession]:
"""Get all sessions, optionally filtered by user_id and/or agent_id."""
sessions = []
for file in self.path.glob("*.json"):
for file in self.dir_path.glob("*.json"):
with open(file, "r", encoding="utf-8") as f:
data = self.deserialize(f.read())
if (not user_id or data["user_id"] == user_id) and (not agent_id or data["agent_id"] == agent_id):
Expand All @@ -63,7 +63,7 @@ def upsert(self, session: AgentSession) -> Optional[AgentSession]:
if "created_at" not in data:
data["created_at"] = data["updated_at"]

with open(self.path / f"{session.session_id}.json", "w", encoding="utf-8") as f:
with open(self.dir_path / f"{session.session_id}.json", "w", encoding="utf-8") as f:
f.write(self.serialize(data))
return session
except Exception as e:
Expand All @@ -75,13 +75,13 @@ def delete_session(self, session_id: Optional[str] = None):
if session_id is None:
return
try:
(self.path / f"{session_id}.json").unlink(missing_ok=True)
(self.dir_path / f"{session_id}.json").unlink(missing_ok=True)
except Exception as e:
logger.error(f"Error deleting session: {e}")

def drop(self) -> None:
"""Drop all sessions from storage."""
for file in self.path.glob("*.json"):
for file in self.dir_path.glob("*.json"):
file.unlink()

def upgrade_schema(self) -> None:
Expand Down
22 changes: 11 additions & 11 deletions phi/storage/agent/yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@


class YamlFileAgentStorage(AgentStorage):
def __init__(self, path: Union[str, Path]):
self.path = Path(path)
self.path.mkdir(parents=True, exist_ok=True)
def __init__(self, dir_path: Union[str, Path]):
self.dir_path = Path(dir_path)
self.dir_path.mkdir(parents=True, exist_ok=True)

def serialize(self, data: dict) -> str:
return yaml.dump(data, default_flow_style=False)
Expand All @@ -21,13 +21,13 @@ def deserialize(self, data: str) -> dict:

def create(self) -> None:
"""Create the storage if it doesn't exist."""
if not self.path.exists():
self.path.mkdir(parents=True, exist_ok=True)
if not self.dir_path.exists():
self.dir_path.mkdir(parents=True, exist_ok=True)

def read(self, session_id: str, user_id: Optional[str] = None) -> Optional[AgentSession]:
"""Read an AgentSession from storage."""
try:
with open(self.path / f"{session_id}.yaml", "r", encoding="utf-8") as f:
with open(self.dir_path / f"{session_id}.yaml", "r", encoding="utf-8") as f:
data = self.deserialize(f.read())
if user_id and data["user_id"] != user_id:
return None
Expand All @@ -38,7 +38,7 @@ def read(self, session_id: str, user_id: Optional[str] = None) -> Optional[Agent
def get_all_session_ids(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> List[str]:
"""Get all session IDs, optionally filtered by user_id and/or agent_id."""
session_ids = []
for file in self.path.glob("*.yaml"):
for file in self.dir_path.glob("*.yaml"):
with open(file, "r", encoding="utf-8") as f:
data = self.deserialize(f.read())
if (not user_id or data["user_id"] == user_id) and (not agent_id or data["agent_id"] == agent_id):
Expand All @@ -48,7 +48,7 @@ def get_all_session_ids(self, user_id: Optional[str] = None, agent_id: Optional[
def get_all_sessions(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> List[AgentSession]:
"""Get all sessions, optionally filtered by user_id and/or agent_id."""
sessions = []
for file in self.path.glob("*.yaml"):
for file in self.dir_path.glob("*.yaml"):
with open(file, "r", encoding="utf-8") as f:
data = self.deserialize(f.read())
if (not user_id or data["user_id"] == user_id) and (not agent_id or data["agent_id"] == agent_id):
Expand All @@ -63,7 +63,7 @@ def upsert(self, session: AgentSession) -> Optional[AgentSession]:
if "created_at" not in data:
data["created_at"] = data["updated_at"]

with open(self.path / f"{session.session_id}.yaml", "w", encoding="utf-8") as f:
with open(self.dir_path / f"{session.session_id}.yaml", "w", encoding="utf-8") as f:
f.write(self.serialize(data))
return session
except Exception as e:
Expand All @@ -75,13 +75,13 @@ def delete_session(self, session_id: Optional[str] = None):
if session_id is None:
return
try:
(self.path / f"{session_id}.yaml").unlink(missing_ok=True)
(self.dir_path / f"{session_id}.yaml").unlink(missing_ok=True)
except Exception as e:
logger.error(f"Error deleting session: {e}")

def drop(self) -> None:
"""Drop all sessions from storage."""
for file in self.path.glob("*.yaml"):
for file in self.dir_path.glob("*.yaml"):
file.unlink()

def upgrade_schema(self) -> None:
Expand Down

0 comments on commit 65ebb7f

Please sign in to comment.