diff --git a/superset/views/sql_lab/views.py b/superset/views/sql_lab/views.py index 012fdc2dc50c3..00ba4b8a5f0fc 100644 --- a/superset/views/sql_lab/views.py +++ b/superset/views/sql_lab/views.py @@ -161,6 +161,10 @@ def put(self, tab_state_id: int) -> FlaskResponse: return Response(status=403) fields = {k: json.loads(v) for k, v in request.form.to_dict().items()} + if client_id := fields.get("latest_query_id"): + query = db.session.query(Query).filter_by(client_id=client_id).one_or_none() + if not query: + return self.json_response({"error": "Bad request"}, status=400) db.session.query(TabState).filter_by(id=tab_state_id).update(fields) db.session.commit() return json_success(json.dumps(tab_state_id)) diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index 6d06e46fa30fa..3157ddd649e34 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -996,6 +996,41 @@ def test_tabstate_with_name(self): self.assertEqual(payload["label"], "Untitled Query foo") + def test_tabstate_update(self): + username = "admin" + self.login(username) + # create a tab + data = { + "queryEditor": json.dumps( + { + "name": "Untitled Query foo", + "dbId": 1, + "schema": None, + "autorun": False, + "sql": "SELECT ...", + "queryLimit": 1000, + } + ) + } + resp = self.get_json_resp("/tabstateview/", data=data) + tab_state_id = resp["id"] + # update tab state with non-existing client_id + client_id = "asdfasdf" + data = {"sql": json.dumps("select 1"), "latest_query_id": json.dumps(client_id)} + response = self.client.put(f"/tabstateview/{tab_state_id}", data=data) + self.assertEqual(response.status_code, 400) + self.assertEqual(response.json["error"], "Bad request") + # generate query + db.session.add(Query(client_id=client_id, database_id=1)) + db.session.commit() + # update tab state with a valid client_id + response = self.client.put(f"/tabstateview/{tab_state_id}", data=data) + self.assertEqual(response.status_code, 200) + # nulls should be ok too + data["latest_query_id"] = "null" + response = self.client.put(f"/tabstateview/{tab_state_id}", data=data) + self.assertEqual(response.status_code, 200) + def test_virtual_table_explore_visibility(self): # test that default visibility it set to True database = superset.utils.database.get_example_database()