Skip to content

Commit

Permalink
Merge pull request #652 from tableau/remote_server
Browse files Browse the repository at this point in the history
Add capability to deploy models remotely
  • Loading branch information
jakeichikawasalesforce authored Nov 25, 2024
2 parents 59f4056 + 3bb4d37 commit fe613a1
Show file tree
Hide file tree
Showing 11 changed files with 155 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,4 @@ tabpy/tabpy_server/staging
# etc
setup.bat
*~
tabpy_log.log.1
tabpy_log.log.*
7 changes: 7 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## v2.13.0

### Improvements

- Add support for deploying functions to a remote TabPy server by setting
`remote_server=True` when creating the Client instance.

## v2.12.0

### Improvements
Expand Down
10 changes: 10 additions & 0 deletions docs/tabpy-tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ The URL and port are where the Tableau-Python-Server process has been started -
more info can be found in the
[Starting TabPy](server-install.md#starting-tabpy) section of the documentation.

When connecting to a remote TabPy server, configure the following parameters:

- Set `remote_server` to `True` to indicate a remote connection
- Set `localhost_endpoint` to the specific localhost address used by the remote server
- **Note:** The protocol and port may differ from the main endpoint

```python
client = Client('https://example.com:443/', remote_server=True, localhost_endpoint='http://localhost:9004/')
```

## Authentication

When TabPy is configured with the authentication feature on, client code
Expand Down
2 changes: 1 addition & 1 deletion tabpy/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.12.1
2.13.0
6 changes: 4 additions & 2 deletions tabpy/tabpy_server/handlers/endpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,13 @@ def delete(self, name):

# delete files
if endpoint_info["type"] != "alias":
delete_path = get_query_object_path(
query_path = get_query_object_path(
self.settings["state_file_path"], name, None
)
staging_path = query_path.replace("/query_objects/", "/staging/endpoints/")
try:
yield self._delete_po_future(delete_path)
yield self._delete_po_future(query_path)
yield self._delete_po_future(staging_path)
except Exception as e:
self.error_out(400, f"Error while deleting: {e}")
self.finish()
Expand Down
4 changes: 2 additions & 2 deletions tabpy/tabpy_server/management/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def add_endpoint(
Name of the endpoint
description : str, optional
Description of this endpoint
doc_string : str, optional
docstring : str, optional
The doc string for this endpoint, if needed.
endpoint_type : str
The endpoint type (model, alias)
Expand Down Expand Up @@ -309,7 +309,7 @@ def update_endpoint(
Name of the endpoint
description : str, optional
Description of this endpoint
doc_string : str, optional
docstring : str, optional
The doc string for this endpoint, if needed.
endpoint_type : str, optional
The endpoint type (model, alias)
Expand Down
88 changes: 87 additions & 1 deletion tabpy/tabpy_tools/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import inspect
from re import compile
import time
import requests
Expand Down Expand Up @@ -49,7 +50,9 @@ def _check_endpoint_name(name):


class Client:
def __init__(self, endpoint, query_timeout=1000):
def __init__(
self, endpoint, query_timeout=1000, remote_server=False, localhost_endpoint=None
):
"""
Connects to a running server.
Expand All @@ -63,10 +66,19 @@ def __init__(self, endpoint, query_timeout=1000):
query_timeout : float, optional
The timeout for query operations.
remote_server : bool, optional
Whether client is a remote TabPy server.
localhost_endpoint : str, optional
The localhost endpoint with potentially different protocol and
port compared to the main endpoint parameter.
"""
_check_hostname(endpoint)

self._endpoint = endpoint
self._remote_server = remote_server
self._localhost_endpoint = localhost_endpoint

session = requests.session()
session.verify = False
Expand Down Expand Up @@ -232,6 +244,12 @@ def deploy(self, name, obj, description="", schema=None, override=False, is_publ
--------
remove, get_endpoints
"""
if self._remote_server:
return self._remote_deploy(
name, obj,
description=description, schema=schema, override=override, is_public=is_public
)

endpoint = self.get_endpoints().get(name)
version = 1
if endpoint:
Expand Down Expand Up @@ -390,6 +408,7 @@ def _gen_endpoint(self, name, obj, description, version=1, schema=None, is_publi
"methods": endpoint_object.get_methods(),
"required_files": [],
"required_packages": [],
"docstring": endpoint_object.get_docstring(),
"schema": copy.copy(schema),
"is_public": is_public,
}
Expand Down Expand Up @@ -419,6 +438,7 @@ def _wait_for_endpoint_deployment(
logger.info(
f"Waiting for endpoint {endpoint_name} to deploy to " f"version {version}"
)
time.sleep(interval)
start = time.time()
while True:
ep_status = self.get_status()
Expand Down Expand Up @@ -447,6 +467,72 @@ def _wait_for_endpoint_deployment(
logger.info(f"Sleeping {interval}...")
time.sleep(interval)

def _remote_deploy(
self, name, obj, description="", schema=None, override=False, is_public=False
):
"""
Remotely deploy a Python function using the /evaluate endpoint. Takes the same inputs
as deploy.
"""
remote_script = self._gen_remote_script()
remote_script += f"{inspect.getsource(obj)}\n"

remote_script += (
f"client.deploy("
f"'{name}', {obj.__name__}, '{description}', "
f"override={override}, is_public={is_public}, schema={schema}"
f")"
)

return self._evaluate_remote_script(remote_script)

def _gen_remote_script(self):
"""
Generates a remote script for TabPy client connection with credential handling.
Returns:
str: A Python script to establish a TabPy client connection
"""
remote_script = [
"from tabpy.tabpy_tools.client import Client",
f"client = Client('{self._localhost_endpoint or self._endpoint}')"
]

remote_script.append(
f"client.set_credentials('{auth.username}', '{auth.password}')"
) if (auth := self._service.service_client.network_wrapper.auth) else None

return "\n".join(remote_script) + "\n"

def _evaluate_remote_script(self, remote_script):
"""
Uses TabPy /evaluate endpoint to execute a remote TabPy client script.
Parameters
----------
remote_script : str
The script to execute remotely.
"""
print(f"Remote script:\n{remote_script}\n")
url = f"{self._endpoint}evaluate"
headers = {"Content-Type": "application/json"}
payload = {"data": {}, "script": remote_script}

response = requests.post(
url,
headers=headers,
auth=self._service.service_client.network_wrapper.auth,
json=payload
)

msg = response.text.replace('null', 'Success')
if "Ad-hoc scripts have been disabled" in msg:
msg += "\n[Deployment to remote tabpy client not allowed.]"

status_message = (f"{response.status_code} - {msg}\n")
print(status_message)
return status_message

def set_credentials(self, username, password):
"""
Set credentials for all the TabPy client-server communication
Expand Down
16 changes: 11 additions & 5 deletions tabpy/tabpy_tools/custom_query_object.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
import platform
import sys
from .query_object import QueryObject as _QueryObject


Expand Down Expand Up @@ -69,12 +71,16 @@ def query(self, *args, **kwargs):
)
raise

def get_doc_string(self):
def get_docstring(self):
"""Get doc string from customized query"""
if self.custom_query.__doc__ is not None:
return self.custom_query.__doc__
else:
return "-- no docstring found in query function --"
default_docstring = "-- no docstring found in query function --"

# TODO: fix docstring parsing on Windows systems
if sys.platform == 'win32':
return default_docstring

ds = getattr(self.custom_query, '__doc__', None)
return ds if ds and isinstance(ds, str) else default_docstring

def get_methods(self):
return [self.get_query_method()]
Expand Down
4 changes: 2 additions & 2 deletions tabpy/tabpy_tools/query_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class QueryObject(abc.ABC):
"""
Derived class needs to implement the following interface:
* query() -- given input, return query result
* get_doc_string() -- returns documentation for the Query Object
* get_docstring() -- returns documentation for the Query Object
"""

def __init__(self, description=""):
Expand All @@ -30,7 +30,7 @@ def query(self, input):
pass

@abc.abstractmethod
def get_doc_string(self):
def get_docstring(self):
"""Returns documentation for the query object
By default, this method returns the docstring for 'query' method
Expand Down
2 changes: 2 additions & 0 deletions tabpy/tabpy_tools/rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Endpoint(RESTObject):
version = RESTProperty(int)
description = RESTProperty(str)
dependencies = RESTProperty(list)
docstring = RESTProperty(str)
methods = RESTProperty(list)
creation_time = RESTProperty(datetime, from_epoch, to_epoch)
last_modified_time = RESTProperty(datetime, from_epoch, to_epoch)
Expand All @@ -64,6 +65,7 @@ def __eq__(self, other):
and self.version == other.version
and self.description == other.description
and self.dependencies == other.dependencies
and self.docstring == other.docstring
and self.methods == other.methods
and self.evaluator == other.evaluator
and self.schema_version == other.schema_version
Expand Down
31 changes: 28 additions & 3 deletions tests/unit/tools_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,25 @@ def setUp(self):

def test_init(self):
client = Client("http://example.com:9004")

self.assertEqual(client._endpoint, "http://example.com:9004")
self.assertEqual(client._remote_server, False)

client = Client("http://example.com/", 10.0)

self.assertEqual(client._endpoint, "http://example.com/")

client = Client(endpoint="https://example.com/", query_timeout=-10.0)

self.assertEqual(client._endpoint, "https://example.com/")
self.assertEqual(client.query_timeout, 0.0)

client = Client(
"http://example.com:442/",
remote_server=True,
localhost_endpoint="http://localhost:9004/"
)
self.assertEqual(client._endpoint, "http://example.com:442/")
self.assertEqual(client._remote_server, True)
self.assertEqual(client._localhost_endpoint, "http://localhost:9004/")

# valid name tests
with self.assertRaises(ValueError):
Client("")
Expand Down Expand Up @@ -90,3 +97,21 @@ def test_check_invalid_endpoint_name(self):
f"endpoint name {endpoint_name } can only contain: "
"a-z, A-Z, 0-9, underscore, hyphens and spaces.",
)

def test_deploy_with_remote_server(self):
client = Client("http://example.com:9004/", remote_server=True)
mock_evaluate_remote_script = Mock()
client._evaluate_remote_script = mock_evaluate_remote_script
client.deploy('name', lambda: True, 'description')
mock_evaluate_remote_script.assert_called()

def test_gen_remote_script(self):
client = Client("http://example.com:9004/", remote_server=True)
script = client._gen_remote_script()
self.assertTrue("from tabpy.tabpy_tools.client import Client" in script)
self.assertTrue("client = Client('http://example.com:9004/')" in script)
self.assertFalse("client.set_credentials" in script)

client.set_credentials("username", "password")
script = client._gen_remote_script()
self.assertTrue("client.set_credentials('username', 'password')" in script)

0 comments on commit fe613a1

Please sign in to comment.