diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile
new file mode 100644
index 0000000000..73f1b9f237
--- /dev/null
+++ b/.devcontainer/Dockerfile
@@ -0,0 +1,27 @@
+# syntax=docker/dockerfile:1
+FROM debian:bookworm-slim
+
+RUN apt-get update && apt-get install -y \
+ libxkbcommon0 \
+ ca-certificates \
+ make \
+ curl \
+ git \
+ unzip \
+ libc++1 \
+ vim \
+ termcap \
+ && apt-get clean autoclean
+
+RUN curl -sSf https://rye-up.com/get | RYE_VERSION="0.15.2" RYE_INSTALL_OPTION="--yes" bash
+ENV PATH=/root/.rye/shims:$PATH
+
+WORKDIR /workspace
+
+COPY README.md .python-version pyproject.toml requirements.lock requirements-dev.lock /workspace/
+
+RUN rye sync --all-features
+
+COPY . /workspace
+
+CMD ["rye", "shell"]
diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json
new file mode 100644
index 0000000000..d55fc4d671
--- /dev/null
+++ b/.devcontainer/devcontainer.json
@@ -0,0 +1,20 @@
+// For format details, see https://aka.ms/devcontainer.json. For config options, see the
+// README at: https://github.com/devcontainers/templates/tree/main/src/debian
+{
+ "name": "Debian",
+ "build": {
+ "dockerfile": "Dockerfile"
+ }
+
+ // Features to add to the dev container. More info: https://containers.dev/features.
+ // "features": {},
+
+ // Use 'forwardPorts' to make a list of ports inside the container available locally.
+ // "forwardPorts": [],
+
+ // Configure tool-specific properties.
+ // "customizations": {},
+
+ // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
+ // "remoteUser": "root"
+}
diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml
deleted file mode 100644
index 300ad9f0ae..0000000000
--- a/.github/ISSUE_TEMPLATE/bug_report.yml
+++ /dev/null
@@ -1,56 +0,0 @@
-name: Bug report
-description: Create a report to help us improve
-labels: ["bug"]
-body:
- - type: markdown
- attributes:
- value: |
- Thanks for taking the time to fill out this bug report! If you have questions about using the OpenAI Python library, please post on our [Community forum](https://community.openai.com).
- - type: textarea
- id: what-happened
- attributes:
- label: Describe the bug
- description: A clear and concise description of what the bug is, and any additional context.
- placeholder: Tell us what you see!
- validations:
- required: true
- - type: textarea
- id: repro-steps
- attributes:
- label: To Reproduce
- description: Steps to reproduce the behavior.
- placeholder: |
- 1. Fetch a '...'
- 2. Update the '....'
- 3. See error
- validations:
- required: true
- - type: textarea
- id: code-snippets
- attributes:
- label: Code snippets
- description: If applicable, add code snippets to help explain your problem.
- render: Python
- validations:
- required: false
- - type: input
- id: os
- attributes:
- label: OS
- placeholder: macOS
- validations:
- required: true
- - type: input
- id: language-version
- attributes:
- label: Python version
- placeholder: Python v3.7.1
- validations:
- required: true
- - type: input
- id: lib-version
- attributes:
- label: Library version
- placeholder: openai-python v0.26.4
- validations:
- required: true
diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml
deleted file mode 100644
index 5bedf975eb..0000000000
--- a/.github/ISSUE_TEMPLATE/config.yml
+++ /dev/null
@@ -1,7 +0,0 @@
-blank_issues_enabled: false
-contact_links:
- - name: OpenAI support
- url: https://help.openai.com/
- about: |
- Please only file issues here that you believe represent actual bugs or feature requests for the OpenAI Python library.
- If you're having general trouble with the OpenAI API, ChatGPT, etc, please visit our help center to get support.
diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml
deleted file mode 100644
index 2bd1c635ba..0000000000
--- a/.github/ISSUE_TEMPLATE/feature_request.yml
+++ /dev/null
@@ -1,20 +0,0 @@
-name: Feature request
-description: Suggest an idea for this library
-labels: ["feature-request"]
-body:
- - type: markdown
- attributes:
- value: |
- Thanks for taking the time to fill out this feature request! Please note, we are not able to accommodate all feature requests given limited bandwidth but we appreciate you taking the time to share with us how to improve the OpenAI Python library.
- - type: textarea
- id: feature
- attributes:
- label: Describe the feature or improvement you're requesting
- description: A clear and concise description of what you want to happen.
- validations:
- required: true
- - type: textarea
- id: context
- attributes:
- label: Additional context
- description: Add any other context about the feature request here.
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
new file mode 100644
index 0000000000..c031d9a1d1
--- /dev/null
+++ b/.github/workflows/ci.yml
@@ -0,0 +1,41 @@
+name: CI
+on:
+ push:
+ branches:
+ - main
+ pull_request:
+ branches:
+ - main
+
+jobs:
+ lint:
+ name: lint
+ runs-on: ubuntu-latest
+ if: github.repository == 'openai/openai-python'
+
+ steps:
+ - uses: actions/checkout@v3
+
+ - name: Install Rye
+ run: |
+ curl -sSf https://rye-up.com/get | bash
+ echo "$HOME/.rye/shims" >> $GITHUB_PATH
+ env:
+ RYE_VERSION: 0.15.2
+ RYE_INSTALL_OPTION: "--yes"
+
+ - name: Install dependencies
+ run: |
+ rye sync --all-features
+
+ - name: Run ruff
+ run: |
+ rye run check:ruff
+
+ - name: Run type checking
+ run: |
+ rye run typecheck
+
+ - name: Ensure importable
+ run: |
+ rye run python -c 'import openai'
diff --git a/.gitignore b/.gitignore
index 7ad641a0c8..a4b2f8c0bd 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,12 +1,14 @@
-*.egg-info
-.idea
-.python-version
-/public/dist
+.vscode
+_dev
+
__pycache__
-build
-*.egg
-.vscode/settings.json
-.ipynb_checkpoints
-.vscode/launch.json
-examples/azure/training.jsonl
-examples/azure/validation.jsonl
+.mypy_cache
+
+dist
+
+.venv
+.idea
+
+.env
+.envrc
+codegen.log
diff --git a/.python-version b/.python-version
new file mode 100644
index 0000000000..43077b2460
--- /dev/null
+++ b/.python-version
@@ -0,0 +1 @@
+3.9.18
diff --git a/.stats.yml b/.stats.yml
new file mode 100644
index 0000000000..f21eb8fef0
--- /dev/null
+++ b/.stats.yml
@@ -0,0 +1 @@
+configured_endpoints: 28
diff --git a/LICENSE b/LICENSE
index 4f14854c32..7b1b36a644 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,21 +1,201 @@
-The MIT License
-
-Copyright (c) OpenAI (https://openai.com)
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in
-all copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-THE SOFTWARE.
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2023 OpenAI
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/Makefile b/Makefile
deleted file mode 100644
index b3ef11eea1..0000000000
--- a/Makefile
+++ /dev/null
@@ -1,11 +0,0 @@
-.PHONY: build upload
-
-build:
- rm -rf dist/ build/
- python -m pip install build
- python -m build .
-
-upload:
- python -m pip install twine
- python -m twine upload dist/openai-*
- rm -rf dist
diff --git a/README.md b/README.md
index 615160b3a4..a27375d598 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,16 @@
-# OpenAI Python Library
+# OpenAI Python API library
-The OpenAI Python library provides convenient access to the OpenAI API
-from applications written in the Python language. It includes a
-pre-defined set of classes for API resources that initialize
-themselves dynamically from API responses which makes it compatible
-with a wide range of versions of the OpenAI API.
+[![PyPI version](https://img.shields.io/pypi/v/openai.svg)](https://pypi.org/project/openai/)
-You can find usage examples for the OpenAI Python library in our [API reference](https://platform.openai.com/docs/api-reference?lang=python) and the [OpenAI Cookbook](https://github.com/openai/openai-cookbook/).
+The OpenAI Python library provides convenient access to the OpenAI REST API from any Python 3.7+
+application. The library includes type definitions for all request params and response fields,
+and offers both synchronous and asynchronous clients powered by [httpx](https://github.com/encode/httpx).
+
+It is generated from our [OpenAPI specification](https://github.com/openai/openai-openapi) with [Stainless](https://stainlessapi.com/).
+
+## Documentation
+
+The API documentation can be found [here](https://platform.openai.com/docs).
## Beta Release
@@ -23,255 +27,483 @@ And follow along with the [beta release notes](https://github.com/openai/openai-
## Installation
-To start, ensure you have Python 3.7.1 or newer. If you just
-want to use the package, run:
-
```sh
-pip install --upgrade openai
+pip install --pre openai
```
-After you have installed the package, import it at the top of a file:
+## Usage
+
+The full API of this library can be found in [api.md](https://www.github.com/openai/openai-python/blob/main/api.md).
```python
-import openai
+from openai import OpenAI
+
+client = OpenAI(
+ # defaults to os.environ.get("OPENAI_API_KEY")
+ api_key="My API Key",
+)
+
+chat_completion = client.chat.completions.create(
+ messages=[
+ {
+ "role": "user",
+ "content": "Say this is a test",
+ }
+ ],
+ model="gpt-3.5-turbo",
+)
```
-To install this package from source to make modifications to it, run the following command from the root of the repository:
+While you can provide an `api_key` keyword argument,
+we recommend using [python-dotenv](https://pypi.org/project/python-dotenv/)
+to add `OPENAI_API_KEY="My API Key"` to your `.env` file
+so that your API Key is not stored in source control.
-```sh
-python setup.py install
-```
+## Async usage
-### Optional dependencies
+Simply import `AsyncOpenAI` instead of `OpenAI` and use `await` with each API call:
-Install dependencies for [`openai.embeddings_utils`](openai/embeddings_utils.py):
+```python
+import asyncio
+from openai import AsyncOpenAI
-```sh
-pip install openai[embeddings]
-```
+client = AsyncOpenAI(
+ # defaults to os.environ.get("OPENAI_API_KEY")
+ api_key="My API Key",
+)
-Install support for [Weights & Biases](https://wandb.me/openai-docs) which can be used for fine-tuning:
-```sh
-pip install openai[wandb]
-```
+async def main() -> None:
+ chat_completion = await client.chat.completions.create(
+ messages=[
+ {
+ "role": "user",
+ "content": "Say this is a test",
+ }
+ ],
+ model="gpt-3.5-turbo",
+ )
-Data libraries like `numpy` and `pandas` are not installed by default due to their size. They’re needed for some functionality of this library, but generally not for talking to the API. If you encounter a `MissingDependencyError`, install them with:
-```sh
-pip install openai[datalib]
+asyncio.run(main())
```
-## Usage
+Functionality between the synchronous and asynchronous clients is otherwise identical.
+
+## Streaming Responses
+
+We provide support for streaming responses using Server Side Events (SSE).
-The library needs to be configured with your OpenAI account's private API key which is available on our [developer platform](https://platform.openai.com/account/api-keys). Either set it as the `OPENAI_API_KEY` environment variable before using the library:
+```python
+from openai import OpenAI
+
+client = OpenAI()
-```bash
-export OPENAI_API_KEY='sk-...'
+stream = client.chat.completions.create(
+ model="gpt-4",
+ messages=[{"role": "user", "content": "Say this is a test"}],
+ stream=True,
+)
+for part in stream:
+ print(part.choices[0].delta.content or "")
```
-Or set `openai.api_key` to its value:
+The async client uses the exact same interface.
```python
-openai.api_key = "sk-..."
+from openai import AsyncOpenAI
+
+client = AsyncOpenAI()
+
+stream = await client.chat.completions.create(
+ prompt="Say this is a test",
+ messages=[{"role": "user", "content": "Say this is a test"}],
+ stream=True,
+)
+async for part in stream:
+ print(part.choices[0].delta.content or "")
```
-Examples of how to use this library to accomplish various tasks can be found in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook/). It contains code examples for: classification using fine-tuning, clustering, code search, customizing embeddings, question answering from a corpus of documents. recommendations, visualization of embeddings, and more.
+## Module-level client
-Most endpoints support a `request_timeout` param. This param takes a `Union[float, Tuple[float, float]]` and will raise an `openai.error.Timeout` error if the request exceeds that time in seconds (See: https://requests.readthedocs.io/en/latest/user/quickstart/#timeouts).
+> [!IMPORTANT]
+> We highly recommend instantiating client instances instead of relying on the global client.
-### Chat completions
+We also expose a global client instance that is accessible in a similar fashion to versions prior to v1.
-Chat models such as `gpt-3.5-turbo` and `gpt-4` can be called using the [chat completions endpoint](https://platform.openai.com/docs/api-reference/chat/create).
+```py
+import openai
-```python
-completion = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world"}])
+# optional; defaults to `os.environ['OPENAI_API_KEY']`
+openai.api_key = '...'
+
+# all client options can be configured just like the `OpenAI` instantiation counterpart
+openai.base_url = "https://..."
+openai.default_headers = {"x-foo": "true"}
+
+completion = openai.chat.completions.create(
+ model="gpt-4",
+ messages=[
+ {
+ "role": "user",
+ "content": "How do I output all files in a directory using Python?",
+ },
+ ],
+)
print(completion.choices[0].message.content)
```
-You can learn more in our [chat completions guide](https://platform.openai.com/docs/guides/gpt/chat-completions-api).
+The API is the exact same as the standard client instance based API.
-### Completions
+This is intended to be used within REPLs or notebooks for faster iteration, **not** in application code.
-Text models such as `babbage-002` or `davinci-002` (and our [legacy completions models](https://platform.openai.com/docs/deprecations/deprecation-history)) can be called using the completions endpoint.
+We recommend that you always instantiate a client (e.g., with `client = OpenAI()`) in application code because:
-```python
-completion = openai.Completion.create(model="davinci-002", prompt="Hello world")
-print(completion.choices[0].text)
-```
+- It can be difficult to reason about where client options are configured
+- It's not possible to change certain client options without potentially causing race conditions
+- It's harder to mock for testing purposes
+- It's not possible to control cleanup of network connections
+
+## Using types
+
+Nested request parameters are [TypedDicts](https://docs.python.org/3/library/typing.html#typing.TypedDict). Responses are [Pydantic models](https://docs.pydantic.dev), which provide helper methods for things like serializing back into JSON ([v1](https://docs.pydantic.dev/1.10/usage/models/), [v2](https://docs.pydantic.dev/latest/usage/serialization/)). To get a dictionary, call `model.model_dump()`.
-You can learn more in our [completions guide](https://platform.openai.com/docs/guides/gpt/completions-api).
+Typed requests and responses provide autocomplete and documentation within your editor. If you would like to see type errors in VS Code to help catch bugs earlier, set `python.analysis.typeCheckingMode` to `basic`.
-### Embeddings
+## Pagination
-Embeddings are designed to measure the similarity or relevance between text strings. To get an embedding for a text string, you can use following:
+List methods in the OpenAI API are paginated.
+
+This library provides auto-paginating iterators with each list response, so you do not have to request successive pages manually:
```python
-text_string = "sample text"
+import openai
-model_id = "text-embedding-ada-002"
+client = OpenAI()
-embedding = openai.Embedding.create(input=text_string, model=model_id)['data'][0]['embedding']
+all_jobs = []
+# Automatically fetches more pages as needed.
+for job in client.fine_tuning.jobs.list(
+ limit=20,
+):
+ # Do something with job here
+ all_jobs.append(job)
+print(all_jobs)
```
-You can learn more in our [embeddings guide](https://platform.openai.com/docs/guides/embeddings/embeddings).
+Or, asynchronously:
-### Fine-tuning
+```python
+import asyncio
+import openai
-Fine-tuning a model on training data can both improve the results (by giving the model more examples to learn from) and lower the cost/latency of API calls by reducing the need to include training examples in prompts.
+client = AsyncOpenAI()
-```python
-# Create a fine-tuning job with an already uploaded file
-openai.FineTuningJob.create(training_file="file-abc123", model="gpt-3.5-turbo")
-# List 10 fine-tuning jobs
-openai.FineTuningJob.list(limit=10)
+async def main() -> None:
+ all_jobs = []
+ # Iterate through items across all pages, issuing requests as needed.
+ async for job in client.fine_tuning.jobs.list(
+ limit=20,
+ ):
+ all_jobs.append(job)
+ print(all_jobs)
-# Retrieve the state of a fine-tune
-openai.FineTuningJob.retrieve("ft-abc123")
-# Cancel a job
-openai.FineTuningJob.cancel("ft-abc123")
+asyncio.run(main())
+```
-# List up to 10 events from a fine-tuning job
-openai.FineTuningJob.list_events(id="ft-abc123", limit=10)
+Alternatively, you can use the `.has_next_page()`, `.next_page_info()`, or `.get_next_page()` methods for more granular control working with pages:
-# Delete a fine-tuned model (must be an owner of the org the model was created in)
-openai.Model.delete("ft:gpt-3.5-turbo:acemeco:suffix:abc123")
+```python
+first_page = await client.fine_tuning.jobs.list(
+ limit=20,
+)
+if first_page.has_next_page():
+ print(f"will fetch next page using these details: {first_page.next_page_info()}")
+ next_page = await first_page.get_next_page()
+ print(f"number of items we just fetched: {len(next_page.data)}")
+
+# Remove `await` for non-async usage.
```
-You can learn more in our [fine-tuning guide](https://platform.openai.com/docs/guides/fine-tuning).
+Or just work directly with the returned data:
-To log the training results from fine-tuning to Weights & Biases use:
+```python
+first_page = await client.fine_tuning.jobs.list(
+ limit=20,
+)
-```
-openai wandb sync
-```
+print(f"next page cursor: {first_page.after}") # => "next page cursor: ..."
+for job in first_page.data:
+ print(job.id)
-For more information, read the [wandb documentation](https://docs.wandb.ai/guides/integrations/openai) on Weights & Biases.
+# Remove `await` for non-async usage.
+```
-### Moderation
+## Nested params
-OpenAI provides a free Moderation endpoint that can be used to check whether content complies with the OpenAI [content policy](https://platform.openai.com/docs/usage-policies).
+Nested parameters are dictionaries, typed using `TypedDict`, for example:
```python
-moderation_resp = openai.Moderation.create(input="Here is some perfectly innocuous text that follows all OpenAI content policies.")
-```
+from openai import OpenAI
-You can learn more in our [moderation guide](https://platform.openai.com/docs/guides/moderation).
+client = OpenAI()
+
+page = client.files.list()
+```
-### Image generation (DALL·E)
+## File Uploads
-DALL·E is a generative image model that can create new images based on a prompt.
+Request parameters that correspond to file uploads can be passed as `bytes`, a [`PathLike`](https://docs.python.org/3/library/os.html#os.PathLike) instance or a tuple of `(filename, contents, media type)`.
```python
-image_resp = openai.Image.create(prompt="two dogs playing chess, oil painting", n=4, size="512x512")
+from pathlib import Path
+from openai import OpenAI
+
+client = OpenAI()
+
+client.files.create(
+ file=Path("input.jsonl"),
+ purpose="fine-tune",
+)
```
-You can learn more in our [image generation guide](https://platform.openai.com/docs/guides/images).
+The async client uses the exact same interface. If you pass a [`PathLike`](https://docs.python.org/3/library/os.html#os.PathLike) instance, the file contents will be read asynchronously automatically.
-### Audio (Whisper)
+## Handling errors
-The speech to text API provides two endpoints, transcriptions and translations, based on our state-of-the-art [open source large-v2 Whisper model](https://github.com/openai/whisper).
+When the library is unable to connect to the API (for example, due to network connection problems or a timeout), a subclass of `openai.APIConnectionError` is raised.
-```python
-f = open("path/to/file.mp3", "rb")
-transcript = openai.Audio.transcribe("whisper-1", f)
+When the API returns a non-success status code (that is, 4xx or 5xx
+response), a subclass of `openai.APIStatusError` is raised, containing `status_code` and `response` properties.
-transcript = openai.Audio.translate("whisper-1", f)
+All errors inherit from `openai.APIError`.
+
+```python
+import openai
+from openai import OpenAI
+
+client = OpenAI()
+
+try:
+ client.fine_tunes.create(
+ training_file="file-XGinujblHPwGLSztz8cPS8XY",
+ )
+except openai.APIConnectionError as e:
+ print("The server could not be reached")
+ print(e.__cause__) # an underlying Exception, likely raised within httpx.
+except openai.RateLimitError as e:
+ print("A 429 status code was received; we should back off a bit.")
+except openai.APIStatusError as e:
+ print("Another non-200-range status code was received")
+ print(e.status_code)
+ print(e.response)
```
-You can learn more in our [speech to text guide](https://platform.openai.com/docs/guides/speech-to-text).
+Error codes are as followed:
+
+| Status Code | Error Type |
+| ----------- | -------------------------- |
+| 400 | `BadRequestError` |
+| 401 | `AuthenticationError` |
+| 403 | `PermissionDeniedError` |
+| 404 | `NotFoundError` |
+| 422 | `UnprocessableEntityError` |
+| 429 | `RateLimitError` |
+| >=500 | `InternalServerError` |
+| N/A | `APIConnectionError` |
+
+### Retries
-### Async API
+Certain errors are automatically retried 2 times by default, with a short exponential backoff.
+Connection errors (for example, due to a network connectivity problem), 408 Request Timeout, 409 Conflict,
+429 Rate Limit, and >=500 Internal errors are all retried by default.
-Async support is available in the API by prepending `a` to a network-bound method:
+You can use the `max_retries` option to configure or disable retry settings:
```python
-async def create_chat_completion():
- chat_completion_resp = await openai.ChatCompletion.acreate(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world"}])
+from openai import OpenAI
+
+# Configure the default for all requests:
+client = OpenAI(
+ # default is 2
+ max_retries=0,
+)
+
+# Or, configure per-request:
+client.with_options(max_retries=5).chat.completions.create(
+ messages=[
+ {
+ "role": "user",
+ "content": "How can I get the name of the current day in Node.js?",
+ }
+ ],
+ model="gpt-3.5-turbo",
+)
```
-To make async requests more efficient, you can pass in your own
-`aiohttp.ClientSession`, but you must manually close the client session at the end
-of your program/event loop:
+### Timeouts
-```python
-from aiohttp import ClientSession
-openai.aiosession.set(ClientSession())
+By default requests time out after 10 minutes. You can configure this with a `timeout` option,
+which accepts a float or an [`httpx.Timeout`](https://www.python-httpx.org/advanced/#fine-tuning-the-configuration) object:
-# At the end of your program, close the http session
-await openai.aiosession.get().close()
+```python
+from openai import OpenAI
+
+# Configure the default for all requests:
+client = OpenAI(
+ # default is 60s
+ timeout=20.0,
+)
+
+# More granular control:
+client = OpenAI(
+ timeout=httpx.Timeout(60.0, read=5.0, write=10.0, connect=2.0),
+)
+
+# Override per-request:
+client.with_options(timeout=5 * 1000).chat.completions.create(
+ messages=[
+ {
+ "role": "user",
+ "content": "How can I list all files in a directory using Python?",
+ }
+ ],
+ model="gpt-3.5-turbo",
+)
```
-### Command-line interface
+On timeout, an `APITimeoutError` is thrown.
-This library additionally provides an `openai` command-line utility
-which makes it easy to interact with the API from your terminal. Run
-`openai api -h` for usage.
+Note that requests that time out are [retried twice by default](#retries).
-```sh
-# list models
-openai api models.list
+## Advanced
-# create a chat completion (gpt-3.5-turbo, gpt-4, etc.)
-openai api chat_completions.create -m gpt-3.5-turbo -g user "Hello world"
+### Logging
-# create a completion (text-davinci-003, text-davinci-002, ada, babbage, curie, davinci, etc.)
-openai api completions.create -m ada -p "Hello world"
+We use the standard library [`logging`](https://docs.python.org/3/library/logging.html) module.
-# generate images via DALL·E API
-openai api image.create -p "two dogs playing chess, cartoon" -n 1
+You can enable logging by setting the environment variable `OPENAI_LOG` to `debug`.
-# using openai through a proxy
-openai --proxy=http://proxy.com api models.list
+```shell
+$ export OPENAI_LOG=debug
```
-### Microsoft Azure Endpoints
+### How to tell whether `None` means `null` or missing
-In order to use the library with Microsoft Azure endpoints, you need to set the `api_type`, `api_base` and `api_version` in addition to the `api_key`. The `api_type` must be set to 'azure' and the others correspond to the properties of your endpoint.
-In addition, the deployment name must be passed as the `deployment_id` parameter.
+In an API response, a field may be explicitly `null`, or missing entirely; in either case, its value is `None` in this library. You can differentiate the two cases with `.model_fields_set`:
-```python
-import openai
-openai.api_type = "azure"
-openai.api_key = "..."
-openai.api_base = "https://example-endpoint.openai.azure.com"
-openai.api_version = "2023-05-15"
+```py
+if response.my_field is None:
+ if 'my_field' not in response.model_fields_set:
+ print('Got json like {}, without a "my_field" key present at all.')
+ else:
+ print('Got json like {"my_field": null}.')
+```
-# create a chat completion
-chat_completion = openai.ChatCompletion.create(deployment_id="deployment-name", model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world"}])
+### Accessing raw response data (e.g. headers)
-# print the completion
-print(chat_completion.choices[0].message.content)
+The "raw" Response object can be accessed by prefixing `.with_raw_response.` to any HTTP method call.
+
+```py
+from openai import OpenAI
+
+client = OpenAI()
+response = client.chat.completions.with_raw_response.create(
+ messages=[{
+ "role": "user",
+ "content": "Say this is a test",
+ }],
+ model="gpt-3.5-turbo",
+)
+print(response.headers.get('X-My-Header'))
+
+completion = response.parse() # get the object that `chat.completions.create()` would have returned
+print(completion)
```
-Please note that for the moment, the Microsoft Azure endpoints can only be used for completion, embedding, and fine-tuning operations.
-For a detailed example of how to use fine-tuning and other operations using Azure endpoints, please check out the following Jupyter notebooks:
+These methods return an [`APIResponse`](https://github.com/openai/openai-python/tree/v1/src/openai/_response.py) object.
-- [Using Azure completions](https://github.com/openai/openai-cookbook/tree/main/examples/azure/completions.ipynb)
-- [Using Azure chat](https://github.com/openai/openai-cookbook/tree/main/examples/azure/chat.ipynb)
-- [Using Azure embeddings](https://github.com/openai/openai-cookbook/blob/main/examples/azure/embeddings.ipynb)
+### Configuring the HTTP client
-### Microsoft Azure Active Directory Authentication
+You can directly override the [httpx client](https://www.python-httpx.org/api/#client) to customize it for your use case, including:
-In order to use Microsoft Active Directory to authenticate to your Azure endpoint, you need to set the `api_type` to "azure_ad" and pass the acquired credential token to `api_key`. The rest of the parameters need to be set as specified in the previous section.
+- Support for proxies
+- Custom transports
+- Additional [advanced](https://www.python-httpx.org/advanced/#client-instances) functionality
```python
-from azure.identity import DefaultAzureCredential
-import openai
+import httpx
+from openai import OpenAI
+
+client = OpenAI(
+ base_url="http://my.test.server.example.com:8083",
+ http_client=httpx.Client(
+ proxies="http://my.test.proxy.example.com",
+ transport=httpx.HTTPTransport(local_address="0.0.0.0"),
+ ),
+)
+```
+
+### Managing HTTP resources
+
+By default the library closes underlying HTTP connections whenever the client is [garbage collected](https://docs.python.org/3/reference/datamodel.html#object.__del__). You can manually close the client using the `.close()` method if desired, or with a context manager that closes when exiting.
+
+## Microsoft Azure OpenAI
+
+To use this library with [Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/overview), use the `AzureOpenAI`
+class instead of the `OpenAI` class.
+
+> [!IMPORTANT]
+> The Azure API shape differs from the core API shape which means that the static types for responses / params
+> won't always be correct.
-# Request credential
-default_credential = DefaultAzureCredential()
-token = default_credential.get_token("https://cognitiveservices.azure.com/.default")
+```py
+from openai import AzureOpenAI
-# Setup parameters
-openai.api_type = "azure_ad"
-openai.api_key = token.token
-openai.api_base = "https://example-endpoint.openai.azure.com/"
-openai.api_version = "2023-05-15"
+# gets the API Key from environment variable AZURE_OPENAI_API_KEY
+client = AzureOpenAI(
+ # https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning
+ api_version="2023-07-01-preview"
+ # https://learn.microsoft.com/en-us/azure/cognitive-services/openai/how-to/create-resource?pivots=web-portal#create-a-resource
+ azure_endpoint="https://example-endpoint.openai.azure.com",
+)
+
+completion = client.chat.completions.create(
+ model="deployment-name", # e.g. gpt-35-instant
+ messages=[
+ {
+ "role": "user",
+ "content": "How do I output all files in a directory using Python?",
+ },
+ ],
+)
+print(completion.model_dump_json(indent=2))
```
-## Credit
+In addition to the options provided in the base `OpenAI` client, the following options are provided:
+
+- `azure_endpoint`
+- `azure_deployment`
+- `api_version`
+- `azure_ad_token`
+- `azure_ad_token_provider`
+
+An example of using the client with Azure Active Directory can be found [here](https://github.com/openai/openai-python/blob/v1/examples/azure_ad.py).
+
+## Versioning
+
+This package generally follows [SemVer](https://semver.org/spec/v2.0.0.html) conventions, though certain backwards-incompatible changes may be released as minor versions:
+
+1. Changes that only affect static types, without breaking runtime behavior.
+2. Changes to library internals which are technically public but not intended or documented for external use. _(Please open a GitHub issue to let us know if you are relying on such internals)_.
+3. Changes that we do not expect to impact the vast majority of users in practice.
+
+We take backwards-compatibility seriously and work hard to ensure you can rely on a smooth upgrade experience.
+
+We are keen for your feedback; please open an [issue](https://www.github.com/openai/openai-python/issues) with questions, bugs, or suggestions.
+
+## Requirements
-This library is forked from the [Stripe Python Library](https://github.com/stripe/stripe-python).
+Python 3.7 or higher.
diff --git a/api.md b/api.md
new file mode 100644
index 0000000000..915a05479a
--- /dev/null
+++ b/api.md
@@ -0,0 +1,172 @@
+# Completions
+
+Types:
+
+```python
+from openai.types import Completion, CompletionChoice, CompletionUsage
+```
+
+Methods:
+
+- client.completions.create(\*\*params) -> Completion
+
+# Chat
+
+## Completions
+
+Types:
+
+```python
+from openai.types.chat import (
+ ChatCompletion,
+ ChatCompletionChunk,
+ ChatCompletionMessage,
+ ChatCompletionMessageParam,
+ ChatCompletionRole,
+)
+```
+
+Methods:
+
+- client.chat.completions.create(\*\*params) -> ChatCompletion
+
+# Edits
+
+Types:
+
+```python
+from openai.types import Edit
+```
+
+Methods:
+
+- client.edits.create(\*\*params) -> Edit
+
+# Embeddings
+
+Types:
+
+```python
+from openai.types import CreateEmbeddingResponse, Embedding
+```
+
+Methods:
+
+- client.embeddings.create(\*\*params) -> CreateEmbeddingResponse
+
+# Files
+
+Types:
+
+```python
+from openai.types import FileContent, FileDeleted, FileObject
+```
+
+Methods:
+
+- client.files.create(\*\*params) -> FileObject
+- client.files.retrieve(file_id) -> FileObject
+- client.files.list() -> SyncPage[FileObject]
+- client.files.delete(file_id) -> FileDeleted
+- client.files.retrieve_content(file_id) -> str
+- client.files.wait_for_processing(\*args) -> FileObject
+
+# Images
+
+Types:
+
+```python
+from openai.types import Image, ImagesResponse
+```
+
+Methods:
+
+- client.images.create_variation(\*\*params) -> ImagesResponse
+- client.images.edit(\*\*params) -> ImagesResponse
+- client.images.generate(\*\*params) -> ImagesResponse
+
+# Audio
+
+## Transcriptions
+
+Types:
+
+```python
+from openai.types.audio import Transcription
+```
+
+Methods:
+
+- client.audio.transcriptions.create(\*\*params) -> Transcription
+
+## Translations
+
+Types:
+
+```python
+from openai.types.audio import Translation
+```
+
+Methods:
+
+- client.audio.translations.create(\*\*params) -> Translation
+
+# Moderations
+
+Types:
+
+```python
+from openai.types import Moderation, ModerationCreateResponse
+```
+
+Methods:
+
+- client.moderations.create(\*\*params) -> ModerationCreateResponse
+
+# Models
+
+Types:
+
+```python
+from openai.types import Model, ModelDeleted
+```
+
+Methods:
+
+- client.models.retrieve(model) -> Model
+- client.models.list() -> SyncPage[Model]
+- client.models.delete(model) -> ModelDeleted
+
+# FineTuning
+
+## Jobs
+
+Types:
+
+```python
+from openai.types.fine_tuning import FineTuningJob, FineTuningJobEvent
+```
+
+Methods:
+
+- client.fine_tuning.jobs.create(\*\*params) -> FineTuningJob
+- client.fine_tuning.jobs.retrieve(fine_tuning_job_id) -> FineTuningJob
+- client.fine_tuning.jobs.list(\*\*params) -> SyncCursorPage[FineTuningJob]
+- client.fine_tuning.jobs.cancel(fine_tuning_job_id) -> FineTuningJob
+- client.fine_tuning.jobs.list_events(fine_tuning_job_id, \*\*params) -> SyncCursorPage[FineTuningJobEvent]
+
+# FineTunes
+
+Types:
+
+```python
+from openai.types import FineTune, FineTuneEvent, FineTuneEventsListResponse
+```
+
+Methods:
+
+- client.fine_tunes.create(\*\*params) -> FineTune
+- client.fine_tunes.retrieve(fine_tune_id) -> FineTune
+- client.fine_tunes.list() -> SyncPage[FineTune]
+- client.fine_tunes.cancel(fine_tune_id) -> FineTune
+- client.fine_tunes.list_events(fine_tune_id, \*\*params) -> FineTuneEventsListResponse
diff --git a/bin/blacken-docs.py b/bin/blacken-docs.py
new file mode 100644
index 0000000000..45d0ad1225
--- /dev/null
+++ b/bin/blacken-docs.py
@@ -0,0 +1,251 @@
+# fork of https://github.com/asottile/blacken-docs implementing https://github.com/asottile/blacken-docs/issues/170
+from __future__ import annotations
+
+import re
+import argparse
+import textwrap
+import contextlib
+from typing import Match, Optional, Sequence, Generator, NamedTuple, cast
+
+import black
+from black.mode import TargetVersion
+from black.const import DEFAULT_LINE_LENGTH
+
+MD_RE = re.compile(
+ r"(?P^(?P *)```\s*python\n)" r"(?P.*?)" r"(?P^(?P=indent)```\s*$)",
+ re.DOTALL | re.MULTILINE,
+)
+MD_PYCON_RE = re.compile(
+ r"(?P^(?P *)```\s*pycon\n)" r"(?P.*?)" r"(?P^(?P=indent)```.*$)",
+ re.DOTALL | re.MULTILINE,
+)
+RST_PY_LANGS = frozenset(("python", "py", "sage", "python3", "py3", "numpy"))
+BLOCK_TYPES = "(code|code-block|sourcecode|ipython)"
+DOCTEST_TYPES = "(testsetup|testcleanup|testcode)"
+RST_RE = re.compile(
+ rf"(?P"
+ rf"^(?P *)\.\. ("
+ rf"jupyter-execute::|"
+ rf"{BLOCK_TYPES}:: (?P\w+)|"
+ rf"{DOCTEST_TYPES}::.*"
+ rf")\n"
+ rf"((?P=indent) +:.*\n)*"
+ rf"\n*"
+ rf")"
+ rf"(?P(^((?P=indent) +.*)?\n)+)",
+ re.MULTILINE,
+)
+RST_PYCON_RE = re.compile(
+ r"(?P"
+ r"(?P *)\.\. ((code|code-block):: pycon|doctest::.*)\n"
+ r"((?P=indent) +:.*\n)*"
+ r"\n*"
+ r")"
+ r"(?P(^((?P=indent) +.*)?(\n|$))+)",
+ re.MULTILINE,
+)
+PYCON_PREFIX = ">>> "
+PYCON_CONTINUATION_PREFIX = "..."
+PYCON_CONTINUATION_RE = re.compile(
+ rf"^{re.escape(PYCON_CONTINUATION_PREFIX)}( |$)",
+)
+LATEX_RE = re.compile(
+ r"(?P^(?P *)\\begin{minted}{python}\n)"
+ r"(?P.*?)"
+ r"(?P^(?P=indent)\\end{minted}\s*$)",
+ re.DOTALL | re.MULTILINE,
+)
+LATEX_PYCON_RE = re.compile(
+ r"(?P^(?P *)\\begin{minted}{pycon}\n)" r"(?P.*?)" r"(?P^(?P=indent)\\end{minted}\s*$)",
+ re.DOTALL | re.MULTILINE,
+)
+PYTHONTEX_LANG = r"(?Ppyblock|pycode|pyconsole|pyverbatim)"
+PYTHONTEX_RE = re.compile(
+ rf"(?P^(?P *)\\begin{{{PYTHONTEX_LANG}}}\n)"
+ rf"(?P.*?)"
+ rf"(?P^(?P=indent)\\end{{(?P=lang)}}\s*$)",
+ re.DOTALL | re.MULTILINE,
+)
+INDENT_RE = re.compile("^ +(?=[^ ])", re.MULTILINE)
+TRAILING_NL_RE = re.compile(r"\n+\Z", re.MULTILINE)
+
+
+class CodeBlockError(NamedTuple):
+ offset: int
+ exc: Exception
+
+
+def format_str(
+ src: str,
+ black_mode: black.FileMode,
+) -> tuple[str, Sequence[CodeBlockError]]:
+ errors: list[CodeBlockError] = []
+
+ @contextlib.contextmanager
+ def _collect_error(match: Match[str]) -> Generator[None, None, None]:
+ try:
+ yield
+ except Exception as e:
+ errors.append(CodeBlockError(match.start(), e))
+
+ def _md_match(match: Match[str]) -> str:
+ code = textwrap.dedent(match["code"])
+ with _collect_error(match):
+ code = black.format_str(code, mode=black_mode)
+ code = textwrap.indent(code, match["indent"])
+ return f'{match["before"]}{code}{match["after"]}'
+
+ def _rst_match(match: Match[str]) -> str:
+ lang = match["lang"]
+ if lang is not None and lang not in RST_PY_LANGS:
+ return match[0]
+ min_indent = min(INDENT_RE.findall(match["code"]))
+ trailing_ws_match = TRAILING_NL_RE.search(match["code"])
+ assert trailing_ws_match
+ trailing_ws = trailing_ws_match.group()
+ code = textwrap.dedent(match["code"])
+ with _collect_error(match):
+ code = black.format_str(code, mode=black_mode)
+ code = textwrap.indent(code, min_indent)
+ return f'{match["before"]}{code.rstrip()}{trailing_ws}'
+
+ def _pycon_match(match: Match[str]) -> str:
+ code = ""
+ fragment = cast(Optional[str], None)
+
+ def finish_fragment() -> None:
+ nonlocal code
+ nonlocal fragment
+
+ if fragment is not None:
+ with _collect_error(match):
+ fragment = black.format_str(fragment, mode=black_mode)
+ fragment_lines = fragment.splitlines()
+ code += f"{PYCON_PREFIX}{fragment_lines[0]}\n"
+ for line in fragment_lines[1:]:
+ # Skip blank lines to handle Black adding a blank above
+ # functions within blocks. A blank line would end the REPL
+ # continuation prompt.
+ #
+ # >>> if True:
+ # ... def f():
+ # ... pass
+ # ...
+ if line:
+ code += f"{PYCON_CONTINUATION_PREFIX} {line}\n"
+ if fragment_lines[-1].startswith(" "):
+ code += f"{PYCON_CONTINUATION_PREFIX}\n"
+ fragment = None
+
+ indentation = None
+ for line in match["code"].splitlines():
+ orig_line, line = line, line.lstrip()
+ if indentation is None and line:
+ indentation = len(orig_line) - len(line)
+ continuation_match = PYCON_CONTINUATION_RE.match(line)
+ if continuation_match and fragment is not None:
+ fragment += line[continuation_match.end() :] + "\n"
+ else:
+ finish_fragment()
+ if line.startswith(PYCON_PREFIX):
+ fragment = line[len(PYCON_PREFIX) :] + "\n"
+ else:
+ code += orig_line[indentation:] + "\n"
+ finish_fragment()
+ return code
+
+ def _md_pycon_match(match: Match[str]) -> str:
+ code = _pycon_match(match)
+ code = textwrap.indent(code, match["indent"])
+ return f'{match["before"]}{code}{match["after"]}'
+
+ def _rst_pycon_match(match: Match[str]) -> str:
+ code = _pycon_match(match)
+ min_indent = min(INDENT_RE.findall(match["code"]))
+ code = textwrap.indent(code, min_indent)
+ return f'{match["before"]}{code}'
+
+ def _latex_match(match: Match[str]) -> str:
+ code = textwrap.dedent(match["code"])
+ with _collect_error(match):
+ code = black.format_str(code, mode=black_mode)
+ code = textwrap.indent(code, match["indent"])
+ return f'{match["before"]}{code}{match["after"]}'
+
+ def _latex_pycon_match(match: Match[str]) -> str:
+ code = _pycon_match(match)
+ code = textwrap.indent(code, match["indent"])
+ return f'{match["before"]}{code}{match["after"]}'
+
+ src = MD_RE.sub(_md_match, src)
+ src = MD_PYCON_RE.sub(_md_pycon_match, src)
+ src = RST_RE.sub(_rst_match, src)
+ src = RST_PYCON_RE.sub(_rst_pycon_match, src)
+ src = LATEX_RE.sub(_latex_match, src)
+ src = LATEX_PYCON_RE.sub(_latex_pycon_match, src)
+ src = PYTHONTEX_RE.sub(_latex_match, src)
+ return src, errors
+
+
+def format_file(
+ filename: str,
+ black_mode: black.FileMode,
+ skip_errors: bool,
+) -> int:
+ with open(filename, encoding="UTF-8") as f:
+ contents = f.read()
+ new_contents, errors = format_str(contents, black_mode)
+ for error in errors:
+ lineno = contents[: error.offset].count("\n") + 1
+ print(f"{filename}:{lineno}: code block parse error {error.exc}")
+ if errors and not skip_errors:
+ return 1
+ if contents != new_contents:
+ print(f"{filename}: Rewriting...")
+ with open(filename, "w", encoding="UTF-8") as f:
+ f.write(new_contents)
+ return 0
+ else:
+ return 0
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-l",
+ "--line-length",
+ type=int,
+ default=DEFAULT_LINE_LENGTH,
+ )
+ parser.add_argument(
+ "-t",
+ "--target-version",
+ action="append",
+ type=lambda v: TargetVersion[v.upper()],
+ default=[],
+ help=f"choices: {[v.name.lower() for v in TargetVersion]}",
+ dest="target_versions",
+ )
+ parser.add_argument(
+ "-S",
+ "--skip-string-normalization",
+ action="store_true",
+ )
+ parser.add_argument("-E", "--skip-errors", action="store_true")
+ parser.add_argument("filenames", nargs="*")
+ args = parser.parse_args(argv)
+
+ black_mode = black.FileMode(
+ target_versions=set(args.target_versions),
+ line_length=args.line_length,
+ string_normalization=not args.skip_string_normalization,
+ )
+
+ retv = 0
+ for filename in args.filenames:
+ retv |= format_file(filename, black_mode, skip_errors=args.skip_errors)
+ return retv
+
+
+if __name__ == "__main__":
+ raise SystemExit(main())
diff --git a/bin/check-test-server b/bin/check-test-server
new file mode 100755
index 0000000000..a6fa34950d
--- /dev/null
+++ b/bin/check-test-server
@@ -0,0 +1,50 @@
+#!/usr/bin/env bash
+
+RED='\033[0;31m'
+GREEN='\033[0;32m'
+YELLOW='\033[0;33m'
+NC='\033[0m' # No Color
+
+function prism_is_running() {
+ curl --silent "http://localhost:4010" >/dev/null 2>&1
+}
+
+function is_overriding_api_base_url() {
+ [ -n "$TEST_API_BASE_URL" ]
+}
+
+if is_overriding_api_base_url ; then
+ # If someone is running the tests against the live API, we can trust they know
+ # what they're doing and exit early.
+ echo -e "${GREEN}✔ Running tests against ${TEST_API_BASE_URL}${NC}"
+
+ exit 0
+elif prism_is_running ; then
+ echo -e "${GREEN}✔ Mock prism server is running with your OpenAPI spec${NC}"
+ echo
+
+ exit 0
+else
+ echo -e "${RED}ERROR:${NC} The test suite will not run without a mock Prism server"
+ echo -e "running against your OpenAPI spec."
+ echo
+ echo -e "${YELLOW}To fix:${NC}"
+ echo
+ echo -e "1. Install Prism (requires Node 16+):"
+ echo
+ echo -e " With npm:"
+ echo -e " \$ ${YELLOW}npm install -g @stoplight/prism-cli${NC}"
+ echo
+ echo -e " With yarn:"
+ echo -e " \$ ${YELLOW}yarn global add @stoplight/prism-cli${NC}"
+ echo
+ echo -e "2. Run the mock server"
+ echo
+ echo -e " To run the server, pass in the path of your OpenAPI"
+ echo -e " spec to the prism command:"
+ echo
+ echo -e " \$ ${YELLOW}prism mock path/to/your.openapi.yml${NC}"
+ echo
+
+ exit 1
+fi
diff --git a/bin/test b/bin/test
new file mode 100755
index 0000000000..60ede7a842
--- /dev/null
+++ b/bin/test
@@ -0,0 +1,3 @@
+#!/usr/bin/env bash
+
+bin/check-test-server && rye run pytest "$@"
diff --git a/chatml.md b/chatml.md
deleted file mode 100644
index 6689953adb..0000000000
--- a/chatml.md
+++ /dev/null
@@ -1,96 +0,0 @@
-> [!IMPORTANT]
-> This page is not currently maintained and is intended to provide general insight into the ChatML format, not current up-to-date information.
-
-(This document is a preview of the underlying format consumed by
-GPT models. As a developer, you can use our [higher-level
-API](https://platform.openai.com/docs/guides/chat) and won't need to
-interact directly with this format today — but expect to have the
-option in the future!)
-
-Traditionally, GPT models consumed unstructured text. ChatGPT models
-instead expect a structured format, called Chat Markup Language
-(ChatML for short).
-ChatML documents consist of a sequence of messages. Each message
-contains a header (which today consists of who said it, but in the
-future will contain other metadata) and contents (which today is a
-text payload, but in the future will contain other datatypes).
-We are still evolving ChatML, but the current version (ChatML v0) can
-be represented with our upcoming "list of dicts" JSON format as
-follows:
-```
-[
- {"token": "<|im_start|>"},
- "system\nYou are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible.\nKnowledge cutoff: 2021-09-01\nCurrent date: 2023-03-01",
- {"token": "<|im_end|>"}, "\n", {"token": "<|im_start|>"},
- "user\nHow are you",
- {"token": "<|im_end|>"}, "\n", {"token": "<|im_start|>"},
- "assistant\nI am doing well!",
- {"token": "<|im_end|>"}, "\n", {"token": "<|im_start|>"},
- "user\nHow are you now?",
- {"token": "<|im_end|>"}, "\n"
-]
-```
-You could also represent it in the classic "unsafe raw string"
-format. However, this format inherently allows injections from user
-input containing special-token syntax, similar to SQL injections:
-```
-<|im_start|>system
-You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible.
-Knowledge cutoff: 2021-09-01
-Current date: 2023-03-01<|im_end|>
-<|im_start|>user
-How are you<|im_end|>
-<|im_start|>assistant
-I am doing well!<|im_end|>
-<|im_start|>user
-How are you now?<|im_end|>
-```
-## Non-chat use-cases
-ChatML can be applied to classic GPT use-cases that are not
-traditionally thought of as chat. For example, instruction following
-(where a user requests for the AI to complete an instruction) can be
-implemented as a ChatML query like the following:
-```
-[
- {"token": "<|im_start|>"},
- "user\nList off some good ideas:",
- {"token": "<|im_end|>"}, "\n", {"token": "<|im_start|>"},
- "assistant"
-]
-```
-We do not currently allow autocompleting of partial messages,
-```
-[
- {"token": "<|im_start|>"},
- "system\nPlease autocomplete the user's message.",
- {"token": "<|im_end|>"}, "\n", {"token": "<|im_start|>"},
- "user\nThis morning I decided to eat a giant"
-]
-```
-Note that ChatML makes explicit to the model the source of each piece
-of text, and particularly shows the boundary between human and AI
-text. This gives an opportunity to mitigate and eventually solve
-injections, as the model can tell which instructions come from the
-developer, the user, or its own input.
-## Few-shot prompting
-In general, we recommend adding few-shot examples using separate
-`system` messages with a `name` field of `example_user` or
-`example_assistant`. For example, here is a 1-shot prompt:
-```
-<|im_start|>system
-Translate from English to French
-<|im_end|>
-<|im_start|>system name=example_user
-How are you?
-<|im_end|>
-<|im_start|>system name=example_assistant
-Comment allez-vous?
-<|im_end|>
-<|im_start|>user
-{{user input here}}<|im_end|>
-```
-If adding instructions in the `system` message doesn't work, you can
-also try putting them into a `user` message. (In the near future, we
-will train our models to be much more steerable via the system
-message. But to date, we have trained only on a few system messages,
-so the models pay much more attention to user examples.)
diff --git a/examples/README.md b/examples/README.md
deleted file mode 100644
index ffa3b42709..0000000000
--- a/examples/README.md
+++ /dev/null
@@ -1,7 +0,0 @@
-# Examples have moved to the [OpenAI Cookbook](https://github.com/openai/openai-cookbook/)
-
-Looking for code examples? Visit the [OpenAI Cookbook](https://github.com/openai/openai-cookbook/), which shares examples of how to use the OpenAI Python library to accomplish common tasks.
-
-Prior to July 2022, code examples were hosted in this examples folder; going forward, code examples will be hosted in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook/).
-
-This separation will help keep the [OpenAI Python library](https://github.com/openai/openai-python) simple and small, without extra files or dependencies.
diff --git a/examples/async_demo.py b/examples/async_demo.py
new file mode 100644
index 0000000000..92c267c38f
--- /dev/null
+++ b/examples/async_demo.py
@@ -0,0 +1,22 @@
+#!/usr/bin/env -S poetry run python
+
+import asyncio
+
+from openai import AsyncOpenAI
+
+# gets API Key from environment variable OPENAI_API_KEY
+client = AsyncOpenAI()
+
+
+async def main() -> None:
+ stream = await client.completions.create(
+ model="text-davinci-003",
+ prompt="Say this is a test",
+ stream=True,
+ )
+ async for completion in stream:
+ print(completion.choices[0].text, end="")
+ print()
+
+
+asyncio.run(main())
diff --git a/examples/azure.py b/examples/azure.py
new file mode 100644
index 0000000000..a28b8cc433
--- /dev/null
+++ b/examples/azure.py
@@ -0,0 +1,43 @@
+from openai import AzureOpenAI
+
+# may change in the future
+# https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning
+api_version = "2023-07-01-preview"
+
+# gets the API Key from environment variable AZURE_OPENAI_API_KEY
+client = AzureOpenAI(
+ api_version=api_version,
+ # https://learn.microsoft.com/en-us/azure/cognitive-services/openai/how-to/create-resource?pivots=web-portal#create-a-resource
+ azure_endpoint="https://example-endpoint.openai.azure.com",
+)
+
+completion = client.chat.completions.create(
+ model="deployment-name", # e.g. gpt-35-instant
+ messages=[
+ {
+ "role": "user",
+ "content": "How do I output all files in a directory using Python?",
+ },
+ ],
+)
+print(completion.model_dump_json(indent=2))
+
+
+deployment_client = AzureOpenAI(
+ api_version=api_version,
+ # https://learn.microsoft.com/en-us/azure/cognitive-services/openai/how-to/create-resource?pivots=web-portal#create-a-resource
+ azure_endpoint="https://example-resource.azure.openai.com/",
+ # Navigate to the Azure OpenAI Studio to deploy a model.
+ azure_deployment="deployment-name", # e.g. gpt-35-instant
+)
+
+completion = deployment_client.chat.completions.create(
+ model="",
+ messages=[
+ {
+ "role": "user",
+ "content": "How do I output all files in a directory using Python?",
+ },
+ ],
+)
+print(completion.model_dump_json(indent=2))
diff --git a/examples/azure/embeddings.ipynb b/examples/azure/embeddings.ipynb
deleted file mode 100644
index c350e597ac..0000000000
--- a/examples/azure/embeddings.ipynb
+++ /dev/null
@@ -1,38 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/azure/embeddings.ipynb](https://github.com/openai/openai-cookbook/tree/main/examples/azure/embeddings.ipynb)."
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3.9.9 ('openai')",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.9"
- },
- "orig_nbformat": 4,
- "vscode": {
- "interpreter": {
- "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
- }
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/examples/azure/finetuning.ipynb b/examples/azure/finetuning.ipynb
deleted file mode 100644
index 07aa224e54..0000000000
--- a/examples/azure/finetuning.ipynb
+++ /dev/null
@@ -1,38 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/azure/finetuning.ipynb](https://github.com/openai/openai-cookbook/tree/main/examples/azure/finetuning.ipynb)."
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3.9.9 ('openai')",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.9"
- },
- "orig_nbformat": 4,
- "vscode": {
- "interpreter": {
- "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
- }
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/examples/azure_ad.py b/examples/azure_ad.py
new file mode 100644
index 0000000000..f13079dd04
--- /dev/null
+++ b/examples/azure_ad.py
@@ -0,0 +1,30 @@
+from azure.identity import DefaultAzureCredential, get_bearer_token_provider
+
+from openai import AzureOpenAI
+
+token_provider = get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")
+
+
+# may change in the future
+# https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning
+api_version = "2023-07-01-preview"
+
+# https://learn.microsoft.com/en-us/azure/cognitive-services/openai/how-to/create-resource?pivots=web-portal#create-a-resource
+endpoint = "https://my-resource.openai.azure.com"
+
+client = AzureOpenAI(
+ api_version=api_version,
+ azure_endpoint=endpoint,
+ azure_ad_token_provider=token_provider,
+)
+
+completion = client.chat.completions.create(
+ model="deployment-name", # e.g. gpt-35-instant
+ messages=[
+ {
+ "role": "user",
+ "content": "How do I output all files in a directory using Python?",
+ },
+ ],
+)
+print(completion.model_dump_json(indent=2))
diff --git a/examples/codex/backtranslation.py b/examples/codex/backtranslation.py
deleted file mode 100644
index 6390e5e174..0000000000
--- a/examples/codex/backtranslation.py
+++ /dev/null
@@ -1,2 +0,0 @@
-# This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook)
-# at [examples/Backtranslation_of_SQL_queries](https://github.com/openai/openai-cookbook/blob/main/examples/Backtranslation_of_SQL_queries.py)
diff --git a/examples/demo.py b/examples/demo.py
new file mode 100644
index 0000000000..37830e3e97
--- /dev/null
+++ b/examples/demo.py
@@ -0,0 +1,38 @@
+#!/usr/bin/env -S poetry run python
+
+from openai import OpenAI
+
+# gets API Key from environment variable OPENAI_API_KEY
+client = OpenAI()
+
+# Non-streaming:
+print("----- standard request -----")
+completion = client.chat.completions.create(
+ model="gpt-4",
+ messages=[
+ {
+ "role": "user",
+ "content": "Say this is a test",
+ },
+ ],
+)
+print(completion.choices[0].message.content)
+
+# Streaming:
+print("----- streaming request -----")
+stream = client.chat.completions.create(
+ model="gpt-4",
+ messages=[
+ {
+ "role": "user",
+ "content": "How do I output all files in a directory using Python?",
+ },
+ ],
+ stream=True,
+)
+for chunk in stream:
+ if not chunk.choices:
+ continue
+
+ print(chunk.choices[0].delta.content, end="")
+print()
diff --git a/examples/embeddings/Classification.ipynb b/examples/embeddings/Classification.ipynb
deleted file mode 100644
index b44d6a76a5..0000000000
--- a/examples/embeddings/Classification.ipynb
+++ /dev/null
@@ -1,38 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/Classification_using_embeddings.ipynb](https://github.com/openai/openai-cookbook/blob/main/examples/Classification_using_embeddings.ipynb)."
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3.9.9 ('openai')",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.9"
- },
- "orig_nbformat": 4,
- "vscode": {
- "interpreter": {
- "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
- }
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/examples/embeddings/Clustering.ipynb b/examples/embeddings/Clustering.ipynb
deleted file mode 100644
index 7a4f14193d..0000000000
--- a/examples/embeddings/Clustering.ipynb
+++ /dev/null
@@ -1,38 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/Clustering.ipynb](https://github.com/openai/openai-cookbook/blob/main/examples/Clustering.ipynb)."
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3.9.9 ('openai')",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.9"
- },
- "orig_nbformat": 4,
- "vscode": {
- "interpreter": {
- "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
- }
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/examples/embeddings/Code_search.ipynb b/examples/embeddings/Code_search.ipynb
deleted file mode 100644
index 440f8f56d5..0000000000
--- a/examples/embeddings/Code_search.ipynb
+++ /dev/null
@@ -1,38 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/Code_search.ipynb](https://github.com/openai/openai-cookbook/blob/main/examples/Code_search.ipynb)."
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3.9.9 ('openai')",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.9"
- },
- "orig_nbformat": 4,
- "vscode": {
- "interpreter": {
- "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
- }
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/examples/embeddings/Get_embeddings.ipynb b/examples/embeddings/Get_embeddings.ipynb
deleted file mode 100644
index 199c2dd156..0000000000
--- a/examples/embeddings/Get_embeddings.ipynb
+++ /dev/null
@@ -1,38 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/Get_embeddings.ipynb](https://github.com/openai/openai-cookbook/blob/main/examples/Get_embeddings.ipynb)."
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3.9.9 ('openai')",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.9"
- },
- "orig_nbformat": 4,
- "vscode": {
- "interpreter": {
- "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
- }
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/examples/embeddings/Obtain_dataset.ipynb b/examples/embeddings/Obtain_dataset.ipynb
deleted file mode 100644
index 9d04f9bce9..0000000000
--- a/examples/embeddings/Obtain_dataset.ipynb
+++ /dev/null
@@ -1,38 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/Obtain_dataset.ipynb](https://github.com/openai/openai-cookbook/blob/main/examples/Obtain_dataset.ipynb)."
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3.9.9 ('openai')",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.9"
- },
- "orig_nbformat": 4,
- "vscode": {
- "interpreter": {
- "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
- }
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/examples/embeddings/Recommendation.ipynb b/examples/embeddings/Recommendation.ipynb
deleted file mode 100644
index 7be5be31d7..0000000000
--- a/examples/embeddings/Recommendation.ipynb
+++ /dev/null
@@ -1,36 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/Recommendation_using_embeddings.ipynb](https://github.com/openai/openai-cookbook/blob/main/examples/Recommendation_using_embeddings.ipynb)."
- ]
- }
- ],
- "metadata": {
- "interpreter": {
- "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
- },
- "kernelspec": {
- "display_name": "Python 3.9.9 64-bit ('openai': virtualenv)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.9"
- },
- "orig_nbformat": 4
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/examples/embeddings/Regression.ipynb b/examples/embeddings/Regression.ipynb
deleted file mode 100644
index 8d44cb97b4..0000000000
--- a/examples/embeddings/Regression.ipynb
+++ /dev/null
@@ -1,38 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/Regression_using_embeddings.ipynb](https://github.com/openai/openai-cookbook/blob/main/examples/Regression_using_embeddings.ipynb)."
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3.9.9 ('openai')",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.9"
- },
- "orig_nbformat": 4,
- "vscode": {
- "interpreter": {
- "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
- }
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/examples/embeddings/Semantic_text_search_using_embeddings.ipynb b/examples/embeddings/Semantic_text_search_using_embeddings.ipynb
deleted file mode 100644
index 78dbc35f35..0000000000
--- a/examples/embeddings/Semantic_text_search_using_embeddings.ipynb
+++ /dev/null
@@ -1,38 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/Semantic_text_search_using_embeddings.ipynb](https://github.com/openai/openai-cookbook/blob/main/examples/Semantic_text_search_using_embeddings.ipynb)."
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3.9.9 ('openai')",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.9"
- },
- "orig_nbformat": 4,
- "vscode": {
- "interpreter": {
- "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
- }
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/examples/embeddings/User_and_product_embeddings.ipynb b/examples/embeddings/User_and_product_embeddings.ipynb
deleted file mode 100644
index 9ebd557b8f..0000000000
--- a/examples/embeddings/User_and_product_embeddings.ipynb
+++ /dev/null
@@ -1,38 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/User_and_product_embeddings.ipynb](https://github.com/openai/openai-cookbook/blob/main/examples/User_and_product_embeddings.ipynb)."
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3.9.9 ('openai')",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.9"
- },
- "orig_nbformat": 4,
- "vscode": {
- "interpreter": {
- "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
- }
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/examples/embeddings/Visualize_in_2d.ipynb b/examples/embeddings/Visualize_in_2d.ipynb
deleted file mode 100644
index 4638b58e95..0000000000
--- a/examples/embeddings/Visualize_in_2d.ipynb
+++ /dev/null
@@ -1,38 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/Visualizing_embeddings_in_2D.ipynb](https://github.com/openai/openai-cookbook/blob/main/examples/Visualizing_embeddings_in_2D.ipynb)."
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3.9.9 ('openai')",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.9"
- },
- "orig_nbformat": 4,
- "vscode": {
- "interpreter": {
- "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
- }
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/examples/embeddings/Visualize_in_3d.ipynb b/examples/embeddings/Visualize_in_3d.ipynb
deleted file mode 100644
index df79b02e9b..0000000000
--- a/examples/embeddings/Visualize_in_3d.ipynb
+++ /dev/null
@@ -1,38 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "id": "b87d69b2",
- "metadata": {},
- "source": [
- "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/Visualizing_embeddings_in_3D.ipynb](https://github.com/openai/openai-cookbook/blob/main/examples/Visualizing_embeddings_in_3D.ipynb)."
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3.9.9 ('openai')",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.9"
- },
- "vscode": {
- "interpreter": {
- "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
- }
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/examples/embeddings/Zero-shot_classification.ipynb b/examples/embeddings/Zero-shot_classification.ipynb
deleted file mode 100644
index d63561879a..0000000000
--- a/examples/embeddings/Zero-shot_classification.ipynb
+++ /dev/null
@@ -1,38 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/Zero-shot_classification_with_embeddings.ipynb](https://github.com/openai/openai-cookbook/blob/main/examples/Zero-shot_classification_with_embeddings.ipynb)."
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3.9.9 ('openai')",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.9"
- },
- "orig_nbformat": 4,
- "vscode": {
- "interpreter": {
- "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
- }
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/examples/finetuning/answers_with_ft.py b/examples/finetuning/answers_with_ft.py
deleted file mode 100644
index 43061f4c1b..0000000000
--- a/examples/finetuning/answers_with_ft.py
+++ /dev/null
@@ -1,2 +0,0 @@
-# This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook)
-# at [examples/fine-tuned_qa](https://github.com/openai/openai-cookbook/tree/main/examples/fine-tuned_qa)
diff --git a/examples/finetuning/finetuning-classification.ipynb b/examples/finetuning/finetuning-classification.ipynb
deleted file mode 100644
index e5ece174d9..0000000000
--- a/examples/finetuning/finetuning-classification.ipynb
+++ /dev/null
@@ -1,38 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/Fine-tuned_classification.ipynb](https://github.com/openai/openai-cookbook/blob/main/examples/Fine-tuned_classification.ipynb)."
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3.9.9 ('openai')",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.9"
- },
- "orig_nbformat": 4,
- "vscode": {
- "interpreter": {
- "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
- }
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/examples/finetuning/olympics-1-collect-data.ipynb b/examples/finetuning/olympics-1-collect-data.ipynb
deleted file mode 100644
index a0c55d438e..0000000000
--- a/examples/finetuning/olympics-1-collect-data.ipynb
+++ /dev/null
@@ -1,38 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/fine-tuned_qa/](https://github.com/openai/openai-cookbook/tree/main/examples/fine-tuned_qa)."
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3.9.9 ('openai')",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.9"
- },
- "orig_nbformat": 4,
- "vscode": {
- "interpreter": {
- "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
- }
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/examples/finetuning/olympics-2-create-qa.ipynb b/examples/finetuning/olympics-2-create-qa.ipynb
deleted file mode 100644
index a0c55d438e..0000000000
--- a/examples/finetuning/olympics-2-create-qa.ipynb
+++ /dev/null
@@ -1,38 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/fine-tuned_qa/](https://github.com/openai/openai-cookbook/tree/main/examples/fine-tuned_qa)."
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3.9.9 ('openai')",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.9"
- },
- "orig_nbformat": 4,
- "vscode": {
- "interpreter": {
- "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
- }
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/examples/finetuning/olympics-3-train-qa.ipynb b/examples/finetuning/olympics-3-train-qa.ipynb
deleted file mode 100644
index a0c55d438e..0000000000
--- a/examples/finetuning/olympics-3-train-qa.ipynb
+++ /dev/null
@@ -1,38 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/fine-tuned_qa/](https://github.com/openai/openai-cookbook/tree/main/examples/fine-tuned_qa)."
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3.9.9 ('openai')",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.9"
- },
- "orig_nbformat": 4,
- "vscode": {
- "interpreter": {
- "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
- }
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/examples/module_client.py b/examples/module_client.py
new file mode 100644
index 0000000000..5f2fb79dcf
--- /dev/null
+++ b/examples/module_client.py
@@ -0,0 +1,25 @@
+import openai
+
+# will default to `os.environ['OPENAI_API_KEY']` if not explicitly set
+openai.api_key = "..."
+
+# all client options can be configured just like the `OpenAI` instantiation counterpart
+openai.base_url = "https://..."
+openai.default_headers = {"x-foo": "true"}
+
+# all API calls work in the exact same fashion as well
+stream = openai.chat.completions.create(
+ model="gpt-4",
+ messages=[
+ {
+ "role": "user",
+ "content": "How do I output all files in a directory using Python?",
+ },
+ ],
+ stream=True,
+)
+
+for chunk in stream:
+ print(chunk.choices[0].delta.content or "", end="", flush=True)
+
+print()
diff --git a/examples/streaming.py b/examples/streaming.py
new file mode 100755
index 0000000000..168877dfc5
--- /dev/null
+++ b/examples/streaming.py
@@ -0,0 +1,56 @@
+#!/usr/bin/env -S poetry run python
+
+import asyncio
+
+from openai import OpenAI, AsyncOpenAI
+
+# This script assumes you have the OPENAI_API_KEY environment variable set to a valid OpenAI API key.
+#
+# You can run this script from the root directory like so:
+# `python examples/streaming.py`
+
+
+def sync_main() -> None:
+ client = OpenAI()
+ response = client.completions.create(
+ model="text-davinci-002",
+ prompt="1,2,3,",
+ max_tokens=5,
+ temperature=0,
+ stream=True,
+ )
+
+ # You can manually control iteration over the response
+ first = next(response)
+ print(f"got response data: {first.model_dump_json(indent=2)}")
+
+ # Or you could automatically iterate through all of data.
+ # Note that the for loop will not exit until *all* of the data has been processed.
+ for data in response:
+ print(data.model_dump_json())
+
+
+async def async_main() -> None:
+ client = AsyncOpenAI()
+ response = await client.completions.create(
+ model="text-davinci-002",
+ prompt="1,2,3,",
+ max_tokens=5,
+ temperature=0,
+ stream=True,
+ )
+
+ # You can manually control iteration over the response.
+ # In Python 3.10+ you can also use the `await anext(response)` builtin instead
+ first = await response.__anext__()
+ print(f"got response data: {first.model_dump_json(indent=2)}")
+
+ # Or you could automatically iterate through all of data.
+ # Note that the for loop will not exit until *all* of the data has been processed.
+ async for data in response:
+ print(data.model_dump_json())
+
+
+sync_main()
+
+asyncio.run(async_main())
diff --git a/mypy.ini b/mypy.ini
new file mode 100644
index 0000000000..a4517a002d
--- /dev/null
+++ b/mypy.ini
@@ -0,0 +1,47 @@
+[mypy]
+pretty = True
+show_error_codes = True
+
+# Exclude _files.py because mypy isn't smart enough to apply
+# the correct type narrowing and as this is an internal module
+# it's fine to just use Pyright.
+exclude = ^(src/openai/_files\.py|_dev/.*\.py)$
+
+strict_equality = True
+implicit_reexport = True
+check_untyped_defs = True
+no_implicit_optional = True
+
+warn_return_any = True
+warn_unreachable = True
+warn_unused_configs = True
+
+# Turn these options off as it could cause conflicts
+# with the Pyright options.
+warn_unused_ignores = False
+warn_redundant_casts = False
+
+disallow_any_generics = True
+disallow_untyped_defs = True
+disallow_untyped_calls = True
+disallow_subclassing_any = True
+disallow_incomplete_defs = True
+disallow_untyped_decorators = True
+cache_fine_grained = True
+
+# By default, mypy reports an error if you assign a value to the result
+# of a function call that doesn't return anything. We do this in our test
+# cases:
+# ```
+# result = ...
+# assert result is None
+# ```
+# Changing this codegen to make mypy happy would increase complexity
+# and would not be worth it.
+disable_error_code = func-returns-value
+
+# https://github.com/python/mypy/issues/12162
+[mypy.overrides]
+module = "black.files.*"
+ignore_errors = true
+ignore_missing_imports = true
diff --git a/noxfile.py b/noxfile.py
new file mode 100644
index 0000000000..53bca7ff2a
--- /dev/null
+++ b/noxfile.py
@@ -0,0 +1,9 @@
+import nox
+
+
+@nox.session(reuse_venv=True, name="test-pydantic-v1")
+def test_pydantic_v1(session: nox.Session) -> None:
+ session.install("-r", "requirements-dev.lock")
+ session.install("pydantic<2")
+
+ session.run("pytest", "--showlocals", "--ignore=tests/functional", *session.posargs)
diff --git a/openai/__init__.py b/openai/__init__.py
deleted file mode 100644
index b44e50f97f..0000000000
--- a/openai/__init__.py
+++ /dev/null
@@ -1,106 +0,0 @@
-# OpenAI Python bindings.
-#
-# Originally forked from the MIT-licensed Stripe Python bindings.
-
-import os
-import sys
-from typing import TYPE_CHECKING, Optional, Union, Callable
-
-from contextvars import ContextVar
-
-if "pkg_resources" not in sys.modules:
- # workaround for the following:
- # https://github.com/benoitc/gunicorn/pull/2539
- sys.modules["pkg_resources"] = object() # type: ignore[assignment]
- import aiohttp
-
- del sys.modules["pkg_resources"]
-
-from openai.api_resources import (
- Audio,
- ChatCompletion,
- Completion,
- Customer,
- Deployment,
- Edit,
- Embedding,
- Engine,
- ErrorObject,
- File,
- FineTune,
- FineTuningJob,
- Image,
- Model,
- Moderation,
-)
-from openai.error import APIError, InvalidRequestError, OpenAIError
-from openai.version import VERSION
-
-if TYPE_CHECKING:
- import requests
- from aiohttp import ClientSession
-
-api_key = os.environ.get("OPENAI_API_KEY")
-# Path of a file with an API key, whose contents can change. Supercedes
-# `api_key` if set. The main use case is volume-mounted Kubernetes secrets,
-# which are updated automatically.
-api_key_path: Optional[str] = os.environ.get("OPENAI_API_KEY_PATH")
-
-organization = os.environ.get("OPENAI_ORGANIZATION")
-api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
-api_type = os.environ.get("OPENAI_API_TYPE", "open_ai")
-api_version = os.environ.get(
- "OPENAI_API_VERSION",
- ("2023-05-15" if api_type in ("azure", "azure_ad", "azuread") else None),
-)
-verify_ssl_certs = True # No effect. Certificates are always verified.
-proxy = None
-app_info = None
-enable_telemetry = False # Ignored; the telemetry feature was removed.
-ca_bundle_path = None # No longer used, feature was removed
-debug = False
-log = None # Set to either 'debug' or 'info', controls console logging
-
-requestssession: Optional[
- Union["requests.Session", Callable[[], "requests.Session"]]
-] = None # Provide a requests.Session or Session factory.
-
-aiosession: ContextVar[Optional["ClientSession"]] = ContextVar(
- "aiohttp-session", default=None
-) # Acts as a global aiohttp ClientSession that reuses connections.
-# This is user-supplied; otherwise, a session is remade for each request.
-
-__version__ = VERSION
-__all__ = [
- "APIError",
- "Audio",
- "ChatCompletion",
- "Completion",
- "Customer",
- "Edit",
- "Image",
- "Deployment",
- "Embedding",
- "Engine",
- "ErrorObject",
- "File",
- "FineTune",
- "FineTuningJob",
- "InvalidRequestError",
- "Model",
- "Moderation",
- "OpenAIError",
- "api_base",
- "api_key",
- "api_type",
- "api_key_path",
- "api_version",
- "app_info",
- "ca_bundle_path",
- "debug",
- "enable_telemetry",
- "log",
- "organization",
- "proxy",
- "verify_ssl_certs",
-]
diff --git a/openai/_openai_scripts.py b/openai/_openai_scripts.py
deleted file mode 100755
index 497de19fab..0000000000
--- a/openai/_openai_scripts.py
+++ /dev/null
@@ -1,89 +0,0 @@
-#!/usr/bin/env python
-import argparse
-import logging
-import sys
-
-import openai
-from openai import version
-from openai.cli import api_register, display_error, tools_register, wandb_register
-
-logger = logging.getLogger()
-formatter = logging.Formatter("[%(asctime)s] %(message)s")
-handler = logging.StreamHandler(sys.stderr)
-handler.setFormatter(formatter)
-logger.addHandler(handler)
-
-
-def main():
- parser = argparse.ArgumentParser(description=None)
- parser.add_argument(
- "-V",
- "--version",
- action="version",
- version="%(prog)s " + version.VERSION,
- )
- parser.add_argument(
- "-v",
- "--verbose",
- action="count",
- dest="verbosity",
- default=0,
- help="Set verbosity.",
- )
- parser.add_argument("-b", "--api-base", help="What API base url to use.")
- parser.add_argument("-k", "--api-key", help="What API key to use.")
- parser.add_argument("-p", "--proxy", nargs='+', help="What proxy to use.")
- parser.add_argument(
- "-o",
- "--organization",
- help="Which organization to run as (will use your default organization if not specified)",
- )
-
- def help(args):
- parser.print_help()
-
- parser.set_defaults(func=help)
-
- subparsers = parser.add_subparsers()
- sub_api = subparsers.add_parser("api", help="Direct API calls")
- sub_tools = subparsers.add_parser("tools", help="Client side tools for convenience")
- sub_wandb = subparsers.add_parser("wandb", help="Logging with Weights & Biases, see https://docs.wandb.ai/guides/integrations/openai for documentation")
-
- api_register(sub_api)
- tools_register(sub_tools)
- wandb_register(sub_wandb)
-
- args = parser.parse_args()
- if args.verbosity == 1:
- logger.setLevel(logging.INFO)
- elif args.verbosity >= 2:
- logger.setLevel(logging.DEBUG)
-
- openai.debug = True
- if args.api_key is not None:
- openai.api_key = args.api_key
- if args.api_base is not None:
- openai.api_base = args.api_base
- if args.organization is not None:
- openai.organization = args.organization
- if args.proxy is not None:
- openai.proxy = {}
- for proxy in args.proxy:
- if proxy.startswith('https'):
- openai.proxy['https'] = proxy
- elif proxy.startswith('http'):
- openai.proxy['http'] = proxy
-
- try:
- args.func(args)
- except openai.error.OpenAIError as e:
- display_error(e)
- return 1
- except KeyboardInterrupt:
- sys.stderr.write("\n")
- return 1
- return 0
-
-
-if __name__ == "__main__":
- sys.exit(main())
diff --git a/openai/api_requestor.py b/openai/api_requestor.py
deleted file mode 100644
index c051bc64f2..0000000000
--- a/openai/api_requestor.py
+++ /dev/null
@@ -1,799 +0,0 @@
-import asyncio
-import json
-import time
-import platform
-import sys
-import threading
-import time
-import warnings
-from json import JSONDecodeError
-from typing import (
- AsyncContextManager,
- AsyncGenerator,
- Callable,
- Dict,
- Iterator,
- Optional,
- Tuple,
- Union,
- overload,
-)
-from urllib.parse import urlencode, urlsplit, urlunsplit
-
-import aiohttp
-import requests
-
-if sys.version_info >= (3, 8):
- from typing import Literal
-else:
- from typing_extensions import Literal
-
-import openai
-from openai import error, util, version
-from openai.openai_response import OpenAIResponse
-from openai.util import ApiType
-
-TIMEOUT_SECS = 600
-MAX_SESSION_LIFETIME_SECS = 180
-MAX_CONNECTION_RETRIES = 2
-
-# Has one attribute per thread, 'session'.
-_thread_context = threading.local()
-
-
-def _build_api_url(url, query):
- scheme, netloc, path, base_query, fragment = urlsplit(url)
-
- if base_query:
- query = "%s&%s" % (base_query, query)
-
- return urlunsplit((scheme, netloc, path, query, fragment))
-
-
-def _requests_proxies_arg(proxy) -> Optional[Dict[str, str]]:
- """Returns a value suitable for the 'proxies' argument to 'requests.request."""
- if proxy is None:
- return None
- elif isinstance(proxy, str):
- return {"http": proxy, "https": proxy}
- elif isinstance(proxy, dict):
- return proxy.copy()
- else:
- raise ValueError(
- "'openai.proxy' must be specified as either a string URL or a dict with string URL under the https and/or http keys."
- )
-
-
-def _aiohttp_proxies_arg(proxy) -> Optional[str]:
- """Returns a value suitable for the 'proxies' argument to 'aiohttp.ClientSession.request."""
- if proxy is None:
- return None
- elif isinstance(proxy, str):
- return proxy
- elif isinstance(proxy, dict):
- return proxy["https"] if "https" in proxy else proxy["http"]
- else:
- raise ValueError(
- "'openai.proxy' must be specified as either a string URL or a dict with string URL under the https and/or http keys."
- )
-
-
-def _make_session() -> requests.Session:
- if openai.requestssession:
- if isinstance(openai.requestssession, requests.Session):
- return openai.requestssession
- return openai.requestssession()
- if not openai.verify_ssl_certs:
- warnings.warn("verify_ssl_certs is ignored; openai always verifies.")
- s = requests.Session()
- proxies = _requests_proxies_arg(openai.proxy)
- if proxies:
- s.proxies = proxies
- s.mount(
- "https://",
- requests.adapters.HTTPAdapter(max_retries=MAX_CONNECTION_RETRIES),
- )
- return s
-
-
-def parse_stream_helper(line: bytes) -> Optional[str]:
- if line and line.startswith(b"data:"):
- if line.startswith(b"data: "):
- # SSE event may be valid when it contain whitespace
- line = line[len(b"data: "):]
- else:
- line = line[len(b"data:"):]
- if line.strip() == b"[DONE]":
- # return here will cause GeneratorExit exception in urllib3
- # and it will close http connection with TCP Reset
- return None
- else:
- return line.decode("utf-8")
- return None
-
-
-def parse_stream(rbody: Iterator[bytes]) -> Iterator[str]:
- for line in rbody:
- _line = parse_stream_helper(line)
- if _line is not None:
- yield _line
-
-
-async def parse_stream_async(rbody: aiohttp.StreamReader):
- async for line in rbody:
- _line = parse_stream_helper(line)
- if _line is not None:
- yield _line
-
-
-class APIRequestor:
- def __init__(
- self,
- key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- ):
- self.api_base = api_base or openai.api_base
- self.api_key = key or util.default_api_key()
- self.api_type = (
- ApiType.from_str(api_type)
- if api_type
- else ApiType.from_str(openai.api_type)
- )
- self.api_version = api_version or openai.api_version
- self.organization = organization or openai.organization
-
- @classmethod
- def format_app_info(cls, info):
- str = info["name"]
- if info["version"]:
- str += "/%s" % (info["version"],)
- if info["url"]:
- str += " (%s)" % (info["url"],)
- return str
-
- def _check_polling_response(self, response: OpenAIResponse, predicate: Callable[[OpenAIResponse], bool]):
- if not predicate(response):
- return
- error_data = response.data['error']
- message = error_data.get('message', 'Operation failed')
- code = error_data.get('code')
- raise error.OpenAIError(message=message, code=code)
-
- def _poll(
- self,
- method,
- url,
- until,
- failed,
- params = None,
- headers = None,
- interval = None,
- delay = None
- ) -> Tuple[Iterator[OpenAIResponse], bool, str]:
- if delay:
- time.sleep(delay)
-
- response, b, api_key = self.request(method, url, params, headers)
- self._check_polling_response(response, failed)
- start_time = time.time()
- while not until(response):
- if time.time() - start_time > TIMEOUT_SECS:
- raise error.Timeout("Operation polling timed out.")
-
- time.sleep(interval or response.retry_after or 10)
- response, b, api_key = self.request(method, url, params, headers)
- self._check_polling_response(response, failed)
-
- response.data = response.data['result']
- return response, b, api_key
-
- async def _apoll(
- self,
- method,
- url,
- until,
- failed,
- params = None,
- headers = None,
- interval = None,
- delay = None
- ) -> Tuple[Iterator[OpenAIResponse], bool, str]:
- if delay:
- await asyncio.sleep(delay)
-
- response, b, api_key = await self.arequest(method, url, params, headers)
- self._check_polling_response(response, failed)
- start_time = time.time()
- while not until(response):
- if time.time() - start_time > TIMEOUT_SECS:
- raise error.Timeout("Operation polling timed out.")
-
- await asyncio.sleep(interval or response.retry_after or 10)
- response, b, api_key = await self.arequest(method, url, params, headers)
- self._check_polling_response(response, failed)
-
- response.data = response.data['result']
- return response, b, api_key
-
- @overload
- def request(
- self,
- method,
- url,
- params,
- headers,
- files,
- stream: Literal[True],
- request_id: Optional[str] = ...,
- request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
- ) -> Tuple[Iterator[OpenAIResponse], bool, str]:
- pass
-
- @overload
- def request(
- self,
- method,
- url,
- params=...,
- headers=...,
- files=...,
- *,
- stream: Literal[True],
- request_id: Optional[str] = ...,
- request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
- ) -> Tuple[Iterator[OpenAIResponse], bool, str]:
- pass
-
- @overload
- def request(
- self,
- method,
- url,
- params=...,
- headers=...,
- files=...,
- stream: Literal[False] = ...,
- request_id: Optional[str] = ...,
- request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
- ) -> Tuple[OpenAIResponse, bool, str]:
- pass
-
- @overload
- def request(
- self,
- method,
- url,
- params=...,
- headers=...,
- files=...,
- stream: bool = ...,
- request_id: Optional[str] = ...,
- request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
- ) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]:
- pass
-
- def request(
- self,
- method,
- url,
- params=None,
- headers=None,
- files=None,
- stream: bool = False,
- request_id: Optional[str] = None,
- request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
- ) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]:
- result = self.request_raw(
- method.lower(),
- url,
- params=params,
- supplied_headers=headers,
- files=files,
- stream=stream,
- request_id=request_id,
- request_timeout=request_timeout,
- )
- resp, got_stream = self._interpret_response(result, stream)
- return resp, got_stream, self.api_key
-
- @overload
- async def arequest(
- self,
- method,
- url,
- params,
- headers,
- files,
- stream: Literal[True],
- request_id: Optional[str] = ...,
- request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
- ) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
- pass
-
- @overload
- async def arequest(
- self,
- method,
- url,
- params=...,
- headers=...,
- files=...,
- *,
- stream: Literal[True],
- request_id: Optional[str] = ...,
- request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
- ) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
- pass
-
- @overload
- async def arequest(
- self,
- method,
- url,
- params=...,
- headers=...,
- files=...,
- stream: Literal[False] = ...,
- request_id: Optional[str] = ...,
- request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
- ) -> Tuple[OpenAIResponse, bool, str]:
- pass
-
- @overload
- async def arequest(
- self,
- method,
- url,
- params=...,
- headers=...,
- files=...,
- stream: bool = ...,
- request_id: Optional[str] = ...,
- request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
- ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
- pass
-
- async def arequest(
- self,
- method,
- url,
- params=None,
- headers=None,
- files=None,
- stream: bool = False,
- request_id: Optional[str] = None,
- request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
- ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
- ctx = AioHTTPSession()
- session = await ctx.__aenter__()
- result = None
- try:
- result = await self.arequest_raw(
- method.lower(),
- url,
- session,
- params=params,
- supplied_headers=headers,
- files=files,
- request_id=request_id,
- request_timeout=request_timeout,
- )
- resp, got_stream = await self._interpret_async_response(result, stream)
- except Exception:
- # Close the request before exiting session context.
- if result is not None:
- result.release()
- await ctx.__aexit__(None, None, None)
- raise
- if got_stream:
-
- async def wrap_resp():
- assert isinstance(resp, AsyncGenerator)
- try:
- async for r in resp:
- yield r
- finally:
- # Close the request before exiting session context. Important to do it here
- # as if stream is not fully exhausted, we need to close the request nevertheless.
- result.release()
- await ctx.__aexit__(None, None, None)
-
- return wrap_resp(), got_stream, self.api_key
- else:
- # Close the request before exiting session context.
- result.release()
- await ctx.__aexit__(None, None, None)
- return resp, got_stream, self.api_key
-
- def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False):
- try:
- error_data = resp["error"]
- except (KeyError, TypeError):
- raise error.APIError(
- "Invalid response object from API: %r (HTTP response code "
- "was %d)" % (rbody, rcode),
- rbody,
- rcode,
- resp,
- )
-
- if "internal_message" in error_data:
- error_data["message"] += "\n\n" + error_data["internal_message"]
-
- util.log_info(
- "OpenAI API error received",
- error_code=error_data.get("code"),
- error_type=error_data.get("type"),
- error_message=error_data.get("message"),
- error_param=error_data.get("param"),
- stream_error=stream_error,
- )
-
- # Rate limits were previously coded as 400's with code 'rate_limit'
- if rcode == 429:
- return error.RateLimitError(
- error_data.get("message"), rbody, rcode, resp, rheaders
- )
- elif rcode in [400, 404, 415]:
- return error.InvalidRequestError(
- error_data.get("message"),
- error_data.get("param"),
- error_data.get("code"),
- rbody,
- rcode,
- resp,
- rheaders,
- )
- elif rcode == 401:
- return error.AuthenticationError(
- error_data.get("message"), rbody, rcode, resp, rheaders
- )
- elif rcode == 403:
- return error.PermissionError(
- error_data.get("message"), rbody, rcode, resp, rheaders
- )
- elif rcode == 409:
- return error.TryAgain(
- error_data.get("message"), rbody, rcode, resp, rheaders
- )
- elif stream_error:
- # TODO: we will soon attach status codes to stream errors
- parts = [error_data.get("message"), "(Error occurred while streaming.)"]
- message = " ".join([p for p in parts if p is not None])
- return error.APIError(message, rbody, rcode, resp, rheaders)
- else:
- return error.APIError(
- f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}",
- rbody,
- rcode,
- resp,
- rheaders,
- )
-
- def request_headers(
- self, method: str, extra, request_id: Optional[str]
- ) -> Dict[str, str]:
- user_agent = "OpenAI/v1 PythonBindings/%s" % (version.VERSION,)
- if openai.app_info:
- user_agent += " " + self.format_app_info(openai.app_info)
-
- uname_without_node = " ".join(
- v for k, v in platform.uname()._asdict().items() if k != "node"
- )
- ua = {
- "bindings_version": version.VERSION,
- "httplib": "requests",
- "lang": "python",
- "lang_version": platform.python_version(),
- "platform": platform.platform(),
- "publisher": "openai",
- "uname": uname_without_node,
- }
- if openai.app_info:
- ua["application"] = openai.app_info
-
- headers = {
- "X-OpenAI-Client-User-Agent": json.dumps(ua),
- "User-Agent": user_agent,
- }
-
- headers.update(util.api_key_to_header(self.api_type, self.api_key))
-
- if self.organization:
- headers["OpenAI-Organization"] = self.organization
-
- if self.api_version is not None and self.api_type == ApiType.OPEN_AI:
- headers["OpenAI-Version"] = self.api_version
- if request_id is not None:
- headers["X-Request-Id"] = request_id
- if openai.debug:
- headers["OpenAI-Debug"] = "true"
- headers.update(extra)
-
- return headers
-
- def _validate_headers(
- self, supplied_headers: Optional[Dict[str, str]]
- ) -> Dict[str, str]:
- headers: Dict[str, str] = {}
- if supplied_headers is None:
- return headers
-
- if not isinstance(supplied_headers, dict):
- raise TypeError("Headers must be a dictionary")
-
- for k, v in supplied_headers.items():
- if not isinstance(k, str):
- raise TypeError("Header keys must be strings")
- if not isinstance(v, str):
- raise TypeError("Header values must be strings")
- headers[k] = v
-
- # NOTE: It is possible to do more validation of the headers, but a request could always
- # be made to the API manually with invalid headers, so we need to handle them server side.
-
- return headers
-
- def _prepare_request_raw(
- self,
- url,
- supplied_headers,
- method,
- params,
- files,
- request_id: Optional[str],
- ) -> Tuple[str, Dict[str, str], Optional[bytes]]:
- abs_url = "%s%s" % (self.api_base, url)
- headers = self._validate_headers(supplied_headers)
-
- data = None
- if method == "get" or method == "delete":
- if params:
- encoded_params = urlencode(
- [(k, v) for k, v in params.items() if v is not None]
- )
- abs_url = _build_api_url(abs_url, encoded_params)
- elif method in {"post", "put"}:
- if params and files:
- data = params
- if params and not files:
- data = json.dumps(params).encode()
- headers["Content-Type"] = "application/json"
- else:
- raise error.APIConnectionError(
- "Unrecognized HTTP method %r. This may indicate a bug in the "
- "OpenAI bindings. Please contact us through our help center at help.openai.com for "
- "assistance." % (method,)
- )
-
- headers = self.request_headers(method, headers, request_id)
-
- util.log_debug("Request to OpenAI API", method=method, path=abs_url)
- util.log_debug("Post details", data=data, api_version=self.api_version)
-
- return abs_url, headers, data
-
- def request_raw(
- self,
- method,
- url,
- *,
- params=None,
- supplied_headers: Optional[Dict[str, str]] = None,
- files=None,
- stream: bool = False,
- request_id: Optional[str] = None,
- request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
- ) -> requests.Response:
- abs_url, headers, data = self._prepare_request_raw(
- url, supplied_headers, method, params, files, request_id
- )
-
- if not hasattr(_thread_context, "session"):
- _thread_context.session = _make_session()
- _thread_context.session_create_time = time.time()
- elif (
- time.time() - getattr(_thread_context, "session_create_time", 0)
- >= MAX_SESSION_LIFETIME_SECS
- ):
- _thread_context.session.close()
- _thread_context.session = _make_session()
- _thread_context.session_create_time = time.time()
- try:
- result = _thread_context.session.request(
- method,
- abs_url,
- headers=headers,
- data=data,
- files=files,
- stream=stream,
- timeout=request_timeout if request_timeout else TIMEOUT_SECS,
- proxies=_thread_context.session.proxies,
- )
- except requests.exceptions.Timeout as e:
- raise error.Timeout("Request timed out: {}".format(e)) from e
- except requests.exceptions.RequestException as e:
- raise error.APIConnectionError(
- "Error communicating with OpenAI: {}".format(e)
- ) from e
- util.log_debug(
- "OpenAI API response",
- path=abs_url,
- response_code=result.status_code,
- processing_ms=result.headers.get("OpenAI-Processing-Ms"),
- request_id=result.headers.get("X-Request-Id"),
- )
- # Don't read the whole stream for debug logging unless necessary.
- if openai.log == "debug":
- util.log_debug(
- "API response body", body=result.content, headers=result.headers
- )
- return result
-
- async def arequest_raw(
- self,
- method,
- url,
- session,
- *,
- params=None,
- supplied_headers: Optional[Dict[str, str]] = None,
- files=None,
- request_id: Optional[str] = None,
- request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
- ) -> aiohttp.ClientResponse:
- abs_url, headers, data = self._prepare_request_raw(
- url, supplied_headers, method, params, files, request_id
- )
-
- if isinstance(request_timeout, tuple):
- timeout = aiohttp.ClientTimeout(
- connect=request_timeout[0],
- total=request_timeout[1],
- )
- else:
- timeout = aiohttp.ClientTimeout(
- total=request_timeout if request_timeout else TIMEOUT_SECS
- )
-
- if files:
- # TODO: Use `aiohttp.MultipartWriter` to create the multipart form data here.
- # For now we use the private `requests` method that is known to have worked so far.
- data, content_type = requests.models.RequestEncodingMixin._encode_files( # type: ignore
- files, data
- )
- headers["Content-Type"] = content_type
- request_kwargs = {
- "method": method,
- "url": abs_url,
- "headers": headers,
- "data": data,
- "proxy": _aiohttp_proxies_arg(openai.proxy),
- "timeout": timeout,
- }
- try:
- result = await session.request(**request_kwargs)
- util.log_info(
- "OpenAI API response",
- path=abs_url,
- response_code=result.status,
- processing_ms=result.headers.get("OpenAI-Processing-Ms"),
- request_id=result.headers.get("X-Request-Id"),
- )
- # Don't read the whole stream for debug logging unless necessary.
- if openai.log == "debug":
- util.log_debug(
- "API response body", body=result.content, headers=result.headers
- )
- return result
- except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e:
- raise error.Timeout("Request timed out") from e
- except aiohttp.ClientError as e:
- raise error.APIConnectionError("Error communicating with OpenAI") from e
-
- def _interpret_response(
- self, result: requests.Response, stream: bool
- ) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]:
- """Returns the response(s) and a bool indicating whether it is a stream."""
- if stream and "text/event-stream" in result.headers.get("Content-Type", ""):
- return (
- self._interpret_response_line(
- line, result.status_code, result.headers, stream=True
- )
- for line in parse_stream(result.iter_lines())
- ), True
- else:
- return (
- self._interpret_response_line(
- result.content.decode("utf-8"),
- result.status_code,
- result.headers,
- stream=False,
- ),
- False,
- )
-
- async def _interpret_async_response(
- self, result: aiohttp.ClientResponse, stream: bool
- ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]:
- """Returns the response(s) and a bool indicating whether it is a stream."""
- if stream and "text/event-stream" in result.headers.get("Content-Type", ""):
- return (
- self._interpret_response_line(
- line, result.status, result.headers, stream=True
- )
- async for line in parse_stream_async(result.content)
- ), True
- else:
- try:
- await result.read()
- except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e:
- raise error.Timeout("Request timed out") from e
- except aiohttp.ClientError as e:
- util.log_warn(e, body=result.content)
- return (
- self._interpret_response_line(
- (await result.read()).decode("utf-8"),
- result.status,
- result.headers,
- stream=False,
- ),
- False,
- )
-
- def _interpret_response_line(
- self, rbody: str, rcode: int, rheaders, stream: bool
- ) -> OpenAIResponse:
- # HTTP 204 response code does not have any content in the body.
- if rcode == 204:
- return OpenAIResponse(None, rheaders)
-
- if rcode == 503:
- raise error.ServiceUnavailableError(
- "The server is overloaded or not ready yet.",
- rbody,
- rcode,
- headers=rheaders,
- )
- try:
- if 'text/plain' in rheaders.get('Content-Type', ''):
- data = rbody
- else:
- data = json.loads(rbody)
- except (JSONDecodeError, UnicodeDecodeError) as e:
- raise error.APIError(
- f"HTTP code {rcode} from API ({rbody})", rbody, rcode, headers=rheaders
- ) from e
- resp = OpenAIResponse(data, rheaders)
- # In the future, we might add a "status" parameter to errors
- # to better handle the "error while streaming" case.
- stream_error = stream and "error" in resp.data
- if stream_error or not 200 <= rcode < 300:
- raise self.handle_error_response(
- rbody, rcode, resp.data, rheaders, stream_error=stream_error
- )
- return resp
-
-
-class AioHTTPSession(AsyncContextManager):
- def __init__(self):
- self._session = None
- self._should_close_session = False
-
- async def __aenter__(self):
- self._session = openai.aiosession.get()
- if self._session is None:
- self._session = await aiohttp.ClientSession().__aenter__()
- self._should_close_session = True
-
- return self._session
-
- async def __aexit__(self, exc_type, exc_value, traceback):
- if self._session is None:
- raise RuntimeError("Session is not initialized")
-
- if self._should_close_session:
- await self._session.__aexit__(exc_type, exc_value, traceback)
\ No newline at end of file
diff --git a/openai/api_resources/__init__.py b/openai/api_resources/__init__.py
deleted file mode 100644
index 78bad1a22a..0000000000
--- a/openai/api_resources/__init__.py
+++ /dev/null
@@ -1,15 +0,0 @@
-from openai.api_resources.audio import Audio # noqa: F401
-from openai.api_resources.chat_completion import ChatCompletion # noqa: F401
-from openai.api_resources.completion import Completion # noqa: F401
-from openai.api_resources.customer import Customer # noqa: F401
-from openai.api_resources.deployment import Deployment # noqa: F401
-from openai.api_resources.edit import Edit # noqa: F401
-from openai.api_resources.embedding import Embedding # noqa: F401
-from openai.api_resources.engine import Engine # noqa: F401
-from openai.api_resources.error_object import ErrorObject # noqa: F401
-from openai.api_resources.file import File # noqa: F401
-from openai.api_resources.fine_tune import FineTune # noqa: F401
-from openai.api_resources.fine_tuning import FineTuningJob # noqa: F401
-from openai.api_resources.image import Image # noqa: F401
-from openai.api_resources.model import Model # noqa: F401
-from openai.api_resources.moderation import Moderation # noqa: F401
diff --git a/openai/api_resources/abstract/__init__.py b/openai/api_resources/abstract/__init__.py
deleted file mode 100644
index 48482bd87a..0000000000
--- a/openai/api_resources/abstract/__init__.py
+++ /dev/null
@@ -1,13 +0,0 @@
-# flake8: noqa
-
-from openai.api_resources.abstract.api_resource import APIResource
-from openai.api_resources.abstract.createable_api_resource import CreateableAPIResource
-from openai.api_resources.abstract.deletable_api_resource import DeletableAPIResource
-from openai.api_resources.abstract.listable_api_resource import ListableAPIResource
-from openai.api_resources.abstract.nested_resource_class_methods import (
- nested_resource_class_methods,
-)
-from openai.api_resources.abstract.paginatable_api_resource import (
- PaginatableAPIResource,
-)
-from openai.api_resources.abstract.updateable_api_resource import UpdateableAPIResource
diff --git a/openai/api_resources/abstract/api_resource.py b/openai/api_resources/abstract/api_resource.py
deleted file mode 100644
index 5d54bb9fd8..0000000000
--- a/openai/api_resources/abstract/api_resource.py
+++ /dev/null
@@ -1,172 +0,0 @@
-from urllib.parse import quote_plus
-
-import openai
-from openai import api_requestor, error, util
-from openai.openai_object import OpenAIObject
-from openai.util import ApiType
-from typing import Optional
-
-
-class APIResource(OpenAIObject):
- api_prefix = ""
- azure_api_prefix = "openai"
- azure_deployments_prefix = "deployments"
-
- @classmethod
- def retrieve(
- cls, id, api_key=None, request_id=None, request_timeout=None, **params
- ):
- instance = cls(id=id, api_key=api_key, **params)
- instance.refresh(request_id=request_id, request_timeout=request_timeout)
- return instance
-
- @classmethod
- def aretrieve(
- cls, id, api_key=None, request_id=None, request_timeout=None, **params
- ):
- instance = cls(id=id, api_key=api_key, **params)
- return instance.arefresh(request_id=request_id, request_timeout=request_timeout)
-
- def refresh(self, request_id=None, request_timeout=None):
- self.refresh_from(
- self.request(
- "get",
- self.instance_url(),
- request_id=request_id,
- request_timeout=request_timeout,
- )
- )
- return self
-
- async def arefresh(self, request_id=None, request_timeout=None):
- self.refresh_from(
- await self.arequest(
- "get",
- self.instance_url(operation="refresh"),
- request_id=request_id,
- request_timeout=request_timeout,
- )
- )
- return self
-
- @classmethod
- def class_url(cls):
- if cls == APIResource:
- raise NotImplementedError(
- "APIResource is an abstract class. You should perform actions on its subclasses."
- )
- # Namespaces are separated in object names with periods (.) and in URLs
- # with forward slashes (/), so replace the former with the latter.
- base = cls.OBJECT_NAME.replace(".", "/") # type: ignore
- if cls.api_prefix:
- return "/%s/%s" % (cls.api_prefix, base)
- return "/%s" % (base)
-
- def instance_url(self, operation=None):
- id = self.get("id")
-
- if not isinstance(id, str):
- raise error.InvalidRequestError(
- "Could not determine which URL to request: %s instance "
- "has invalid ID: %r, %s. ID should be of type `str` (or"
- " `unicode`)" % (type(self).__name__, id, type(id)),
- "id",
- )
- api_version = self.api_version or openai.api_version
- extn = quote_plus(id)
-
- if self.typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD):
- if not api_version:
- raise error.InvalidRequestError(
- "An API version is required for the Azure API type."
- )
-
- if not operation:
- base = self.class_url()
- return "/%s%s/%s?api-version=%s" % (
- self.azure_api_prefix,
- base,
- extn,
- api_version,
- )
-
- return "/%s/%s/%s/%s?api-version=%s" % (
- self.azure_api_prefix,
- self.azure_deployments_prefix,
- extn,
- operation,
- api_version,
- )
-
- elif self.typed_api_type == ApiType.OPEN_AI:
- base = self.class_url()
- return "%s/%s" % (base, extn)
-
- else:
- raise error.InvalidAPIType("Unsupported API type %s" % self.api_type)
-
- # The `method_` and `url_` arguments are suffixed with an underscore to
- # avoid conflicting with actual request parameters in `params`.
- @classmethod
- def _static_request(
- cls,
- method_,
- url_,
- api_key=None,
- api_base=None,
- api_type=None,
- request_id=None,
- api_version=None,
- organization=None,
- **params,
- ):
- requestor = api_requestor.APIRequestor(
- api_key,
- api_version=api_version,
- organization=organization,
- api_base=api_base,
- api_type=api_type,
- )
- response, _, api_key = requestor.request(
- method_, url_, params, request_id=request_id
- )
- return util.convert_to_openai_object(
- response, api_key, api_version, organization
- )
-
- @classmethod
- async def _astatic_request(
- cls,
- method_,
- url_,
- api_key=None,
- api_base=None,
- api_type=None,
- request_id=None,
- api_version=None,
- organization=None,
- **params,
- ):
- requestor = api_requestor.APIRequestor(
- api_key,
- api_version=api_version,
- organization=organization,
- api_base=api_base,
- api_type=api_type,
- )
- response, _, api_key = await requestor.arequest(
- method_, url_, params, request_id=request_id
- )
- return response
-
- @classmethod
- def _get_api_type_and_version(
- cls, api_type: Optional[str] = None, api_version: Optional[str] = None
- ):
- typed_api_type = (
- ApiType.from_str(api_type)
- if api_type
- else ApiType.from_str(openai.api_type)
- )
- typed_api_version = api_version or openai.api_version
- return (typed_api_type, typed_api_version)
diff --git a/openai/api_resources/abstract/createable_api_resource.py b/openai/api_resources/abstract/createable_api_resource.py
deleted file mode 100644
index 1361c02627..0000000000
--- a/openai/api_resources/abstract/createable_api_resource.py
+++ /dev/null
@@ -1,98 +0,0 @@
-from openai import api_requestor, util, error
-from openai.api_resources.abstract.api_resource import APIResource
-from openai.util import ApiType
-
-
-class CreateableAPIResource(APIResource):
- plain_old_data = False
-
- @classmethod
- def __prepare_create_requestor(
- cls,
- api_key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- ):
- requestor = api_requestor.APIRequestor(
- api_key,
- api_base=api_base,
- api_type=api_type,
- api_version=api_version,
- organization=organization,
- )
- typed_api_type, api_version = cls._get_api_type_and_version(
- api_type, api_version
- )
-
- if typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD):
- base = cls.class_url()
- url = "/%s%s?api-version=%s" % (cls.azure_api_prefix, base, api_version)
- elif typed_api_type == ApiType.OPEN_AI:
- url = cls.class_url()
- else:
- raise error.InvalidAPIType("Unsupported API type %s" % api_type)
- return requestor, url
-
- @classmethod
- def create(
- cls,
- api_key=None,
- api_base=None,
- api_type=None,
- request_id=None,
- api_version=None,
- organization=None,
- **params,
- ):
- requestor, url = cls.__prepare_create_requestor(
- api_key,
- api_base,
- api_type,
- api_version,
- organization,
- )
-
- response, _, api_key = requestor.request(
- "post", url, params, request_id=request_id
- )
-
- return util.convert_to_openai_object(
- response,
- api_key,
- api_version,
- organization,
- plain_old_data=cls.plain_old_data,
- )
-
- @classmethod
- async def acreate(
- cls,
- api_key=None,
- api_base=None,
- api_type=None,
- request_id=None,
- api_version=None,
- organization=None,
- **params,
- ):
- requestor, url = cls.__prepare_create_requestor(
- api_key,
- api_base,
- api_type,
- api_version,
- organization,
- )
-
- response, _, api_key = await requestor.arequest(
- "post", url, params, request_id=request_id
- )
-
- return util.convert_to_openai_object(
- response,
- api_key,
- api_version,
- organization,
- plain_old_data=cls.plain_old_data,
- )
diff --git a/openai/api_resources/abstract/deletable_api_resource.py b/openai/api_resources/abstract/deletable_api_resource.py
deleted file mode 100644
index a800ceb812..0000000000
--- a/openai/api_resources/abstract/deletable_api_resource.py
+++ /dev/null
@@ -1,48 +0,0 @@
-from urllib.parse import quote_plus
-from typing import Awaitable
-
-from openai import error
-from openai.api_resources.abstract.api_resource import APIResource
-from openai.util import ApiType
-
-
-class DeletableAPIResource(APIResource):
- @classmethod
- def __prepare_delete(cls, sid, api_type=None, api_version=None):
- if isinstance(cls, APIResource):
- raise ValueError(".delete may only be called as a class method now.")
-
- base = cls.class_url()
- extn = quote_plus(sid)
-
- typed_api_type, api_version = cls._get_api_type_and_version(
- api_type, api_version
- )
- if typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD):
- url = "/%s%s/%s?api-version=%s" % (
- cls.azure_api_prefix,
- base,
- extn,
- api_version,
- )
- elif typed_api_type == ApiType.OPEN_AI:
- url = "%s/%s" % (base, extn)
- else:
- raise error.InvalidAPIType("Unsupported API type %s" % api_type)
- return url
-
- @classmethod
- def delete(cls, sid, api_type=None, api_version=None, **params):
- url = cls.__prepare_delete(sid, api_type, api_version)
-
- return cls._static_request(
- "delete", url, api_type=api_type, api_version=api_version, **params
- )
-
- @classmethod
- def adelete(cls, sid, api_type=None, api_version=None, **params) -> Awaitable:
- url = cls.__prepare_delete(sid, api_type, api_version)
-
- return cls._astatic_request(
- "delete", url, api_type=api_type, api_version=api_version, **params
- )
diff --git a/openai/api_resources/abstract/engine_api_resource.py b/openai/api_resources/abstract/engine_api_resource.py
deleted file mode 100644
index bbef90e23e..0000000000
--- a/openai/api_resources/abstract/engine_api_resource.py
+++ /dev/null
@@ -1,328 +0,0 @@
-import time
-from pydoc import apropos
-from typing import Optional
-from urllib.parse import quote_plus
-
-import openai
-from openai import api_requestor, error, util
-from openai.api_resources.abstract.api_resource import APIResource
-from openai.openai_response import OpenAIResponse
-from openai.util import ApiType
-
-MAX_TIMEOUT = 20
-
-
-class EngineAPIResource(APIResource):
- plain_old_data = False
-
- def __init__(self, engine: Optional[str] = None, **kwargs):
- super().__init__(engine=engine, **kwargs)
-
- @classmethod
- def class_url(
- cls,
- engine: Optional[str] = None,
- api_type: Optional[str] = None,
- api_version: Optional[str] = None,
- ):
- # Namespaces are separated in object names with periods (.) and in URLs
- # with forward slashes (/), so replace the former with the latter.
- base = cls.OBJECT_NAME.replace(".", "/") # type: ignore
- typed_api_type, api_version = cls._get_api_type_and_version(
- api_type, api_version
- )
-
- if typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD):
- if not api_version:
- raise error.InvalidRequestError(
- "An API version is required for the Azure API type.",
- "api_version"
- )
- if engine is None:
- raise error.InvalidRequestError(
- "You must provide the deployment name in the 'engine' parameter to access the Azure OpenAI service",
- "engine"
- )
- extn = quote_plus(engine)
- return "/%s/%s/%s/%s?api-version=%s" % (
- cls.azure_api_prefix,
- cls.azure_deployments_prefix,
- extn,
- base,
- api_version,
- )
-
- elif typed_api_type == ApiType.OPEN_AI:
- if engine is None:
- return "/%s" % (base)
-
- extn = quote_plus(engine)
- return "/engines/%s/%s" % (extn, base)
-
- else:
- raise error.InvalidAPIType("Unsupported API type %s" % api_type)
-
- @classmethod
- def __prepare_create_request(
- cls,
- api_key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- **params,
- ):
- deployment_id = params.pop("deployment_id", None)
- engine = params.pop("engine", deployment_id)
- model = params.get("model", None)
- timeout = params.pop("timeout", None)
- stream = params.get("stream", False)
- headers = params.pop("headers", None)
- request_timeout = params.pop("request_timeout", None)
- typed_api_type = cls._get_api_type_and_version(api_type=api_type)[0]
- if typed_api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
- if deployment_id is None and engine is None:
- raise error.InvalidRequestError(
- "Must provide an 'engine' or 'deployment_id' parameter to create a %s"
- % cls,
- "engine",
- )
- else:
- if model is None and engine is None:
- raise error.InvalidRequestError(
- "Must provide an 'engine' or 'model' parameter to create a %s"
- % cls,
- "engine",
- )
-
- if timeout is None:
- # No special timeout handling
- pass
- elif timeout > 0:
- # API only supports timeouts up to MAX_TIMEOUT
- params["timeout"] = min(timeout, MAX_TIMEOUT)
- timeout = (timeout - params["timeout"]) or None
- elif timeout == 0:
- params["timeout"] = MAX_TIMEOUT
-
- requestor = api_requestor.APIRequestor(
- api_key,
- api_base=api_base,
- api_type=api_type,
- api_version=api_version,
- organization=organization,
- )
- url = cls.class_url(engine, api_type, api_version)
- return (
- deployment_id,
- engine,
- timeout,
- stream,
- headers,
- request_timeout,
- typed_api_type,
- requestor,
- url,
- params,
- )
-
- @classmethod
- def create(
- cls,
- api_key=None,
- api_base=None,
- api_type=None,
- request_id=None,
- api_version=None,
- organization=None,
- **params,
- ):
- (
- deployment_id,
- engine,
- timeout,
- stream,
- headers,
- request_timeout,
- typed_api_type,
- requestor,
- url,
- params,
- ) = cls.__prepare_create_request(
- api_key, api_base, api_type, api_version, organization, **params
- )
-
- response, _, api_key = requestor.request(
- "post",
- url,
- params=params,
- headers=headers,
- stream=stream,
- request_id=request_id,
- request_timeout=request_timeout,
- )
-
- if stream:
- # must be an iterator
- assert not isinstance(response, OpenAIResponse)
- return (
- util.convert_to_openai_object(
- line,
- api_key,
- api_version,
- organization,
- engine=engine,
- plain_old_data=cls.plain_old_data,
- )
- for line in response
- )
- else:
- obj = util.convert_to_openai_object(
- response,
- api_key,
- api_version,
- organization,
- engine=engine,
- plain_old_data=cls.plain_old_data,
- )
-
- if timeout is not None:
- obj.wait(timeout=timeout or None)
-
- return obj
-
- @classmethod
- async def acreate(
- cls,
- api_key=None,
- api_base=None,
- api_type=None,
- request_id=None,
- api_version=None,
- organization=None,
- **params,
- ):
- (
- deployment_id,
- engine,
- timeout,
- stream,
- headers,
- request_timeout,
- typed_api_type,
- requestor,
- url,
- params,
- ) = cls.__prepare_create_request(
- api_key, api_base, api_type, api_version, organization, **params
- )
- response, _, api_key = await requestor.arequest(
- "post",
- url,
- params=params,
- headers=headers,
- stream=stream,
- request_id=request_id,
- request_timeout=request_timeout,
- )
-
- if stream:
- # must be an iterator
- assert not isinstance(response, OpenAIResponse)
- return (
- util.convert_to_openai_object(
- line,
- api_key,
- api_version,
- organization,
- engine=engine,
- plain_old_data=cls.plain_old_data,
- )
- async for line in response
- )
- else:
- obj = util.convert_to_openai_object(
- response,
- api_key,
- api_version,
- organization,
- engine=engine,
- plain_old_data=cls.plain_old_data,
- )
-
- if timeout is not None:
- await obj.await_(timeout=timeout or None)
-
- return obj
-
- def instance_url(self):
- id = self.get("id")
-
- if not isinstance(id, str):
- raise error.InvalidRequestError(
- f"Could not determine which URL to request: {type(self).__name__} instance has invalid ID: {id}, {type(id)}. ID should be of type str.",
- "id",
- )
-
- extn = quote_plus(id)
- params_connector = "?"
-
- if self.typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD):
- api_version = self.api_version or openai.api_version
- if not api_version:
- raise error.InvalidRequestError(
- "An API version is required for the Azure API type.",
- "api_version"
- )
- base = self.OBJECT_NAME.replace(".", "/")
- url = "/%s/%s/%s/%s/%s?api-version=%s" % (
- self.azure_api_prefix,
- self.azure_deployments_prefix,
- self.engine,
- base,
- extn,
- api_version,
- )
- params_connector = "&"
-
- elif self.typed_api_type == ApiType.OPEN_AI:
- base = self.class_url(self.engine, self.api_type, self.api_version)
- url = "%s/%s" % (base, extn)
-
- else:
- raise error.InvalidAPIType("Unsupported API type %s" % self.api_type)
-
- timeout = self.get("timeout")
- if timeout is not None:
- timeout = quote_plus(str(timeout))
- url += params_connector + "timeout={}".format(timeout)
- return url
-
- def wait(self, timeout=None):
- start = time.time()
- while self.status != "complete":
- self.timeout = (
- min(timeout + start - time.time(), MAX_TIMEOUT)
- if timeout is not None
- else MAX_TIMEOUT
- )
- if self.timeout < 0:
- del self.timeout
- break
- self.refresh()
- return self
-
- async def await_(self, timeout=None):
- """Async version of `EngineApiResource.wait`"""
- start = time.time()
- while self.status != "complete":
- self.timeout = (
- min(timeout + start - time.time(), MAX_TIMEOUT)
- if timeout is not None
- else MAX_TIMEOUT
- )
- if self.timeout < 0:
- del self.timeout
- break
- await self.arefresh()
- return self
diff --git a/openai/api_resources/abstract/listable_api_resource.py b/openai/api_resources/abstract/listable_api_resource.py
deleted file mode 100644
index 3e59979f13..0000000000
--- a/openai/api_resources/abstract/listable_api_resource.py
+++ /dev/null
@@ -1,95 +0,0 @@
-from openai import api_requestor, util, error
-from openai.api_resources.abstract.api_resource import APIResource
-from openai.util import ApiType
-
-
-class ListableAPIResource(APIResource):
- @classmethod
- def auto_paging_iter(cls, *args, **params):
- return cls.list(*args, **params).auto_paging_iter()
-
- @classmethod
- def __prepare_list_requestor(
- cls,
- api_key=None,
- api_version=None,
- organization=None,
- api_base=None,
- api_type=None,
- ):
- requestor = api_requestor.APIRequestor(
- api_key,
- api_base=api_base or cls.api_base(),
- api_version=api_version,
- api_type=api_type,
- organization=organization,
- )
-
- typed_api_type, api_version = cls._get_api_type_and_version(
- api_type, api_version
- )
-
- if typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD):
- base = cls.class_url()
- url = "/%s%s?api-version=%s" % (cls.azure_api_prefix, base, api_version)
- elif typed_api_type == ApiType.OPEN_AI:
- url = cls.class_url()
- else:
- raise error.InvalidAPIType("Unsupported API type %s" % api_type)
- return requestor, url
-
- @classmethod
- def list(
- cls,
- api_key=None,
- request_id=None,
- api_version=None,
- organization=None,
- api_base=None,
- api_type=None,
- **params,
- ):
- requestor, url = cls.__prepare_list_requestor(
- api_key,
- api_version,
- organization,
- api_base,
- api_type,
- )
-
- response, _, api_key = requestor.request(
- "get", url, params, request_id=request_id
- )
- openai_object = util.convert_to_openai_object(
- response, api_key, api_version, organization
- )
- openai_object._retrieve_params = params
- return openai_object
-
- @classmethod
- async def alist(
- cls,
- api_key=None,
- request_id=None,
- api_version=None,
- organization=None,
- api_base=None,
- api_type=None,
- **params,
- ):
- requestor, url = cls.__prepare_list_requestor(
- api_key,
- api_version,
- organization,
- api_base,
- api_type,
- )
-
- response, _, api_key = await requestor.arequest(
- "get", url, params, request_id=request_id
- )
- openai_object = util.convert_to_openai_object(
- response, api_key, api_version, organization
- )
- openai_object._retrieve_params = params
- return openai_object
diff --git a/openai/api_resources/abstract/nested_resource_class_methods.py b/openai/api_resources/abstract/nested_resource_class_methods.py
deleted file mode 100644
index 68197ab1fa..0000000000
--- a/openai/api_resources/abstract/nested_resource_class_methods.py
+++ /dev/null
@@ -1,169 +0,0 @@
-from urllib.parse import quote_plus
-
-from openai import api_requestor, util
-
-
-def _nested_resource_class_methods(
- resource,
- path=None,
- operations=None,
- resource_plural=None,
- async_=False,
-):
- if resource_plural is None:
- resource_plural = "%ss" % resource
- if path is None:
- path = resource_plural
- if operations is None:
- raise ValueError("operations list required")
-
- def wrapper(cls):
- def nested_resource_url(cls, id, nested_id=None):
- url = "%s/%s/%s" % (cls.class_url(), quote_plus(id), quote_plus(path))
- if nested_id is not None:
- url += "/%s" % quote_plus(nested_id)
- return url
-
- resource_url_method = "%ss_url" % resource
- setattr(cls, resource_url_method, classmethod(nested_resource_url))
-
- def nested_resource_request(
- cls,
- method,
- url,
- api_base=None,
- api_key=None,
- request_id=None,
- api_version=None,
- organization=None,
- **params,
- ):
- requestor = api_requestor.APIRequestor(
- api_key, api_base=api_base, api_version=api_version, organization=organization
- )
- response, _, api_key = requestor.request(
- method, url, params, request_id=request_id
- )
- return util.convert_to_openai_object(
- response, api_key, api_version, organization
- )
-
- async def anested_resource_request(
- cls,
- method,
- url,
- api_key=None,
- api_base=None,
- request_id=None,
- api_version=None,
- organization=None,
- **params,
- ):
- requestor = api_requestor.APIRequestor(
- api_key, api_base=api_base, api_version=api_version, organization=organization
- )
- response, _, api_key = await requestor.arequest(
- method, url, params, request_id=request_id
- )
- return util.convert_to_openai_object(
- response, api_key, api_version, organization
- )
-
- resource_request_method = "%ss_request" % resource
- setattr(
- cls,
- resource_request_method,
- classmethod(
- anested_resource_request if async_ else nested_resource_request
- ),
- )
-
- for operation in operations:
- if operation == "create":
-
- def create_nested_resource(cls, id, **params):
- url = getattr(cls, resource_url_method)(id)
- return getattr(cls, resource_request_method)("post", url, **params)
-
- create_method = "create_%s" % resource
- setattr(cls, create_method, classmethod(create_nested_resource))
-
- elif operation == "retrieve":
-
- def retrieve_nested_resource(cls, id, nested_id, **params):
- url = getattr(cls, resource_url_method)(id, nested_id)
- return getattr(cls, resource_request_method)("get", url, **params)
-
- retrieve_method = "retrieve_%s" % resource
- setattr(cls, retrieve_method, classmethod(retrieve_nested_resource))
-
- elif operation == "update":
-
- def modify_nested_resource(cls, id, nested_id, **params):
- url = getattr(cls, resource_url_method)(id, nested_id)
- return getattr(cls, resource_request_method)("post", url, **params)
-
- modify_method = "modify_%s" % resource
- setattr(cls, modify_method, classmethod(modify_nested_resource))
-
- elif operation == "delete":
-
- def delete_nested_resource(cls, id, nested_id, **params):
- url = getattr(cls, resource_url_method)(id, nested_id)
- return getattr(cls, resource_request_method)(
- "delete", url, **params
- )
-
- delete_method = "delete_%s" % resource
- setattr(cls, delete_method, classmethod(delete_nested_resource))
-
- elif operation == "list":
-
- def list_nested_resources(cls, id, **params):
- url = getattr(cls, resource_url_method)(id)
- return getattr(cls, resource_request_method)("get", url, **params)
-
- list_method = "list_%s" % resource_plural
- setattr(cls, list_method, classmethod(list_nested_resources))
-
- elif operation == "paginated_list":
-
- def paginated_list_nested_resources(
- cls, id, limit=None, after=None, **params
- ):
- url = getattr(cls, resource_url_method)(id)
- return getattr(cls, resource_request_method)(
- "get", url, limit=limit, after=after, **params
- )
-
- list_method = "list_%s" % resource_plural
- setattr(cls, list_method, classmethod(paginated_list_nested_resources))
-
- else:
- raise ValueError("Unknown operation: %s" % operation)
-
- return cls
-
- return wrapper
-
-
-def nested_resource_class_methods(
- resource,
- path=None,
- operations=None,
- resource_plural=None,
-):
- return _nested_resource_class_methods(
- resource, path, operations, resource_plural, async_=False
- )
-
-
-def anested_resource_class_methods(
- resource,
- path=None,
- operations=None,
- resource_plural=None,
-):
- return _nested_resource_class_methods(
- resource, path, operations, resource_plural, async_=True
- )
diff --git a/openai/api_resources/abstract/paginatable_api_resource.py b/openai/api_resources/abstract/paginatable_api_resource.py
deleted file mode 100644
index 2d75744f23..0000000000
--- a/openai/api_resources/abstract/paginatable_api_resource.py
+++ /dev/null
@@ -1,125 +0,0 @@
-from openai import api_requestor, error, util
-from openai.api_resources.abstract.listable_api_resource import ListableAPIResource
-from openai.util import ApiType
-
-
-class PaginatableAPIResource(ListableAPIResource):
- @classmethod
- def auto_paging_iter(cls, *args, **params):
- next_cursor = None
- has_more = True
- if not params.get("limit"):
- params["limit"] = 20
- while has_more:
- if next_cursor:
- params["after"] = next_cursor
- response = cls.list(*args, **params)
-
- for item in response.data:
- yield item
-
- if response.data:
- next_cursor = response.data[-1].id
- has_more = response.has_more
-
- @classmethod
- def __prepare_list_requestor(
- cls,
- api_key=None,
- api_version=None,
- organization=None,
- api_base=None,
- api_type=None,
- ):
- requestor = api_requestor.APIRequestor(
- api_key,
- api_base=api_base or cls.api_base(),
- api_version=api_version,
- api_type=api_type,
- organization=organization,
- )
-
- typed_api_type, api_version = cls._get_api_type_and_version(
- api_type, api_version
- )
-
- if typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD):
- base = cls.class_url()
- url = "/%s%s?api-version=%s" % (cls.azure_api_prefix, base, api_version)
- elif typed_api_type == ApiType.OPEN_AI:
- url = cls.class_url()
- else:
- raise error.InvalidAPIType("Unsupported API type %s" % api_type)
- return requestor, url
-
- @classmethod
- def list(
- cls,
- limit=None,
- starting_after=None,
- api_key=None,
- request_id=None,
- api_version=None,
- organization=None,
- api_base=None,
- api_type=None,
- **params,
- ):
- requestor, url = cls.__prepare_list_requestor(
- api_key,
- api_version,
- organization,
- api_base,
- api_type,
- )
-
- params = {
- **params,
- "limit": limit,
- "starting_after": starting_after,
- }
-
- response, _, api_key = requestor.request(
- "get", url, params, request_id=request_id
- )
- openai_object = util.convert_to_openai_object(
- response, api_key, api_version, organization
- )
- openai_object._retrieve_params = params
- return openai_object
-
- @classmethod
- async def alist(
- cls,
- limit=None,
- starting_after=None,
- api_key=None,
- request_id=None,
- api_version=None,
- organization=None,
- api_base=None,
- api_type=None,
- **params,
- ):
- requestor, url = cls.__prepare_list_requestor(
- api_key,
- api_version,
- organization,
- api_base,
- api_type,
- )
-
- params = {
- **params,
- "limit": limit,
- "starting_after": starting_after,
- }
-
- response, _, api_key = await requestor.arequest(
- "get", url, params, request_id=request_id
- )
- openai_object = util.convert_to_openai_object(
- response, api_key, api_version, organization
- )
- openai_object._retrieve_params = params
- return openai_object
diff --git a/openai/api_resources/abstract/updateable_api_resource.py b/openai/api_resources/abstract/updateable_api_resource.py
deleted file mode 100644
index 245f9b80b3..0000000000
--- a/openai/api_resources/abstract/updateable_api_resource.py
+++ /dev/null
@@ -1,16 +0,0 @@
-from urllib.parse import quote_plus
-from typing import Awaitable
-
-from openai.api_resources.abstract.api_resource import APIResource
-
-
-class UpdateableAPIResource(APIResource):
- @classmethod
- def modify(cls, sid, **params):
- url = "%s/%s" % (cls.class_url(), quote_plus(sid))
- return cls._static_request("post", url, **params)
-
- @classmethod
- def amodify(cls, sid, **params) -> Awaitable:
- url = "%s/%s" % (cls.class_url(), quote_plus(sid))
- return cls._astatic_request("patch", url, **params)
diff --git a/openai/api_resources/audio.py b/openai/api_resources/audio.py
deleted file mode 100644
index cb316f07f1..0000000000
--- a/openai/api_resources/audio.py
+++ /dev/null
@@ -1,311 +0,0 @@
-from typing import Any, List
-
-import openai
-from openai import api_requestor, util
-from openai.api_resources.abstract import APIResource
-
-
-class Audio(APIResource):
- OBJECT_NAME = "audio"
-
- @classmethod
- def _get_url(cls, action, deployment_id=None, api_type=None, api_version=None):
- if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
- return f"/{cls.azure_api_prefix}/deployments/{deployment_id}/audio/{action}?api-version={api_version}"
- return cls.class_url() + f"/{action}"
-
- @classmethod
- def _prepare_request(
- cls,
- file,
- filename,
- model,
- api_key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- **params,
- ):
- requestor = api_requestor.APIRequestor(
- api_key,
- api_base=api_base or openai.api_base,
- api_type=api_type,
- api_version=api_version,
- organization=organization,
- )
- files: List[Any] = []
- data = {
- "model": model,
- **params,
- }
- files.append(("file", (filename, file, "application/octet-stream")))
- return requestor, files, data
-
- @classmethod
- def transcribe(
- cls,
- model,
- file,
- api_key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- *,
- deployment_id=None,
- **params,
- ):
- requestor, files, data = cls._prepare_request(
- file=file,
- filename=file.name,
- model=model,
- api_key=api_key,
- api_base=api_base,
- api_type=api_type,
- api_version=api_version,
- organization=organization,
- **params,
- )
- api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
- url = cls._get_url("transcriptions", deployment_id=deployment_id, api_type=api_type, api_version=api_version)
- response, _, api_key = requestor.request("post", url, files=files, params=data)
- return util.convert_to_openai_object(
- response, api_key, api_version, organization
- )
-
- @classmethod
- def translate(
- cls,
- model,
- file,
- api_key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- *,
- deployment_id=None,
- **params,
- ):
- requestor, files, data = cls._prepare_request(
- file=file,
- filename=file.name,
- model=model,
- api_key=api_key,
- api_base=api_base,
- api_type=api_type,
- api_version=api_version,
- organization=organization,
- **params,
- )
- api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
- url = cls._get_url("translations", deployment_id=deployment_id, api_type=api_type, api_version=api_version)
- response, _, api_key = requestor.request("post", url, files=files, params=data)
- return util.convert_to_openai_object(
- response, api_key, api_version, organization
- )
-
- @classmethod
- def transcribe_raw(
- cls,
- model,
- file,
- filename,
- api_key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- *,
- deployment_id=None,
- **params,
- ):
- requestor, files, data = cls._prepare_request(
- file=file,
- filename=filename,
- model=model,
- api_key=api_key,
- api_base=api_base,
- api_type=api_type,
- api_version=api_version,
- organization=organization,
- **params,
- )
- api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
- url = cls._get_url("transcriptions", deployment_id=deployment_id, api_type=api_type, api_version=api_version)
- response, _, api_key = requestor.request("post", url, files=files, params=data)
- return util.convert_to_openai_object(
- response, api_key, api_version, organization
- )
-
- @classmethod
- def translate_raw(
- cls,
- model,
- file,
- filename,
- api_key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- *,
- deployment_id=None,
- **params,
- ):
- requestor, files, data = cls._prepare_request(
- file=file,
- filename=filename,
- model=model,
- api_key=api_key,
- api_base=api_base,
- api_type=api_type,
- api_version=api_version,
- organization=organization,
- **params,
- )
- api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
- url = cls._get_url("translations", deployment_id=deployment_id, api_type=api_type, api_version=api_version)
- response, _, api_key = requestor.request("post", url, files=files, params=data)
- return util.convert_to_openai_object(
- response, api_key, api_version, organization
- )
-
- @classmethod
- async def atranscribe(
- cls,
- model,
- file,
- api_key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- *,
- deployment_id=None,
- **params,
- ):
- requestor, files, data = cls._prepare_request(
- file=file,
- filename=file.name,
- model=model,
- api_key=api_key,
- api_base=api_base,
- api_type=api_type,
- api_version=api_version,
- organization=organization,
- **params,
- )
- api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
- url = cls._get_url("transcriptions", deployment_id=deployment_id, api_type=api_type, api_version=api_version)
- response, _, api_key = await requestor.arequest(
- "post", url, files=files, params=data
- )
- return util.convert_to_openai_object(
- response, api_key, api_version, organization
- )
-
- @classmethod
- async def atranslate(
- cls,
- model,
- file,
- api_key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- *,
- deployment_id=None,
- **params,
- ):
- requestor, files, data = cls._prepare_request(
- file=file,
- filename=file.name,
- model=model,
- api_key=api_key,
- api_base=api_base,
- api_type=api_type,
- api_version=api_version,
- organization=organization,
- **params,
- )
- api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
- url = cls._get_url("translations", deployment_id=deployment_id, api_type=api_type, api_version=api_version)
- response, _, api_key = await requestor.arequest(
- "post", url, files=files, params=data
- )
- return util.convert_to_openai_object(
- response, api_key, api_version, organization
- )
-
- @classmethod
- async def atranscribe_raw(
- cls,
- model,
- file,
- filename,
- api_key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- *,
- deployment_id=None,
- **params,
- ):
- requestor, files, data = cls._prepare_request(
- file=file,
- filename=filename,
- model=model,
- api_key=api_key,
- api_base=api_base,
- api_type=api_type,
- api_version=api_version,
- organization=organization,
- **params,
- )
- api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
- url = cls._get_url("transcriptions", deployment_id=deployment_id, api_type=api_type, api_version=api_version)
- response, _, api_key = await requestor.arequest(
- "post", url, files=files, params=data
- )
- return util.convert_to_openai_object(
- response, api_key, api_version, organization
- )
-
- @classmethod
- async def atranslate_raw(
- cls,
- model,
- file,
- filename,
- api_key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- *,
- deployment_id=None,
- **params,
- ):
- requestor, files, data = cls._prepare_request(
- file=file,
- filename=filename,
- model=model,
- api_key=api_key,
- api_base=api_base,
- api_type=api_type,
- api_version=api_version,
- organization=organization,
- **params,
- )
- api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
- url = cls._get_url("translations", deployment_id=deployment_id, api_type=api_type, api_version=api_version)
- response, _, api_key = await requestor.arequest(
- "post", url, files=files, params=data
- )
- return util.convert_to_openai_object(
- response, api_key, api_version, organization
- )
diff --git a/openai/api_resources/chat_completion.py b/openai/api_resources/chat_completion.py
deleted file mode 100644
index 7e55f9e38f..0000000000
--- a/openai/api_resources/chat_completion.py
+++ /dev/null
@@ -1,50 +0,0 @@
-import time
-
-from openai import util
-from openai.api_resources.abstract.engine_api_resource import EngineAPIResource
-from openai.error import TryAgain
-
-
-class ChatCompletion(EngineAPIResource):
- engine_required = False
- OBJECT_NAME = "chat.completions"
-
- @classmethod
- def create(cls, *args, **kwargs):
- """
- Creates a new chat completion for the provided messages and parameters.
-
- See https://platform.openai.com/docs/api-reference/chat/create
- for a list of valid parameters.
- """
- start = time.time()
- timeout = kwargs.pop("timeout", None)
-
- while True:
- try:
- return super().create(*args, **kwargs)
- except TryAgain as e:
- if timeout is not None and time.time() > start + timeout:
- raise
-
- util.log_info("Waiting for model to warm up", error=e)
-
- @classmethod
- async def acreate(cls, *args, **kwargs):
- """
- Creates a new chat completion for the provided messages and parameters.
-
- See https://platform.openai.com/docs/api-reference/chat/create
- for a list of valid parameters.
- """
- start = time.time()
- timeout = kwargs.pop("timeout", None)
-
- while True:
- try:
- return await super().acreate(*args, **kwargs)
- except TryAgain as e:
- if timeout is not None and time.time() > start + timeout:
- raise
-
- util.log_info("Waiting for model to warm up", error=e)
diff --git a/openai/api_resources/completion.py b/openai/api_resources/completion.py
deleted file mode 100644
index 7b9c44bd08..0000000000
--- a/openai/api_resources/completion.py
+++ /dev/null
@@ -1,50 +0,0 @@
-import time
-
-from openai import util
-from openai.api_resources.abstract import DeletableAPIResource, ListableAPIResource
-from openai.api_resources.abstract.engine_api_resource import EngineAPIResource
-from openai.error import TryAgain
-
-
-class Completion(EngineAPIResource):
- OBJECT_NAME = "completions"
-
- @classmethod
- def create(cls, *args, **kwargs):
- """
- Creates a new completion for the provided prompt and parameters.
-
- See https://platform.openai.com/docs/api-reference/completions/create for a list
- of valid parameters.
- """
- start = time.time()
- timeout = kwargs.pop("timeout", None)
-
- while True:
- try:
- return super().create(*args, **kwargs)
- except TryAgain as e:
- if timeout is not None and time.time() > start + timeout:
- raise
-
- util.log_info("Waiting for model to warm up", error=e)
-
- @classmethod
- async def acreate(cls, *args, **kwargs):
- """
- Creates a new completion for the provided prompt and parameters.
-
- See https://platform.openai.com/docs/api-reference/completions/create for a list
- of valid parameters.
- """
- start = time.time()
- timeout = kwargs.pop("timeout", None)
-
- while True:
- try:
- return await super().acreate(*args, **kwargs)
- except TryAgain as e:
- if timeout is not None and time.time() > start + timeout:
- raise
-
- util.log_info("Waiting for model to warm up", error=e)
diff --git a/openai/api_resources/customer.py b/openai/api_resources/customer.py
deleted file mode 100644
index 8690d07b38..0000000000
--- a/openai/api_resources/customer.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from openai.openai_object import OpenAIObject
-
-
-class Customer(OpenAIObject):
- @classmethod
- def get_url(cls, customer, endpoint):
- return f"/customer/{customer}/{endpoint}"
-
- @classmethod
- def create(cls, customer, endpoint, **params):
- instance = cls()
- return instance.request("post", cls.get_url(customer, endpoint), params)
-
- @classmethod
- def acreate(cls, customer, endpoint, **params):
- instance = cls()
- return instance.arequest("post", cls.get_url(customer, endpoint), params)
diff --git a/openai/api_resources/deployment.py b/openai/api_resources/deployment.py
deleted file mode 100644
index 2f3fcd1307..0000000000
--- a/openai/api_resources/deployment.py
+++ /dev/null
@@ -1,119 +0,0 @@
-from openai import util
-from openai.api_resources.abstract import (
- DeletableAPIResource,
- ListableAPIResource,
- CreateableAPIResource,
-)
-from openai.error import InvalidRequestError, APIError
-
-
-class Deployment(CreateableAPIResource, ListableAPIResource, DeletableAPIResource):
- OBJECT_NAME = "deployments"
-
- @classmethod
- def _check_create(cls, *args, **kwargs):
- typed_api_type, _ = cls._get_api_type_and_version(
- kwargs.get("api_type", None), None
- )
- if typed_api_type not in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
- raise APIError(
- "Deployment operations are only available for the Azure API type."
- )
-
- if kwargs.get("model", None) is None:
- raise InvalidRequestError(
- "Must provide a 'model' parameter to create a Deployment.",
- param="model",
- )
-
- scale_settings = kwargs.get("scale_settings", None)
- if scale_settings is None:
- raise InvalidRequestError(
- "Must provide a 'scale_settings' parameter to create a Deployment.",
- param="scale_settings",
- )
-
- if "scale_type" not in scale_settings or (
- scale_settings["scale_type"].lower() == "manual"
- and "capacity" not in scale_settings
- ):
- raise InvalidRequestError(
- "The 'scale_settings' parameter contains invalid or incomplete values.",
- param="scale_settings",
- )
-
- @classmethod
- def create(cls, *args, **kwargs):
- """
- Creates a new deployment for the provided prompt and parameters.
- """
- cls._check_create(*args, **kwargs)
- return super().create(*args, **kwargs)
-
- @classmethod
- def acreate(cls, *args, **kwargs):
- """
- Creates a new deployment for the provided prompt and parameters.
- """
- cls._check_create(*args, **kwargs)
- return super().acreate(*args, **kwargs)
-
- @classmethod
- def _check_list(cls, *args, **kwargs):
- typed_api_type, _ = cls._get_api_type_and_version(
- kwargs.get("api_type", None), None
- )
- if typed_api_type not in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
- raise APIError(
- "Deployment operations are only available for the Azure API type."
- )
-
- @classmethod
- def list(cls, *args, **kwargs):
- cls._check_list(*args, **kwargs)
- return super().list(*args, **kwargs)
-
- @classmethod
- def alist(cls, *args, **kwargs):
- cls._check_list(*args, **kwargs)
- return super().alist(*args, **kwargs)
-
- @classmethod
- def _check_delete(cls, *args, **kwargs):
- typed_api_type, _ = cls._get_api_type_and_version(
- kwargs.get("api_type", None), None
- )
- if typed_api_type not in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
- raise APIError(
- "Deployment operations are only available for the Azure API type."
- )
-
- @classmethod
- def delete(cls, *args, **kwargs):
- cls._check_delete(*args, **kwargs)
- return super().delete(*args, **kwargs)
-
- @classmethod
- def adelete(cls, *args, **kwargs):
- cls._check_delete(*args, **kwargs)
- return super().adelete(*args, **kwargs)
-
- @classmethod
- def _check_retrieve(cls, *args, **kwargs):
- typed_api_type, _ = cls._get_api_type_and_version(
- kwargs.get("api_type", None), None
- )
- if typed_api_type not in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
- raise APIError(
- "Deployment operations are only available for the Azure API type."
- )
-
- @classmethod
- def retrieve(cls, *args, **kwargs):
- cls._check_retrieve(*args, **kwargs)
- return super().retrieve(*args, **kwargs)
-
- @classmethod
- def aretrieve(cls, *args, **kwargs):
- cls._check_retrieve(*args, **kwargs)
- return super().aretrieve(*args, **kwargs)
diff --git a/openai/api_resources/edit.py b/openai/api_resources/edit.py
deleted file mode 100644
index 985f062ddb..0000000000
--- a/openai/api_resources/edit.py
+++ /dev/null
@@ -1,57 +0,0 @@
-import time
-
-from openai import util, error
-from openai.api_resources.abstract.engine_api_resource import EngineAPIResource
-from openai.error import TryAgain
-
-
-class Edit(EngineAPIResource):
- OBJECT_NAME = "edits"
-
- @classmethod
- def create(cls, *args, **kwargs):
- """
- Creates a new edit for the provided input, instruction, and parameters.
- """
- start = time.time()
- timeout = kwargs.pop("timeout", None)
-
- api_type = kwargs.pop("api_type", None)
- typed_api_type = cls._get_api_type_and_version(api_type=api_type)[0]
- if typed_api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
- raise error.InvalidAPIType(
- "This operation is not supported by the Azure OpenAI API yet."
- )
-
- while True:
- try:
- return super().create(*args, **kwargs)
- except TryAgain as e:
- if timeout is not None and time.time() > start + timeout:
- raise
-
- util.log_info("Waiting for model to warm up", error=e)
-
- @classmethod
- async def acreate(cls, *args, **kwargs):
- """
- Creates a new edit for the provided input, instruction, and parameters.
- """
- start = time.time()
- timeout = kwargs.pop("timeout", None)
-
- api_type = kwargs.pop("api_type", None)
- typed_api_type = cls._get_api_type_and_version(api_type=api_type)[0]
- if typed_api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
- raise error.InvalidAPIType(
- "This operation is not supported by the Azure OpenAI API yet."
- )
-
- while True:
- try:
- return await super().acreate(*args, **kwargs)
- except TryAgain as e:
- if timeout is not None and time.time() > start + timeout:
- raise
-
- util.log_info("Waiting for model to warm up", error=e)
diff --git a/openai/api_resources/embedding.py b/openai/api_resources/embedding.py
deleted file mode 100644
index e937636404..0000000000
--- a/openai/api_resources/embedding.py
+++ /dev/null
@@ -1,91 +0,0 @@
-import base64
-import time
-
-from openai import util
-from openai.api_resources.abstract.engine_api_resource import EngineAPIResource
-from openai.datalib.numpy_helper import assert_has_numpy
-from openai.datalib.numpy_helper import numpy as np
-from openai.error import TryAgain
-
-
-class Embedding(EngineAPIResource):
- OBJECT_NAME = "embeddings"
-
- @classmethod
- def create(cls, *args, **kwargs):
- """
- Creates a new embedding for the provided input and parameters.
-
- See https://platform.openai.com/docs/api-reference/embeddings for a list
- of valid parameters.
- """
- start = time.time()
- timeout = kwargs.pop("timeout", None)
-
- user_provided_encoding_format = kwargs.get("encoding_format", None)
-
- # If encoding format was not explicitly specified, we opaquely use base64 for performance
- if not user_provided_encoding_format:
- kwargs["encoding_format"] = "base64"
-
- while True:
- try:
- response = super().create(*args, **kwargs)
-
- # If a user specifies base64, we'll just return the encoded string.
- # This is only for the default case.
- if not user_provided_encoding_format:
- for data in response.data:
-
- # If an engine isn't using this optimization, don't do anything
- if type(data["embedding"]) == str:
- assert_has_numpy()
- data["embedding"] = np.frombuffer(
- base64.b64decode(data["embedding"]), dtype="float32"
- ).tolist()
-
- return response
- except TryAgain as e:
- if timeout is not None and time.time() > start + timeout:
- raise
-
- util.log_info("Waiting for model to warm up", error=e)
-
- @classmethod
- async def acreate(cls, *args, **kwargs):
- """
- Creates a new embedding for the provided input and parameters.
-
- See https://platform.openai.com/docs/api-reference/embeddings for a list
- of valid parameters.
- """
- start = time.time()
- timeout = kwargs.pop("timeout", None)
-
- user_provided_encoding_format = kwargs.get("encoding_format", None)
-
- # If encoding format was not explicitly specified, we opaquely use base64 for performance
- if not user_provided_encoding_format:
- kwargs["encoding_format"] = "base64"
-
- while True:
- try:
- response = await super().acreate(*args, **kwargs)
-
- # If a user specifies base64, we'll just return the encoded string.
- # This is only for the default case.
- if not user_provided_encoding_format:
- for data in response.data:
-
- # If an engine isn't using this optimization, don't do anything
- if type(data["embedding"]) == str:
- data["embedding"] = np.frombuffer(
- base64.b64decode(data["embedding"]), dtype="float32"
- ).tolist()
-
- return response
- except TryAgain as e:
- if timeout is not None and time.time() > start + timeout:
- raise
-
- util.log_info("Waiting for model to warm up", error=e)
diff --git a/openai/api_resources/engine.py b/openai/api_resources/engine.py
deleted file mode 100644
index 5a0c467c2f..0000000000
--- a/openai/api_resources/engine.py
+++ /dev/null
@@ -1,50 +0,0 @@
-import time
-import warnings
-
-from openai import util
-from openai.api_resources.abstract import ListableAPIResource, UpdateableAPIResource
-from openai.error import TryAgain
-
-
-class Engine(ListableAPIResource, UpdateableAPIResource):
- OBJECT_NAME = "engines"
-
- def generate(self, timeout=None, **params):
- start = time.time()
- while True:
- try:
- return self.request(
- "post",
- self.instance_url() + "/generate",
- params,
- stream=params.get("stream"),
- plain_old_data=True,
- )
- except TryAgain as e:
- if timeout is not None and time.time() > start + timeout:
- raise
-
- util.log_info("Waiting for model to warm up", error=e)
-
- async def agenerate(self, timeout=None, **params):
- start = time.time()
- while True:
- try:
- return await self.arequest(
- "post",
- self.instance_url() + "/generate",
- params,
- stream=params.get("stream"),
- plain_old_data=True,
- )
- except TryAgain as e:
- if timeout is not None and time.time() > start + timeout:
- raise
-
- util.log_info("Waiting for model to warm up", error=e)
-
- def embeddings(self, **params):
- warnings.warn(
- "Engine.embeddings is deprecated, use Embedding.create", DeprecationWarning
- )
- return self.request("post", self.instance_url() + "/embeddings", params)
diff --git a/openai/api_resources/error_object.py b/openai/api_resources/error_object.py
deleted file mode 100644
index 555dc35237..0000000000
--- a/openai/api_resources/error_object.py
+++ /dev/null
@@ -1,28 +0,0 @@
-from typing import Optional
-
-from openai.openai_object import OpenAIObject
-from openai.util import merge_dicts
-
-
-class ErrorObject(OpenAIObject):
- def refresh_from(
- self,
- values,
- api_key=None,
- api_version=None,
- api_type=None,
- organization=None,
- response_ms: Optional[int] = None,
- ):
- # Unlike most other API resources, the API will omit attributes in
- # error objects when they have a null value. We manually set default
- # values here to facilitate generic error handling.
- values = merge_dicts({"message": None, "type": None}, values)
- return super(ErrorObject, self).refresh_from(
- values=values,
- api_key=api_key,
- api_version=api_version,
- api_type=api_type,
- organization=organization,
- response_ms=response_ms,
- )
diff --git a/openai/api_resources/experimental/__init__.py b/openai/api_resources/experimental/__init__.py
deleted file mode 100644
index d24c7b0cb5..0000000000
--- a/openai/api_resources/experimental/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from openai.api_resources.experimental.completion_config import ( # noqa: F401
- CompletionConfig,
-)
diff --git a/openai/api_resources/experimental/completion_config.py b/openai/api_resources/experimental/completion_config.py
deleted file mode 100644
index 5d4feb40e1..0000000000
--- a/openai/api_resources/experimental/completion_config.py
+++ /dev/null
@@ -1,11 +0,0 @@
-from openai.api_resources.abstract import (
- CreateableAPIResource,
- DeletableAPIResource,
- ListableAPIResource,
-)
-
-
-class CompletionConfig(
- CreateableAPIResource, ListableAPIResource, DeletableAPIResource
-):
- OBJECT_NAME = "experimental.completion_configs"
diff --git a/openai/api_resources/file.py b/openai/api_resources/file.py
deleted file mode 100644
index dba2ee92e1..0000000000
--- a/openai/api_resources/file.py
+++ /dev/null
@@ -1,279 +0,0 @@
-import json
-import os
-from typing import cast
-import time
-
-import openai
-from openai import api_requestor, util, error
-from openai.api_resources.abstract import DeletableAPIResource, ListableAPIResource
-from openai.util import ApiType
-
-
-class File(ListableAPIResource, DeletableAPIResource):
- OBJECT_NAME = "files"
-
- @classmethod
- def __prepare_file_create(
- cls,
- file,
- purpose,
- model=None,
- api_key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- user_provided_filename=None,
- ):
- requestor = api_requestor.APIRequestor(
- api_key,
- api_base=api_base or openai.api_base,
- api_type=api_type,
- api_version=api_version,
- organization=organization,
- )
- typed_api_type, api_version = cls._get_api_type_and_version(
- api_type, api_version
- )
-
- if typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD):
- base = cls.class_url()
- url = "/%s%s?api-version=%s" % (cls.azure_api_prefix, base, api_version)
- elif typed_api_type == ApiType.OPEN_AI:
- url = cls.class_url()
- else:
- raise error.InvalidAPIType("Unsupported API type %s" % api_type)
-
- # Set the filename on 'purpose' and 'model' to None so they are
- # interpreted as form data.
- files = [("purpose", (None, purpose))]
- if model is not None:
- files.append(("model", (None, model)))
- if user_provided_filename is not None:
- files.append(
- ("file", (user_provided_filename, file, "application/octet-stream"))
- )
- else:
- files.append(("file", ("file", file, "application/octet-stream")))
-
- return requestor, url, files
-
- @classmethod
- def create(
- cls,
- file,
- purpose,
- model=None,
- api_key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- user_provided_filename=None,
- ):
- requestor, url, files = cls.__prepare_file_create(
- file,
- purpose,
- model,
- api_key,
- api_base,
- api_type,
- api_version,
- organization,
- user_provided_filename,
- )
- response, _, api_key = requestor.request("post", url, files=files)
- return util.convert_to_openai_object(
- response, api_key, api_version, organization
- )
-
- @classmethod
- async def acreate(
- cls,
- file,
- purpose,
- model=None,
- api_key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- user_provided_filename=None,
- ):
- requestor, url, files = cls.__prepare_file_create(
- file,
- purpose,
- model,
- api_key,
- api_base,
- api_type,
- api_version,
- organization,
- user_provided_filename,
- )
- response, _, api_key = await requestor.arequest("post", url, files=files)
- return util.convert_to_openai_object(
- response, api_key, api_version, organization
- )
-
- @classmethod
- def __prepare_file_download(
- cls,
- id,
- api_key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- ):
- requestor = api_requestor.APIRequestor(
- api_key,
- api_base=api_base or openai.api_base,
- api_type=api_type,
- api_version=api_version,
- organization=organization,
- )
- typed_api_type, api_version = cls._get_api_type_and_version(
- api_type, api_version
- )
-
- if typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD):
- base = cls.class_url()
- url = f"/{cls.azure_api_prefix}{base}/{id}/content?api-version={api_version}"
- elif typed_api_type == ApiType.OPEN_AI:
- url = f"{cls.class_url()}/{id}/content"
- else:
- raise error.InvalidAPIType("Unsupported API type %s" % api_type)
-
- return requestor, url
-
- @classmethod
- def download(
- cls,
- id,
- api_key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- ):
- requestor, url = cls.__prepare_file_download(
- id, api_key, api_base, api_type, api_version, organization
- )
-
- result = requestor.request_raw("get", url)
- if not 200 <= result.status_code < 300:
- raise requestor.handle_error_response(
- result.content,
- result.status_code,
- json.loads(cast(bytes, result.content)),
- result.headers,
- stream_error=False,
- )
- return result.content
-
- @classmethod
- async def adownload(
- cls,
- id,
- api_key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- ):
- requestor, url = cls.__prepare_file_download(
- id, api_key, api_base, api_type, api_version, organization
- )
-
- async with api_requestor.aiohttp_session() as session:
- result = await requestor.arequest_raw("get", url, session)
- if not 200 <= result.status < 300:
- raise requestor.handle_error_response(
- result.content,
- result.status,
- json.loads(cast(bytes, result.content)),
- result.headers,
- stream_error=False,
- )
- return result.content
-
- @classmethod
- def __find_matching_files(cls, name, bytes, all_files, purpose):
- matching_files = []
- basename = os.path.basename(name)
- for f in all_files:
- if f["purpose"] != purpose:
- continue
- file_basename = os.path.basename(f["filename"])
- if file_basename != basename:
- continue
- if "bytes" in f and f["bytes"] != bytes:
- continue
- if "size" in f and int(f["size"]) != bytes:
- continue
- matching_files.append(f)
- return matching_files
-
- @classmethod
- def find_matching_files(
- cls,
- name,
- bytes,
- purpose,
- api_key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- ):
- """Find already uploaded files with the same name, size, and purpose."""
- all_files = cls.list(
- api_key=api_key,
- api_base=api_base or openai.api_base,
- api_type=api_type,
- api_version=api_version,
- organization=organization,
- ).get("data", [])
- return cls.__find_matching_files(name, bytes, all_files, purpose)
-
- @classmethod
- async def afind_matching_files(
- cls,
- name,
- bytes,
- purpose,
- api_key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- ):
- """Find already uploaded files with the same name, size, and purpose."""
- all_files = (
- await cls.alist(
- api_key=api_key,
- api_base=api_base or openai.api_base,
- api_type=api_type,
- api_version=api_version,
- organization=organization,
- )
- ).get("data", [])
- return cls.__find_matching_files(name, bytes, all_files, purpose)
-
- @classmethod
- def wait_for_processing(cls, id, max_wait_seconds=30 * 60):
- TERMINAL_STATES = ["processed", "error", "deleted"]
-
- start = time.time()
- file = cls.retrieve(id=id)
- while file.status not in TERMINAL_STATES:
- file = cls.retrieve(id=id)
- time.sleep(5.0)
- if time.time() - start > max_wait_seconds:
- raise openai.error.OpenAIError(
- message="Giving up on waiting for file {id} to finish processing after {max_wait_seconds} seconds.".format(
- id=id, max_wait_seconds=max_wait_seconds
- )
- )
- return file.status
diff --git a/openai/api_resources/fine_tune.py b/openai/api_resources/fine_tune.py
deleted file mode 100644
index 45e3cf2af3..0000000000
--- a/openai/api_resources/fine_tune.py
+++ /dev/null
@@ -1,204 +0,0 @@
-from urllib.parse import quote_plus
-
-from openai import api_requestor, util, error
-from openai.api_resources.abstract import (
- CreateableAPIResource,
- ListableAPIResource,
- nested_resource_class_methods,
-)
-from openai.api_resources.abstract.deletable_api_resource import DeletableAPIResource
-from openai.openai_response import OpenAIResponse
-from openai.util import ApiType
-
-
-@nested_resource_class_methods("event", operations=["list"])
-class FineTune(ListableAPIResource, CreateableAPIResource, DeletableAPIResource):
- OBJECT_NAME = "fine-tunes"
-
- @classmethod
- def _prepare_cancel(
- cls,
- id,
- api_key=None,
- api_type=None,
- request_id=None,
- api_version=None,
- **params,
- ):
- base = cls.class_url()
- extn = quote_plus(id)
-
- typed_api_type, api_version = cls._get_api_type_and_version(
- api_type, api_version
- )
- if typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD):
- url = "/%s%s/%s/cancel?api-version=%s" % (
- cls.azure_api_prefix,
- base,
- extn,
- api_version,
- )
- elif typed_api_type == ApiType.OPEN_AI:
- url = "%s/%s/cancel" % (base, extn)
- else:
- raise error.InvalidAPIType("Unsupported API type %s" % api_type)
-
- instance = cls(id, api_key, **params)
- return instance, url
-
- @classmethod
- def cancel(
- cls,
- id,
- api_key=None,
- api_type=None,
- request_id=None,
- api_version=None,
- **params,
- ):
- instance, url = cls._prepare_cancel(
- id,
- api_key,
- api_type,
- request_id,
- api_version,
- **params,
- )
- return instance.request("post", url, request_id=request_id)
-
- @classmethod
- def acancel(
- cls,
- id,
- api_key=None,
- api_type=None,
- request_id=None,
- api_version=None,
- **params,
- ):
- instance, url = cls._prepare_cancel(
- id,
- api_key,
- api_type,
- request_id,
- api_version,
- **params,
- )
- return instance.arequest("post", url, request_id=request_id)
-
- @classmethod
- def _prepare_stream_events(
- cls,
- id,
- api_key=None,
- api_base=None,
- api_type=None,
- request_id=None,
- api_version=None,
- organization=None,
- **params,
- ):
- base = cls.class_url()
- extn = quote_plus(id)
-
- requestor = api_requestor.APIRequestor(
- api_key,
- api_base=api_base,
- api_type=api_type,
- api_version=api_version,
- organization=organization,
- )
-
- typed_api_type, api_version = cls._get_api_type_and_version(
- api_type, api_version
- )
-
- if typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD):
- url = "/%s%s/%s/events?stream=true&api-version=%s" % (
- cls.azure_api_prefix,
- base,
- extn,
- api_version,
- )
- elif typed_api_type == ApiType.OPEN_AI:
- url = "%s/%s/events?stream=true" % (base, extn)
- else:
- raise error.InvalidAPIType("Unsupported API type %s" % api_type)
-
- return requestor, url
-
- @classmethod
- def stream_events(
- cls,
- id,
- api_key=None,
- api_base=None,
- api_type=None,
- request_id=None,
- api_version=None,
- organization=None,
- **params,
- ):
- requestor, url = cls._prepare_stream_events(
- id,
- api_key,
- api_base,
- api_type,
- request_id,
- api_version,
- organization,
- **params,
- )
-
- response, _, api_key = requestor.request(
- "get", url, params, stream=True, request_id=request_id
- )
-
- assert not isinstance(response, OpenAIResponse) # must be an iterator
- return (
- util.convert_to_openai_object(
- line,
- api_key,
- api_version,
- organization,
- )
- for line in response
- )
-
- @classmethod
- async def astream_events(
- cls,
- id,
- api_key=None,
- api_base=None,
- api_type=None,
- request_id=None,
- api_version=None,
- organization=None,
- **params,
- ):
- requestor, url = cls._prepare_stream_events(
- id,
- api_key,
- api_base,
- api_type,
- request_id,
- api_version,
- organization,
- **params,
- )
-
- response, _, api_key = await requestor.arequest(
- "get", url, params, stream=True, request_id=request_id
- )
-
- assert not isinstance(response, OpenAIResponse) # must be an iterator
- return (
- util.convert_to_openai_object(
- line,
- api_key,
- api_version,
- organization,
- )
- async for line in response
- )
diff --git a/openai/api_resources/fine_tuning.py b/openai/api_resources/fine_tuning.py
deleted file mode 100644
index f03be56ab7..0000000000
--- a/openai/api_resources/fine_tuning.py
+++ /dev/null
@@ -1,88 +0,0 @@
-from urllib.parse import quote_plus
-
-from openai import error
-from openai.api_resources.abstract import (
- CreateableAPIResource,
- PaginatableAPIResource,
- nested_resource_class_methods,
-)
-from openai.api_resources.abstract.deletable_api_resource import DeletableAPIResource
-from openai.util import ApiType
-
-
-@nested_resource_class_methods("event", operations=["paginated_list"])
-class FineTuningJob(
- PaginatableAPIResource, CreateableAPIResource, DeletableAPIResource
-):
- OBJECT_NAME = "fine_tuning.jobs"
-
- @classmethod
- def _prepare_cancel(
- cls,
- id,
- api_key=None,
- api_type=None,
- request_id=None,
- api_version=None,
- **params,
- ):
- base = cls.class_url()
- extn = quote_plus(id)
-
- typed_api_type, api_version = cls._get_api_type_and_version(
- api_type, api_version
- )
- if typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD):
- url = "/%s%s/%s/cancel?api-version=%s" % (
- cls.azure_api_prefix,
- base,
- extn,
- api_version,
- )
- elif typed_api_type == ApiType.OPEN_AI:
- url = "%s/%s/cancel" % (base, extn)
- else:
- raise error.InvalidAPIType("Unsupported API type %s" % api_type)
-
- instance = cls(id, api_key, **params)
- return instance, url
-
- @classmethod
- def cancel(
- cls,
- id,
- api_key=None,
- api_type=None,
- request_id=None,
- api_version=None,
- **params,
- ):
- instance, url = cls._prepare_cancel(
- id,
- api_key,
- api_type,
- request_id,
- api_version,
- **params,
- )
- return instance.request("post", url, request_id=request_id)
-
- @classmethod
- def acancel(
- cls,
- id,
- api_key=None,
- api_type=None,
- request_id=None,
- api_version=None,
- **params,
- ):
- instance, url = cls._prepare_cancel(
- id,
- api_key,
- api_type,
- request_id,
- api_version,
- **params,
- )
- return instance.arequest("post", url, request_id=request_id)
diff --git a/openai/api_resources/image.py b/openai/api_resources/image.py
deleted file mode 100644
index 1522923510..0000000000
--- a/openai/api_resources/image.py
+++ /dev/null
@@ -1,273 +0,0 @@
-# WARNING: This interface is considered experimental and may changed in the future without warning.
-from typing import Any, List
-
-import openai
-from openai import api_requestor, error, util
-from openai.api_resources.abstract import APIResource
-
-
-class Image(APIResource):
- OBJECT_NAME = "images"
-
- @classmethod
- def _get_url(cls, action, azure_action, api_type, api_version):
- if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD) and azure_action is not None:
- return f"/{cls.azure_api_prefix}{cls.class_url()}/{action}:{azure_action}?api-version={api_version}"
- else:
- return f"{cls.class_url()}/{action}"
-
- @classmethod
- def create(
- cls,
- api_key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- **params,
- ):
- requestor = api_requestor.APIRequestor(
- api_key,
- api_base=api_base or openai.api_base,
- api_type=api_type,
- api_version=api_version,
- organization=organization,
- )
-
- api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
-
- response, _, api_key = requestor.request(
- "post", cls._get_url("generations", azure_action="submit", api_type=api_type, api_version=api_version), params
- )
-
- if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
- requestor.api_base = "" # operation_location is a full url
- response, _, api_key = requestor._poll(
- "get", response.operation_location,
- until=lambda response: response.data['status'] in [ 'succeeded' ],
- failed=lambda response: response.data['status'] in [ 'failed' ]
- )
-
- return util.convert_to_openai_object(
- response, api_key, api_version, organization
- )
-
- @classmethod
- async def acreate(
- cls,
- api_key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- **params,
- ):
-
- requestor = api_requestor.APIRequestor(
- api_key,
- api_base=api_base or openai.api_base,
- api_type=api_type,
- api_version=api_version,
- organization=organization,
- )
-
- api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
-
- response, _, api_key = await requestor.arequest(
- "post", cls._get_url("generations", azure_action="submit", api_type=api_type, api_version=api_version), params
- )
-
- if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
- requestor.api_base = "" # operation_location is a full url
- response, _, api_key = await requestor._apoll(
- "get", response.operation_location,
- until=lambda response: response.data['status'] in [ 'succeeded' ],
- failed=lambda response: response.data['status'] in [ 'failed' ]
- )
-
- return util.convert_to_openai_object(
- response, api_key, api_version, organization
- )
-
- @classmethod
- def _prepare_create_variation(
- cls,
- image,
- api_key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- **params,
- ):
- requestor = api_requestor.APIRequestor(
- api_key,
- api_base=api_base or openai.api_base,
- api_type=api_type,
- api_version=api_version,
- organization=organization,
- )
- api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
-
- url = cls._get_url("variations", azure_action=None, api_type=api_type, api_version=api_version)
-
- files: List[Any] = []
- for key, value in params.items():
- files.append((key, (None, value)))
- files.append(("image", ("image", image, "application/octet-stream")))
- return requestor, url, files
-
- @classmethod
- def create_variation(
- cls,
- image,
- api_key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- **params,
- ):
- if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
- raise error.InvalidAPIType("Variations are not supported by the Azure OpenAI API yet.")
-
- requestor, url, files = cls._prepare_create_variation(
- image,
- api_key,
- api_base,
- api_type,
- api_version,
- organization,
- **params,
- )
-
- response, _, api_key = requestor.request("post", url, files=files)
-
- return util.convert_to_openai_object(
- response, api_key, api_version, organization
- )
-
- @classmethod
- async def acreate_variation(
- cls,
- image,
- api_key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- **params,
- ):
- if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
- raise error.InvalidAPIType("Variations are not supported by the Azure OpenAI API yet.")
-
- requestor, url, files = cls._prepare_create_variation(
- image,
- api_key,
- api_base,
- api_type,
- api_version,
- organization,
- **params,
- )
-
- response, _, api_key = await requestor.arequest("post", url, files=files)
-
- return util.convert_to_openai_object(
- response, api_key, api_version, organization
- )
-
- @classmethod
- def _prepare_create_edit(
- cls,
- image,
- mask=None,
- api_key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- **params,
- ):
- requestor = api_requestor.APIRequestor(
- api_key,
- api_base=api_base or openai.api_base,
- api_type=api_type,
- api_version=api_version,
- organization=organization,
- )
- api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
-
- url = cls._get_url("edits", azure_action=None, api_type=api_type, api_version=api_version)
-
- files: List[Any] = []
- for key, value in params.items():
- files.append((key, (None, value)))
- files.append(("image", ("image", image, "application/octet-stream")))
- if mask is not None:
- files.append(("mask", ("mask", mask, "application/octet-stream")))
- return requestor, url, files
-
- @classmethod
- def create_edit(
- cls,
- image,
- mask=None,
- api_key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- **params,
- ):
- if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
- raise error.InvalidAPIType("Edits are not supported by the Azure OpenAI API yet.")
-
- requestor, url, files = cls._prepare_create_edit(
- image,
- mask,
- api_key,
- api_base,
- api_type,
- api_version,
- organization,
- **params,
- )
-
- response, _, api_key = requestor.request("post", url, files=files)
-
- return util.convert_to_openai_object(
- response, api_key, api_version, organization
- )
-
- @classmethod
- async def acreate_edit(
- cls,
- image,
- mask=None,
- api_key=None,
- api_base=None,
- api_type=None,
- api_version=None,
- organization=None,
- **params,
- ):
- if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
- raise error.InvalidAPIType("Edits are not supported by the Azure OpenAI API yet.")
-
- requestor, url, files = cls._prepare_create_edit(
- image,
- mask,
- api_key,
- api_base,
- api_type,
- api_version,
- organization,
- **params,
- )
-
- response, _, api_key = await requestor.arequest("post", url, files=files)
-
- return util.convert_to_openai_object(
- response, api_key, api_version, organization
- )
diff --git a/openai/api_resources/model.py b/openai/api_resources/model.py
deleted file mode 100644
index 9785e17fe1..0000000000
--- a/openai/api_resources/model.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from openai.api_resources.abstract import DeletableAPIResource, ListableAPIResource
-
-
-class Model(ListableAPIResource, DeletableAPIResource):
- OBJECT_NAME = "models"
diff --git a/openai/api_resources/moderation.py b/openai/api_resources/moderation.py
deleted file mode 100644
index bd19646b49..0000000000
--- a/openai/api_resources/moderation.py
+++ /dev/null
@@ -1,45 +0,0 @@
-from typing import List, Optional, Union
-
-from openai.openai_object import OpenAIObject
-
-
-class Moderation(OpenAIObject):
- VALID_MODEL_NAMES: List[str] = ["text-moderation-stable", "text-moderation-latest"]
-
- @classmethod
- def get_url(cls):
- return "/moderations"
-
- @classmethod
- def _prepare_create(cls, input, model, api_key):
- if model is not None and model not in cls.VALID_MODEL_NAMES:
- raise ValueError(
- f"The parameter model should be chosen from {cls.VALID_MODEL_NAMES} "
- f"and it is default to be None."
- )
-
- instance = cls(api_key=api_key)
- params = {"input": input}
- if model is not None:
- params["model"] = model
- return instance, params
-
- @classmethod
- def create(
- cls,
- input: Union[str, List[str]],
- model: Optional[str] = None,
- api_key: Optional[str] = None,
- ):
- instance, params = cls._prepare_create(input, model, api_key)
- return instance.request("post", cls.get_url(), params)
-
- @classmethod
- def acreate(
- cls,
- input: Union[str, List[str]],
- model: Optional[str] = None,
- api_key: Optional[str] = None,
- ):
- instance, params = cls._prepare_create(input, model, api_key)
- return instance.arequest("post", cls.get_url(), params)
diff --git a/openai/cli.py b/openai/cli.py
deleted file mode 100644
index a6e99396ae..0000000000
--- a/openai/cli.py
+++ /dev/null
@@ -1,1416 +0,0 @@
-import datetime
-import os
-import signal
-import sys
-import warnings
-from typing import Optional
-
-import requests
-
-import openai
-from openai.upload_progress import BufferReader
-from openai.validators import (
- apply_necessary_remediation,
- apply_validators,
- get_validators,
- read_any_format,
- write_out_file,
-)
-
-
-class bcolors:
- HEADER = "\033[95m"
- OKBLUE = "\033[94m"
- OKGREEN = "\033[92m"
- WARNING = "\033[93m"
- FAIL = "\033[91m"
- ENDC = "\033[0m"
- BOLD = "\033[1m"
- UNDERLINE = "\033[4m"
-
-
-def organization_info(obj):
- organization = getattr(obj, "organization", None)
- if organization is not None:
- return "[organization={}] ".format(organization)
- else:
- return ""
-
-
-def display(obj):
- sys.stderr.write(organization_info(obj))
- sys.stderr.flush()
- print(obj)
-
-
-def display_error(e):
- extra = (
- " (HTTP status code: {})".format(e.http_status)
- if e.http_status is not None
- else ""
- )
- sys.stderr.write(
- "{}{}Error:{} {}{}\n".format(
- organization_info(e), bcolors.FAIL, bcolors.ENDC, e, extra
- )
- )
-
-
-class Engine:
- @classmethod
- def get(cls, args):
- engine = openai.Engine.retrieve(id=args.id)
- display(engine)
-
- @classmethod
- def update(cls, args):
- engine = openai.Engine.modify(args.id, replicas=args.replicas)
- display(engine)
-
- @classmethod
- def generate(cls, args):
- warnings.warn(
- "Engine.generate is deprecated, use Completion.create", DeprecationWarning
- )
- if args.completions and args.completions > 1 and args.stream:
- raise ValueError("Can't stream multiple completions with openai CLI")
-
- kwargs = {}
- if args.model is not None:
- kwargs["model"] = args.model
- resp = openai.Engine(id=args.id).generate(
- completions=args.completions,
- context=args.context,
- length=args.length,
- stream=args.stream,
- temperature=args.temperature,
- top_p=args.top_p,
- logprobs=args.logprobs,
- stop=args.stop,
- **kwargs,
- )
- if not args.stream:
- resp = [resp]
-
- for part in resp:
- completions = len(part["data"])
- for c_idx, c in enumerate(part["data"]):
- if completions > 1:
- sys.stdout.write("===== Completion {} =====\n".format(c_idx))
- sys.stdout.write("".join(c["text"]))
- if completions > 1:
- sys.stdout.write("\n")
- sys.stdout.flush()
-
- @classmethod
- def list(cls, args):
- engines = openai.Engine.list()
- display(engines)
-
-
-class ChatCompletion:
- @classmethod
- def create(cls, args):
- if args.n is not None and args.n > 1 and args.stream:
- raise ValueError(
- "Can't stream chat completions with n>1 with the current CLI"
- )
-
- messages = [
- {"role": role, "content": content} for role, content in args.message
- ]
-
- resp = openai.ChatCompletion.create(
- # Required
- model=args.model,
- engine=args.engine,
- messages=messages,
- # Optional
- n=args.n,
- max_tokens=args.max_tokens,
- temperature=args.temperature,
- top_p=args.top_p,
- stop=args.stop,
- stream=args.stream,
- )
- if not args.stream:
- resp = [resp]
-
- for part in resp:
- choices = part["choices"]
- for c_idx, c in enumerate(sorted(choices, key=lambda s: s["index"])):
- if len(choices) > 1:
- sys.stdout.write("===== Chat Completion {} =====\n".format(c_idx))
- if args.stream:
- delta = c["delta"]
- if "content" in delta:
- sys.stdout.write(delta["content"])
- else:
- sys.stdout.write(c["message"]["content"])
- if len(choices) > 1: # not in streams
- sys.stdout.write("\n")
- sys.stdout.flush()
-
-
-class Completion:
- @classmethod
- def create(cls, args):
- if args.n is not None and args.n > 1 and args.stream:
- raise ValueError("Can't stream completions with n>1 with the current CLI")
-
- if args.engine and args.model:
- warnings.warn(
- "In most cases, you should not be specifying both engine and model."
- )
-
- resp = openai.Completion.create(
- engine=args.engine,
- model=args.model,
- n=args.n,
- max_tokens=args.max_tokens,
- logprobs=args.logprobs,
- prompt=args.prompt,
- stream=args.stream,
- temperature=args.temperature,
- top_p=args.top_p,
- stop=args.stop,
- echo=True,
- )
- if not args.stream:
- resp = [resp]
-
- for part in resp:
- choices = part["choices"]
- for c_idx, c in enumerate(sorted(choices, key=lambda s: s["index"])):
- if len(choices) > 1:
- sys.stdout.write("===== Completion {} =====\n".format(c_idx))
- sys.stdout.write(c["text"])
- if len(choices) > 1:
- sys.stdout.write("\n")
- sys.stdout.flush()
-
-
-class Deployment:
- @classmethod
- def get(cls, args):
- resp = openai.Deployment.retrieve(id=args.id)
- print(resp)
-
- @classmethod
- def delete(cls, args):
- model = openai.Deployment.delete(args.id)
- print(model)
-
- @classmethod
- def list(cls, args):
- models = openai.Deployment.list()
- print(models)
-
- @classmethod
- def create(cls, args):
- models = openai.Deployment.create(
- model=args.model, scale_settings={"scale_type": args.scale_type}
- )
- print(models)
-
-
-class Model:
- @classmethod
- def get(cls, args):
- resp = openai.Model.retrieve(id=args.id)
- print(resp)
-
- @classmethod
- def delete(cls, args):
- model = openai.Model.delete(args.id)
- print(model)
-
- @classmethod
- def list(cls, args):
- models = openai.Model.list()
- print(models)
-
-
-class File:
- @classmethod
- def create(cls, args):
- with open(args.file, "rb") as file_reader:
- buffer_reader = BufferReader(file_reader.read(), desc="Upload progress")
- resp = openai.File.create(
- file=buffer_reader,
- purpose=args.purpose,
- user_provided_filename=args.file,
- )
- print(resp)
-
- @classmethod
- def get(cls, args):
- resp = openai.File.retrieve(id=args.id)
- print(resp)
-
- @classmethod
- def delete(cls, args):
- file = openai.File.delete(args.id)
- print(file)
-
- @classmethod
- def list(cls, args):
- file = openai.File.list()
- print(file)
-
-
-class Image:
- @classmethod
- def create(cls, args):
- resp = openai.Image.create(
- prompt=args.prompt,
- size=args.size,
- n=args.num_images,
- response_format=args.response_format,
- )
- print(resp)
-
- @classmethod
- def create_variation(cls, args):
- with open(args.image, "rb") as file_reader:
- buffer_reader = BufferReader(file_reader.read(), desc="Upload progress")
- resp = openai.Image.create_variation(
- image=buffer_reader,
- size=args.size,
- n=args.num_images,
- response_format=args.response_format,
- )
- print(resp)
-
- @classmethod
- def create_edit(cls, args):
- with open(args.image, "rb") as file_reader:
- image_reader = BufferReader(file_reader.read(), desc="Upload progress")
- mask_reader = None
- if args.mask is not None:
- with open(args.mask, "rb") as file_reader:
- mask_reader = BufferReader(file_reader.read(), desc="Upload progress")
- resp = openai.Image.create_edit(
- image=image_reader,
- mask=mask_reader,
- prompt=args.prompt,
- size=args.size,
- n=args.num_images,
- response_format=args.response_format,
- )
- print(resp)
-
-
-class Audio:
- @classmethod
- def transcribe(cls, args):
- with open(args.file, "rb") as r:
- file_reader = BufferReader(r.read(), desc="Upload progress")
-
- resp = openai.Audio.transcribe_raw(
- # Required
- model=args.model,
- file=file_reader,
- filename=args.file,
- # Optional
- response_format=args.response_format,
- language=args.language,
- temperature=args.temperature,
- prompt=args.prompt,
- )
- print(resp)
-
- @classmethod
- def translate(cls, args):
- with open(args.file, "rb") as r:
- file_reader = BufferReader(r.read(), desc="Upload progress")
- resp = openai.Audio.translate_raw(
- # Required
- model=args.model,
- file=file_reader,
- filename=args.file,
- # Optional
- response_format=args.response_format,
- language=args.language,
- temperature=args.temperature,
- prompt=args.prompt,
- )
- print(resp)
-
-
-class FineTune:
- @classmethod
- def list(cls, args):
- resp = openai.FineTune.list()
- print(resp)
-
- @classmethod
- def _is_url(cls, file: str):
- return file.lower().startswith("http")
-
- @classmethod
- def _download_file_from_public_url(cls, url: str) -> Optional[bytes]:
- resp = requests.get(url)
- if resp.status_code == 200:
- return resp.content
- else:
- return None
-
- @classmethod
- def _maybe_upload_file(
- cls,
- file: Optional[str] = None,
- content: Optional[bytes] = None,
- user_provided_file: Optional[str] = None,
- check_if_file_exists: bool = True,
- ):
- # Exactly one of `file` or `content` must be provided
- if (file is None) == (content is None):
- raise ValueError("Exactly one of `file` or `content` must be provided")
-
- if content is None:
- assert file is not None
- with open(file, "rb") as f:
- content = f.read()
-
- if check_if_file_exists:
- bytes = len(content)
- matching_files = openai.File.find_matching_files(
- name=user_provided_file or f.name, bytes=bytes, purpose="fine-tune"
- )
- if len(matching_files) > 0:
- file_ids = [f["id"] for f in matching_files]
- sys.stdout.write(
- "Found potentially duplicated files with name '{name}', purpose 'fine-tune' and size {size} bytes\n".format(
- name=os.path.basename(matching_files[0]["filename"]),
- size=matching_files[0]["bytes"]
- if "bytes" in matching_files[0]
- else matching_files[0]["size"],
- )
- )
- sys.stdout.write("\n".join(file_ids))
- while True:
- sys.stdout.write(
- "\nEnter file ID to reuse an already uploaded file, or an empty string to upload this file anyway: "
- )
- inp = sys.stdin.readline().strip()
- if inp in file_ids:
- sys.stdout.write(
- "Reusing already uploaded file: {id}\n".format(id=inp)
- )
- return inp
- elif inp == "":
- break
- else:
- sys.stdout.write(
- "File id '{id}' is not among the IDs of the potentially duplicated files\n".format(
- id=inp
- )
- )
-
- buffer_reader = BufferReader(content, desc="Upload progress")
- resp = openai.File.create(
- file=buffer_reader,
- purpose="fine-tune",
- user_provided_filename=user_provided_file or file,
- )
- sys.stdout.write(
- "Uploaded file from {file}: {id}\n".format(
- file=user_provided_file or file, id=resp["id"]
- )
- )
- return resp["id"]
-
- @classmethod
- def _get_or_upload(cls, file, check_if_file_exists=True):
- try:
- # 1. If it's a valid file, use it
- openai.File.retrieve(file)
- return file
- except openai.error.InvalidRequestError:
- pass
- if os.path.isfile(file):
- # 2. If it's a file on the filesystem, upload it
- return cls._maybe_upload_file(
- file=file, check_if_file_exists=check_if_file_exists
- )
- if cls._is_url(file):
- # 3. If it's a URL, download it temporarily
- content = cls._download_file_from_public_url(file)
- if content is not None:
- return cls._maybe_upload_file(
- content=content,
- check_if_file_exists=check_if_file_exists,
- user_provided_file=file,
- )
- return file
-
- @classmethod
- def create(cls, args):
- create_args = {
- "training_file": cls._get_or_upload(
- args.training_file, args.check_if_files_exist
- ),
- }
- if args.validation_file:
- create_args["validation_file"] = cls._get_or_upload(
- args.validation_file, args.check_if_files_exist
- )
-
- for hparam in (
- "model",
- "suffix",
- "n_epochs",
- "batch_size",
- "learning_rate_multiplier",
- "prompt_loss_weight",
- "compute_classification_metrics",
- "classification_n_classes",
- "classification_positive_class",
- "classification_betas",
- ):
- attr = getattr(args, hparam)
- if attr is not None:
- create_args[hparam] = attr
-
- resp = openai.FineTune.create(**create_args)
-
- if args.no_follow:
- print(resp)
- return
-
- sys.stdout.write(
- "Created fine-tune: {job_id}\n"
- "Streaming events until fine-tuning is complete...\n\n"
- "(Ctrl-C will interrupt the stream, but not cancel the fine-tune)\n".format(
- job_id=resp["id"]
- )
- )
- cls._stream_events(resp["id"])
-
- @classmethod
- def get(cls, args):
- resp = openai.FineTune.retrieve(id=args.id)
- print(resp)
-
- @classmethod
- def results(cls, args):
- fine_tune = openai.FineTune.retrieve(id=args.id)
- if "result_files" not in fine_tune or len(fine_tune["result_files"]) == 0:
- raise openai.error.InvalidRequestError(
- f"No results file available for fine-tune {args.id}", "id"
- )
- result_file = openai.FineTune.retrieve(id=args.id)["result_files"][0]
- resp = openai.File.download(id=result_file["id"])
- print(resp.decode("utf-8"))
-
- @classmethod
- def events(cls, args):
- if args.stream:
- raise openai.error.OpenAIError(
- message=(
- "The --stream parameter is deprecated, use fine_tunes.follow "
- "instead:\n\n"
- " openai api fine_tunes.follow -i {id}\n".format(id=args.id)
- ),
- )
-
- resp = openai.FineTune.list_events(id=args.id) # type: ignore
- print(resp)
-
- @classmethod
- def follow(cls, args):
- cls._stream_events(args.id)
-
- @classmethod
- def _stream_events(cls, job_id):
- def signal_handler(sig, frame):
- status = openai.FineTune.retrieve(job_id).status
- sys.stdout.write(
- "\nStream interrupted. Job is still {status}.\n"
- "To resume the stream, run:\n\n"
- " openai api fine_tunes.follow -i {job_id}\n\n"
- "To cancel your job, run:\n\n"
- " openai api fine_tunes.cancel -i {job_id}\n\n".format(
- status=status, job_id=job_id
- )
- )
- sys.exit(0)
-
- signal.signal(signal.SIGINT, signal_handler)
-
- events = openai.FineTune.stream_events(job_id)
- # TODO(rachel): Add a nifty spinner here.
- try:
- for event in events:
- sys.stdout.write(
- "[%s] %s"
- % (
- datetime.datetime.fromtimestamp(event["created_at"]),
- event["message"],
- )
- )
- sys.stdout.write("\n")
- sys.stdout.flush()
- except Exception:
- sys.stdout.write(
- "\nStream interrupted (client disconnected).\n"
- "To resume the stream, run:\n\n"
- " openai api fine_tunes.follow -i {job_id}\n\n".format(job_id=job_id)
- )
- return
-
- resp = openai.FineTune.retrieve(id=job_id)
- status = resp["status"]
- if status == "succeeded":
- sys.stdout.write("\nJob complete! Status: succeeded 🎉")
- sys.stdout.write(
- "\nTry out your fine-tuned model:\n\n"
- "openai api completions.create -m {model} -p ".format(
- model=resp["fine_tuned_model"]
- )
- )
- elif status == "failed":
- sys.stdout.write(
- "\nJob failed. Please contact us through our help center at help.openai.com if you need assistance."
- )
- sys.stdout.write("\n")
-
- @classmethod
- def cancel(cls, args):
- resp = openai.FineTune.cancel(id=args.id)
- print(resp)
-
- @classmethod
- def delete(cls, args):
- resp = openai.FineTune.delete(sid=args.id)
- print(resp)
-
- @classmethod
- def prepare_data(cls, args):
- sys.stdout.write("Analyzing...\n")
- fname = args.file
- auto_accept = args.quiet
- df, remediation = read_any_format(fname)
- apply_necessary_remediation(None, remediation)
-
- validators = get_validators()
-
- apply_validators(
- df,
- fname,
- remediation,
- validators,
- auto_accept,
- write_out_file_func=write_out_file,
- )
-
-
-class FineTuningJob:
- @classmethod
- def list(cls, args):
- has_ft_jobs = False
- for fine_tune_job in openai.FineTuningJob.auto_paging_iter():
- has_ft_jobs = True
- print(fine_tune_job)
- if not has_ft_jobs:
- print("No fine-tuning jobs found.")
-
- @classmethod
- def _is_url(cls, file: str):
- return file.lower().startswith("http")
-
- @classmethod
- def _download_file_from_public_url(cls, url: str) -> Optional[bytes]:
- resp = requests.get(url)
- if resp.status_code == 200:
- return resp.content
- else:
- return None
-
- @classmethod
- def _maybe_upload_file(
- cls,
- file: Optional[str] = None,
- content: Optional[bytes] = None,
- user_provided_file: Optional[str] = None,
- check_if_file_exists: bool = True,
- ):
- # Exactly one of `file` or `content` must be provided
- if (file is None) == (content is None):
- raise ValueError("Exactly one of `file` or `content` must be provided")
-
- if content is None:
- assert file is not None
- with open(file, "rb") as f:
- content = f.read()
-
- if check_if_file_exists:
- bytes = len(content)
- matching_files = openai.File.find_matching_files(
- name=user_provided_file or f.name,
- bytes=bytes,
- purpose="fine-tune",
- )
- if len(matching_files) > 0:
- file_ids = [f["id"] for f in matching_files]
- sys.stdout.write(
- "Found potentially duplicated files with name '{name}', purpose 'fine-tune', and size {size} bytes\n".format(
- name=os.path.basename(matching_files[0]["filename"]),
- size=matching_files[0]["bytes"]
- if "bytes" in matching_files[0]
- else matching_files[0]["size"],
- )
- )
- sys.stdout.write("\n".join(file_ids))
- while True:
- sys.stdout.write(
- "\nEnter file ID to reuse an already uploaded file, or an empty string to upload this file anyway: "
- )
- inp = sys.stdin.readline().strip()
- if inp in file_ids:
- sys.stdout.write(
- "Reusing already uploaded file: {id}\n".format(id=inp)
- )
- return inp
- elif inp == "":
- break
- else:
- sys.stdout.write(
- "File id '{id}' is not among the IDs of the potentially duplicated files\n".format(
- id=inp
- )
- )
-
- buffer_reader = BufferReader(content, desc="Upload progress")
- resp = openai.File.create(
- file=buffer_reader,
- purpose="fine-tune",
- user_provided_filename=user_provided_file or file,
- )
- sys.stdout.write(
- "Uploaded file from {file}: {id}\n".format(
- file=user_provided_file or file, id=resp["id"]
- )
- )
- sys.stdout.write("Waiting for file to finish processing before proceeding..\n")
- sys.stdout.flush()
- status = openai.File.wait_for_processing(resp["id"])
- if status != "processed":
- raise openai.error.OpenAIError(
- "File {id} failed to process, status={status}.".format(
- id=resp["id"], status=status
- )
- )
-
- sys.stdout.write(
- "File {id} finished processing and is ready for use in fine-tuning".format(
- id=resp["id"]
- )
- )
- sys.stdout.flush()
- return resp["id"]
-
- @classmethod
- def _get_or_upload(cls, file, check_if_file_exists=True):
- try:
- # 1. If it's a valid file, use it
- openai.File.retrieve(file)
- return file
- except openai.error.InvalidRequestError:
- pass
- if os.path.isfile(file):
- # 2. If it's a file on the filesystem, upload it
- return cls._maybe_upload_file(
- file=file, check_if_file_exists=check_if_file_exists
- )
- if cls._is_url(file):
- # 3. If it's a URL, download it temporarily
- content = cls._download_file_from_public_url(file)
- if content is not None:
- return cls._maybe_upload_file(
- content=content,
- check_if_file_exists=check_if_file_exists,
- user_provided_file=file,
- )
- return file
-
- @classmethod
- def create(cls, args):
- create_args = {
- "training_file": cls._get_or_upload(
- args.training_file, args.check_if_files_exist
- ),
- }
- if args.validation_file:
- create_args["validation_file"] = cls._get_or_upload(
- args.validation_file, args.check_if_files_exist
- )
-
- for param in ("model", "suffix"):
- attr = getattr(args, param)
- if attr is not None:
- create_args[param] = attr
-
- if getattr(args, "n_epochs"):
- create_args["hyperparameters"] = {
- "n_epochs": args.n_epochs,
- }
-
- resp = openai.FineTuningJob.create(**create_args)
- print(resp)
- return
-
- @classmethod
- def get(cls, args):
- resp = openai.FineTuningJob.retrieve(id=args.id)
- print(resp)
-
- @classmethod
- def results(cls, args):
- fine_tune = openai.FineTuningJob.retrieve(id=args.id)
- if "result_files" not in fine_tune or len(fine_tune["result_files"]) == 0:
- raise openai.error.InvalidRequestError(
- f"No results file available for fine-tune {args.id}", "id"
- )
- result_file = openai.FineTuningJob.retrieve(id=args.id)["result_files"][0]
- resp = openai.File.download(id=result_file)
- print(resp.decode("utf-8"))
-
- @classmethod
- def events(cls, args):
- seen, has_more, after = 0, True, None
- while has_more:
- resp = openai.FineTuningJob.list_events(id=args.id, after=after) # type: ignore
- for event in resp["data"]:
- print(event)
- seen += 1
- if args.limit is not None and seen >= args.limit:
- return
- has_more = resp.get("has_more", False)
- if resp["data"]:
- after = resp["data"][-1]["id"]
-
- @classmethod
- def follow(cls, args):
- raise openai.error.OpenAIError(
- message="Event streaming is not yet supported for `fine_tuning.job` events"
- )
-
- @classmethod
- def cancel(cls, args):
- resp = openai.FineTuningJob.cancel(id=args.id)
- print(resp)
-
-
-class WandbLogger:
- @classmethod
- def sync(cls, args):
- import openai.wandb_logger
-
- resp = openai.wandb_logger.WandbLogger.sync(
- id=args.id,
- n_fine_tunes=args.n_fine_tunes,
- project=args.project,
- entity=args.entity,
- force=args.force,
- )
- print(resp)
-
-
-def tools_register(parser):
- subparsers = parser.add_subparsers(
- title="Tools", help="Convenience client side tools"
- )
-
- def help(args):
- parser.print_help()
-
- parser.set_defaults(func=help)
-
- sub = subparsers.add_parser("fine_tunes.prepare_data")
- sub.add_argument(
- "-f",
- "--file",
- required=True,
- help="JSONL, JSON, CSV, TSV, TXT or XLSX file containing prompt-completion examples to be analyzed."
- "This should be the local file path.",
- )
- sub.add_argument(
- "-q",
- "--quiet",
- required=False,
- action="store_true",
- help="Auto accepts all suggestions, without asking for user input. To be used within scripts.",
- )
- sub.set_defaults(func=FineTune.prepare_data)
-
-
-def api_register(parser):
- # Engine management
- subparsers = parser.add_subparsers(help="All API subcommands")
-
- def help(args):
- parser.print_help()
-
- parser.set_defaults(func=help)
-
- sub = subparsers.add_parser("engines.list")
- sub.set_defaults(func=Engine.list)
-
- sub = subparsers.add_parser("engines.get")
- sub.add_argument("-i", "--id", required=True)
- sub.set_defaults(func=Engine.get)
-
- sub = subparsers.add_parser("engines.update")
- sub.add_argument("-i", "--id", required=True)
- sub.add_argument("-r", "--replicas", type=int)
- sub.set_defaults(func=Engine.update)
-
- sub = subparsers.add_parser("engines.generate")
- sub.add_argument("-i", "--id", required=True)
- sub.add_argument(
- "--stream", help="Stream tokens as they're ready.", action="store_true"
- )
- sub.add_argument("-c", "--context", help="An optional context to generate from")
- sub.add_argument("-l", "--length", help="How many tokens to generate", type=int)
- sub.add_argument(
- "-t",
- "--temperature",
- help="""What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
-
-Mutually exclusive with `top_p`.""",
- type=float,
- )
- sub.add_argument(
- "-p",
- "--top_p",
- help="""An alternative to sampling with temperature, called nucleus sampling, where the considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%% probability mass are considered.
-
- Mutually exclusive with `temperature`.""",
- type=float,
- )
- sub.add_argument(
- "-n",
- "--completions",
- help="How many parallel completions to run on this context",
- type=int,
- )
- sub.add_argument(
- "--logprobs",
- help="Include the log probabilites on the `logprobs` most likely tokens. So for example, if `logprobs` is 10, the API will return a list of the 10 most likely tokens. If `logprobs` is supplied, the API will always return the logprob of the generated token, so there may be up to `logprobs+1` elements in the response.",
- type=int,
- )
- sub.add_argument(
- "--stop", help="A stop sequence at which to stop generating tokens."
- )
- sub.add_argument(
- "-m",
- "--model",
- required=False,
- help="A model (most commonly a model ID) to generate from. Defaults to the engine's default model.",
- )
- sub.set_defaults(func=Engine.generate)
-
- # Chat Completions
- sub = subparsers.add_parser("chat_completions.create")
-
- sub._action_groups.pop()
- req = sub.add_argument_group("required arguments")
- opt = sub.add_argument_group("optional arguments")
-
- req.add_argument(
- "-g",
- "--message",
- action="append",
- nargs=2,
- metavar=("ROLE", "CONTENT"),
- help="A message in `{role} {content}` format. Use this argument multiple times to add multiple messages.",
- required=True,
- )
-
- group = opt.add_mutually_exclusive_group()
- group.add_argument(
- "-e",
- "--engine",
- help="The engine to use. See https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=programming-language-python for more about what engines are available.",
- )
- group.add_argument(
- "-m",
- "--model",
- help="The model to use.",
- )
-
- opt.add_argument(
- "-n",
- "--n",
- help="How many completions to generate for the conversation.",
- type=int,
- )
- opt.add_argument(
- "-M", "--max-tokens", help="The maximum number of tokens to generate.", type=int
- )
- opt.add_argument(
- "-t",
- "--temperature",
- help="""What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
-
-Mutually exclusive with `top_p`.""",
- type=float,
- )
- opt.add_argument(
- "-P",
- "--top_p",
- help="""An alternative to sampling with temperature, called nucleus sampling, where the considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%% probability mass are considered.
-
- Mutually exclusive with `temperature`.""",
- type=float,
- )
- opt.add_argument(
- "--stop",
- help="A stop sequence at which to stop generating tokens for the message.",
- )
- opt.add_argument(
- "--stream", help="Stream messages as they're ready.", action="store_true"
- )
- sub.set_defaults(func=ChatCompletion.create)
-
- # Completions
- sub = subparsers.add_parser("completions.create")
- sub.add_argument(
- "-e",
- "--engine",
- help="The engine to use. See https://platform.openai.com/docs/engines for more about what engines are available.",
- )
- sub.add_argument(
- "-m",
- "--model",
- help="The model to use. At most one of `engine` or `model` should be specified.",
- )
- sub.add_argument(
- "--stream", help="Stream tokens as they're ready.", action="store_true"
- )
- sub.add_argument("-p", "--prompt", help="An optional prompt to complete from")
- sub.add_argument(
- "-M", "--max-tokens", help="The maximum number of tokens to generate", type=int
- )
- sub.add_argument(
- "-t",
- "--temperature",
- help="""What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
-
-Mutually exclusive with `top_p`.""",
- type=float,
- )
- sub.add_argument(
- "-P",
- "--top_p",
- help="""An alternative to sampling with temperature, called nucleus sampling, where the considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%% probability mass are considered.
-
- Mutually exclusive with `temperature`.""",
- type=float,
- )
- sub.add_argument(
- "-n",
- "--n",
- help="How many sub-completions to generate for each prompt.",
- type=int,
- )
- sub.add_argument(
- "--logprobs",
- help="Include the log probabilites on the `logprobs` most likely tokens, as well the chosen tokens. So for example, if `logprobs` is 10, the API will return a list of the 10 most likely tokens. If `logprobs` is 0, only the chosen tokens will have logprobs returned.",
- type=int,
- )
- sub.add_argument(
- "--stop", help="A stop sequence at which to stop generating tokens."
- )
- sub.set_defaults(func=Completion.create)
-
- # Deployments
- sub = subparsers.add_parser("deployments.list")
- sub.set_defaults(func=Deployment.list)
-
- sub = subparsers.add_parser("deployments.get")
- sub.add_argument("-i", "--id", required=True, help="The deployment ID")
- sub.set_defaults(func=Deployment.get)
-
- sub = subparsers.add_parser("deployments.delete")
- sub.add_argument("-i", "--id", required=True, help="The deployment ID")
- sub.set_defaults(func=Deployment.delete)
-
- sub = subparsers.add_parser("deployments.create")
- sub.add_argument("-m", "--model", required=True, help="The model ID")
- sub.add_argument(
- "-s",
- "--scale_type",
- required=True,
- help="The scale type. Either 'manual' or 'standard'",
- )
- sub.set_defaults(func=Deployment.create)
-
- # Models
- sub = subparsers.add_parser("models.list")
- sub.set_defaults(func=Model.list)
-
- sub = subparsers.add_parser("models.get")
- sub.add_argument("-i", "--id", required=True, help="The model ID")
- sub.set_defaults(func=Model.get)
-
- sub = subparsers.add_parser("models.delete")
- sub.add_argument("-i", "--id", required=True, help="The model ID")
- sub.set_defaults(func=Model.delete)
-
- # Files
- sub = subparsers.add_parser("files.create")
-
- sub.add_argument(
- "-f",
- "--file",
- required=True,
- help="File to upload",
- )
- sub.add_argument(
- "-p",
- "--purpose",
- help="Why are you uploading this file? (see https://platform.openai.com/docs/api-reference/ for purposes)",
- required=True,
- )
- sub.set_defaults(func=File.create)
-
- sub = subparsers.add_parser("files.get")
- sub.add_argument("-i", "--id", required=True, help="The files ID")
- sub.set_defaults(func=File.get)
-
- sub = subparsers.add_parser("files.delete")
- sub.add_argument("-i", "--id", required=True, help="The files ID")
- sub.set_defaults(func=File.delete)
-
- sub = subparsers.add_parser("files.list")
- sub.set_defaults(func=File.list)
-
- # Finetune
- sub = subparsers.add_parser("fine_tunes.list")
- sub.set_defaults(func=FineTune.list)
-
- sub = subparsers.add_parser("fine_tunes.create")
- sub.add_argument(
- "-t",
- "--training_file",
- required=True,
- help="JSONL file containing prompt-completion examples for training. This can "
- "be the ID of a file uploaded through the OpenAI API (e.g. file-abcde12345), "
- 'a local file path, or a URL that starts with "http".',
- )
- sub.add_argument(
- "-v",
- "--validation_file",
- help="JSONL file containing prompt-completion examples for validation. This can "
- "be the ID of a file uploaded through the OpenAI API (e.g. file-abcde12345), "
- 'a local file path, or a URL that starts with "http".',
- )
- sub.add_argument(
- "--no_check_if_files_exist",
- dest="check_if_files_exist",
- action="store_false",
- help="If this argument is set and training_file or validation_file are file paths, immediately upload them. If this argument is not set, check if they may be duplicates of already uploaded files before uploading, based on file name and file size.",
- )
- sub.add_argument(
- "-m",
- "--model",
- help="The model to start fine-tuning from",
- )
- sub.add_argument(
- "--suffix",
- help="If set, this argument can be used to customize the generated fine-tuned model name."
- "All punctuation and whitespace in `suffix` will be replaced with a "
- "single dash, and the string will be lower cased. The max "
- "length of `suffix` is 40 chars. "
- "The generated name will match the form `{base_model}:ft-{org-title}:{suffix}-{timestamp}`. "
- 'For example, `openai api fine_tunes.create -t test.jsonl -m ada --suffix "custom model name" '
- "could generate a model with the name "
- "ada:ft-your-org:custom-model-name-2022-02-15-04-21-04",
- )
- sub.add_argument(
- "--no_follow",
- action="store_true",
- help="If set, returns immediately after creating the job. Otherwise, streams events and waits for the job to complete.",
- )
- sub.add_argument(
- "--n_epochs",
- type=int,
- help="The number of epochs to train the model for. An epoch refers to one "
- "full cycle through the training dataset.",
- )
- sub.add_argument(
- "--batch_size",
- type=int,
- help="The batch size to use for training. The batch size is the number of "
- "training examples used to train a single forward and backward pass.",
- )
- sub.add_argument(
- "--learning_rate_multiplier",
- type=float,
- help="The learning rate multiplier to use for training. The fine-tuning "
- "learning rate is determined by the original learning rate used for "
- "pretraining multiplied by this value.",
- )
- sub.add_argument(
- "--prompt_loss_weight",
- type=float,
- help="The weight to use for the prompt loss. The optimum value here depends "
- "depends on your use case. This determines how much the model prioritizes "
- "learning from prompt tokens vs learning from completion tokens.",
- )
- sub.add_argument(
- "--compute_classification_metrics",
- action="store_true",
- help="If set, we calculate classification-specific metrics such as accuracy "
- "and F-1 score using the validation set at the end of every epoch.",
- )
- sub.set_defaults(compute_classification_metrics=None)
- sub.add_argument(
- "--classification_n_classes",
- type=int,
- help="The number of classes in a classification task. This parameter is "
- "required for multiclass classification.",
- )
- sub.add_argument(
- "--classification_positive_class",
- help="The positive class in binary classification. This parameter is needed "
- "to generate precision, recall and F-1 metrics when doing binary "
- "classification.",
- )
- sub.add_argument(
- "--classification_betas",
- type=float,
- nargs="+",
- help="If this is provided, we calculate F-beta scores at the specified beta "
- "values. The F-beta score is a generalization of F-1 score. This is only "
- "used for binary classification.",
- )
- sub.set_defaults(func=FineTune.create)
-
- sub = subparsers.add_parser("fine_tunes.get")
- sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job")
- sub.set_defaults(func=FineTune.get)
-
- sub = subparsers.add_parser("fine_tunes.results")
- sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job")
- sub.set_defaults(func=FineTune.results)
-
- sub = subparsers.add_parser("fine_tunes.events")
- sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job")
-
- # TODO(rachel): Remove this in 1.0
- sub.add_argument(
- "-s",
- "--stream",
- action="store_true",
- help="[DEPRECATED] If set, events will be streamed until the job is done. Otherwise, "
- "displays the event history to date.",
- )
- sub.set_defaults(func=FineTune.events)
-
- sub = subparsers.add_parser("fine_tunes.follow")
- sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job")
- sub.set_defaults(func=FineTune.follow)
-
- sub = subparsers.add_parser("fine_tunes.cancel")
- sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job")
- sub.set_defaults(func=FineTune.cancel)
-
- sub = subparsers.add_parser("fine_tunes.delete")
- sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job")
- sub.set_defaults(func=FineTune.delete)
-
- # Image
- sub = subparsers.add_parser("image.create")
- sub.add_argument("-p", "--prompt", type=str, required=True)
- sub.add_argument("-n", "--num-images", type=int, default=1)
- sub.add_argument(
- "-s", "--size", type=str, default="1024x1024", help="Size of the output image"
- )
- sub.add_argument("--response-format", type=str, default="url")
- sub.set_defaults(func=Image.create)
-
- sub = subparsers.add_parser("image.create_edit")
- sub.add_argument("-p", "--prompt", type=str, required=True)
- sub.add_argument("-n", "--num-images", type=int, default=1)
- sub.add_argument(
- "-I",
- "--image",
- type=str,
- required=True,
- help="Image to modify. Should be a local path and a PNG encoded image.",
- )
- sub.add_argument(
- "-s", "--size", type=str, default="1024x1024", help="Size of the output image"
- )
- sub.add_argument("--response-format", type=str, default="url")
- sub.add_argument(
- "-M",
- "--mask",
- type=str,
- required=False,
- help="Path to a mask image. It should be the same size as the image you're editing and a RGBA PNG image. The Alpha channel acts as the mask.",
- )
- sub.set_defaults(func=Image.create_edit)
-
- sub = subparsers.add_parser("image.create_variation")
- sub.add_argument("-n", "--num-images", type=int, default=1)
- sub.add_argument(
- "-I",
- "--image",
- type=str,
- required=True,
- help="Image to modify. Should be a local path and a PNG encoded image.",
- )
- sub.add_argument(
- "-s", "--size", type=str, default="1024x1024", help="Size of the output image"
- )
- sub.add_argument("--response-format", type=str, default="url")
- sub.set_defaults(func=Image.create_variation)
-
- # Audio
- # transcriptions
- sub = subparsers.add_parser("audio.transcribe")
- # Required
- sub.add_argument("-m", "--model", type=str, default="whisper-1")
- sub.add_argument("-f", "--file", type=str, required=True)
- # Optional
- sub.add_argument("--response-format", type=str)
- sub.add_argument("--language", type=str)
- sub.add_argument("-t", "--temperature", type=float)
- sub.add_argument("--prompt", type=str)
- sub.set_defaults(func=Audio.transcribe)
- # translations
- sub = subparsers.add_parser("audio.translate")
- # Required
- sub.add_argument("-m", "--model", type=str, default="whisper-1")
- sub.add_argument("-f", "--file", type=str, required=True)
- # Optional
- sub.add_argument("--response-format", type=str)
- sub.add_argument("--language", type=str)
- sub.add_argument("-t", "--temperature", type=float)
- sub.add_argument("--prompt", type=str)
- sub.set_defaults(func=Audio.translate)
-
- # FineTuning Jobs
- sub = subparsers.add_parser("fine_tuning.job.list")
- sub.set_defaults(func=FineTuningJob.list)
-
- sub = subparsers.add_parser("fine_tuning.job.create")
- sub.add_argument(
- "-t",
- "--training_file",
- required=True,
- help="JSONL file containing either chat-completion or prompt-completion examples for training. "
- "This can be the ID of a file uploaded through the OpenAI API (e.g. file-abcde12345), "
- 'a local file path, or a URL that starts with "http".',
- )
- sub.add_argument(
- "-v",
- "--validation_file",
- help="JSONL file containing either chat-completion or prompt-completion examples for validation. "
- "This can be the ID of a file uploaded through the OpenAI API (e.g. file-abcde12345), "
- 'a local file path, or a URL that starts with "http".',
- )
- sub.add_argument(
- "--no_check_if_files_exist",
- dest="check_if_files_exist",
- action="store_false",
- help="If this argument is set and training_file or validation_file are file paths, immediately upload them. If this argument is not set, check if they may be duplicates of already uploaded files before uploading, based on file name and file size.",
- )
- sub.add_argument(
- "-m",
- "--model",
- help="The model to start fine-tuning from",
- )
- sub.add_argument(
- "--suffix",
- help="If set, this argument can be used to customize the generated fine-tuned model name."
- "All punctuation and whitespace in `suffix` will be replaced with a "
- "single dash, and the string will be lower cased. The max "
- "length of `suffix` is 18 chars. "
- "The generated name will match the form `ft:{base_model}:{org-title}:{suffix}:{rstring}` where `rstring` "
- "is a random string sortable as a timestamp. "
- 'For example, `openai api fine_tuning.job.create -t test.jsonl -m gpt-3.5-turbo-0613 --suffix "first finetune!" '
- "could generate a model with the name "
- "ft:gpt-3.5-turbo-0613:your-org:first-finetune:7p4PqAoY",
- )
- sub.add_argument(
- "--n_epochs",
- type=int,
- help="The number of epochs to train the model for. An epoch refers to one "
- "full cycle through the training dataset.",
- )
- sub.set_defaults(func=FineTuningJob.create)
-
- sub = subparsers.add_parser("fine_tuning.job.get")
- sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job")
- sub.set_defaults(func=FineTuningJob.get)
-
- sub = subparsers.add_parser("fine_tuning.job.results")
- sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job")
- sub.set_defaults(func=FineTuningJob.results)
-
- sub = subparsers.add_parser("fine_tuning.job.events")
- sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job")
- sub.add_argument(
- "--limit",
- type=int,
- required=False,
- help="The number of events to return, starting from most recent. If not specified, all events will be returned.",
- )
- sub.set_defaults(func=FineTuningJob.events)
-
- sub = subparsers.add_parser("fine_tuning.job.follow")
- sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job")
- sub.set_defaults(func=FineTuningJob.follow)
-
- sub = subparsers.add_parser("fine_tuning.job.cancel")
- sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job")
- sub.set_defaults(func=FineTuningJob.cancel)
-
-
-def wandb_register(parser):
- subparsers = parser.add_subparsers(
- title="wandb", help="Logging with Weights & Biases, see https://docs.wandb.ai/guides/integrations/openai for documentation"
- )
-
- def help(args):
- parser.print_help()
-
- parser.set_defaults(func=help)
-
- sub = subparsers.add_parser("sync")
- sub.add_argument("-i", "--id", help="The id of the fine-tune job (optional)")
- sub.add_argument(
- "-n",
- "--n_fine_tunes",
- type=int,
- default=None,
- help="Number of most recent fine-tunes to log when an id is not provided. By default, every fine-tune is synced.",
- )
- sub.add_argument(
- "--project",
- default="OpenAI-Fine-Tune",
- help="""Name of the Weights & Biases project where you're sending runs. By default, it is "OpenAI-Fine-Tune".""",
- )
- sub.add_argument(
- "--entity",
- help="Weights & Biases username or team name where you're sending runs. By default, your default entity is used, which is usually your username.",
- )
- sub.add_argument(
- "--force",
- action="store_true",
- help="Forces logging and overwrite existing wandb run of the same fine-tune.",
- )
- sub.add_argument(
- "--legacy",
- action="store_true",
- help="Log results from legacy OpenAI /v1/fine-tunes api",
- )
- sub.set_defaults(force=False)
- sub.set_defaults(legacy=False)
- sub.set_defaults(func=WandbLogger.sync)
diff --git a/openai/datalib/__init__.py b/openai/datalib/__init__.py
deleted file mode 100644
index d02b49cfff..0000000000
--- a/openai/datalib/__init__.py
+++ /dev/null
@@ -1,14 +0,0 @@
-"""
-This module helps make data libraries like `numpy` and `pandas` optional dependencies.
-
-The libraries add up to 130MB+, which makes it challenging to deploy applications
-using this library in environments with code size constraints, like AWS Lambda.
-
-This module serves as an import proxy and provides a few utilities for dealing with the optionality.
-
-Since the primary use case of this library (talking to the OpenAI API) doesn't generally require data libraries,
-it's safe to make them optional. The rare case when data libraries are needed in the client is handled through
-assertions with instructive error messages.
-
-See also `setup.py`.
-"""
diff --git a/openai/datalib/common.py b/openai/datalib/common.py
deleted file mode 100644
index 96f9908a18..0000000000
--- a/openai/datalib/common.py
+++ /dev/null
@@ -1,17 +0,0 @@
-INSTRUCTIONS = """
-
-OpenAI error:
-
- missing `{library}`
-
-This feature requires additional dependencies:
-
- $ pip install openai[datalib]
-
-"""
-
-NUMPY_INSTRUCTIONS = INSTRUCTIONS.format(library="numpy")
-
-
-class MissingDependencyError(Exception):
- pass
diff --git a/openai/datalib/numpy_helper.py b/openai/datalib/numpy_helper.py
deleted file mode 100644
index fb80f2ae54..0000000000
--- a/openai/datalib/numpy_helper.py
+++ /dev/null
@@ -1,15 +0,0 @@
-from openai.datalib.common import INSTRUCTIONS, MissingDependencyError
-
-try:
- import numpy
-except ImportError:
- numpy = None
-
-HAS_NUMPY = bool(numpy)
-
-NUMPY_INSTRUCTIONS = INSTRUCTIONS.format(library="numpy")
-
-
-def assert_has_numpy():
- if not HAS_NUMPY:
- raise MissingDependencyError(NUMPY_INSTRUCTIONS)
diff --git a/openai/datalib/pandas_helper.py b/openai/datalib/pandas_helper.py
deleted file mode 100644
index 4e86d7b4f9..0000000000
--- a/openai/datalib/pandas_helper.py
+++ /dev/null
@@ -1,15 +0,0 @@
-from openai.datalib.common import INSTRUCTIONS, MissingDependencyError
-
-try:
- import pandas
-except ImportError:
- pandas = None
-
-HAS_PANDAS = bool(pandas)
-
-PANDAS_INSTRUCTIONS = INSTRUCTIONS.format(library="pandas")
-
-
-def assert_has_pandas():
- if not HAS_PANDAS:
- raise MissingDependencyError(PANDAS_INSTRUCTIONS)
diff --git a/openai/embeddings_utils.py b/openai/embeddings_utils.py
deleted file mode 100644
index dc26445c3c..0000000000
--- a/openai/embeddings_utils.py
+++ /dev/null
@@ -1,252 +0,0 @@
-import textwrap as tr
-from typing import List, Optional
-
-import matplotlib.pyplot as plt
-import plotly.express as px
-from scipy import spatial
-from sklearn.decomposition import PCA
-from sklearn.manifold import TSNE
-from sklearn.metrics import average_precision_score, precision_recall_curve
-from tenacity import retry, stop_after_attempt, wait_random_exponential
-
-import openai
-from openai.datalib.numpy_helper import numpy as np
-from openai.datalib.pandas_helper import pandas as pd
-
-
-@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
-def get_embedding(text: str, engine="text-embedding-ada-002", **kwargs) -> List[float]:
-
- # replace newlines, which can negatively affect performance.
- text = text.replace("\n", " ")
-
- return openai.Embedding.create(input=[text], engine=engine, **kwargs)["data"][0]["embedding"]
-
-
-@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
-async def aget_embedding(
- text: str, engine="text-embedding-ada-002", **kwargs
-) -> List[float]:
-
- # replace newlines, which can negatively affect performance.
- text = text.replace("\n", " ")
-
- return (await openai.Embedding.acreate(input=[text], engine=engine, **kwargs))["data"][0][
- "embedding"
- ]
-
-
-@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
-def get_embeddings(
- list_of_text: List[str], engine="text-embedding-ada-002", **kwargs
-) -> List[List[float]]:
- assert len(list_of_text) <= 8191, "The batch size should not be larger than 8191."
-
- # replace newlines, which can negatively affect performance.
- list_of_text = [text.replace("\n", " ") for text in list_of_text]
-
- data = openai.Embedding.create(input=list_of_text, engine=engine, **kwargs).data
- return [d["embedding"] for d in data]
-
-
-@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
-async def aget_embeddings(
- list_of_text: List[str], engine="text-embedding-ada-002", **kwargs
-) -> List[List[float]]:
- assert len(list_of_text) <= 8191, "The batch size should not be larger than 8191."
-
- # replace newlines, which can negatively affect performance.
- list_of_text = [text.replace("\n", " ") for text in list_of_text]
-
- data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, **kwargs)).data
- return [d["embedding"] for d in data]
-
-
-def cosine_similarity(a, b):
- return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
-
-
-def plot_multiclass_precision_recall(
- y_score, y_true_untransformed, class_list, classifier_name
-):
- """
- Precision-Recall plotting for a multiclass problem. It plots average precision-recall, per class precision recall and reference f1 contours.
-
- Code slightly modified, but heavily based on https://scikit-learn.org/stable/auto_examples/model_selection/plot_precision_recall.html
- """
- n_classes = len(class_list)
- y_true = pd.concat(
- [(y_true_untransformed == class_list[i]) for i in range(n_classes)], axis=1
- ).values
-
- # For each class
- precision = dict()
- recall = dict()
- average_precision = dict()
- for i in range(n_classes):
- precision[i], recall[i], _ = precision_recall_curve(y_true[:, i], y_score[:, i])
- average_precision[i] = average_precision_score(y_true[:, i], y_score[:, i])
-
- # A "micro-average": quantifying score on all classes jointly
- precision_micro, recall_micro, _ = precision_recall_curve(
- y_true.ravel(), y_score.ravel()
- )
- average_precision_micro = average_precision_score(y_true, y_score, average="micro")
- print(
- str(classifier_name)
- + " - Average precision score over all classes: {0:0.2f}".format(
- average_precision_micro
- )
- )
-
- # setup plot details
- plt.figure(figsize=(9, 10))
- f_scores = np.linspace(0.2, 0.8, num=4)
- lines = []
- labels = []
- for f_score in f_scores:
- x = np.linspace(0.01, 1)
- y = f_score * x / (2 * x - f_score)
- (l,) = plt.plot(x[y >= 0], y[y >= 0], color="gray", alpha=0.2)
- plt.annotate("f1={0:0.1f}".format(f_score), xy=(0.9, y[45] + 0.02))
-
- lines.append(l)
- labels.append("iso-f1 curves")
- (l,) = plt.plot(recall_micro, precision_micro, color="gold", lw=2)
- lines.append(l)
- labels.append(
- "average Precision-recall (auprc = {0:0.2f})" "".format(average_precision_micro)
- )
-
- for i in range(n_classes):
- (l,) = plt.plot(recall[i], precision[i], lw=2)
- lines.append(l)
- labels.append(
- "Precision-recall for class `{0}` (auprc = {1:0.2f})"
- "".format(class_list[i], average_precision[i])
- )
-
- fig = plt.gcf()
- fig.subplots_adjust(bottom=0.25)
- plt.xlim([0.0, 1.0])
- plt.ylim([0.0, 1.05])
- plt.xlabel("Recall")
- plt.ylabel("Precision")
- plt.title(f"{classifier_name}: Precision-Recall curve for each class")
- plt.legend(lines, labels)
-
-
-def distances_from_embeddings(
- query_embedding: List[float],
- embeddings: List[List[float]],
- distance_metric="cosine",
-) -> List[List]:
- """Return the distances between a query embedding and a list of embeddings."""
- distance_metrics = {
- "cosine": spatial.distance.cosine,
- "L1": spatial.distance.cityblock,
- "L2": spatial.distance.euclidean,
- "Linf": spatial.distance.chebyshev,
- }
- distances = [
- distance_metrics[distance_metric](query_embedding, embedding)
- for embedding in embeddings
- ]
- return distances
-
-
-def indices_of_nearest_neighbors_from_distances(distances) -> np.ndarray:
- """Return a list of indices of nearest neighbors from a list of distances."""
- return np.argsort(distances)
-
-
-def pca_components_from_embeddings(
- embeddings: List[List[float]], n_components=2
-) -> np.ndarray:
- """Return the PCA components of a list of embeddings."""
- pca = PCA(n_components=n_components)
- array_of_embeddings = np.array(embeddings)
- return pca.fit_transform(array_of_embeddings)
-
-
-def tsne_components_from_embeddings(
- embeddings: List[List[float]], n_components=2, **kwargs
-) -> np.ndarray:
- """Returns t-SNE components of a list of embeddings."""
- # use better defaults if not specified
- if "init" not in kwargs.keys():
- kwargs["init"] = "pca"
- if "learning_rate" not in kwargs.keys():
- kwargs["learning_rate"] = "auto"
- tsne = TSNE(n_components=n_components, **kwargs)
- array_of_embeddings = np.array(embeddings)
- return tsne.fit_transform(array_of_embeddings)
-
-
-def chart_from_components(
- components: np.ndarray,
- labels: Optional[List[str]] = None,
- strings: Optional[List[str]] = None,
- x_title="Component 0",
- y_title="Component 1",
- mark_size=5,
- **kwargs,
-):
- """Return an interactive 2D chart of embedding components."""
- empty_list = ["" for _ in components]
- data = pd.DataFrame(
- {
- x_title: components[:, 0],
- y_title: components[:, 1],
- "label": labels if labels else empty_list,
- "string": ["
".join(tr.wrap(string, width=30)) for string in strings]
- if strings
- else empty_list,
- }
- )
- chart = px.scatter(
- data,
- x=x_title,
- y=y_title,
- color="label" if labels else None,
- symbol="label" if labels else None,
- hover_data=["string"] if strings else None,
- **kwargs,
- ).update_traces(marker=dict(size=mark_size))
- return chart
-
-
-def chart_from_components_3D(
- components: np.ndarray,
- labels: Optional[List[str]] = None,
- strings: Optional[List[str]] = None,
- x_title: str = "Component 0",
- y_title: str = "Component 1",
- z_title: str = "Compontent 2",
- mark_size: int = 5,
- **kwargs,
-):
- """Return an interactive 3D chart of embedding components."""
- empty_list = ["" for _ in components]
- data = pd.DataFrame(
- {
- x_title: components[:, 0],
- y_title: components[:, 1],
- z_title: components[:, 2],
- "label": labels if labels else empty_list,
- "string": ["
".join(tr.wrap(string, width=30)) for string in strings]
- if strings
- else empty_list,
- }
- )
- chart = px.scatter_3d(
- data,
- x=x_title,
- y=y_title,
- z=z_title,
- color="label" if labels else None,
- symbol="label" if labels else None,
- hover_data=["string"] if strings else None,
- **kwargs,
- ).update_traces(marker=dict(size=mark_size))
- return chart
diff --git a/openai/error.py b/openai/error.py
deleted file mode 100644
index 2928ef6aa6..0000000000
--- a/openai/error.py
+++ /dev/null
@@ -1,169 +0,0 @@
-import openai
-
-
-class OpenAIError(Exception):
- def __init__(
- self,
- message=None,
- http_body=None,
- http_status=None,
- json_body=None,
- headers=None,
- code=None,
- ):
- super(OpenAIError, self).__init__(message)
-
- if http_body and hasattr(http_body, "decode"):
- try:
- http_body = http_body.decode("utf-8")
- except BaseException:
- http_body = (
- ""
- )
-
- self._message = message
- self.http_body = http_body
- self.http_status = http_status
- self.json_body = json_body
- self.headers = headers or {}
- self.code = code
- self.request_id = self.headers.get("request-id", None)
- self.error = self.construct_error_object()
- self.organization = self.headers.get("openai-organization", None)
-
- def __str__(self):
- msg = self._message or ""
- if self.request_id is not None:
- return "Request {0}: {1}".format(self.request_id, msg)
- else:
- return msg
-
- # Returns the underlying `Exception` (base class) message, which is usually
- # the raw message returned by OpenAI's API. This was previously available
- # in python2 via `error.message`. Unlike `str(error)`, it omits "Request
- # req_..." from the beginning of the string.
- @property
- def user_message(self):
- return self._message
-
- def __repr__(self):
- return "%s(message=%r, http_status=%r, request_id=%r)" % (
- self.__class__.__name__,
- self._message,
- self.http_status,
- self.request_id,
- )
-
- def construct_error_object(self):
- if (
- self.json_body is None
- or not isinstance(self.json_body, dict)
- or "error" not in self.json_body
- or not isinstance(self.json_body["error"], dict)
- ):
- return None
-
- return openai.api_resources.error_object.ErrorObject.construct_from(
- self.json_body["error"]
- )
-
-
-class APIError(OpenAIError):
- pass
-
-
-class TryAgain(OpenAIError):
- pass
-
-
-class Timeout(OpenAIError):
- pass
-
-
-class APIConnectionError(OpenAIError):
- def __init__(
- self,
- message,
- http_body=None,
- http_status=None,
- json_body=None,
- headers=None,
- code=None,
- should_retry=False,
- ):
- super(APIConnectionError, self).__init__(
- message, http_body, http_status, json_body, headers, code
- )
- self.should_retry = should_retry
-
-
-class InvalidRequestError(OpenAIError):
- def __init__(
- self,
- message,
- param,
- code=None,
- http_body=None,
- http_status=None,
- json_body=None,
- headers=None,
- ):
- super(InvalidRequestError, self).__init__(
- message, http_body, http_status, json_body, headers, code
- )
- self.param = param
-
- def __repr__(self):
- return "%s(message=%r, param=%r, code=%r, http_status=%r, " "request_id=%r)" % (
- self.__class__.__name__,
- self._message,
- self.param,
- self.code,
- self.http_status,
- self.request_id,
- )
-
- def __reduce__(self):
- return type(self), (
- self._message,
- self.param,
- self.code,
- self.http_body,
- self.http_status,
- self.json_body,
- self.headers,
- )
-
-
-class AuthenticationError(OpenAIError):
- pass
-
-
-class PermissionError(OpenAIError):
- pass
-
-
-class RateLimitError(OpenAIError):
- pass
-
-
-class ServiceUnavailableError(OpenAIError):
- pass
-
-
-class InvalidAPIType(OpenAIError):
- pass
-
-
-class SignatureVerificationError(OpenAIError):
- def __init__(self, message, sig_header, http_body=None):
- super(SignatureVerificationError, self).__init__(message, http_body)
- self.sig_header = sig_header
-
- def __reduce__(self):
- return type(self), (
- self._message,
- self.sig_header,
- self.http_body,
- )
diff --git a/openai/object_classes.py b/openai/object_classes.py
deleted file mode 100644
index 08093650fd..0000000000
--- a/openai/object_classes.py
+++ /dev/null
@@ -1,12 +0,0 @@
-from openai import api_resources
-from openai.api_resources.experimental.completion_config import CompletionConfig
-
-OBJECT_CLASSES = {
- "engine": api_resources.Engine,
- "experimental.completion_config": CompletionConfig,
- "file": api_resources.File,
- "fine-tune": api_resources.FineTune,
- "model": api_resources.Model,
- "deployment": api_resources.Deployment,
- "fine_tuning.job": api_resources.FineTuningJob,
-}
diff --git a/openai/openai_object.py b/openai/openai_object.py
deleted file mode 100644
index 95f8829742..0000000000
--- a/openai/openai_object.py
+++ /dev/null
@@ -1,347 +0,0 @@
-import json
-from copy import deepcopy
-from typing import Optional, Tuple, Union
-
-import openai
-from openai import api_requestor, util
-from openai.openai_response import OpenAIResponse
-from openai.util import ApiType
-
-
-class OpenAIObject(dict):
- api_base_override = None
-
- def __init__(
- self,
- id=None,
- api_key=None,
- api_version=None,
- api_type=None,
- organization=None,
- response_ms: Optional[int] = None,
- api_base=None,
- engine=None,
- **params,
- ):
- super(OpenAIObject, self).__init__()
-
- if response_ms is not None and not isinstance(response_ms, int):
- raise TypeError(f"response_ms is a {type(response_ms).__name__}.")
- self._response_ms = response_ms
-
- self._retrieve_params = params
-
- object.__setattr__(self, "api_key", api_key)
- object.__setattr__(self, "api_version", api_version)
- object.__setattr__(self, "api_type", api_type)
- object.__setattr__(self, "organization", organization)
- object.__setattr__(self, "api_base_override", api_base)
- object.__setattr__(self, "engine", engine)
-
- if id:
- self["id"] = id
-
- @property
- def response_ms(self) -> Optional[int]:
- return self._response_ms
-
- def __setattr__(self, k, v):
- if k[0] == "_" or k in self.__dict__:
- return super(OpenAIObject, self).__setattr__(k, v)
-
- self[k] = v
- return None
-
- def __getattr__(self, k):
- if k[0] == "_":
- raise AttributeError(k)
- try:
- return self[k]
- except KeyError as err:
- raise AttributeError(*err.args)
-
- def __delattr__(self, k):
- if k[0] == "_" or k in self.__dict__:
- return super(OpenAIObject, self).__delattr__(k)
- else:
- del self[k]
-
- def __setitem__(self, k, v):
- if v == "":
- raise ValueError(
- "You cannot set %s to an empty string. "
- "We interpret empty strings as None in requests."
- "You may set %s.%s = None to delete the property" % (k, str(self), k)
- )
- super(OpenAIObject, self).__setitem__(k, v)
-
- def __delitem__(self, k):
- raise NotImplementedError("del is not supported")
-
- # Custom unpickling method that uses `update` to update the dictionary
- # without calling __setitem__, which would fail if any value is an empty
- # string
- def __setstate__(self, state):
- self.update(state)
-
- # Custom pickling method to ensure the instance is pickled as a custom
- # class and not as a dict, otherwise __setstate__ would not be called when
- # unpickling.
- def __reduce__(self):
- reduce_value = (
- type(self), # callable
- ( # args
- self.get("id", None),
- self.api_key,
- self.api_version,
- self.api_type,
- self.organization,
- ),
- dict(self), # state
- )
- return reduce_value
-
- @classmethod
- def construct_from(
- cls,
- values,
- api_key: Optional[str] = None,
- api_version=None,
- organization=None,
- engine=None,
- response_ms: Optional[int] = None,
- ):
- instance = cls(
- values.get("id"),
- api_key=api_key,
- api_version=api_version,
- organization=organization,
- engine=engine,
- response_ms=response_ms,
- )
- instance.refresh_from(
- values,
- api_key=api_key,
- api_version=api_version,
- organization=organization,
- response_ms=response_ms,
- )
- return instance
-
- def refresh_from(
- self,
- values,
- api_key=None,
- api_version=None,
- api_type=None,
- organization=None,
- response_ms: Optional[int] = None,
- ):
- self.api_key = api_key or getattr(values, "api_key", None)
- self.api_version = api_version or getattr(values, "api_version", None)
- self.api_type = api_type or getattr(values, "api_type", None)
- self.organization = organization or getattr(values, "organization", None)
- self._response_ms = response_ms or getattr(values, "_response_ms", None)
-
- # Wipe old state before setting new.
- self.clear()
- for k, v in values.items():
- super(OpenAIObject, self).__setitem__(
- k, util.convert_to_openai_object(v, api_key, api_version, organization)
- )
-
- self._previous = values
-
- @classmethod
- def api_base(cls):
- return None
-
- def request(
- self,
- method,
- url,
- params=None,
- headers=None,
- stream=False,
- plain_old_data=False,
- request_id: Optional[str] = None,
- request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
- ):
- if params is None:
- params = self._retrieve_params
- requestor = api_requestor.APIRequestor(
- key=self.api_key,
- api_base=self.api_base_override or self.api_base(),
- api_type=self.api_type,
- api_version=self.api_version,
- organization=self.organization,
- )
- response, stream, api_key = requestor.request(
- method,
- url,
- params=params,
- stream=stream,
- headers=headers,
- request_id=request_id,
- request_timeout=request_timeout,
- )
-
- if stream:
- assert not isinstance(response, OpenAIResponse) # must be an iterator
- return (
- util.convert_to_openai_object(
- line,
- api_key,
- self.api_version,
- self.organization,
- plain_old_data=plain_old_data,
- )
- for line in response
- )
- else:
- return util.convert_to_openai_object(
- response,
- api_key,
- self.api_version,
- self.organization,
- plain_old_data=plain_old_data,
- )
-
- async def arequest(
- self,
- method,
- url,
- params=None,
- headers=None,
- stream=False,
- plain_old_data=False,
- request_id: Optional[str] = None,
- request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
- ):
- if params is None:
- params = self._retrieve_params
- requestor = api_requestor.APIRequestor(
- key=self.api_key,
- api_base=self.api_base_override or self.api_base(),
- api_type=self.api_type,
- api_version=self.api_version,
- organization=self.organization,
- )
- response, stream, api_key = await requestor.arequest(
- method,
- url,
- params=params,
- stream=stream,
- headers=headers,
- request_id=request_id,
- request_timeout=request_timeout,
- )
-
- if stream:
- assert not isinstance(response, OpenAIResponse) # must be an iterator
- return (
- util.convert_to_openai_object(
- line,
- api_key,
- self.api_version,
- self.organization,
- plain_old_data=plain_old_data,
- )
- for line in response
- )
- else:
- return util.convert_to_openai_object(
- response,
- api_key,
- self.api_version,
- self.organization,
- plain_old_data=plain_old_data,
- )
-
- def __repr__(self):
- ident_parts = [type(self).__name__]
-
- obj = self.get("object")
- if isinstance(obj, str):
- ident_parts.append(obj)
-
- if isinstance(self.get("id"), str):
- ident_parts.append("id=%s" % (self.get("id"),))
-
- unicode_repr = "<%s at %s> JSON: %s" % (
- " ".join(ident_parts),
- hex(id(self)),
- str(self),
- )
-
- return unicode_repr
-
- def __str__(self):
- obj = self.to_dict_recursive()
- return json.dumps(obj, indent=2)
-
- def to_dict(self):
- return dict(self)
-
- def to_dict_recursive(self):
- d = dict(self)
- for k, v in d.items():
- if isinstance(v, OpenAIObject):
- d[k] = v.to_dict_recursive()
- elif isinstance(v, list):
- d[k] = [
- e.to_dict_recursive() if isinstance(e, OpenAIObject) else e
- for e in v
- ]
- return d
-
- @property
- def openai_id(self):
- return self.id
-
- @property
- def typed_api_type(self):
- return (
- ApiType.from_str(self.api_type)
- if self.api_type
- else ApiType.from_str(openai.api_type)
- )
-
- # This class overrides __setitem__ to throw exceptions on inputs that it
- # doesn't like. This can cause problems when we try to copy an object
- # wholesale because some data that's returned from the API may not be valid
- # if it was set to be set manually. Here we override the class' copy
- # arguments so that we can bypass these possible exceptions on __setitem__.
- def __copy__(self):
- copied = OpenAIObject(
- self.get("id"),
- self.api_key,
- api_version=self.api_version,
- api_type=self.api_type,
- organization=self.organization,
- )
-
- copied._retrieve_params = self._retrieve_params
-
- for k, v in self.items():
- # Call parent's __setitem__ to avoid checks that we've added in the
- # overridden version that can throw exceptions.
- super(OpenAIObject, copied).__setitem__(k, v)
-
- return copied
-
- # This class overrides __setitem__ to throw exceptions on inputs that it
- # doesn't like. This can cause problems when we try to copy an object
- # wholesale because some data that's returned from the API may not be valid
- # if it was set to be set manually. Here we override the class' copy
- # arguments so that we can bypass these possible exceptions on __setitem__.
- def __deepcopy__(self, memo):
- copied = self.__copy__()
- memo[id(self)] = copied
-
- for k, v in self.items():
- # Call parent's __setitem__ to avoid checks that we've added in the
- # overridden version that can throw exceptions.
- super(OpenAIObject, copied).__setitem__(k, deepcopy(v, memo))
-
- return copied
diff --git a/openai/openai_response.py b/openai/openai_response.py
deleted file mode 100644
index d2230b1540..0000000000
--- a/openai/openai_response.py
+++ /dev/null
@@ -1,31 +0,0 @@
-from typing import Optional
-
-
-class OpenAIResponse:
- def __init__(self, data, headers):
- self._headers = headers
- self.data = data
-
- @property
- def request_id(self) -> Optional[str]:
- return self._headers.get("request-id")
-
- @property
- def retry_after(self) -> Optional[int]:
- try:
- return int(self._headers.get("retry-after"))
- except TypeError:
- return None
-
- @property
- def operation_location(self) -> Optional[str]:
- return self._headers.get("operation-location")
-
- @property
- def organization(self) -> Optional[str]:
- return self._headers.get("OpenAI-Organization")
-
- @property
- def response_ms(self) -> Optional[int]:
- h = self._headers.get("Openai-Processing-Ms")
- return None if h is None else round(float(h))
diff --git a/openai/tests/__init__.py b/openai/tests/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/openai/tests/asyncio/__init__.py b/openai/tests/asyncio/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/openai/tests/asyncio/test_endpoints.py b/openai/tests/asyncio/test_endpoints.py
deleted file mode 100644
index 1b146e6749..0000000000
--- a/openai/tests/asyncio/test_endpoints.py
+++ /dev/null
@@ -1,90 +0,0 @@
-import io
-import json
-
-import pytest
-from aiohttp import ClientSession
-
-import openai
-from openai import error
-
-pytestmark = [pytest.mark.asyncio]
-
-
-# FILE TESTS
-async def test_file_upload():
- result = await openai.File.acreate(
- file=io.StringIO(
- json.dumps({"prompt": "test file data", "completion": "tada"})
- ),
- purpose="fine-tune",
- )
- assert result.purpose == "fine-tune"
- assert "id" in result
-
- result = await openai.File.aretrieve(id=result.id)
- assert result.status == "uploaded"
-
-
-# COMPLETION TESTS
-async def test_completions():
- result = await openai.Completion.acreate(
- prompt="This was a test", n=5, engine="ada"
- )
- assert len(result.choices) == 5
-
-
-async def test_completions_multiple_prompts():
- result = await openai.Completion.acreate(
- prompt=["This was a test", "This was another test"], n=5, engine="ada"
- )
- assert len(result.choices) == 10
-
-
-async def test_completions_model():
- result = await openai.Completion.acreate(prompt="This was a test", n=5, model="ada")
- assert len(result.choices) == 5
- assert result.model.startswith("ada")
-
-
-async def test_timeout_raises_error():
- # A query that should take awhile to return
- with pytest.raises(error.Timeout):
- await openai.Completion.acreate(
- prompt="test" * 1000,
- n=10,
- model="ada",
- max_tokens=100,
- request_timeout=0.01,
- )
-
-
-async def test_timeout_does_not_error():
- # A query that should be fast
- await openai.Completion.acreate(
- prompt="test",
- model="ada",
- request_timeout=10,
- )
-
-
-async def test_completions_stream_finishes_global_session():
- async with ClientSession() as session:
- openai.aiosession.set(session)
-
- # A query that should be fast
- parts = []
- async for part in await openai.Completion.acreate(
- prompt="test", model="ada", request_timeout=3, stream=True
- ):
- parts.append(part)
- assert len(parts) > 1
-
-
-async def test_completions_stream_finishes_local_session():
- # A query that should be fast
- parts = []
- async for part in await openai.Completion.acreate(
- prompt="test", model="ada", request_timeout=3, stream=True
- ):
- parts.append(part)
- assert len(parts) > 1
diff --git a/openai/tests/test_api_requestor.py b/openai/tests/test_api_requestor.py
deleted file mode 100644
index 56e8ec89da..0000000000
--- a/openai/tests/test_api_requestor.py
+++ /dev/null
@@ -1,101 +0,0 @@
-import json
-
-import pytest
-import requests
-from pytest_mock import MockerFixture
-
-from openai import Model
-from openai.api_requestor import APIRequestor
-
-
-@pytest.mark.requestor
-def test_requestor_sets_request_id(mocker: MockerFixture) -> None:
- # Fake out 'requests' and confirm that the X-Request-Id header is set.
-
- got_headers = {}
-
- def fake_request(self, *args, **kwargs):
- nonlocal got_headers
- got_headers = kwargs["headers"]
- r = requests.Response()
- r.status_code = 200
- r.headers["content-type"] = "application/json"
- r._content = json.dumps({}).encode("utf-8")
- return r
-
- mocker.patch("requests.sessions.Session.request", fake_request)
- fake_request_id = "1234"
- Model.retrieve("xxx", request_id=fake_request_id) # arbitrary API resource
- got_request_id = got_headers.get("X-Request-Id")
- assert got_request_id == fake_request_id
-
-
-@pytest.mark.requestor
-def test_requestor_open_ai_headers() -> None:
- api_requestor = APIRequestor(key="test_key", api_type="open_ai")
- headers = {"Test_Header": "Unit_Test_Header"}
- headers = api_requestor.request_headers(
- method="get", extra=headers, request_id="test_id"
- )
- assert "Test_Header" in headers
- assert headers["Test_Header"] == "Unit_Test_Header"
- assert "Authorization" in headers
- assert headers["Authorization"] == "Bearer test_key"
-
-
-@pytest.mark.requestor
-def test_requestor_azure_headers() -> None:
- api_requestor = APIRequestor(key="test_key", api_type="azure")
- headers = {"Test_Header": "Unit_Test_Header"}
- headers = api_requestor.request_headers(
- method="get", extra=headers, request_id="test_id"
- )
- assert "Test_Header" in headers
- assert headers["Test_Header"] == "Unit_Test_Header"
- assert "api-key" in headers
- assert headers["api-key"] == "test_key"
-
-
-@pytest.mark.requestor
-def test_requestor_azure_ad_headers() -> None:
- api_requestor = APIRequestor(key="test_key", api_type="azure_ad")
- headers = {"Test_Header": "Unit_Test_Header"}
- headers = api_requestor.request_headers(
- method="get", extra=headers, request_id="test_id"
- )
- assert "Test_Header" in headers
- assert headers["Test_Header"] == "Unit_Test_Header"
- assert "Authorization" in headers
- assert headers["Authorization"] == "Bearer test_key"
-
-
-@pytest.mark.requestor
-def test_requestor_cycle_sessions(mocker: MockerFixture) -> None:
- # HACK: we need to purge the _thread_context to not interfere
- # with other tests
- from openai.api_requestor import _thread_context
-
- delattr(_thread_context, "session")
-
- api_requestor = APIRequestor(key="test_key", api_type="azure_ad")
-
- mock_session = mocker.MagicMock()
- mocker.patch("openai.api_requestor._make_session", lambda: mock_session)
-
- # We don't call `session.close()` if not enough time has elapsed
- api_requestor.request_raw("get", "http://example.com")
- mock_session.request.assert_called()
- api_requestor.request_raw("get", "http://example.com")
- mock_session.close.assert_not_called()
-
- mocker.patch("openai.api_requestor.MAX_SESSION_LIFETIME_SECS", 0)
-
- # Due to 0 lifetime, the original session will be closed before the next call
- # and a new session will be created
- mock_session_2 = mocker.MagicMock()
- mocker.patch("openai.api_requestor._make_session", lambda: mock_session_2)
- api_requestor.request_raw("get", "http://example.com")
- mock_session.close.assert_called()
- mock_session_2.request.assert_called()
-
- delattr(_thread_context, "session")
diff --git a/openai/tests/test_endpoints.py b/openai/tests/test_endpoints.py
deleted file mode 100644
index 958e07f091..0000000000
--- a/openai/tests/test_endpoints.py
+++ /dev/null
@@ -1,118 +0,0 @@
-import io
-import json
-
-import pytest
-import requests
-
-import openai
-from openai import error
-
-
-# FILE TESTS
-def test_file_upload():
- result = openai.File.create(
- file=io.StringIO(
- json.dumps({"prompt": "test file data", "completion": "tada"})
- ),
- purpose="fine-tune",
- )
- assert result.purpose == "fine-tune"
- assert "id" in result
-
- result = openai.File.retrieve(id=result.id)
- assert result.status == "uploaded"
-
-
-# CHAT COMPLETION TESTS
-def test_chat_completions():
- result = openai.ChatCompletion.create(
- model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello!"}]
- )
- assert len(result.choices) == 1
-
-
-def test_chat_completions_multiple():
- result = openai.ChatCompletion.create(
- model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello!"}], n=5
- )
- assert len(result.choices) == 5
-
-
-def test_chat_completions_streaming():
- result = None
- events = openai.ChatCompletion.create(
- model="gpt-3.5-turbo",
- messages=[{"role": "user", "content": "Hello!"}],
- stream=True,
- )
- for result in events:
- assert len(result.choices) == 1
-
-
-# COMPLETION TESTS
-def test_completions():
- result = openai.Completion.create(prompt="This was a test", n=5, engine="ada")
- assert len(result.choices) == 5
-
-
-def test_completions_multiple_prompts():
- result = openai.Completion.create(
- prompt=["This was a test", "This was another test"], n=5, engine="ada"
- )
- assert len(result.choices) == 10
-
-
-def test_completions_model():
- result = openai.Completion.create(prompt="This was a test", n=5, model="ada")
- assert len(result.choices) == 5
- assert result.model.startswith("ada")
-
-
-def test_timeout_raises_error():
- # A query that should take awhile to return
- with pytest.raises(error.Timeout):
- openai.Completion.create(
- prompt="test" * 1000,
- n=10,
- model="ada",
- max_tokens=100,
- request_timeout=0.01,
- )
-
-
-def test_timeout_does_not_error():
- # A query that should be fast
- openai.Completion.create(
- prompt="test",
- model="ada",
- request_timeout=10,
- )
-
-
-def test_user_session():
- with requests.Session() as session:
- openai.requestssession = session
-
- completion = openai.Completion.create(
- prompt="hello world",
- model="ada",
- )
- assert completion
-
-
-def test_user_session_factory():
- def factory():
- session = requests.Session()
- session.mount(
- "https://",
- requests.adapters.HTTPAdapter(max_retries=4),
- )
- return session
-
- openai.requestssession = factory
-
- completion = openai.Completion.create(
- prompt="hello world",
- model="ada",
- )
- assert completion
diff --git a/openai/tests/test_exceptions.py b/openai/tests/test_exceptions.py
deleted file mode 100644
index 7760cdc5f6..0000000000
--- a/openai/tests/test_exceptions.py
+++ /dev/null
@@ -1,40 +0,0 @@
-import pickle
-
-import pytest
-
-import openai
-
-EXCEPTION_TEST_CASES = [
- openai.InvalidRequestError(
- "message",
- "param",
- code=400,
- http_body={"test": "test1"},
- http_status="fail",
- json_body={"text": "iono some text"},
- headers={"request-id": "asasd"},
- ),
- openai.error.AuthenticationError(),
- openai.error.PermissionError(),
- openai.error.RateLimitError(),
- openai.error.ServiceUnavailableError(),
- openai.error.SignatureVerificationError("message", "sig_header?"),
- openai.error.APIConnectionError("message!", should_retry=True),
- openai.error.TryAgain(),
- openai.error.Timeout(),
- openai.error.APIError(
- message="message",
- code=400,
- http_body={"test": "test1"},
- http_status="fail",
- json_body={"text": "iono some text"},
- headers={"request-id": "asasd"},
- ),
- openai.error.OpenAIError(),
-]
-
-
-class TestExceptions:
- @pytest.mark.parametrize("error", EXCEPTION_TEST_CASES)
- def test_exceptions_are_pickleable(self, error) -> None:
- assert error.__repr__() == pickle.loads(pickle.dumps(error)).__repr__()
diff --git a/openai/tests/test_file_cli.py b/openai/tests/test_file_cli.py
deleted file mode 100644
index 69ea29e2a0..0000000000
--- a/openai/tests/test_file_cli.py
+++ /dev/null
@@ -1,39 +0,0 @@
-import json
-import subprocess
-import time
-from tempfile import NamedTemporaryFile
-
-STILL_PROCESSING = "File is still processing. Check back later."
-
-
-def test_file_cli() -> None:
- contents = json.dumps({"prompt": "1 + 3 =", "completion": "4"}) + "\n"
- with NamedTemporaryFile(suffix=".jsonl", mode="wb") as train_file:
- train_file.write(contents.encode("utf-8"))
- train_file.flush()
- create_output = subprocess.check_output(
- ["openai", "api", "files.create", "-f", train_file.name, "-p", "fine-tune"]
- )
- file_obj = json.loads(create_output)
- assert file_obj["bytes"] == len(contents)
- file_id: str = file_obj["id"]
- assert file_id.startswith("file-")
- start_time = time.time()
- while True:
- delete_result = subprocess.run(
- ["openai", "api", "files.delete", "-i", file_id],
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- encoding="utf-8",
- )
- if delete_result.returncode == 0:
- break
- elif STILL_PROCESSING in delete_result.stderr:
- time.sleep(0.5)
- if start_time + 60 < time.time():
- raise RuntimeError("timed out waiting for file to become available")
- continue
- else:
- raise RuntimeError(
- f"delete failed: stdout={delete_result.stdout} stderr={delete_result.stderr}"
- )
diff --git a/openai/tests/test_long_examples_validator.py b/openai/tests/test_long_examples_validator.py
deleted file mode 100644
index 949a7cbbae..0000000000
--- a/openai/tests/test_long_examples_validator.py
+++ /dev/null
@@ -1,54 +0,0 @@
-import json
-import subprocess
-from tempfile import NamedTemporaryFile
-
-import pytest
-
-from openai.datalib.numpy_helper import HAS_NUMPY, NUMPY_INSTRUCTIONS
-from openai.datalib.pandas_helper import HAS_PANDAS, PANDAS_INSTRUCTIONS
-
-
-@pytest.mark.skipif(not HAS_PANDAS, reason=PANDAS_INSTRUCTIONS)
-@pytest.mark.skipif(not HAS_NUMPY, reason=NUMPY_INSTRUCTIONS)
-def test_long_examples_validator() -> None:
- """
- Ensures that long_examples_validator() handles previously applied recommendations,
- namely dropped duplicates, without resulting in a KeyError.
- """
-
- # data
- short_prompt = "a prompt "
- long_prompt = short_prompt * 500
-
- short_completion = "a completion "
- long_completion = short_completion * 500
-
- # the order of these matters
- unprepared_training_data = [
- {"prompt": long_prompt, "completion": long_completion}, # 1 of 2 duplicates
- {"prompt": short_prompt, "completion": short_completion},
- {"prompt": long_prompt, "completion": long_completion}, # 2 of 2 duplicates
- ]
-
- with NamedTemporaryFile(suffix=".jsonl", mode="w") as training_data:
- print(training_data.name)
- for prompt_completion_row in unprepared_training_data:
- training_data.write(json.dumps(prompt_completion_row) + "\n")
- training_data.flush()
-
- prepared_data_cmd_output = subprocess.run(
- [f"openai tools fine_tunes.prepare_data -f {training_data.name}"],
- stdout=subprocess.PIPE,
- text=True,
- input="y\ny\ny\ny\ny", # apply all recommendations, one at a time
- stderr=subprocess.PIPE,
- encoding="utf-8",
- shell=True,
- )
-
- # validate data was prepared successfully
- assert prepared_data_cmd_output.stderr == ""
- # validate get_long_indexes() applied during optional_fn() call in long_examples_validator()
- assert "indices of the long examples has changed" in prepared_data_cmd_output.stdout
-
- return prepared_data_cmd_output.stdout
diff --git a/openai/tests/test_url_composition.py b/openai/tests/test_url_composition.py
deleted file mode 100644
index 5034354a05..0000000000
--- a/openai/tests/test_url_composition.py
+++ /dev/null
@@ -1,209 +0,0 @@
-from sys import api_version
-
-import pytest
-
-from openai import Completion, Engine
-from openai.util import ApiType
-
-
-@pytest.mark.url
-def test_completions_url_composition_azure() -> None:
- url = Completion.class_url("test_engine", "azure", "2021-11-01-preview")
- assert (
- url
- == "/openai/deployments/test_engine/completions?api-version=2021-11-01-preview"
- )
-
-
-@pytest.mark.url
-def test_completions_url_composition_azure_ad() -> None:
- url = Completion.class_url("test_engine", "azure_ad", "2021-11-01-preview")
- assert (
- url
- == "/openai/deployments/test_engine/completions?api-version=2021-11-01-preview"
- )
-
-
-@pytest.mark.url
-def test_completions_url_composition_default() -> None:
- url = Completion.class_url("test_engine")
- assert url == "/engines/test_engine/completions"
-
-
-@pytest.mark.url
-def test_completions_url_composition_open_ai() -> None:
- url = Completion.class_url("test_engine", "open_ai")
- assert url == "/engines/test_engine/completions"
-
-
-@pytest.mark.url
-def test_completions_url_composition_invalid_type() -> None:
- with pytest.raises(Exception):
- url = Completion.class_url("test_engine", "invalid")
-
-
-@pytest.mark.url
-def test_completions_url_composition_instance_url_azure() -> None:
- completion = Completion(
- id="test_id",
- engine="test_engine",
- api_type="azure",
- api_version="2021-11-01-preview",
- )
- url = completion.instance_url()
- assert (
- url
- == "/openai/deployments/test_engine/completions/test_id?api-version=2021-11-01-preview"
- )
-
-
-@pytest.mark.url
-def test_completions_url_composition_instance_url_azure_ad() -> None:
- completion = Completion(
- id="test_id",
- engine="test_engine",
- api_type="azure_ad",
- api_version="2021-11-01-preview",
- )
- url = completion.instance_url()
- assert (
- url
- == "/openai/deployments/test_engine/completions/test_id?api-version=2021-11-01-preview"
- )
-
-
-@pytest.mark.url
-def test_completions_url_composition_instance_url_azure_no_version() -> None:
- completion = Completion(
- id="test_id", engine="test_engine", api_type="azure", api_version=None
- )
- with pytest.raises(Exception):
- completion.instance_url()
-
-
-@pytest.mark.url
-def test_completions_url_composition_instance_url_default() -> None:
- completion = Completion(id="test_id", engine="test_engine")
- url = completion.instance_url()
- assert url == "/engines/test_engine/completions/test_id"
-
-
-@pytest.mark.url
-def test_completions_url_composition_instance_url_open_ai() -> None:
- completion = Completion(
- id="test_id",
- engine="test_engine",
- api_type="open_ai",
- api_version="2021-11-01-preview",
- )
- url = completion.instance_url()
- assert url == "/engines/test_engine/completions/test_id"
-
-
-@pytest.mark.url
-def test_completions_url_composition_instance_url_invalid() -> None:
- completion = Completion(id="test_id", engine="test_engine", api_type="invalid")
- with pytest.raises(Exception):
- url = completion.instance_url()
-
-
-@pytest.mark.url
-def test_completions_url_composition_instance_url_timeout_azure() -> None:
- completion = Completion(
- id="test_id",
- engine="test_engine",
- api_type="azure",
- api_version="2021-11-01-preview",
- )
- completion["timeout"] = 12
- url = completion.instance_url()
- assert (
- url
- == "/openai/deployments/test_engine/completions/test_id?api-version=2021-11-01-preview&timeout=12"
- )
-
-
-@pytest.mark.url
-def test_completions_url_composition_instance_url_timeout_openai() -> None:
- completion = Completion(id="test_id", engine="test_engine", api_type="open_ai")
- completion["timeout"] = 12
- url = completion.instance_url()
- assert url == "/engines/test_engine/completions/test_id?timeout=12"
-
-
-@pytest.mark.url
-def test_engine_search_url_composition_azure() -> None:
- engine = Engine(id="test_id", api_type="azure", api_version="2021-11-01-preview")
- assert engine.api_type == "azure"
- assert engine.typed_api_type == ApiType.AZURE
- url = engine.instance_url("test_operation")
- assert (
- url
- == "/openai/deployments/test_id/test_operation?api-version=2021-11-01-preview"
- )
-
-
-@pytest.mark.url
-def test_engine_search_url_composition_azure_ad() -> None:
- engine = Engine(id="test_id", api_type="azure_ad", api_version="2021-11-01-preview")
- assert engine.api_type == "azure_ad"
- assert engine.typed_api_type == ApiType.AZURE_AD
- url = engine.instance_url("test_operation")
- assert (
- url
- == "/openai/deployments/test_id/test_operation?api-version=2021-11-01-preview"
- )
-
-
-@pytest.mark.url
-def test_engine_search_url_composition_azure_no_version() -> None:
- engine = Engine(id="test_id", api_type="azure", api_version=None)
- assert engine.api_type == "azure"
- assert engine.typed_api_type == ApiType.AZURE
- with pytest.raises(Exception):
- engine.instance_url("test_operation")
-
-
-@pytest.mark.url
-def test_engine_search_url_composition_azure_no_operation() -> None:
- engine = Engine(id="test_id", api_type="azure", api_version="2021-11-01-preview")
- assert engine.api_type == "azure"
- assert engine.typed_api_type == ApiType.AZURE
- assert (
- engine.instance_url()
- == "/openai/engines/test_id?api-version=2021-11-01-preview"
- )
-
-
-@pytest.mark.url
-def test_engine_search_url_composition_default() -> None:
- engine = Engine(id="test_id")
- assert engine.api_type == None
- assert engine.typed_api_type == ApiType.OPEN_AI
- url = engine.instance_url()
- assert url == "/engines/test_id"
-
-
-@pytest.mark.url
-def test_engine_search_url_composition_open_ai() -> None:
- engine = Engine(id="test_id", api_type="open_ai")
- assert engine.api_type == "open_ai"
- assert engine.typed_api_type == ApiType.OPEN_AI
- url = engine.instance_url()
- assert url == "/engines/test_id"
-
-
-@pytest.mark.url
-def test_engine_search_url_composition_invalid_type() -> None:
- engine = Engine(id="test_id", api_type="invalid")
- assert engine.api_type == "invalid"
- with pytest.raises(Exception):
- assert engine.typed_api_type == ApiType.OPEN_AI
-
-
-@pytest.mark.url
-def test_engine_search_url_composition_invalid_search() -> None:
- engine = Engine(id="test_id", api_type="invalid")
- assert engine.api_type == "invalid"
- with pytest.raises(Exception):
- engine.search()
diff --git a/openai/tests/test_util.py b/openai/tests/test_util.py
deleted file mode 100644
index 6220ccb7f4..0000000000
--- a/openai/tests/test_util.py
+++ /dev/null
@@ -1,55 +0,0 @@
-import json
-from tempfile import NamedTemporaryFile
-
-import pytest
-
-import openai
-from openai import util
-
-
-@pytest.fixture(scope="function")
-def api_key_file():
- saved_path = openai.api_key_path
- try:
- with NamedTemporaryFile(prefix="openai-api-key", mode="wt") as tmp:
- openai.api_key_path = tmp.name
- yield tmp
- finally:
- openai.api_key_path = saved_path
-
-
-def test_openai_api_key_path(api_key_file) -> None:
- print("sk-foo", file=api_key_file)
- api_key_file.flush()
- assert util.default_api_key() == "sk-foo"
-
-
-def test_openai_api_key_path_with_malformed_key(api_key_file) -> None:
- print("malformed-api-key", file=api_key_file)
- api_key_file.flush()
- with pytest.raises(ValueError, match="Malformed API key"):
- util.default_api_key()
-
-
-def test_key_order_openai_object_rendering() -> None:
- sample_response = {
- "id": "chatcmpl-7NaPEA6sgX7LnNPyKPbRlsyqLbr5V",
- "object": "chat.completion",
- "created": 1685855844,
- "model": "gpt-3.5-turbo-0301",
- "usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97},
- "choices": [
- {
- "message": {
- "role": "assistant",
- "content": "The 2020 World Series was played at Globe Life Field in Arlington, Texas. It was the first time that the World Series was played at a neutral site because of the COVID-19 pandemic.",
- },
- "finish_reason": "stop",
- "index": 0,
- }
- ],
- }
-
- oai_object = util.convert_to_openai_object(sample_response)
- # The `__str__` method was sorting while dumping to json
- assert list(json.loads(str(oai_object)).keys()) == list(sample_response.keys())
diff --git a/openai/upload_progress.py b/openai/upload_progress.py
deleted file mode 100644
index e4da62a4e0..0000000000
--- a/openai/upload_progress.py
+++ /dev/null
@@ -1,52 +0,0 @@
-import io
-
-
-class CancelledError(Exception):
- def __init__(self, msg):
- self.msg = msg
- Exception.__init__(self, msg)
-
- def __str__(self):
- return self.msg
-
- __repr__ = __str__
-
-
-class BufferReader(io.BytesIO):
- def __init__(self, buf=b"", desc=None):
- self._len = len(buf)
- io.BytesIO.__init__(self, buf)
- self._progress = 0
- self._callback = progress(len(buf), desc=desc)
-
- def __len__(self):
- return self._len
-
- def read(self, n=-1):
- chunk = io.BytesIO.read(self, n)
- self._progress += len(chunk)
- if self._callback:
- try:
- self._callback(self._progress)
- except Exception as e: # catches exception from the callback
- raise CancelledError("The upload was cancelled: {}".format(e))
- return chunk
-
-
-def progress(total, desc):
- import tqdm # type: ignore
-
- meter = tqdm.tqdm(total=total, unit_scale=True, desc=desc)
-
- def incr(progress):
- meter.n = progress
- if progress == total:
- meter.close()
- else:
- meter.refresh()
-
- return incr
-
-
-def MB(i):
- return int(i // 1024**2)
diff --git a/openai/util.py b/openai/util.py
deleted file mode 100644
index 5501d5b67e..0000000000
--- a/openai/util.py
+++ /dev/null
@@ -1,188 +0,0 @@
-import logging
-import os
-import re
-import sys
-from enum import Enum
-from typing import Optional
-
-import openai
-
-OPENAI_LOG = os.environ.get("OPENAI_LOG")
-
-logger = logging.getLogger("openai")
-
-__all__ = [
- "log_info",
- "log_debug",
- "log_warn",
- "logfmt",
-]
-
-api_key_to_header = (
- lambda api, key: {"Authorization": f"Bearer {key}"}
- if api in (ApiType.OPEN_AI, ApiType.AZURE_AD)
- else {"api-key": f"{key}"}
-)
-
-
-class ApiType(Enum):
- AZURE = 1
- OPEN_AI = 2
- AZURE_AD = 3
-
- @staticmethod
- def from_str(label):
- if label.lower() == "azure":
- return ApiType.AZURE
- elif label.lower() in ("azure_ad", "azuread"):
- return ApiType.AZURE_AD
- elif label.lower() in ("open_ai", "openai"):
- return ApiType.OPEN_AI
- else:
- raise openai.error.InvalidAPIType(
- "The API type provided in invalid. Please select one of the supported API types: 'azure', 'azure_ad', 'open_ai'"
- )
-
-
-def _console_log_level():
- if openai.log in ["debug", "info"]:
- return openai.log
- elif OPENAI_LOG in ["debug", "info"]:
- return OPENAI_LOG
- else:
- return None
-
-
-def log_debug(message, **params):
- msg = logfmt(dict(message=message, **params))
- if _console_log_level() == "debug":
- print(msg, file=sys.stderr)
- logger.debug(msg)
-
-
-def log_info(message, **params):
- msg = logfmt(dict(message=message, **params))
- if _console_log_level() in ["debug", "info"]:
- print(msg, file=sys.stderr)
- logger.info(msg)
-
-
-def log_warn(message, **params):
- msg = logfmt(dict(message=message, **params))
- print(msg, file=sys.stderr)
- logger.warn(msg)
-
-
-def logfmt(props):
- def fmt(key, val):
- # Handle case where val is a bytes or bytesarray
- if hasattr(val, "decode"):
- val = val.decode("utf-8")
- # Check if val is already a string to avoid re-encoding into ascii.
- if not isinstance(val, str):
- val = str(val)
- if re.search(r"\s", val):
- val = repr(val)
- # key should already be a string
- if re.search(r"\s", key):
- key = repr(key)
- return "{key}={val}".format(key=key, val=val)
-
- return " ".join([fmt(key, val) for key, val in sorted(props.items())])
-
-
-def get_object_classes():
- # This is here to avoid a circular dependency
- from openai.object_classes import OBJECT_CLASSES
-
- return OBJECT_CLASSES
-
-
-def convert_to_openai_object(
- resp,
- api_key=None,
- api_version=None,
- organization=None,
- engine=None,
- plain_old_data=False,
-):
- # If we get a OpenAIResponse, we'll want to return a OpenAIObject.
-
- response_ms: Optional[int] = None
- if isinstance(resp, openai.openai_response.OpenAIResponse):
- organization = resp.organization
- response_ms = resp.response_ms
- resp = resp.data
-
- if plain_old_data:
- return resp
- elif isinstance(resp, list):
- return [
- convert_to_openai_object(
- i, api_key, api_version, organization, engine=engine
- )
- for i in resp
- ]
- elif isinstance(resp, dict) and not isinstance(
- resp, openai.openai_object.OpenAIObject
- ):
- resp = resp.copy()
- klass_name = resp.get("object")
- if isinstance(klass_name, str):
- klass = get_object_classes().get(
- klass_name, openai.openai_object.OpenAIObject
- )
- else:
- klass = openai.openai_object.OpenAIObject
-
- return klass.construct_from(
- resp,
- api_key=api_key,
- api_version=api_version,
- organization=organization,
- response_ms=response_ms,
- engine=engine,
- )
- else:
- return resp
-
-
-def convert_to_dict(obj):
- """Converts a OpenAIObject back to a regular dict.
-
- Nested OpenAIObjects are also converted back to regular dicts.
-
- :param obj: The OpenAIObject to convert.
-
- :returns: The OpenAIObject as a dict.
- """
- if isinstance(obj, list):
- return [convert_to_dict(i) for i in obj]
- # This works by virtue of the fact that OpenAIObjects _are_ dicts. The dict
- # comprehension returns a regular dict and recursively applies the
- # conversion to each value.
- elif isinstance(obj, dict):
- return {k: convert_to_dict(v) for k, v in obj.items()}
- else:
- return obj
-
-
-def merge_dicts(x, y):
- z = x.copy()
- z.update(y)
- return z
-
-
-def default_api_key() -> str:
- if openai.api_key_path:
- with open(openai.api_key_path, "rt") as k:
- api_key = k.read().strip()
- if not api_key.startswith("sk-"):
- raise ValueError(f"Malformed API key in {openai.api_key_path}.")
- return api_key
- elif openai.api_key is not None:
- return openai.api_key
- else:
- raise openai.error.AuthenticationError(
- "No API key provided. You can set your API key in code using 'openai.api_key = ', or you can set the environment variable OPENAI_API_KEY=). If your API key is stored in a file, you can point the openai module at it with 'openai.api_key_path = '. You can generate API keys in the OpenAI web interface. See https://platform.openai.com/account/api-keys for details."
- )
diff --git a/openai/version.py b/openai/version.py
deleted file mode 100644
index 51f3ce82ff..0000000000
--- a/openai/version.py
+++ /dev/null
@@ -1 +0,0 @@
-VERSION = "0.27.9"
diff --git a/openai/wandb_logger.py b/openai/wandb_logger.py
deleted file mode 100644
index d8e060c41b..0000000000
--- a/openai/wandb_logger.py
+++ /dev/null
@@ -1,314 +0,0 @@
-try:
- import wandb
-
- WANDB_AVAILABLE = True
-except:
- WANDB_AVAILABLE = False
-
-
-if WANDB_AVAILABLE:
- import datetime
- import io
- import json
- import re
- from pathlib import Path
-
- from openai import File, FineTune, FineTuningJob
- from openai.datalib.numpy_helper import numpy as np
- from openai.datalib.pandas_helper import assert_has_pandas, pandas as pd
-
-
-class WandbLogger:
- """
- Log fine-tunes to [Weights & Biases](https://wandb.me/openai-docs)
- """
-
- if not WANDB_AVAILABLE:
- print("Logging requires wandb to be installed. Run `pip install wandb`.")
- else:
- _wandb_api = None
- _logged_in = False
-
- @classmethod
- def sync(
- cls,
- id=None,
- n_fine_tunes=None,
- project="OpenAI-Fine-Tune",
- entity=None,
- force=False,
- legacy=False,
- **kwargs_wandb_init,
- ):
- """
- Sync fine-tunes to Weights & Biases.
- :param id: The id of the fine-tune (optional)
- :param n_fine_tunes: Number of most recent fine-tunes to log when an id is not provided. By default, every fine-tune is synced.
- :param project: Name of the project where you're sending runs. By default, it is "GPT-3".
- :param entity: Username or team name where you're sending runs. By default, your default entity is used, which is usually your username.
- :param force: Forces logging and overwrite existing wandb run of the same fine-tune.
- """
-
- assert_has_pandas()
-
- if not WANDB_AVAILABLE:
- return
-
- if id:
- print("Retrieving fine-tune job...")
- if legacy:
- fine_tune = FineTune.retrieve(id=id)
- else:
- fine_tune = FineTuningJob.retrieve(id=id)
- fine_tune.pop("events", None)
- fine_tunes = [fine_tune]
- else:
- # get list of fine_tune to log
- if legacy:
- fine_tunes = FineTune.list()
- else:
- fine_tunes = list(FineTuningJob.auto_paging_iter())
- if not fine_tunes or fine_tunes.get("data") is None:
- print("No fine-tune has been retrieved")
- return
- fine_tunes = fine_tunes["data"][
- -n_fine_tunes if n_fine_tunes is not None else None :
- ]
-
- # log starting from oldest fine_tune
- show_individual_warnings = (
- False if id is None and n_fine_tunes is None else True
- )
- fine_tune_logged = [
- cls._log_fine_tune(
- fine_tune,
- project,
- entity,
- force,
- legacy,
- show_individual_warnings,
- **kwargs_wandb_init,
- )
- for fine_tune in fine_tunes
- ]
-
- if not show_individual_warnings and not any(fine_tune_logged):
- print("No new successful fine-tunes were found")
-
- return "🎉 wandb sync completed successfully"
-
- @classmethod
- def _log_fine_tune(
- cls,
- fine_tune,
- project,
- entity,
- force,
- legacy,
- show_individual_warnings,
- **kwargs_wandb_init,
- ):
- fine_tune_id = fine_tune.get("id")
- status = fine_tune.get("status")
-
- # check run completed successfully
- if status != "succeeded":
- if show_individual_warnings:
- print(
- f'Fine-tune {fine_tune_id} has the status "{status}" and will not be logged'
- )
- return
-
- # check results are present
- try:
- if legacy:
- results_id = fine_tune["result_files"][0]["id"]
- else:
- results_id = fine_tune["result_files"][0]
- results = File.download(id=results_id).decode("utf-8")
- except:
- if show_individual_warnings:
- print(f"Fine-tune {fine_tune_id} has no results and will not be logged")
- return
-
- # check run has not been logged already
- run_path = f"{project}/{fine_tune_id}"
- if entity is not None:
- run_path = f"{entity}/{run_path}"
- wandb_run = cls._get_wandb_run(run_path)
- if wandb_run:
- wandb_status = wandb_run.summary.get("status")
- if show_individual_warnings:
- if wandb_status == "succeeded":
- print(
- f"Fine-tune {fine_tune_id} has already been logged successfully at {wandb_run.url}"
- )
- if not force:
- print(
- 'Use "--force" in the CLI or "force=True" in python if you want to overwrite previous run'
- )
- else:
- print(
- f"A run for fine-tune {fine_tune_id} was previously created but didn't end successfully"
- )
- if wandb_status != "succeeded" or force:
- print(
- f"A new wandb run will be created for fine-tune {fine_tune_id} and previous run will be overwritten"
- )
- if wandb_status == "succeeded" and not force:
- return
-
- # start a wandb run
- wandb.init(
- job_type="fine-tune",
- config=cls._get_config(fine_tune),
- project=project,
- entity=entity,
- name=fine_tune_id,
- id=fine_tune_id,
- **kwargs_wandb_init,
- )
-
- # log results
- df_results = pd.read_csv(io.StringIO(results))
- for _, row in df_results.iterrows():
- metrics = {k: v for k, v in row.items() if not np.isnan(v)}
- step = metrics.pop("step")
- if step is not None:
- step = int(step)
- wandb.log(metrics, step=step)
- fine_tuned_model = fine_tune.get("fine_tuned_model")
- if fine_tuned_model is not None:
- wandb.summary["fine_tuned_model"] = fine_tuned_model
-
- # training/validation files and fine-tune details
- cls._log_artifacts(fine_tune, project, entity)
-
- # mark run as complete
- wandb.summary["status"] = "succeeded"
-
- wandb.finish()
- return True
-
- @classmethod
- def _ensure_logged_in(cls):
- if not cls._logged_in:
- if wandb.login():
- cls._logged_in = True
- else:
- raise Exception("You need to log in to wandb")
-
- @classmethod
- def _get_wandb_run(cls, run_path):
- cls._ensure_logged_in()
- try:
- if cls._wandb_api is None:
- cls._wandb_api = wandb.Api()
- return cls._wandb_api.run(run_path)
- except Exception:
- return None
-
- @classmethod
- def _get_wandb_artifact(cls, artifact_path):
- cls._ensure_logged_in()
- try:
- if cls._wandb_api is None:
- cls._wandb_api = wandb.Api()
- return cls._wandb_api.artifact(artifact_path)
- except Exception:
- return None
-
- @classmethod
- def _get_config(cls, fine_tune):
- config = dict(fine_tune)
- for key in ("training_files", "validation_files", "result_files"):
- if config.get(key) and len(config[key]):
- config[key] = config[key][0]
- if config.get("created_at"):
- config["created_at"] = datetime.datetime.fromtimestamp(config["created_at"])
- return config
-
- @classmethod
- def _log_artifacts(cls, fine_tune, project, entity):
- # training/validation files
- training_file = (
- fine_tune["training_files"][0]
- if fine_tune.get("training_files") and len(fine_tune["training_files"])
- else None
- )
- validation_file = (
- fine_tune["validation_files"][0]
- if fine_tune.get("validation_files") and len(fine_tune["validation_files"])
- else None
- )
- for file, prefix, artifact_type in (
- (training_file, "train", "training_files"),
- (validation_file, "valid", "validation_files"),
- ):
- if file is not None:
- cls._log_artifact_inputs(file, prefix, artifact_type, project, entity)
-
- # fine-tune details
- fine_tune_id = fine_tune.get("id")
- artifact = wandb.Artifact(
- "fine_tune_details",
- type="fine_tune_details",
- metadata=fine_tune,
- )
- with artifact.new_file(
- "fine_tune_details.json", mode="w", encoding="utf-8"
- ) as f:
- json.dump(fine_tune, f, indent=2)
- wandb.run.log_artifact(
- artifact,
- aliases=["latest", fine_tune_id],
- )
-
- @classmethod
- def _log_artifact_inputs(cls, file, prefix, artifact_type, project, entity):
- file_id = file["id"]
- filename = Path(file["filename"]).name
- stem = Path(file["filename"]).stem
-
- # get input artifact
- artifact_name = f"{prefix}-{filename}"
- # sanitize name to valid wandb artifact name
- artifact_name = re.sub(r"[^a-zA-Z0-9_\-.]", "_", artifact_name)
- artifact_alias = file_id
- artifact_path = f"{project}/{artifact_name}:{artifact_alias}"
- if entity is not None:
- artifact_path = f"{entity}/{artifact_path}"
- artifact = cls._get_wandb_artifact(artifact_path)
-
- # create artifact if file not already logged previously
- if artifact is None:
- # get file content
- try:
- file_content = File.download(id=file_id).decode("utf-8")
- except:
- print(
- f"File {file_id} could not be retrieved. Make sure you are allowed to download training/validation files"
- )
- return
- artifact = wandb.Artifact(artifact_name, type=artifact_type, metadata=file)
- with artifact.new_file(filename, mode="w", encoding="utf-8") as f:
- f.write(file_content)
-
- # create a Table
- try:
- table, n_items = cls._make_table(file_content)
- artifact.add(table, stem)
- wandb.config.update({f"n_{prefix}": n_items})
- artifact.metadata["items"] = n_items
- except:
- print(f"File {file_id} could not be read as a valid JSON file")
- else:
- # log number of items
- wandb.config.update({f"n_{prefix}": artifact.metadata.get("items")})
-
- wandb.run.use_artifact(artifact, aliases=["latest", artifact_alias])
-
- @classmethod
- def _make_table(cls, file_content):
- df = pd.read_json(io.StringIO(file_content), orient="records", lines=True)
- return wandb.Table(dataframe=df), len(df)
diff --git a/public/Makefile b/public/Makefile
deleted file mode 100644
index 2862fd4261..0000000000
--- a/public/Makefile
+++ /dev/null
@@ -1,7 +0,0 @@
-.PHONY: build upload
-
-build:
- OPENAI_UPLOAD=y python setup.py sdist
-
-upload:
- OPENAI_UPLOAD=y twine upload dist/*
diff --git a/public/setup.py b/public/setup.py
deleted file mode 100644
index 0198a53361..0000000000
--- a/public/setup.py
+++ /dev/null
@@ -1,10 +0,0 @@
-import os
-
-from setuptools import setup
-
-if os.getenv("OPENAI_UPLOAD") != "y":
- raise RuntimeError(
- "This package is a placeholder package on the public PyPI instance, and is not the correct version to install. If you are having trouble figuring out the correct package to install, please contact us."
- )
-
-setup(name="openai", description="Placeholder package", version="0.0.1")
diff --git a/pyproject.toml b/pyproject.toml
index 6116c7fa2f..7f6e3123d4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,13 +1,160 @@
+[project]
+name = "openai"
+version = "1.0.0"
+description = "Client library for the openai API"
+readme = "README.md"
+license = "Apache-2.0"
+authors = [
+{ name = "OpenAI", email = "support@openai.com" },
+]
+dependencies = [
+ "httpx>=0.23.0, <1",
+ "pydantic>=1.9.0, <3",
+ "typing-extensions>=4.5, <5",
+ "anyio>=3.5.0, <4",
+ "distro>=1.7.0, <2",
+ "tqdm > 4"
+]
+requires-python = ">= 3.7.1"
+classifiers = [
+ "Typing :: Typed",
+ "Intended Audience :: Developers",
+ "Programming Language :: Python :: 3.7",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Operating System :: OS Independent",
+ "Operating System :: POSIX",
+ "Operating System :: MacOS",
+ "Operating System :: POSIX :: Linux",
+ "Operating System :: Microsoft :: Windows",
+ "Topic :: Software Development :: Libraries :: Python Modules",
+]
+
+[project.optional-dependencies]
+datalib = ["numpy >= 1", "pandas >= 1.2.3", "pandas-stubs >= 1.1.0.11"]
+
+[project.urls]
+Homepage = "https://github.com/openai/openai-python"
+Repository = "https://github.com/openai/openai-python"
+
+[project.scripts]
+openai = "openai.cli:main"
+
+[tool.rye]
+managed = true
+dev-dependencies = [
+ "pyright==1.1.332",
+ "mypy==1.6.1",
+ "black==23.3.0",
+ "respx==0.19.2",
+ "pytest==7.1.1",
+ "pytest-asyncio==0.21.1",
+ "ruff==0.0.282",
+ "isort==5.10.1",
+ "time-machine==2.9.0",
+ "nox==2023.4.22",
+ "dirty-equals>=0.6.0",
+ "azure-identity >=1.14.1",
+ "types-tqdm > 4"
+]
+
+[tool.rye.scripts]
+format = { chain = [
+ "format:black",
+ "format:docs",
+ "format:ruff",
+ "format:isort",
+]}
+"format:black" = "black ."
+"format:docs" = "python bin/blacken-docs.py README.md api.md"
+"format:ruff" = "ruff --fix ."
+"format:isort" = "isort ."
+
+"check:ruff" = "ruff ."
+
+typecheck = { chain = [
+ "typecheck:pyright",
+ "typecheck:mypy"
+]}
+"typecheck:pyright" = "pyright"
+"typecheck:verify-types" = "pyright --verifytypes openai --ignoreexternal"
+"typecheck:mypy" = "mypy --enable-incomplete-feature=Unpack ."
+
[build-system]
-requires = ["setuptools"]
-build-backend = "setuptools.build_meta"
+requires = ["hatchling"]
+build-backend = "hatchling.build"
+
+[tool.hatch.build]
+include = [
+ "src/*"
+]
+
+[tool.hatch.build.targets.wheel]
+packages = ["src/openai"]
[tool.black]
-target-version = ['py36']
-exclude = '.*\.ipynb'
+line-length = 120
+target-version = ["py37"]
+
+[tool.pytest.ini_options]
+testpaths = ["tests"]
+addopts = "--tb=short"
+xfail_strict = true
+asyncio_mode = "auto"
+filterwarnings = [
+ "error"
+]
+
+[tool.pyright]
+# this enables practically every flag given by pyright.
+# there are a couple of flags that are still disabled by
+# default in strict mode as they are experimental and niche.
+typeCheckingMode = "strict"
+pythonVersion = "3.7"
+
+exclude = [
+ "_dev",
+ ".venv",
+ ".nox",
+]
+
+reportImplicitOverride = true
+
+reportImportCycles = false
+reportPrivateUsage = false
[tool.isort]
-py_version = 36
-include_trailing_comma = "true"
-line_length = 88
-multi_line_output = 3
+profile = "black"
+length_sort = true
+extra_standard_library = ["typing_extensions"]
+
+[tool.ruff]
+line-length = 120
+format = "grouped"
+target-version = "py37"
+select = [
+ # remove unused imports
+ "F401",
+ # bare except statements
+ "E722",
+ # unused arguments
+ "ARG",
+ # print statements
+ "T201",
+ "T203",
+]
+unfixable = [
+ # disable auto fix for print statements
+ "T201",
+ "T203",
+]
+ignore-init-module-imports = true
+
+
+[tool.ruff.per-file-ignores]
+"bin/**.py" = ["T201", "T203"]
+"tests/**.py" = ["T201", "T203"]
+"examples/**.py" = ["T201", "T203"]
diff --git a/pytest.ini b/pytest.ini
deleted file mode 100644
index 5b78d87c16..0000000000
--- a/pytest.ini
+++ /dev/null
@@ -1,4 +0,0 @@
-[pytest]
-markers =
- url: mark a test as part of the url composition tests.
- requestor: mark test as part of the api_requestor tests.
diff --git a/requirements-dev.lock b/requirements-dev.lock
new file mode 100644
index 0000000000..0747babdc5
--- /dev/null
+++ b/requirements-dev.lock
@@ -0,0 +1,74 @@
+# generated by rye
+# use `rye lock` or `rye sync` to update this lockfile
+#
+# last locked with the following flags:
+# pre: false
+# features: []
+# all-features: true
+
+-e file:.
+annotated-types==0.6.0
+anyio==3.7.1
+argcomplete==3.1.2
+attrs==23.1.0
+azure-core==1.29.5
+azure-identity==1.15.0
+black==23.3.0
+certifi==2023.7.22
+cffi==1.16.0
+charset-normalizer==3.3.1
+click==8.1.7
+colorlog==6.7.0
+cryptography==41.0.5
+dirty-equals==0.6.0
+distlib==0.3.7
+distro==1.8.0
+exceptiongroup==1.1.3
+filelock==3.12.4
+h11==0.12.0
+httpcore==0.15.0
+httpx==0.23.0
+idna==3.4
+iniconfig==2.0.0
+isort==5.10.1
+msal==1.24.1
+msal-extensions==1.0.0
+mypy==1.6.1
+mypy-extensions==1.0.0
+nodeenv==1.8.0
+nox==2023.4.22
+numpy==1.26.1
+packaging==23.2
+pandas==2.1.1
+pandas-stubs==2.1.1.230928
+pathspec==0.11.2
+platformdirs==3.11.0
+pluggy==1.3.0
+portalocker==2.8.2
+py==1.11.0
+pycparser==2.21
+pydantic==2.4.2
+pydantic-core==2.10.1
+pyjwt==2.8.0
+pyright==1.1.332
+pytest==7.1.1
+pytest-asyncio==0.21.1
+python-dateutil==2.8.2
+pytz==2023.3.post1
+requests==2.31.0
+respx==0.19.2
+rfc3986==1.5.0
+ruff==0.0.282
+six==1.16.0
+sniffio==1.3.0
+time-machine==2.9.0
+tomli==2.0.1
+tqdm==4.66.1
+types-pytz==2023.3.1.1
+types-tqdm==4.66.0.2
+typing-extensions==4.8.0
+tzdata==2023.3
+urllib3==2.0.7
+virtualenv==20.24.5
+# The following packages are considered to be unsafe in a requirements file:
+setuptools==68.2.2
diff --git a/requirements.lock b/requirements.lock
new file mode 100644
index 0000000000..be9606fc3c
--- /dev/null
+++ b/requirements.lock
@@ -0,0 +1,32 @@
+# generated by rye
+# use `rye lock` or `rye sync` to update this lockfile
+#
+# last locked with the following flags:
+# pre: false
+# features: []
+# all-features: true
+
+-e file:.
+annotated-types==0.6.0
+anyio==3.7.1
+certifi==2023.7.22
+distro==1.8.0
+exceptiongroup==1.1.3
+h11==0.12.0
+httpcore==0.15.0
+httpx==0.23.0
+idna==3.4
+numpy==1.26.1
+pandas==2.1.1
+pandas-stubs==2.1.1.230928
+pydantic==2.4.2
+pydantic-core==2.10.1
+python-dateutil==2.8.2
+pytz==2023.3.post1
+rfc3986==1.5.0
+six==1.16.0
+sniffio==1.3.0
+tqdm==4.66.1
+types-pytz==2023.3.1.1
+typing-extensions==4.8.0
+tzdata==2023.3
diff --git a/setup.cfg b/setup.cfg
deleted file mode 100644
index 3729647b8d..0000000000
--- a/setup.cfg
+++ /dev/null
@@ -1,65 +0,0 @@
-[metadata]
-name = openai
-version = attr: openai.version.VERSION
-description = Python client library for the OpenAI API
-long_description = file: README.md
-long_description_content_type = text/markdown
-author = OpenAI
-author_email = support@openai.com
-url = https://github.com/openai/openai-python
-license_files = LICENSE
-classifiers =
- Programming Language :: Python :: 3
- License :: OSI Approved :: MIT License
- Operating System :: OS Independent
-
-[options]
-packages = find:
-python_requires = >=3.7.1
-zip_safe = True
-include_package_data = True
-install_requires =
- requests >= 2.20 # to get the patch for CVE-2018-18074
- tqdm # Needed for progress bars
- typing_extensions; python_version<"3.8" # Needed for type hints for mypy
- aiohttp # Needed for async support
-
-[options.extras_require]
-dev =
- black ~= 21.6b0
- pytest == 6.*
- pytest-asyncio
- pytest-mock
-datalib =
- numpy
- pandas >= 1.2.3 # Needed for CLI fine-tuning data preparation tool
- pandas-stubs >= 1.1.0.11 # Needed for type hints for mypy
- openpyxl >= 3.0.7 # Needed for CLI fine-tuning data preparation tool xlsx format
-wandb =
- wandb
- numpy
- pandas >= 1.2.3 # Needed for CLI fine-tuning data preparation tool
- pandas-stubs >= 1.1.0.11 # Needed for type hints for mypy
- openpyxl >= 3.0.7 # Needed for CLI fine-tuning data preparation tool xlsx format
-embeddings =
- scikit-learn >= 1.0.2 # Needed for embedding utils, versions >= 1.1 require python 3.8
- tenacity >= 8.0.1
- matplotlib
- plotly
- numpy
- scipy
- pandas >= 1.2.3 # Needed for CLI fine-tuning data preparation tool
- pandas-stubs >= 1.1.0.11 # Needed for type hints for mypy
- openpyxl >= 3.0.7 # Needed for CLI fine-tuning data preparation tool xlsx format
-
-[options.entry_points]
-console_scripts =
- openai = openai._openai_scripts:main
-
-[options.package_data]
- openai = py.typed
-
-[options.packages.find]
-exclude =
- tests
- tests.*
diff --git a/setup.py b/setup.py
deleted file mode 100644
index 606849326a..0000000000
--- a/setup.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from setuptools import setup
-
-setup()
diff --git a/src/openai/__init__.py b/src/openai/__init__.py
new file mode 100644
index 0000000000..f033d8f26c
--- /dev/null
+++ b/src/openai/__init__.py
@@ -0,0 +1,342 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+import os as _os
+from typing_extensions import override
+
+from . import types
+from ._types import NoneType, Transport, ProxiesTypes
+from ._utils import file_from_path
+from ._client import (
+ Client,
+ OpenAI,
+ Stream,
+ Timeout,
+ Transport,
+ AsyncClient,
+ AsyncOpenAI,
+ AsyncStream,
+ RequestOptions,
+)
+from ._version import __title__, __version__
+from ._exceptions import (
+ APIError,
+ OpenAIError,
+ ConflictError,
+ NotFoundError,
+ APIStatusError,
+ RateLimitError,
+ APITimeoutError,
+ BadRequestError,
+ APIConnectionError,
+ AuthenticationError,
+ InternalServerError,
+ PermissionDeniedError,
+ UnprocessableEntityError,
+ APIResponseValidationError,
+)
+from ._utils._logs import setup_logging as _setup_logging
+
+__all__ = [
+ "types",
+ "__version__",
+ "__title__",
+ "NoneType",
+ "Transport",
+ "ProxiesTypes",
+ "OpenAIError",
+ "APIError",
+ "APIStatusError",
+ "APITimeoutError",
+ "APIConnectionError",
+ "APIResponseValidationError",
+ "BadRequestError",
+ "AuthenticationError",
+ "PermissionDeniedError",
+ "NotFoundError",
+ "ConflictError",
+ "UnprocessableEntityError",
+ "RateLimitError",
+ "InternalServerError",
+ "Timeout",
+ "RequestOptions",
+ "Client",
+ "AsyncClient",
+ "Stream",
+ "AsyncStream",
+ "OpenAI",
+ "AsyncOpenAI",
+ "file_from_path",
+]
+
+from .lib import azure as _azure
+from .version import VERSION as VERSION
+from .lib.azure import AzureOpenAI as AzureOpenAI
+from .lib.azure import AsyncAzureOpenAI as AsyncAzureOpenAI
+
+_setup_logging()
+
+# Update the __module__ attribute for exported symbols so that
+# error messages point to this module instead of the module
+# it was originally defined in, e.g.
+# openai._exceptions.NotFoundError -> openai.NotFoundError
+__locals = locals()
+for __name in __all__:
+ if not __name.startswith("__"):
+ try:
+ setattr(__locals[__name], "__module__", "openai")
+ except (TypeError, AttributeError):
+ # Some of our exported symbols are builtins which we can't set attributes for.
+ pass
+
+# ------ Module level client ------
+import typing as _t
+import typing_extensions as _te
+
+import httpx as _httpx
+
+from ._base_client import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES
+
+api_key: str | None = None
+
+organization: str | None = None
+
+base_url: str | _httpx.URL | None = None
+
+timeout: float | Timeout | None = DEFAULT_TIMEOUT
+
+max_retries: int = DEFAULT_MAX_RETRIES
+
+default_headers: _t.Mapping[str, str] | None = None
+
+default_query: _t.Mapping[str, object] | None = None
+
+http_client: _httpx.Client | None = None
+
+_ApiType = _te.Literal["openai", "azure"]
+
+api_type: _ApiType | None = _t.cast(_ApiType, _os.environ.get("OPENAI_API_TYPE"))
+
+api_version: str | None = _os.environ.get("OPENAI_API_VERSION")
+
+azure_endpoint: str | None = _os.environ.get("AZURE_OPENAI_ENDPOINT")
+
+azure_ad_token: str | None = _os.environ.get("AZURE_OPENAI_AD_TOKEN")
+
+azure_ad_token_provider: _azure.AzureADTokenProvider | None = None
+
+
+class _ModuleClient(OpenAI):
+ # Note: we have to use type: ignores here as overriding class members
+ # with properties is technically unsafe but it is fine for our use case
+
+ @property # type: ignore
+ @override
+ def api_key(self) -> str | None:
+ return api_key
+
+ @api_key.setter # type: ignore
+ def api_key(self, value: str | None) -> None: # type: ignore
+ global api_key
+
+ api_key = value
+
+ @property # type: ignore
+ @override
+ def organization(self) -> str | None:
+ return organization
+
+ @organization.setter # type: ignore
+ def organization(self, value: str | None) -> None: # type: ignore
+ global organization
+
+ organization = value
+
+ @property
+ @override
+ def base_url(self) -> _httpx.URL:
+ if base_url is not None:
+ return _httpx.URL(base_url)
+
+ return super().base_url
+
+ @base_url.setter
+ def base_url(self, url: _httpx.URL | str) -> None:
+ super().base_url = url # type: ignore[misc]
+
+ @property # type: ignore
+ @override
+ def timeout(self) -> float | Timeout | None:
+ return timeout
+
+ @timeout.setter # type: ignore
+ def timeout(self, value: float | Timeout | None) -> None: # type: ignore
+ global timeout
+
+ timeout = value
+
+ @property # type: ignore
+ @override
+ def max_retries(self) -> int:
+ return max_retries
+
+ @max_retries.setter # type: ignore
+ def max_retries(self, value: int) -> None: # type: ignore
+ global max_retries
+
+ max_retries = value
+
+ @property # type: ignore
+ @override
+ def _custom_headers(self) -> _t.Mapping[str, str] | None:
+ return default_headers
+
+ @_custom_headers.setter # type: ignore
+ def _custom_headers(self, value: _t.Mapping[str, str] | None) -> None: # type: ignore
+ global default_headers
+
+ default_headers = value
+
+ @property # type: ignore
+ @override
+ def _custom_query(self) -> _t.Mapping[str, object] | None:
+ return default_query
+
+ @_custom_query.setter # type: ignore
+ def _custom_query(self, value: _t.Mapping[str, object] | None) -> None: # type: ignore
+ global default_query
+
+ default_query = value
+
+ @property # type: ignore
+ @override
+ def _client(self) -> _httpx.Client:
+ return http_client or super()._client
+
+ @_client.setter # type: ignore
+ def _client(self, value: _httpx.Client) -> None: # type: ignore
+ global http_client
+
+ http_client = value
+
+ @override
+ def __del__(self) -> None:
+ try:
+ super().__del__()
+ except Exception:
+ pass
+
+
+class _AzureModuleClient(_ModuleClient, AzureOpenAI): # type: ignore
+ ...
+
+
+class _AmbiguousModuleClientUsageError(OpenAIError):
+ def __init__(self) -> None:
+ super().__init__(
+ "Ambiguous use of module client; please set `openai.api_type` or the `OPENAI_API_TYPE` environment variable to `openai` or `azure`"
+ )
+
+
+def _has_openai_credentials() -> bool:
+ return _os.environ.get("OPENAI_API_KEY") is not None
+
+
+def _has_azure_credentials() -> bool:
+ return azure_endpoint is not None or _os.environ.get("AZURE_OPENAI_API_KEY") is not None
+
+
+def _has_azure_ad_credentials() -> bool:
+ return (
+ _os.environ.get("AZURE_OPENAI_AD_TOKEN") is not None
+ or azure_ad_token is not None
+ or azure_ad_token_provider is not None
+ )
+
+
+_client: OpenAI | None = None
+
+
+def _load_client() -> OpenAI: # type: ignore[reportUnusedFunction]
+ global _client
+
+ if _client is None:
+ global api_type, azure_endpoint, azure_ad_token, api_version
+
+ if azure_endpoint is None:
+ azure_endpoint = _os.environ.get("AZURE_OPENAI_ENDPOINT")
+
+ if azure_ad_token is None:
+ azure_ad_token = _os.environ.get("AZURE_OPENAI_AD_TOKEN")
+
+ if api_version is None:
+ api_version = _os.environ.get("OPENAI_API_VERSION")
+
+ if api_type is None:
+ has_openai = _has_openai_credentials()
+ has_azure = _has_azure_credentials()
+ has_azure_ad = _has_azure_ad_credentials()
+
+ if has_openai and (has_azure or has_azure_ad):
+ raise _AmbiguousModuleClientUsageError()
+
+ if (azure_ad_token is not None or azure_ad_token_provider is not None) and _os.environ.get(
+ "AZURE_OPENAI_API_KEY"
+ ) is not None:
+ raise _AmbiguousModuleClientUsageError()
+
+ if has_azure or has_azure_ad:
+ api_type = "azure"
+ else:
+ api_type = "openai"
+
+ if api_type == "azure":
+ _client = _AzureModuleClient( # type: ignore
+ api_version=api_version,
+ azure_endpoint=azure_endpoint,
+ api_key=api_key,
+ azure_ad_token=azure_ad_token,
+ azure_ad_token_provider=azure_ad_token_provider,
+ organization=organization,
+ base_url=base_url,
+ timeout=timeout,
+ max_retries=max_retries,
+ default_headers=default_headers,
+ default_query=default_query,
+ http_client=http_client,
+ )
+ return _client
+
+ _client = _ModuleClient(
+ api_key=api_key,
+ organization=organization,
+ base_url=base_url,
+ timeout=timeout,
+ max_retries=max_retries,
+ default_headers=default_headers,
+ default_query=default_query,
+ http_client=http_client,
+ )
+ return _client
+
+ return _client
+
+
+def _reset_client() -> None: # type: ignore[reportUnusedFunction]
+ global _client
+
+ _client = None
+
+
+from ._module_client import chat as chat
+from ._module_client import audio as audio
+from ._module_client import edits as edits
+from ._module_client import files as files
+from ._module_client import images as images
+from ._module_client import models as models
+from ._module_client import embeddings as embeddings
+from ._module_client import fine_tunes as fine_tunes
+from ._module_client import completions as completions
+from ._module_client import fine_tuning as fine_tuning
+from ._module_client import moderations as moderations
diff --git a/src/openai/__main__.py b/src/openai/__main__.py
new file mode 100644
index 0000000000..4e28416e10
--- /dev/null
+++ b/src/openai/__main__.py
@@ -0,0 +1,3 @@
+from .cli import main
+
+main()
diff --git a/src/openai/_base_client.py b/src/openai/_base_client.py
new file mode 100644
index 0000000000..22f90050d7
--- /dev/null
+++ b/src/openai/_base_client.py
@@ -0,0 +1,1768 @@
+from __future__ import annotations
+
+import os
+import json
+import time
+import uuid
+import email
+import inspect
+import logging
+import platform
+import warnings
+import email.utils
+from types import TracebackType
+from random import random
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Type,
+ Union,
+ Generic,
+ Mapping,
+ TypeVar,
+ Iterable,
+ Iterator,
+ Optional,
+ Generator,
+ AsyncIterator,
+ cast,
+ overload,
+)
+from functools import lru_cache
+from typing_extensions import Literal, override
+
+import anyio
+import httpx
+import distro
+import pydantic
+from httpx import URL, Limits
+from pydantic import PrivateAttr
+
+from . import _exceptions
+from ._qs import Querystring
+from ._files import to_httpx_files, async_to_httpx_files
+from ._types import (
+ NOT_GIVEN,
+ Body,
+ Omit,
+ Query,
+ ModelT,
+ Headers,
+ Timeout,
+ NotGiven,
+ ResponseT,
+ Transport,
+ AnyMapping,
+ PostParser,
+ ProxiesTypes,
+ RequestFiles,
+ AsyncTransport,
+ RequestOptions,
+ UnknownResponse,
+ ModelBuilderProtocol,
+ BinaryResponseContent,
+)
+from ._utils import is_dict, is_given, is_mapping
+from ._compat import model_copy, model_dump
+from ._models import GenericModel, FinalRequestOptions, validate_type, construct_type
+from ._response import APIResponse
+from ._constants import (
+ DEFAULT_LIMITS,
+ DEFAULT_TIMEOUT,
+ DEFAULT_MAX_RETRIES,
+ RAW_RESPONSE_HEADER,
+)
+from ._streaming import Stream, AsyncStream
+from ._exceptions import APIStatusError, APITimeoutError, APIConnectionError
+
+log: logging.Logger = logging.getLogger(__name__)
+
+# TODO: make base page type vars covariant
+SyncPageT = TypeVar("SyncPageT", bound="BaseSyncPage[Any]")
+AsyncPageT = TypeVar("AsyncPageT", bound="BaseAsyncPage[Any]")
+
+
+_T = TypeVar("_T")
+_T_co = TypeVar("_T_co", covariant=True)
+
+_StreamT = TypeVar("_StreamT", bound=Stream[Any])
+_AsyncStreamT = TypeVar("_AsyncStreamT", bound=AsyncStream[Any])
+
+if TYPE_CHECKING:
+ from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
+else:
+ try:
+ from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
+ except ImportError:
+ # taken from https://github.com/encode/httpx/blob/3ba5fe0d7ac70222590e759c31442b1cab263791/httpx/_config.py#L366
+ HTTPX_DEFAULT_TIMEOUT = Timeout(5.0)
+
+
+class PageInfo:
+ """Stores the necesary information to build the request to retrieve the next page.
+
+ Either `url` or `params` must be set.
+ """
+
+ url: URL | NotGiven
+ params: Query | NotGiven
+
+ @overload
+ def __init__(
+ self,
+ *,
+ url: URL,
+ ) -> None:
+ ...
+
+ @overload
+ def __init__(
+ self,
+ *,
+ params: Query,
+ ) -> None:
+ ...
+
+ def __init__(
+ self,
+ *,
+ url: URL | NotGiven = NOT_GIVEN,
+ params: Query | NotGiven = NOT_GIVEN,
+ ) -> None:
+ self.url = url
+ self.params = params
+
+
+class BasePage(GenericModel, Generic[ModelT]):
+ """
+ Defines the core interface for pagination.
+
+ Type Args:
+ ModelT: The pydantic model that represents an item in the response.
+
+ Methods:
+ has_next_page(): Check if there is another page available
+ next_page_info(): Get the necessary information to make a request for the next page
+ """
+
+ _options: FinalRequestOptions = PrivateAttr()
+ _model: Type[ModelT] = PrivateAttr()
+
+ def has_next_page(self) -> bool:
+ items = self._get_page_items()
+ if not items:
+ return False
+ return self.next_page_info() is not None
+
+ def next_page_info(self) -> Optional[PageInfo]:
+ ...
+
+ def _get_page_items(self) -> Iterable[ModelT]: # type: ignore[empty-body]
+ ...
+
+ def _params_from_url(self, url: URL) -> httpx.QueryParams:
+ # TODO: do we have to preprocess params here?
+ return httpx.QueryParams(cast(Any, self._options.params)).merge(url.params)
+
+ def _info_to_options(self, info: PageInfo) -> FinalRequestOptions:
+ options = model_copy(self._options)
+ options._strip_raw_response_header()
+
+ if not isinstance(info.params, NotGiven):
+ options.params = {**options.params, **info.params}
+ return options
+
+ if not isinstance(info.url, NotGiven):
+ params = self._params_from_url(info.url)
+ url = info.url.copy_with(params=params)
+ options.params = dict(url.params)
+ options.url = str(url)
+ return options
+
+ raise ValueError("Unexpected PageInfo state")
+
+
+class BaseSyncPage(BasePage[ModelT], Generic[ModelT]):
+ _client: SyncAPIClient = pydantic.PrivateAttr()
+
+ def _set_private_attributes(
+ self,
+ client: SyncAPIClient,
+ model: Type[ModelT],
+ options: FinalRequestOptions,
+ ) -> None:
+ self._model = model
+ self._client = client
+ self._options = options
+
+ # Pydantic uses a custom `__iter__` method to support casting BaseModels
+ # to dictionaries. e.g. dict(model).
+ # As we want to support `for item in page`, this is inherently incompatible
+ # with the default pydantic behaviour. It is not possible to support both
+ # use cases at once. Fortunately, this is not a big deal as all other pydantic
+ # methods should continue to work as expected as there is an alternative method
+ # to cast a model to a dictionary, model.dict(), which is used internally
+ # by pydantic.
+ def __iter__(self) -> Iterator[ModelT]: # type: ignore
+ for page in self.iter_pages():
+ for item in page._get_page_items():
+ yield item
+
+ def iter_pages(self: SyncPageT) -> Iterator[SyncPageT]:
+ page = self
+ while True:
+ yield page
+ if page.has_next_page():
+ page = page.get_next_page()
+ else:
+ return
+
+ def get_next_page(self: SyncPageT) -> SyncPageT:
+ info = self.next_page_info()
+ if not info:
+ raise RuntimeError(
+ "No next page expected; please check `.has_next_page()` before calling `.get_next_page()`."
+ )
+
+ options = self._info_to_options(info)
+ return self._client._request_api_list(self._model, page=self.__class__, options=options)
+
+
+class AsyncPaginator(Generic[ModelT, AsyncPageT]):
+ def __init__(
+ self,
+ client: AsyncAPIClient,
+ options: FinalRequestOptions,
+ page_cls: Type[AsyncPageT],
+ model: Type[ModelT],
+ ) -> None:
+ self._model = model
+ self._client = client
+ self._options = options
+ self._page_cls = page_cls
+
+ def __await__(self) -> Generator[Any, None, AsyncPageT]:
+ return self._get_page().__await__()
+
+ async def _get_page(self) -> AsyncPageT:
+ def _parser(resp: AsyncPageT) -> AsyncPageT:
+ resp._set_private_attributes(
+ model=self._model,
+ options=self._options,
+ client=self._client,
+ )
+ return resp
+
+ self._options.post_parser = _parser
+
+ return await self._client.request(self._page_cls, self._options)
+
+ async def __aiter__(self) -> AsyncIterator[ModelT]:
+ # https://github.com/microsoft/pyright/issues/3464
+ page = cast(
+ AsyncPageT,
+ await self, # type: ignore
+ )
+ async for item in page:
+ yield item
+
+
+class BaseAsyncPage(BasePage[ModelT], Generic[ModelT]):
+ _client: AsyncAPIClient = pydantic.PrivateAttr()
+
+ def _set_private_attributes(
+ self,
+ model: Type[ModelT],
+ client: AsyncAPIClient,
+ options: FinalRequestOptions,
+ ) -> None:
+ self._model = model
+ self._client = client
+ self._options = options
+
+ async def __aiter__(self) -> AsyncIterator[ModelT]:
+ async for page in self.iter_pages():
+ for item in page._get_page_items():
+ yield item
+
+ async def iter_pages(self: AsyncPageT) -> AsyncIterator[AsyncPageT]:
+ page = self
+ while True:
+ yield page
+ if page.has_next_page():
+ page = await page.get_next_page()
+ else:
+ return
+
+ async def get_next_page(self: AsyncPageT) -> AsyncPageT:
+ info = self.next_page_info()
+ if not info:
+ raise RuntimeError(
+ "No next page expected; please check `.has_next_page()` before calling `.get_next_page()`."
+ )
+
+ options = self._info_to_options(info)
+ return await self._client._request_api_list(self._model, page=self.__class__, options=options)
+
+
+_HttpxClientT = TypeVar("_HttpxClientT", bound=Union[httpx.Client, httpx.AsyncClient])
+_DefaultStreamT = TypeVar("_DefaultStreamT", bound=Union[Stream[Any], AsyncStream[Any]])
+
+
+class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
+ _client: _HttpxClientT
+ _version: str
+ _base_url: URL
+ max_retries: int
+ timeout: Union[float, Timeout, None]
+ _limits: httpx.Limits
+ _proxies: ProxiesTypes | None
+ _transport: Transport | AsyncTransport | None
+ _strict_response_validation: bool
+ _idempotency_header: str | None
+ _default_stream_cls: type[_DefaultStreamT] | None = None
+
+ def __init__(
+ self,
+ *,
+ version: str,
+ base_url: str | URL,
+ _strict_response_validation: bool,
+ max_retries: int = DEFAULT_MAX_RETRIES,
+ timeout: float | Timeout | None = DEFAULT_TIMEOUT,
+ limits: httpx.Limits,
+ transport: Transport | AsyncTransport | None,
+ proxies: ProxiesTypes | None,
+ custom_headers: Mapping[str, str] | None = None,
+ custom_query: Mapping[str, object] | None = None,
+ ) -> None:
+ self._version = version
+ self._base_url = self._enforce_trailing_slash(URL(base_url))
+ self.max_retries = max_retries
+ self.timeout = timeout
+ self._limits = limits
+ self._proxies = proxies
+ self._transport = transport
+ self._custom_headers = custom_headers or {}
+ self._custom_query = custom_query or {}
+ self._strict_response_validation = _strict_response_validation
+ self._idempotency_header = None
+
+ def _enforce_trailing_slash(self, url: URL) -> URL:
+ if url.raw_path.endswith(b"/"):
+ return url
+ return url.copy_with(raw_path=url.raw_path + b"/")
+
+ def _make_status_error_from_response(
+ self,
+ response: httpx.Response,
+ ) -> APIStatusError:
+ err_text = response.text.strip()
+ body = err_text
+
+ try:
+ body = json.loads(err_text)
+ err_msg = f"Error code: {response.status_code} - {body}"
+ except Exception:
+ err_msg = err_text or f"Error code: {response.status_code}"
+
+ return self._make_status_error(err_msg, body=body, response=response)
+
+ def _make_status_error(
+ self,
+ err_msg: str,
+ *,
+ body: object,
+ response: httpx.Response,
+ ) -> _exceptions.APIStatusError:
+ raise NotImplementedError()
+
+ def _remaining_retries(
+ self,
+ remaining_retries: Optional[int],
+ options: FinalRequestOptions,
+ ) -> int:
+ return remaining_retries if remaining_retries is not None else options.get_max_retries(self.max_retries)
+
+ def _build_headers(self, options: FinalRequestOptions) -> httpx.Headers:
+ custom_headers = options.headers or {}
+ headers_dict = _merge_mappings(self.default_headers, custom_headers)
+ self._validate_headers(headers_dict, custom_headers)
+
+ headers = httpx.Headers(headers_dict)
+
+ idempotency_header = self._idempotency_header
+ if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers:
+ if not options.idempotency_key:
+ options.idempotency_key = self._idempotency_key()
+
+ headers[idempotency_header] = options.idempotency_key
+
+ return headers
+
+ def _prepare_url(self, url: str) -> URL:
+ """
+ Merge a URL argument together with any 'base_url' on the client,
+ to create the URL used for the outgoing request.
+ """
+ # Copied from httpx's `_merge_url` method.
+ merge_url = URL(url)
+ if merge_url.is_relative_url:
+ merge_raw_path = self.base_url.raw_path + merge_url.raw_path.lstrip(b"/")
+ return self.base_url.copy_with(raw_path=merge_raw_path)
+
+ return merge_url
+
+ def _build_request(
+ self,
+ options: FinalRequestOptions,
+ ) -> httpx.Request:
+ if log.isEnabledFor(logging.DEBUG):
+ log.debug("Request options: %s", model_dump(options, exclude_unset=True))
+
+ kwargs: dict[str, Any] = {}
+
+ json_data = options.json_data
+ if options.extra_json is not None:
+ if json_data is None:
+ json_data = cast(Body, options.extra_json)
+ elif is_mapping(json_data):
+ json_data = _merge_mappings(json_data, options.extra_json)
+ else:
+ raise RuntimeError(f"Unexpected JSON data type, {type(json_data)}, cannot merge with `extra_body`")
+
+ headers = self._build_headers(options)
+ params = _merge_mappings(self._custom_query, options.params)
+
+ # If the given Content-Type header is multipart/form-data then it
+ # has to be removed so that httpx can generate the header with
+ # additional information for us as it has to be in this form
+ # for the server to be able to correctly parse the request:
+ # multipart/form-data; boundary=---abc--
+ if headers.get("Content-Type") == "multipart/form-data":
+ headers.pop("Content-Type")
+
+ # As we are now sending multipart/form-data instead of application/json
+ # we need to tell httpx to use it, https://www.python-httpx.org/advanced/#multipart-file-encoding
+ if json_data:
+ if not is_dict(json_data):
+ raise TypeError(
+ f"Expected query input to be a dictionary for multipart requests but got {type(json_data)} instead."
+ )
+ kwargs["data"] = self._serialize_multipartform(json_data)
+
+ # TODO: report this error to httpx
+ return self._client.build_request( # pyright: ignore[reportUnknownMemberType]
+ headers=headers,
+ timeout=self.timeout if isinstance(options.timeout, NotGiven) else options.timeout,
+ method=options.method,
+ url=self._prepare_url(options.url),
+ # the `Query` type that we use is incompatible with qs'
+ # `Params` type as it needs to be typed as `Mapping[str, object]`
+ # so that passing a `TypedDict` doesn't cause an error.
+ # https://github.com/microsoft/pyright/issues/3526#event-6715453066
+ params=self.qs.stringify(cast(Mapping[str, Any], params)) if params else None,
+ json=json_data,
+ files=options.files,
+ **kwargs,
+ )
+
+ def _serialize_multipartform(self, data: Mapping[object, object]) -> dict[str, object]:
+ items = self.qs.stringify_items(
+ # TODO: type ignore is required as stringify_items is well typed but we can't be
+ # well typed without heavy validation.
+ data, # type: ignore
+ array_format="brackets",
+ )
+ serialized: dict[str, object] = {}
+ for key, value in items:
+ if key in serialized:
+ raise ValueError(f"Duplicate key encountered: {key}; This behaviour is not supported")
+ serialized[key] = value
+ return serialized
+
+ def _process_response(
+ self,
+ *,
+ cast_to: Type[ResponseT],
+ options: FinalRequestOptions,
+ response: httpx.Response,
+ stream: bool,
+ stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
+ ) -> ResponseT:
+ api_response = APIResponse(
+ raw=response,
+ client=self,
+ cast_to=cast_to,
+ stream=stream,
+ stream_cls=stream_cls,
+ options=options,
+ )
+
+ if response.request.headers.get(RAW_RESPONSE_HEADER) == "true":
+ return cast(ResponseT, api_response)
+
+ return api_response.parse()
+
+ def _process_response_data(
+ self,
+ *,
+ data: object,
+ cast_to: type[ResponseT],
+ response: httpx.Response,
+ ) -> ResponseT:
+ if data is None:
+ return cast(ResponseT, None)
+
+ if cast_to is UnknownResponse:
+ return cast(ResponseT, data)
+
+ if inspect.isclass(cast_to) and issubclass(cast_to, ModelBuilderProtocol):
+ return cast(ResponseT, cast_to.build(response=response, data=data))
+
+ if self._strict_response_validation:
+ return cast(ResponseT, validate_type(type_=cast_to, value=data))
+
+ return cast(ResponseT, construct_type(type_=cast_to, value=data))
+
+ @property
+ def qs(self) -> Querystring:
+ return Querystring()
+
+ @property
+ def custom_auth(self) -> httpx.Auth | None:
+ return None
+
+ @property
+ def auth_headers(self) -> dict[str, str]:
+ return {}
+
+ @property
+ def default_headers(self) -> dict[str, str | Omit]:
+ return {
+ "Accept": "application/json",
+ "Content-Type": "application/json",
+ "User-Agent": self.user_agent,
+ **self.platform_headers(),
+ **self.auth_headers,
+ **self._custom_headers,
+ }
+
+ def _validate_headers(
+ self,
+ headers: Headers, # noqa: ARG002
+ custom_headers: Headers, # noqa: ARG002
+ ) -> None:
+ """Validate the given default headers and custom headers.
+
+ Does nothing by default.
+ """
+ return
+
+ @property
+ def user_agent(self) -> str:
+ return f"{self.__class__.__name__}/Python {self._version}"
+
+ @property
+ def base_url(self) -> URL:
+ return self._base_url
+
+ @base_url.setter
+ def base_url(self, url: URL | str) -> None:
+ self._client.base_url = url if isinstance(url, URL) else URL(url)
+
+ @lru_cache(maxsize=None)
+ def platform_headers(self) -> Dict[str, str]:
+ return {
+ "X-Stainless-Lang": "python",
+ "X-Stainless-Package-Version": self._version,
+ "X-Stainless-OS": str(get_platform()),
+ "X-Stainless-Arch": str(get_architecture()),
+ "X-Stainless-Runtime": platform.python_implementation(),
+ "X-Stainless-Runtime-Version": platform.python_version(),
+ }
+
+ def _calculate_retry_timeout(
+ self,
+ remaining_retries: int,
+ options: FinalRequestOptions,
+ response_headers: Optional[httpx.Headers] = None,
+ ) -> float:
+ max_retries = options.get_max_retries(self.max_retries)
+ try:
+ # About the Retry-After header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
+ #
+ # ". See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After#syntax for
+ # details.
+ if response_headers is not None:
+ retry_header = response_headers.get("retry-after")
+ try:
+ retry_after = int(retry_header)
+ except Exception:
+ retry_date_tuple = email.utils.parsedate_tz(retry_header)
+ if retry_date_tuple is None:
+ retry_after = -1
+ else:
+ retry_date = email.utils.mktime_tz(retry_date_tuple)
+ retry_after = int(retry_date - time.time())
+ else:
+ retry_after = -1
+
+ except Exception:
+ retry_after = -1
+
+ # If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says.
+ if 0 < retry_after <= 60:
+ return retry_after
+
+ initial_retry_delay = 0.5
+ max_retry_delay = 8.0
+ nb_retries = max_retries - remaining_retries
+
+ # Apply exponential backoff, but not more than the max.
+ sleep_seconds = min(initial_retry_delay * pow(2.0, nb_retries), max_retry_delay)
+
+ # Apply some jitter, plus-or-minus half a second.
+ jitter = 1 - 0.25 * random()
+ timeout = sleep_seconds * jitter
+ return timeout if timeout >= 0 else 0
+
+ def _should_retry(self, response: httpx.Response) -> bool:
+ # Note: this is not a standard header
+ should_retry_header = response.headers.get("x-should-retry")
+
+ # If the server explicitly says whether or not to retry, obey.
+ if should_retry_header == "true":
+ return True
+ if should_retry_header == "false":
+ return False
+
+ # Retry on request timeouts.
+ if response.status_code == 408:
+ return True
+
+ # Retry on lock timeouts.
+ if response.status_code == 409:
+ return True
+
+ # Retry on rate limits.
+ if response.status_code == 429:
+ return True
+
+ # Retry internal errors.
+ if response.status_code >= 500:
+ return True
+
+ return False
+
+ def _idempotency_key(self) -> str:
+ return f"stainless-python-retry-{uuid.uuid4()}"
+
+
+class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
+ _client: httpx.Client
+ _has_custom_http_client: bool
+ _default_stream_cls: type[Stream[Any]] | None = None
+
+ def __init__(
+ self,
+ *,
+ version: str,
+ base_url: str | URL,
+ max_retries: int = DEFAULT_MAX_RETRIES,
+ timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
+ transport: Transport | None = None,
+ proxies: ProxiesTypes | None = None,
+ limits: Limits | None = None,
+ http_client: httpx.Client | None = None,
+ custom_headers: Mapping[str, str] | None = None,
+ custom_query: Mapping[str, object] | None = None,
+ _strict_response_validation: bool,
+ ) -> None:
+ if limits is not None:
+ warnings.warn(
+ "The `connection_pool_limits` argument is deprecated. The `http_client` argument should be passed instead",
+ category=DeprecationWarning,
+ stacklevel=3,
+ )
+ if http_client is not None:
+ raise ValueError("The `http_client` argument is mutually exclusive with `connection_pool_limits`")
+ else:
+ limits = DEFAULT_LIMITS
+
+ if transport is not None:
+ warnings.warn(
+ "The `transport` argument is deprecated. The `http_client` argument should be passed instead",
+ category=DeprecationWarning,
+ stacklevel=3,
+ )
+ if http_client is not None:
+ raise ValueError("The `http_client` argument is mutually exclusive with `transport`")
+
+ if proxies is not None:
+ warnings.warn(
+ "The `proxies` argument is deprecated. The `http_client` argument should be passed instead",
+ category=DeprecationWarning,
+ stacklevel=3,
+ )
+ if http_client is not None:
+ raise ValueError("The `http_client` argument is mutually exclusive with `proxies`")
+
+ if not is_given(timeout):
+ # if the user passed in a custom http client with a non-default
+ # timeout set then we use that timeout.
+ #
+ # note: there is an edge case here where the user passes in a client
+ # where they've explicitly set the timeout to match the default timeout
+ # as this check is structural, meaning that we'll think they didn't
+ # pass in a timeout and will ignore it
+ if http_client and http_client.timeout != HTTPX_DEFAULT_TIMEOUT:
+ timeout = http_client.timeout
+ else:
+ timeout = DEFAULT_TIMEOUT
+
+ super().__init__(
+ version=version,
+ limits=limits,
+ # cast to a valid type because mypy doesn't understand our type narrowing
+ timeout=cast(Timeout, timeout),
+ proxies=proxies,
+ base_url=base_url,
+ transport=transport,
+ max_retries=max_retries,
+ custom_query=custom_query,
+ custom_headers=custom_headers,
+ _strict_response_validation=_strict_response_validation,
+ )
+ self._client = http_client or httpx.Client(
+ base_url=base_url,
+ # cast to a valid type because mypy doesn't understand our type narrowing
+ timeout=cast(Timeout, timeout),
+ proxies=proxies,
+ transport=transport,
+ limits=limits,
+ )
+ self._has_custom_http_client = bool(http_client)
+
+ def is_closed(self) -> bool:
+ return self._client.is_closed
+
+ def close(self) -> None:
+ """Close the underlying HTTPX client.
+
+ The client will *not* be usable after this.
+ """
+ # If an error is thrown while constructing a client, self._client
+ # may not be present
+ if hasattr(self, "_client"):
+ self._client.close()
+
+ def __enter__(self: _T) -> _T:
+ return self
+
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc: BaseException | None,
+ exc_tb: TracebackType | None,
+ ) -> None:
+ self.close()
+
+ def _prepare_options(
+ self,
+ options: FinalRequestOptions, # noqa: ARG002
+ ) -> None:
+ """Hook for mutating the given options"""
+ return None
+
+ def _prepare_request(
+ self,
+ request: httpx.Request, # noqa: ARG002
+ ) -> None:
+ """This method is used as a callback for mutating the `Request` object
+ after it has been constructed.
+ This is useful for cases where you want to add certain headers based off of
+ the request properties, e.g. `url`, `method` etc.
+ """
+ return None
+
+ @overload
+ def request(
+ self,
+ cast_to: Type[ResponseT],
+ options: FinalRequestOptions,
+ remaining_retries: Optional[int] = None,
+ *,
+ stream: Literal[True],
+ stream_cls: Type[_StreamT],
+ ) -> _StreamT:
+ ...
+
+ @overload
+ def request(
+ self,
+ cast_to: Type[ResponseT],
+ options: FinalRequestOptions,
+ remaining_retries: Optional[int] = None,
+ *,
+ stream: Literal[False] = False,
+ ) -> ResponseT:
+ ...
+
+ @overload
+ def request(
+ self,
+ cast_to: Type[ResponseT],
+ options: FinalRequestOptions,
+ remaining_retries: Optional[int] = None,
+ *,
+ stream: bool = False,
+ stream_cls: Type[_StreamT] | None = None,
+ ) -> ResponseT | _StreamT:
+ ...
+
+ def request(
+ self,
+ cast_to: Type[ResponseT],
+ options: FinalRequestOptions,
+ remaining_retries: Optional[int] = None,
+ *,
+ stream: bool = False,
+ stream_cls: type[_StreamT] | None = None,
+ ) -> ResponseT | _StreamT:
+ return self._request(
+ cast_to=cast_to,
+ options=options,
+ stream=stream,
+ stream_cls=stream_cls,
+ remaining_retries=remaining_retries,
+ )
+
+ def _request(
+ self,
+ *,
+ cast_to: Type[ResponseT],
+ options: FinalRequestOptions,
+ remaining_retries: int | None,
+ stream: bool,
+ stream_cls: type[_StreamT] | None,
+ ) -> ResponseT | _StreamT:
+ self._prepare_options(options)
+
+ retries = self._remaining_retries(remaining_retries, options)
+ request = self._build_request(options)
+ self._prepare_request(request)
+
+ try:
+ response = self._client.send(request, auth=self.custom_auth, stream=stream)
+ log.debug(
+ 'HTTP Request: %s %s "%i %s"', request.method, request.url, response.status_code, response.reason_phrase
+ )
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
+ if retries > 0 and self._should_retry(err.response):
+ return self._retry_request(
+ options,
+ cast_to,
+ retries,
+ err.response.headers,
+ stream=stream,
+ stream_cls=stream_cls,
+ )
+
+ # If the response is streamed then we need to explicitly read the response
+ # to completion before attempting to access the response text.
+ err.response.read()
+ raise self._make_status_error_from_response(err.response) from None
+ except httpx.TimeoutException as err:
+ if retries > 0:
+ return self._retry_request(
+ options,
+ cast_to,
+ retries,
+ stream=stream,
+ stream_cls=stream_cls,
+ )
+ raise APITimeoutError(request=request) from err
+ except Exception as err:
+ if retries > 0:
+ return self._retry_request(
+ options,
+ cast_to,
+ retries,
+ stream=stream,
+ stream_cls=stream_cls,
+ )
+ raise APIConnectionError(request=request) from err
+
+ return self._process_response(
+ cast_to=cast_to,
+ options=options,
+ response=response,
+ stream=stream,
+ stream_cls=stream_cls,
+ )
+
+ def _retry_request(
+ self,
+ options: FinalRequestOptions,
+ cast_to: Type[ResponseT],
+ remaining_retries: int,
+ response_headers: Optional[httpx.Headers] = None,
+ *,
+ stream: bool,
+ stream_cls: type[_StreamT] | None,
+ ) -> ResponseT | _StreamT:
+ remaining = remaining_retries - 1
+ timeout = self._calculate_retry_timeout(remaining, options, response_headers)
+ log.info("Retrying request to %s in %f seconds", options.url, timeout)
+
+ # In a synchronous context we are blocking the entire thread. Up to the library user to run the client in a
+ # different thread if necessary.
+ time.sleep(timeout)
+
+ return self._request(
+ options=options,
+ cast_to=cast_to,
+ remaining_retries=remaining,
+ stream=stream,
+ stream_cls=stream_cls,
+ )
+
+ def _request_api_list(
+ self,
+ model: Type[ModelT],
+ page: Type[SyncPageT],
+ options: FinalRequestOptions,
+ ) -> SyncPageT:
+ def _parser(resp: SyncPageT) -> SyncPageT:
+ resp._set_private_attributes(
+ client=self,
+ model=model,
+ options=options,
+ )
+ return resp
+
+ options.post_parser = _parser
+
+ return self.request(page, options, stream=False)
+
+ @overload
+ def get(
+ self,
+ path: str,
+ *,
+ cast_to: Type[ResponseT],
+ options: RequestOptions = {},
+ stream: Literal[False] = False,
+ ) -> ResponseT:
+ ...
+
+ @overload
+ def get(
+ self,
+ path: str,
+ *,
+ cast_to: Type[ResponseT],
+ options: RequestOptions = {},
+ stream: Literal[True],
+ stream_cls: type[_StreamT],
+ ) -> _StreamT:
+ ...
+
+ @overload
+ def get(
+ self,
+ path: str,
+ *,
+ cast_to: Type[ResponseT],
+ options: RequestOptions = {},
+ stream: bool,
+ stream_cls: type[_StreamT] | None = None,
+ ) -> ResponseT | _StreamT:
+ ...
+
+ def get(
+ self,
+ path: str,
+ *,
+ cast_to: Type[ResponseT],
+ options: RequestOptions = {},
+ stream: bool = False,
+ stream_cls: type[_StreamT] | None = None,
+ ) -> ResponseT | _StreamT:
+ opts = FinalRequestOptions.construct(method="get", url=path, **options)
+ # cast is required because mypy complains about returning Any even though
+ # it understands the type variables
+ return cast(ResponseT, self.request(cast_to, opts, stream=stream, stream_cls=stream_cls))
+
+ @overload
+ def post(
+ self,
+ path: str,
+ *,
+ cast_to: Type[ResponseT],
+ body: Body | None = None,
+ options: RequestOptions = {},
+ files: RequestFiles | None = None,
+ stream: Literal[False] = False,
+ ) -> ResponseT:
+ ...
+
+ @overload
+ def post(
+ self,
+ path: str,
+ *,
+ cast_to: Type[ResponseT],
+ body: Body | None = None,
+ options: RequestOptions = {},
+ files: RequestFiles | None = None,
+ stream: Literal[True],
+ stream_cls: type[_StreamT],
+ ) -> _StreamT:
+ ...
+
+ @overload
+ def post(
+ self,
+ path: str,
+ *,
+ cast_to: Type[ResponseT],
+ body: Body | None = None,
+ options: RequestOptions = {},
+ files: RequestFiles | None = None,
+ stream: bool,
+ stream_cls: type[_StreamT] | None = None,
+ ) -> ResponseT | _StreamT:
+ ...
+
+ def post(
+ self,
+ path: str,
+ *,
+ cast_to: Type[ResponseT],
+ body: Body | None = None,
+ options: RequestOptions = {},
+ files: RequestFiles | None = None,
+ stream: bool = False,
+ stream_cls: type[_StreamT] | None = None,
+ ) -> ResponseT | _StreamT:
+ opts = FinalRequestOptions.construct(
+ method="post", url=path, json_data=body, files=to_httpx_files(files), **options
+ )
+ return cast(ResponseT, self.request(cast_to, opts, stream=stream, stream_cls=stream_cls))
+
+ def patch(
+ self,
+ path: str,
+ *,
+ cast_to: Type[ResponseT],
+ body: Body | None = None,
+ options: RequestOptions = {},
+ ) -> ResponseT:
+ opts = FinalRequestOptions.construct(method="patch", url=path, json_data=body, **options)
+ return self.request(cast_to, opts)
+
+ def put(
+ self,
+ path: str,
+ *,
+ cast_to: Type[ResponseT],
+ body: Body | None = None,
+ files: RequestFiles | None = None,
+ options: RequestOptions = {},
+ ) -> ResponseT:
+ opts = FinalRequestOptions.construct(
+ method="put", url=path, json_data=body, files=to_httpx_files(files), **options
+ )
+ return self.request(cast_to, opts)
+
+ def delete(
+ self,
+ path: str,
+ *,
+ cast_to: Type[ResponseT],
+ body: Body | None = None,
+ options: RequestOptions = {},
+ ) -> ResponseT:
+ opts = FinalRequestOptions.construct(method="delete", url=path, json_data=body, **options)
+ return self.request(cast_to, opts)
+
+ def get_api_list(
+ self,
+ path: str,
+ *,
+ model: Type[ModelT],
+ page: Type[SyncPageT],
+ body: Body | None = None,
+ options: RequestOptions = {},
+ method: str = "get",
+ ) -> SyncPageT:
+ opts = FinalRequestOptions.construct(method=method, url=path, json_data=body, **options)
+ return self._request_api_list(model, page, opts)
+
+
+class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
+ _client: httpx.AsyncClient
+ _has_custom_http_client: bool
+ _default_stream_cls: type[AsyncStream[Any]] | None = None
+
+ def __init__(
+ self,
+ *,
+ version: str,
+ base_url: str | URL,
+ _strict_response_validation: bool,
+ max_retries: int = DEFAULT_MAX_RETRIES,
+ timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
+ transport: AsyncTransport | None = None,
+ proxies: ProxiesTypes | None = None,
+ limits: Limits | None = None,
+ http_client: httpx.AsyncClient | None = None,
+ custom_headers: Mapping[str, str] | None = None,
+ custom_query: Mapping[str, object] | None = None,
+ ) -> None:
+ if limits is not None:
+ warnings.warn(
+ "The `connection_pool_limits` argument is deprecated. The `http_client` argument should be passed instead",
+ category=DeprecationWarning,
+ stacklevel=3,
+ )
+ if http_client is not None:
+ raise ValueError("The `http_client` argument is mutually exclusive with `connection_pool_limits`")
+ else:
+ limits = DEFAULT_LIMITS
+
+ if transport is not None:
+ warnings.warn(
+ "The `transport` argument is deprecated. The `http_client` argument should be passed instead",
+ category=DeprecationWarning,
+ stacklevel=3,
+ )
+ if http_client is not None:
+ raise ValueError("The `http_client` argument is mutually exclusive with `transport`")
+
+ if proxies is not None:
+ warnings.warn(
+ "The `proxies` argument is deprecated. The `http_client` argument should be passed instead",
+ category=DeprecationWarning,
+ stacklevel=3,
+ )
+ if http_client is not None:
+ raise ValueError("The `http_client` argument is mutually exclusive with `proxies`")
+
+ if not is_given(timeout):
+ # if the user passed in a custom http client with a non-default
+ # timeout set then we use that timeout.
+ #
+ # note: there is an edge case here where the user passes in a client
+ # where they've explicitly set the timeout to match the default timeout
+ # as this check is structural, meaning that we'll think they didn't
+ # pass in a timeout and will ignore it
+ if http_client and http_client.timeout != HTTPX_DEFAULT_TIMEOUT:
+ timeout = http_client.timeout
+ else:
+ timeout = DEFAULT_TIMEOUT
+
+ super().__init__(
+ version=version,
+ base_url=base_url,
+ limits=limits,
+ # cast to a valid type because mypy doesn't understand our type narrowing
+ timeout=cast(Timeout, timeout),
+ proxies=proxies,
+ transport=transport,
+ max_retries=max_retries,
+ custom_query=custom_query,
+ custom_headers=custom_headers,
+ _strict_response_validation=_strict_response_validation,
+ )
+ self._client = http_client or httpx.AsyncClient(
+ base_url=base_url,
+ # cast to a valid type because mypy doesn't understand our type narrowing
+ timeout=cast(Timeout, timeout),
+ proxies=proxies,
+ transport=transport,
+ limits=limits,
+ )
+ self._has_custom_http_client = bool(http_client)
+
+ def is_closed(self) -> bool:
+ return self._client.is_closed
+
+ async def close(self) -> None:
+ """Close the underlying HTTPX client.
+
+ The client will *not* be usable after this.
+ """
+ await self._client.aclose()
+
+ async def __aenter__(self: _T) -> _T:
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc: BaseException | None,
+ exc_tb: TracebackType | None,
+ ) -> None:
+ await self.close()
+
+ async def _prepare_options(
+ self,
+ options: FinalRequestOptions, # noqa: ARG002
+ ) -> None:
+ """Hook for mutating the given options"""
+ return None
+
+ async def _prepare_request(
+ self,
+ request: httpx.Request, # noqa: ARG002
+ ) -> None:
+ """This method is used as a callback for mutating the `Request` object
+ after it has been constructed.
+ This is useful for cases where you want to add certain headers based off of
+ the request properties, e.g. `url`, `method` etc.
+ """
+ return None
+
+ @overload
+ async def request(
+ self,
+ cast_to: Type[ResponseT],
+ options: FinalRequestOptions,
+ *,
+ stream: Literal[False] = False,
+ remaining_retries: Optional[int] = None,
+ ) -> ResponseT:
+ ...
+
+ @overload
+ async def request(
+ self,
+ cast_to: Type[ResponseT],
+ options: FinalRequestOptions,
+ *,
+ stream: Literal[True],
+ stream_cls: type[_AsyncStreamT],
+ remaining_retries: Optional[int] = None,
+ ) -> _AsyncStreamT:
+ ...
+
+ @overload
+ async def request(
+ self,
+ cast_to: Type[ResponseT],
+ options: FinalRequestOptions,
+ *,
+ stream: bool,
+ stream_cls: type[_AsyncStreamT] | None = None,
+ remaining_retries: Optional[int] = None,
+ ) -> ResponseT | _AsyncStreamT:
+ ...
+
+ async def request(
+ self,
+ cast_to: Type[ResponseT],
+ options: FinalRequestOptions,
+ *,
+ stream: bool = False,
+ stream_cls: type[_AsyncStreamT] | None = None,
+ remaining_retries: Optional[int] = None,
+ ) -> ResponseT | _AsyncStreamT:
+ return await self._request(
+ cast_to=cast_to,
+ options=options,
+ stream=stream,
+ stream_cls=stream_cls,
+ remaining_retries=remaining_retries,
+ )
+
+ async def _request(
+ self,
+ cast_to: Type[ResponseT],
+ options: FinalRequestOptions,
+ *,
+ stream: bool,
+ stream_cls: type[_AsyncStreamT] | None,
+ remaining_retries: int | None,
+ ) -> ResponseT | _AsyncStreamT:
+ await self._prepare_options(options)
+
+ retries = self._remaining_retries(remaining_retries, options)
+ request = self._build_request(options)
+ await self._prepare_request(request)
+
+ try:
+ response = await self._client.send(request, auth=self.custom_auth, stream=stream)
+ log.debug(
+ 'HTTP Request: %s %s "%i %s"', request.method, request.url, response.status_code, response.reason_phrase
+ )
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
+ if retries > 0 and self._should_retry(err.response):
+ return await self._retry_request(
+ options,
+ cast_to,
+ retries,
+ err.response.headers,
+ stream=stream,
+ stream_cls=stream_cls,
+ )
+
+ # If the response is streamed then we need to explicitly read the response
+ # to completion before attempting to access the response text.
+ await err.response.aread()
+ raise self._make_status_error_from_response(err.response) from None
+ except httpx.ConnectTimeout as err:
+ if retries > 0:
+ return await self._retry_request(options, cast_to, retries, stream=stream, stream_cls=stream_cls)
+ raise APITimeoutError(request=request) from err
+ except httpx.ReadTimeout as err:
+ # We explicitly do not retry on ReadTimeout errors as this means
+ # that the server processing the request has taken 60 seconds
+ # (our default timeout). This likely indicates that something
+ # is not working as expected on the server side.
+ raise
+ except httpx.TimeoutException as err:
+ if retries > 0:
+ return await self._retry_request(options, cast_to, retries, stream=stream, stream_cls=stream_cls)
+ raise APITimeoutError(request=request) from err
+ except Exception as err:
+ if retries > 0:
+ return await self._retry_request(options, cast_to, retries, stream=stream, stream_cls=stream_cls)
+ raise APIConnectionError(request=request) from err
+
+ return self._process_response(
+ cast_to=cast_to,
+ options=options,
+ response=response,
+ stream=stream,
+ stream_cls=stream_cls,
+ )
+
+ async def _retry_request(
+ self,
+ options: FinalRequestOptions,
+ cast_to: Type[ResponseT],
+ remaining_retries: int,
+ response_headers: Optional[httpx.Headers] = None,
+ *,
+ stream: bool,
+ stream_cls: type[_AsyncStreamT] | None,
+ ) -> ResponseT | _AsyncStreamT:
+ remaining = remaining_retries - 1
+ timeout = self._calculate_retry_timeout(remaining, options, response_headers)
+ log.info("Retrying request to %s in %f seconds", options.url, timeout)
+
+ await anyio.sleep(timeout)
+
+ return await self._request(
+ options=options,
+ cast_to=cast_to,
+ remaining_retries=remaining,
+ stream=stream,
+ stream_cls=stream_cls,
+ )
+
+ def _request_api_list(
+ self,
+ model: Type[ModelT],
+ page: Type[AsyncPageT],
+ options: FinalRequestOptions,
+ ) -> AsyncPaginator[ModelT, AsyncPageT]:
+ return AsyncPaginator(client=self, options=options, page_cls=page, model=model)
+
+ @overload
+ async def get(
+ self,
+ path: str,
+ *,
+ cast_to: Type[ResponseT],
+ options: RequestOptions = {},
+ stream: Literal[False] = False,
+ ) -> ResponseT:
+ ...
+
+ @overload
+ async def get(
+ self,
+ path: str,
+ *,
+ cast_to: Type[ResponseT],
+ options: RequestOptions = {},
+ stream: Literal[True],
+ stream_cls: type[_AsyncStreamT],
+ ) -> _AsyncStreamT:
+ ...
+
+ @overload
+ async def get(
+ self,
+ path: str,
+ *,
+ cast_to: Type[ResponseT],
+ options: RequestOptions = {},
+ stream: bool,
+ stream_cls: type[_AsyncStreamT] | None = None,
+ ) -> ResponseT | _AsyncStreamT:
+ ...
+
+ async def get(
+ self,
+ path: str,
+ *,
+ cast_to: Type[ResponseT],
+ options: RequestOptions = {},
+ stream: bool = False,
+ stream_cls: type[_AsyncStreamT] | None = None,
+ ) -> ResponseT | _AsyncStreamT:
+ opts = FinalRequestOptions.construct(method="get", url=path, **options)
+ return await self.request(cast_to, opts, stream=stream, stream_cls=stream_cls)
+
+ @overload
+ async def post(
+ self,
+ path: str,
+ *,
+ cast_to: Type[ResponseT],
+ body: Body | None = None,
+ files: RequestFiles | None = None,
+ options: RequestOptions = {},
+ stream: Literal[False] = False,
+ ) -> ResponseT:
+ ...
+
+ @overload
+ async def post(
+ self,
+ path: str,
+ *,
+ cast_to: Type[ResponseT],
+ body: Body | None = None,
+ files: RequestFiles | None = None,
+ options: RequestOptions = {},
+ stream: Literal[True],
+ stream_cls: type[_AsyncStreamT],
+ ) -> _AsyncStreamT:
+ ...
+
+ @overload
+ async def post(
+ self,
+ path: str,
+ *,
+ cast_to: Type[ResponseT],
+ body: Body | None = None,
+ files: RequestFiles | None = None,
+ options: RequestOptions = {},
+ stream: bool,
+ stream_cls: type[_AsyncStreamT] | None = None,
+ ) -> ResponseT | _AsyncStreamT:
+ ...
+
+ async def post(
+ self,
+ path: str,
+ *,
+ cast_to: Type[ResponseT],
+ body: Body | None = None,
+ files: RequestFiles | None = None,
+ options: RequestOptions = {},
+ stream: bool = False,
+ stream_cls: type[_AsyncStreamT] | None = None,
+ ) -> ResponseT | _AsyncStreamT:
+ opts = FinalRequestOptions.construct(
+ method="post", url=path, json_data=body, files=await async_to_httpx_files(files), **options
+ )
+ return await self.request(cast_to, opts, stream=stream, stream_cls=stream_cls)
+
+ async def patch(
+ self,
+ path: str,
+ *,
+ cast_to: Type[ResponseT],
+ body: Body | None = None,
+ options: RequestOptions = {},
+ ) -> ResponseT:
+ opts = FinalRequestOptions.construct(method="patch", url=path, json_data=body, **options)
+ return await self.request(cast_to, opts)
+
+ async def put(
+ self,
+ path: str,
+ *,
+ cast_to: Type[ResponseT],
+ body: Body | None = None,
+ files: RequestFiles | None = None,
+ options: RequestOptions = {},
+ ) -> ResponseT:
+ opts = FinalRequestOptions.construct(
+ method="put", url=path, json_data=body, files=await async_to_httpx_files(files), **options
+ )
+ return await self.request(cast_to, opts)
+
+ async def delete(
+ self,
+ path: str,
+ *,
+ cast_to: Type[ResponseT],
+ body: Body | None = None,
+ options: RequestOptions = {},
+ ) -> ResponseT:
+ opts = FinalRequestOptions.construct(method="delete", url=path, json_data=body, **options)
+ return await self.request(cast_to, opts)
+
+ def get_api_list(
+ self,
+ path: str,
+ *,
+ # TODO: support paginating `str`
+ model: Type[ModelT],
+ page: Type[AsyncPageT],
+ body: Body | None = None,
+ options: RequestOptions = {},
+ method: str = "get",
+ ) -> AsyncPaginator[ModelT, AsyncPageT]:
+ opts = FinalRequestOptions.construct(method=method, url=path, json_data=body, **options)
+ return self._request_api_list(model, page, opts)
+
+
+def make_request_options(
+ *,
+ query: Query | None = None,
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ idempotency_key: str | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ post_parser: PostParser | NotGiven = NOT_GIVEN,
+) -> RequestOptions:
+ """Create a dict of type RequestOptions without keys of NotGiven values."""
+ options: RequestOptions = {}
+ if extra_headers is not None:
+ options["headers"] = extra_headers
+
+ if extra_body is not None:
+ options["extra_json"] = cast(AnyMapping, extra_body)
+
+ if query is not None:
+ options["params"] = query
+
+ if extra_query is not None:
+ options["params"] = {**options.get("params", {}), **extra_query}
+
+ if not isinstance(timeout, NotGiven):
+ options["timeout"] = timeout
+
+ if idempotency_key is not None:
+ options["idempotency_key"] = idempotency_key
+
+ if is_given(post_parser):
+ # internal
+ options["post_parser"] = post_parser # type: ignore
+
+ return options
+
+
+class OtherPlatform:
+ def __init__(self, name: str) -> None:
+ self.name = name
+
+ @override
+ def __str__(self) -> str:
+ return f"Other:{self.name}"
+
+
+Platform = Union[
+ OtherPlatform,
+ Literal[
+ "MacOS",
+ "Linux",
+ "Windows",
+ "FreeBSD",
+ "OpenBSD",
+ "iOS",
+ "Android",
+ "Unknown",
+ ],
+]
+
+
+def get_platform() -> Platform:
+ system = platform.system().lower()
+ platform_name = platform.platform().lower()
+ if "iphone" in platform_name or "ipad" in platform_name:
+ # Tested using Python3IDE on an iPhone 11 and Pythonista on an iPad 7
+ # system is Darwin and platform_name is a string like:
+ # - Darwin-21.6.0-iPhone12,1-64bit
+ # - Darwin-21.6.0-iPad7,11-64bit
+ return "iOS"
+
+ if system == "darwin":
+ return "MacOS"
+
+ if system == "windows":
+ return "Windows"
+
+ if "android" in platform_name:
+ # Tested using Pydroid 3
+ # system is Linux and platform_name is a string like 'Linux-5.10.81-android12-9-00001-geba40aecb3b7-ab8534902-aarch64-with-libc'
+ return "Android"
+
+ if system == "linux":
+ # https://distro.readthedocs.io/en/latest/#distro.id
+ distro_id = distro.id()
+ if distro_id == "freebsd":
+ return "FreeBSD"
+
+ if distro_id == "openbsd":
+ return "OpenBSD"
+
+ return "Linux"
+
+ if platform_name:
+ return OtherPlatform(platform_name)
+
+ return "Unknown"
+
+
+class OtherArch:
+ def __init__(self, name: str) -> None:
+ self.name = name
+
+ @override
+ def __str__(self) -> str:
+ return f"other:{self.name}"
+
+
+Arch = Union[OtherArch, Literal["x32", "x64", "arm", "arm64", "unknown"]]
+
+
+def get_architecture() -> Arch:
+ python_bitness, _ = platform.architecture()
+ machine = platform.machine().lower()
+ if machine in ("arm64", "aarch64"):
+ return "arm64"
+
+ # TODO: untested
+ if machine == "arm":
+ return "arm"
+
+ if machine == "x86_64":
+ return "x64"
+
+ # TODO: untested
+ if python_bitness == "32bit":
+ return "x32"
+
+ if machine:
+ return OtherArch(machine)
+
+ return "unknown"
+
+
+def _merge_mappings(
+ obj1: Mapping[_T_co, Union[_T, Omit]],
+ obj2: Mapping[_T_co, Union[_T, Omit]],
+) -> Dict[_T_co, _T]:
+ """Merge two mappings of the same type, removing any values that are instances of `Omit`.
+
+ In cases with duplicate keys the second mapping takes precedence.
+ """
+ merged = {**obj1, **obj2}
+ return {key: value for key, value in merged.items() if not isinstance(value, Omit)}
+
+
+class HttpxBinaryResponseContent(BinaryResponseContent):
+ response: httpx.Response
+
+ def __init__(self, response: httpx.Response) -> None:
+ self.response = response
+
+ @property
+ @override
+ def content(self) -> bytes:
+ return self.response.content
+
+ @property
+ @override
+ def text(self) -> str:
+ return self.response.text
+
+ @property
+ @override
+ def encoding(self) -> Optional[str]:
+ return self.response.encoding
+
+ @property
+ @override
+ def charset_encoding(self) -> Optional[str]:
+ return self.response.charset_encoding
+
+ @override
+ def json(self, **kwargs: Any) -> Any:
+ return self.response.json(**kwargs)
+
+ @override
+ def read(self) -> bytes:
+ return self.response.read()
+
+ @override
+ def iter_bytes(self, chunk_size: Optional[int] = None) -> Iterator[bytes]:
+ return self.response.iter_bytes(chunk_size)
+
+ @override
+ def iter_text(self, chunk_size: Optional[int] = None) -> Iterator[str]:
+ return self.response.iter_text(chunk_size)
+
+ @override
+ def iter_lines(self) -> Iterator[str]:
+ return self.response.iter_lines()
+
+ @override
+ def iter_raw(self, chunk_size: Optional[int] = None) -> Iterator[bytes]:
+ return self.response.iter_raw(chunk_size)
+
+ @override
+ def stream_to_file(self, file: str | os.PathLike[str]) -> None:
+ with open(file, mode="wb") as f:
+ for data in self.response.iter_bytes():
+ f.write(data)
+
+ @override
+ def close(self) -> None:
+ return self.response.close()
+
+ @override
+ async def aread(self) -> bytes:
+ return await self.response.aread()
+
+ @override
+ async def aiter_bytes(self, chunk_size: Optional[int] = None) -> AsyncIterator[bytes]:
+ return self.response.aiter_bytes(chunk_size)
+
+ @override
+ async def aiter_text(self, chunk_size: Optional[int] = None) -> AsyncIterator[str]:
+ return self.response.aiter_text(chunk_size)
+
+ @override
+ async def aiter_lines(self) -> AsyncIterator[str]:
+ return self.response.aiter_lines()
+
+ @override
+ async def aiter_raw(self, chunk_size: Optional[int] = None) -> AsyncIterator[bytes]:
+ return self.response.aiter_raw(chunk_size)
+
+ @override
+ async def astream_to_file(self, file: str | os.PathLike[str]) -> None:
+ path = anyio.Path(file)
+ async with await path.open(mode="wb") as f:
+ async for data in self.response.aiter_bytes():
+ await f.write(data)
+
+ @override
+ async def aclose(self) -> None:
+ return await self.response.aclose()
diff --git a/src/openai/_client.py b/src/openai/_client.py
new file mode 100644
index 0000000000..9df7eabf9a
--- /dev/null
+++ b/src/openai/_client.py
@@ -0,0 +1,488 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+import os
+import asyncio
+from typing import Union, Mapping
+from typing_extensions import override
+
+import httpx
+
+from . import resources, _exceptions
+from ._qs import Querystring
+from ._types import (
+ NOT_GIVEN,
+ Omit,
+ Timeout,
+ NotGiven,
+ Transport,
+ ProxiesTypes,
+ RequestOptions,
+)
+from ._utils import is_given
+from ._version import __version__
+from ._streaming import Stream as Stream
+from ._streaming import AsyncStream as AsyncStream
+from ._exceptions import OpenAIError, APIStatusError
+from ._base_client import DEFAULT_MAX_RETRIES, SyncAPIClient, AsyncAPIClient
+
+__all__ = [
+ "Timeout",
+ "Transport",
+ "ProxiesTypes",
+ "RequestOptions",
+ "resources",
+ "OpenAI",
+ "AsyncOpenAI",
+ "Client",
+ "AsyncClient",
+]
+
+
+class OpenAI(SyncAPIClient):
+ completions: resources.Completions
+ chat: resources.Chat
+ edits: resources.Edits
+ embeddings: resources.Embeddings
+ files: resources.Files
+ images: resources.Images
+ audio: resources.Audio
+ moderations: resources.Moderations
+ models: resources.Models
+ fine_tuning: resources.FineTuning
+ fine_tunes: resources.FineTunes
+ with_raw_response: OpenAIWithRawResponse
+
+ # client options
+ api_key: str
+ organization: str | None
+
+ def __init__(
+ self,
+ *,
+ api_key: str | None = None,
+ organization: str | None = None,
+ base_url: str | httpx.URL | None = None,
+ timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
+ max_retries: int = DEFAULT_MAX_RETRIES,
+ default_headers: Mapping[str, str] | None = None,
+ default_query: Mapping[str, object] | None = None,
+ # Configure a custom httpx client. See the [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
+ http_client: httpx.Client | None = None,
+ # Enable or disable schema validation for data returned by the API.
+ # When enabled an error APIResponseValidationError is raised
+ # if the API responds with invalid data for the expected schema.
+ #
+ # This parameter may be removed or changed in the future.
+ # If you rely on this feature, please open a GitHub issue
+ # outlining your use-case to help us decide if it should be
+ # part of our public interface in the future.
+ _strict_response_validation: bool = False,
+ ) -> None:
+ """Construct a new synchronous openai client instance.
+
+ This automatically infers the following arguments from their corresponding environment variables if they are not provided:
+ - `api_key` from `OPENAI_API_KEY`
+ - `organization` from `OPENAI_ORG_ID`
+ """
+ if api_key is None:
+ api_key = os.environ.get("OPENAI_API_KEY")
+ if api_key is None:
+ raise OpenAIError(
+ "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
+ )
+ self.api_key = api_key
+
+ if organization is None:
+ organization = os.environ.get("OPENAI_ORG_ID")
+ self.organization = organization
+
+ if base_url is None:
+ base_url = f"https://api.openai.com/v1"
+
+ super().__init__(
+ version=__version__,
+ base_url=base_url,
+ max_retries=max_retries,
+ timeout=timeout,
+ http_client=http_client,
+ custom_headers=default_headers,
+ custom_query=default_query,
+ _strict_response_validation=_strict_response_validation,
+ )
+
+ self._default_stream_cls = Stream
+
+ self.completions = resources.Completions(self)
+ self.chat = resources.Chat(self)
+ self.edits = resources.Edits(self)
+ self.embeddings = resources.Embeddings(self)
+ self.files = resources.Files(self)
+ self.images = resources.Images(self)
+ self.audio = resources.Audio(self)
+ self.moderations = resources.Moderations(self)
+ self.models = resources.Models(self)
+ self.fine_tuning = resources.FineTuning(self)
+ self.fine_tunes = resources.FineTunes(self)
+ self.with_raw_response = OpenAIWithRawResponse(self)
+
+ @property
+ @override
+ def qs(self) -> Querystring:
+ return Querystring(array_format="comma")
+
+ @property
+ @override
+ def auth_headers(self) -> dict[str, str]:
+ api_key = self.api_key
+ return {"Authorization": f"Bearer {api_key}"}
+
+ @property
+ @override
+ def default_headers(self) -> dict[str, str | Omit]:
+ return {
+ **super().default_headers,
+ "OpenAI-Organization": self.organization if self.organization is not None else Omit(),
+ **self._custom_headers,
+ }
+
+ def copy(
+ self,
+ *,
+ api_key: str | None = None,
+ organization: str | None = None,
+ base_url: str | httpx.URL | None = None,
+ timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
+ http_client: httpx.Client | None = None,
+ max_retries: int | NotGiven = NOT_GIVEN,
+ default_headers: Mapping[str, str] | None = None,
+ set_default_headers: Mapping[str, str] | None = None,
+ default_query: Mapping[str, object] | None = None,
+ set_default_query: Mapping[str, object] | None = None,
+ ) -> OpenAI:
+ """
+ Create a new client instance re-using the same options given to the current client with optional overriding.
+
+ It should be noted that this does not share the underlying httpx client class which may lead
+ to performance issues.
+ """
+ if default_headers is not None and set_default_headers is not None:
+ raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive")
+
+ if default_query is not None and set_default_query is not None:
+ raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive")
+
+ headers = self._custom_headers
+ if default_headers is not None:
+ headers = {**headers, **default_headers}
+ elif set_default_headers is not None:
+ headers = set_default_headers
+
+ params = self._custom_query
+ if default_query is not None:
+ params = {**params, **default_query}
+ elif set_default_query is not None:
+ params = set_default_query
+
+ http_client = http_client or self._client
+ return self.__class__(
+ api_key=api_key or self.api_key,
+ organization=organization or self.organization,
+ base_url=base_url or str(self.base_url),
+ timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
+ http_client=http_client,
+ max_retries=max_retries if is_given(max_retries) else self.max_retries,
+ default_headers=headers,
+ default_query=params,
+ )
+
+ # Alias for `copy` for nicer inline usage, e.g.
+ # client.with_options(timeout=10).foo.create(...)
+ with_options = copy
+
+ def __del__(self) -> None:
+ if not hasattr(self, "_has_custom_http_client") or not hasattr(self, "close"):
+ # this can happen if the '__init__' method raised an error
+ return
+
+ if self._has_custom_http_client:
+ return
+
+ self.close()
+
+ @override
+ def _make_status_error(
+ self,
+ err_msg: str,
+ *,
+ body: object,
+ response: httpx.Response,
+ ) -> APIStatusError:
+ if response.status_code == 400:
+ return _exceptions.BadRequestError(err_msg, response=response, body=body)
+
+ if response.status_code == 401:
+ return _exceptions.AuthenticationError(err_msg, response=response, body=body)
+
+ if response.status_code == 403:
+ return _exceptions.PermissionDeniedError(err_msg, response=response, body=body)
+
+ if response.status_code == 404:
+ return _exceptions.NotFoundError(err_msg, response=response, body=body)
+
+ if response.status_code == 409:
+ return _exceptions.ConflictError(err_msg, response=response, body=body)
+
+ if response.status_code == 422:
+ return _exceptions.UnprocessableEntityError(err_msg, response=response, body=body)
+
+ if response.status_code == 429:
+ return _exceptions.RateLimitError(err_msg, response=response, body=body)
+
+ if response.status_code >= 500:
+ return _exceptions.InternalServerError(err_msg, response=response, body=body)
+ return APIStatusError(err_msg, response=response, body=body)
+
+
+class AsyncOpenAI(AsyncAPIClient):
+ completions: resources.AsyncCompletions
+ chat: resources.AsyncChat
+ edits: resources.AsyncEdits
+ embeddings: resources.AsyncEmbeddings
+ files: resources.AsyncFiles
+ images: resources.AsyncImages
+ audio: resources.AsyncAudio
+ moderations: resources.AsyncModerations
+ models: resources.AsyncModels
+ fine_tuning: resources.AsyncFineTuning
+ fine_tunes: resources.AsyncFineTunes
+ with_raw_response: AsyncOpenAIWithRawResponse
+
+ # client options
+ api_key: str
+ organization: str | None
+
+ def __init__(
+ self,
+ *,
+ api_key: str | None = None,
+ organization: str | None = None,
+ base_url: str | httpx.URL | None = None,
+ timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
+ max_retries: int = DEFAULT_MAX_RETRIES,
+ default_headers: Mapping[str, str] | None = None,
+ default_query: Mapping[str, object] | None = None,
+ # Configure a custom httpx client. See the [httpx documentation](https://www.python-httpx.org/api/#asyncclient) for more details.
+ http_client: httpx.AsyncClient | None = None,
+ # Enable or disable schema validation for data returned by the API.
+ # When enabled an error APIResponseValidationError is raised
+ # if the API responds with invalid data for the expected schema.
+ #
+ # This parameter may be removed or changed in the future.
+ # If you rely on this feature, please open a GitHub issue
+ # outlining your use-case to help us decide if it should be
+ # part of our public interface in the future.
+ _strict_response_validation: bool = False,
+ ) -> None:
+ """Construct a new async openai client instance.
+
+ This automatically infers the following arguments from their corresponding environment variables if they are not provided:
+ - `api_key` from `OPENAI_API_KEY`
+ - `organization` from `OPENAI_ORG_ID`
+ """
+ if api_key is None:
+ api_key = os.environ.get("OPENAI_API_KEY")
+ if api_key is None:
+ raise OpenAIError(
+ "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
+ )
+ self.api_key = api_key
+
+ if organization is None:
+ organization = os.environ.get("OPENAI_ORG_ID")
+ self.organization = organization
+
+ if base_url is None:
+ base_url = f"https://api.openai.com/v1"
+
+ super().__init__(
+ version=__version__,
+ base_url=base_url,
+ max_retries=max_retries,
+ timeout=timeout,
+ http_client=http_client,
+ custom_headers=default_headers,
+ custom_query=default_query,
+ _strict_response_validation=_strict_response_validation,
+ )
+
+ self._default_stream_cls = AsyncStream
+
+ self.completions = resources.AsyncCompletions(self)
+ self.chat = resources.AsyncChat(self)
+ self.edits = resources.AsyncEdits(self)
+ self.embeddings = resources.AsyncEmbeddings(self)
+ self.files = resources.AsyncFiles(self)
+ self.images = resources.AsyncImages(self)
+ self.audio = resources.AsyncAudio(self)
+ self.moderations = resources.AsyncModerations(self)
+ self.models = resources.AsyncModels(self)
+ self.fine_tuning = resources.AsyncFineTuning(self)
+ self.fine_tunes = resources.AsyncFineTunes(self)
+ self.with_raw_response = AsyncOpenAIWithRawResponse(self)
+
+ @property
+ @override
+ def qs(self) -> Querystring:
+ return Querystring(array_format="comma")
+
+ @property
+ @override
+ def auth_headers(self) -> dict[str, str]:
+ api_key = self.api_key
+ return {"Authorization": f"Bearer {api_key}"}
+
+ @property
+ @override
+ def default_headers(self) -> dict[str, str | Omit]:
+ return {
+ **super().default_headers,
+ "OpenAI-Organization": self.organization if self.organization is not None else Omit(),
+ **self._custom_headers,
+ }
+
+ def copy(
+ self,
+ *,
+ api_key: str | None = None,
+ organization: str | None = None,
+ base_url: str | httpx.URL | None = None,
+ timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
+ http_client: httpx.AsyncClient | None = None,
+ max_retries: int | NotGiven = NOT_GIVEN,
+ default_headers: Mapping[str, str] | None = None,
+ set_default_headers: Mapping[str, str] | None = None,
+ default_query: Mapping[str, object] | None = None,
+ set_default_query: Mapping[str, object] | None = None,
+ ) -> AsyncOpenAI:
+ """
+ Create a new client instance re-using the same options given to the current client with optional overriding.
+
+ It should be noted that this does not share the underlying httpx client class which may lead
+ to performance issues.
+ """
+ if default_headers is not None and set_default_headers is not None:
+ raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive")
+
+ if default_query is not None and set_default_query is not None:
+ raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive")
+
+ headers = self._custom_headers
+ if default_headers is not None:
+ headers = {**headers, **default_headers}
+ elif set_default_headers is not None:
+ headers = set_default_headers
+
+ params = self._custom_query
+ if default_query is not None:
+ params = {**params, **default_query}
+ elif set_default_query is not None:
+ params = set_default_query
+
+ http_client = http_client or self._client
+ return self.__class__(
+ api_key=api_key or self.api_key,
+ organization=organization or self.organization,
+ base_url=base_url or str(self.base_url),
+ timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
+ http_client=http_client,
+ max_retries=max_retries if is_given(max_retries) else self.max_retries,
+ default_headers=headers,
+ default_query=params,
+ )
+
+ # Alias for `copy` for nicer inline usage, e.g.
+ # client.with_options(timeout=10).foo.create(...)
+ with_options = copy
+
+ def __del__(self) -> None:
+ if not hasattr(self, "_has_custom_http_client") or not hasattr(self, "close"):
+ # this can happen if the '__init__' method raised an error
+ return
+
+ if self._has_custom_http_client:
+ return
+
+ try:
+ asyncio.get_running_loop().create_task(self.close())
+ except Exception:
+ pass
+
+ @override
+ def _make_status_error(
+ self,
+ err_msg: str,
+ *,
+ body: object,
+ response: httpx.Response,
+ ) -> APIStatusError:
+ if response.status_code == 400:
+ return _exceptions.BadRequestError(err_msg, response=response, body=body)
+
+ if response.status_code == 401:
+ return _exceptions.AuthenticationError(err_msg, response=response, body=body)
+
+ if response.status_code == 403:
+ return _exceptions.PermissionDeniedError(err_msg, response=response, body=body)
+
+ if response.status_code == 404:
+ return _exceptions.NotFoundError(err_msg, response=response, body=body)
+
+ if response.status_code == 409:
+ return _exceptions.ConflictError(err_msg, response=response, body=body)
+
+ if response.status_code == 422:
+ return _exceptions.UnprocessableEntityError(err_msg, response=response, body=body)
+
+ if response.status_code == 429:
+ return _exceptions.RateLimitError(err_msg, response=response, body=body)
+
+ if response.status_code >= 500:
+ return _exceptions.InternalServerError(err_msg, response=response, body=body)
+ return APIStatusError(err_msg, response=response, body=body)
+
+
+class OpenAIWithRawResponse:
+ def __init__(self, client: OpenAI) -> None:
+ self.completions = resources.CompletionsWithRawResponse(client.completions)
+ self.chat = resources.ChatWithRawResponse(client.chat)
+ self.edits = resources.EditsWithRawResponse(client.edits)
+ self.embeddings = resources.EmbeddingsWithRawResponse(client.embeddings)
+ self.files = resources.FilesWithRawResponse(client.files)
+ self.images = resources.ImagesWithRawResponse(client.images)
+ self.audio = resources.AudioWithRawResponse(client.audio)
+ self.moderations = resources.ModerationsWithRawResponse(client.moderations)
+ self.models = resources.ModelsWithRawResponse(client.models)
+ self.fine_tuning = resources.FineTuningWithRawResponse(client.fine_tuning)
+ self.fine_tunes = resources.FineTunesWithRawResponse(client.fine_tunes)
+
+
+class AsyncOpenAIWithRawResponse:
+ def __init__(self, client: AsyncOpenAI) -> None:
+ self.completions = resources.AsyncCompletionsWithRawResponse(client.completions)
+ self.chat = resources.AsyncChatWithRawResponse(client.chat)
+ self.edits = resources.AsyncEditsWithRawResponse(client.edits)
+ self.embeddings = resources.AsyncEmbeddingsWithRawResponse(client.embeddings)
+ self.files = resources.AsyncFilesWithRawResponse(client.files)
+ self.images = resources.AsyncImagesWithRawResponse(client.images)
+ self.audio = resources.AsyncAudioWithRawResponse(client.audio)
+ self.moderations = resources.AsyncModerationsWithRawResponse(client.moderations)
+ self.models = resources.AsyncModelsWithRawResponse(client.models)
+ self.fine_tuning = resources.AsyncFineTuningWithRawResponse(client.fine_tuning)
+ self.fine_tunes = resources.AsyncFineTunesWithRawResponse(client.fine_tunes)
+
+
+Client = OpenAI
+
+AsyncClient = AsyncOpenAI
diff --git a/src/openai/_compat.py b/src/openai/_compat.py
new file mode 100644
index 0000000000..34323c9b7e
--- /dev/null
+++ b/src/openai/_compat.py
@@ -0,0 +1,173 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Union, TypeVar, cast
+from datetime import date, datetime
+
+import pydantic
+from pydantic.fields import FieldInfo
+
+from ._types import StrBytesIntFloat
+
+_ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel)
+
+# --------------- Pydantic v2 compatibility ---------------
+
+# Pyright incorrectly reports some of our functions as overriding a method when they don't
+# pyright: reportIncompatibleMethodOverride=false
+
+PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
+
+# v1 re-exports
+if TYPE_CHECKING:
+
+ def parse_date(value: date | StrBytesIntFloat) -> date: # noqa: ARG001
+ ...
+
+ def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: # noqa: ARG001
+ ...
+
+ def get_args(t: type[Any]) -> tuple[Any, ...]: # noqa: ARG001
+ ...
+
+ def is_union(tp: type[Any] | None) -> bool: # noqa: ARG001
+ ...
+
+ def get_origin(t: type[Any]) -> type[Any] | None: # noqa: ARG001
+ ...
+
+ def is_literal_type(type_: type[Any]) -> bool: # noqa: ARG001
+ ...
+
+ def is_typeddict(type_: type[Any]) -> bool: # noqa: ARG001
+ ...
+
+else:
+ if PYDANTIC_V2:
+ from pydantic.v1.typing import get_args as get_args
+ from pydantic.v1.typing import is_union as is_union
+ from pydantic.v1.typing import get_origin as get_origin
+ from pydantic.v1.typing import is_typeddict as is_typeddict
+ from pydantic.v1.typing import is_literal_type as is_literal_type
+ from pydantic.v1.datetime_parse import parse_date as parse_date
+ from pydantic.v1.datetime_parse import parse_datetime as parse_datetime
+ else:
+ from pydantic.typing import get_args as get_args
+ from pydantic.typing import is_union as is_union
+ from pydantic.typing import get_origin as get_origin
+ from pydantic.typing import is_typeddict as is_typeddict
+ from pydantic.typing import is_literal_type as is_literal_type
+ from pydantic.datetime_parse import parse_date as parse_date
+ from pydantic.datetime_parse import parse_datetime as parse_datetime
+
+
+# refactored config
+if TYPE_CHECKING:
+ from pydantic import ConfigDict as ConfigDict
+else:
+ if PYDANTIC_V2:
+ from pydantic import ConfigDict
+ else:
+ # TODO: provide an error message here?
+ ConfigDict = None
+
+
+# renamed methods / properties
+def parse_obj(model: type[_ModelT], value: object) -> _ModelT:
+ if PYDANTIC_V2:
+ return model.model_validate(value)
+ else:
+ return cast(_ModelT, model.parse_obj(value)) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
+
+
+def field_is_required(field: FieldInfo) -> bool:
+ if PYDANTIC_V2:
+ return field.is_required()
+ return field.required # type: ignore
+
+
+def field_get_default(field: FieldInfo) -> Any:
+ value = field.get_default()
+ if PYDANTIC_V2:
+ from pydantic_core import PydanticUndefined
+
+ if value == PydanticUndefined:
+ return None
+ return value
+ return value
+
+
+def field_outer_type(field: FieldInfo) -> Any:
+ if PYDANTIC_V2:
+ return field.annotation
+ return field.outer_type_ # type: ignore
+
+
+def get_model_config(model: type[pydantic.BaseModel]) -> Any:
+ if PYDANTIC_V2:
+ return model.model_config
+ return model.__config__ # type: ignore
+
+
+def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]:
+ if PYDANTIC_V2:
+ return model.model_fields
+ return model.__fields__ # type: ignore
+
+
+def model_copy(model: _ModelT) -> _ModelT:
+ if PYDANTIC_V2:
+ return model.model_copy()
+ return model.copy() # type: ignore
+
+
+def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:
+ if PYDANTIC_V2:
+ return model.model_dump_json(indent=indent)
+ return model.json(indent=indent) # type: ignore
+
+
+def model_dump(
+ model: pydantic.BaseModel,
+ *,
+ exclude_unset: bool = False,
+ exclude_defaults: bool = False,
+) -> dict[str, Any]:
+ if PYDANTIC_V2:
+ return model.model_dump(
+ exclude_unset=exclude_unset,
+ exclude_defaults=exclude_defaults,
+ )
+ return cast(
+ "dict[str, Any]",
+ model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
+ exclude_unset=exclude_unset,
+ exclude_defaults=exclude_defaults,
+ ),
+ )
+
+
+def model_parse(model: type[_ModelT], data: Any) -> _ModelT:
+ if PYDANTIC_V2:
+ return model.model_validate(data)
+ return model.parse_obj(data) # pyright: ignore[reportDeprecated]
+
+
+# generic models
+if TYPE_CHECKING:
+
+ class GenericModel(pydantic.BaseModel):
+ ...
+
+else:
+ if PYDANTIC_V2:
+ # there no longer needs to be a distinction in v2 but
+ # we still have to create our own subclass to avoid
+ # inconsistent MRO ordering errors
+ class GenericModel(pydantic.BaseModel):
+ ...
+
+ else:
+ import pydantic.generics
+
+ class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel):
+ ...
diff --git a/src/openai/_constants.py b/src/openai/_constants.py
new file mode 100644
index 0000000000..2e402300d3
--- /dev/null
+++ b/src/openai/_constants.py
@@ -0,0 +1,10 @@
+# File generated from our OpenAPI spec by Stainless.
+
+import httpx
+
+RAW_RESPONSE_HEADER = "X-Stainless-Raw-Response"
+
+# default timeout is 10 minutes
+DEFAULT_TIMEOUT = httpx.Timeout(timeout=600.0, connect=5.0)
+DEFAULT_MAX_RETRIES = 2
+DEFAULT_LIMITS = httpx.Limits(max_connections=100, max_keepalive_connections=20)
diff --git a/src/openai/_exceptions.py b/src/openai/_exceptions.py
new file mode 100644
index 0000000000..b79ac5fd64
--- /dev/null
+++ b/src/openai/_exceptions.py
@@ -0,0 +1,123 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import Any, Optional, cast
+from typing_extensions import Literal
+
+import httpx
+
+from ._utils import is_dict
+
+__all__ = [
+ "BadRequestError",
+ "AuthenticationError",
+ "PermissionDeniedError",
+ "NotFoundError",
+ "ConflictError",
+ "UnprocessableEntityError",
+ "RateLimitError",
+ "InternalServerError",
+]
+
+
+class OpenAIError(Exception):
+ pass
+
+
+class APIError(OpenAIError):
+ message: str
+ request: httpx.Request
+
+ body: object | None
+ """The API response body.
+
+ If the API responded with a valid JSON structure then this property will be the
+ decoded result.
+
+ If it isn't a valid JSON structure then this will be the raw response.
+
+ If there was no response associated with this error then it will be `None`.
+ """
+
+ code: Optional[str]
+ param: Optional[str]
+ type: Optional[str]
+
+ def __init__(self, message: str, request: httpx.Request, *, body: object | None) -> None:
+ super().__init__(message)
+ self.request = request
+ self.message = message
+
+ if is_dict(body):
+ self.code = cast(Any, body.get("code"))
+ self.param = cast(Any, body.get("param"))
+ self.type = cast(Any, body.get("type"))
+ else:
+ self.code = None
+ self.param = None
+ self.type = None
+
+
+class APIResponseValidationError(APIError):
+ response: httpx.Response
+ status_code: int
+
+ def __init__(self, response: httpx.Response, body: object | None, *, message: str | None = None) -> None:
+ super().__init__(message or "Data returned by API invalid for expected schema.", response.request, body=body)
+ self.response = response
+ self.status_code = response.status_code
+
+
+class APIStatusError(APIError):
+ """Raised when an API response has a status code of 4xx or 5xx."""
+
+ response: httpx.Response
+ status_code: int
+
+ def __init__(self, message: str, *, response: httpx.Response, body: object | None) -> None:
+ super().__init__(message, response.request, body=body)
+ self.response = response
+ self.status_code = response.status_code
+
+
+class APIConnectionError(APIError):
+ def __init__(self, *, message: str = "Connection error.", request: httpx.Request) -> None:
+ super().__init__(message, request, body=None)
+
+
+class APITimeoutError(APIConnectionError):
+ def __init__(self, request: httpx.Request) -> None:
+ super().__init__(message="Request timed out.", request=request)
+
+
+class BadRequestError(APIStatusError):
+ status_code: Literal[400] = 400 # pyright: ignore[reportIncompatibleVariableOverride]
+
+
+class AuthenticationError(APIStatusError):
+ status_code: Literal[401] = 401 # pyright: ignore[reportIncompatibleVariableOverride]
+
+
+class PermissionDeniedError(APIStatusError):
+ status_code: Literal[403] = 403 # pyright: ignore[reportIncompatibleVariableOverride]
+
+
+class NotFoundError(APIStatusError):
+ status_code: Literal[404] = 404 # pyright: ignore[reportIncompatibleVariableOverride]
+
+
+class ConflictError(APIStatusError):
+ status_code: Literal[409] = 409 # pyright: ignore[reportIncompatibleVariableOverride]
+
+
+class UnprocessableEntityError(APIStatusError):
+ status_code: Literal[422] = 422 # pyright: ignore[reportIncompatibleVariableOverride]
+
+
+class RateLimitError(APIStatusError):
+ status_code: Literal[429] = 429 # pyright: ignore[reportIncompatibleVariableOverride]
+
+
+class InternalServerError(APIStatusError):
+ pass
diff --git a/src/openai/_extras/__init__.py b/src/openai/_extras/__init__.py
new file mode 100644
index 0000000000..dc6625c5dc
--- /dev/null
+++ b/src/openai/_extras/__init__.py
@@ -0,0 +1,3 @@
+from .numpy_proxy import numpy as numpy
+from .numpy_proxy import has_numpy as has_numpy
+from .pandas_proxy import pandas as pandas
diff --git a/src/openai/_extras/_common.py b/src/openai/_extras/_common.py
new file mode 100644
index 0000000000..6e71720e64
--- /dev/null
+++ b/src/openai/_extras/_common.py
@@ -0,0 +1,21 @@
+from .._exceptions import OpenAIError
+
+INSTRUCTIONS = """
+
+OpenAI error:
+
+ missing `{library}`
+
+This feature requires additional dependencies:
+
+ $ pip install openai[{extra}]
+
+"""
+
+
+def format_instructions(*, library: str, extra: str) -> str:
+ return INSTRUCTIONS.format(library=library, extra=extra)
+
+
+class MissingDependencyError(OpenAIError):
+ pass
diff --git a/src/openai/_extras/numpy_proxy.py b/src/openai/_extras/numpy_proxy.py
new file mode 100644
index 0000000000..408eaebd3b
--- /dev/null
+++ b/src/openai/_extras/numpy_proxy.py
@@ -0,0 +1,39 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+from typing_extensions import ClassVar, override
+
+from .._utils import LazyProxy
+from ._common import MissingDependencyError, format_instructions
+
+if TYPE_CHECKING:
+ import numpy as numpy
+
+
+NUMPY_INSTRUCTIONS = format_instructions(library="numpy", extra="datalib")
+
+
+class NumpyProxy(LazyProxy[Any]):
+ should_cache: ClassVar[bool] = True
+
+ @override
+ def __load__(self) -> Any:
+ try:
+ import numpy
+ except ImportError:
+ raise MissingDependencyError(NUMPY_INSTRUCTIONS)
+
+ return numpy
+
+
+if not TYPE_CHECKING:
+ numpy = NumpyProxy()
+
+
+def has_numpy() -> bool:
+ try:
+ import numpy # noqa: F401 # pyright: ignore[reportUnusedImport]
+ except ImportError:
+ return False
+
+ return True
diff --git a/src/openai/_extras/pandas_proxy.py b/src/openai/_extras/pandas_proxy.py
new file mode 100644
index 0000000000..2fc0d2a7eb
--- /dev/null
+++ b/src/openai/_extras/pandas_proxy.py
@@ -0,0 +1,30 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+from typing_extensions import ClassVar, override
+
+from .._utils import LazyProxy
+from ._common import MissingDependencyError, format_instructions
+
+if TYPE_CHECKING:
+ import pandas as pandas
+
+
+PANDAS_INSTRUCTIONS = format_instructions(library="pandas", extra="datalib")
+
+
+class PandasProxy(LazyProxy[Any]):
+ should_cache: ClassVar[bool] = True
+
+ @override
+ def __load__(self) -> Any:
+ try:
+ import pandas
+ except ImportError:
+ raise MissingDependencyError(PANDAS_INSTRUCTIONS)
+
+ return pandas
+
+
+if not TYPE_CHECKING:
+ pandas = PandasProxy()
diff --git a/src/openai/_files.py b/src/openai/_files.py
new file mode 100644
index 0000000000..49e3536243
--- /dev/null
+++ b/src/openai/_files.py
@@ -0,0 +1,122 @@
+from __future__ import annotations
+
+import io
+import os
+import pathlib
+from typing import overload
+from typing_extensions import TypeGuard
+
+import anyio
+
+from ._types import (
+ FileTypes,
+ FileContent,
+ RequestFiles,
+ HttpxFileTypes,
+ HttpxFileContent,
+ HttpxRequestFiles,
+)
+from ._utils import is_tuple_t, is_mapping_t, is_sequence_t
+
+
+def is_file_content(obj: object) -> TypeGuard[FileContent]:
+ return (
+ isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
+ )
+
+
+def assert_is_file_content(obj: object, *, key: str | None = None) -> None:
+ if not is_file_content(obj):
+ prefix = f"Expected entry at `{key}`" if key is not None else f"Expected file input `{obj!r}`"
+ raise RuntimeError(
+ f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead. See https://github.com/openai/openai-python/tree/v1#file-uploads"
+ ) from None
+
+
+@overload
+def to_httpx_files(files: None) -> None:
+ ...
+
+
+@overload
+def to_httpx_files(files: RequestFiles) -> HttpxRequestFiles:
+ ...
+
+
+def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
+ if files is None:
+ return None
+
+ if is_mapping_t(files):
+ files = {key: _transform_file(file) for key, file in files.items()}
+ elif is_sequence_t(files):
+ files = [(key, _transform_file(file)) for key, file in files]
+ else:
+ raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence")
+
+ return files
+
+
+def _transform_file(file: FileTypes) -> HttpxFileTypes:
+ if is_file_content(file):
+ if isinstance(file, os.PathLike):
+ path = pathlib.Path(file)
+ return (path.name, path.read_bytes())
+
+ return file
+
+ if is_tuple_t(file):
+ return (file[0], _read_file_content(file[1]), *file[2:])
+
+ raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
+
+
+def _read_file_content(file: FileContent) -> HttpxFileContent:
+ if isinstance(file, os.PathLike):
+ return pathlib.Path(file).read_bytes()
+ return file
+
+
+@overload
+async def async_to_httpx_files(files: None) -> None:
+ ...
+
+
+@overload
+async def async_to_httpx_files(files: RequestFiles) -> HttpxRequestFiles:
+ ...
+
+
+async def async_to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
+ if files is None:
+ return None
+
+ if is_mapping_t(files):
+ files = {key: await _async_transform_file(file) for key, file in files.items()}
+ elif is_sequence_t(files):
+ files = [(key, await _async_transform_file(file)) for key, file in files]
+ else:
+ raise TypeError("Unexpected file type input {type(files)}, expected mapping or sequence")
+
+ return files
+
+
+async def _async_transform_file(file: FileTypes) -> HttpxFileTypes:
+ if is_file_content(file):
+ if isinstance(file, os.PathLike):
+ path = anyio.Path(file)
+ return (path.name, await path.read_bytes())
+
+ return file
+
+ if is_tuple_t(file):
+ return (file[0], await _async_read_file_content(file[1]), *file[2:])
+
+ raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
+
+
+async def _async_read_file_content(file: FileContent) -> HttpxFileContent:
+ if isinstance(file, os.PathLike):
+ return await anyio.Path(file).read_bytes()
+
+ return file
diff --git a/src/openai/_models.py b/src/openai/_models.py
new file mode 100644
index 0000000000..00d787ca87
--- /dev/null
+++ b/src/openai/_models.py
@@ -0,0 +1,460 @@
+from __future__ import annotations
+
+import inspect
+from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, cast
+from datetime import date, datetime
+from typing_extensions import (
+ Unpack,
+ Literal,
+ ClassVar,
+ Protocol,
+ Required,
+ TypedDict,
+ final,
+ override,
+ runtime_checkable,
+)
+
+import pydantic
+import pydantic.generics
+from pydantic.fields import FieldInfo
+
+from ._types import (
+ Body,
+ IncEx,
+ Query,
+ ModelT,
+ Headers,
+ Timeout,
+ NotGiven,
+ AnyMapping,
+ HttpxRequestFiles,
+)
+from ._utils import (
+ is_list,
+ is_given,
+ is_mapping,
+ parse_date,
+ parse_datetime,
+ strip_not_given,
+)
+from ._compat import PYDANTIC_V2, ConfigDict
+from ._compat import GenericModel as BaseGenericModel
+from ._compat import (
+ get_args,
+ is_union,
+ parse_obj,
+ get_origin,
+ is_literal_type,
+ get_model_config,
+ get_model_fields,
+ field_get_default,
+)
+from ._constants import RAW_RESPONSE_HEADER
+
+__all__ = ["BaseModel", "GenericModel"]
+
+_T = TypeVar("_T")
+
+
+@runtime_checkable
+class _ConfigProtocol(Protocol):
+ allow_population_by_field_name: bool
+
+
+class BaseModel(pydantic.BaseModel):
+ if PYDANTIC_V2:
+ model_config: ClassVar[ConfigDict] = ConfigDict(extra="allow")
+ else:
+
+ @property
+ @override
+ def model_fields_set(self) -> set[str]:
+ # a forwards-compat shim for pydantic v2
+ return self.__fields_set__ # type: ignore
+
+ class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
+ extra: Any = pydantic.Extra.allow # type: ignore
+
+ @override
+ def __str__(self) -> str:
+ # mypy complains about an invalid self arg
+ return f'{self.__repr_name__()}({self.__repr_str__(", ")})' # type: ignore[misc]
+
+ # Override the 'construct' method in a way that supports recursive parsing without validation.
+ # Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836.
+ @classmethod
+ @override
+ def construct(
+ cls: Type[ModelT],
+ _fields_set: set[str] | None = None,
+ **values: object,
+ ) -> ModelT:
+ m = cls.__new__(cls)
+ fields_values: dict[str, object] = {}
+
+ config = get_model_config(cls)
+ populate_by_name = (
+ config.allow_population_by_field_name
+ if isinstance(config, _ConfigProtocol)
+ else config.get("populate_by_name")
+ )
+
+ if _fields_set is None:
+ _fields_set = set()
+
+ model_fields = get_model_fields(cls)
+ for name, field in model_fields.items():
+ key = field.alias
+ if key is None or (key not in values and populate_by_name):
+ key = name
+
+ if key in values:
+ fields_values[name] = _construct_field(value=values[key], field=field, key=key)
+ _fields_set.add(name)
+ else:
+ fields_values[name] = field_get_default(field)
+
+ _extra = {}
+ for key, value in values.items():
+ if key not in model_fields:
+ if PYDANTIC_V2:
+ _extra[key] = value
+ else:
+ fields_values[key] = value
+
+ object.__setattr__(m, "__dict__", fields_values)
+
+ if PYDANTIC_V2:
+ # these properties are copied from Pydantic's `model_construct()` method
+ object.__setattr__(m, "__pydantic_private__", None)
+ object.__setattr__(m, "__pydantic_extra__", _extra)
+ object.__setattr__(m, "__pydantic_fields_set__", _fields_set)
+ else:
+ # init_private_attributes() does not exist in v2
+ m._init_private_attributes() # type: ignore
+
+ # copied from Pydantic v1's `construct()` method
+ object.__setattr__(m, "__fields_set__", _fields_set)
+
+ return m
+
+ if not TYPE_CHECKING:
+ # type checkers incorrectly complain about this assignment
+ # because the type signatures are technically different
+ # although not in practice
+ model_construct = construct
+
+ if not PYDANTIC_V2:
+ # we define aliases for some of the new pydantic v2 methods so
+ # that we can just document these methods without having to specify
+ # a specifc pydantic version as some users may not know which
+ # pydantic version they are currently using
+
+ @override
+ def model_dump(
+ self,
+ *,
+ mode: Literal["json", "python"] | str = "python",
+ include: IncEx = None,
+ exclude: IncEx = None,
+ by_alias: bool = False,
+ exclude_unset: bool = False,
+ exclude_defaults: bool = False,
+ exclude_none: bool = False,
+ round_trip: bool = False,
+ warnings: bool = True,
+ ) -> dict[str, Any]:
+ """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump
+
+ Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
+
+ Args:
+ mode: The mode in which `to_python` should run.
+ If mode is 'json', the dictionary will only contain JSON serializable types.
+ If mode is 'python', the dictionary may contain any Python objects.
+ include: A list of fields to include in the output.
+ exclude: A list of fields to exclude from the output.
+ by_alias: Whether to use the field's alias in the dictionary key if defined.
+ exclude_unset: Whether to exclude fields that are unset or None from the output.
+ exclude_defaults: Whether to exclude fields that are set to their default value from the output.
+ exclude_none: Whether to exclude fields that have a value of `None` from the output.
+ round_trip: Whether to enable serialization and deserialization round-trip support.
+ warnings: Whether to log warnings when invalid fields are encountered.
+
+ Returns:
+ A dictionary representation of the model.
+ """
+ if mode != "python":
+ raise ValueError("mode is only supported in Pydantic v2")
+ if round_trip != False:
+ raise ValueError("round_trip is only supported in Pydantic v2")
+ if warnings != True:
+ raise ValueError("warnings is only supported in Pydantic v2")
+ return super().dict( # pyright: ignore[reportDeprecated]
+ include=include,
+ exclude=exclude,
+ by_alias=by_alias,
+ exclude_unset=exclude_unset,
+ exclude_defaults=exclude_defaults,
+ exclude_none=exclude_none,
+ )
+
+ @override
+ def model_dump_json(
+ self,
+ *,
+ indent: int | None = None,
+ include: IncEx = None,
+ exclude: IncEx = None,
+ by_alias: bool = False,
+ exclude_unset: bool = False,
+ exclude_defaults: bool = False,
+ exclude_none: bool = False,
+ round_trip: bool = False,
+ warnings: bool = True,
+ ) -> str:
+ """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json
+
+ Generates a JSON representation of the model using Pydantic's `to_json` method.
+
+ Args:
+ indent: Indentation to use in the JSON output. If None is passed, the output will be compact.
+ include: Field(s) to include in the JSON output. Can take either a string or set of strings.
+ exclude: Field(s) to exclude from the JSON output. Can take either a string or set of strings.
+ by_alias: Whether to serialize using field aliases.
+ exclude_unset: Whether to exclude fields that have not been explicitly set.
+ exclude_defaults: Whether to exclude fields that have the default value.
+ exclude_none: Whether to exclude fields that have a value of `None`.
+ round_trip: Whether to use serialization/deserialization between JSON and class instance.
+ warnings: Whether to show any warnings that occurred during serialization.
+
+ Returns:
+ A JSON string representation of the model.
+ """
+ if round_trip != False:
+ raise ValueError("round_trip is only supported in Pydantic v2")
+ if warnings != True:
+ raise ValueError("warnings is only supported in Pydantic v2")
+ return super().json( # type: ignore[reportDeprecated]
+ indent=indent,
+ include=include,
+ exclude=exclude,
+ by_alias=by_alias,
+ exclude_unset=exclude_unset,
+ exclude_defaults=exclude_defaults,
+ exclude_none=exclude_none,
+ )
+
+
+def _construct_field(value: object, field: FieldInfo, key: str) -> object:
+ if value is None:
+ return field_get_default(field)
+
+ if PYDANTIC_V2:
+ type_ = field.annotation
+ else:
+ type_ = cast(type, field.outer_type_) # type: ignore
+
+ if type_ is None:
+ raise RuntimeError(f"Unexpected field type is None for {key}")
+
+ return construct_type(value=value, type_=type_)
+
+
+def construct_type(*, value: object, type_: type) -> object:
+ """Loose coercion to the expected type with construction of nested values.
+
+ If the given value does not match the expected type then it is returned as-is.
+ """
+
+ # we need to use the origin class for any types that are subscripted generics
+ # e.g. Dict[str, object]
+ origin = get_origin(type_) or type_
+ args = get_args(type_)
+
+ if is_union(origin):
+ try:
+ return validate_type(type_=type_, value=value)
+ except Exception:
+ pass
+
+ # if the data is not valid, use the first variant that doesn't fail while deserializing
+ for variant in args:
+ try:
+ return construct_type(value=value, type_=variant)
+ except Exception:
+ continue
+
+ raise RuntimeError(f"Could not convert data into a valid instance of {type_}")
+
+ if origin == dict:
+ if not is_mapping(value):
+ return value
+
+ _, items_type = get_args(type_) # Dict[_, items_type]
+ return {key: construct_type(value=item, type_=items_type) for key, item in value.items()}
+
+ if not is_literal_type(type_) and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel)):
+ if is_list(value):
+ return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value]
+
+ if is_mapping(value):
+ if issubclass(type_, BaseModel):
+ return type_.construct(**value) # type: ignore[arg-type]
+
+ return cast(Any, type_).construct(**value)
+
+ if origin == list:
+ if not is_list(value):
+ return value
+
+ inner_type = args[0] # List[inner_type]
+ return [construct_type(value=entry, type_=inner_type) for entry in value]
+
+ if origin == float:
+ if isinstance(value, int):
+ coerced = float(value)
+ if coerced != value:
+ return value
+ return coerced
+
+ return value
+
+ if type_ == datetime:
+ try:
+ return parse_datetime(value) # type: ignore
+ except Exception:
+ return value
+
+ if type_ == date:
+ try:
+ return parse_date(value) # type: ignore
+ except Exception:
+ return value
+
+ return value
+
+
+def validate_type(*, type_: type[_T], value: object) -> _T:
+ """Strict validation that the given value matches the expected type"""
+ if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):
+ return cast(_T, parse_obj(type_, value))
+
+ return cast(_T, _validate_non_model_type(type_=type_, value=value))
+
+
+# our use of subclasssing here causes weirdness for type checkers,
+# so we just pretend that we don't subclass
+if TYPE_CHECKING:
+ GenericModel = BaseModel
+else:
+
+ class GenericModel(BaseGenericModel, BaseModel):
+ pass
+
+
+if PYDANTIC_V2:
+ from pydantic import TypeAdapter
+
+ def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
+ return TypeAdapter(type_).validate_python(value)
+
+elif not TYPE_CHECKING: # TODO: condition is weird
+
+ class RootModel(GenericModel, Generic[_T]):
+ """Used as a placeholder to easily convert runtime types to a Pydantic format
+ to provide validation.
+
+ For example:
+ ```py
+ validated = RootModel[int](__root__='5').__root__
+ # validated: 5
+ ```
+ """
+
+ __root__: _T
+
+ def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
+ model = _create_pydantic_model(type_).validate(value)
+ return cast(_T, model.__root__)
+
+ def _create_pydantic_model(type_: _T) -> Type[RootModel[_T]]:
+ return RootModel[type_] # type: ignore
+
+
+class FinalRequestOptionsInput(TypedDict, total=False):
+ method: Required[str]
+ url: Required[str]
+ params: Query
+ headers: Headers
+ max_retries: int
+ timeout: float | Timeout | None
+ files: HttpxRequestFiles | None
+ idempotency_key: str
+ json_data: Body
+ extra_json: AnyMapping
+
+
+@final
+class FinalRequestOptions(pydantic.BaseModel):
+ method: str
+ url: str
+ params: Query = {}
+ headers: Union[Headers, NotGiven] = NotGiven()
+ max_retries: Union[int, NotGiven] = NotGiven()
+ timeout: Union[float, Timeout, None, NotGiven] = NotGiven()
+ files: Union[HttpxRequestFiles, None] = None
+ idempotency_key: Union[str, None] = None
+ post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven()
+
+ # It should be noted that we cannot use `json` here as that would override
+ # a BaseModel method in an incompatible fashion.
+ json_data: Union[Body, None] = None
+ extra_json: Union[AnyMapping, None] = None
+
+ if PYDANTIC_V2:
+ model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
+ else:
+
+ class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
+ arbitrary_types_allowed: bool = True
+
+ def get_max_retries(self, max_retries: int) -> int:
+ if isinstance(self.max_retries, NotGiven):
+ return max_retries
+ return self.max_retries
+
+ def _strip_raw_response_header(self) -> None:
+ if not is_given(self.headers):
+ return
+
+ if self.headers.get(RAW_RESPONSE_HEADER):
+ self.headers = {**self.headers}
+ self.headers.pop(RAW_RESPONSE_HEADER)
+
+ # override the `construct` method so that we can run custom transformations.
+ # this is necessary as we don't want to do any actual runtime type checking
+ # (which means we can't use validators) but we do want to ensure that `NotGiven`
+ # values are not present
+ #
+ # type ignore required because we're adding explicit types to `**values`
+ @classmethod
+ def construct( # type: ignore
+ cls,
+ _fields_set: set[str] | None = None,
+ **values: Unpack[FinalRequestOptionsInput],
+ ) -> FinalRequestOptions:
+ kwargs: dict[str, Any] = {
+ # we unconditionally call `strip_not_given` on any value
+ # as it will just ignore any non-mapping types
+ key: strip_not_given(value)
+ for key, value in values.items()
+ }
+ if PYDANTIC_V2:
+ return super().model_construct(_fields_set, **kwargs)
+ return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated]
+
+ if not TYPE_CHECKING:
+ # type checkers incorrectly complain about this assignment
+ model_construct = construct
diff --git a/src/openai/_module_client.py b/src/openai/_module_client.py
new file mode 100644
index 0000000000..ca80468e88
--- /dev/null
+++ b/src/openai/_module_client.py
@@ -0,0 +1,85 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from typing_extensions import override
+
+from . import resources, _load_client
+from ._utils import LazyProxy
+
+
+class ChatProxy(LazyProxy[resources.Chat]):
+ @override
+ def __load__(self) -> resources.Chat:
+ return _load_client().chat
+
+
+class EditsProxy(LazyProxy[resources.Edits]):
+ @override
+ def __load__(self) -> resources.Edits:
+ return _load_client().edits
+
+
+class FilesProxy(LazyProxy[resources.Files]):
+ @override
+ def __load__(self) -> resources.Files:
+ return _load_client().files
+
+
+class AudioProxy(LazyProxy[resources.Audio]):
+ @override
+ def __load__(self) -> resources.Audio:
+ return _load_client().audio
+
+
+class ImagesProxy(LazyProxy[resources.Images]):
+ @override
+ def __load__(self) -> resources.Images:
+ return _load_client().images
+
+
+class ModelsProxy(LazyProxy[resources.Models]):
+ @override
+ def __load__(self) -> resources.Models:
+ return _load_client().models
+
+
+class EmbeddingsProxy(LazyProxy[resources.Embeddings]):
+ @override
+ def __load__(self) -> resources.Embeddings:
+ return _load_client().embeddings
+
+
+class FineTunesProxy(LazyProxy[resources.FineTunes]):
+ @override
+ def __load__(self) -> resources.FineTunes:
+ return _load_client().fine_tunes
+
+
+class CompletionsProxy(LazyProxy[resources.Completions]):
+ @override
+ def __load__(self) -> resources.Completions:
+ return _load_client().completions
+
+
+class ModerationsProxy(LazyProxy[resources.Moderations]):
+ @override
+ def __load__(self) -> resources.Moderations:
+ return _load_client().moderations
+
+
+class FineTuningProxy(LazyProxy[resources.FineTuning]):
+ @override
+ def __load__(self) -> resources.FineTuning:
+ return _load_client().fine_tuning
+
+
+chat: resources.Chat = ChatProxy().__as_proxied__()
+edits: resources.Edits = EditsProxy().__as_proxied__()
+files: resources.Files = FilesProxy().__as_proxied__()
+audio: resources.Audio = AudioProxy().__as_proxied__()
+images: resources.Images = ImagesProxy().__as_proxied__()
+models: resources.Models = ModelsProxy().__as_proxied__()
+embeddings: resources.Embeddings = EmbeddingsProxy().__as_proxied__()
+fine_tunes: resources.FineTunes = FineTunesProxy().__as_proxied__()
+completions: resources.Completions = CompletionsProxy().__as_proxied__()
+moderations: resources.Moderations = ModerationsProxy().__as_proxied__()
+fine_tuning: resources.FineTuning = FineTuningProxy().__as_proxied__()
diff --git a/src/openai/_qs.py b/src/openai/_qs.py
new file mode 100644
index 0000000000..274320ca5e
--- /dev/null
+++ b/src/openai/_qs.py
@@ -0,0 +1,150 @@
+from __future__ import annotations
+
+from typing import Any, List, Tuple, Union, Mapping, TypeVar
+from urllib.parse import parse_qs, urlencode
+from typing_extensions import Literal, get_args
+
+from ._types import NOT_GIVEN, NotGiven, NotGivenOr
+from ._utils import flatten
+
+_T = TypeVar("_T")
+
+
+ArrayFormat = Literal["comma", "repeat", "indices", "brackets"]
+NestedFormat = Literal["dots", "brackets"]
+
+PrimitiveData = Union[str, int, float, bool, None]
+# this should be Data = Union[PrimitiveData, "List[Data]", "Tuple[Data]", "Mapping[str, Data]"]
+# https://github.com/microsoft/pyright/issues/3555
+Data = Union[PrimitiveData, List[Any], Tuple[Any], "Mapping[str, Any]"]
+Params = Mapping[str, Data]
+
+
+class Querystring:
+ array_format: ArrayFormat
+ nested_format: NestedFormat
+
+ def __init__(
+ self,
+ *,
+ array_format: ArrayFormat = "repeat",
+ nested_format: NestedFormat = "brackets",
+ ) -> None:
+ self.array_format = array_format
+ self.nested_format = nested_format
+
+ def parse(self, query: str) -> Mapping[str, object]:
+ # Note: custom format syntax is not supported yet
+ return parse_qs(query)
+
+ def stringify(
+ self,
+ params: Params,
+ *,
+ array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN,
+ nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN,
+ ) -> str:
+ return urlencode(
+ self.stringify_items(
+ params,
+ array_format=array_format,
+ nested_format=nested_format,
+ )
+ )
+
+ def stringify_items(
+ self,
+ params: Params,
+ *,
+ array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN,
+ nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN,
+ ) -> list[tuple[str, str]]:
+ opts = Options(
+ qs=self,
+ array_format=array_format,
+ nested_format=nested_format,
+ )
+ return flatten([self._stringify_item(key, value, opts) for key, value in params.items()])
+
+ def _stringify_item(
+ self,
+ key: str,
+ value: Data,
+ opts: Options,
+ ) -> list[tuple[str, str]]:
+ if isinstance(value, Mapping):
+ items: list[tuple[str, str]] = []
+ nested_format = opts.nested_format
+ for subkey, subvalue in value.items():
+ items.extend(
+ self._stringify_item(
+ # TODO: error if unknown format
+ f"{key}.{subkey}" if nested_format == "dots" else f"{key}[{subkey}]",
+ subvalue,
+ opts,
+ )
+ )
+ return items
+
+ if isinstance(value, (list, tuple)):
+ array_format = opts.array_format
+ if array_format == "comma":
+ return [
+ (
+ key,
+ ",".join(self._primitive_value_to_str(item) for item in value if item is not None),
+ ),
+ ]
+ elif array_format == "repeat":
+ items = []
+ for item in value:
+ items.extend(self._stringify_item(key, item, opts))
+ return items
+ elif array_format == "indices":
+ raise NotImplementedError("The array indices format is not supported yet")
+ elif array_format == "brackets":
+ items = []
+ key = key + "[]"
+ for item in value:
+ items.extend(self._stringify_item(key, item, opts))
+ return items
+ else:
+ raise NotImplementedError(
+ f"Unknown array_format value: {array_format}, choose from {', '.join(get_args(ArrayFormat))}"
+ )
+
+ serialised = self._primitive_value_to_str(value)
+ if not serialised:
+ return []
+ return [(key, serialised)]
+
+ def _primitive_value_to_str(self, value: PrimitiveData) -> str:
+ # copied from httpx
+ if value is True:
+ return "true"
+ elif value is False:
+ return "false"
+ elif value is None:
+ return ""
+ return str(value)
+
+
+_qs = Querystring()
+parse = _qs.parse
+stringify = _qs.stringify
+stringify_items = _qs.stringify_items
+
+
+class Options:
+ array_format: ArrayFormat
+ nested_format: NestedFormat
+
+ def __init__(
+ self,
+ qs: Querystring = _qs,
+ *,
+ array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN,
+ nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN,
+ ) -> None:
+ self.array_format = qs.array_format if isinstance(array_format, NotGiven) else array_format
+ self.nested_format = qs.nested_format if isinstance(nested_format, NotGiven) else nested_format
diff --git a/src/openai/_resource.py b/src/openai/_resource.py
new file mode 100644
index 0000000000..db1b0fa45a
--- /dev/null
+++ b/src/openai/_resource.py
@@ -0,0 +1,42 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+import time
+import asyncio
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from ._client import OpenAI, AsyncOpenAI
+
+
+class SyncAPIResource:
+ _client: OpenAI
+
+ def __init__(self, client: OpenAI) -> None:
+ self._client = client
+ self._get = client.get
+ self._post = client.post
+ self._patch = client.patch
+ self._put = client.put
+ self._delete = client.delete
+ self._get_api_list = client.get_api_list
+
+ def _sleep(self, seconds: float) -> None:
+ time.sleep(seconds)
+
+
+class AsyncAPIResource:
+ _client: AsyncOpenAI
+
+ def __init__(self, client: AsyncOpenAI) -> None:
+ self._client = client
+ self._get = client.get
+ self._post = client.post
+ self._patch = client.patch
+ self._put = client.put
+ self._delete = client.delete
+ self._get_api_list = client.get_api_list
+
+ async def _sleep(self, seconds: float) -> None:
+ await asyncio.sleep(seconds)
diff --git a/src/openai/_response.py b/src/openai/_response.py
new file mode 100644
index 0000000000..3cc8fd8cc1
--- /dev/null
+++ b/src/openai/_response.py
@@ -0,0 +1,252 @@
+from __future__ import annotations
+
+import inspect
+import datetime
+import functools
+from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast
+from typing_extensions import Awaitable, ParamSpec, get_args, override, get_origin
+
+import httpx
+import pydantic
+
+from ._types import NoneType, UnknownResponse, BinaryResponseContent
+from ._utils import is_given
+from ._models import BaseModel
+from ._constants import RAW_RESPONSE_HEADER
+from ._exceptions import APIResponseValidationError
+
+if TYPE_CHECKING:
+ from ._models import FinalRequestOptions
+ from ._base_client import Stream, BaseClient, AsyncStream
+
+
+P = ParamSpec("P")
+R = TypeVar("R")
+
+
+class APIResponse(Generic[R]):
+ _cast_to: type[R]
+ _client: BaseClient[Any, Any]
+ _parsed: R | None
+ _stream: bool
+ _stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None
+ _options: FinalRequestOptions
+
+ http_response: httpx.Response
+
+ def __init__(
+ self,
+ *,
+ raw: httpx.Response,
+ cast_to: type[R],
+ client: BaseClient[Any, Any],
+ stream: bool,
+ stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
+ options: FinalRequestOptions,
+ ) -> None:
+ self._cast_to = cast_to
+ self._client = client
+ self._parsed = None
+ self._stream = stream
+ self._stream_cls = stream_cls
+ self._options = options
+ self.http_response = raw
+
+ def parse(self) -> R:
+ if self._parsed is not None:
+ return self._parsed
+
+ parsed = self._parse()
+ if is_given(self._options.post_parser):
+ parsed = self._options.post_parser(parsed)
+
+ self._parsed = parsed
+ return parsed
+
+ @property
+ def headers(self) -> httpx.Headers:
+ return self.http_response.headers
+
+ @property
+ def http_request(self) -> httpx.Request:
+ return self.http_response.request
+
+ @property
+ def status_code(self) -> int:
+ return self.http_response.status_code
+
+ @property
+ def url(self) -> httpx.URL:
+ return self.http_response.url
+
+ @property
+ def method(self) -> str:
+ return self.http_request.method
+
+ @property
+ def content(self) -> bytes:
+ return self.http_response.content
+
+ @property
+ def text(self) -> str:
+ return self.http_response.text
+
+ @property
+ def http_version(self) -> str:
+ return self.http_response.http_version
+
+ @property
+ def elapsed(self) -> datetime.timedelta:
+ """The time taken for the complete request/response cycle to complete."""
+ return self.http_response.elapsed
+
+ def _parse(self) -> R:
+ if self._stream:
+ if self._stream_cls:
+ return cast(
+ R,
+ self._stream_cls(
+ cast_to=_extract_stream_chunk_type(self._stream_cls),
+ response=self.http_response,
+ client=cast(Any, self._client),
+ ),
+ )
+
+ stream_cls = cast("type[Stream[Any]] | type[AsyncStream[Any]] | None", self._client._default_stream_cls)
+ if stream_cls is None:
+ raise MissingStreamClassError()
+
+ return cast(
+ R,
+ stream_cls(
+ cast_to=self._cast_to,
+ response=self.http_response,
+ client=cast(Any, self._client),
+ ),
+ )
+
+ cast_to = self._cast_to
+ if cast_to is NoneType:
+ return cast(R, None)
+
+ response = self.http_response
+ if cast_to == str:
+ return cast(R, response.text)
+
+ origin = get_origin(cast_to) or cast_to
+
+ if inspect.isclass(origin) and issubclass(origin, BinaryResponseContent):
+ return cast(R, cast_to(response)) # type: ignore
+
+ if origin == APIResponse:
+ raise RuntimeError("Unexpected state - cast_to is `APIResponse`")
+
+ if inspect.isclass(origin) and issubclass(origin, httpx.Response):
+ # Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response
+ # and pass that class to our request functions. We cannot change the variance to be either
+ # covariant or contravariant as that makes our usage of ResponseT illegal. We could construct
+ # the response class ourselves but that is something that should be supported directly in httpx
+ # as it would be easy to incorrectly construct the Response object due to the multitude of arguments.
+ if cast_to != httpx.Response:
+ raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`")
+ return cast(R, response)
+
+ # The check here is necessary as we are subverting the the type system
+ # with casts as the relationship between TypeVars and Types are very strict
+ # which means we must return *exactly* what was input or transform it in a
+ # way that retains the TypeVar state. As we cannot do that in this function
+ # then we have to resort to using `cast`. At the time of writing, we know this
+ # to be safe as we have handled all the types that could be bound to the
+ # `ResponseT` TypeVar, however if that TypeVar is ever updated in the future, then
+ # this function would become unsafe but a type checker would not report an error.
+ if (
+ cast_to is not UnknownResponse
+ and not origin is list
+ and not origin is dict
+ and not origin is Union
+ and not issubclass(origin, BaseModel)
+ ):
+ raise RuntimeError(
+ f"Invalid state, expected {cast_to} to be a subclass type of {BaseModel}, {dict}, {list} or {Union}."
+ )
+
+ # split is required to handle cases where additional information is included
+ # in the response, e.g. application/json; charset=utf-8
+ content_type, *_ = response.headers.get("content-type").split(";")
+ if content_type != "application/json":
+ if self._client._strict_response_validation:
+ raise APIResponseValidationError(
+ response=response,
+ message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.",
+ body=response.text,
+ )
+
+ # If the API responds with content that isn't JSON then we just return
+ # the (decoded) text without performing any parsing so that you can still
+ # handle the response however you need to.
+ return response.text # type: ignore
+
+ data = response.json()
+
+ try:
+ return self._client._process_response_data(
+ data=data,
+ cast_to=cast_to, # type: ignore
+ response=response,
+ )
+ except pydantic.ValidationError as err:
+ raise APIResponseValidationError(response=response, body=data) from err
+
+ @override
+ def __repr__(self) -> str:
+ return f""
+
+
+class MissingStreamClassError(TypeError):
+ def __init__(self) -> None:
+ super().__init__(
+ "The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `openai._streaming` for reference",
+ )
+
+
+def _extract_stream_chunk_type(stream_cls: type) -> type:
+ args = get_args(stream_cls)
+ if not args:
+ raise TypeError(
+ f"Expected stream_cls to have been given a generic type argument, e.g. Stream[Foo] but received {stream_cls}",
+ )
+ return cast(type, args[0])
+
+
+def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, APIResponse[R]]:
+ """Higher order function that takes one of our bound API methods and wraps it
+ to support returning the raw `APIResponse` object directly.
+ """
+
+ @functools.wraps(func)
+ def wrapped(*args: P.args, **kwargs: P.kwargs) -> APIResponse[R]:
+ extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})}
+ extra_headers[RAW_RESPONSE_HEADER] = "true"
+
+ kwargs["extra_headers"] = extra_headers
+
+ return cast(APIResponse[R], func(*args, **kwargs))
+
+ return wrapped
+
+
+def async_to_raw_response_wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[APIResponse[R]]]:
+ """Higher order function that takes one of our bound API methods and wraps it
+ to support returning the raw `APIResponse` object directly.
+ """
+
+ @functools.wraps(func)
+ async def wrapped(*args: P.args, **kwargs: P.kwargs) -> APIResponse[R]:
+ extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})}
+ extra_headers[RAW_RESPONSE_HEADER] = "true"
+
+ kwargs["extra_headers"] = extra_headers
+
+ return cast(APIResponse[R], await func(*args, **kwargs))
+
+ return wrapped
diff --git a/src/openai/_streaming.py b/src/openai/_streaming.py
new file mode 100644
index 0000000000..cee737f4f5
--- /dev/null
+++ b/src/openai/_streaming.py
@@ -0,0 +1,232 @@
+# Note: initially copied from https://github.com/florimondmanca/httpx-sse/blob/master/src/httpx_sse/_decoders.py
+from __future__ import annotations
+
+import json
+from typing import TYPE_CHECKING, Any, Generic, Iterator, AsyncIterator
+from typing_extensions import override
+
+import httpx
+
+from ._types import ResponseT
+from ._utils import is_mapping
+from ._exceptions import APIError
+
+if TYPE_CHECKING:
+ from ._base_client import SyncAPIClient, AsyncAPIClient
+
+
+class Stream(Generic[ResponseT]):
+ """Provides the core interface to iterate over a synchronous stream response."""
+
+ response: httpx.Response
+
+ def __init__(
+ self,
+ *,
+ cast_to: type[ResponseT],
+ response: httpx.Response,
+ client: SyncAPIClient,
+ ) -> None:
+ self.response = response
+ self._cast_to = cast_to
+ self._client = client
+ self._decoder = SSEDecoder()
+ self._iterator = self.__stream__()
+
+ def __next__(self) -> ResponseT:
+ return self._iterator.__next__()
+
+ def __iter__(self) -> Iterator[ResponseT]:
+ for item in self._iterator:
+ yield item
+
+ def _iter_events(self) -> Iterator[ServerSentEvent]:
+ yield from self._decoder.iter(self.response.iter_lines())
+
+ def __stream__(self) -> Iterator[ResponseT]:
+ cast_to = self._cast_to
+ response = self.response
+ process_data = self._client._process_response_data
+
+ for sse in self._iter_events():
+ if sse.data.startswith("[DONE]"):
+ break
+
+ if sse.event is None:
+ data = sse.json()
+ if is_mapping(data) and data.get("error"):
+ raise APIError(
+ message="An error ocurred during streaming",
+ request=self.response.request,
+ body=data["error"],
+ )
+
+ yield process_data(data=data, cast_to=cast_to, response=response)
+
+
+class AsyncStream(Generic[ResponseT]):
+ """Provides the core interface to iterate over an asynchronous stream response."""
+
+ response: httpx.Response
+
+ def __init__(
+ self,
+ *,
+ cast_to: type[ResponseT],
+ response: httpx.Response,
+ client: AsyncAPIClient,
+ ) -> None:
+ self.response = response
+ self._cast_to = cast_to
+ self._client = client
+ self._decoder = SSEDecoder()
+ self._iterator = self.__stream__()
+
+ async def __anext__(self) -> ResponseT:
+ return await self._iterator.__anext__()
+
+ async def __aiter__(self) -> AsyncIterator[ResponseT]:
+ async for item in self._iterator:
+ yield item
+
+ async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
+ async for sse in self._decoder.aiter(self.response.aiter_lines()):
+ yield sse
+
+ async def __stream__(self) -> AsyncIterator[ResponseT]:
+ cast_to = self._cast_to
+ response = self.response
+ process_data = self._client._process_response_data
+
+ async for sse in self._iter_events():
+ if sse.data.startswith("[DONE]"):
+ break
+
+ if sse.event is None:
+ data = sse.json()
+ if is_mapping(data) and data.get("error"):
+ raise APIError(
+ message="An error ocurred during streaming",
+ request=self.response.request,
+ body=data["error"],
+ )
+
+ yield process_data(data=data, cast_to=cast_to, response=response)
+
+
+class ServerSentEvent:
+ def __init__(
+ self,
+ *,
+ event: str | None = None,
+ data: str | None = None,
+ id: str | None = None,
+ retry: int | None = None,
+ ) -> None:
+ if data is None:
+ data = ""
+
+ self._id = id
+ self._data = data
+ self._event = event or None
+ self._retry = retry
+
+ @property
+ def event(self) -> str | None:
+ return self._event
+
+ @property
+ def id(self) -> str | None:
+ return self._id
+
+ @property
+ def retry(self) -> int | None:
+ return self._retry
+
+ @property
+ def data(self) -> str:
+ return self._data
+
+ def json(self) -> Any:
+ return json.loads(self.data)
+
+ @override
+ def __repr__(self) -> str:
+ return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})"
+
+
+class SSEDecoder:
+ _data: list[str]
+ _event: str | None
+ _retry: int | None
+ _last_event_id: str | None
+
+ def __init__(self) -> None:
+ self._event = None
+ self._data = []
+ self._last_event_id = None
+ self._retry = None
+
+ def iter(self, iterator: Iterator[str]) -> Iterator[ServerSentEvent]:
+ """Given an iterator that yields lines, iterate over it & yield every event encountered"""
+ for line in iterator:
+ line = line.rstrip("\n")
+ sse = self.decode(line)
+ if sse is not None:
+ yield sse
+
+ async def aiter(self, iterator: AsyncIterator[str]) -> AsyncIterator[ServerSentEvent]:
+ """Given an async iterator that yields lines, iterate over it & yield every event encountered"""
+ async for line in iterator:
+ line = line.rstrip("\n")
+ sse = self.decode(line)
+ if sse is not None:
+ yield sse
+
+ def decode(self, line: str) -> ServerSentEvent | None:
+ # See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501
+
+ if not line:
+ if not self._event and not self._data and not self._last_event_id and self._retry is None:
+ return None
+
+ sse = ServerSentEvent(
+ event=self._event,
+ data="\n".join(self._data),
+ id=self._last_event_id,
+ retry=self._retry,
+ )
+
+ # NOTE: as per the SSE spec, do not reset last_event_id.
+ self._event = None
+ self._data = []
+ self._retry = None
+
+ return sse
+
+ if line.startswith(":"):
+ return None
+
+ fieldname, _, value = line.partition(":")
+
+ if value.startswith(" "):
+ value = value[1:]
+
+ if fieldname == "event":
+ self._event = value
+ elif fieldname == "data":
+ self._data.append(value)
+ elif fieldname == "id":
+ if "\0" in value:
+ pass
+ else:
+ self._last_event_id = value
+ elif fieldname == "retry":
+ try:
+ self._retry = int(value)
+ except (TypeError, ValueError):
+ pass
+ else:
+ pass # Field is ignored.
+
+ return None
diff --git a/src/openai/_types.py b/src/openai/_types.py
new file mode 100644
index 0000000000..dabd15866f
--- /dev/null
+++ b/src/openai/_types.py
@@ -0,0 +1,343 @@
+from __future__ import annotations
+
+from os import PathLike
+from abc import ABC, abstractmethod
+from typing import (
+ IO,
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ List,
+ Type,
+ Tuple,
+ Union,
+ Mapping,
+ TypeVar,
+ Callable,
+ Iterator,
+ Optional,
+ Sequence,
+ AsyncIterator,
+)
+from typing_extensions import (
+ Literal,
+ Protocol,
+ TypeAlias,
+ TypedDict,
+ override,
+ runtime_checkable,
+)
+
+import pydantic
+from httpx import URL, Proxy, Timeout, Response, BaseTransport, AsyncBaseTransport
+
+if TYPE_CHECKING:
+ from ._models import BaseModel
+
+Transport = BaseTransport
+AsyncTransport = AsyncBaseTransport
+Query = Mapping[str, object]
+Body = object
+AnyMapping = Mapping[str, object]
+ModelT = TypeVar("ModelT", bound=pydantic.BaseModel)
+_T = TypeVar("_T")
+
+
+class BinaryResponseContent(ABC):
+ def __init__(
+ self,
+ response: Any,
+ ) -> None:
+ ...
+
+ @property
+ @abstractmethod
+ def content(self) -> bytes:
+ pass
+
+ @property
+ @abstractmethod
+ def text(self) -> str:
+ pass
+
+ @property
+ @abstractmethod
+ def encoding(self) -> Optional[str]:
+ """
+ Return an encoding to use for decoding the byte content into text.
+ The priority for determining this is given by...
+
+ * `.encoding = <>` has been set explicitly.
+ * The encoding as specified by the charset parameter in the Content-Type header.
+ * The encoding as determined by `default_encoding`, which may either be
+ a string like "utf-8" indicating the encoding to use, or may be a callable
+ which enables charset autodetection.
+ """
+ pass
+
+ @property
+ @abstractmethod
+ def charset_encoding(self) -> Optional[str]:
+ """
+ Return the encoding, as specified by the Content-Type header.
+ """
+ pass
+
+ @abstractmethod
+ def json(self, **kwargs: Any) -> Any:
+ pass
+
+ @abstractmethod
+ def read(self) -> bytes:
+ """
+ Read and return the response content.
+ """
+ pass
+
+ @abstractmethod
+ def iter_bytes(self, chunk_size: Optional[int] = None) -> Iterator[bytes]:
+ """
+ A byte-iterator over the decoded response content.
+ This allows us to handle gzip, deflate, and brotli encoded responses.
+ """
+ pass
+
+ @abstractmethod
+ def iter_text(self, chunk_size: Optional[int] = None) -> Iterator[str]:
+ """
+ A str-iterator over the decoded response content
+ that handles both gzip, deflate, etc but also detects the content's
+ string encoding.
+ """
+ pass
+
+ @abstractmethod
+ def iter_lines(self) -> Iterator[str]:
+ pass
+
+ @abstractmethod
+ def iter_raw(self, chunk_size: Optional[int] = None) -> Iterator[bytes]:
+ """
+ A byte-iterator over the raw response content.
+ """
+ pass
+
+ @abstractmethod
+ def stream_to_file(self, file: str | PathLike[str]) -> None:
+ """
+ Stream the output to the given file.
+ """
+ pass
+
+ @abstractmethod
+ def close(self) -> None:
+ """
+ Close the response and release the connection.
+ Automatically called if the response body is read to completion.
+ """
+ pass
+
+ @abstractmethod
+ async def aread(self) -> bytes:
+ """
+ Read and return the response content.
+ """
+ pass
+
+ @abstractmethod
+ async def aiter_bytes(self, chunk_size: Optional[int] = None) -> AsyncIterator[bytes]:
+ """
+ A byte-iterator over the decoded response content.
+ This allows us to handle gzip, deflate, and brotli encoded responses.
+ """
+ pass
+
+ @abstractmethod
+ async def aiter_text(self, chunk_size: Optional[int] = None) -> AsyncIterator[str]:
+ """
+ A str-iterator over the decoded response content
+ that handles both gzip, deflate, etc but also detects the content's
+ string encoding.
+ """
+ pass
+
+ @abstractmethod
+ async def aiter_lines(self) -> AsyncIterator[str]:
+ pass
+
+ @abstractmethod
+ async def aiter_raw(self, chunk_size: Optional[int] = None) -> AsyncIterator[bytes]:
+ """
+ A byte-iterator over the raw response content.
+ """
+ pass
+
+ async def astream_to_file(self, file: str | PathLike[str]) -> None:
+ """
+ Stream the output to the given file.
+ """
+ pass
+
+ @abstractmethod
+ async def aclose(self) -> None:
+ """
+ Close the response and release the connection.
+ Automatically called if the response body is read to completion.
+ """
+ pass
+
+
+# Approximates httpx internal ProxiesTypes and RequestFiles types
+# while adding support for `PathLike` instances
+ProxiesDict = Dict["str | URL", Union[None, str, URL, Proxy]]
+ProxiesTypes = Union[str, Proxy, ProxiesDict]
+if TYPE_CHECKING:
+ FileContent = Union[IO[bytes], bytes, PathLike[str]]
+else:
+ FileContent = Union[IO[bytes], bytes, PathLike] # PathLike is not subscriptable in Python 3.8.
+FileTypes = Union[
+ # file (or bytes)
+ FileContent,
+ # (filename, file (or bytes))
+ Tuple[Optional[str], FileContent],
+ # (filename, file (or bytes), content_type)
+ Tuple[Optional[str], FileContent, Optional[str]],
+ # (filename, file (or bytes), content_type, headers)
+ Tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]],
+]
+RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]]
+
+# duplicate of the above but without our custom file support
+HttpxFileContent = Union[IO[bytes], bytes]
+HttpxFileTypes = Union[
+ # file (or bytes)
+ HttpxFileContent,
+ # (filename, file (or bytes))
+ Tuple[Optional[str], HttpxFileContent],
+ # (filename, file (or bytes), content_type)
+ Tuple[Optional[str], HttpxFileContent, Optional[str]],
+ # (filename, file (or bytes), content_type, headers)
+ Tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]],
+]
+HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]]
+
+# Workaround to support (cast_to: Type[ResponseT]) -> ResponseT
+# where ResponseT includes `None`. In order to support directly
+# passing `None`, overloads would have to be defined for every
+# method that uses `ResponseT` which would lead to an unacceptable
+# amount of code duplication and make it unreadable. See _base_client.py
+# for example usage.
+#
+# This unfortunately means that you will either have
+# to import this type and pass it explicitly:
+#
+# from openai import NoneType
+# client.get('/foo', cast_to=NoneType)
+#
+# or build it yourself:
+#
+# client.get('/foo', cast_to=type(None))
+if TYPE_CHECKING:
+ NoneType: Type[None]
+else:
+ NoneType = type(None)
+
+
+class RequestOptions(TypedDict, total=False):
+ headers: Headers
+ max_retries: int
+ timeout: float | Timeout | None
+ params: Query
+ extra_json: AnyMapping
+ idempotency_key: str
+
+
+# Sentinel class used when the response type is an object with an unknown schema
+class UnknownResponse:
+ ...
+
+
+# Sentinel class used until PEP 0661 is accepted
+class NotGiven:
+ """
+ A sentinel singleton class used to distinguish omitted keyword arguments
+ from those passed in with the value None (which may have different behavior).
+
+ For example:
+
+ ```py
+ def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ...
+
+ get(timout=1) # 1s timeout
+ get(timout=None) # No timeout
+ get() # Default timeout behavior, which may not be statically known at the method definition.
+ ```
+ """
+
+ def __bool__(self) -> Literal[False]:
+ return False
+
+ @override
+ def __repr__(self) -> str:
+ return "NOT_GIVEN"
+
+
+NotGivenOr = Union[_T, NotGiven]
+NOT_GIVEN = NotGiven()
+
+
+class Omit:
+ """In certain situations you need to be able to represent a case where a default value has
+ to be explicitly removed and `None` is not an appropriate substitute, for example:
+
+ ```py
+ # as the default `Content-Type` header is `application/json` that will be sent
+ client.post('/upload/files', files={'file': b'my raw file content'})
+
+ # you can't explicitly override the header as it has to be dynamically generated
+ # to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983'
+ client.post(..., headers={'Content-Type': 'multipart/form-data'})
+
+ # instead you can remove the default `application/json` header by passing Omit
+ client.post(..., headers={'Content-Type': Omit()})
+ ```
+ """
+
+ def __bool__(self) -> Literal[False]:
+ return False
+
+
+@runtime_checkable
+class ModelBuilderProtocol(Protocol):
+ @classmethod
+ def build(
+ cls: type[_T],
+ *,
+ response: Response,
+ data: object,
+ ) -> _T:
+ ...
+
+
+Headers = Mapping[str, Union[str, Omit]]
+
+
+class HeadersLikeProtocol(Protocol):
+ def get(self, __key: str) -> str | None:
+ ...
+
+
+HeadersLike = Union[Headers, HeadersLikeProtocol]
+
+ResponseT = TypeVar(
+ "ResponseT",
+ bound="Union[str, None, BaseModel, List[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]",
+)
+
+StrBytesIntFloat = Union[str, bytes, int, float]
+
+# Note: copied from Pydantic
+# https://github.com/pydantic/pydantic/blob/32ea570bf96e84234d2992e1ddf40ab8a565925a/pydantic/main.py#L49
+IncEx: TypeAlias = "set[int] | set[str] | dict[int, Any] | dict[str, Any] | None"
+
+PostParser = Callable[[Any], Any]
diff --git a/src/openai/_utils/__init__.py b/src/openai/_utils/__init__.py
new file mode 100644
index 0000000000..d3397212de
--- /dev/null
+++ b/src/openai/_utils/__init__.py
@@ -0,0 +1,36 @@
+from ._proxy import LazyProxy as LazyProxy
+from ._utils import flatten as flatten
+from ._utils import is_dict as is_dict
+from ._utils import is_list as is_list
+from ._utils import is_given as is_given
+from ._utils import is_tuple as is_tuple
+from ._utils import is_mapping as is_mapping
+from ._utils import is_tuple_t as is_tuple_t
+from ._utils import parse_date as parse_date
+from ._utils import is_sequence as is_sequence
+from ._utils import coerce_float as coerce_float
+from ._utils import is_list_type as is_list_type
+from ._utils import is_mapping_t as is_mapping_t
+from ._utils import removeprefix as removeprefix
+from ._utils import removesuffix as removesuffix
+from ._utils import extract_files as extract_files
+from ._utils import is_sequence_t as is_sequence_t
+from ._utils import is_union_type as is_union_type
+from ._utils import required_args as required_args
+from ._utils import coerce_boolean as coerce_boolean
+from ._utils import coerce_integer as coerce_integer
+from ._utils import file_from_path as file_from_path
+from ._utils import parse_datetime as parse_datetime
+from ._utils import strip_not_given as strip_not_given
+from ._utils import deepcopy_minimal as deepcopy_minimal
+from ._utils import extract_type_arg as extract_type_arg
+from ._utils import is_required_type as is_required_type
+from ._utils import is_annotated_type as is_annotated_type
+from ._utils import maybe_coerce_float as maybe_coerce_float
+from ._utils import get_required_header as get_required_header
+from ._utils import maybe_coerce_boolean as maybe_coerce_boolean
+from ._utils import maybe_coerce_integer as maybe_coerce_integer
+from ._utils import strip_annotated_type as strip_annotated_type
+from ._transform import PropertyInfo as PropertyInfo
+from ._transform import transform as transform
+from ._transform import maybe_transform as maybe_transform
diff --git a/src/openai/_utils/_logs.py b/src/openai/_utils/_logs.py
new file mode 100644
index 0000000000..e5113fd8c0
--- /dev/null
+++ b/src/openai/_utils/_logs.py
@@ -0,0 +1,25 @@
+import os
+import logging
+
+logger: logging.Logger = logging.getLogger("openai")
+httpx_logger: logging.Logger = logging.getLogger("httpx")
+
+
+def _basic_config() -> None:
+ # e.g. [2023-10-05 14:12:26 - openai._base_client:818 - DEBUG] HTTP Request: POST http://127.0.0.1:4010/foo/bar "200 OK"
+ logging.basicConfig(
+ format="[%(asctime)s - %(name)s:%(lineno)d - %(levelname)s] %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ )
+
+
+def setup_logging() -> None:
+ env = os.environ.get("OPENAI_LOG")
+ if env == "debug":
+ _basic_config()
+ logger.setLevel(logging.DEBUG)
+ httpx_logger.setLevel(logging.DEBUG)
+ elif env == "info":
+ _basic_config()
+ logger.setLevel(logging.INFO)
+ httpx_logger.setLevel(logging.INFO)
diff --git a/src/openai/_utils/_proxy.py b/src/openai/_utils/_proxy.py
new file mode 100644
index 0000000000..aa934a3fbc
--- /dev/null
+++ b/src/openai/_utils/_proxy.py
@@ -0,0 +1,61 @@
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from typing import Generic, TypeVar, Iterable, cast
+from typing_extensions import ClassVar, override
+
+T = TypeVar("T")
+
+
+class LazyProxy(Generic[T], ABC):
+ """Implements data methods to pretend that an instance is another instance.
+
+ This includes forwarding attribute access and othe methods.
+ """
+
+ should_cache: ClassVar[bool] = False
+
+ def __init__(self) -> None:
+ self.__proxied: T | None = None
+
+ def __getattr__(self, attr: str) -> object:
+ return getattr(self.__get_proxied__(), attr)
+
+ @override
+ def __repr__(self) -> str:
+ return repr(self.__get_proxied__())
+
+ @override
+ def __str__(self) -> str:
+ return str(self.__get_proxied__())
+
+ @override
+ def __dir__(self) -> Iterable[str]:
+ return self.__get_proxied__().__dir__()
+
+ @property # type: ignore
+ @override
+ def __class__(self) -> type:
+ return self.__get_proxied__().__class__
+
+ def __get_proxied__(self) -> T:
+ if not self.should_cache:
+ return self.__load__()
+
+ proxied = self.__proxied
+ if proxied is not None:
+ return proxied
+
+ self.__proxied = proxied = self.__load__()
+ return proxied
+
+ def __set_proxied__(self, value: T) -> None:
+ self.__proxied = value
+
+ def __as_proxied__(self) -> T:
+ """Helper method that returns the current proxy, typed as the loaded object"""
+ return cast(T, self)
+
+ @abstractmethod
+ def __load__(self) -> T:
+ ...
diff --git a/src/openai/_utils/_transform.py b/src/openai/_utils/_transform.py
new file mode 100644
index 0000000000..db40bff27f
--- /dev/null
+++ b/src/openai/_utils/_transform.py
@@ -0,0 +1,214 @@
+from __future__ import annotations
+
+from typing import Any, List, Mapping, TypeVar, cast
+from datetime import date, datetime
+from typing_extensions import Literal, get_args, override, get_type_hints
+
+import pydantic
+
+from ._utils import (
+ is_list,
+ is_mapping,
+ is_list_type,
+ is_union_type,
+ extract_type_arg,
+ is_required_type,
+ is_annotated_type,
+ strip_annotated_type,
+)
+from .._compat import model_dump, is_typeddict
+
+_T = TypeVar("_T")
+
+
+# TODO: support for drilling globals() and locals()
+# TODO: ensure works correctly with forward references in all cases
+
+
+PropertyFormat = Literal["iso8601", "custom"]
+
+
+class PropertyInfo:
+ """Metadata class to be used in Annotated types to provide information about a given type.
+
+ For example:
+
+ class MyParams(TypedDict):
+ account_holder_name: Annotated[str, PropertyInfo(alias='accountHolderName')]
+
+ This means that {'account_holder_name': 'Robert'} will be transformed to {'accountHolderName': 'Robert'} before being sent to the API.
+ """
+
+ alias: str | None
+ format: PropertyFormat | None
+ format_template: str | None
+
+ def __init__(
+ self,
+ *,
+ alias: str | None = None,
+ format: PropertyFormat | None = None,
+ format_template: str | None = None,
+ ) -> None:
+ self.alias = alias
+ self.format = format
+ self.format_template = format_template
+
+ @override
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}')"
+
+
+def maybe_transform(
+ data: Mapping[str, object] | List[Any] | None,
+ expected_type: object,
+) -> Any | None:
+ """Wrapper over `transform()` that allows `None` to be passed.
+
+ See `transform()` for more details.
+ """
+ if data is None:
+ return None
+ return transform(data, expected_type)
+
+
+# Wrapper over _transform_recursive providing fake types
+def transform(
+ data: _T,
+ expected_type: object,
+) -> _T:
+ """Transform dictionaries based off of type information from the given type, for example:
+
+ ```py
+ class Params(TypedDict, total=False):
+ card_id: Required[Annotated[str, PropertyInfo(alias='cardID')]]
+
+ transformed = transform({'card_id': ''}, Params)
+ # {'cardID': ''}
+ ```
+
+ Any keys / data that does not have type information given will be included as is.
+
+ It should be noted that the transformations that this function does are not represented in the type system.
+ """
+ transformed = _transform_recursive(data, annotation=cast(type, expected_type))
+ return cast(_T, transformed)
+
+
+def _get_annoted_type(type_: type) -> type | None:
+ """If the given type is an `Annotated` type then it is returned, if not `None` is returned.
+
+ This also unwraps the type when applicable, e.g. `Required[Annotated[T, ...]]`
+ """
+ if is_required_type(type_):
+ # Unwrap `Required[Annotated[T, ...]]` to `Annotated[T, ...]`
+ type_ = get_args(type_)[0]
+
+ if is_annotated_type(type_):
+ return type_
+
+ return None
+
+
+def _maybe_transform_key(key: str, type_: type) -> str:
+ """Transform the given `data` based on the annotations provided in `type_`.
+
+ Note: this function only looks at `Annotated` types that contain `PropertInfo` metadata.
+ """
+ annotated_type = _get_annoted_type(type_)
+ if annotated_type is None:
+ # no `Annotated` definition for this type, no transformation needed
+ return key
+
+ # ignore the first argument as it is the actual type
+ annotations = get_args(annotated_type)[1:]
+ for annotation in annotations:
+ if isinstance(annotation, PropertyInfo) and annotation.alias is not None:
+ return annotation.alias
+
+ return key
+
+
+def _transform_recursive(
+ data: object,
+ *,
+ annotation: type,
+ inner_type: type | None = None,
+) -> object:
+ """Transform the given data against the expected type.
+
+ Args:
+ annotation: The direct type annotation given to the particular piece of data.
+ This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
+
+ inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
+ is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
+ the list can be transformed using the metadata from the container type.
+
+ Defaults to the same value as the `annotation` argument.
+ """
+ if inner_type is None:
+ inner_type = annotation
+
+ stripped_type = strip_annotated_type(inner_type)
+ if is_typeddict(stripped_type) and is_mapping(data):
+ return _transform_typeddict(data, stripped_type)
+
+ if is_list_type(stripped_type) and is_list(data):
+ inner_type = extract_type_arg(stripped_type, 0)
+ return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
+
+ if is_union_type(stripped_type):
+ # For union types we run the transformation against all subtypes to ensure that everything is transformed.
+ #
+ # TODO: there may be edge cases where the same normalized field name will transform to two different names
+ # in different subtypes.
+ for subtype in get_args(stripped_type):
+ data = _transform_recursive(data, annotation=annotation, inner_type=subtype)
+ return data
+
+ if isinstance(data, pydantic.BaseModel):
+ return model_dump(data, exclude_unset=True, exclude_defaults=True)
+
+ return _transform_value(data, annotation)
+
+
+def _transform_value(data: object, type_: type) -> object:
+ annotated_type = _get_annoted_type(type_)
+ if annotated_type is None:
+ return data
+
+ # ignore the first argument as it is the actual type
+ annotations = get_args(annotated_type)[1:]
+ for annotation in annotations:
+ if isinstance(annotation, PropertyInfo) and annotation.format is not None:
+ return _format_data(data, annotation.format, annotation.format_template)
+
+ return data
+
+
+def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
+ if isinstance(data, (date, datetime)):
+ if format_ == "iso8601":
+ return data.isoformat()
+
+ if format_ == "custom" and format_template is not None:
+ return data.strftime(format_template)
+
+ return data
+
+
+def _transform_typeddict(
+ data: Mapping[str, object],
+ expected_type: type,
+) -> Mapping[str, object]:
+ result: dict[str, object] = {}
+ annotations = get_type_hints(expected_type, include_extras=True)
+ for key, value in data.items():
+ type_ = annotations.get(key)
+ if type_ is None:
+ # we do not have a type annotation for this field, leave it as is
+ result[key] = value
+ else:
+ result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)
+ return result
diff --git a/src/openai/_utils/_utils.py b/src/openai/_utils/_utils.py
new file mode 100644
index 0000000000..4b51dcb2e8
--- /dev/null
+++ b/src/openai/_utils/_utils.py
@@ -0,0 +1,408 @@
+from __future__ import annotations
+
+import os
+import re
+import inspect
+import functools
+from typing import (
+ Any,
+ Tuple,
+ Mapping,
+ TypeVar,
+ Callable,
+ Iterable,
+ Sequence,
+ cast,
+ overload,
+)
+from pathlib import Path
+from typing_extensions import Required, Annotated, TypeGuard, get_args, get_origin
+
+from .._types import Headers, NotGiven, FileTypes, NotGivenOr, HeadersLike
+from .._compat import is_union as _is_union
+from .._compat import parse_date as parse_date
+from .._compat import parse_datetime as parse_datetime
+
+_T = TypeVar("_T")
+_TupleT = TypeVar("_TupleT", bound=Tuple[object, ...])
+_MappingT = TypeVar("_MappingT", bound=Mapping[str, object])
+_SequenceT = TypeVar("_SequenceT", bound=Sequence[object])
+CallableT = TypeVar("CallableT", bound=Callable[..., Any])
+
+
+def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
+ return [item for sublist in t for item in sublist]
+
+
+def extract_files(
+ # TODO: this needs to take Dict but variance issues.....
+ # create protocol type ?
+ query: Mapping[str, object],
+ *,
+ paths: Sequence[Sequence[str]],
+) -> list[tuple[str, FileTypes]]:
+ """Recursively extract files from the given dictionary based on specified paths.
+
+ A path may look like this ['foo', 'files', '', 'data'].
+
+ Note: this mutates the given dictionary.
+ """
+ files: list[tuple[str, FileTypes]] = []
+ for path in paths:
+ files.extend(_extract_items(query, path, index=0, flattened_key=None))
+ return files
+
+
+def _extract_items(
+ obj: object,
+ path: Sequence[str],
+ *,
+ index: int,
+ flattened_key: str | None,
+) -> list[tuple[str, FileTypes]]:
+ try:
+ key = path[index]
+ except IndexError:
+ if isinstance(obj, NotGiven):
+ # no value was provided - we can safely ignore
+ return []
+
+ # cyclical import
+ from .._files import assert_is_file_content
+
+ # We have exhausted the path, return the entry we found.
+ assert_is_file_content(obj, key=flattened_key)
+ assert flattened_key is not None
+ return [(flattened_key, cast(FileTypes, obj))]
+
+ index += 1
+ if is_dict(obj):
+ try:
+ # We are at the last entry in the path so we must remove the field
+ if (len(path)) == index:
+ item = obj.pop(key)
+ else:
+ item = obj[key]
+ except KeyError:
+ # Key was not present in the dictionary, this is not indicative of an error
+ # as the given path may not point to a required field. We also do not want
+ # to enforce required fields as the API may differ from the spec in some cases.
+ return []
+ if flattened_key is None:
+ flattened_key = key
+ else:
+ flattened_key += f"[{key}]"
+ return _extract_items(
+ item,
+ path,
+ index=index,
+ flattened_key=flattened_key,
+ )
+ elif is_list(obj):
+ if key != "":
+ return []
+
+ return flatten(
+ [
+ _extract_items(
+ item,
+ path,
+ index=index,
+ flattened_key=flattened_key + "[]" if flattened_key is not None else "[]",
+ )
+ for item in obj
+ ]
+ )
+
+ # Something unexpected was passed, just ignore it.
+ return []
+
+
+def is_given(obj: NotGivenOr[_T]) -> TypeGuard[_T]:
+ return not isinstance(obj, NotGiven)
+
+
+# Type safe methods for narrowing types with TypeVars.
+# The default narrowing for isinstance(obj, dict) is dict[unknown, unknown],
+# however this cause Pyright to rightfully report errors. As we know we don't
+# care about the contained types we can safely use `object` in it's place.
+#
+# There are two separate functions defined, `is_*` and `is_*_t` for different use cases.
+# `is_*` is for when you're dealing with an unknown input
+# `is_*_t` is for when you're narrowing a known union type to a specific subset
+
+
+def is_tuple(obj: object) -> TypeGuard[tuple[object, ...]]:
+ return isinstance(obj, tuple)
+
+
+def is_tuple_t(obj: _TupleT | object) -> TypeGuard[_TupleT]:
+ return isinstance(obj, tuple)
+
+
+def is_sequence(obj: object) -> TypeGuard[Sequence[object]]:
+ return isinstance(obj, Sequence)
+
+
+def is_sequence_t(obj: _SequenceT | object) -> TypeGuard[_SequenceT]:
+ return isinstance(obj, Sequence)
+
+
+def is_mapping(obj: object) -> TypeGuard[Mapping[str, object]]:
+ return isinstance(obj, Mapping)
+
+
+def is_mapping_t(obj: _MappingT | object) -> TypeGuard[_MappingT]:
+ return isinstance(obj, Mapping)
+
+
+def is_dict(obj: object) -> TypeGuard[dict[object, object]]:
+ return isinstance(obj, dict)
+
+
+def is_list(obj: object) -> TypeGuard[list[object]]:
+ return isinstance(obj, list)
+
+
+def is_annotated_type(typ: type) -> bool:
+ return get_origin(typ) == Annotated
+
+
+def is_list_type(typ: type) -> bool:
+ return (get_origin(typ) or typ) == list
+
+
+def is_union_type(typ: type) -> bool:
+ return _is_union(get_origin(typ))
+
+
+def is_required_type(typ: type) -> bool:
+ return get_origin(typ) == Required
+
+
+# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
+def strip_annotated_type(typ: type) -> type:
+ if is_required_type(typ) or is_annotated_type(typ):
+ return strip_annotated_type(cast(type, get_args(typ)[0]))
+
+ return typ
+
+
+def extract_type_arg(typ: type, index: int) -> type:
+ args = get_args(typ)
+ try:
+ return cast(type, args[index])
+ except IndexError:
+ raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not")
+
+
+def deepcopy_minimal(item: _T) -> _T:
+ """Minimal reimplementation of copy.deepcopy() that will only copy certain object types:
+
+ - mappings, e.g. `dict`
+ - list
+
+ This is done for performance reasons.
+ """
+ if is_mapping(item):
+ return cast(_T, {k: deepcopy_minimal(v) for k, v in item.items()})
+ if is_list(item):
+ return cast(_T, [deepcopy_minimal(entry) for entry in item])
+ return item
+
+
+# copied from https://github.com/Rapptz/RoboDanny
+def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str:
+ size = len(seq)
+ if size == 0:
+ return ""
+
+ if size == 1:
+ return seq[0]
+
+ if size == 2:
+ return f"{seq[0]} {final} {seq[1]}"
+
+ return delim.join(seq[:-1]) + f" {final} {seq[-1]}"
+
+
+def quote(string: str) -> str:
+ """Add single quotation marks around the given string. Does *not* do any escaping."""
+ return "'" + string + "'"
+
+
+def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]:
+ """Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function.
+
+ Useful for enforcing runtime validation of overloaded functions.
+
+ Example usage:
+ ```py
+ @overload
+ def foo(*, a: str) -> str:
+ ...
+
+ @overload
+ def foo(*, b: bool) -> str:
+ ...
+
+ # This enforces the same constraints that a static type checker would
+ # i.e. that either a or b must be passed to the function
+ @required_args(['a'], ['b'])
+ def foo(*, a: str | None = None, b: bool | None = None) -> str:
+ ...
+ ```
+ """
+
+ def inner(func: CallableT) -> CallableT:
+ params = inspect.signature(func).parameters
+ positional = [
+ name
+ for name, param in params.items()
+ if param.kind
+ in {
+ param.POSITIONAL_ONLY,
+ param.POSITIONAL_OR_KEYWORD,
+ }
+ ]
+
+ @functools.wraps(func)
+ def wrapper(*args: object, **kwargs: object) -> object:
+ given_params: set[str] = set()
+ for i, _ in enumerate(args):
+ try:
+ given_params.add(positional[i])
+ except IndexError:
+ raise TypeError(f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given")
+
+ for key in kwargs.keys():
+ given_params.add(key)
+
+ for variant in variants:
+ matches = all((param in given_params for param in variant))
+ if matches:
+ break
+ else: # no break
+ if len(variants) > 1:
+ variations = human_join(
+ ["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants]
+ )
+ msg = f"Missing required arguments; Expected either {variations} arguments to be given"
+ else:
+ # TODO: this error message is not deterministic
+ missing = list(set(variants[0]) - given_params)
+ if len(missing) > 1:
+ msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}"
+ else:
+ msg = f"Missing required argument: {quote(missing[0])}"
+ raise TypeError(msg)
+ return func(*args, **kwargs)
+
+ return wrapper # type: ignore
+
+ return inner
+
+
+_K = TypeVar("_K")
+_V = TypeVar("_V")
+
+
+@overload
+def strip_not_given(obj: None) -> None:
+ ...
+
+
+@overload
+def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]:
+ ...
+
+
+@overload
+def strip_not_given(obj: object) -> object:
+ ...
+
+
+def strip_not_given(obj: object | None) -> object:
+ """Remove all top-level keys where their values are instances of `NotGiven`"""
+ if obj is None:
+ return None
+
+ if not is_mapping(obj):
+ return obj
+
+ return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
+
+
+def coerce_integer(val: str) -> int:
+ return int(val, base=10)
+
+
+def coerce_float(val: str) -> float:
+ return float(val)
+
+
+def coerce_boolean(val: str) -> bool:
+ return val == "true" or val == "1" or val == "on"
+
+
+def maybe_coerce_integer(val: str | None) -> int | None:
+ if val is None:
+ return None
+ return coerce_integer(val)
+
+
+def maybe_coerce_float(val: str | None) -> float | None:
+ if val is None:
+ return None
+ return coerce_float(val)
+
+
+def maybe_coerce_boolean(val: str | None) -> bool | None:
+ if val is None:
+ return None
+ return coerce_boolean(val)
+
+
+def removeprefix(string: str, prefix: str) -> str:
+ """Remove a prefix from a string.
+
+ Backport of `str.removeprefix` for Python < 3.9
+ """
+ if string.startswith(prefix):
+ return string[len(prefix) :]
+ return string
+
+
+def removesuffix(string: str, suffix: str) -> str:
+ """Remove a suffix from a string.
+
+ Backport of `str.removesuffix` for Python < 3.9
+ """
+ if string.endswith(suffix):
+ return string[: -len(suffix)]
+ return string
+
+
+def file_from_path(path: str) -> FileTypes:
+ contents = Path(path).read_bytes()
+ file_name = os.path.basename(path)
+ return (file_name, contents)
+
+
+def get_required_header(headers: HeadersLike, header: str) -> str:
+ lower_header = header.lower()
+ if isinstance(headers, Mapping):
+ headers = cast(Headers, headers)
+ for k, v in headers.items():
+ if k.lower() == lower_header and isinstance(v, str):
+ return v
+
+ """ to deal with the case where the header looks like Stainless-Event-Id """
+ intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize())
+
+ for normalized_header in [header, lower_header, header.upper(), intercaps_header]:
+ value = headers.get(normalized_header)
+ if value:
+ return value
+
+ raise ValueError(f"Could not find {header} header")
diff --git a/src/openai/_version.py b/src/openai/_version.py
new file mode 100644
index 0000000000..e9a3efc55c
--- /dev/null
+++ b/src/openai/_version.py
@@ -0,0 +1,4 @@
+# File generated from our OpenAPI spec by Stainless.
+
+__title__ = "openai"
+__version__ = "1.0.0"
diff --git a/src/openai/cli/__init__.py b/src/openai/cli/__init__.py
new file mode 100644
index 0000000000..d453d5e179
--- /dev/null
+++ b/src/openai/cli/__init__.py
@@ -0,0 +1 @@
+from ._cli import main as main
diff --git a/src/openai/cli/_api/__init__.py b/src/openai/cli/_api/__init__.py
new file mode 100644
index 0000000000..56a0260a6d
--- /dev/null
+++ b/src/openai/cli/_api/__init__.py
@@ -0,0 +1 @@
+from ._main import register_commands as register_commands
diff --git a/src/openai/cli/_api/_main.py b/src/openai/cli/_api/_main.py
new file mode 100644
index 0000000000..fe5a5e6fc0
--- /dev/null
+++ b/src/openai/cli/_api/_main.py
@@ -0,0 +1,16 @@
+from __future__ import annotations
+
+from argparse import ArgumentParser
+
+from . import chat, audio, files, image, models, completions
+
+
+def register_commands(parser: ArgumentParser) -> None:
+ subparsers = parser.add_subparsers(help="All API subcommands")
+
+ chat.register(subparsers)
+ image.register(subparsers)
+ audio.register(subparsers)
+ files.register(subparsers)
+ models.register(subparsers)
+ completions.register(subparsers)
diff --git a/src/openai/cli/_api/audio.py b/src/openai/cli/_api/audio.py
new file mode 100644
index 0000000000..eaf57748ad
--- /dev/null
+++ b/src/openai/cli/_api/audio.py
@@ -0,0 +1,94 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Optional, cast
+from argparse import ArgumentParser
+
+from .._utils import get_client, print_model
+from ..._types import NOT_GIVEN
+from .._models import BaseModel
+from .._progress import BufferReader
+
+if TYPE_CHECKING:
+ from argparse import _SubParsersAction
+
+
+def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
+ # transcriptions
+ sub = subparser.add_parser("audio.transcriptions.create")
+
+ # Required
+ sub.add_argument("-m", "--model", type=str, default="whisper-1")
+ sub.add_argument("-f", "--file", type=str, required=True)
+ # Optional
+ sub.add_argument("--response-format", type=str)
+ sub.add_argument("--language", type=str)
+ sub.add_argument("-t", "--temperature", type=float)
+ sub.add_argument("--prompt", type=str)
+ sub.set_defaults(func=CLIAudio.transcribe, args_model=CLITranscribeArgs)
+
+ # translations
+ sub = subparser.add_parser("audio.translations.create")
+
+ # Required
+ sub.add_argument("-f", "--file", type=str, required=True)
+ # Optional
+ sub.add_argument("-m", "--model", type=str, default="whisper-1")
+ sub.add_argument("--response-format", type=str)
+ # TODO: doesn't seem to be supported by the API
+ # sub.add_argument("--language", type=str)
+ sub.add_argument("-t", "--temperature", type=float)
+ sub.add_argument("--prompt", type=str)
+ sub.set_defaults(func=CLIAudio.translate, args_model=CLITranslationArgs)
+
+
+class CLITranscribeArgs(BaseModel):
+ model: str
+ file: str
+ response_format: Optional[str] = None
+ language: Optional[str] = None
+ temperature: Optional[float] = None
+ prompt: Optional[str] = None
+
+
+class CLITranslationArgs(BaseModel):
+ model: str
+ file: str
+ response_format: Optional[str] = None
+ language: Optional[str] = None
+ temperature: Optional[float] = None
+ prompt: Optional[str] = None
+
+
+class CLIAudio:
+ @staticmethod
+ def transcribe(args: CLITranscribeArgs) -> None:
+ with open(args.file, "rb") as file_reader:
+ buffer_reader = BufferReader(file_reader.read(), desc="Upload progress")
+
+ model = get_client().audio.transcriptions.create(
+ file=buffer_reader,
+ model=args.model,
+ language=args.language or NOT_GIVEN,
+ temperature=args.temperature or NOT_GIVEN,
+ prompt=args.prompt or NOT_GIVEN,
+ # casts required because the API is typed for enums
+ # but we don't want to validate that here for forwards-compat
+ response_format=cast(Any, args.response_format),
+ )
+ print_model(model)
+
+ @staticmethod
+ def translate(args: CLITranslationArgs) -> None:
+ with open(args.file, "rb") as file_reader:
+ buffer_reader = BufferReader(file_reader.read(), desc="Upload progress")
+
+ model = get_client().audio.translations.create(
+ file=buffer_reader,
+ model=args.model,
+ temperature=args.temperature or NOT_GIVEN,
+ prompt=args.prompt or NOT_GIVEN,
+ # casts required because the API is typed for enums
+ # but we don't want to validate that here for forwards-compat
+ response_format=cast(Any, args.response_format),
+ )
+ print_model(model)
diff --git a/src/openai/cli/_api/chat/__init__.py b/src/openai/cli/_api/chat/__init__.py
new file mode 100644
index 0000000000..87d971630a
--- /dev/null
+++ b/src/openai/cli/_api/chat/__init__.py
@@ -0,0 +1,13 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+from argparse import ArgumentParser
+
+from . import completions
+
+if TYPE_CHECKING:
+ from argparse import _SubParsersAction
+
+
+def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
+ completions.register(subparser)
diff --git a/src/openai/cli/_api/chat/completions.py b/src/openai/cli/_api/chat/completions.py
new file mode 100644
index 0000000000..e7566b143d
--- /dev/null
+++ b/src/openai/cli/_api/chat/completions.py
@@ -0,0 +1,154 @@
+from __future__ import annotations
+
+import sys
+from typing import TYPE_CHECKING, List, Optional, cast
+from argparse import ArgumentParser
+from typing_extensions import NamedTuple
+
+from ..._utils import get_client
+from ..._models import BaseModel
+from ...._streaming import Stream
+from ....types.chat import (
+ ChatCompletionRole,
+ ChatCompletionChunk,
+ CompletionCreateParams,
+)
+from ....types.chat.completion_create_params import (
+ CompletionCreateParamsStreaming,
+ CompletionCreateParamsNonStreaming,
+)
+
+if TYPE_CHECKING:
+ from argparse import _SubParsersAction
+
+
+def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
+ sub = subparser.add_parser("chat.completions.create")
+
+ sub._action_groups.pop()
+ req = sub.add_argument_group("required arguments")
+ opt = sub.add_argument_group("optional arguments")
+
+ req.add_argument(
+ "-g",
+ "--message",
+ action="append",
+ nargs=2,
+ metavar=("ROLE", "CONTENT"),
+ help="A message in `{role} {content}` format. Use this argument multiple times to add multiple messages.",
+ required=True,
+ )
+ req.add_argument(
+ "-m",
+ "--model",
+ help="The model to use.",
+ required=True,
+ )
+
+ opt.add_argument(
+ "-n",
+ "--n",
+ help="How many completions to generate for the conversation.",
+ type=int,
+ )
+ opt.add_argument("-M", "--max-tokens", help="The maximum number of tokens to generate.", type=int)
+ opt.add_argument(
+ "-t",
+ "--temperature",
+ help="""What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
+
+Mutually exclusive with `top_p`.""",
+ type=float,
+ )
+ opt.add_argument(
+ "-P",
+ "--top_p",
+ help="""An alternative to sampling with temperature, called nucleus sampling, where the considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%% probability mass are considered.
+
+ Mutually exclusive with `temperature`.""",
+ type=float,
+ )
+ opt.add_argument(
+ "--stop",
+ help="A stop sequence at which to stop generating tokens for the message.",
+ )
+ opt.add_argument("--stream", help="Stream messages as they're ready.", action="store_true")
+ sub.set_defaults(func=CLIChatCompletion.create, args_model=CLIChatCompletionCreateArgs)
+
+
+class CLIMessage(NamedTuple):
+ role: ChatCompletionRole
+ content: str
+
+
+class CLIChatCompletionCreateArgs(BaseModel):
+ message: List[CLIMessage]
+ model: str
+ n: Optional[int] = None
+ max_tokens: Optional[int] = None
+ temperature: Optional[float] = None
+ top_p: Optional[float] = None
+ stop: Optional[str] = None
+ stream: bool = False
+
+
+class CLIChatCompletion:
+ @staticmethod
+ def create(args: CLIChatCompletionCreateArgs) -> None:
+ params: CompletionCreateParams = {
+ "model": args.model,
+ "messages": [{"role": message.role, "content": message.content} for message in args.message],
+ "n": args.n,
+ "temperature": args.temperature,
+ "top_p": args.top_p,
+ "stop": args.stop,
+ # type checkers are not good at inferring union types so we have to set stream afterwards
+ "stream": False,
+ }
+ if args.stream:
+ params["stream"] = args.stream # type: ignore
+ if args.max_tokens is not None:
+ params["max_tokens"] = args.max_tokens
+
+ if args.stream:
+ return CLIChatCompletion._stream_create(cast(CompletionCreateParamsStreaming, params))
+
+ return CLIChatCompletion._create(cast(CompletionCreateParamsNonStreaming, params))
+
+ @staticmethod
+ def _create(params: CompletionCreateParamsNonStreaming) -> None:
+ completion = get_client().chat.completions.create(**params)
+ should_print_header = len(completion.choices) > 1
+ for choice in completion.choices:
+ if should_print_header:
+ sys.stdout.write("===== Chat Completion {} =====\n".format(choice.index))
+
+ content = choice.message.content if choice.message.content is not None else "None"
+ sys.stdout.write(content)
+
+ if should_print_header or not content.endswith("\n"):
+ sys.stdout.write("\n")
+
+ sys.stdout.flush()
+
+ @staticmethod
+ def _stream_create(params: CompletionCreateParamsStreaming) -> None:
+ # cast is required for mypy
+ stream = cast( # pyright: ignore[reportUnnecessaryCast]
+ Stream[ChatCompletionChunk], get_client().chat.completions.create(**params)
+ )
+ for chunk in stream:
+ should_print_header = len(chunk.choices) > 1
+ for choice in chunk.choices:
+ if should_print_header:
+ sys.stdout.write("===== Chat Completion {} =====\n".format(choice.index))
+
+ content = choice.delta.content or ""
+ sys.stdout.write(content)
+
+ if should_print_header:
+ sys.stdout.write("\n")
+
+ sys.stdout.flush()
+
+ sys.stdout.write("\n")
diff --git a/src/openai/cli/_api/completions.py b/src/openai/cli/_api/completions.py
new file mode 100644
index 0000000000..ce1036b224
--- /dev/null
+++ b/src/openai/cli/_api/completions.py
@@ -0,0 +1,173 @@
+from __future__ import annotations
+
+import sys
+from typing import TYPE_CHECKING, Optional, cast
+from argparse import ArgumentParser
+from functools import partial
+
+from openai.types.completion import Completion
+
+from .._utils import get_client
+from ..._types import NOT_GIVEN, NotGivenOr
+from ..._utils import is_given
+from .._errors import CLIError
+from .._models import BaseModel
+from ..._streaming import Stream
+
+if TYPE_CHECKING:
+ from argparse import _SubParsersAction
+
+
+def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
+ sub = subparser.add_parser("completions.create")
+
+ # Required
+ sub.add_argument(
+ "-m",
+ "--model",
+ help="The model to use",
+ required=True,
+ )
+
+ # Optional
+ sub.add_argument("-p", "--prompt", help="An optional prompt to complete from")
+ sub.add_argument("--stream", help="Stream tokens as they're ready.", action="store_true")
+ sub.add_argument("-M", "--max-tokens", help="The maximum number of tokens to generate", type=int)
+ sub.add_argument(
+ "-t",
+ "--temperature",
+ help="""What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
+
+Mutually exclusive with `top_p`.""",
+ type=float,
+ )
+ sub.add_argument(
+ "-P",
+ "--top_p",
+ help="""An alternative to sampling with temperature, called nucleus sampling, where the considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%% probability mass are considered.
+
+ Mutually exclusive with `temperature`.""",
+ type=float,
+ )
+ sub.add_argument(
+ "-n",
+ "--n",
+ help="How many sub-completions to generate for each prompt.",
+ type=int,
+ )
+ sub.add_argument(
+ "--logprobs",
+ help="Include the log probabilites on the `logprobs` most likely tokens, as well the chosen tokens. So for example, if `logprobs` is 10, the API will return a list of the 10 most likely tokens. If `logprobs` is 0, only the chosen tokens will have logprobs returned.",
+ type=int,
+ )
+ sub.add_argument(
+ "--best_of",
+ help="Generates `best_of` completions server-side and returns the 'best' (the one with the highest log probability per token). Results cannot be streamed.",
+ type=int,
+ )
+ sub.add_argument(
+ "--echo",
+ help="Echo back the prompt in addition to the completion",
+ action="store_true",
+ )
+ sub.add_argument(
+ "--frequency_penalty",
+ help="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.",
+ type=float,
+ )
+ sub.add_argument(
+ "--presence_penalty",
+ help="Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.",
+ type=float,
+ )
+ sub.add_argument("--suffix", help="The suffix that comes after a completion of inserted text.")
+ sub.add_argument("--stop", help="A stop sequence at which to stop generating tokens.")
+ sub.add_argument(
+ "--user",
+ help="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.",
+ )
+ # TODO: add support for logit_bias
+ sub.set_defaults(func=CLICompletions.create, args_model=CLICompletionCreateArgs)
+
+
+class CLICompletionCreateArgs(BaseModel):
+ model: str
+ stream: bool = False
+
+ prompt: Optional[str] = None
+ n: NotGivenOr[int] = NOT_GIVEN
+ stop: NotGivenOr[str] = NOT_GIVEN
+ user: NotGivenOr[str] = NOT_GIVEN
+ echo: NotGivenOr[bool] = NOT_GIVEN
+ suffix: NotGivenOr[str] = NOT_GIVEN
+ best_of: NotGivenOr[int] = NOT_GIVEN
+ top_p: NotGivenOr[float] = NOT_GIVEN
+ logprobs: NotGivenOr[int] = NOT_GIVEN
+ max_tokens: NotGivenOr[int] = NOT_GIVEN
+ temperature: NotGivenOr[float] = NOT_GIVEN
+ presence_penalty: NotGivenOr[float] = NOT_GIVEN
+ frequency_penalty: NotGivenOr[float] = NOT_GIVEN
+
+
+class CLICompletions:
+ @staticmethod
+ def create(args: CLICompletionCreateArgs) -> None:
+ if is_given(args.n) and args.n > 1 and args.stream:
+ raise CLIError("Can't stream completions with n>1 with the current CLI")
+
+ make_request = partial(
+ get_client().completions.create,
+ n=args.n,
+ echo=args.echo,
+ stop=args.stop,
+ user=args.user,
+ model=args.model,
+ top_p=args.top_p,
+ prompt=args.prompt,
+ suffix=args.suffix,
+ best_of=args.best_of,
+ logprobs=args.logprobs,
+ max_tokens=args.max_tokens,
+ temperature=args.temperature,
+ presence_penalty=args.presence_penalty,
+ frequency_penalty=args.frequency_penalty,
+ )
+
+ if args.stream:
+ return CLICompletions._stream_create(
+ # mypy doesn't understand the `partial` function but pyright does
+ cast(Stream[Completion], make_request(stream=True)) # pyright: ignore[reportUnnecessaryCast]
+ )
+
+ return CLICompletions._create(make_request())
+
+ @staticmethod
+ def _create(completion: Completion) -> None:
+ should_print_header = len(completion.choices) > 1
+ for choice in completion.choices:
+ if should_print_header:
+ sys.stdout.write("===== Completion {} =====\n".format(choice.index))
+
+ sys.stdout.write(choice.text)
+
+ if should_print_header or not choice.text.endswith("\n"):
+ sys.stdout.write("\n")
+
+ sys.stdout.flush()
+
+ @staticmethod
+ def _stream_create(stream: Stream[Completion]) -> None:
+ for completion in stream:
+ should_print_header = len(completion.choices) > 1
+ for choice in sorted(completion.choices, key=lambda c: c.index):
+ if should_print_header:
+ sys.stdout.write("===== Chat Completion {} =====\n".format(choice.index))
+
+ sys.stdout.write(choice.text)
+
+ if should_print_header:
+ sys.stdout.write("\n")
+
+ sys.stdout.flush()
+
+ sys.stdout.write("\n")
diff --git a/src/openai/cli/_api/files.py b/src/openai/cli/_api/files.py
new file mode 100644
index 0000000000..ae6dadf0f1
--- /dev/null
+++ b/src/openai/cli/_api/files.py
@@ -0,0 +1,75 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+from argparse import ArgumentParser
+
+from .._utils import get_client, print_model
+from .._models import BaseModel
+from .._progress import BufferReader
+
+if TYPE_CHECKING:
+ from argparse import _SubParsersAction
+
+
+def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
+ sub = subparser.add_parser("files.create")
+
+ sub.add_argument(
+ "-f",
+ "--file",
+ required=True,
+ help="File to upload",
+ )
+ sub.add_argument(
+ "-p",
+ "--purpose",
+ help="Why are you uploading this file? (see https://platform.openai.com/docs/api-reference/ for purposes)",
+ required=True,
+ )
+ sub.set_defaults(func=CLIFile.create, args_model=CLIFileCreateArgs)
+
+ sub = subparser.add_parser("files.retrieve")
+ sub.add_argument("-i", "--id", required=True, help="The files ID")
+ sub.set_defaults(func=CLIFile.get, args_model=CLIFileCreateArgs)
+
+ sub = subparser.add_parser("files.delete")
+ sub.add_argument("-i", "--id", required=True, help="The files ID")
+ sub.set_defaults(func=CLIFile.delete, args_model=CLIFileCreateArgs)
+
+ sub = subparser.add_parser("files.list")
+ sub.set_defaults(func=CLIFile.list)
+
+
+class CLIFileIDArgs(BaseModel):
+ id: str
+
+
+class CLIFileCreateArgs(BaseModel):
+ file: str
+ purpose: str
+
+
+class CLIFile:
+ @staticmethod
+ def create(args: CLIFileCreateArgs) -> None:
+ with open(args.file, "rb") as file_reader:
+ buffer_reader = BufferReader(file_reader.read(), desc="Upload progress")
+
+ file = get_client().files.create(file=(args.file, buffer_reader), purpose=args.purpose)
+ print_model(file)
+
+ @staticmethod
+ def get(args: CLIFileIDArgs) -> None:
+ file = get_client().files.retrieve(file_id=args.id)
+ print_model(file)
+
+ @staticmethod
+ def delete(args: CLIFileIDArgs) -> None:
+ file = get_client().files.delete(file_id=args.id)
+ print_model(file)
+
+ @staticmethod
+ def list() -> None:
+ files = get_client().files.list()
+ for file in files:
+ print_model(file)
diff --git a/src/openai/cli/_api/image.py b/src/openai/cli/_api/image.py
new file mode 100644
index 0000000000..e6149eeac4
--- /dev/null
+++ b/src/openai/cli/_api/image.py
@@ -0,0 +1,130 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, cast
+from argparse import ArgumentParser
+
+from .._utils import get_client, print_model
+from ..._types import NOT_GIVEN, NotGiven, NotGivenOr
+from .._models import BaseModel
+from .._progress import BufferReader
+
+if TYPE_CHECKING:
+ from argparse import _SubParsersAction
+
+
+def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
+ sub = subparser.add_parser("images.generate")
+ sub.add_argument("-p", "--prompt", type=str, required=True)
+ sub.add_argument("-n", "--num-images", type=int, default=1)
+ sub.add_argument("-s", "--size", type=str, default="1024x1024", help="Size of the output image")
+ sub.add_argument("--response-format", type=str, default="url")
+ sub.set_defaults(func=CLIImage.create, args_model=CLIImageCreateArgs)
+
+ sub = subparser.add_parser("images.edit")
+ sub.add_argument("-p", "--prompt", type=str, required=True)
+ sub.add_argument("-n", "--num-images", type=int, default=1)
+ sub.add_argument(
+ "-I",
+ "--image",
+ type=str,
+ required=True,
+ help="Image to modify. Should be a local path and a PNG encoded image.",
+ )
+ sub.add_argument("-s", "--size", type=str, default="1024x1024", help="Size of the output image")
+ sub.add_argument("--response-format", type=str, default="url")
+ sub.add_argument(
+ "-M",
+ "--mask",
+ type=str,
+ required=False,
+ help="Path to a mask image. It should be the same size as the image you're editing and a RGBA PNG image. The Alpha channel acts as the mask.",
+ )
+ sub.set_defaults(func=CLIImage.edit, args_model=CLIImageEditArgs)
+
+ sub = subparser.add_parser("images.create_variation")
+ sub.add_argument("-n", "--num-images", type=int, default=1)
+ sub.add_argument(
+ "-I",
+ "--image",
+ type=str,
+ required=True,
+ help="Image to modify. Should be a local path and a PNG encoded image.",
+ )
+ sub.add_argument("-s", "--size", type=str, default="1024x1024", help="Size of the output image")
+ sub.add_argument("--response-format", type=str, default="url")
+ sub.set_defaults(func=CLIImage.create_variation, args_model=CLIImageCreateVariationArgs)
+
+
+class CLIImageCreateArgs(BaseModel):
+ prompt: str
+ num_images: int
+ size: str
+ response_format: str
+
+
+class CLIImageCreateVariationArgs(BaseModel):
+ image: str
+ num_images: int
+ size: str
+ response_format: str
+
+
+class CLIImageEditArgs(BaseModel):
+ image: str
+ num_images: int
+ size: str
+ response_format: str
+ prompt: str
+ mask: NotGivenOr[str] = NOT_GIVEN
+
+
+class CLIImage:
+ @staticmethod
+ def create(args: CLIImageCreateArgs) -> None:
+ image = get_client().images.generate(
+ prompt=args.prompt,
+ n=args.num_images,
+ # casts required because the API is typed for enums
+ # but we don't want to validate that here for forwards-compat
+ size=cast(Any, args.size),
+ response_format=cast(Any, args.response_format),
+ )
+ print_model(image)
+
+ @staticmethod
+ def create_variation(args: CLIImageCreateVariationArgs) -> None:
+ with open(args.image, "rb") as file_reader:
+ buffer_reader = BufferReader(file_reader.read(), desc="Upload progress")
+
+ image = get_client().images.create_variation(
+ image=("image", buffer_reader),
+ n=args.num_images,
+ # casts required because the API is typed for enums
+ # but we don't want to validate that here for forwards-compat
+ size=cast(Any, args.size),
+ response_format=cast(Any, args.response_format),
+ )
+ print_model(image)
+
+ @staticmethod
+ def edit(args: CLIImageEditArgs) -> None:
+ with open(args.image, "rb") as file_reader:
+ buffer_reader = BufferReader(file_reader.read(), desc="Image upload progress")
+
+ if isinstance(args.mask, NotGiven):
+ mask: NotGivenOr[BufferReader] = NOT_GIVEN
+ else:
+ with open(args.mask, "rb") as file_reader:
+ mask = BufferReader(file_reader.read(), desc="Mask progress")
+
+ image = get_client().images.edit(
+ prompt=args.prompt,
+ image=("image", buffer_reader),
+ n=args.num_images,
+ mask=("mask", mask) if not isinstance(mask, NotGiven) else mask,
+ # casts required because the API is typed for enums
+ # but we don't want to validate that here for forwards-compat
+ size=cast(Any, args.size),
+ response_format=cast(Any, args.response_format),
+ )
+ print_model(image)
diff --git a/src/openai/cli/_api/models.py b/src/openai/cli/_api/models.py
new file mode 100644
index 0000000000..017218fa6e
--- /dev/null
+++ b/src/openai/cli/_api/models.py
@@ -0,0 +1,45 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+from argparse import ArgumentParser
+
+from .._utils import get_client, print_model
+from .._models import BaseModel
+
+if TYPE_CHECKING:
+ from argparse import _SubParsersAction
+
+
+def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
+ sub = subparser.add_parser("models.list")
+ sub.set_defaults(func=CLIModels.list)
+
+ sub = subparser.add_parser("models.retrieve")
+ sub.add_argument("-i", "--id", required=True, help="The model ID")
+ sub.set_defaults(func=CLIModels.get, args_model=CLIModelIDArgs)
+
+ sub = subparser.add_parser("models.delete")
+ sub.add_argument("-i", "--id", required=True, help="The model ID")
+ sub.set_defaults(func=CLIModels.delete, args_model=CLIModelIDArgs)
+
+
+class CLIModelIDArgs(BaseModel):
+ id: str
+
+
+class CLIModels:
+ @staticmethod
+ def get(args: CLIModelIDArgs) -> None:
+ model = get_client().models.retrieve(model=args.id)
+ print_model(model)
+
+ @staticmethod
+ def delete(args: CLIModelIDArgs) -> None:
+ model = get_client().models.delete(model=args.id)
+ print_model(model)
+
+ @staticmethod
+ def list() -> None:
+ models = get_client().models.list()
+ for model in models:
+ print_model(model)
diff --git a/src/openai/cli/_cli.py b/src/openai/cli/_cli.py
new file mode 100644
index 0000000000..72e5c923bd
--- /dev/null
+++ b/src/openai/cli/_cli.py
@@ -0,0 +1,234 @@
+from __future__ import annotations
+
+import sys
+import logging
+import argparse
+from typing import Any, List, Type, Optional
+from typing_extensions import ClassVar
+
+import httpx
+import pydantic
+
+import openai
+
+from . import _tools
+from .. import _ApiType, __version__
+from ._api import register_commands
+from ._utils import can_use_http2
+from .._types import ProxiesDict
+from ._errors import CLIError, display_error
+from .._compat import PYDANTIC_V2, ConfigDict, model_parse
+from .._models import BaseModel
+from .._exceptions import APIError
+
+logger = logging.getLogger()
+formatter = logging.Formatter("[%(asctime)s] %(message)s")
+handler = logging.StreamHandler(sys.stderr)
+handler.setFormatter(formatter)
+logger.addHandler(handler)
+
+
+class Arguments(BaseModel):
+ if PYDANTIC_V2:
+ model_config: ClassVar[ConfigDict] = ConfigDict(
+ extra="ignore",
+ )
+ else:
+
+ class Config(pydantic.BaseConfig): # type: ignore
+ extra: Any = pydantic.Extra.ignore # type: ignore
+
+ verbosity: int
+ version: Optional[str] = None
+
+ api_key: Optional[str]
+ api_base: Optional[str]
+ organization: Optional[str]
+ proxy: Optional[List[str]]
+ api_type: Optional[_ApiType] = None
+ api_version: Optional[str] = None
+
+ # azure
+ azure_endpoint: Optional[str] = None
+ azure_ad_token: Optional[str] = None
+
+ # internal, set by subparsers to parse their specific args
+ args_model: Optional[Type[BaseModel]] = None
+
+ # internal, used so that subparsers can forward unknown arguments
+ unknown_args: List[str] = []
+ allow_unknown_args: bool = False
+
+
+def _build_parser() -> argparse.ArgumentParser:
+ parser = argparse.ArgumentParser(description=None, prog="openai")
+ parser.add_argument(
+ "-v",
+ "--verbose",
+ action="count",
+ dest="verbosity",
+ default=0,
+ help="Set verbosity.",
+ )
+ parser.add_argument("-b", "--api-base", help="What API base url to use.")
+ parser.add_argument("-k", "--api-key", help="What API key to use.")
+ parser.add_argument("-p", "--proxy", nargs="+", help="What proxy to use.")
+ parser.add_argument(
+ "-o",
+ "--organization",
+ help="Which organization to run as (will use your default organization if not specified)",
+ )
+ parser.add_argument(
+ "-t",
+ "--api-type",
+ type=str,
+ choices=("openai", "azure"),
+ help="The backend API to call, must be `openai` or `azure`",
+ )
+ parser.add_argument(
+ "--api-version",
+ help="The Azure API version, e.g. 'https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning'",
+ )
+
+ # azure
+ parser.add_argument(
+ "--azure-endpoint",
+ help="The Azure endpoint, e.g. 'https://endpoint.openai.azure.com'",
+ )
+ parser.add_argument(
+ "--azure-ad-token",
+ help="A token from Azure Active Directory, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id",
+ )
+
+ # prints the package version
+ parser.add_argument(
+ "-V",
+ "--version",
+ action="version",
+ version="%(prog)s " + __version__,
+ )
+
+ def help() -> None:
+ parser.print_help()
+
+ parser.set_defaults(func=help)
+
+ subparsers = parser.add_subparsers()
+ sub_api = subparsers.add_parser("api", help="Direct API calls")
+
+ register_commands(sub_api)
+
+ sub_tools = subparsers.add_parser("tools", help="Client side tools for convenience")
+ _tools.register_commands(sub_tools, subparsers)
+
+ return parser
+
+
+def main() -> int:
+ try:
+ _main()
+ except (APIError, CLIError, pydantic.ValidationError) as err:
+ display_error(err)
+ return 1
+ except KeyboardInterrupt:
+ sys.stderr.write("\n")
+ return 1
+ return 0
+
+
+def _parse_args(parser: argparse.ArgumentParser) -> tuple[argparse.Namespace, Arguments, list[str]]:
+ # argparse by default will strip out the `--` but we want to keep it for unknown arguments
+ if "--" in sys.argv:
+ idx = sys.argv.index("--")
+ known_args = sys.argv[1:idx]
+ unknown_args = sys.argv[idx:]
+ else:
+ known_args = sys.argv[1:]
+ unknown_args = []
+
+ parsed, remaining_unknown = parser.parse_known_args(known_args)
+
+ # append any remaining unknown arguments from the initial parsing
+ remaining_unknown.extend(unknown_args)
+
+ args = model_parse(Arguments, vars(parsed))
+ if not args.allow_unknown_args:
+ # we have to parse twice to ensure any unknown arguments
+ # result in an error if that behaviour is desired
+ parser.parse_args()
+
+ return parsed, args, remaining_unknown
+
+
+def _main() -> None:
+ parser = _build_parser()
+ parsed, args, unknown = _parse_args(parser)
+
+ if args.verbosity != 0:
+ sys.stderr.write("Warning: --verbosity isn't supported yet\n")
+
+ proxies: ProxiesDict = {}
+ if args.proxy is not None:
+ for proxy in args.proxy:
+ key = "https://" if proxy.startswith("https") else "http://"
+ if key in proxies:
+ raise CLIError(f"Multiple {key} proxies given - only the last one would be used")
+
+ proxies[key] = proxy
+
+ http_client = httpx.Client(
+ proxies=proxies or None,
+ http2=can_use_http2(),
+ )
+ openai.http_client = http_client
+
+ if args.organization:
+ openai.organization = args.organization
+
+ if args.api_key:
+ openai.api_key = args.api_key
+
+ if args.api_base:
+ openai.base_url = args.api_base
+
+ # azure
+ if args.api_type is not None:
+ openai.api_type = args.api_type
+
+ if args.azure_endpoint is not None:
+ openai.azure_endpoint = args.azure_endpoint
+
+ if args.api_version is not None:
+ openai.api_version = args.api_version
+
+ if args.azure_ad_token is not None:
+ openai.azure_ad_token = args.azure_ad_token
+
+ try:
+ if args.args_model:
+ parsed.func(
+ model_parse(
+ args.args_model,
+ {
+ **{
+ # we omit None values so that they can be defaulted to `NotGiven`
+ # and we'll strip it from the API request
+ key: value
+ for key, value in vars(parsed).items()
+ if value is not None
+ },
+ "unknown_args": unknown,
+ },
+ )
+ )
+ else:
+ parsed.func()
+ finally:
+ try:
+ http_client.close()
+ except Exception:
+ pass
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/src/openai/cli/_errors.py b/src/openai/cli/_errors.py
new file mode 100644
index 0000000000..ac2a3780d0
--- /dev/null
+++ b/src/openai/cli/_errors.py
@@ -0,0 +1,23 @@
+from __future__ import annotations
+
+import sys
+
+import pydantic
+
+from ._utils import Colours, organization_info
+from .._exceptions import APIError, OpenAIError
+
+
+class CLIError(OpenAIError):
+ ...
+
+
+class SilentCLIError(CLIError):
+ ...
+
+
+def display_error(err: CLIError | APIError | pydantic.ValidationError) -> None:
+ if isinstance(err, SilentCLIError):
+ return
+
+ sys.stderr.write("{}{}Error:{} {}\n".format(organization_info(), Colours.FAIL, Colours.ENDC, err))
diff --git a/src/openai/cli/_models.py b/src/openai/cli/_models.py
new file mode 100644
index 0000000000..5583db2609
--- /dev/null
+++ b/src/openai/cli/_models.py
@@ -0,0 +1,17 @@
+from typing import Any
+from typing_extensions import ClassVar
+
+import pydantic
+
+from .. import _models
+from .._compat import PYDANTIC_V2, ConfigDict
+
+
+class BaseModel(_models.BaseModel):
+ if PYDANTIC_V2:
+ model_config: ClassVar[ConfigDict] = ConfigDict(extra="ignore", arbitrary_types_allowed=True)
+ else:
+
+ class Config(pydantic.BaseConfig): # type: ignore
+ extra: Any = pydantic.Extra.ignore # type: ignore
+ arbitrary_types_allowed: bool = True
diff --git a/src/openai/cli/_progress.py b/src/openai/cli/_progress.py
new file mode 100644
index 0000000000..390aaa9dfe
--- /dev/null
+++ b/src/openai/cli/_progress.py
@@ -0,0 +1,59 @@
+from __future__ import annotations
+
+import io
+from typing import Callable
+from typing_extensions import override
+
+
+class CancelledError(Exception):
+ def __init__(self, msg: str) -> None:
+ self.msg = msg
+ super().__init__(msg)
+
+ @override
+ def __str__(self) -> str:
+ return self.msg
+
+ __repr__ = __str__
+
+
+class BufferReader(io.BytesIO):
+ def __init__(self, buf: bytes = b"", desc: str | None = None) -> None:
+ super().__init__(buf)
+ self._len = len(buf)
+ self._progress = 0
+ self._callback = progress(len(buf), desc=desc)
+
+ def __len__(self) -> int:
+ return self._len
+
+ @override
+ def read(self, n: int | None = -1) -> bytes:
+ chunk = io.BytesIO.read(self, n)
+ self._progress += len(chunk)
+
+ try:
+ self._callback(self._progress)
+ except Exception as e: # catches exception from the callback
+ raise CancelledError("The upload was cancelled: {}".format(e))
+
+ return chunk
+
+
+def progress(total: float, desc: str | None) -> Callable[[float], None]:
+ import tqdm
+
+ meter = tqdm.tqdm(total=total, unit_scale=True, desc=desc)
+
+ def incr(progress: float) -> None:
+ meter.n = progress
+ if progress == total:
+ meter.close()
+ else:
+ meter.refresh()
+
+ return incr
+
+
+def MB(i: int) -> int:
+ return int(i // 1024**2)
diff --git a/src/openai/cli/_tools/__init__.py b/src/openai/cli/_tools/__init__.py
new file mode 100644
index 0000000000..56a0260a6d
--- /dev/null
+++ b/src/openai/cli/_tools/__init__.py
@@ -0,0 +1 @@
+from ._main import register_commands as register_commands
diff --git a/src/openai/cli/_tools/_main.py b/src/openai/cli/_tools/_main.py
new file mode 100644
index 0000000000..bd6cda408f
--- /dev/null
+++ b/src/openai/cli/_tools/_main.py
@@ -0,0 +1,17 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+from argparse import ArgumentParser
+
+from . import migrate, fine_tunes
+
+if TYPE_CHECKING:
+ from argparse import _SubParsersAction
+
+
+def register_commands(parser: ArgumentParser, subparser: _SubParsersAction[ArgumentParser]) -> None:
+ migrate.register(subparser)
+
+ namespaced = parser.add_subparsers(title="Tools", help="Convenience client side tools")
+
+ fine_tunes.register(namespaced)
diff --git a/src/openai/cli/_tools/fine_tunes.py b/src/openai/cli/_tools/fine_tunes.py
new file mode 100644
index 0000000000..2128b88952
--- /dev/null
+++ b/src/openai/cli/_tools/fine_tunes.py
@@ -0,0 +1,63 @@
+from __future__ import annotations
+
+import sys
+from typing import TYPE_CHECKING
+from argparse import ArgumentParser
+
+from .._models import BaseModel
+from ...lib._validators import (
+ get_validators,
+ write_out_file,
+ read_any_format,
+ apply_validators,
+ apply_necessary_remediation,
+)
+
+if TYPE_CHECKING:
+ from argparse import _SubParsersAction
+
+
+def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
+ sub = subparser.add_parser("fine_tunes.prepare_data")
+ sub.add_argument(
+ "-f",
+ "--file",
+ required=True,
+ help="JSONL, JSON, CSV, TSV, TXT or XLSX file containing prompt-completion examples to be analyzed."
+ "This should be the local file path.",
+ )
+ sub.add_argument(
+ "-q",
+ "--quiet",
+ required=False,
+ action="store_true",
+ help="Auto accepts all suggestions, without asking for user input. To be used within scripts.",
+ )
+ sub.set_defaults(func=prepare_data, args_model=PrepareDataArgs)
+
+
+class PrepareDataArgs(BaseModel):
+ file: str
+
+ quiet: bool
+
+
+def prepare_data(args: PrepareDataArgs) -> None:
+ sys.stdout.write("Analyzing...\n")
+ fname = args.file
+ auto_accept = args.quiet
+ df, remediation = read_any_format(fname)
+ apply_necessary_remediation(None, remediation)
+
+ validators = get_validators()
+
+ assert df is not None
+
+ apply_validators(
+ df,
+ fname,
+ remediation,
+ validators,
+ auto_accept,
+ write_out_file_func=write_out_file,
+ )
diff --git a/src/openai/cli/_tools/migrate.py b/src/openai/cli/_tools/migrate.py
new file mode 100644
index 0000000000..714bead8e3
--- /dev/null
+++ b/src/openai/cli/_tools/migrate.py
@@ -0,0 +1,181 @@
+from __future__ import annotations
+
+import os
+import sys
+import json
+import shutil
+import tarfile
+import platform
+import subprocess
+from typing import TYPE_CHECKING, List
+from pathlib import Path
+from argparse import ArgumentParser
+
+import httpx
+
+from .._errors import CLIError, SilentCLIError
+from .._models import BaseModel
+
+if TYPE_CHECKING:
+ from argparse import _SubParsersAction
+
+
+def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
+ sub = subparser.add_parser("migrate")
+ sub.set_defaults(func=migrate, args_model=MigrateArgs, allow_unknown_args=True)
+
+ sub = subparser.add_parser("grit")
+ sub.set_defaults(func=grit, args_model=GritArgs, allow_unknown_args=True)
+
+
+class GritArgs(BaseModel):
+ # internal
+ unknown_args: List[str] = []
+
+
+def grit(args: GritArgs) -> None:
+ grit_path = install()
+
+ try:
+ subprocess.check_call([grit_path, *args.unknown_args])
+ except subprocess.CalledProcessError:
+ # stdout and stderr are forwarded by subprocess so an error will already
+ # have been displayed
+ raise SilentCLIError()
+
+
+class MigrateArgs(BaseModel):
+ # internal
+ unknown_args: List[str] = []
+
+
+def migrate(args: MigrateArgs) -> None:
+ grit_path = install()
+
+ try:
+ subprocess.check_call([grit_path, "apply", "openai", *args.unknown_args])
+ except subprocess.CalledProcessError:
+ # stdout and stderr are forwarded by subprocess so an error will already
+ # have been displayed
+ raise SilentCLIError()
+
+
+# handles downloading the Grit CLI until they provide their own PyPi package
+
+KEYGEN_ACCOUNT = "custodian-dev"
+
+
+def _cache_dir() -> Path:
+ xdg = os.environ.get("XDG_CACHE_HOME")
+ if xdg is not None:
+ return Path(xdg)
+
+ return Path.home() / ".cache"
+
+
+def _debug(message: str) -> None:
+ if not os.environ.get("DEBUG"):
+ return
+
+ sys.stdout.write(f"[DEBUG]: {message}\n")
+
+
+def install() -> Path:
+ """Installs the Grit CLI and returns the location of the binary"""
+ if sys.platform == "win32":
+ raise CLIError("Windows is not supported yet in the migration CLI")
+
+ platform = "macos" if sys.platform == "darwin" else "linux"
+
+ dir_name = _cache_dir() / "openai-python"
+ install_dir = dir_name / ".install"
+ target_dir = install_dir / "bin"
+
+ target_path = target_dir / "marzano"
+ temp_file = target_dir / "marzano.tmp"
+
+ if target_path.exists():
+ _debug(f"{target_path} already exists")
+ sys.stdout.flush()
+ return target_path
+
+ _debug(f"Using Grit CLI path: {target_path}")
+
+ target_dir.mkdir(parents=True, exist_ok=True)
+
+ if temp_file.exists():
+ temp_file.unlink()
+
+ arch = _get_arch()
+ _debug(f"Using architecture {arch}")
+
+ file_name = f"marzano-{platform}-{arch}"
+ meta_url = f"https://api.keygen.sh/v1/accounts/{KEYGEN_ACCOUNT}/artifacts/{file_name}"
+
+ sys.stdout.write(f"Retrieving Grit CLI metadata from {meta_url}\n")
+ with httpx.Client() as client:
+ response = client.get(meta_url) # pyright: ignore[reportUnknownMemberType]
+
+ data = response.json()
+ errors = data.get("errors")
+ if errors:
+ for error in errors:
+ sys.stdout.write(f"{error}\n")
+
+ raise CLIError("Could not locate Grit CLI binary - see above errors")
+
+ write_manifest(install_dir, data["data"]["relationships"]["release"]["data"]["id"])
+
+ link = data["data"]["links"]["redirect"]
+ _debug(f"Redirect URL {link}")
+
+ download_response = client.get(link) # pyright: ignore[reportUnknownMemberType]
+ with open(temp_file, "wb") as file:
+ for chunk in download_response.iter_bytes():
+ file.write(chunk)
+
+ unpacked_dir = target_dir / "cli-bin"
+ unpacked_dir.mkdir(parents=True, exist_ok=True)
+
+ with tarfile.open(temp_file, "r:gz") as archive:
+ archive.extractall(unpacked_dir)
+
+ for item in unpacked_dir.iterdir():
+ item.rename(target_dir / item.name)
+
+ shutil.rmtree(unpacked_dir)
+ os.remove(temp_file)
+ os.chmod(target_path, 0o755)
+
+ sys.stdout.flush()
+
+ return target_path
+
+
+def _get_arch() -> str:
+ architecture = platform.machine().lower()
+
+ # Map the architecture names to Node.js equivalents
+ arch_map = {
+ "x86_64": "x64",
+ "amd64": "x64",
+ "armv7l": "arm",
+ "aarch64": "arm64",
+ }
+
+ return arch_map.get(architecture, architecture)
+
+
+def write_manifest(install_path: Path, release: str) -> None:
+ manifest = {
+ "installPath": str(install_path),
+ "binaries": {
+ "marzano": {
+ "name": "marzano",
+ "release": release,
+ },
+ },
+ }
+ manifest_path = Path(install_path) / "manifests.json"
+ with open(manifest_path, "w") as f:
+ json.dump(manifest, f, indent=2)
diff --git a/src/openai/cli/_utils.py b/src/openai/cli/_utils.py
new file mode 100644
index 0000000000..027ab08de3
--- /dev/null
+++ b/src/openai/cli/_utils.py
@@ -0,0 +1,45 @@
+from __future__ import annotations
+
+import sys
+
+import openai
+
+from .. import OpenAI, _load_client
+from .._compat import model_json
+from .._models import BaseModel
+
+
+class Colours:
+ HEADER = "\033[95m"
+ OKBLUE = "\033[94m"
+ OKGREEN = "\033[92m"
+ WARNING = "\033[93m"
+ FAIL = "\033[91m"
+ ENDC = "\033[0m"
+ BOLD = "\033[1m"
+ UNDERLINE = "\033[4m"
+
+
+def get_client() -> OpenAI:
+ return _load_client()
+
+
+def organization_info() -> str:
+ organization = openai.organization
+ if organization is not None:
+ return "[organization={}] ".format(organization)
+
+ return ""
+
+
+def print_model(model: BaseModel) -> None:
+ sys.stdout.write(model_json(model, indent=2) + "\n")
+
+
+def can_use_http2() -> bool:
+ try:
+ import h2 # type: ignore # noqa
+ except ImportError:
+ return False
+
+ return True
diff --git a/openai/validators.py b/src/openai/lib/_validators.py
similarity index 80%
rename from openai/validators.py
rename to src/openai/lib/_validators.py
index 078179a44b..8e4ed3c9f4 100644
--- a/openai/validators.py
+++ b/src/openai/lib/_validators.py
@@ -1,9 +1,12 @@
+# pyright: basic
+from __future__ import annotations
+
import os
import sys
-from typing import Any, Callable, NamedTuple, Optional
+from typing import Any, TypeVar, Callable, Optional, NamedTuple
+from typing_extensions import TypeAlias
-from openai.datalib.pandas_helper import assert_has_pandas
-from openai.datalib.pandas_helper import pandas as pd
+from .._extras import pandas as pd
class Remediation(NamedTuple):
@@ -16,7 +19,10 @@ class Remediation(NamedTuple):
error_msg: Optional[str] = None
-def num_examples_validator(df):
+OptionalDataFrameT = TypeVar("OptionalDataFrameT", bound="Optional[pd.DataFrame]")
+
+
+def num_examples_validator(df: pd.DataFrame) -> Remediation:
"""
This validator will only print out the number of examples and recommend to the user to increase the number of examples if less than 100.
"""
@@ -26,18 +32,16 @@ def num_examples_validator(df):
if len(df) >= MIN_EXAMPLES
else ". In general, we recommend having at least a few hundred examples. We've found that performance tends to linearly increase for every doubling of the number of examples"
)
- immediate_msg = (
- f"\n- Your file contains {len(df)} prompt-completion pairs{optional_suggestion}"
- )
+ immediate_msg = f"\n- Your file contains {len(df)} prompt-completion pairs{optional_suggestion}"
return Remediation(name="num_examples", immediate_msg=immediate_msg)
-def necessary_column_validator(df, necessary_column):
+def necessary_column_validator(df: pd.DataFrame, necessary_column: str) -> Remediation:
"""
This validator will ensure that the necessary column is present in the dataframe.
"""
- def lower_case_column(df, column):
+ def lower_case_column(df: pd.DataFrame, column: Any) -> pd.DataFrame:
cols = [c for c in df.columns if str(c).lower() == column]
df.rename(columns={cols[0]: column.lower()}, inplace=True)
return df
@@ -50,13 +54,11 @@ def lower_case_column(df, column):
if necessary_column not in df.columns:
if necessary_column in [str(c).lower() for c in df.columns]:
- def lower_case_column_creator(df):
+ def lower_case_column_creator(df: pd.DataFrame) -> pd.DataFrame:
return lower_case_column(df, necessary_column)
necessary_fn = lower_case_column_creator
- immediate_msg = (
- f"\n- The `{necessary_column}` column/key should be lowercase"
- )
+ immediate_msg = f"\n- The `{necessary_column}` column/key should be lowercase"
necessary_msg = f"Lower case column name to `{necessary_column}`"
else:
error_msg = f"`{necessary_column}` column/key is missing. Please make sure you name your columns/keys appropriately, then retry"
@@ -70,14 +72,15 @@ def lower_case_column_creator(df):
)
-def additional_column_validator(df, fields=["prompt", "completion"]):
+def additional_column_validator(df: pd.DataFrame, fields: list[str] = ["prompt", "completion"]) -> Remediation:
"""
This validator will remove additional columns from the dataframe.
"""
additional_columns = []
necessary_msg = None
immediate_msg = None
- necessary_fn = None
+ necessary_fn = None # type: ignore
+
if len(df.columns) > 2:
additional_columns = [c for c in df.columns if c not in fields]
warn_message = ""
@@ -88,7 +91,7 @@ def additional_column_validator(df, fields=["prompt", "completion"]):
immediate_msg = f"\n- The input file should contain exactly two columns/keys per row. Additional columns/keys present are: {additional_columns}{warn_message}"
necessary_msg = f"Remove additional columns/keys: {additional_columns}"
- def necessary_fn(x):
+ def necessary_fn(x: Any) -> Any:
return x[fields]
return Remediation(
@@ -99,12 +102,12 @@ def necessary_fn(x):
)
-def non_empty_field_validator(df, field="completion"):
+def non_empty_field_validator(df: pd.DataFrame, field: str = "completion") -> Remediation:
"""
This validator will ensure that no completion is empty.
"""
necessary_msg = None
- necessary_fn = None
+ necessary_fn = None # type: ignore
immediate_msg = None
if df[field].apply(lambda x: x == "").any() or df[field].isnull().any():
@@ -112,10 +115,11 @@ def non_empty_field_validator(df, field="completion"):
empty_indexes = df.reset_index().index[empty_rows].tolist()
immediate_msg = f"\n- `{field}` column/key should not contain empty strings. These are rows: {empty_indexes}"
- def necessary_fn(x):
+ def necessary_fn(x: Any) -> Any:
return x[x[field] != ""].dropna(subset=[field])
necessary_msg = f"Remove {len(empty_indexes)} rows with empty {field}s"
+
return Remediation(
name=f"empty_{field}",
immediate_msg=immediate_msg,
@@ -124,7 +128,7 @@ def necessary_fn(x):
)
-def duplicated_rows_validator(df, fields=["prompt", "completion"]):
+def duplicated_rows_validator(df: pd.DataFrame, fields: list[str] = ["prompt", "completion"]) -> Remediation:
"""
This validator will suggest to the user to remove duplicate rows if they exist.
"""
@@ -132,13 +136,13 @@ def duplicated_rows_validator(df, fields=["prompt", "completion"]):
duplicated_indexes = df.reset_index().index[duplicated_rows].tolist()
immediate_msg = None
optional_msg = None
- optional_fn = None
+ optional_fn = None # type: ignore
if len(duplicated_indexes) > 0:
immediate_msg = f"\n- There are {len(duplicated_indexes)} duplicated {'-'.join(fields)} sets. These are rows: {duplicated_indexes}"
optional_msg = f"Remove {len(duplicated_indexes)} duplicate rows"
- def optional_fn(x):
+ def optional_fn(x: Any) -> Any:
return x.drop_duplicates(subset=fields)
return Remediation(
@@ -149,21 +153,19 @@ def optional_fn(x):
)
-def long_examples_validator(df):
+def long_examples_validator(df: pd.DataFrame) -> Remediation:
"""
This validator will suggest to the user to remove examples that are too long.
"""
immediate_msg = None
optional_msg = None
- optional_fn = None
+ optional_fn = None # type: ignore
ft_type = infer_task_type(df)
if ft_type != "open-ended generation":
- def get_long_indexes(d):
- long_examples = d.apply(
- lambda x: len(x.prompt) + len(x.completion) > 10000, axis=1
- )
+ def get_long_indexes(d: pd.DataFrame) -> Any:
+ long_examples = d.apply(lambda x: len(x.prompt) + len(x.completion) > 10000, axis=1)
return d.reset_index().index[long_examples].tolist()
long_indexes = get_long_indexes(df)
@@ -172,8 +174,7 @@ def get_long_indexes(d):
immediate_msg = f"\n- There are {len(long_indexes)} examples that are very long. These are rows: {long_indexes}\nFor conditional generation, and for classification the examples shouldn't be longer than 2048 tokens."
optional_msg = f"Remove {len(long_indexes)} long examples"
- def optional_fn(x):
-
+ def optional_fn(x: Any) -> Any:
long_indexes_to_drop = get_long_indexes(x)
if long_indexes != long_indexes_to_drop:
sys.stdout.write(
@@ -189,14 +190,14 @@ def optional_fn(x):
)
-def common_prompt_suffix_validator(df):
+def common_prompt_suffix_validator(df: pd.DataFrame) -> Remediation:
"""
This validator will suggest to add a common suffix to the prompt if one doesn't already exist in case of classification or conditional generation.
"""
error_msg = None
immediate_msg = None
optional_msg = None
- optional_fn = None
+ optional_fn = None # type: ignore
# Find a suffix which is not contained within the prompt otherwise
suggested_suffix = "\n\n### =>\n\n"
@@ -222,7 +223,7 @@ def common_prompt_suffix_validator(df):
if ft_type == "open-ended generation":
return Remediation(name="common_suffix")
- def add_suffix(x, suffix):
+ def add_suffix(x: Any, suffix: Any) -> Any:
x["prompt"] += suffix
return x
@@ -233,27 +234,19 @@ def add_suffix(x, suffix):
if common_suffix != "":
common_suffix_new_line_handled = common_suffix.replace("\n", "\\n")
- immediate_msg = (
- f"\n- All prompts end with suffix `{common_suffix_new_line_handled}`"
- )
+ immediate_msg = f"\n- All prompts end with suffix `{common_suffix_new_line_handled}`"
if len(common_suffix) > 10:
immediate_msg += f". This suffix seems very long. Consider replacing with a shorter suffix, such as `{display_suggested_suffix}`"
- if (
- df.prompt.str[: -len(common_suffix)]
- .str.contains(common_suffix, regex=False)
- .any()
- ):
+ if df.prompt.str[: -len(common_suffix)].str.contains(common_suffix, regex=False).any():
immediate_msg += f"\n WARNING: Some of your prompts contain the suffix `{common_suffix}` more than once. We strongly suggest that you review your prompts and add a unique suffix"
else:
immediate_msg = "\n- Your data does not contain a common separator at the end of your prompts. Having a separator string appended to the end of the prompt makes it clearer to the fine-tuned model where the completion should begin. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples. If you intend to do open-ended generation, then you should leave the prompts empty"
if common_suffix == "":
- optional_msg = (
- f"Add a suffix separator `{display_suggested_suffix}` to all prompts"
- )
+ optional_msg = f"Add a suffix separator `{display_suggested_suffix}` to all prompts"
- def optional_fn(x):
+ def optional_fn(x: Any) -> Any:
return add_suffix(x, suggested_suffix)
return Remediation(
@@ -265,7 +258,7 @@ def optional_fn(x):
)
-def common_prompt_prefix_validator(df):
+def common_prompt_prefix_validator(df: pd.DataFrame) -> Remediation:
"""
This validator will suggest to remove a common prefix from the prompt if a long one exist.
"""
@@ -273,13 +266,13 @@ def common_prompt_prefix_validator(df):
immediate_msg = None
optional_msg = None
- optional_fn = None
+ optional_fn = None # type: ignore
common_prefix = get_common_xfix(df.prompt, xfix="prefix")
if common_prefix == "":
return Remediation(name="common_prefix")
- def remove_common_prefix(x, prefix):
+ def remove_common_prefix(x: Any, prefix: Any) -> Any:
x["prompt"] = x["prompt"].str[len(prefix) :]
return x
@@ -293,7 +286,7 @@ def remove_common_prefix(x, prefix):
immediate_msg += ". Fine-tuning doesn't require the instruction specifying the task, or a few-shot example scenario. Most of the time you should only add the input data into the prompt, and the desired output into the completion"
optional_msg = f"Remove prefix `{common_prefix}` from all prompts"
- def optional_fn(x):
+ def optional_fn(x: Any) -> Any:
return remove_common_prefix(x, common_prefix)
return Remediation(
@@ -304,7 +297,7 @@ def optional_fn(x):
)
-def common_completion_prefix_validator(df):
+def common_completion_prefix_validator(df: pd.DataFrame) -> Remediation:
"""
This validator will suggest to remove a common prefix from the completion if a long one exist.
"""
@@ -315,7 +308,7 @@ def common_completion_prefix_validator(df):
if len(common_prefix) < MAX_PREFIX_LEN:
return Remediation(name="common_prefix")
- def remove_common_prefix(x, prefix, ws_prefix):
+ def remove_common_prefix(x: Any, prefix: Any, ws_prefix: Any) -> Any:
x["completion"] = x["completion"].str[len(prefix) :]
if ws_prefix:
# keep the single whitespace as prefix
@@ -329,7 +322,7 @@ def remove_common_prefix(x, prefix, ws_prefix):
immediate_msg = f"\n- All completions start with prefix `{common_prefix}`. Most of the time you should only add the output data into the completion, without any prefix"
optional_msg = f"Remove prefix `{common_prefix}` from all completions"
- def optional_fn(x):
+ def optional_fn(x: Any) -> Any:
return remove_common_prefix(x, common_prefix, ws_prefix)
return Remediation(
@@ -340,14 +333,14 @@ def optional_fn(x):
)
-def common_completion_suffix_validator(df):
+def common_completion_suffix_validator(df: pd.DataFrame) -> Remediation:
"""
This validator will suggest to add a common suffix to the completion if one doesn't already exist in case of classification or conditional generation.
"""
error_msg = None
immediate_msg = None
optional_msg = None
- optional_fn = None
+ optional_fn = None # type: ignore
ft_type = infer_task_type(df)
if ft_type == "open-ended generation" or ft_type == "classification":
@@ -378,33 +371,25 @@ def common_completion_suffix_validator(df):
break
display_suggested_suffix = suggested_suffix.replace("\n", "\\n")
- def add_suffix(x, suffix):
+ def add_suffix(x: Any, suffix: Any) -> Any:
x["completion"] += suffix
return x
if common_suffix != "":
common_suffix_new_line_handled = common_suffix.replace("\n", "\\n")
- immediate_msg = (
- f"\n- All completions end with suffix `{common_suffix_new_line_handled}`"
- )
+ immediate_msg = f"\n- All completions end with suffix `{common_suffix_new_line_handled}`"
if len(common_suffix) > 10:
immediate_msg += f". This suffix seems very long. Consider replacing with a shorter suffix, such as `{display_suggested_suffix}`"
- if (
- df.completion.str[: -len(common_suffix)]
- .str.contains(common_suffix, regex=False)
- .any()
- ):
+ if df.completion.str[: -len(common_suffix)].str.contains(common_suffix, regex=False).any():
immediate_msg += f"\n WARNING: Some of your completions contain the suffix `{common_suffix}` more than once. We suggest that you review your completions and add a unique ending"
else:
immediate_msg = "\n- Your data does not contain a common ending at the end of your completions. Having a common ending string appended to the end of the completion makes it clearer to the fine-tuned model where the completion should end. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples."
if common_suffix == "":
- optional_msg = (
- f"Add a suffix ending `{display_suggested_suffix}` to all completions"
- )
+ optional_msg = f"Add a suffix ending `{display_suggested_suffix}` to all completions"
- def optional_fn(x):
+ def optional_fn(x: Any) -> Any:
return add_suffix(x, suggested_suffix)
return Remediation(
@@ -416,15 +401,13 @@ def optional_fn(x):
)
-def completions_space_start_validator(df):
+def completions_space_start_validator(df: pd.DataFrame) -> Remediation:
"""
This validator will suggest to add a space at the start of the completion if it doesn't already exist. This helps with tokenization.
"""
- def add_space_start(x):
- x["completion"] = x["completion"].apply(
- lambda x: ("" if x[0] == " " else " ") + x
- )
+ def add_space_start(x: Any) -> Any:
+ x["completion"] = x["completion"].apply(lambda x: ("" if x[0] == " " else " ") + x)
return x
optional_msg = None
@@ -443,25 +426,17 @@ def add_space_start(x):
)
-def lower_case_validator(df, column):
+def lower_case_validator(df: pd.DataFrame, column: Any) -> Remediation | None:
"""
This validator will suggest to lowercase the column values, if more than a third of letters are uppercase.
"""
- def lower_case(x):
+ def lower_case(x: Any) -> Any:
x[column] = x[column].str.lower()
return x
- count_upper = (
- df[column]
- .apply(lambda x: sum(1 for c in x if c.isalpha() and c.isupper()))
- .sum()
- )
- count_lower = (
- df[column]
- .apply(lambda x: sum(1 for c in x if c.isalpha() and c.islower()))
- .sum()
- )
+ count_upper = df[column].apply(lambda x: sum(1 for c in x if c.isalpha() and c.isupper())).sum()
+ count_lower = df[column].apply(lambda x: sum(1 for c in x if c.isalpha() and c.islower())).sum()
if count_upper * 2 > count_lower:
return Remediation(
@@ -470,15 +445,17 @@ def lower_case(x):
optional_msg=f"Lowercase all your data in column/key `{column}`",
optional_fn=lower_case,
)
+ return None
-def read_any_format(fname, fields=["prompt", "completion"]):
+def read_any_format(
+ fname: str, fields: list[str] = ["prompt", "completion"]
+) -> tuple[pd.DataFrame | None, Remediation]:
"""
This function will read a file saved in .csv, .json, .txt, .xlsx or .tsv format using pandas.
- for .xlsx it will read the first sheet
- for .txt it will assume completions and split on newline
"""
- assert_has_pandas()
remediation = None
necessary_msg = None
immediate_msg = None
@@ -488,13 +465,11 @@ def read_any_format(fname, fields=["prompt", "completion"]):
if os.path.isfile(fname):
try:
if fname.lower().endswith(".csv") or fname.lower().endswith(".tsv"):
- file_extension_str, separator = (
- ("CSV", ",") if fname.lower().endswith(".csv") else ("TSV", "\t")
- )
- immediate_msg = f"\n- Based on your file extension, your file is formatted as a {file_extension_str} file"
- necessary_msg = (
- f"Your format `{file_extension_str}` will be converted to `JSONL`"
+ file_extension_str, separator = ("CSV", ",") if fname.lower().endswith(".csv") else ("TSV", "\t")
+ immediate_msg = (
+ f"\n- Based on your file extension, your file is formatted as a {file_extension_str} file"
)
+ necessary_msg = f"Your format `{file_extension_str}` will be converted to `JSONL`"
df = pd.read_csv(fname, sep=separator, dtype=str).fillna("")
elif fname.lower().endswith(".xlsx"):
immediate_msg = "\n- Based on your file extension, your file is formatted as an Excel file"
@@ -505,9 +480,7 @@ def read_any_format(fname, fields=["prompt", "completion"]):
immediate_msg += "\n- Your Excel file contains more than one sheet. Please either save as csv or ensure all data is present in the first sheet. WARNING: Reading only the first sheet..."
df = pd.read_excel(fname, dtype=str).fillna("")
elif fname.lower().endswith(".txt"):
- immediate_msg = (
- "\n- Based on your file extension, you provided a text file"
- )
+ immediate_msg = "\n- Based on your file extension, you provided a text file"
necessary_msg = "Your format `TXT` will be converted to `JSONL`"
with open(fname, "r") as f:
content = f.read()
@@ -517,32 +490,32 @@ def read_any_format(fname, fields=["prompt", "completion"]):
dtype=str,
).fillna("")
elif fname.lower().endswith(".jsonl"):
- df = pd.read_json(fname, lines=True, dtype=str).fillna("")
- if len(df) == 1:
+ df = pd.read_json(fname, lines=True, dtype=str).fillna("") # type: ignore
+ if len(df) == 1: # type: ignore
# this is NOT what we expect for a .jsonl file
immediate_msg = "\n- Your JSONL file appears to be in a JSON format. Your file will be converted to JSONL format"
necessary_msg = "Your format `JSON` will be converted to `JSONL`"
- df = pd.read_json(fname, dtype=str).fillna("")
+ df = pd.read_json(fname, dtype=str).fillna("") # type: ignore
else:
pass # this is what we expect for a .jsonl file
elif fname.lower().endswith(".json"):
try:
# to handle case where .json file is actually a .jsonl file
- df = pd.read_json(fname, lines=True, dtype=str).fillna("")
- if len(df) == 1:
+ df = pd.read_json(fname, lines=True, dtype=str).fillna("") # type: ignore
+ if len(df) == 1: # type: ignore
# this code path corresponds to a .json file that has one line
- df = pd.read_json(fname, dtype=str).fillna("")
+ df = pd.read_json(fname, dtype=str).fillna("") # type: ignore
else:
# this is NOT what we expect for a .json file
immediate_msg = "\n- Your JSON file appears to be in a JSONL format. Your file will be converted to JSONL format"
- necessary_msg = (
- "Your format `JSON` will be converted to `JSONL`"
- )
+ necessary_msg = "Your format `JSON` will be converted to `JSONL`"
except ValueError:
# this code path corresponds to a .json file that has multiple lines (i.e. it is indented)
- df = pd.read_json(fname, dtype=str).fillna("")
+ df = pd.read_json(fname, dtype=str).fillna("") # type: ignore
else:
- error_msg = "Your file must have one of the following extensions: .CSV, .TSV, .XLSX, .TXT, .JSON or .JSONL"
+ error_msg = (
+ "Your file must have one of the following extensions: .CSV, .TSV, .XLSX, .TXT, .JSON or .JSONL"
+ )
if "." in fname:
error_msg += f" Your file `{fname}` ends with the extension `.{fname.split('.')[-1]}` which is not supported."
else:
@@ -564,7 +537,7 @@ def read_any_format(fname, fields=["prompt", "completion"]):
return df, remediation
-def format_inferrer_validator(df):
+def format_inferrer_validator(df: pd.DataFrame) -> Remediation:
"""
This validator will infer the likely fine-tuning format of the data, and display it to the user if it is classification.
It will also suggest to use ada and explain train/validation split benefits.
@@ -576,14 +549,12 @@ def format_inferrer_validator(df):
return Remediation(name="num_examples", immediate_msg=immediate_msg)
-def apply_necessary_remediation(df, remediation):
+def apply_necessary_remediation(df: OptionalDataFrameT, remediation: Remediation) -> OptionalDataFrameT:
"""
This function will apply a necessary remediation to a dataframe, or print an error message if one exists.
"""
if remediation.error_msg is not None:
- sys.stderr.write(
- f"\n\nERROR in {remediation.name} validator: {remediation.error_msg}\n\nAborting..."
- )
+ sys.stderr.write(f"\n\nERROR in {remediation.name} validator: {remediation.error_msg}\n\nAborting...")
sys.exit(1)
if remediation.immediate_msg is not None:
sys.stdout.write(remediation.immediate_msg)
@@ -592,7 +563,7 @@ def apply_necessary_remediation(df, remediation):
return df
-def accept_suggestion(input_text, auto_accept):
+def accept_suggestion(input_text: str, auto_accept: bool) -> bool:
sys.stdout.write(input_text)
if auto_accept:
sys.stdout.write("Y\n")
@@ -600,7 +571,9 @@ def accept_suggestion(input_text, auto_accept):
return input().lower() != "n"
-def apply_optional_remediation(df, remediation, auto_accept):
+def apply_optional_remediation(
+ df: pd.DataFrame, remediation: Remediation, auto_accept: bool
+) -> tuple[pd.DataFrame, bool]:
"""
This function will apply an optional remediation to a dataframe, based on the user input.
"""
@@ -608,6 +581,7 @@ def apply_optional_remediation(df, remediation, auto_accept):
input_text = f"- [Recommended] {remediation.optional_msg} [Y/n]: "
if remediation.optional_msg is not None:
if accept_suggestion(input_text, auto_accept):
+ assert remediation.optional_fn is not None
df = remediation.optional_fn(df)
optional_applied = True
if remediation.necessary_msg is not None:
@@ -615,7 +589,7 @@ def apply_optional_remediation(df, remediation, auto_accept):
return df, optional_applied
-def estimate_fine_tuning_time(df):
+def estimate_fine_tuning_time(df: pd.DataFrame) -> None:
"""
Estimate the time it'll take to fine-tune the dataset
"""
@@ -628,7 +602,7 @@ def estimate_fine_tuning_time(df):
size = df.memory_usage(index=True).sum()
expected_time = size * 0.0515
- def format_time(time):
+ def format_time(time: float) -> str:
if time < 60:
return f"{round(time, 2)} seconds"
elif time < 3600:
@@ -644,21 +618,20 @@ def format_time(time):
)
-def get_outfnames(fname, split):
+def get_outfnames(fname: str, split: bool) -> list[str]:
suffixes = ["_train", "_valid"] if split else [""]
i = 0
while True:
index_suffix = f" ({i})" if i > 0 else ""
candidate_fnames = [
- os.path.splitext(fname)[0] + "_prepared" + suffix + index_suffix + ".jsonl"
- for suffix in suffixes
+ os.path.splitext(fname)[0] + "_prepared" + suffix + index_suffix + ".jsonl" for suffix in suffixes
]
if not any(os.path.isfile(f) for f in candidate_fnames):
return candidate_fnames
i += 1
-def get_classification_hyperparams(df):
+def get_classification_hyperparams(df: pd.DataFrame) -> tuple[int, object]:
n_classes = df.completion.nunique()
pos_class = None
if n_classes == 2:
@@ -666,7 +639,7 @@ def get_classification_hyperparams(df):
return n_classes, pos_class
-def write_out_file(df, fname, any_remediations, auto_accept):
+def write_out_file(df: pd.DataFrame, fname: str, any_remediations: bool, auto_accept: bool) -> None:
"""
This function will write out a dataframe to a file, if the user would like to proceed, and also offer a fine-tuning command with the newly created file.
For classification it will optionally ask the user if they would like to split the data into train/valid files, and modify the suggested command to include the valid set.
@@ -683,9 +656,7 @@ def write_out_file(df, fname, any_remediations, auto_accept):
additional_params = ""
common_prompt_suffix_new_line_handled = common_prompt_suffix.replace("\n", "\\n")
- common_completion_suffix_new_line_handled = common_completion_suffix.replace(
- "\n", "\\n"
- )
+ common_completion_suffix_new_line_handled = common_completion_suffix.replace("\n", "\\n")
optional_ending_string = (
f' Make sure to include `stop=["{common_completion_suffix_new_line_handled}"]` so that the generated texts ends at the expected place.'
if len(common_completion_suffix_new_line_handled) > 0
@@ -708,12 +679,10 @@ def write_out_file(df, fname, any_remediations, auto_accept):
n_train = max(len(df) - MAX_VALID_EXAMPLES, int(len(df) * 0.8))
df_train = df.sample(n=n_train, random_state=42)
df_valid = df.drop(df_train.index)
- df_train[["prompt", "completion"]].to_json(
+ df_train[["prompt", "completion"]].to_json( # type: ignore
fnames[0], lines=True, orient="records", force_ascii=False
)
- df_valid[["prompt", "completion"]].to_json(
- fnames[1], lines=True, orient="records", force_ascii=False
- )
+ df_valid[["prompt", "completion"]].to_json(fnames[1], lines=True, orient="records", force_ascii=False)
n_classes, pos_class = get_classification_hyperparams(df)
additional_params += " --compute_classification_metrics"
@@ -723,9 +692,7 @@ def write_out_file(df, fname, any_remediations, auto_accept):
additional_params += f" --classification_n_classes {n_classes}"
else:
assert len(fnames) == 1
- df[["prompt", "completion"]].to_json(
- fnames[0], lines=True, orient="records", force_ascii=False
- )
+ df[["prompt", "completion"]].to_json(fnames[0], lines=True, orient="records", force_ascii=False)
# Add -v VALID_FILE if we split the file into train / valid
files_string = ("s" if split else "") + " to `" + ("` and `".join(fnames))
@@ -743,7 +710,7 @@ def write_out_file(df, fname, any_remediations, auto_accept):
sys.stdout.write("Aborting... did not write the file\n")
-def infer_task_type(df):
+def infer_task_type(df: pd.DataFrame) -> str:
"""
Infer the likely fine-tuning task type from the data
"""
@@ -757,31 +724,28 @@ def infer_task_type(df):
return "conditional generation"
-def get_common_xfix(series, xfix="suffix"):
+def get_common_xfix(series: Any, xfix: str = "suffix") -> str:
"""
Finds the longest common suffix or prefix of all the values in a series
"""
common_xfix = ""
while True:
common_xfixes = (
- series.str[-(len(common_xfix) + 1) :]
- if xfix == "suffix"
- else series.str[: len(common_xfix) + 1]
+ series.str[-(len(common_xfix) + 1) :] if xfix == "suffix" else series.str[: len(common_xfix) + 1]
) # first few or last few characters
- if (
- common_xfixes.nunique() != 1
- ): # we found the character at which we don't have a unique xfix anymore
+ if common_xfixes.nunique() != 1: # we found the character at which we don't have a unique xfix anymore
break
- elif (
- common_xfix == common_xfixes.values[0]
- ): # the entire first row is a prefix of every other row
+ elif common_xfix == common_xfixes.values[0]: # the entire first row is a prefix of every other row
break
else: # the first or last few characters are still common across all rows - let's try to add one more
common_xfix = common_xfixes.values[0]
return common_xfix
-def get_validators():
+Validator: TypeAlias = "Callable[[pd.DataFrame], Remediation | None]"
+
+
+def get_validators() -> list[Validator]:
return [
num_examples_validator,
lambda x: necessary_column_validator(x, "prompt"),
@@ -802,14 +766,14 @@ def get_validators():
def apply_validators(
- df,
- fname,
- remediation,
- validators,
- auto_accept,
- write_out_file_func,
-):
- optional_remediations = []
+ df: pd.DataFrame,
+ fname: str,
+ remediation: Remediation | None,
+ validators: list[Validator],
+ auto_accept: bool,
+ write_out_file_func: Callable[..., Any],
+) -> None:
+ optional_remediations: list[Remediation] = []
if remediation is not None:
optional_remediations.append(remediation)
for validator in validators:
@@ -822,27 +786,18 @@ def apply_validators(
[
remediation
for remediation in optional_remediations
- if remediation.optional_msg is not None
- or remediation.necessary_msg is not None
+ if remediation.optional_msg is not None or remediation.necessary_msg is not None
]
)
any_necessary_applied = any(
- [
- remediation
- for remediation in optional_remediations
- if remediation.necessary_msg is not None
- ]
+ [remediation for remediation in optional_remediations if remediation.necessary_msg is not None]
)
any_optional_applied = False
if any_optional_or_necessary_remediations:
- sys.stdout.write(
- "\n\nBased on the analysis we will perform the following actions:\n"
- )
+ sys.stdout.write("\n\nBased on the analysis we will perform the following actions:\n")
for remediation in optional_remediations:
- df, optional_applied = apply_optional_remediation(
- df, remediation, auto_accept
- )
+ df, optional_applied = apply_optional_remediation(df, remediation, auto_accept)
any_optional_applied = any_optional_applied or optional_applied
else:
sys.stdout.write("\n\nNo remediations found.\n")
diff --git a/src/openai/lib/azure.py b/src/openai/lib/azure.py
new file mode 100644
index 0000000000..f5fcd24fd1
--- /dev/null
+++ b/src/openai/lib/azure.py
@@ -0,0 +1,439 @@
+from __future__ import annotations
+
+import os
+import inspect
+from typing import Any, Union, Mapping, TypeVar, Callable, Awaitable, overload
+from typing_extensions import override
+
+import httpx
+
+from .._types import NOT_GIVEN, Omit, Timeout, NotGiven
+from .._utils import is_given, is_mapping
+from .._client import OpenAI, AsyncOpenAI
+from .._models import FinalRequestOptions
+from .._streaming import Stream, AsyncStream
+from .._exceptions import OpenAIError
+from .._base_client import DEFAULT_MAX_RETRIES, BaseClient
+
+_deployments_endpoints = set(
+ [
+ "/completions",
+ "/chat/completions",
+ "/embeddings",
+ "/audio/transcriptions",
+ "/audio/translations",
+ ]
+)
+
+
+AzureADTokenProvider = Callable[[], str]
+AsyncAzureADTokenProvider = Callable[[], "str | Awaitable[str]"]
+_HttpxClientT = TypeVar("_HttpxClientT", bound=Union[httpx.Client, httpx.AsyncClient])
+_DefaultStreamT = TypeVar("_DefaultStreamT", bound=Union[Stream[Any], AsyncStream[Any]])
+
+
+# we need to use a sentinel API key value for Azure AD
+# as we don't want to make the `api_key` in the main client Optional
+# and Azure AD tokens may be retrieved on a per-request basis
+API_KEY_SENTINEL = "".join(["<", "missing API key", ">"])
+
+
+class MutuallyExclusiveAuthError(OpenAIError):
+ def __init__(self) -> None:
+ super().__init__(
+ "The `api_key`, `azure_ad_token` and `azure_ad_token_provider` arguments are mutually exclusive; Only one can be passed at a time"
+ )
+
+
+class BaseAzureClient(BaseClient[_HttpxClientT, _DefaultStreamT]):
+ @override
+ def _build_request(
+ self,
+ options: FinalRequestOptions,
+ ) -> httpx.Request:
+ if options.url in _deployments_endpoints and is_mapping(options.json_data):
+ model = options.json_data.get("model")
+ if model is not None and not "/deployments" in str(self.base_url):
+ options.url = f"/deployments/{model}{options.url}"
+
+ return super()._build_request(options)
+
+
+class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
+ @overload
+ def __init__(
+ self,
+ *,
+ azure_endpoint: str,
+ azure_deployment: str | None = None,
+ api_version: str | None = None,
+ api_key: str | None = None,
+ azure_ad_token: str | None = None,
+ azure_ad_token_provider: AzureADTokenProvider | None = None,
+ organization: str | None = None,
+ timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
+ max_retries: int = DEFAULT_MAX_RETRIES,
+ default_headers: Mapping[str, str] | None = None,
+ default_query: Mapping[str, object] | None = None,
+ http_client: httpx.Client | None = None,
+ _strict_response_validation: bool = False,
+ ) -> None:
+ ...
+
+ @overload
+ def __init__(
+ self,
+ *,
+ azure_deployment: str | None = None,
+ api_version: str | None = None,
+ api_key: str | None = None,
+ azure_ad_token: str | None = None,
+ azure_ad_token_provider: AzureADTokenProvider | None = None,
+ organization: str | None = None,
+ timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
+ max_retries: int = DEFAULT_MAX_RETRIES,
+ default_headers: Mapping[str, str] | None = None,
+ default_query: Mapping[str, object] | None = None,
+ http_client: httpx.Client | None = None,
+ _strict_response_validation: bool = False,
+ ) -> None:
+ ...
+
+ @overload
+ def __init__(
+ self,
+ *,
+ base_url: str,
+ api_version: str | None = None,
+ api_key: str | None = None,
+ azure_ad_token: str | None = None,
+ azure_ad_token_provider: AzureADTokenProvider | None = None,
+ organization: str | None = None,
+ timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
+ max_retries: int = DEFAULT_MAX_RETRIES,
+ default_headers: Mapping[str, str] | None = None,
+ default_query: Mapping[str, object] | None = None,
+ http_client: httpx.Client | None = None,
+ _strict_response_validation: bool = False,
+ ) -> None:
+ ...
+
+ def __init__(
+ self,
+ *,
+ api_version: str | None = None,
+ azure_endpoint: str | None = None,
+ azure_deployment: str | None = None,
+ api_key: str | None = None,
+ azure_ad_token: str | None = None,
+ azure_ad_token_provider: AzureADTokenProvider | None = None,
+ organization: str | None = None,
+ base_url: str | None = None,
+ timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
+ max_retries: int = DEFAULT_MAX_RETRIES,
+ default_headers: Mapping[str, str] | None = None,
+ default_query: Mapping[str, object] | None = None,
+ http_client: httpx.Client | None = None,
+ _strict_response_validation: bool = False,
+ ) -> None:
+ """Construct a new synchronous azure openai client instance.
+
+ This automatically infers the following arguments from their corresponding environment variables if they are not provided:
+ - `api_key` from `AZURE_OPENAI_API_KEY`
+ - `organization` from `OPENAI_ORG_ID`
+ - `azure_ad_token` from `AZURE_OPENAI_AD_TOKEN`
+ - `api_version` from `OPENAI_API_VERSION`
+ - `azure_endpoint` from `AZURE_OPENAI_ENDPOINT`
+
+ Args:
+ azure_endpoint: Your Azure endpoint, including the resource, e.g. `https://example-resource.azure.openai.com/`
+
+ azure_ad_token: Your Azure Active Directory token, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
+
+ azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request.
+
+ azure_deployment: A model deployment, if given sets the base client URL to include `/deployments/{azure_deployment}`.
+ Note: this means you won't be able to use non-deployment endpoints.
+ """
+ if api_key is None:
+ api_key = os.environ.get("AZURE_OPENAI_API_KEY")
+
+ if azure_ad_token is None:
+ azure_ad_token = os.environ.get("AZURE_OPENAI_AD_TOKEN")
+
+ if api_key is None and azure_ad_token is None and azure_ad_token_provider is None:
+ raise OpenAIError(
+ "Missing credentials. Please pass one of `api_key`, `azure_ad_token`, `azure_ad_token_provider`, or the `AZURE_OPENAI_API_KEY` or `AZURE_OPENAI_AD_TOKEN` environment variables."
+ )
+
+ if api_version is None:
+ api_version = os.environ.get("OPENAI_API_VERSION")
+
+ if api_version is None:
+ raise ValueError(
+ "Must provide either the `api_version` argument or the `OPENAI_API_VERSION` environment variable"
+ )
+
+ if default_query is None:
+ default_query = {"api-version": api_version}
+ else:
+ default_query = {"api-version": api_version, **default_query}
+
+ if base_url is None:
+ if azure_endpoint is None:
+ azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
+
+ if azure_endpoint is None:
+ raise ValueError(
+ "Must provide one of the `base_url` or `azure_endpoint` arguments, or the `AZURE_OPENAI_ENDPOINT` environment variable"
+ )
+
+ if azure_deployment is not None:
+ base_url = f"{azure_endpoint}/openai/deployments/{azure_deployment}"
+ else:
+ base_url = f"{azure_endpoint}/openai"
+ else:
+ if azure_endpoint is not None:
+ raise ValueError("base_url and azure_endpoint are mutually exclusive")
+
+ if api_key is None:
+ # define a sentinel value to avoid any typing issues
+ api_key = API_KEY_SENTINEL
+
+ super().__init__(
+ api_key=api_key,
+ organization=organization,
+ base_url=base_url,
+ timeout=timeout,
+ max_retries=max_retries,
+ default_headers=default_headers,
+ default_query=default_query,
+ http_client=http_client,
+ _strict_response_validation=_strict_response_validation,
+ )
+ self._azure_ad_token = azure_ad_token
+ self._azure_ad_token_provider = azure_ad_token_provider
+
+ def _get_azure_ad_token(self) -> str | None:
+ if self._azure_ad_token is not None:
+ return self._azure_ad_token
+
+ provider = self._azure_ad_token_provider
+ if provider is not None:
+ token = provider()
+ if not token or not isinstance(token, str): # pyright: ignore[reportUnnecessaryIsInstance]
+ raise ValueError(
+ f"Expected `azure_ad_token_provider` argument to return a string but it returned {token}",
+ )
+ return token
+
+ return None
+
+ @override
+ def _prepare_options(self, options: FinalRequestOptions) -> None:
+ headers: dict[str, str | Omit] = {**options.headers} if is_given(options.headers) else {}
+ options.headers = headers
+
+ azure_ad_token = self._get_azure_ad_token()
+ if azure_ad_token is not None:
+ if headers.get("Authorization") is None:
+ headers["Authorization"] = f"Bearer {azure_ad_token}"
+ elif self.api_key is not API_KEY_SENTINEL:
+ if headers.get("api-key") is None:
+ headers["api-key"] = self.api_key
+ else:
+ # should never be hit
+ raise ValueError("Unable to handle auth")
+
+ return super()._prepare_options(options)
+
+
+class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], AsyncOpenAI):
+ @overload
+ def __init__(
+ self,
+ *,
+ azure_endpoint: str,
+ azure_deployment: str | None = None,
+ api_version: str | None = None,
+ api_key: str | None = None,
+ azure_ad_token: str | None = None,
+ azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
+ organization: str | None = None,
+ timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
+ max_retries: int = DEFAULT_MAX_RETRIES,
+ default_headers: Mapping[str, str] | None = None,
+ default_query: Mapping[str, object] | None = None,
+ http_client: httpx.AsyncClient | None = None,
+ _strict_response_validation: bool = False,
+ ) -> None:
+ ...
+
+ @overload
+ def __init__(
+ self,
+ *,
+ azure_deployment: str | None = None,
+ api_version: str | None = None,
+ api_key: str | None = None,
+ azure_ad_token: str | None = None,
+ azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
+ organization: str | None = None,
+ timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
+ max_retries: int = DEFAULT_MAX_RETRIES,
+ default_headers: Mapping[str, str] | None = None,
+ default_query: Mapping[str, object] | None = None,
+ http_client: httpx.AsyncClient | None = None,
+ _strict_response_validation: bool = False,
+ ) -> None:
+ ...
+
+ @overload
+ def __init__(
+ self,
+ *,
+ base_url: str,
+ api_version: str | None = None,
+ api_key: str | None = None,
+ azure_ad_token: str | None = None,
+ azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
+ organization: str | None = None,
+ timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
+ max_retries: int = DEFAULT_MAX_RETRIES,
+ default_headers: Mapping[str, str] | None = None,
+ default_query: Mapping[str, object] | None = None,
+ http_client: httpx.AsyncClient | None = None,
+ _strict_response_validation: bool = False,
+ ) -> None:
+ ...
+
+ def __init__(
+ self,
+ *,
+ azure_endpoint: str | None = None,
+ azure_deployment: str | None = None,
+ api_version: str | None = None,
+ api_key: str | None = None,
+ azure_ad_token: str | None = None,
+ azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
+ organization: str | None = None,
+ base_url: str | None = None,
+ timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
+ max_retries: int = DEFAULT_MAX_RETRIES,
+ default_headers: Mapping[str, str] | None = None,
+ default_query: Mapping[str, object] | None = None,
+ http_client: httpx.AsyncClient | None = None,
+ _strict_response_validation: bool = False,
+ ) -> None:
+ """Construct a new asynchronous azure openai client instance.
+
+ This automatically infers the following arguments from their corresponding environment variables if they are not provided:
+ - `api_key` from `AZURE_OPENAI_API_KEY`
+ - `organization` from `OPENAI_ORG_ID`
+ - `azure_ad_token` from `AZURE_OPENAI_AD_TOKEN`
+ - `api_version` from `OPENAI_API_VERSION`
+ - `azure_endpoint` from `AZURE_OPENAI_ENDPOINT`
+
+ Args:
+ azure_endpoint: Your Azure endpoint, including the resource, e.g. `https://example-resource.azure.openai.com/`
+
+ azure_ad_token: Your Azure Active Directory token, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
+
+ azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request.
+
+ azure_deployment: A model deployment, if given sets the base client URL to include `/deployments/{azure_deployment}`.
+ Note: this means you won't be able to use non-deployment endpoints.
+ """
+ if api_key is None:
+ api_key = os.environ.get("AZURE_OPENAI_API_KEY")
+
+ if azure_ad_token is None:
+ azure_ad_token = os.environ.get("AZURE_OPENAI_AD_TOKEN")
+
+ if api_key is None and azure_ad_token is None and azure_ad_token_provider is None:
+ raise OpenAIError(
+ "Missing credentials. Please pass one of `api_key`, `azure_ad_token`, `azure_ad_token_provider`, or the `AZURE_OPENAI_API_KEY` or `AZURE_OPENAI_AD_TOKEN` environment variables."
+ )
+
+ if api_version is None:
+ api_version = os.environ.get("OPENAI_API_VERSION")
+
+ if api_version is None:
+ raise ValueError(
+ "Must provide either the `api_version` argument or the `OPENAI_API_VERSION` environment variable"
+ )
+
+ if default_query is None:
+ default_query = {"api-version": api_version}
+ else:
+ default_query = {"api-version": api_version, **default_query}
+
+ if base_url is None:
+ if azure_endpoint is None:
+ azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
+
+ if azure_endpoint is None:
+ raise ValueError(
+ "Must provide one of the `base_url` or `azure_endpoint` arguments, or the `AZURE_OPENAI_ENDPOINT` environment variable"
+ )
+
+ if azure_deployment is not None:
+ base_url = f"{azure_endpoint}/openai/deployments/{azure_deployment}"
+ else:
+ base_url = f"{azure_endpoint}/openai"
+ else:
+ if azure_endpoint is not None:
+ raise ValueError("base_url and azure_endpoint are mutually exclusive")
+
+ if api_key is None:
+ # define a sentinel value to avoid any typing issues
+ api_key = API_KEY_SENTINEL
+
+ super().__init__(
+ api_key=api_key,
+ organization=organization,
+ base_url=base_url,
+ timeout=timeout,
+ max_retries=max_retries,
+ default_headers=default_headers,
+ default_query=default_query,
+ http_client=http_client,
+ _strict_response_validation=_strict_response_validation,
+ )
+ self._azure_ad_token = azure_ad_token
+ self._azure_ad_token_provider = azure_ad_token_provider
+
+ async def _get_azure_ad_token(self) -> str | None:
+ if self._azure_ad_token is not None:
+ return self._azure_ad_token
+
+ provider = self._azure_ad_token_provider
+ if provider is not None:
+ token = provider()
+ if inspect.isawaitable(token):
+ token = await token
+ if not token or not isinstance(token, str):
+ raise ValueError(
+ f"Expected `azure_ad_token_provider` argument to return a string but it returned {token}",
+ )
+ return token
+
+ return None
+
+ @override
+ async def _prepare_options(self, options: FinalRequestOptions) -> None:
+ headers: dict[str, str | Omit] = {**options.headers} if is_given(options.headers) else {}
+ options.headers = headers
+
+ azure_ad_token = await self._get_azure_ad_token()
+ if azure_ad_token is not None:
+ if headers.get("Authorization") is None:
+ headers["Authorization"] = f"Bearer {azure_ad_token}"
+ elif self.api_key is not API_KEY_SENTINEL:
+ if headers.get("api-key") is None:
+ headers["api-key"] = self.api_key
+ else:
+ # should never be hit
+ raise ValueError("Unable to handle auth")
+
+ return await super()._prepare_options(options)
diff --git a/src/openai/pagination.py b/src/openai/pagination.py
new file mode 100644
index 0000000000..ff45f39517
--- /dev/null
+++ b/src/openai/pagination.py
@@ -0,0 +1,95 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from typing import Any, List, Generic, TypeVar, Optional, cast
+from typing_extensions import Protocol, override, runtime_checkable
+
+from ._types import ModelT
+from ._models import BaseModel
+from ._base_client import BasePage, PageInfo, BaseSyncPage, BaseAsyncPage
+
+__all__ = ["SyncPage", "AsyncPage", "SyncCursorPage", "AsyncCursorPage"]
+
+_BaseModelT = TypeVar("_BaseModelT", bound=BaseModel)
+
+
+@runtime_checkable
+class CursorPageItem(Protocol):
+ id: str
+
+
+class SyncPage(BaseSyncPage[ModelT], BasePage[ModelT], Generic[ModelT]):
+ """Note: no pagination actually occurs yet, this is for forwards-compatibility."""
+
+ data: List[ModelT]
+ object: str
+
+ @override
+ def _get_page_items(self) -> List[ModelT]:
+ return self.data
+
+ @override
+ def next_page_info(self) -> None:
+ """
+ This page represents a response that isn't actually paginated at the API level
+ so there will never be a next page.
+ """
+ return None
+
+
+class AsyncPage(BaseAsyncPage[ModelT], BasePage[ModelT], Generic[ModelT]):
+ """Note: no pagination actually occurs yet, this is for forwards-compatibility."""
+
+ data: List[ModelT]
+ object: str
+
+ @override
+ def _get_page_items(self) -> List[ModelT]:
+ return self.data
+
+ @override
+ def next_page_info(self) -> None:
+ """
+ This page represents a response that isn't actually paginated at the API level
+ so there will never be a next page.
+ """
+ return None
+
+
+class SyncCursorPage(BaseSyncPage[ModelT], BasePage[ModelT], Generic[ModelT]):
+ data: List[ModelT]
+
+ @override
+ def _get_page_items(self) -> List[ModelT]:
+ return self.data
+
+ @override
+ def next_page_info(self) -> Optional[PageInfo]:
+ if not self.data:
+ return None
+
+ item = cast(Any, self.data[-1])
+ if not isinstance(item, CursorPageItem):
+ # TODO emit warning log
+ return None
+
+ return PageInfo(params={"after": item.id})
+
+
+class AsyncCursorPage(BaseAsyncPage[ModelT], BasePage[ModelT], Generic[ModelT]):
+ data: List[ModelT]
+
+ @override
+ def _get_page_items(self) -> List[ModelT]:
+ return self.data
+
+ @override
+ def next_page_info(self) -> Optional[PageInfo]:
+ if not self.data:
+ return None
+
+ item = cast(Any, self.data[-1])
+ if not isinstance(item, CursorPageItem):
+ # TODO emit warning log
+ return None
+
+ return PageInfo(params={"after": item.id})
diff --git a/openai/py.typed b/src/openai/py.typed
similarity index 100%
rename from openai/py.typed
rename to src/openai/py.typed
diff --git a/src/openai/resources/__init__.py b/src/openai/resources/__init__.py
new file mode 100644
index 0000000000..e0a26c72d2
--- /dev/null
+++ b/src/openai/resources/__init__.py
@@ -0,0 +1,95 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from .chat import Chat, AsyncChat, ChatWithRawResponse, AsyncChatWithRawResponse
+from .audio import Audio, AsyncAudio, AudioWithRawResponse, AsyncAudioWithRawResponse
+from .edits import Edits, AsyncEdits, EditsWithRawResponse, AsyncEditsWithRawResponse
+from .files import Files, AsyncFiles, FilesWithRawResponse, AsyncFilesWithRawResponse
+from .images import (
+ Images,
+ AsyncImages,
+ ImagesWithRawResponse,
+ AsyncImagesWithRawResponse,
+)
+from .models import (
+ Models,
+ AsyncModels,
+ ModelsWithRawResponse,
+ AsyncModelsWithRawResponse,
+)
+from .embeddings import (
+ Embeddings,
+ AsyncEmbeddings,
+ EmbeddingsWithRawResponse,
+ AsyncEmbeddingsWithRawResponse,
+)
+from .fine_tunes import (
+ FineTunes,
+ AsyncFineTunes,
+ FineTunesWithRawResponse,
+ AsyncFineTunesWithRawResponse,
+)
+from .completions import (
+ Completions,
+ AsyncCompletions,
+ CompletionsWithRawResponse,
+ AsyncCompletionsWithRawResponse,
+)
+from .fine_tuning import (
+ FineTuning,
+ AsyncFineTuning,
+ FineTuningWithRawResponse,
+ AsyncFineTuningWithRawResponse,
+)
+from .moderations import (
+ Moderations,
+ AsyncModerations,
+ ModerationsWithRawResponse,
+ AsyncModerationsWithRawResponse,
+)
+
+__all__ = [
+ "Completions",
+ "AsyncCompletions",
+ "CompletionsWithRawResponse",
+ "AsyncCompletionsWithRawResponse",
+ "Chat",
+ "AsyncChat",
+ "ChatWithRawResponse",
+ "AsyncChatWithRawResponse",
+ "Edits",
+ "AsyncEdits",
+ "EditsWithRawResponse",
+ "AsyncEditsWithRawResponse",
+ "Embeddings",
+ "AsyncEmbeddings",
+ "EmbeddingsWithRawResponse",
+ "AsyncEmbeddingsWithRawResponse",
+ "Files",
+ "AsyncFiles",
+ "FilesWithRawResponse",
+ "AsyncFilesWithRawResponse",
+ "Images",
+ "AsyncImages",
+ "ImagesWithRawResponse",
+ "AsyncImagesWithRawResponse",
+ "Audio",
+ "AsyncAudio",
+ "AudioWithRawResponse",
+ "AsyncAudioWithRawResponse",
+ "Moderations",
+ "AsyncModerations",
+ "ModerationsWithRawResponse",
+ "AsyncModerationsWithRawResponse",
+ "Models",
+ "AsyncModels",
+ "ModelsWithRawResponse",
+ "AsyncModelsWithRawResponse",
+ "FineTuning",
+ "AsyncFineTuning",
+ "FineTuningWithRawResponse",
+ "AsyncFineTuningWithRawResponse",
+ "FineTunes",
+ "AsyncFineTunes",
+ "FineTunesWithRawResponse",
+ "AsyncFineTunesWithRawResponse",
+]
diff --git a/src/openai/resources/audio/__init__.py b/src/openai/resources/audio/__init__.py
new file mode 100644
index 0000000000..771bfe9da2
--- /dev/null
+++ b/src/openai/resources/audio/__init__.py
@@ -0,0 +1,30 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from .audio import Audio, AsyncAudio, AudioWithRawResponse, AsyncAudioWithRawResponse
+from .translations import (
+ Translations,
+ AsyncTranslations,
+ TranslationsWithRawResponse,
+ AsyncTranslationsWithRawResponse,
+)
+from .transcriptions import (
+ Transcriptions,
+ AsyncTranscriptions,
+ TranscriptionsWithRawResponse,
+ AsyncTranscriptionsWithRawResponse,
+)
+
+__all__ = [
+ "Transcriptions",
+ "AsyncTranscriptions",
+ "TranscriptionsWithRawResponse",
+ "AsyncTranscriptionsWithRawResponse",
+ "Translations",
+ "AsyncTranslations",
+ "TranslationsWithRawResponse",
+ "AsyncTranslationsWithRawResponse",
+ "Audio",
+ "AsyncAudio",
+ "AudioWithRawResponse",
+ "AsyncAudioWithRawResponse",
+]
diff --git a/src/openai/resources/audio/audio.py b/src/openai/resources/audio/audio.py
new file mode 100644
index 0000000000..8e8872c5b5
--- /dev/null
+++ b/src/openai/resources/audio/audio.py
@@ -0,0 +1,60 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from ..._resource import SyncAPIResource, AsyncAPIResource
+from .translations import (
+ Translations,
+ AsyncTranslations,
+ TranslationsWithRawResponse,
+ AsyncTranslationsWithRawResponse,
+)
+from .transcriptions import (
+ Transcriptions,
+ AsyncTranscriptions,
+ TranscriptionsWithRawResponse,
+ AsyncTranscriptionsWithRawResponse,
+)
+
+if TYPE_CHECKING:
+ from ..._client import OpenAI, AsyncOpenAI
+
+__all__ = ["Audio", "AsyncAudio"]
+
+
+class Audio(SyncAPIResource):
+ transcriptions: Transcriptions
+ translations: Translations
+ with_raw_response: AudioWithRawResponse
+
+ def __init__(self, client: OpenAI) -> None:
+ super().__init__(client)
+ self.transcriptions = Transcriptions(client)
+ self.translations = Translations(client)
+ self.with_raw_response = AudioWithRawResponse(self)
+
+
+class AsyncAudio(AsyncAPIResource):
+ transcriptions: AsyncTranscriptions
+ translations: AsyncTranslations
+ with_raw_response: AsyncAudioWithRawResponse
+
+ def __init__(self, client: AsyncOpenAI) -> None:
+ super().__init__(client)
+ self.transcriptions = AsyncTranscriptions(client)
+ self.translations = AsyncTranslations(client)
+ self.with_raw_response = AsyncAudioWithRawResponse(self)
+
+
+class AudioWithRawResponse:
+ def __init__(self, audio: Audio) -> None:
+ self.transcriptions = TranscriptionsWithRawResponse(audio.transcriptions)
+ self.translations = TranslationsWithRawResponse(audio.translations)
+
+
+class AsyncAudioWithRawResponse:
+ def __init__(self, audio: AsyncAudio) -> None:
+ self.transcriptions = AsyncTranscriptionsWithRawResponse(audio.transcriptions)
+ self.translations = AsyncTranslationsWithRawResponse(audio.translations)
diff --git a/src/openai/resources/audio/transcriptions.py b/src/openai/resources/audio/transcriptions.py
new file mode 100644
index 0000000000..ca61f8bd42
--- /dev/null
+++ b/src/openai/resources/audio/transcriptions.py
@@ -0,0 +1,206 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Union, Mapping, cast
+from typing_extensions import Literal
+
+from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes
+from ..._utils import extract_files, maybe_transform, deepcopy_minimal
+from ..._resource import SyncAPIResource, AsyncAPIResource
+from ..._response import to_raw_response_wrapper, async_to_raw_response_wrapper
+from ...types.audio import Transcription, transcription_create_params
+from ..._base_client import make_request_options
+
+if TYPE_CHECKING:
+ from ..._client import OpenAI, AsyncOpenAI
+
+__all__ = ["Transcriptions", "AsyncTranscriptions"]
+
+
+class Transcriptions(SyncAPIResource):
+ with_raw_response: TranscriptionsWithRawResponse
+
+ def __init__(self, client: OpenAI) -> None:
+ super().__init__(client)
+ self.with_raw_response = TranscriptionsWithRawResponse(self)
+
+ def create(
+ self,
+ *,
+ file: FileTypes,
+ model: Union[str, Literal["whisper-1"]],
+ language: str | NotGiven = NOT_GIVEN,
+ prompt: str | NotGiven = NOT_GIVEN,
+ response_format: Literal["json", "text", "srt", "verbose_json", "vtt"] | NotGiven = NOT_GIVEN,
+ temperature: float | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> Transcription:
+ """
+ Transcribes audio into the input language.
+
+ Args:
+ file:
+ The audio file object (not file name) to transcribe, in one of these formats:
+ flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
+
+ model: ID of the model to use. Only `whisper-1` is currently available.
+
+ language: The language of the input audio. Supplying the input language in
+ [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will
+ improve accuracy and latency.
+
+ prompt: An optional text to guide the model's style or continue a previous audio
+ segment. The
+ [prompt](https://platform.openai.com/docs/guides/speech-to-text/prompting)
+ should match the audio language.
+
+ response_format: The format of the transcript output, in one of these options: json, text, srt,
+ verbose_json, or vtt.
+
+ temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the
+ output more random, while lower values like 0.2 will make it more focused and
+ deterministic. If set to 0, the model will use
+ [log probability](https://en.wikipedia.org/wiki/Log_probability) to
+ automatically increase the temperature until certain thresholds are hit.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ body = deepcopy_minimal(
+ {
+ "file": file,
+ "model": model,
+ "language": language,
+ "prompt": prompt,
+ "response_format": response_format,
+ "temperature": temperature,
+ }
+ )
+ files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
+ if files:
+ # It should be noted that the actual Content-Type header that will be
+ # sent to the server will contain a `boundary` parameter, e.g.
+ # multipart/form-data; boundary=---abc--
+ extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
+
+ return self._post(
+ "/audio/transcriptions",
+ body=maybe_transform(body, transcription_create_params.TranscriptionCreateParams),
+ files=files,
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=Transcription,
+ )
+
+
+class AsyncTranscriptions(AsyncAPIResource):
+ with_raw_response: AsyncTranscriptionsWithRawResponse
+
+ def __init__(self, client: AsyncOpenAI) -> None:
+ super().__init__(client)
+ self.with_raw_response = AsyncTranscriptionsWithRawResponse(self)
+
+ async def create(
+ self,
+ *,
+ file: FileTypes,
+ model: Union[str, Literal["whisper-1"]],
+ language: str | NotGiven = NOT_GIVEN,
+ prompt: str | NotGiven = NOT_GIVEN,
+ response_format: Literal["json", "text", "srt", "verbose_json", "vtt"] | NotGiven = NOT_GIVEN,
+ temperature: float | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> Transcription:
+ """
+ Transcribes audio into the input language.
+
+ Args:
+ file:
+ The audio file object (not file name) to transcribe, in one of these formats:
+ flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
+
+ model: ID of the model to use. Only `whisper-1` is currently available.
+
+ language: The language of the input audio. Supplying the input language in
+ [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will
+ improve accuracy and latency.
+
+ prompt: An optional text to guide the model's style or continue a previous audio
+ segment. The
+ [prompt](https://platform.openai.com/docs/guides/speech-to-text/prompting)
+ should match the audio language.
+
+ response_format: The format of the transcript output, in one of these options: json, text, srt,
+ verbose_json, or vtt.
+
+ temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the
+ output more random, while lower values like 0.2 will make it more focused and
+ deterministic. If set to 0, the model will use
+ [log probability](https://en.wikipedia.org/wiki/Log_probability) to
+ automatically increase the temperature until certain thresholds are hit.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ body = deepcopy_minimal(
+ {
+ "file": file,
+ "model": model,
+ "language": language,
+ "prompt": prompt,
+ "response_format": response_format,
+ "temperature": temperature,
+ }
+ )
+ files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
+ if files:
+ # It should be noted that the actual Content-Type header that will be
+ # sent to the server will contain a `boundary` parameter, e.g.
+ # multipart/form-data; boundary=---abc--
+ extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
+
+ return await self._post(
+ "/audio/transcriptions",
+ body=maybe_transform(body, transcription_create_params.TranscriptionCreateParams),
+ files=files,
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=Transcription,
+ )
+
+
+class TranscriptionsWithRawResponse:
+ def __init__(self, transcriptions: Transcriptions) -> None:
+ self.create = to_raw_response_wrapper(
+ transcriptions.create,
+ )
+
+
+class AsyncTranscriptionsWithRawResponse:
+ def __init__(self, transcriptions: AsyncTranscriptions) -> None:
+ self.create = async_to_raw_response_wrapper(
+ transcriptions.create,
+ )
diff --git a/src/openai/resources/audio/translations.py b/src/openai/resources/audio/translations.py
new file mode 100644
index 0000000000..0b499b9865
--- /dev/null
+++ b/src/openai/resources/audio/translations.py
@@ -0,0 +1,192 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Union, Mapping, cast
+from typing_extensions import Literal
+
+from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes
+from ..._utils import extract_files, maybe_transform, deepcopy_minimal
+from ..._resource import SyncAPIResource, AsyncAPIResource
+from ..._response import to_raw_response_wrapper, async_to_raw_response_wrapper
+from ...types.audio import Translation, translation_create_params
+from ..._base_client import make_request_options
+
+if TYPE_CHECKING:
+ from ..._client import OpenAI, AsyncOpenAI
+
+__all__ = ["Translations", "AsyncTranslations"]
+
+
+class Translations(SyncAPIResource):
+ with_raw_response: TranslationsWithRawResponse
+
+ def __init__(self, client: OpenAI) -> None:
+ super().__init__(client)
+ self.with_raw_response = TranslationsWithRawResponse(self)
+
+ def create(
+ self,
+ *,
+ file: FileTypes,
+ model: Union[str, Literal["whisper-1"]],
+ prompt: str | NotGiven = NOT_GIVEN,
+ response_format: str | NotGiven = NOT_GIVEN,
+ temperature: float | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> Translation:
+ """
+ Translates audio into English.
+
+ Args:
+ file: The audio file object (not file name) translate, in one of these formats: flac,
+ mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
+
+ model: ID of the model to use. Only `whisper-1` is currently available.
+
+ prompt: An optional text to guide the model's style or continue a previous audio
+ segment. The
+ [prompt](https://platform.openai.com/docs/guides/speech-to-text/prompting)
+ should be in English.
+
+ response_format: The format of the transcript output, in one of these options: json, text, srt,
+ verbose_json, or vtt.
+
+ temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the
+ output more random, while lower values like 0.2 will make it more focused and
+ deterministic. If set to 0, the model will use
+ [log probability](https://en.wikipedia.org/wiki/Log_probability) to
+ automatically increase the temperature until certain thresholds are hit.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ body = deepcopy_minimal(
+ {
+ "file": file,
+ "model": model,
+ "prompt": prompt,
+ "response_format": response_format,
+ "temperature": temperature,
+ }
+ )
+ files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
+ if files:
+ # It should be noted that the actual Content-Type header that will be
+ # sent to the server will contain a `boundary` parameter, e.g.
+ # multipart/form-data; boundary=---abc--
+ extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
+
+ return self._post(
+ "/audio/translations",
+ body=maybe_transform(body, translation_create_params.TranslationCreateParams),
+ files=files,
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=Translation,
+ )
+
+
+class AsyncTranslations(AsyncAPIResource):
+ with_raw_response: AsyncTranslationsWithRawResponse
+
+ def __init__(self, client: AsyncOpenAI) -> None:
+ super().__init__(client)
+ self.with_raw_response = AsyncTranslationsWithRawResponse(self)
+
+ async def create(
+ self,
+ *,
+ file: FileTypes,
+ model: Union[str, Literal["whisper-1"]],
+ prompt: str | NotGiven = NOT_GIVEN,
+ response_format: str | NotGiven = NOT_GIVEN,
+ temperature: float | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> Translation:
+ """
+ Translates audio into English.
+
+ Args:
+ file: The audio file object (not file name) translate, in one of these formats: flac,
+ mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
+
+ model: ID of the model to use. Only `whisper-1` is currently available.
+
+ prompt: An optional text to guide the model's style or continue a previous audio
+ segment. The
+ [prompt](https://platform.openai.com/docs/guides/speech-to-text/prompting)
+ should be in English.
+
+ response_format: The format of the transcript output, in one of these options: json, text, srt,
+ verbose_json, or vtt.
+
+ temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the
+ output more random, while lower values like 0.2 will make it more focused and
+ deterministic. If set to 0, the model will use
+ [log probability](https://en.wikipedia.org/wiki/Log_probability) to
+ automatically increase the temperature until certain thresholds are hit.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ body = deepcopy_minimal(
+ {
+ "file": file,
+ "model": model,
+ "prompt": prompt,
+ "response_format": response_format,
+ "temperature": temperature,
+ }
+ )
+ files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
+ if files:
+ # It should be noted that the actual Content-Type header that will be
+ # sent to the server will contain a `boundary` parameter, e.g.
+ # multipart/form-data; boundary=---abc--
+ extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
+
+ return await self._post(
+ "/audio/translations",
+ body=maybe_transform(body, translation_create_params.TranslationCreateParams),
+ files=files,
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=Translation,
+ )
+
+
+class TranslationsWithRawResponse:
+ def __init__(self, translations: Translations) -> None:
+ self.create = to_raw_response_wrapper(
+ translations.create,
+ )
+
+
+class AsyncTranslationsWithRawResponse:
+ def __init__(self, translations: AsyncTranslations) -> None:
+ self.create = async_to_raw_response_wrapper(
+ translations.create,
+ )
diff --git a/src/openai/resources/chat/__init__.py b/src/openai/resources/chat/__init__.py
new file mode 100644
index 0000000000..2e56c0cbfa
--- /dev/null
+++ b/src/openai/resources/chat/__init__.py
@@ -0,0 +1,20 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from .chat import Chat, AsyncChat, ChatWithRawResponse, AsyncChatWithRawResponse
+from .completions import (
+ Completions,
+ AsyncCompletions,
+ CompletionsWithRawResponse,
+ AsyncCompletionsWithRawResponse,
+)
+
+__all__ = [
+ "Completions",
+ "AsyncCompletions",
+ "CompletionsWithRawResponse",
+ "AsyncCompletionsWithRawResponse",
+ "Chat",
+ "AsyncChat",
+ "ChatWithRawResponse",
+ "AsyncChatWithRawResponse",
+]
diff --git a/src/openai/resources/chat/chat.py b/src/openai/resources/chat/chat.py
new file mode 100644
index 0000000000..3847b20512
--- /dev/null
+++ b/src/openai/resources/chat/chat.py
@@ -0,0 +1,48 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from ..._resource import SyncAPIResource, AsyncAPIResource
+from .completions import (
+ Completions,
+ AsyncCompletions,
+ CompletionsWithRawResponse,
+ AsyncCompletionsWithRawResponse,
+)
+
+if TYPE_CHECKING:
+ from ..._client import OpenAI, AsyncOpenAI
+
+__all__ = ["Chat", "AsyncChat"]
+
+
+class Chat(SyncAPIResource):
+ completions: Completions
+ with_raw_response: ChatWithRawResponse
+
+ def __init__(self, client: OpenAI) -> None:
+ super().__init__(client)
+ self.completions = Completions(client)
+ self.with_raw_response = ChatWithRawResponse(self)
+
+
+class AsyncChat(AsyncAPIResource):
+ completions: AsyncCompletions
+ with_raw_response: AsyncChatWithRawResponse
+
+ def __init__(self, client: AsyncOpenAI) -> None:
+ super().__init__(client)
+ self.completions = AsyncCompletions(client)
+ self.with_raw_response = AsyncChatWithRawResponse(self)
+
+
+class ChatWithRawResponse:
+ def __init__(self, chat: Chat) -> None:
+ self.completions = CompletionsWithRawResponse(chat.completions)
+
+
+class AsyncChatWithRawResponse:
+ def __init__(self, chat: AsyncChat) -> None:
+ self.completions = AsyncCompletionsWithRawResponse(chat.completions)
diff --git a/src/openai/resources/chat/completions.py b/src/openai/resources/chat/completions.py
new file mode 100644
index 0000000000..e6e6ce52b8
--- /dev/null
+++ b/src/openai/resources/chat/completions.py
@@ -0,0 +1,942 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Dict, List, Union, Optional, overload
+from typing_extensions import Literal
+
+from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
+from ..._utils import required_args, maybe_transform
+from ..._resource import SyncAPIResource, AsyncAPIResource
+from ..._response import to_raw_response_wrapper, async_to_raw_response_wrapper
+from ..._streaming import Stream, AsyncStream
+from ...types.chat import (
+ ChatCompletion,
+ ChatCompletionChunk,
+ ChatCompletionMessageParam,
+ completion_create_params,
+)
+from ..._base_client import make_request_options
+
+if TYPE_CHECKING:
+ from ..._client import OpenAI, AsyncOpenAI
+
+__all__ = ["Completions", "AsyncCompletions"]
+
+
+class Completions(SyncAPIResource):
+ with_raw_response: CompletionsWithRawResponse
+
+ def __init__(self, client: OpenAI) -> None:
+ super().__init__(client)
+ self.with_raw_response = CompletionsWithRawResponse(self)
+
+ @overload
+ def create(
+ self,
+ *,
+ messages: List[ChatCompletionMessageParam],
+ model: Union[
+ str,
+ Literal[
+ "gpt-4",
+ "gpt-4-0314",
+ "gpt-4-0613",
+ "gpt-4-32k",
+ "gpt-4-32k-0314",
+ "gpt-4-32k-0613",
+ "gpt-3.5-turbo",
+ "gpt-3.5-turbo-16k",
+ "gpt-3.5-turbo-0301",
+ "gpt-3.5-turbo-0613",
+ "gpt-3.5-turbo-16k-0613",
+ ],
+ ],
+ frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ function_call: completion_create_params.FunctionCall | NotGiven = NOT_GIVEN,
+ functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN,
+ logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
+ max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
+ n: Optional[int] | NotGiven = NOT_GIVEN,
+ presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
+ stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
+ temperature: Optional[float] | NotGiven = NOT_GIVEN,
+ top_p: Optional[float] | NotGiven = NOT_GIVEN,
+ user: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> ChatCompletion:
+ """
+ Creates a model response for the given chat conversation.
+
+ Args:
+ messages: A list of messages comprising the conversation so far.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models).
+
+ model: ID of the model to use. See the
+ [model endpoint compatibility](https://platform.openai.com/docs/models/model-endpoint-compatibility)
+ table for details on which models work with the Chat API.
+
+ frequency_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on their
+ existing frequency in the text so far, decreasing the model's likelihood to
+ repeat the same line verbatim.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+
+ function_call: Controls how the model calls functions. "none" means the model will not call a
+ function and instead generates a message. "auto" means the model can pick
+ between generating a message or calling a function. Specifying a particular
+ function via `{"name": "my_function"}` forces the model to call that function.
+ "none" is the default when no functions are present. "auto" is the default if
+ functions are present.
+
+ functions: A list of functions the model may generate JSON inputs for.
+
+ logit_bias: Modify the likelihood of specified tokens appearing in the completion.
+
+ Accepts a json object that maps tokens (specified by their token ID in the
+ tokenizer) to an associated bias value from -100 to 100. Mathematically, the
+ bias is added to the logits generated by the model prior to sampling. The exact
+ effect will vary per model, but values between -1 and 1 should decrease or
+ increase likelihood of selection; values like -100 or 100 should result in a ban
+ or exclusive selection of the relevant token.
+
+ max_tokens: The maximum number of [tokens](/tokenizer) to generate in the chat completion.
+
+ The total length of input tokens and generated tokens is limited by the model's
+ context length.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)
+ for counting tokens.
+
+ n: How many chat completion choices to generate for each input message.
+
+ presence_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on
+ whether they appear in the text so far, increasing the model's likelihood to
+ talk about new topics.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+
+ stop: Up to 4 sequences where the API will stop generating further tokens.
+
+ stream: If set, partial message deltas will be sent, like in ChatGPT. Tokens will be
+ sent as data-only
+ [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
+ as they become available, with the stream terminated by a `data: [DONE]`
+ message.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).
+
+ temperature: What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
+ make the output more random, while lower values like 0.2 will make it more
+ focused and deterministic.
+
+ We generally recommend altering this or `top_p` but not both.
+
+ top_p: An alternative to sampling with temperature, called nucleus sampling, where the
+ model considers the results of the tokens with top_p probability mass. So 0.1
+ means only the tokens comprising the top 10% probability mass are considered.
+
+ We generally recommend altering this or `temperature` but not both.
+
+ user: A unique identifier representing your end-user, which can help OpenAI to monitor
+ and detect abuse.
+ [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ ...
+
+ @overload
+ def create(
+ self,
+ *,
+ messages: List[ChatCompletionMessageParam],
+ model: Union[
+ str,
+ Literal[
+ "gpt-4",
+ "gpt-4-0314",
+ "gpt-4-0613",
+ "gpt-4-32k",
+ "gpt-4-32k-0314",
+ "gpt-4-32k-0613",
+ "gpt-3.5-turbo",
+ "gpt-3.5-turbo-16k",
+ "gpt-3.5-turbo-0301",
+ "gpt-3.5-turbo-0613",
+ "gpt-3.5-turbo-16k-0613",
+ ],
+ ],
+ stream: Literal[True],
+ frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ function_call: completion_create_params.FunctionCall | NotGiven = NOT_GIVEN,
+ functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN,
+ logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
+ max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
+ n: Optional[int] | NotGiven = NOT_GIVEN,
+ presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
+ temperature: Optional[float] | NotGiven = NOT_GIVEN,
+ top_p: Optional[float] | NotGiven = NOT_GIVEN,
+ user: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> Stream[ChatCompletionChunk]:
+ """
+ Creates a model response for the given chat conversation.
+
+ Args:
+ messages: A list of messages comprising the conversation so far.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models).
+
+ model: ID of the model to use. See the
+ [model endpoint compatibility](https://platform.openai.com/docs/models/model-endpoint-compatibility)
+ table for details on which models work with the Chat API.
+
+ stream: If set, partial message deltas will be sent, like in ChatGPT. Tokens will be
+ sent as data-only
+ [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
+ as they become available, with the stream terminated by a `data: [DONE]`
+ message.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).
+
+ frequency_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on their
+ existing frequency in the text so far, decreasing the model's likelihood to
+ repeat the same line verbatim.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+
+ function_call: Controls how the model calls functions. "none" means the model will not call a
+ function and instead generates a message. "auto" means the model can pick
+ between generating a message or calling a function. Specifying a particular
+ function via `{"name": "my_function"}` forces the model to call that function.
+ "none" is the default when no functions are present. "auto" is the default if
+ functions are present.
+
+ functions: A list of functions the model may generate JSON inputs for.
+
+ logit_bias: Modify the likelihood of specified tokens appearing in the completion.
+
+ Accepts a json object that maps tokens (specified by their token ID in the
+ tokenizer) to an associated bias value from -100 to 100. Mathematically, the
+ bias is added to the logits generated by the model prior to sampling. The exact
+ effect will vary per model, but values between -1 and 1 should decrease or
+ increase likelihood of selection; values like -100 or 100 should result in a ban
+ or exclusive selection of the relevant token.
+
+ max_tokens: The maximum number of [tokens](/tokenizer) to generate in the chat completion.
+
+ The total length of input tokens and generated tokens is limited by the model's
+ context length.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)
+ for counting tokens.
+
+ n: How many chat completion choices to generate for each input message.
+
+ presence_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on
+ whether they appear in the text so far, increasing the model's likelihood to
+ talk about new topics.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+
+ stop: Up to 4 sequences where the API will stop generating further tokens.
+
+ temperature: What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
+ make the output more random, while lower values like 0.2 will make it more
+ focused and deterministic.
+
+ We generally recommend altering this or `top_p` but not both.
+
+ top_p: An alternative to sampling with temperature, called nucleus sampling, where the
+ model considers the results of the tokens with top_p probability mass. So 0.1
+ means only the tokens comprising the top 10% probability mass are considered.
+
+ We generally recommend altering this or `temperature` but not both.
+
+ user: A unique identifier representing your end-user, which can help OpenAI to monitor
+ and detect abuse.
+ [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ ...
+
+ @overload
+ def create(
+ self,
+ *,
+ messages: List[ChatCompletionMessageParam],
+ model: Union[
+ str,
+ Literal[
+ "gpt-4",
+ "gpt-4-0314",
+ "gpt-4-0613",
+ "gpt-4-32k",
+ "gpt-4-32k-0314",
+ "gpt-4-32k-0613",
+ "gpt-3.5-turbo",
+ "gpt-3.5-turbo-16k",
+ "gpt-3.5-turbo-0301",
+ "gpt-3.5-turbo-0613",
+ "gpt-3.5-turbo-16k-0613",
+ ],
+ ],
+ stream: bool,
+ frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ function_call: completion_create_params.FunctionCall | NotGiven = NOT_GIVEN,
+ functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN,
+ logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
+ max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
+ n: Optional[int] | NotGiven = NOT_GIVEN,
+ presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
+ temperature: Optional[float] | NotGiven = NOT_GIVEN,
+ top_p: Optional[float] | NotGiven = NOT_GIVEN,
+ user: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> ChatCompletion | Stream[ChatCompletionChunk]:
+ """
+ Creates a model response for the given chat conversation.
+
+ Args:
+ messages: A list of messages comprising the conversation so far.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models).
+
+ model: ID of the model to use. See the
+ [model endpoint compatibility](https://platform.openai.com/docs/models/model-endpoint-compatibility)
+ table for details on which models work with the Chat API.
+
+ stream: If set, partial message deltas will be sent, like in ChatGPT. Tokens will be
+ sent as data-only
+ [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
+ as they become available, with the stream terminated by a `data: [DONE]`
+ message.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).
+
+ frequency_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on their
+ existing frequency in the text so far, decreasing the model's likelihood to
+ repeat the same line verbatim.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+
+ function_call: Controls how the model calls functions. "none" means the model will not call a
+ function and instead generates a message. "auto" means the model can pick
+ between generating a message or calling a function. Specifying a particular
+ function via `{"name": "my_function"}` forces the model to call that function.
+ "none" is the default when no functions are present. "auto" is the default if
+ functions are present.
+
+ functions: A list of functions the model may generate JSON inputs for.
+
+ logit_bias: Modify the likelihood of specified tokens appearing in the completion.
+
+ Accepts a json object that maps tokens (specified by their token ID in the
+ tokenizer) to an associated bias value from -100 to 100. Mathematically, the
+ bias is added to the logits generated by the model prior to sampling. The exact
+ effect will vary per model, but values between -1 and 1 should decrease or
+ increase likelihood of selection; values like -100 or 100 should result in a ban
+ or exclusive selection of the relevant token.
+
+ max_tokens: The maximum number of [tokens](/tokenizer) to generate in the chat completion.
+
+ The total length of input tokens and generated tokens is limited by the model's
+ context length.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)
+ for counting tokens.
+
+ n: How many chat completion choices to generate for each input message.
+
+ presence_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on
+ whether they appear in the text so far, increasing the model's likelihood to
+ talk about new topics.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+
+ stop: Up to 4 sequences where the API will stop generating further tokens.
+
+ temperature: What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
+ make the output more random, while lower values like 0.2 will make it more
+ focused and deterministic.
+
+ We generally recommend altering this or `top_p` but not both.
+
+ top_p: An alternative to sampling with temperature, called nucleus sampling, where the
+ model considers the results of the tokens with top_p probability mass. So 0.1
+ means only the tokens comprising the top 10% probability mass are considered.
+
+ We generally recommend altering this or `temperature` but not both.
+
+ user: A unique identifier representing your end-user, which can help OpenAI to monitor
+ and detect abuse.
+ [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ ...
+
+ @required_args(["messages", "model"], ["messages", "model", "stream"])
+ def create(
+ self,
+ *,
+ messages: List[ChatCompletionMessageParam],
+ model: Union[
+ str,
+ Literal[
+ "gpt-4",
+ "gpt-4-0314",
+ "gpt-4-0613",
+ "gpt-4-32k",
+ "gpt-4-32k-0314",
+ "gpt-4-32k-0613",
+ "gpt-3.5-turbo",
+ "gpt-3.5-turbo-16k",
+ "gpt-3.5-turbo-0301",
+ "gpt-3.5-turbo-0613",
+ "gpt-3.5-turbo-16k-0613",
+ ],
+ ],
+ frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ function_call: completion_create_params.FunctionCall | NotGiven = NOT_GIVEN,
+ functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN,
+ logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
+ max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
+ n: Optional[int] | NotGiven = NOT_GIVEN,
+ presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
+ stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
+ temperature: Optional[float] | NotGiven = NOT_GIVEN,
+ top_p: Optional[float] | NotGiven = NOT_GIVEN,
+ user: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> ChatCompletion | Stream[ChatCompletionChunk]:
+ return self._post(
+ "/chat/completions",
+ body=maybe_transform(
+ {
+ "messages": messages,
+ "model": model,
+ "frequency_penalty": frequency_penalty,
+ "function_call": function_call,
+ "functions": functions,
+ "logit_bias": logit_bias,
+ "max_tokens": max_tokens,
+ "n": n,
+ "presence_penalty": presence_penalty,
+ "stop": stop,
+ "stream": stream,
+ "temperature": temperature,
+ "top_p": top_p,
+ "user": user,
+ },
+ completion_create_params.CompletionCreateParams,
+ ),
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=ChatCompletion,
+ stream=stream or False,
+ stream_cls=Stream[ChatCompletionChunk],
+ )
+
+
+class AsyncCompletions(AsyncAPIResource):
+ with_raw_response: AsyncCompletionsWithRawResponse
+
+ def __init__(self, client: AsyncOpenAI) -> None:
+ super().__init__(client)
+ self.with_raw_response = AsyncCompletionsWithRawResponse(self)
+
+ @overload
+ async def create(
+ self,
+ *,
+ messages: List[ChatCompletionMessageParam],
+ model: Union[
+ str,
+ Literal[
+ "gpt-4",
+ "gpt-4-0314",
+ "gpt-4-0613",
+ "gpt-4-32k",
+ "gpt-4-32k-0314",
+ "gpt-4-32k-0613",
+ "gpt-3.5-turbo",
+ "gpt-3.5-turbo-16k",
+ "gpt-3.5-turbo-0301",
+ "gpt-3.5-turbo-0613",
+ "gpt-3.5-turbo-16k-0613",
+ ],
+ ],
+ frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ function_call: completion_create_params.FunctionCall | NotGiven = NOT_GIVEN,
+ functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN,
+ logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
+ max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
+ n: Optional[int] | NotGiven = NOT_GIVEN,
+ presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
+ stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
+ temperature: Optional[float] | NotGiven = NOT_GIVEN,
+ top_p: Optional[float] | NotGiven = NOT_GIVEN,
+ user: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> ChatCompletion:
+ """
+ Creates a model response for the given chat conversation.
+
+ Args:
+ messages: A list of messages comprising the conversation so far.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models).
+
+ model: ID of the model to use. See the
+ [model endpoint compatibility](https://platform.openai.com/docs/models/model-endpoint-compatibility)
+ table for details on which models work with the Chat API.
+
+ frequency_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on their
+ existing frequency in the text so far, decreasing the model's likelihood to
+ repeat the same line verbatim.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+
+ function_call: Controls how the model calls functions. "none" means the model will not call a
+ function and instead generates a message. "auto" means the model can pick
+ between generating a message or calling a function. Specifying a particular
+ function via `{"name": "my_function"}` forces the model to call that function.
+ "none" is the default when no functions are present. "auto" is the default if
+ functions are present.
+
+ functions: A list of functions the model may generate JSON inputs for.
+
+ logit_bias: Modify the likelihood of specified tokens appearing in the completion.
+
+ Accepts a json object that maps tokens (specified by their token ID in the
+ tokenizer) to an associated bias value from -100 to 100. Mathematically, the
+ bias is added to the logits generated by the model prior to sampling. The exact
+ effect will vary per model, but values between -1 and 1 should decrease or
+ increase likelihood of selection; values like -100 or 100 should result in a ban
+ or exclusive selection of the relevant token.
+
+ max_tokens: The maximum number of [tokens](/tokenizer) to generate in the chat completion.
+
+ The total length of input tokens and generated tokens is limited by the model's
+ context length.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)
+ for counting tokens.
+
+ n: How many chat completion choices to generate for each input message.
+
+ presence_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on
+ whether they appear in the text so far, increasing the model's likelihood to
+ talk about new topics.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+
+ stop: Up to 4 sequences where the API will stop generating further tokens.
+
+ stream: If set, partial message deltas will be sent, like in ChatGPT. Tokens will be
+ sent as data-only
+ [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
+ as they become available, with the stream terminated by a `data: [DONE]`
+ message.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).
+
+ temperature: What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
+ make the output more random, while lower values like 0.2 will make it more
+ focused and deterministic.
+
+ We generally recommend altering this or `top_p` but not both.
+
+ top_p: An alternative to sampling with temperature, called nucleus sampling, where the
+ model considers the results of the tokens with top_p probability mass. So 0.1
+ means only the tokens comprising the top 10% probability mass are considered.
+
+ We generally recommend altering this or `temperature` but not both.
+
+ user: A unique identifier representing your end-user, which can help OpenAI to monitor
+ and detect abuse.
+ [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ ...
+
+ @overload
+ async def create(
+ self,
+ *,
+ messages: List[ChatCompletionMessageParam],
+ model: Union[
+ str,
+ Literal[
+ "gpt-4",
+ "gpt-4-0314",
+ "gpt-4-0613",
+ "gpt-4-32k",
+ "gpt-4-32k-0314",
+ "gpt-4-32k-0613",
+ "gpt-3.5-turbo",
+ "gpt-3.5-turbo-16k",
+ "gpt-3.5-turbo-0301",
+ "gpt-3.5-turbo-0613",
+ "gpt-3.5-turbo-16k-0613",
+ ],
+ ],
+ stream: Literal[True],
+ frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ function_call: completion_create_params.FunctionCall | NotGiven = NOT_GIVEN,
+ functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN,
+ logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
+ max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
+ n: Optional[int] | NotGiven = NOT_GIVEN,
+ presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
+ temperature: Optional[float] | NotGiven = NOT_GIVEN,
+ top_p: Optional[float] | NotGiven = NOT_GIVEN,
+ user: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> AsyncStream[ChatCompletionChunk]:
+ """
+ Creates a model response for the given chat conversation.
+
+ Args:
+ messages: A list of messages comprising the conversation so far.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models).
+
+ model: ID of the model to use. See the
+ [model endpoint compatibility](https://platform.openai.com/docs/models/model-endpoint-compatibility)
+ table for details on which models work with the Chat API.
+
+ stream: If set, partial message deltas will be sent, like in ChatGPT. Tokens will be
+ sent as data-only
+ [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
+ as they become available, with the stream terminated by a `data: [DONE]`
+ message.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).
+
+ frequency_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on their
+ existing frequency in the text so far, decreasing the model's likelihood to
+ repeat the same line verbatim.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+
+ function_call: Controls how the model calls functions. "none" means the model will not call a
+ function and instead generates a message. "auto" means the model can pick
+ between generating a message or calling a function. Specifying a particular
+ function via `{"name": "my_function"}` forces the model to call that function.
+ "none" is the default when no functions are present. "auto" is the default if
+ functions are present.
+
+ functions: A list of functions the model may generate JSON inputs for.
+
+ logit_bias: Modify the likelihood of specified tokens appearing in the completion.
+
+ Accepts a json object that maps tokens (specified by their token ID in the
+ tokenizer) to an associated bias value from -100 to 100. Mathematically, the
+ bias is added to the logits generated by the model prior to sampling. The exact
+ effect will vary per model, but values between -1 and 1 should decrease or
+ increase likelihood of selection; values like -100 or 100 should result in a ban
+ or exclusive selection of the relevant token.
+
+ max_tokens: The maximum number of [tokens](/tokenizer) to generate in the chat completion.
+
+ The total length of input tokens and generated tokens is limited by the model's
+ context length.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)
+ for counting tokens.
+
+ n: How many chat completion choices to generate for each input message.
+
+ presence_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on
+ whether they appear in the text so far, increasing the model's likelihood to
+ talk about new topics.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+
+ stop: Up to 4 sequences where the API will stop generating further tokens.
+
+ temperature: What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
+ make the output more random, while lower values like 0.2 will make it more
+ focused and deterministic.
+
+ We generally recommend altering this or `top_p` but not both.
+
+ top_p: An alternative to sampling with temperature, called nucleus sampling, where the
+ model considers the results of the tokens with top_p probability mass. So 0.1
+ means only the tokens comprising the top 10% probability mass are considered.
+
+ We generally recommend altering this or `temperature` but not both.
+
+ user: A unique identifier representing your end-user, which can help OpenAI to monitor
+ and detect abuse.
+ [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ ...
+
+ @overload
+ async def create(
+ self,
+ *,
+ messages: List[ChatCompletionMessageParam],
+ model: Union[
+ str,
+ Literal[
+ "gpt-4",
+ "gpt-4-0314",
+ "gpt-4-0613",
+ "gpt-4-32k",
+ "gpt-4-32k-0314",
+ "gpt-4-32k-0613",
+ "gpt-3.5-turbo",
+ "gpt-3.5-turbo-16k",
+ "gpt-3.5-turbo-0301",
+ "gpt-3.5-turbo-0613",
+ "gpt-3.5-turbo-16k-0613",
+ ],
+ ],
+ stream: bool,
+ frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ function_call: completion_create_params.FunctionCall | NotGiven = NOT_GIVEN,
+ functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN,
+ logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
+ max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
+ n: Optional[int] | NotGiven = NOT_GIVEN,
+ presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
+ temperature: Optional[float] | NotGiven = NOT_GIVEN,
+ top_p: Optional[float] | NotGiven = NOT_GIVEN,
+ user: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
+ """
+ Creates a model response for the given chat conversation.
+
+ Args:
+ messages: A list of messages comprising the conversation so far.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models).
+
+ model: ID of the model to use. See the
+ [model endpoint compatibility](https://platform.openai.com/docs/models/model-endpoint-compatibility)
+ table for details on which models work with the Chat API.
+
+ stream: If set, partial message deltas will be sent, like in ChatGPT. Tokens will be
+ sent as data-only
+ [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
+ as they become available, with the stream terminated by a `data: [DONE]`
+ message.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).
+
+ frequency_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on their
+ existing frequency in the text so far, decreasing the model's likelihood to
+ repeat the same line verbatim.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+
+ function_call: Controls how the model calls functions. "none" means the model will not call a
+ function and instead generates a message. "auto" means the model can pick
+ between generating a message or calling a function. Specifying a particular
+ function via `{"name": "my_function"}` forces the model to call that function.
+ "none" is the default when no functions are present. "auto" is the default if
+ functions are present.
+
+ functions: A list of functions the model may generate JSON inputs for.
+
+ logit_bias: Modify the likelihood of specified tokens appearing in the completion.
+
+ Accepts a json object that maps tokens (specified by their token ID in the
+ tokenizer) to an associated bias value from -100 to 100. Mathematically, the
+ bias is added to the logits generated by the model prior to sampling. The exact
+ effect will vary per model, but values between -1 and 1 should decrease or
+ increase likelihood of selection; values like -100 or 100 should result in a ban
+ or exclusive selection of the relevant token.
+
+ max_tokens: The maximum number of [tokens](/tokenizer) to generate in the chat completion.
+
+ The total length of input tokens and generated tokens is limited by the model's
+ context length.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)
+ for counting tokens.
+
+ n: How many chat completion choices to generate for each input message.
+
+ presence_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on
+ whether they appear in the text so far, increasing the model's likelihood to
+ talk about new topics.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+
+ stop: Up to 4 sequences where the API will stop generating further tokens.
+
+ temperature: What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
+ make the output more random, while lower values like 0.2 will make it more
+ focused and deterministic.
+
+ We generally recommend altering this or `top_p` but not both.
+
+ top_p: An alternative to sampling with temperature, called nucleus sampling, where the
+ model considers the results of the tokens with top_p probability mass. So 0.1
+ means only the tokens comprising the top 10% probability mass are considered.
+
+ We generally recommend altering this or `temperature` but not both.
+
+ user: A unique identifier representing your end-user, which can help OpenAI to monitor
+ and detect abuse.
+ [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ ...
+
+ @required_args(["messages", "model"], ["messages", "model", "stream"])
+ async def create(
+ self,
+ *,
+ messages: List[ChatCompletionMessageParam],
+ model: Union[
+ str,
+ Literal[
+ "gpt-4",
+ "gpt-4-0314",
+ "gpt-4-0613",
+ "gpt-4-32k",
+ "gpt-4-32k-0314",
+ "gpt-4-32k-0613",
+ "gpt-3.5-turbo",
+ "gpt-3.5-turbo-16k",
+ "gpt-3.5-turbo-0301",
+ "gpt-3.5-turbo-0613",
+ "gpt-3.5-turbo-16k-0613",
+ ],
+ ],
+ frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ function_call: completion_create_params.FunctionCall | NotGiven = NOT_GIVEN,
+ functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN,
+ logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
+ max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
+ n: Optional[int] | NotGiven = NOT_GIVEN,
+ presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
+ stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
+ temperature: Optional[float] | NotGiven = NOT_GIVEN,
+ top_p: Optional[float] | NotGiven = NOT_GIVEN,
+ user: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
+ return await self._post(
+ "/chat/completions",
+ body=maybe_transform(
+ {
+ "messages": messages,
+ "model": model,
+ "frequency_penalty": frequency_penalty,
+ "function_call": function_call,
+ "functions": functions,
+ "logit_bias": logit_bias,
+ "max_tokens": max_tokens,
+ "n": n,
+ "presence_penalty": presence_penalty,
+ "stop": stop,
+ "stream": stream,
+ "temperature": temperature,
+ "top_p": top_p,
+ "user": user,
+ },
+ completion_create_params.CompletionCreateParams,
+ ),
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=ChatCompletion,
+ stream=stream or False,
+ stream_cls=AsyncStream[ChatCompletionChunk],
+ )
+
+
+class CompletionsWithRawResponse:
+ def __init__(self, completions: Completions) -> None:
+ self.create = to_raw_response_wrapper(
+ completions.create,
+ )
+
+
+class AsyncCompletionsWithRawResponse:
+ def __init__(self, completions: AsyncCompletions) -> None:
+ self.create = async_to_raw_response_wrapper(
+ completions.create,
+ )
diff --git a/src/openai/resources/completions.py b/src/openai/resources/completions.py
new file mode 100644
index 0000000000..26a34524c6
--- /dev/null
+++ b/src/openai/resources/completions.py
@@ -0,0 +1,1117 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Dict, List, Union, Optional, overload
+from typing_extensions import Literal
+
+from ..types import Completion, completion_create_params
+from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
+from .._utils import required_args, maybe_transform
+from .._resource import SyncAPIResource, AsyncAPIResource
+from .._response import to_raw_response_wrapper, async_to_raw_response_wrapper
+from .._streaming import Stream, AsyncStream
+from .._base_client import make_request_options
+
+if TYPE_CHECKING:
+ from .._client import OpenAI, AsyncOpenAI
+
+__all__ = ["Completions", "AsyncCompletions"]
+
+
+class Completions(SyncAPIResource):
+ with_raw_response: CompletionsWithRawResponse
+
+ def __init__(self, client: OpenAI) -> None:
+ super().__init__(client)
+ self.with_raw_response = CompletionsWithRawResponse(self)
+
+ @overload
+ def create(
+ self,
+ *,
+ model: Union[
+ str,
+ Literal[
+ "babbage-002",
+ "davinci-002",
+ "gpt-3.5-turbo-instruct",
+ "text-davinci-003",
+ "text-davinci-002",
+ "text-davinci-001",
+ "code-davinci-002",
+ "text-curie-001",
+ "text-babbage-001",
+ "text-ada-001",
+ ],
+ ],
+ prompt: Union[str, List[str], List[int], List[List[int]], None],
+ best_of: Optional[int] | NotGiven = NOT_GIVEN,
+ echo: Optional[bool] | NotGiven = NOT_GIVEN,
+ frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
+ logprobs: Optional[int] | NotGiven = NOT_GIVEN,
+ max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
+ n: Optional[int] | NotGiven = NOT_GIVEN,
+ presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
+ stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
+ suffix: Optional[str] | NotGiven = NOT_GIVEN,
+ temperature: Optional[float] | NotGiven = NOT_GIVEN,
+ top_p: Optional[float] | NotGiven = NOT_GIVEN,
+ user: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> Completion:
+ """
+ Creates a completion for the provided prompt and parameters.
+
+ Args:
+ model: ID of the model to use. You can use the
+ [List models](https://platform.openai.com/docs/api-reference/models/list) API to
+ see all of your available models, or see our
+ [Model overview](https://platform.openai.com/docs/models/overview) for
+ descriptions of them.
+
+ prompt: The prompt(s) to generate completions for, encoded as a string, array of
+ strings, array of tokens, or array of token arrays.
+
+ Note that <|endoftext|> is the document separator that the model sees during
+ training, so if a prompt is not specified the model will generate as if from the
+ beginning of a new document.
+
+ best_of: Generates `best_of` completions server-side and returns the "best" (the one with
+ the highest log probability per token). Results cannot be streamed.
+
+ When used with `n`, `best_of` controls the number of candidate completions and
+ `n` specifies how many to return – `best_of` must be greater than `n`.
+
+ **Note:** Because this parameter generates many completions, it can quickly
+ consume your token quota. Use carefully and ensure that you have reasonable
+ settings for `max_tokens` and `stop`.
+
+ echo: Echo back the prompt in addition to the completion
+
+ frequency_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on their
+ existing frequency in the text so far, decreasing the model's likelihood to
+ repeat the same line verbatim.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+
+ logit_bias: Modify the likelihood of specified tokens appearing in the completion.
+
+ Accepts a json object that maps tokens (specified by their token ID in the GPT
+ tokenizer) to an associated bias value from -100 to 100. You can use this
+ [tokenizer tool](/tokenizer?view=bpe) (which works for both GPT-2 and GPT-3) to
+ convert text to token IDs. Mathematically, the bias is added to the logits
+ generated by the model prior to sampling. The exact effect will vary per model,
+ but values between -1 and 1 should decrease or increase likelihood of selection;
+ values like -100 or 100 should result in a ban or exclusive selection of the
+ relevant token.
+
+ As an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token
+ from being generated.
+
+ logprobs: Include the log probabilities on the `logprobs` most likely tokens, as well the
+ chosen tokens. For example, if `logprobs` is 5, the API will return a list of
+ the 5 most likely tokens. The API will always return the `logprob` of the
+ sampled token, so there may be up to `logprobs+1` elements in the response.
+
+ The maximum value for `logprobs` is 5.
+
+ max_tokens: The maximum number of [tokens](/tokenizer) to generate in the completion.
+
+ The token count of your prompt plus `max_tokens` cannot exceed the model's
+ context length.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)
+ for counting tokens.
+
+ n: How many completions to generate for each prompt.
+
+ **Note:** Because this parameter generates many completions, it can quickly
+ consume your token quota. Use carefully and ensure that you have reasonable
+ settings for `max_tokens` and `stop`.
+
+ presence_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on
+ whether they appear in the text so far, increasing the model's likelihood to
+ talk about new topics.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+
+ stop: Up to 4 sequences where the API will stop generating further tokens. The
+ returned text will not contain the stop sequence.
+
+ stream: Whether to stream back partial progress. If set, tokens will be sent as
+ data-only
+ [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
+ as they become available, with the stream terminated by a `data: [DONE]`
+ message.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).
+
+ suffix: The suffix that comes after a completion of inserted text.
+
+ temperature: What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
+ make the output more random, while lower values like 0.2 will make it more
+ focused and deterministic.
+
+ We generally recommend altering this or `top_p` but not both.
+
+ top_p: An alternative to sampling with temperature, called nucleus sampling, where the
+ model considers the results of the tokens with top_p probability mass. So 0.1
+ means only the tokens comprising the top 10% probability mass are considered.
+
+ We generally recommend altering this or `temperature` but not both.
+
+ user: A unique identifier representing your end-user, which can help OpenAI to monitor
+ and detect abuse.
+ [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ ...
+
+ @overload
+ def create(
+ self,
+ *,
+ model: Union[
+ str,
+ Literal[
+ "babbage-002",
+ "davinci-002",
+ "gpt-3.5-turbo-instruct",
+ "text-davinci-003",
+ "text-davinci-002",
+ "text-davinci-001",
+ "code-davinci-002",
+ "text-curie-001",
+ "text-babbage-001",
+ "text-ada-001",
+ ],
+ ],
+ prompt: Union[str, List[str], List[int], List[List[int]], None],
+ stream: Literal[True],
+ best_of: Optional[int] | NotGiven = NOT_GIVEN,
+ echo: Optional[bool] | NotGiven = NOT_GIVEN,
+ frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
+ logprobs: Optional[int] | NotGiven = NOT_GIVEN,
+ max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
+ n: Optional[int] | NotGiven = NOT_GIVEN,
+ presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
+ suffix: Optional[str] | NotGiven = NOT_GIVEN,
+ temperature: Optional[float] | NotGiven = NOT_GIVEN,
+ top_p: Optional[float] | NotGiven = NOT_GIVEN,
+ user: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> Stream[Completion]:
+ """
+ Creates a completion for the provided prompt and parameters.
+
+ Args:
+ model: ID of the model to use. You can use the
+ [List models](https://platform.openai.com/docs/api-reference/models/list) API to
+ see all of your available models, or see our
+ [Model overview](https://platform.openai.com/docs/models/overview) for
+ descriptions of them.
+
+ prompt: The prompt(s) to generate completions for, encoded as a string, array of
+ strings, array of tokens, or array of token arrays.
+
+ Note that <|endoftext|> is the document separator that the model sees during
+ training, so if a prompt is not specified the model will generate as if from the
+ beginning of a new document.
+
+ stream: Whether to stream back partial progress. If set, tokens will be sent as
+ data-only
+ [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
+ as they become available, with the stream terminated by a `data: [DONE]`
+ message.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).
+
+ best_of: Generates `best_of` completions server-side and returns the "best" (the one with
+ the highest log probability per token). Results cannot be streamed.
+
+ When used with `n`, `best_of` controls the number of candidate completions and
+ `n` specifies how many to return – `best_of` must be greater than `n`.
+
+ **Note:** Because this parameter generates many completions, it can quickly
+ consume your token quota. Use carefully and ensure that you have reasonable
+ settings for `max_tokens` and `stop`.
+
+ echo: Echo back the prompt in addition to the completion
+
+ frequency_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on their
+ existing frequency in the text so far, decreasing the model's likelihood to
+ repeat the same line verbatim.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+
+ logit_bias: Modify the likelihood of specified tokens appearing in the completion.
+
+ Accepts a json object that maps tokens (specified by their token ID in the GPT
+ tokenizer) to an associated bias value from -100 to 100. You can use this
+ [tokenizer tool](/tokenizer?view=bpe) (which works for both GPT-2 and GPT-3) to
+ convert text to token IDs. Mathematically, the bias is added to the logits
+ generated by the model prior to sampling. The exact effect will vary per model,
+ but values between -1 and 1 should decrease or increase likelihood of selection;
+ values like -100 or 100 should result in a ban or exclusive selection of the
+ relevant token.
+
+ As an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token
+ from being generated.
+
+ logprobs: Include the log probabilities on the `logprobs` most likely tokens, as well the
+ chosen tokens. For example, if `logprobs` is 5, the API will return a list of
+ the 5 most likely tokens. The API will always return the `logprob` of the
+ sampled token, so there may be up to `logprobs+1` elements in the response.
+
+ The maximum value for `logprobs` is 5.
+
+ max_tokens: The maximum number of [tokens](/tokenizer) to generate in the completion.
+
+ The token count of your prompt plus `max_tokens` cannot exceed the model's
+ context length.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)
+ for counting tokens.
+
+ n: How many completions to generate for each prompt.
+
+ **Note:** Because this parameter generates many completions, it can quickly
+ consume your token quota. Use carefully and ensure that you have reasonable
+ settings for `max_tokens` and `stop`.
+
+ presence_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on
+ whether they appear in the text so far, increasing the model's likelihood to
+ talk about new topics.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+
+ stop: Up to 4 sequences where the API will stop generating further tokens. The
+ returned text will not contain the stop sequence.
+
+ suffix: The suffix that comes after a completion of inserted text.
+
+ temperature: What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
+ make the output more random, while lower values like 0.2 will make it more
+ focused and deterministic.
+
+ We generally recommend altering this or `top_p` but not both.
+
+ top_p: An alternative to sampling with temperature, called nucleus sampling, where the
+ model considers the results of the tokens with top_p probability mass. So 0.1
+ means only the tokens comprising the top 10% probability mass are considered.
+
+ We generally recommend altering this or `temperature` but not both.
+
+ user: A unique identifier representing your end-user, which can help OpenAI to monitor
+ and detect abuse.
+ [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ ...
+
+ @overload
+ def create(
+ self,
+ *,
+ model: Union[
+ str,
+ Literal[
+ "babbage-002",
+ "davinci-002",
+ "gpt-3.5-turbo-instruct",
+ "text-davinci-003",
+ "text-davinci-002",
+ "text-davinci-001",
+ "code-davinci-002",
+ "text-curie-001",
+ "text-babbage-001",
+ "text-ada-001",
+ ],
+ ],
+ prompt: Union[str, List[str], List[int], List[List[int]], None],
+ stream: bool,
+ best_of: Optional[int] | NotGiven = NOT_GIVEN,
+ echo: Optional[bool] | NotGiven = NOT_GIVEN,
+ frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
+ logprobs: Optional[int] | NotGiven = NOT_GIVEN,
+ max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
+ n: Optional[int] | NotGiven = NOT_GIVEN,
+ presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
+ suffix: Optional[str] | NotGiven = NOT_GIVEN,
+ temperature: Optional[float] | NotGiven = NOT_GIVEN,
+ top_p: Optional[float] | NotGiven = NOT_GIVEN,
+ user: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> Completion | Stream[Completion]:
+ """
+ Creates a completion for the provided prompt and parameters.
+
+ Args:
+ model: ID of the model to use. You can use the
+ [List models](https://platform.openai.com/docs/api-reference/models/list) API to
+ see all of your available models, or see our
+ [Model overview](https://platform.openai.com/docs/models/overview) for
+ descriptions of them.
+
+ prompt: The prompt(s) to generate completions for, encoded as a string, array of
+ strings, array of tokens, or array of token arrays.
+
+ Note that <|endoftext|> is the document separator that the model sees during
+ training, so if a prompt is not specified the model will generate as if from the
+ beginning of a new document.
+
+ stream: Whether to stream back partial progress. If set, tokens will be sent as
+ data-only
+ [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
+ as they become available, with the stream terminated by a `data: [DONE]`
+ message.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).
+
+ best_of: Generates `best_of` completions server-side and returns the "best" (the one with
+ the highest log probability per token). Results cannot be streamed.
+
+ When used with `n`, `best_of` controls the number of candidate completions and
+ `n` specifies how many to return – `best_of` must be greater than `n`.
+
+ **Note:** Because this parameter generates many completions, it can quickly
+ consume your token quota. Use carefully and ensure that you have reasonable
+ settings for `max_tokens` and `stop`.
+
+ echo: Echo back the prompt in addition to the completion
+
+ frequency_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on their
+ existing frequency in the text so far, decreasing the model's likelihood to
+ repeat the same line verbatim.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+
+ logit_bias: Modify the likelihood of specified tokens appearing in the completion.
+
+ Accepts a json object that maps tokens (specified by their token ID in the GPT
+ tokenizer) to an associated bias value from -100 to 100. You can use this
+ [tokenizer tool](/tokenizer?view=bpe) (which works for both GPT-2 and GPT-3) to
+ convert text to token IDs. Mathematically, the bias is added to the logits
+ generated by the model prior to sampling. The exact effect will vary per model,
+ but values between -1 and 1 should decrease or increase likelihood of selection;
+ values like -100 or 100 should result in a ban or exclusive selection of the
+ relevant token.
+
+ As an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token
+ from being generated.
+
+ logprobs: Include the log probabilities on the `logprobs` most likely tokens, as well the
+ chosen tokens. For example, if `logprobs` is 5, the API will return a list of
+ the 5 most likely tokens. The API will always return the `logprob` of the
+ sampled token, so there may be up to `logprobs+1` elements in the response.
+
+ The maximum value for `logprobs` is 5.
+
+ max_tokens: The maximum number of [tokens](/tokenizer) to generate in the completion.
+
+ The token count of your prompt plus `max_tokens` cannot exceed the model's
+ context length.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)
+ for counting tokens.
+
+ n: How many completions to generate for each prompt.
+
+ **Note:** Because this parameter generates many completions, it can quickly
+ consume your token quota. Use carefully and ensure that you have reasonable
+ settings for `max_tokens` and `stop`.
+
+ presence_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on
+ whether they appear in the text so far, increasing the model's likelihood to
+ talk about new topics.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+
+ stop: Up to 4 sequences where the API will stop generating further tokens. The
+ returned text will not contain the stop sequence.
+
+ suffix: The suffix that comes after a completion of inserted text.
+
+ temperature: What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
+ make the output more random, while lower values like 0.2 will make it more
+ focused and deterministic.
+
+ We generally recommend altering this or `top_p` but not both.
+
+ top_p: An alternative to sampling with temperature, called nucleus sampling, where the
+ model considers the results of the tokens with top_p probability mass. So 0.1
+ means only the tokens comprising the top 10% probability mass are considered.
+
+ We generally recommend altering this or `temperature` but not both.
+
+ user: A unique identifier representing your end-user, which can help OpenAI to monitor
+ and detect abuse.
+ [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ ...
+
+ @required_args(["model", "prompt"], ["model", "prompt", "stream"])
+ def create(
+ self,
+ *,
+ model: Union[
+ str,
+ Literal[
+ "babbage-002",
+ "davinci-002",
+ "gpt-3.5-turbo-instruct",
+ "text-davinci-003",
+ "text-davinci-002",
+ "text-davinci-001",
+ "code-davinci-002",
+ "text-curie-001",
+ "text-babbage-001",
+ "text-ada-001",
+ ],
+ ],
+ prompt: Union[str, List[str], List[int], List[List[int]], None],
+ best_of: Optional[int] | NotGiven = NOT_GIVEN,
+ echo: Optional[bool] | NotGiven = NOT_GIVEN,
+ frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
+ logprobs: Optional[int] | NotGiven = NOT_GIVEN,
+ max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
+ n: Optional[int] | NotGiven = NOT_GIVEN,
+ presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
+ stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
+ suffix: Optional[str] | NotGiven = NOT_GIVEN,
+ temperature: Optional[float] | NotGiven = NOT_GIVEN,
+ top_p: Optional[float] | NotGiven = NOT_GIVEN,
+ user: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> Completion | Stream[Completion]:
+ return self._post(
+ "/completions",
+ body=maybe_transform(
+ {
+ "model": model,
+ "prompt": prompt,
+ "best_of": best_of,
+ "echo": echo,
+ "frequency_penalty": frequency_penalty,
+ "logit_bias": logit_bias,
+ "logprobs": logprobs,
+ "max_tokens": max_tokens,
+ "n": n,
+ "presence_penalty": presence_penalty,
+ "stop": stop,
+ "stream": stream,
+ "suffix": suffix,
+ "temperature": temperature,
+ "top_p": top_p,
+ "user": user,
+ },
+ completion_create_params.CompletionCreateParams,
+ ),
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=Completion,
+ stream=stream or False,
+ stream_cls=Stream[Completion],
+ )
+
+
+class AsyncCompletions(AsyncAPIResource):
+ with_raw_response: AsyncCompletionsWithRawResponse
+
+ def __init__(self, client: AsyncOpenAI) -> None:
+ super().__init__(client)
+ self.with_raw_response = AsyncCompletionsWithRawResponse(self)
+
+ @overload
+ async def create(
+ self,
+ *,
+ model: Union[
+ str,
+ Literal[
+ "babbage-002",
+ "davinci-002",
+ "gpt-3.5-turbo-instruct",
+ "text-davinci-003",
+ "text-davinci-002",
+ "text-davinci-001",
+ "code-davinci-002",
+ "text-curie-001",
+ "text-babbage-001",
+ "text-ada-001",
+ ],
+ ],
+ prompt: Union[str, List[str], List[int], List[List[int]], None],
+ best_of: Optional[int] | NotGiven = NOT_GIVEN,
+ echo: Optional[bool] | NotGiven = NOT_GIVEN,
+ frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
+ logprobs: Optional[int] | NotGiven = NOT_GIVEN,
+ max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
+ n: Optional[int] | NotGiven = NOT_GIVEN,
+ presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
+ stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
+ suffix: Optional[str] | NotGiven = NOT_GIVEN,
+ temperature: Optional[float] | NotGiven = NOT_GIVEN,
+ top_p: Optional[float] | NotGiven = NOT_GIVEN,
+ user: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> Completion:
+ """
+ Creates a completion for the provided prompt and parameters.
+
+ Args:
+ model: ID of the model to use. You can use the
+ [List models](https://platform.openai.com/docs/api-reference/models/list) API to
+ see all of your available models, or see our
+ [Model overview](https://platform.openai.com/docs/models/overview) for
+ descriptions of them.
+
+ prompt: The prompt(s) to generate completions for, encoded as a string, array of
+ strings, array of tokens, or array of token arrays.
+
+ Note that <|endoftext|> is the document separator that the model sees during
+ training, so if a prompt is not specified the model will generate as if from the
+ beginning of a new document.
+
+ best_of: Generates `best_of` completions server-side and returns the "best" (the one with
+ the highest log probability per token). Results cannot be streamed.
+
+ When used with `n`, `best_of` controls the number of candidate completions and
+ `n` specifies how many to return – `best_of` must be greater than `n`.
+
+ **Note:** Because this parameter generates many completions, it can quickly
+ consume your token quota. Use carefully and ensure that you have reasonable
+ settings for `max_tokens` and `stop`.
+
+ echo: Echo back the prompt in addition to the completion
+
+ frequency_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on their
+ existing frequency in the text so far, decreasing the model's likelihood to
+ repeat the same line verbatim.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+
+ logit_bias: Modify the likelihood of specified tokens appearing in the completion.
+
+ Accepts a json object that maps tokens (specified by their token ID in the GPT
+ tokenizer) to an associated bias value from -100 to 100. You can use this
+ [tokenizer tool](/tokenizer?view=bpe) (which works for both GPT-2 and GPT-3) to
+ convert text to token IDs. Mathematically, the bias is added to the logits
+ generated by the model prior to sampling. The exact effect will vary per model,
+ but values between -1 and 1 should decrease or increase likelihood of selection;
+ values like -100 or 100 should result in a ban or exclusive selection of the
+ relevant token.
+
+ As an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token
+ from being generated.
+
+ logprobs: Include the log probabilities on the `logprobs` most likely tokens, as well the
+ chosen tokens. For example, if `logprobs` is 5, the API will return a list of
+ the 5 most likely tokens. The API will always return the `logprob` of the
+ sampled token, so there may be up to `logprobs+1` elements in the response.
+
+ The maximum value for `logprobs` is 5.
+
+ max_tokens: The maximum number of [tokens](/tokenizer) to generate in the completion.
+
+ The token count of your prompt plus `max_tokens` cannot exceed the model's
+ context length.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)
+ for counting tokens.
+
+ n: How many completions to generate for each prompt.
+
+ **Note:** Because this parameter generates many completions, it can quickly
+ consume your token quota. Use carefully and ensure that you have reasonable
+ settings for `max_tokens` and `stop`.
+
+ presence_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on
+ whether they appear in the text so far, increasing the model's likelihood to
+ talk about new topics.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+
+ stop: Up to 4 sequences where the API will stop generating further tokens. The
+ returned text will not contain the stop sequence.
+
+ stream: Whether to stream back partial progress. If set, tokens will be sent as
+ data-only
+ [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
+ as they become available, with the stream terminated by a `data: [DONE]`
+ message.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).
+
+ suffix: The suffix that comes after a completion of inserted text.
+
+ temperature: What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
+ make the output more random, while lower values like 0.2 will make it more
+ focused and deterministic.
+
+ We generally recommend altering this or `top_p` but not both.
+
+ top_p: An alternative to sampling with temperature, called nucleus sampling, where the
+ model considers the results of the tokens with top_p probability mass. So 0.1
+ means only the tokens comprising the top 10% probability mass are considered.
+
+ We generally recommend altering this or `temperature` but not both.
+
+ user: A unique identifier representing your end-user, which can help OpenAI to monitor
+ and detect abuse.
+ [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ ...
+
+ @overload
+ async def create(
+ self,
+ *,
+ model: Union[
+ str,
+ Literal[
+ "babbage-002",
+ "davinci-002",
+ "gpt-3.5-turbo-instruct",
+ "text-davinci-003",
+ "text-davinci-002",
+ "text-davinci-001",
+ "code-davinci-002",
+ "text-curie-001",
+ "text-babbage-001",
+ "text-ada-001",
+ ],
+ ],
+ prompt: Union[str, List[str], List[int], List[List[int]], None],
+ stream: Literal[True],
+ best_of: Optional[int] | NotGiven = NOT_GIVEN,
+ echo: Optional[bool] | NotGiven = NOT_GIVEN,
+ frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
+ logprobs: Optional[int] | NotGiven = NOT_GIVEN,
+ max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
+ n: Optional[int] | NotGiven = NOT_GIVEN,
+ presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
+ suffix: Optional[str] | NotGiven = NOT_GIVEN,
+ temperature: Optional[float] | NotGiven = NOT_GIVEN,
+ top_p: Optional[float] | NotGiven = NOT_GIVEN,
+ user: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> AsyncStream[Completion]:
+ """
+ Creates a completion for the provided prompt and parameters.
+
+ Args:
+ model: ID of the model to use. You can use the
+ [List models](https://platform.openai.com/docs/api-reference/models/list) API to
+ see all of your available models, or see our
+ [Model overview](https://platform.openai.com/docs/models/overview) for
+ descriptions of them.
+
+ prompt: The prompt(s) to generate completions for, encoded as a string, array of
+ strings, array of tokens, or array of token arrays.
+
+ Note that <|endoftext|> is the document separator that the model sees during
+ training, so if a prompt is not specified the model will generate as if from the
+ beginning of a new document.
+
+ stream: Whether to stream back partial progress. If set, tokens will be sent as
+ data-only
+ [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
+ as they become available, with the stream terminated by a `data: [DONE]`
+ message.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).
+
+ best_of: Generates `best_of` completions server-side and returns the "best" (the one with
+ the highest log probability per token). Results cannot be streamed.
+
+ When used with `n`, `best_of` controls the number of candidate completions and
+ `n` specifies how many to return – `best_of` must be greater than `n`.
+
+ **Note:** Because this parameter generates many completions, it can quickly
+ consume your token quota. Use carefully and ensure that you have reasonable
+ settings for `max_tokens` and `stop`.
+
+ echo: Echo back the prompt in addition to the completion
+
+ frequency_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on their
+ existing frequency in the text so far, decreasing the model's likelihood to
+ repeat the same line verbatim.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+
+ logit_bias: Modify the likelihood of specified tokens appearing in the completion.
+
+ Accepts a json object that maps tokens (specified by their token ID in the GPT
+ tokenizer) to an associated bias value from -100 to 100. You can use this
+ [tokenizer tool](/tokenizer?view=bpe) (which works for both GPT-2 and GPT-3) to
+ convert text to token IDs. Mathematically, the bias is added to the logits
+ generated by the model prior to sampling. The exact effect will vary per model,
+ but values between -1 and 1 should decrease or increase likelihood of selection;
+ values like -100 or 100 should result in a ban or exclusive selection of the
+ relevant token.
+
+ As an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token
+ from being generated.
+
+ logprobs: Include the log probabilities on the `logprobs` most likely tokens, as well the
+ chosen tokens. For example, if `logprobs` is 5, the API will return a list of
+ the 5 most likely tokens. The API will always return the `logprob` of the
+ sampled token, so there may be up to `logprobs+1` elements in the response.
+
+ The maximum value for `logprobs` is 5.
+
+ max_tokens: The maximum number of [tokens](/tokenizer) to generate in the completion.
+
+ The token count of your prompt plus `max_tokens` cannot exceed the model's
+ context length.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)
+ for counting tokens.
+
+ n: How many completions to generate for each prompt.
+
+ **Note:** Because this parameter generates many completions, it can quickly
+ consume your token quota. Use carefully and ensure that you have reasonable
+ settings for `max_tokens` and `stop`.
+
+ presence_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on
+ whether they appear in the text so far, increasing the model's likelihood to
+ talk about new topics.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+
+ stop: Up to 4 sequences where the API will stop generating further tokens. The
+ returned text will not contain the stop sequence.
+
+ suffix: The suffix that comes after a completion of inserted text.
+
+ temperature: What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
+ make the output more random, while lower values like 0.2 will make it more
+ focused and deterministic.
+
+ We generally recommend altering this or `top_p` but not both.
+
+ top_p: An alternative to sampling with temperature, called nucleus sampling, where the
+ model considers the results of the tokens with top_p probability mass. So 0.1
+ means only the tokens comprising the top 10% probability mass are considered.
+
+ We generally recommend altering this or `temperature` but not both.
+
+ user: A unique identifier representing your end-user, which can help OpenAI to monitor
+ and detect abuse.
+ [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ ...
+
+ @overload
+ async def create(
+ self,
+ *,
+ model: Union[
+ str,
+ Literal[
+ "babbage-002",
+ "davinci-002",
+ "gpt-3.5-turbo-instruct",
+ "text-davinci-003",
+ "text-davinci-002",
+ "text-davinci-001",
+ "code-davinci-002",
+ "text-curie-001",
+ "text-babbage-001",
+ "text-ada-001",
+ ],
+ ],
+ prompt: Union[str, List[str], List[int], List[List[int]], None],
+ stream: bool,
+ best_of: Optional[int] | NotGiven = NOT_GIVEN,
+ echo: Optional[bool] | NotGiven = NOT_GIVEN,
+ frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
+ logprobs: Optional[int] | NotGiven = NOT_GIVEN,
+ max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
+ n: Optional[int] | NotGiven = NOT_GIVEN,
+ presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
+ suffix: Optional[str] | NotGiven = NOT_GIVEN,
+ temperature: Optional[float] | NotGiven = NOT_GIVEN,
+ top_p: Optional[float] | NotGiven = NOT_GIVEN,
+ user: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> Completion | AsyncStream[Completion]:
+ """
+ Creates a completion for the provided prompt and parameters.
+
+ Args:
+ model: ID of the model to use. You can use the
+ [List models](https://platform.openai.com/docs/api-reference/models/list) API to
+ see all of your available models, or see our
+ [Model overview](https://platform.openai.com/docs/models/overview) for
+ descriptions of them.
+
+ prompt: The prompt(s) to generate completions for, encoded as a string, array of
+ strings, array of tokens, or array of token arrays.
+
+ Note that <|endoftext|> is the document separator that the model sees during
+ training, so if a prompt is not specified the model will generate as if from the
+ beginning of a new document.
+
+ stream: Whether to stream back partial progress. If set, tokens will be sent as
+ data-only
+ [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
+ as they become available, with the stream terminated by a `data: [DONE]`
+ message.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).
+
+ best_of: Generates `best_of` completions server-side and returns the "best" (the one with
+ the highest log probability per token). Results cannot be streamed.
+
+ When used with `n`, `best_of` controls the number of candidate completions and
+ `n` specifies how many to return – `best_of` must be greater than `n`.
+
+ **Note:** Because this parameter generates many completions, it can quickly
+ consume your token quota. Use carefully and ensure that you have reasonable
+ settings for `max_tokens` and `stop`.
+
+ echo: Echo back the prompt in addition to the completion
+
+ frequency_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on their
+ existing frequency in the text so far, decreasing the model's likelihood to
+ repeat the same line verbatim.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+
+ logit_bias: Modify the likelihood of specified tokens appearing in the completion.
+
+ Accepts a json object that maps tokens (specified by their token ID in the GPT
+ tokenizer) to an associated bias value from -100 to 100. You can use this
+ [tokenizer tool](/tokenizer?view=bpe) (which works for both GPT-2 and GPT-3) to
+ convert text to token IDs. Mathematically, the bias is added to the logits
+ generated by the model prior to sampling. The exact effect will vary per model,
+ but values between -1 and 1 should decrease or increase likelihood of selection;
+ values like -100 or 100 should result in a ban or exclusive selection of the
+ relevant token.
+
+ As an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token
+ from being generated.
+
+ logprobs: Include the log probabilities on the `logprobs` most likely tokens, as well the
+ chosen tokens. For example, if `logprobs` is 5, the API will return a list of
+ the 5 most likely tokens. The API will always return the `logprob` of the
+ sampled token, so there may be up to `logprobs+1` elements in the response.
+
+ The maximum value for `logprobs` is 5.
+
+ max_tokens: The maximum number of [tokens](/tokenizer) to generate in the completion.
+
+ The token count of your prompt plus `max_tokens` cannot exceed the model's
+ context length.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)
+ for counting tokens.
+
+ n: How many completions to generate for each prompt.
+
+ **Note:** Because this parameter generates many completions, it can quickly
+ consume your token quota. Use carefully and ensure that you have reasonable
+ settings for `max_tokens` and `stop`.
+
+ presence_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on
+ whether they appear in the text so far, increasing the model's likelihood to
+ talk about new topics.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+
+ stop: Up to 4 sequences where the API will stop generating further tokens. The
+ returned text will not contain the stop sequence.
+
+ suffix: The suffix that comes after a completion of inserted text.
+
+ temperature: What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
+ make the output more random, while lower values like 0.2 will make it more
+ focused and deterministic.
+
+ We generally recommend altering this or `top_p` but not both.
+
+ top_p: An alternative to sampling with temperature, called nucleus sampling, where the
+ model considers the results of the tokens with top_p probability mass. So 0.1
+ means only the tokens comprising the top 10% probability mass are considered.
+
+ We generally recommend altering this or `temperature` but not both.
+
+ user: A unique identifier representing your end-user, which can help OpenAI to monitor
+ and detect abuse.
+ [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ ...
+
+ @required_args(["model", "prompt"], ["model", "prompt", "stream"])
+ async def create(
+ self,
+ *,
+ model: Union[
+ str,
+ Literal[
+ "babbage-002",
+ "davinci-002",
+ "gpt-3.5-turbo-instruct",
+ "text-davinci-003",
+ "text-davinci-002",
+ "text-davinci-001",
+ "code-davinci-002",
+ "text-curie-001",
+ "text-babbage-001",
+ "text-ada-001",
+ ],
+ ],
+ prompt: Union[str, List[str], List[int], List[List[int]], None],
+ best_of: Optional[int] | NotGiven = NOT_GIVEN,
+ echo: Optional[bool] | NotGiven = NOT_GIVEN,
+ frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
+ logprobs: Optional[int] | NotGiven = NOT_GIVEN,
+ max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
+ n: Optional[int] | NotGiven = NOT_GIVEN,
+ presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
+ stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
+ stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
+ suffix: Optional[str] | NotGiven = NOT_GIVEN,
+ temperature: Optional[float] | NotGiven = NOT_GIVEN,
+ top_p: Optional[float] | NotGiven = NOT_GIVEN,
+ user: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> Completion | AsyncStream[Completion]:
+ return await self._post(
+ "/completions",
+ body=maybe_transform(
+ {
+ "model": model,
+ "prompt": prompt,
+ "best_of": best_of,
+ "echo": echo,
+ "frequency_penalty": frequency_penalty,
+ "logit_bias": logit_bias,
+ "logprobs": logprobs,
+ "max_tokens": max_tokens,
+ "n": n,
+ "presence_penalty": presence_penalty,
+ "stop": stop,
+ "stream": stream,
+ "suffix": suffix,
+ "temperature": temperature,
+ "top_p": top_p,
+ "user": user,
+ },
+ completion_create_params.CompletionCreateParams,
+ ),
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=Completion,
+ stream=stream or False,
+ stream_cls=AsyncStream[Completion],
+ )
+
+
+class CompletionsWithRawResponse:
+ def __init__(self, completions: Completions) -> None:
+ self.create = to_raw_response_wrapper(
+ completions.create,
+ )
+
+
+class AsyncCompletionsWithRawResponse:
+ def __init__(self, completions: AsyncCompletions) -> None:
+ self.create = async_to_raw_response_wrapper(
+ completions.create,
+ )
diff --git a/src/openai/resources/edits.py b/src/openai/resources/edits.py
new file mode 100644
index 0000000000..5c114c915f
--- /dev/null
+++ b/src/openai/resources/edits.py
@@ -0,0 +1,191 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+import typing_extensions
+from typing import TYPE_CHECKING, Union, Optional
+from typing_extensions import Literal
+
+from ..types import Edit, edit_create_params
+from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
+from .._utils import maybe_transform
+from .._resource import SyncAPIResource, AsyncAPIResource
+from .._response import to_raw_response_wrapper, async_to_raw_response_wrapper
+from .._base_client import make_request_options
+
+if TYPE_CHECKING:
+ from .._client import OpenAI, AsyncOpenAI
+
+__all__ = ["Edits", "AsyncEdits"]
+
+
+class Edits(SyncAPIResource):
+ with_raw_response: EditsWithRawResponse
+
+ def __init__(self, client: OpenAI) -> None:
+ super().__init__(client)
+ self.with_raw_response = EditsWithRawResponse(self)
+
+ @typing_extensions.deprecated(
+ "The Edits API is deprecated; please use Chat Completions instead.\n\nhttps://openai.com/blog/gpt-4-api-general-availability#deprecation-of-the-edits-api\n"
+ )
+ def create(
+ self,
+ *,
+ instruction: str,
+ model: Union[str, Literal["text-davinci-edit-001", "code-davinci-edit-001"]],
+ input: Optional[str] | NotGiven = NOT_GIVEN,
+ n: Optional[int] | NotGiven = NOT_GIVEN,
+ temperature: Optional[float] | NotGiven = NOT_GIVEN,
+ top_p: Optional[float] | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> Edit:
+ """
+ Creates a new edit for the provided input, instruction, and parameters.
+
+ Args:
+ instruction: The instruction that tells the model how to edit the prompt.
+
+ model: ID of the model to use. You can use the `text-davinci-edit-001` or
+ `code-davinci-edit-001` model with this endpoint.
+
+ input: The input text to use as a starting point for the edit.
+
+ n: How many edits to generate for the input and instruction.
+
+ temperature: What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
+ make the output more random, while lower values like 0.2 will make it more
+ focused and deterministic.
+
+ We generally recommend altering this or `top_p` but not both.
+
+ top_p: An alternative to sampling with temperature, called nucleus sampling, where the
+ model considers the results of the tokens with top_p probability mass. So 0.1
+ means only the tokens comprising the top 10% probability mass are considered.
+
+ We generally recommend altering this or `temperature` but not both.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return self._post(
+ "/edits",
+ body=maybe_transform(
+ {
+ "instruction": instruction,
+ "model": model,
+ "input": input,
+ "n": n,
+ "temperature": temperature,
+ "top_p": top_p,
+ },
+ edit_create_params.EditCreateParams,
+ ),
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=Edit,
+ )
+
+
+class AsyncEdits(AsyncAPIResource):
+ with_raw_response: AsyncEditsWithRawResponse
+
+ def __init__(self, client: AsyncOpenAI) -> None:
+ super().__init__(client)
+ self.with_raw_response = AsyncEditsWithRawResponse(self)
+
+ @typing_extensions.deprecated(
+ "The Edits API is deprecated; please use Chat Completions instead.\n\nhttps://openai.com/blog/gpt-4-api-general-availability#deprecation-of-the-edits-api\n"
+ )
+ async def create(
+ self,
+ *,
+ instruction: str,
+ model: Union[str, Literal["text-davinci-edit-001", "code-davinci-edit-001"]],
+ input: Optional[str] | NotGiven = NOT_GIVEN,
+ n: Optional[int] | NotGiven = NOT_GIVEN,
+ temperature: Optional[float] | NotGiven = NOT_GIVEN,
+ top_p: Optional[float] | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> Edit:
+ """
+ Creates a new edit for the provided input, instruction, and parameters.
+
+ Args:
+ instruction: The instruction that tells the model how to edit the prompt.
+
+ model: ID of the model to use. You can use the `text-davinci-edit-001` or
+ `code-davinci-edit-001` model with this endpoint.
+
+ input: The input text to use as a starting point for the edit.
+
+ n: How many edits to generate for the input and instruction.
+
+ temperature: What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
+ make the output more random, while lower values like 0.2 will make it more
+ focused and deterministic.
+
+ We generally recommend altering this or `top_p` but not both.
+
+ top_p: An alternative to sampling with temperature, called nucleus sampling, where the
+ model considers the results of the tokens with top_p probability mass. So 0.1
+ means only the tokens comprising the top 10% probability mass are considered.
+
+ We generally recommend altering this or `temperature` but not both.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return await self._post(
+ "/edits",
+ body=maybe_transform(
+ {
+ "instruction": instruction,
+ "model": model,
+ "input": input,
+ "n": n,
+ "temperature": temperature,
+ "top_p": top_p,
+ },
+ edit_create_params.EditCreateParams,
+ ),
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=Edit,
+ )
+
+
+class EditsWithRawResponse:
+ def __init__(self, edits: Edits) -> None:
+ self.create = to_raw_response_wrapper( # pyright: ignore[reportDeprecated]
+ edits.create # pyright: ignore[reportDeprecated],
+ )
+
+
+class AsyncEditsWithRawResponse:
+ def __init__(self, edits: AsyncEdits) -> None:
+ self.create = async_to_raw_response_wrapper( # pyright: ignore[reportDeprecated]
+ edits.create # pyright: ignore[reportDeprecated],
+ )
diff --git a/src/openai/resources/embeddings.py b/src/openai/resources/embeddings.py
new file mode 100644
index 0000000000..dd540fc796
--- /dev/null
+++ b/src/openai/resources/embeddings.py
@@ -0,0 +1,221 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+import base64
+from typing import TYPE_CHECKING, List, Union, cast
+from typing_extensions import Literal
+
+from ..types import CreateEmbeddingResponse, embedding_create_params
+from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
+from .._utils import is_given, maybe_transform
+from .._extras import numpy as np
+from .._extras import has_numpy
+from .._resource import SyncAPIResource, AsyncAPIResource
+from .._response import to_raw_response_wrapper, async_to_raw_response_wrapper
+from .._base_client import make_request_options
+
+if TYPE_CHECKING:
+ from .._client import OpenAI, AsyncOpenAI
+
+__all__ = ["Embeddings", "AsyncEmbeddings"]
+
+
+class Embeddings(SyncAPIResource):
+ with_raw_response: EmbeddingsWithRawResponse
+
+ def __init__(self, client: OpenAI) -> None:
+ super().__init__(client)
+ self.with_raw_response = EmbeddingsWithRawResponse(self)
+
+ def create(
+ self,
+ *,
+ input: Union[str, List[str], List[int], List[List[int]]],
+ model: Union[str, Literal["text-embedding-ada-002"]],
+ encoding_format: Literal["float", "base64"] | NotGiven = NOT_GIVEN,
+ user: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> CreateEmbeddingResponse:
+ """
+ Creates an embedding vector representing the input text.
+
+ Args:
+ input: Input text to embed, encoded as a string or array of tokens. To embed multiple
+ inputs in a single request, pass an array of strings or array of token arrays.
+ The input must not exceed the max input tokens for the model (8192 tokens for
+ `text-embedding-ada-002`) and cannot be an empty string.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)
+ for counting tokens.
+
+ model: ID of the model to use. You can use the
+ [List models](https://platform.openai.com/docs/api-reference/models/list) API to
+ see all of your available models, or see our
+ [Model overview](https://platform.openai.com/docs/models/overview) for
+ descriptions of them.
+
+ encoding_format: The format to return the embeddings in. Can be either `float` or
+ [`base64`](https://pypi.org/project/pybase64/).
+
+ user: A unique identifier representing your end-user, which can help OpenAI to monitor
+ and detect abuse.
+ [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ params = {
+ "input": input,
+ "model": model,
+ "user": user,
+ "encoding_format": encoding_format,
+ }
+ if not is_given(encoding_format) and has_numpy():
+ params["encoding_format"] = "base64"
+
+ def parser(obj: CreateEmbeddingResponse) -> CreateEmbeddingResponse:
+ if is_given(encoding_format):
+ # don't modify the response object if a user explicitly asked for a format
+ return obj
+
+ for embedding in obj.data:
+ data = cast(object, embedding.embedding)
+ if not isinstance(data, str):
+ # numpy is not installed / base64 optimisation isn't enabled for this model yet
+ continue
+
+ embedding.embedding = np.frombuffer( # type: ignore[no-untyped-call]
+ base64.b64decode(data), dtype="float32"
+ ).tolist()
+
+ return obj
+
+ return self._post(
+ "/embeddings",
+ body=maybe_transform(params, embedding_create_params.EmbeddingCreateParams),
+ options=make_request_options(
+ extra_headers=extra_headers,
+ extra_query=extra_query,
+ extra_body=extra_body,
+ timeout=timeout,
+ post_parser=parser,
+ ),
+ cast_to=CreateEmbeddingResponse,
+ )
+
+
+class AsyncEmbeddings(AsyncAPIResource):
+ with_raw_response: AsyncEmbeddingsWithRawResponse
+
+ def __init__(self, client: AsyncOpenAI) -> None:
+ super().__init__(client)
+ self.with_raw_response = AsyncEmbeddingsWithRawResponse(self)
+
+ async def create(
+ self,
+ *,
+ input: Union[str, List[str], List[int], List[List[int]]],
+ model: Union[str, Literal["text-embedding-ada-002"]],
+ encoding_format: Literal["float", "base64"] | NotGiven = NOT_GIVEN,
+ user: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> CreateEmbeddingResponse:
+ """
+ Creates an embedding vector representing the input text.
+
+ Args:
+ input: Input text to embed, encoded as a string or array of tokens. To embed multiple
+ inputs in a single request, pass an array of strings or array of token arrays.
+ The input must not exceed the max input tokens for the model (8192 tokens for
+ `text-embedding-ada-002`) and cannot be an empty string.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)
+ for counting tokens.
+
+ model: ID of the model to use. You can use the
+ [List models](https://platform.openai.com/docs/api-reference/models/list) API to
+ see all of your available models, or see our
+ [Model overview](https://platform.openai.com/docs/models/overview) for
+ descriptions of them.
+
+ encoding_format: The format to return the embeddings in. Can be either `float` or
+ [`base64`](https://pypi.org/project/pybase64/).
+
+ user: A unique identifier representing your end-user, which can help OpenAI to monitor
+ and detect abuse.
+ [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ params = {
+ "input": input,
+ "model": model,
+ "user": user,
+ "encoding_format": encoding_format,
+ }
+ if not is_given(encoding_format) and has_numpy():
+ params["encoding_format"] = "base64"
+
+ def parser(obj: CreateEmbeddingResponse) -> CreateEmbeddingResponse:
+ if is_given(encoding_format):
+ # don't modify the response object if a user explicitly asked for a format
+ return obj
+
+ for embedding in obj.data:
+ data = cast(object, embedding.embedding)
+ if not isinstance(data, str):
+ # numpy is not installed / base64 optimisation isn't enabled for this model yet
+ continue
+
+ embedding.embedding = np.frombuffer( # type: ignore[no-untyped-call]
+ base64.b64decode(data), dtype="float32"
+ ).tolist()
+
+ return obj
+
+ return await self._post(
+ "/embeddings",
+ body=maybe_transform(params, embedding_create_params.EmbeddingCreateParams),
+ options=make_request_options(
+ extra_headers=extra_headers,
+ extra_query=extra_query,
+ extra_body=extra_body,
+ timeout=timeout,
+ post_parser=parser,
+ ),
+ cast_to=CreateEmbeddingResponse,
+ )
+
+
+class EmbeddingsWithRawResponse:
+ def __init__(self, embeddings: Embeddings) -> None:
+ self.create = to_raw_response_wrapper(
+ embeddings.create,
+ )
+
+
+class AsyncEmbeddingsWithRawResponse:
+ def __init__(self, embeddings: AsyncEmbeddings) -> None:
+ self.create = async_to_raw_response_wrapper(
+ embeddings.create,
+ )
diff --git a/src/openai/resources/files.py b/src/openai/resources/files.py
new file mode 100644
index 0000000000..d2e674c942
--- /dev/null
+++ b/src/openai/resources/files.py
@@ -0,0 +1,471 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+import time
+from typing import TYPE_CHECKING, Mapping, cast
+
+from ..types import FileObject, FileDeleted, file_create_params
+from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes
+from .._utils import extract_files, maybe_transform, deepcopy_minimal
+from .._resource import SyncAPIResource, AsyncAPIResource
+from .._response import to_raw_response_wrapper, async_to_raw_response_wrapper
+from ..pagination import SyncPage, AsyncPage
+from .._base_client import AsyncPaginator, make_request_options
+
+if TYPE_CHECKING:
+ from .._client import OpenAI, AsyncOpenAI
+
+__all__ = ["Files", "AsyncFiles"]
+
+
+class Files(SyncAPIResource):
+ with_raw_response: FilesWithRawResponse
+
+ def __init__(self, client: OpenAI) -> None:
+ super().__init__(client)
+ self.with_raw_response = FilesWithRawResponse(self)
+
+ def create(
+ self,
+ *,
+ file: FileTypes,
+ purpose: str,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> FileObject:
+ """Upload a file that can be used across various endpoints/features.
+
+ Currently, the
+ size of all the files uploaded by one organization can be up to 1 GB. Please
+ [contact us](https://help.openai.com/) if you need to increase the storage
+ limit.
+
+ Args:
+ file: The file object (not file name) to be uploaded.
+
+ If the `purpose` is set to "fine-tune", the file will be used for fine-tuning.
+
+ purpose: The intended purpose of the uploaded file.
+
+ Use "fine-tune" for
+ [fine-tuning](https://platform.openai.com/docs/api-reference/fine-tuning). This
+ allows us to validate the format of the uploaded file is correct for
+ fine-tuning.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ body = deepcopy_minimal(
+ {
+ "file": file,
+ "purpose": purpose,
+ }
+ )
+ files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
+ if files:
+ # It should be noted that the actual Content-Type header that will be
+ # sent to the server will contain a `boundary` parameter, e.g.
+ # multipart/form-data; boundary=---abc--
+ extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
+
+ return self._post(
+ "/files",
+ body=maybe_transform(body, file_create_params.FileCreateParams),
+ files=files,
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=FileObject,
+ )
+
+ def retrieve(
+ self,
+ file_id: str,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> FileObject:
+ """
+ Returns information about a specific file.
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return self._get(
+ f"/files/{file_id}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=FileObject,
+ )
+
+ def list(
+ self,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> SyncPage[FileObject]:
+ """Returns a list of files that belong to the user's organization."""
+ return self._get_api_list(
+ "/files",
+ page=SyncPage[FileObject],
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ model=FileObject,
+ )
+
+ def delete(
+ self,
+ file_id: str,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> FileDeleted:
+ """
+ Delete a file.
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return self._delete(
+ f"/files/{file_id}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=FileDeleted,
+ )
+
+ def retrieve_content(
+ self,
+ file_id: str,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> str:
+ """
+ Returns the contents of the specified file.
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ extra_headers = {"Accept": "application/json", **(extra_headers or {})}
+ return self._get(
+ f"/files/{file_id}/content",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=str,
+ )
+
+ def wait_for_processing(
+ self,
+ id: str,
+ *,
+ poll_interval: float = 5.0,
+ max_wait_seconds: float = 30 * 60,
+ ) -> FileObject:
+ """Waits for the given file to be processed, default timeout is 30 mins."""
+ TERMINAL_STATES = {"processed", "error", "deleted"}
+
+ start = time.time()
+ file = self.retrieve(id)
+ while file.status not in TERMINAL_STATES:
+ self._sleep(poll_interval)
+
+ file = self.retrieve(id)
+ if time.time() - start > max_wait_seconds:
+ raise RuntimeError(
+ f"Giving up on waiting for file {id} to finish processing after {max_wait_seconds} seconds."
+ )
+
+ return file
+
+
+class AsyncFiles(AsyncAPIResource):
+ with_raw_response: AsyncFilesWithRawResponse
+
+ def __init__(self, client: AsyncOpenAI) -> None:
+ super().__init__(client)
+ self.with_raw_response = AsyncFilesWithRawResponse(self)
+
+ async def create(
+ self,
+ *,
+ file: FileTypes,
+ purpose: str,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> FileObject:
+ """Upload a file that can be used across various endpoints/features.
+
+ Currently, the
+ size of all the files uploaded by one organization can be up to 1 GB. Please
+ [contact us](https://help.openai.com/) if you need to increase the storage
+ limit.
+
+ Args:
+ file: The file object (not file name) to be uploaded.
+
+ If the `purpose` is set to "fine-tune", the file will be used for fine-tuning.
+
+ purpose: The intended purpose of the uploaded file.
+
+ Use "fine-tune" for
+ [fine-tuning](https://platform.openai.com/docs/api-reference/fine-tuning). This
+ allows us to validate the format of the uploaded file is correct for
+ fine-tuning.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ body = deepcopy_minimal(
+ {
+ "file": file,
+ "purpose": purpose,
+ }
+ )
+ files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
+ if files:
+ # It should be noted that the actual Content-Type header that will be
+ # sent to the server will contain a `boundary` parameter, e.g.
+ # multipart/form-data; boundary=---abc--
+ extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
+
+ return await self._post(
+ "/files",
+ body=maybe_transform(body, file_create_params.FileCreateParams),
+ files=files,
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=FileObject,
+ )
+
+ async def retrieve(
+ self,
+ file_id: str,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> FileObject:
+ """
+ Returns information about a specific file.
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return await self._get(
+ f"/files/{file_id}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=FileObject,
+ )
+
+ def list(
+ self,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> AsyncPaginator[FileObject, AsyncPage[FileObject]]:
+ """Returns a list of files that belong to the user's organization."""
+ return self._get_api_list(
+ "/files",
+ page=AsyncPage[FileObject],
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ model=FileObject,
+ )
+
+ async def delete(
+ self,
+ file_id: str,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> FileDeleted:
+ """
+ Delete a file.
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return await self._delete(
+ f"/files/{file_id}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=FileDeleted,
+ )
+
+ async def retrieve_content(
+ self,
+ file_id: str,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> str:
+ """
+ Returns the contents of the specified file.
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ extra_headers = {"Accept": "application/json", **(extra_headers or {})}
+ return await self._get(
+ f"/files/{file_id}/content",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=str,
+ )
+
+ async def wait_for_processing(
+ self,
+ id: str,
+ *,
+ poll_interval: float = 5.0,
+ max_wait_seconds: float = 30 * 60,
+ ) -> FileObject:
+ """Waits for the given file to be processed, default timeout is 30 mins."""
+ TERMINAL_STATES = {"processed", "error", "deleted"}
+
+ start = time.time()
+ file = await self.retrieve(id)
+ while file.status not in TERMINAL_STATES:
+ await self._sleep(poll_interval)
+
+ file = await self.retrieve(id)
+ if time.time() - start > max_wait_seconds:
+ raise RuntimeError(
+ f"Giving up on waiting for file {id} to finish processing after {max_wait_seconds} seconds."
+ )
+
+ return file
+
+
+class FilesWithRawResponse:
+ def __init__(self, files: Files) -> None:
+ self.create = to_raw_response_wrapper(
+ files.create,
+ )
+ self.retrieve = to_raw_response_wrapper(
+ files.retrieve,
+ )
+ self.list = to_raw_response_wrapper(
+ files.list,
+ )
+ self.delete = to_raw_response_wrapper(
+ files.delete,
+ )
+ self.retrieve_content = to_raw_response_wrapper(
+ files.retrieve_content,
+ )
+
+
+class AsyncFilesWithRawResponse:
+ def __init__(self, files: AsyncFiles) -> None:
+ self.create = async_to_raw_response_wrapper(
+ files.create,
+ )
+ self.retrieve = async_to_raw_response_wrapper(
+ files.retrieve,
+ )
+ self.list = async_to_raw_response_wrapper(
+ files.list,
+ )
+ self.delete = async_to_raw_response_wrapper(
+ files.delete,
+ )
+ self.retrieve_content = async_to_raw_response_wrapper(
+ files.retrieve_content,
+ )
diff --git a/src/openai/resources/fine_tunes.py b/src/openai/resources/fine_tunes.py
new file mode 100644
index 0000000000..28f4225102
--- /dev/null
+++ b/src/openai/resources/fine_tunes.py
@@ -0,0 +1,820 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, List, Union, Optional, overload
+from typing_extensions import Literal
+
+from ..types import (
+ FineTune,
+ FineTuneEvent,
+ FineTuneEventsListResponse,
+ fine_tune_create_params,
+ fine_tune_list_events_params,
+)
+from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
+from .._utils import maybe_transform
+from .._resource import SyncAPIResource, AsyncAPIResource
+from .._response import to_raw_response_wrapper, async_to_raw_response_wrapper
+from .._streaming import Stream, AsyncStream
+from ..pagination import SyncPage, AsyncPage
+from .._base_client import AsyncPaginator, make_request_options
+
+if TYPE_CHECKING:
+ from .._client import OpenAI, AsyncOpenAI
+
+__all__ = ["FineTunes", "AsyncFineTunes"]
+
+
+class FineTunes(SyncAPIResource):
+ with_raw_response: FineTunesWithRawResponse
+
+ def __init__(self, client: OpenAI) -> None:
+ super().__init__(client)
+ self.with_raw_response = FineTunesWithRawResponse(self)
+
+ def create(
+ self,
+ *,
+ training_file: str,
+ batch_size: Optional[int] | NotGiven = NOT_GIVEN,
+ classification_betas: Optional[List[float]] | NotGiven = NOT_GIVEN,
+ classification_n_classes: Optional[int] | NotGiven = NOT_GIVEN,
+ classification_positive_class: Optional[str] | NotGiven = NOT_GIVEN,
+ compute_classification_metrics: Optional[bool] | NotGiven = NOT_GIVEN,
+ hyperparameters: fine_tune_create_params.Hyperparameters | NotGiven = NOT_GIVEN,
+ learning_rate_multiplier: Optional[float] | NotGiven = NOT_GIVEN,
+ model: Union[str, Literal["ada", "babbage", "curie", "davinci"], None] | NotGiven = NOT_GIVEN,
+ prompt_loss_weight: Optional[float] | NotGiven = NOT_GIVEN,
+ suffix: Optional[str] | NotGiven = NOT_GIVEN,
+ validation_file: Optional[str] | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> FineTune:
+ """
+ Creates a job that fine-tunes a specified model from a given dataset.
+
+ Response includes details of the enqueued job including job status and the name
+ of the fine-tuned models once complete.
+
+ [Learn more about fine-tuning](https://platform.openai.com/docs/guides/legacy-fine-tuning)
+
+ Args:
+ training_file: The ID of an uploaded file that contains training data.
+
+ See [upload file](https://platform.openai.com/docs/api-reference/files/upload)
+ for how to upload a file.
+
+ Your dataset must be formatted as a JSONL file, where each training example is a
+ JSON object with the keys "prompt" and "completion". Additionally, you must
+ upload your file with the purpose `fine-tune`.
+
+ See the
+ [fine-tuning guide](https://platform.openai.com/docs/guides/legacy-fine-tuning/creating-training-data)
+ for more details.
+
+ batch_size: The batch size to use for training. The batch size is the number of training
+ examples used to train a single forward and backward pass.
+
+ By default, the batch size will be dynamically configured to be ~0.2% of the
+ number of examples in the training set, capped at 256 - in general, we've found
+ that larger batch sizes tend to work better for larger datasets.
+
+ classification_betas: If this is provided, we calculate F-beta scores at the specified beta values.
+ The F-beta score is a generalization of F-1 score. This is only used for binary
+ classification.
+
+ With a beta of 1 (i.e. the F-1 score), precision and recall are given the same
+ weight. A larger beta score puts more weight on recall and less on precision. A
+ smaller beta score puts more weight on precision and less on recall.
+
+ classification_n_classes: The number of classes in a classification task.
+
+ This parameter is required for multiclass classification.
+
+ classification_positive_class: The positive class in binary classification.
+
+ This parameter is needed to generate precision, recall, and F1 metrics when
+ doing binary classification.
+
+ compute_classification_metrics: If set, we calculate classification-specific metrics such as accuracy and F-1
+ score using the validation set at the end of every epoch. These metrics can be
+ viewed in the
+ [results file](https://platform.openai.com/docs/guides/legacy-fine-tuning/analyzing-your-fine-tuned-model).
+
+ In order to compute classification metrics, you must provide a
+ `validation_file`. Additionally, you must specify `classification_n_classes` for
+ multiclass classification or `classification_positive_class` for binary
+ classification.
+
+ hyperparameters: The hyperparameters used for the fine-tuning job.
+
+ learning_rate_multiplier: The learning rate multiplier to use for training. The fine-tuning learning rate
+ is the original learning rate used for pretraining multiplied by this value.
+
+ By default, the learning rate multiplier is the 0.05, 0.1, or 0.2 depending on
+ final `batch_size` (larger learning rates tend to perform better with larger
+ batch sizes). We recommend experimenting with values in the range 0.02 to 0.2 to
+ see what produces the best results.
+
+ model: The name of the base model to fine-tune. You can select one of "ada", "babbage",
+ "curie", "davinci", or a fine-tuned model created after 2022-04-21 and before
+ 2023-08-22. To learn more about these models, see the
+ [Models](https://platform.openai.com/docs/models) documentation.
+
+ prompt_loss_weight: The weight to use for loss on the prompt tokens. This controls how much the
+ model tries to learn to generate the prompt (as compared to the completion which
+ always has a weight of 1.0), and can add a stabilizing effect to training when
+ completions are short.
+
+ If prompts are extremely long (relative to completions), it may make sense to
+ reduce this weight so as to avoid over-prioritizing learning the prompt.
+
+ suffix: A string of up to 40 characters that will be added to your fine-tuned model
+ name.
+
+ For example, a `suffix` of "custom-model-name" would produce a model name like
+ `ada:ft-your-org:custom-model-name-2022-02-15-04-21-04`.
+
+ validation_file: The ID of an uploaded file that contains validation data.
+
+ If you provide this file, the data is used to generate validation metrics
+ periodically during fine-tuning. These metrics can be viewed in the
+ [fine-tuning results file](https://platform.openai.com/docs/guides/legacy-fine-tuning/analyzing-your-fine-tuned-model).
+ Your train and validation data should be mutually exclusive.
+
+ Your dataset must be formatted as a JSONL file, where each validation example is
+ a JSON object with the keys "prompt" and "completion". Additionally, you must
+ upload your file with the purpose `fine-tune`.
+
+ See the
+ [fine-tuning guide](https://platform.openai.com/docs/guides/legacy-fine-tuning/creating-training-data)
+ for more details.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return self._post(
+ "/fine-tunes",
+ body=maybe_transform(
+ {
+ "training_file": training_file,
+ "batch_size": batch_size,
+ "classification_betas": classification_betas,
+ "classification_n_classes": classification_n_classes,
+ "classification_positive_class": classification_positive_class,
+ "compute_classification_metrics": compute_classification_metrics,
+ "hyperparameters": hyperparameters,
+ "learning_rate_multiplier": learning_rate_multiplier,
+ "model": model,
+ "prompt_loss_weight": prompt_loss_weight,
+ "suffix": suffix,
+ "validation_file": validation_file,
+ },
+ fine_tune_create_params.FineTuneCreateParams,
+ ),
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=FineTune,
+ )
+
+ def retrieve(
+ self,
+ fine_tune_id: str,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> FineTune:
+ """
+ Gets info about the fine-tune job.
+
+ [Learn more about fine-tuning](https://platform.openai.com/docs/guides/legacy-fine-tuning)
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return self._get(
+ f"/fine-tunes/{fine_tune_id}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=FineTune,
+ )
+
+ def list(
+ self,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> SyncPage[FineTune]:
+ """List your organization's fine-tuning jobs"""
+ return self._get_api_list(
+ "/fine-tunes",
+ page=SyncPage[FineTune],
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ model=FineTune,
+ )
+
+ def cancel(
+ self,
+ fine_tune_id: str,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> FineTune:
+ """
+ Immediately cancel a fine-tune job.
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return self._post(
+ f"/fine-tunes/{fine_tune_id}/cancel",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=FineTune,
+ )
+
+ @overload
+ def list_events(
+ self,
+ fine_tune_id: str,
+ *,
+ stream: Literal[False] | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = 86400,
+ ) -> FineTuneEventsListResponse:
+ """
+ Get fine-grained status updates for a fine-tune job.
+
+ Args:
+ stream: Whether to stream events for the fine-tune job. If set to true, events will be
+ sent as data-only
+ [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
+ as they become available. The stream will terminate with a `data: [DONE]`
+ message when the job is finished (succeeded, cancelled, or failed).
+
+ If set to false, only events generated so far will be returned.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ ...
+
+ @overload
+ def list_events(
+ self,
+ fine_tune_id: str,
+ *,
+ stream: Literal[True],
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = 86400,
+ ) -> Stream[FineTuneEvent]:
+ """
+ Get fine-grained status updates for a fine-tune job.
+
+ Args:
+ stream: Whether to stream events for the fine-tune job. If set to true, events will be
+ sent as data-only
+ [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
+ as they become available. The stream will terminate with a `data: [DONE]`
+ message when the job is finished (succeeded, cancelled, or failed).
+
+ If set to false, only events generated so far will be returned.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ ...
+
+ @overload
+ def list_events(
+ self,
+ fine_tune_id: str,
+ *,
+ stream: bool,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = 86400,
+ ) -> FineTuneEventsListResponse | Stream[FineTuneEvent]:
+ """
+ Get fine-grained status updates for a fine-tune job.
+
+ Args:
+ stream: Whether to stream events for the fine-tune job. If set to true, events will be
+ sent as data-only
+ [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
+ as they become available. The stream will terminate with a `data: [DONE]`
+ message when the job is finished (succeeded, cancelled, or failed).
+
+ If set to false, only events generated so far will be returned.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ ...
+
+ def list_events(
+ self,
+ fine_tune_id: str,
+ *,
+ stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = 86400,
+ ) -> FineTuneEventsListResponse | Stream[FineTuneEvent]:
+ return self._get(
+ f"/fine-tunes/{fine_tune_id}/events",
+ options=make_request_options(
+ extra_headers=extra_headers,
+ extra_query=extra_query,
+ extra_body=extra_body,
+ timeout=timeout,
+ query=maybe_transform({"stream": stream}, fine_tune_list_events_params.FineTuneListEventsParams),
+ ),
+ cast_to=FineTuneEventsListResponse,
+ stream=stream or False,
+ stream_cls=Stream[FineTuneEvent],
+ )
+
+
+class AsyncFineTunes(AsyncAPIResource):
+ with_raw_response: AsyncFineTunesWithRawResponse
+
+ def __init__(self, client: AsyncOpenAI) -> None:
+ super().__init__(client)
+ self.with_raw_response = AsyncFineTunesWithRawResponse(self)
+
+ async def create(
+ self,
+ *,
+ training_file: str,
+ batch_size: Optional[int] | NotGiven = NOT_GIVEN,
+ classification_betas: Optional[List[float]] | NotGiven = NOT_GIVEN,
+ classification_n_classes: Optional[int] | NotGiven = NOT_GIVEN,
+ classification_positive_class: Optional[str] | NotGiven = NOT_GIVEN,
+ compute_classification_metrics: Optional[bool] | NotGiven = NOT_GIVEN,
+ hyperparameters: fine_tune_create_params.Hyperparameters | NotGiven = NOT_GIVEN,
+ learning_rate_multiplier: Optional[float] | NotGiven = NOT_GIVEN,
+ model: Union[str, Literal["ada", "babbage", "curie", "davinci"], None] | NotGiven = NOT_GIVEN,
+ prompt_loss_weight: Optional[float] | NotGiven = NOT_GIVEN,
+ suffix: Optional[str] | NotGiven = NOT_GIVEN,
+ validation_file: Optional[str] | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> FineTune:
+ """
+ Creates a job that fine-tunes a specified model from a given dataset.
+
+ Response includes details of the enqueued job including job status and the name
+ of the fine-tuned models once complete.
+
+ [Learn more about fine-tuning](https://platform.openai.com/docs/guides/legacy-fine-tuning)
+
+ Args:
+ training_file: The ID of an uploaded file that contains training data.
+
+ See [upload file](https://platform.openai.com/docs/api-reference/files/upload)
+ for how to upload a file.
+
+ Your dataset must be formatted as a JSONL file, where each training example is a
+ JSON object with the keys "prompt" and "completion". Additionally, you must
+ upload your file with the purpose `fine-tune`.
+
+ See the
+ [fine-tuning guide](https://platform.openai.com/docs/guides/legacy-fine-tuning/creating-training-data)
+ for more details.
+
+ batch_size: The batch size to use for training. The batch size is the number of training
+ examples used to train a single forward and backward pass.
+
+ By default, the batch size will be dynamically configured to be ~0.2% of the
+ number of examples in the training set, capped at 256 - in general, we've found
+ that larger batch sizes tend to work better for larger datasets.
+
+ classification_betas: If this is provided, we calculate F-beta scores at the specified beta values.
+ The F-beta score is a generalization of F-1 score. This is only used for binary
+ classification.
+
+ With a beta of 1 (i.e. the F-1 score), precision and recall are given the same
+ weight. A larger beta score puts more weight on recall and less on precision. A
+ smaller beta score puts more weight on precision and less on recall.
+
+ classification_n_classes: The number of classes in a classification task.
+
+ This parameter is required for multiclass classification.
+
+ classification_positive_class: The positive class in binary classification.
+
+ This parameter is needed to generate precision, recall, and F1 metrics when
+ doing binary classification.
+
+ compute_classification_metrics: If set, we calculate classification-specific metrics such as accuracy and F-1
+ score using the validation set at the end of every epoch. These metrics can be
+ viewed in the
+ [results file](https://platform.openai.com/docs/guides/legacy-fine-tuning/analyzing-your-fine-tuned-model).
+
+ In order to compute classification metrics, you must provide a
+ `validation_file`. Additionally, you must specify `classification_n_classes` for
+ multiclass classification or `classification_positive_class` for binary
+ classification.
+
+ hyperparameters: The hyperparameters used for the fine-tuning job.
+
+ learning_rate_multiplier: The learning rate multiplier to use for training. The fine-tuning learning rate
+ is the original learning rate used for pretraining multiplied by this value.
+
+ By default, the learning rate multiplier is the 0.05, 0.1, or 0.2 depending on
+ final `batch_size` (larger learning rates tend to perform better with larger
+ batch sizes). We recommend experimenting with values in the range 0.02 to 0.2 to
+ see what produces the best results.
+
+ model: The name of the base model to fine-tune. You can select one of "ada", "babbage",
+ "curie", "davinci", or a fine-tuned model created after 2022-04-21 and before
+ 2023-08-22. To learn more about these models, see the
+ [Models](https://platform.openai.com/docs/models) documentation.
+
+ prompt_loss_weight: The weight to use for loss on the prompt tokens. This controls how much the
+ model tries to learn to generate the prompt (as compared to the completion which
+ always has a weight of 1.0), and can add a stabilizing effect to training when
+ completions are short.
+
+ If prompts are extremely long (relative to completions), it may make sense to
+ reduce this weight so as to avoid over-prioritizing learning the prompt.
+
+ suffix: A string of up to 40 characters that will be added to your fine-tuned model
+ name.
+
+ For example, a `suffix` of "custom-model-name" would produce a model name like
+ `ada:ft-your-org:custom-model-name-2022-02-15-04-21-04`.
+
+ validation_file: The ID of an uploaded file that contains validation data.
+
+ If you provide this file, the data is used to generate validation metrics
+ periodically during fine-tuning. These metrics can be viewed in the
+ [fine-tuning results file](https://platform.openai.com/docs/guides/legacy-fine-tuning/analyzing-your-fine-tuned-model).
+ Your train and validation data should be mutually exclusive.
+
+ Your dataset must be formatted as a JSONL file, where each validation example is
+ a JSON object with the keys "prompt" and "completion". Additionally, you must
+ upload your file with the purpose `fine-tune`.
+
+ See the
+ [fine-tuning guide](https://platform.openai.com/docs/guides/legacy-fine-tuning/creating-training-data)
+ for more details.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return await self._post(
+ "/fine-tunes",
+ body=maybe_transform(
+ {
+ "training_file": training_file,
+ "batch_size": batch_size,
+ "classification_betas": classification_betas,
+ "classification_n_classes": classification_n_classes,
+ "classification_positive_class": classification_positive_class,
+ "compute_classification_metrics": compute_classification_metrics,
+ "hyperparameters": hyperparameters,
+ "learning_rate_multiplier": learning_rate_multiplier,
+ "model": model,
+ "prompt_loss_weight": prompt_loss_weight,
+ "suffix": suffix,
+ "validation_file": validation_file,
+ },
+ fine_tune_create_params.FineTuneCreateParams,
+ ),
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=FineTune,
+ )
+
+ async def retrieve(
+ self,
+ fine_tune_id: str,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> FineTune:
+ """
+ Gets info about the fine-tune job.
+
+ [Learn more about fine-tuning](https://platform.openai.com/docs/guides/legacy-fine-tuning)
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return await self._get(
+ f"/fine-tunes/{fine_tune_id}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=FineTune,
+ )
+
+ def list(
+ self,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> AsyncPaginator[FineTune, AsyncPage[FineTune]]:
+ """List your organization's fine-tuning jobs"""
+ return self._get_api_list(
+ "/fine-tunes",
+ page=AsyncPage[FineTune],
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ model=FineTune,
+ )
+
+ async def cancel(
+ self,
+ fine_tune_id: str,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> FineTune:
+ """
+ Immediately cancel a fine-tune job.
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return await self._post(
+ f"/fine-tunes/{fine_tune_id}/cancel",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=FineTune,
+ )
+
+ @overload
+ async def list_events(
+ self,
+ fine_tune_id: str,
+ *,
+ stream: Literal[False] | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = 86400,
+ ) -> FineTuneEventsListResponse:
+ """
+ Get fine-grained status updates for a fine-tune job.
+
+ Args:
+ stream: Whether to stream events for the fine-tune job. If set to true, events will be
+ sent as data-only
+ [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
+ as they become available. The stream will terminate with a `data: [DONE]`
+ message when the job is finished (succeeded, cancelled, or failed).
+
+ If set to false, only events generated so far will be returned.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ ...
+
+ @overload
+ async def list_events(
+ self,
+ fine_tune_id: str,
+ *,
+ stream: Literal[True],
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = 86400,
+ ) -> AsyncStream[FineTuneEvent]:
+ """
+ Get fine-grained status updates for a fine-tune job.
+
+ Args:
+ stream: Whether to stream events for the fine-tune job. If set to true, events will be
+ sent as data-only
+ [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
+ as they become available. The stream will terminate with a `data: [DONE]`
+ message when the job is finished (succeeded, cancelled, or failed).
+
+ If set to false, only events generated so far will be returned.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ ...
+
+ @overload
+ async def list_events(
+ self,
+ fine_tune_id: str,
+ *,
+ stream: bool,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = 86400,
+ ) -> FineTuneEventsListResponse | AsyncStream[FineTuneEvent]:
+ """
+ Get fine-grained status updates for a fine-tune job.
+
+ Args:
+ stream: Whether to stream events for the fine-tune job. If set to true, events will be
+ sent as data-only
+ [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
+ as they become available. The stream will terminate with a `data: [DONE]`
+ message when the job is finished (succeeded, cancelled, or failed).
+
+ If set to false, only events generated so far will be returned.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ ...
+
+ async def list_events(
+ self,
+ fine_tune_id: str,
+ *,
+ stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = 86400,
+ ) -> FineTuneEventsListResponse | AsyncStream[FineTuneEvent]:
+ return await self._get(
+ f"/fine-tunes/{fine_tune_id}/events",
+ options=make_request_options(
+ extra_headers=extra_headers,
+ extra_query=extra_query,
+ extra_body=extra_body,
+ timeout=timeout,
+ query=maybe_transform({"stream": stream}, fine_tune_list_events_params.FineTuneListEventsParams),
+ ),
+ cast_to=FineTuneEventsListResponse,
+ stream=stream or False,
+ stream_cls=AsyncStream[FineTuneEvent],
+ )
+
+
+class FineTunesWithRawResponse:
+ def __init__(self, fine_tunes: FineTunes) -> None:
+ self.create = to_raw_response_wrapper(
+ fine_tunes.create,
+ )
+ self.retrieve = to_raw_response_wrapper(
+ fine_tunes.retrieve,
+ )
+ self.list = to_raw_response_wrapper(
+ fine_tunes.list,
+ )
+ self.cancel = to_raw_response_wrapper(
+ fine_tunes.cancel,
+ )
+ self.list_events = to_raw_response_wrapper(
+ fine_tunes.list_events,
+ )
+
+
+class AsyncFineTunesWithRawResponse:
+ def __init__(self, fine_tunes: AsyncFineTunes) -> None:
+ self.create = async_to_raw_response_wrapper(
+ fine_tunes.create,
+ )
+ self.retrieve = async_to_raw_response_wrapper(
+ fine_tunes.retrieve,
+ )
+ self.list = async_to_raw_response_wrapper(
+ fine_tunes.list,
+ )
+ self.cancel = async_to_raw_response_wrapper(
+ fine_tunes.cancel,
+ )
+ self.list_events = async_to_raw_response_wrapper(
+ fine_tunes.list_events,
+ )
diff --git a/src/openai/resources/fine_tuning/__init__.py b/src/openai/resources/fine_tuning/__init__.py
new file mode 100644
index 0000000000..9133c25d4a
--- /dev/null
+++ b/src/openai/resources/fine_tuning/__init__.py
@@ -0,0 +1,20 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from .jobs import Jobs, AsyncJobs, JobsWithRawResponse, AsyncJobsWithRawResponse
+from .fine_tuning import (
+ FineTuning,
+ AsyncFineTuning,
+ FineTuningWithRawResponse,
+ AsyncFineTuningWithRawResponse,
+)
+
+__all__ = [
+ "Jobs",
+ "AsyncJobs",
+ "JobsWithRawResponse",
+ "AsyncJobsWithRawResponse",
+ "FineTuning",
+ "AsyncFineTuning",
+ "FineTuningWithRawResponse",
+ "AsyncFineTuningWithRawResponse",
+]
diff --git a/src/openai/resources/fine_tuning/fine_tuning.py b/src/openai/resources/fine_tuning/fine_tuning.py
new file mode 100644
index 0000000000..2e5f36e546
--- /dev/null
+++ b/src/openai/resources/fine_tuning/fine_tuning.py
@@ -0,0 +1,43 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from .jobs import Jobs, AsyncJobs, JobsWithRawResponse, AsyncJobsWithRawResponse
+from ..._resource import SyncAPIResource, AsyncAPIResource
+
+if TYPE_CHECKING:
+ from ..._client import OpenAI, AsyncOpenAI
+
+__all__ = ["FineTuning", "AsyncFineTuning"]
+
+
+class FineTuning(SyncAPIResource):
+ jobs: Jobs
+ with_raw_response: FineTuningWithRawResponse
+
+ def __init__(self, client: OpenAI) -> None:
+ super().__init__(client)
+ self.jobs = Jobs(client)
+ self.with_raw_response = FineTuningWithRawResponse(self)
+
+
+class AsyncFineTuning(AsyncAPIResource):
+ jobs: AsyncJobs
+ with_raw_response: AsyncFineTuningWithRawResponse
+
+ def __init__(self, client: AsyncOpenAI) -> None:
+ super().__init__(client)
+ self.jobs = AsyncJobs(client)
+ self.with_raw_response = AsyncFineTuningWithRawResponse(self)
+
+
+class FineTuningWithRawResponse:
+ def __init__(self, fine_tuning: FineTuning) -> None:
+ self.jobs = JobsWithRawResponse(fine_tuning.jobs)
+
+
+class AsyncFineTuningWithRawResponse:
+ def __init__(self, fine_tuning: AsyncFineTuning) -> None:
+ self.jobs = AsyncJobsWithRawResponse(fine_tuning.jobs)
diff --git a/src/openai/resources/fine_tuning/jobs.py b/src/openai/resources/fine_tuning/jobs.py
new file mode 100644
index 0000000000..b721c892b5
--- /dev/null
+++ b/src/openai/resources/fine_tuning/jobs.py
@@ -0,0 +1,567 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Union, Optional
+from typing_extensions import Literal
+
+from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
+from ..._utils import maybe_transform
+from ..._resource import SyncAPIResource, AsyncAPIResource
+from ..._response import to_raw_response_wrapper, async_to_raw_response_wrapper
+from ...pagination import SyncCursorPage, AsyncCursorPage
+from ..._base_client import AsyncPaginator, make_request_options
+from ...types.fine_tuning import (
+ FineTuningJob,
+ FineTuningJobEvent,
+ job_list_params,
+ job_create_params,
+ job_list_events_params,
+)
+
+if TYPE_CHECKING:
+ from ..._client import OpenAI, AsyncOpenAI
+
+__all__ = ["Jobs", "AsyncJobs"]
+
+
+class Jobs(SyncAPIResource):
+ with_raw_response: JobsWithRawResponse
+
+ def __init__(self, client: OpenAI) -> None:
+ super().__init__(client)
+ self.with_raw_response = JobsWithRawResponse(self)
+
+ def create(
+ self,
+ *,
+ model: Union[str, Literal["babbage-002", "davinci-002", "gpt-3.5-turbo"]],
+ training_file: str,
+ hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN,
+ suffix: Optional[str] | NotGiven = NOT_GIVEN,
+ validation_file: Optional[str] | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> FineTuningJob:
+ """
+ Creates a job that fine-tunes a specified model from a given dataset.
+
+ Response includes details of the enqueued job including job status and the name
+ of the fine-tuned models once complete.
+
+ [Learn more about fine-tuning](https://platform.openai.com/docs/guides/fine-tuning)
+
+ Args:
+ model: The name of the model to fine-tune. You can select one of the
+ [supported models](https://platform.openai.com/docs/guides/fine-tuning/what-models-can-be-fine-tuned).
+
+ training_file: The ID of an uploaded file that contains training data.
+
+ See [upload file](https://platform.openai.com/docs/api-reference/files/upload)
+ for how to upload a file.
+
+ Your dataset must be formatted as a JSONL file. Additionally, you must upload
+ your file with the purpose `fine-tune`.
+
+ See the [fine-tuning guide](https://platform.openai.com/docs/guides/fine-tuning)
+ for more details.
+
+ hyperparameters: The hyperparameters used for the fine-tuning job.
+
+ suffix: A string of up to 18 characters that will be added to your fine-tuned model
+ name.
+
+ For example, a `suffix` of "custom-model-name" would produce a model name like
+ `ft:gpt-3.5-turbo:openai:custom-model-name:7p4lURel`.
+
+ validation_file: The ID of an uploaded file that contains validation data.
+
+ If you provide this file, the data is used to generate validation metrics
+ periodically during fine-tuning. These metrics can be viewed in the fine-tuning
+ results file. The same data should not be present in both train and validation
+ files.
+
+ Your dataset must be formatted as a JSONL file. You must upload your file with
+ the purpose `fine-tune`.
+
+ See the [fine-tuning guide](https://platform.openai.com/docs/guides/fine-tuning)
+ for more details.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return self._post(
+ "/fine_tuning/jobs",
+ body=maybe_transform(
+ {
+ "model": model,
+ "training_file": training_file,
+ "hyperparameters": hyperparameters,
+ "suffix": suffix,
+ "validation_file": validation_file,
+ },
+ job_create_params.JobCreateParams,
+ ),
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=FineTuningJob,
+ )
+
+ def retrieve(
+ self,
+ fine_tuning_job_id: str,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> FineTuningJob:
+ """
+ Get info about a fine-tuning job.
+
+ [Learn more about fine-tuning](https://platform.openai.com/docs/guides/fine-tuning)
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return self._get(
+ f"/fine_tuning/jobs/{fine_tuning_job_id}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=FineTuningJob,
+ )
+
+ def list(
+ self,
+ *,
+ after: str | NotGiven = NOT_GIVEN,
+ limit: int | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> SyncCursorPage[FineTuningJob]:
+ """
+ List your organization's fine-tuning jobs
+
+ Args:
+ after: Identifier for the last job from the previous pagination request.
+
+ limit: Number of fine-tuning jobs to retrieve.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return self._get_api_list(
+ "/fine_tuning/jobs",
+ page=SyncCursorPage[FineTuningJob],
+ options=make_request_options(
+ extra_headers=extra_headers,
+ extra_query=extra_query,
+ extra_body=extra_body,
+ timeout=timeout,
+ query=maybe_transform(
+ {
+ "after": after,
+ "limit": limit,
+ },
+ job_list_params.JobListParams,
+ ),
+ ),
+ model=FineTuningJob,
+ )
+
+ def cancel(
+ self,
+ fine_tuning_job_id: str,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> FineTuningJob:
+ """
+ Immediately cancel a fine-tune job.
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return self._post(
+ f"/fine_tuning/jobs/{fine_tuning_job_id}/cancel",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=FineTuningJob,
+ )
+
+ def list_events(
+ self,
+ fine_tuning_job_id: str,
+ *,
+ after: str | NotGiven = NOT_GIVEN,
+ limit: int | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> SyncCursorPage[FineTuningJobEvent]:
+ """
+ Get status updates for a fine-tuning job.
+
+ Args:
+ after: Identifier for the last event from the previous pagination request.
+
+ limit: Number of events to retrieve.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return self._get_api_list(
+ f"/fine_tuning/jobs/{fine_tuning_job_id}/events",
+ page=SyncCursorPage[FineTuningJobEvent],
+ options=make_request_options(
+ extra_headers=extra_headers,
+ extra_query=extra_query,
+ extra_body=extra_body,
+ timeout=timeout,
+ query=maybe_transform(
+ {
+ "after": after,
+ "limit": limit,
+ },
+ job_list_events_params.JobListEventsParams,
+ ),
+ ),
+ model=FineTuningJobEvent,
+ )
+
+
+class AsyncJobs(AsyncAPIResource):
+ with_raw_response: AsyncJobsWithRawResponse
+
+ def __init__(self, client: AsyncOpenAI) -> None:
+ super().__init__(client)
+ self.with_raw_response = AsyncJobsWithRawResponse(self)
+
+ async def create(
+ self,
+ *,
+ model: Union[str, Literal["babbage-002", "davinci-002", "gpt-3.5-turbo"]],
+ training_file: str,
+ hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN,
+ suffix: Optional[str] | NotGiven = NOT_GIVEN,
+ validation_file: Optional[str] | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> FineTuningJob:
+ """
+ Creates a job that fine-tunes a specified model from a given dataset.
+
+ Response includes details of the enqueued job including job status and the name
+ of the fine-tuned models once complete.
+
+ [Learn more about fine-tuning](https://platform.openai.com/docs/guides/fine-tuning)
+
+ Args:
+ model: The name of the model to fine-tune. You can select one of the
+ [supported models](https://platform.openai.com/docs/guides/fine-tuning/what-models-can-be-fine-tuned).
+
+ training_file: The ID of an uploaded file that contains training data.
+
+ See [upload file](https://platform.openai.com/docs/api-reference/files/upload)
+ for how to upload a file.
+
+ Your dataset must be formatted as a JSONL file. Additionally, you must upload
+ your file with the purpose `fine-tune`.
+
+ See the [fine-tuning guide](https://platform.openai.com/docs/guides/fine-tuning)
+ for more details.
+
+ hyperparameters: The hyperparameters used for the fine-tuning job.
+
+ suffix: A string of up to 18 characters that will be added to your fine-tuned model
+ name.
+
+ For example, a `suffix` of "custom-model-name" would produce a model name like
+ `ft:gpt-3.5-turbo:openai:custom-model-name:7p4lURel`.
+
+ validation_file: The ID of an uploaded file that contains validation data.
+
+ If you provide this file, the data is used to generate validation metrics
+ periodically during fine-tuning. These metrics can be viewed in the fine-tuning
+ results file. The same data should not be present in both train and validation
+ files.
+
+ Your dataset must be formatted as a JSONL file. You must upload your file with
+ the purpose `fine-tune`.
+
+ See the [fine-tuning guide](https://platform.openai.com/docs/guides/fine-tuning)
+ for more details.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return await self._post(
+ "/fine_tuning/jobs",
+ body=maybe_transform(
+ {
+ "model": model,
+ "training_file": training_file,
+ "hyperparameters": hyperparameters,
+ "suffix": suffix,
+ "validation_file": validation_file,
+ },
+ job_create_params.JobCreateParams,
+ ),
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=FineTuningJob,
+ )
+
+ async def retrieve(
+ self,
+ fine_tuning_job_id: str,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> FineTuningJob:
+ """
+ Get info about a fine-tuning job.
+
+ [Learn more about fine-tuning](https://platform.openai.com/docs/guides/fine-tuning)
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return await self._get(
+ f"/fine_tuning/jobs/{fine_tuning_job_id}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=FineTuningJob,
+ )
+
+ def list(
+ self,
+ *,
+ after: str | NotGiven = NOT_GIVEN,
+ limit: int | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> AsyncPaginator[FineTuningJob, AsyncCursorPage[FineTuningJob]]:
+ """
+ List your organization's fine-tuning jobs
+
+ Args:
+ after: Identifier for the last job from the previous pagination request.
+
+ limit: Number of fine-tuning jobs to retrieve.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return self._get_api_list(
+ "/fine_tuning/jobs",
+ page=AsyncCursorPage[FineTuningJob],
+ options=make_request_options(
+ extra_headers=extra_headers,
+ extra_query=extra_query,
+ extra_body=extra_body,
+ timeout=timeout,
+ query=maybe_transform(
+ {
+ "after": after,
+ "limit": limit,
+ },
+ job_list_params.JobListParams,
+ ),
+ ),
+ model=FineTuningJob,
+ )
+
+ async def cancel(
+ self,
+ fine_tuning_job_id: str,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> FineTuningJob:
+ """
+ Immediately cancel a fine-tune job.
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return await self._post(
+ f"/fine_tuning/jobs/{fine_tuning_job_id}/cancel",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=FineTuningJob,
+ )
+
+ def list_events(
+ self,
+ fine_tuning_job_id: str,
+ *,
+ after: str | NotGiven = NOT_GIVEN,
+ limit: int | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> AsyncPaginator[FineTuningJobEvent, AsyncCursorPage[FineTuningJobEvent]]:
+ """
+ Get status updates for a fine-tuning job.
+
+ Args:
+ after: Identifier for the last event from the previous pagination request.
+
+ limit: Number of events to retrieve.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return self._get_api_list(
+ f"/fine_tuning/jobs/{fine_tuning_job_id}/events",
+ page=AsyncCursorPage[FineTuningJobEvent],
+ options=make_request_options(
+ extra_headers=extra_headers,
+ extra_query=extra_query,
+ extra_body=extra_body,
+ timeout=timeout,
+ query=maybe_transform(
+ {
+ "after": after,
+ "limit": limit,
+ },
+ job_list_events_params.JobListEventsParams,
+ ),
+ ),
+ model=FineTuningJobEvent,
+ )
+
+
+class JobsWithRawResponse:
+ def __init__(self, jobs: Jobs) -> None:
+ self.create = to_raw_response_wrapper(
+ jobs.create,
+ )
+ self.retrieve = to_raw_response_wrapper(
+ jobs.retrieve,
+ )
+ self.list = to_raw_response_wrapper(
+ jobs.list,
+ )
+ self.cancel = to_raw_response_wrapper(
+ jobs.cancel,
+ )
+ self.list_events = to_raw_response_wrapper(
+ jobs.list_events,
+ )
+
+
+class AsyncJobsWithRawResponse:
+ def __init__(self, jobs: AsyncJobs) -> None:
+ self.create = async_to_raw_response_wrapper(
+ jobs.create,
+ )
+ self.retrieve = async_to_raw_response_wrapper(
+ jobs.retrieve,
+ )
+ self.list = async_to_raw_response_wrapper(
+ jobs.list,
+ )
+ self.cancel = async_to_raw_response_wrapper(
+ jobs.cancel,
+ )
+ self.list_events = async_to_raw_response_wrapper(
+ jobs.list_events,
+ )
diff --git a/src/openai/resources/images.py b/src/openai/resources/images.py
new file mode 100644
index 0000000000..1fd39b43a6
--- /dev/null
+++ b/src/openai/resources/images.py
@@ -0,0 +1,479 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Mapping, Optional, cast
+from typing_extensions import Literal
+
+from ..types import (
+ ImagesResponse,
+ image_edit_params,
+ image_generate_params,
+ image_create_variation_params,
+)
+from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes
+from .._utils import extract_files, maybe_transform, deepcopy_minimal
+from .._resource import SyncAPIResource, AsyncAPIResource
+from .._response import to_raw_response_wrapper, async_to_raw_response_wrapper
+from .._base_client import make_request_options
+
+if TYPE_CHECKING:
+ from .._client import OpenAI, AsyncOpenAI
+
+__all__ = ["Images", "AsyncImages"]
+
+
+class Images(SyncAPIResource):
+ with_raw_response: ImagesWithRawResponse
+
+ def __init__(self, client: OpenAI) -> None:
+ super().__init__(client)
+ self.with_raw_response = ImagesWithRawResponse(self)
+
+ def create_variation(
+ self,
+ *,
+ image: FileTypes,
+ n: Optional[int] | NotGiven = NOT_GIVEN,
+ response_format: Optional[Literal["url", "b64_json"]] | NotGiven = NOT_GIVEN,
+ size: Optional[Literal["256x256", "512x512", "1024x1024"]] | NotGiven = NOT_GIVEN,
+ user: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> ImagesResponse:
+ """
+ Creates a variation of a given image.
+
+ Args:
+ image: The image to use as the basis for the variation(s). Must be a valid PNG file,
+ less than 4MB, and square.
+
+ n: The number of images to generate. Must be between 1 and 10.
+
+ response_format: The format in which the generated images are returned. Must be one of `url` or
+ `b64_json`.
+
+ size: The size of the generated images. Must be one of `256x256`, `512x512`, or
+ `1024x1024`.
+
+ user: A unique identifier representing your end-user, which can help OpenAI to monitor
+ and detect abuse.
+ [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ body = deepcopy_minimal(
+ {
+ "image": image,
+ "n": n,
+ "response_format": response_format,
+ "size": size,
+ "user": user,
+ }
+ )
+ files = extract_files(cast(Mapping[str, object], body), paths=[["image"]])
+ if files:
+ # It should be noted that the actual Content-Type header that will be
+ # sent to the server will contain a `boundary` parameter, e.g.
+ # multipart/form-data; boundary=---abc--
+ extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
+
+ return self._post(
+ "/images/variations",
+ body=maybe_transform(body, image_create_variation_params.ImageCreateVariationParams),
+ files=files,
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=ImagesResponse,
+ )
+
+ def edit(
+ self,
+ *,
+ image: FileTypes,
+ prompt: str,
+ mask: FileTypes | NotGiven = NOT_GIVEN,
+ n: Optional[int] | NotGiven = NOT_GIVEN,
+ response_format: Optional[Literal["url", "b64_json"]] | NotGiven = NOT_GIVEN,
+ size: Optional[Literal["256x256", "512x512", "1024x1024"]] | NotGiven = NOT_GIVEN,
+ user: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> ImagesResponse:
+ """
+ Creates an edited or extended image given an original image and a prompt.
+
+ Args:
+ image: The image to edit. Must be a valid PNG file, less than 4MB, and square. If mask
+ is not provided, image must have transparency, which will be used as the mask.
+
+ prompt: A text description of the desired image(s). The maximum length is 1000
+ characters.
+
+ mask: An additional image whose fully transparent areas (e.g. where alpha is zero)
+ indicate where `image` should be edited. Must be a valid PNG file, less than
+ 4MB, and have the same dimensions as `image`.
+
+ n: The number of images to generate. Must be between 1 and 10.
+
+ response_format: The format in which the generated images are returned. Must be one of `url` or
+ `b64_json`.
+
+ size: The size of the generated images. Must be one of `256x256`, `512x512`, or
+ `1024x1024`.
+
+ user: A unique identifier representing your end-user, which can help OpenAI to monitor
+ and detect abuse.
+ [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ body = deepcopy_minimal(
+ {
+ "image": image,
+ "prompt": prompt,
+ "mask": mask,
+ "n": n,
+ "response_format": response_format,
+ "size": size,
+ "user": user,
+ }
+ )
+ files = extract_files(cast(Mapping[str, object], body), paths=[["image"], ["mask"]])
+ if files:
+ # It should be noted that the actual Content-Type header that will be
+ # sent to the server will contain a `boundary` parameter, e.g.
+ # multipart/form-data; boundary=---abc--
+ extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
+
+ return self._post(
+ "/images/edits",
+ body=maybe_transform(body, image_edit_params.ImageEditParams),
+ files=files,
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=ImagesResponse,
+ )
+
+ def generate(
+ self,
+ *,
+ prompt: str,
+ n: Optional[int] | NotGiven = NOT_GIVEN,
+ response_format: Optional[Literal["url", "b64_json"]] | NotGiven = NOT_GIVEN,
+ size: Optional[Literal["256x256", "512x512", "1024x1024"]] | NotGiven = NOT_GIVEN,
+ user: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> ImagesResponse:
+ """
+ Creates an image given a prompt.
+
+ Args:
+ prompt: A text description of the desired image(s). The maximum length is 1000
+ characters.
+
+ n: The number of images to generate. Must be between 1 and 10.
+
+ response_format: The format in which the generated images are returned. Must be one of `url` or
+ `b64_json`.
+
+ size: The size of the generated images. Must be one of `256x256`, `512x512`, or
+ `1024x1024`.
+
+ user: A unique identifier representing your end-user, which can help OpenAI to monitor
+ and detect abuse.
+ [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return self._post(
+ "/images/generations",
+ body=maybe_transform(
+ {
+ "prompt": prompt,
+ "n": n,
+ "response_format": response_format,
+ "size": size,
+ "user": user,
+ },
+ image_generate_params.ImageGenerateParams,
+ ),
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=ImagesResponse,
+ )
+
+
+class AsyncImages(AsyncAPIResource):
+ with_raw_response: AsyncImagesWithRawResponse
+
+ def __init__(self, client: AsyncOpenAI) -> None:
+ super().__init__(client)
+ self.with_raw_response = AsyncImagesWithRawResponse(self)
+
+ async def create_variation(
+ self,
+ *,
+ image: FileTypes,
+ n: Optional[int] | NotGiven = NOT_GIVEN,
+ response_format: Optional[Literal["url", "b64_json"]] | NotGiven = NOT_GIVEN,
+ size: Optional[Literal["256x256", "512x512", "1024x1024"]] | NotGiven = NOT_GIVEN,
+ user: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> ImagesResponse:
+ """
+ Creates a variation of a given image.
+
+ Args:
+ image: The image to use as the basis for the variation(s). Must be a valid PNG file,
+ less than 4MB, and square.
+
+ n: The number of images to generate. Must be between 1 and 10.
+
+ response_format: The format in which the generated images are returned. Must be one of `url` or
+ `b64_json`.
+
+ size: The size of the generated images. Must be one of `256x256`, `512x512`, or
+ `1024x1024`.
+
+ user: A unique identifier representing your end-user, which can help OpenAI to monitor
+ and detect abuse.
+ [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ body = deepcopy_minimal(
+ {
+ "image": image,
+ "n": n,
+ "response_format": response_format,
+ "size": size,
+ "user": user,
+ }
+ )
+ files = extract_files(cast(Mapping[str, object], body), paths=[["image"]])
+ if files:
+ # It should be noted that the actual Content-Type header that will be
+ # sent to the server will contain a `boundary` parameter, e.g.
+ # multipart/form-data; boundary=---abc--
+ extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
+
+ return await self._post(
+ "/images/variations",
+ body=maybe_transform(body, image_create_variation_params.ImageCreateVariationParams),
+ files=files,
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=ImagesResponse,
+ )
+
+ async def edit(
+ self,
+ *,
+ image: FileTypes,
+ prompt: str,
+ mask: FileTypes | NotGiven = NOT_GIVEN,
+ n: Optional[int] | NotGiven = NOT_GIVEN,
+ response_format: Optional[Literal["url", "b64_json"]] | NotGiven = NOT_GIVEN,
+ size: Optional[Literal["256x256", "512x512", "1024x1024"]] | NotGiven = NOT_GIVEN,
+ user: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> ImagesResponse:
+ """
+ Creates an edited or extended image given an original image and a prompt.
+
+ Args:
+ image: The image to edit. Must be a valid PNG file, less than 4MB, and square. If mask
+ is not provided, image must have transparency, which will be used as the mask.
+
+ prompt: A text description of the desired image(s). The maximum length is 1000
+ characters.
+
+ mask: An additional image whose fully transparent areas (e.g. where alpha is zero)
+ indicate where `image` should be edited. Must be a valid PNG file, less than
+ 4MB, and have the same dimensions as `image`.
+
+ n: The number of images to generate. Must be between 1 and 10.
+
+ response_format: The format in which the generated images are returned. Must be one of `url` or
+ `b64_json`.
+
+ size: The size of the generated images. Must be one of `256x256`, `512x512`, or
+ `1024x1024`.
+
+ user: A unique identifier representing your end-user, which can help OpenAI to monitor
+ and detect abuse.
+ [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ body = deepcopy_minimal(
+ {
+ "image": image,
+ "prompt": prompt,
+ "mask": mask,
+ "n": n,
+ "response_format": response_format,
+ "size": size,
+ "user": user,
+ }
+ )
+ files = extract_files(cast(Mapping[str, object], body), paths=[["image"], ["mask"]])
+ if files:
+ # It should be noted that the actual Content-Type header that will be
+ # sent to the server will contain a `boundary` parameter, e.g.
+ # multipart/form-data; boundary=---abc--
+ extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
+
+ return await self._post(
+ "/images/edits",
+ body=maybe_transform(body, image_edit_params.ImageEditParams),
+ files=files,
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=ImagesResponse,
+ )
+
+ async def generate(
+ self,
+ *,
+ prompt: str,
+ n: Optional[int] | NotGiven = NOT_GIVEN,
+ response_format: Optional[Literal["url", "b64_json"]] | NotGiven = NOT_GIVEN,
+ size: Optional[Literal["256x256", "512x512", "1024x1024"]] | NotGiven = NOT_GIVEN,
+ user: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> ImagesResponse:
+ """
+ Creates an image given a prompt.
+
+ Args:
+ prompt: A text description of the desired image(s). The maximum length is 1000
+ characters.
+
+ n: The number of images to generate. Must be between 1 and 10.
+
+ response_format: The format in which the generated images are returned. Must be one of `url` or
+ `b64_json`.
+
+ size: The size of the generated images. Must be one of `256x256`, `512x512`, or
+ `1024x1024`.
+
+ user: A unique identifier representing your end-user, which can help OpenAI to monitor
+ and detect abuse.
+ [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return await self._post(
+ "/images/generations",
+ body=maybe_transform(
+ {
+ "prompt": prompt,
+ "n": n,
+ "response_format": response_format,
+ "size": size,
+ "user": user,
+ },
+ image_generate_params.ImageGenerateParams,
+ ),
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=ImagesResponse,
+ )
+
+
+class ImagesWithRawResponse:
+ def __init__(self, images: Images) -> None:
+ self.create_variation = to_raw_response_wrapper(
+ images.create_variation,
+ )
+ self.edit = to_raw_response_wrapper(
+ images.edit,
+ )
+ self.generate = to_raw_response_wrapper(
+ images.generate,
+ )
+
+
+class AsyncImagesWithRawResponse:
+ def __init__(self, images: AsyncImages) -> None:
+ self.create_variation = async_to_raw_response_wrapper(
+ images.create_variation,
+ )
+ self.edit = async_to_raw_response_wrapper(
+ images.edit,
+ )
+ self.generate = async_to_raw_response_wrapper(
+ images.generate,
+ )
diff --git a/src/openai/resources/models.py b/src/openai/resources/models.py
new file mode 100644
index 0000000000..689bbd6621
--- /dev/null
+++ b/src/openai/resources/models.py
@@ -0,0 +1,235 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from ..types import Model, ModelDeleted
+from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
+from .._resource import SyncAPIResource, AsyncAPIResource
+from .._response import to_raw_response_wrapper, async_to_raw_response_wrapper
+from ..pagination import SyncPage, AsyncPage
+from .._base_client import AsyncPaginator, make_request_options
+
+if TYPE_CHECKING:
+ from .._client import OpenAI, AsyncOpenAI
+
+__all__ = ["Models", "AsyncModels"]
+
+
+class Models(SyncAPIResource):
+ with_raw_response: ModelsWithRawResponse
+
+ def __init__(self, client: OpenAI) -> None:
+ super().__init__(client)
+ self.with_raw_response = ModelsWithRawResponse(self)
+
+ def retrieve(
+ self,
+ model: str,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> Model:
+ """
+ Retrieves a model instance, providing basic information about the model such as
+ the owner and permissioning.
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return self._get(
+ f"/models/{model}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=Model,
+ )
+
+ def list(
+ self,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> SyncPage[Model]:
+ """
+ Lists the currently available models, and provides basic information about each
+ one such as the owner and availability.
+ """
+ return self._get_api_list(
+ "/models",
+ page=SyncPage[Model],
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ model=Model,
+ )
+
+ def delete(
+ self,
+ model: str,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> ModelDeleted:
+ """Delete a fine-tuned model.
+
+ You must have the Owner role in your organization to
+ delete a model.
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return self._delete(
+ f"/models/{model}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=ModelDeleted,
+ )
+
+
+class AsyncModels(AsyncAPIResource):
+ with_raw_response: AsyncModelsWithRawResponse
+
+ def __init__(self, client: AsyncOpenAI) -> None:
+ super().__init__(client)
+ self.with_raw_response = AsyncModelsWithRawResponse(self)
+
+ async def retrieve(
+ self,
+ model: str,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> Model:
+ """
+ Retrieves a model instance, providing basic information about the model such as
+ the owner and permissioning.
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return await self._get(
+ f"/models/{model}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=Model,
+ )
+
+ def list(
+ self,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> AsyncPaginator[Model, AsyncPage[Model]]:
+ """
+ Lists the currently available models, and provides basic information about each
+ one such as the owner and availability.
+ """
+ return self._get_api_list(
+ "/models",
+ page=AsyncPage[Model],
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ model=Model,
+ )
+
+ async def delete(
+ self,
+ model: str,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> ModelDeleted:
+ """Delete a fine-tuned model.
+
+ You must have the Owner role in your organization to
+ delete a model.
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return await self._delete(
+ f"/models/{model}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=ModelDeleted,
+ )
+
+
+class ModelsWithRawResponse:
+ def __init__(self, models: Models) -> None:
+ self.retrieve = to_raw_response_wrapper(
+ models.retrieve,
+ )
+ self.list = to_raw_response_wrapper(
+ models.list,
+ )
+ self.delete = to_raw_response_wrapper(
+ models.delete,
+ )
+
+
+class AsyncModelsWithRawResponse:
+ def __init__(self, models: AsyncModels) -> None:
+ self.retrieve = async_to_raw_response_wrapper(
+ models.retrieve,
+ )
+ self.list = async_to_raw_response_wrapper(
+ models.list,
+ )
+ self.delete = async_to_raw_response_wrapper(
+ models.delete,
+ )
diff --git a/src/openai/resources/moderations.py b/src/openai/resources/moderations.py
new file mode 100644
index 0000000000..1ee3e72564
--- /dev/null
+++ b/src/openai/resources/moderations.py
@@ -0,0 +1,148 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, List, Union
+from typing_extensions import Literal
+
+from ..types import ModerationCreateResponse, moderation_create_params
+from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
+from .._utils import maybe_transform
+from .._resource import SyncAPIResource, AsyncAPIResource
+from .._response import to_raw_response_wrapper, async_to_raw_response_wrapper
+from .._base_client import make_request_options
+
+if TYPE_CHECKING:
+ from .._client import OpenAI, AsyncOpenAI
+
+__all__ = ["Moderations", "AsyncModerations"]
+
+
+class Moderations(SyncAPIResource):
+ with_raw_response: ModerationsWithRawResponse
+
+ def __init__(self, client: OpenAI) -> None:
+ super().__init__(client)
+ self.with_raw_response = ModerationsWithRawResponse(self)
+
+ def create(
+ self,
+ *,
+ input: Union[str, List[str]],
+ model: Union[str, Literal["text-moderation-latest", "text-moderation-stable"]] | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> ModerationCreateResponse:
+ """
+ Classifies if text violates OpenAI's Content Policy
+
+ Args:
+ input: The input text to classify
+
+ model: Two content moderations models are available: `text-moderation-stable` and
+ `text-moderation-latest`.
+
+ The default is `text-moderation-latest` which will be automatically upgraded
+ over time. This ensures you are always using our most accurate model. If you use
+ `text-moderation-stable`, we will provide advanced notice before updating the
+ model. Accuracy of `text-moderation-stable` may be slightly lower than for
+ `text-moderation-latest`.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return self._post(
+ "/moderations",
+ body=maybe_transform(
+ {
+ "input": input,
+ "model": model,
+ },
+ moderation_create_params.ModerationCreateParams,
+ ),
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=ModerationCreateResponse,
+ )
+
+
+class AsyncModerations(AsyncAPIResource):
+ with_raw_response: AsyncModerationsWithRawResponse
+
+ def __init__(self, client: AsyncOpenAI) -> None:
+ super().__init__(client)
+ self.with_raw_response = AsyncModerationsWithRawResponse(self)
+
+ async def create(
+ self,
+ *,
+ input: Union[str, List[str]],
+ model: Union[str, Literal["text-moderation-latest", "text-moderation-stable"]] | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | None | NotGiven = NOT_GIVEN,
+ ) -> ModerationCreateResponse:
+ """
+ Classifies if text violates OpenAI's Content Policy
+
+ Args:
+ input: The input text to classify
+
+ model: Two content moderations models are available: `text-moderation-stable` and
+ `text-moderation-latest`.
+
+ The default is `text-moderation-latest` which will be automatically upgraded
+ over time. This ensures you are always using our most accurate model. If you use
+ `text-moderation-stable`, we will provide advanced notice before updating the
+ model. Accuracy of `text-moderation-stable` may be slightly lower than for
+ `text-moderation-latest`.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return await self._post(
+ "/moderations",
+ body=maybe_transform(
+ {
+ "input": input,
+ "model": model,
+ },
+ moderation_create_params.ModerationCreateParams,
+ ),
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=ModerationCreateResponse,
+ )
+
+
+class ModerationsWithRawResponse:
+ def __init__(self, moderations: Moderations) -> None:
+ self.create = to_raw_response_wrapper(
+ moderations.create,
+ )
+
+
+class AsyncModerationsWithRawResponse:
+ def __init__(self, moderations: AsyncModerations) -> None:
+ self.create = async_to_raw_response_wrapper(
+ moderations.create,
+ )
diff --git a/src/openai/types/__init__.py b/src/openai/types/__init__.py
new file mode 100644
index 0000000000..defaf13446
--- /dev/null
+++ b/src/openai/types/__init__.py
@@ -0,0 +1,42 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from .edit import Edit as Edit
+from .image import Image as Image
+from .model import Model as Model
+from .embedding import Embedding as Embedding
+from .fine_tune import FineTune as FineTune
+from .completion import Completion as Completion
+from .moderation import Moderation as Moderation
+from .file_object import FileObject as FileObject
+from .file_content import FileContent as FileContent
+from .file_deleted import FileDeleted as FileDeleted
+from .model_deleted import ModelDeleted as ModelDeleted
+from .fine_tune_event import FineTuneEvent as FineTuneEvent
+from .images_response import ImagesResponse as ImagesResponse
+from .completion_usage import CompletionUsage as CompletionUsage
+from .completion_choice import CompletionChoice as CompletionChoice
+from .image_edit_params import ImageEditParams as ImageEditParams
+from .edit_create_params import EditCreateParams as EditCreateParams
+from .file_create_params import FileCreateParams as FileCreateParams
+from .image_generate_params import ImageGenerateParams as ImageGenerateParams
+from .embedding_create_params import EmbeddingCreateParams as EmbeddingCreateParams
+from .fine_tune_create_params import FineTuneCreateParams as FineTuneCreateParams
+from .completion_create_params import CompletionCreateParams as CompletionCreateParams
+from .moderation_create_params import ModerationCreateParams as ModerationCreateParams
+from .create_embedding_response import (
+ CreateEmbeddingResponse as CreateEmbeddingResponse,
+)
+from .moderation_create_response import (
+ ModerationCreateResponse as ModerationCreateResponse,
+)
+from .fine_tune_list_events_params import (
+ FineTuneListEventsParams as FineTuneListEventsParams,
+)
+from .image_create_variation_params import (
+ ImageCreateVariationParams as ImageCreateVariationParams,
+)
+from .fine_tune_events_list_response import (
+ FineTuneEventsListResponse as FineTuneEventsListResponse,
+)
diff --git a/src/openai/types/audio/__init__.py b/src/openai/types/audio/__init__.py
new file mode 100644
index 0000000000..469bc6f25b
--- /dev/null
+++ b/src/openai/types/audio/__init__.py
@@ -0,0 +1,12 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from .translation import Translation as Translation
+from .transcription import Transcription as Transcription
+from .translation_create_params import (
+ TranslationCreateParams as TranslationCreateParams,
+)
+from .transcription_create_params import (
+ TranscriptionCreateParams as TranscriptionCreateParams,
+)
diff --git a/src/openai/types/audio/transcription.py b/src/openai/types/audio/transcription.py
new file mode 100644
index 0000000000..d2274faa0e
--- /dev/null
+++ b/src/openai/types/audio/transcription.py
@@ -0,0 +1,9 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from ..._models import BaseModel
+
+__all__ = ["Transcription"]
+
+
+class Transcription(BaseModel):
+ text: str
diff --git a/src/openai/types/audio/transcription_create_params.py b/src/openai/types/audio/transcription_create_params.py
new file mode 100644
index 0000000000..f8f193484a
--- /dev/null
+++ b/src/openai/types/audio/transcription_create_params.py
@@ -0,0 +1,52 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import Union
+from typing_extensions import Literal, Required, TypedDict
+
+from ..._types import FileTypes
+
+__all__ = ["TranscriptionCreateParams"]
+
+
+class TranscriptionCreateParams(TypedDict, total=False):
+ file: Required[FileTypes]
+ """
+ The audio file object (not file name) to transcribe, in one of these formats:
+ flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
+ """
+
+ model: Required[Union[str, Literal["whisper-1"]]]
+ """ID of the model to use. Only `whisper-1` is currently available."""
+
+ language: str
+ """The language of the input audio.
+
+ Supplying the input language in
+ [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will
+ improve accuracy and latency.
+ """
+
+ prompt: str
+ """An optional text to guide the model's style or continue a previous audio
+ segment.
+
+ The [prompt](https://platform.openai.com/docs/guides/speech-to-text/prompting)
+ should match the audio language.
+ """
+
+ response_format: Literal["json", "text", "srt", "verbose_json", "vtt"]
+ """
+ The format of the transcript output, in one of these options: json, text, srt,
+ verbose_json, or vtt.
+ """
+
+ temperature: float
+ """The sampling temperature, between 0 and 1.
+
+ Higher values like 0.8 will make the output more random, while lower values like
+ 0.2 will make it more focused and deterministic. If set to 0, the model will use
+ [log probability](https://en.wikipedia.org/wiki/Log_probability) to
+ automatically increase the temperature until certain thresholds are hit.
+ """
diff --git a/src/openai/types/audio/translation.py b/src/openai/types/audio/translation.py
new file mode 100644
index 0000000000..a01d622abc
--- /dev/null
+++ b/src/openai/types/audio/translation.py
@@ -0,0 +1,9 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from ..._models import BaseModel
+
+__all__ = ["Translation"]
+
+
+class Translation(BaseModel):
+ text: str
diff --git a/src/openai/types/audio/translation_create_params.py b/src/openai/types/audio/translation_create_params.py
new file mode 100644
index 0000000000..bfa5fc56d2
--- /dev/null
+++ b/src/openai/types/audio/translation_create_params.py
@@ -0,0 +1,44 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import Union
+from typing_extensions import Literal, Required, TypedDict
+
+from ..._types import FileTypes
+
+__all__ = ["TranslationCreateParams"]
+
+
+class TranslationCreateParams(TypedDict, total=False):
+ file: Required[FileTypes]
+ """
+ The audio file object (not file name) translate, in one of these formats: flac,
+ mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
+ """
+
+ model: Required[Union[str, Literal["whisper-1"]]]
+ """ID of the model to use. Only `whisper-1` is currently available."""
+
+ prompt: str
+ """An optional text to guide the model's style or continue a previous audio
+ segment.
+
+ The [prompt](https://platform.openai.com/docs/guides/speech-to-text/prompting)
+ should be in English.
+ """
+
+ response_format: str
+ """
+ The format of the transcript output, in one of these options: json, text, srt,
+ verbose_json, or vtt.
+ """
+
+ temperature: float
+ """The sampling temperature, between 0 and 1.
+
+ Higher values like 0.8 will make the output more random, while lower values like
+ 0.2 will make it more focused and deterministic. If set to 0, the model will use
+ [log probability](https://en.wikipedia.org/wiki/Log_probability) to
+ automatically increase the temperature until certain thresholds are hit.
+ """
diff --git a/src/openai/types/chat/__init__.py b/src/openai/types/chat/__init__.py
new file mode 100644
index 0000000000..2f23cf3ca4
--- /dev/null
+++ b/src/openai/types/chat/__init__.py
@@ -0,0 +1,12 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from .chat_completion import ChatCompletion as ChatCompletion
+from .chat_completion_role import ChatCompletionRole as ChatCompletionRole
+from .chat_completion_chunk import ChatCompletionChunk as ChatCompletionChunk
+from .chat_completion_message import ChatCompletionMessage as ChatCompletionMessage
+from .completion_create_params import CompletionCreateParams as CompletionCreateParams
+from .chat_completion_message_param import (
+ ChatCompletionMessageParam as ChatCompletionMessageParam,
+)
diff --git a/src/openai/types/chat/chat_completion.py b/src/openai/types/chat/chat_completion.py
new file mode 100644
index 0000000000..8d7a0b9716
--- /dev/null
+++ b/src/openai/types/chat/chat_completion.py
@@ -0,0 +1,50 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from typing import List, Optional
+from typing_extensions import Literal
+
+from ..._models import BaseModel
+from ..completion_usage import CompletionUsage
+from .chat_completion_message import ChatCompletionMessage
+
+__all__ = ["ChatCompletion", "Choice"]
+
+
+class Choice(BaseModel):
+ finish_reason: Literal["stop", "length", "function_call", "content_filter"]
+ """The reason the model stopped generating tokens.
+
+ This will be `stop` if the model hit a natural stop point or a provided stop
+ sequence, `length` if the maximum number of tokens specified in the request was
+ reached, `content_filter` if content was omitted due to a flag from our content
+ filters, or `function_call` if the model called a function.
+ """
+
+ index: int
+ """The index of the choice in the list of choices."""
+
+ message: ChatCompletionMessage
+ """A chat completion message generated by the model."""
+
+
+class ChatCompletion(BaseModel):
+ id: str
+ """A unique identifier for the chat completion."""
+
+ choices: List[Choice]
+ """A list of chat completion choices.
+
+ Can be more than one if `n` is greater than 1.
+ """
+
+ created: int
+ """The Unix timestamp (in seconds) of when the chat completion was created."""
+
+ model: str
+ """The model used for the chat completion."""
+
+ object: str
+ """The object type, which is always `chat.completion`."""
+
+ usage: Optional[CompletionUsage] = None
+ """Usage statistics for the completion request."""
diff --git a/src/openai/types/chat/chat_completion_chunk.py b/src/openai/types/chat/chat_completion_chunk.py
new file mode 100644
index 0000000000..66610898b4
--- /dev/null
+++ b/src/openai/types/chat/chat_completion_chunk.py
@@ -0,0 +1,76 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from typing import List, Optional
+from typing_extensions import Literal
+
+from ..._models import BaseModel
+from .chat_completion_role import ChatCompletionRole
+
+__all__ = ["ChatCompletionChunk", "Choice", "ChoiceDelta", "ChoiceDeltaFunctionCall"]
+
+
+class ChoiceDeltaFunctionCall(BaseModel):
+ arguments: Optional[str] = None
+ """
+ The arguments to call the function with, as generated by the model in JSON
+ format. Note that the model does not always generate valid JSON, and may
+ hallucinate parameters not defined by your function schema. Validate the
+ arguments in your code before calling your function.
+ """
+
+ name: Optional[str] = None
+ """The name of the function to call."""
+
+
+class ChoiceDelta(BaseModel):
+ content: Optional[str] = None
+ """The contents of the chunk message."""
+
+ function_call: Optional[ChoiceDeltaFunctionCall] = None
+ """
+ The name and arguments of a function that should be called, as generated by the
+ model.
+ """
+
+ role: Optional[ChatCompletionRole] = None
+ """The role of the author of this message."""
+
+
+class Choice(BaseModel):
+ delta: ChoiceDelta
+ """A chat completion delta generated by streamed model responses."""
+
+ finish_reason: Optional[Literal["stop", "length", "function_call", "content_filter"]]
+ """The reason the model stopped generating tokens.
+
+ This will be `stop` if the model hit a natural stop point or a provided stop
+ sequence, `length` if the maximum number of tokens specified in the request was
+ reached, `content_filter` if content was omitted due to a flag from our content
+ filters, or `function_call` if the model called a function.
+ """
+
+ index: int
+ """The index of the choice in the list of choices."""
+
+
+class ChatCompletionChunk(BaseModel):
+ id: str
+ """A unique identifier for the chat completion. Each chunk has the same ID."""
+
+ choices: List[Choice]
+ """A list of chat completion choices.
+
+ Can be more than one if `n` is greater than 1.
+ """
+
+ created: int
+ """The Unix timestamp (in seconds) of when the chat completion was created.
+
+ Each chunk has the same timestamp.
+ """
+
+ model: str
+ """The model to generate the completion."""
+
+ object: str
+ """The object type, which is always `chat.completion.chunk`."""
diff --git a/src/openai/types/chat/chat_completion_message.py b/src/openai/types/chat/chat_completion_message.py
new file mode 100644
index 0000000000..531eb3d43c
--- /dev/null
+++ b/src/openai/types/chat/chat_completion_message.py
@@ -0,0 +1,35 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from typing import Optional
+
+from ..._models import BaseModel
+from .chat_completion_role import ChatCompletionRole
+
+__all__ = ["ChatCompletionMessage", "FunctionCall"]
+
+
+class FunctionCall(BaseModel):
+ arguments: str
+ """
+ The arguments to call the function with, as generated by the model in JSON
+ format. Note that the model does not always generate valid JSON, and may
+ hallucinate parameters not defined by your function schema. Validate the
+ arguments in your code before calling your function.
+ """
+
+ name: str
+ """The name of the function to call."""
+
+
+class ChatCompletionMessage(BaseModel):
+ content: Optional[str]
+ """The contents of the message."""
+
+ role: ChatCompletionRole
+ """The role of the author of this message."""
+
+ function_call: Optional[FunctionCall] = None
+ """
+ The name and arguments of a function that should be called, as generated by the
+ model.
+ """
diff --git a/src/openai/types/chat/chat_completion_message_param.py b/src/openai/types/chat/chat_completion_message_param.py
new file mode 100644
index 0000000000..29b8882573
--- /dev/null
+++ b/src/openai/types/chat/chat_completion_message_param.py
@@ -0,0 +1,50 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import Optional
+from typing_extensions import Literal, Required, TypedDict
+
+__all__ = ["ChatCompletionMessageParam", "FunctionCall"]
+
+
+class FunctionCall(TypedDict, total=False):
+ arguments: Required[str]
+ """
+ The arguments to call the function with, as generated by the model in JSON
+ format. Note that the model does not always generate valid JSON, and may
+ hallucinate parameters not defined by your function schema. Validate the
+ arguments in your code before calling your function.
+ """
+
+ name: Required[str]
+ """The name of the function to call."""
+
+
+class ChatCompletionMessageParam(TypedDict, total=False):
+ content: Required[Optional[str]]
+ """The contents of the message.
+
+ `content` is required for all messages, and may be null for assistant messages
+ with function calls.
+ """
+
+ role: Required[Literal["system", "user", "assistant", "function"]]
+ """The role of the messages author.
+
+ One of `system`, `user`, `assistant`, or `function`.
+ """
+
+ function_call: FunctionCall
+ """
+ The name and arguments of a function that should be called, as generated by the
+ model.
+ """
+
+ name: str
+ """The name of the author of this message.
+
+ `name` is required if role is `function`, and it should be the name of the
+ function whose response is in the `content`. May contain a-z, A-Z, 0-9, and
+ underscores, with a maximum length of 64 characters.
+ """
diff --git a/src/openai/types/chat/chat_completion_role.py b/src/openai/types/chat/chat_completion_role.py
new file mode 100644
index 0000000000..da8896a072
--- /dev/null
+++ b/src/openai/types/chat/chat_completion_role.py
@@ -0,0 +1,7 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from typing_extensions import Literal
+
+__all__ = ["ChatCompletionRole"]
+
+ChatCompletionRole = Literal["system", "user", "assistant", "function"]
diff --git a/src/openai/types/chat/completion_create_params.py b/src/openai/types/chat/completion_create_params.py
new file mode 100644
index 0000000000..d681a90cd6
--- /dev/null
+++ b/src/openai/types/chat/completion_create_params.py
@@ -0,0 +1,194 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import Dict, List, Union, Optional
+from typing_extensions import Literal, Required, TypedDict
+
+from .chat_completion_message_param import ChatCompletionMessageParam
+
+__all__ = [
+ "CompletionCreateParamsBase",
+ "FunctionCall",
+ "FunctionCallFunctionCallOption",
+ "Function",
+ "CompletionCreateParamsNonStreaming",
+ "CompletionCreateParamsStreaming",
+]
+
+
+class CompletionCreateParamsBase(TypedDict, total=False):
+ messages: Required[List[ChatCompletionMessageParam]]
+ """A list of messages comprising the conversation so far.
+
+ [Example Python code](https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models).
+ """
+
+ model: Required[
+ Union[
+ str,
+ Literal[
+ "gpt-4",
+ "gpt-4-0314",
+ "gpt-4-0613",
+ "gpt-4-32k",
+ "gpt-4-32k-0314",
+ "gpt-4-32k-0613",
+ "gpt-3.5-turbo",
+ "gpt-3.5-turbo-16k",
+ "gpt-3.5-turbo-0301",
+ "gpt-3.5-turbo-0613",
+ "gpt-3.5-turbo-16k-0613",
+ ],
+ ]
+ ]
+ """ID of the model to use.
+
+ See the
+ [model endpoint compatibility](https://platform.openai.com/docs/models/model-endpoint-compatibility)
+ table for details on which models work with the Chat API.
+ """
+
+ frequency_penalty: Optional[float]
+ """Number between -2.0 and 2.0.
+
+ Positive values penalize new tokens based on their existing frequency in the
+ text so far, decreasing the model's likelihood to repeat the same line verbatim.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+ """
+
+ function_call: FunctionCall
+ """Controls how the model calls functions.
+
+ "none" means the model will not call a function and instead generates a message.
+ "auto" means the model can pick between generating a message or calling a
+ function. Specifying a particular function via `{"name": "my_function"}` forces
+ the model to call that function. "none" is the default when no functions are
+ present. "auto" is the default if functions are present.
+ """
+
+ functions: List[Function]
+ """A list of functions the model may generate JSON inputs for."""
+
+ logit_bias: Optional[Dict[str, int]]
+ """Modify the likelihood of specified tokens appearing in the completion.
+
+ Accepts a json object that maps tokens (specified by their token ID in the
+ tokenizer) to an associated bias value from -100 to 100. Mathematically, the
+ bias is added to the logits generated by the model prior to sampling. The exact
+ effect will vary per model, but values between -1 and 1 should decrease or
+ increase likelihood of selection; values like -100 or 100 should result in a ban
+ or exclusive selection of the relevant token.
+ """
+
+ max_tokens: Optional[int]
+ """The maximum number of [tokens](/tokenizer) to generate in the chat completion.
+
+ The total length of input tokens and generated tokens is limited by the model's
+ context length.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)
+ for counting tokens.
+ """
+
+ n: Optional[int]
+ """How many chat completion choices to generate for each input message."""
+
+ presence_penalty: Optional[float]
+ """Number between -2.0 and 2.0.
+
+ Positive values penalize new tokens based on whether they appear in the text so
+ far, increasing the model's likelihood to talk about new topics.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+ """
+
+ stop: Union[Optional[str], List[str]]
+ """Up to 4 sequences where the API will stop generating further tokens."""
+
+ temperature: Optional[float]
+ """What sampling temperature to use, between 0 and 2.
+
+ Higher values like 0.8 will make the output more random, while lower values like
+ 0.2 will make it more focused and deterministic.
+
+ We generally recommend altering this or `top_p` but not both.
+ """
+
+ top_p: Optional[float]
+ """
+ An alternative to sampling with temperature, called nucleus sampling, where the
+ model considers the results of the tokens with top_p probability mass. So 0.1
+ means only the tokens comprising the top 10% probability mass are considered.
+
+ We generally recommend altering this or `temperature` but not both.
+ """
+
+ user: str
+ """
+ A unique identifier representing your end-user, which can help OpenAI to monitor
+ and detect abuse.
+ [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
+ """
+
+
+class FunctionCallFunctionCallOption(TypedDict, total=False):
+ name: Required[str]
+ """The name of the function to call."""
+
+
+FunctionCall = Union[Literal["none", "auto"], FunctionCallFunctionCallOption]
+
+
+class Function(TypedDict, total=False):
+ name: Required[str]
+ """The name of the function to be called.
+
+ Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length
+ of 64.
+ """
+
+ parameters: Required[Dict[str, object]]
+ """The parameters the functions accepts, described as a JSON Schema object.
+
+ See the [guide](https://platform.openai.com/docs/guides/gpt/function-calling)
+ for examples, and the
+ [JSON Schema reference](https://json-schema.org/understanding-json-schema/) for
+ documentation about the format.
+
+ To describe a function that accepts no parameters, provide the value
+ `{"type": "object", "properties": {}}`.
+ """
+
+ description: str
+ """
+ A description of what the function does, used by the model to choose when and
+ how to call the function.
+ """
+
+
+class CompletionCreateParamsNonStreaming(CompletionCreateParamsBase):
+ stream: Optional[Literal[False]]
+ """If set, partial message deltas will be sent, like in ChatGPT.
+
+ Tokens will be sent as data-only
+ [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
+ as they become available, with the stream terminated by a `data: [DONE]`
+ message.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).
+ """
+
+
+class CompletionCreateParamsStreaming(CompletionCreateParamsBase):
+ stream: Required[Literal[True]]
+ """If set, partial message deltas will be sent, like in ChatGPT.
+
+ Tokens will be sent as data-only
+ [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
+ as they become available, with the stream terminated by a `data: [DONE]`
+ message.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).
+ """
+
+
+CompletionCreateParams = Union[CompletionCreateParamsNonStreaming, CompletionCreateParamsStreaming]
diff --git a/src/openai/types/completion.py b/src/openai/types/completion.py
new file mode 100644
index 0000000000..0a90838fd4
--- /dev/null
+++ b/src/openai/types/completion.py
@@ -0,0 +1,29 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from typing import List, Optional
+
+from .._models import BaseModel
+from .completion_usage import CompletionUsage
+from .completion_choice import CompletionChoice
+
+__all__ = ["Completion"]
+
+
+class Completion(BaseModel):
+ id: str
+ """A unique identifier for the completion."""
+
+ choices: List[CompletionChoice]
+ """The list of completion choices the model generated for the input prompt."""
+
+ created: int
+ """The Unix timestamp (in seconds) of when the completion was created."""
+
+ model: str
+ """The model used for completion."""
+
+ object: str
+ """The object type, which is always "text_completion" """
+
+ usage: Optional[CompletionUsage] = None
+ """Usage statistics for the completion request."""
diff --git a/src/openai/types/completion_choice.py b/src/openai/types/completion_choice.py
new file mode 100644
index 0000000000..e86d706ed1
--- /dev/null
+++ b/src/openai/types/completion_choice.py
@@ -0,0 +1,35 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from typing import Dict, List, Optional
+from typing_extensions import Literal
+
+from .._models import BaseModel
+
+__all__ = ["CompletionChoice", "Logprobs"]
+
+
+class Logprobs(BaseModel):
+ text_offset: Optional[List[int]] = None
+
+ token_logprobs: Optional[List[float]] = None
+
+ tokens: Optional[List[str]] = None
+
+ top_logprobs: Optional[List[Dict[str, int]]] = None
+
+
+class CompletionChoice(BaseModel):
+ finish_reason: Literal["stop", "length", "content_filter"]
+ """The reason the model stopped generating tokens.
+
+ This will be `stop` if the model hit a natural stop point or a provided stop
+ sequence, `length` if the maximum number of tokens specified in the request was
+ reached, or `content_filter` if content was omitted due to a flag from our
+ content filters.
+ """
+
+ index: int
+
+ logprobs: Optional[Logprobs]
+
+ text: str
diff --git a/src/openai/types/completion_create_params.py b/src/openai/types/completion_create_params.py
new file mode 100644
index 0000000000..023c087d5f
--- /dev/null
+++ b/src/openai/types/completion_create_params.py
@@ -0,0 +1,184 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import Dict, List, Union, Optional
+from typing_extensions import Literal, Required, TypedDict
+
+__all__ = ["CompletionCreateParamsBase", "CompletionCreateParamsNonStreaming", "CompletionCreateParamsStreaming"]
+
+
+class CompletionCreateParamsBase(TypedDict, total=False):
+ model: Required[
+ Union[
+ str,
+ Literal[
+ "babbage-002",
+ "davinci-002",
+ "gpt-3.5-turbo-instruct",
+ "text-davinci-003",
+ "text-davinci-002",
+ "text-davinci-001",
+ "code-davinci-002",
+ "text-curie-001",
+ "text-babbage-001",
+ "text-ada-001",
+ ],
+ ]
+ ]
+ """ID of the model to use.
+
+ You can use the
+ [List models](https://platform.openai.com/docs/api-reference/models/list) API to
+ see all of your available models, or see our
+ [Model overview](https://platform.openai.com/docs/models/overview) for
+ descriptions of them.
+ """
+
+ prompt: Required[Union[str, List[str], List[int], List[List[int]], None]]
+ """
+ The prompt(s) to generate completions for, encoded as a string, array of
+ strings, array of tokens, or array of token arrays.
+
+ Note that <|endoftext|> is the document separator that the model sees during
+ training, so if a prompt is not specified the model will generate as if from the
+ beginning of a new document.
+ """
+
+ best_of: Optional[int]
+ """
+ Generates `best_of` completions server-side and returns the "best" (the one with
+ the highest log probability per token). Results cannot be streamed.
+
+ When used with `n`, `best_of` controls the number of candidate completions and
+ `n` specifies how many to return – `best_of` must be greater than `n`.
+
+ **Note:** Because this parameter generates many completions, it can quickly
+ consume your token quota. Use carefully and ensure that you have reasonable
+ settings for `max_tokens` and `stop`.
+ """
+
+ echo: Optional[bool]
+ """Echo back the prompt in addition to the completion"""
+
+ frequency_penalty: Optional[float]
+ """Number between -2.0 and 2.0.
+
+ Positive values penalize new tokens based on their existing frequency in the
+ text so far, decreasing the model's likelihood to repeat the same line verbatim.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+ """
+
+ logit_bias: Optional[Dict[str, int]]
+ """Modify the likelihood of specified tokens appearing in the completion.
+
+ Accepts a json object that maps tokens (specified by their token ID in the GPT
+ tokenizer) to an associated bias value from -100 to 100. You can use this
+ [tokenizer tool](/tokenizer?view=bpe) (which works for both GPT-2 and GPT-3) to
+ convert text to token IDs. Mathematically, the bias is added to the logits
+ generated by the model prior to sampling. The exact effect will vary per model,
+ but values between -1 and 1 should decrease or increase likelihood of selection;
+ values like -100 or 100 should result in a ban or exclusive selection of the
+ relevant token.
+
+ As an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token
+ from being generated.
+ """
+
+ logprobs: Optional[int]
+ """
+ Include the log probabilities on the `logprobs` most likely tokens, as well the
+ chosen tokens. For example, if `logprobs` is 5, the API will return a list of
+ the 5 most likely tokens. The API will always return the `logprob` of the
+ sampled token, so there may be up to `logprobs+1` elements in the response.
+
+ The maximum value for `logprobs` is 5.
+ """
+
+ max_tokens: Optional[int]
+ """The maximum number of [tokens](/tokenizer) to generate in the completion.
+
+ The token count of your prompt plus `max_tokens` cannot exceed the model's
+ context length.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)
+ for counting tokens.
+ """
+
+ n: Optional[int]
+ """How many completions to generate for each prompt.
+
+ **Note:** Because this parameter generates many completions, it can quickly
+ consume your token quota. Use carefully and ensure that you have reasonable
+ settings for `max_tokens` and `stop`.
+ """
+
+ presence_penalty: Optional[float]
+ """Number between -2.0 and 2.0.
+
+ Positive values penalize new tokens based on whether they appear in the text so
+ far, increasing the model's likelihood to talk about new topics.
+
+ [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
+ """
+
+ stop: Union[Optional[str], List[str], None]
+ """Up to 4 sequences where the API will stop generating further tokens.
+
+ The returned text will not contain the stop sequence.
+ """
+
+ suffix: Optional[str]
+ """The suffix that comes after a completion of inserted text."""
+
+ temperature: Optional[float]
+ """What sampling temperature to use, between 0 and 2.
+
+ Higher values like 0.8 will make the output more random, while lower values like
+ 0.2 will make it more focused and deterministic.
+
+ We generally recommend altering this or `top_p` but not both.
+ """
+
+ top_p: Optional[float]
+ """
+ An alternative to sampling with temperature, called nucleus sampling, where the
+ model considers the results of the tokens with top_p probability mass. So 0.1
+ means only the tokens comprising the top 10% probability mass are considered.
+
+ We generally recommend altering this or `temperature` but not both.
+ """
+
+ user: str
+ """
+ A unique identifier representing your end-user, which can help OpenAI to monitor
+ and detect abuse.
+ [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
+ """
+
+
+class CompletionCreateParamsNonStreaming(CompletionCreateParamsBase):
+ stream: Optional[Literal[False]]
+ """Whether to stream back partial progress.
+
+ If set, tokens will be sent as data-only
+ [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
+ as they become available, with the stream terminated by a `data: [DONE]`
+ message.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).
+ """
+
+
+class CompletionCreateParamsStreaming(CompletionCreateParamsBase):
+ stream: Required[Literal[True]]
+ """Whether to stream back partial progress.
+
+ If set, tokens will be sent as data-only
+ [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
+ as they become available, with the stream terminated by a `data: [DONE]`
+ message.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).
+ """
+
+
+CompletionCreateParams = Union[CompletionCreateParamsNonStreaming, CompletionCreateParamsStreaming]
diff --git a/src/openai/types/completion_usage.py b/src/openai/types/completion_usage.py
new file mode 100644
index 0000000000..b825d5529f
--- /dev/null
+++ b/src/openai/types/completion_usage.py
@@ -0,0 +1,16 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from .._models import BaseModel
+
+__all__ = ["CompletionUsage"]
+
+
+class CompletionUsage(BaseModel):
+ completion_tokens: int
+ """Number of tokens in the generated completion."""
+
+ prompt_tokens: int
+ """Number of tokens in the prompt."""
+
+ total_tokens: int
+ """Total number of tokens used in the request (prompt + completion)."""
diff --git a/src/openai/types/create_embedding_response.py b/src/openai/types/create_embedding_response.py
new file mode 100644
index 0000000000..eccd148d3c
--- /dev/null
+++ b/src/openai/types/create_embedding_response.py
@@ -0,0 +1,30 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from typing import List
+
+from .._models import BaseModel
+from .embedding import Embedding
+
+__all__ = ["CreateEmbeddingResponse", "Usage"]
+
+
+class Usage(BaseModel):
+ prompt_tokens: int
+ """The number of tokens used by the prompt."""
+
+ total_tokens: int
+ """The total number of tokens used by the request."""
+
+
+class CreateEmbeddingResponse(BaseModel):
+ data: List[Embedding]
+ """The list of embeddings generated by the model."""
+
+ model: str
+ """The name of the model used to generate the embedding."""
+
+ object: str
+ """The object type, which is always "embedding"."""
+
+ usage: Usage
+ """The usage information for the request."""
diff --git a/src/openai/types/edit.py b/src/openai/types/edit.py
new file mode 100644
index 0000000000..41b327534e
--- /dev/null
+++ b/src/openai/types/edit.py
@@ -0,0 +1,40 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from typing import List
+from typing_extensions import Literal
+
+from .._models import BaseModel
+from .completion_usage import CompletionUsage
+
+__all__ = ["Edit", "Choice"]
+
+
+class Choice(BaseModel):
+ finish_reason: Literal["stop", "length"]
+ """The reason the model stopped generating tokens.
+
+ This will be `stop` if the model hit a natural stop point or a provided stop
+ sequence, `length` if the maximum number of tokens specified in the request was
+ reached, or `content_filter` if content was omitted due to a flag from our
+ content filters.
+ """
+
+ index: int
+ """The index of the choice in the list of choices."""
+
+ text: str
+ """The edited result."""
+
+
+class Edit(BaseModel):
+ choices: List[Choice]
+ """A list of edit choices. Can be more than one if `n` is greater than 1."""
+
+ created: int
+ """The Unix timestamp (in seconds) of when the edit was created."""
+
+ object: str
+ """The object type, which is always `edit`."""
+
+ usage: CompletionUsage
+ """Usage statistics for the completion request."""
diff --git a/src/openai/types/edit_create_params.py b/src/openai/types/edit_create_params.py
new file mode 100644
index 0000000000..a23b79c369
--- /dev/null
+++ b/src/openai/types/edit_create_params.py
@@ -0,0 +1,44 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import Union, Optional
+from typing_extensions import Literal, Required, TypedDict
+
+__all__ = ["EditCreateParams"]
+
+
+class EditCreateParams(TypedDict, total=False):
+ instruction: Required[str]
+ """The instruction that tells the model how to edit the prompt."""
+
+ model: Required[Union[str, Literal["text-davinci-edit-001", "code-davinci-edit-001"]]]
+ """ID of the model to use.
+
+ You can use the `text-davinci-edit-001` or `code-davinci-edit-001` model with
+ this endpoint.
+ """
+
+ input: Optional[str]
+ """The input text to use as a starting point for the edit."""
+
+ n: Optional[int]
+ """How many edits to generate for the input and instruction."""
+
+ temperature: Optional[float]
+ """What sampling temperature to use, between 0 and 2.
+
+ Higher values like 0.8 will make the output more random, while lower values like
+ 0.2 will make it more focused and deterministic.
+
+ We generally recommend altering this or `top_p` but not both.
+ """
+
+ top_p: Optional[float]
+ """
+ An alternative to sampling with temperature, called nucleus sampling, where the
+ model considers the results of the tokens with top_p probability mass. So 0.1
+ means only the tokens comprising the top 10% probability mass are considered.
+
+ We generally recommend altering this or `temperature` but not both.
+ """
diff --git a/src/openai/types/embedding.py b/src/openai/types/embedding.py
new file mode 100644
index 0000000000..4579b9bb57
--- /dev/null
+++ b/src/openai/types/embedding.py
@@ -0,0 +1,22 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from typing import List
+
+from .._models import BaseModel
+
+__all__ = ["Embedding"]
+
+
+class Embedding(BaseModel):
+ embedding: List[float]
+ """The embedding vector, which is a list of floats.
+
+ The length of vector depends on the model as listed in the
+ [embedding guide](https://platform.openai.com/docs/guides/embeddings).
+ """
+
+ index: int
+ """The index of the embedding in the list of embeddings."""
+
+ object: str
+ """The object type, which is always "embedding"."""
diff --git a/src/openai/types/embedding_create_params.py b/src/openai/types/embedding_create_params.py
new file mode 100644
index 0000000000..bc8535f880
--- /dev/null
+++ b/src/openai/types/embedding_create_params.py
@@ -0,0 +1,43 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import List, Union
+from typing_extensions import Literal, Required, TypedDict
+
+__all__ = ["EmbeddingCreateParams"]
+
+
+class EmbeddingCreateParams(TypedDict, total=False):
+ input: Required[Union[str, List[str], List[int], List[List[int]]]]
+ """Input text to embed, encoded as a string or array of tokens.
+
+ To embed multiple inputs in a single request, pass an array of strings or array
+ of token arrays. The input must not exceed the max input tokens for the model
+ (8192 tokens for `text-embedding-ada-002`) and cannot be an empty string.
+ [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)
+ for counting tokens.
+ """
+
+ model: Required[Union[str, Literal["text-embedding-ada-002"]]]
+ """ID of the model to use.
+
+ You can use the
+ [List models](https://platform.openai.com/docs/api-reference/models/list) API to
+ see all of your available models, or see our
+ [Model overview](https://platform.openai.com/docs/models/overview) for
+ descriptions of them.
+ """
+
+ encoding_format: Literal["float", "base64"]
+ """The format to return the embeddings in.
+
+ Can be either `float` or [`base64`](https://pypi.org/project/pybase64/).
+ """
+
+ user: str
+ """
+ A unique identifier representing your end-user, which can help OpenAI to monitor
+ and detect abuse.
+ [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
+ """
diff --git a/src/openai/types/file_content.py b/src/openai/types/file_content.py
new file mode 100644
index 0000000000..92b316b9eb
--- /dev/null
+++ b/src/openai/types/file_content.py
@@ -0,0 +1,6 @@
+# File generated from our OpenAPI spec by Stainless.
+
+
+__all__ = ["FileContent"]
+
+FileContent = str
diff --git a/src/openai/types/file_create_params.py b/src/openai/types/file_create_params.py
new file mode 100644
index 0000000000..07b068c5c6
--- /dev/null
+++ b/src/openai/types/file_create_params.py
@@ -0,0 +1,26 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing_extensions import Required, TypedDict
+
+from .._types import FileTypes
+
+__all__ = ["FileCreateParams"]
+
+
+class FileCreateParams(TypedDict, total=False):
+ file: Required[FileTypes]
+ """The file object (not file name) to be uploaded.
+
+ If the `purpose` is set to "fine-tune", the file will be used for fine-tuning.
+ """
+
+ purpose: Required[str]
+ """The intended purpose of the uploaded file.
+
+ Use "fine-tune" for
+ [fine-tuning](https://platform.openai.com/docs/api-reference/fine-tuning). This
+ allows us to validate the format of the uploaded file is correct for
+ fine-tuning.
+ """
diff --git a/src/openai/types/file_deleted.py b/src/openai/types/file_deleted.py
new file mode 100644
index 0000000000..a526b2b986
--- /dev/null
+++ b/src/openai/types/file_deleted.py
@@ -0,0 +1,13 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from .._models import BaseModel
+
+__all__ = ["FileDeleted"]
+
+
+class FileDeleted(BaseModel):
+ id: str
+
+ deleted: bool
+
+ object: str
diff --git a/src/openai/types/file_object.py b/src/openai/types/file_object.py
new file mode 100644
index 0000000000..dac24a88c5
--- /dev/null
+++ b/src/openai/types/file_object.py
@@ -0,0 +1,40 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from typing import Optional
+
+from .._models import BaseModel
+
+__all__ = ["FileObject"]
+
+
+class FileObject(BaseModel):
+ id: str
+ """The file identifier, which can be referenced in the API endpoints."""
+
+ bytes: int
+ """The size of the file in bytes."""
+
+ created_at: int
+ """The Unix timestamp (in seconds) for when the file was created."""
+
+ filename: str
+ """The name of the file."""
+
+ object: str
+ """The object type, which is always "file"."""
+
+ purpose: str
+ """The intended purpose of the file. Currently, only "fine-tune" is supported."""
+
+ status: Optional[str] = None
+ """
+ The current status of the file, which can be either `uploaded`, `processed`,
+ `pending`, `error`, `deleting` or `deleted`.
+ """
+
+ status_details: Optional[str] = None
+ """Additional details about the status of the file.
+
+ If the file is in the `error` state, this will include a message describing the
+ error.
+ """
diff --git a/src/openai/types/fine_tune.py b/src/openai/types/fine_tune.py
new file mode 100644
index 0000000000..4124def2f5
--- /dev/null
+++ b/src/openai/types/fine_tune.py
@@ -0,0 +1,93 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from typing import List, Optional
+
+from .._models import BaseModel
+from .file_object import FileObject
+from .fine_tune_event import FineTuneEvent
+
+__all__ = ["FineTune", "Hyperparams"]
+
+
+class Hyperparams(BaseModel):
+ batch_size: int
+ """The batch size to use for training.
+
+ The batch size is the number of training examples used to train a single forward
+ and backward pass.
+ """
+
+ learning_rate_multiplier: float
+ """The learning rate multiplier to use for training."""
+
+ n_epochs: int
+ """The number of epochs to train the model for.
+
+ An epoch refers to one full cycle through the training dataset.
+ """
+
+ prompt_loss_weight: float
+ """The weight to use for loss on the prompt tokens."""
+
+ classification_n_classes: Optional[int] = None
+ """The number of classes to use for computing classification metrics."""
+
+ classification_positive_class: Optional[str] = None
+ """The positive class to use for computing classification metrics."""
+
+ compute_classification_metrics: Optional[bool] = None
+ """
+ The classification metrics to compute using the validation dataset at the end of
+ every epoch.
+ """
+
+
+class FineTune(BaseModel):
+ id: str
+ """The object identifier, which can be referenced in the API endpoints."""
+
+ created_at: int
+ """The Unix timestamp (in seconds) for when the fine-tuning job was created."""
+
+ fine_tuned_model: Optional[str]
+ """The name of the fine-tuned model that is being created."""
+
+ hyperparams: Hyperparams
+ """The hyperparameters used for the fine-tuning job.
+
+ See the
+ [fine-tuning guide](https://platform.openai.com/docs/guides/legacy-fine-tuning/hyperparameters)
+ for more details.
+ """
+
+ model: str
+ """The base model that is being fine-tuned."""
+
+ object: str
+ """The object type, which is always "fine-tune"."""
+
+ organization_id: str
+ """The organization that owns the fine-tuning job."""
+
+ result_files: List[FileObject]
+ """The compiled results files for the fine-tuning job."""
+
+ status: str
+ """
+ The current status of the fine-tuning job, which can be either `created`,
+ `running`, `succeeded`, `failed`, or `cancelled`.
+ """
+
+ training_files: List[FileObject]
+ """The list of files used for training."""
+
+ updated_at: int
+ """The Unix timestamp (in seconds) for when the fine-tuning job was last updated."""
+
+ validation_files: List[FileObject]
+ """The list of files used for validation."""
+
+ events: Optional[List[FineTuneEvent]] = None
+ """
+ The list of events that have been observed in the lifecycle of the FineTune job.
+ """
diff --git a/src/openai/types/fine_tune_create_params.py b/src/openai/types/fine_tune_create_params.py
new file mode 100644
index 0000000000..1be9c9ea04
--- /dev/null
+++ b/src/openai/types/fine_tune_create_params.py
@@ -0,0 +1,140 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import List, Union, Optional
+from typing_extensions import Literal, Required, TypedDict
+
+__all__ = ["FineTuneCreateParams", "Hyperparameters"]
+
+
+class FineTuneCreateParams(TypedDict, total=False):
+ training_file: Required[str]
+ """The ID of an uploaded file that contains training data.
+
+ See [upload file](https://platform.openai.com/docs/api-reference/files/upload)
+ for how to upload a file.
+
+ Your dataset must be formatted as a JSONL file, where each training example is a
+ JSON object with the keys "prompt" and "completion". Additionally, you must
+ upload your file with the purpose `fine-tune`.
+
+ See the
+ [fine-tuning guide](https://platform.openai.com/docs/guides/legacy-fine-tuning/creating-training-data)
+ for more details.
+ """
+
+ batch_size: Optional[int]
+ """The batch size to use for training.
+
+ The batch size is the number of training examples used to train a single forward
+ and backward pass.
+
+ By default, the batch size will be dynamically configured to be ~0.2% of the
+ number of examples in the training set, capped at 256 - in general, we've found
+ that larger batch sizes tend to work better for larger datasets.
+ """
+
+ classification_betas: Optional[List[float]]
+ """If this is provided, we calculate F-beta scores at the specified beta values.
+
+ The F-beta score is a generalization of F-1 score. This is only used for binary
+ classification.
+
+ With a beta of 1 (i.e. the F-1 score), precision and recall are given the same
+ weight. A larger beta score puts more weight on recall and less on precision. A
+ smaller beta score puts more weight on precision and less on recall.
+ """
+
+ classification_n_classes: Optional[int]
+ """The number of classes in a classification task.
+
+ This parameter is required for multiclass classification.
+ """
+
+ classification_positive_class: Optional[str]
+ """The positive class in binary classification.
+
+ This parameter is needed to generate precision, recall, and F1 metrics when
+ doing binary classification.
+ """
+
+ compute_classification_metrics: Optional[bool]
+ """
+ If set, we calculate classification-specific metrics such as accuracy and F-1
+ score using the validation set at the end of every epoch. These metrics can be
+ viewed in the
+ [results file](https://platform.openai.com/docs/guides/legacy-fine-tuning/analyzing-your-fine-tuned-model).
+
+ In order to compute classification metrics, you must provide a
+ `validation_file`. Additionally, you must specify `classification_n_classes` for
+ multiclass classification or `classification_positive_class` for binary
+ classification.
+ """
+
+ hyperparameters: Hyperparameters
+ """The hyperparameters used for the fine-tuning job."""
+
+ learning_rate_multiplier: Optional[float]
+ """
+ The learning rate multiplier to use for training. The fine-tuning learning rate
+ is the original learning rate used for pretraining multiplied by this value.
+
+ By default, the learning rate multiplier is the 0.05, 0.1, or 0.2 depending on
+ final `batch_size` (larger learning rates tend to perform better with larger
+ batch sizes). We recommend experimenting with values in the range 0.02 to 0.2 to
+ see what produces the best results.
+ """
+
+ model: Union[str, Literal["ada", "babbage", "curie", "davinci"], None]
+ """The name of the base model to fine-tune.
+
+ You can select one of "ada", "babbage", "curie", "davinci", or a fine-tuned
+ model created after 2022-04-21 and before 2023-08-22. To learn more about these
+ models, see the [Models](https://platform.openai.com/docs/models) documentation.
+ """
+
+ prompt_loss_weight: Optional[float]
+ """The weight to use for loss on the prompt tokens.
+
+ This controls how much the model tries to learn to generate the prompt (as
+ compared to the completion which always has a weight of 1.0), and can add a
+ stabilizing effect to training when completions are short.
+
+ If prompts are extremely long (relative to completions), it may make sense to
+ reduce this weight so as to avoid over-prioritizing learning the prompt.
+ """
+
+ suffix: Optional[str]
+ """
+ A string of up to 40 characters that will be added to your fine-tuned model
+ name.
+
+ For example, a `suffix` of "custom-model-name" would produce a model name like
+ `ada:ft-your-org:custom-model-name-2022-02-15-04-21-04`.
+ """
+
+ validation_file: Optional[str]
+ """The ID of an uploaded file that contains validation data.
+
+ If you provide this file, the data is used to generate validation metrics
+ periodically during fine-tuning. These metrics can be viewed in the
+ [fine-tuning results file](https://platform.openai.com/docs/guides/legacy-fine-tuning/analyzing-your-fine-tuned-model).
+ Your train and validation data should be mutually exclusive.
+
+ Your dataset must be formatted as a JSONL file, where each validation example is
+ a JSON object with the keys "prompt" and "completion". Additionally, you must
+ upload your file with the purpose `fine-tune`.
+
+ See the
+ [fine-tuning guide](https://platform.openai.com/docs/guides/legacy-fine-tuning/creating-training-data)
+ for more details.
+ """
+
+
+class Hyperparameters(TypedDict, total=False):
+ n_epochs: Union[Literal["auto"], int]
+ """The number of epochs to train the model for.
+
+ An epoch refers to one full cycle through the training dataset.
+ """
diff --git a/src/openai/types/fine_tune_event.py b/src/openai/types/fine_tune_event.py
new file mode 100644
index 0000000000..6499def98d
--- /dev/null
+++ b/src/openai/types/fine_tune_event.py
@@ -0,0 +1,15 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from .._models import BaseModel
+
+__all__ = ["FineTuneEvent"]
+
+
+class FineTuneEvent(BaseModel):
+ created_at: int
+
+ level: str
+
+ message: str
+
+ object: str
diff --git a/src/openai/types/fine_tune_events_list_response.py b/src/openai/types/fine_tune_events_list_response.py
new file mode 100644
index 0000000000..ca159d8772
--- /dev/null
+++ b/src/openai/types/fine_tune_events_list_response.py
@@ -0,0 +1,14 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from typing import List
+
+from .._models import BaseModel
+from .fine_tune_event import FineTuneEvent
+
+__all__ = ["FineTuneEventsListResponse"]
+
+
+class FineTuneEventsListResponse(BaseModel):
+ data: List[FineTuneEvent]
+
+ object: str
diff --git a/src/openai/types/fine_tune_list_events_params.py b/src/openai/types/fine_tune_list_events_params.py
new file mode 100644
index 0000000000..1f23b108e6
--- /dev/null
+++ b/src/openai/types/fine_tune_list_events_params.py
@@ -0,0 +1,41 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import Union
+from typing_extensions import Literal, Required, TypedDict
+
+__all__ = ["FineTuneListEventsParamsBase", "FineTuneListEventsParamsNonStreaming", "FineTuneListEventsParamsStreaming"]
+
+
+class FineTuneListEventsParamsBase(TypedDict, total=False):
+ pass
+
+
+class FineTuneListEventsParamsNonStreaming(FineTuneListEventsParamsBase):
+ stream: Literal[False]
+ """Whether to stream events for the fine-tune job.
+
+ If set to true, events will be sent as data-only
+ [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
+ as they become available. The stream will terminate with a `data: [DONE]`
+ message when the job is finished (succeeded, cancelled, or failed).
+
+ If set to false, only events generated so far will be returned.
+ """
+
+
+class FineTuneListEventsParamsStreaming(FineTuneListEventsParamsBase):
+ stream: Required[Literal[True]]
+ """Whether to stream events for the fine-tune job.
+
+ If set to true, events will be sent as data-only
+ [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
+ as they become available. The stream will terminate with a `data: [DONE]`
+ message when the job is finished (succeeded, cancelled, or failed).
+
+ If set to false, only events generated so far will be returned.
+ """
+
+
+FineTuneListEventsParams = Union[FineTuneListEventsParamsNonStreaming, FineTuneListEventsParamsStreaming]
diff --git a/src/openai/types/fine_tuning/__init__.py b/src/openai/types/fine_tuning/__init__.py
new file mode 100644
index 0000000000..d24160c5bd
--- /dev/null
+++ b/src/openai/types/fine_tuning/__init__.py
@@ -0,0 +1,9 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from .fine_tuning_job import FineTuningJob as FineTuningJob
+from .job_list_params import JobListParams as JobListParams
+from .job_create_params import JobCreateParams as JobCreateParams
+from .fine_tuning_job_event import FineTuningJobEvent as FineTuningJobEvent
+from .job_list_events_params import JobListEventsParams as JobListEventsParams
diff --git a/src/openai/types/fine_tuning/fine_tuning_job.py b/src/openai/types/fine_tuning/fine_tuning_job.py
new file mode 100644
index 0000000000..2ae1cbb473
--- /dev/null
+++ b/src/openai/types/fine_tuning/fine_tuning_job.py
@@ -0,0 +1,107 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from typing import List, Union, Optional
+from typing_extensions import Literal
+
+from ..._models import BaseModel
+
+__all__ = ["FineTuningJob", "Error", "Hyperparameters"]
+
+
+class Error(BaseModel):
+ code: str
+ """A machine-readable error code."""
+
+ message: str
+ """A human-readable error message."""
+
+ param: Optional[str]
+ """The parameter that was invalid, usually `training_file` or `validation_file`.
+
+ This field will be null if the failure was not parameter-specific.
+ """
+
+
+class Hyperparameters(BaseModel):
+ n_epochs: Union[Literal["auto"], int]
+ """The number of epochs to train the model for.
+
+ An epoch refers to one full cycle through the training dataset. "auto" decides
+ the optimal number of epochs based on the size of the dataset. If setting the
+ number manually, we support any number between 1 and 50 epochs.
+ """
+
+
+class FineTuningJob(BaseModel):
+ id: str
+ """The object identifier, which can be referenced in the API endpoints."""
+
+ created_at: int
+ """The Unix timestamp (in seconds) for when the fine-tuning job was created."""
+
+ error: Optional[Error]
+ """
+ For fine-tuning jobs that have `failed`, this will contain more information on
+ the cause of the failure.
+ """
+
+ fine_tuned_model: Optional[str]
+ """The name of the fine-tuned model that is being created.
+
+ The value will be null if the fine-tuning job is still running.
+ """
+
+ finished_at: Optional[int]
+ """The Unix timestamp (in seconds) for when the fine-tuning job was finished.
+
+ The value will be null if the fine-tuning job is still running.
+ """
+
+ hyperparameters: Hyperparameters
+ """The hyperparameters used for the fine-tuning job.
+
+ See the [fine-tuning guide](https://platform.openai.com/docs/guides/fine-tuning)
+ for more details.
+ """
+
+ model: str
+ """The base model that is being fine-tuned."""
+
+ object: str
+ """The object type, which is always "fine_tuning.job"."""
+
+ organization_id: str
+ """The organization that owns the fine-tuning job."""
+
+ result_files: List[str]
+ """The compiled results file ID(s) for the fine-tuning job.
+
+ You can retrieve the results with the
+ [Files API](https://platform.openai.com/docs/api-reference/files/retrieve-contents).
+ """
+
+ status: str
+ """
+ The current status of the fine-tuning job, which can be either
+ `validating_files`, `queued`, `running`, `succeeded`, `failed`, or `cancelled`.
+ """
+
+ trained_tokens: Optional[int]
+ """The total number of billable tokens processed by this fine-tuning job.
+
+ The value will be null if the fine-tuning job is still running.
+ """
+
+ training_file: str
+ """The file ID used for training.
+
+ You can retrieve the training data with the
+ [Files API](https://platform.openai.com/docs/api-reference/files/retrieve-contents).
+ """
+
+ validation_file: Optional[str]
+ """The file ID used for validation.
+
+ You can retrieve the validation results with the
+ [Files API](https://platform.openai.com/docs/api-reference/files/retrieve-contents).
+ """
diff --git a/src/openai/types/fine_tuning/fine_tuning_job_event.py b/src/openai/types/fine_tuning/fine_tuning_job_event.py
new file mode 100644
index 0000000000..c21a0503ab
--- /dev/null
+++ b/src/openai/types/fine_tuning/fine_tuning_job_event.py
@@ -0,0 +1,19 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from typing_extensions import Literal
+
+from ..._models import BaseModel
+
+__all__ = ["FineTuningJobEvent"]
+
+
+class FineTuningJobEvent(BaseModel):
+ id: str
+
+ created_at: int
+
+ level: Literal["info", "warn", "error"]
+
+ message: str
+
+ object: str
diff --git a/src/openai/types/fine_tuning/job_create_params.py b/src/openai/types/fine_tuning/job_create_params.py
new file mode 100644
index 0000000000..2a67b81817
--- /dev/null
+++ b/src/openai/types/fine_tuning/job_create_params.py
@@ -0,0 +1,65 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import Union, Optional
+from typing_extensions import Literal, Required, TypedDict
+
+__all__ = ["JobCreateParams", "Hyperparameters"]
+
+
+class JobCreateParams(TypedDict, total=False):
+ model: Required[Union[str, Literal["babbage-002", "davinci-002", "gpt-3.5-turbo"]]]
+ """The name of the model to fine-tune.
+
+ You can select one of the
+ [supported models](https://platform.openai.com/docs/guides/fine-tuning/what-models-can-be-fine-tuned).
+ """
+
+ training_file: Required[str]
+ """The ID of an uploaded file that contains training data.
+
+ See [upload file](https://platform.openai.com/docs/api-reference/files/upload)
+ for how to upload a file.
+
+ Your dataset must be formatted as a JSONL file. Additionally, you must upload
+ your file with the purpose `fine-tune`.
+
+ See the [fine-tuning guide](https://platform.openai.com/docs/guides/fine-tuning)
+ for more details.
+ """
+
+ hyperparameters: Hyperparameters
+ """The hyperparameters used for the fine-tuning job."""
+
+ suffix: Optional[str]
+ """
+ A string of up to 18 characters that will be added to your fine-tuned model
+ name.
+
+ For example, a `suffix` of "custom-model-name" would produce a model name like
+ `ft:gpt-3.5-turbo:openai:custom-model-name:7p4lURel`.
+ """
+
+ validation_file: Optional[str]
+ """The ID of an uploaded file that contains validation data.
+
+ If you provide this file, the data is used to generate validation metrics
+ periodically during fine-tuning. These metrics can be viewed in the fine-tuning
+ results file. The same data should not be present in both train and validation
+ files.
+
+ Your dataset must be formatted as a JSONL file. You must upload your file with
+ the purpose `fine-tune`.
+
+ See the [fine-tuning guide](https://platform.openai.com/docs/guides/fine-tuning)
+ for more details.
+ """
+
+
+class Hyperparameters(TypedDict, total=False):
+ n_epochs: Union[Literal["auto"], int]
+ """The number of epochs to train the model for.
+
+ An epoch refers to one full cycle through the training dataset.
+ """
diff --git a/src/openai/types/fine_tuning/job_list_events_params.py b/src/openai/types/fine_tuning/job_list_events_params.py
new file mode 100644
index 0000000000..7be3d53315
--- /dev/null
+++ b/src/openai/types/fine_tuning/job_list_events_params.py
@@ -0,0 +1,15 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing_extensions import TypedDict
+
+__all__ = ["JobListEventsParams"]
+
+
+class JobListEventsParams(TypedDict, total=False):
+ after: str
+ """Identifier for the last event from the previous pagination request."""
+
+ limit: int
+ """Number of events to retrieve."""
diff --git a/src/openai/types/fine_tuning/job_list_params.py b/src/openai/types/fine_tuning/job_list_params.py
new file mode 100644
index 0000000000..8160136901
--- /dev/null
+++ b/src/openai/types/fine_tuning/job_list_params.py
@@ -0,0 +1,15 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing_extensions import TypedDict
+
+__all__ = ["JobListParams"]
+
+
+class JobListParams(TypedDict, total=False):
+ after: str
+ """Identifier for the last job from the previous pagination request."""
+
+ limit: int
+ """Number of fine-tuning jobs to retrieve."""
diff --git a/src/openai/types/image.py b/src/openai/types/image.py
new file mode 100644
index 0000000000..4b8d1aaf18
--- /dev/null
+++ b/src/openai/types/image.py
@@ -0,0 +1,18 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from typing import Optional
+
+from .._models import BaseModel
+
+__all__ = ["Image"]
+
+
+class Image(BaseModel):
+ b64_json: Optional[str] = None
+ """
+ The base64-encoded JSON of the generated image, if `response_format` is
+ `b64_json`.
+ """
+
+ url: Optional[str] = None
+ """The URL of the generated image, if `response_format` is `url` (default)."""
diff --git a/src/openai/types/image_create_variation_params.py b/src/openai/types/image_create_variation_params.py
new file mode 100644
index 0000000000..d3b439070e
--- /dev/null
+++ b/src/openai/types/image_create_variation_params.py
@@ -0,0 +1,40 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import Optional
+from typing_extensions import Literal, Required, TypedDict
+
+from .._types import FileTypes
+
+__all__ = ["ImageCreateVariationParams"]
+
+
+class ImageCreateVariationParams(TypedDict, total=False):
+ image: Required[FileTypes]
+ """The image to use as the basis for the variation(s).
+
+ Must be a valid PNG file, less than 4MB, and square.
+ """
+
+ n: Optional[int]
+ """The number of images to generate. Must be between 1 and 10."""
+
+ response_format: Optional[Literal["url", "b64_json"]]
+ """The format in which the generated images are returned.
+
+ Must be one of `url` or `b64_json`.
+ """
+
+ size: Optional[Literal["256x256", "512x512", "1024x1024"]]
+ """The size of the generated images.
+
+ Must be one of `256x256`, `512x512`, or `1024x1024`.
+ """
+
+ user: str
+ """
+ A unique identifier representing your end-user, which can help OpenAI to monitor
+ and detect abuse.
+ [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
+ """
diff --git a/src/openai/types/image_edit_params.py b/src/openai/types/image_edit_params.py
new file mode 100644
index 0000000000..ce07a9cb30
--- /dev/null
+++ b/src/openai/types/image_edit_params.py
@@ -0,0 +1,54 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import Optional
+from typing_extensions import Literal, Required, TypedDict
+
+from .._types import FileTypes
+
+__all__ = ["ImageEditParams"]
+
+
+class ImageEditParams(TypedDict, total=False):
+ image: Required[FileTypes]
+ """The image to edit.
+
+ Must be a valid PNG file, less than 4MB, and square. If mask is not provided,
+ image must have transparency, which will be used as the mask.
+ """
+
+ prompt: Required[str]
+ """A text description of the desired image(s).
+
+ The maximum length is 1000 characters.
+ """
+
+ mask: FileTypes
+ """An additional image whose fully transparent areas (e.g.
+
+ where alpha is zero) indicate where `image` should be edited. Must be a valid
+ PNG file, less than 4MB, and have the same dimensions as `image`.
+ """
+
+ n: Optional[int]
+ """The number of images to generate. Must be between 1 and 10."""
+
+ response_format: Optional[Literal["url", "b64_json"]]
+ """The format in which the generated images are returned.
+
+ Must be one of `url` or `b64_json`.
+ """
+
+ size: Optional[Literal["256x256", "512x512", "1024x1024"]]
+ """The size of the generated images.
+
+ Must be one of `256x256`, `512x512`, or `1024x1024`.
+ """
+
+ user: str
+ """
+ A unique identifier representing your end-user, which can help OpenAI to monitor
+ and detect abuse.
+ [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
+ """
diff --git a/src/openai/types/image_generate_params.py b/src/openai/types/image_generate_params.py
new file mode 100644
index 0000000000..4999ed958d
--- /dev/null
+++ b/src/openai/types/image_generate_params.py
@@ -0,0 +1,38 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import Optional
+from typing_extensions import Literal, Required, TypedDict
+
+__all__ = ["ImageGenerateParams"]
+
+
+class ImageGenerateParams(TypedDict, total=False):
+ prompt: Required[str]
+ """A text description of the desired image(s).
+
+ The maximum length is 1000 characters.
+ """
+
+ n: Optional[int]
+ """The number of images to generate. Must be between 1 and 10."""
+
+ response_format: Optional[Literal["url", "b64_json"]]
+ """The format in which the generated images are returned.
+
+ Must be one of `url` or `b64_json`.
+ """
+
+ size: Optional[Literal["256x256", "512x512", "1024x1024"]]
+ """The size of the generated images.
+
+ Must be one of `256x256`, `512x512`, or `1024x1024`.
+ """
+
+ user: str
+ """
+ A unique identifier representing your end-user, which can help OpenAI to monitor
+ and detect abuse.
+ [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
+ """
diff --git a/src/openai/types/images_response.py b/src/openai/types/images_response.py
new file mode 100644
index 0000000000..9d1bc95a42
--- /dev/null
+++ b/src/openai/types/images_response.py
@@ -0,0 +1,14 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from typing import List
+
+from .image import Image
+from .._models import BaseModel
+
+__all__ = ["ImagesResponse"]
+
+
+class ImagesResponse(BaseModel):
+ created: int
+
+ data: List[Image]
diff --git a/src/openai/types/model.py b/src/openai/types/model.py
new file mode 100644
index 0000000000..29e71b81a0
--- /dev/null
+++ b/src/openai/types/model.py
@@ -0,0 +1,19 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from .._models import BaseModel
+
+__all__ = ["Model"]
+
+
+class Model(BaseModel):
+ id: str
+ """The model identifier, which can be referenced in the API endpoints."""
+
+ created: int
+ """The Unix timestamp (in seconds) when the model was created."""
+
+ object: str
+ """The object type, which is always "model"."""
+
+ owned_by: str
+ """The organization that owns the model."""
diff --git a/src/openai/types/model_deleted.py b/src/openai/types/model_deleted.py
new file mode 100644
index 0000000000..5329da1378
--- /dev/null
+++ b/src/openai/types/model_deleted.py
@@ -0,0 +1,13 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from .._models import BaseModel
+
+__all__ = ["ModelDeleted"]
+
+
+class ModelDeleted(BaseModel):
+ id: str
+
+ deleted: bool
+
+ object: str
diff --git a/src/openai/types/moderation.py b/src/openai/types/moderation.py
new file mode 100644
index 0000000000..bf586fc24a
--- /dev/null
+++ b/src/openai/types/moderation.py
@@ -0,0 +1,120 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from pydantic import Field as FieldInfo
+
+from .._models import BaseModel
+
+__all__ = ["Moderation", "Categories", "CategoryScores"]
+
+
+class Categories(BaseModel):
+ harassment: bool
+ """
+ Content that expresses, incites, or promotes harassing language towards any
+ target.
+ """
+
+ harassment_threatening: bool = FieldInfo(alias="harassment/threatening")
+ """
+ Harassment content that also includes violence or serious harm towards any
+ target.
+ """
+
+ hate: bool
+ """
+ Content that expresses, incites, or promotes hate based on race, gender,
+ ethnicity, religion, nationality, sexual orientation, disability status, or
+ caste. Hateful content aimed at non-protected groups (e.g., chess players) is
+ harrassment.
+ """
+
+ hate_threatening: bool = FieldInfo(alias="hate/threatening")
+ """
+ Hateful content that also includes violence or serious harm towards the targeted
+ group based on race, gender, ethnicity, religion, nationality, sexual
+ orientation, disability status, or caste.
+ """
+
+ self_minus_harm: bool = FieldInfo(alias="self-harm")
+ """
+ Content that promotes, encourages, or depicts acts of self-harm, such as
+ suicide, cutting, and eating disorders.
+ """
+
+ self_minus_harm_instructions: bool = FieldInfo(alias="self-harm/instructions")
+ """
+ Content that encourages performing acts of self-harm, such as suicide, cutting,
+ and eating disorders, or that gives instructions or advice on how to commit such
+ acts.
+ """
+
+ self_minus_harm_intent: bool = FieldInfo(alias="self-harm/intent")
+ """
+ Content where the speaker expresses that they are engaging or intend to engage
+ in acts of self-harm, such as suicide, cutting, and eating disorders.
+ """
+
+ sexual: bool
+ """
+ Content meant to arouse sexual excitement, such as the description of sexual
+ activity, or that promotes sexual services (excluding sex education and
+ wellness).
+ """
+
+ sexual_minors: bool = FieldInfo(alias="sexual/minors")
+ """Sexual content that includes an individual who is under 18 years old."""
+
+ violence: bool
+ """Content that depicts death, violence, or physical injury."""
+
+ violence_graphic: bool = FieldInfo(alias="violence/graphic")
+ """Content that depicts death, violence, or physical injury in graphic detail."""
+
+
+class CategoryScores(BaseModel):
+ harassment: float
+ """The score for the category 'harassment'."""
+
+ harassment_threatening: float = FieldInfo(alias="harassment/threatening")
+ """The score for the category 'harassment/threatening'."""
+
+ hate: float
+ """The score for the category 'hate'."""
+
+ hate_threatening: float = FieldInfo(alias="hate/threatening")
+ """The score for the category 'hate/threatening'."""
+
+ self_minus_harm: float = FieldInfo(alias="self-harm")
+ """The score for the category 'self-harm'."""
+
+ self_minus_harm_instructions: float = FieldInfo(alias="self-harm/instructions")
+ """The score for the category 'self-harm/instructions'."""
+
+ self_minus_harm_intent: float = FieldInfo(alias="self-harm/intent")
+ """The score for the category 'self-harm/intent'."""
+
+ sexual: float
+ """The score for the category 'sexual'."""
+
+ sexual_minors: float = FieldInfo(alias="sexual/minors")
+ """The score for the category 'sexual/minors'."""
+
+ violence: float
+ """The score for the category 'violence'."""
+
+ violence_graphic: float = FieldInfo(alias="violence/graphic")
+ """The score for the category 'violence/graphic'."""
+
+
+class Moderation(BaseModel):
+ categories: Categories
+ """A list of the categories, and whether they are flagged or not."""
+
+ category_scores: CategoryScores
+ """A list of the categories along with their scores as predicted by model."""
+
+ flagged: bool
+ """
+ Whether the content violates
+ [OpenAI's usage policies](/policies/usage-policies).
+ """
diff --git a/src/openai/types/moderation_create_params.py b/src/openai/types/moderation_create_params.py
new file mode 100644
index 0000000000..25ed3ce940
--- /dev/null
+++ b/src/openai/types/moderation_create_params.py
@@ -0,0 +1,25 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import List, Union
+from typing_extensions import Literal, Required, TypedDict
+
+__all__ = ["ModerationCreateParams"]
+
+
+class ModerationCreateParams(TypedDict, total=False):
+ input: Required[Union[str, List[str]]]
+ """The input text to classify"""
+
+ model: Union[str, Literal["text-moderation-latest", "text-moderation-stable"]]
+ """
+ Two content moderations models are available: `text-moderation-stable` and
+ `text-moderation-latest`.
+
+ The default is `text-moderation-latest` which will be automatically upgraded
+ over time. This ensures you are always using our most accurate model. If you use
+ `text-moderation-stable`, we will provide advanced notice before updating the
+ model. Accuracy of `text-moderation-stable` may be slightly lower than for
+ `text-moderation-latest`.
+ """
diff --git a/src/openai/types/moderation_create_response.py b/src/openai/types/moderation_create_response.py
new file mode 100644
index 0000000000..0962cdbfd9
--- /dev/null
+++ b/src/openai/types/moderation_create_response.py
@@ -0,0 +1,19 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from typing import List
+
+from .._models import BaseModel
+from .moderation import Moderation
+
+__all__ = ["ModerationCreateResponse"]
+
+
+class ModerationCreateResponse(BaseModel):
+ id: str
+ """The unique identifier for the moderation request."""
+
+ model: str
+ """The model used to generate the moderation results."""
+
+ results: List[Moderation]
+ """A list of moderation objects."""
diff --git a/src/openai/version.py b/src/openai/version.py
new file mode 100644
index 0000000000..01a08ab5a9
--- /dev/null
+++ b/src/openai/version.py
@@ -0,0 +1,3 @@
+from ._version import __version__
+
+VERSION: str = __version__
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000000..1016754ef3
--- /dev/null
+++ b/tests/__init__.py
@@ -0,0 +1 @@
+# File generated from our OpenAPI spec by Stainless.
diff --git a/tests/api_resources/__init__.py b/tests/api_resources/__init__.py
new file mode 100644
index 0000000000..1016754ef3
--- /dev/null
+++ b/tests/api_resources/__init__.py
@@ -0,0 +1 @@
+# File generated from our OpenAPI spec by Stainless.
diff --git a/tests/api_resources/audio/__init__.py b/tests/api_resources/audio/__init__.py
new file mode 100644
index 0000000000..1016754ef3
--- /dev/null
+++ b/tests/api_resources/audio/__init__.py
@@ -0,0 +1 @@
+# File generated from our OpenAPI spec by Stainless.
diff --git a/tests/api_resources/audio/test_transcriptions.py b/tests/api_resources/audio/test_transcriptions.py
new file mode 100644
index 0000000000..aefdf1790f
--- /dev/null
+++ b/tests/api_resources/audio/test_transcriptions.py
@@ -0,0 +1,87 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+import os
+
+import pytest
+
+from openai import OpenAI, AsyncOpenAI
+from tests.utils import assert_matches_type
+from openai._client import OpenAI, AsyncOpenAI
+from openai.types.audio import Transcription
+
+base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
+api_key = "My API Key"
+
+
+class TestTranscriptions:
+ strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ loose_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
+
+ @parametrize
+ def test_method_create(self, client: OpenAI) -> None:
+ transcription = client.audio.transcriptions.create(
+ file=b"raw file contents",
+ model="whisper-1",
+ )
+ assert_matches_type(Transcription, transcription, path=["response"])
+
+ @parametrize
+ def test_method_create_with_all_params(self, client: OpenAI) -> None:
+ transcription = client.audio.transcriptions.create(
+ file=b"raw file contents",
+ model="whisper-1",
+ language="string",
+ prompt="string",
+ response_format="json",
+ temperature=0,
+ )
+ assert_matches_type(Transcription, transcription, path=["response"])
+
+ @parametrize
+ def test_raw_response_create(self, client: OpenAI) -> None:
+ response = client.audio.transcriptions.with_raw_response.create(
+ file=b"raw file contents",
+ model="whisper-1",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ transcription = response.parse()
+ assert_matches_type(Transcription, transcription, path=["response"])
+
+
+class TestAsyncTranscriptions:
+ strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ loose_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
+
+ @parametrize
+ async def test_method_create(self, client: AsyncOpenAI) -> None:
+ transcription = await client.audio.transcriptions.create(
+ file=b"raw file contents",
+ model="whisper-1",
+ )
+ assert_matches_type(Transcription, transcription, path=["response"])
+
+ @parametrize
+ async def test_method_create_with_all_params(self, client: AsyncOpenAI) -> None:
+ transcription = await client.audio.transcriptions.create(
+ file=b"raw file contents",
+ model="whisper-1",
+ language="string",
+ prompt="string",
+ response_format="json",
+ temperature=0,
+ )
+ assert_matches_type(Transcription, transcription, path=["response"])
+
+ @parametrize
+ async def test_raw_response_create(self, client: AsyncOpenAI) -> None:
+ response = await client.audio.transcriptions.with_raw_response.create(
+ file=b"raw file contents",
+ model="whisper-1",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ transcription = response.parse()
+ assert_matches_type(Transcription, transcription, path=["response"])
diff --git a/tests/api_resources/audio/test_translations.py b/tests/api_resources/audio/test_translations.py
new file mode 100644
index 0000000000..0657e80eb8
--- /dev/null
+++ b/tests/api_resources/audio/test_translations.py
@@ -0,0 +1,85 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+import os
+
+import pytest
+
+from openai import OpenAI, AsyncOpenAI
+from tests.utils import assert_matches_type
+from openai._client import OpenAI, AsyncOpenAI
+from openai.types.audio import Translation
+
+base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
+api_key = "My API Key"
+
+
+class TestTranslations:
+ strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ loose_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
+
+ @parametrize
+ def test_method_create(self, client: OpenAI) -> None:
+ translation = client.audio.translations.create(
+ file=b"raw file contents",
+ model="whisper-1",
+ )
+ assert_matches_type(Translation, translation, path=["response"])
+
+ @parametrize
+ def test_method_create_with_all_params(self, client: OpenAI) -> None:
+ translation = client.audio.translations.create(
+ file=b"raw file contents",
+ model="whisper-1",
+ prompt="string",
+ response_format="string",
+ temperature=0,
+ )
+ assert_matches_type(Translation, translation, path=["response"])
+
+ @parametrize
+ def test_raw_response_create(self, client: OpenAI) -> None:
+ response = client.audio.translations.with_raw_response.create(
+ file=b"raw file contents",
+ model="whisper-1",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ translation = response.parse()
+ assert_matches_type(Translation, translation, path=["response"])
+
+
+class TestAsyncTranslations:
+ strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ loose_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
+
+ @parametrize
+ async def test_method_create(self, client: AsyncOpenAI) -> None:
+ translation = await client.audio.translations.create(
+ file=b"raw file contents",
+ model="whisper-1",
+ )
+ assert_matches_type(Translation, translation, path=["response"])
+
+ @parametrize
+ async def test_method_create_with_all_params(self, client: AsyncOpenAI) -> None:
+ translation = await client.audio.translations.create(
+ file=b"raw file contents",
+ model="whisper-1",
+ prompt="string",
+ response_format="string",
+ temperature=0,
+ )
+ assert_matches_type(Translation, translation, path=["response"])
+
+ @parametrize
+ async def test_raw_response_create(self, client: AsyncOpenAI) -> None:
+ response = await client.audio.translations.with_raw_response.create(
+ file=b"raw file contents",
+ model="whisper-1",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ translation = response.parse()
+ assert_matches_type(Translation, translation, path=["response"])
diff --git a/tests/api_resources/chat/__init__.py b/tests/api_resources/chat/__init__.py
new file mode 100644
index 0000000000..1016754ef3
--- /dev/null
+++ b/tests/api_resources/chat/__init__.py
@@ -0,0 +1 @@
+# File generated from our OpenAPI spec by Stainless.
diff --git a/tests/api_resources/chat/test_completions.py b/tests/api_resources/chat/test_completions.py
new file mode 100644
index 0000000000..dacf5d2596
--- /dev/null
+++ b/tests/api_resources/chat/test_completions.py
@@ -0,0 +1,281 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+import os
+
+import pytest
+
+from openai import OpenAI, AsyncOpenAI
+from tests.utils import assert_matches_type
+from openai._client import OpenAI, AsyncOpenAI
+from openai.types.chat import ChatCompletion
+
+base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
+api_key = "My API Key"
+
+
+class TestCompletions:
+ strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ loose_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
+
+ @parametrize
+ def test_method_create_overload_1(self, client: OpenAI) -> None:
+ completion = client.chat.completions.create(
+ messages=[
+ {
+ "content": "string",
+ "role": "system",
+ }
+ ],
+ model="gpt-3.5-turbo",
+ )
+ assert_matches_type(ChatCompletion, completion, path=["response"])
+
+ @parametrize
+ def test_method_create_with_all_params_overload_1(self, client: OpenAI) -> None:
+ completion = client.chat.completions.create(
+ messages=[
+ {
+ "content": "string",
+ "function_call": {
+ "arguments": "string",
+ "name": "string",
+ },
+ "name": "string",
+ "role": "system",
+ }
+ ],
+ model="gpt-3.5-turbo",
+ frequency_penalty=-2,
+ function_call="none",
+ functions=[
+ {
+ "description": "string",
+ "name": "string",
+ "parameters": {"foo": "bar"},
+ }
+ ],
+ logit_bias={"foo": 0},
+ max_tokens=0,
+ n=1,
+ presence_penalty=-2,
+ stop="string",
+ stream=False,
+ temperature=1,
+ top_p=1,
+ user="user-1234",
+ )
+ assert_matches_type(ChatCompletion, completion, path=["response"])
+
+ @parametrize
+ def test_raw_response_create_overload_1(self, client: OpenAI) -> None:
+ response = client.chat.completions.with_raw_response.create(
+ messages=[
+ {
+ "content": "string",
+ "role": "system",
+ }
+ ],
+ model="gpt-3.5-turbo",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ completion = response.parse()
+ assert_matches_type(ChatCompletion, completion, path=["response"])
+
+ @parametrize
+ def test_method_create_overload_2(self, client: OpenAI) -> None:
+ client.chat.completions.create(
+ messages=[
+ {
+ "content": "string",
+ "role": "system",
+ }
+ ],
+ model="gpt-3.5-turbo",
+ stream=True,
+ )
+
+ @parametrize
+ def test_method_create_with_all_params_overload_2(self, client: OpenAI) -> None:
+ client.chat.completions.create(
+ messages=[
+ {
+ "content": "string",
+ "function_call": {
+ "arguments": "string",
+ "name": "string",
+ },
+ "name": "string",
+ "role": "system",
+ }
+ ],
+ model="gpt-3.5-turbo",
+ stream=True,
+ frequency_penalty=-2,
+ function_call="none",
+ functions=[
+ {
+ "description": "string",
+ "name": "string",
+ "parameters": {"foo": "bar"},
+ }
+ ],
+ logit_bias={"foo": 0},
+ max_tokens=0,
+ n=1,
+ presence_penalty=-2,
+ stop="string",
+ temperature=1,
+ top_p=1,
+ user="user-1234",
+ )
+
+ @parametrize
+ def test_raw_response_create_overload_2(self, client: OpenAI) -> None:
+ response = client.chat.completions.with_raw_response.create(
+ messages=[
+ {
+ "content": "string",
+ "role": "system",
+ }
+ ],
+ model="gpt-3.5-turbo",
+ stream=True,
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ response.parse()
+
+
+class TestAsyncCompletions:
+ strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ loose_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
+
+ @parametrize
+ async def test_method_create_overload_1(self, client: AsyncOpenAI) -> None:
+ completion = await client.chat.completions.create(
+ messages=[
+ {
+ "content": "string",
+ "role": "system",
+ }
+ ],
+ model="gpt-3.5-turbo",
+ )
+ assert_matches_type(ChatCompletion, completion, path=["response"])
+
+ @parametrize
+ async def test_method_create_with_all_params_overload_1(self, client: AsyncOpenAI) -> None:
+ completion = await client.chat.completions.create(
+ messages=[
+ {
+ "content": "string",
+ "function_call": {
+ "arguments": "string",
+ "name": "string",
+ },
+ "name": "string",
+ "role": "system",
+ }
+ ],
+ model="gpt-3.5-turbo",
+ frequency_penalty=-2,
+ function_call="none",
+ functions=[
+ {
+ "description": "string",
+ "name": "string",
+ "parameters": {"foo": "bar"},
+ }
+ ],
+ logit_bias={"foo": 0},
+ max_tokens=0,
+ n=1,
+ presence_penalty=-2,
+ stop="string",
+ stream=False,
+ temperature=1,
+ top_p=1,
+ user="user-1234",
+ )
+ assert_matches_type(ChatCompletion, completion, path=["response"])
+
+ @parametrize
+ async def test_raw_response_create_overload_1(self, client: AsyncOpenAI) -> None:
+ response = await client.chat.completions.with_raw_response.create(
+ messages=[
+ {
+ "content": "string",
+ "role": "system",
+ }
+ ],
+ model="gpt-3.5-turbo",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ completion = response.parse()
+ assert_matches_type(ChatCompletion, completion, path=["response"])
+
+ @parametrize
+ async def test_method_create_overload_2(self, client: AsyncOpenAI) -> None:
+ await client.chat.completions.create(
+ messages=[
+ {
+ "content": "string",
+ "role": "system",
+ }
+ ],
+ model="gpt-3.5-turbo",
+ stream=True,
+ )
+
+ @parametrize
+ async def test_method_create_with_all_params_overload_2(self, client: AsyncOpenAI) -> None:
+ await client.chat.completions.create(
+ messages=[
+ {
+ "content": "string",
+ "function_call": {
+ "arguments": "string",
+ "name": "string",
+ },
+ "name": "string",
+ "role": "system",
+ }
+ ],
+ model="gpt-3.5-turbo",
+ stream=True,
+ frequency_penalty=-2,
+ function_call="none",
+ functions=[
+ {
+ "description": "string",
+ "name": "string",
+ "parameters": {"foo": "bar"},
+ }
+ ],
+ logit_bias={"foo": 0},
+ max_tokens=0,
+ n=1,
+ presence_penalty=-2,
+ stop="string",
+ temperature=1,
+ top_p=1,
+ user="user-1234",
+ )
+
+ @parametrize
+ async def test_raw_response_create_overload_2(self, client: AsyncOpenAI) -> None:
+ response = await client.chat.completions.with_raw_response.create(
+ messages=[
+ {
+ "content": "string",
+ "role": "system",
+ }
+ ],
+ model="gpt-3.5-turbo",
+ stream=True,
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ response.parse()
diff --git a/tests/api_resources/fine_tuning/__init__.py b/tests/api_resources/fine_tuning/__init__.py
new file mode 100644
index 0000000000..1016754ef3
--- /dev/null
+++ b/tests/api_resources/fine_tuning/__init__.py
@@ -0,0 +1 @@
+# File generated from our OpenAPI spec by Stainless.
diff --git a/tests/api_resources/fine_tuning/test_jobs.py b/tests/api_resources/fine_tuning/test_jobs.py
new file mode 100644
index 0000000000..9defcadab6
--- /dev/null
+++ b/tests/api_resources/fine_tuning/test_jobs.py
@@ -0,0 +1,240 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+import os
+
+import pytest
+
+from openai import OpenAI, AsyncOpenAI
+from tests.utils import assert_matches_type
+from openai._client import OpenAI, AsyncOpenAI
+from openai.pagination import SyncCursorPage, AsyncCursorPage
+from openai.types.fine_tuning import FineTuningJob, FineTuningJobEvent
+
+base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
+api_key = "My API Key"
+
+
+class TestJobs:
+ strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ loose_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
+
+ @parametrize
+ def test_method_create(self, client: OpenAI) -> None:
+ job = client.fine_tuning.jobs.create(
+ model="gpt-3.5-turbo",
+ training_file="file-abc123",
+ )
+ assert_matches_type(FineTuningJob, job, path=["response"])
+
+ @parametrize
+ def test_method_create_with_all_params(self, client: OpenAI) -> None:
+ job = client.fine_tuning.jobs.create(
+ model="gpt-3.5-turbo",
+ training_file="file-abc123",
+ hyperparameters={"n_epochs": "auto"},
+ suffix="x",
+ validation_file="file-abc123",
+ )
+ assert_matches_type(FineTuningJob, job, path=["response"])
+
+ @parametrize
+ def test_raw_response_create(self, client: OpenAI) -> None:
+ response = client.fine_tuning.jobs.with_raw_response.create(
+ model="gpt-3.5-turbo",
+ training_file="file-abc123",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ job = response.parse()
+ assert_matches_type(FineTuningJob, job, path=["response"])
+
+ @parametrize
+ def test_method_retrieve(self, client: OpenAI) -> None:
+ job = client.fine_tuning.jobs.retrieve(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ )
+ assert_matches_type(FineTuningJob, job, path=["response"])
+
+ @parametrize
+ def test_raw_response_retrieve(self, client: OpenAI) -> None:
+ response = client.fine_tuning.jobs.with_raw_response.retrieve(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ job = response.parse()
+ assert_matches_type(FineTuningJob, job, path=["response"])
+
+ @parametrize
+ def test_method_list(self, client: OpenAI) -> None:
+ job = client.fine_tuning.jobs.list()
+ assert_matches_type(SyncCursorPage[FineTuningJob], job, path=["response"])
+
+ @parametrize
+ def test_method_list_with_all_params(self, client: OpenAI) -> None:
+ job = client.fine_tuning.jobs.list(
+ after="string",
+ limit=0,
+ )
+ assert_matches_type(SyncCursorPage[FineTuningJob], job, path=["response"])
+
+ @parametrize
+ def test_raw_response_list(self, client: OpenAI) -> None:
+ response = client.fine_tuning.jobs.with_raw_response.list()
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ job = response.parse()
+ assert_matches_type(SyncCursorPage[FineTuningJob], job, path=["response"])
+
+ @parametrize
+ def test_method_cancel(self, client: OpenAI) -> None:
+ job = client.fine_tuning.jobs.cancel(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ )
+ assert_matches_type(FineTuningJob, job, path=["response"])
+
+ @parametrize
+ def test_raw_response_cancel(self, client: OpenAI) -> None:
+ response = client.fine_tuning.jobs.with_raw_response.cancel(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ job = response.parse()
+ assert_matches_type(FineTuningJob, job, path=["response"])
+
+ @parametrize
+ def test_method_list_events(self, client: OpenAI) -> None:
+ job = client.fine_tuning.jobs.list_events(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ )
+ assert_matches_type(SyncCursorPage[FineTuningJobEvent], job, path=["response"])
+
+ @parametrize
+ def test_method_list_events_with_all_params(self, client: OpenAI) -> None:
+ job = client.fine_tuning.jobs.list_events(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ after="string",
+ limit=0,
+ )
+ assert_matches_type(SyncCursorPage[FineTuningJobEvent], job, path=["response"])
+
+ @parametrize
+ def test_raw_response_list_events(self, client: OpenAI) -> None:
+ response = client.fine_tuning.jobs.with_raw_response.list_events(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ job = response.parse()
+ assert_matches_type(SyncCursorPage[FineTuningJobEvent], job, path=["response"])
+
+
+class TestAsyncJobs:
+ strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ loose_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
+
+ @parametrize
+ async def test_method_create(self, client: AsyncOpenAI) -> None:
+ job = await client.fine_tuning.jobs.create(
+ model="gpt-3.5-turbo",
+ training_file="file-abc123",
+ )
+ assert_matches_type(FineTuningJob, job, path=["response"])
+
+ @parametrize
+ async def test_method_create_with_all_params(self, client: AsyncOpenAI) -> None:
+ job = await client.fine_tuning.jobs.create(
+ model="gpt-3.5-turbo",
+ training_file="file-abc123",
+ hyperparameters={"n_epochs": "auto"},
+ suffix="x",
+ validation_file="file-abc123",
+ )
+ assert_matches_type(FineTuningJob, job, path=["response"])
+
+ @parametrize
+ async def test_raw_response_create(self, client: AsyncOpenAI) -> None:
+ response = await client.fine_tuning.jobs.with_raw_response.create(
+ model="gpt-3.5-turbo",
+ training_file="file-abc123",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ job = response.parse()
+ assert_matches_type(FineTuningJob, job, path=["response"])
+
+ @parametrize
+ async def test_method_retrieve(self, client: AsyncOpenAI) -> None:
+ job = await client.fine_tuning.jobs.retrieve(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ )
+ assert_matches_type(FineTuningJob, job, path=["response"])
+
+ @parametrize
+ async def test_raw_response_retrieve(self, client: AsyncOpenAI) -> None:
+ response = await client.fine_tuning.jobs.with_raw_response.retrieve(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ job = response.parse()
+ assert_matches_type(FineTuningJob, job, path=["response"])
+
+ @parametrize
+ async def test_method_list(self, client: AsyncOpenAI) -> None:
+ job = await client.fine_tuning.jobs.list()
+ assert_matches_type(AsyncCursorPage[FineTuningJob], job, path=["response"])
+
+ @parametrize
+ async def test_method_list_with_all_params(self, client: AsyncOpenAI) -> None:
+ job = await client.fine_tuning.jobs.list(
+ after="string",
+ limit=0,
+ )
+ assert_matches_type(AsyncCursorPage[FineTuningJob], job, path=["response"])
+
+ @parametrize
+ async def test_raw_response_list(self, client: AsyncOpenAI) -> None:
+ response = await client.fine_tuning.jobs.with_raw_response.list()
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ job = response.parse()
+ assert_matches_type(AsyncCursorPage[FineTuningJob], job, path=["response"])
+
+ @parametrize
+ async def test_method_cancel(self, client: AsyncOpenAI) -> None:
+ job = await client.fine_tuning.jobs.cancel(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ )
+ assert_matches_type(FineTuningJob, job, path=["response"])
+
+ @parametrize
+ async def test_raw_response_cancel(self, client: AsyncOpenAI) -> None:
+ response = await client.fine_tuning.jobs.with_raw_response.cancel(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ job = response.parse()
+ assert_matches_type(FineTuningJob, job, path=["response"])
+
+ @parametrize
+ async def test_method_list_events(self, client: AsyncOpenAI) -> None:
+ job = await client.fine_tuning.jobs.list_events(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ )
+ assert_matches_type(AsyncCursorPage[FineTuningJobEvent], job, path=["response"])
+
+ @parametrize
+ async def test_method_list_events_with_all_params(self, client: AsyncOpenAI) -> None:
+ job = await client.fine_tuning.jobs.list_events(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ after="string",
+ limit=0,
+ )
+ assert_matches_type(AsyncCursorPage[FineTuningJobEvent], job, path=["response"])
+
+ @parametrize
+ async def test_raw_response_list_events(self, client: AsyncOpenAI) -> None:
+ response = await client.fine_tuning.jobs.with_raw_response.list_events(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ job = response.parse()
+ assert_matches_type(AsyncCursorPage[FineTuningJobEvent], job, path=["response"])
diff --git a/tests/api_resources/test_completions.py b/tests/api_resources/test_completions.py
new file mode 100644
index 0000000000..7b48e88ed2
--- /dev/null
+++ b/tests/api_resources/test_completions.py
@@ -0,0 +1,185 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+import os
+
+import pytest
+
+from openai import OpenAI, AsyncOpenAI
+from tests.utils import assert_matches_type
+from openai.types import Completion
+from openai._client import OpenAI, AsyncOpenAI
+
+base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
+api_key = "My API Key"
+
+
+class TestCompletions:
+ strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ loose_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
+
+ @parametrize
+ def test_method_create_overload_1(self, client: OpenAI) -> None:
+ completion = client.completions.create(
+ model="string",
+ prompt="This is a test.",
+ )
+ assert_matches_type(Completion, completion, path=["response"])
+
+ @parametrize
+ def test_method_create_with_all_params_overload_1(self, client: OpenAI) -> None:
+ completion = client.completions.create(
+ model="string",
+ prompt="This is a test.",
+ best_of=0,
+ echo=True,
+ frequency_penalty=-2,
+ logit_bias={"foo": 0},
+ logprobs=0,
+ max_tokens=16,
+ n=1,
+ presence_penalty=-2,
+ stop="\n",
+ stream=False,
+ suffix="test.",
+ temperature=1,
+ top_p=1,
+ user="user-1234",
+ )
+ assert_matches_type(Completion, completion, path=["response"])
+
+ @parametrize
+ def test_raw_response_create_overload_1(self, client: OpenAI) -> None:
+ response = client.completions.with_raw_response.create(
+ model="string",
+ prompt="This is a test.",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ completion = response.parse()
+ assert_matches_type(Completion, completion, path=["response"])
+
+ @parametrize
+ def test_method_create_overload_2(self, client: OpenAI) -> None:
+ client.completions.create(
+ model="string",
+ prompt="This is a test.",
+ stream=True,
+ )
+
+ @parametrize
+ def test_method_create_with_all_params_overload_2(self, client: OpenAI) -> None:
+ client.completions.create(
+ model="string",
+ prompt="This is a test.",
+ stream=True,
+ best_of=0,
+ echo=True,
+ frequency_penalty=-2,
+ logit_bias={"foo": 0},
+ logprobs=0,
+ max_tokens=16,
+ n=1,
+ presence_penalty=-2,
+ stop="\n",
+ suffix="test.",
+ temperature=1,
+ top_p=1,
+ user="user-1234",
+ )
+
+ @parametrize
+ def test_raw_response_create_overload_2(self, client: OpenAI) -> None:
+ response = client.completions.with_raw_response.create(
+ model="string",
+ prompt="This is a test.",
+ stream=True,
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ response.parse()
+
+
+class TestAsyncCompletions:
+ strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ loose_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
+
+ @parametrize
+ async def test_method_create_overload_1(self, client: AsyncOpenAI) -> None:
+ completion = await client.completions.create(
+ model="string",
+ prompt="This is a test.",
+ )
+ assert_matches_type(Completion, completion, path=["response"])
+
+ @parametrize
+ async def test_method_create_with_all_params_overload_1(self, client: AsyncOpenAI) -> None:
+ completion = await client.completions.create(
+ model="string",
+ prompt="This is a test.",
+ best_of=0,
+ echo=True,
+ frequency_penalty=-2,
+ logit_bias={"foo": 0},
+ logprobs=0,
+ max_tokens=16,
+ n=1,
+ presence_penalty=-2,
+ stop="\n",
+ stream=False,
+ suffix="test.",
+ temperature=1,
+ top_p=1,
+ user="user-1234",
+ )
+ assert_matches_type(Completion, completion, path=["response"])
+
+ @parametrize
+ async def test_raw_response_create_overload_1(self, client: AsyncOpenAI) -> None:
+ response = await client.completions.with_raw_response.create(
+ model="string",
+ prompt="This is a test.",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ completion = response.parse()
+ assert_matches_type(Completion, completion, path=["response"])
+
+ @parametrize
+ async def test_method_create_overload_2(self, client: AsyncOpenAI) -> None:
+ await client.completions.create(
+ model="string",
+ prompt="This is a test.",
+ stream=True,
+ )
+
+ @parametrize
+ async def test_method_create_with_all_params_overload_2(self, client: AsyncOpenAI) -> None:
+ await client.completions.create(
+ model="string",
+ prompt="This is a test.",
+ stream=True,
+ best_of=0,
+ echo=True,
+ frequency_penalty=-2,
+ logit_bias={"foo": 0},
+ logprobs=0,
+ max_tokens=16,
+ n=1,
+ presence_penalty=-2,
+ stop="\n",
+ suffix="test.",
+ temperature=1,
+ top_p=1,
+ user="user-1234",
+ )
+
+ @parametrize
+ async def test_raw_response_create_overload_2(self, client: AsyncOpenAI) -> None:
+ response = await client.completions.with_raw_response.create(
+ model="string",
+ prompt="This is a test.",
+ stream=True,
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ response.parse()
diff --git a/tests/api_resources/test_edits.py b/tests/api_resources/test_edits.py
new file mode 100644
index 0000000000..76069d6b83
--- /dev/null
+++ b/tests/api_resources/test_edits.py
@@ -0,0 +1,95 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+import os
+
+import pytest
+
+from openai import OpenAI, AsyncOpenAI
+from tests.utils import assert_matches_type
+from openai.types import Edit
+from openai._client import OpenAI, AsyncOpenAI
+
+# pyright: reportDeprecated=false
+
+base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
+api_key = "My API Key"
+
+
+class TestEdits:
+ strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ loose_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
+
+ @parametrize
+ def test_method_create(self, client: OpenAI) -> None:
+ with pytest.warns(DeprecationWarning):
+ edit = client.edits.create(
+ instruction="Fix the spelling mistakes.",
+ model="text-davinci-edit-001",
+ )
+ assert_matches_type(Edit, edit, path=["response"])
+
+ @parametrize
+ def test_method_create_with_all_params(self, client: OpenAI) -> None:
+ with pytest.warns(DeprecationWarning):
+ edit = client.edits.create(
+ instruction="Fix the spelling mistakes.",
+ model="text-davinci-edit-001",
+ input="What day of the wek is it?",
+ n=1,
+ temperature=1,
+ top_p=1,
+ )
+ assert_matches_type(Edit, edit, path=["response"])
+
+ @parametrize
+ def test_raw_response_create(self, client: OpenAI) -> None:
+ with pytest.warns(DeprecationWarning):
+ response = client.edits.with_raw_response.create(
+ instruction="Fix the spelling mistakes.",
+ model="text-davinci-edit-001",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ edit = response.parse()
+ assert_matches_type(Edit, edit, path=["response"])
+
+
+class TestAsyncEdits:
+ strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ loose_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
+
+ @parametrize
+ async def test_method_create(self, client: AsyncOpenAI) -> None:
+ with pytest.warns(DeprecationWarning):
+ edit = await client.edits.create(
+ instruction="Fix the spelling mistakes.",
+ model="text-davinci-edit-001",
+ )
+ assert_matches_type(Edit, edit, path=["response"])
+
+ @parametrize
+ async def test_method_create_with_all_params(self, client: AsyncOpenAI) -> None:
+ with pytest.warns(DeprecationWarning):
+ edit = await client.edits.create(
+ instruction="Fix the spelling mistakes.",
+ model="text-davinci-edit-001",
+ input="What day of the wek is it?",
+ n=1,
+ temperature=1,
+ top_p=1,
+ )
+ assert_matches_type(Edit, edit, path=["response"])
+
+ @parametrize
+ async def test_raw_response_create(self, client: AsyncOpenAI) -> None:
+ with pytest.warns(DeprecationWarning):
+ response = await client.edits.with_raw_response.create(
+ instruction="Fix the spelling mistakes.",
+ model="text-davinci-edit-001",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ edit = response.parse()
+ assert_matches_type(Edit, edit, path=["response"])
diff --git a/tests/api_resources/test_embeddings.py b/tests/api_resources/test_embeddings.py
new file mode 100644
index 0000000000..faf07ffb7c
--- /dev/null
+++ b/tests/api_resources/test_embeddings.py
@@ -0,0 +1,83 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+import os
+
+import pytest
+
+from openai import OpenAI, AsyncOpenAI
+from tests.utils import assert_matches_type
+from openai.types import CreateEmbeddingResponse
+from openai._client import OpenAI, AsyncOpenAI
+
+base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
+api_key = "My API Key"
+
+
+class TestEmbeddings:
+ strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ loose_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
+
+ @parametrize
+ def test_method_create(self, client: OpenAI) -> None:
+ embedding = client.embeddings.create(
+ input="The quick brown fox jumped over the lazy dog",
+ model="text-embedding-ada-002",
+ )
+ assert_matches_type(CreateEmbeddingResponse, embedding, path=["response"])
+
+ @parametrize
+ def test_method_create_with_all_params(self, client: OpenAI) -> None:
+ embedding = client.embeddings.create(
+ input="The quick brown fox jumped over the lazy dog",
+ model="text-embedding-ada-002",
+ encoding_format="float",
+ user="user-1234",
+ )
+ assert_matches_type(CreateEmbeddingResponse, embedding, path=["response"])
+
+ @parametrize
+ def test_raw_response_create(self, client: OpenAI) -> None:
+ response = client.embeddings.with_raw_response.create(
+ input="The quick brown fox jumped over the lazy dog",
+ model="text-embedding-ada-002",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ embedding = response.parse()
+ assert_matches_type(CreateEmbeddingResponse, embedding, path=["response"])
+
+
+class TestAsyncEmbeddings:
+ strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ loose_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
+
+ @parametrize
+ async def test_method_create(self, client: AsyncOpenAI) -> None:
+ embedding = await client.embeddings.create(
+ input="The quick brown fox jumped over the lazy dog",
+ model="text-embedding-ada-002",
+ )
+ assert_matches_type(CreateEmbeddingResponse, embedding, path=["response"])
+
+ @parametrize
+ async def test_method_create_with_all_params(self, client: AsyncOpenAI) -> None:
+ embedding = await client.embeddings.create(
+ input="The quick brown fox jumped over the lazy dog",
+ model="text-embedding-ada-002",
+ encoding_format="float",
+ user="user-1234",
+ )
+ assert_matches_type(CreateEmbeddingResponse, embedding, path=["response"])
+
+ @parametrize
+ async def test_raw_response_create(self, client: AsyncOpenAI) -> None:
+ response = await client.embeddings.with_raw_response.create(
+ input="The quick brown fox jumped over the lazy dog",
+ model="text-embedding-ada-002",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ embedding = response.parse()
+ assert_matches_type(CreateEmbeddingResponse, embedding, path=["response"])
diff --git a/tests/api_resources/test_files.py b/tests/api_resources/test_files.py
new file mode 100644
index 0000000000..389763586e
--- /dev/null
+++ b/tests/api_resources/test_files.py
@@ -0,0 +1,184 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+import os
+
+import pytest
+
+from openai import OpenAI, AsyncOpenAI
+from tests.utils import assert_matches_type
+from openai.types import FileObject, FileDeleted
+from openai._client import OpenAI, AsyncOpenAI
+from openai.pagination import SyncPage, AsyncPage
+
+base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
+api_key = "My API Key"
+
+
+class TestFiles:
+ strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ loose_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
+
+ @parametrize
+ def test_method_create(self, client: OpenAI) -> None:
+ file = client.files.create(
+ file=b"raw file contents",
+ purpose="string",
+ )
+ assert_matches_type(FileObject, file, path=["response"])
+
+ @parametrize
+ def test_raw_response_create(self, client: OpenAI) -> None:
+ response = client.files.with_raw_response.create(
+ file=b"raw file contents",
+ purpose="string",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ file = response.parse()
+ assert_matches_type(FileObject, file, path=["response"])
+
+ @parametrize
+ def test_method_retrieve(self, client: OpenAI) -> None:
+ file = client.files.retrieve(
+ "string",
+ )
+ assert_matches_type(FileObject, file, path=["response"])
+
+ @parametrize
+ def test_raw_response_retrieve(self, client: OpenAI) -> None:
+ response = client.files.with_raw_response.retrieve(
+ "string",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ file = response.parse()
+ assert_matches_type(FileObject, file, path=["response"])
+
+ @parametrize
+ def test_method_list(self, client: OpenAI) -> None:
+ file = client.files.list()
+ assert_matches_type(SyncPage[FileObject], file, path=["response"])
+
+ @parametrize
+ def test_raw_response_list(self, client: OpenAI) -> None:
+ response = client.files.with_raw_response.list()
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ file = response.parse()
+ assert_matches_type(SyncPage[FileObject], file, path=["response"])
+
+ @parametrize
+ def test_method_delete(self, client: OpenAI) -> None:
+ file = client.files.delete(
+ "string",
+ )
+ assert_matches_type(FileDeleted, file, path=["response"])
+
+ @parametrize
+ def test_raw_response_delete(self, client: OpenAI) -> None:
+ response = client.files.with_raw_response.delete(
+ "string",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ file = response.parse()
+ assert_matches_type(FileDeleted, file, path=["response"])
+
+ @parametrize
+ def test_method_retrieve_content(self, client: OpenAI) -> None:
+ file = client.files.retrieve_content(
+ "string",
+ )
+ assert_matches_type(str, file, path=["response"])
+
+ @parametrize
+ def test_raw_response_retrieve_content(self, client: OpenAI) -> None:
+ response = client.files.with_raw_response.retrieve_content(
+ "string",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ file = response.parse()
+ assert_matches_type(str, file, path=["response"])
+
+
+class TestAsyncFiles:
+ strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ loose_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
+
+ @parametrize
+ async def test_method_create(self, client: AsyncOpenAI) -> None:
+ file = await client.files.create(
+ file=b"raw file contents",
+ purpose="string",
+ )
+ assert_matches_type(FileObject, file, path=["response"])
+
+ @parametrize
+ async def test_raw_response_create(self, client: AsyncOpenAI) -> None:
+ response = await client.files.with_raw_response.create(
+ file=b"raw file contents",
+ purpose="string",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ file = response.parse()
+ assert_matches_type(FileObject, file, path=["response"])
+
+ @parametrize
+ async def test_method_retrieve(self, client: AsyncOpenAI) -> None:
+ file = await client.files.retrieve(
+ "string",
+ )
+ assert_matches_type(FileObject, file, path=["response"])
+
+ @parametrize
+ async def test_raw_response_retrieve(self, client: AsyncOpenAI) -> None:
+ response = await client.files.with_raw_response.retrieve(
+ "string",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ file = response.parse()
+ assert_matches_type(FileObject, file, path=["response"])
+
+ @parametrize
+ async def test_method_list(self, client: AsyncOpenAI) -> None:
+ file = await client.files.list()
+ assert_matches_type(AsyncPage[FileObject], file, path=["response"])
+
+ @parametrize
+ async def test_raw_response_list(self, client: AsyncOpenAI) -> None:
+ response = await client.files.with_raw_response.list()
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ file = response.parse()
+ assert_matches_type(AsyncPage[FileObject], file, path=["response"])
+
+ @parametrize
+ async def test_method_delete(self, client: AsyncOpenAI) -> None:
+ file = await client.files.delete(
+ "string",
+ )
+ assert_matches_type(FileDeleted, file, path=["response"])
+
+ @parametrize
+ async def test_raw_response_delete(self, client: AsyncOpenAI) -> None:
+ response = await client.files.with_raw_response.delete(
+ "string",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ file = response.parse()
+ assert_matches_type(FileDeleted, file, path=["response"])
+
+ @parametrize
+ async def test_method_retrieve_content(self, client: AsyncOpenAI) -> None:
+ file = await client.files.retrieve_content(
+ "string",
+ )
+ assert_matches_type(str, file, path=["response"])
+
+ @parametrize
+ async def test_raw_response_retrieve_content(self, client: AsyncOpenAI) -> None:
+ response = await client.files.with_raw_response.retrieve_content(
+ "string",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ file = response.parse()
+ assert_matches_type(str, file, path=["response"])
diff --git a/tests/api_resources/test_fine_tunes.py b/tests/api_resources/test_fine_tunes.py
new file mode 100644
index 0000000000..edaf784848
--- /dev/null
+++ b/tests/api_resources/test_fine_tunes.py
@@ -0,0 +1,274 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+import os
+
+import pytest
+
+from openai import OpenAI, AsyncOpenAI
+from tests.utils import assert_matches_type
+from openai.types import FineTune, FineTuneEventsListResponse
+from openai._client import OpenAI, AsyncOpenAI
+from openai.pagination import SyncPage, AsyncPage
+
+base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
+api_key = "My API Key"
+
+
+class TestFineTunes:
+ strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ loose_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
+
+ @parametrize
+ def test_method_create(self, client: OpenAI) -> None:
+ fine_tune = client.fine_tunes.create(
+ training_file="file-abc123",
+ )
+ assert_matches_type(FineTune, fine_tune, path=["response"])
+
+ @parametrize
+ def test_method_create_with_all_params(self, client: OpenAI) -> None:
+ fine_tune = client.fine_tunes.create(
+ training_file="file-abc123",
+ batch_size=0,
+ classification_betas=[0.6, 1, 1.5, 2],
+ classification_n_classes=0,
+ classification_positive_class="string",
+ compute_classification_metrics=True,
+ hyperparameters={"n_epochs": "auto"},
+ learning_rate_multiplier=0,
+ model="curie",
+ prompt_loss_weight=0,
+ suffix="x",
+ validation_file="file-abc123",
+ )
+ assert_matches_type(FineTune, fine_tune, path=["response"])
+
+ @parametrize
+ def test_raw_response_create(self, client: OpenAI) -> None:
+ response = client.fine_tunes.with_raw_response.create(
+ training_file="file-abc123",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ fine_tune = response.parse()
+ assert_matches_type(FineTune, fine_tune, path=["response"])
+
+ @parametrize
+ def test_method_retrieve(self, client: OpenAI) -> None:
+ fine_tune = client.fine_tunes.retrieve(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ )
+ assert_matches_type(FineTune, fine_tune, path=["response"])
+
+ @parametrize
+ def test_raw_response_retrieve(self, client: OpenAI) -> None:
+ response = client.fine_tunes.with_raw_response.retrieve(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ fine_tune = response.parse()
+ assert_matches_type(FineTune, fine_tune, path=["response"])
+
+ @parametrize
+ def test_method_list(self, client: OpenAI) -> None:
+ fine_tune = client.fine_tunes.list()
+ assert_matches_type(SyncPage[FineTune], fine_tune, path=["response"])
+
+ @parametrize
+ def test_raw_response_list(self, client: OpenAI) -> None:
+ response = client.fine_tunes.with_raw_response.list()
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ fine_tune = response.parse()
+ assert_matches_type(SyncPage[FineTune], fine_tune, path=["response"])
+
+ @parametrize
+ def test_method_cancel(self, client: OpenAI) -> None:
+ fine_tune = client.fine_tunes.cancel(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ )
+ assert_matches_type(FineTune, fine_tune, path=["response"])
+
+ @parametrize
+ def test_raw_response_cancel(self, client: OpenAI) -> None:
+ response = client.fine_tunes.with_raw_response.cancel(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ fine_tune = response.parse()
+ assert_matches_type(FineTune, fine_tune, path=["response"])
+
+ @pytest.mark.skip(reason="Prism chokes on this")
+ @parametrize
+ def test_method_list_events_overload_1(self, client: OpenAI) -> None:
+ fine_tune = client.fine_tunes.list_events(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ )
+ assert_matches_type(FineTuneEventsListResponse, fine_tune, path=["response"])
+
+ @pytest.mark.skip(reason="Prism chokes on this")
+ @parametrize
+ def test_method_list_events_with_all_params_overload_1(self, client: OpenAI) -> None:
+ fine_tune = client.fine_tunes.list_events(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ stream=False,
+ )
+ assert_matches_type(FineTuneEventsListResponse, fine_tune, path=["response"])
+
+ @pytest.mark.skip(reason="Prism chokes on this")
+ @parametrize
+ def test_raw_response_list_events_overload_1(self, client: OpenAI) -> None:
+ response = client.fine_tunes.with_raw_response.list_events(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ fine_tune = response.parse()
+ assert_matches_type(FineTuneEventsListResponse, fine_tune, path=["response"])
+
+ @pytest.mark.skip(reason="Prism chokes on this")
+ @parametrize
+ def test_method_list_events_overload_2(self, client: OpenAI) -> None:
+ client.fine_tunes.list_events(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ stream=True,
+ )
+
+ @pytest.mark.skip(reason="Prism chokes on this")
+ @parametrize
+ def test_raw_response_list_events_overload_2(self, client: OpenAI) -> None:
+ response = client.fine_tunes.with_raw_response.list_events(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ stream=True,
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ response.parse()
+
+
+class TestAsyncFineTunes:
+ strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ loose_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
+
+ @parametrize
+ async def test_method_create(self, client: AsyncOpenAI) -> None:
+ fine_tune = await client.fine_tunes.create(
+ training_file="file-abc123",
+ )
+ assert_matches_type(FineTune, fine_tune, path=["response"])
+
+ @parametrize
+ async def test_method_create_with_all_params(self, client: AsyncOpenAI) -> None:
+ fine_tune = await client.fine_tunes.create(
+ training_file="file-abc123",
+ batch_size=0,
+ classification_betas=[0.6, 1, 1.5, 2],
+ classification_n_classes=0,
+ classification_positive_class="string",
+ compute_classification_metrics=True,
+ hyperparameters={"n_epochs": "auto"},
+ learning_rate_multiplier=0,
+ model="curie",
+ prompt_loss_weight=0,
+ suffix="x",
+ validation_file="file-abc123",
+ )
+ assert_matches_type(FineTune, fine_tune, path=["response"])
+
+ @parametrize
+ async def test_raw_response_create(self, client: AsyncOpenAI) -> None:
+ response = await client.fine_tunes.with_raw_response.create(
+ training_file="file-abc123",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ fine_tune = response.parse()
+ assert_matches_type(FineTune, fine_tune, path=["response"])
+
+ @parametrize
+ async def test_method_retrieve(self, client: AsyncOpenAI) -> None:
+ fine_tune = await client.fine_tunes.retrieve(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ )
+ assert_matches_type(FineTune, fine_tune, path=["response"])
+
+ @parametrize
+ async def test_raw_response_retrieve(self, client: AsyncOpenAI) -> None:
+ response = await client.fine_tunes.with_raw_response.retrieve(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ fine_tune = response.parse()
+ assert_matches_type(FineTune, fine_tune, path=["response"])
+
+ @parametrize
+ async def test_method_list(self, client: AsyncOpenAI) -> None:
+ fine_tune = await client.fine_tunes.list()
+ assert_matches_type(AsyncPage[FineTune], fine_tune, path=["response"])
+
+ @parametrize
+ async def test_raw_response_list(self, client: AsyncOpenAI) -> None:
+ response = await client.fine_tunes.with_raw_response.list()
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ fine_tune = response.parse()
+ assert_matches_type(AsyncPage[FineTune], fine_tune, path=["response"])
+
+ @parametrize
+ async def test_method_cancel(self, client: AsyncOpenAI) -> None:
+ fine_tune = await client.fine_tunes.cancel(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ )
+ assert_matches_type(FineTune, fine_tune, path=["response"])
+
+ @parametrize
+ async def test_raw_response_cancel(self, client: AsyncOpenAI) -> None:
+ response = await client.fine_tunes.with_raw_response.cancel(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ fine_tune = response.parse()
+ assert_matches_type(FineTune, fine_tune, path=["response"])
+
+ @pytest.mark.skip(reason="Prism chokes on this")
+ @parametrize
+ async def test_method_list_events_overload_1(self, client: AsyncOpenAI) -> None:
+ fine_tune = await client.fine_tunes.list_events(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ )
+ assert_matches_type(FineTuneEventsListResponse, fine_tune, path=["response"])
+
+ @pytest.mark.skip(reason="Prism chokes on this")
+ @parametrize
+ async def test_method_list_events_with_all_params_overload_1(self, client: AsyncOpenAI) -> None:
+ fine_tune = await client.fine_tunes.list_events(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ stream=False,
+ )
+ assert_matches_type(FineTuneEventsListResponse, fine_tune, path=["response"])
+
+ @pytest.mark.skip(reason="Prism chokes on this")
+ @parametrize
+ async def test_raw_response_list_events_overload_1(self, client: AsyncOpenAI) -> None:
+ response = await client.fine_tunes.with_raw_response.list_events(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ fine_tune = response.parse()
+ assert_matches_type(FineTuneEventsListResponse, fine_tune, path=["response"])
+
+ @pytest.mark.skip(reason="Prism chokes on this")
+ @parametrize
+ async def test_method_list_events_overload_2(self, client: AsyncOpenAI) -> None:
+ await client.fine_tunes.list_events(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ stream=True,
+ )
+
+ @pytest.mark.skip(reason="Prism chokes on this")
+ @parametrize
+ async def test_raw_response_list_events_overload_2(self, client: AsyncOpenAI) -> None:
+ response = await client.fine_tunes.with_raw_response.list_events(
+ "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
+ stream=True,
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ response.parse()
diff --git a/tests/api_resources/test_images.py b/tests/api_resources/test_images.py
new file mode 100644
index 0000000000..fa7fb6d533
--- /dev/null
+++ b/tests/api_resources/test_images.py
@@ -0,0 +1,197 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+import os
+
+import pytest
+
+from openai import OpenAI, AsyncOpenAI
+from tests.utils import assert_matches_type
+from openai.types import ImagesResponse
+from openai._client import OpenAI, AsyncOpenAI
+
+base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
+api_key = "My API Key"
+
+
+class TestImages:
+ strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ loose_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
+
+ @parametrize
+ def test_method_create_variation(self, client: OpenAI) -> None:
+ image = client.images.create_variation(
+ image=b"raw file contents",
+ )
+ assert_matches_type(ImagesResponse, image, path=["response"])
+
+ @parametrize
+ def test_method_create_variation_with_all_params(self, client: OpenAI) -> None:
+ image = client.images.create_variation(
+ image=b"raw file contents",
+ n=1,
+ response_format="url",
+ size="1024x1024",
+ user="user-1234",
+ )
+ assert_matches_type(ImagesResponse, image, path=["response"])
+
+ @parametrize
+ def test_raw_response_create_variation(self, client: OpenAI) -> None:
+ response = client.images.with_raw_response.create_variation(
+ image=b"raw file contents",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ image = response.parse()
+ assert_matches_type(ImagesResponse, image, path=["response"])
+
+ @parametrize
+ def test_method_edit(self, client: OpenAI) -> None:
+ image = client.images.edit(
+ image=b"raw file contents",
+ prompt="A cute baby sea otter wearing a beret",
+ )
+ assert_matches_type(ImagesResponse, image, path=["response"])
+
+ @parametrize
+ def test_method_edit_with_all_params(self, client: OpenAI) -> None:
+ image = client.images.edit(
+ image=b"raw file contents",
+ prompt="A cute baby sea otter wearing a beret",
+ mask=b"raw file contents",
+ n=1,
+ response_format="url",
+ size="1024x1024",
+ user="user-1234",
+ )
+ assert_matches_type(ImagesResponse, image, path=["response"])
+
+ @parametrize
+ def test_raw_response_edit(self, client: OpenAI) -> None:
+ response = client.images.with_raw_response.edit(
+ image=b"raw file contents",
+ prompt="A cute baby sea otter wearing a beret",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ image = response.parse()
+ assert_matches_type(ImagesResponse, image, path=["response"])
+
+ @parametrize
+ def test_method_generate(self, client: OpenAI) -> None:
+ image = client.images.generate(
+ prompt="A cute baby sea otter",
+ )
+ assert_matches_type(ImagesResponse, image, path=["response"])
+
+ @parametrize
+ def test_method_generate_with_all_params(self, client: OpenAI) -> None:
+ image = client.images.generate(
+ prompt="A cute baby sea otter",
+ n=1,
+ response_format="url",
+ size="1024x1024",
+ user="user-1234",
+ )
+ assert_matches_type(ImagesResponse, image, path=["response"])
+
+ @parametrize
+ def test_raw_response_generate(self, client: OpenAI) -> None:
+ response = client.images.with_raw_response.generate(
+ prompt="A cute baby sea otter",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ image = response.parse()
+ assert_matches_type(ImagesResponse, image, path=["response"])
+
+
+class TestAsyncImages:
+ strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ loose_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
+
+ @parametrize
+ async def test_method_create_variation(self, client: AsyncOpenAI) -> None:
+ image = await client.images.create_variation(
+ image=b"raw file contents",
+ )
+ assert_matches_type(ImagesResponse, image, path=["response"])
+
+ @parametrize
+ async def test_method_create_variation_with_all_params(self, client: AsyncOpenAI) -> None:
+ image = await client.images.create_variation(
+ image=b"raw file contents",
+ n=1,
+ response_format="url",
+ size="1024x1024",
+ user="user-1234",
+ )
+ assert_matches_type(ImagesResponse, image, path=["response"])
+
+ @parametrize
+ async def test_raw_response_create_variation(self, client: AsyncOpenAI) -> None:
+ response = await client.images.with_raw_response.create_variation(
+ image=b"raw file contents",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ image = response.parse()
+ assert_matches_type(ImagesResponse, image, path=["response"])
+
+ @parametrize
+ async def test_method_edit(self, client: AsyncOpenAI) -> None:
+ image = await client.images.edit(
+ image=b"raw file contents",
+ prompt="A cute baby sea otter wearing a beret",
+ )
+ assert_matches_type(ImagesResponse, image, path=["response"])
+
+ @parametrize
+ async def test_method_edit_with_all_params(self, client: AsyncOpenAI) -> None:
+ image = await client.images.edit(
+ image=b"raw file contents",
+ prompt="A cute baby sea otter wearing a beret",
+ mask=b"raw file contents",
+ n=1,
+ response_format="url",
+ size="1024x1024",
+ user="user-1234",
+ )
+ assert_matches_type(ImagesResponse, image, path=["response"])
+
+ @parametrize
+ async def test_raw_response_edit(self, client: AsyncOpenAI) -> None:
+ response = await client.images.with_raw_response.edit(
+ image=b"raw file contents",
+ prompt="A cute baby sea otter wearing a beret",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ image = response.parse()
+ assert_matches_type(ImagesResponse, image, path=["response"])
+
+ @parametrize
+ async def test_method_generate(self, client: AsyncOpenAI) -> None:
+ image = await client.images.generate(
+ prompt="A cute baby sea otter",
+ )
+ assert_matches_type(ImagesResponse, image, path=["response"])
+
+ @parametrize
+ async def test_method_generate_with_all_params(self, client: AsyncOpenAI) -> None:
+ image = await client.images.generate(
+ prompt="A cute baby sea otter",
+ n=1,
+ response_format="url",
+ size="1024x1024",
+ user="user-1234",
+ )
+ assert_matches_type(ImagesResponse, image, path=["response"])
+
+ @parametrize
+ async def test_raw_response_generate(self, client: AsyncOpenAI) -> None:
+ response = await client.images.with_raw_response.generate(
+ prompt="A cute baby sea otter",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ image = response.parse()
+ assert_matches_type(ImagesResponse, image, path=["response"])
diff --git a/tests/api_resources/test_models.py b/tests/api_resources/test_models.py
new file mode 100644
index 0000000000..3998809610
--- /dev/null
+++ b/tests/api_resources/test_models.py
@@ -0,0 +1,116 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+import os
+
+import pytest
+
+from openai import OpenAI, AsyncOpenAI
+from tests.utils import assert_matches_type
+from openai.types import Model, ModelDeleted
+from openai._client import OpenAI, AsyncOpenAI
+from openai.pagination import SyncPage, AsyncPage
+
+base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
+api_key = "My API Key"
+
+
+class TestModels:
+ strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ loose_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
+
+ @parametrize
+ def test_method_retrieve(self, client: OpenAI) -> None:
+ model = client.models.retrieve(
+ "gpt-3.5-turbo",
+ )
+ assert_matches_type(Model, model, path=["response"])
+
+ @parametrize
+ def test_raw_response_retrieve(self, client: OpenAI) -> None:
+ response = client.models.with_raw_response.retrieve(
+ "gpt-3.5-turbo",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ model = response.parse()
+ assert_matches_type(Model, model, path=["response"])
+
+ @parametrize
+ def test_method_list(self, client: OpenAI) -> None:
+ model = client.models.list()
+ assert_matches_type(SyncPage[Model], model, path=["response"])
+
+ @parametrize
+ def test_raw_response_list(self, client: OpenAI) -> None:
+ response = client.models.with_raw_response.list()
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ model = response.parse()
+ assert_matches_type(SyncPage[Model], model, path=["response"])
+
+ @parametrize
+ def test_method_delete(self, client: OpenAI) -> None:
+ model = client.models.delete(
+ "ft:gpt-3.5-turbo:acemeco:suffix:abc123",
+ )
+ assert_matches_type(ModelDeleted, model, path=["response"])
+
+ @parametrize
+ def test_raw_response_delete(self, client: OpenAI) -> None:
+ response = client.models.with_raw_response.delete(
+ "ft:gpt-3.5-turbo:acemeco:suffix:abc123",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ model = response.parse()
+ assert_matches_type(ModelDeleted, model, path=["response"])
+
+
+class TestAsyncModels:
+ strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ loose_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
+
+ @parametrize
+ async def test_method_retrieve(self, client: AsyncOpenAI) -> None:
+ model = await client.models.retrieve(
+ "gpt-3.5-turbo",
+ )
+ assert_matches_type(Model, model, path=["response"])
+
+ @parametrize
+ async def test_raw_response_retrieve(self, client: AsyncOpenAI) -> None:
+ response = await client.models.with_raw_response.retrieve(
+ "gpt-3.5-turbo",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ model = response.parse()
+ assert_matches_type(Model, model, path=["response"])
+
+ @parametrize
+ async def test_method_list(self, client: AsyncOpenAI) -> None:
+ model = await client.models.list()
+ assert_matches_type(AsyncPage[Model], model, path=["response"])
+
+ @parametrize
+ async def test_raw_response_list(self, client: AsyncOpenAI) -> None:
+ response = await client.models.with_raw_response.list()
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ model = response.parse()
+ assert_matches_type(AsyncPage[Model], model, path=["response"])
+
+ @parametrize
+ async def test_method_delete(self, client: AsyncOpenAI) -> None:
+ model = await client.models.delete(
+ "ft:gpt-3.5-turbo:acemeco:suffix:abc123",
+ )
+ assert_matches_type(ModelDeleted, model, path=["response"])
+
+ @parametrize
+ async def test_raw_response_delete(self, client: AsyncOpenAI) -> None:
+ response = await client.models.with_raw_response.delete(
+ "ft:gpt-3.5-turbo:acemeco:suffix:abc123",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ model = response.parse()
+ assert_matches_type(ModelDeleted, model, path=["response"])
diff --git a/tests/api_resources/test_moderations.py b/tests/api_resources/test_moderations.py
new file mode 100644
index 0000000000..502030d614
--- /dev/null
+++ b/tests/api_resources/test_moderations.py
@@ -0,0 +1,75 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+import os
+
+import pytest
+
+from openai import OpenAI, AsyncOpenAI
+from tests.utils import assert_matches_type
+from openai.types import ModerationCreateResponse
+from openai._client import OpenAI, AsyncOpenAI
+
+base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
+api_key = "My API Key"
+
+
+class TestModerations:
+ strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ loose_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
+
+ @parametrize
+ def test_method_create(self, client: OpenAI) -> None:
+ moderation = client.moderations.create(
+ input="I want to kill them.",
+ )
+ assert_matches_type(ModerationCreateResponse, moderation, path=["response"])
+
+ @parametrize
+ def test_method_create_with_all_params(self, client: OpenAI) -> None:
+ moderation = client.moderations.create(
+ input="I want to kill them.",
+ model="text-moderation-stable",
+ )
+ assert_matches_type(ModerationCreateResponse, moderation, path=["response"])
+
+ @parametrize
+ def test_raw_response_create(self, client: OpenAI) -> None:
+ response = client.moderations.with_raw_response.create(
+ input="I want to kill them.",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ moderation = response.parse()
+ assert_matches_type(ModerationCreateResponse, moderation, path=["response"])
+
+
+class TestAsyncModerations:
+ strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ loose_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
+
+ @parametrize
+ async def test_method_create(self, client: AsyncOpenAI) -> None:
+ moderation = await client.moderations.create(
+ input="I want to kill them.",
+ )
+ assert_matches_type(ModerationCreateResponse, moderation, path=["response"])
+
+ @parametrize
+ async def test_method_create_with_all_params(self, client: AsyncOpenAI) -> None:
+ moderation = await client.moderations.create(
+ input="I want to kill them.",
+ model="text-moderation-stable",
+ )
+ assert_matches_type(ModerationCreateResponse, moderation, path=["response"])
+
+ @parametrize
+ async def test_raw_response_create(self, client: AsyncOpenAI) -> None:
+ response = await client.moderations.with_raw_response.create(
+ input="I want to kill them.",
+ )
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ moderation = response.parse()
+ assert_matches_type(ModerationCreateResponse, moderation, path=["response"])
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000000..c3a1efe9df
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,16 @@
+import asyncio
+import logging
+from typing import Iterator
+
+import pytest
+
+pytest.register_assert_rewrite("tests.utils")
+
+logging.getLogger("openai").setLevel(logging.DEBUG)
+
+
+@pytest.fixture(scope="session")
+def event_loop() -> Iterator[asyncio.AbstractEventLoop]:
+ loop = asyncio.new_event_loop()
+ yield loop
+ loop.close()
diff --git a/tests/lib/test_azure.py b/tests/lib/test_azure.py
new file mode 100644
index 0000000000..b0bd87571b
--- /dev/null
+++ b/tests/lib/test_azure.py
@@ -0,0 +1,36 @@
+from typing import Union
+
+import pytest
+
+from openai._models import FinalRequestOptions
+from openai.lib.azure import AzureOpenAI, AsyncAzureOpenAI
+
+Client = Union[AzureOpenAI, AsyncAzureOpenAI]
+
+
+sync_client = AzureOpenAI(
+ api_version="2023-07-01",
+ api_key="example API key",
+ azure_endpoint="https://example-resource.azure.openai.com",
+)
+
+async_client = AsyncAzureOpenAI(
+ api_version="2023-07-01",
+ api_key="example API key",
+ azure_endpoint="https://example-resource.azure.openai.com",
+)
+
+
+@pytest.mark.parametrize("client", [sync_client, async_client])
+def test_implicit_deployment_path(client: Client) -> None:
+ req = client._build_request(
+ FinalRequestOptions.construct(
+ method="post",
+ url="/chat/completions",
+ json_data={"model": "my-deployment-model"},
+ )
+ )
+ assert (
+ req.url
+ == "https://example-resource.azure.openai.com/openai/deployments/my-deployment-model/chat/completions?api-version=2023-07-01"
+ )
diff --git a/tests/test_client.py b/tests/test_client.py
new file mode 100644
index 0000000000..3b70594ecd
--- /dev/null
+++ b/tests/test_client.py
@@ -0,0 +1,1110 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+import os
+import json
+import asyncio
+import inspect
+from typing import Any, Dict, Union, cast
+from unittest import mock
+
+import httpx
+import pytest
+from respx import MockRouter
+from pydantic import ValidationError
+
+from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
+from openai._client import OpenAI, AsyncOpenAI
+from openai._models import BaseModel, FinalRequestOptions
+from openai._streaming import Stream, AsyncStream
+from openai._exceptions import APIResponseValidationError
+from openai._base_client import (
+ DEFAULT_TIMEOUT,
+ HTTPX_DEFAULT_TIMEOUT,
+ BaseClient,
+ make_request_options,
+)
+
+base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
+api_key = "My API Key"
+
+
+def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
+ request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ url = httpx.URL(request.url)
+ return dict(url.params)
+
+
+class TestOpenAI:
+ client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+
+ @pytest.mark.respx(base_url=base_url)
+ def test_raw_response(self, respx_mock: MockRouter) -> None:
+ respx_mock.post("/foo").mock(return_value=httpx.Response(200, json='{"foo": "bar"}'))
+
+ response = self.client.post("/foo", cast_to=httpx.Response)
+ assert response.status_code == 200
+ assert isinstance(response, httpx.Response)
+ assert response.json() == '{"foo": "bar"}'
+
+ @pytest.mark.respx(base_url=base_url)
+ def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
+ respx_mock.post("/foo").mock(
+ return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
+ )
+
+ response = self.client.post("/foo", cast_to=httpx.Response)
+ assert response.status_code == 200
+ assert isinstance(response, httpx.Response)
+ assert response.json() == '{"foo": "bar"}'
+
+ def test_copy(self) -> None:
+ copied = self.client.copy()
+ assert id(copied) != id(self.client)
+
+ copied = self.client.copy(api_key="another My API Key")
+ assert copied.api_key == "another My API Key"
+ assert self.client.api_key == "My API Key"
+
+ def test_copy_default_options(self) -> None:
+ # options that have a default are overridden correctly
+ copied = self.client.copy(max_retries=7)
+ assert copied.max_retries == 7
+ assert self.client.max_retries == 2
+
+ copied2 = copied.copy(max_retries=6)
+ assert copied2.max_retries == 6
+ assert copied.max_retries == 7
+
+ # timeout
+ assert isinstance(self.client.timeout, httpx.Timeout)
+ copied = self.client.copy(timeout=None)
+ assert copied.timeout is None
+ assert isinstance(self.client.timeout, httpx.Timeout)
+
+ def test_copy_default_headers(self) -> None:
+ client = OpenAI(
+ base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
+ )
+ assert client.default_headers["X-Foo"] == "bar"
+
+ # does not override the already given value when not specified
+ copied = client.copy()
+ assert copied.default_headers["X-Foo"] == "bar"
+
+ # merges already given headers
+ copied = client.copy(default_headers={"X-Bar": "stainless"})
+ assert copied.default_headers["X-Foo"] == "bar"
+ assert copied.default_headers["X-Bar"] == "stainless"
+
+ # uses new values for any already given headers
+ copied = client.copy(default_headers={"X-Foo": "stainless"})
+ assert copied.default_headers["X-Foo"] == "stainless"
+
+ # set_default_headers
+
+ # completely overrides already set values
+ copied = client.copy(set_default_headers={})
+ assert copied.default_headers.get("X-Foo") is None
+
+ copied = client.copy(set_default_headers={"X-Bar": "Robert"})
+ assert copied.default_headers["X-Bar"] == "Robert"
+
+ with pytest.raises(
+ ValueError,
+ match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
+ ):
+ client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
+
+ def test_copy_default_query(self) -> None:
+ client = OpenAI(
+ base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
+ )
+ assert _get_params(client)["foo"] == "bar"
+
+ # does not override the already given value when not specified
+ copied = client.copy()
+ assert _get_params(copied)["foo"] == "bar"
+
+ # merges already given params
+ copied = client.copy(default_query={"bar": "stainless"})
+ params = _get_params(copied)
+ assert params["foo"] == "bar"
+ assert params["bar"] == "stainless"
+
+ # uses new values for any already given headers
+ copied = client.copy(default_query={"foo": "stainless"})
+ assert _get_params(copied)["foo"] == "stainless"
+
+ # set_default_query
+
+ # completely overrides already set values
+ copied = client.copy(set_default_query={})
+ assert _get_params(copied) == {}
+
+ copied = client.copy(set_default_query={"bar": "Robert"})
+ assert _get_params(copied)["bar"] == "Robert"
+
+ with pytest.raises(
+ ValueError,
+ # TODO: update
+ match="`default_query` and `set_default_query` arguments are mutually exclusive",
+ ):
+ client.copy(set_default_query={}, default_query={"foo": "Bar"})
+
+ def test_copy_signature(self) -> None:
+ # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
+ init_signature = inspect.signature(
+ # mypy doesn't like that we access the `__init__` property.
+ self.client.__init__, # type: ignore[misc]
+ )
+ copy_signature = inspect.signature(self.client.copy)
+ exclude_params = {"transport", "proxies", "_strict_response_validation"}
+
+ for name in init_signature.parameters.keys():
+ if name in exclude_params:
+ continue
+
+ copy_param = copy_signature.parameters.get(name)
+ assert copy_param is not None, f"copy() signature is missing the {name} param"
+
+ def test_request_timeout(self) -> None:
+ request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
+ assert timeout == DEFAULT_TIMEOUT
+
+ request = self.client._build_request(
+ FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
+ )
+ timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
+ assert timeout == httpx.Timeout(100.0)
+
+ def test_client_timeout_option(self) -> None:
+ client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
+
+ request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
+ assert timeout == httpx.Timeout(0)
+
+ def test_http_client_timeout_option(self) -> None:
+ # custom timeout given to the httpx client should be used
+ with httpx.Client(timeout=None) as http_client:
+ client = OpenAI(
+ base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
+ )
+
+ request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
+ assert timeout == httpx.Timeout(None)
+
+ # no timeout given to the httpx client should not use the httpx default
+ with httpx.Client() as http_client:
+ client = OpenAI(
+ base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
+ )
+
+ request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
+ assert timeout == DEFAULT_TIMEOUT
+
+ # explicitly passing the default timeout currently results in it being ignored
+ with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
+ client = OpenAI(
+ base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
+ )
+
+ request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
+ assert timeout == DEFAULT_TIMEOUT # our default
+
+ def test_default_headers_option(self) -> None:
+ client = OpenAI(
+ base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
+ )
+ request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ assert request.headers.get("x-foo") == "bar"
+ assert request.headers.get("x-stainless-lang") == "python"
+
+ client2 = OpenAI(
+ base_url=base_url,
+ api_key=api_key,
+ _strict_response_validation=True,
+ default_headers={
+ "X-Foo": "stainless",
+ "X-Stainless-Lang": "my-overriding-header",
+ },
+ )
+ request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
+ assert request.headers.get("x-foo") == "stainless"
+ assert request.headers.get("x-stainless-lang") == "my-overriding-header"
+
+ def test_validate_headers(self) -> None:
+ client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ assert request.headers.get("Authorization") == f"Bearer {api_key}"
+
+ with pytest.raises(Exception):
+ client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
+ _ = client2
+
+ def test_default_query_option(self) -> None:
+ client = OpenAI(
+ base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
+ )
+ request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ url = httpx.URL(request.url)
+ assert dict(url.params) == {"query_param": "bar"}
+
+ request = client._build_request(
+ FinalRequestOptions(
+ method="get",
+ url="/foo",
+ params={"foo": "baz", "query_param": "overriden"},
+ )
+ )
+ url = httpx.URL(request.url)
+ assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
+
+ def test_request_extra_json(self) -> None:
+ request = self.client._build_request(
+ FinalRequestOptions(
+ method="post",
+ url="/foo",
+ json_data={"foo": "bar"},
+ extra_json={"baz": False},
+ ),
+ )
+ data = json.loads(request.content.decode("utf-8"))
+ assert data == {"foo": "bar", "baz": False}
+
+ request = self.client._build_request(
+ FinalRequestOptions(
+ method="post",
+ url="/foo",
+ extra_json={"baz": False},
+ ),
+ )
+ data = json.loads(request.content.decode("utf-8"))
+ assert data == {"baz": False}
+
+ # `extra_json` takes priority over `json_data` when keys clash
+ request = self.client._build_request(
+ FinalRequestOptions(
+ method="post",
+ url="/foo",
+ json_data={"foo": "bar", "baz": True},
+ extra_json={"baz": None},
+ ),
+ )
+ data = json.loads(request.content.decode("utf-8"))
+ assert data == {"foo": "bar", "baz": None}
+
+ def test_request_extra_headers(self) -> None:
+ request = self.client._build_request(
+ FinalRequestOptions(
+ method="post",
+ url="/foo",
+ **make_request_options(extra_headers={"X-Foo": "Foo"}),
+ ),
+ )
+ assert request.headers.get("X-Foo") == "Foo"
+
+ # `extra_headers` takes priority over `default_headers` when keys clash
+ request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
+ FinalRequestOptions(
+ method="post",
+ url="/foo",
+ **make_request_options(
+ extra_headers={"X-Bar": "false"},
+ ),
+ ),
+ )
+ assert request.headers.get("X-Bar") == "false"
+
+ def test_request_extra_query(self) -> None:
+ request = self.client._build_request(
+ FinalRequestOptions(
+ method="post",
+ url="/foo",
+ **make_request_options(
+ extra_query={"my_query_param": "Foo"},
+ ),
+ ),
+ )
+ params = cast(Dict[str, str], dict(request.url.params))
+ assert params == {"my_query_param": "Foo"}
+
+ # if both `query` and `extra_query` are given, they are merged
+ request = self.client._build_request(
+ FinalRequestOptions(
+ method="post",
+ url="/foo",
+ **make_request_options(
+ query={"bar": "1"},
+ extra_query={"foo": "2"},
+ ),
+ ),
+ )
+ params = cast(Dict[str, str], dict(request.url.params))
+ assert params == {"bar": "1", "foo": "2"}
+
+ # `extra_query` takes priority over `query` when keys clash
+ request = self.client._build_request(
+ FinalRequestOptions(
+ method="post",
+ url="/foo",
+ **make_request_options(
+ query={"foo": "1"},
+ extra_query={"foo": "2"},
+ ),
+ ),
+ )
+ params = cast(Dict[str, str], dict(request.url.params))
+ assert params == {"foo": "2"}
+
+ @pytest.mark.respx(base_url=base_url)
+ def test_basic_union_response(self, respx_mock: MockRouter) -> None:
+ class Model1(BaseModel):
+ name: str
+
+ class Model2(BaseModel):
+ foo: str
+
+ respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
+
+ response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ assert isinstance(response, Model2)
+ assert response.foo == "bar"
+
+ @pytest.mark.respx(base_url=base_url)
+ def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
+ """Union of objects with the same field name using a different type"""
+
+ class Model1(BaseModel):
+ foo: int
+
+ class Model2(BaseModel):
+ foo: str
+
+ respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
+
+ response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ assert isinstance(response, Model2)
+ assert response.foo == "bar"
+
+ respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
+
+ response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ assert isinstance(response, Model1)
+ assert response.foo == 1
+
+ @pytest.mark.parametrize(
+ "client",
+ [
+ OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
+ OpenAI(
+ base_url="http://localhost:5000/custom/path/",
+ api_key=api_key,
+ _strict_response_validation=True,
+ http_client=httpx.Client(),
+ ),
+ ],
+ ids=["standard", "custom http client"],
+ )
+ def test_base_url_trailing_slash(self, client: OpenAI) -> None:
+ request = client._build_request(
+ FinalRequestOptions(
+ method="post",
+ url="/foo",
+ json_data={"foo": "bar"},
+ ),
+ )
+ assert request.url == "http://localhost:5000/custom/path/foo"
+
+ @pytest.mark.parametrize(
+ "client",
+ [
+ OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
+ OpenAI(
+ base_url="http://localhost:5000/custom/path/",
+ api_key=api_key,
+ _strict_response_validation=True,
+ http_client=httpx.Client(),
+ ),
+ ],
+ ids=["standard", "custom http client"],
+ )
+ def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
+ request = client._build_request(
+ FinalRequestOptions(
+ method="post",
+ url="/foo",
+ json_data={"foo": "bar"},
+ ),
+ )
+ assert request.url == "http://localhost:5000/custom/path/foo"
+
+ @pytest.mark.parametrize(
+ "client",
+ [
+ OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
+ OpenAI(
+ base_url="http://localhost:5000/custom/path/",
+ api_key=api_key,
+ _strict_response_validation=True,
+ http_client=httpx.Client(),
+ ),
+ ],
+ ids=["standard", "custom http client"],
+ )
+ def test_absolute_request_url(self, client: OpenAI) -> None:
+ request = client._build_request(
+ FinalRequestOptions(
+ method="post",
+ url="https://myapi.com/foo",
+ json_data={"foo": "bar"},
+ ),
+ )
+ assert request.url == "https://myapi.com/foo"
+
+ def test_client_del(self) -> None:
+ client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ assert not client.is_closed()
+
+ client.__del__()
+
+ assert client.is_closed()
+
+ def test_copied_client_does_not_close_http(self) -> None:
+ client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ assert not client.is_closed()
+
+ copied = client.copy()
+ assert copied is not client
+
+ copied.__del__()
+
+ assert not copied.is_closed()
+ assert not client.is_closed()
+
+ def test_client_context_manager(self) -> None:
+ client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ with client as c2:
+ assert c2 is client
+ assert not c2.is_closed()
+ assert not client.is_closed()
+ assert client.is_closed()
+
+ @pytest.mark.respx(base_url=base_url)
+ def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
+ class Model(BaseModel):
+ foo: str
+
+ respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
+
+ with pytest.raises(APIResponseValidationError) as exc:
+ self.client.get("/foo", cast_to=Model)
+
+ assert isinstance(exc.value.__cause__, ValidationError)
+
+ @pytest.mark.respx(base_url=base_url)
+ def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
+ class Model(BaseModel):
+ name: str
+
+ respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
+
+ response = self.client.post("/foo", cast_to=Model, stream=True)
+ assert isinstance(response, Stream)
+
+ @pytest.mark.respx(base_url=base_url)
+ def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
+ class Model(BaseModel):
+ name: str
+
+ respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
+
+ strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+
+ with pytest.raises(APIResponseValidationError):
+ strict_client.get("/foo", cast_to=Model)
+
+ client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+
+ response = client.get("/foo", cast_to=Model)
+ assert isinstance(response, str) # type: ignore[unreachable]
+
+ @pytest.mark.parametrize(
+ "remaining_retries,retry_after,timeout",
+ [
+ [3, "20", 20],
+ [3, "0", 0.5],
+ [3, "-10", 0.5],
+ [3, "60", 60],
+ [3, "61", 0.5],
+ [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
+ [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
+ [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
+ [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
+ [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
+ [3, "99999999999999999999999999999999999", 0.5],
+ [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
+ [3, "", 0.5],
+ [2, "", 0.5 * 2.0],
+ [1, "", 0.5 * 4.0],
+ ],
+ )
+ @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
+ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
+ client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+
+ headers = httpx.Headers({"retry-after": retry_after})
+ options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
+ calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
+ assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
+
+
+class TestAsyncOpenAI:
+ client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+
+ @pytest.mark.respx(base_url=base_url)
+ @pytest.mark.asyncio
+ async def test_raw_response(self, respx_mock: MockRouter) -> None:
+ respx_mock.post("/foo").mock(return_value=httpx.Response(200, json='{"foo": "bar"}'))
+
+ response = await self.client.post("/foo", cast_to=httpx.Response)
+ assert response.status_code == 200
+ assert isinstance(response, httpx.Response)
+ assert response.json() == '{"foo": "bar"}'
+
+ @pytest.mark.respx(base_url=base_url)
+ @pytest.mark.asyncio
+ async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
+ respx_mock.post("/foo").mock(
+ return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
+ )
+
+ response = await self.client.post("/foo", cast_to=httpx.Response)
+ assert response.status_code == 200
+ assert isinstance(response, httpx.Response)
+ assert response.json() == '{"foo": "bar"}'
+
+ def test_copy(self) -> None:
+ copied = self.client.copy()
+ assert id(copied) != id(self.client)
+
+ copied = self.client.copy(api_key="another My API Key")
+ assert copied.api_key == "another My API Key"
+ assert self.client.api_key == "My API Key"
+
+ def test_copy_default_options(self) -> None:
+ # options that have a default are overridden correctly
+ copied = self.client.copy(max_retries=7)
+ assert copied.max_retries == 7
+ assert self.client.max_retries == 2
+
+ copied2 = copied.copy(max_retries=6)
+ assert copied2.max_retries == 6
+ assert copied.max_retries == 7
+
+ # timeout
+ assert isinstance(self.client.timeout, httpx.Timeout)
+ copied = self.client.copy(timeout=None)
+ assert copied.timeout is None
+ assert isinstance(self.client.timeout, httpx.Timeout)
+
+ def test_copy_default_headers(self) -> None:
+ client = AsyncOpenAI(
+ base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
+ )
+ assert client.default_headers["X-Foo"] == "bar"
+
+ # does not override the already given value when not specified
+ copied = client.copy()
+ assert copied.default_headers["X-Foo"] == "bar"
+
+ # merges already given headers
+ copied = client.copy(default_headers={"X-Bar": "stainless"})
+ assert copied.default_headers["X-Foo"] == "bar"
+ assert copied.default_headers["X-Bar"] == "stainless"
+
+ # uses new values for any already given headers
+ copied = client.copy(default_headers={"X-Foo": "stainless"})
+ assert copied.default_headers["X-Foo"] == "stainless"
+
+ # set_default_headers
+
+ # completely overrides already set values
+ copied = client.copy(set_default_headers={})
+ assert copied.default_headers.get("X-Foo") is None
+
+ copied = client.copy(set_default_headers={"X-Bar": "Robert"})
+ assert copied.default_headers["X-Bar"] == "Robert"
+
+ with pytest.raises(
+ ValueError,
+ match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
+ ):
+ client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
+
+ def test_copy_default_query(self) -> None:
+ client = AsyncOpenAI(
+ base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
+ )
+ assert _get_params(client)["foo"] == "bar"
+
+ # does not override the already given value when not specified
+ copied = client.copy()
+ assert _get_params(copied)["foo"] == "bar"
+
+ # merges already given params
+ copied = client.copy(default_query={"bar": "stainless"})
+ params = _get_params(copied)
+ assert params["foo"] == "bar"
+ assert params["bar"] == "stainless"
+
+ # uses new values for any already given headers
+ copied = client.copy(default_query={"foo": "stainless"})
+ assert _get_params(copied)["foo"] == "stainless"
+
+ # set_default_query
+
+ # completely overrides already set values
+ copied = client.copy(set_default_query={})
+ assert _get_params(copied) == {}
+
+ copied = client.copy(set_default_query={"bar": "Robert"})
+ assert _get_params(copied)["bar"] == "Robert"
+
+ with pytest.raises(
+ ValueError,
+ # TODO: update
+ match="`default_query` and `set_default_query` arguments are mutually exclusive",
+ ):
+ client.copy(set_default_query={}, default_query={"foo": "Bar"})
+
+ def test_copy_signature(self) -> None:
+ # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
+ init_signature = inspect.signature(
+ # mypy doesn't like that we access the `__init__` property.
+ self.client.__init__, # type: ignore[misc]
+ )
+ copy_signature = inspect.signature(self.client.copy)
+ exclude_params = {"transport", "proxies", "_strict_response_validation"}
+
+ for name in init_signature.parameters.keys():
+ if name in exclude_params:
+ continue
+
+ copy_param = copy_signature.parameters.get(name)
+ assert copy_param is not None, f"copy() signature is missing the {name} param"
+
+ async def test_request_timeout(self) -> None:
+ request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
+ assert timeout == DEFAULT_TIMEOUT
+
+ request = self.client._build_request(
+ FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
+ )
+ timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
+ assert timeout == httpx.Timeout(100.0)
+
+ async def test_client_timeout_option(self) -> None:
+ client = AsyncOpenAI(
+ base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
+ )
+
+ request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
+ assert timeout == httpx.Timeout(0)
+
+ async def test_http_client_timeout_option(self) -> None:
+ # custom timeout given to the httpx client should be used
+ async with httpx.AsyncClient(timeout=None) as http_client:
+ client = AsyncOpenAI(
+ base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
+ )
+
+ request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
+ assert timeout == httpx.Timeout(None)
+
+ # no timeout given to the httpx client should not use the httpx default
+ async with httpx.AsyncClient() as http_client:
+ client = AsyncOpenAI(
+ base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
+ )
+
+ request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
+ assert timeout == DEFAULT_TIMEOUT
+
+ # explicitly passing the default timeout currently results in it being ignored
+ async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
+ client = AsyncOpenAI(
+ base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
+ )
+
+ request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
+ assert timeout == DEFAULT_TIMEOUT # our default
+
+ def test_default_headers_option(self) -> None:
+ client = AsyncOpenAI(
+ base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
+ )
+ request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ assert request.headers.get("x-foo") == "bar"
+ assert request.headers.get("x-stainless-lang") == "python"
+
+ client2 = AsyncOpenAI(
+ base_url=base_url,
+ api_key=api_key,
+ _strict_response_validation=True,
+ default_headers={
+ "X-Foo": "stainless",
+ "X-Stainless-Lang": "my-overriding-header",
+ },
+ )
+ request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
+ assert request.headers.get("x-foo") == "stainless"
+ assert request.headers.get("x-stainless-lang") == "my-overriding-header"
+
+ def test_validate_headers(self) -> None:
+ client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ assert request.headers.get("Authorization") == f"Bearer {api_key}"
+
+ with pytest.raises(Exception):
+ client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
+ _ = client2
+
+ def test_default_query_option(self) -> None:
+ client = AsyncOpenAI(
+ base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
+ )
+ request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ url = httpx.URL(request.url)
+ assert dict(url.params) == {"query_param": "bar"}
+
+ request = client._build_request(
+ FinalRequestOptions(
+ method="get",
+ url="/foo",
+ params={"foo": "baz", "query_param": "overriden"},
+ )
+ )
+ url = httpx.URL(request.url)
+ assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
+
+ def test_request_extra_json(self) -> None:
+ request = self.client._build_request(
+ FinalRequestOptions(
+ method="post",
+ url="/foo",
+ json_data={"foo": "bar"},
+ extra_json={"baz": False},
+ ),
+ )
+ data = json.loads(request.content.decode("utf-8"))
+ assert data == {"foo": "bar", "baz": False}
+
+ request = self.client._build_request(
+ FinalRequestOptions(
+ method="post",
+ url="/foo",
+ extra_json={"baz": False},
+ ),
+ )
+ data = json.loads(request.content.decode("utf-8"))
+ assert data == {"baz": False}
+
+ # `extra_json` takes priority over `json_data` when keys clash
+ request = self.client._build_request(
+ FinalRequestOptions(
+ method="post",
+ url="/foo",
+ json_data={"foo": "bar", "baz": True},
+ extra_json={"baz": None},
+ ),
+ )
+ data = json.loads(request.content.decode("utf-8"))
+ assert data == {"foo": "bar", "baz": None}
+
+ def test_request_extra_headers(self) -> None:
+ request = self.client._build_request(
+ FinalRequestOptions(
+ method="post",
+ url="/foo",
+ **make_request_options(extra_headers={"X-Foo": "Foo"}),
+ ),
+ )
+ assert request.headers.get("X-Foo") == "Foo"
+
+ # `extra_headers` takes priority over `default_headers` when keys clash
+ request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
+ FinalRequestOptions(
+ method="post",
+ url="/foo",
+ **make_request_options(
+ extra_headers={"X-Bar": "false"},
+ ),
+ ),
+ )
+ assert request.headers.get("X-Bar") == "false"
+
+ def test_request_extra_query(self) -> None:
+ request = self.client._build_request(
+ FinalRequestOptions(
+ method="post",
+ url="/foo",
+ **make_request_options(
+ extra_query={"my_query_param": "Foo"},
+ ),
+ ),
+ )
+ params = cast(Dict[str, str], dict(request.url.params))
+ assert params == {"my_query_param": "Foo"}
+
+ # if both `query` and `extra_query` are given, they are merged
+ request = self.client._build_request(
+ FinalRequestOptions(
+ method="post",
+ url="/foo",
+ **make_request_options(
+ query={"bar": "1"},
+ extra_query={"foo": "2"},
+ ),
+ ),
+ )
+ params = cast(Dict[str, str], dict(request.url.params))
+ assert params == {"bar": "1", "foo": "2"}
+
+ # `extra_query` takes priority over `query` when keys clash
+ request = self.client._build_request(
+ FinalRequestOptions(
+ method="post",
+ url="/foo",
+ **make_request_options(
+ query={"foo": "1"},
+ extra_query={"foo": "2"},
+ ),
+ ),
+ )
+ params = cast(Dict[str, str], dict(request.url.params))
+ assert params == {"foo": "2"}
+
+ @pytest.mark.respx(base_url=base_url)
+ async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
+ class Model1(BaseModel):
+ name: str
+
+ class Model2(BaseModel):
+ foo: str
+
+ respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
+
+ response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ assert isinstance(response, Model2)
+ assert response.foo == "bar"
+
+ @pytest.mark.respx(base_url=base_url)
+ async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
+ """Union of objects with the same field name using a different type"""
+
+ class Model1(BaseModel):
+ foo: int
+
+ class Model2(BaseModel):
+ foo: str
+
+ respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
+
+ response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ assert isinstance(response, Model2)
+ assert response.foo == "bar"
+
+ respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
+
+ response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ assert isinstance(response, Model1)
+ assert response.foo == 1
+
+ @pytest.mark.parametrize(
+ "client",
+ [
+ AsyncOpenAI(
+ base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
+ ),
+ AsyncOpenAI(
+ base_url="http://localhost:5000/custom/path/",
+ api_key=api_key,
+ _strict_response_validation=True,
+ http_client=httpx.AsyncClient(),
+ ),
+ ],
+ ids=["standard", "custom http client"],
+ )
+ def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
+ request = client._build_request(
+ FinalRequestOptions(
+ method="post",
+ url="/foo",
+ json_data={"foo": "bar"},
+ ),
+ )
+ assert request.url == "http://localhost:5000/custom/path/foo"
+
+ @pytest.mark.parametrize(
+ "client",
+ [
+ AsyncOpenAI(
+ base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
+ ),
+ AsyncOpenAI(
+ base_url="http://localhost:5000/custom/path/",
+ api_key=api_key,
+ _strict_response_validation=True,
+ http_client=httpx.AsyncClient(),
+ ),
+ ],
+ ids=["standard", "custom http client"],
+ )
+ def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
+ request = client._build_request(
+ FinalRequestOptions(
+ method="post",
+ url="/foo",
+ json_data={"foo": "bar"},
+ ),
+ )
+ assert request.url == "http://localhost:5000/custom/path/foo"
+
+ @pytest.mark.parametrize(
+ "client",
+ [
+ AsyncOpenAI(
+ base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
+ ),
+ AsyncOpenAI(
+ base_url="http://localhost:5000/custom/path/",
+ api_key=api_key,
+ _strict_response_validation=True,
+ http_client=httpx.AsyncClient(),
+ ),
+ ],
+ ids=["standard", "custom http client"],
+ )
+ def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
+ request = client._build_request(
+ FinalRequestOptions(
+ method="post",
+ url="https://myapi.com/foo",
+ json_data={"foo": "bar"},
+ ),
+ )
+ assert request.url == "https://myapi.com/foo"
+
+ async def test_client_del(self) -> None:
+ client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ assert not client.is_closed()
+
+ client.__del__()
+
+ await asyncio.sleep(0.2)
+ assert client.is_closed()
+
+ async def test_copied_client_does_not_close_http(self) -> None:
+ client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ assert not client.is_closed()
+
+ copied = client.copy()
+ assert copied is not client
+
+ copied.__del__()
+
+ await asyncio.sleep(0.2)
+ assert not copied.is_closed()
+ assert not client.is_closed()
+
+ async def test_client_context_manager(self) -> None:
+ client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ async with client as c2:
+ assert c2 is client
+ assert not c2.is_closed()
+ assert not client.is_closed()
+ assert client.is_closed()
+
+ @pytest.mark.respx(base_url=base_url)
+ @pytest.mark.asyncio
+ async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
+ class Model(BaseModel):
+ foo: str
+
+ respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
+
+ with pytest.raises(APIResponseValidationError) as exc:
+ await self.client.get("/foo", cast_to=Model)
+
+ assert isinstance(exc.value.__cause__, ValidationError)
+
+ @pytest.mark.respx(base_url=base_url)
+ @pytest.mark.asyncio
+ async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
+ class Model(BaseModel):
+ name: str
+
+ respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
+
+ response = await self.client.post("/foo", cast_to=Model, stream=True)
+ assert isinstance(response, AsyncStream)
+
+ @pytest.mark.respx(base_url=base_url)
+ @pytest.mark.asyncio
+ async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
+ class Model(BaseModel):
+ name: str
+
+ respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
+
+ strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+
+ with pytest.raises(APIResponseValidationError):
+ await strict_client.get("/foo", cast_to=Model)
+
+ client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+
+ response = await client.get("/foo", cast_to=Model)
+ assert isinstance(response, str) # type: ignore[unreachable]
+
+ @pytest.mark.parametrize(
+ "remaining_retries,retry_after,timeout",
+ [
+ [3, "20", 20],
+ [3, "0", 0.5],
+ [3, "-10", 0.5],
+ [3, "60", 60],
+ [3, "61", 0.5],
+ [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
+ [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
+ [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
+ [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
+ [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
+ [3, "99999999999999999999999999999999999", 0.5],
+ [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
+ [3, "", 0.5],
+ [2, "", 0.5 * 2.0],
+ [1, "", 0.5 * 4.0],
+ ],
+ )
+ @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
+ @pytest.mark.asyncio
+ async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
+ client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+
+ headers = httpx.Headers({"retry-after": retry_after})
+ options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
+ calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
+ assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
diff --git a/tests/test_deepcopy.py b/tests/test_deepcopy.py
new file mode 100644
index 0000000000..8cf65ce94e
--- /dev/null
+++ b/tests/test_deepcopy.py
@@ -0,0 +1,59 @@
+from openai._utils import deepcopy_minimal
+
+
+def assert_different_identities(obj1: object, obj2: object) -> None:
+ assert obj1 == obj2
+ assert id(obj1) != id(obj2)
+
+
+def test_simple_dict() -> None:
+ obj1 = {"foo": "bar"}
+ obj2 = deepcopy_minimal(obj1)
+ assert_different_identities(obj1, obj2)
+
+
+def test_nested_dict() -> None:
+ obj1 = {"foo": {"bar": True}}
+ obj2 = deepcopy_minimal(obj1)
+ assert_different_identities(obj1, obj2)
+ assert_different_identities(obj1["foo"], obj2["foo"])
+
+
+def test_complex_nested_dict() -> None:
+ obj1 = {"foo": {"bar": [{"hello": "world"}]}}
+ obj2 = deepcopy_minimal(obj1)
+ assert_different_identities(obj1, obj2)
+ assert_different_identities(obj1["foo"], obj2["foo"])
+ assert_different_identities(obj1["foo"]["bar"], obj2["foo"]["bar"])
+ assert_different_identities(obj1["foo"]["bar"][0], obj2["foo"]["bar"][0])
+
+
+def test_simple_list() -> None:
+ obj1 = ["a", "b", "c"]
+ obj2 = deepcopy_minimal(obj1)
+ assert_different_identities(obj1, obj2)
+
+
+def test_nested_list() -> None:
+ obj1 = ["a", [1, 2, 3]]
+ obj2 = deepcopy_minimal(obj1)
+ assert_different_identities(obj1, obj2)
+ assert_different_identities(obj1[1], obj2[1])
+
+
+class MyObject:
+ ...
+
+
+def test_ignores_other_types() -> None:
+ # custom classes
+ my_obj = MyObject()
+ obj1 = {"foo": my_obj}
+ obj2 = deepcopy_minimal(obj1)
+ assert_different_identities(obj1, obj2)
+ assert obj1["foo"] is my_obj
+
+ # tuples
+ obj3 = ("a", "b")
+ obj4 = deepcopy_minimal(obj3)
+ assert obj3 is obj4
diff --git a/tests/test_extract_files.py b/tests/test_extract_files.py
new file mode 100644
index 0000000000..554487da42
--- /dev/null
+++ b/tests/test_extract_files.py
@@ -0,0 +1,64 @@
+from __future__ import annotations
+
+from typing import Sequence
+
+import pytest
+
+from openai._types import FileTypes
+from openai._utils import extract_files
+
+
+def test_removes_files_from_input() -> None:
+ query = {"foo": "bar"}
+ assert extract_files(query, paths=[]) == []
+ assert query == {"foo": "bar"}
+
+ query2 = {"foo": b"Bar", "hello": "world"}
+ assert extract_files(query2, paths=[["foo"]]) == [("foo", b"Bar")]
+ assert query2 == {"hello": "world"}
+
+ query3 = {"foo": {"foo": {"bar": b"Bar"}}, "hello": "world"}
+ assert extract_files(query3, paths=[["foo", "foo", "bar"]]) == [("foo[foo][bar]", b"Bar")]
+ assert query3 == {"foo": {"foo": {}}, "hello": "world"}
+
+ query4 = {"foo": {"bar": b"Bar", "baz": "foo"}, "hello": "world"}
+ assert extract_files(query4, paths=[["foo", "bar"]]) == [("foo[bar]", b"Bar")]
+ assert query4 == {"hello": "world", "foo": {"baz": "foo"}}
+
+
+def test_multiple_files() -> None:
+ query = {"documents": [{"file": b"My first file"}, {"file": b"My second file"}]}
+ assert extract_files(query, paths=[["documents", "", "file"]]) == [
+ ("documents[][file]", b"My first file"),
+ ("documents[][file]", b"My second file"),
+ ]
+ assert query == {"documents": [{}, {}]}
+
+
+@pytest.mark.parametrize(
+ "query,paths,expected",
+ [
+ [
+ {"foo": {"bar": "baz"}},
+ [["foo", "", "bar"]],
+ [],
+ ],
+ [
+ {"foo": ["bar", "baz"]},
+ [["foo", "bar"]],
+ [],
+ ],
+ [
+ {"foo": {"bar": "baz"}},
+ [["foo", "foo"]],
+ [],
+ ],
+ ],
+ ids=["dict expecting array", "arraye expecting dict", "unknown keys"],
+)
+def test_ignores_incorrect_paths(
+ query: dict[str, object],
+ paths: Sequence[Sequence[str]],
+ expected: list[tuple[str, FileTypes]],
+) -> None:
+ assert extract_files(query, paths=paths) == expected
diff --git a/tests/test_files.py b/tests/test_files.py
new file mode 100644
index 0000000000..15d5c6a811
--- /dev/null
+++ b/tests/test_files.py
@@ -0,0 +1,51 @@
+from pathlib import Path
+
+import anyio
+import pytest
+from dirty_equals import IsDict, IsList, IsBytes, IsTuple
+
+from openai._files import to_httpx_files, async_to_httpx_files
+
+readme_path = Path(__file__).parent.parent.joinpath("README.md")
+
+
+def test_pathlib_includes_file_name() -> None:
+ result = to_httpx_files({"file": readme_path})
+ print(result)
+ assert result == IsDict({"file": IsTuple("README.md", IsBytes())})
+
+
+def test_tuple_input() -> None:
+ result = to_httpx_files([("file", readme_path)])
+ print(result)
+ assert result == IsList(IsTuple("file", IsTuple("README.md", IsBytes())))
+
+
+@pytest.mark.asyncio
+async def test_async_pathlib_includes_file_name() -> None:
+ result = await async_to_httpx_files({"file": readme_path})
+ print(result)
+ assert result == IsDict({"file": IsTuple("README.md", IsBytes())})
+
+
+@pytest.mark.asyncio
+async def test_async_supports_anyio_path() -> None:
+ result = await async_to_httpx_files({"file": anyio.Path(readme_path)})
+ print(result)
+ assert result == IsDict({"file": IsTuple("README.md", IsBytes())})
+
+
+@pytest.mark.asyncio
+async def test_async_tuple_input() -> None:
+ result = await async_to_httpx_files([("file", readme_path)])
+ print(result)
+ assert result == IsList(IsTuple("file", IsTuple("README.md", IsBytes())))
+
+
+def test_string_not_allowed() -> None:
+ with pytest.raises(TypeError, match="Expected file types input to be a FileContent type or to be a tuple"):
+ to_httpx_files(
+ {
+ "file": "foo", # type: ignore
+ }
+ )
diff --git a/tests/test_models.py b/tests/test_models.py
new file mode 100644
index 0000000000..713bd2cb1b
--- /dev/null
+++ b/tests/test_models.py
@@ -0,0 +1,573 @@
+import json
+from typing import Any, Dict, List, Union, Optional, cast
+from datetime import datetime, timezone
+from typing_extensions import Literal
+
+import pytest
+import pydantic
+from pydantic import Field
+
+from openai._compat import PYDANTIC_V2, parse_obj, model_dump, model_json
+from openai._models import BaseModel
+
+
+class BasicModel(BaseModel):
+ foo: str
+
+
+@pytest.mark.parametrize("value", ["hello", 1], ids=["correct type", "mismatched"])
+def test_basic(value: object) -> None:
+ m = BasicModel.construct(foo=value)
+ assert m.foo == value
+
+
+def test_directly_nested_model() -> None:
+ class NestedModel(BaseModel):
+ nested: BasicModel
+
+ m = NestedModel.construct(nested={"foo": "Foo!"})
+ assert m.nested.foo == "Foo!"
+
+ # mismatched types
+ m = NestedModel.construct(nested="hello!")
+ assert m.nested == "hello!"
+
+
+def test_optional_nested_model() -> None:
+ class NestedModel(BaseModel):
+ nested: Optional[BasicModel]
+
+ m1 = NestedModel.construct(nested=None)
+ assert m1.nested is None
+
+ m2 = NestedModel.construct(nested={"foo": "bar"})
+ assert m2.nested is not None
+ assert m2.nested.foo == "bar"
+
+ # mismatched types
+ m3 = NestedModel.construct(nested={"foo"})
+ assert isinstance(cast(Any, m3.nested), set)
+ assert m3.nested == {"foo"}
+
+
+def test_list_nested_model() -> None:
+ class NestedModel(BaseModel):
+ nested: List[BasicModel]
+
+ m = NestedModel.construct(nested=[{"foo": "bar"}, {"foo": "2"}])
+ assert m.nested is not None
+ assert isinstance(m.nested, list)
+ assert len(m.nested) == 2
+ assert m.nested[0].foo == "bar"
+ assert m.nested[1].foo == "2"
+
+ # mismatched types
+ m = NestedModel.construct(nested=True)
+ assert cast(Any, m.nested) is True
+
+ m = NestedModel.construct(nested=[False])
+ assert cast(Any, m.nested) == [False]
+
+
+def test_optional_list_nested_model() -> None:
+ class NestedModel(BaseModel):
+ nested: Optional[List[BasicModel]]
+
+ m1 = NestedModel.construct(nested=[{"foo": "bar"}, {"foo": "2"}])
+ assert m1.nested is not None
+ assert isinstance(m1.nested, list)
+ assert len(m1.nested) == 2
+ assert m1.nested[0].foo == "bar"
+ assert m1.nested[1].foo == "2"
+
+ m2 = NestedModel.construct(nested=None)
+ assert m2.nested is None
+
+ # mismatched types
+ m3 = NestedModel.construct(nested={1})
+ assert cast(Any, m3.nested) == {1}
+
+ m4 = NestedModel.construct(nested=[False])
+ assert cast(Any, m4.nested) == [False]
+
+
+def test_list_optional_items_nested_model() -> None:
+ class NestedModel(BaseModel):
+ nested: List[Optional[BasicModel]]
+
+ m = NestedModel.construct(nested=[None, {"foo": "bar"}])
+ assert m.nested is not None
+ assert isinstance(m.nested, list)
+ assert len(m.nested) == 2
+ assert m.nested[0] is None
+ assert m.nested[1] is not None
+ assert m.nested[1].foo == "bar"
+
+ # mismatched types
+ m3 = NestedModel.construct(nested="foo")
+ assert cast(Any, m3.nested) == "foo"
+
+ m4 = NestedModel.construct(nested=[False])
+ assert cast(Any, m4.nested) == [False]
+
+
+def test_list_mismatched_type() -> None:
+ class NestedModel(BaseModel):
+ nested: List[str]
+
+ m = NestedModel.construct(nested=False)
+ assert cast(Any, m.nested) is False
+
+
+def test_raw_dictionary() -> None:
+ class NestedModel(BaseModel):
+ nested: Dict[str, str]
+
+ m = NestedModel.construct(nested={"hello": "world"})
+ assert m.nested == {"hello": "world"}
+
+ # mismatched types
+ m = NestedModel.construct(nested=False)
+ assert cast(Any, m.nested) is False
+
+
+def test_nested_dictionary_model() -> None:
+ class NestedModel(BaseModel):
+ nested: Dict[str, BasicModel]
+
+ m = NestedModel.construct(nested={"hello": {"foo": "bar"}})
+ assert isinstance(m.nested, dict)
+ assert m.nested["hello"].foo == "bar"
+
+ # mismatched types
+ m = NestedModel.construct(nested={"hello": False})
+ assert cast(Any, m.nested["hello"]) is False
+
+
+def test_unknown_fields() -> None:
+ m1 = BasicModel.construct(foo="foo", unknown=1)
+ assert m1.foo == "foo"
+ assert cast(Any, m1).unknown == 1
+
+ m2 = BasicModel.construct(foo="foo", unknown={"foo_bar": True})
+ assert m2.foo == "foo"
+ assert cast(Any, m2).unknown == {"foo_bar": True}
+
+ assert model_dump(m2) == {"foo": "foo", "unknown": {"foo_bar": True}}
+
+
+def test_strict_validation_unknown_fields() -> None:
+ class Model(BaseModel):
+ foo: str
+
+ model = parse_obj(Model, dict(foo="hello!", user="Robert"))
+ assert model.foo == "hello!"
+ assert cast(Any, model).user == "Robert"
+
+ assert model_dump(model) == {"foo": "hello!", "user": "Robert"}
+
+
+def test_aliases() -> None:
+ class Model(BaseModel):
+ my_field: int = Field(alias="myField")
+
+ m = Model.construct(myField=1)
+ assert m.my_field == 1
+
+ # mismatched types
+ m = Model.construct(myField={"hello": False})
+ assert cast(Any, m.my_field) == {"hello": False}
+
+
+def test_repr() -> None:
+ model = BasicModel(foo="bar")
+ assert str(model) == "BasicModel(foo='bar')"
+ assert repr(model) == "BasicModel(foo='bar')"
+
+
+def test_repr_nested_model() -> None:
+ class Child(BaseModel):
+ name: str
+ age: int
+
+ class Parent(BaseModel):
+ name: str
+ child: Child
+
+ model = Parent(name="Robert", child=Child(name="Foo", age=5))
+ assert str(model) == "Parent(name='Robert', child=Child(name='Foo', age=5))"
+ assert repr(model) == "Parent(name='Robert', child=Child(name='Foo', age=5))"
+
+
+def test_optional_list() -> None:
+ class Submodel(BaseModel):
+ name: str
+
+ class Model(BaseModel):
+ items: Optional[List[Submodel]]
+
+ m = Model.construct(items=None)
+ assert m.items is None
+
+ m = Model.construct(items=[])
+ assert m.items == []
+
+ m = Model.construct(items=[{"name": "Robert"}])
+ assert m.items is not None
+ assert len(m.items) == 1
+ assert m.items[0].name == "Robert"
+
+
+def test_nested_union_of_models() -> None:
+ class Submodel1(BaseModel):
+ bar: bool
+
+ class Submodel2(BaseModel):
+ thing: str
+
+ class Model(BaseModel):
+ foo: Union[Submodel1, Submodel2]
+
+ m = Model.construct(foo={"thing": "hello"})
+ assert isinstance(m.foo, Submodel2)
+ assert m.foo.thing == "hello"
+
+
+def test_nested_union_of_mixed_types() -> None:
+ class Submodel1(BaseModel):
+ bar: bool
+
+ class Model(BaseModel):
+ foo: Union[Submodel1, Literal[True], Literal["CARD_HOLDER"]]
+
+ m = Model.construct(foo=True)
+ assert m.foo is True
+
+ m = Model.construct(foo="CARD_HOLDER")
+ assert m.foo is "CARD_HOLDER"
+
+ m = Model.construct(foo={"bar": False})
+ assert isinstance(m.foo, Submodel1)
+ assert m.foo.bar is False
+
+
+def test_nested_union_multiple_variants() -> None:
+ class Submodel1(BaseModel):
+ bar: bool
+
+ class Submodel2(BaseModel):
+ thing: str
+
+ class Submodel3(BaseModel):
+ foo: int
+
+ class Model(BaseModel):
+ foo: Union[Submodel1, Submodel2, None, Submodel3]
+
+ m = Model.construct(foo={"thing": "hello"})
+ assert isinstance(m.foo, Submodel2)
+ assert m.foo.thing == "hello"
+
+ m = Model.construct(foo=None)
+ assert m.foo is None
+
+ m = Model.construct()
+ assert m.foo is None
+
+ m = Model.construct(foo={"foo": "1"})
+ assert isinstance(m.foo, Submodel3)
+ assert m.foo.foo == 1
+
+
+def test_nested_union_invalid_data() -> None:
+ class Submodel1(BaseModel):
+ level: int
+
+ class Submodel2(BaseModel):
+ name: str
+
+ class Model(BaseModel):
+ foo: Union[Submodel1, Submodel2]
+
+ m = Model.construct(foo=True)
+ assert cast(bool, m.foo) is True
+
+ m = Model.construct(foo={"name": 3})
+ if PYDANTIC_V2:
+ assert isinstance(m.foo, Submodel1)
+ assert m.foo.name == 3 # type: ignore
+ else:
+ assert isinstance(m.foo, Submodel2)
+ assert m.foo.name == "3"
+
+
+def test_list_of_unions() -> None:
+ class Submodel1(BaseModel):
+ level: int
+
+ class Submodel2(BaseModel):
+ name: str
+
+ class Model(BaseModel):
+ items: List[Union[Submodel1, Submodel2]]
+
+ m = Model.construct(items=[{"level": 1}, {"name": "Robert"}])
+ assert len(m.items) == 2
+ assert isinstance(m.items[0], Submodel1)
+ assert m.items[0].level == 1
+ assert isinstance(m.items[1], Submodel2)
+ assert m.items[1].name == "Robert"
+
+ m = Model.construct(items=[{"level": -1}, 156])
+ assert len(m.items) == 2
+ assert isinstance(m.items[0], Submodel1)
+ assert m.items[0].level == -1
+ assert m.items[1] == 156
+
+
+def test_union_of_lists() -> None:
+ class SubModel1(BaseModel):
+ level: int
+
+ class SubModel2(BaseModel):
+ name: str
+
+ class Model(BaseModel):
+ items: Union[List[SubModel1], List[SubModel2]]
+
+ # with one valid entry
+ m = Model.construct(items=[{"name": "Robert"}])
+ assert len(m.items) == 1
+ assert isinstance(m.items[0], SubModel2)
+ assert m.items[0].name == "Robert"
+
+ # with two entries pointing to different types
+ m = Model.construct(items=[{"level": 1}, {"name": "Robert"}])
+ assert len(m.items) == 2
+ assert isinstance(m.items[0], SubModel1)
+ assert m.items[0].level == 1
+ assert isinstance(m.items[1], SubModel1)
+ assert cast(Any, m.items[1]).name == "Robert"
+
+ # with two entries pointing to *completely* different types
+ m = Model.construct(items=[{"level": -1}, 156])
+ assert len(m.items) == 2
+ assert isinstance(m.items[0], SubModel1)
+ assert m.items[0].level == -1
+ assert m.items[1] == 156
+
+
+def test_dict_of_union() -> None:
+ class SubModel1(BaseModel):
+ name: str
+
+ class SubModel2(BaseModel):
+ foo: str
+
+ class Model(BaseModel):
+ data: Dict[str, Union[SubModel1, SubModel2]]
+
+ m = Model.construct(data={"hello": {"name": "there"}, "foo": {"foo": "bar"}})
+ assert len(list(m.data.keys())) == 2
+ assert isinstance(m.data["hello"], SubModel1)
+ assert m.data["hello"].name == "there"
+ assert isinstance(m.data["foo"], SubModel2)
+ assert m.data["foo"].foo == "bar"
+
+ # TODO: test mismatched type
+
+
+def test_double_nested_union() -> None:
+ class SubModel1(BaseModel):
+ name: str
+
+ class SubModel2(BaseModel):
+ bar: str
+
+ class Model(BaseModel):
+ data: Dict[str, List[Union[SubModel1, SubModel2]]]
+
+ m = Model.construct(data={"foo": [{"bar": "baz"}, {"name": "Robert"}]})
+ assert len(m.data["foo"]) == 2
+
+ entry1 = m.data["foo"][0]
+ assert isinstance(entry1, SubModel2)
+ assert entry1.bar == "baz"
+
+ entry2 = m.data["foo"][1]
+ assert isinstance(entry2, SubModel1)
+ assert entry2.name == "Robert"
+
+ # TODO: test mismatched type
+
+
+def test_union_of_dict() -> None:
+ class SubModel1(BaseModel):
+ name: str
+
+ class SubModel2(BaseModel):
+ foo: str
+
+ class Model(BaseModel):
+ data: Union[Dict[str, SubModel1], Dict[str, SubModel2]]
+
+ m = Model.construct(data={"hello": {"name": "there"}, "foo": {"foo": "bar"}})
+ assert len(list(m.data.keys())) == 2
+ assert isinstance(m.data["hello"], SubModel1)
+ assert m.data["hello"].name == "there"
+ assert isinstance(m.data["foo"], SubModel1)
+ assert cast(Any, m.data["foo"]).foo == "bar"
+
+
+def test_iso8601_datetime() -> None:
+ class Model(BaseModel):
+ created_at: datetime
+
+ expected = datetime(2019, 12, 27, 18, 11, 19, 117000, tzinfo=timezone.utc)
+
+ if PYDANTIC_V2:
+ expected_json = '{"created_at":"2019-12-27T18:11:19.117000Z"}'
+ else:
+ expected_json = '{"created_at": "2019-12-27T18:11:19.117000+00:00"}'
+
+ model = Model.construct(created_at="2019-12-27T18:11:19.117Z")
+ assert model.created_at == expected
+ assert model_json(model) == expected_json
+
+ model = parse_obj(Model, dict(created_at="2019-12-27T18:11:19.117Z"))
+ assert model.created_at == expected
+ assert model_json(model) == expected_json
+
+
+def test_does_not_coerce_int() -> None:
+ class Model(BaseModel):
+ bar: int
+
+ assert Model.construct(bar=1).bar == 1
+ assert Model.construct(bar=10.9).bar == 10.9
+ assert Model.construct(bar="19").bar == "19" # type: ignore[comparison-overlap]
+ assert Model.construct(bar=False).bar is False
+
+
+def test_int_to_float_safe_conversion() -> None:
+ class Model(BaseModel):
+ float_field: float
+
+ m = Model.construct(float_field=10)
+ assert m.float_field == 10.0
+ assert isinstance(m.float_field, float)
+
+ m = Model.construct(float_field=10.12)
+ assert m.float_field == 10.12
+ assert isinstance(m.float_field, float)
+
+ # number too big
+ m = Model.construct(float_field=2**53 + 1)
+ assert m.float_field == 2**53 + 1
+ assert isinstance(m.float_field, int)
+
+
+def test_deprecated_alias() -> None:
+ class Model(BaseModel):
+ resource_id: str = Field(alias="model_id")
+
+ @property
+ def model_id(self) -> str:
+ return self.resource_id
+
+ m = Model.construct(model_id="id")
+ assert m.model_id == "id"
+ assert m.resource_id == "id"
+ assert m.resource_id is m.model_id
+
+ m = parse_obj(Model, {"model_id": "id"})
+ assert m.model_id == "id"
+ assert m.resource_id == "id"
+ assert m.resource_id is m.model_id
+
+
+def test_omitted_fields() -> None:
+ class Model(BaseModel):
+ resource_id: Optional[str] = None
+
+ m = Model.construct()
+ assert "resource_id" not in m.model_fields_set
+
+ m = Model.construct(resource_id=None)
+ assert "resource_id" in m.model_fields_set
+
+ m = Model.construct(resource_id="foo")
+ assert "resource_id" in m.model_fields_set
+
+
+def test_forwards_compat_model_dump_method() -> None:
+ class Model(BaseModel):
+ foo: Optional[str] = Field(alias="FOO", default=None)
+
+ m = Model(FOO="hello")
+ assert m.model_dump() == {"foo": "hello"}
+ assert m.model_dump(include={"bar"}) == {}
+ assert m.model_dump(exclude={"foo"}) == {}
+ assert m.model_dump(by_alias=True) == {"FOO": "hello"}
+
+ m2 = Model()
+ assert m2.model_dump() == {"foo": None}
+ assert m2.model_dump(exclude_unset=True) == {}
+ assert m2.model_dump(exclude_none=True) == {}
+ assert m2.model_dump(exclude_defaults=True) == {}
+
+ m3 = Model(FOO=None)
+ assert m3.model_dump() == {"foo": None}
+ assert m3.model_dump(exclude_none=True) == {}
+
+ if not PYDANTIC_V2:
+ with pytest.raises(ValueError, match="mode is only supported in Pydantic v2"):
+ m.model_dump(mode="json")
+
+ with pytest.raises(ValueError, match="round_trip is only supported in Pydantic v2"):
+ m.model_dump(round_trip=True)
+
+ with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"):
+ m.model_dump(warnings=False)
+
+
+def test_forwards_compat_model_dump_json_method() -> None:
+ class Model(BaseModel):
+ foo: Optional[str] = Field(alias="FOO", default=None)
+
+ m = Model(FOO="hello")
+ assert json.loads(m.model_dump_json()) == {"foo": "hello"}
+ assert json.loads(m.model_dump_json(include={"bar"})) == {}
+ assert json.loads(m.model_dump_json(include={"foo"})) == {"foo": "hello"}
+ assert json.loads(m.model_dump_json(by_alias=True)) == {"FOO": "hello"}
+
+ assert m.model_dump_json(indent=2) == '{\n "foo": "hello"\n}'
+
+ m2 = Model()
+ assert json.loads(m2.model_dump_json()) == {"foo": None}
+ assert json.loads(m2.model_dump_json(exclude_unset=True)) == {}
+ assert json.loads(m2.model_dump_json(exclude_none=True)) == {}
+ assert json.loads(m2.model_dump_json(exclude_defaults=True)) == {}
+
+ m3 = Model(FOO=None)
+ assert json.loads(m3.model_dump_json()) == {"foo": None}
+ assert json.loads(m3.model_dump_json(exclude_none=True)) == {}
+
+ if not PYDANTIC_V2:
+ with pytest.raises(ValueError, match="round_trip is only supported in Pydantic v2"):
+ m.model_dump_json(round_trip=True)
+
+ with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"):
+ m.model_dump_json(warnings=False)
+
+
+def test_type_compat() -> None:
+ # our model type can be assigned to Pydantic's model type
+
+ def takes_pydantic(model: pydantic.BaseModel) -> None: # noqa: ARG001
+ ...
+
+ class OurModel(BaseModel):
+ foo: Optional[str] = None
+
+ takes_pydantic(OurModel())
diff --git a/tests/test_module_client.py b/tests/test_module_client.py
new file mode 100644
index 0000000000..0beca37f61
--- /dev/null
+++ b/tests/test_module_client.py
@@ -0,0 +1,179 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+import os as _os
+
+import httpx
+import pytest
+from httpx import URL
+
+import openai
+from openai import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES
+
+
+def reset_state() -> None:
+ openai._reset_client()
+ openai.api_key = None or "My API Key"
+ openai.organization = None
+ openai.base_url = None
+ openai.timeout = DEFAULT_TIMEOUT
+ openai.max_retries = DEFAULT_MAX_RETRIES
+ openai.default_headers = None
+ openai.default_query = None
+ openai.http_client = None
+ openai.api_type = _os.environ.get("OPENAI_API_TYPE") # type: ignore
+ openai.api_version = None
+ openai.azure_endpoint = None
+ openai.azure_ad_token = None
+ openai.azure_ad_token_provider = None
+
+
+@pytest.fixture(autouse=True)
+def reset_state_fixture() -> None:
+ reset_state()
+
+
+def test_base_url_option() -> None:
+ assert openai.base_url is None
+ assert openai.completions._client.base_url == URL("https://api.openai.com/v1/")
+
+ openai.base_url = "http://foo.com"
+
+ assert openai.base_url == URL("http://foo.com")
+ assert openai.completions._client.base_url == URL("http://foo.com")
+
+
+def test_timeout_option() -> None:
+ assert openai.timeout == openai.DEFAULT_TIMEOUT
+ assert openai.completions._client.timeout == openai.DEFAULT_TIMEOUT
+
+ openai.timeout = 3
+
+ assert openai.timeout == 3
+ assert openai.completions._client.timeout == 3
+
+
+def test_max_retries_option() -> None:
+ assert openai.max_retries == openai.DEFAULT_MAX_RETRIES
+ assert openai.completions._client.max_retries == openai.DEFAULT_MAX_RETRIES
+
+ openai.max_retries = 1
+
+ assert openai.max_retries == 1
+ assert openai.completions._client.max_retries == 1
+
+
+def test_default_headers_option() -> None:
+ assert openai.default_headers == None
+
+ openai.default_headers = {"Foo": "Bar"}
+
+ assert openai.default_headers["Foo"] == "Bar"
+ assert openai.completions._client.default_headers["Foo"] == "Bar"
+
+
+def test_default_query_option() -> None:
+ assert openai.default_query is None
+ assert openai.completions._client._custom_query == {}
+
+ openai.default_query = {"Foo": {"nested": 1}}
+
+ assert openai.default_query["Foo"] == {"nested": 1}
+ assert openai.completions._client._custom_query["Foo"] == {"nested": 1}
+
+
+def test_http_client_option() -> None:
+ assert openai.http_client is None
+
+ original_http_client = openai.completions._client._client
+ assert original_http_client is not None
+
+ new_client = httpx.Client()
+ openai.http_client = new_client
+
+ assert openai.completions._client._client is new_client
+
+
+import contextlib
+from typing import Iterator
+
+from openai.lib.azure import AzureOpenAI
+
+
+@contextlib.contextmanager
+def fresh_env() -> Iterator[None]:
+ old = _os.environ.copy()
+
+ try:
+ _os.environ.clear()
+ yield
+ finally:
+ _os.environ.update(old)
+
+
+def test_only_api_key_results_in_openai_api() -> None:
+ with fresh_env():
+ openai.api_type = None
+ openai.api_key = "example API key"
+
+ assert type(openai.completions._client).__name__ == "_ModuleClient"
+
+
+def test_azure_api_key_env_without_api_version() -> None:
+ with fresh_env():
+ openai.api_type = None
+ _os.environ["AZURE_OPENAI_API_KEY"] = "example API key"
+
+ with pytest.raises(ValueError, match=r"Expected `api_version` to be given for the Azure client"):
+ openai.completions._client
+
+
+def test_azure_api_key_and_version_env() -> None:
+ with fresh_env():
+ openai.api_type = None
+ _os.environ["AZURE_OPENAI_API_KEY"] = "example API key"
+ _os.environ["OPENAI_API_VERSION"] = "example-version"
+
+ with pytest.raises(
+ ValueError,
+ match=r"Must provide one of the `base_url` or `azure_endpoint` arguments, or the `OPENAI_BASE_URL`",
+ ):
+ openai.completions._client
+
+
+def test_azure_api_key_version_and_endpoint_env() -> None:
+ with fresh_env():
+ openai.api_type = None
+ _os.environ["AZURE_OPENAI_API_KEY"] = "example API key"
+ _os.environ["OPENAI_API_VERSION"] = "example-version"
+ _os.environ["AZURE_OPENAI_ENDPOINT"] = "https://www.example"
+
+ openai.completions._client
+
+ assert openai.api_type == "azure"
+
+
+def test_azure_azure_ad_token_version_and_endpoint_env() -> None:
+ with fresh_env():
+ openai.api_type = None
+ _os.environ["AZURE_OPENAI_AD_TOKEN"] = "example AD token"
+ _os.environ["OPENAI_API_VERSION"] = "example-version"
+ _os.environ["AZURE_OPENAI_ENDPOINT"] = "https://www.example"
+
+ client = openai.completions._client
+ assert isinstance(client, AzureOpenAI)
+ assert client._azure_ad_token == "example AD token"
+
+
+def test_azure_azure_ad_token_provider_version_and_endpoint_env() -> None:
+ with fresh_env():
+ openai.api_type = None
+ _os.environ["OPENAI_API_VERSION"] = "example-version"
+ _os.environ["AZURE_OPENAI_ENDPOINT"] = "https://www.example"
+ openai.azure_ad_token_provider = lambda: "token"
+
+ client = openai.completions._client
+ assert isinstance(client, AzureOpenAI)
+ assert client._azure_ad_token_provider is not None
+ assert client._azure_ad_token_provider() == "token"
diff --git a/tests/test_qs.py b/tests/test_qs.py
new file mode 100644
index 0000000000..697b8a95ec
--- /dev/null
+++ b/tests/test_qs.py
@@ -0,0 +1,78 @@
+from typing import Any, cast
+from functools import partial
+from urllib.parse import unquote
+
+import pytest
+
+from openai._qs import Querystring, stringify
+
+
+def test_empty() -> None:
+ assert stringify({}) == ""
+ assert stringify({"a": {}}) == ""
+ assert stringify({"a": {"b": {"c": {}}}}) == ""
+
+
+def test_basic() -> None:
+ assert stringify({"a": 1}) == "a=1"
+ assert stringify({"a": "b"}) == "a=b"
+ assert stringify({"a": True}) == "a=true"
+ assert stringify({"a": False}) == "a=false"
+ assert stringify({"a": 1.23456}) == "a=1.23456"
+ assert stringify({"a": None}) == ""
+
+
+@pytest.mark.parametrize("method", ["class", "function"])
+def test_nested_dotted(method: str) -> None:
+ if method == "class":
+ serialise = Querystring(nested_format="dots").stringify
+ else:
+ serialise = partial(stringify, nested_format="dots")
+
+ assert unquote(serialise({"a": {"b": "c"}})) == "a.b=c"
+ assert unquote(serialise({"a": {"b": "c", "d": "e", "f": "g"}})) == "a.b=c&a.d=e&a.f=g"
+ assert unquote(serialise({"a": {"b": {"c": {"d": "e"}}}})) == "a.b.c.d=e"
+ assert unquote(serialise({"a": {"b": True}})) == "a.b=true"
+
+
+def test_nested_brackets() -> None:
+ assert unquote(stringify({"a": {"b": "c"}})) == "a[b]=c"
+ assert unquote(stringify({"a": {"b": "c", "d": "e", "f": "g"}})) == "a[b]=c&a[d]=e&a[f]=g"
+ assert unquote(stringify({"a": {"b": {"c": {"d": "e"}}}})) == "a[b][c][d]=e"
+ assert unquote(stringify({"a": {"b": True}})) == "a[b]=true"
+
+
+@pytest.mark.parametrize("method", ["class", "function"])
+def test_array_comma(method: str) -> None:
+ if method == "class":
+ serialise = Querystring(array_format="comma").stringify
+ else:
+ serialise = partial(stringify, array_format="comma")
+
+ assert unquote(serialise({"in": ["foo", "bar"]})) == "in=foo,bar"
+ assert unquote(serialise({"a": {"b": [True, False]}})) == "a[b]=true,false"
+ assert unquote(serialise({"a": {"b": [True, False, None, True]}})) == "a[b]=true,false,true"
+
+
+def test_array_repeat() -> None:
+ assert unquote(stringify({"in": ["foo", "bar"]})) == "in=foo&in=bar"
+ assert unquote(stringify({"a": {"b": [True, False]}})) == "a[b]=true&a[b]=false"
+ assert unquote(stringify({"a": {"b": [True, False, None, True]}})) == "a[b]=true&a[b]=false&a[b]=true"
+ assert unquote(stringify({"in": ["foo", {"b": {"c": ["d", "e"]}}]})) == "in=foo&in[b][c]=d&in[b][c]=e"
+
+
+@pytest.mark.parametrize("method", ["class", "function"])
+def test_array_brackets(method: str) -> None:
+ if method == "class":
+ serialise = Querystring(array_format="brackets").stringify
+ else:
+ serialise = partial(stringify, array_format="brackets")
+
+ assert unquote(serialise({"in": ["foo", "bar"]})) == "in[]=foo&in[]=bar"
+ assert unquote(serialise({"a": {"b": [True, False]}})) == "a[b][]=true&a[b][]=false"
+ assert unquote(serialise({"a": {"b": [True, False, None, True]}})) == "a[b][]=true&a[b][]=false&a[b][]=true"
+
+
+def test_unknown_array_format() -> None:
+ with pytest.raises(NotImplementedError, match="Unknown array_format value: foo, choose from comma, repeat"):
+ stringify({"a": ["foo", "bar"]}, array_format=cast(Any, "foo"))
diff --git a/tests/test_required_args.py b/tests/test_required_args.py
new file mode 100644
index 0000000000..1de017db24
--- /dev/null
+++ b/tests/test_required_args.py
@@ -0,0 +1,111 @@
+from __future__ import annotations
+
+import pytest
+
+from openai._utils import required_args
+
+
+def test_too_many_positional_params() -> None:
+ @required_args(["a"])
+ def foo(a: str | None = None) -> str | None:
+ return a
+
+ with pytest.raises(TypeError, match=r"foo\(\) takes 1 argument\(s\) but 2 were given"):
+ foo("a", "b") # type: ignore
+
+
+def test_positional_param() -> None:
+ @required_args(["a"])
+ def foo(a: str | None = None) -> str | None:
+ return a
+
+ assert foo("a") == "a"
+ assert foo(None) is None
+ assert foo(a="b") == "b"
+
+ with pytest.raises(TypeError, match="Missing required argument: 'a'"):
+ foo()
+
+
+def test_keyword_only_param() -> None:
+ @required_args(["a"])
+ def foo(*, a: str | None = None) -> str | None:
+ return a
+
+ assert foo(a="a") == "a"
+ assert foo(a=None) is None
+ assert foo(a="b") == "b"
+
+ with pytest.raises(TypeError, match="Missing required argument: 'a'"):
+ foo()
+
+
+def test_multiple_params() -> None:
+ @required_args(["a", "b", "c"])
+ def foo(a: str = "", *, b: str = "", c: str = "") -> str | None:
+ return a + " " + b + " " + c
+
+ assert foo(a="a", b="b", c="c") == "a b c"
+
+ error_message = r"Missing required arguments.*"
+
+ with pytest.raises(TypeError, match=error_message):
+ foo()
+
+ with pytest.raises(TypeError, match=error_message):
+ foo(a="a")
+
+ with pytest.raises(TypeError, match=error_message):
+ foo(b="b")
+
+ with pytest.raises(TypeError, match=error_message):
+ foo(c="c")
+
+ with pytest.raises(TypeError, match=r"Missing required argument: 'a'"):
+ foo(b="a", c="c")
+
+ with pytest.raises(TypeError, match=r"Missing required argument: 'b'"):
+ foo("a", c="c")
+
+
+def test_multiple_variants() -> None:
+ @required_args(["a"], ["b"])
+ def foo(*, a: str | None = None, b: str | None = None) -> str | None:
+ return a if a is not None else b
+
+ assert foo(a="foo") == "foo"
+ assert foo(b="bar") == "bar"
+ assert foo(a=None) is None
+ assert foo(b=None) is None
+
+ # TODO: this error message could probably be improved
+ with pytest.raises(
+ TypeError,
+ match=r"Missing required arguments; Expected either \('a'\) or \('b'\) arguments to be given",
+ ):
+ foo()
+
+
+def test_multiple_params_multiple_variants() -> None:
+ @required_args(["a", "b"], ["c"])
+ def foo(*, a: str | None = None, b: str | None = None, c: str | None = None) -> str | None:
+ if a is not None:
+ return a
+ if b is not None:
+ return b
+ return c
+
+ error_message = r"Missing required arguments; Expected either \('a' and 'b'\) or \('c'\) arguments to be given"
+
+ with pytest.raises(TypeError, match=error_message):
+ foo(a="foo")
+
+ with pytest.raises(TypeError, match=error_message):
+ foo(b="bar")
+
+ with pytest.raises(TypeError, match=error_message):
+ foo()
+
+ assert foo(a=None, b="bar") == "bar"
+ assert foo(c=None) is None
+ assert foo(c="foo") == "foo"
diff --git a/tests/test_streaming.py b/tests/test_streaming.py
new file mode 100644
index 0000000000..75e4ca2699
--- /dev/null
+++ b/tests/test_streaming.py
@@ -0,0 +1,104 @@
+from typing import Iterator, AsyncIterator
+
+import pytest
+
+from openai._streaming import SSEDecoder
+
+
+@pytest.mark.asyncio
+async def test_basic_async() -> None:
+ async def body() -> AsyncIterator[str]:
+ yield "event: completion"
+ yield 'data: {"foo":true}'
+ yield ""
+
+ async for sse in SSEDecoder().aiter(body()):
+ assert sse.event == "completion"
+ assert sse.json() == {"foo": True}
+
+
+def test_basic() -> None:
+ def body() -> Iterator[str]:
+ yield "event: completion"
+ yield 'data: {"foo":true}'
+ yield ""
+
+ it = SSEDecoder().iter(body())
+ sse = next(it)
+ assert sse.event == "completion"
+ assert sse.json() == {"foo": True}
+
+ with pytest.raises(StopIteration):
+ next(it)
+
+
+def test_data_missing_event() -> None:
+ def body() -> Iterator[str]:
+ yield 'data: {"foo":true}'
+ yield ""
+
+ it = SSEDecoder().iter(body())
+ sse = next(it)
+ assert sse.event is None
+ assert sse.json() == {"foo": True}
+
+ with pytest.raises(StopIteration):
+ next(it)
+
+
+def test_event_missing_data() -> None:
+ def body() -> Iterator[str]:
+ yield "event: ping"
+ yield ""
+
+ it = SSEDecoder().iter(body())
+ sse = next(it)
+ assert sse.event == "ping"
+ assert sse.data == ""
+
+ with pytest.raises(StopIteration):
+ next(it)
+
+
+def test_multiple_events() -> None:
+ def body() -> Iterator[str]:
+ yield "event: ping"
+ yield ""
+ yield "event: completion"
+ yield ""
+
+ it = SSEDecoder().iter(body())
+
+ sse = next(it)
+ assert sse.event == "ping"
+ assert sse.data == ""
+
+ sse = next(it)
+ assert sse.event == "completion"
+ assert sse.data == ""
+
+ with pytest.raises(StopIteration):
+ next(it)
+
+
+def test_multiple_events_with_data() -> None:
+ def body() -> Iterator[str]:
+ yield "event: ping"
+ yield 'data: {"foo":true}'
+ yield ""
+ yield "event: completion"
+ yield 'data: {"bar":false}'
+ yield ""
+
+ it = SSEDecoder().iter(body())
+
+ sse = next(it)
+ assert sse.event == "ping"
+ assert sse.json() == {"foo": True}
+
+ sse = next(it)
+ assert sse.event == "completion"
+ assert sse.json() == {"bar": False}
+
+ with pytest.raises(StopIteration):
+ next(it)
diff --git a/tests/test_transform.py b/tests/test_transform.py
new file mode 100644
index 0000000000..3fc89bb093
--- /dev/null
+++ b/tests/test_transform.py
@@ -0,0 +1,232 @@
+from __future__ import annotations
+
+from typing import Any, List, Union, Optional
+from datetime import date, datetime
+from typing_extensions import Required, Annotated, TypedDict
+
+import pytest
+
+from openai._utils import PropertyInfo, transform, parse_datetime
+from openai._models import BaseModel
+
+
+class Foo1(TypedDict):
+ foo_bar: Annotated[str, PropertyInfo(alias="fooBar")]
+
+
+def test_top_level_alias() -> None:
+ assert transform({"foo_bar": "hello"}, expected_type=Foo1) == {"fooBar": "hello"}
+
+
+class Foo2(TypedDict):
+ bar: Bar2
+
+
+class Bar2(TypedDict):
+ this_thing: Annotated[int, PropertyInfo(alias="this__thing")]
+ baz: Annotated[Baz2, PropertyInfo(alias="Baz")]
+
+
+class Baz2(TypedDict):
+ my_baz: Annotated[str, PropertyInfo(alias="myBaz")]
+
+
+def test_recursive_typeddict() -> None:
+ assert transform({"bar": {"this_thing": 1}}, Foo2) == {"bar": {"this__thing": 1}}
+ assert transform({"bar": {"baz": {"my_baz": "foo"}}}, Foo2) == {"bar": {"Baz": {"myBaz": "foo"}}}
+
+
+class Foo3(TypedDict):
+ things: List[Bar3]
+
+
+class Bar3(TypedDict):
+ my_field: Annotated[str, PropertyInfo(alias="myField")]
+
+
+def test_list_of_typeddict() -> None:
+ result = transform({"things": [{"my_field": "foo"}, {"my_field": "foo2"}]}, expected_type=Foo3)
+ assert result == {"things": [{"myField": "foo"}, {"myField": "foo2"}]}
+
+
+class Foo4(TypedDict):
+ foo: Union[Bar4, Baz4]
+
+
+class Bar4(TypedDict):
+ foo_bar: Annotated[str, PropertyInfo(alias="fooBar")]
+
+
+class Baz4(TypedDict):
+ foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")]
+
+
+def test_union_of_typeddict() -> None:
+ assert transform({"foo": {"foo_bar": "bar"}}, Foo4) == {"foo": {"fooBar": "bar"}}
+ assert transform({"foo": {"foo_baz": "baz"}}, Foo4) == {"foo": {"fooBaz": "baz"}}
+ assert transform({"foo": {"foo_baz": "baz", "foo_bar": "bar"}}, Foo4) == {"foo": {"fooBaz": "baz", "fooBar": "bar"}}
+
+
+class Foo5(TypedDict):
+ foo: Annotated[Union[Bar4, List[Baz4]], PropertyInfo(alias="FOO")]
+
+
+class Bar5(TypedDict):
+ foo_bar: Annotated[str, PropertyInfo(alias="fooBar")]
+
+
+class Baz5(TypedDict):
+ foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")]
+
+
+def test_union_of_list() -> None:
+ assert transform({"foo": {"foo_bar": "bar"}}, Foo5) == {"FOO": {"fooBar": "bar"}}
+ assert transform(
+ {
+ "foo": [
+ {"foo_baz": "baz"},
+ {"foo_baz": "baz"},
+ ]
+ },
+ Foo5,
+ ) == {"FOO": [{"fooBaz": "baz"}, {"fooBaz": "baz"}]}
+
+
+class Foo6(TypedDict):
+ bar: Annotated[str, PropertyInfo(alias="Bar")]
+
+
+def test_includes_unknown_keys() -> None:
+ assert transform({"bar": "bar", "baz_": {"FOO": 1}}, Foo6) == {
+ "Bar": "bar",
+ "baz_": {"FOO": 1},
+ }
+
+
+class Foo7(TypedDict):
+ bar: Annotated[List[Bar7], PropertyInfo(alias="bAr")]
+ foo: Bar7
+
+
+class Bar7(TypedDict):
+ foo: str
+
+
+def test_ignores_invalid_input() -> None:
+ assert transform({"bar": ""}, Foo7) == {"bAr": ""}
+ assert transform({"foo": ""}, Foo7) == {"foo": ""}
+
+
+class DatetimeDict(TypedDict, total=False):
+ foo: Annotated[datetime, PropertyInfo(format="iso8601")]
+
+ bar: Annotated[Optional[datetime], PropertyInfo(format="iso8601")]
+
+ required: Required[Annotated[Optional[datetime], PropertyInfo(format="iso8601")]]
+
+ list_: Required[Annotated[Optional[List[datetime]], PropertyInfo(format="iso8601")]]
+
+ union: Annotated[Union[int, datetime], PropertyInfo(format="iso8601")]
+
+
+class DateDict(TypedDict, total=False):
+ foo: Annotated[date, PropertyInfo(format="iso8601")]
+
+
+def test_iso8601_format() -> None:
+ dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00")
+ assert transform({"foo": dt}, DatetimeDict) == {"foo": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap]
+
+ dt = dt.replace(tzinfo=None)
+ assert transform({"foo": dt}, DatetimeDict) == {"foo": "2023-02-23T14:16:36.337692"} # type: ignore[comparison-overlap]
+
+ assert transform({"foo": None}, DateDict) == {"foo": None} # type: ignore[comparison-overlap]
+ assert transform({"foo": date.fromisoformat("2023-02-23")}, DateDict) == {"foo": "2023-02-23"} # type: ignore[comparison-overlap]
+
+
+def test_optional_iso8601_format() -> None:
+ dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00")
+ assert transform({"bar": dt}, DatetimeDict) == {"bar": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap]
+
+ assert transform({"bar": None}, DatetimeDict) == {"bar": None}
+
+
+def test_required_iso8601_format() -> None:
+ dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00")
+ assert transform({"required": dt}, DatetimeDict) == {"required": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap]
+
+ assert transform({"required": None}, DatetimeDict) == {"required": None}
+
+
+def test_union_datetime() -> None:
+ dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00")
+ assert transform({"union": dt}, DatetimeDict) == { # type: ignore[comparison-overlap]
+ "union": "2023-02-23T14:16:36.337692+00:00"
+ }
+
+ assert transform({"union": "foo"}, DatetimeDict) == {"union": "foo"}
+
+
+def test_nested_list_iso6801_format() -> None:
+ dt1 = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00")
+ dt2 = parse_datetime("2022-01-15T06:34:23Z")
+ assert transform({"list_": [dt1, dt2]}, DatetimeDict) == { # type: ignore[comparison-overlap]
+ "list_": ["2023-02-23T14:16:36.337692+00:00", "2022-01-15T06:34:23+00:00"]
+ }
+
+
+def test_datetime_custom_format() -> None:
+ dt = parse_datetime("2022-01-15T06:34:23Z")
+
+ result = transform(dt, Annotated[datetime, PropertyInfo(format="custom", format_template="%H")])
+ assert result == "06" # type: ignore[comparison-overlap]
+
+
+class DateDictWithRequiredAlias(TypedDict, total=False):
+ required_prop: Required[Annotated[date, PropertyInfo(format="iso8601", alias="prop")]]
+
+
+def test_datetime_with_alias() -> None:
+ assert transform({"required_prop": None}, DateDictWithRequiredAlias) == {"prop": None} # type: ignore[comparison-overlap]
+ assert transform({"required_prop": date.fromisoformat("2023-02-23")}, DateDictWithRequiredAlias) == {"prop": "2023-02-23"} # type: ignore[comparison-overlap]
+
+
+class MyModel(BaseModel):
+ foo: str
+
+
+def test_pydantic_model_to_dictionary() -> None:
+ assert transform(MyModel(foo="hi!"), Any) == {"foo": "hi!"}
+ assert transform(MyModel.construct(foo="hi!"), Any) == {"foo": "hi!"}
+
+
+def test_pydantic_empty_model() -> None:
+ assert transform(MyModel.construct(), Any) == {}
+
+
+def test_pydantic_unknown_field() -> None:
+ assert transform(MyModel.construct(my_untyped_field=True), Any) == {"my_untyped_field": True}
+
+
+def test_pydantic_mismatched_types() -> None:
+ model = MyModel.construct(foo=True)
+ with pytest.warns(UserWarning):
+ params = transform(model, Any)
+ assert params == {"foo": True}
+
+
+def test_pydantic_mismatched_object_type() -> None:
+ model = MyModel.construct(foo=MyModel.construct(hello="world"))
+ with pytest.warns(UserWarning):
+ params = transform(model, Any)
+ assert params == {"foo": {"hello": "world"}}
+
+
+class ModelNestedObjects(BaseModel):
+ nested: MyModel
+
+
+def test_pydantic_nested_objects() -> None:
+ model = ModelNestedObjects.construct(nested={"foo": "stainless"})
+ assert isinstance(model.nested, MyModel)
+ assert transform(model, Any) == {"nested": {"foo": "stainless"}}
diff --git a/tests/utils.py b/tests/utils.py
new file mode 100644
index 0000000000..3cccab223a
--- /dev/null
+++ b/tests/utils.py
@@ -0,0 +1,105 @@
+from __future__ import annotations
+
+import traceback
+from typing import Any, TypeVar, cast
+from datetime import date, datetime
+from typing_extensions import Literal, get_args, get_origin, assert_type
+
+from openai._types import NoneType
+from openai._utils import is_dict, is_list, is_list_type, is_union_type
+from openai._compat import PYDANTIC_V2, field_outer_type, get_model_fields
+from openai._models import BaseModel
+
+BaseModelT = TypeVar("BaseModelT", bound=BaseModel)
+
+
+def assert_matches_model(model: type[BaseModelT], value: BaseModelT, *, path: list[str]) -> bool:
+ for name, field in get_model_fields(model).items():
+ field_value = getattr(value, name)
+ if PYDANTIC_V2:
+ allow_none = False
+ else:
+ # in v1 nullability was structured differently
+ # https://docs.pydantic.dev/2.0/migration/#required-optional-and-nullable-fields
+ allow_none = getattr(field, "allow_none", False)
+
+ assert_matches_type(
+ field_outer_type(field),
+ field_value,
+ path=[*path, name],
+ allow_none=allow_none,
+ )
+
+ return True
+
+
+# Note: the `path` argument is only used to improve error messages when `--showlocals` is used
+def assert_matches_type(
+ type_: Any,
+ value: object,
+ *,
+ path: list[str],
+ allow_none: bool = False,
+) -> None:
+ if allow_none and value is None:
+ return
+
+ if type_ is None or type_ is NoneType:
+ assert value is None
+ return
+
+ origin = get_origin(type_) or type_
+
+ if is_list_type(type_):
+ return _assert_list_type(type_, value)
+
+ if origin == str:
+ assert isinstance(value, str)
+ elif origin == int:
+ assert isinstance(value, int)
+ elif origin == bool:
+ assert isinstance(value, bool)
+ elif origin == float:
+ assert isinstance(value, float)
+ elif origin == datetime:
+ assert isinstance(value, datetime)
+ elif origin == date:
+ assert isinstance(value, date)
+ elif origin == object:
+ # nothing to do here, the expected type is unknown
+ pass
+ elif origin == Literal:
+ assert value in get_args(type_)
+ elif origin == dict:
+ assert is_dict(value)
+
+ args = get_args(type_)
+ key_type = args[0]
+ items_type = args[1]
+
+ for key, item in value.items():
+ assert_matches_type(key_type, key, path=[*path, ""])
+ assert_matches_type(items_type, item, path=[*path, ""])
+ elif is_union_type(type_):
+ for i, variant in enumerate(get_args(type_)):
+ try:
+ assert_matches_type(variant, value, path=[*path, f"variant {i}"])
+ return
+ except AssertionError:
+ traceback.print_exc()
+ continue
+
+ assert False, "Did not match any variants"
+ elif issubclass(origin, BaseModel):
+ assert isinstance(value, type_)
+ assert assert_matches_model(type_, cast(Any, value), path=path)
+ else:
+ assert None, f"Unhandled field type: {type_}"
+
+
+def _assert_list_type(type_: type[object], value: object) -> None:
+ assert is_list(value)
+
+ inner_type = get_args(type_)[0]
+ for entry in value:
+ assert_type(inner_type, entry) # type: ignore