diff --git a/tableauserverclient/__init__.py b/tableauserverclient/__init__.py index f8549992..bab2cf05 100644 --- a/tableauserverclient/__init__.py +++ b/tableauserverclient/__init__.py @@ -43,6 +43,7 @@ TaskItem, UserItem, ViewItem, + VirtualConnectionItem, WebhookItem, WeeklyInterval, WorkbookItem, @@ -124,4 +125,5 @@ "LinkedTaskItem", "LinkedTaskStepItem", "LinkedTaskFlowRunItem", + "VirtualConnectionItem", ] diff --git a/tableauserverclient/models/__init__.py b/tableauserverclient/models/__init__.py index 41676da2..e4131b72 100644 --- a/tableauserverclient/models/__init__.py +++ b/tableauserverclient/models/__init__.py @@ -45,6 +45,7 @@ from tableauserverclient.models.task_item import TaskItem from tableauserverclient.models.user_item import UserItem from tableauserverclient.models.view_item import ViewItem +from tableauserverclient.models.virtual_connection_item import VirtualConnectionItem from tableauserverclient.models.webhook_item import WebhookItem from tableauserverclient.models.workbook_item import WorkbookItem @@ -96,6 +97,7 @@ "TaskItem", "UserItem", "ViewItem", + "VirtualConnectionItem", "WebhookItem", "WorkbookItem", "LinkedTaskItem", diff --git a/tableauserverclient/models/connection_item.py b/tableauserverclient/models/connection_item.py index 29ffd270..62ff530c 100644 --- a/tableauserverclient/models/connection_item.py +++ b/tableauserverclient/models/connection_item.py @@ -66,12 +66,14 @@ def from_response(cls, resp, ns) -> List["ConnectionItem"]: for connection_xml in all_connection_xml: connection_item = cls() connection_item._id = connection_xml.get("id", None) - connection_item._connection_type = connection_xml.get("type", None) + connection_item._connection_type = connection_xml.get("type", connection_xml.get("dbClass", None)) connection_item.embed_password = string_to_bool(connection_xml.get("embedPassword", "")) connection_item.server_address = connection_xml.get("serverAddress", None) connection_item.server_port = connection_xml.get("serverPort", None) connection_item.username = connection_xml.get("userName", None) - connection_item._query_tagging = string_to_bool(connection_xml.get("queryTaggingEnabled", None)) + connection_item._query_tagging = ( + string_to_bool(s) if (s := connection_xml.get("queryTagging", None)) else None + ) datasource_elem = connection_xml.find(".//t:datasource", namespaces=ns) if datasource_elem is not None: connection_item._datasource_id = datasource_elem.get("id", None) diff --git a/tableauserverclient/models/tableau_types.py b/tableauserverclient/models/tableau_types.py index 33fe5eb0..bac07207 100644 --- a/tableauserverclient/models/tableau_types.py +++ b/tableauserverclient/models/tableau_types.py @@ -1,11 +1,12 @@ from typing import Union -from .datasource_item import DatasourceItem -from .flow_item import FlowItem -from .project_item import ProjectItem -from .view_item import ViewItem -from .workbook_item import WorkbookItem -from .metric_item import MetricItem +from tableauserverclient.models.datasource_item import DatasourceItem +from tableauserverclient.models.flow_item import FlowItem +from tableauserverclient.models.project_item import ProjectItem +from tableauserverclient.models.view_item import ViewItem +from tableauserverclient.models.workbook_item import WorkbookItem +from tableauserverclient.models.metric_item import MetricItem +from tableauserverclient.models.virtual_connection_item import VirtualConnectionItem class Resource: @@ -18,12 +19,13 @@ class Resource: Metric = "metric" Project = "project" View = "view" + VirtualConnection = "virtualConnection" Workbook = "workbook" # resource types that have permissions, can be renamed, etc # todo: refactoring: should actually define TableauItem as an interface and let all these implement it -TableauItem = Union[DatasourceItem, FlowItem, MetricItem, ProjectItem, ViewItem, WorkbookItem] +TableauItem = Union[DatasourceItem, FlowItem, MetricItem, ProjectItem, ViewItem, WorkbookItem, VirtualConnectionItem] def plural_type(content_type: Resource) -> str: diff --git a/tableauserverclient/models/virtual_connection_item.py b/tableauserverclient/models/virtual_connection_item.py new file mode 100644 index 00000000..76a3b5de --- /dev/null +++ b/tableauserverclient/models/virtual_connection_item.py @@ -0,0 +1,77 @@ +import datetime as dt +import json +from typing import Callable, Dict, Iterable, List, Optional +from xml.etree.ElementTree import Element + +from defusedxml.ElementTree import fromstring + +from tableauserverclient.datetime_helpers import parse_datetime +from tableauserverclient.models.connection_item import ConnectionItem +from tableauserverclient.models.exceptions import UnpopulatedPropertyError +from tableauserverclient.models.permissions_item import PermissionsRule + + +class VirtualConnectionItem: + def __init__(self, name: str) -> None: + self.name = name + self.created_at: Optional[dt.datetime] = None + self.has_extracts: Optional[bool] = None + self._id: Optional[str] = None + self.is_certified: Optional[bool] = None + self.updated_at: Optional[dt.datetime] = None + self.webpage_url: Optional[str] = None + self._connections: Optional[Callable[[], Iterable[ConnectionItem]]] = None + self.project_id: Optional[str] = None + self.owner_id: Optional[str] = None + self.content: Optional[Dict[str, dict]] = None + self.certification_note: Optional[str] = None + + def __str__(self) -> str: + return f"{self.__class__.__qualname__}(name={self.name})" + + def __repr__(self) -> str: + return f"<{self!s}>" + + def _set_permissions(self, permissions): + self._permissions = permissions + + @property + def id(self) -> Optional[str]: + return self._id + + @property + def permissions(self) -> List[PermissionsRule]: + if self._permissions is None: + error = "Workbook item must be populated with permissions first." + raise UnpopulatedPropertyError(error) + return self._permissions() + + @property + def connections(self) -> Iterable[ConnectionItem]: + if self._connections is None: + raise AttributeError("connections not populated. Call populate_connections() first.") + return self._connections() + + @classmethod + def from_response(cls, response: bytes, ns: Dict[str, str]) -> List["VirtualConnectionItem"]: + parsed_response = fromstring(response) + return [cls.from_xml(xml, ns) for xml in parsed_response.findall(".//t:virtualConnection[@name]", ns)] + + @classmethod + def from_xml(cls, xml: Element, ns: Dict[str, str]) -> "VirtualConnectionItem": + v_conn = cls(xml.get("name", "")) + v_conn._id = xml.get("id", None) + v_conn.webpage_url = xml.get("webpageUrl", None) + v_conn.created_at = parse_datetime(xml.get("createdAt", None)) + v_conn.updated_at = parse_datetime(xml.get("updatedAt", None)) + v_conn.is_certified = string_to_bool(s) if (s := xml.get("isCertified", None)) else None + v_conn.certification_note = xml.get("certificationNote", None) + v_conn.has_extracts = string_to_bool(s) if (s := xml.get("hasExtracts", None)) else None + v_conn.project_id = p.get("id", None) if ((p := xml.find(".//t:project[@id]", ns)) is not None) else None + v_conn.owner_id = o.get("id", None) if ((o := xml.find(".//t:owner[@id]", ns)) is not None) else None + v_conn.content = json.loads(c.text or "{}") if ((c := xml.find(".//t:content", ns)) is not None) else None + return v_conn + + +def string_to_bool(s: str) -> bool: + return s.lower() in ["true", "1", "t", "y", "yes"] diff --git a/tableauserverclient/server/endpoint/__init__.py b/tableauserverclient/server/endpoint/__init__.py index 30eae922..b05b9add 100644 --- a/tableauserverclient/server/endpoint/__init__.py +++ b/tableauserverclient/server/endpoint/__init__.py @@ -27,6 +27,7 @@ from tableauserverclient.server.endpoint.tasks_endpoint import Tasks from tableauserverclient.server.endpoint.users_endpoint import Users from tableauserverclient.server.endpoint.views_endpoint import Views +from tableauserverclient.server.endpoint.virtual_connections_endpoint import VirtualConnections from tableauserverclient.server.endpoint.webhooks_endpoint import Webhooks from tableauserverclient.server.endpoint.workbooks_endpoint import Workbooks @@ -62,6 +63,7 @@ "Tasks", "Users", "Views", + "VirtualConnections", "Webhooks", "Workbooks", ] diff --git a/tableauserverclient/server/endpoint/virtual_connections_endpoint.py b/tableauserverclient/server/endpoint/virtual_connections_endpoint.py new file mode 100644 index 00000000..f71db00c --- /dev/null +++ b/tableauserverclient/server/endpoint/virtual_connections_endpoint.py @@ -0,0 +1,173 @@ +from functools import partial +import json +from pathlib import Path +from typing import Iterable, List, Optional, Set, TYPE_CHECKING, Tuple, Union + +from tableauserverclient.models.connection_item import ConnectionItem +from tableauserverclient.models.pagination_item import PaginationItem +from tableauserverclient.models.revision_item import RevisionItem +from tableauserverclient.models.virtual_connection_item import VirtualConnectionItem +from tableauserverclient.server.request_factory import RequestFactory +from tableauserverclient.server.request_options import RequestOptions +from tableauserverclient.server.endpoint.endpoint import QuerysetEndpoint, api +from tableauserverclient.server.endpoint.permissions_endpoint import _PermissionsEndpoint +from tableauserverclient.server.endpoint.resource_tagger import TaggingMixin +from tableauserverclient.server.pager import Pager + +if TYPE_CHECKING: + from tableauserverclient.server import Server + + +class VirtualConnections(QuerysetEndpoint[VirtualConnectionItem], TaggingMixin): + def __init__(self, parent_srv: "Server") -> None: + super().__init__(parent_srv) + self._permissions = _PermissionsEndpoint(parent_srv, lambda: self.baseurl) + + @property + def baseurl(self) -> str: + return f"{self.parent_srv.baseurl}/sites/{self.parent_srv.site_id}/virtualConnections" + + @api(version="3.18") + def get(self, req_options: Optional[RequestOptions] = None) -> Tuple[List[VirtualConnectionItem], PaginationItem]: + server_response = self.get_request(self.baseurl, req_options) + pagination_item = PaginationItem.from_response(server_response.content, self.parent_srv.namespace) + virtual_connections = VirtualConnectionItem.from_response(server_response.content, self.parent_srv.namespace) + return virtual_connections, pagination_item + + @api(version="3.18") + def populate_connections(self, virtual_connection: VirtualConnectionItem) -> VirtualConnectionItem: + def _connection_fetcher(): + return Pager(partial(self._get_virtual_database_connections, virtual_connection)) + + virtual_connection._connections = _connection_fetcher + return virtual_connection + + def _get_virtual_database_connections( + self, virtual_connection: VirtualConnectionItem, req_options: Optional[RequestOptions] = None + ) -> Tuple[List[ConnectionItem], PaginationItem]: + server_response = self.get_request(f"{self.baseurl}/{virtual_connection.id}/connections", req_options) + connections = ConnectionItem.from_response(server_response.content, self.parent_srv.namespace) + pagination_item = PaginationItem.from_response(server_response.content, self.parent_srv.namespace) + + return connections, pagination_item + + @api(version="3.18") + def update_connection_db_connection( + self, virtual_connection: Union[str, VirtualConnectionItem], connection: ConnectionItem + ) -> ConnectionItem: + vconn_id = getattr(virtual_connection, "id", virtual_connection) + url = f"{self.baseurl}/{vconn_id}/connections/{connection.id}/modify" + xml_request = RequestFactory.VirtualConnection.update_db_connection(connection) + server_response = self.put_request(url, xml_request) + return ConnectionItem.from_response(server_response.content, self.parent_srv.namespace)[0] + + @api(version="3.23") + def get_by_id(self, virtual_connection: Union[str, VirtualConnectionItem]) -> VirtualConnectionItem: + vconn_id = getattr(virtual_connection, "id", virtual_connection) + url = f"{self.baseurl}/{vconn_id}" + server_response = self.get_request(url) + return VirtualConnectionItem.from_response(server_response.content, self.parent_srv.namespace)[0] + + @api(version="3.23") + def download(self, virtual_connection: Union[str, VirtualConnectionItem]) -> str: + v_conn = self.get_by_id(virtual_connection) + return json.dumps(v_conn.content) + + @api(version="3.23") + def update(self, virtual_connection: VirtualConnectionItem) -> VirtualConnectionItem: + url = f"{self.baseurl}/{virtual_connection.id}" + xml_request = RequestFactory.VirtualConnection.update(virtual_connection) + server_response = self.put_request(url, xml_request) + return VirtualConnectionItem.from_response(server_response.content, self.parent_srv.namespace)[0] + + @api(version="3.23") + def get_revisions( + self, virtual_connection: VirtualConnectionItem, req_options: Optional[RequestOptions] = None + ) -> Tuple[List[RevisionItem], PaginationItem]: + server_response = self.get_request(f"{self.baseurl}/{virtual_connection.id}/revisions", req_options) + pagination_item = PaginationItem.from_response(server_response.content, self.parent_srv.namespace) + revisions = RevisionItem.from_response(server_response.content, self.parent_srv.namespace, virtual_connection) + return revisions, pagination_item + + @api(version="3.23") + def download_revision(self, virtual_connection: VirtualConnectionItem, revision_number: int) -> str: + url = f"{self.baseurl}/{virtual_connection.id}/revisions/{revision_number}" + server_response = self.get_request(url) + virtual_connection = VirtualConnectionItem.from_response(server_response.content, self.parent_srv.namespace)[0] + return json.dumps(virtual_connection.content) + + @api(version="3.23") + def delete(self, virtual_connection: Union[VirtualConnectionItem, str]) -> None: + vconn_id = getattr(virtual_connection, "id", virtual_connection) + self.delete_request(f"{self.baseurl}/{vconn_id}") + + @api(version="3.23") + def publish( + self, + virtual_connection: VirtualConnectionItem, + virtual_connection_content: str, + mode: str = "CreateNew", + publish_as_draft: bool = False, + ) -> VirtualConnectionItem: + """ + Publish a virtual connection to the server. + + For the virtual_connection object, name, project_id, and owner_id are + required. + + The virtual_connection_content can be a json string or a file path to a + json file. + + The mode can be "CreateNew" or "Overwrite". If mode is + "Overwrite" and the virtual connection already exists, it will be + overwritten. + + If publish_as_draft is True, the virtual connection will be published + as a draft, and the id of the draft will be on the response object. + """ + try: + json.loads(virtual_connection_content) + except json.JSONDecodeError: + file = Path(virtual_connection_content) + if not file.exists(): + raise RuntimeError(f"{virtual_connection_content} is not valid json nor an existing file path") + content = file.read_text() + else: + content = virtual_connection_content + + if mode not in ["CreateNew", "Overwrite"]: + raise ValueError(f"Invalid mode: {mode}") + overwrite = mode == "Overwrite" + + url = f"{self.baseurl}?overwrite={str(overwrite).lower()}&publishAsDraft={str(publish_as_draft).lower()}" + xml_request = RequestFactory.VirtualConnection.publish(virtual_connection, content) + server_response = self.post_request(url, xml_request) + return VirtualConnectionItem.from_response(server_response.content, self.parent_srv.namespace)[0] + + @api(version="3.22") + def populate_permissions(self, item: VirtualConnectionItem) -> None: + self._permissions.populate(item) + + @api(version="3.22") + def add_permissions(self, resource, rules): + return self._permissions.update(resource, rules) + + @api(version="3.22") + def delete_permission(self, item, capability_item): + return self._permissions.delete(item, capability_item) + + @api(version="3.23") + def add_tags( + self, virtual_connection: Union[VirtualConnectionItem, str], tags: Union[Iterable[str], str] + ) -> Set[str]: + return super().add_tags(virtual_connection, tags) + + @api(version="3.23") + def delete_tags( + self, virtual_connection: Union[VirtualConnectionItem, str], tags: Union[Iterable[str], str] + ) -> None: + return super().delete_tags(virtual_connection, tags) + + @api(version="3.23") + def update_tags(self, virtual_connection: VirtualConnectionItem) -> None: + raise NotImplementedError("Update tags is not implemented for Virtual Connections") diff --git a/tableauserverclient/server/request_factory.py b/tableauserverclient/server/request_factory.py index 7fc9c955..96fa1468 100644 --- a/tableauserverclient/server/request_factory.py +++ b/tableauserverclient/server/request_factory.py @@ -1356,6 +1356,65 @@ def update_request(self, xml_request: ET.Element, group_set_item: "GroupSetItem" return ET.tostring(xml_request) +class VirtualConnectionRequest: + @_tsrequest_wrapped + def update_db_connection(self, xml_request: ET.Element, connection_item: ConnectionItem) -> bytes: + connection_element = ET.SubElement(xml_request, "connection") + if connection_item.server_address is not None: + connection_element.attrib["serverAddress"] = connection_item.server_address + if connection_item.server_port is not None: + connection_element.attrib["serverPort"] = str(connection_item.server_port) + if connection_item.username is not None: + connection_element.attrib["userName"] = connection_item.username + if connection_item.password is not None: + connection_element.attrib["password"] = connection_item.password + + return ET.tostring(xml_request) + + @_tsrequest_wrapped + def update(self, xml_request: ET.Element, virtual_connection: VirtualConnectionItem) -> bytes: + vc_element = ET.SubElement(xml_request, "virtualConnection") + if virtual_connection.name is not None: + vc_element.attrib["name"] = virtual_connection.name + if virtual_connection.is_certified is not None: + vc_element.attrib["isCertified"] = str(virtual_connection.is_certified).lower() + if virtual_connection.certification_note is not None: + vc_element.attrib["certificationNote"] = virtual_connection.certification_note + if virtual_connection.project_id is not None: + project_element = ET.SubElement(vc_element, "project") + project_element.attrib["id"] = virtual_connection.project_id + if virtual_connection.owner_id is not None: + owner_element = ET.SubElement(vc_element, "owner") + owner_element.attrib["id"] = virtual_connection.owner_id + + return ET.tostring(xml_request) + + @_tsrequest_wrapped + def publish(self, xml_request: ET.Element, virtual_connection: VirtualConnectionItem, content: str) -> bytes: + vc_element = ET.SubElement(xml_request, "virtualConnection") + if virtual_connection.name is not None: + vc_element.attrib["name"] = virtual_connection.name + else: + raise ValueError("Virtual Connection must have a name.") + if virtual_connection.project_id is not None: + project_element = ET.SubElement(vc_element, "project") + project_element.attrib["id"] = virtual_connection.project_id + else: + raise ValueError("Virtual Connection must have a project id.") + if virtual_connection.owner_id is not None: + owner_element = ET.SubElement(vc_element, "owner") + owner_element.attrib["id"] = virtual_connection.owner_id + else: + raise ValueError("Virtual Connection must have an owner id.") + if content is not None: + content_element = ET.SubElement(vc_element, "content") + content_element.text = content + else: + raise ValueError("Virtual Connection must have content.") + + return ET.tostring(xml_request) + + class RequestFactory(object): Auth = AuthRequest() Connection = Connection() @@ -1382,5 +1441,6 @@ class RequestFactory(object): Tag = TagRequest() Task = TaskRequest() User = UserRequest() + VirtualConnection = VirtualConnectionRequest() Workbook = WorkbookRequest() Webhook = WebhookRequest() diff --git a/tableauserverclient/server/server.py b/tableauserverclient/server/server.py index 20a7dc3d..e563a713 100644 --- a/tableauserverclient/server/server.py +++ b/tableauserverclient/server/server.py @@ -36,6 +36,7 @@ LinkedTasks, GroupSets, Tags, + VirtualConnections, ) from tableauserverclient.server.exceptions import ( ServerInfoEndpointNotFoundError, @@ -105,6 +106,7 @@ def __init__(self, server_address, use_server_version=False, http_options=None, self.linked_tasks = LinkedTasks(self) self.group_sets = GroupSets(self) self.tags = Tags(self) + self.virtual_connections = VirtualConnections(self) self._session = self._session_factory() self._http_options = dict() # must set this before making a server call diff --git a/test/assets/virtual_connection_add_permissions.xml b/test/assets/virtual_connection_add_permissions.xml new file mode 100644 index 00000000..d8b05284 --- /dev/null +++ b/test/assets/virtual_connection_add_permissions.xml @@ -0,0 +1,21 @@ + + + + + + + + + + + + + + + + + + + + + diff --git a/test/assets/virtual_connection_database_connection_update.xml b/test/assets/virtual_connection_database_connection_update.xml new file mode 100644 index 00000000..a6135d60 --- /dev/null +++ b/test/assets/virtual_connection_database_connection_update.xml @@ -0,0 +1,6 @@ + + + + diff --git a/test/assets/virtual_connection_populate_connections.xml b/test/assets/virtual_connection_populate_connections.xml new file mode 100644 index 00000000..77d89952 --- /dev/null +++ b/test/assets/virtual_connection_populate_connections.xml @@ -0,0 +1,6 @@ + + + + + + diff --git a/test/assets/virtual_connections_download.xml b/test/assets/virtual_connections_download.xml new file mode 100644 index 00000000..889e70ce --- /dev/null +++ b/test/assets/virtual_connections_download.xml @@ -0,0 +1,7 @@ + + + + + {"policyCollection":{"luid":"34ae5eb9-ceac-4158-86f1-a5d8163d5261","policies":[]},"revision":{"luid":"1b2e2aae-b904-4f5a-aa4d-9f114b8e5f57","revisableProperties":{}}} + + diff --git a/test/assets/virtual_connections_get.xml b/test/assets/virtual_connections_get.xml new file mode 100644 index 00000000..f1f410e4 --- /dev/null +++ b/test/assets/virtual_connections_get.xml @@ -0,0 +1,14 @@ + + + + + + + diff --git a/test/assets/virtual_connections_publish.xml b/test/assets/virtual_connections_publish.xml new file mode 100644 index 00000000..889e70ce --- /dev/null +++ b/test/assets/virtual_connections_publish.xml @@ -0,0 +1,7 @@ + + + + + {"policyCollection":{"luid":"34ae5eb9-ceac-4158-86f1-a5d8163d5261","policies":[]},"revision":{"luid":"1b2e2aae-b904-4f5a-aa4d-9f114b8e5f57","revisableProperties":{}}} + + diff --git a/test/assets/virtual_connections_revisions.xml b/test/assets/virtual_connections_revisions.xml new file mode 100644 index 00000000..37411342 --- /dev/null +++ b/test/assets/virtual_connections_revisions.xml @@ -0,0 +1,14 @@ + + + + + + + + + + + + + + diff --git a/test/assets/virtual_connections_update.xml b/test/assets/virtual_connections_update.xml new file mode 100644 index 00000000..60d5d169 --- /dev/null +++ b/test/assets/virtual_connections_update.xml @@ -0,0 +1,8 @@ + + + + + + + diff --git a/test/test_tagging.py b/test/test_tagging.py index fc88eea8..0184af41 100644 --- a/test/test_tagging.py +++ b/test/test_tagging.py @@ -83,6 +83,12 @@ def make_flow() -> TSC.FlowItem: return flow +def make_vconn() -> TSC.VirtualConnectionItem: + vconn = TSC.VirtualConnectionItem("test") + vconn._id = str(uuid.uuid4()) + return vconn + + sample_taggable_items = ( [ ("workbooks", make_workbook()), @@ -97,6 +103,8 @@ def make_flow() -> TSC.FlowItem: ("databases", "some_id"), ("flows", make_flow()), ("flows", "some_id"), + ("virtual_connections", make_vconn()), + ("virtual_connections", "some_id"), ], ) diff --git a/test/test_virtual_connection.py b/test/test_virtual_connection.py new file mode 100644 index 00000000..975033d2 --- /dev/null +++ b/test/test_virtual_connection.py @@ -0,0 +1,242 @@ +import json +from pathlib import Path +import unittest + +import requests_mock + +import tableauserverclient as TSC +from tableauserverclient.datetime_helpers import parse_datetime +from tableauserverclient.models.virtual_connection_item import VirtualConnectionItem + +ASSET_DIR = Path(__file__).parent / "assets" + +VIRTUAL_CONNECTION_GET_XML = ASSET_DIR / "virtual_connections_get.xml" +VIRTUAL_CONNECTION_POPULATE_CONNECTIONS = ASSET_DIR / "virtual_connection_populate_connections.xml" +VC_DB_CONN_UPDATE = ASSET_DIR / "virtual_connection_database_connection_update.xml" +VIRTUAL_CONNECTION_DOWNLOAD = ASSET_DIR / "virtual_connections_download.xml" +VIRTUAL_CONNECTION_UPDATE = ASSET_DIR / "virtual_connections_update.xml" +VIRTUAL_CONNECTION_REVISIONS = ASSET_DIR / "virtual_connections_revisions.xml" +VIRTUAL_CONNECTION_PUBLISH = ASSET_DIR / "virtual_connections_publish.xml" +ADD_PERMISSIONS = ASSET_DIR / "virtual_connection_add_permissions.xml" + + +class TestVirtualConnections(unittest.TestCase): + def setUp(self) -> None: + self.server = TSC.Server("http://test") + + self.server._site_id = "dad65087-b08b-4603-af4e-2887b8aafc67" + self.server._auth_token = "j80k54ll2lfMZ0tv97mlPvvSCRyD0DOM" + self.server.version = "3.23" + + self.baseurl = f"{self.server.baseurl}/sites/{self.server.site_id}/virtualConnections" + return super().setUp() + + def test_from_xml(self): + items = VirtualConnectionItem.from_response(VIRTUAL_CONNECTION_GET_XML.read_bytes(), self.server.namespace) + + assert len(items) == 1 + virtual_connection = items[0] + assert virtual_connection.created_at == parse_datetime("2024-05-30T09:00:00Z") + assert not virtual_connection.has_extracts + assert virtual_connection.id == "8fd7cc02-bb55-4d15-b8b1-9650239efe79" + assert virtual_connection.is_certified + assert virtual_connection.name == "vconn" + assert virtual_connection.updated_at == parse_datetime("2024-06-18T09:00:00Z") + assert virtual_connection.webpage_url == "https://test/#/site/site-name/virtualconnections/3" + + def test_virtual_connection_get(self): + with requests_mock.mock() as m: + m.get(self.baseurl, text=VIRTUAL_CONNECTION_GET_XML.read_text()) + items, pagination_item = self.server.virtual_connections.get() + + assert len(items) == 1 + assert pagination_item.total_available == 1 + assert items[0].name == "vconn" + + def test_virtual_connection_populate_connections(self): + vconn = VirtualConnectionItem("vconn") + vconn._id = "8fd7cc02-bb55-4d15-b8b1-9650239efe79" + with requests_mock.mock() as m: + m.get(f"{self.baseurl}/{vconn.id}/connections", text=VIRTUAL_CONNECTION_POPULATE_CONNECTIONS.read_text()) + vc_out = self.server.virtual_connections.populate_connections(vconn) + connection_list = list(vconn.connections) + + assert vc_out is vconn + assert vc_out._connections is not None + + assert len(connection_list) == 1 + connection = connection_list[0] + assert connection.id == "37ca6ced-58d7-4dcf-99dc-f0a85223cbef" + assert connection.connection_type == "postgres" + assert connection.server_address == "localhost" + assert connection.server_port == "5432" + assert connection.username == "pgadmin" + + def test_virtual_connection_update_connection_db_connection(self): + vconn = VirtualConnectionItem("vconn") + vconn._id = "8fd7cc02-bb55-4d15-b8b1-9650239efe79" + connection = TSC.ConnectionItem() + connection._id = "37ca6ced-58d7-4dcf-99dc-f0a85223cbef" + connection.server_address = "localhost" + connection.server_port = "5432" + connection.username = "pgadmin" + connection.password = "password" + with requests_mock.mock() as m: + m.put(f"{self.baseurl}/{vconn.id}/connections/{connection.id}/modify", text=VC_DB_CONN_UPDATE.read_text()) + updated_connection = self.server.virtual_connections.update_connection_db_connection(vconn, connection) + + assert updated_connection.id == "37ca6ced-58d7-4dcf-99dc-f0a85223cbef" + assert updated_connection.server_address == "localhost" + assert updated_connection.server_port == "5432" + assert updated_connection.username == "pgadmin" + assert updated_connection.password is None + + def test_virtual_connection_get_by_id(self): + vconn = VirtualConnectionItem("vconn") + vconn._id = "8fd7cc02-bb55-4d15-b8b1-9650239efe79" + with requests_mock.mock() as m: + m.get(f"{self.baseurl}/{vconn.id}", text=VIRTUAL_CONNECTION_DOWNLOAD.read_text()) + vconn = self.server.virtual_connections.get_by_id(vconn) + + assert vconn.content + assert vconn.created_at is None + assert vconn.id is None + assert "policyCollection" in vconn.content + assert "revision" in vconn.content + + def test_virtual_connection_update(self): + vconn = VirtualConnectionItem("vconn") + vconn._id = "8fd7cc02-bb55-4d15-b8b1-9650239efe79" + vconn.is_certified = True + vconn.certification_note = "demo certification note" + vconn.project_id = "5286d663-8668-4ac2-8c8d-91af7d585f6b" + vconn.owner_id = "9324cf6b-ba72-4b8e-b895-ac3f28d2f0e0" + with requests_mock.mock() as m: + m.put(f"{self.baseurl}/{vconn.id}", text=VIRTUAL_CONNECTION_UPDATE.read_text()) + vconn = self.server.virtual_connections.update(vconn) + + assert not vconn.has_extracts + assert vconn.id is None + assert vconn.is_certified + assert vconn.name == "testv1" + assert vconn.certification_note == "demo certification note" + assert vconn.project_id == "5286d663-8668-4ac2-8c8d-91af7d585f6b" + assert vconn.owner_id == "9324cf6b-ba72-4b8e-b895-ac3f28d2f0e0" + + def test_virtual_connection_get_revisions(self): + vconn = VirtualConnectionItem("vconn") + vconn._id = "8fd7cc02-bb55-4d15-b8b1-9650239efe79" + with requests_mock.mock() as m: + m.get(f"{self.baseurl}/{vconn.id}/revisions", text=VIRTUAL_CONNECTION_REVISIONS.read_text()) + revisions, pagination_item = self.server.virtual_connections.get_revisions(vconn) + + assert len(revisions) == 3 + assert pagination_item.total_available == 3 + assert revisions[0].resource_id == vconn.id + assert revisions[0].resource_name == vconn.name + assert revisions[0].created_at == parse_datetime("2016-07-26T20:34:56Z") + assert revisions[0].revision_number == "1" + assert not revisions[0].current + assert not revisions[0].deleted + assert revisions[0].user_name == "Cassie" + assert revisions[0].user_id == "5de011f8-5aa9-4d5b-b991-f462c8dd6bb7" + assert revisions[1].resource_id == vconn.id + assert revisions[1].resource_name == vconn.name + assert revisions[1].created_at == parse_datetime("2016-07-27T20:34:56Z") + assert revisions[1].revision_number == "2" + assert not revisions[1].current + assert not revisions[1].deleted + assert revisions[2].resource_id == vconn.id + assert revisions[2].resource_name == vconn.name + assert revisions[2].created_at == parse_datetime("2016-07-28T20:34:56Z") + assert revisions[2].revision_number == "3" + assert revisions[2].current + assert not revisions[2].deleted + assert revisions[2].user_name == "Cassie" + assert revisions[2].user_id == "5de011f8-5aa9-4d5b-b991-f462c8dd6bb7" + + def test_virtual_connection_download_revision(self): + vconn = VirtualConnectionItem("vconn") + vconn._id = "8fd7cc02-bb55-4d15-b8b1-9650239efe79" + with requests_mock.mock() as m: + m.get(f"{self.baseurl}/{vconn.id}/revisions/1", text=VIRTUAL_CONNECTION_DOWNLOAD.read_text()) + content = self.server.virtual_connections.download_revision(vconn, 1) + + assert content + assert "policyCollection" in content + data = json.loads(content) + assert "policyCollection" in data + assert "revision" in data + + def test_virtual_connection_delete(self): + vconn = VirtualConnectionItem("vconn") + vconn._id = "8fd7cc02-bb55-4d15-b8b1-9650239efe79" + with requests_mock.mock() as m: + m.delete(f"{self.baseurl}/{vconn.id}") + self.server.virtual_connections.delete(vconn) + self.server.virtual_connections.delete(vconn.id) + + assert m.call_count == 2 + + def test_virtual_connection_publish(self): + vconn = VirtualConnectionItem("vconn") + vconn._id = "8fd7cc02-bb55-4d15-b8b1-9650239efe79" + vconn.project_id = "9836791c-9468-40f0-b7f3-d10b9562a046" + vconn.owner_id = "ee8bc9ca-77fe-4ae0-8093-cf77f0ee67a9" + with requests_mock.mock() as m: + m.post(f"{self.baseurl}?overwrite=false&publishAsDraft=false", text=VIRTUAL_CONNECTION_PUBLISH.read_text()) + vconn = self.server.virtual_connections.publish( + vconn, '{"test": 0}', mode="CreateNew", publish_as_draft=False + ) + + assert vconn.name == "vconn_test" + assert vconn.owner_id == "ee8bc9ca-77fe-4ae0-8093-cf77f0ee67a9" + assert vconn.project_id == "9836791c-9468-40f0-b7f3-d10b9562a046" + assert vconn.content + assert "policyCollection" in vconn.content + assert "revision" in vconn.content + + def test_virtual_connection_publish_draft_overwrite(self): + vconn = VirtualConnectionItem("vconn") + vconn._id = "8fd7cc02-bb55-4d15-b8b1-9650239efe79" + vconn.project_id = "9836791c-9468-40f0-b7f3-d10b9562a046" + vconn.owner_id = "ee8bc9ca-77fe-4ae0-8093-cf77f0ee67a9" + with requests_mock.mock() as m: + m.post(f"{self.baseurl}?overwrite=true&publishAsDraft=true", text=VIRTUAL_CONNECTION_PUBLISH.read_text()) + vconn = self.server.virtual_connections.publish( + vconn, '{"test": 0}', mode="Overwrite", publish_as_draft=True + ) + + assert vconn.name == "vconn_test" + assert vconn.owner_id == "ee8bc9ca-77fe-4ae0-8093-cf77f0ee67a9" + assert vconn.project_id == "9836791c-9468-40f0-b7f3-d10b9562a046" + assert vconn.content + assert "policyCollection" in vconn.content + assert "revision" in vconn.content + + def test_add_permissions(self) -> None: + with open(ADD_PERMISSIONS, "rb") as f: + response_xml = f.read().decode("utf-8") + + single_virtual_connection = TSC.VirtualConnectionItem("test") + single_virtual_connection._id = "21778de4-b7b9-44bc-a599-1506a2639ace" + + bob = TSC.UserItem.as_reference("7c37ee24-c4b1-42b6-a154-eaeab7ee330a") + group_of_people = TSC.GroupItem.as_reference("5e5e1978-71fa-11e4-87dd-7382f5c437af") + + new_permissions = [ + TSC.PermissionsRule(bob, {"Write": "Allow"}), + TSC.PermissionsRule(group_of_people, {"Read": "Deny"}), + ] + + with requests_mock.mock() as m: + m.put(self.baseurl + "/21778de4-b7b9-44bc-a599-1506a2639ace/permissions", text=response_xml) + permissions = self.server.virtual_connections.add_permissions(single_virtual_connection, new_permissions) + + self.assertEqual(permissions[0].grantee.tag_name, "group") + self.assertEqual(permissions[0].grantee.id, "5e5e1978-71fa-11e4-87dd-7382f5c437af") + self.assertDictEqual(permissions[0].capabilities, {TSC.Permission.Capability.Read: TSC.Permission.Mode.Deny}) + + self.assertEqual(permissions[1].grantee.tag_name, "user") + self.assertEqual(permissions[1].grantee.id, "7c37ee24-c4b1-42b6-a154-eaeab7ee330a") + self.assertDictEqual(permissions[1].capabilities, {TSC.Permission.Capability.Write: TSC.Permission.Mode.Allow})