diff --git a/tableauserverclient/server/endpoint/auth_endpoint.py b/tableauserverclient/server/endpoint/auth_endpoint.py index 4211bb7e..35dfa5d7 100644 --- a/tableauserverclient/server/endpoint/auth_endpoint.py +++ b/tableauserverclient/server/endpoint/auth_endpoint.py @@ -84,9 +84,10 @@ def sign_in(self, auth_req: "Credentials") -> contextmgr: self._check_status(server_response, url) parsed_response = fromstring(server_response.content) site_id = parsed_response.find(".//t:site", namespaces=self.parent_srv.namespace).get("id", None) + site_url = parsed_response.find(".//t:site", namespaces=self.parent_srv.namespace).get("contentUrl", None) user_id = parsed_response.find(".//t:user", namespaces=self.parent_srv.namespace).get("id", None) auth_token = parsed_response.find("t:credentials", namespaces=self.parent_srv.namespace).get("token", None) - self.parent_srv._set_auth(site_id, user_id, auth_token) + self.parent_srv._set_auth(site_id, user_id, auth_token, site_url) logger.info(f"Signed into {self.parent_srv.server_address} as user with id {user_id}") return Auth.contextmgr(self.sign_out) @@ -155,9 +156,10 @@ def switch_site(self, site_item: "SiteItem") -> contextmgr: self._check_status(server_response, url) parsed_response = fromstring(server_response.content) site_id = parsed_response.find(".//t:site", namespaces=self.parent_srv.namespace).get("id", None) + site_url = parsed_response.find(".//t:site", namespaces=self.parent_srv.namespace).get("contentUrl", None) user_id = parsed_response.find(".//t:user", namespaces=self.parent_srv.namespace).get("id", None) auth_token = parsed_response.find("t:credentials", namespaces=self.parent_srv.namespace).get("token", None) - self.parent_srv._set_auth(site_id, user_id, auth_token) + self.parent_srv._set_auth(site_id, user_id, auth_token, site_url) logger.info(f"Signed into {self.parent_srv.server_address} as user with id {user_id}") return Auth.contextmgr(self.sign_out) diff --git a/tableauserverclient/server/server.py b/tableauserverclient/server/server.py index 4eeefcaf..02abb3fe 100644 --- a/tableauserverclient/server/server.py +++ b/tableauserverclient/server/server.py @@ -207,12 +207,14 @@ def _clear_auth(self): self._site_id = None self._user_id = None self._auth_token = None + self._site_url = None self._session = self._session_factory() - def _set_auth(self, site_id, user_id, auth_token): + def _set_auth(self, site_id, user_id, auth_token, site_url=None): self._site_id = site_id self._user_id = user_id self._auth_token = auth_token + self._site_url = site_url def _get_legacy_version(self): # the serverInfo call was introduced in 2.4, earlier than that we have this different call @@ -282,6 +284,13 @@ def site_id(self): raise NotSignedInError(error) return self._site_id + @property + def site_url(self): + if self._site_url is None: + error = "Missing site URL. You must sign in first." + raise NotSignedInError(error) + return self._site_url + @property def user_id(self): if self._user_id is None: diff --git a/test/test_auth.py b/test/test_auth.py index 48100ad8..09e3e251 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -27,6 +27,7 @@ def test_sign_in(self): self.assertEqual("eIX6mvFsqyansa4KqEI1UwOpS8ggRs2l", self.server.auth_token) self.assertEqual("6b7179ba-b82b-4f0f-91ed-812074ac5da6", self.server.site_id) + self.assertEqual("Samples", self.server.site_url) self.assertEqual("1a96d216-e9b8-497b-a82a-0b899a965e01", self.server.user_id) def test_sign_in_with_personal_access_tokens(self): @@ -41,6 +42,7 @@ def test_sign_in_with_personal_access_tokens(self): self.assertEqual("eIX6mvFsqyansa4KqEI1UwOpS8ggRs2l", self.server.auth_token) self.assertEqual("6b7179ba-b82b-4f0f-91ed-812074ac5da6", self.server.site_id) + self.assertEqual("Samples", self.server.site_url) self.assertEqual("1a96d216-e9b8-497b-a82a-0b899a965e01", self.server.user_id) def test_sign_in_impersonate(self): @@ -93,6 +95,7 @@ def test_sign_out(self): self.assertIsNone(self.server._auth_token) self.assertIsNone(self.server._site_id) + self.assertIsNone(self.server._site_url) self.assertIsNone(self.server._user_id) def test_switch_site(self): @@ -109,6 +112,7 @@ def test_switch_site(self): self.assertEqual("eIX6mvFsqyansa4KqEI1UwOpS8ggRs2l", self.server.auth_token) self.assertEqual("6b7179ba-b82b-4f0f-91ed-812074ac5da6", self.server.site_id) + self.assertEqual("Samples", self.server.site_url) self.assertEqual("1a96d216-e9b8-497b-a82a-0b899a965e01", self.server.user_id) def test_revoke_all_server_admin_tokens(self): @@ -125,4 +129,5 @@ def test_revoke_all_server_admin_tokens(self): self.assertEqual("eIX6mvFsqyansa4KqEI1UwOpS8ggRs2l", self.server.auth_token) self.assertEqual("6b7179ba-b82b-4f0f-91ed-812074ac5da6", self.server.site_id) + self.assertEqual("Samples", self.server.site_url) self.assertEqual("1a96d216-e9b8-497b-a82a-0b899a965e01", self.server.user_id)