From 82bdf956fd9e6ee546b20839cd163d26e781305a Mon Sep 17 00:00:00 2001 From: Steve Yoo Date: Wed, 30 Apr 2025 12:57:52 -0400 Subject: [PATCH] Add session id to user agent string --- .../enhancement-UserAgent-35014.json | 5 + awscli/clidriver.py | 2 + awscli/telemetry.py | 243 +++++++++++++++ .../build_system/functional/test_utils.py | 2 +- .../build_system/unit/test_install.py | 2 +- tests/functional/test_telemetry.py | 294 ++++++++++++++++++ tests/{backends/build_system => }/markers.py | 0 .../ec2instanceconnect/test_websocket.py | 10 +- tests/unit/test_clidriver.py | 14 +- 9 files changed, 560 insertions(+), 12 deletions(-) create mode 100644 .changes/next-release/enhancement-UserAgent-35014.json create mode 100644 awscli/telemetry.py create mode 100644 tests/functional/test_telemetry.py rename tests/{backends/build_system => }/markers.py (100%) diff --git a/.changes/next-release/enhancement-UserAgent-35014.json b/.changes/next-release/enhancement-UserAgent-35014.json new file mode 100644 index 000000000000..59802b5f4320 --- /dev/null +++ b/.changes/next-release/enhancement-UserAgent-35014.json @@ -0,0 +1,5 @@ +{ + "type": "enhancement", + "category": "User Agent", + "description": "Append session id to user agent string" +} diff --git a/awscli/clidriver.py b/awscli/clidriver.py index 261a472fbae2..6eb5586f3423 100644 --- a/awscli/clidriver.py +++ b/awscli/clidriver.py @@ -75,6 +75,7 @@ set_stream_logger, ) from awscli.plugin import load_plugins +from awscli.telemetry import add_session_id_component_to_user_agent_extra from awscli.utils import ( IMDSRegionProvider, OutputStreamFactory, @@ -176,6 +177,7 @@ def _set_user_agent_for_session(session): session.user_agent_version = __version__ _add_distribution_source_to_user_agent(session) _add_linux_distribution_to_user_agent(session) + add_session_id_component_to_user_agent_extra(session) def no_pager_handler(session, parsed_args, **kwargs): diff --git a/awscli/telemetry.py b/awscli/telemetry.py new file mode 100644 index 000000000000..e8cec47f3518 --- /dev/null +++ b/awscli/telemetry.py @@ -0,0 +1,243 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import hashlib +import io +import os +import socket +import sqlite3 +import sys +import threading +import time +from dataclasses import dataclass +from functools import cached_property +from pathlib import Path + +from botocore.useragent import UserAgentComponent + +from awscli.compat import is_windows +from awscli.utils import add_component_to_user_agent_extra + +_CACHE_DIR = Path.home() / '.aws' / 'cli' / 'cache' +_DATABASE_FILENAME = 'session.db' +_SESSION_LENGTH_SECONDS = 60 * 30 + +_CACHE_DIR.mkdir(parents=True, exist_ok=True) + + +@dataclass +class CLISessionData: + key: str + session_id: str + timestamp: int + + +class CLISessionDatabaseConnection: + _CREATE_TABLE = """ + CREATE TABLE IF NOT EXISTS session ( + key TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + timestamp INTEGER NOT NULL + ) + """ + _ENABLE_WAL = 'PRAGMA journal_mode=WAL' + + def __init__(self, connection=None): + self._connection = connection or sqlite3.connect( + _CACHE_DIR / _DATABASE_FILENAME, + check_same_thread=False, + isolation_level=None, + ) + self._ensure_database_setup() + + def execute(self, query, *parameters): + try: + return self._connection.execute(query, *parameters) + except sqlite3.OperationalError: + # Process timed out waiting for database lock. + # Return any empty `Cursor` object instead of + # raising an exception. + return sqlite3.Cursor(self._connection) + + def _ensure_database_setup(self): + self._create_record_table() + self._try_to_enable_wal() + + def _create_record_table(self): + self.execute(self._CREATE_TABLE) + + def _try_to_enable_wal(self): + try: + self.execute(self._ENABLE_WAL) + except sqlite3.Error: + # This is just a performance enhancement so it is optional. Not all + # systems will have a sqlite compiled with the WAL enabled. + pass + + +class CLISessionDatabaseWriter: + _WRITE_RECORD = """ + INSERT OR REPLACE INTO session ( + key, session_id, timestamp + ) VALUES (?, ?, ?) + """ + + def __init__(self, connection): + self._connection = connection + + def write(self, data): + self._connection.execute( + self._WRITE_RECORD, + ( + data.key, + data.session_id, + data.timestamp, + ), + ) + + +class CLISessionDatabaseReader: + _READ_RECORD = """ + SELECT * + FROM session + WHERE key = ? + """ + + def __init__(self, connection): + self._connection = connection + + def read(self, key): + cursor = self._connection.execute(self._READ_RECORD, (key,)) + result = cursor.fetchone() + if result is None: + return + return CLISessionData(*result) + + +class CLISessionDatabaseSweeper: + _DELETE_RECORDS = """ + DELETE FROM session + WHERE timestamp < ? + """ + + def __init__(self, connection): + self._connection = connection + + def sweep(self, timestamp): + try: + self._connection.execute(self._DELETE_RECORDS, (timestamp,)) + except Exception: + # This is just a background cleanup task. No need to + # handle it or direct to stderr. + return + + +class CLISessionGenerator: + def generate_session_id(self, hostname, tty, timestamp): + return self._generate_md5_hash(hostname, tty, timestamp) + + def generate_cache_key(self, hostname, tty): + return self._generate_md5_hash(hostname, tty) + + def _generate_md5_hash(self, *args): + str_to_hash = "" + for arg in args: + if arg is not None: + str_to_hash += str(arg) + return hashlib.md5(str_to_hash.encode('utf-8')).hexdigest() + + +class CLISessionOrchestrator: + def __init__(self, generator, writer, reader, sweeper): + self._generator = generator + self._writer = writer + self._reader = reader + self._sweeper = sweeper + + self._sweep_cache() + + @cached_property + def cache_key(self): + return self._generator.generate_cache_key(self._hostname, self._tty) + + @cached_property + def _session_id(self): + return self._generator.generate_session_id( + self._hostname, self._tty, self._timestamp + ) + + @cached_property + def session_id(self): + if (cached_data := self._reader.read(self.cache_key)) is not None: + # Cache hit, but session id is expired. Generate new id and update. + if ( + cached_data.timestamp + _SESSION_LENGTH_SECONDS + < self._timestamp + ): + cached_data.session_id = self._session_id + # Always update the timestamp to last used. + cached_data.timestamp = self._timestamp + self._writer.write(cached_data) + return cached_data.session_id + # Cache miss, generate and write new record. + session_id = self._session_id + session_data = CLISessionData( + self.cache_key, session_id, self._timestamp + ) + self._writer.write(session_data) + return session_id + + @cached_property + def _tty(self): + # os.ttyname is only available on Unix platforms. + if is_windows: + return + try: + return os.ttyname(sys.stdin.fileno()) + except (OSError, io.UnsupportedOperation): + # Standard input was redirected to a pseudofile. + # This can happen when running tests on IDEs or + # running scripts with redirected input. + return + + @cached_property + def _hostname(self): + return socket.gethostname() + + @cached_property + def _timestamp(self): + return int(time.time()) + + def _sweep_cache(self): + t = threading.Thread( + target=self._sweeper.sweep, + args=(self._timestamp - _SESSION_LENGTH_SECONDS,), + daemon=True, + ) + t.start() + + +def _get_cli_session_orchestrator(): + conn = CLISessionDatabaseConnection() + return CLISessionOrchestrator( + CLISessionGenerator(), + CLISessionDatabaseWriter(conn), + CLISessionDatabaseReader(conn), + CLISessionDatabaseSweeper(conn), + ) + + +def add_session_id_component_to_user_agent_extra(session, orchestrator=None): + cli_session_orchestrator = orchestrator or _get_cli_session_orchestrator() + add_component_to_user_agent_extra( + session, UserAgentComponent("sid", cli_session_orchestrator.session_id) + ) diff --git a/tests/backends/build_system/functional/test_utils.py b/tests/backends/build_system/functional/test_utils.py index 270b8dcc0f66..63a6882fa147 100644 --- a/tests/backends/build_system/functional/test_utils.py +++ b/tests/backends/build_system/functional/test_utils.py @@ -25,7 +25,7 @@ parse_requirements, ) -from tests.backends.build_system.markers import if_windows, skip_if_windows +from tests.markers import if_windows, skip_if_windows @pytest.fixture diff --git a/tests/backends/build_system/unit/test_install.py b/tests/backends/build_system/unit/test_install.py index 73d7be3481f5..56caba87343c 100644 --- a/tests/backends/build_system/unit/test_install.py +++ b/tests/backends/build_system/unit/test_install.py @@ -6,7 +6,7 @@ from build_system.install import Installer, Uninstaller from backends.build_system.utils import Utils -from tests.backends.build_system.markers import if_windows, skip_if_windows +from tests.markers import if_windows, skip_if_windows class FakeUtils(Utils): diff --git a/tests/functional/test_telemetry.py b/tests/functional/test_telemetry.py new file mode 100644 index 000000000000..3711540bdc08 --- /dev/null +++ b/tests/functional/test_telemetry.py @@ -0,0 +1,294 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import sqlite3 +from unittest.mock import MagicMock, patch + +import pytest +from botocore.session import Session + +from awscli.telemetry import ( + CLISessionData, + CLISessionDatabaseConnection, + CLISessionDatabaseReader, + CLISessionDatabaseSweeper, + CLISessionDatabaseWriter, + CLISessionGenerator, + CLISessionOrchestrator, + add_session_id_component_to_user_agent_extra, +) +from tests.markers import skip_if_windows + + +@pytest.fixture +def session_conn(): + conn = CLISessionDatabaseConnection( + connection=sqlite3.connect( + # Use an in-memory db for testing. + ':memory:', + check_same_thread=False, + isolation_level=None, + ), + ) + # Write an initial record. + conn.execute( + """ + INSERT OR REPLACE INTO session ( + key, session_id, timestamp + ) VALUES ('first_key', 'first_id', 5555555555) + """ + ) + return conn + + +@pytest.fixture +def session_writer(session_conn): + return CLISessionDatabaseWriter(session_conn) + + +@pytest.fixture +def session_reader(session_conn): + return CLISessionDatabaseReader(session_conn) + + +@pytest.fixture +def session_sweeper(session_conn): + return CLISessionDatabaseSweeper(session_conn) + + +@pytest.fixture +def session_generator(): + return CLISessionGenerator() + + +@pytest.fixture +def expired_data(session_writer, session_reader, session_sweeper): + # Write an expired record. + session_writer.write( + CLISessionData( + key='expired_key', + session_id='expired_id', + timestamp=1000000000, + ) + ) + # Ensure expired record exists. + assert session_reader.read('expired_key') is not None + yield + # Ensure cleanup after test is run. + session_sweeper.sweep(1000000001) + + +class TestCLISessionDatabaseConnection: + def test_ensure_database_setup(self, session_conn): + cursor = session_conn.execute( + """ + SELECT name + FROM sqlite_master + WHERE type='table' + AND name='session'; + """ + ) + assert cursor.fetchall() == [('session',)] + + def test_timeout_does_not_raise_exception(self, session_conn): + class FakeConnection(sqlite3.Connection): + def execute(self, query, *parameters): + # Simulate timeout by always raising. + raise sqlite3.OperationalError() + + fake_conn = CLISessionDatabaseConnection(FakeConnection(":memory:")) + cursor = fake_conn.execute( + """ + SELECT name + FROM sqlite_master + WHERE type='table' + AND name='session'; + """ + ) + assert cursor.fetchall() == [] + + +class TestCLISessionDatabaseWriter: + def test_write(self, session_writer, session_reader, session_sweeper): + session_writer.write( + CLISessionData( + key='new-key', + session_id='new-id', + timestamp=1000000000, + ) + ) + session_data = session_reader.read('new-key') + assert session_data.key == 'new-key' + assert session_data.session_id == 'new-id' + assert session_data.timestamp == 1000000000 + session_sweeper.sweep(1000000001) + + +class TestCLISessionDatabaseReader: + def test_read(self, session_reader): + session_data = session_reader.read('first_key') + assert session_data.key == 'first_key' + assert session_data.session_id == 'first_id' + assert session_data.timestamp == 5555555555 + + def test_read_nonexistent_record(self, session_reader): + session_data = session_reader.read('bad_key') + assert session_data is None + + +class TestCLISessionDatabaseSweeper: + def test_sweep(self, expired_data, session_reader, session_sweeper): + session_sweeper.sweep(1000000001) + swept_data = session_reader.read('expired_key') + assert swept_data is None + + def test_sweep_not_expired( + self, expired_data, session_reader, session_sweeper + ): + session_sweeper.sweep(1000000000) + swept_data = session_reader.read('expired_key') + assert swept_data is not None + + def test_sweep_never_raises(self, session_sweeper): + # Normally this would raise `sqlite3.ProgrammingError`, + # but the `sweep` method catches bare exceptions. + session_sweeper.sweep({'bad': 'input'}) + + +class TestCLISessionGenerator: + def test_generate_session_id(self, session_generator): + session_id = session_generator.generate_session_id( + 'my-hostname', + 'my-tty', + 1000000000, + ) + assert session_id == 'd949713b13ee3fb52983b04316e8e6b5' + + def test_generate_cache_key(self, session_generator): + cache_key = session_generator.generate_cache_key( + 'my-hostname', + 'my-tty', + ) + assert cache_key == 'b1ca2be0ffac12f172933b6777e06f2c' + + +@skip_if_windows +@patch('sys.stdin') +@patch('time.time', return_value=5555555555) +@patch('socket.gethostname', return_value='my-hostname') +@patch('os.ttyname', return_value='my-tty') +class TestCLISessionOrchestrator: + def test_session_id_gets_cached( + self, + patched_tty_name, + patched_hostname, + patched_time, + patched_stdin, + session_sweeper, + session_generator, + session_reader, + session_writer, + ): + patched_stdin.fileno.return_value = None + orchestrator = CLISessionOrchestrator( + session_generator, session_writer, session_reader, session_sweeper + ) + assert orchestrator.session_id == '881cea8546fa4888970cce8d133c3bf9' + + session_data = session_reader.read(orchestrator.cache_key) + assert session_data.key == orchestrator.cache_key + assert session_data.session_id == orchestrator.session_id + assert session_data.timestamp == 5555555555 + + def test_cached_session_id_updated_if_expired( + self, + patched_tty_name, + patched_hostname, + patched_time, + patched_stdin, + session_sweeper, + session_generator, + session_reader, + session_writer, + ): + patched_stdin.fileno.return_value = None + + # First, generate and cache a session id. + orchestrator_1 = CLISessionOrchestrator( + session_generator, session_writer, session_reader, session_sweeper + ) + session_id_1 = orchestrator_1.session_id + session_data_1 = session_reader.read(orchestrator_1.cache_key) + assert session_data_1.session_id == session_id_1 + + # Update the timestamp and get the new session id. + patched_time.return_value = 7777777777 + orchestrator_2 = CLISessionOrchestrator( + session_generator, session_writer, session_reader, session_sweeper + ) + session_id_2 = orchestrator_2.session_id + session_data_2 = session_reader.read(orchestrator_2.cache_key) + + # Cache key should be the same. + assert session_data_2.key == session_data_1.key + # Session id and timestamp should be updated. + assert session_data_2.session_id == session_id_2 + assert session_data_2.session_id != session_data_1.session_id + assert session_data_2.timestamp == 7777777777 + assert session_data_2.timestamp != session_data_1.timestamp + + def test_cached_session_id_not_updated_if_valid( + self, + patched_tty_name, + patched_hostname, + patched_time, + patched_stdin, + session_sweeper, + session_generator, + session_reader, + session_writer, + ): + patched_stdin.fileno.return_value = None + + # First, generate and cache a session id. + orchestrator_1 = CLISessionOrchestrator( + session_generator, session_writer, session_reader, session_sweeper + ) + session_id_1 = orchestrator_1.session_id + session_data_1 = session_reader.read(orchestrator_1.cache_key) + assert session_data_1.session_id == session_id_1 + + # Update the timestamp. + patched_time.return_value = 5555555556 + orchestrator_2 = CLISessionOrchestrator( + session_generator, session_writer, session_reader, session_sweeper + ) + session_id_2 = orchestrator_2.session_id + session_data_2 = session_reader.read(orchestrator_2.cache_key) + + # Cache key should be the same. + assert session_data_2.key == session_data_1.key + # Session id should not be updated. + assert session_data_2.session_id == session_id_2 + assert session_data_2.session_id == session_data_1.session_id + # Only timestamp should be updated. + assert session_data_2.timestamp == 5555555556 + assert session_data_2.timestamp != session_data_1.timestamp + + +def test_add_session_id_component_to_user_agent_extra(): + session = MagicMock(Session) + session.user_agent_extra = '' + orchestrator = MagicMock(CLISessionOrchestrator) + orchestrator.session_id = 'my-session-id' + add_session_id_component_to_user_agent_extra(session, orchestrator) + assert session.user_agent_extra == 'sid/my-session-id' diff --git a/tests/backends/build_system/markers.py b/tests/markers.py similarity index 100% rename from tests/backends/build_system/markers.py rename to tests/markers.py diff --git a/tests/unit/customizations/ec2instanceconnect/test_websocket.py b/tests/unit/customizations/ec2instanceconnect/test_websocket.py index 427078493ff7..fbbc162e9aee 100644 --- a/tests/unit/customizations/ec2instanceconnect/test_websocket.py +++ b/tests/unit/customizations/ec2instanceconnect/test_websocket.py @@ -52,15 +52,7 @@ WebsocketManager, WindowsStdinStdoutIO, ) - -skip_if_windows = pytest.mark.skipif( - platform.system() not in ['Darwin', 'Linux'], - reason="This test does not run on windows.", -) -if_windows = pytest.mark.skipif( - platform.system() in ['Darwin', 'Linux'], - reason="This test only runs on windows.", -) +from tests.markers import if_windows, skip_if_windows class TestWebsocketIO: diff --git a/tests/unit/test_clidriver.py b/tests/unit/test_clidriver.py index 67838fa7165a..81e23e3aca64 100644 --- a/tests/unit/test_clidriver.py +++ b/tests/unit/test_clidriver.py @@ -274,6 +274,12 @@ def _run_main(self, args, parsed_globals): return 0 +class FakeCLISessionOrchestrator: + @property + def session_id(self): + return 'mysessionid' + + class TestCliDriver: def setup_method(self): self.session = FakeSession() @@ -774,13 +780,19 @@ def test_idempotency_token_is_not_required_in_help_text(self): self.assertEqual(rc, 252) self.assertNotIn('--idempotency-token', self.stderr.getvalue()) + @mock.patch( + 'awscli.telemetry._get_cli_session_orchestrator', + return_value=FakeCLISessionOrchestrator(), + ) @mock.patch('awscli.clidriver.platform.system', return_value='Linux') @mock.patch('awscli.clidriver.platform.machine', return_value='x86_64') @mock.patch('awscli.clidriver.distro.id', return_value='amzn') @mock.patch('awscli.clidriver.distro.major_version', return_value='1') def test_user_agent_for_linux(self, *args): driver = create_clidriver() - expected_user_agent = 'md/installer#source md/distrib#amzn.1' + expected_user_agent = ( + 'md/installer#source md/distrib#amzn.1 sid/mysessionid' + ) self.assertEqual(expected_user_agent, driver.session.user_agent_extra) def test_user_agent(self, *args):