Skip to content

Commit

Permalink
chore[Security]: restrict libs to allow specific functionalities
Browse files Browse the repository at this point in the history
  • Loading branch information
ArslanSaleem committed Nov 12, 2024
1 parent 719043c commit 3850595
Show file tree
Hide file tree
Showing 12 changed files with 721 additions and 23 deletions.
31 changes: 22 additions & 9 deletions pandasai/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,34 @@

# List of Python packages that are whitelisted for import in generated code
WHITELISTED_LIBRARIES = [
"sklearn",
"statsmodels",
"seaborn",
"plotly",
"ggplot",
"matplotlib",
"numpy",
"datetime",
"json",
"io",
"base64",
"scipy",
"streamlit",
"modin",
"scikit-learn",
"pandas",
]

# List of restricted libs
RESTRICTED_LIBS = [
"os", # OS-level operations (file handling, environment variables)
"sys", # System-level access
"subprocess", # Run system commands
"shutil", # File operations, including delete
"multiprocessing", # Spawn new processes
"threading", # Thread-level operations
"socket", # Network connections
"http", # HTTP requests
"ftplib", # FTP connections
"paramiko", # SSH operations
"tempfile", # Create temporary files
"pathlib", # Filesystem path handling
"resource", # Access resource usage limits (system-related)
"ssl", # SSL socket connections
"pickle", # Unsafe object serialization
"ctypes", # C-level interaction with memory
"psutil", # System and process utilities
]

PANDASBI_SETUP_MESSAGE = (
Expand Down
36 changes: 29 additions & 7 deletions pandasai/helpers/optional.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@
import warnings
from typing import TYPE_CHECKING, List

import matplotlib.pyplot as plt
import numpy as np
from pandas.util.version import Version

import pandasai.pandas as pd
from pandasai.constants import WHITELISTED_BUILTINS
from pandasai.safe_libs.restricted_base64 import RestrictedBase64
from pandasai.safe_libs.restricted_datetime import RestrictedDatetime
from pandasai.safe_libs.restricted_json import RestrictedJson
from pandasai.safe_libs.restricted_matplotlib import RestrictedMatplotlib
from pandasai.safe_libs.restricted_numpy import RestrictedNumpy
from pandasai.safe_libs.restricted_pandas import RestrictedPandas
from pandasai.safe_libs.restricted_seaborn import RestrictedSeaborn

if TYPE_CHECKING:
import types
Expand Down Expand Up @@ -54,10 +58,7 @@ def get_environment(additional_deps: List[dict]) -> dict:
Returns (dict): A dictionary of environment variables
"""
return {
"pd": pd,
"plt": plt,
"np": np,
env = {
**{
lib["alias"]: (
getattr(import_dependency(lib["module"]), lib["name"])
Expand All @@ -73,6 +74,27 @@ def get_environment(additional_deps: List[dict]) -> dict:
},
}

env["pd"] = RestrictedPandas()
env["plt"] = RestrictedMatplotlib()
env["np"] = RestrictedNumpy()
if any(lib["name"] == "seaborn" for lib in additional_deps):
env["sns"] = RestrictedSeaborn()

for lib in additional_deps:
if lib["name"] == "seaborn":
env["sns"] = RestrictedSeaborn()

if lib["name"] == "datetime":
env["datetime"] = RestrictedDatetime()

if lib["name"] == "json":
env["json"] = RestrictedJson()

if lib["name"] == "base64":
env["base64"] = RestrictedBase64()

return env


def import_dependency(
name: str,
Expand Down
55 changes: 54 additions & 1 deletion pandasai/pipelines/chat/code_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from ...connectors import BaseConnector
from ...connectors.sql import SQLConnector
from ...constants import WHITELISTED_BUILTINS, WHITELISTED_LIBRARIES
from ...constants import RESTRICTED_LIBS, WHITELISTED_BUILTINS, WHITELISTED_LIBRARIES
from ...exceptions import (
BadImportError,
ExecuteSQLQueryNotUsed,
Expand Down Expand Up @@ -161,6 +161,58 @@ def get_code_to_run(self, code: str, context: CodeExecutionContext) -> Any:
return code_to_run

def _is_malicious_code(self, code) -> bool:
tree = ast.parse(code)

# Check for private attributes and access of restricted libs
def check_restricted_access(node):
"""Check if the node accesses restricted modules or private attributes."""
if isinstance(node, ast.Attribute):
attr_chain = []
while isinstance(node, ast.Attribute):
if node.attr.startswith("_"):
raise MaliciousQueryError(
f"Access to private attribute '{node.attr}' is not allowed."
)
attr_chain.insert(0, node.attr)
node = node.value
if isinstance(node, ast.Name):
attr_chain.insert(0, node.id)
if any(module in RESTRICTED_LIBS for module in attr_chain):
raise MaliciousQueryError(
f"Restricted access detected in attribute chain: {'.'.join(attr_chain)}"
)

elif isinstance(node, ast.Subscript) and isinstance(
node.value, ast.Attribute
):
check_restricted_access(node.value)

for node in ast.walk(tree):
# Check 'import ...' statements
if isinstance(node, ast.Import):
for alias in node.names:
sub_module_names = alias.name.split(".")
if any(module in RESTRICTED_LIBS for module in sub_module_names):
raise MaliciousQueryError(
f"Restricted library import detected: {alias.name}"
)

# Check 'from ... import ...' statements
elif isinstance(node, ast.ImportFrom):
sub_module_names = node.module.split(".")
if any(module in RESTRICTED_LIBS for module in sub_module_names):
raise MaliciousQueryError(
f"Restricted library import detected: {node.module}"
)
if any(alias.name in RESTRICTED_LIBS for alias in node.names):
raise MaliciousQueryError(
"Restricted library import detected in 'from ... import ...'"
)

# Check attribute access for restricted libraries (e.g., scipy.sparse._sputils.sys)
elif isinstance(node, (ast.Attribute, ast.Subscript)):
check_restricted_access(node)

dangerous_modules = [
" os",
" io",
Expand All @@ -176,6 +228,7 @@ def _is_malicious_code(self, code) -> bool:
"(chr",
"b64decode",
]

return any(
re.search(r"\b" + re.escape(module) + r"\b", code)
for module in dangerous_modules
Expand Down
27 changes: 27 additions & 0 deletions pandasai/safe_libs/base_restricted_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
class BaseRestrictedModule:
def _wrap_function(self, func):
def wrapper(*args, **kwargs):
# Check for any suspicious arguments that might be used for importing
for arg in args + tuple(kwargs.values()):
if isinstance(arg, str) and any(
module in arg.lower()
for module in ["io", "os", "subprocess", "sys", "importlib"]
):
raise SecurityError(
f"Potential security risk: '{arg}' is not allowed"
)
return func(*args, **kwargs)

return wrapper

def _wrap_class(self, cls):
class WrappedClass(cls):
def __getattribute__(self, name):
attr = super().__getattribute__(name)
return self._wrap_function(self, attr) if callable(attr) else attr

return WrappedClass


class SecurityError(Exception):
pass
21 changes: 21 additions & 0 deletions pandasai/safe_libs/restricted_base64.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import base64

from .base_restricted_module import BaseRestrictedModule


class RestrictedBase64(BaseRestrictedModule):
def __init__(self):
self.allowed_functions = [
"b64encode", # Safe function to encode data into base64
"b64decode", # Safe function to decode base64 encoded data
]

# Bind the allowed functions to the object
for func in self.allowed_functions:
if hasattr(base64, func):
setattr(self, func, self._wrap_function(getattr(base64, func)))

def __getattr__(self, name):
if name not in self.allowed_functions:
raise AttributeError(f"'{name}' is not allowed in RestrictedBase64")
return getattr(base64, name)
64 changes: 64 additions & 0 deletions pandasai/safe_libs/restricted_datetime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import datetime

from .base_restricted_module import BaseRestrictedModule


class RestrictedDatetime(BaseRestrictedModule):
def __init__(self):
self.allowed_attributes = [
# Classes
"date",
"time",
"datetime",
"timedelta",
"tzinfo",
"timezone",
# Constants
"MINYEAR",
"MAXYEAR",
# Time zone constants
"UTC",
# Functions
"now",
"utcnow",
"today",
"fromtimestamp",
"utcfromtimestamp",
"fromordinal",
"combine",
"strptime",
# Timedelta operations
"timedelta",
# Date operations
"weekday",
"isoweekday",
"isocalendar",
"isoformat",
"ctime",
"strftime",
"year",
"month",
"day",
"hour",
"minute",
"second",
"microsecond",
# Time operations
"replace",
"tzname",
"dst",
"utcoffset",
# Comparison methods
"min",
"max",
]

for attr in self.allowed_attributes:
if hasattr(datetime, attr):
setattr(self, attr, self._wrap_function(getattr(datetime, attr)))

def __getattr__(self, name):
if name not in self.allowed_attributes:
raise AttributeError(f"'{name}' is not allowed in RestrictedDatetime")

return getattr(datetime, name)
23 changes: 23 additions & 0 deletions pandasai/safe_libs/restricted_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import json

from .base_restricted_module import BaseRestrictedModule


class RestrictedJson(BaseRestrictedModule):
def __init__(self):
self.allowed_functions = [
"load",
"loads",
"dump",
"dumps",
]

# Bind the allowed functions to the object
for func in self.allowed_functions:
if hasattr(json, func):
setattr(self, func, self._wrap_function(getattr(json, func)))

def __getattr__(self, name):
if name not in self.allowed_functions:
raise AttributeError(f"'{name}' is not allowed in RestrictedJson")
return getattr(json, name)
75 changes: 75 additions & 0 deletions pandasai/safe_libs/restricted_matplotlib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import matplotlib.pyplot as plt
import matplotlib.figure as figure
import matplotlib.axes as axes
from .base_restricted_module import BaseRestrictedModule


class RestrictedMatplotlib(BaseRestrictedModule):
def __init__(self):
self.allowed_attributes = [
# Figure and Axes creation
"figure",
"subplots",
"subplot",
# Plotting functions
"plot",
"scatter",
"bar",
"barh",
"hist",
"boxplot",
"violinplot",
"pie",
"errorbar",
"contour",
"contourf",
"imshow",
"pcolor",
"pcolormesh",
# Axis manipulation
"xlabel",
"ylabel",
"title",
"legend",
"xlim",
"ylim",
"axis",
"xticks",
"yticks",
"grid",
"axhline",
"axvline",
# Colorbar
"colorbar",
# Text and annotations
"text",
"annotate",
# Styling
"style",
# Save and show
"show",
"savefig",
# Color maps
"get_cmap",
# 3D plotting
"axes3d",
# Utility functions
"close",
"clf",
"cla",
# Constants
"rcParams",
]

for attr in self.allowed_attributes:
if hasattr(plt, attr):
setattr(self, attr, self._wrap_function(getattr(plt, attr)))

# Special handling for figure and axes
self.Figure = self._wrap_class(figure.Figure)
self.Axes = self._wrap_class(axes.Axes)

def __getattr__(self, name):
if name not in self.allowed_attributes:
raise AttributeError(f"'{name}' is not allowed in RestrictedMatplotlib")
return getattr(plt, name)
Loading

0 comments on commit 3850595

Please sign in to comment.