diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 0000000..a2a06fe --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,25 @@ +name: deploy + +on: + workflow_run: + workflows: ["Tests"] + branches: [main] + types: + - completed + +permissions: + contents: write + +jobs: + deploy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install the latest version of uv + uses: astral-sh/setup-uv@v1 + with: + version: "latest" + + - name: Deploy with MkDocs + run: uv run mkdocs gh-deploy --force --config-file docs/mkdocs.yml \ No newline at end of file diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..eb65a89 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,48 @@ +name: Publish + +on: + release: + types: [published] + +jobs: + publish: + permissions: + id-token: write + contents: read + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Get current version from pyproject.toml + id: get_version + run: | + echo "VERSION=$(grep -m 1 'version =' "pyproject.toml" | awk -F'"' '{print $2}')" >> $GITHUB_ENV + + - name: Extract tag version + id: extract_tag + run: | + TAG_VERSION=$(echo "${{ github.event.release.tag_name }}" | sed -E 's/v(.*)/\1/') + echo "TAG_VERSION=$TAG_VERSION" >> $GITHUB_ENV + + - name: Check if tag matches version from pyproject.toml + id: check_tag + run: | + if [ "${{ env.TAG_VERSION }}" != "${{ env.VERSION }}" ]; then + echo "::error::Tag version (${{ env.TAG_VERSION }}) does not match version in pyproject.toml (${{ env.VERSION }})." + exit 1 + fi + + - name: Install the latest version of uv + uses: astral-sh/setup-uv@v1 + with: + version: "latest" + + - name: Build Package + run: uv build + + - name: Upload to GitHub Release + run: | + gh release upload ${{ github.event.release.tag_name }} dist/* + + - name: Publish package to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 \ No newline at end of file diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml new file mode 100644 index 0000000..776cb6e --- /dev/null +++ b/.github/workflows/python-app.yml @@ -0,0 +1,35 @@ +name: Tests + +on: + push: + branches: '*' + pull_request: + branches: '*' + +jobs: + tests: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + + - name: Install the latest version of uv + uses: astral-sh/setup-uv@v1 + with: + version: "latest" + + - name: Test with Python ${{ matrix.python-version }} + run: | + uv run --python ${{ matrix.python-version }} pytest tests + + - name: Ruff + run: | + uvx ruff check + uvx ruff format --check + + - name: Test build + run: uv build \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..575d8d6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,144 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class +**/.ruff_cache +lcov.info + + +# C extensions +*.so + +# Hermit cache +**/.hermit/ + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +*.ipynb +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Datafiles +*.csv +*.gz +*.h5 +*.pkl +*.pk +*.db +*.db-journal + +# Configuration +.vscode/ +.idea/ +.workstations/configs/test-workstation.yml + +# ignore integration tests +test_cli.py + +# ignore lockfile +uv.lock \ No newline at end of file diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..205fc90 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,13 @@ +Copyright 2024 Block, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..b5b5e22 --- /dev/null +++ b/README.md @@ -0,0 +1,6 @@ +### Cloud-Workstations +

+ +

+ +Please see the docs [here](https://square.github.io/cloud-workstation/). diff --git a/bin/.just-1.35.0.pkg b/bin/.just-1.35.0.pkg new file mode 120000 index 0000000..383f451 --- /dev/null +++ b/bin/.just-1.35.0.pkg @@ -0,0 +1 @@ +hermit \ No newline at end of file diff --git a/bin/.ruff-0.6.4.pkg b/bin/.ruff-0.6.4.pkg new file mode 120000 index 0000000..383f451 --- /dev/null +++ b/bin/.ruff-0.6.4.pkg @@ -0,0 +1 @@ +hermit \ No newline at end of file diff --git a/bin/.uv-0.4.9.pkg b/bin/.uv-0.4.9.pkg new file mode 120000 index 0000000..383f451 --- /dev/null +++ b/bin/.uv-0.4.9.pkg @@ -0,0 +1 @@ +hermit \ No newline at end of file diff --git a/bin/README.hermit.md b/bin/README.hermit.md new file mode 100644 index 0000000..e889550 --- /dev/null +++ b/bin/README.hermit.md @@ -0,0 +1,7 @@ +# Hermit environment + +This is a [Hermit](https://github.com/cashapp/hermit) bin directory. + +The symlinks in this directory are managed by Hermit and will automatically +download and install Hermit itself as well as packages. These packages are +local to this environment. diff --git a/bin/activate-hermit b/bin/activate-hermit new file mode 100755 index 0000000..fe28214 --- /dev/null +++ b/bin/activate-hermit @@ -0,0 +1,21 @@ +#!/bin/bash +# This file must be used with "source bin/activate-hermit" from bash or zsh. +# You cannot run it directly +# +# THIS FILE IS GENERATED; DO NOT MODIFY + +if [ "${BASH_SOURCE-}" = "$0" ]; then + echo "You must source this script: \$ source $0" >&2 + exit 33 +fi + +BIN_DIR="$(dirname "${BASH_SOURCE[0]:-${(%):-%x}}")" +if "${BIN_DIR}/hermit" noop > /dev/null; then + eval "$("${BIN_DIR}/hermit" activate "${BIN_DIR}/..")" + + if [ -n "${BASH-}" ] || [ -n "${ZSH_VERSION-}" ]; then + hash -r 2>/dev/null + fi + + echo "Hermit environment $("${HERMIT_ENV}"/bin/hermit env HERMIT_ENV) activated" +fi diff --git a/bin/hermit b/bin/hermit new file mode 100755 index 0000000..7fef769 --- /dev/null +++ b/bin/hermit @@ -0,0 +1,43 @@ +#!/bin/bash +# +# THIS FILE IS GENERATED; DO NOT MODIFY + +set -eo pipefail + +export HERMIT_USER_HOME=~ + +if [ -z "${HERMIT_STATE_DIR}" ]; then + case "$(uname -s)" in + Darwin) + export HERMIT_STATE_DIR="${HERMIT_USER_HOME}/Library/Caches/hermit" + ;; + Linux) + export HERMIT_STATE_DIR="${XDG_CACHE_HOME:-${HERMIT_USER_HOME}/.cache}/hermit" + ;; + esac +fi + +export HERMIT_DIST_URL="${HERMIT_DIST_URL:-https://github.com/cashapp/hermit/releases/download/stable}" +HERMIT_CHANNEL="$(basename "${HERMIT_DIST_URL}")" +export HERMIT_CHANNEL +export HERMIT_EXE=${HERMIT_EXE:-${HERMIT_STATE_DIR}/pkg/hermit@${HERMIT_CHANNEL}/hermit} + +if [ ! -x "${HERMIT_EXE}" ]; then + echo "Bootstrapping ${HERMIT_EXE} from ${HERMIT_DIST_URL}" 1>&2 + INSTALL_SCRIPT="$(mktemp)" + # This value must match that of the install script + INSTALL_SCRIPT_SHA256="180e997dd837f839a3072a5e2f558619b6d12555cd5452d3ab19d87720704e38" + if [ "${INSTALL_SCRIPT_SHA256}" = "BYPASS" ]; then + curl -fsSL "${HERMIT_DIST_URL}/install.sh" -o "${INSTALL_SCRIPT}" + else + # Install script is versioned by its sha256sum value + curl -fsSL "${HERMIT_DIST_URL}/install-${INSTALL_SCRIPT_SHA256}.sh" -o "${INSTALL_SCRIPT}" + # Verify install script's sha256sum + openssl dgst -sha256 "${INSTALL_SCRIPT}" | \ + awk -v EXPECTED="$INSTALL_SCRIPT_SHA256" \ + '$2!=EXPECTED {print "Install script sha256 " $2 " does not match " EXPECTED; exit 1}' + fi + /bin/bash "${INSTALL_SCRIPT}" 1>&2 +fi + +exec "${HERMIT_EXE}" --level=fatal exec "$0" -- "$@" diff --git a/bin/hermit.hcl b/bin/hermit.hcl new file mode 100644 index 0000000..081cbe8 --- /dev/null +++ b/bin/hermit.hcl @@ -0,0 +1,2 @@ +github-token-auth { +} diff --git a/bin/just b/bin/just new file mode 120000 index 0000000..12088e1 --- /dev/null +++ b/bin/just @@ -0,0 +1 @@ +.just-1.35.0.pkg \ No newline at end of file diff --git a/bin/ruff b/bin/ruff new file mode 120000 index 0000000..3eb7e98 --- /dev/null +++ b/bin/ruff @@ -0,0 +1 @@ +.ruff-0.6.4.pkg \ No newline at end of file diff --git a/bin/uv b/bin/uv new file mode 120000 index 0000000..0de9b70 --- /dev/null +++ b/bin/uv @@ -0,0 +1 @@ +.uv-0.4.9.pkg \ No newline at end of file diff --git a/bin/uvx b/bin/uvx new file mode 120000 index 0000000..0de9b70 --- /dev/null +++ b/bin/uvx @@ -0,0 +1 @@ +.uv-0.4.9.pkg \ No newline at end of file diff --git a/docs/docs/api.md b/docs/docs/api.md new file mode 100644 index 0000000..aa9be02 --- /dev/null +++ b/docs/docs/api.md @@ -0,0 +1,5 @@ +# API +::: workstation + options: + show_submodules: true + docstring_style: numpy \ No newline at end of file diff --git a/docs/docs/cli.md b/docs/docs/cli.md new file mode 100644 index 0000000..7026335 --- /dev/null +++ b/docs/docs/cli.md @@ -0,0 +1,10 @@ +# Command Line Usage + +This page provides usage for workstation CLI. + +::: mkdocs-click + :module: workstation.cli + :command: cli + :prog_name: workstation + :style: table + :depth: 1 diff --git a/docs/docs/css/app.css b/docs/docs/css/app.css new file mode 100644 index 0000000..8d3f229 --- /dev/null +++ b/docs/docs/css/app.css @@ -0,0 +1,87 @@ +:root, [data-md-color-scheme="default"] { + --md-primary-fg-color: #00D64F; +} + +[data-md-color-scheme="slate"] { + --md-primary-fg-color: #15c13e; + --md-default-bg-color: #121212; + --md-code-bg-color: hsla(var(--md-hue),25%,25%,1); +} + +@font-face { + font-family: cash-market; + src: url("https://cash-f.squarecdn.com/static/fonts/cash-market/v2/CashMarket-Regular.woff2") format("woff2"); + font-weight: 400; + font-style: normal +} + +@font-face { + font-family: cash-market; + src: url("https://cash-f.squarecdn.com/static/fonts/cash-market/v2/CashMarket-Medium.woff2") format("woff2"); + font-weight: 500; + font-style: normal +} + +@font-face { + font-family: cash-market; + src: url("https://cash-f.squarecdn.com/static/fonts/cash-market/v2/CashMarket-Bold.woff2") format("woff2"); + font-weight: 700; + font-style: normal +} + +/* Use Cash fonts. */ +body, input { + font-family: cash-market,"Helvetica Neue",helvetica,sans-serif; +} + +/* The material theme uses lighter weights for h1-h4 by default. Use bolder weights instead. */ +.md-typeset h1, .md-typeset h2, .md-typeset h3, .md-typeset h4 { + line-height: normal; + font-weight: bold; +} + +/* The material theme uses a lighter colour for h1 by default. Use the fg colour instead. */ +.md-typeset h1 { + color: var(--md-default-fg-color); +} + +/* Make links really look like links. They need to be bolded because the accent colour isn't + very readable with a lighter weight on white. */ +.md-typeset a { + text-decoration: underline; + font-weight: bold; +} + +/* Remove highlights from search results. */ +.md-typeset mark { + background-color: transparent; +} + +/* Header links need to be bolded to be readable. */ +.md-header__title { + font-weight: bold; +} +.md-tabs__link { + font-weight: bold; +} + +button.dl { + font-weight: 300; + font-size: 25px; + line-height: 40px; + padding: 3px 10px; + display: inline-block; + border-radius: 6px; + margin: 5px 0; + width: auto; +} + +.logo { + text-align: center; + margin-top: 150px; +} + +/* Make admonitions maintain normal font size so they are readable */ +.md-typeset .admonition { + font-size: inherit; +} diff --git a/docs/docs/img/cashapp.png b/docs/docs/img/cashapp.png new file mode 100644 index 0000000..d9f6e25 Binary files /dev/null and b/docs/docs/img/cashapp.png differ diff --git a/docs/docs/img/favicon.ico b/docs/docs/img/favicon.ico new file mode 100644 index 0000000..4d23597 Binary files /dev/null and b/docs/docs/img/favicon.ico differ diff --git a/docs/docs/img/gui-ssh.png b/docs/docs/img/gui-ssh.png new file mode 100644 index 0000000..4109212 Binary files /dev/null and b/docs/docs/img/gui-ssh.png differ diff --git a/docs/docs/img/install.png b/docs/docs/img/install.png new file mode 100644 index 0000000..602765c Binary files /dev/null and b/docs/docs/img/install.png differ diff --git a/docs/docs/img/ssh-remote1.png b/docs/docs/img/ssh-remote1.png new file mode 100644 index 0000000..31270ba Binary files /dev/null and b/docs/docs/img/ssh-remote1.png differ diff --git a/docs/docs/img/ssh-remote2.png b/docs/docs/img/ssh-remote2.png new file mode 100644 index 0000000..835f74a Binary files /dev/null and b/docs/docs/img/ssh-remote2.png differ diff --git a/docs/docs/img/ssh-remote3.png b/docs/docs/img/ssh-remote3.png new file mode 100644 index 0000000..3b419d5 Binary files /dev/null and b/docs/docs/img/ssh-remote3.png differ diff --git a/docs/docs/index.md b/docs/docs/index.md new file mode 100644 index 0000000..4c8ee8d --- /dev/null +++ b/docs/docs/index.md @@ -0,0 +1,206 @@ +# Cloud Workstations +Cloud Workstations are remote development machines hosted in Google cloud. Workstations provide a full-featured development experience and allow users to connect a remote machine to VSCode. Workstations provide a development experience similar to your laptop with some key differences. + +Workstations... + +* run inside your VPC or cloud network so you can access internal sources +* are ephemeral and shut down after a maximum runtime +* run a linux OS and allow you to attach accelerators or change the machine type +* provide preconfigured environments freeing users from having to worry about their local setup + + + +## More details: +### Authentication + +* Users authenticate to a workstation via IAM using their LDAP. This means that running `gcloud auth login`, etc. is necessary to access google services from the workstation. Like on your laptop, this cached credential is valid for 12 hours and persists between workstation sessions. +* Compared to VertexAI notebooks and CustomJobs, which are typically run as a service account, workstation sessions have more narrowly scoped access. By default, only the creator of a workstation has access to that workstation instance. + + +### Configuration + +* Like your laptop, workstations will go to sleep/shutdown after a period of inactivity. Additionally, there is a maximum runtime before it is shut off. These values are set in the workstation definition (config), and we have started with some defaults. This activity is measured via SSH or HTTP traffic. + Furthermore, anything in `/home/` will persist between sessions. + + +## IDE + +* There are two ways to use the IDE: + * via the web browser, which does not requires an SSH connection, ensuring low latency + * by attaching your IDE remote workstation via the remote SSH extension +* Google has some examples of the open source VSCode and Jetbrains. + + + +## Setup +Install the workstation CLI using pip or (recommended) pipx + +```shell +pipx install cloud-workstations +``` + +## Usage + +### Determine the type of workstation you want to create +Using the command `workstation list-configs` you can see all configs that are available to you. New workstation configs can be added by an admin. If you have questions about which config to use or have a use case for a new config, please reach out ot #machine-learning-tools. + +```shell +workstation list-configs +``` +For example you might get a result like this: +``` +├── Config: medium +│ ├── 💽 Image: us-central1-docker.pkg.dev/cloud-workstations-images/predefined/code-oss:latest +│ ├── 💻 Machine Type: n2-standard-16 +│ ├── ⏳ Idle Timeout (hours): 2.0 +│ └── ⏳ Max Runtime (hours): 12.0 +``` + +Which indicates that you have the config `medium` available, and that this will give you a `n2-standard-16` machine with a 2 hour idle timeout and a 12 hour max runtime. + + +### Create your workstation and start it +```shell +❯ workstation create --config medium --name +❯ workstation start --name +``` +If you make a VSCode based IDE image you can also use the `--browser` option will open VSCode directly in your browser. If you'd like to connect your remote machine to your Desktop VSCode app follow the setup directions for local [VS Code][connect-to-a-workstation-with-local-vs-code] and then start your workstation with the `--code` option. + +```shell +❯ workstation start --name --code +``` + +You can also ssh directly to your remote machine. For example, running commands like `ssh ` in a local terminal will open an SSH connection to your remote machine. + + +You can use either the `workstation` CLI or the browser to list, start, and stop workstations. You can view active workstations in the browser at https://console.cloud.google.com/workstations/ + +![ssh-gui](img/gui-ssh.png) + + + +### Connect to a Workstation with local VS Code +1. To enable local VS Code to connect to a Workstation, you will need to install the Remote - SSH extension. +2. Setup your `~/.ssh/config` file to include the following: + ``` + Include ~/.workstations/configs/*.config + ``` +3. Once installed, you can connect to a Workstation by clicking on the green icon in the bottom left corner of VS Code and selecting "Remote-SSH: Connect to Host...". You can then select the workstation host from the list. +![workstation option](img/ssh-remote1.png) +![workstation IDE](img/ssh-remote2.png) +1. Install your extensions on the remote host. VS Code doesn't install some like copilot or Python by default on the remote host. But click on the cloud icon and select which extensions you want to install. +![extensions](img/ssh-remote3.png) + + +### List your workstations +You can go to https://console.cloud.google.com/workstations/ to see you workstations and status or use the ` workstation list ---user ` command + +```shell +❯ workstation list --user damien +Workstations +├── Workstation: damien-medium + ├── 🛑 Stopped + ├── User: damien + ├── 💽 Image: us-central1-docker.pkg.dev/cloud-workstations-images/predefined/code-oss:latest + ├── 💻 Machine Type: n2-standard-16 + ├── ⏳ Idle Timeout (hours): 2.0 + └── ⏳ Max Runtime (hours): 12.0 +Total Workstations: 1 +``` + +Json output is also supported with the `--json` flag. +```shell +❯ workstation list --json | jq '.[] | select(.user == "damien" and .state == "STATE_STOPPED")' +``` +Which give you the output: +```json +{ + "name": "damien-medium", + "user": "damien", + "project": "example", + "location": "us-central1", + "config": "vscode-medium", + "cluster": "example", + "state": "STATE_STOPPED", + "idle_timeout": 2.0, + "max_runtime": 12.0, + "type": "n2-standard-16", + "image": "us-central1-docker.pkg.dev/cloud-workstations-images/predefined/code-oss:latest" +} +``` + +### Idle vs Max Runtime +Idle Timeout is the amount of time the workstation will wait before shutting down if there is no activity based on the config (in the previous config example that is 4 hours). This is measured via SSH or HTTP traffic. Max Runtime is the maximum amount of time the workstation will run before shutting down regardless of activity. It is possible to create configs that never idle, or have a max runtime. + + +### Syncing files to the workstation +There are a few ways to sync files to the workstation. A built in way is to use the `workstation sync` command. This command will sync the files in `~/remote-machines/workstation/` to the workstation. This is useful for syncing your keys and other files that you might like for a customized experience like a custom ~/.zshrc. However there are drag and drop methods, or just doing it directly from the terminal with ssh or rsync. + +1. Option 1: Use `workstation sync --name ` to sync the keys and anything else in ``~/remote-machines/workstation/` to your home directory on the workstation, with additional options for different directories. +2. Option 2: Turn on the workstation and drag and drop the files into the file explorer in VS Code. +3. Option 3: User scp or rsync if you setup your ~/.ssh/config as above and use the commands like normal. For example: + + ```shell + rsync -avz --exclude=".*" --exclude "*/__pycache__" /path/to/local/folder :/path/to/remote/folder + ``` + +#### Example of using workstation sync to sync credentials +You can sync files from your laptop to the workstation using the `workstation sync` command. This command will sync the files in your local `~/remote-machines/workstation/` directory to the workstation. This is useful for syncing your keys and other files that you might like for a customized experience like a custom ~/.zshrc. However there are drag and drop methods, or just doing it directly from the terminal. + +For example, if you have followed the directions on [Accessing GitHub](#accessing-github) and saved the Github certificate in `~/remote-machines/workstation/` the following command will sync that cert to the remote machine and enable github access. Syncing credentails or configuration files only needs to happen once after creating the workstation instance since these will persist on the disk. + +```shell +❯ workstation sync --name damien-medium +building file list ... done +./ +.p10k.zsh +.zsh_plugins.txt +.zshrc +.ssh/ +.ssh/id_ed25519 +.ssh/id_ed25519.pub +test/ + +sent 95505 bytes received 182 bytes 63791.33 bytes/sec +total size is 94956 speedup is 0.99 +``` + + + + +## Report issues +Use the issue link above to take you to the repo issue [page](https://github.com/square/workstations/issues). Please include the version of the CLI you are using `❯ workstation --version` and which config. + + + +## Development and Testing + +We use [Hermit](https://cashapp.github.io/hermit/) to manage the development environment. To get started, install Hermit and run the following command + +```shell +curl -fsSL https://github.com/cashapp/hermit/releases/download/stable/install.sh | /bin/bash +``` + +After cloning the repo do `. bin/activate-hermit` which will make the environment available with UV, Ruff, and Just. + +### Render Docs +`just docs` + +### Run Tests +`just tests` + +### Run local copy +`uv run workstation --version` +### Using VSCode Debugger + +If you want to use VS Code debug mode use this profile. +```json + { + "name": "Python: Module", + "type": "debugpy", + "request": "launch", + "module": "workstation.cli", + "console": "integratedTerminal", + "args": ["list-configs", "--cluster", ""] + } +``` \ No newline at end of file diff --git a/docs/docs/js/all-pages.js b/docs/docs/js/all-pages.js new file mode 100644 index 0000000..e9934d8 --- /dev/null +++ b/docs/docs/js/all-pages.js @@ -0,0 +1,12 @@ +//open external links in a new window +function external_new_window() { + for(var c = document.getElementsByTagName("a"), a = 0;a < c.length;a++) { + var b = c[a]; + if(b.getAttribute("href") && b.hostname !== location.hostname) { + b.target = "_blank"; + b.rel = "noopener"; + } + } +} + +external_new_window(); diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml new file mode 100644 index 0000000..fc66761 --- /dev/null +++ b/docs/mkdocs.yml @@ -0,0 +1,120 @@ +site_name: Workstation +site_description: Documentation for Workstation +repo_url: https://github.com/square/workstations +repo_name: "square/workstations" +edit_uri: "https://github.com/square/workstations/blob/main/docs/docs/" +strict: false +#test +#################################################################################################### +# Glad you're here. Some house rules: +# - The top-level tabs should stay lean. If you're adding more, get someone to peer review. +# - Maintain alphabetical ordering. Each top-level section has the following in this order: +# - A welcome page +# - Second-level sections +# - Standalone pages +# - Feel free to add a page to multiple top-level sections, if it's appropriate, but please try to keep a maximum of 3 levels. +# - If you are moving a page's URL (i.e. its location in the repo), add a redirect. There's a place +# list of redirects below. +# - Suffix external links in nav with ↗. +#################################################################################################### + +nav: + - Home: + - index.md + - Issues: + - https://github.com/square/workstations/issues + - Developer Docs: api.md + - Command Line Usage: + - cli.md + +theme: + name: material + favicon: img/favicon.ico + logo: img/cashapp.png + features: + - search.highlight + - search.suggest + - navigation.sections + - navigation.tabs + - navigation.tabs.sticky + - navigation.top + # - navigation.expand + - content.tabs.link + - navigation.indexes + palette: + - media: "(prefers-color-scheme: light)" + scheme: default + accent: green + toggle: + icon: material/eye-outline + name: Switch to dark mode + - media: "(prefers-color-scheme: dark)" + scheme: slate + accent: green + toggle: + icon: material/eye + name: Switch to light mode + static_templates: + - 404.html +extra_css: + - "css/app.css" +extra_javascript: + - "js/all-pages.js" +extra: + repo_icon: bitbucket + search: + tokenizer: '[\s\-\.]+' + prebuild_index: true + analytics: + provider: google + property: "UA-163700149-1" +markdown_extensions: + mkdocs-click: + pymdownx.snippets: + auto_append: + - includes/abbreviations.md + abbr: + footnotes: + admonition: + attr_list: + codehilite: + guess_lang: false + def_list: + markdown_include.include: + md_in_html: + meta: + pymdownx.betterem: + smart_enable: all + pymdownx.caret: + pymdownx.inlinehilite: + pymdownx.magiclink: + repo_url_shortener: true + repo_url_shorthand: true + social_url_shorthand: true + social_url_shortener: true + user: squareup + normalize_issue_symbols: true + pymdownx.smartsymbols: + pymdownx.superfences: + pymdownx.details: + pymdownx.critic: + pymdownx.tabbed: + alternate_style: true + smarty: + tables: + pymdownx.tasklist: + clickable_checkbox: true + custom_checkbox: true + toc: + permalink: true +plugins: + autorefs: + tags: + mkdocstrings: + handlers: + python: + paths: ["../src/workstation",] + search: + lang: en + redirects: + redirect_maps: diff --git a/justfile b/justfile new file mode 100644 index 0000000..ccdf378 --- /dev/null +++ b/justfile @@ -0,0 +1,13 @@ +# This is the default recipe when no arguments are provided +[private] +default: + @just --list --unsorted + +docs: + #!/usr/bin/env zsh + cd docs + uv run mkdocs serve + +tests: + #!/usr/bin/env zsh + uv run pytest --cov=workstation tests diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..75b2dab --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,76 @@ +[project] +name = "cloud-workstation" +version = "0.1.0" +requires-python = ">=3.8" +readme = "README.md" +description = "" +authors = [ + { name = "Damien Ramunno-Johnson", email = "damien@squareup.com" } # Use this object format for authors +] + + +dependencies = [ + "google-cloud-workstations>=0.5.6,<1.0.0", + "click>=8.1.7,<9.0.0", + "pyyaml>=6.0.1,<7.0.0", + "wheel>=0.43.0,<0.44.0", + "rich>=13.7.1,<14.0.0", + "grpcio==1.64.1", # Pin grpcio to version 1.64.1 as 1.65.0+ have issues https://github.com/grpc/grpc/issues/37178 + "google-cloud-logging>=3.10.0,<4.0.0" +] + +[tool.uv] +dev-dependencies = [ + "pytest>=8.2.0,<9.0.0", + "mypy>=1.10.0,<2.0.0", + "mkdocstrings-python>=1.10.0,<2.0.0", + "mkdocs-section-index>=0.3.9,<0.4.0", + "mkdocs>=1.6.0,<2.0.0", + "mkdocs-material>=9.5.20,<10.0.0", + "markdown-include>=0.8.1,<0.9.0", + "mkdocs-gen-files>=0.5.0,<0.6.0", + "mkdocs-literate-nav>=0.6.1,<0.7.0", + "pymdown-extensions>=10.8.1,<11.0.0", + "mkdocs-include-markdown-plugin>=6.0.5,<7.0.0", + "mkdocs-redirects>=1.2.1,<2.0.0", + "mkdocs-click>=0.8.1,<0.9.0", + "mkdocs-autorefs>=1.0.1,<2.0.0", + "pytest-xdist>=3.6.1,<4.0.0", + "types-pyyaml>=6.0.12.20240311,<7.0.0", + "pytest-cov", + "pytest-mock", +] + +[project.scripts] +workstation = "workstation.cli:cli" + +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.ruff.format] +docstring-code-format = true + +[tool.ruff.lint] +select = ["I", "D"] +ignore = ["D104", "D100"] + +[tool.pytest.ini_options] +addopts = "-n 4" +markers = ["integration: mark a test as an integration test."] +filterwarnings = ["ignore::DeprecationWarning"] + +[tool.ruff.lint.isort] +force-sort-within-sections = false + +[tool.ruff.lint.pydocstyle] +convention = "numpy" + +[tool.ruff.lint.per-file-ignores] +"tests/*" = ["D103"] + + +[tool.coverage.run] +omit = [ + "src/workstation/cli/__main__.py" +] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..df0951c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,25 @@ +google-cloud-workstations>=0.5.6,<1.0.0 +click>=8.1.7,<9.0.0 +pyyaml>=6.0.1,<7.0.0 +wheel>=0.43.0,<0.44.0 +rich>=13.7.1,<14.0.0 +grpcio==1.64.1 +google-cloud-logging>=3.10.0,<4.0.0 +pytest>=8.2.0,<9.0.0 +mypy>=1.10.0,<2.0.0 +mkdocstrings-python>=1.10.0,<2.0.0 +mkdocs-section-index>=0.3.9,<0.4.0 +mkdocs>=1.6.0,<2.0.0 +mkdocs-material>=9.5.20,<10.0.0 +markdown-include>=0.8.1,<0.9.0 +mkdocs-gen-files>=0.5.0,<0.6.0 +mkdocs-literate-nav>=0.6.1,<0.7.0 +pymdown-extensions>=10.8.1,<11.0.0 +mkdocs-include-markdown-plugin>=6.0.5,<7.0.0 +mkdocs-redirects>=1.2.1,<2.0.0 +mkdocs-click>=0.8.1,<0.9.0 +mkdocs-autorefs>=1.0.1,<2.0.0 +pytest-xdist>=3.6.1,<4.0.0 +types-pyyaml>=6.0.12.20240311,<7.0.0 +pytest-cov +pytest-mock diff --git a/src/workstation/__init__.py b/src/workstation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/workstation/cli/__init__.py b/src/workstation/cli/__init__.py new file mode 100644 index 0000000..cbad9ee --- /dev/null +++ b/src/workstation/cli/__init__.py @@ -0,0 +1,35 @@ +import click + +from .crud import create, delete, list, list_configs, logs, start, stop, sync + +try: + from block.clitools.clock import group as base_group + + namespace = "mlds" +except ImportError: + from click import group as base_group + + namespace = None + + +def group_wrapper(*args, **kwargs): # noqa: D103 + if namespace: + kwargs["namespace"] = namespace + return base_group(*args, **kwargs) + + +@group_wrapper(name="workstation") +@click.version_option(package_name="cloud-workstation") +@click.pass_context +def cli(context: click.Context): + """Create and manage Google Cloud Workstation.""" + + +cli.add_command(create) +cli.add_command(list_configs) +cli.add_command(list) +cli.add_command(start) +cli.add_command(stop) +cli.add_command(delete) +cli.add_command(sync) +cli.add_command(logs) diff --git a/src/workstation/cli/__main__.py b/src/workstation/cli/__main__.py new file mode 100644 index 0000000..2f5c613 --- /dev/null +++ b/src/workstation/cli/__main__.py @@ -0,0 +1,6 @@ +import sys + +from . import cli + +if __name__ == "__main__": + sys.exit(cli()) diff --git a/src/workstation/cli/crud.py b/src/workstation/cli/crud.py new file mode 100644 index 0000000..560a24d --- /dev/null +++ b/src/workstation/cli/crud.py @@ -0,0 +1,507 @@ +""" +crud module provides command-line interface (CLI) commands for managing workstations. + +Include functionalities to create, delete, list, start, and stop workstations, +as well as manage workstation configurations. + +Functions +--------- +get_gcloud_config(project: Optional[str], location: Optional[str]) -> tuple + Retrieve GCP configuration details including project, location, and account. +common_options(func: callable) -> callable + Apply common CLI options to commands. +create(context: click.Context, cluster: Optional[str], config: str, location: Optional[str], name: str, project: Optional[str], proxy: Optional[str], no_proxy: Optional[str], **kwargs) + Create a workstation. +list_configs(context: click.Context, project: Optional[str], location: Optional[str], **kwargs) + List workstation configurations. +list(context: click.Context, project: Optional[str], location: Optional[str], all: bool, user: str, export_json: bool, cluster: Optional[str], **kwargs) + List workstations. +start(context: click.Context, name: str, code: bool, browser: bool, **kwargs) + Start workstation and optionally open it either locally with VSCode or through VSCode in a browser. +stop(context: click.Context, name: str, **kwargs) + Stop workstation. +delete(context: click.Context, name: str, **kwargs) + Delete workstation. +sync(context: click.Context, name: str, **kwargs) + Sync files to workstation. +logs(name: str, project: str, **kwargs) + Open logs for the workstation. + +""" + +import getpass +import json +import sys +import webbrowser +from typing import Optional + +import click +from rich.console import Console +from rich.prompt import Confirm +from rich.traceback import install +from rich.tree import Tree + +from workstation.config import ConfigManager +from workstation.core import ( + create_workstation, + delete_workstation, + list_workstation_configs, + list_workstations, + start_workstation, + stop_workstation, +) +from workstation.utils import ( + check_gcloud_auth, + config_tree, + get_instance_assignment, + read_gcloud_config, + sync_files_workstation, +) + +try: + from block.clitools.clock import command +except ImportError: + from click import command + +config_manager = ConfigManager() +console = Console() +install() + + +def get_gcloud_config(project: Optional[str], location: Optional[str]): # noqa: D103 + """ + Retrieve GCP configuration details including project, location, and account. + + Parameters + ---------- + project : str, optional + GCP project name. + location : str, optional + GCP location. + + Returns + ------- + tuple + A tuple containing project, location, and account details. + """ + config_project, config_location, account = read_gcloud_config() + + if project is None: + if config_project is not None: + project = config_project + else: + raise ValueError( + "Project not found in gcloud config and was not passed in." + ) + + if location is None: + if config_location is not None: + location = config_location + else: + raise ValueError( + "Location not found in gcloud config and was not passed in." + ) + + if account is None: + raise ValueError("Account not found in gcloud config.") + + return project, location, account + + +_common_options = [ + click.option( + "--config", + "-c", + help="Name of workstation config", + type=str, + metavar="", + ), + click.option( + "--project", + "-p", + help="GCP Project name, if not provided will use the default project in gcloud config.", + type=str, + metavar="", + ), + click.option( + "--location", + "-l", + help="Workstation location, if not provided will use the default location in gcloud config.", + default="us-central1", + type=str, + metavar="", + ), + click.option( + "--cluster", + default="cluster-public", + help="Cluster used for workstations.", # NOQA + type=str, + metavar="", + ), +] + + +def common_options(func): # noqa: D103 + """ + Apply common CLI options to commands. + + Parameters + ---------- + func : callable + The function to apply the options to. + + Returns + ------- + callable + The function with common options applied. + """ + for option in reversed(_common_options): + func = option(func) + return func + + +@command() +@common_options +@click.option( + "--name", + help="Name of the workstation to create.", + type=str, + metavar="", +) +@click.option( + "--proxy", + help="proxy setting.", + type=str, + metavar="", +) +@click.option( + "--no-proxy", + help="No proxy setting.", + type=str, + metavar="", +) +@click.pass_context +def create( + context: click.Context, + cluster: Optional[str], + config: str, + location: Optional[str], + name: str, + project: Optional[str], + proxy: Optional[str], + no_proxy: Optional[str], + **kwargs, +): + """Create a workstation.""" + # Make sure the user is authenticated + check_gcloud_auth() + + project, location, account = get_gcloud_config(project=project, location=location) + + # Ensure USER is set on laptop + user = getpass.getuser() + + try: + from block.mlds.proxy.block import Proxy + + proxies = Proxy(project=project, name=name) + proxy = proxies.proxy + no_proxy = proxies.no_proxy + except ImportError: + pass + + if config_manager.check_if_config_exists(name): + console.print(f"Workstation config for {name} already exists.") + overwrite = Confirm.ask("Overwrite config?") + if not overwrite: + console.print(f"Exiting without creating workstation {name}.") + sys.exit(0) + + _ = create_workstation( + cluster=cluster, + config=config, + name=name, + user=user, + account=account, + project=project, + location=location, + proxy=proxy, + no_proxy=no_proxy, + ) + + config_manager.write_ssh_config( + name=name, + user=user, + cluster=cluster, + region=location, + project=project, + config=config, + ) + + console.print(f"Workstation {name} created.") + + +@command() +@common_options +@click.pass_context +def list_configs( + context: click.Context, + project: Optional[str], + location: Optional[str], + **kwargs, +): + """List workstation configurations.""" + # Make sure the user is authenticated + check_gcloud_auth() + + project, location, account = get_gcloud_config(project=project, location=location) + configs = list_workstation_configs( + cluster=kwargs["cluster"], + project=project, + location=location, + ) + + console.print(config_tree(configs)) + + +@command() +@common_options +@click.option( + "--json", + "export_json", + default=False, + is_flag=True, + help="print json output", +) +@click.option( + "-u", + "--user", + default=getpass.getuser(), + help="Lists workstations only from a given user.", +) +@click.option( + "-a", "--all", is_flag=True, default=False, help="List workstations from all users." +) +@click.pass_context +def list( + context: click.Context, + project: Optional[str], + location: Optional[str], + all: bool, + user: str, + export_json: bool, + cluster: Optional[str], + **kwargs, +): + """List workstations.""" + # Make sure the user is authenticated + check_gcloud_auth() + + project, location, account = get_gcloud_config(project=project, location=location) + + workstations = list_workstations( + cluster=cluster, + project=project, + location=location, + ) + + if not export_json: + tree = Tree("Workstations", style="bold blue") + + for workstation in workstations: + if not all and workstation.get("env", {}).get("LDAP") != user: + continue + + if workstation["state"].name == "STATE_RUNNING": + status = ":play_button: Running" + elif workstation["state"].name == "STATE_STOPPED": + status = ":stop_sign: Stopped" + + config_branch = tree.add( + f"Workstation: {workstation['name'].split('/')[-1]}" + ) + config_branch.add(f"{status}", style="white") + config_branch.add(f"User: {workstation['env']['LDAP']}", style="white") + config_branch.add(f":minidisc: Image: {workstation['config']['image']}") + config_branch.add( + f":computer: Machine Type: {workstation['config']['machine_type']}" + ) + config_branch.add( + f":hourglass_flowing_sand: Idle Timeout (s): {str(workstation['config']['idle_timeout'])}" + ) + config_branch.add( + f":hourglass_flowing_sand: Max Runtime (s): {str(workstation['config']['max_runtime'])}" + ) + + console.print(tree) + console.print("Total Workstations: ", len(tree.children)) + else: + results = [] + for workstation in workstations: + result = {} + result["name"] = workstation["name"].split("/")[-1] + result["user"] = workstation["env"]["LDAP"] + result["user"] = workstation["env"]["LDAP"] + result["project"] = workstation["project"] + result["location"] = workstation["location"] + result["config"] = workstation["config"]["name"].split("/")[-1] + result["cluster"] = workstation["cluster"] + result["state"] = workstation["state"].name + result["idle_timeout"] = workstation["config"]["idle_timeout"] + result["max_runtime"] = workstation["config"]["max_runtime"] + result["type"] = workstation["config"]["machine_type"] + result["image"] = workstation["config"]["image"] + results.append(result) + + json_data = json.dumps(results, indent=4) + console.print(json_data) + + +@command() +@click.option( + "-n", + "--name", + help="Name of the workstation to start.", + type=str, + metavar="", + required=True, +) +@click.option( + "--code", + help="Open workstation in VSCode locally. " + "This requires setup illustrated in " + "https://workstation.mlds.cash/#connect-to-a-workstation-with-local-vs-code", + is_flag=True, + default=False, +) +@click.option( + "--browser", + help="Open workstation with a remote VSCode session in a web browser.", + is_flag=True, + default=False, +) +@click.pass_context +def start(context: click.Context, name: str, code: bool, browser: bool, **kwargs): + """Start workstation and optionally open it either locally with VSCode or through VSCode in a browser.""" + # Make sure the user is authenticated + check_gcloud_auth() + + if code and browser: + raise ValueError( + "Select either local VSCode (--code) or remote VSCode in web browser (--browser)." + ) + + workstation_details = config_manager.read_configuration(name) + + response = start_workstation(**workstation_details) + url = f"https://80-{response.host}" + if not code and not browser: + console.print( + "Use --browser or --code to open the workstation in browser or vs code directly." + ) + console.print(url) + elif code: + url = f"vscode://vscode-remote/ssh-remote+{name}/home/{getpass.getuser()}" + console.print("Opening workstation in VSCode...") + webbrowser.open(url) + elif browser: + console.print(f"Opening workstation at {url}...") + webbrowser.open(url) + + +@command() +@click.option( + "--name", + help="Name of the workstation to stop.", + type=str, + metavar="", +) +@click.pass_context +def stop(context: click.Context, **kwargs): + """Stop workstation.""" + # Make sure the user is authenticated + check_gcloud_auth() + + workstation_details = config_manager.read_configuration(kwargs["name"]) + response = stop_workstation(**workstation_details) + console.print(response.name, response.state) + + +@command() +@click.option( + "--name", + help="Name of the workstation to delete.", + type=str, + metavar="", +) +@click.pass_context +def delete(context: click.Context, **kwargs): + """Delete workstation.""" + # Make sure the user is authenticated + check_gcloud_auth() + + workstation_details = config_manager.read_configuration(kwargs["name"]) + + response = delete_workstation(**workstation_details) + config_manager.delete_configuration(kwargs["name"]) + if response.state.value == 0: + console.print(f"Workstation {kwargs['name']} deleted.") + + +@command() +@click.option( + "--name", + help="Name of the workstation to sync.", + type=str, + metavar="", +) +@click.pass_context +def sync( + context: click.Context, + name: str, + **kwargs, +): + """Sync files to workstation.""" + # Make sure the user is authenticated + check_gcloud_auth() + + # TDOO: Add source and destination options + source = "~/remote-machines/workstation/" + destination = "~/" + + workstation_details = config_manager.read_configuration(name) + + result = sync_files_workstation( + source=source, + destination=destination, + **workstation_details, + ) + + for line in result.stdout.split("\n"): + console.print(line) + if result.returncode != 0: + console.print(result.args) + console.print(result.stderr) + + +@command() +@click.argument( + "name", + type=str, +) +@click.option( + "--project", + help="Name of the workstation GCP project.", + type=str, + metavar="", +) +def logs(name: str, project: str, **kwargs): + """Open logs for the workstation.""" + check_gcloud_auth() + instances = get_instance_assignment(project=project, name=name) + instance = instances.get(name, None) + if instances is None: + console.print(f"Workstation {name} not found.") + return + console.print(f"Logs for instance: {instance.get('instance_name')} opening") + webbrowser.open(instance.get("logs_url")) diff --git a/src/workstation/config.py b/src/workstation/config.py new file mode 100644 index 0000000..9a8d5e4 --- /dev/null +++ b/src/workstation/config.py @@ -0,0 +1,311 @@ +import os +import re +from dataclasses import asdict, dataclass +from pathlib import Path +from textwrap import dedent + +import yaml +from rich.console import Console + +from workstation.utils import NoPortFree, check_socket + +console = Console() + + +@dataclass +class WorkstationConfig: + """A class to represent a Workstation's configuration. + + Attributes + ---------- + name : str + The name of the workstation. + location : str + The location where the workstation is deployed. + cluster : str + The cluster associated with the workstation. + config : str + The specific configuration settings of the workstation. + project : str + The project associated with the workstation. + + Methods + ------- + generate_workstation_yml() -> Path + Generates a YAML configuration file for the workstation and saves it to the current directory. + """ + + name: str + location: str + cluster: str + config: str + project: str + + def generate_workstation_yml(self) -> Path: + """Generate a YAML configuration file for the workstation. + + Returns + ------- + Path + The path to the generated YAML file. + """ + write_path = Path(".", f"{self.name}.yml") + with open(write_path, "w") as file: + yaml.dump(asdict(self), file, sort_keys=False) + + return write_path + + +class ConfigManager: + """A class to manage Workstation configurations. + + Attributes + ---------- + workstation_data_dir : Path + The directory where workstation data is stored. + workstation_configs : Path + The directory where individual workstation configurations are stored. + + Methods + ------- + check_if_config_exists(name: str) -> bool + Checks if a configuration file with the given name exists. + write_configuration(project: str, name: str, location: str, cluster: str, config: str) -> Path + Writes the configuration to a YAML file and returns the path to it. + read_configuration(name: str) -> dict + Reads the configuration for the given name and returns it as a dictionary. + delete_configuration(name: str) -> None + Deletes the configuration file and its corresponding YAML file for the given name. + write_ssh_config(name: str, user: str, project: str, cluster: str, config: str, region: str) + Writes the SSH configuration for the workstation. + """ + + def __init__(self): + self.workstation_data_dir = Path.home() / ".workstations" + self.workstation_configs = self.workstation_data_dir / "configs" + + def check_if_config_exists(self, name: str) -> bool: + """Check if a configuration file with the given name exists. + + Parameters + ---------- + name : str + The name of the configuration to check. + + Returns + ------- + bool + True if the configuration exists, False otherwise. + """ + return (self.workstation_configs / (name + ".yml")).exists() + + def write_configuration( + self, project: str, name: str, location: str, cluster: str, config: str + ) -> Path: + """Write the configuration to a YAML file. + + Parameters + ---------- + project : str + The project name. + name : str + The name of the workstation. + location : str + The location of the workstation. + cluster : str + The cluster associated with the workstation. + config : str + The specific configuration settings. + + Returns + ------- + Path + The path to the written YAML file. + + Raises + ------ + Exception + If any error occurs during the writing process. + """ + self.workstation_configs.mkdir(parents=True, exist_ok=True) + + current_dir = Path.cwd() + os.chdir(self.workstation_configs) + try: + workstation = WorkstationConfig( + project=project, + name=name, + location=location, + cluster=cluster, + config=config, + ) + + workstation_path = workstation.generate_workstation_yml() + return self.workstation_configs / workstation_path + except Exception as e: + os.chdir(current_dir) + raise e + + def read_configuration(self, name: str) -> dict: + """Read the configuration for the given name. + + Parameters + ---------- + name : str + The name of the configuration to read. + + Returns + ------- + dict + The contents of the configuration file as a dictionary. + + Raises + ------ + FileNotFoundError + If the configuration file does not exist. + KeyError + If required keys are missing from the configuration file. + """ + workstation_config = self.workstation_configs / (name + ".yml") + + if not workstation_config.exists(): + raise FileNotFoundError( + f"Configuration {name} not found, please check if {workstation_config} exists." + ) + + with open(workstation_config, "r") as file: + contents = yaml.safe_load(file) + + # check that project, name, location, cluster, and config are in the file + # For the error say what keys are missing + if not all( + key in contents + for key in ["project", "name", "location", "cluster", "config"] + ): + missing_keys = [ + key + for key in ["project", "name", "location", "cluster", "config"] + if key not in contents + ] + raise KeyError(f"Configuration file {name} is missing keys {missing_keys}") + + return contents + + def delete_configuration(self, name: str) -> None: + """Delete the configuration file and its corresponding YAML file. + + Parameters + ---------- + name : str + The name of the configuration to delete. + + Raises + ------ + FileNotFoundError + If the configuration file does not exist. + """ + workstation_yml = self.workstation_configs / (name + ".yml") + workstation_config = self.workstation_configs / (name + ".config") + + if not workstation_config.exists(): + raise FileNotFoundError(f"Configuration {name} not found") + if not workstation_yml.exists(): + raise FileNotFoundError(f"Configuration {name} not found") + + workstation_config.unlink() + workstation_yml.unlink() + + def write_ssh_config( + self, + name: str, + user: str, + project: str, + cluster: str, + config: str, + region: str, + ): + """Write the SSH configuration for the workstation. + + Parameters + ---------- + name : str + The name of the workstation. + user : str + The user for SSH connection. + project : str + The project name. + cluster : str + The cluster associated with the workstation. + config : str + The specific configuration settings. + region : str + The region where the workstation is deployed. + + Raises + ------ + NoPortFree + If no free port is found after checking 20 ports. + """ + workstation_config = self.workstation_configs / (name + ".config") + + # get all of the ports that are currently in use from the config files + ports = [] + for config_file in self.workstation_configs.glob("*.config"): + with open(config_file, "r") as file: + contents = file.read() + # Check if the match is not None before calling group + match = re.search(r"\n\s*Port\s+(\d+)", contents) + if match is not None: + port = int(match.group(1)) + ports.append(port) + + if len(ports) == 0: + port = 6000 + else: + port = max(ports) + 1 + + for _ in range(20): + if check_socket("localhost", port): + break + port += 1 + else: + raise NoPortFree("Could not find a free port after checking 20 ports.") + + proxy_command = ( + "sh -c '" + "cleanup() { pkill -P $$; }; " + "trap cleanup EXIT; " + "gcloud workstations start-tcp-tunnel " + f"--project={project} " + f"--cluster={cluster} " + f"--config={config} " + f"--region={region} " + "--local-host-port=localhost:%p %h 22 & " + "timeout=10; " + "while ! nc -z localhost %p; do " + "sleep 1; " + "timeout=$((timeout - 1)); " + "if [ $timeout -le 0 ]; then " + "exit 1; " + "fi; " + "done; " + "nc localhost %p'" + ) + + config_content = dedent( + f""" + Host {name} + HostName {name} + Port {port} + User {user} + StrictHostKeyChecking no + UserKnownHostsFile /dev/null + ControlMaster auto + ControlPersist 30m + ControlPath ~/.ssh/cm/%r@%h:%p + ProxyCommand {proxy_command} + """ + ).strip() + + with open(workstation_config, "w") as file: + file.write(config_content) diff --git a/src/workstation/core.py b/src/workstation/core.py new file mode 100644 index 0000000..3c9b2f1 --- /dev/null +++ b/src/workstation/core.py @@ -0,0 +1,348 @@ +import sys +from typing import Dict, List, Optional + +from google.api_core.exceptions import AlreadyExists +from google.api_core.operation import Operation +from google.cloud import workstations_v1beta +from google.cloud.workstations_v1beta.types import Workstation +from rich.console import Console + +from workstation.config import ConfigManager +from workstation.machines import machine_types +from workstation.utils import get_logger + +console = Console() +config_manager = ConfigManager() +logger = get_logger() + + +def list_workstation_clusters(project: str, location: str) -> List[Dict]: + """ + List workstation clusters in a specific project and location. + + Parameters + ---------- + project : str + The Google Cloud project ID. + location : str + The Google Cloud location. + + Returns + ------- + List[Dict] + A list of workstation cluster configurations. + """ + client = workstations_v1beta.WorkstationsClient() + + request = workstations_v1beta.ListWorkstationClustersRequest( + parent=f"projects/{project}/locations/{location}", + ) + page_result = client.list_workstation_clusters(request=request) + + configs = [] + for config in page_result: + configs.append( + { + "name": config.name, + "image": config.subnetwork, + } + ) + return configs + + +def list_workstation_configs(project: str, location: str, cluster: str) -> List[Dict]: + """ + List usable workstation configurations in a specific project, location, and cluster. + + Parameters + ---------- + project : str + The Google Cloud project ID. + location : str + The Google Cloud location. + cluster : str + The workstation cluster name. + + Returns + ------- + List[Dict] + A list of usable workstation configurations. + """ + client = workstations_v1beta.WorkstationsClient() + + request = workstations_v1beta.ListUsableWorkstationConfigsRequest( + parent=f"projects/{project}/locations/{location}/workstationClusters/{cluster}", + ) + page_result = client.list_usable_workstation_configs(request=request) + + configs = [] + for config in page_result: + if config.host.gce_instance.machine_type not in machine_types: + logger.debug( + f"{config.host.gce_instance.machine_type} not exist in machine_types in machines.py" + ) + continue + machine_details = machine_types[config.host.gce_instance.machine_type] + machine_specs = f"machine_specs[{machine_details['vCPUs']} vCPUs, {machine_details['Memory (GB)']} GB]" + configs.append( + { + "name": config.name, + "image": config.container.image, + "machine_type": config.host.gce_instance.machine_type, + "idle_timeout": config.idle_timeout.total_seconds(), + "max_runtime": config.running_timeout.total_seconds(), + "machine_specs": machine_specs, + } + ) + return configs + + +def create_workstation( + project: str, + location: str, + cluster: str, + config: str, + name: str, + account: str, + user: str, + proxy: Optional[str] = None, + no_proxy: Optional[str] = None, +) -> Workstation: + """ + Create a new workstation with the specified configuration. + + Parameters + ---------- + project : str + The Google Cloud project ID. + location : str + The Google Cloud location. + cluster : str + The workstation cluster name. + config : str + The workstation configuration name. + name : str + The name of the new workstation. + account : str + The account associated with the workstation. + user : str + The user associated with the workstation. + proxy : Optional[str], optional + Proxy settings, by default None. + no_proxy : Optional[str], optional + No-proxy settings, by default None. + + Returns + ------- + Workstation + Response from the workstation creation request. + """ + client = workstations_v1beta.WorkstationsClient() + env = { + "LDAP": user, + "ACCOUNT": account, + } + + if proxy: + env["http_proxy"] = proxy + env["HTTPS_PROXY"] = proxy + env["https_proxy"] = proxy + env["HTTP_PROXY"] = proxy + env["no_proxy"] = no_proxy + env["NO_PROXY"] = no_proxy + + request = workstations_v1beta.CreateWorkstationRequest( + parent=f"projects/{project}/locations/{location}/workstationClusters/{cluster}/workstationConfigs/{config}", + workstation_id=name, + workstation=Workstation( + display_name=name, + env=env, + ), + ) + + try: + operation = client.create_workstation(request=request) + response = operation.result() + except AlreadyExists: + console.print(f"Workstation [bold blue]{name}[/bold blue] already exists") + sys.exit(1) + + config_manager.write_configuration( + project=project, + name=name, + location=location, + cluster=cluster, + config=config, + ) + + return response + + +def start_workstation( + project: str, + name: str, + location: str, + cluster: str, + config: str, +) -> Operation: + """ + Start an existing workstation. + + Parameters + ---------- + project : str + The Google Cloud project ID. + name : str + The name of the workstation. + location : str + The Google Cloud location. + cluster : str + The workstation cluster name. + config : str + The workstation configuration name. + + Returns + ------- + Operation + Response from the workstation start request. + """ + client = workstations_v1beta.WorkstationsClient() + + request = workstations_v1beta.StartWorkstationRequest( + name=f"projects/{project}/locations/{location}/workstationClusters/{cluster}/workstationConfigs/{config}/workstations/{name}", + ) + + operation = client.start_workstation(request=request) + console.print("Waiting for operation to complete (~3 minutes)...") + response = operation.result() + + return response + + +def stop_workstation( + project: str, + name: str, + location: str, + cluster: str, + config: str, +) -> Operation: + """ + Stop an existing workstation. + + Parameters + ---------- + project : str + The Google Cloud project ID. + name : str + The name of the workstation. + location : str + The Google Cloud location. + cluster : str + The workstation cluster name. + config : str + The workstation configuration name. + + Returns + ------- + Operation + Response from the workstation stop request. + """ + client = workstations_v1beta.WorkstationsClient() + + request = workstations_v1beta.StopWorkstationRequest( + name=f"projects/{project}/locations/{location}/workstationClusters/{cluster}/workstationConfigs/{config}/workstations/{name}", + ) + + operation = client.stop_workstation(request=request) + console.print("Waiting for operation to complete...") + response = operation.result() + + return response + + +def delete_workstation( + project: str, + name: str, + location: str, + cluster: str, + config: str, +) -> Operation: + """ + Delete an existing workstation. + + Parameters + ---------- + project : str + The Google Cloud project ID. + name : str + The name of the workstation. + location : str + The Google Cloud location. + cluster : str + The workstation cluster name. + config : str + The workstation configuration name. + + Returns + ------- + Operation + Response from the workstation deletion request. + """ + client = workstations_v1beta.WorkstationsClient() + + request = workstations_v1beta.DeleteWorkstationRequest( + name=f"projects/{project}/locations/{location}/workstationClusters/{cluster}/workstationConfigs/{config}/workstations/{name}", + ) + + operation = client.delete_workstation(request=request) + console.print("Waiting for operation to complete...") + response = operation.result() + + return response + + +def list_workstations(project: str, location: str, cluster: str) -> List[Dict]: + """ + List all workstations in a specific project, location, and cluster. + + Parameters + ---------- + project : str + The Google Cloud project ID. + location : str + The Google Cloud location. + cluster : str + The workstation cluster name. + + Returns + ------- + List[Dict] + A list of workstation configurations. + """ + configs = list_workstation_configs( + project=project, location=location, cluster=cluster + ) + + client = workstations_v1beta.WorkstationsClient() + workstations = [] + + for config in configs: + request = workstations_v1beta.ListWorkstationsRequest( + parent=config.get("name"), + ) + + page_result = client.list_workstations(request=request) + + for workstation in page_result: + workstations.append( + { + "name": workstation.name, + "state": workstation.state, + "env": workstation.env, + "config": config, + "project": project, + "location": location, + "cluster": cluster, + } + ) + return workstations diff --git a/src/workstation/machines.py b/src/workstation/machines.py new file mode 100644 index 0000000..e5cb23e --- /dev/null +++ b/src/workstation/machines.py @@ -0,0 +1,42 @@ +machine_types = { + "e2-medium": {"Type": "E2", "vCPUs": 2, "Memory (GB)": 4}, + "e2-standard-2": {"Type": "E2", "vCPUs": 2, "Memory (GB)": 8}, + "e2-standard-4": {"Type": "E2", "vCPUs": 4, "Memory (GB)": 16}, + "e2-standard-8": {"Type": "E2", "vCPUs": 8, "Memory (GB)": 32}, + "e2-standard-16": {"Type": "E2", "vCPUs": 16, "Memory (GB)": 64}, + "e2-standard-32": {"Type": "E2", "vCPUs": 32, "Memory (GB)": 128}, + "n1-standard-1": {"Type": "N1", "vCPUs": 1, "Memory (GB)": 3.75}, + "n1-standard-2": {"Type": "N1", "vCPUs": 2, "Memory (GB)": 7.5}, + "n1-standard-4": {"Type": "N1", "vCPUs": 4, "Memory (GB)": 15}, + "n1-standard-8": {"Type": "N1", "vCPUs": 8, "Memory (GB)": 30}, + "n1-standard-16": {"Type": "N1", "vCPUs": 16, "Memory (GB)": 60}, + "n1-standard-32": {"Type": "N1", "vCPUs": 32, "Memory (GB)": 120}, + "n1-standard-64": {"Type": "N1", "vCPUs": 64, "Memory (GB)": 240}, + "n1-standard-96": {"Type": "N1", "vCPUs": 96, "Memory (GB)": 360}, + "n2-standard-2": {"Type": "N2", "vCPUs": 2, "Memory (GB)": 8}, + "n2-standard-4": {"Type": "N2", "vCPUs": 4, "Memory (GB)": 16}, + "n2-standard-8": {"Type": "N2", "vCPUs": 8, "Memory (GB)": 32}, + "n2-standard-16": {"Type": "N2", "vCPUs": 16, "Memory (GB)": 64}, + "n2-standard-32": {"Type": "N2", "vCPUs": 32, "Memory (GB)": 128}, + "n2d-standard-2": {"Type": "N2D", "vCPUs": 2, "Memory (GB)": 8}, + "n2d-standard-4": {"Type": "N2D", "vCPUs": 4, "Memory (GB)": 16}, + "n2d-standard-8": {"Type": "N2D", "vCPUs": 8, "Memory (GB)": 32}, + "n2d-standard-16": {"Type": "N2D", "vCPUs": 16, "Memory (GB)": 64}, + "n2d-standard-32": {"Type": "N2D", "vCPUs": 32, "Memory (GB)": 128}, + "n2d-highmem-2": {"Type": "N2D Highmem", "vCPUs": 2, "Memory (GB)": 16}, + "n2d-highmem-4": {"Type": "N2D Highmem", "vCPUs": 4, "Memory (GB)": 32}, + "n2d-highmem-8": {"Type": "N2D Highmem", "vCPUs": 8, "Memory (GB)": 64}, + "n2d-highmem-16": {"Type": "N2D Highmem", "vCPUs": 16, "Memory (GB)": 128}, + "n2d-highmem-32": {"Type": "N2D Highmem", "vCPUs": 32, "Memory (GB)": 256}, + "n2d-highmem-48": {"Type": "N2D Highmem", "vCPUs": 48, "Memory (GB)": 384}, + "n2d-highmem-64": {"Type": "N2D Highmem", "vCPUs": 64, "Memory (GB)": 512}, + "n2d-highmem-80": {"Type": "N2D Highmem", "vCPUs": 80, "Memory (GB)": 640}, + "n2d-highmem-96": {"Type": "N2D Highmem", "vCPUs": 96, "Memory (GB)": 768}, + "t2d-standard-60": {"Type": "Tau T2D", "vCPUs": 60, "Memory (GB)": 240}, + "a2-highgpu-1g": {"Type": "A2", "vCPUs": 12, "Memory (GB)": 85}, + "a2-highgpu-2g": {"Type": "A2", "vCPUs": 24, "Memory (GB)": 170}, + "a2-highgpu-4g": {"Type": "A2", "vCPUs": 48, "Memory (GB)": 340}, + "a2-highgpu-8g": {"Type": "A2", "vCPUs": 96, "Memory (GB)": 680}, + "a2-megagpu-16g": {"Type": "A2", "vCPUs": 96, "Memory (GB)": 1360}, + "a2-ultragpu-1g": {"Type": "A2", "vCPUs": 12, "Memory (GB)": 170}, +} diff --git a/src/workstation/utils.py b/src/workstation/utils.py new file mode 100644 index 0000000..e5644a0 --- /dev/null +++ b/src/workstation/utils.py @@ -0,0 +1,359 @@ +import configparser +import logging +import os +import socket +import subprocess +import sys +import time +from datetime import datetime, timedelta, timezone +from subprocess import CalledProcessError + +import google.auth +from google.auth.exceptions import DefaultCredentialsError, RefreshError +from google.auth.transport.requests import Request +from google.cloud import logging as cloud_logging # Google Cloud Logging client +from rich.console import Console +from rich.tree import Tree + +console = Console() + + +def default_serializer(obj): + """ + Handle specific object types that are not serializable by default. + + Parameters + ---------- + obj : Any + The object to serialize. + + Returns + ------- + Any + Serialized object (e.g., dictionary). + + Raises + ------ + TypeError + If the object type is not serializable. + """ + # Handle protobuf ScalarMapContainer + if hasattr(obj, "MapContainer") or "google._upb._message.ScalarMapContainer" in str( + type(obj) + ): + # Convert and filter out non-essential attributes + return { + key: value for key, value in obj.__dict__.items() if key != "MapContainer" + } + raise TypeError(f"Type {type(obj)} not serializable") + + +def read_gcloud_config(): + """ + Read the default Google Cloud configuration. + + Returns + ------- + Tuple[str, str, str] + Default project ID, location, and account from gcloud configuration. + """ + config_path = os.path.expanduser("~/.config/gcloud/configurations/config_default") + config = configparser.ConfigParser() + config.read(config_path) + + # Assuming the default settings are under the 'core' section + default_project = config.get("core", "project", fallback=None) + default_location = config.get("compute", "region", fallback=None) + account = config.get("core", "account", fallback=None) + + return default_project, default_location, account + + +def config_tree(configs: list) -> Tree: + """ + Generate a tree structure for displaying workstation configurations using Rich library. + + Parameters + ---------- + configs : list + A list of workstation configurations. + + Returns + ------- + Tree + A Rich Tree object representing the configurations. + """ + tree = Tree("Configs", style="bold blue") + + for config in configs: + config_branch = tree.add(f"Config: {config['name'].split('/')[-1]}") + config_branch.add(f":minidisc: Image: {config['image']}") + config_branch.add(f":computer: Machine Type: {config['machine_type']}") + config_branch.add(f":computer: Machine Specs: {config['machine_specs']}") + config_branch.add( + f":hourglass_flowing_sand: Idle Timeout (s): {str(config['idle_timeout'])}" + ) + config_branch.add( + f":hourglass_flowing_sand: Max Runtime (s): {str(config['max_runtime'])}" + ) + + return tree + + +def check_socket(host, port): + """ + Check if a socket on the given host and port is available. + + Parameters + ---------- + host : str + The hostname or IP address. + port : int + The port number. + + Returns + ------- + bool + True if the socket is available, False otherwise. + """ + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + s.bind((host, port)) + return True + except socket.error: + return False + finally: + s.close() + + +def sync_files_workstation( + project: str, + name: str, + location: str, + cluster: str, + config: str, + source: str, + destination: str, +): + """ + Synchronize files from the local system to the workstation using rsync over an SSH tunnel. + + Parameters + ---------- + project : str + The Google Cloud project ID. + name : str + The name of the workstation. + location : str + The Google Cloud location. + cluster : str + The workstation cluster name. + config : str + The workstation configuration name. + source : str + The source directory on the local system. + destination : str + The destination directory on the workstation. + + Returns + ------- + subprocess.CompletedProcess + The result of the rsync command. + """ + port = 61000 + for _ in range(20): + if check_socket("localhost", port): + break + port += 1 + else: + raise NoPortFree("Could not find a free port after checking 20 ports.") + + process = subprocess.Popen( + [ + "gcloud", + "workstations", + "start-tcp-tunnel", + f"--project={project}", + f"--cluster={cluster}", + f"--config={config}", + f"--region={location}", + f"--region={location}", + f"{name}", + "22", + f"--local-host-port=:{port}", + ], + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + if process.poll() is not None: + if process.returncode != 0: + raise CalledProcessError(process.stderr.read()) + + # use rsync to sync files from local to workstation + source_path = os.path.expanduser(source) + destination_path = f"localhost:{destination}" + + command = [ + "rsync", + "-av", + "--exclude=.venv", + "--exclude=.git", + "--exclude=.DS_Store", + "-e", + f"ssh -p {port} -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null", + source_path, + destination_path, + ] + counter = 0 + while check_socket("localhost", port): + if counter >= 10: + break + time.sleep(1) + counter += 1 + + result = subprocess.run(command, capture_output=True, text=True) + process.kill() + return result + + +class NoPortFree(Exception): + """Exception raised when no free port is available for the SSH tunnel.""" + + pass + + +def check_gcloud_auth(): + """ + Check if the current gcloud CLI is authenticated and refresh if necessary. + + Returns + ------- + bool + True if authentication is successful, False otherwise. + + Raises + ------ + SystemExit + If reauthentication is needed. + """ + from google.auth import default + + try: + credentials, project = google.auth.default() + + # Check if the credentials are valid and refresh if necessary + if credentials.requires_scopes: + credentials = credentials.with_scopes( + ["https://www.googleapis.com/auth/cloud-platform"] + ) + + credentials.refresh(Request()) + return True + + except (DefaultCredentialsError, RefreshError): + console.print( + "Reauthentication is needed. Please run [bold blue]gcloud auth login & gcloud auth application-default login[/bold blue]." + ) + sys.exit(1) + + +def get_instance_assignment(project: str, name: str): + """ + Get the instance assignment log entries for a specific workstation. + + Parameters + ---------- + project : str + The Google Cloud project ID. + name : str + The name of the workstation. + + Returns + ------- + Dict + A dictionary of log entries related to the instance assignment. + """ + check_gcloud_auth() + client = cloud_logging.Client(project=project) + + timestamp = (datetime.now(timezone.utc) - timedelta(days=1)).isoformat() + + filter_str = ( + f'logName="projects/{project}/logs/workstations.googleapis.com%2Fvm_assignments" ' + f'AND timestamp >= "{timestamp}"' + ) + + entries = client.list_entries(filter_=filter_str) + + log_entries_dict = {} + + for entry in entries: + try: + workstation_id, log_entry = process_entry(entry, project) + log_entries_dict[workstation_id] = log_entry + if workstation_id == name: + return log_entries_dict + except Exception as exc: + print(f"Entry {entry} generated an exception: {exc}") + + return log_entries_dict + + +def process_entry(entry, project): + """ + Process a log entry to extract workstation information. + + Parameters + ---------- + entry + A log entry object. + project : str + The Google Cloud project ID. + + Returns + ------- + Tuple[str, Dict] + Workstation ID and a dictionary with instance information. + """ + workstation_id = entry.resource.labels.get("workstation_id") + instance_name = entry.labels.get("instance_name") + instance_id = entry.labels.get("instance_id") + + resource_type = "gce_instance" + base_url = f"https://console.cloud.google.com/logs/query;query=resource.type%3D%22{resource_type}%22%0Aresource.labels.instance_id%3D%22" + url = f"{base_url}{instance_id}%22?project={project}" + + log_entry = { + "instance_name": instance_name, + "instance_id": instance_id, + "logs_url": url, + } + + return workstation_id, log_entry + + +def get_logger(): + """ + Set log level from LOG_LEVEL environment variable, default to INFO. + + This is useful for debugging purpose. + The value of LOG_LEVEL should be one of these: 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'. + """ + log_level = os.getenv("LOG_LEVEL", "INFO").upper() + logger = logging.getLogger(__name__) + logger.setLevel(log_level) + + # Avoid adding multiple handlers + if not logger.handlers: + handler = logging.StreamHandler() # Log to console + handler.setLevel(log_level) + formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") + handler.setFormatter(formatter) + logger.addHandler(handler) + + return logger + + +if __name__ == "__main__": + pass diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_ConfigManager.py b/tests/test_ConfigManager.py new file mode 100644 index 0000000..00b748a --- /dev/null +++ b/tests/test_ConfigManager.py @@ -0,0 +1,144 @@ +import os +from pathlib import Path + +import pytest + +from workstation.config import ConfigManager + + +@pytest.fixture +def config_manager(): + return ConfigManager() + + +@pytest.fixture +def temp_workstation_dir(tmp_path): + original_home = Path.home() + os.environ["HOME"] = str(tmp_path) + config_manager = ConfigManager() + # Create necessary directories + config_manager.workstation_configs.mkdir(parents=True, exist_ok=True) + yield config_manager + os.environ["HOME"] = str(original_home) + + +def test_write_ssh_config(temp_workstation_dir): + manager = temp_workstation_dir + + # Test basic functionality + name = "test_workstation" + user = "test_user" + project = "test_project" + cluster = "test_cluster" + config = "test_config" + region = "test_region" + + manager.write_ssh_config(name, user, project, cluster, config, region) + + config_file_path = manager.workstation_configs / f"{name}.config" + assert config_file_path.exists() + + with open(config_file_path, "r") as f: + content = f.read() + + assert f"Host {name}" in content + assert f"User {user}" in content + assert ( + f"ProxyCommand sh -c 'cleanup() {{ pkill -P $$; }}; trap cleanup EXIT; gcloud workstations start-tcp-tunnel --project={project} --cluster={cluster} --config={config} --region={region} --local-host-port=localhost:%p %h 22 & timeout=10; while ! nc -z localhost %p; do sleep 1; timeout=$((timeout - 1)); if [ $timeout -le 0 ]; then exit 1; fi; done; nc localhost %p'" + in content + ) + + manager = temp_workstation_dir + + name1 = "workstation1" + user1 = "user1" + project1 = "project1" + cluster1 = "cluster1" + config1 = "config1" + region1 = "region1" + + name2 = "workstation2" + user2 = "user2" + project2 = "project2" + cluster2 = "cluster2" + config2 = "config2" + region2 = "region2" + + manager.write_ssh_config(name1, user1, project1, cluster1, config1, region1) + + config_file_path1 = manager.workstation_configs / f"{name1}.config" + with open(config_file_path1, "r") as f: + content1 = f.read() + assert user1 in content1 + + manager.write_ssh_config(name2, user2, project2, cluster2, config2, region2) + + config_file_path2 = manager.workstation_configs / f"{name2}.config" + with open(config_file_path2, "r") as f: + content2 = f.read() + assert user2 in content2 + + +def test_write_ssh_config_no_existing_configs(temp_workstation_dir): + manager = temp_workstation_dir + + name = "workstation_no_existing" + user = "user_no_existing" + project = "project_no_existing" + cluster = "cluster_no_existing" + config = "config_no_existing" + region = "region_no_existing" + + manager.write_ssh_config(name, user, project, cluster, config, region) + + config_file_path = manager.workstation_configs / f"{name}.config" + with open(config_file_path, "r") as f: + content = f.read() + + assert user in content + assert cluster in content + assert config in content + assert region in content + assert project in content + assert name in content + + +def test_read_and_write_configuration(temp_workstation_dir): + manager = temp_workstation_dir + + name = "test_workstation" + project = "test_project" + location = "test_location" + cluster = "test_cluster" + config = "test_config" + + manager.write_configuration(project, name, location, cluster, config) + + contents = manager.read_configuration(name) + + assert contents["project"] == project + assert contents["name"] == name + assert contents["location"] == location + assert contents["cluster"] == cluster + assert contents["config"] == config + + +def test_delete_configuration(temp_workstation_dir): + manager = temp_workstation_dir + + name = "test_workstation" + project = "test_project" + location = "test_location" + cluster = "test_cluster" + config = "test_config" + + manager.write_configuration(project, name, location, cluster, config) + manager.write_ssh_config(name, "test_user", project, cluster, config, location) + + manager.delete_configuration(name) + + yml_file_path = manager.workstation_configs / f"{name}.yml" + config_file_path = manager.workstation_configs / f"{name}.config" + + assert not yml_file_path.exists() + assert not config_file_path.exists() diff --git a/tests/test_crud.py b/tests/test_crud.py new file mode 100644 index 0000000..43216f5 --- /dev/null +++ b/tests/test_crud.py @@ -0,0 +1,73 @@ +import getpass +from unittest.mock import MagicMock, patch + +import pytest +from click.testing import CliRunner + +from workstation.cli import crud + + +@patch("workstation.cli.crud.list_workstation_configs") +@patch("workstation.cli.crud.check_gcloud_auth") +@patch("workstation.cli.crud.get_gcloud_config") +def test_list_configs( + mock_get_gcloud_config, mock_check_gcloud_auth, mock_list_workstation_configs +): + runner = CliRunner() + mock_get_gcloud_config.return_value = ( + "test-project", + "us-central1", + "test-account", + ) + mock_check_gcloud_auth.return_value = True + mock_list_workstation_configs.return_value = [ + { + "name": "config/config1", + "image": "img", + "machine_type": "type_a", + "machine_specs": "spec_a", + "idle_timeout": 360, + "max_runtime": 720, + } + ] + + result = runner.invoke(crud.list_configs) + print(result.output) + + assert result.exit_code == 0 + assert "config1" in result.output + + +@patch("workstation.cli.crud.list_workstations") +@patch("workstation.cli.crud.check_gcloud_auth") +@patch("workstation.cli.crud.get_gcloud_config") +def test_list(mock_get_gcloud_config, mock_check_gcloud_auth, mock_list_workstations): + runner = CliRunner() + mock_get_gcloud_config.return_value = ( + "test-project", + "us-central1", + "test-account", + ) + mock_check_gcloud_auth.return_value = True + workstation_state = MagicMock() + workstation_state.name = "STATE_RUNNING" + mock_list_workstations.return_value = [ + { + "name": "workstation1", + "state": workstation_state, + "env": {"LDAP": "test-user"}, + "config": { + "image": "test-image", + "machine_type": "n1-standard-4", + "idle_timeout": 3600, + "max_runtime": 7200, + }, + } + ] + + result = runner.invoke(crud.list, ["--user", "test-user"]) + print(result.output) + + assert result.exit_code == 0 + assert "workstation1" in result.output + assert "\u25b6 Running" in result.output diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..3e857c5 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,56 @@ +import pytest +from pytest_mock import MockerFixture + +from workstation.utils import get_instance_assignment, process_entry + + +def test_process_entry(mocker: MockerFixture): + # Mocking a log entry object + entry = mocker.MagicMock() + entry.resource.labels.get.return_value = "workstation-id" + entry.labels.get.side_effect = ["instance-name", "instance-id"] + + project = "test-project" + + # Expected URL + resource_type = "gce_instance" + base_url = f"https://console.cloud.google.com/logs/query;query=resource.type%3D%22{resource_type}%22%0Aresource.labels.instance_id%3D%22" + expected_url = f"{base_url}instance-id%22?project={project}" + + # Call the function + workstation_id, log_entry = process_entry(entry, project) + + # Assertions + assert workstation_id == "workstation-id" + assert log_entry == { + "instance_name": "instance-name", + "instance_id": "instance-id", + "logs_url": expected_url, + } + + +def test_get_instance_assignment(mocker: MockerFixture): + # Mock the check_gcloud_auth function + mocker.patch("workstation.utils.check_gcloud_auth", return_value=True) + # Mock the Client and its list_entries method + mock_client = mocker.patch("workstation.utils.cloud_logging.Client") + mock_instance = mock_client.return_value + entry_mock = mocker.MagicMock() + entry_mock.resource.labels.get.return_value = "workstation-id" + entry_mock.labels.get.side_effect = ["instance-name", "instance-id"] + mock_instance.list_entries.return_value = [entry_mock] + + project = "test-project" + name = "workstation-id" + + result = get_instance_assignment(project, name) + + expected_result = { + "workstation-id": { + "instance_name": "instance-name", + "instance_id": "instance-id", + "logs_url": "https://console.cloud.google.com/logs/query;query=resource.type%3D%22gce_instance%22%0Aresource.labels.instance_id%3D%22instance-id%22?project=test-project", + } + } + + assert result == expected_result