Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Reddit Tool #1666

Merged
merged 16 commits into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions cookbook/examples/agents/10_reddit_post_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from phi.agent import Agent
from phi.tools.duckduckgo import DuckDuckGo
from phi.tools.reddit import RedditTools


web_searcher = Agent(
name="Web Searcher",
role="Searches the web for information on a topic",
description="An intelligent agent that performs comprehensive web searches to gather current and accurate information",
tools=[DuckDuckGo()],
instructions=[
"1. Perform focused web searches using relevant keywords",
"2. Filter results for credibility and recency",
"3. Extract key information and main points",
"4. Organize information in a logical structure",
"5. Verify facts from multiple sources when possible",
"6. Focus on authoritative and reliable sources",
],
)

reddit_agent = Agent(
name="Reddit Agent",
role="Uploads post on Reddit",
description="Specialized agent for crafting and publishing engaging Reddit posts",
tools=[RedditTools()],
instructions=[
"1. Get information regarding the subreddit",
"2. Create attention-grabbing yet accurate titles",
"3. Format posts using proper Reddit markdown",
"4. Avoid including links ",
"5. Follow subreddit-specific rules and guidelines",
"6. Structure content for maximum readability",
"7. Add appropriate tags and flairs if required",
],
show_tool_calls=True,
)

post_team = Agent(
team=[web_searcher, reddit_agent],
instructions=[
"Work together to create engaging and informative Reddit posts",
"Start by researching the topic thoroughly using web searches",
"Craft a well-structured post with accurate information and sources",
"Follow Reddit guidelines and best practices for posting",
],
show_tool_calls=True,
markdown=True,
)

post_team.print_response(
"Create a post on web technologies and frameworks to focus in 2025 on the subreddit r/webdev ",
stream=True,
)
39 changes: 39 additions & 0 deletions cookbook/tools/reddit_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
Steps to get Reddit credentials:

1. Create/Login to Reddit account
- Go to https://www.reddit.com

2. Create a Reddit App
- Go to https://www.reddit.com/prefs/apps
- Click "Create App" or "Create Another App" button
- Fill in required details:
* Name: Your app name
* App type: Select "script"
* Description: Brief description
* About url: Your website (can be http://localhost)
* Redirect uri: http://localhost:8080
- Click "Create app" button

3. Get credentials
- client_id: Found under your app name (looks like a random string)
- client_secret: Listed as "secret"
- user_agent: Format as: "platform:app_id:version (by /u/username)"
- username: Your Reddit username
- password: Your Reddit account password

"""
from phi.agent import Agent
from phi.tools.reddit import RedditTools

agent = Agent(
instructions=[
"Use your tools to answer questions about Reddit content and statistics",
"Respect Reddit's content policies and NSFW restrictions",
"When analyzing subreddits, provide relevant statistics and trends",
],
tools=[RedditTools()],
show_tool_calls=True,
)

agent.print_response("What are the top 5 posts on r/SAAS this week ?", stream=True)
288 changes: 288 additions & 0 deletions phi/tools/reddit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
import json
from os import getenv
from typing import Optional, Dict, List, Union
from phi.tools import Toolkit
from phi.utils.log import logger

try:
import praw # type: ignore
except ImportError:
raise ImportError("praw` not installed. Please install using `pip install praw`")


class RedditTools(Toolkit):
def __init__(
self,
reddit_instance: Optional[praw.Reddit] = None,
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
user_agent: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
get_user_info: bool = True,
get_top_posts: bool = True,
get_subreddit_info: bool = True,
get_trending_subreddits: bool = True,
get_subreddit_stats: bool = True,
create_post: bool = True,
srexrg marked this conversation as resolved.
Show resolved Hide resolved
):
super().__init__(name="reddit")

if reddit_instance is not None:
logger.info("Using provided Reddit instance")
self.reddit = reddit_instance
else:
# Get credentials from environment variables if not provided
self.client_id = client_id or getenv("REDDIT_CLIENT_ID")
self.client_secret = client_secret or getenv("REDDIT_CLIENT_SECRET")
self.user_agent = user_agent or getenv("REDDIT_USER_AGENT", "RedditTools v1.0")
self.username = username or getenv("REDDIT_USERNAME")
self.password = password or getenv("REDDIT_PASSWORD")

self.reddit = None
# Check if we have all required credentials
if all([self.client_id, self.client_secret]):
# Initialize with read-only access if no user credentials
if not all([self.username, self.password]):
logger.info("Initializing Reddit client with read-only access")
self.reddit = praw.Reddit(
client_id=self.client_id,
client_secret=self.client_secret,
user_agent=self.user_agent,
)
# Initialize with user authentication if credentials provided
else:
logger.info(f"Initializing Reddit client with user authentication for u/{self.username}")
self.reddit = praw.Reddit(
client_id=self.client_id,
client_secret=self.client_secret,
user_agent=self.user_agent,
username=self.username,
password=self.password,
)
else:
logger.warning("Missing Reddit API credentials")

if get_user_info:
self.register(self.get_user_info)
if get_top_posts:
self.register(self.get_top_posts)
if get_subreddit_info:
self.register(self.get_subreddit_info)
if get_trending_subreddits:
self.register(self.get_trending_subreddits)
if get_subreddit_stats:
self.register(self.get_subreddit_stats)
if create_post:
self.register(self.create_post)

def _check_user_auth(self) -> bool:
"""
Check if user authentication is available for actions that require it.
Returns:
bool: True if user is authenticated, False otherwise
"""
if not self.reddit:
logger.error("Reddit client not initialized")
return False

if not all([self.username, self.password]):
logger.error("User authentication required. Please provide username and password.")
return False

try:
# Verify authentication by checking if we can get the authenticated user
self.reddit.user.me()
return True
except Exception as e:
logger.error(f"Authentication error: {e}")
return False

def get_user_info(self, username: str) -> str:
"""Get information about a Reddit user."""
if not self.reddit:
return "Please provide Reddit API credentials"

try:
logger.info(f"Getting info for u/{username}")

user = self.reddit.redditor(username)
info: Dict[str, Union[str, int, bool, float]] = {
"name": user.name,
"comment_karma": user.comment_karma,
"link_karma": user.link_karma,
"is_mod": user.is_mod,
"is_gold": user.is_gold,
"is_employee": user.is_employee,
"created_utc": user.created_utc,
}

return json.dumps(info)

except Exception as e:
return f"Error getting user info: {e}"

def get_top_posts(self, subreddit: str, time_filter: str = "week", limit: int = 10) -> str:
"""
Get top posts from a subreddit for a specific time period.
Args:
subreddit (str): Name of the subreddit.
time_filter (str): Time period to filter posts.
limit (int): Number of posts to fetch.
Returns:
str: JSON string containing top posts.
"""
if not self.reddit:
return "Please provide Reddit API credentials"

try:
posts = self.reddit.subreddit(subreddit).top(time_filter=time_filter, limit=limit)
top_posts: List[Dict[str, Union[str, int, float]]] = [
{
"title": post.title,
"score": post.score,
"url": post.url,
"author": str(post.author),
"created_utc": post.created_utc,
}
for post in posts
]
return json.dumps({"top_posts": top_posts})
except Exception as e:
return f"Error getting top posts: {e}"

def get_subreddit_info(self, subreddit_name: str) -> str:
"""
Get information about a specific subreddit.
Args:
subreddit_name (str): Name of the subreddit.
Returns:
str: JSON string containing subreddit information.
"""
if not self.reddit:
return "Please provide Reddit API credentials"

try:
logger.info(f"Getting info for r/{subreddit_name}")

subreddit = self.reddit.subreddit(subreddit_name)
flairs = [flair["text"] for flair in subreddit.flair.link_templates]
info: Dict[str, Union[str, int, bool, float, List[str]]] = {
"display_name": subreddit.display_name,
"title": subreddit.title,
"description": subreddit.description,
"subscribers": subreddit.subscribers,
"created_utc": subreddit.created_utc,
"over18": subreddit.over18,
"available_flairs": flairs,
"public_description": subreddit.public_description,
"url": subreddit.url,
}

return json.dumps(info)

except Exception as e:
return f"Error getting subreddit info: {e}"

def get_trending_subreddits(self) -> str:
"""Get currently trending subreddits."""
if not self.reddit:
return "Please provide Reddit API credentials"

try:
popular_subreddits = self.reddit.subreddits.popular(limit=5)
trending: List[str] = [subreddit.display_name for subreddit in popular_subreddits]
return json.dumps({"trending_subreddits": trending})
except Exception as e:
return f"Error getting trending subreddits: {e}"

def get_subreddit_stats(self, subreddit: str) -> str:
"""
Get statistics about a subreddit.
Args:
subreddit (str): Name of the subreddit.
Returns:
str: JSON string containing subreddit statistics
"""
if not self.reddit:
return "Please provide Reddit API credentials"

try:
sub = self.reddit.subreddit(subreddit)
stats: Dict[str, Union[str, int, bool, float]] = {
"display_name": sub.display_name,
"subscribers": sub.subscribers,
"active_users": sub.active_user_count,
"description": sub.description,
"created_utc": sub.created_utc,
"over18": sub.over18,
"public_description": sub.public_description,
}
return json.dumps({"subreddit_stats": stats})
except Exception as e:
return f"Error getting subreddit stats: {e}"

def create_post(
self,
subreddit: str,
title: str,
content: str,
flair: Optional[str] = None,
is_self: bool = True,
) -> str:
"""
Create a new post in a subreddit.

Args:
subreddit (str): Name of the subreddit to post in.
title (str): Title of the post.
content (str): Content of the post (text for self posts, URL for link posts).
flair (Optional[str]): Flair to add to the post. Must be an available flair in the subreddit.
is_self (bool): Whether this is a self (text) post (True) or link post (False).
Returns:
str: JSON string containing the created post information.
"""
if not self.reddit:
return "Please provide Reddit API credentials"

if not self._check_user_auth():
return "User authentication required for posting. Please provide username and password."

try:
logger.info(f"Creating post in r/{subreddit}")

subreddit_obj = self.reddit.subreddit(subreddit)

if flair:
available_flairs = [f["text"] for f in subreddit_obj.flair.link_templates]
if flair not in available_flairs:
return f"Invalid flair. Available flairs: {', '.join(available_flairs)}"

if is_self:
submission = subreddit_obj.submit(
title=title,
selftext=content,
flair_id=flair,
)
else:
submission = subreddit_obj.submit(
title=title,
url=content,
flair_id=flair,
)
logger.info(f"Post created: {submission.permalink}")

post_info: Dict[str, Union[str, int, float]] = {
"id": submission.id,
"title": submission.title,
"url": submission.url,
"permalink": submission.permalink,
"created_utc": submission.created_utc,
"author": str(submission.author),
"flair": submission.link_flair_text,
}

return json.dumps({"post": post_info})

except Exception as e:
return f"Error creating post: {e}"