diff --git a/.devcontainer/build/devcontainer.json b/.devcontainer/build/devcontainer.json index 5debe40f27a6c..d92b7a8595881 100644 --- a/.devcontainer/build/devcontainer.json +++ b/.devcontainer/build/devcontainer.json @@ -23,7 +23,7 @@ }, "postCreateCommand": { - "post_create": "bash .devcontainer/post_create_commands.sh", + "post_create": "bash .devcontainer/post_create_commands.sh", "bashrc": "echo \"alias python=python3\" >> ~/.bashrc" }, diff --git a/.devcontainer/build_gpu/devcontainer.json b/.devcontainer/build_gpu/devcontainer.json index cbe6ee4d3ecfe..9dd6e95b2d1bf 100644 --- a/.devcontainer/build_gpu/devcontainer.json +++ b/.devcontainer/build_gpu/devcontainer.json @@ -24,7 +24,7 @@ }, "postCreateCommand": { - "post_create": "bash .devcontainer/post_create_commands.sh", + "post_create": "bash .devcontainer/post_create_commands.sh", "bashrc": "echo \"alias python=python3\" >> ~/.bashrc" }, diff --git a/.devcontainer/image/devcontainer.json b/.devcontainer/image/devcontainer.json index f5ef5400a41e9..4e54acc447657 100644 --- a/.devcontainer/image/devcontainer.json +++ b/.devcontainer/image/devcontainer.json @@ -14,20 +14,20 @@ }, "postCreateCommand": { - "post_create": "bash .devcontainer/post_create_commands.sh", + "post_create": "bash .devcontainer/post_create_commands.sh", "bashrc": "echo \"alias python=python3\" >> ~/.bashrc" }, "initializeCommand": "docker pull unifyai/ivy:latest", // Use 'forwardPorts' to make a list of ports inside the container available locally. // "forwardPorts": [], - + // Uncomment when using a ptrace-based debugger like C++, Go, and Rust // "runArgs": [ "--cap-add=SYS_PTRACE", "--security-opt", "seccomp=unconfined" ], - + // Uncomment to use the Docker CLI from inside the container. See https://aka.ms/vscode-remote/samples/docker-from-docker. // "mounts": [ "source=/var/run/docker.sock,target=/var/run/docker.sock,type=bind" ], - + // Uncomment to connect as a non-root user if you've added one. See https://aka.ms/vscode-remote/containers/non-root. // "remoteUser": "vscode", "features": { diff --git a/.devcontainer/image_gpu/devcontainer.json b/.devcontainer/image_gpu/devcontainer.json index a3bbfba0c34a4..c932c92507b2f 100644 --- a/.devcontainer/image_gpu/devcontainer.json +++ b/.devcontainer/image_gpu/devcontainer.json @@ -16,20 +16,20 @@ "runArgs": ["--gpus","all"], "postCreateCommand": { - "post_create": "bash .devcontainer/post_create_commands.sh", + "post_create": "bash .devcontainer/post_create_commands.sh", "bashrc": "echo \"alias python=python3\" >> ~/.bashrc" }, "initializeCommand": "docker pull unifyai/ivy:latest", - + // Use 'forwardPorts' to make a list of ports inside the container available locally. // "forwardPorts": [], - + // Uncomment when using a ptrace-based debugger like C++, Go, and Rust // "runArgs": [ "--cap-add=SYS_PTRACE", "--security-opt", "seccomp=unconfined" ], - + // Uncomment to use the Docker CLI from inside the container. See https://aka.ms/vscode-remote/samples/docker-from-docker. // "mounts": [ "source=/var/run/docker.sock,target=/var/run/docker.sock,type=bind" ], - + // Uncomment to connect as a non-root user if you've added one. See https://aka.ms/vscode-remote/containers/non-root. // "remoteUser": "vscode", "features": { diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index a2175bcbfaf5f..f375d2cf90da1 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -5,7 +5,7 @@ ivy/utils/backend @VedPatwardhan @CatB1t ivy/utils/backend/ast_helpers.py @CatB1t # Ivy Testing -ivy_tests/test_ivy/helpers/ @CatB1t +ivy_tests/test_ivy/helpers/ @sherry30 @CatB1t ivy_tests/array_api_testing/ @aarsh2001 @hirwa-nshuti # Docs builder diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index aeae956c725f6..b1d1b60abc1ca 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,31 +1,31 @@ - -# PR Description +# PR Description - -## Related Issue +## Related Issue - Close # -## Checklist +## Checklist - [ ] Did you add a function? - [ ] Did you add the tests? - [ ] Did you follow the steps we provided? -### Socials: +### Socials: - diff --git a/.github/workflows/array-api-intelligent-tests.yml b/.github/workflows/array-api-intelligent-tests.yml index 604448e4750dd..5d18b3e64c9a3 100644 --- a/.github/workflows/array-api-intelligent-tests.yml +++ b/.github/workflows/array-api-intelligent-tests.yml @@ -30,7 +30,7 @@ jobs: SSH_DEPLOY_KEY: ${{ secrets.SSH_DEPLOY_KEY }} run: | source ./ivy/clone_mapping.sh main - pip install pydriller pymongo + pip install pydriller pymongo cp Mapping/tests.pbz2 ivy/ cd ivy python run_tests_CLI/array_api_determine_tests.py diff --git a/.github/workflows/auto-comment.yml b/.github/workflows/auto-comment.yml index c66bce4507e22..178124f651f00 100644 --- a/.github/workflows/auto-comment.yml +++ b/.github/workflows/auto-comment.yml @@ -16,7 +16,7 @@ jobs: issue_number: context.issue.number, owner: context.repo.owner, repo: context.repo.repo, - + body: `Thanks for contributing to Ivy! 😊👏 Here are some of the important points from our Contributing Guidelines 📝: 1. Feel free to ignore the \`run_tests (1)\`, \`run_tests (2)\`, … jobs, and only look at the \`display_test_results\` job. 👀 It contains the following two sections: diff --git a/.github/workflows/checklist_actions.yml b/.github/workflows/checklist_actions.yml index 6bfc557186be7..fcba0e24baa74 100644 --- a/.github/workflows/checklist_actions.yml +++ b/.github/workflows/checklist_actions.yml @@ -38,7 +38,7 @@ jobs: comment-id: ${{ github.event.comment.id }} body: ${{ steps.template.outputs.result }} edit-mode: replace - + frontend_pr_commented: name: Frontend PR comment if: ${{ github.event.issue.pull_request && github.event.comment.body == 'add_frontend_checklist' }} diff --git a/.github/workflows/dockerfile-push.yml b/.github/workflows/dockerfile-push.yml index 7b61ef63b6c3b..06be15ea1e427 100644 --- a/.github/workflows/dockerfile-push.yml +++ b/.github/workflows/dockerfile-push.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-20.04 steps: - - + - name: Checkout 🛎 Ivy uses: actions/checkout@v3 diff --git a/.github/workflows/intelligent-tests.yml b/.github/workflows/intelligent-tests.yml index 6c87d170816b1..41a8c4b7e7cbe 100644 --- a/.github/workflows/intelligent-tests.yml +++ b/.github/workflows/intelligent-tests.yml @@ -63,7 +63,7 @@ jobs: SSH_DEPLOY_KEY: ${{ secrets.SSH_DEPLOY_KEY }} run: | source ./ivy/clone_mapping.sh master${{ matrix.branch }} - pip install pydriller pymongo + pip install pydriller pymongo cp Mapping/tests.pbz2 ivy/ cd ivy mkdir .ivy diff --git a/.github/workflows/manual-tests-pr.yml b/.github/workflows/manual-tests-pr.yml index b818bd37e614a..1e74fbff6d274 100644 --- a/.github/workflows/manual-tests-pr.yml +++ b/.github/workflows/manual-tests-pr.yml @@ -34,7 +34,7 @@ jobs: mkdir .ivy touch .ivy/key.pem echo -n ${{ secrets.USER_API_KEY }} > .ivy/key.pem - python setup_tests.py ${{ github.event.inputs.test }} + python setup_tests.py ${{ github.event.inputs.test }} python run_tests_pr.py new_failures.txt continue-on-error: true diff --git a/.github/workflows/manual-tests.yml b/.github/workflows/manual-tests.yml index 24f2c92ee2f4f..099de5b4c4d53 100644 --- a/.github/workflows/manual-tests.yml +++ b/.github/workflows/manual-tests.yml @@ -52,7 +52,7 @@ jobs: mkdir .ivy touch .ivy/key.pem echo -n ${{ secrets.USER_API_KEY }} > .ivy/key.pem - python3 setup_tests.py ${{ github.event.inputs.test }} + python3 setup_tests.py ${{ github.event.inputs.test }} python3 run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.event.inputs.gpu }} ${{ github.run_id }} 'false' ${{ steps.jobs.outputs.html_url }} continue-on-error: true @@ -86,7 +86,7 @@ jobs: mkdir .ivy touch .ivy/key.pem echo -n ${{ secrets.USER_API_KEY }} > .ivy/key.pem - python setup_tests.py "${{ github.event.inputs.test }}" + python setup_tests.py "${{ github.event.inputs.test }}" python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} ${{ github.event.inputs.version}} 'false' ${{ github.run_id }} 'false' ${{ steps.jobs.outputs.html_url }} continue-on-error: true diff --git a/.gitmodules b/.gitmodules index 7aebe443b29c1..3f06491476bf6 100644 --- a/.gitmodules +++ b/.gitmodules @@ -2,3 +2,6 @@ [submodule "ivy_tests/array_api_testing/test_array_api"] path = ivy_tests/array_api_testing/test_array_api url = https://github.com/data-apis/array-api-tests.git +[submodule "docs/demos"] + path = docs/demos + url = https://github.com/unifyai/demos.git diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 07e0419aad0eb..19bd018094670 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,10 @@ repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: check-yaml + - id: trailing-whitespace + - id: check-toml - repo: https://github.com/psf/black rev: 23.7.0 hooks: @@ -7,7 +13,7 @@ repos: args: - "--preview" - repo: https://github.com/PyCQA/autoflake - rev: v2.2.0 + rev: v2.2.1 hooks: - id: autoflake - repo: https://github.com/pycqa/flake8 diff --git a/LICENSE b/LICENSE index 1bed16118cb75..0c0c4b97e0b30 100644 --- a/LICENSE +++ b/LICENSE @@ -174,9 +174,9 @@ Copyright 2021 The Ivy Authors. All rights reserved. defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. - - 10. The software in this directory and its subdirectories is licensed under the Apache License, - Version 2.0, except for the software contained within the ivy/compiler directory, + + 10. The software in this directory and its subdirectories is licensed under the Apache License, + Version 2.0, except for the software contained within the ivy/compiler directory, which is subject to the license set forth in the LICENSE file located within that directory. END OF TERMS AND CONDITIONS diff --git a/README.md b/README.md index be051879882f0..e4ce1c4e04438 100644 --- a/README.md +++ b/README.md @@ -322,63 +322,6 @@ The model\'s output can be visualized as follows: -Last but not least, we are also working on specific extensions totally -written in Ivy and therefore usable within any framework, covering -topics like [Mechanics](https://github.com/unifyai/mech), [Computer -Vision](https://github.com/unifyai/vision), -[Robotics](https://github.com/unifyai/robot), a [Reinforcement Learning -Gym](https://github.com/unifyai/gym), -[Memory](https://github.com/unifyai/memory) and implementation of -various [Models](https://github.com/unifyai/models) or [Builder -tools](https://github.com/unifyai/builder) with trainers, data loaders -and more! - -
-
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
-
As always, you can find more information about [Ivy as a framework in the @@ -437,7 +380,7 @@ expected. :sweat_smile: ``` bash git clone https://github.com/unifyai/ivy.git -cd ivy +cd ivy pip install --user -e . ``` @@ -1526,7 +1469,7 @@ device = "cuda:0" if ivy.gpu_is_available() else "cpu" # training hyperparams optimizer= ivy.Adam(1e-4) -batch_size = 64 +batch_size = 64 num_epochs = 20 num_classes = 10 @@ -1600,7 +1543,7 @@ def train(images, classes, epochs, model, device, num_classes=10, batch_size=32) f.writerows(metrics) -# assuming the dataset(images and classes) are already prepared in a folder +# assuming the dataset(images and classes) are already prepared in a folder train(images, classes, num_epochs, model, device, num_classes = num_classes, batch_size = batch_size) ``` diff --git a/automation_tools/checklists/frontend_checklist.md b/automation_tools/checklists/frontend_checklist.md index ecfe3226b0dc7..53fc5f25e99b1 100644 --- a/automation_tools/checklists/frontend_checklist.md +++ b/automation_tools/checklists/frontend_checklist.md @@ -6,7 +6,7 @@ The [Ivy Docs](https://unify.ai/docs/ivy/) represent the ground truth for the ta Please note that the contributor is not expected to understand everything in the checklist. It's mainly here for the reviewer to make sure everything has been done correctly 🙂 #### LEGEND 🗺: -- ❌ : Check item is not completed. +- ❌ : Check item is not completed. - ✅ : Check item is ready for review. - 🆘 : Stuck/Doubting implementation (PR author should add comments explaining why). - ⏩ : Check is not applicable to function (skip). @@ -24,15 +24,15 @@ Please note that the contributor is not expected to understand everything in the 2. - [ ] ❌: A ToDo comment has been added prompting to pass the frontend argument to the ivy function whose behavior is to be extended. 6. - [ ] ❌: In case a frontend function is being added: 1. - [ ] ❌: It is a composition of ivy functions. - 2. - [ ] ❌: In case the needed composition is long (using numerous ivy functions), a `Missing Function Suggestion` issue has been opened to suggest a new ivy function should be added to shorten the frontend implementation. + 2. - [ ] ❌: In case the needed composition is long (using numerous ivy functions), a `Missing Function Suggestion` issue has been opened to suggest a new ivy function should be added to shorten the frontend implementation. 3. - [ ] ❌: `@to_ivy_arrays_and_back` has been added to the function. 7. - [ ] ❌: In case a frontend method is being added: - 1. - [ ] ❌: It is composed of existing frontend functions or methods. + 1. - [ ] ❌: It is composed of existing frontend functions or methods. 2. - [ ] ❌: If a required frontend function has not yet been added, the method may be implemented as a composition of ivy functions, making sure that: - [ ] ❌: `@to_ivy_arrays_and_back` has been added to the method. - [ ] ❌: A ToDo comment has been made prompting to remove the decorator and update the implementation as soon as the missing function has been added. 8. - [ ] ❌: The function/method's test has been added (except in the alias case mentioned in <2>): - 1. - [ ] ❌: All supported arguments are being generated in `handle_frontend_test`/`handle_frontend_method` and passed to `test_frontend_function`/`test_frontend_method`. + 1. - [ ] ❌: All supported arguments are being generated in `handle_frontend_test`/`handle_frontend_method` and passed to `test_frontend_function`/`test_frontend_method`. 2. - [ ] ❌: The argument generation covers all possible supported values. Array sizes, dimensions, and axes adhere to the full supported set of the original function/method. 3. - [ ] ❌: The `available_dtypes` parameter passed to the helper generating the function/method's input array is set to `helpers.get_dtypes("valid")`. If there are unsupported dtypes that cause the test to fail, they should be handled by adding `@with_supported_dtypes`/`@with_unsupported_dtype` to the function/method. 9. - [ ] ❌: The PR is not introducing any test failures. diff --git a/automation_tools/checklists/reformat_checklist.md b/automation_tools/checklists/reformat_checklist.md index d3b185224f6fd..0e40e9804211a 100644 --- a/automation_tools/checklists/reformat_checklist.md +++ b/automation_tools/checklists/reformat_checklist.md @@ -4,7 +4,7 @@ The [Ivy Docs](https://unify.ai/docs/ivy/) represent the ground truth for the task descriptions and this checklist should only be used as a supplementary item to aid with the review process. #### LEGEND 🗺: -- ❌ : Check item is not completed. +- ❌ : Check item is not completed. - ✅ : Check item is ready for review. - 🆘 : Stuck/Doubting implementation (PR author should add comments explaining why). - ⏩ : Check is not applicable to function (skip). @@ -16,7 +16,7 @@ The [Ivy Docs](https://unify.ai/docs/ivy/) represent the ground truth for the ta - [ ] ❌: [ivy/functional/backends/numpy/{{ .category_name }}.py](https://github.com/unifyai/ivy/tree/main/ivy/functional/backends/numpy/{{ .category_name }}.py). - [ ] ❌: [ivy/functional/backends/tensorflow/{{ .category_name }}.py](https://github.com/unifyai/ivy/tree/main/ivy/functional/backends/tensorflow/{{ .category_name }}.py). - [ ] ❌: [ivy/functional/backends/torch/{{ .category_name }}.py](https://github.com/unifyai/ivy/tree/main/ivy/functional/backends/torch/{{ .category_name }}.py). -2. - [ ] ❌: Implement the following if they don't exist: +2. - [ ] ❌: Implement the following if they don't exist: 1. - [ ] ❌: The `ivy.Array` instance method in [ivy/data_classes/array/{{ .category_name }}.py](https://github.com/unifyai/ivy/blob/main/ivy/data_classes/array/{{ .category_name }}.py). 2. - [ ] ❌: The `ivy.Array` special method in [ivy/data_classes/array/array.py](https://github.com/unifyai/ivy/blob/main/ivy/data_classes/array/array.py). 3. - [ ] ❌: The `ivy.Array` reverse special method in [ivy/data_classes/array/array.py](https://github.com/unifyai/ivy/blob/main/ivy/data_classes/array/array.py). @@ -25,11 +25,11 @@ The [Ivy Docs](https://unify.ai/docs/ivy/) represent the ground truth for the ta 6. - [ ] ❌: The `ivy.Container` special method in [ivy/data_classes/container/container.py](https://github.com/unifyai/ivy/blob/main/ivy/data_classes/container/container.py). 7. - [ ] ❌: The `ivy.Container` reverse special method in [ivy/data_classes/container/container.py](https://github.com/unifyai/ivy/blob/main/ivy/data_classes/container/container.py). 3. - [ ] ❌: Make sure that the aforementioned methods are added into the correct category-specific parent class, such as `ivy.ArrayWithElementwise`, `ivy.ContainerWithManipulation` etc. -4. - [ ] ❌: Correct all of the [Function Arguments and the type hints](https://unify.ai/docs/ivy/overview/deep_dive/function_arguments.html#function-arguments) for every function **and** its _relevant methods_, including those you did not implement yourself. -5. - [ ] ❌: Add the correct [Docstrings](https://unify.ai/docs/ivy/overview/deep_dive/docstrings.html#docstrings) to every function **and** its _relevant methods_, including those you did not implement yourself. The following should be added: - 1. - [ ] ❌: The function's [Array API standard](https://data-apis.org/array-api/latest/index.html) description in [ivy/functional/{{ .category_name }}.py](https://github.com/unifyai/ivy/blob/main/ivy/functional/ivy/{{ .category_name }}.py). If the function is not part of the Array API standard then a description of similar style should be added to the same file. +4. - [ ] ❌: Correct all of the [Function Arguments and the type hints](https://unify.ai/docs/ivy/overview/deep_dive/function_arguments.html#function-arguments) for every function **and** its _relevant methods_, including those you did not implement yourself. +5. - [ ] ❌: Add the correct [Docstrings](https://unify.ai/docs/ivy/overview/deep_dive/docstrings.html#docstrings) to every function **and** its _relevant methods_, including those you did not implement yourself. The following should be added: + 1. - [ ] ❌: The function's [Array API standard](https://data-apis.org/array-api/latest/index.html) description in [ivy/functional/{{ .category_name }}.py](https://github.com/unifyai/ivy/blob/main/ivy/functional/ivy/{{ .category_name }}.py). If the function is not part of the Array API standard then a description of similar style should be added to the same file. The following modifications should be made to the description: - - [ ] ❌: Remove type definitions in the `Parameters` and `Returns` sections. + - [ ] ❌: Remove type definitions in the `Parameters` and `Returns` sections. - [ ] ❌: Add `out` to the `Parameters` section if function accepts an `out` argument. - [ ] ❌: Replace `out` with `ret` in the `Returns` section. 2. - [ ] ❌: Reference to docstring for ivy.function_name ([5.a](#ref1)) for the function description **and** modified `Parameters` and `Returns` sections as described in [the docs](https://unify.ai/docs/ivy/overview/deep_dive/docstrings.html#docstrings) in: @@ -40,48 +40,48 @@ The [Ivy Docs](https://unify.ai/docs/ivy/) represent the ground truth for the ta - [ ] ❌: [ivy/container/container.py](https://github.com/unifyai/ivy/blob/main/ivy/data_classes/container/container.py) if the function has a special method ( like `__function_name__` ). - [ ] ❌: [ivy/container/container.py](https://github.com/unifyai/ivy/blob/main/ivy/data_classes/container/container.py) if the function has a reverse special method ( like `__rfunction_name__` ). 6. - [ ] ❌: Add thorough [Docstring Examples](https://unify.ai/docs/ivy/overview/deep_dive/docstring_examples.html#docstring-examples) for every function **and** its _relevant methods_ and ensure they pass the docstring tests. - + **Functional Examples** in [ivy/functional/{{ .category_name }}.py](https://github.com/unifyai/ivy/blob/main/ivy/functional/ivy/{{ .category_name }}.py). - 1. - [ ] ❌: Cover all possible variants for each of the arguments independently (not combinatorily). + 1. - [ ] ❌: Cover all possible variants for each of the arguments independently (not combinatorily). 2. - [ ] ❌: Vary the values and input shapes considerably between examples. 3. - [ ] ❌: Start out simple and get more complex with each example. - 4. - [ ] ❌: Show an example with: - - [ ] ❌: `out` unused. + 4. - [ ] ❌: Show an example with: + - [ ] ❌: `out` unused. - [ ] ❌: `out` used to update a new array y. - - [ ] ❌: `out` used to inplace update the input array x (if x has the same dtype and shape as the return). - 5. - [ ] ❌: If broadcasting is relevant for the function, then show examples which highlight this. - + - [ ] ❌: `out` used to inplace update the input array x (if x has the same dtype and shape as the return). + 5. - [ ] ❌: If broadcasting is relevant for the function, then show examples which highlight this. + **Nestable Function Examples** in [ivy/functional/{{ .category_name }}.py](https://github.com/unifyai/ivy/blob/main/ivy/functional/ivy/{{ .category_name }}.py). Only if the function supports nestable operations. - - 6. - [ ] ❌: Add an example that passes in an `ivy.Container` instance in place of one of the arguments. + + 6. - [ ] ❌: Add an example that passes in an `ivy.Container` instance in place of one of the arguments. 7. - [ ] ❌: Add an example passes in `ivy.Container` instances for multiple arguments. - + **Container Static Method Examples** in [ivy/container/{{ .category_name }}.py](https://github.com/unifyai/ivy/blob/main/ivy/data_classes/container/{{ .category_name }}.py). 8. - [ ] ❌: The example from point ([6.f](#ref2)) should be replicated, but added to the `ivy.Container` **static method** docstring in with `ivy.` replaced with `ivy.Container.static_` in the example. 9. - [ ] ❌: The example from point ([6.g](#ref3)) should be replicated, but added to the `ivy.Container` **static method** docstring, with `ivy.` replaced with `ivy.Container.static_` in the example. - + **Array Instance Method Example** in [ivy/array/{{ .category_name }}.py](https://github.com/unifyai/ivy/blob/main/ivy/data_classes/array/{{ .category_name }}). - + 10. - [ ] ❌: Call this instance method of the `ivy.Array` class. - + **Container Instance Method Example** in [ivy/container/{{ .category_name }}.py](https://github.com/unifyai/ivy/blob/main/ivy/data_classes/container/{{ .category_name }}.py). - + 11. - [ ] ❌: Call this instance method of the `ivy.Container` class. - + **Array Operator Examples** in [ivy/array/array.py](https://github.com/unifyai/ivy/blob/main/ivy/data_classes/array/array.py). - + 12. - [ ] ❌: Call the operator on two `ivy.Array` instances. 13. - [ ] ❌: Call the operator with an `ivy.Array` instance on the left and `ivy.Container` on the right. - + **Array Reverse Operator Example** in [ivy/array/array.py](https://github.com/unifyai/ivy/blob/main/ivy/data_classes/array/array.py). - + 14. - [ ] ❌: Call the operator with a `Number` on the left and an `ivy.Array` instance on the right. - + **Container Operator Examples** in [ivy/container/container.py](https://github.com/unifyai/ivy/blob/main/ivy/data_classes/container/container.py). - + 15. - [ ] ❌: Call the operator on two `ivy.Container` instances containing Number instances at the leaves. 16. - [ ] ❌: Call the operator on two `ivy.Container` instances containing `ivy.Array` instances at the leaves. 17. - [ ] ❌: Call the operator with an `ivy.Container` instance on the left and `ivy.Array` on the right. diff --git a/automation_tools/dashboard_automation/update_db.py b/automation_tools/dashboard_automation/update_db.py index 42c3f19cefa84..092e313725f91 100644 --- a/automation_tools/dashboard_automation/update_db.py +++ b/automation_tools/dashboard_automation/update_db.py @@ -22,9 +22,10 @@ def make_clickable(url, name): - return ''.format(name) + return ( + f'' + ) def update_test_results(): @@ -45,7 +46,7 @@ def update_test_results(): res = make_clickable(action_url + run_id, result_config[result]) collection.update_one( {"_id": test_configs[workflow][1]}, - {"$set": {backend + "." + submodule: res}}, + {"$set": {f"{backend}.{submodule}": res}}, upsert=True, ) return diff --git a/determine_test_coverage.py b/determine_test_coverage.py index e2fcbf0ab6691..588399a9ab8e5 100644 --- a/determine_test_coverage.py +++ b/determine_test_coverage.py @@ -17,8 +17,7 @@ test_names = get_all_tests() # Create a Dictionary of Test Names to Index -tests["index_mapping"] = test_names -tests["tests_mapping"] = {} +tests = {"index_mapping": test_names, "tests_mapping": {}} for i in range(len(test_names)): tests["tests_mapping"][test_names[i]] = i @@ -47,7 +46,7 @@ for directory in directories: for file_name in os.listdir(directory): if file_name.endswith("cover"): - file_name = directory + "/" + file_name + file_name = f"{directory}/{file_name}" if file_name not in tests: tests[file_name] = [] with open(file_name) as f: diff --git a/determine_tests.py b/determine_tests.py index d3a94ddc91fab..c422d2bee1e2e 100644 --- a/determine_tests.py +++ b/determine_tests.py @@ -44,15 +44,15 @@ def main(): modified_files = commit._parse_diff(diff_index) for file in modified_files: try: - file_name = file.new_path + ",cover" + file_name = f"{file.new_path},cover" except: # noqa continue if file_name not in tests.keys(): continue tests_file = tests[file_name] change = file.diff_parsed - added = set([x - 1 for (x, _) in change["added"]]) - deleted = set([x - 1 for (x, _) in change["deleted"]]) + added = {x - 1 for (x, _) in change["added"]} + deleted = {x - 1 for (x, _) in change["deleted"]} updated = added.intersection(deleted) added = added.difference(updated) deleted = deleted.difference(updated) @@ -121,9 +121,8 @@ def main(): relevant_added_tests.append(test) break added_tests = relevant_added_tests - else: - if len(added_tests) > 50: - added_tests = added_tests[:50] + elif len(added_tests) > 50: + added_tests = added_tests[:50] # Add these new_tests in the Mapping old_num_tests = len(old_tests) tests["index_mapping"] += added_tests diff --git a/docs/_templates/top_data_toc.rst b/docs/_templates/top_data_toc.rst index fd991a4641b6c..5d6c8bb764121 100644 --- a/docs/_templates/top_data_toc.rst +++ b/docs/_templates/top_data_toc.rst @@ -6,9 +6,9 @@ {% block options %}{{ super() }} :hide-table: {% endblock %} -{# - As this toc generates files a little differently, we added this to fix linking - issues +{# + As this toc generates files a little differently, we added this to fix linking + issues #} {% block custom_content %} .. autosummary:: diff --git a/docs/compiler/setting_up.rst b/docs/compiler/setting_up.rst deleted file mode 100644 index 8fb9a17bbcb2c..0000000000000 --- a/docs/compiler/setting_up.rst +++ /dev/null @@ -1,35 +0,0 @@ -Setting Up -========== - -To use Ivy's compiler and transpiler, you'll need an **API key**. We are starting to -grant pilot access to certain users, so you can `join the waitlist `_ -if you want to get one! - -Ivy Folder ----------- - -When importing Ivy for the first time, a ``.ivy`` folder will be created in your -working directory. If you want to keep this folder in a different location, -you can set an ``IVY_ROOT`` environment variable with the path of your ``.ivy`` folder. - -Setting Up the API key ----------------------- - -Once the ``.ivy`` folder has been created (either manually or automatically by -importing Ivy), you will have to paste your API key as the content of the ``key.pem`` file. -For reference, this would be equivalent to: - -.. code-block:: console - - echo -n API_KEY > .ivy/key.pem - -Issues and Questions --------------------- - -If you find any issue or bug while using the compiler and/or the transpiler, please -raise an `issue in GitHub `_ and add the ``compiler`` -or the ``transpiler`` label accordingly. A member of the team will get back to you ASAP! - -Otherwise, if you haven't found a bug but want to ask a question, suggest something, or get help -from the team directly, feel free to open a new post at the ``pilot-access`` forum in -`Ivy's discord server! `_ \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index dbfd99969b2bb..e88f08969b560 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -3,25 +3,40 @@ .. include:: ../README.md :parser: myst_parser.sphinx_ + +.. toctree:: + :hidden: + :maxdepth: -1 + + Home + + .. toctree:: :hidden: :maxdepth: -1 - :caption: Overview + :caption: The Basics overview/get_started.rst - Examples - overview/glossary.rst - overview/faq.rst + demos/quickstart.ipynb .. toctree:: :hidden: :maxdepth: -1 - :caption: Users + :caption: Demos + + demos/learn_the_basics.rst + demos/guides.rst + demos/examples_and_demos.rst - overview/background.rst + +.. toctree:: + :hidden: + :maxdepth: -1 + :caption: Background + + overview/motivation.rst overview/related_work.rst - overview/extensions.rst .. toctree:: @@ -32,22 +47,21 @@ overview/design.rst overview/contributing.rst overview/deep_dive.rst + overview/glossary.rst + overview/faq.rst .. toctree:: :hidden: :maxdepth: -1 - :caption: Compiling and Transpiling + :caption: API Reference - compiler/setting_up.rst - compiler/compiler.rst - compiler/transpiler.rst + overview/one_liners.rst .. autosummary:: :toctree: docs/functional :template: top_functional_toc.rst - :caption: API Reference :recursive: :hide-table: diff --git a/docs/overview/contributing.rst b/docs/overview/contributing.rst index c5b3b1aca787c..6e988bba87b52 100644 --- a/docs/overview/contributing.rst +++ b/docs/overview/contributing.rst @@ -15,28 +15,25 @@ We want our ML unification journey to be as inclusive as possible, this is all o The contributor guide is split into the sections below, it's best to go from start to finish, but you can also dive in at any stage! We're excited for you to get involved! 🦾 -| (a) `Setting Up `_ +| (a) `Setting Up `_ | Building the right environment 🏛️ | -| (b) :ref:`The Basics` +| (b) `The Basics `_ | Managing your fork 🇾, creating issues ⭕, and creating pull-requests ⬆️ | -| (c) :ref:`Building the Docs` +| (c) `Building the Docs `_ | How to build the documentation locally 🏗️ | -| (d) :ref:`Deep Dive` +| (d) `Deep Dive `_ | Take a deep dive into the codebase 🤿 | -| (e) :ref:`Open Tasks` +| (e) `Open Tasks `_ | See where you can help us out! 🙋 | -| (f) :ref:`Applied Libraries` -| Getting started with our applied libraries! 📚 -| -| (g) :ref:`Helpful Resources` +| (f) `Helpful Resources `_ | Resources you would find useful when learning Ivy 📖 -| -| (g) :ref:`Error Handling` +| +| (g) `Error Handling `_ | Common errors you will be facing contributing to Ivy ❌ .. toctree:: @@ -49,7 +46,6 @@ The contributor guide is split into the sections below, it's best to go from sta contributing/building_the_docs.rst Deep Dive contributing/open_tasks.rst - contributing/applied_libraries.rst contributing/helpful_resources.rst contributing/error_handling.rst diff --git a/docs/overview/contributing/applied_libraries.rst b/docs/overview/contributing/applied_libraries.rst deleted file mode 100644 index 11de60c24d5cb..0000000000000 --- a/docs/overview/contributing/applied_libraries.rst +++ /dev/null @@ -1,142 +0,0 @@ -Applied Libraries -================= - -.. _`Ivy Robot`: https://unify.ai/docs/robot/ -.. _`Mech`: https://unify.ai/docs/mech/ -.. _`Vision`: https://unify.ai/docs/vision/ -.. _`Demo Utils`: https://github.com/unifyai/demo-utils -.. _`Ivy`: https://github.com/unifyai/ivy -.. _`Docker Desktop`: https://www.docker.com/products/docker-desktop/ -.. _`discord`: https://discord.gg/sXyFF8tDtm -.. _`pycharm channel`: https://discord.com/channels/799879767196958751/942114831039856730 -.. _`docker channel`: https://discord.com/channels/799879767196958751/942114744691740772 -.. _`pre-commit channel`: https://discord.com/channels/799879767196958751/982725464110034944 -.. _`pip packages channel`: https://discord.com/channels/799879767196958751/942114789642080317 - -Introduction ------------- - -Helping to contribute towards the ivy libraries requires a slightly more complex setup than is needed for contributing to ivy alone. -For instance, `Ivy Robot`_ depends on `Mech`_, `Vision`_ and `Demo Utils`_. -Thus, the related repositories have to be pulled into the same local folder, and `Ivy`_ must also be pulled into this same folder. - -To have a better grasp, let's look at an example of Ivy Robot in the next section! - -Example - Ivy Robot -------------------- - -**Directory Tree** - -1. Due to dependencies, the related Ivy repositories have to be placed in the same local directory: - -.. code-block:: none - - |-- your-local-dir - | |-- ivy - | |-- mech - | |-- vision - | |-- robot - | |-- demo-utils - -2. Clone all repositories into a mutual directory: - - .. code-block:: none - - git clone https://github.com/unifyai/ivy.git - - .. code-block:: none - - git clone https://github.com/unifyai/mech.git - - .. code-block:: none - - git clone https://github.com/unifyai/vision.git - - .. code-block:: none - - git clone https://github.com/unifyai/robot.git - - .. code-block:: none - - git clone https://github.com/unifyai/demo-utils.git - -3. The next steps will depend on your type of development. - -**Local Development** - -1. Create a virtual environment (venv) in the same directory: - - .. code-block:: none - - python3 -m venv ivy_dev - -2. Activate the environment: - - (on Windows) - .. code-block:: none - - ivy_dev\Scripts\activate.bat - - (on Mac/Linux) - .. code-block:: none - - source ivy_dev/bin/activate - -3. Go into each directory and install packages in develop/editable mode: - - .. code-block:: none - - cd ivy - python3 -m pip install --user -e . - - (repeat for all repositories) - - **NOTE:** In develop mode, packages are linked to their local directory. - Therefore, changes or edits are reflected immediately when in use. - -4. To use: - - .. code-block:: none - - python3 - - .. code-block:: python - - import ivy_robot - -**Docker Development** - -1. Install `Docker Desktop`_ - -2. Go into the :code:`robot` repository and build the docker image: - - .. code-block:: none - - cd robot - docker build -t my-robot . - -3. To use, first mount the local directories, then start up :code:`python3` with Docker: - - (in the folder containing all repositories) - .. code-block:: none - - docker run --rm -it -v `pwd`/ivy:/ivy -v `pwd`/mech:/mech -v `pwd`/vision:/vision -v `pwd`/robot:/robot -v `pwd`/demo-utils:/demo-utils my-robot python3 - - **NOTE:** Mounting allows the docker container to use local folder as volumes, thus reflecting the local changes or edits made. - Users are not required to rebuild the docker image after every change. - -**IDE Development** - -1. For **PyCharm**, configurations are saved in the :code:`.idea` folder (part of the ivy repo). - -2. For **VSCode**, configurations can be found in the :code:`.devcontainer` folder (not part of the ivy repo). - -**NOTE:** To use the development container in VSCode, the extension "Remote - Containers" needs to be installed. - -**NOTE:** When using GitHub Codespaces, the :code:`mounts` config in :code:`.devcontainer/devcontainer.json` is not supported. - -**Round Up** - -These examples should hopefully give you a good understanding of what is required when developing the Ivy applied libraries. - -If you have any questions, please feel free to reach out on `discord`_ in the `pycharm channel`_, `docker channel`_, `pre-commit channel`_, `pip packages channel`_ or `other channel`_, depending on the question! diff --git a/docs/overview/contributing/building_the_docs.rst b/docs/overview/contributing/building_the_docs.rst index 680e21b099089..12f59b1dbd618 100644 --- a/docs/overview/contributing/building_the_docs.rst +++ b/docs/overview/contributing/building_the_docs.rst @@ -2,7 +2,8 @@ Building the Docs ================= This document describes how to build the Ivy docs. If you want to know more about how -our custom building pipeline work, check our :ref:`Building the Docs Pipeline` deep dive +our custom building pipeline work, check our `Building the Docs Pipeline +<../deep_dive/building_the_docs_pipline.rst>`_ deep dive Building the Docs using Docker ------------------------------ @@ -22,20 +23,20 @@ This script will build the docs for Ivy and store it in ``docs/build``. Using existing image on Docker Hub ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -You can also use the ``unifyai/doc-builder`` image hosted on -`Docker Hub `_ to build the -docs. This will be helpful if you want to build the docs for Ivy applied libraries. +You can also use the ``unifyai/doc-builder`` image hosted on +`Docker Hub `_ to build the +docs. -Run ``docker run`` to build the docs. The following command will build the docs for +Run ``docker run`` to build the docs. The following command will build the docs for the project in the current directory and output them to ``docs/build``. .. code-block:: bash - cd + cd docker run --rm -v $(pwd):/project unifyai/doc-builder This command will mount the module directory to ``/project`` in the container, the -current directory can be the root of ``ivy`` or any applied library such as ``mech``. +current directory should be the root of ``ivy``. Building the image locally ~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -66,7 +67,7 @@ Building the Docs without Docker -------------------------------- You can also build the docs without Docker. You will first need to clone the -``unifyai/doc-builder`` repository. Then use the convenience script +``unifyai/doc-builder`` repository. Then use the convenience script ``make_docs_without_docker.sh``. Run this command if you are using HTTPS: @@ -86,8 +87,8 @@ Then, run the following command to build the docs: .. code-block:: bash cd doc-builder - ./make_docs_without_docker.sh + ./make_docs_without_docker.sh The script will install the required dependencies for `sphinx `_ -which is used to build the docs, as well as dependencies required by Ivy or the Ivy -applied library. Then it will build the docs for Ivy and store it in ``docs/build``. +which is used to build the docs, as well as dependencies required by Ivy. Then it will +build the docs for Ivy and store it in ``docs/build``. diff --git a/docs/overview/contributing/error_handling.rst b/docs/overview/contributing/error_handling.rst index 44f4df41fee1b..ff4bbefbf7d8d 100644 --- a/docs/overview/contributing/error_handling.rst +++ b/docs/overview/contributing/error_handling.rst @@ -10,7 +10,7 @@ Error Handling This section, "Error Handling" aims to assist you in navigating through some common errors you might encounter while working with the Ivy's Functional API. We'll go through some common errors which you might encounter while working as a contributor or a developer. -#. This is the case where we pass in a dtype to `torch` which is not actually supported by the torch's native framework itself. The function which was +#. This is the case where we pass in a dtype to `torch` which is not actually supported by the torch's native framework itself. The function which was .. code-block:: python @@ -33,17 +33,17 @@ This section, "Error Handling" aims to assist you in navigating through some com E ), E fn_name='logaddexp2', E ) - E + E E You can reproduce this example by temporarily adding @reproduce_failure('6.82.4', b'AXicY2BkAAMoBaaR2WAAAACVAAY=') as a decorator on your test case #. This is the case where the value from the ground-truth backend(tensorflow) does not match the value of the backend(jax) we are testing for this case. .. code-block:: python - + E AssertionError: the results from backend jax and ground truth framework tensorflow do not match - E 0.25830078125!=0.258544921875 - E - E + E 0.25830078125!=0.258544921875 + E + E E Falsifying example: test_acosh( E backend_fw='jax', E on_device='cpu', @@ -61,7 +61,7 @@ This section, "Error Handling" aims to assist you in navigating through some com E ), E fn_name='acosh', E ) - E + E E You can reproduce this example by temporarily adding @reproduce_failure('6.82.4', b'AXicY2BAABYQwQgiAABDAAY=') as a decorator on your test case #. This is a similar assertion as stated in point 2 but with torch and ground-truth tensorflow not matching but the matrices are quite different so there should be an issue in the backends rather than a numerical instability here: @@ -73,9 +73,9 @@ This section, "Error Handling" aims to assist you in navigating through some com E [1.41421356 1.41421356 1.41421356] E [1.41421356 inf 1.41421356]]!=[[1.41421356e+000 1.41421356e+000 1.41421356e+000] E [1.41421356e+000 1.41421356e+000 1.41421356e+000] - E [1.41421356e+000 1.34078079e+154 1.41421356e+000]] - E - E + E [1.41421356e+000 1.34078079e+154 1.41421356e+000]] + E + E E Falsifying example: test_abs( E backend_fw='torch', E on_device='cpu', @@ -96,7 +96,7 @@ This section, "Error Handling" aims to assist you in navigating through some com E container=[False], E ), E ) - E + E E You can reproduce this example by temporarily adding @reproduce_failure('6.82.4', b'AXicY2ZkYAIiBiBgZIAAxqHEXsAAB7jUQAAAMtEAzQ==') as a decorator on your test case diff --git a/docs/overview/contributing/open_tasks.rst b/docs/overview/contributing/open_tasks.rst index 186b3ee912cc0..081f7946d6488 100644 --- a/docs/overview/contributing/open_tasks.rst +++ b/docs/overview/contributing/open_tasks.rst @@ -7,13 +7,14 @@ Open Tasks .. _`issue description`: https://github.com/unifyai/ivy/issues/1526 .. _`reference API`: https://numpy.org/doc/stable/reference/routines.linalg.html .. _`imports`: https://github.com/unifyai/ivy/blob/38dbb607334cb32eb513630c4496ad0024f80e1c/ivy/functional/frontends/numpy/__init__.py#L27 +.. _`Deep Dive`: ../deep_dive.rst Here, we explain all tasks which are currently open for contributions from the community! This section of the docs will be updated frequently, whereby new tasks will be added and completed tasks will be removed. The tasks outlined here are generally broad high-level tasks, each of which is made up of many individual sub-tasks, distributed across task-specific `ToDo List Issues `_. -Please read about `ToDo List Issues `_ in detail before continuing. +Please read about :ref:`overview/contributing/the_basics:ToDo List Issues` in detail before continuing. All tasks should be selected and allocated as described in the ToDo List Issues section. We make no mention of task selection and allocation in the explanations below, which instead focus on the steps to complete only once a sub-task has been allocated to you. @@ -32,7 +33,7 @@ Function Formatting Currently, we have many ToDo list issues `open `_ for a general function formatting task, which is explained below. -Each function in each submodule should be updated to follow the implementation instructions given in the :ref:`Deep Dive` section. +Each function in each submodule should be updated to follow the implementation instructions given in the `Deep Dive`_ section. The updates should be applied for the: #. ivy API @@ -44,30 +45,30 @@ The updates should be applied for the: #. container operators #. container reverse operators -The :ref:`Deep Dive` is an **essential** resource for learning how each of these functions/methods should be implemented. -Before starting any contribution task, you should go through the :ref:`Deep Dive`, and familiarize yourself with the content. +The `Deep Dive`_ is an **essential** resource for learning how each of these functions/methods should be implemented. +Before starting any contribution task, you should go through the `Deep Dive`_, and familiarize yourself with the content. At the time of writing, many of the functions are not implemented as they should be. -You will need to make changes to the current implementations, but you do not need to address *all* sections of the :ref:`Deep Dive` in detail. +You will need to make changes to the current implementations, but you do not need to address *all* sections of the `Deep Dive`_ in detail. Specifically, you **do not** need to address the following: #. Implement the hypothesis testing for the function #. Get the tests passing for your function, if they are failing before you start -However, everything else covered in the :ref:`Deep Dive` must be addressed. +However, everything else covered in the `Deep Dive`_ must be addressed. Some common important tasks are: #. Remove all :code:`lambda` and direct bindings for the backend functions (in :code:`ivy.functional.backends`), with each function instead defined using :code:`def`. #. Implement the following if they don't exist but should do: :class:`ivy.Array` instance method, :class:`ivy.Container` instance method, :class:`ivy.Array` special method, :class:`ivy.Array` reverse special method, :class:`ivy.Container` special method, :class:`ivy.Container` reverse special method. #. Make sure that the aforementioned methods are added into the correct category-specific parent class, such as :class:`ivy.ArrayWithElementwise`, :class:`ivy.ContainerWithManipulation` etc. -#. Correct all of the :ref:`Function Arguments` and the type hints for every function **and** its *relevant methods*, including those you did not implement yourself. -#. Add the correct :ref:`Docstrings` to every function **and** its *relevant methods*, including those you did not implement yourself. -#. Add thorough :ref:`Docstring Examples` for every function **and** its *relevant methods* and ensure they pass the docstring tests. +#. Correct all of the `Function Arguments <../deep_dive/function_arguments.rst>`_ and the type hints for every function **and** its *relevant methods*, including those you did not implement yourself. +#. Add the correct `Docstrings <../deep_dive/docstrings.rst>`_ to every function **and** its *relevant methods*, including those you did not implement yourself. +#. Add thorough `Docstring Examples <../deep_dive/docstring_examples.rst>`_ for every function **and** its *relevant methods* and ensure they pass the docstring tests. Formatting checklist ~~~~~~~~~~~~~~~~~~~~ -After creating your Pull Request on github, you should then produce the checklist for the formatting task as follows: +After creating your Pull Request on github, you should then produce the checklist for the formatting task as follows: 1. Add a comment with the following format: :code:`add_reformatting_checklist_` on your PR, where ** is the name of the category that the function belongs to. An example of this is shown below. @@ -94,7 +95,7 @@ The PR assignee will then see this comment and address your issues. .. image:: https://github.com/unifyai/unifyai.github.io/blob/main/img/externally_linked/contributing/open_tasks/checklist_SOS.png?raw=true :width: 420 -**Notes**: +**Notes**: 1. It is important that the PR author is the one to add the checklist generating comment in order to ensure they will have access to edit and update it later. 2. The checklist items' statuses should be manually updated by the PR author. @@ -106,15 +107,15 @@ The PR assignee will then see this comment and address your issues. Frontend APIs ------------- -For this task, the goal will be to implement functions for each of the frontend functional APIs (see :ref:`Ivy as a Transpiler`), with frontend APIs implemented for: :code:`JAX`, :code:`NumPy`, :code:`TensorFlow` :code:`PyTorch`, :code:`Paddle`, :code:`Scipy`, :code:`MXNet` and :code:`MindSpore`. +For this task, the goal will be to implement functions for each of the frontend functional APIs (see `Ivy as a Transpiler <../design/ivy_as_a_transpiler.rst>`_), with frontend APIs implemented for: :code:`JAX`, :code:`NumPy`, :code:`TensorFlow` :code:`PyTorch`, :code:`Paddle`, :code:`Scipy`, :code:`MXNet` and :code:`MindSpore`. Currently, we have many ToDo list issues `open `_ for this task. The general workflow for this task is: -#. Find the correct location for the function by following the :ref:`Where to place a frontend function` subsection below -#. Implement the function by following the :ref:`Ivy Frontends` guide -#. Write tests for your function by following the :ref:`Ivy Frontend Tests` guide +#. Find the correct location for the function by following the :ref:`overview/contributing/open_tasks:Where to place a frontend function` subsection below +#. Implement the function by following the `Ivy Frontends <../deep_dive/ivy_frontends.rst>`_ guide +#. Write tests for your function by following the `Ivy Frontend Tests <../deep_dive/ivy_frontends_tests.rst>`_ guide #. Verify that the tests for your function are passing If you feel as though there is an ivy function :code:`ivy.` clearly missing, which would make your frontend function much simpler to implement, then you should first do the following: @@ -128,7 +129,7 @@ At some point, a member of our team will assess whether it should be added, and After this, you then have two options for how to proceed: -#. Try to implement the function as a composition of currently present ivy functions, as explained in the "Temporary Compositions" sub-section of the :ref:`Ivy Frontends` guide, and add the :code:`#ToDo` comment in the implementation as explained. +#. Try to implement the function as a composition of currently present ivy functions, as explained in the :ref:`overview/deep_dive/ivy_frontends:Short Frontend Implementations` sub-section of the `Ivy Frontends <../deep_dive/ivy_frontends.rst>`_ guide, and add the :code:`#ToDo` comment in the implementation as explained. Once the PR is merged, your sub-task issue will then be closed as normal. #. Alternatively, if you do not want to try and implement the frontend function compositionally, or if this is not feasible, then you can simply choose another frontend function to work on. You could also choose to work on another open task entirely at this point if you wanted to. @@ -217,7 +218,7 @@ However, you can still use the checklist as a reference in cases where you do un **Notes**: -1. More details on how to update the checklist items can be found in the :ref:`Formatting checklist` part of our docs. +1. More details on how to update the checklist items can be found in the :ref:`overview/contributing/open_tasks:Formatting checklist` part of our docs. 2. Do not edit the checklist text, only the emoji symbols. 3. Please refrain from using the checkboxes next to checklist items. @@ -233,29 +234,29 @@ There is only one central ToDo list `issue `_ -#. Every function will have a different file structure according to the function type, refer to :ref:`Where to place a backend function` subsection below. -#. Implement the container instance method in :mod:`ivy/container/experimental/[relevant_submodule].py` and the array instance method +#. Analyze the function type, we have a very detailed section for it in the deep dive `Function Types Guide <../deep_dive/function_types.rst>`_ +#. Every function will have a different file structure according to the function type, refer to :ref:`overview/contributing/open_tasks:Where to place a backend function` subsection below. +#. Implement the container instance method in :mod:`ivy/container/experimental/[relevant_submodule].py` and the array instance method in :mod:`ivy/array/experimental/[relevant_submodule].py` -#. Write tests for the function using the :ref:`Ivy Tests` guide, and make sure they are passing. +#. Write tests for the function using the `Ivy Tests <../deep_dive/ivy_tests.rst>`_ guide, and make sure they are passing. A few points to keep in mind while doing this: #. Make sure all the positional arguments are positional-only and optional arguments are keyword-only. #. In case some tests require function-specific parameters, you can create composite hypothesis strategies using the :code:`draw` function in the hypothesis library. -If you’re stuck on a function which requires complex compositions, feel free to reselect a function +If you’re stuck on a function which requires complex compositions, feel free to reselect a function Extending the Ivy API ~~~~~~~~~~~~~~~~~~~~~~~ -We primarily invite contributors to work on the tasks listed as :ref:`Open Tasks`, as these are on our current roadmap. As a result of this, we prompt everyone interested in contributing to our Experimental API to do so under the `Ivy Experimental API Open Task`_. +We primarily invite contributors to work on the tasks listed as :ref:`overview/contributing/open_tasks:Open Tasks`, as these are on our current roadmap. As a result of this, we prompt everyone interested in contributing to our Experimental API to do so under the :ref:`Ivy Experimental API Open Task `. -However, if you would like to extend Ivy's functionality with a new function, you are invited to open an issue using the *Missing Function Suggestion* template as described in `Creating an Issue on Ivy’s GitHub using a Template `_. +However, if you would like to extend Ivy's functionality with a new function, you are invited to open an issue using the *Missing Function Suggestion* template as described in :ref:`overview/contributing/open_tasks:Creating an Issue on Ivy's GitHub using a Template`. In this template form, you'll be asked to fill in the reason you think we should implement the suggested function, as well as the links to any native implementations of the suggested function. -We will review your issue as soon as possible and let you know if it's been accepted or not. In case we deem that the suggested function fits our roadmap, we will add it as a subtask to the `Ivy Experimental API Open Task`_. +We will review your issue as soon as possible and let you know if it's been accepted or not. In case we deem that the suggested function fits our roadmap, we will add it as a subtask to the `Ivy Experimental API Open Task `_. Where to place a backend function ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -267,15 +268,15 @@ There are multiple types of backend functions as discussed above, we will go thr **Primary Functions** Implement the function in :mod:`ivy/functional/ivy/experimental/[relevant_submodule].py` simply deferring to their backend-specific implementation -(where ivy.current_backend(x).function_name() is called), refer to the `Ivy API Guide `_ +(where ivy.current_backend(x).function_name() is called), refer to the :ref:`Ivy API Guide ` to get a clearer picture of how this must be done. Then, implement the functions in each of the backend files :mod:`ivy/functional/backends/backend_name/experimental/[relevant_submodule].py`, -you can refer to the `Backend API Guide `_ for this. +you can refer to the :ref:`Backend API Guide ` for this. **Compositional Functions** Implement the function in :mod:`ivy/functional/ivy/experimental/[relevant_submodule].py`, we will not use the primary function approach in this case, the implementation will be a composition of functions from Ivy's functional API. You can refer to -`Compositional Functions Guide `_ for a better understanding of this. +:ref:`overview/deep_dive/function_types:Compositional Functions` for a better understanding of this. You don't need to add any implementation in any other file in this case. **Mixed Functions** @@ -287,8 +288,8 @@ will be a composition of functions from Ivy's functional API. After you are done **Other Function Types** -`Standalone Functions `_, `Nestable Functions `_ and -`Convenience Functions `_ are the ones which you will rarely come across +:ref:`overview/deep_dive/function_types:Standalone Functions`, :ref:`overview/deep_dive/function_types:Nestable Functions` and +:ref:`overview/deep_dive/function_types:Convenience Functions` are the ones which you will rarely come across while implementing a function from the ToDo List but they are an essential part of the Ivy API. diff --git a/docs/overview/contributing/setting_up.rst b/docs/overview/contributing/setting_up.rst index 6987e5e046ee1..188f1d04d095b 100644 --- a/docs/overview/contributing/setting_up.rst +++ b/docs/overview/contributing/setting_up.rst @@ -107,7 +107,7 @@ Using miniconda #. Create the environment by running the command (:code:`ivy_dev` is the name of the environment) .. code-block:: none - + conda create --name ivy_dev python=3.10.0 #. Activate the environment by: @@ -129,27 +129,27 @@ Using miniconda a. Going to settings -> project -> Python Interpreter b. Clicking add interpreter (currently by clicking the ⚙ icon on the right side) which should open a new window. - + c. Choosing "conda environment" from the left panel. Choose the existing environment and select the drop down and you should find the path python in the environment. #. VSCode a. Go to the command palette (Ctrl+Shift+P) or (⌘+shift+p) for Mac and type "Python: Select Interpreter" and select the environment you created. - + If you don't find a path to your created python environment, you can run :code:`where python` in the conda command line while the environment is activate and it should give the path which can be added manually. #. Installing the development dependencies. a. On Linux, Windows, or Intel Mac, you will need to use the `optional.txt` requirements file. To install dependencies. - + .. code-block:: none - + pip install -r requirements/optional.txt - + b. On M1 Mac, you will need to use the optional_apple_silicon_1 and optional_apple_silicon_2 requirements files. To install dependencies. - + .. code-block:: none - + pip install -r requirements/optional_apple_silicon_1.txt pip install -r requirements/optional_apple_silicon_2.txt @@ -208,26 +208,26 @@ This is a builtin package and doesn't require explicit installation. a. Go to the command palette (Ctrl+Shift+P) or (⌘+shift+p) for Mac and type `Python: Select Interpreter` and select the environment you created. #. Installing the development dependencies. - + a. On Linux, Windows, or Intel Mac, you will need to use the `optional.txt` requirements file. To install dependencies. - + .. code-block:: none - + pip install -r requirements/optional.txt - Note: In case you are using Ubuntu 22.04, PaddlePaddle won't install properly. You have to download it from the source. - + Note: In case you are using Ubuntu 22.04, PaddlePaddle won't install properly. You have to download it from the source. + .. code-block:: none - + wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2.19_amd64.deb sudo dpkg -i libssl1.1_1.1.1f-1ubuntu2.19_amd64.deb - + PS: If the link gets expired at some point in the future, check http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/?C=M;O=D for a valid one. b. On M1 Mac, you will need to use the optional_apple_silicon_1 and optional_apple_silicon_2 requirements files. To install dependencies. - + .. code-block:: none - + pip install -r requirements/optional_apple_silicon_1.txt pip install -r requirements/optional_apple_silicon_2.txt @@ -423,14 +423,14 @@ Ubuntu **Docker Connection not Successfull** This is a common error which you might face. If you are not successfully able to connect docker with Pycharm(point 4a) and your docker is also running, the issue is that you are not able to use your docker socket. So, executing the below two commands should solve this. - + .. code-block:: none - + sudo chmod a+rwx /var/run/docker.sock - + .. code-block:: none - - sudo chmod a+rwx /var/run/docker.pid + + sudo chmod a+rwx /var/run/docker.pid For questions, please reach out on `discord`_ in the `docker channel`_! @@ -601,14 +601,14 @@ For windows users, the file path should be entered with "/" (forward-slashes), f WSL *** - + It is understandable that working with computationally heavy tools like Docker and PyCharm is not always comfortable for developers. -By utilizing WSL, you can run a Linux distribution on your Windows machine, and in addition, venv is leveraged to create -isolated Python environments eliminating the need for a full-fledged containerization solution like Docker, and with VSCode being an appropriate alternative to PyCharm, +By utilizing WSL, you can run a Linux distribution on your Windows machine, and in addition, venv is leveraged to create +isolated Python environments eliminating the need for a full-fledged containerization solution like Docker, and with VSCode being an appropriate alternative to PyCharm, the steps explained below will help you in setting up a less resource-intensive Ivy environment. #. Install `WSL `_. -#. Install `Visual Studio Code `_. +#. Install `Visual Studio Code `_. You can follow `this guide `_ to integrate WSL into VSCode. #. Open the WSL terminal by typing in the name of your Linux distribution in the windows start menu (e.g. :code:`Ubuntu`). #. Create a virtual environment by following the steps below: @@ -638,7 +638,7 @@ the steps explained below will help you in setting up a less resource-intensive pip install git+https://github.com/unifyai/ivy.git -#. If you want to set up a local repository, you can do so by following `this guide `_ +#. If you want to set up a local repository, you can do so by following :ref:`this guide ` as explained above and install the required development dependencies by running: .. code-block:: none @@ -650,12 +650,12 @@ the steps explained below will help you in setting up a less resource-intensive pip install -r requirements/requirements.txt #. Once done, you can now open VSCode right from your terminal and get started with your development by just running: - + .. code-block:: none code . -#. To set up the Python Interpreter in VSCode, go to the command palette (Ctrl+Shift+P) and type **Python: Select Interpreter** and select the environment you created. +#. To set up the Python Interpreter in VSCode, go to the command palette (Ctrl+Shift+P) and type **Python: Select Interpreter** and select the environment you created. For a more detailed explanation, you can follow `this guide `_. #. Now that your development environment is set up, you can now run tests locally by running :code:`pytest test_fle_path::test_fn_name` in the terminal or if you want to set up testing in VSCode, you may follow the guide **Setting Up Testing** for VSCode as explained below, next to this subsection. @@ -700,14 +700,14 @@ Just follow the steps outlined below: :width: 420 3. Then you will head to the dropdown of "Dev container configuration", then select an image to set up with. As there are six options available as of now - + - :code:`Default project configuration` - This is the default option, it will set up with the default codespaces environment. - :code:`Ivy Development Environment (build)` - This will set up the development environment of ivy for CPU and build image from :code:`ivy/docker/Dockerfile`. - :code:`Ivy GPU Development Environment (build)` - This will set up the development environment of ivy for GPU and build image from :code:`ivy/docker/DockerfileGPU`. - :code:`Ivv Development Environment for Multiver...` - This will set up the development environment of multiversion support with ivy and build image from :code:`ivy/docker/DockerfileMultiversion`. - :code:`Ivy Development Environment (image)` - This will set up the development environment of ivy for CPU and build image from the latest image from dockerhub. - :code:`Ivy GPU Development Environment (image)` - This will set up the development environment of ivy for GPU and build image from the latest image from dockerhub. - + For now, we will select :code:`Ivy Development Environment (image)`. Select your region and preferred machine type, then click on "Create Codespace". @@ -745,7 +745,7 @@ The configuration files install all the required packages, and extensions for yo If you want to setup a GPU instance on codespaces and also have access to it, kindly follow the guidelines below: -1. Points 1 and 2 are the same from ref:`Setting up Codespaces` section above. You will be on a screen shown below. Just select the Machine Type to be "6-Core (1 GPU)". +1. Points 1 and 2 are the same from ref:`Setting up Codespaces` section above. You will be on a screen shown below. Just select the Machine Type to be "6-Core (1 GPU)". .. image:: https://raw.githubusercontent.com/unifyai/unifyai.github.io/main/img/externally_linked/contributing/setting_up/github_codespaces/Selecting_the_GPU.png?raw=true :width: 420 diff --git a/docs/overview/contributing/the_basics.rst b/docs/overview/contributing/the_basics.rst index 2965789636ba0..96de0bfc02e9c 100644 --- a/docs/overview/contributing/the_basics.rst +++ b/docs/overview/contributing/the_basics.rst @@ -10,7 +10,6 @@ The Basics .. _`commit frequency channel`: https://discord.com/channels/799879767196958751/982728822317256712 .. _`PyCharm blog`: https://www.jetbrains.com/help/pycharm/finding-and-replacing-text-in-file.html .. _`Debugging`: https://www.jetbrains.com/help/pycharm/debugging-code.html -.. _`Ivy Experimental API Open Task`: https://unify.ai/docs/ivy/overview/contributing/open_tasks.html#ivy-experimental-api Getting Help ------------ @@ -54,7 +53,7 @@ We make extensive use of `ToDo list issues `_, `frontend APIs `_ and `ivy experimental API `_. +a. Find a task to work on which (i) is not marked as completed with a tick (ii) does not have an issue created and (iii) is not mentioned in the comments. Currently, there are three open tasks: :ref:`overview/contributing/open_tasks:Function Formatting`, :ref:`overview/contributing/open_tasks:Frontend APIs` and :ref:`overview/contributing/open_tasks:Ivy Experimental API`. b. Create a new issue with the title being just the name of the sub-task you would like to work on. @@ -71,7 +70,7 @@ d. Start working on the task, and open a PR as soon as you have a full or partia :code:`Close #Issue_number` - This is important, so that the merging of your PR will automatically close the associated issue. Make sure this is in the + This is important, so that the merging of your PR will automatically close the associated issue. Make sure this is in the description of the PR, otherwise it might not link correctly. If you have a partial solution, the Ivy team can help to guide you through the process of getting it working 🙂 Also, remember to make the PR name well described and if there are some details that can support your changes add them to the description of the PR. @@ -340,16 +339,16 @@ With Docker #. With PyCharm (With or without docker): 1. PyCharm enables users to run pytest using the green button present near every function declaration inside the :code:`ivy_tests` folder. - + .. image:: https://raw.githubusercontent.com/unifyai/unifyai.github.io/main/img/externally_linked/contributing/the_basics/pytest_with_pycharm/pytest_button_pycharm.png?raw=true :width: 420 - + 2. Testing can be done for the entire project, individual submodules, individual files, and individual tests. This can be done by selecting the appropriate configuration from the top pane in PyCharm. - + .. image:: https://raw.githubusercontent.com/unifyai/unifyai.github.io/main/img/externally_linked/contributing/the_basics/pytest_with_pycharm/pytest_with_pycharm.png?raw=true :width: 420 - + #. Through the command line (With docker): 1. We need to replace the folder inside the container with the current local ivy directory to run tests on the current local code. @@ -357,29 +356,29 @@ With Docker .. code-block:: none docker exec rm -rf ivy - docker cp ivy :/ + docker cp ivy :/ 2. We need to then enter inside the docker container and change into the :code:`ivy` directory using the following command. .. code-block:: none - docker exec -it ivy_container bash + docker exec -it ivy_container bash cd ivy 3. Run the test using the pytest command. 1. Ivy Tests: - 1. For a single function: + 1. For a single function: .. code-block:: none - + pytest ivy_tests/test_ivy/test_functional/test_core/test_image.py::test_random_crop --no-header --no-summary -q - + 2. For a single file: .. code-block:: none - + pytest ivy_tests/test_ivy/test_functional/test_core/test_image.py --no-header --no-summary -q 3. For all tests: @@ -390,28 +389,28 @@ With Docker 2. Array API Tests: - 1. For a single function: + 1. For a single function: .. code-block:: none - + pytest ivy_tests/array_api_testing/test_array_api/array_api_tests/test_creation_functions.py::test_arange --no-header --no-summary -q - + 2. For a single file: .. code-block:: none - + pytest ivy_tests/array_api_testing/test_array_api/array_api_tests/test_creation_functions.py --no-header --no-summary -q - + 3. For all tests: .. code-block:: none pytest ivy_tests/array_api_testing/test_array_api/ --no-header --no-summary -q - + 3. For the entire project: .. code-block:: none - + pytest ivy_tests/ --no-header --no-summary -q #. Through the command line (Without docker): @@ -435,16 +434,16 @@ With Docker 1. Ivy Tests: - 1. For a single function: + 1. For a single function: .. code-block:: none - + python -m pytest ivy_tests/test_ivy/test_functional/test_core/test_image.py::test_random_crop --no-header --no-summary -q - + 2. For a single file: .. code-block:: none - + python -m pytest ivy_tests/test_ivy/test_functional/test_core/test_image.py --no-header --no-summary -q 3. For all tests: @@ -453,34 +452,34 @@ With Docker python -m pytest ivy_tests/test_ivy/ --no-header --no-summary -q - 2. Array API Tests + 2. Array API Tests - 1. For a single function: + 1. For a single function: .. code-block:: none - + python -m pytest ivy_tests/array_api_testing/test_array_api/array_api_tests/test_creation_functions.py::test_arange --no-header --no-summary -q - + 2. For a single file: .. code-block:: none - + python -m pytest ivy_tests/array_api_testing/test_array_api/array_api_tests/test_creation_functions.py --no-header --no-summary -q - + 3. For all tests: .. code-block:: none python -m pytest ivy_tests/array_api_testing/test_array_api/ --no-header --no-summary -q - + 3. For the entire project .. code-block:: none - + python -m pytest ivy_tests/ --no-header --no-summary -q #. Optional Flags: Various optional flags are available for running the tests such as :code:`device`, :code:`backend`, etc. - 1. :code:`device`: + 1. :code:`device`: 1. This flag enables the setting of the device where the tests would be run. 2. Possible values being :code:`cpu` and :code:`gpu`. 3. Default value is :code:`cpu` @@ -542,7 +541,7 @@ with PyCharm :align: center 3. Stepping through the code: - 1. Step over: + 1. Step over: Steps over the current line of code and takes you to the next line even if the highlighted line has method calls in it. 1. Click the Step Over button or press :code:`F8` @@ -568,7 +567,7 @@ with PyCharm 2. Click the desired method. - 4. Python Console: + 4. Python Console: 1. Click the Console option on Debug Tool Window: This currently stores variables and their values upto which the code has been executed. You can print outputs and debug the code further on. diff --git a/docs/overview/deep_dive.rst b/docs/overview/deep_dive.rst index e92fc8e1a7688..024dfd599ad43 100644 --- a/docs/overview/deep_dive.rst +++ b/docs/overview/deep_dive.rst @@ -4,7 +4,7 @@ Deep Dive .. _`issues`: https://github.com/unifyai/ivy/issues .. _`pull-requests`: https://github.com/unifyai/ivy/pulls -For general users of the framework, who are mainly concerned with learning how to *use* Ivy, then the :ref:`Design` section is the best place to start 🙂 +For general users of the framework, who are mainly concerned with learning how to *use* Ivy, then the `Design `_ section is the best place to start 🙂 This *deep dive* section is more targeted at people who would like to dive deeper into how Ivy actually works under the hood 🔧 @@ -13,74 +13,79 @@ Going through the sections outlined below will get you right into the weeds of t It's best to go through the sub-sections from start to finish, but you can also dive in at any stage! We're excited for you to get involved! 🦾 -| (a) :ref:`Navigating the Code` 🧭 +| (a) `Navigating the Code `_ 🧭 | A quick tour through the codebase | -| (b) :ref:`Function Types` 🧮 +| (b) `Function Types `_ 🧮 | Primary, compositional, mixed, and nestable functions | -| (c) :ref:`Superset Behaviour` ⊃ +| (c) `Superset Behaviour `_ ⊃ | Ivy goes for the superset when unifying the backend functions | -| (d) :ref:`Backend Setting` ⚙ +| (d) `Backend Setting `_ ⚙ | How the backend is set, and what this means for each function type️ | -| (e) :ref:`Arrays` 🔢 +| (e) `Arrays `_ 🔢 | Different types of arrays, and how they're handled | -| (f) :ref:`Containers` 🗂 +| (f) `Containers `_ 🗂 | What the :class:`ivy.Container` does | -| (g) :ref:`Data Types` 💾 +| (g) `Data Types `_ 💾 | How functions infer the correct data type | -| (h) :ref:`Devices` 📱 +| (h) `Devices `_ 📱 | How functions infer the correct device | -| (i) :ref:`Inplace Updates` 🎯 +| (i) `Inplace Updates `_ 🎯 | How the :code:`out` argument is used to specify the output target | -| (j) :ref:`Function Wrapping` 🎁 +| (j) `Function Wrapping `_ 🎁 | How functions are dynamically wrapped at runtime | -| (k) :ref:`Formatting` 📋 +| (k) `Formatting `_ 📋 | How the code is automatically formatted | -| (l) :ref:`Function Arguments` 📑 +| (l) `Ivy Lint `_ 🧹 +| Ivy's Custom Code Formatters +| +| (m) `Function Arguments `_ 📑 | How to add the correct function arguments | -| (m) :ref:`Docstrings` 📄 +| (n) `Docstrings `_ 📄 | How to properly write docstrings | -| (n) :ref:`Docstring Examples` 💯 +| (o) `Docstring Examples `_ 💯 | How to add useful examples to the docstrings | -| (o) :ref:`Array API Tests` 🤝 +| (p) `Array API Tests `_ 🤝 | How we're borrowing the test suite from the Array API Standard | -| (p) :ref:`Ivy Tests` 🧪 +| (q) `Ivy Tests `_ 🧪 | How to add new tests for each Ivy function | -| (q) :ref:`Ivy Frontends` ➡ +| (r) `Ivy Frontends `_ ➡ | How to implement frontend functions | -| (r) :ref:`Ivy Frontend Tests` 🧪 +| (s) `Ivy Frontend Tests `_ 🧪 | How to add new tests for each frontend function | -| (s) :ref:`Exception Handling` ⚠ +| (t) `Exception Handling `_ ⚠ | How to handle exceptions and assertions in a function | -| (t) :ref:`Continuous Integration` 🔁 +| (u) `Continuous Integration `_ 🔁 | Ivy Tests running on the Repository | -| (u) :ref:`Gradients` 🔁 +| (v) `Gradients `_ 🔁 | Everything about our Gradients API | -| (v) :ref:`Operating Modes` 🧮 +| (w) `Operating Modes `_ 🧮 | Everything about modes Ivy can operate in, along with their purposes | -| (w) :ref:`Building the Docs Pipeline` 📚 +| (x) `Building the Docs Pipeline `_ 📚 | How are we building our docs + + .. toctree:: :hidden: :maxdepth: -1 @@ -97,6 +102,7 @@ We're excited for you to get involved! 🦾 deep_dive/inplace_updates.rst deep_dive/function_wrapping.rst deep_dive/formatting.rst + deep_dive/ivy_lint.rst deep_dive/function_arguments.rst deep_dive/docstrings.rst deep_dive/docstring_examples.rst @@ -108,4 +114,4 @@ We're excited for you to get involved! 🦾 deep_dive/continuous_integration.rst deep_dive/gradients.rst deep_dive/operating_modes.rst - deep_dive/building_the_docs_pipline.rst + deep_dive/building_the_docs_pipeline.rst diff --git a/docs/overview/deep_dive/array_api_tests.rst b/docs/overview/deep_dive/array_api_tests.rst index 6bc203297d22b..e9225fedc5c83 100644 --- a/docs/overview/deep_dive/array_api_tests.rst +++ b/docs/overview/deep_dive/array_api_tests.rst @@ -12,12 +12,10 @@ Array API Tests .. _`array-api test repository`: https://github.com/data-apis/array-api/tree/main .. _`issue`: https://github.com/numpy/numpy/issues/21213 .. _`ivy_tests/array_api_testing/test_array_api/array_api_tests/test_special_cases.py`: https://github.com/data-apis/array-api-tests/blob/ddd3b7a278cd0c0b68c0e4666b2c9f4e67b7b284/array_api_tests/test_special_cases.py -.. _`here`: https://unify.ai/docs/ivy/overview/contributing/the_basics.html#running-tests-locally .. _`git website`: https://www.git-scm.com/book/en/v2/Git-Tools-Submodules .. _`hypothesis`: https://hypothesis.readthedocs.io/en/latest/ -.. _`ivy tests`: https://unify.ai/docs/ivy/overview/deep_dive/ivy_tests.html -.. _`final section`: https://unify.ai/docs/ivy/overview/deep_dive/ivy_tests.html#re-running-failed-ivy-tests -.. _`CI Pipeline`: https://unify.ai/docs/ivy/overview/deep_dive/continuous_integration.html +.. _`ivy tests`: ivy_tests.rst +.. _`CI Pipeline`: continuous_integration.html In conjunction with our own ivy unit tests, we import the array-api `test suite`_. These tests check that all ivy backend libraries behave according to the `Array API Standard`_ which was established in May 2020 by a group of maintainers. @@ -99,7 +97,7 @@ Using the IDE You can also run a specific test or test file by using your IDE. To make this work, you should set the backend explicitly in the `_array_module.py` file as explained in the previous subsection. After that, you can run the API test files as you typically would with other tests. -See `here`_ for instructions on how to run tests in ivy more generally. +See :ref:`here ` for instructions on how to run tests in ivy more generally. *NB*: make sure to not add any changes to the array-api files to your commit. @@ -107,7 +105,7 @@ Regenerating Test Failures -------------------------- Array-API tests are written using `hypothesis`_ to perform property-based testing, just like the `ivy tests`_. However, unlike the ivy tests, the Array-API tests make liberal use of :code:`data.draw` in the main body of the test function instead of generating the data in the :code:`@given` decorator that wraps it. -This means that failed tests cannot be re-run with the :code:`@example` decorator, as explained in the `final section`_ of the ivy tests deep dive. +This means that failed tests cannot be re-run with the :code:`@example` decorator, as explained in the :ref:`final section ` of the ivy tests deep dive. Fortunately, it is possible to regenerate test failures using a unique decorator that appears in the final line of the falsifying example in the error stack trace: .. code-block:: none diff --git a/docs/overview/deep_dive/arrays.rst b/docs/overview/deep_dive/arrays.rst index 7029354e696a5..02dccc5352c56 100644 --- a/docs/overview/deep_dive/arrays.rst +++ b/docs/overview/deep_dive/arrays.rst @@ -85,7 +85,7 @@ Therefore, most functions in Ivy must adopt the following pipeline: #. call the backend-specific function, passing in these :class:`ivy.NativeArray` instances #. convert all of the :class:`ivy.NativeArray` instances which are returned from the backend function back into :class:`ivy.Array` instances, and return -Given the repeating nature of these steps, this is all entirely handled in the `inputs_to_native_arrays`_ and `outputs_to_ivy_arrays`_ wrappers, as explained in the :ref:`Function Wrapping` section. +Given the repeating nature of these steps, this is all entirely handled in the `inputs_to_native_arrays`_ and `outputs_to_ivy_arrays`_ wrappers, as explained in the `Function Wrapping `_ section. All Ivy functions *also* accept :class:`ivy.NativeArray` instances in the input. This is for a couple of reasons. @@ -93,11 +93,11 @@ Firstly, :class:`ivy.Array` instances must be converted to :class:`ivy.NativeArr Secondly, this makes it easier to combine backend-specific code with Ivy code, without needing to explicitly wrap any arrays before calling sections of Ivy code. Therefore, all input arrays to Ivy functions have type :code:`Union[ivy.Array, ivy.NativeArray]`, whereas the output arrays have type :class:`ivy.Array`. -This is further explained in the :ref:`Function Arguments` section. +This is further explained in the `Function Arguments `_ section. However, :class:`ivy.NativeArray` instances are not permitted for the :code:`out` argument, which is used in most functions. This is because the :code:`out` argument dictates the array to which the result should be written, and so it effectively serves the same purpose as the function return. -This is further explained in the :ref:`Inplace Updates` section. +This is further explained in the `Inplace Updates `_ section. As a final point, extra attention is required for *compositional* functions, as these do not directly defer to a backend implementation. If the first line of code in a compositional function performs operations on the input array, then this will call the special methods on an :class:`ivy.NativeArray` and not on an :class:`ivy.Array`. @@ -113,7 +113,7 @@ Ivy's functional API and its functions can easily be integrated with non-Ivy cla To make use of that feature, the class must contain an implementation for these functions and it must contain an implementation for the function :code:`__ivy_array_function__`. If a non-Ivy class is passed to an Ivy function, a call to this class's :code:`__ivy_array_function__` is made which directs Ivy's function to handle that input type correctly. This allows users to define custom implementations for any of the functions that can be found in Ivy's functional API which would further make it easy to integrate those classes with other Ivy projects. **Note** -This functionality is inspired by `NumPy's`_ :code:`__ivy_array_function__` and `PyTorch's`_ :code:`__torch_function__`. +This functionality is inspired by `NumPy's`_ :code:`__ivy_array_function__` and `PyTorch's`_ :code:`__torch_function__`. As an example, consider the following class :code:`MyArray` with the following definition: @@ -138,12 +138,12 @@ There are different ways to do so. One way is to use a global dict :code:`HANDLE return NotImplemented if not all(issubclass(t, (MyArray, ivy.Array, ivy.NativeArray)) for t in types): return NotImplemented - return HANDLED_FUNCTIONS[func](*args, **kwargs) + return HANDLED_FUNCTIONS[func](*args, **kwargs) -:code:`__ivy_array_function__` accepts four parameters: :code:`func` representing a reference to the array API function being +:code:`__ivy_array_function__` accepts four parameters: :code:`func` representing a reference to the array API function being overridden, :code:`types` a list of the types of objects implementing :code:`__ivy_array_function__`, :code:`args` a tuple of arguments supplied to the function, and :code:`kwargs` being a dictionary of keyword arguments passed to the function. While this class contains an implementation for :code:`__ivy_array_function__`, it is still not enough as it is necessary to implement any needed Ivy functions with the new :code:`MyArray` class as input(s) for the code to run successfully. -We will define a decorator function :code:`implements` that can be used to add functions to :code:`HANDLED_FUNCTIONS`: +We will define a decorator function :code:`implements` that can be used to add functions to :code:`HANDLED_FUNCTIONS`: .. code-block:: python @@ -151,7 +151,7 @@ We will define a decorator function :code:`implements` that can be used to add f def decorator(func): HANDLED_FUNCTIONS[ivy_function] = func return func - return decorator + return decorator Lastly, we need to apply that decorator to the override function. Let’s consider for example a function that overrides :code:`ivy.abs`: @@ -168,7 +168,7 @@ Now that we have added the function to :code:`HANDLED_FUNCTIONS`, we can now use X = MyArray(-3) X = ivy.abs(X) -Of course :code:`ivy.abs` is an example of a function that is easy to override since it only requires one operand. The same approach can be used to override functions with multiple operands, including arrays or array-like objects that define :code:`__ivy_array_function__`. +Of course :code:`ivy.abs` is an example of a function that is easy to override since it only requires one operand. The same approach can be used to override functions with multiple operands, including arrays or array-like objects that define :code:`__ivy_array_function__`. It is relevant to mention again that any function not stored inside the dict :code:`HANDLED_FUNCTIONS` will not work and it is also important to notice that the operands passed to the function must match that of the function stored in the dict. For instance :code:`my_abs` takes only one parameter which is a :code:`MyArray` object. So, passing any other operands to the function will result in an exception :code:`IvyBackendException` being thrown. Lastly, for a custom class to be covered completely with Ivy's functional API, it is necessary to create an implementation for all the relevant functions within the API that will be used by this custom class. That can be all the functions in the API or only a subset of them. diff --git a/docs/overview/deep_dive/backend_setting.rst b/docs/overview/deep_dive/backend_setting.rst index ac42524e05c80..b2be3a2f0a7a5 100644 --- a/docs/overview/deep_dive/backend_setting.rst +++ b/docs/overview/deep_dive/backend_setting.rst @@ -23,30 +23,30 @@ When calling `this function`_ for setting the backend, the following steps are p #. loop through the original :code:`ivy_original_dict` (which has all functions, including compositional), and (a) add the primary function from the backend if it exists, (b) else add the compositional function from :code:`ivy_original_dict`. #. `wrap the functions`_ where necessary, extending them with shared repeated functionality and `writing the function`_ to :attr:`ivy.__dict__`. Wrapping is used in order to avoid excessive code duplication in every backend function implementation. - This is explained in more detail in the next section: :ref:`Function Wrapping`. + This is explained in more detail in the next section: `Function Wrapping `_. It's helpful to look at an example: .. code-block:: python x = ivy.array([[2., 3.]]) - ivy.get_backend() + ivy.current_backend() .. code-block:: python y = ivy.multiply(torch.Tensor([3.]), torch.Tensor([4.])) - ivy.get_backend() + ivy.current_backend() .. code-block:: python ivy.set_backend('jax') z = ivy.matmul(jax.numpy.array([[2.,3.]]), jax.numpy.array([[5.],[6.]])) - ivy.get_backend() + ivy.current_backend() ivy.previous_backend() - ivy.get_backend() + ivy.current_backend() In the last example above, the moment any backend is set, it will be used over the `implicit_backend`_. @@ -74,7 +74,7 @@ Essentially, when the user calls :code:`ivy.set_backend(, dynamic=True) #. Next, the global :code:`ivy.__dict__` is updated to the new backend as mentioned in the Backend Setting section above. #. Finally, the objects are `converted from numpy`_ to the target backend using the newly set backend. -By default, the dynamic backend attribute is set to True when you create an ivy array (e.g., :code:`x = ivy.array([1,2,3])`), but the attribute is mutable and can be changed after the ivy array is created (e.g., :code:`x.dynamic_backend= True`). +By default, the dynamic backend attribute is set to True when you create an ivy array (e.g., :code:`x = ivy.array([1,2,3])`), but the attribute is mutable and can be changed after the ivy array is created (e.g., :code:`x.dynamic_backend= True`). Here's an example to illustrate how this works in practice: .. code-block:: python diff --git a/docs/overview/deep_dive/building_the_docs_pipline.rst b/docs/overview/deep_dive/building_the_docs_pipeline.rst similarity index 92% rename from docs/overview/deep_dive/building_the_docs_pipline.rst rename to docs/overview/deep_dive/building_the_docs_pipeline.rst index 464bfa63eb372..e38f01b892941 100644 --- a/docs/overview/deep_dive/building_the_docs_pipline.rst +++ b/docs/overview/deep_dive/building_the_docs_pipeline.rst @@ -7,7 +7,7 @@ Building the Docs Pipeline .. _doc-builder repository: https://github.com/unifyai/doc-builder To build our docs, we use `Sphinx`_. Sphinx is an extendable documentation generator -for Python. As our building pipeline is complex, we heavily customize Sphinx using +for Python. As our building pipeline is complex, we heavily customize Sphinx using custom and third party extensions. As well as having a convenience script to build the docs. @@ -43,7 +43,7 @@ document. The project should have the following characteristics: 5. It can contain an optional ``docs/partial_conf.py`` which is a partial `Sphinx configuration file`_. - This file will be imported with the default ``conf.py`` file located in the + This file will be imported with the default ``conf.py`` file located in the ``doc-builder`` repo. Running the script: @@ -52,7 +52,7 @@ Running the script: ./make_docs_without_docker.sh /path/to/project -will result in the creation of documentation for the project in the directory +will result in the creation of documentation for the project in the directory ``docs/build``. Options @@ -62,7 +62,7 @@ Options -C, --no-cleanup Disable the backup/cleanup procedure -g, --git-add Stage changed files before generating the docs -s, --skip-dependencies-install Skip installing dependencies using pip --j, --jobs N Build in parallel with N processes where possible +-j, --jobs N Build in parallel with N processes where possible (special value ``auto`` will set N to cpu-count) -D setting Override a setting in ``conf.py`` @@ -71,7 +71,7 @@ The Docker image The Docker image `unifyai/doc-builder `_ works as a wrapper around the ``make_docs_without_docker.sh`` script. It runs the script -on the ``/project`` directory, located in the container `as shown here +on the ``/project`` directory, located in the container `as shown here `_: .. code-block:: bash @@ -84,10 +84,10 @@ To build the docs through docker you use this command: docker run -v /path/to/project:/project unifyai/doc-builder -You can also add options described in the :ref:`The convenience script` section. +You can also add options described in the :ref:`overview/deep_dive/building_the_docs_pipeline:The convenience script` section. .. code-block:: bash - + docker run -v /path/to/project:/project unifyai/doc-builder --no-cleanup How Ivy's docs is structured @@ -160,22 +160,22 @@ the files that should be included in the table of contents. Which in recursively to every page in this documentation, for example this page is included in the ``toctree`` of ``overview/deep_dive.rst``, which is included in the ``toctree`` of ``index.rst``. You can read more about the ``toctree`` directive in `sphinx docs -`_, from +`_, from now on we'll only explain the directives that are custom to Ivy's doc-builder. The last directive is ``autosummary``, which is used to automatically generate a table of contents for a module, as well as the documentation itself automatically by discovering the docstrings of the module. This is a custom directive, built on the original `autosummary`_ -extension. We will explain in detail how did we change it, in :ref:`Custom Extensions`. +extension. We will explain in detail how did we change it, in :ref:`overview/deep_dive/building_the_docs_pipeline:Custom Extensions`. ``partial_conf.py`` ~~~~~~~~~~~~~~~~~~~ -This is a partial `Sphinx configuration file`_. Which is being imported in the +This is a partial `Sphinx configuration file`_. Which is being imported in the `conf.py `_, it's used to customize options that are specific to the project being documented. -While importing common configurations such as the theme, the extensions, etc in the +While importing common configurations such as the theme, the extensions, etc in the original ``conf.py``. This is a part of ``partial_conf.py``: @@ -190,18 +190,18 @@ This is a part of ``partial_conf.py``: "ivy_tests.test_ivy.helpers": "Testing", } -Here we are overriding the ``ivy_toctree_caption_map`` configuration, which is used to -customize the title of the table of contents for each module. +Here we are overriding the ``ivy_toctree_caption_map`` configuration, which is used to +customize the title of the table of contents for each module. ``ivy_toctree_caption_map`` is one of the configuration options we have in our -``custom_autosummary`` extension, which will be covered extensively in -:ref:`Custom Extensions`. +``custom_autosummary`` extension, which will be covered extensively in +:ref:`overview/deep_dive/building_the_docs_pipeline:Custom Extensions`. ``prebuild.sh`` ~~~~~~~~~~~~~~~ This is an optional file, which is executed before the docs are built. This is useful -if you need to install some dependencies for the docs to build. In Ivy's case, we -install ``torch`` then ``torch-scatter`` sequentially to avoid a bug in +if you need to install some dependencies for the docs to build. In Ivy's case, we +install ``torch`` then ``torch-scatter`` sequentially to avoid a bug in ``torch-scatter``'s setup. And if we want to make any changes to the docker container before building the docs. @@ -250,7 +250,7 @@ The directive is included like this: .. discussion-links:: module.foo -First it will look for the ``discussion_channel_map`` configuration, in Ivy it looks like +First it will look for the ``discussion_channel_map`` configuration, in Ivy it looks like this: .. code-block:: python @@ -263,15 +263,15 @@ this: } The key is the module name, if it's not found the ``discussion-link`` directive will -render an empty node. The first and only value in the list is the channel id of the +render an empty node. The first and only value in the list is the channel id of the module, it is in a list as we used to have forums as well but they are removed now. The output string is generated by a series of replaces on template strings, which are customizable using the config. To understand how it works, let's look at the default configurations and their values: -- ``discussion_paragraph``: ``"This should have hopefully given you an overview of the - {{submodule}} submodule, if you have any questions, please feel free to reach out on +- ``discussion_paragraph``: ``"This should have hopefully given you an overview of the + {{submodule}} submodule, if you have any questions, please feel free to reach out on our [discord]({{discord_link}}) in the [{{submodule}} channel]({{channel_link}})!"`` - ``discord_link``: ``"https://discord.gg/ZVQdvbzNQJ"`` - ``channel_link``: ``"https://discord.com/channels/799879767196958751/{{channel_id}}"`` @@ -283,39 +283,39 @@ Here is an example of how it works for ``ivy.functional.ivy.creation``: The result will be like this: - This should have hopefully given you an overview of the - **creation** submodule, if you have any questions, please feel free to reach out on + This should have hopefully given you an overview of the + **creation** submodule, if you have any questions, please feel free to reach out on our [discord]({{discord_link}}) in the [**creation** channel]({{channel_link}})! 2. Then we resolve the ``{{discord_link}}`` template string. The result will be like this: - - This should have hopefully given you an overview of the - creation submodule, if you have any questions, please feel free to reach out on + + This should have hopefully given you an overview of the + creation submodule, if you have any questions, please feel free to reach out on our [discord](**https://discord.gg/ZVQdvbzNQJ**) in the [creation channel]({{channel_link}})! 3. Then we resolve the ``{{channel_link}}`` template string. The result will be like this: - - This should have hopefully given you an overview of the - creation submodule, if you have any questions, please feel free to reach out on + + This should have hopefully given you an overview of the + creation submodule, if you have any questions, please feel free to reach out on our [discord](\https://discord.gg/ZVQdvbzNQJ) in the [creation channel](**https://discord.com/channels/799879767196958751/{{channel_id}}**)! 4. We finally resolve ``{{channel_id}}`` template strings. The result will be like this: - - This should have hopefully given you an overview of the - creation submodule, if you have any questions, please feel free to reach out on + + This should have hopefully given you an overview of the + creation submodule, if you have any questions, please feel free to reach out on our [discord](\https://discord.gg/ZVQdvbzNQJ) in the [creation channel](\https://discord.com/channels/799879767196958751/**1000043690254946374**)! 5. After that we render the node paragraph as if it's a Markdown text resulting this: - This should have hopefully given you an overview of the - creation submodule, if you have any questions, please feel free to reach out on - our `discord `_ in the `creation channel + This should have hopefully given you an overview of the + creation submodule, if you have any questions, please feel free to reach out on + our `discord `_ in the `creation channel `_! All of the above template strings can be customized using the configuration, so feel free @@ -324,7 +324,7 @@ to change them to your liking. ``skippable_function`` ~~~~~~~~~~~~~~~~~~~~~~ -This extension provides a custom auto documenter ``autoskippablemethod`` that skip +This extension provides a custom auto documenter ``autoskippablemethod`` that skip functions that match values in ``skippable_method_attributes`` configuration. This is an example of ``skippable_method_attributes`` configuration in @@ -338,14 +338,14 @@ This is an example of ``skippable_method_attributes`` configuration in } ] -This will remove any function that has ``__qualname__`` attribute equal to +This will remove any function that has ``__qualname__`` attribute equal to ``_wrap_function..new_function``. ``ivy_data`` ~~~~~~~~~~~~ This is a custom documenter for ``autodoc`` that documents Ivy data attributes that live -in ``ivy.functional.ivy``, it will replace the module to ``ivy.`` instead of +in ``ivy.functional.ivy``, it will replace the module to ``ivy.`` instead of ``ivy.functional.ivy.``. It's used instead of simply using ``ivy.`` because data attributes have @@ -353,6 +353,6 @@ no ``__doc__`` atribute, instead docs are discovered by parsing the source code So for Sphinx to find the required docs, it needs to be supplied the full module name, then using the ``autoivydata`` directive will replace the module name to ``ivy.``. -Please refer to the `auto documenter guide in sphinx documentation +Please refer to the `auto documenter guide in sphinx documentation `_ for more info. diff --git a/docs/overview/deep_dive/containers.rst b/docs/overview/deep_dive/containers.rst index 0b975eb990421..13521ec772f17 100644 --- a/docs/overview/deep_dive/containers.rst +++ b/docs/overview/deep_dive/containers.rst @@ -185,18 +185,18 @@ As for the special methods which are `implemented`_ in the main :class:`ivy.Cont As a result, the operator functions will make use of the special methods of the lefthand passed input objects if available, otherwise it will make use of the reverse special method of the righthand operand. For instance, if the lefthand operand at any given leaf of the container in an :class:`ivy.Array`, then the operator function will make calls to the special methods of this array object. -As explained in the :ref:`Arrays` section of the Deep Dive, these special methods will in turn call the corresponding functions from the ivy functional API. - +As explained in the `Arrays `_ section of the Deep Dive, these special methods will in turn call the corresponding functions from the ivy functional API. + Examples include `__add__`_, `__sub__`_, `__mul__`_ and `__truediv__`_ which will make calls to :func:`ivy.add`, :func:`ivy.subtract`, :func:`ivy.multiply` and :func:`ivy.divide` respectively if the lefthand operand is an :class:`ivy.Array` object. Otherwise, these special methods will be called on whatever objects are at the leaves of the container, such as int, float, :class:`ivy.NativeArray` etc. Nestable Functions ------------------ -As introduced in the :ref:`Function Types` section, most functions in Ivy are *nestable*, which means that they can accept :class:`ivy.Container` instances in place of **any** of the arguments. +As introduced in the `Function Types `_ section, most functions in Ivy are *nestable*, which means that they can accept :class:`ivy.Container` instances in place of **any** of the arguments. Here, we expand on this explanation. -Please check out the explanation in the :ref:`Function Types` section first. +Please check out the explanation in the `Function Types `_ section first. **Explicitly Nestable Functions** diff --git a/docs/overview/deep_dive/continuous_integration.rst b/docs/overview/deep_dive/continuous_integration.rst index bfabdc6973c00..e639efa7bcfca 100644 --- a/docs/overview/deep_dive/continuous_integration.rst +++ b/docs/overview/deep_dive/continuous_integration.rst @@ -289,7 +289,7 @@ Array API Tests --------------- The `array-api-intelligent-tests.yml (Push) `_ and the `array-api-intelligent-tests-pr.yml (Pull Request) `_ workflows run the Array API Tests. Similar to Ivy Tests, The Array API tests are also determined intelligently and only relevant tests are triggered on each commit. -More details about the Array API Tests are available `here `_. +More details about the Array API Tests are available `here `_. Periodic Testing ---------------- diff --git a/docs/overview/deep_dive/data_types.rst b/docs/overview/deep_dive/data_types.rst index aecaebd30e9b2..f2f72dea820c2 100644 --- a/docs/overview/deep_dive/data_types.rst +++ b/docs/overview/deep_dive/data_types.rst @@ -80,7 +80,7 @@ Data Type Module The `data_type.py`_ module provides a variety of functions for working with data types. A few examples include :func:`ivy.astype` which copies an array to a specified data type, :func:`ivy.broadcast_to` which broadcasts an array to a specified shape, and :func:`ivy.result_type` which returns the dtype that results from applying the type promotion rules to the arguments. -Many functions in the :mod:`data_type.py` module are *convenience* functions, which means that they do not directly modify arrays, as explained in the :ref:`Function Types` section. +Many functions in the :mod:`data_type.py` module are *convenience* functions, which means that they do not directly modify arrays, as explained in the `Function Types `_ section. For example, the following are all convenience functions: `ivy.can_cast`_, which determines if one data type can be cast to another data type according to type-promotion rules, `ivy.dtype `__, which gets the data type for the input array, `ivy.set_default_dtype`_, which sets the global default data dtype, and `ivy.default_dtype`_, which returns the correct data type to use. @@ -95,8 +95,7 @@ Data Type Promotion In order to ensure that the same data type is always returned when operations are performed on arrays with different data types, regardless of which backend framework is set, Ivy has it's own set of data type promotion rules and corresponding functions. These rules build directly on top of the `rules `_ outlined in the `Array API Standard`_. -The rules are simple: all data type promotions in Ivy should adhere to this `promotion table `_, -which is the union of the Array API Standard `promotion table `_ and an extra `promotion table `_. +The rules are simple: all data type promotions in Ivy should adhere a promotion table that extends Array API Standard `promotion table `_ using this `promotion table `_, and one of two extra `promotion tables `_ depending on precision mode that will be explained in the following section. In order to ensure adherence to this promotion table, many backend functions make use of the functions `ivy.promote_types `_, `ivy.type_promote_arrays `_, or `ivy.promote_types_of_inputs `_. These functions: promote data types in the inputs and return the new data types, promote the data types of the arrays in the input and return new arrays, and promote the data types of the numeric or array values inputs and return new type promoted values, respectively. @@ -182,10 +181,39 @@ Whenever the user defines data with a specific data type, they expect a certain The user expects specific behaviour and memory constraints whenever they specify and use concrete data types, and those decisions should be respected. Therefore, Ivy does not upcast specific values to improve the stability or precision of the computation. +Precise Mode +~~~~~~~~~~~~~~~ -Arguments in other Functions ----------------------------- +There are cases that arise in mixed promotion (Integer and Float, Complex and Float) that aren't covered by the Array API Standard promotion table, and depending on each use case, +the mixed promotion rules differ as observed in different frameworks, for example Tensorflow leaves integer/floating mixed promotion undefined to make behavior utterly predictable (at some cost to user convenience), while Numpy avoids precision loss at all costs even if that meant casting the arrays to wider-than-necessary dtypes + +Precise Promotion Table +""""""""""""""""""""""""" + +This table focuses on numerical accuracy at the cost of a higher memory footprint. A 16-bit signed or unsigned integer cannot be represented at full precision by a 16-bit float, which has only 10 bits of mantissa. Therefore, it might make sense to promote integers to floats represented by twice the number of bits. There are two disadvantages of this approach: + +#. It still leaves int64 and uint64 promotion undefined, because there is no standard floating point type with enough bits of mantissa to represent their full range of values. We could relax the precision constraint and use ``float64`` as the upper bound for this case. +#. Some operations result in types that are much wider than necessary; for example mixed operations between ``uint16`` and float16 would promote all the way to ``float64``, which is not ideal. + +.. code-block:: python + + with ivy.PreciseMode(True): + print(ivy.promote_types("float32","int32")) + # float64 +Non-Precise Promotion Table +""""""""""""""""""""""""""""""""" +The advantage of this approach is that, outside unsigned ints, it avoids all wider-than-necessary promotions: you can never get an f64 output without a 64-bit input, and you can never get an ``float32`` output without a 32-bit input: this results in convenient semantics for working on accelerators while avoiding unwanted 64-bit values. This feature of giving primacy to floating point types resembles the type promotion behavior of PyTorch. +the disadvantage of this approach is that mixed float/integer promotion is very prone to precision loss: for example, ``int64`` (with a maximum value of 9.2*10^18 can be promoted to ``float16`` (with a maximum value of 6.5*10^4, meaning most representable values will become inf, but we are fine accepting potential loss of precision (but not loss of magnitude) in mixed type promotion which satisfies most of the use cases in deep learning scenarios. + +.. code-block:: python + + with ivy.PreciseMode(False): + print(ivy.promote_types("float32","int32")) + # float32 + +Arguments in other Functions +------------------- All ``dtype`` arguments are keyword-only. All creation functions include the ``dtype`` argument, for specifying the data type of the created array. Some other non-creation functions also support the ``dtype`` argument, such as :func:`ivy.prod` and :func:`ivy.sum`, but most functions do not include it. @@ -193,7 +221,7 @@ The non-creation functions which do support it are generally functions that invo The ``dtype`` argument is handled in the `infer_dtype`_ wrapper, for all functions which have the decorator :code:`@infer_dtype`. This function calls `ivy.default_dtype`_ in order to determine the correct data type. -As discussed in the :ref:`Function Wrapping` section, this is applied to all applicable functions dynamically during `backend setting`_. +As discussed in the `Function Wrapping `_ section, this is applied to all applicable functions dynamically during `backend setting`_. Overall, `ivy.default_dtype`_ infers the data type as follows: @@ -653,4 +681,4 @@ If you have any questions, please feel free to reach out on `discord`_ in the `d \ No newline at end of file + diff --git a/docs/overview/deep_dive/devices.rst b/docs/overview/deep_dive/devices.rst index 4f2cbbbcf24ce..1159535728bdb 100644 --- a/docs/overview/deep_dive/devices.rst +++ b/docs/overview/deep_dive/devices.rst @@ -33,7 +33,7 @@ The devices currently supported by Ivy are as follows: * gpu:idx * tpu:idx -In a similar manner to the :class:`ivy.Dtype` and :class:`ivy.NativeDtype` classes (see :ref:`Data Types`), there is both an `ivy.Device`_ class and an :class:`ivy.NativeDevice` class, with :class:`ivy.NativeDevice` initially set as an `empty class`_. +In a similar manner to the :class:`ivy.Dtype` and :class:`ivy.NativeDtype` classes (see `Data Types `_), there is both an `ivy.Device`_ class and an :class:`ivy.NativeDevice` class, with :class:`ivy.NativeDevice` initially set as an `empty class`_. The :class:`ivy.Device` class derives from :code:`str`, and has simple logic in the constructor to verify that the string formatting is correct. When a backend is set, the :class:`ivy.NativeDevice` is replaced with the backend-specific `device class`_. @@ -43,7 +43,7 @@ Device Module The `device.py`_ module provides a variety of functions for working with devices. A few examples include :func:`ivy.get_all_ivy_arrays_on_dev` which gets all arrays which are currently alive on the specified device, :func:`ivy.dev` which gets the device for input array, and :func:`ivy.num_gpus` which determines the number of available GPUs for use with the backend framework. -Many functions in the :mod:`device.py` module are *convenience* functions, which means that they do not directly modify arrays, as explained in the :ref:`Function Types` section. +Many functions in the :mod:`device.py` module are *convenience* functions, which means that they do not directly modify arrays, as explained in the `Function Types `_ section. For example, the following are all convenience functions: `ivy.total_mem_on_dev`_, which gets the total amount of memory for a given device, `ivy.dev_util`_, which gets the current utilization (%) for a given device, `ivy.num_cpu_cores`_, which determines the number of cores available in the CPU, and `ivy.default_device`_, which returns the correct device to use. @@ -64,7 +64,7 @@ In cases where the input arrays are located on different devices, an error will The :code:`device` argument is handled in `infer_device`_ for all functions which have the :code:`@infer_device` decorator, similar to how :code:`dtype` is handled. This function calls `ivy.default_device`_ in order to determine the correct device. -As discussed in the :ref:`Function Wrapping` section, this is applied to all applicable functions dynamically during `backend setting`_. +As discussed in the `Function Wrapping `_ section, this is applied to all applicable functions dynamically during `backend setting`_. Overall, `ivy.default_device`_ infers the device as follows: @@ -77,7 +77,7 @@ Overall, `ivy.default_device`_ infers the device as follows: For the majority of functions which defer to `infer_device`_ for handling the device, these steps will have been followed and the :code:`device` argument will be populated with the correct value before the backend-specific implementation is even entered into. Therefore, whereas the :code:`device` argument is listed as optional in the ivy API at :mod:`ivy/functional/ivy/category_name.py`, the argument is listed as required in the backend-specific implementations at :mod:`ivy/functional/backends/backend_name/category_name.py`. -This is exactly the same as with the :code:`dtype` argument, as explained in the :ref:`Data Types` section. +This is exactly the same as with the :code:`dtype` argument, as explained in the `Data Types `_ section. Let's take a look at the function :func:`ivy.zeros` as an example. @@ -155,7 +155,7 @@ doesn't care about this, it moves all the tensors to the same device before perf **Controlling Device Handling Behaviour** -In Ivy, users can control the device on which the operation is to be executed using `ivy.set_soft_device_mode`_ flag. There are two cases for this, +In Ivy, users can control the device on which the operation is to be executed using `ivy.set_soft_device_mode`_ flag. There are two cases for this, either the soft device mode is set to :code:`True` or :code:`False`. **When ivy.set_soft_device_mode(True)**: @@ -167,7 +167,7 @@ In the example below, even though the input arrays :code:`x` and :code:`y` are c are moved to :code:`ivy.default_device()` while performing :code:`ivy.add` operation, and the output array will be on this device. .. code-block:: python - + ivy.set_backend("torch") ivy.set_soft_device_mode(True) x = ivy.array([1], device="cpu") @@ -214,7 +214,7 @@ This is the exception you will get while running the code above: File "/content/ivy/ivy/func_wrapper.py", line 863, in _handle_device_shifting raise ivy.utils.exceptions.IvyException( During the handling of the above exception, another exception occurred: - Expected all input arrays to be on the same device, but found atleast two devices - ('cpu', 'gpu:0'), + Expected all input arrays to be on the same device, but found atleast two devices - ('cpu', 'gpu:0'), set `ivy.set_soft_device_mode(True)` to handle this problem. b. If all the input arrays are on the same device, the operation is executed without raising any device exceptions. @@ -278,5 +278,5 @@ If you have any questions, please feel free to reach out on `discord`_ in the `d .. raw:: html diff --git a/docs/overview/deep_dive/exception_handling.rst b/docs/overview/deep_dive/exception_handling.rst index 9c192a3f0e3a7..8695c318fe42d 100644 --- a/docs/overview/deep_dive/exception_handling.rst +++ b/docs/overview/deep_dive/exception_handling.rst @@ -57,14 +57,14 @@ For a more general case, the :code:`IvyError` class can be used. def __init__(self, *messages, include_backend=False): super().__init__(*messages, include_backend=include_backend) -More Custom Exception classes were created to unify sub-categories of errors. We try our best to ensure that the same type of +More Custom Exception classes were created to unify sub-categories of errors. We try our best to ensure that the same type of Exception is raised for the same type of Error regardless of the backend. This will ensure that the exceptions are truly unified for all the different types of errors. The implementations of these custom classes are exactly the same as :code:`IvyError` class. Currently there are 5 custom exception classes in ivy. 1. :code:`IvyIndexError`: This Error is raised for anything Indexing related. For Instance, providing out of bound axis in any function. -2. :code:`IvyValueError`: This is for anything related to providing wrong values. For instance, passing :code:`high` value +2. :code:`IvyValueError`: This is for anything related to providing wrong values. For instance, passing :code:`high` value smaller than :code:`low` value in :code:`ivy.random_uniform`. 3. :code:`IvyAttributeError`: This is raised when an undefined attribute is referenced. 4. :code:`IvyBroadcastShapeError`: This is raised whenever 2 shapes are expected to be broadcastable but are not. @@ -75,12 +75,12 @@ The correct type of Exception class should be used for the corresponding type of Configurable Mode for Stack Trace --------------------------------- -Ivy's transpilation nature allows users to write code in their preferred frontend -framework and then execute it with a different backend framework. For example, a -user who is comfortable with NumPy can use Ivy's NumPy frontend to run their code -with a JAX backend. However, since they may have no prior experience with JAX or -other backend frameworks, they may not want to encounter stack traces that traverse -Ivy and JAX functions. In such cases, it may be preferable for the user to avoid +Ivy's transpilation nature allows users to write code in their preferred frontend +framework and then execute it with a different backend framework. For example, a +user who is comfortable with NumPy can use Ivy's NumPy frontend to run their code +with a JAX backend. However, since they may have no prior experience with JAX or +other backend frameworks, they may not want to encounter stack traces that traverse +Ivy and JAX functions. In such cases, it may be preferable for the user to avoid encountering stack traces that extend through Ivy and JAX functions. Therefore, options are made available for the stack traces to either truncate @@ -400,7 +400,7 @@ Let's look at the comparison of before and after adding the decorator. In NumPy, .. code-block:: none - + >>> x = ivy.array([0,0,1]) >>> ivy.all(x, axis=2) @@ -449,7 +449,7 @@ and for Numpy, :code:`AxisError` is raised. To unify the behaviour, we raise :co In Numpy, .. code-block:: python - + # in ivy/functional/backends/numpy/utility.py def all( x: np.ndarray, diff --git a/docs/overview/deep_dive/formatting.rst b/docs/overview/deep_dive/formatting.rst index 307f059575dd3..c561f1fe04b97 100644 --- a/docs/overview/deep_dive/formatting.rst +++ b/docs/overview/deep_dive/formatting.rst @@ -3,7 +3,6 @@ Formatting .. _`flake8`: https://flake8.pycqa.org/en/latest/index.html .. _`black`: https://black.readthedocs.io/en/stable/index.html -.. _`pre-commit guide`: https://unify.ai/docs/ivy/overview/contributing/setting_up.html#pre-commit .. _`formatting channel`: https://discord.com/channels/799879767196958751/1028266706436624456 .. _`discord`: https://discord.gg/sXyFF8tDtm @@ -19,7 +18,7 @@ Lint Checks In addition to `black`_ and `flake8`_, Ivy uses other linters to help automate the formatting process, especially for issues `flake8`_ detects but doesn't fix automatically. In addition to that, we validate docstring as part of our -linting process. You can learn more about our docstring formatting in the :ref:`Docstrings` section. +linting process. You can learn more about our docstring formatting in the `Docstrings `_ section. We use the following linters: @@ -28,7 +27,7 @@ We use the following linters: * `autoflake `_ * `docformatter `_ * `pydocstyle `_ -* `ivy-lint `_ (WIP 🚧) +* `ivy-lint `_ You can also take a look at our configuration for linting in `setup.cfg `_ file. @@ -77,6 +76,7 @@ You should expect to see something similar to the following output when you run flake8...................................................................Passed docformatter.............................................................Passed pydocstyle...............................................................Passed + ivy-lint.................................................................Passed [INFO] Restored changes from ~/.cache/pre-commit/patch1687898304-8072. [formatting-docs 3516aed563] Test commit 1 file changed, 1 insertion(+) @@ -99,6 +99,7 @@ If something goes wrong, you will see the following output: flake8...................................................................Passed docformatter.............................................................Passed pydocstyle...............................................................Passed + ivy-lint.................................................................Passed [INFO] Restored changes from ~/.cache/pre-commit/patch1687898304-8072. You will notice that some files have changed if you checked ``git status``, you'll need to add them and commit again. @@ -168,22 +169,22 @@ We have a GitHub action that runs: 1. Every day at 08:00 UTC 2. Manually invoked by making a comment with ``ivy-gardener`` on a PR -The first action is to ensure that the code in the whole codebase is always formatted correctly. The second action -is to reformat the files you changed in your PR directly on GitHub. This is useful in case if you didn't setup +The first action is to ensure that the code in the whole codebase is always formatted correctly. The second action +is to reformat the files you changed in your PR directly on GitHub. This is useful in case if you didn't setup pre-commit correctly or if you or one of our maintainers want to reformat your code remotely. -Under the hood, when ``ivy-gardener`` is found in a comment, an ivy bot will trigger the same set of lint checks +Under the hood, when ``ivy-gardener`` is found in a comment, an ivy bot will trigger the same set of lint checks as in the pre-commit process. Then the suggested changes produced in the checks will be applied automatically as -a new commit if there is any. +a new commit if there is any. -However, it is possible for the linters run in the ``ivy-gardener`` and the GitHub action every day to face -formatting errors that need human intervention like typos and uninitialized arguments. In this case, errors will -be thrown by the linters and by the lint checks that runs later, while fixes to other simpler errors will still +However, it is possible for the linters run in the ``ivy-gardener`` and the GitHub action every day to face +formatting errors that need human intervention like typos and uninitialized arguments. In this case, errors will +be thrown by the linters and by the lint checks that runs later, while fixes to other simpler errors will still be applied by the ``ivy-gardener`` properly. -On the other hand, ``ivy-gardener`` itself can fail if the bot handling it (ivy-branch) can not apply the changes -suggested by the linters, for example, when it does not have access to edit the target branch. In this case, you -should try to give the maintainer bot the access to your branch (which is an option shown in GitHub UI) and give it +On the other hand, ``ivy-gardener`` itself can fail if the bot handling it (ivy-branch) can not apply the changes +suggested by the linters, for example, when it does not have access to edit the target branch. In this case, you +should try to give the maintainer bot the access to your branch (which is an option shown in GitHub UI) and give it another try, or manually resolve the formatting errors by commiting the changes yourself. **Round Up** diff --git a/docs/overview/deep_dive/function_arguments.rst b/docs/overview/deep_dive/function_arguments.rst index 537b45dcb9bf5..988db211600ec 100644 --- a/docs/overview/deep_dive/function_arguments.rst +++ b/docs/overview/deep_dive/function_arguments.rst @@ -10,7 +10,7 @@ Function Arguments Here, we explain how the function arguments differ between the placeholder implementation at :mod:`ivy/functional/ivy/category_name.py`, and the backend-specific implementation at :mod:`ivy/functional/backends/backend_name/category_name.py`. -Many of these points are already addressed in the previous sections: :ref:`Arrays`, :ref:`Data Types`, :ref:`Devices` and :ref:`Inplace Updates`. +Many of these points are already addressed in the previous sections: `Arrays `_, `Data Types `_, `Devices `_ and `Inplace Updates `_. However, we thought it would be convenient to revisit all of these considerations in a single section, dedicated to function arguments. As for type-hints, all functions in the Ivy API at :mod:`ivy/functional/ivy/category_name.py` should have full and thorough type-hints. @@ -161,13 +161,13 @@ For example, calling any of (:code:`+`, :code:`-`, :code:`*`, :code:`/` etc.) on :class:`ivy.NativeArray` instances are also not permitted for the :code:`out` argument, which is used in many functions. This is because the :code:`out` argument dictates the array to which the result should be written, and so it effectively serves the same purpose as the function return when no :code:`out` argument is specified. -This is all explained in more detail in the :ref:`Arrays` section. +This is all explained in more detail in the `Arrays `_ section. out Argument ------------ The :code:`out` argument should always be provided as a keyword-only argument, and it should be added to all functions in the Ivy API and backend API which support inplace updates, with a default value of :code:`None` in all cases. -The :code:`out` argument is explained in more detail in the :ref:`Inplace Updates` section. +The :code:`out` argument is explained in more detail in the `Inplace Updates `_ section. dtype and device arguments -------------------------- @@ -175,7 +175,7 @@ dtype and device arguments In the Ivy API at :mod:`ivy/functional/ivy/category_name.py`, the :code:`dtype` and :code:`device` arguments should both always be provided as keyword-only arguments, with a default value of :code:`None`. In contrast, these arguments should both be added as required arguments in the backend implementation at :mod:`ivy/functional/backends/backend_name/category_name.py`. In a nutshell, by the time the backend implementation is entered, the correct :code:`dtype` and :code:`device` to use have both already been correctly handled by code which is wrapped around the backend implementation. -This is further explained in the :ref:`Data Types` and :ref:`Devices` sections respectively. +This is further explained in the `Data Types `_ and `Devices `_ sections respectively. Numbers in Operator Functions ----------------------------- diff --git a/docs/overview/deep_dive/function_types.rst b/docs/overview/deep_dive/function_types.rst index 9765d6383aee6..aba496df485d1 100644 --- a/docs/overview/deep_dive/function_types.rst +++ b/docs/overview/deep_dive/function_types.rst @@ -81,8 +81,8 @@ The backend-specific implementation of :func:`ivy.tan` for PyTorch in :mod:`ivy x = _cast_for_unary_op(x) return torch.tan(x, out=out) -The reason that the Ivy implementation has type hint :code:`Union[ivy.Array, ivy.NativeArray]` but PyTorch implementation has :class:`torch.Tensor` is explained in the :ref:`Arrays` section. -Likewise, the reason that the :code:`out` argument in the Ivy implementation has array type hint :class:`ivy.Array` whereas :code:`x` has :code:`Union[ivy.Array, ivy.NativeArray]` is also explained in the :ref:`Arrays` section. +The reason that the Ivy implementation has type hint :code:`Union[ivy.Array, ivy.NativeArray]` but PyTorch implementation has :class:`torch.Tensor` is explained in the `Arrays `_ section. +Likewise, the reason that the :code:`out` argument in the Ivy implementation has array type hint :class:`ivy.Array` whereas :code:`x` has :code:`Union[ivy.Array, ivy.NativeArray]` is also explained in the `Arrays `_ section. Compositional Functions ----------------------- @@ -116,7 +116,7 @@ Mixed Functions --------------- --------------- -Sometimes, a function may only be provided by some of the supported backends. In this case, we have to take a mixed approach. We should always have a backend-specific implementation if there is a similar function provided by a certain backend. This maximises runtime efficiency, as the function in the backend will be implemented directly in C or C++. Such functions have some backend-specific implementations in :mod:`ivy/functional/backends/backend_name/category_name.py`, but not for all backends. To support backends that do not have a backend-specific implementation, a compositional implementation is also provided in :mod:`ivy/functional/ivy/category_name.py`. Compositional functions should only be used when there is no similar function to wrap in the backend. +Sometimes, a function may only be provided by some of the supported backends. In this case, we have to take a mixed approach. We should always have a backend-specific implementation if there is a similar function provided by a certain backend. This maximises runtime efficiency, as the function in the backend will be implemented directly in C or C++. Such functions have some backend-specific implementations in :mod:`ivy/functional/backends/backend_name/category_name.py`, but not for all backends. To support backends that do not have a backend-specific implementation, a compositional implementation is also provided in :mod:`ivy/functional/ivy/category_name.py`. Compositional functions should only be used when there is no similar function to wrap in the backend. Because these functions include both a compositional implementation and also at least one backend-specific implementation, these functions are referred to as *mixed*. @@ -135,7 +135,7 @@ One example of this is :code:`ivy.linear` for which the torch native function :c to be a 2 dimensional tensor while as ivy also allows the :code:`weight` argument to be 3 dimensional. While achieving the objective of having superset behaviour across the backends, the native functionality of frameworks should be made use of as much as possible. Even if a framework-specific function doesn't provide complete superset behaviour, we should still make use of the partial behaviour that it provides and then add more logic for the -remaining part. This is explained in detail in the :ref:`Maximizing Usage of Native Functionality` section. Ivy allows this partial support with the help of the `partial_mixed_handler`_ +remaining part. This is explained in detail in the :ref:`overview/deep_dive/superset_behaviour:Maximizing Usage of Native Functionality` section. Ivy allows this partial support with the help of the `partial_mixed_handler`_ attribute which should be added to the backend implementation with a boolean function that specifies some condition on the inputs to switch between the compositional and primary implementations. For example, the :code:`torch` backend implementation of :code:`linear`` looks like: @@ -159,7 +159,7 @@ the :code:`handle_partial_mixed_function` decorator first evaluates the boolean is `True` and the compositional implementation otherwise. -For further information on decorators, please refer to the :ref:`Function Wrapping` section. +For further information on decorators, please refer to the `Function Wrapping `_ section. For all mixed functions, we must add the :code:`mixed_backend_wrappers` attribute to the compositional implementation of mixed functions to specify which additional wrappers need to be applied to the primary implementation and which ones from the compositional implementation should be skipped. We do this by creating a dictionary of two keys, :code:`to_add` and :code:`to_skip`, each containing the tuple of wrappers to be added or skipped respectively. In general, :code:`handle_out_argument`, :code:`inputs_to_native_arrays` and :code:`outputs_to_ivy_arrays` @@ -215,9 +215,9 @@ This *nestable* property of Ivy functions means that the same function can be us This added support for handling :class:`ivy.Container` instances is all handled automatically when `_wrap_function`_ is applied to every function in the :code:`ivy` module during `backend setting`_. This will add the `handle_nestable`_ wrapping to the function if it has the :code:`@handle_nestable` decorator. -This function wrapping process is covered in a bit more detail in the :ref:`Function Wrapping` section. +This function wrapping process is covered in a bit more detail in the `Function Wrapping `_ section. -Nestable functions are explained in more detail in the :ref:`Containers` section. +Nestable functions are explained in more detail in the `Containers ` section. Convenience Functions --------------------- diff --git a/docs/overview/deep_dive/function_wrapping.rst b/docs/overview/deep_dive/function_wrapping.rst index c07c0132298a1..8a585244222e6 100644 --- a/docs/overview/deep_dive/function_wrapping.rst +++ b/docs/overview/deep_dive/function_wrapping.rst @@ -17,8 +17,9 @@ Function Wrapping .. _`handle_nestable`: https://github.com/unifyai/ivy/blob/5658401b266352d3bf72c95e4af6ae9233115722/ivy/func_wrapper.py#L896 .. _`inputs_to_native_shapes`: https://github.com/unifyai/ivy/blob/5658401b266352d3bf72c95e4af6ae9233115722/ivy/func_wrapper.py#L488 .. _`outputs_to_ivy_shapes`: https://github.com/unifyai/ivy/blob/5658401b266352d3bf72c95e4af6ae9233115722/ivy/func_wrapper.py#L501 +.. _`to_native_shapes_and_back`: https://github.com/unifyai/ivy/blob/5658401b266352d3bf72c95e4af6ae9233115722/ivy/func_wrapper.py#L514 .. _`handle_view`: https://github.com/unifyai/ivy/blob/5658401b266352d3bf72c95e4af6ae9233115722/ivy/func_wrapper.py#L627 -.. _`handle_view_indexing`: https://github.com/unifyai/ivy/blob/5658401b266352d3bf72c95e4af6ae9233115722/ivy/func_wrapper.py#L659 +.. _`handle_view_indexing`: https://github.com/unifyai/ivy/blob/5658401b266352d3bf72c95e4af6ae9233115722/ivy/func_wrapper.py#L659 .. _`handle_array_function`: https://github.com/unifyai/ivy/blob/5658401b266352d3bf72c95e4af6ae9233115722/ivy/func_wrapper.py#L299 .. _`handle_complex_input`: https://github.com/unifyai/ivy/blob/bd9b5b1080d33004e821a48c486b3a879b9d6616/ivy/func_wrapper.py#L1393 .. _`repo`: https://github.com/unifyai/ivy @@ -29,6 +30,11 @@ Function Wrapping .. _`ivy.linear`: https://github.com/unifyai/ivy/blob/5658401b266352d3bf72c95e4af6ae9233115722/ivy/functional/ivy/layers.py#L81 .. _`handle_exceptions`: https://github.com/unifyai/ivy/blob/5658401b266352d3bf72c95e4af6ae9233115722/ivy/utils/exceptions.py#L189 .. _`example`: https://github.com/unifyai/ivy/blob/5658401b266352d3bf72c95e4af6ae9233115722/ivy/functional/backends/torch/layers.py#L30 +.. _`Arrays`: arrays.rst +.. _`Inplace Updates`: inplace_updates.rst +.. _`Data Types`: data_types.rst +.. _`Devices`: devices.rst +.. _`Backend Setting`: backend_setting.rst When a backend framework is set by calling :code:`ivy.set_backend(backend_name)`, then all Ivy functions are `wrapped`_. This is achieved by calling `_wrap_function`_, which will apply the appropriate wrapping to the given function, based on what decorators it has. @@ -37,7 +43,7 @@ For example, `abs`_ has the decorators :code:`@to_native_arrays_and_back` and :c The new function returned by :code:`_wrap_function` is a replacement of the original function with extra code added to support requirements common to many functions in the API. This is the main purpose of the wrapping, to avoid code duplication which would exist if we added identical logic in every single function independently. -Depending on the function being wrapped, the new function might handle :ref:`Arrays`, :ref:`Inplace Updates`, :ref:`Data Types` and/or :ref:`Devices`. +Depending on the function being wrapped, the new function might handle `Arrays`_, `Inplace Updates`_, `Data Types`_ and/or `Devices`_. Our test decorators actually transforms to :code:`@given` decorators at Pytest collecting time, therefore this allows us to use other **Hypothesis** decorators like, :code:`@reproduce_failure`, :code:`@settings`, :code:`@seed`. @@ -74,11 +80,11 @@ This recommended order is followed to ensure that tests are efficient and accura Conversion Wrappers ^^^^^^^^^^^^^^^^^^^ -#. `inputs_to_native_arrays`_ : This wrapping function converts all :class:`ivy.Array` instances in the arguments to their :class:`ivy.NativeArray` counterparts, based on the :ref:`Backend Setting` before calling the function. -#. `inputs_to_ivy_arrays`_ : This wrapping function converts all :class:`ivy.NativeArray` instances in the arguments to their :class:`ivy.Array` counterparts, based on the :ref:`Backend Setting` before calling the function. -#. `outputs_to_ivy_arrays`_ : This wrapping function converts all :class:`ivy.NativeArray` instances in the outputs to their :class:`ivy.Array` counterparts, based on the :ref:`Backend Setting` before calling the function. +#. `inputs_to_native_arrays`_ : This wrapping function converts all :class:`ivy.Array` instances in the arguments to their :class:`ivy.NativeArray` counterparts, based on the `Backend Setting`_ before calling the function. +#. `inputs_to_ivy_arrays`_ : This wrapping function converts all :class:`ivy.NativeArray` instances in the arguments to their :class:`ivy.Array` counterparts, based on the `Backend Setting`_ before calling the function. +#. `outputs_to_ivy_arrays`_ : This wrapping function converts all :class:`ivy.NativeArray` instances in the outputs to their :class:`ivy.Array` counterparts, based on the `Backend Setting`_ before calling the function. #. `to_native_arrays_and_back`_ : This wrapping function converts all :class:`ivy.Array` instances in the arguments to their :class:`ivy.NativeArray` counterparts, calls the function with those arguments and then converts the :class:`ivy.NativeArray` instances in the output back to :class:`ivy.Array`. - This wrapping function is heavily used because it enables achieving the objective of ensuring that every ivy function could accept an :class:`ivy.Array` and return an :class:`ivy.Array`, making it independent of the :ref:`Backend Setting`. + This wrapping function is heavily used because it enables achieving the objective of ensuring that every ivy function could accept an :class:`ivy.Array` and return an :class:`ivy.Array`, making it independent of the `Backend Setting`_. Inference Wrappers ^^^^^^^^^^^^^^^^^^ @@ -95,7 +101,7 @@ Out Argument Support #. `handle_out_argument`_ : This wrapping function is used in nearly all ivy functions. It enables appropriate handling of the :code:`out` argument of functions. In cases where the backend framework natively supports the :code:`out` argument for a function, we prefer to use it as it's a more efficient implementation of the :code:`out` argument for that particular backend framework. - But in cases when it isn't supported, we support it anyway with :ref:`Inplace Updates`. + But in cases when it isn't supported, we support it anyway with `Inplace Updates`_. Nestable Support ^^^^^^^^^^^^^^^^ @@ -105,7 +111,7 @@ Nestable Support Partial Mixed Function Support ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -#. `handle_partial_mixed_function`_: This wrapping function enables switching between compositional and primary implementations of :ref:`Mixed Functions` based on some condition on the arguments of the function. +#. `handle_partial_mixed_function`_: This wrapping function enables switching between compositional and primary implementations of :ref:`overview/deep_dive/function_types:Mixed Functions` based on some condition on the arguments of the function. #. The condition is specified through a lambda function which when evaluates to `True` the primary implementation is run and otherwise the compositional implementation is executed. #. For backends that have a primary implementation of a mixed function, the reference to the compositional implementation is `stored as an attribute`_ inside the backend function during backend setting. To make use of this decorator, one must #. add the :code:`partial_mixed_handler` attribute containing the lambda function to the backend implementation. Here's an `example`_ from the torch backend implementation of linear. @@ -113,29 +119,29 @@ Partial Mixed Function Support Shape Conversion ^^^^^^^^^^^^^^^^ -#. `inputs_to_native_shapes`_ : This wrapping function converts all :class:`ivy.Shape` instances in the arguments to their :class:`ivy.NativeShape` counterparts, based on the :ref:`Backend Setting` before calling the function. -#. `outputs_to_ivy_shapes`_ : This wrapping function converts all :class:`ivy.NativeShape` instances in the outputs to their :class:`ivy.Shape` counterparts, based on the :ref:`Backend Setting` before calling the function. +#. `inputs_to_native_shapes`_ : This wrapping function converts all :class:`ivy.Shape` instances in the arguments to their :class:`ivy.NativeShape` counterparts, based on the `Backend Setting`_ before calling the function. +#. `outputs_to_ivy_shapes`_ : This wrapping function converts all :class:`ivy.NativeShape` instances in the outputs to their :class:`ivy.Shape` counterparts, based on the `Backend Setting`_ before calling the function. #. `to_native_shapes_and_back`_ : This wrapping function converts all :class:`ivy.Shape` instances in the arguments to their :class:`ivy.NativeShape` counterparts, calls the function with those arguments and then converts the :class:`ivy.NativeShape` instances in the output back to :class:`ivy.Shape`. View Handling ^^^^^^^^^^^^^ -#. `handle_view`_ : This wrapping function performs view handling based on our :ref:`Views` policy. +#. `handle_view`_ : This wrapping function performs view handling based on our :ref:`overview/deep_dive/inplace_updates:Views` policy. #. `handle_view_indexing`_ : This wrapping function is aimed at handling views for indexing. -Exception Handling +Exception Handling ^^^^^^^^^^^^^^^^^^ -#. `handle_exceptions`_ : This wrapping function helps in catching native exceptions and unifying them into `IvyException` or the relevant subclasses. More information can be found in the :ref:`Exception Handling` section. +#. `handle_exceptions`_ : This wrapping function helps in catching native exceptions and unifying them into `IvyException` or the relevant subclasses. More information can be found in the :ref:`overview/deep_dive/function_wrapping:Exception Handling` section. -Miscellaneous Wrappers +Miscellaneous Wrappers ^^^^^^^^^^^^^^^^^^^^^^ -#. `handle_array_function`_ : This wrapping function enables :ref:`Integrating custom classes with Ivy` +#. `handle_array_function`_ : This wrapping function enables :ref:`overview/deep_dive/arrays:Integrating custom classes with Ivy` #. `handle_complex_input`_ : This wrapping function enables handling of complex numbers. It introduces a keyword argument :code:`complex_mode`, which is used to choose the function's behaviour as per the wrapper's docstring. -When calling `_wrap_function`_ during :ref:`Backend Setting`, firstly the attributes of the functions are checked to get all the wrapping functions for a particular function. +When calling `_wrap_function`_ during `Backend Setting`_, firstly the attributes of the functions are checked to get all the wrapping functions for a particular function. Then all the wrapping functions applicable to a function are used to wrap the function. Each of these topics and each associated piece of logic added by the various wrapper functions are covered in more detail in the next sections. diff --git a/docs/overview/deep_dive/gradients.rst b/docs/overview/deep_dive/gradients.rst index 5ec644aa6a015..d33d17a7d2e7d 100644 --- a/docs/overview/deep_dive/gradients.rst +++ b/docs/overview/deep_dive/gradients.rst @@ -7,7 +7,7 @@ Gradients Overview -------- -Gradients are a crucial aspect of all modern deep learning workflows. +Gradients are a crucial aspect of all modern deep learning workflows. Different frameworks provide different APIs for gradient computation and there were a few considerations to be made while building a unified gradients API in Ivy. There are a number of functions added in ivy to allow gradient computation, but we'll mainly focus on the most commonly used and the most general function :func:`ivy.execute_with_gradients`. This is because the other gradient functions such as :func:`ivy.value_and_grad` and :func:`ivy.grad` can be considered as providing a subset of the functionality that :func:`ivy.execute_with_gradients` provides. @@ -21,13 +21,13 @@ The :func:`ivy.execute_with_gradients` function signature Following is the pseudo function signature for the :func:`ivy.execute_with_gradients` function, .. code-block:: python - + def execute_with_gradients ( func : Callable, xs : Any arbitrary nest, xs_grad_idxs : Input indices, ret_grad_idxs : Output indices, - ) : + ) : return func_ret, grads The :code:`func` in the input can be any user-defined function that returns a single scalar or any arbitrary nest of scalars. @@ -36,13 +36,13 @@ By scalars, we are referring to zero-dimensional arrays. So for example, the following are some valid outputs by the :code:`func`, .. code-block:: python - + ivy.array(12.) - + # OR ivy.Container( - a=ivy.array(12.), + a=ivy.array(12.), b=ivy.Container( c=ivy.array(15.), d=ivy.array(32.) @@ -74,8 +74,8 @@ An example using :func:`ivy.execute_with_gradients` xs = [x, y] ret, grads = ivy.execute_with_gradients( - func, - xs, + func, + xs, xs_grad_idxs=[[0]], ret_grad_idxs=[["a"]] ) @@ -126,7 +126,7 @@ Our policy on gradients * The gradient API is fully-functional in ivy. * There is no explicit variable class or any public-facing function for adding gradient support to an ivy.Array. * The gradient functions in ivy implicitly convert all arrays to support gradient computation before computing gradients and detach all arrays after computing gradients. -* We don't retain any previously tracked computations in arrays by frameworks like torch for e.g. +* We don't retain any previously tracked computations in arrays by frameworks like torch for e.g. * This makes our gradient API disambiguous, flexible, and easy to debug. * Any framework-specific tracking of computations or variable classes should be handled in the corresponding frontends. diff --git a/docs/overview/deep_dive/inplace_updates.rst b/docs/overview/deep_dive/inplace_updates.rst index 2faaf7c9b3a77..42df6520a9668 100644 --- a/docs/overview/deep_dive/inplace_updates.rst +++ b/docs/overview/deep_dive/inplace_updates.rst @@ -9,7 +9,6 @@ Inplace Updates .. _`jax.numpy.tan`: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.tan.html?highlight=tan .. _`presence of this attribute`: https://github.com/unifyai/ivy/blob/8ded4a5fc13a278bcbf2d76d1fa58ab41f5797d0/ivy/func_wrapper.py#L341 .. _`by the backend function`: https://github.com/unifyai/ivy/blob/8ded4a5fc13a278bcbf2d76d1fa58ab41f5797d0/ivy/func_wrapper.py#L372 -.. _`by the wrapper`: https://github.com/unifyai/ivy/blob/8ded4a5fc13a278bcbf2d76d1fa58ab41f5797d0/ivy/func_wrapper.py#L377 .. _`handled by the wrapper`: https://github.com/unifyai/ivy/blob/8ded4a5fc13a278bcbf2d76d1fa58ab41f5797d0/ivy/func_wrapper.py#L373 .. _`_wrap_fn`: https://github.com/unifyai/ivy/blob/6497b8a3d6b0d8aac735a158cd03c8f98eb288c2/ivy/container/wrapping.py#L69 .. _`NON_WRAPPED_FUNCTIONS`: https://github.com/unifyai/ivy/blob/fdaea62380c9892e679eba37f26c14a7333013fe/ivy/func_wrapper.py#L9 @@ -17,8 +16,6 @@ Inplace Updates .. _`ivy.reshape`: https://github.com/unifyai/ivy/blob/633eb420c5006a0a17c238bfa794cf5b6add8598/ivy/functional/ivy/manipulation.py#L418 .. _`ivy.astype`: https://github.com/unifyai/ivy/blob/8482eb3fcadd0721f339a1a55c3f3b9f5c86d8ba/ivy/functional/ivy/data_type.py#L46 .. _`ivy.asarray`: https://github.com/unifyai/ivy/blob/8482eb3fcadd0721f339a1a55c3f3b9f5c86d8ba/ivy/functional/ivy/creation.py#L114 -.. _`wrapping`: -.. _`ivy.inplace_update`: https://github.com/unifyai/ivy/blob/3a21a6bef52b93989f2fa2fa90e3b0f08cc2eb1b/ivy/functional/ivy/general.py#L1137 .. _`repo`: https://github.com/unifyai/ivy .. _`discord`: https://discord.gg/sXyFF8tDtm .. _`inplace updates channel`: https://discord.com/channels/799879767196958751/1028681763947552778 @@ -28,7 +25,7 @@ Inplace Updates Inplace updates enable users to overwrite the contents of existing arrays with new data. This enables much more control over the memory-efficiency of the program, preventing old unused arrays from being kept in memory for any longer than is strictly necessary. -The function `ivy.inplace_update`_ enables explicit inplace updates. +The function :func:`ivy.inplace_update` enables explicit inplace updates. :func:`ivy.inplace_update` is a *primary* function, and the backend-specific implementations for each backend are presented below. We also explain the rationale for why each implementation is the way it is, and the important differences. @@ -256,14 +253,14 @@ This could for example be the input array itself, but can also be any other arra All Ivy functions which return a single array should support inplace updates via the :code:`out` argument. The type hint of the :code:`out` argument is :code:`Optional[ivy.Array]`. However, as discussed above, if the function is *nestable* then :class:`ivy.Container` instances are also supported. -:class:`ivy.Container` is omitted from the type hint in such cases, as explained in the :ref:`Function Arguments` section. +:class:`ivy.Container` is omitted from the type hint in such cases, as explained in the `Function Arguments `_ section. When the :code:`out` argument is unspecified, then the return is simply provided in a newly created :class:`ivy.Array` (or :class:`ivy.Container` if *nestable*). However, when :code:`out` is specified, then the return is provided as an inplace update of the :code:`out` argument provided. This can for example be the same as the input to the function, resulting in a simple inplace update of the input. In the case of :class:`ivy.Array` return types, the :code:`out` argument is predominantly handled in `handle_out_argument`_. -As explained in the :ref:`Function Wrapping` section, this wrapping is applied to every function with the :code:`@handle_out_argument` decorator dynamically during `backend setting`_. +As explained in the `Function Wrapping `_ section, this wrapping is applied to every function with the :code:`@handle_out_argument` decorator dynamically during `backend setting`_. **Primary Functions** @@ -306,14 +303,14 @@ The implementations of :func:`ivy.tan` for each backend are as follows. **PyTorch** (includes :code:`support_native_out` attribute): .. code-block:: python - + def tan(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor: x = _cast_for_unary_op(x) return torch.tan(x, out=out) tan.support_native_out = True -It's very important to ensure the :code:`support_native_out` attribute is not added to backend implementations that do not handle the :code:`out` argument, as the `presence of this attribute`_ dictates whether the argument should be handled `by the backend function`_ or `by the wrapper`_. +It's very important to ensure the :code:`support_native_out` attribute is not added to backend implementations that do not handle the :code:`out` argument, as the `presence of this attribute`_ dictates whether the argument should be handled `by the backend function`_ or `by the wrapper `_. This distinction only concerns how the inplace update is applied to the native array, which is operated upon directly by the backend. If :code:`out` is specified in an Ivy function, then an inplace update is always **also** performed on the :class:`ivy.Array` instance itself, which is how :code:`out` is provided to the function originally. @@ -351,10 +348,10 @@ Here we still have the :attr:`support_native_out` attribute since we want to tak However, in the :code:`else` statement, the last operation is :func:`torch.transpose` which does not support the :code:`out` argument, and so the native inplace update can't be performed by torch here. This is why we need to call :func:`ivy.inplace_update` explicitly here, to ensure the native inplace update is performed, as well as the :class:`ivy.Array` inplace update. -Another case where we need to use :func:`ivy.inplace_update`_ with a function that has :attr:`support_native_out` is for the example of the :code:`torch` backend implementation of the :func:`ivy.remainder` function +Another case where we need to use :func:`ivy.inplace_update` with a function that has :attr:`support_native_out` is for the example of the :code:`torch` backend implementation of the :func:`ivy.remainder` function .. code-block:: python - + def remainder( x1: Union[float, torch.Tensor], x2: Union[float, torch.Tensor], @@ -435,7 +432,7 @@ Technically, this could be handled using the `handle_out_argument`_ wrapping, bu **Mixed Functions** -As explained in the :ref:`Function Types` section, *mixed* functions can effectively behave as either compositional or primary functions, depending on the backend that is selected. We must add the :code:`handle_out_argument` to the :code:`add_wrappers`key of +As explained in the `Function Types `_ section, *mixed* functions can effectively behave as either compositional or primary functions, depending on the backend that is selected. We must add the :code:`handle_out_argument` to the :code:`add_wrappers`key of the :code:`mixed_backend_wrappers` attribute so that the decorator gets added to the primary implementation when the backend is set. Here's an `example`_ from the linear function. @@ -451,7 +448,7 @@ When :code:`copy` is not specified explicitly, then an inplace update is perform Setting :code:`copy=False` is equivalent to passing :code:`out=input_array`. If only one of :code:`copy` or :code:`out` is specified, then this specified argument is given priority. If both are specified, then priority is given to the more general :code:`out` argument. -As with the :code:`out` argument, the :code:`copy` argument is also handled `by the wrapper `_. +As with the :code:`out` argument, the :code:`copy` argument is also handled `by the wrapper `_. Views @@ -484,7 +481,7 @@ Here's a brief description of the additional attributes added to :class:`ivy.Arr #. PyTorch reference stack (:code:`._torch_view_refs`): Functional views referencing this array in its PyTorch base, only populated for original arrays or functional views. #. PyTorch manipulation cache (:code:`._torch_manipulation`): Tuple storing array or view and function which made the functional view, only populated for functional views -.. note:: +.. note:: Parts of an arrays metadata like :code:`stride` are attributed to the low-level memory layout of arrays while views in :code:`ivy` operate at a higher level of abstraction. As a result, :func:`ivy.strides` isn't guaranteed to produce an output reflective of the underlying memory layout if the :class:`ivy.Array` passed in is a view (or in other words has a :code:`_base`). diff --git a/docs/overview/deep_dive/ivy_frontends.rst b/docs/overview/deep_dive/ivy_frontends.rst index 4954ba12d1611..8214423af7e3b 100644 --- a/docs/overview/deep_dive/ivy_frontends.rst +++ b/docs/overview/deep_dive/ivy_frontends.rst @@ -17,7 +17,6 @@ Ivy Frontends .. _`YouTube tutorial series`: https://www.youtube.com/watch?v=72kBVJTpzIw&list=PLwNuX3xB_tv-wTpVDMSJr7XW6IP_qZH0t .. _`discord`: https://discord.gg/sXyFF8tDtm .. _`ivy frontends channel`: https://discord.com/channels/799879767196958751/998782045494976522 -.. _`open task`: https://unify.ai/docs/ivy/overview/contributing/open_tasks.html#frontend-apis .. _`Array manipulation routines`: https://numpy.org/doc/stable/reference/routines.array-manipulation.html# .. _`Array creation routines`: https://numpy.org/doc/stable/reference/routines.array-creation.html @@ -62,7 +61,7 @@ Therefore, in order to avoid this potential conflict: * You should ensure that the tests are passing before merging any frontend PRs. The only exception to this rule is if the test is failing due to a bug in the Ivy functional API, which does not need to be solved as part of the frontend task. -There will be some implicit discussion of the locations of frontend functions in these examples, however an explicit explanation of how to place a frontend function can be found in a sub-section of the Frontend APIs `open task`_. +There will be some implicit discussion of the locations of frontend functions in these examples, however an explicit explanation of how to place a frontend function can be found in a sub-section of the Frontend APIs :ref:`open task `. **NOTE:** Type hints, docstrings, and examples are not required when working on frontend functions. @@ -73,7 +72,7 @@ There will be some implicit discussion of the locations of frontend functions in The native arrays of each framework have their own attributes and instance methods which differ from the attributes and instance methods of :class:`ivy.Array`. As such we have implemented framework-specific array classes: :class:`tf_frontend.Tensor`, :class:`torch_frontend.Tensor`, :class:`numpy_frontend.ndarray`, and :class:`jax_frontend.DeviceArray`. These classes simply wrap an :class:`ivy.Array`, which is stored in the :code:`ivy_array` attribute, and behave as closely as possible to the native framework array classes. -This is explained further in the `Classes and Instance Methods `_ section. +This is explained further in the :ref:`overview/deep_dive/ivy_frontends:Classes and Instance Methods` section. As we aim to replicate the frontend frameworks as closely as possible, all functions accept their frontend array class (as well as :class:`ivy.Array` and :class:`ivy.NativeArray`) and return a frontend array. However, since most logic in each function is handled by Ivy, the :class:`ivy.Array` must be extracted from any frontend array inputs. @@ -262,7 +261,7 @@ However, these functions are specified to have key-word only arguments and in so In order to tackle these variations in behaviour, the :code:`map_raw_ops_alias` decorator was designed to wrap the functions that exist in the TensorFlow namespace, thus reducing unnecessary re-implementations. .. code-block:: python - + # in ivy/functional/frontends/tensorflow/math.py @to_ivy_arrays_and_back def argmax(input, axis, output_type=None, name=None): @@ -318,7 +317,7 @@ Short Frontend Implementations Ideally, all frontend functions should call the equivalent Ivy function and only be one line long. This is mainly because compositional implementations are bound to be slower than direct backend implementation calls. -In case a frontend function is complex and there is no equivalent Ivy function to use, it is strongly advised to add that function to our Experimental API. To do so, you are invited to open a *Missing Function Suggestion* issue as described in the `Open Tasks `_ section. A member of our team will then review your issue, and if the proposed addition is deemed to be timely and sensible, we will add the function to the "Extend Ivy Functional API" `ToDo list issue `_. +In case a frontend function is complex and there is no equivalent Ivy function to use, it is strongly advised to add that function to our Experimental API. To do so, you are invited to open a *Missing Function Suggestion* issue as described in the `Open Tasks <../contributing/open_tasks.rst>`_ section. A member of our team will then review your issue, and if the proposed addition is deemed to be timely and sensible, we will add the function to the "Extend Ivy Functional API" `ToDo list issue `_. If you would rather not wait around for a member of our team to review your suggestion, you can instead go straight ahead and add the frontend function as a heavy composition of the existing Ivy functions, with a :code:`#ToDo` comment included, explaining that this frontend implementation will be simplified when :func:`ivy.func_name` is added. @@ -348,7 +347,7 @@ The native TensorFlow function :func:`tf.reduce_logsumexp` does not have an equi Through compositions, we can easily meet the required input-output behaviour for the TensorFlow frontend function. -The entire workflow for extending the Ivy Frontends as an external contributor is explained in more detail in the `Open Tasks `_ section. +The entire workflow for extending the Ivy Frontends as an external contributor is explained in more detail in the :ref:`Open Tasks ` section. Unused Arguments ---------------- @@ -409,7 +408,7 @@ Classes and Instance Methods ---------------------------- Most frameworks include instance methods and special methods on their array class for common array processing functions, such as :func:`reshape`, :func:`expand_dims` and :func:`add`. -This simple design choice comes with many advantages, some of which are explained in our :ref:`Ivy Array` section. +This simple design choice comes with many advantages, some of which are explained in our `Ivy Array <../design/ivy_as_a_framework/ivy_array.rst>`_ section. **Important Note** Before implementing the instance method or special method, make sure that the regular function in the specific frontend is already implemented. @@ -516,7 +515,7 @@ For example, :class:`numpy.matrix` has an instance method of :meth:`any`: return any(self.A, axis=axis, out=out) We need to create these frontend array classes and all of their instance methods and also their special methods such that we are able to transpile code which makes use of these methods. -As explained in :ref:`Ivy as a Transpiler`, when transpiling code we first extract the computation graph in the source framework. +As explained in `Ivy as a Transpiler <../design/ivy_as_a_transpiler.rst>`_, when transpiling code we first extract the computation graph in the source framework. In the case of instance methods, we then replace each of the original instance methods in the extracted computation graph with these new instance methods defined in the Ivy frontend class. Frontend Data Type Promotion Rules diff --git a/docs/overview/deep_dive/ivy_frontends_tests.rst b/docs/overview/deep_dive/ivy_frontends_tests.rst index 5b27829d9e5ff..97a6f7a6ab2c2 100644 --- a/docs/overview/deep_dive/ivy_frontends_tests.rst +++ b/docs/overview/deep_dive/ivy_frontends_tests.rst @@ -1,17 +1,16 @@ Ivy Frontend Tests ================== -.. _`here`: https://unify.ai/docs/ivy/design/ivy_as_a_transpiler.html +.. _`here`: ../design/ivy_as_a_transpiler.rst .. _`ivy frontends tests channel`: https://discord.com/channels/799879767196958751/1028267758028337193 .. _`test ivy`: https://github.com/unifyai/ivy/tree/db9a22d96efd3820fb289e9997eb41dda6570868/ivy_tests/test_ivy .. _`test_frontend_function`: https://github.com/unifyai/ivy/blob/591ac37a664ebdf2ca50a5b0751a3a54ee9d5934/ivy_tests/test_ivy/helpers.py#L1047 .. _`discord`: https://discord.gg/sXyFF8tDtm -.. _`Function Wrapping`: https://unify.ai/docs/ivy/overview/deep_dive/function_wrapping.html -.. _`open task`: https://unify.ai/docs/ivy/overview/contributing/open_tasks.html -.. _`Ivy Tests`: https://unify.ai/docs/ivy/overview/deep_dive/ivy_tests.html +.. _`Function Wrapping`: function_wrapping.rst +.. _`open task`: ../contributing/open_tasks.rst +.. _`Ivy Tests`: ivy_tests.rst .. _`Function Testing Helpers`: https://github.com/unifyai/ivy/blob/bf0becd459004ae6cffeb3c38c02c94eab5b7721/ivy_tests/test_ivy/helpers/function_testing.py -.. _`CI Pipeline`: https://unify.ai/docs/ivy/overview/deep_dive/continuous_integration.html -.. _`setting up`: https://unify.ai/docs/ivy/compiler/setting_up.html#setting-up-testing +.. _`CI Pipeline`: continuous_integration.rst Introduction @@ -168,7 +167,7 @@ ivy.tan() **TensorFlow** .. code-block:: python - + # ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py @handle_frontend_test( fn_tree="tensorflow.math.tan", @@ -625,9 +624,9 @@ for example, :code:`ndarray.__add__` would expect an array as input, despite the **Important Helper Functions** :func:`@handle_frontend_method` requires 3 keyword only parameters: - - :code:`class_tree` A full path to the array class in **Ivy** namespace. + - :code:`class_tree` A full path to the array class in **Ivy** namespace. - :code:`init_tree` A full path to initialization function. - - :code:`method_name` The name of the method to test. + - :code:`method_name` The name of the method to test. :func:`helpers.test_frontend_method` is used to test frontend instance methods. It is used in the same way as :func:`helpers.test_frontend_function`. A few important arguments for this function are following: - :code:`init_input_dtypes` Input dtypes of the arguments on which we are initializing the array on. @@ -802,7 +801,7 @@ The CI Pipeline runs the entire collection of Frontend Tests for the frontend th You will need to make sure the Frontend Test is passing for each Ivy Frontend function you introduce/modify. If a test fails on the CI, you can see details about the failure under `Details -> Run Frontend Tests` as shown in `CI Pipeline`_. -You can also run the tests locally before making a PR. See the relevant `setting up`_ section for instructions on how to do so. +You can also run the tests locally before making a PR. See the relevant :ref:`overview/contributing/setting_up:Setting Up Testing in PyCharm` section for instructions on how to do so. Frontend Framework Testing Configuration ---------------------------------------- diff --git a/docs/overview/deep_dive/ivy_lint.rst b/docs/overview/deep_dive/ivy_lint.rst new file mode 100644 index 0000000000000..388e67bfd58b2 --- /dev/null +++ b/docs/overview/deep_dive/ivy_lint.rst @@ -0,0 +1,58 @@ +Ivy-Lint: Ivy's Custom Code Formatters +====================================== + +Overview +-------- + +``ivy-lint`` is a specialized suite of formatters crafted for the Ivy codebase. It addresses unique formatting requirements not catered to by standard Python formatters. While the suite currently highlights the ``FunctionOrderingFormatter``, we're continually expanding to include more formatters tailored to Ivy's needs. + +Existing Formatters +------------------- + +FunctionOrderingFormatter +~~~~~~~~~~~~~~~~~~~~~~~~~ + +This formatter ensures a standardized order of declarations within Python files, organizing functions, classes, and assignments based on a hierarchy designed for the Ivy codebase. + +**Purpose**: To bring a sense of uniformity and structure to the code files by sorting various Python declarations. + +**Target Files**: Specifically designed for frontends and tests. + +How the Formatter Works: +~~~~~~~~~~~~~~~~~~~~~~~~ + +1. **Header Management**: + - Removes pre-existing headers in the source code based on specific patterns. + +2. **Comments Handling**: + - Extracts code components along with their leading comments, ensuring that relevant comments are retained during the reordering process. + +3. **Dependency Handling**: + - Constructs dependency graphs to understand and maintain the relationships between classes and assignments. + +4. **Sorting Logic**: + - Prioritizes imports, followed by assignments based on certain dependencies, then classes, and finally functions. + - Preserves module-level docstrings at the top of the file. + - Organizes helper functions and primary functions into separate sections for clarity. + +5. **File Processing**: + - Processes files that align with certain patterns, rearranging their content as needed. + +Integration and Usage +--------------------- + +To get the best out of ``ivy-lint``, integrate it within a pre-commit hook. This ensures that whenever code changes are about to be committed, the suite checks and, if needed, formats the files to align with Ivy's standards. + +For comprehensive details on weaving ``ivy-lint`` into your development practices, kindly refer to our `formatting guide `_. + +Contribution +------------ + +We’re always thrilled to welcome contributions to ``ivy-lint``. If you're brimming with ideas for a new formatter or can enhance our existing ones, please connect with us either on our GitHub repository or our `discord `_ channel. + +Round Up +-------- + +``ivy-lint`` stands as a testament to Ivy's commitment to code clarity and uniformity. As the landscape of our needs shifts, we foresee further refining and expanding our suite of formatters. + +For all discussions or inquiries, you're always welcome on `discord `_ in the `formatting channel `_. diff --git a/docs/overview/deep_dive/ivy_tests.rst b/docs/overview/deep_dive/ivy_tests.rst index c151c7a7e8fb7..be78008687a57 100644 --- a/docs/overview/deep_dive/ivy_tests.rst +++ b/docs/overview/deep_dive/ivy_tests.rst @@ -13,7 +13,7 @@ Ivy Tests .. _`methods`: https://hypothesis.readthedocs.io/en/latest/data.html .. _`finfo`: https://github.com/unifyai/ivy/blob/d8f1ffe8ebf38fa75161c1a9459170e95f3c82b6/ivy/functional/ivy/data_type.py#L276 .. _`data generation`: https://github.com/unifyai/ivy/blob/7063bf4475b93f87a4a96ef26c56c2bd309a2338/ivy_tests/test_ivy/test_functional/test_core/test_dtype.py#L337 -.. _`Function Types`: https://unify.ai/docs/ivy/overview/deep_dive/function_types.html +.. _`Function Types`: function_types.rst .. _`test_default_int_dtype`: https://github.com/unifyai/ivy/blob/7063bf4475b93f87a4a96ef26c56c2bd309a2338/ivy_tests/test_ivy/test_functional/test_core/test_dtype.py#L835 .. _`sampled_from`: https://hypothesis.readthedocs.io/en/latest/data.html#hypothesis.strategies.sampled_from .. _`lists`: https://hypothesis.readthedocs.io/en/latest/data.html#hypothesis.strategies.lists @@ -53,9 +53,15 @@ Ivy Tests .. _`dtype_and_values`: https://github.com/unifyai/ivy/blob/e50f71e283313caa9737f3c284496022ac67b58b/ivy_tests/test_ivy/helpers/hypothesis_helpers/array_helpers.py#L83 .. _`dtype_values_axis`: https://github.com/unifyai/ivy/blob/e50f71e283313caa9737f3c284496022ac67b58b/ivy_tests/test_ivy/helpers/hypothesis_helpers/array_helpers.py#L235 .. _`array_values`: https://github.com/unifyai/ivy/blob/e50f71e283313caa9737f3c284496022ac67b58b/ivy_tests/test_ivy/helpers/hypothesis_helpers/array_helpers.py#L543 -.. _`CI Pipeline`: https://unify.ai/docs/ivy/overview/deep_dive/continuous_integration.html -.. _`Setting Up Testing in PyCharm`: https://unify.ai/docs/ivy/overview/contributing/setting_up.html#setting-up-testing-in-pycharm -.. _`Setting up for Free`: https://unify.ai/docs/ivy/overview/contributing/setting_up.html#setting-up-for-free +.. _`array_dtypes`: https://github.com/unifyai/ivy/blob/e50f71e283313caa9737f3c284496022ac67b58b/ivy_tests/test_ivy/helpers/hypothesis_helpers/dtype_helpers.py#L15 +.. _`array_bools`: https://github.com/unifyai/ivy/blob/e50f71e283313caa9737f3c284496022ac67b58b/ivy_tests/test_ivy/helpers/hypothesis_helpers/array_helpers.py#L17 +.. _`reshape_shapes`: https://github.com/unifyai/ivy/blob/e50f71e283313caa9737f3c284496022ac67b58b/ivy_tests/test_ivy/helpers/hypothesis_helpers/general_helpers.py#L16 +.. _`get_axis`: https://github.com/unifyai/ivy/blob/e50f71e283313caa9737f3c284496022ac67b58b/ivy_tests/test_ivy/helpers/hypothesis_helpers/general_helpers.py#L178 +.. _`get_shape`: https://github.com/unifyai/ivy/blob/e50f71e283313caa9737f3c284496022ac67b58b/ivy_tests/test_ivy/helpers/hypothesis_helpers/general_helpers.py#L67 +.. _`get_bounds`: https://github.com/unifyai/ivy/blob/e50f71e283313caa9737f3c284496022ac67b58b/ivy_tests/test_ivy/helpers/hypothesis_helpers/general_helpers.py#L145 +.. _`subsets`: https://github.com/unifyai/ivy/blob/e50f71e283313caa9737f3c284496022ac67b58b/ivy_tests/test_ivy/helpers/hypothesis_helpers/general_helpers.py#L48 +.. _`num_positional_args`: https://github.com/unifyai/ivy/blob/e50f71e283313caa9737f3c284496022ac67b58b/ivy_tests/test_ivy/helpers/testing_helpers.py#L78 +.. _`CI Pipeline`: continuous_integration.rst .. _`Hypothesis docs`: https://hypothesis.readthedocs.io/en/latest/data.html#core-strategies On top of the Array API `test suite`_, which is included as a submodule mapped to the folder :code:`test_array_api`, there is also a collection of Ivy tests, located in subfolder `test_ivy`_. @@ -252,7 +258,7 @@ Writing Ivy Tests ^^^^^^^^^^^^^^^^^ As mentioned previously, testing Ivy functions needs a lot of pre-processing and past-processing, using only :code:`given` decorator would not be sufficient -to write an effective test, the following example describes how to implement a test for the function :code:`ivy.abs, using our test decorators and test helpers. +to write an effective test, the following example describes how to implement a test for the function :code:`ivy.abs`, using our test decorators and test helpers. .. code-block:: python @handle_test( @@ -306,7 +312,7 @@ One thing to note here is the :code:`test_flags` variable in the test function. The test flags can also be generated explicitly like this -: .. code-block:: python - + @handle_test( as_variable_flags = st.lists(st.booleans(), min_size = , max_size = ), native_array_flags = st.lists(st.booleans(), min_size = , max_size = ), @@ -485,7 +491,7 @@ Meaning if the input is to be treated as a container, at the same time, is it a The generated values are then passed to the array creation functions inside the test function as tuples. -9. `valid_axes`_ - This function generates valid axes for a given array dimension. +9. valid_axes - This function generates valid axes for a given array dimension. For example -: .. code-block:: python @@ -561,7 +567,7 @@ This function should be used in places where the result doesn’t depend on the **Note** - Under the hood, **array_values** strategy is called if the data type is *integer*, and **none_or_list_of_floats** is called when the data type is *float*. -15. `get_probs`_ - This is used to generate a tuple containing two values. +15. get_probs - This is used to generate a tuple containing two values. The first one being the *unnormalized probabilities* for all elements in a population, the second one being the *population size*. For example-: @@ -616,7 +622,7 @@ It would be helpful to keep in mind the following points while writing test -: Testing Partial Mixed Functions ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -As explained in the :ref:`Function Types` section, partial mixed functions are a special type of mixed functions that either utilize the compositional implementation +As explained in the `Function Types `_ section, partial mixed functions are a special type of mixed functions that either utilize the compositional implementation or the primary implementation depending on some conditions on the input. Therefore, the data-types supported by partial mixed functions depend on which implementation will be used for the given input. For example, when :code:`function_supported_dtypes` is called with respect to `ivy.linear` with torch backend, the following output is returned: @@ -625,7 +631,7 @@ be used for the given input. For example, when :code:`function_supported_dtypes` {'compositional': ('float32', 'int8', 'uint8', 'float64', 'int16', 'int32', 'int64'), 'primary': ('bool', 'float32', 'int8', 'uint8', 'float64', 'int64', 'int16', 'int32')} As can be seen from the above output that the data-types supported will depend on the implementation used for the given input. It's because of this reason that we need a slightly -different pipeline for testing partial mixed functions. Basically, while writing the strategies for the tests of these functions, we need to first determine which implementation +different pipeline for testing partial mixed functions. Basically, while writing the strategies for the tests of these functions, we need to first determine which implementation will be used and then based on that generate the data to test the function. Here's an example from the test of :code:`ivy.linear` function: @@ -678,8 +684,8 @@ will be used and then based on that generate the data to test the function. Here As can be seen from the above code, a boolean parameter :code:`mixed_fn_compos` is generated first to determine whether to generate test data for the compositional implementation or the primary one. When it is equal to :code:`True`, the relevant data for the compositional implementation should -be generated and when :code:`False`, data corresponding to the primary implementation should be generated. Another boolean, :code:`is_torch_backend` -is to be used to determine if the current backend is :code:`torch`. Then these booleans are used together in this :code:`if` condition: +be generated and when :code:`False`, data corresponding to the primary implementation should be generated. Another boolean, :code:`is_torch_backend` +is to be used to determine if the current backend is :code:`torch`. Then these booleans are used together in this :code:`if` condition: :code:`if is_torch_backend and not mixed_fn_compos` and :code:`weight_shape` is updated to be 2 dimensional because the torch backend implementation only supports 2 dimensional weights. Notice that the parameter :code:`mixed_fn_compos` is also be passed to :code:`helpers.get_dtypes` and :code:`helpers.ints` functions so that the dtypes corresponding to the implementation to be tested are returned. In general, :code:`helpers.get_dtypes`, @@ -843,7 +849,7 @@ You will need to make sure the Ivy Test is passing for each Ivy function you int If a test fails on the CI, you can see details about the failure under `Details -> Run Ivy Tests` as shown in `CI Pipeline`_. You can also run the tests locally before making a PR. The instructions differ according to the IDE you are using. For -PyCharm and Visual Studio Code you can refer to the `Setting Up Testing in PyCharm`_ section and `Setting up for Free`_ +PyCharm and Visual Studio Code you can refer to the :ref:`overview/contributing/setting_up:Setting Up Testing in PyCharm` section and :ref:`overview/contributing/setting_up:Setting up for Free` section respectively. Re-Running Failed Ivy Tests diff --git a/docs/overview/deep_dive/navigating_the_code.rst b/docs/overview/deep_dive/navigating_the_code.rst index 29049c1bd6748..1bd9b0ae9dd7d 100644 --- a/docs/overview/deep_dive/navigating_the_code.rst +++ b/docs/overview/deep_dive/navigating_the_code.rst @@ -7,7 +7,6 @@ Navigating the Code .. _`navigating the code channel`: https://discord.com/channels/799879767196958751/982737793476345888 .. _`Array API Standard convention`: https://data-apis.org/array-api/2021.12/API_specification/array_object.html#api-specification-array-object--page-root .. _`flake8`: https://flake8.pycqa.org/en/latest/index.html -.. _`pre-commit guide`: https://unify.ai/docs/ivy/overview/contributing/setting_up.html#pre-commit Categorization -------------- diff --git a/docs/overview/deep_dive/superset_behaviour.rst b/docs/overview/deep_dive/superset_behaviour.rst index 446544c9fb936..5e232c7ceabd3 100644 --- a/docs/overview/deep_dive/superset_behaviour.rst +++ b/docs/overview/deep_dive/superset_behaviour.rst @@ -73,7 +73,7 @@ We explore this through the examples of :func:`softplus`. **ivy.softplus** -When looking at the :func:`softplus` (or closest equivalent) implementations for `Ivy `_, `JAX `_, `TensorFlow `_, and `PyTorch `_, we can see that torch is the only framework which supports the inclusion of the :code:`beta` and :code:`threshold` arguments, which are added for improved numerical stability. +When looking at the :func:`softplus` (or closest equivalent) implementations for `Ivy <../../docs/functional/ivy/activations/ivy.functional.ivy.activations.softplus.rst>`_, `JAX `_, `TensorFlow `_, and `PyTorch `_, we can see that torch is the only framework which supports the inclusion of the :code:`beta` and :code:`threshold` arguments, which are added for improved numerical stability. We can also see that numpy does not support a :func:`softplus` function at all. Ivy should also support the :code:`beta` and :code:`threshold` arguments, in order to provide the generalized superset implementation among the backend frameworks. @@ -143,25 +143,25 @@ The first three examples are more-or-less superset examples, while the last exam **ivy.linspace** -When looking at the :func:`linspace` (or closest equivalent) implementations for `Ivy `_, `JAX `_, `NumPy `_, `TensorFlow `_, and `PyTorch `_, we can see that torch does not support arrays for the :code:`start` and :code:`end` arguments, while JAX, numpy, and tensorflow all do. +When looking at the :func:`linspace` (or closest equivalent) implementations for `Ivy <../../docs/functional/ivy/creation/ivy.functional.ivy.creation.linspace.rst>`_, `JAX `_, `NumPy `_, `TensorFlow `_, and `PyTorch `_, we can see that torch does not support arrays for the :code:`start` and :code:`end` arguments, while JAX, numpy, and tensorflow all do. Likewise, Ivy also supports arrays for the :code:`start` and :code:`stop` arguments, and in doing so provides the generalized superset implementation among the backend frameworks. **ivy.eye** -When looking at the :func:`eye` (or closest equivalent) implementations for `Ivy `_, `JAX `_, `NumPy `_, `TensorFlow `_, and `PyTorch `_, we can see that tensorflow is the only framework which supports a :code:`batch_shape` argument. +When looking at the :func:`eye` (or closest equivalent) implementations for `Ivy <../../docs/functional/ivy/creation/ivy.functional.ivy.creation.eye.rst>`_, `JAX `_, `NumPy `_, `TensorFlow `_, and `PyTorch `_, we can see that tensorflow is the only framework which supports a :code:`batch_shape` argument. Likewise, Ivy also supports a :code:`batch_shape` argument, and in doing so provides the generalized superset implementation among the backend frameworks. **ivy.scatter_nd** -When looking at the :func:`scatter_nd` (or closest equivalent) implementations for `Ivy `_, `JAX `_, `NumPy `_, `TensorFlow `_, and `PyTorch `_, we can see that torch only supports scattering along a single dimension, while all other frameworks support scattering across multiple dimensions at once. +When looking at the :func:`scatter_nd` (or closest equivalent) implementations for `Ivy <../../docs/functional/ivy/general/ivy.functional.ivy.general.scatter_nd.rst>`_, `JAX `_, `NumPy `_, `TensorFlow `_, and `PyTorch `_, we can see that torch only supports scattering along a single dimension, while all other frameworks support scattering across multiple dimensions at once. Likewise, Ivy also supports scattering across multiple dimensions at once, and in doing so provides the generalized superset implementation among the backend frameworks. **ivy.logical_and** -When looking at the :func:`logical_and` (or closest equivalent) implementations for `Ivy `_, `JAX `_, `NumPy `_, `TensorFlow `_, and `PyTorch `_, we can see that numpy and torch support the :code:`out` argument for performing inplace updates, while JAX and tensorflow do not. +When looking at the :func:`logical_and` (or closest equivalent) implementations for `Ivy <../../docs/functional/ivy/elementwise/ivy.functional.ivy.elementwise.logical_and.rst>`_, `JAX `_, `NumPy `_, `TensorFlow `_, and `PyTorch `_, we can see that numpy and torch support the :code:`out` argument for performing inplace updates, while JAX and tensorflow do not. With regards to the supported data types, JAX, numpy, and torch support numeric arrays, while tensorflow supports only boolean arrays. With regards to both of these points, Ivy provides the generalized superset implementation among the backend frameworks, with support for the :code:`out` argument and also support for both numeric and boolean arrays in the input. @@ -174,8 +174,8 @@ Maximizing Usage of Native Functionality While achieving the objective of having superset behaviour across the backends, the native functionality of frameworks should be made use of as much as possible. Even if a framework-specific function doesn't provide complete superset behaviour, we should still make use of the partial behaviour that it provides and then add more logic for the remaining part. -This is for efficiency reasons and is more explained under the `Mixed Function `_ section. -In cases when a framework-specific function exists for one or two backends but not the others, we implement a `Mixed Function `_. +This is for efficiency reasons and is more explained under the :ref:`Mixed Function ` section. +In cases when a framework-specific function exists for one or two backends but not the others, we implement a :ref:`Mixed Function `. But when the framework-specific functions do not cover all superset functionality, Ivy also allows for a mixed-compositional hybrid approach. Consider the example of :func:`interpolate`. diff --git a/docs/overview/design.rst b/docs/overview/design.rst index d487262d5f4d2..ea32c66512596 100644 --- a/docs/overview/design.rst +++ b/docs/overview/design.rst @@ -1,11 +1,13 @@ Design ====== +.. _`Deep Dive`: deep_dive.rst + This section is aimed at general users, who would like to learn how to use Ivy, and are less concerned about how it all works under the hood 🔧 -The :ref:`Deep Dive` section is more targeted at potential contributors, and at users who would like to dive deeper into the weeds of the framework🌱, and gain a better understanding of what is actually going on behind the scenes 🎬 +The `Deep Dive`_ section is more targeted at potential contributors, and at users who would like to dive deeper into the weeds of the framework🌱, and gain a better understanding of what is actually going on behind the scenes 🎬 -If that sounds like you, feel free to check out the :ref:`Deep Dive` section after you've gone through the higher level overview which is covered in this *design* section! +If that sounds like you, feel free to check out the `Deep Dive`_ section after you've gone through the higher level overview which is covered in this *design* section! | So, starting off with our higher level *design* section, Ivy can fulfill two distinct purposes: | @@ -23,16 +25,16 @@ If that sounds like you, feel free to check out the :ref:`Deep Dive` section aft :align: center :width: 100% -| (a) :ref:`Building Blocks` +| (a) `Building Blocks `_ | back-end functional APIs ✅ | Ivy functional API ✅ | Framework Handler ✅ | Ivy Compiler 🚧 | -| (b) :ref:`Ivy as a Transpiler` +| (b) `Ivy as a Transpiler `_ | front-end functional APIs 🚧 | -| (c) :ref:`Ivy as a Framework` +| (c) `Ivy as a Framework `_ | Ivy stateful API ✅ | Ivy Container ✅ | Ivy Array ✅ diff --git a/docs/overview/design/building_blocks.rst b/docs/overview/design/building_blocks.rst index 3b77f3b1d26aa..0d1e2b22250d8 100644 --- a/docs/overview/design/building_blocks.rst +++ b/docs/overview/design/building_blocks.rst @@ -1,8 +1,6 @@ Building Blocks =============== -.. _`out argument`: https://unify.ai/docs/ivy/overview/deep_dive/inplace_updates.html#out-argument - Here we explain the components of Ivy which are fundamental to its usage either as a code converter or as a fully-fledged framework-agnostic ML framework. These are the 4 parts labelled as (a) in the image below: @@ -73,7 +71,7 @@ There are separate backend modules for JAX, TensorFlow, PyTorch, and NumPy, and stack.support_native_out = True -There were no changes required for this function, however NumPy and PyTorch both had to be marked as supporting the `out argument`_ natively. +There were no changes required for this function, however NumPy and PyTorch both had to be marked as supporting the :ref:`overview/deep_dive/inplace_updates:out argument` natively. For more complicated functions, we need to do more than simply wrap and maybe change the name. For functions with differing behavior then we must modify the function to fit the unified in-out behavior of Ivy’s API. @@ -491,7 +489,7 @@ The example above further emphasizes that the graph compiler creates a computati Specifically, the same Ivy code compiles to different graphs depending on the selected backend. However, when compiling native framework code, we are only able to compile a graph for that same framework. For example, we cannot take torch code and compile this into tensorflow code. -However, we can transpile torch code into tensorflow code (see :ref:`Ivy as a Transpiler` for more details). +However, we can transpile torch code into tensorflow code (see `Ivy as a Transpiler `_ for more details). The graph compiler does not compile to C++, CUDA, or any other lower level language. It simply traces the backend functional methods in the graph, stores this graph, and then efficiently traverses this graph at execution time, all in Python. diff --git a/docs/overview/design/ivy_as_a_framework.rst b/docs/overview/design/ivy_as_a_framework.rst index 5d5e5cb73b0ec..bf1201048a94b 100644 --- a/docs/overview/design/ivy_as_a_framework.rst +++ b/docs/overview/design/ivy_as_a_framework.rst @@ -1,10 +1,10 @@ Ivy as a Framework ================== -On the :ref:`Building Blocks` page, we explored the role of the backend functional APIs, the Ivy functional API, the framework handler, and the graph compiler. +On the `Building Blocks `_ page, we explored the role of the backend functional APIs, the Ivy functional API, the framework handler, and the graph compiler. These are parts labeled as (a) in the image below. -On the :ref:`Ivy as a Transpiler` page, we explained the role of the backend-specific frontends in Ivy, and how these enable automatic code conversions between different ML frameworks. +On the `Ivy as a Transpiler `_ page, we explained the role of the backend-specific frontends in Ivy, and how these enable automatic code conversions between different ML frameworks. This part is labeled as (b) in the image below. So far, by considering parts (a) and (b), we have mainly treated Ivy as a fully functional framework with code conversion abilities. @@ -19,13 +19,13 @@ These parts are labeled as (c) in the image below. You may choose from the following upcoming discussions or click next. -| (a) :ref:`Ivy Container` +| (a) `Ivy Container `_ | Hierarchical container solving almost everything behind the scenes in Ivy | -| (b) :ref:`Ivy Stateful API` +| (b) `Ivy Stateful API `_ | Trainable Layers, Modules, Optimizers, and more built on the functional API and the Ivy Container | -| (c) :ref:`Ivy Array` +| (c) `Ivy Array `_ | Bringing methods as array attributes to Ivy, cleaning up and simplifying code .. toctree:: diff --git a/docs/overview/design/ivy_as_a_framework/ivy_array.rst b/docs/overview/design/ivy_as_a_framework/ivy_array.rst index cdce00d179a58..96d0ba2e6ef76 100644 --- a/docs/overview/design/ivy_as_a_framework/ivy_array.rst +++ b/docs/overview/design/ivy_as_a_framework/ivy_array.rst @@ -197,7 +197,7 @@ API Monkey Patching All ivy functions with array inputs/outputs have been wrapped to return :class:`ivy.Array` instances while accepting both :class:`ivy.Array` and :class:`ivy.NativeArray` instances. This allows for the control required to provide a unified array interface. -For more details on wrapping, see the `Function Wrapping `_ page in deep dive. +For more details on wrapping, see the `Function Wrapping <../../deep_dive/function_wrapping.rst>`_ page in deep dive. Instance Methods diff --git a/docs/overview/design/ivy_as_a_framework/ivy_container.rst b/docs/overview/design/ivy_as_a_framework/ivy_container.rst index f566472a6147d..fad13c67f9f26 100644 --- a/docs/overview/design/ivy_as_a_framework/ivy_container.rst +++ b/docs/overview/design/ivy_as_a_framework/ivy_container.rst @@ -145,7 +145,7 @@ Or we can flip each sub-array: } } -There are about 200 such functions for the :class:`ivy.Container` class in total, check out the `code `_ or `docs `_ to see what they are! +There are about 200 such functions for the :class:`ivy.Container` class in total, check out the `code `_ or `docs <../../../docs/data_classes/data_classes/ivy.data_classes.container.rst>`_ to see what they are! Built-ins ---------- @@ -458,7 +458,7 @@ All nested structures above this height are truncated into single keys with a These are very useful methods when stepping through code and debugging complex nested structures such as the weights of a network. There are also methods: :code:`cont_with_print_limit` for controlling the printable size of arrays before the shape is instead displayed, :code:`cont_with_key_length_limit` for setting the maximum key length before string clipping, :code:`cont_with_print_indent` for controlling the nested indent, and many more. -Check out the `docs `_ for more details! +Check out the `docs <../../../docs/data_classes/data_classes/ivy.data_classes.container.rst>`_ for more details! Use Cases --------- diff --git a/docs/overview/design/ivy_as_a_framework/ivy_stateful_api.rst b/docs/overview/design/ivy_as_a_framework/ivy_stateful_api.rst index 2993a8f5279e2..3c6574b884d04 100644 --- a/docs/overview/design/ivy_as_a_framework/ivy_stateful_api.rst +++ b/docs/overview/design/ivy_as_a_framework/ivy_stateful_api.rst @@ -370,7 +370,7 @@ The actual implementation for the :class:`ivy.Linear` layer exposed in the Ivy s self.v.b if self._with_bias else None) The :class:`ivy.Initializer` class has a single abstract method, :code:`create_variables(var_shape, dev, fan_out=None, fan_in=None, *args, **kwargs)`. -Check out the `code `_ or `docs `_ for more details. +Check out the `code `_ or :ref:`docs ` for more details. The default initializer for the weights is :class:`ivy.GlorotUniform` and for this bias is :class:`ivy.Zeros`. Let’s take a quick look at what these look like. :class:`ivy.GlorotUniform` derives from a more general :class:`ivy.Uniform` initializer class, and is then simply implemented as follows: diff --git a/docs/overview/design/ivy_as_a_transpiler.rst b/docs/overview/design/ivy_as_a_transpiler.rst index 56467a8d6e53b..50dd33d747ada 100644 --- a/docs/overview/design/ivy_as_a_transpiler.rst +++ b/docs/overview/design/ivy_as_a_transpiler.rst @@ -1,7 +1,7 @@ Ivy as a Transpiler =================== -On the :ref:`Building Blocks` page, we explored the role of the backend functional APIs, the Ivy functional API, the backend handler, and the graph compiler. +On the `Building Blocks `_ page, we explored the role of the backend functional APIs, the Ivy functional API, the backend handler, and the graph compiler. These parts are labelled (a) in the image below. Here, we explain the role of the backend-specific frontends in Ivy, and how these enable automatic code conversions between different ML frameworks. diff --git a/docs/overview/extensions.rst b/docs/overview/extensions.rst deleted file mode 100644 index dce2922e795d8..0000000000000 --- a/docs/overview/extensions.rst +++ /dev/null @@ -1,15 +0,0 @@ -Extensions -========== - -| (a) :ref:`Applied Libraries` ✅ -| Ivy libraries in mechanics, vision, robotics, memory and other areas -| -| (b) **Builder [page coming soon!]** ✅ -| :code:`ivy.Trainer`, :code:`ivy.Dataset`, :code:`ivy.Dataloader` and other helpful classes and functions for creating training workflows in only a few lines of code - -.. toctree:: - :hidden: - :maxdepth: -1 - :caption: Extensions - - extensions/applied_libraries.rst diff --git a/docs/overview/extensions/applied_libraries.rst b/docs/overview/extensions/applied_libraries.rst deleted file mode 100644 index 3798463466737..0000000000000 --- a/docs/overview/extensions/applied_libraries.rst +++ /dev/null @@ -1,113 +0,0 @@ -Applied Libraries -================= - -In other parts of the overview, we have focused on the the Ivy framework itself. -Here, we explore how Ivy has been used to create a suite of libraries in various fields related to ML. -Aside from being useful tools for ML developers in any framework, these libraries are a perfect showcase of what is possible using Ivy! - -Currently, there are Ivy libraries for: Mechanics, 3D Vision, Robotics, Gym Environments, and Differentiable Memory. -We run through some demos from these libraries now, and encourage you to pip install the libraries and run the demos yourself if you like what you see! - -Ivy Mechanics -------------- - -`Ivy Mechanics `_ provides functions for conversions of orientation, pose, and positional representations, as well as transformations, and some other more applied functions. -The orientation module is the largest, with conversions to and from all Euler conventions, quaternions, rotation matrices, rotation vectors, and axis-angle representations. - -For example, this demo shows the use of :code:`ivy_mech.target_facing_rotation_matrix`: - -.. image:: https://github.com/unifyai/unifyai.github.io/blob/main/img/externally_linked/ivy_mech/demo_a.gif?raw=true - :align: center - :width: 100% - -This demo shows the use of :code:`ivy_mech.polar_to_cartesian_coords`: - -.. image:: https://github.com/unifyai/unifyai.github.io/blob/main/img/externally_linked/ivy_mech/demo_b.gif?raw=true - :align: center - :width: 100% - -Ivy Vision ----------- - -`Ivy Vision `_ focuses predominantly on 3D vision, with functions for image projections, co-ordinate frame transformation, forward warping, inverse warping, optical flow, depth generation, voxel grids, point clouds, and others. - -For example, this demo shows the use of :code:`ivy_vision.coords_to_voxel_grid`: - -.. image:: https://github.com/unifyai/unifyai.github.io/blob/main/img/externally_linked/ivy_vision/voxel_grid_demo.gif?raw=true - :align: center - :width: 100% - -This demo shows the use of :code:`ivy_vision.render_pixel_coords`: - -.. image:: https://github.com/unifyai/unifyai.github.io/blob/main/img/externally_linked/ivy_vision/point_render_demo.gif?raw=true - :align: center - :width: 100% - -This demo shows Neural Radiance Fields (NeRF): - -.. image:: https://github.com/unifyai/unifyai.github.io/blob/main/img/externally_linked/ivy_vision/nerf_demo.gif?raw=true - :align: center - :width: 100% - - -Ivy Robot ---------- - -`Ivy Robot `_ provides functions and classes for gradient-based trajectory optimization and motion planning. -Classes are provided both for mobile robots and robot manipulators. - -For example, this demo shows the use of :code:`ivy_robot.sample_spline_path` and :code:`ivy_robot.RigidMobile.sample_body` for gradient-based motion planning of a drone. - -.. image:: https://github.com/unifyai/unifyai.github.io/blob/main/img/externally_linked/ivy_robot/demo_a.gif?raw=true - :align: center - :width: 100% - -This demo shows the use of :code:`ivy_robot.sample_spline_path` and :code:`ivy_robot.Manipulator.sample_links` for gradient-based motion planning of a robot manipulator: - -.. image:: https://github.com/unifyai/unifyai.github.io/blob/main/img/externally_linked/ivy_robot/demo_b.gif?raw=true - :align: center - :width: 100% - -Ivy Gym -------- - -`Ivy Gym `_ provides differentiable implementations of the control environments provided by OpenAI Gym, as well as a new “Swimmer” task which illustrates the simplicity of creating new tasks. -The differentiable nature of the environments means that the cumulative reward can be directly optimized in a supervised manner, without the need for reinforcement learning. -Ivy Gym opens the door for intersectional research between supervised learning, trajectory optimization, and reinforcement learning. - -For example, we show demos of each of the environments :code:`cartpole`, :code:`mountain_car`, :code:`pendulum`, :code:`reacher`, and :code:`swimmer` solved using direct trajectory optimization below. -We optimize for a specific starting state of the environment: - -.. image:: https://github.com/unifyai/unifyai.github.io/blob/main/img/externally_linked/ivy_gym/demo_a.gif?raw=true - :align: center - :width: 100% - -We show demos of each of the environments :code:`cartpole`, :code:`mountain_car`, :code:`pendulum`, :code:`reacher`, and :code:`swimmer` solved using supervised learning via a policy network. -We train a policy which is conditioned on the environment state, and the starting state is then randomized between training steps: - -.. image:: https://github.com/unifyai/unifyai.github.io/blob/main/img/externally_linked/ivy_gym/demo_b.gif?raw=true - :align: center - :width: 100% - -Ivy Memory ----------- - -`Ivy Memory `_ provides differentiable memory modules, including learnt modules such as Neural Turing Machines (NTM), but also parameter-free modules such as End-to-End Egospheric Spatial Memory (ESM). - -For example, in this demo we learn to copy a sequence using :code:`ivy_memory.NTM`: - -.. image:: https://github.com/unifyai/unifyai.github.io/blob/main/img/externally_linked/ivy_memory/demo_a.gif?raw=true - :align: center - :width: 100% - -In this demo we create an egocentric 3D map of a room using :code:`ivy_memory.ESM`: - -.. image:: https://github.com/unifyai/unifyai.github.io/blob/main/img/externally_linked/ivy_memory/demo_b.gif?raw=true - :align: center - :width: 100% - -**Round Up** - -Hopefully, this has given you an idea of what’s possible using Ivy’s collection of applied libraries, and more importantly, given you inspiration for what’s possible using Ivy 🙂 - -Please reach out on `discord `_ if you have any questions! diff --git a/docs/overview/faq.rst b/docs/overview/faq.rst index 9c47fb1e15ba1..e74df1b21dff7 100644 --- a/docs/overview/faq.rst +++ b/docs/overview/faq.rst @@ -4,7 +4,6 @@ FAQ .. _`dex`: https://github.com/dexidp/dex .. _`API for distributed training`: https://github.com/unifyai/ivy/blob/a2f37b1bae232b7ba5257e59f8b46a0374cca9f1/ivy/functional/ivy/device.py#L660 .. _`fully support these`: https://pytorch.org/tutorials/prototype/vmap_recipe.html -.. _`Ivy Builder`: https://github.com/unifyai/builder .. _`README`: https://github.com/unifyai/ivy These are some of the most common technical questions that continue to arise when we're discussing Ivy with developers in the community. @@ -53,7 +52,7 @@ For some backends, shape-checking will be performed during the compilation phase GPU handling ------------ -**Q:** How does Ivy handle GPU usage? +**Q:** How does Ivy handle GPU usage? **A:** Ivy handles GPU usage by simply wrapping the backend frameworks, so Ivy will use GPUs in the same manner as the backend framework does. E.g. When using a torch backend, torch will be a dependency of Ivy, and its handling of GPU functionalities will be inherited and extended upon by Ivy. @@ -156,7 +155,6 @@ The Pipeline **A:** We are not advocating to replace all code with Ivy. We would encourage users to continue using whatever data loaders they want to, and perhaps just use an Ivy model, or use Ivy to convert a model, or even just a single function from a library. -If users want to use Ivy more deeply, then they can use `Ivy Builder`_, which includes framework-agnostic abstract data loaders, trainers, and other higher level classes for composing full training pipelines. State ----- diff --git a/docs/overview/get_started.rst b/docs/overview/get_started.rst index 5e850df91bb20..9d891f143e5c6 100644 --- a/docs/overview/get_started.rst +++ b/docs/overview/get_started.rst @@ -3,8 +3,8 @@ Get Started .. - If you want to use **Ivy's compiler and transpiler**, make sure to follow the - `setting up instructions for the API key `_ + If you want to use **Ivy's compiler and transpiler**, make sure to follow the + :ref:`setting up instructions for the API key ` after installing Ivy! @@ -24,7 +24,7 @@ Keep in mind that this **won't** install any framework other than NumPy! Docker ------ -If you prefer to use containers, we also have pre-built Docker images with all the +If you prefer to use containers, we also have pre-built Docker images with all the supported frameworks and some relevant packages already installed, which you can pull from: .. code-block:: bash @@ -40,17 +40,54 @@ If you are working on a GPU device, you can pull from: Installing from source ---------------------- -You can also install Ivy from source if you want to take advantage of the latest +You can also install Ivy from source if you want to take advantage of the latest changes, but we can't ensure everything will work as expected! .. code-block:: bash git clone https://github.com/unifyai/ivy.git - cd ivy + cd ivy pip install --user -e . -If you are planning to contribute, you want to run the tests, or you are looking -for more in-depth instructions, it's probably best to check out -the `Contributing - Setting Up `_ page, +If you are planning to contribute, you want to run the tests, or you are looking +for more in-depth instructions, it's probably best to check out +the `Contributing - Setting Up `_ page, where OS-specific and IDE-specific instructions and video tutorials to install Ivy are available! + + +Ivy's compiler and transpiler +----------------------------- + +To use Ivy's compiler and transpiler, you'll need an **API key**. We are starting to +grant pilot access to certain users, so you can `join the waitlist `_ +if you want to get one! + +Ivy Folder +~~~~~~~~~~ + +When importing Ivy for the first time, a ``.ivy`` folder will be created in your +working directory. If you want to keep this folder in a different location, +you can set an ``IVY_ROOT`` environment variable with the path of your ``.ivy`` folder. + +Setting Up the API key +~~~~~~~~~~~~~~~~~~~~~~ + +Once the ``.ivy`` folder has been created (either manually or automatically by +importing Ivy), you will have to paste your API key as the content of the ``key.pem`` file. +For reference, this would be equivalent to: + +.. code-block:: bash + + echo -n API_KEY > .ivy/key.pem + +Issues and Questions +~~~~~~~~~~~~~~~~~~~~ + +If you find any issue or bug while using the compiler and/or the transpiler, please +raise an `issue in GitHub `_ and add the ``compiler`` +or the ``transpiler`` label accordingly. A member of the team will get back to you ASAP! + +Otherwise, if you haven't found a bug but want to ask a question, suggest something, or get help +from the team directly, feel free to open a new post at the ``pilot-access`` forum in +`Ivy's discord server! `_ diff --git a/docs/overview/glossary.rst b/docs/overview/glossary.rst index 606dbf1c8a934..956654fd5758f 100644 --- a/docs/overview/glossary.rst +++ b/docs/overview/glossary.rst @@ -45,12 +45,6 @@ All of these new words can get confusing! We've created a glossary to help nail Automatic Code Conversions Allows code to be converted from one framework to another whilst retaining its functional assets. - Applied Libraries - Suite of various machine learning libraries that have been built using the Ivy framework. - - Ivy Builder - Helpful classes and functions for creating training workflows. - Primary Functions Primary functions are the lowest level building blocks in Ivy and are generally implemented as light wrapping around an existing function in the backend framework, which serves a near-identical purpose. diff --git a/docs/overview/background.rst b/docs/overview/motivation.rst similarity index 51% rename from docs/overview/background.rst rename to docs/overview/motivation.rst index d203a51b18c52..a6f9278225af7 100644 --- a/docs/overview/background.rst +++ b/docs/overview/motivation.rst @@ -1,13 +1,13 @@ -Background +Motivation ========== -| (a) :ref:`ML Explosion` +| (a) `ML Explosion `_ | A huge number of ML tools have exploded onto the scene! | -| (b) :ref:`Why Unify?` +| (b) `Why Unify? `_ | Why should we try to unify them? | -| (c) :ref:`Standardization` +| (c) `Standardization `_ | We’re collaborating with The `Consortium for Python Data API Standards `_ .. toctree:: @@ -15,6 +15,6 @@ Background :maxdepth: -1 :caption: Background - background/ml_explosion.rst - background/why_unify.rst - background/standardization.rst + motivation/ml_explosion.rst + motivation/why_unify.rst + motivation/standardization.rst diff --git a/docs/overview/background/ml_explosion.rst b/docs/overview/motivation/ml_explosion.rst similarity index 100% rename from docs/overview/background/ml_explosion.rst rename to docs/overview/motivation/ml_explosion.rst diff --git a/docs/overview/background/standardization.rst b/docs/overview/motivation/standardization.rst similarity index 100% rename from docs/overview/background/standardization.rst rename to docs/overview/motivation/standardization.rst diff --git a/docs/overview/background/why_unify.rst b/docs/overview/motivation/why_unify.rst similarity index 100% rename from docs/overview/background/why_unify.rst rename to docs/overview/motivation/why_unify.rst diff --git a/docs/overview/one_liners.rst b/docs/overview/one_liners.rst new file mode 100644 index 0000000000000..0b11527b0b132 --- /dev/null +++ b/docs/overview/one_liners.rst @@ -0,0 +1,30 @@ +One liners +---------- + +.. grid:: 1 1 3 3 + :gutter: 4 + + .. grid-item-card:: ``ivy.compile()`` + :link: one_liners/compile.rst + + Compiles a ``Callable`` or set of them into an Ivy graph. + + .. grid-item-card:: ``ivy.transpile()`` + :link: one_liners/transpile.rst + + Transpiles a ``Callable`` or set of them from a ``source`` framework to another + framework. + + .. grid-item-card:: ``ivy.unify()`` + :link: one_liners/unify.rst + + Transpiles an object into Ivy code. It's an alias to + ``ivy.transpile(..., to="ivy", ...)`` + +.. toctree:: + :hidden: + :maxdepth: -1 + + one_liners/compile.rst + one_liners/transpile.rst + one_liners/unify.rst diff --git a/docs/compiler/compiler.rst b/docs/overview/one_liners/compile.rst similarity index 89% rename from docs/compiler/compiler.rst rename to docs/overview/one_liners/compile.rst index f6f0b5ab2a9e4..98d3cfd826a3a 100644 --- a/docs/compiler/compiler.rst +++ b/docs/overview/one_liners/compile.rst @@ -1,34 +1,34 @@ -Graph Compiler -============== +``ivy.compile()`` +================= .. ⚠️ **Warning**: The compiler and the transpiler are not publicly available yet, so certain parts of this doc won't work as expected as of now! -When we call an Ivy function, there is always a small performance hit due to added -Python wrapping. This overhead becomes increasingly noticeable when we use large -models with multiple function calls. The Graph Compiler improves the performance of -Ivy by removing the extra wrapping around each function call. +When we call an Ivy function, there is always a small performance hit due to added +Python wrapping. This overhead becomes increasingly noticeable when we use large +models with multiple function calls. The Graph Compiler improves the performance of +Ivy by removing the extra wrapping around each function call. -The Graph Compiler takes in any Ivy function, framework-specific (backend) function, -or composition of both, and produces a simplified executable computation graph composed +The Graph Compiler takes in any Ivy function, framework-specific (backend) function, +or composition of both, and produces a simplified executable computation graph composed of functions from the backend functional API only, which results in: -- Simplified code: The Graph Compiler simplifies the code by removing all the wrapping +- Simplified code: The Graph Compiler simplifies the code by removing all the wrapping and functions that don't contribute to the output: print statements, loggers, etc. -- Improved performance: The compiled graph has no performance overhead due to Ivy's - function wrapping, likewise, redundant operations from the original function are also +- Improved performance: The compiled graph has no performance overhead due to Ivy's + function wrapping, likewise, redundant operations from the original function are also removed, increasing its overall performance. Compiler API ------------ .. py:function:: ivy.compile(*objs, stateful = None, arg_stateful_idxs = None, kwarg_stateful_idxs = None, to = None, include_generators = True, array_caching = True, return_backend_compiled_fn = False, static_argnums = None, static_argnames = None, args = None, kwargs = None,) - - Compiles a ``Callable`` or set of them into an Ivy graph. If ``args`` or ``kwargs`` are specified, + + Compiles a ``Callable`` or set of them into an Ivy graph. If ``args`` or ``kwargs`` are specified, compilation is performed eagerly, otherwise, compilation will happen lazily. - + :param objs: Callable(s) to compile and create a graph of. :type objs: ``Callable`` :param stateful: List of instances to be considered stateful during the graph compilation. @@ -91,12 +91,12 @@ In this case, the compiled graph would be: From the graph, we can observe that: 1. As ``x`` and ``y`` are the only variables used when calculating the returned value ``z``, - the non-contributing variable(s), ``k`` was not included in the graph. Function calls that + the non-contributing variable(s), ``k`` was not included in the graph. Function calls that don't contribute to the output like the ``print`` function were also excluded. -2. As we set the backend to ``torch`` during the compilation process, the compiled +2. As we set the backend to ``torch`` during the compilation process, the compiled functions are torch functions, and the input and output types are torch tensors. -3. The tensor shape in the graph only indicates the shape of the inputs the graph was - traced with. The compiler doesn't impose additional restrictions on the shape or +3. The tensor shape in the graph only indicates the shape of the inputs the graph was + traced with. The compiler doesn't impose additional restrictions on the shape or datatype of the input array(s). .. code-block:: python @@ -114,17 +114,17 @@ From the graph, we can observe that: Eager vs lazy Compilation ~~~~~~~~~~~~~~~~~~~~~~~~~ -The graph compiler runs the original function under the hood and tracks its computation -to create the compiled graph. The **eager compilation** method traces the graph in the -corresponding function call with the specified inputs before we use the compiled +The graph compiler runs the original function under the hood and tracks its computation +to create the compiled graph. The **eager compilation** method traces the graph in the +corresponding function call with the specified inputs before we use the compiled function. -Instead of compiling functions before using them, Ivy also allows you to compile the -function dynamically. This can be done by passing only the function to the -compile method and not including the function arguments. In this case, the output will be a -``LazyGraph`` instead of a ``Graph`` instance. When this ``LazyGraph`` object is first invoked with -function arguments, it compiles the function and returns the output of the compiled -function. Once the graph has been initialized, calls to the ``LazyGraph`` object will +Instead of compiling functions before using them, Ivy also allows you to compile the +function dynamically. This can be done by passing only the function to the +compile method and not including the function arguments. In this case, the output will be a +``LazyGraph`` instead of a ``Graph`` instance. When this ``LazyGraph`` object is first invoked with +function arguments, it compiles the function and returns the output of the compiled +function. Once the graph has been initialized, calls to the ``LazyGraph`` object will use the compiled function to compute the outputs directly. .. code-block:: python @@ -138,18 +138,18 @@ use the compiled function to compute the outputs directly. # Compile and return the output out = lazy_graph(x, y) -To sum up, lazy compilation enables you to delay the compilation process until you have -the necessary inputs during execution. This is particularly useful in cases like -compiling libraries, where it’s not feasible to provide valid arguments for every +To sum up, lazy compilation enables you to delay the compilation process until you have +the necessary inputs during execution. This is particularly useful in cases like +compiling libraries, where it’s not feasible to provide valid arguments for every function call. -Now let's look at additional functionalities that you can find in the +Now let's look at additional functionalities that you can find in the compiler. Array caching ~~~~~~~~~~~~~ -The compiler is able to cache constant arrays and their operations through the +The compiler is able to cache constant arrays and their operations through the ``array_caching`` flag, reducing computation time after compilation. .. code-block:: python @@ -166,15 +166,15 @@ The compiler is able to cache constant arrays and their operations through the comp_func = ivy.compile(fn, args=(x,)) -When calling ``ivy.compile()``, the ``array_caching`` argument is set to ``True`` by +When calling ``ivy.compile()``, the ``array_caching`` argument is set to ``True`` by default, which returns the following graph. .. image:: https://raw.githubusercontent.com/unifyai/unifyai.github.io/main/img/externally_linked/compiler/figure2.png -This shows that by caching the constant operation in the graph, a simpler graph can be -obtained. However, if desired, this argument can be set to ``False``, which results in the -graph below. This ultimately results in a trade-off between time and memory, as -cached results need to be stored in memory but if they are not cached these operations +This shows that by caching the constant operation in the graph, a simpler graph can be +obtained. However, if desired, this argument can be set to ``False``, which results in the +graph below. This ultimately results in a trade-off between time and memory, as +cached results need to be stored in memory but if they are not cached these operations need to be performed. .. image:: https://raw.githubusercontent.com/unifyai/unifyai.github.io/main/img/externally_linked/compiler/figure3.png @@ -195,7 +195,7 @@ are included as nodes or "baked" into the graph. a = torch.randint(0, 100, size=[1]) z = x ** a return z + torch.rand([1]) - + comp_func = ivy.compile(fn, include_generators=True, args=(x,)) Returns: @@ -224,7 +224,7 @@ Returns: Stateful ~~~~~~~~ -Finally, you can also track ``__setattr__`` and ``__getattr__`` methods of +Finally, you can also track ``__setattr__`` and ``__getattr__`` methods of arbitrary classes using the ``stateful`` parameters. .. code-block:: python @@ -248,34 +248,34 @@ arbitrary classes using the ``stateful`` parameters. Sharp bits ---------- -As some parts of the graph compiler are still under development, there are some sharp -bits to take into account when using it. All of these points are WIP, so they'll be +As some parts of the graph compiler are still under development, there are some sharp +bits to take into account when using it. All of these points are WIP, so they'll be removed soon! -1. **Dynamic control flow**: The compiled graph is built using function tracing at the - moment, so dynamic control flow such as conditional branches or conditional loops - will not be registered correctly. As an example, if there is a while loop in your - code that depends on a changing value, the number of iterations in the final graph - will be the same as the number of iterations performed with the input passed to the +1. **Dynamic control flow**: The compiled graph is built using function tracing at the + moment, so dynamic control flow such as conditional branches or conditional loops + will not be registered correctly. As an example, if there is a while loop in your + code that depends on a changing value, the number of iterations in the final graph + will be the same as the number of iterations performed with the input passed to the compile function. -2. **Non-framework-specific code**: As the compiler traces the function using the - functional API of the underlying framework, any piece of code inside the model that - is not from the said framework will not be correctly registered, this includes other - frameworks code (such as NumPy statements inside a torch model) or python statements +2. **Non-framework-specific code**: As the compiler traces the function using the + functional API of the underlying framework, any piece of code inside the model that + is not from the said framework will not be correctly registered, this includes other + frameworks code (such as NumPy statements inside a torch model) or python statements such as len(). -3. **Incorrectly cached parts of the graph**: There are certain cases where compilation +3. **Incorrectly cached parts of the graph**: There are certain cases where compilation can succeed but hide some cached parts of the graph which shouldn't really be cached. - To check this, it's recommended to compile with a noise array of the same shape and + To check this, it's recommended to compile with a noise array of the same shape and then check if the output of the original function and the compiled graph with another - input is the same. If you find out that the graph is not right, feel free to open an - `issue `_ with a minimal example and we'll look + input is the same. If you find out that the graph is not right, feel free to open an + `issue `_ with a minimal example and we'll look into it! Examples -------- -Below, we compile a ResNet50 model from -`Hugging Face `_ and use it to classify the +Below, we compile a ResNet50 model from +`Hugging Face `_ and use it to classify the breed of a cat. .. code-block:: python diff --git a/docs/compiler/transpiler.rst b/docs/overview/one_liners/transpile.rst similarity index 76% rename from docs/compiler/transpiler.rst rename to docs/overview/one_liners/transpile.rst index 92e55711364b9..701be359e3165 100644 --- a/docs/compiler/transpiler.rst +++ b/docs/overview/one_liners/transpile.rst @@ -1,32 +1,32 @@ -Transpiler -========== +``ivy.transpile()`` +================= .. ⚠️ **Warning**: The compiler and the transpiler are not publicly available yet, so certain parts of this doc won't work as expected as of now! -Ivy's Transpiler converts a function written in any framework into your framework of -choice, preserving all the logic between frameworks. +Ivy's Transpiler converts a function written in any framework into your framework of +choice, preserving all the logic between frameworks. As the output of transpilation is native code in the target framework, it -can be used as if it was originally developed in that framework, +can be used as if it was originally developed in that framework, allowing you to apply and use framework-specific optimizations or tools. -This makes all ML-related projects available to you, independently of the framework you +This makes all ML-related projects available to you, independently of the framework you want to use to research, develop, or deploy systems. So if you want to: -- Use functions and building blocks like neural networks, layers, activations, and - training pipelines published in other frameworks. Ex: Using Haiku modules in PyTorch to +- Use functions and building blocks like neural networks, layers, activations, and + training pipelines published in other frameworks. Ex: Using Haiku modules in PyTorch to get access to the latest model. -- Integrate code developed in other frameworks into your code. Ex: Use the Kornia +- Integrate code developed in other frameworks into your code. Ex: Use the Kornia library in Jax for extra performance. -- Take advantage of specific features in other frameworks. Ex: Convert Jax code to Tensorflow for deployment. +- Take advantage of specific features in other frameworks. Ex: Convert Jax code to Tensorflow for deployment. Ivy's Transpiler is definitely the tool for the job 🔧 -To convert the code, it traces a computational graph using the Graph Compiler and -leverages Ivy's frontends and backends to link one framework to another. After swapping -each function node in the computational graph with their equivalent Ivy frontend +To convert the code, it traces a computational graph using the Graph Compiler and +leverages Ivy's frontends and backends to link one framework to another. After swapping +each function node in the computational graph with their equivalent Ivy frontend functions, the compiler removes all the wrapping in the frontends and replaces them with the native functions of the target framework. @@ -36,7 +36,7 @@ Transpiler API .. py:function:: ivy.transpile(*objs, source = None, to = None, debug_mode = False, args = None, kwargs = None, params_v = None,) - Transpiles a ``Callable`` or set of them from a ``source`` framework to another framework. If ``args`` or ``kwargs`` are specified, + Transpiles a ``Callable`` or set of them from a ``source`` framework to another framework. If ``args`` or ``kwargs`` are specified, transpilation is performed eagerly, otherwise, transpilation will happen lazily. :param objs: Native callable(s) to transpile. @@ -58,33 +58,15 @@ Transpiler API :rtype: ``Union[Graph, LazyGraph, ModuleType, ivy.Module, torch.nn.Module, tf.keras.Model, hk.Module]`` :return: A transpiled ``Graph`` or a non-initialized ``LazyGraph``. If the object is a native trainable module, the corresponding module in the target framework will be returned. If the object is a ``ModuleType``, the function will return a copy of the module with every method lazily transpiled. -.. py:function:: ivy.unify(*objs, source = None, args = None, kwargs = None, **transpile_kwargs,) - - Transpiles an object into Ivy code. It's an alias to - ``ivy.transpile(..., to="ivy", ...)`` - - :param objs: Native callable(s) to transpile. - :type objs: ``Callable`` - :param source: The framework that ``obj`` is from. This must be provided unless ``obj`` is a framework-specific module. - :type source: ``Optional[str]`` - :param args: If specified, arguments that will be used to unify eagerly. - :type args: ``Optional[Tuple]`` - :param kwargs: If specified, keyword arguments that will be used to unify eagerly. - :type kwargs: ``Optional[dict]`` - :param transpile_kwargs: Arbitrary keyword arguments that will be passed to ``ivy.transpile``. - - :rtype: ``Union[Graph, LazyGraph, ModuleType, ivy.Module]`` - :return: A transpiled ``Graph`` or a non-initialized ``LazyGraph``. If the object is a native trainable module, the corresponding module in the target framework will be returned. If the object is a ``ModuleType``, the function will return a copy of the module with every method lazily transpiled. - Using the transpiler -------------------- Similar to the ``ivy.compile`` function, ``ivy.unify`` and ``ivy.transpile`` can be used -eagerly and lazily. If you pass the necessary arguments, the function will be called -instantly, otherwise, transpilation will happen the first time you invoke the function -with the proper arguments. +eagerly and lazily. If you pass the necessary arguments, the function will be called +instantly, otherwise, transpilation will happen the first time you invoke the function +with the proper arguments. -In both cases, arguments or keyword arguments can be arrays from +In both cases, arguments or keyword arguments can be arrays from either the ``source`` framework or the target (``to``) framework. Transpiling functions @@ -111,7 +93,7 @@ a small JAX function to Torch both eagerly and lazily. ret = eager_graph(x1) # Arguments are not available -> transpilation happens lazily - lazy_graph = ivy.transpile(test_fn, source="jax", to="torch") + lazy_graph = ivy.transpile(test_fn, source="jax", to="torch") # The transpiled graph is initialized, transpilation will happen here ret = lazy_graph(x1) @@ -122,7 +104,7 @@ a small JAX function to Torch both eagerly and lazily. Transpiling Libraries ~~~~~~~~~~~~~~~~~~~~~ -Likewise, you can use ``ivy.transpile`` to convert entire libraries and modules with just one line of +Likewise, you can use ``ivy.transpile`` to convert entire libraries and modules with just one line of code! .. code-block:: python @@ -150,8 +132,8 @@ code! Transpiling Modules ~~~~~~~~~~~~~~~~~~~ -Last but not least, Ivy can also transpile trainable modules from one framework to -another, at the moment we support ``torch.nn.Module`` when ``to="torch"``, +Last but not least, Ivy can also transpile trainable modules from one framework to +another, at the moment we support ``torch.nn.Module`` when ``to="torch"``, ``tf.keras.Model`` when ``to="tensorflow"``, and haiku models when ``to="jax"``. .. code-block:: @@ -193,61 +175,41 @@ another, at the moment we support ``torch.nn.Module`` when ``to="torch"``, ret = forward_classifier.apply(params, None, x) -Ivy.unify -~~~~~~~~~ - -As mentioned above, ``ivy.unify`` is an alias for transpilation to Ivy, so you can use it -exactly in the same way to convert framework specific code to Ivy. - -.. code-block:: python - - import ivy - ivy.set_backend("jax") - - def test_fn(x): - return jax.numpy.sum(x) - - x1 = ivy.array([1., 2.]) - - # transpiled_func and unified_func will have the same result - transpiled_func = ivy.transpile(test_fn, to="ivy", args=(x1,)) - unified_func = ivy.unify(test_fn, args=(x1,)) - Sharp bits ---------- -In a similar fashion to the compiler, the transpiler is under development and we are +In a similar fashion to the compiler, the transpiler is under development and we are still working on some rough edges. These include: -1. **Keras model subclassing**: If a model is transpiled to keras, the resulting - ``tf.keras.Model`` can not be used within a keras sequential model at the moment. If - you want to use the transpiled model as part of a more complex keras model, you can - `create a Model subclass - `_. +1. **Keras model subclassing**: If a model is transpiled to keras, the resulting + ``tf.keras.Model`` can not be used within a keras sequential model at the moment. If + you want to use the transpiled model as part of a more complex keras model, you can + `create a Model subclass + `_. Due to this, any training of a keras model should be done using a TensorFlow training pipeline instead of the keras utils. -2. **Keras arguments**: Keras models require at least an argument to be passed, so if a - model from another framework that only takes ``kwargs`` is transpiled to keras, - you'll need to pass a ``None`` argument to the transpiled model before the +2. **Keras arguments**: Keras models require at least an argument to be passed, so if a + model from another framework that only takes ``kwargs`` is transpiled to keras, + you'll need to pass a ``None`` argument to the transpiled model before the corresponding ``kwargs``. -3. **Haiku transform with state**: As of now, we only support the transpilation of - transformed Haiku modules, this means that ``transformed_with_state`` objects will +3. **Haiku transform with state**: As of now, we only support the transpilation of + transformed Haiku modules, this means that ``transformed_with_state`` objects will not be correctly transpiled. -4. **Array format between frameworks**: As the compiler outputs a 1-to-1 mapping of the - compiled function, the format of the tensors is preserved when transpiling from a - framework to another. As an example, if you transpile a convolutional block from +4. **Array format between frameworks**: As the compiler outputs a 1-to-1 mapping of the + compiled function, the format of the tensors is preserved when transpiling from a + framework to another. As an example, if you transpile a convolutional block from PyTorch (which uses ``N, C, H, W``) to TensorFlow (which uses ``N, H, W, C``) and want - to use it as part of a bigger (TensorFlow) model, you'll need to include a permute statement for - the inference to be correct. + to use it as part of a bigger (TensorFlow) model, you'll need to include a permute statement for + the inference to be correct. -Keep in mind that the transpiler uses the graph compiler under the hood, so the -`sharp bits of the compiler `_ +Keep in mind that the transpiler uses the graph compiler under the hood, so the +:ref:`sharp bits of the compiler ` apply here as well! Examples -------- -Here, we are transpiling a HF model from torch to tensorflow and then using the +Here, we are transpiling a HF model from torch to tensorflow and then using the resulting model with tensorflow tensors directly: .. code-block:: python diff --git a/docs/overview/one_liners/unify.rst b/docs/overview/one_liners/unify.rst new file mode 100644 index 0000000000000..a07ac2fbf5b40 --- /dev/null +++ b/docs/overview/one_liners/unify.rst @@ -0,0 +1,107 @@ +``ivy.unify()`` +================ + +.. + + ⚠️ **Warning**: The compiler and the transpiler are not publicly available yet, so certain parts of this doc won't work as expected as of now! + +Ivy's Unify function is an alias for ``ivy.transpile(..., to="ivy", ...)``. You can know +more about the transpiler in the `transpile() `_ page. + +Unify API +--------- + +.. py:function:: ivy.unify(*objs, source = None, args = None, kwargs = None, **transpile_kwargs,) + + Transpiles an object into Ivy code. It's an alias to + ``ivy.transpile(..., to="ivy", ...)`` + + :param objs: Native callable(s) to transpile. + :type objs: ``Callable`` + :param source: The framework that ``obj`` is from. This must be provided unless ``obj`` is a framework-specific module. + :type source: ``Optional[str]`` + :param args: If specified, arguments that will be used to unify eagerly. + :type args: ``Optional[Tuple]`` + :param kwargs: If specified, keyword arguments that will be used to unify eagerly. + :type kwargs: ``Optional[dict]`` + :param transpile_kwargs: Arbitrary keyword arguments that will be passed to ``ivy.transpile``. + + :rtype: ``Union[Graph, LazyGraph, ModuleType, ivy.Module]`` + :return: A transpiled ``Graph`` or a non-initialized ``LazyGraph``. If the object is a native trainable module, the corresponding module in the target framework will be returned. If the object is a ``ModuleType``, the function will return a copy of the module with every method lazily transpiled. + +Usage +----- + +As we mentioned, ``ivy.unify()`` is an alias for ``ivy.transpile(..., to="ivy", ...)``. +So you can use it in the same way as ``ivy.transpile()``. In this case, instead of +getting a graph composed of functions from the functional API of the target framework, +the function will return a graph fully composed of ivy functions, allowing you to run +the graph in any framework directly. + +.. code-block:: python + + import ivy + ivy.set_backend("jax") + + def test_fn(x): + return jax.numpy.sum(x) + + x1 = ivy.array([1., 2.]) + + # transpiled_func and unified_func will have the same result + transpiled_func = ivy.transpile(test_fn, to="ivy", args=(x1,)) + unified_func = ivy.unify(test_fn, args=(x1,)) + +Sharp bits +---------- + +``ivy.unify()`` has the same sharp bits as ``ivy.transpile()``. You can know more about +them in the :ref:`overview/one_liners/transpile:Sharp bits` section of the transpiler. + +Examples +-------- + +Below, we will define a function in torch and try to call it with different native +arguments. + +Here we will define the torch function and unify it: + +.. code-block:: python + + import ivy + import torch + + def normalize(x): + mean = torch.mean(x) + std = torch.std(x) + return torch.div(torch.sub(x, mean), std) + + normalize = ivy.unify(normalize, source="torch") + +Now we can call the function with different ivy backends: + +.. code-block:: python + + import numpy as np + import jax.numpy as jnp + import tensorflow as tf + + # create random numpy arrays for testing + x = np.random.uniform(size=10).astype(np.float32) + ivy.set_backend("numpy") + print(normalize(x)) + + # jax + x_ = jnp.array(x) + ivy.set_backend("jax") + print(normalize(x_)) + + # tensorflow + x_ = tf.constant(x) + ivy.set_backend("tensorflow") + print(normalize(x_)) + + # torch + x_ = torch.tensor(x) + ivy.set_backend("torch") + print(normalize(x_)) diff --git a/docs/overview/related_work.rst b/docs/overview/related_work.rst index 1d82f83fcbe0c..51b6b10746179 100644 --- a/docs/overview/related_work.rst +++ b/docs/overview/related_work.rst @@ -1,11 +1,23 @@ Related Work ============ +.. _`RWorks API Standards`: related_work/api_standards.rst +.. _`RWorks Wrapper Frameworks`: related_work/wrapper_frameworks.rst +.. _`RWorks Frameworks`: related_work/frameworks.rst +.. _`RWorks Graph Tracers`: related_work/graph_tracers.rst +.. _`RWorks Exchange Formats`: related_work/exchange_formats.rst +.. _`RWorks Compiler Infrastructure`: related_work/compiler_infrastructure.rst +.. _`RWorks Multi-Vendor Compiler Frameworks`: related_work/multi_vendor_compiler_frameworks.rst +.. _`RWorks Vendor-Specific APIs`: related_work/vendor_specific_apis.rst +.. _`RWorks Vendor-Specific Compilers`: related_work/vendor_specific_compilers.rst +.. _`RWorks ML-Unifying Companies`: related_work/ml_unifying_companies.rst +.. _`RWorks What does Ivy Add?`: related_work/what_does_ivy_add.rst + In this section, we explain how Ivy compares to many other very important and related pieces of work, which also address fragmentation but at other areas within the ML stack. Firstly, we need to look at the overall ML stack, and understand how the high level frameworks relate to the low level components. -In order to conceptualize this rather complex hierarchy, we have broken the ML stack into 9 groups, which are: :ref:`RWorks API Standards`, :ref:`RWorks Wrapper Frameworks`, :ref:`RWorks Frameworks`, :ref:`RWorks Graph Tracers`, :ref:`RWorks Exchange Formats`, :ref:`RWorks Compiler Infrastructure`, :ref:`RWorks Multi-Vendor Compiler Frameworks`, :ref:`RWorks Vendor-Specific APIs` and :ref:`RWorks Vendor-Specific Compilers`, going from high level to low level respectively. +In order to conceptualize this rather complex hierarchy, we have broken the ML stack into 9 groups, which are: `RWorks API Standards`_, `RWorks Wrapper Frameworks`_, `RWorks Frameworks`_, `RWorks Graph Tracers`_, `RWorks Exchange Formats`_, `RWorks Compiler Infrastructure`_, `RWorks Multi-Vendor Compiler Frameworks`_, `RWorks Vendor-Specific APIs`_ and `RWorks Vendor-Specific Compilers`_, going from high level to low level respectively. .. image:: https://github.com/unifyai/unifyai.github.io/blob/main/img/externally_linked/related_work/ml_stack.png?raw=true :width: 100% @@ -18,37 +30,37 @@ We see these efforts as being very complimentary to Ivy's vision for high level Finally, we discuss how Ivy compares to each of these important works at all levels within the ML stack. -| (a) :ref:`RWorks API Standards` 🤝🏽 +| (a) `RWorks API Standards`_ 🤝🏽 | Standardized APIs which similar libraries should adhere to | -| (b) :ref:`RWorks Wrapper Frameworks` 🎁 +| (b) `RWorks Wrapper Frameworks`_ 🎁 | Frameworks which wrap other ML frameworks | -| (c) :ref:`RWorks Frameworks` 🔢 +| (c) `RWorks Frameworks`_ 🔢 | Standalone ML Frameworks | -| (d) :ref:`RWorks Graph Tracers` 🕸️ +| (d) `RWorks Graph Tracers`_ 🕸️ | Extracting acyclic directed computation graphs from code | -| (e) :ref:`RWorks Exchange Formats` 💱 +| (e) `RWorks Exchange Formats`_ 💱 | File formats to exchange neural networks between frameworks | -| (f) :ref:`RWorks Compiler Infrastructure` 🔟️🏗️ +| (f) `RWorks Compiler Infrastructure`_ 🔟️🏗️ | Infrastructure and standards to simplify the lives of compiler designers | -| (g) :ref:`RWorks Multi-Vendor Compiler Frameworks` 🖥️💻🔟 +| (g) `RWorks Multi-Vendor Compiler Frameworks`_ 🖥️💻🔟 | Executing ML code on a variety of hardware targets | -| (h) :ref:`RWorks Vendor-Specific APIs` 💻🔢 +| (h) `RWorks Vendor-Specific APIs`_ 💻🔢 | Interfacing with specific hardware in an intuitive manner | -| (i) :ref:`RWorks Vendor-Specific Compilers` 💻🔟 +| (i) `RWorks Vendor-Specific Compilers`_ 💻🔟 | Compiling code to specific hardware | -| (j) :ref:`RWorks ML-Unifying Companies` 📈 +| (j) `RWorks ML-Unifying Companies`_ 📈 | Companies working towards unification in ML | -| (k) :ref:`RWorks What does Ivy Add?` 🟢 +| (k) `RWorks What does Ivy Add?`_ 🟢 | How does Ivy fit into all of this? .. toctree:: diff --git a/docs/overview/related_work/compiler_infrastructure.rst b/docs/overview/related_work/compiler_infrastructure.rst index a0ff2ee411b32..b5a11bb113493 100644 --- a/docs/overview/related_work/compiler_infrastructure.rst +++ b/docs/overview/related_work/compiler_infrastructure.rst @@ -39,7 +39,7 @@ OneAPI The set of APIs spans several domains that benefit from acceleration, including libraries for linear algebra math, deep learning, machine learning, video processing, and others. `OneDNN`_ is particularly relevant, focusing on neural network functions for deep learning training and inference. Intel CPUs and GPUs have accelerators for Deep Learning software, and OneDNN provides a unified interface to utilize these accelerators, with much of the hardware-specific complexity abstracted away. -In a similar manner to `MLIR`_, OneAPI is also designed to operate at a lower level than the Neural Network :ref:`Exchange Formats`. +In a similar manner to `MLIR`_, OneAPI is also designed to operate at a lower level than the Neural Network :ref:`overview/related_work/what_does_ivy_add:Exchange Formats`. The interface is lower level and more primitive than the neural network exchange formats, with a focus on the core low-level operations such as convolutions, matrix multiplications, batch normalization etc. This makes OneDNN very much complementary to these formats, where OneDNN can sit below the exchange formats in the overall stack, enabling accelerators to be fully leveraged with minimal hardware-specific considerations, with this all helpfully being abstracted by the OneDNN API. Indeed, OneAPI and MLIR can work together in tandem, and OneDNN is working to `integrate Tensor Possessing Primitives in the MLIR compilers used underneath TensorFlow `_. diff --git a/docs/overview/related_work/what_does_ivy_add.rst b/docs/overview/related_work/what_does_ivy_add.rst index 475f2c6a1ff62..14a407d24a751 100644 --- a/docs/overview/related_work/what_does_ivy_add.rst +++ b/docs/overview/related_work/what_does_ivy_add.rst @@ -38,7 +38,7 @@ However, for the time being, we are focusing exclusively on Python, in order to Wrapper Frameworks ------------------ Ivy is itself a Python Wrapper Framework. -The biggest difference between Ivy and all others listed in the :ref:`Wrapper Frameworks` section is that Ivy supports transpilations between frameworks, while all other frameworks only enable the creation of entirely new code which itself is framework-agnostic. +The biggest difference between Ivy and all others listed in the `Wrapper Frameworks `_ section is that Ivy supports transpilations between frameworks, while all other frameworks only enable the creation of entirely new code which itself is framework-agnostic. There are also other more subtle differences. For example, Ivy includes both a low level fully functional API and a high level stateful API, offering both low level control and high level convenience. In contrast, `EagerPy`_ and `TensorLy`_ both only include functional APIs, `Thinc`_ only includes a high level stateful API, and `NeuroPod`_ only supports an even higher level wrapper for deployment. @@ -51,7 +51,7 @@ It therefore extends what is possible in any of the specific individual framewor Graph Tracers ------------- -Ivy’s :ref:`Graph Compiler` exhibits similar properties to many of the framework-specific graph tracers. +Ivy’s `Graph Compiler <../one_liners/compile>`_ exhibits similar properties to many of the framework-specific graph tracers. Ivy’s graph compiler employs function tracing for computing the graph, and uses this graph as an intermediate representation during the transpilation process. Of all the graph tracers, Ivy’s graph compiler is most similar to `torch.fx`_. This is because :code:`torch.fx` also operates entirely in Python, without deferring to lower level languages for tracing and extracting the computation graph or the intermediate representation. @@ -99,7 +99,7 @@ However, again they do nothing to address the challenge of running code from one ML-Unifying Companies --------------------- -The ML-unifying companies `Quansight`_, `OctoML`_ and `Modular`_ are/were directly involved with the `Array API Standard`_, `Apache TVM`_ and `MLIR`_ respectively, as explained in the :ref:`ML-Unifying Companies` section. +The ML-unifying companies `Quansight`_, `OctoML`_ and `Modular`_ are/were directly involved with the `Array API Standard`_, `Apache TVM`_ and `MLIR`_ respectively, as explained in the `ML-Unifying Companies `_ section. For the same reasons that Ivy as a framework is complementary to these three frameworks, Ivy as a company is also complementary to these three companies. Firstly, we are adhering to the `Array API Standard`_ defined by Quansight. In essence, they have written the standard and we have implemented it, which is pretty much as complementary as it gets. diff --git a/docs/partial_conf.py b/docs/partial_conf.py index ce4ade7c51ba7..ffb6e62ce4c1f 100644 --- a/docs/partial_conf.py +++ b/docs/partial_conf.py @@ -42,9 +42,15 @@ skippable_method_attributes = [{"__qualname__": "_wrap_function..new_function"}] +autosectionlabel_prefix_document = True + # Retrieve html_theme_options from docs/conf.py from docs.conf import html_theme_options html_theme_options["switcher"]["json_url"] = "https://unify.ai/docs/versions/ivy.json" +html_sidebars = {"**": ["custom-toc-tree"]} repo_name = "ivy" + +# Retrieve demos specific configuration +from docs.demos.demos_conf import * # noqa diff --git a/docs/prebuild.sh b/docs/prebuild.sh index d22d94c22a279..af122eafaaf58 100755 --- a/docs/prebuild.sh +++ b/docs/prebuild.sh @@ -1,6 +1,6 @@ #!/bin/bash -e -# For some reason torch needed to be installed sequentially before installing from +# For some reason torch needed to be installed sequentially before installing from # requirements.txt pip install torch || exit 1 pip install torch-scatter || exit 1 diff --git a/duplicate.py b/duplicate.py index e7ac5767d8b2c..bd084f4ca411b 100644 --- a/duplicate.py +++ b/duplicate.py @@ -8,7 +8,7 @@ def get_all_functions_from_directory(root_dir, startswith="test"): print("Invalid directory") exit(1) functions_names = [] - for filename in glob.iglob(root_dir + "/**/*.py", recursive=True): + for filename in glob.iglob(f"{root_dir}/**/*.py", recursive=True): if len(filename) >= 2 and filename[:2] == "./": filename = filename[2:] filename = filename.replace(".py", "") diff --git a/install_dependencies.sh b/install_dependencies.sh index 40115a8571594..b620d7d381e3e 100755 --- a/install_dependencies.sh +++ b/install_dependencies.sh @@ -1,3 +1,5 @@ +sudo apt-get update +sudo apt-get install pandoc -y pip install -r requirements/requirements.txt if [[ $(arch) == 'arm64' ]]; then pip install -r requirements/optional_apple_silicon_1.txt diff --git a/ivy/__init__.py b/ivy/__init__.py index f84ccbaacb1a6..d00fb23359c30 100644 --- a/ivy/__init__.py +++ b/ivy/__init__.py @@ -791,6 +791,12 @@ class Node(str): add_array_specs() _imported_frameworks_before_compiler = list(sys.modules.keys()) + +try: + from .engines import XLA as xla + from .engines import ivy2xla +except: + pass try: from .compiler.compiler import transpile, compile, unify except: # noqa: E722 diff --git a/ivy/data_classes/array/activations.py b/ivy/data_classes/array/activations.py index 8bc0745927d91..6cb2aeaf68ddf 100644 --- a/ivy/data_classes/array/activations.py +++ b/ivy/data_classes/array/activations.py @@ -166,6 +166,7 @@ def softmax( /, *, axis: Optional[int] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Array] = None, ) -> ivy.Array: """ @@ -179,6 +180,9 @@ def softmax( input array. axis the axis or axes along which the softmax should be computed + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. @@ -195,7 +199,7 @@ def softmax( >>> print(y) ivy.array([0.422, 0.155, 0.422]) """ - return ivy.softmax(self._data, axis=axis, out=out) + return ivy.softmax(self._data, axis=axis, complex_mode=complex_mode, out=out) def softplus( self: ivy.Array, diff --git a/ivy/data_classes/array/conversions.py b/ivy/data_classes/array/conversions.py index 30ac4445c0192..8da2954f31249 100644 --- a/ivy/data_classes/array/conversions.py +++ b/ivy/data_classes/array/conversions.py @@ -53,7 +53,7 @@ def _to_ivy(x: Any) -> Any: def to_ivy( x: Union[ivy.Array, ivy.NativeArray, Iterable], nested: bool = False, - include_derived: Optional[Dict[type, bool]] = None, + include_derived: Optional[Dict[str, bool]] = None, ) -> Union[ivy.Array, ivy.NativeArray, Iterable]: """ Return the input array converted to an ivy.Array instance if it is a native array @@ -84,7 +84,7 @@ def to_ivy( def args_to_ivy( *args: Iterable[Any], - include_derived: Optional[Dict[type, bool]] = None, + include_derived: Optional[Dict[str, bool]] = None, **kwargs: Dict[str, Any], ) -> Tuple[Iterable[Any], Dict[str, Any]]: """ @@ -115,7 +115,7 @@ def args_to_ivy( def to_native( x: Union[ivy.Array, ivy.NativeArray, Iterable], nested: bool = False, - include_derived: Optional[Dict[type, bool]] = None, + include_derived: Optional[Dict[str, bool]] = None, cont_inplace: bool = False, to_ignore: Optional[Union[type, Tuple[type]]] = None, ) -> Union[ivy.Array, ivy.NativeArray, Iterable]: @@ -157,7 +157,7 @@ def to_native( def args_to_native( *args: Iterable[Any], - include_derived: Dict[type, bool] = None, + include_derived: Dict[str, bool] = None, cont_inplace: bool = False, to_ignore: Optional[Union[type, Tuple[type]]] = None, **kwargs: Dict[str, Any], diff --git a/ivy/data_classes/array/elementwise.py b/ivy/data_classes/array/elementwise.py index 794da705e123c..c9ce05589e45b 100644 --- a/ivy/data_classes/array/elementwise.py +++ b/ivy/data_classes/array/elementwise.py @@ -2055,7 +2055,7 @@ def positive(self: ivy.Array, *, out: Optional[ivy.Array] = None) -> ivy.Array: def pow( self: ivy.Array, - x2: Union[ivy.Array, ivy.NativeArray], + x2: Union[int, float, ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None, diff --git a/ivy/data_classes/array/experimental/activations.py b/ivy/data_classes/array/experimental/activations.py index 699022d70bdfa..f27ef89516bbb 100644 --- a/ivy/data_classes/array/experimental/activations.py +++ b/ivy/data_classes/array/experimental/activations.py @@ -1,6 +1,6 @@ # global import abc -from typing import Optional, Union +from typing import Optional, Union, Literal # local import ivy @@ -8,7 +8,12 @@ class _ArrayWithActivationsExperimental(abc.ABC): def logit( - self, /, *, eps: Optional[float] = None, out: Optional[ivy.Array] = None + self, + /, + *, + eps: Optional[float] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ ivy.Array instance method variant of ivy.logit. This method simply wraps the @@ -23,6 +28,9 @@ def logit( When eps is None the function outpus NaN where x < 0 or x > 1. and inf or -inf where x = 1 or x = 0, respectively. Otherwise if eps is defined, x is clamped to [eps, 1 - eps] + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out Optional output array. @@ -43,7 +51,7 @@ def logit( >>> print(z) ivy.array([ 1.38629448, 1.38629448, -1.38629436]) """ - return ivy.logit(self, eps=eps, out=out) + return ivy.logit(self, eps=eps, complex_mode=complex_mode, out=out) def thresholded_relu( self: ivy.Array, @@ -152,6 +160,7 @@ def relu6(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: def logsigmoid( self: ivy.Array, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Array: """ ivy.Array instance method variant of ivy.logsigmoid. This method simply wraps @@ -162,6 +171,9 @@ def logsigmoid( ---------- self Input array. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -179,7 +191,7 @@ def logsigmoid( >>> print(z) ivy.array([-2.57888985, -0.31326169, -0.69314718, -0.01104775]) """ - return ivy.logsigmoid(self._data) + return ivy.logsigmoid(self._data, complex_mode=complex_mode) def selu(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: """ diff --git a/ivy/data_classes/array/experimental/creation.py b/ivy/data_classes/array/experimental/creation.py index 8942f95f39706..1bd85ba5170c3 100644 --- a/ivy/data_classes/array/experimental/creation.py +++ b/ivy/data_classes/array/experimental/creation.py @@ -211,3 +211,56 @@ def trilu( on the same device as ``self``. """ return ivy.trilu(self._data, k=k, upper=upper, out=out) + + @staticmethod + def mel_weight_matrix( + num_mel_bins: Union[int, ivy.Array], + dft_length: Union[int, ivy.Array], + sample_rate: Union[int, ivy.Array], + lower_edge_hertz: Optional[Union[float, ivy.Array]] = 0.0, + upper_edge_hertz: Optional[Union[float, ivy.Array]] = 3000.0, + ): + """ + Generate a MelWeightMatrix that can be used to re-weight a Tensor containing a + linearly sampled frequency spectra (from DFT or STFT) into num_mel_bins + frequency information based on the [lower_edge_hertz, upper_edge_hertz] + + range on the mel scale. This function defines the mel scale + in terms of a frequency in hertz according to the following + formula: mel(f) = 2595 * log10(1 + f/700) + + Parameters + ---------- + num_mel_bins + The number of bands in the mel spectrum. + dft_length + The size of the original DFT obtained from (n_fft / 2 + 1). + sample_rate + Samples per second of the input signal. + lower_edge_hertz + Lower bound on the frequencies to be included in the mel spectrum. + upper_edge_hertz + The desired top edge of the highest frequency band. + + Returns + ------- + ret + MelWeightMatrix of shape: [frames, num_mel_bins]. + + Examples + -------- + >>> x = ivy.array([[1, 2, 3], + >>> [1, 1, 1], + >>> [5,6,7 ]]) + >>> x.mel_weight_matrix(3, 3, 8000) + ivy.array([[0. ,0. , 0.], + [0. ,0. , 0.75694758], + [0. ,0. , 0. ]]) + """ + return ivy.mel_weight_matrix( + num_mel_bins, + dft_length, + sample_rate, + lower_edge_hertz, + upper_edge_hertz, + ) diff --git a/ivy/data_classes/array/experimental/linear_algebra.py b/ivy/data_classes/array/experimental/linear_algebra.py index ffc50ed222021..75c56cdf9c9da 100644 --- a/ivy/data_classes/array/experimental/linear_algebra.py +++ b/ivy/data_classes/array/experimental/linear_algebra.py @@ -728,3 +728,59 @@ def dot( ivy.array([[-15.28]]) """ return ivy.dot(self._data, b, out=out) + + def general_inner_product( + self: Union[ivy.Array, ivy.NativeArray], + b: Union[ivy.Array, ivy.NativeArray], + n_modes: Optional[int] = None, + /, + *, + out: Optional[ivy.Array] = None, + ) -> ivy.Array: + """ + ivy.Array instance method variant of ivy.general_inner_product. This method + simply wraps the function, and so the docstring for ivy.general_inner_product + also applies to this method with minimal changes. + + Parameters + ---------- + self + first input tensor. + b + second input tensor. + n_modes + int, default is None. If None, the traditional inner product is returned + (i.e. a float) otherwise, the product between the `n_modes` last modes of + `a` and the `n_modes` first modes of `b` is returned. The resulting tensor's + order is `len(a) - n_modes`. + out + Optional output array. If provided, the output array to store the result. + + Returns + ------- + The inner product of the input arrays. + + Examples + -------- + With :class:`ivy.Array` inputs: + + >>> a = ivy.array([1, 2, 3]) + >>> b = ivy.array([4, 5, 6]) + >>> result = a.general_inner_product(b, n_modes=1) + >>> print(result) + ivy.array(32) + + >>> a = ivy.array([1, 2]) + >>> b = ivy.array([4, 5]) + >>> result = a.general_inner_product(b) + >>> print(result) + ivy.array(14) + + >>> a = ivy.array([[1, 1], [1, 1]]) + >>> b = ivy.array([[1, 2, 3, 4],[1, 1, 1, 1]]) + >>> result = a.general_inner_product(b, n_modes=1) + >>> print(result) + ivy.array([[2, 3, 4, 5], + [2, 3, 4, 5]]) + """ + return ivy.general_inner_product(self, b, n_modes, out=out) diff --git a/ivy/data_classes/array/experimental/losses.py b/ivy/data_classes/array/experimental/losses.py index 45f01d9760fd0..68265a85a16e7 100644 --- a/ivy/data_classes/array/experimental/losses.py +++ b/ivy/data_classes/array/experimental/losses.py @@ -23,9 +23,9 @@ def l1_loss( Parameters ---------- self - input array. + input array containing true labels. target - input array containing the targeted values. + input array containing targeted labels. reduction ``'mean'``: The output will be averaged. ``'sum'``: The output will be summed. @@ -49,9 +49,73 @@ def l1_loss( """ return ivy.l1_loss(self._data, target, reduction=reduction, out=out) + def log_poisson_loss( + self: Union[ivy.Array, ivy.NativeArray], + target: Union[ivy.Array, ivy.NativeArray], + /, + *, + compute_full_loss: bool = False, + axis: int = -1, + reduction: str = "none", + out: Optional[ivy.Array] = None, + ) -> ivy.Array: + """ + ivy.Array instance method variant of ivy.log_poisson_loss. This method simply + wraps the function, and so the docstring for ivy.l1_loss also applies to this + method with minimal changes. + + Parameters + ---------- + self + input array containing true labels. + target + input array containing targeted labels. + compute_full_loss + whether to compute the full loss. If false, a constant term is dropped + in favor of more efficient optimization. Default: ``False``. + axis + the axis along which to compute the log-likelihood loss. If axis is ``-1``, + the log-likelihood loss will be computed along the last dimension. + Default: ``-1``. + reduction + ``'none'``: No reduction will be applied to the output. + ``'mean'``: The output will be averaged. + ``'sum'``: The output will be summed. Default: ``'none'``. + out + optional output array, for writing the result to. It must have a shape + that the inputs broadcast to. + + Returns + ------- + ret + The binary log-likelihood loss between the given distributions. + + + Examples + -------- + >>> x = ivy.array([0, 0, 1, 0]) + >>> y = ivy.array([0.25, 0.25, 0.25, 0.25]) + >>> loss = x.log_poisson_loss(y) + >>> print(loss) + ivy.array([1.28402555, 1.28402555, 1.03402555, 1.28402555]) + + >>> z = ivy.array([0.1, 0.1, 0.7, 0.1]) + >>> loss = x.x.log_poisson_loss(z, reduction='mean') + >>> print(loss) + ivy.array(1.1573164) + """ + return ivy.log_poisson_loss( + self._data, + target, + compute_full_loss=compute_full_loss, + axis=axis, + reduction=reduction, + out=out, + ) + def huber_loss( self: ivy.Array, - pred: Union[ivy.Array, ivy.NativeArray], + target: Union[ivy.Array, ivy.NativeArray], /, *, reduction: Optional[str] = "mean", @@ -66,9 +130,9 @@ def huber_loss( Parameters ---------- self - The true (ground truth) values. - pred - The predicted values by the model. + input array containing true labels. + target + input array containing targeted labels. reduction : str, optional The type of reduction to apply to the loss. Possible values are "mean" (default) @@ -94,7 +158,7 @@ def huber_loss( ivy.array([0.125, 0.125, 0.5 , 0.125]) """ return ivy.huber_loss( - self._data, pred, reduction=reduction, delta=delta, out=out + self._data, target, reduction=reduction, delta=delta, out=out ) def smooth_l1_loss( @@ -187,3 +251,47 @@ def soft_margin_loss( ivy.array([0.35667497, 0.22314353, 1.60943791]) """ return ivy.soft_margin_loss(self._data, target, reduction=reduction, out=out) + + def kl_div( + self: ivy.Array, + target: Union[ivy.Array, ivy.NativeArray], + /, + *, + reduction: Optional[str] = "mean", + out: Optional[ivy.Array] = None, + ) -> ivy.Array: + """ + ivy.Array instance method variant of ivy.kl_div. This method simply wraps the + function, and so the docstring for ivy.kl_div also applies to this method with + minimal changes. + + Parameters + ---------- + self + Array containing input probability distribution. + target + Array contaiing target probability distribution. + reduction + 'none': No reduction will be applied to the output. + 'mean': The output will be averaged. + 'batchmean': The output will be divided by batch size. + 'sum': The output will be summed. + Default: 'mean'. + out + Optional output array, for writing the result to. + It must have a shape that the inputs broadcast to. + + Returns + ------- + ret + The Kullback-Leibler divergence loss between the two input arrays. + + Examples + -------- + >>> input = ivy.array([0.2, 0.8], [0.5, 0.5]) + >>> target = ivy.array([0.6, 0.4], [0.3, 0.7]) + >>> output_array = input.kl_div(target) + >>> print(output_array) + ivy.array(0.0916) + """ + return ivy.kl_div(self._data, target, reduction=reduction, out=out) diff --git a/ivy/data_classes/array/experimental/manipulation.py b/ivy/data_classes/array/experimental/manipulation.py index 48d9586c24dba..3acf71a02d305 100644 --- a/ivy/data_classes/array/experimental/manipulation.py +++ b/ivy/data_classes/array/experimental/manipulation.py @@ -1334,3 +1334,39 @@ def soft_thresholding( thresholded tensor on which the operator has been applied """ return ivy.soft_thresholding(self._data, threshold, out=out) + + def column_stack( + self: ivy.Array, + arrays: Sequence[Union[ivy.Array, ivy.NativeArray]], + /, + *, + out: Optional[ivy.Array] = None, + ) -> ivy.Array: + """ + ivy.Array instance method variant of ivy.column_stack. + + This method simply wraps the function, and so the docstring for + ivy.column_stack also applies to this method with minimal + changes. + + Parameters + ---------- + self + Array that will be stacked at the begining of the provided array iterable. + arrays + Arrays to be stacked. + out + Output array. + + Returns + ------- + ret + Stacked input. + """ + if not isinstance(arrays, (list, tuple)): + arrays = [arrays] + if isinstance(arrays, tuple): + x = (self._data) + arrays + else: + x = [self._data] + arrays + return ivy.column_stack(x, out=out) diff --git a/ivy/data_classes/array/experimental/statistical.py b/ivy/data_classes/array/experimental/statistical.py index 857fd61d3f37c..0a2ab13449157 100644 --- a/ivy/data_classes/array/experimental/statistical.py +++ b/ivy/data_classes/array/experimental/statistical.py @@ -180,6 +180,66 @@ def nanmean( self._data, axis=axis, keepdims=keepdims, dtype=dtype, out=out ) + def nanprod( + self: ivy.Array, + /, + *, + axis: Optional[Union[Tuple[int], int]] = None, + dtype: Optional[Union[ivy.Dtype, ivy.NativeDtype]] = None, + out: Optional[ivy.Array] = None, + keepdims: Optional[bool] = False, + initial: Optional[Union[int, float, complex]] = None, + where: Optional[ivy.Array] = None, + ) -> ivy.Array: + """ + ivy.Array instance method variant of ivy.nanprod. This method simply wraps the + function, and so the docstring for ivy.prod also applies to this method with + minimal changes. + + Parameters + ---------- + self + Input array. + axis + Axis or axes along which the product is computed. + The default is to compute the product of the flattened array. + dtype + The desired data type of returned array. Default is None. + out + optional output array, for writing the result to. + keepdims + If this is set to True, the axes which are reduced are left in the result + as dimensions with size one. With this option, the result will broadcast + correctly against the original a. + initial + The starting value for this product. + where + Elements to include in the product + + Returns + ------- + ret + The product of array elements over a given axis treating + Not a Numbers (NaNs) as ones + + Examples + -------- + >>> a = ivy.array([[1, 2], [3, ivy.nan]]) + >>> a.nanprod(a) + 6.0 + >>> a.nanprod(a, axis=0) + ivy.array([3., 2.]) + """ + return ivy.nanprod( + self._data, + axis=axis, + keepdims=keepdims, + dtype=dtype, + out=out, + initial=initial, + where=where, + ) + def quantile( self: ivy.Array, q: Union[ivy.Array, float], diff --git a/ivy/data_classes/array/general.py b/ivy/data_classes/array/general.py index 7c94e779a673f..6320790c8b13d 100644 --- a/ivy/data_classes/array/general.py +++ b/ivy/data_classes/array/general.py @@ -807,6 +807,12 @@ def array_equal(self: ivy.Array, x: Union[ivy.Array, ivy.NativeArray], /) -> boo >>> c = a.array_equal(b) >>> print(c) True + + >>> i = ivy.array([1, 2]) + >>> j = ivy.array([1, 2, 3]) + >>> k = i.array_equal(j) + >>> print(k) + False """ return ivy.array_equal(self, x) @@ -988,7 +994,7 @@ def value_is_nan(self: ivy.Array, /, *, include_infs: bool = True) -> bool: """ return ivy.value_is_nan(self, include_infs=include_infs) - def exists(self: ivy.Array) -> bool: + def exists(self: ivy.Array, /) -> bool: """ ivy.Array instance method variant of ivy.exists. This method simply wraps the function, and so the docstring for ivy.exists also applies to this method with @@ -1002,7 +1008,7 @@ def exists(self: ivy.Array) -> bool: Returns ------- ret - True if x is not None, else False. + True if input is not None, else False. Examples -------- @@ -1347,7 +1353,7 @@ def get_num_dims(self: ivy.Array, /, *, as_array: bool = False) -> int: >>> b = x.get_num_dims(as_array=False) >>> print(b) 3 - + >>> b = x.get_num_dims(as_array=True) >>> print(b) ivy.array(3) diff --git a/ivy/data_classes/array/layers.py b/ivy/data_classes/array/layers.py index 7d635d97f8fa3..1bfbb3cc805db 100644 --- a/ivy/data_classes/array/layers.py +++ b/ivy/data_classes/array/layers.py @@ -1,6 +1,6 @@ # global import abc -from typing import Optional, Tuple, Union, List, Sequence, Dict +from typing import Optional, Tuple, Union, List, Sequence # local import ivy @@ -393,12 +393,12 @@ def scaled_dot_product_attention( ) def multi_head_attention( - self: Union[ivy.Array, ivy.NativeArray], + self: ivy.Array, key: Optional[Union[ivy.Array, ivy.NativeArray]] = None, value: Optional[Union[ivy.Array, ivy.NativeArray]] = None, /, *, - num_heads: Optional[int] = 8, + num_heads: int = 8, scale: Optional[float] = None, attention_mask: Optional[Union[ivy.Array, ivy.NativeArray]] = None, in_proj_weights: Optional[Union[ivy.Array, ivy.NativeArray]] = None, @@ -408,21 +408,17 @@ def multi_head_attention( out_proj_weights: Optional[Union[ivy.Array, ivy.NativeArray]] = None, in_proj_bias: Optional[Union[ivy.Array, ivy.NativeArray]] = None, out_proj_bias: Optional[Union[ivy.Array, ivy.NativeArray]] = None, - is_causal: Optional[bool] = False, - return_attention_weights: Optional[bool] = False, - average_attention_weights: Optional[bool] = True, - dropout: Optional[float] = 0.0, - training: Optional[bool] = False, - key_chains: Optional[Union[List[str], Dict[str, str]]] = None, - to_apply: bool = True, - prune_unapplied: bool = False, - map_sequences: bool = False, - out: Optional[Union[ivy.Array, ivy.Container]] = None, + is_causal: bool = False, + return_attention_weights: bool = False, + average_attention_weights: bool = True, + dropout: float = 0.0, + training: bool = False, + out: Optional[ivy.Array] = None, ) -> ivy.Array: return ivy.multi_head_attention( self._data, - key, - value, + key=key, + value=value, num_heads=num_heads, scale=scale, attention_mask=attention_mask, diff --git a/ivy/data_classes/container/activations.py b/ivy/data_classes/container/activations.py index af9e04c0f6f15..938ea09fc50fe 100644 --- a/ivy/data_classes/container/activations.py +++ b/ivy/data_classes/container/activations.py @@ -538,6 +538,7 @@ def _static_softmax( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, ) -> ivy.Container: """ @@ -562,6 +563,9 @@ def _static_softmax( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. @@ -589,6 +593,7 @@ def _static_softmax( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, + complex_mode=complex_mode, out=out, ) @@ -600,6 +605,7 @@ def softmax( key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", map_sequences: Union[bool, ivy.Container] = False, out: Optional[ivy.Container] = None, ) -> ivy.Container: @@ -625,6 +631,9 @@ def softmax( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. @@ -651,6 +660,7 @@ def softmax( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, + complex_mode=complex_mode, out=out, ) diff --git a/ivy/data_classes/container/conversions.py b/ivy/data_classes/container/conversions.py index 910729cc613fe..00cdf6cdcfef8 100644 --- a/ivy/data_classes/container/conversions.py +++ b/ivy/data_classes/container/conversions.py @@ -18,7 +18,7 @@ class _ContainerWithConversions(ContainerBase): def _static_to_native( x: Union[ivy.Array, ivy.NativeArray, ivy.Container], nested: Union[bool, ivy.Container] = False, - include_derived: Optional[Union[Dict[type, bool], ivy.Container]] = None, + include_derived: Optional[Union[Dict[str, bool], ivy.Container]] = None, key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, @@ -78,7 +78,7 @@ def _static_to_native( def to_native( self: ivy.Container, nested: Union[bool, ivy.Container] = False, - include_derived: Optional[Union[Dict[type, bool], ivy.Container]] = None, + include_derived: Optional[Union[Dict[str, bool], ivy.Container]] = None, key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, @@ -138,7 +138,7 @@ def to_native( def _static_to_ivy( x: Union[ivy.Array, ivy.NativeArray, ivy.Container], nested: Union[bool, ivy.Container] = False, - include_derived: Optional[Union[Dict[type, bool], ivy.Container]] = None, + include_derived: Optional[Union[Dict[str, bool], ivy.Container]] = None, key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, @@ -199,7 +199,7 @@ def _static_to_ivy( def to_ivy( self: ivy.Container, nested: Union[bool, ivy.Container] = False, - include_derived: Optional[Union[Dict[type, bool], ivy.Container]] = None, + include_derived: Optional[Union[Dict[str, bool], ivy.Container]] = None, key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, diff --git a/ivy/data_classes/container/creation.py b/ivy/data_classes/container/creation.py index e6296fb0d3e37..ea238eb55c733 100644 --- a/ivy/data_classes/container/creation.py +++ b/ivy/data_classes/container/creation.py @@ -1303,7 +1303,7 @@ def _static_one_hot( ret container with tensors of zeros with the same shape and type as the inputs, unless dtype provided which overrides. - + Examples -------- With :class:`ivy.Container` inputs: @@ -1314,11 +1314,11 @@ def _static_one_hot( >>> z = ivy.Container.static_one_hot(x, y) >>> print(z) { - a: ivy.array([[0., 1., 0., 0., 0.], + a: ivy.array([[0., 1., 0., 0., 0.], [0., 0., 1., 0., 0.]]), - b: ivy.array([[0., 0., 0., 1., 0.], + b: ivy.array([[0., 0., 0., 1., 0.], [0., 1., 0., 0., 0.]]), - c: ivy.array([[0., 0., 1., 0., 0.], + c: ivy.array([[0., 0., 1., 0., 0.], [0., 0., 0., 1., 0.]]) } @@ -1328,7 +1328,7 @@ def _static_one_hot( >>> z = ivy.Container.static_one_hot(x, y) >>> print(z) { - a: ivy.array([[0., 1., 0., 0., 0.], + a: ivy.array([[0., 1., 0., 0., 0.], [0., 0., 1., 0., 0.]]), b: ivy.array([], shape=(0, 5)), c: ivy.array([[0., 0., 0., 0., 1.]]) @@ -1417,11 +1417,11 @@ def one_hot( >>> z = x.one_hot(y) >>> print(z) { - a: ivy.array([[0., 1., 0., 0., 0.], + a: ivy.array([[0., 1., 0., 0., 0.], [0., 0., 1., 0., 0.]]), - b: ivy.array([[0., 0., 0., 1., 0.], + b: ivy.array([[0., 0., 0., 1., 0.], [0., 1., 0., 0., 0.]]), - c: ivy.array([[0., 0., 1., 0., 0.], + c: ivy.array([[0., 0., 1., 0., 0.], [0., 0., 0., 1., 0.]]) } @@ -1431,7 +1431,7 @@ def one_hot( >>> z = x.one_hot(y) >>> print(z) { - a: ivy.array([[0., 1., 0., 0., 0.], + a: ivy.array([[0., 1., 0., 0., 0.], [0., 0., 1., 0., 0.]]), b: ivy.array([], shape=(0, 5)), c: ivy.array([[0., 0., 0., 0., 1.]]) diff --git a/ivy/data_classes/container/elementwise.py b/ivy/data_classes/container/elementwise.py index 226a9fe15edf2..2c4432952b979 100644 --- a/ivy/data_classes/container/elementwise.py +++ b/ivy/data_classes/container/elementwise.py @@ -6243,7 +6243,7 @@ def _static_multiply( a container containing the element-wise results. The returned container must have a data type determined by :ref:`type-promotion`. - + Examples -------- With :code:`ivy.Container` inputs: @@ -6755,7 +6755,7 @@ def positive( @staticmethod def _static_pow( x1: Union[ivy.Array, ivy.NativeArray, ivy.Container], - x2: Union[ivy.Array, ivy.NativeArray, ivy.Container], + x2: Union[int, float, ivy.Array, ivy.NativeArray, ivy.Container], /, *, key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, @@ -6823,7 +6823,7 @@ def _static_pow( def pow( self: ivy.Container, - x2: Union[ivy.Container, ivy.Array, ivy.NativeArray], + x2: Union[int, float, ivy.Container, ivy.Array, ivy.NativeArray], /, *, key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, diff --git a/ivy/data_classes/container/experimental/activations.py b/ivy/data_classes/container/experimental/activations.py index aa60fc5bd9dfd..082fb5e062b40 100644 --- a/ivy/data_classes/container/experimental/activations.py +++ b/ivy/data_classes/container/experimental/activations.py @@ -1,5 +1,5 @@ # global -from typing import Union, Optional, List, Dict +from typing import Union, Optional, List, Dict, Literal # local import ivy @@ -13,6 +13,7 @@ def static_logit( /, *, eps: Optional[Union[float, ivy.Container]] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, ) -> ivy.Container: """ @@ -28,6 +29,9 @@ def static_logit( When eps is None the function outpus NaN where x < 0 or x > 1. and inf or -inf where x = 1 or x = 0, respectively. Otherwise if eps is defined, x is clamped to [eps, 1 - eps] + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out Optional output Contaner. @@ -62,6 +66,7 @@ def static_logit( "logit", x, eps=eps, + complex_mode=complex_mode, out=out, ) @@ -70,6 +75,7 @@ def logit( /, *, eps: Optional[Union[float, ivy.Container]] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, ) -> ivy.Container: """ @@ -85,6 +91,9 @@ def logit( When eps is None the function outpus NaN where x < 0 or x > 1. and inf or -inf where x = 1 or x = 0, respectively. Otherwise if eps is defined, x is clamped to [eps, 1 - eps] + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out Optional output Contaner. @@ -115,7 +124,7 @@ def logit( b: ivy.array([-1.38629436, 1.38629448, -1.38629436]) } """ - return self.static_logit(self, eps=eps, out=out) + return self.static_logit(self, eps=eps, complex_mode=complex_mode, out=out) @staticmethod def static_thresholded_relu( @@ -442,6 +451,7 @@ def static_logsigmoid( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Container: """ ivy.Container static method variant of ivy.logsigmoid. This method simply wraps @@ -463,6 +473,9 @@ def static_logsigmoid( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -497,6 +510,7 @@ def static_logsigmoid( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, + complex_mode=complex_mode, ) def logsigmoid( @@ -507,6 +521,7 @@ def logsigmoid( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Container: """ Apply element-wise Log-sigmoid of x i.e. log(1 / (1 + exp(-x)). @@ -515,6 +530,9 @@ def logsigmoid( ---------- self Input container. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -537,6 +555,7 @@ def logsigmoid( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, + complex_mode=complex_mode, ) @staticmethod diff --git a/ivy/data_classes/container/experimental/creation.py b/ivy/data_classes/container/experimental/creation.py index 8dd5ac76ad3a3..63eb416191d7b 100644 --- a/ivy/data_classes/container/experimental/creation.py +++ b/ivy/data_classes/container/experimental/creation.py @@ -1101,3 +1101,102 @@ def trilu( upper=upper, out=out, ) + + @staticmethod + def static_mel_weight_matrix( + num_mel_bins: Union[int, ivy.Container], + dft_length: Union[int, ivy.Container], + sample_rate: Union[int, ivy.Container], + lower_edge_hertz: Optional[Union[float, ivy.Container]] = 0.0, + upper_edge_hertz: Optional[Union[float, ivy.Container]] = 3000.0, + *, + key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, + to_apply: Union[bool, ivy.Container] = True, + prune_unapplied: Union[bool, ivy.Container] = False, + map_sequences: Union[bool, ivy.Container] = False, + ) -> ivy.Container: + r""" + ivy.Container instance method variant of ivy.mel_weight_matrix. This method + simply wraps the function, and so the docstring for ivy.mel_weight_matrix also + applies to this method with minimal changes. + + Parameters + ---------- + num_mel_bins + The number of bands in the mel spectrum. + dft_length + The size of the original DFT obtained from (n_fft / 2 + 1). + sample_rate + Samples per second of the input signal. + lower_edge_hertz + Lower bound on the frequencies to be included in the mel spectrum. + upper_edge_hertz + The desired top edge of the highest frequency band. + key_chains + The key-chains to apply or not apply the method to. Default is ``None``. + to_apply + If True, the method will be applied to key_chains, otherwise key_chains + will be skipped. Default is ``True``. + prune_unapplied + Whether to prune key_chains for which the function was not applied. + Default is ``False``. + map_sequences + Whether to also map method to sequences (lists, tuples). + Default is ``False``. + + Returns + ------- + ret + MelWeightMatrix of shape: [frames, num_mel_bins] + """ + return ContainerBase.cont_multi_map_in_function( + "mel_weight_matrix", + num_mel_bins, + dft_length, + sample_rate, + lower_edge_hertz, + upper_edge_hertz, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + ) + + def mel_weight_matrix( + self: ivy.Container, + num_mel_bins: int, + dft_length: int, + sample_rate: int, + lower_edge_hertz: Optional[float] = 0.0, + upper_edge_hertz: Optional[float] = 3000.0, + ): + r""" + ivy.Container instance method variant of ivy.mel_weight_matrix. This method + simply wraps the function, and so the docstring for ivy.mel_weight_matrix also + applies to this method with minimal changes. + + Parameters + ---------- + num_mel_bins + The number of bands in the mel spectrum. + dft_length + The size of the original DFT obtained from (n_fft / 2 + 1). + sample_rate + Samples per second of the input signal. + lower_edge_hertz + Lower bound on the frequencies to be included in the mel spectrum. + upper_edge_hertz + The desired top edge of the highest frequency band. + + Returns + ------- + ret + MelWeightMatrix of shape: [frames, num_mel_bins] + """ + return self.static_mel_weight_matrix( + num_mel_bins, + dft_length, + sample_rate, + lower_edge_hertz, + upper_edge_hertz, + ) diff --git a/ivy/data_classes/container/experimental/losses.py b/ivy/data_classes/container/experimental/losses.py index b0cf6c1c4225e..e4ee3c40b45d7 100644 --- a/ivy/data_classes/container/experimental/losses.py +++ b/ivy/data_classes/container/experimental/losses.py @@ -160,6 +160,181 @@ def l1_loss( out=out, ) + @staticmethod + def _static_log_poisson_loss( + input: Union[ivy.Container, ivy.Array, ivy.NativeArray], + target: Union[ivy.Container, ivy.Array, ivy.NativeArray], + /, + *, + compute_full_loss: bool = False, + axis: int = -1, + reduction: Optional[Union[str, ivy.Container]] = "mean", + key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, + to_apply: Union[bool, ivy.Container] = True, + prune_unapplied: Union[bool, ivy.Container] = False, + map_sequences: Union[bool, ivy.Container] = False, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ + ivy.Container static method variant of ivy.log_poisson_loss. This method simply + wraps the function, and so the docstring for ivy.log_poisson_loss also applies + to this method with minimal changes. + + Parameters + ---------- + input + input array or container. + target + input array or container containing the targeted values. + compute_full_loss + whether to compute the full loss. If false, a constant term is dropped + in favor of more efficient optimization. Default: ``False``. + axis + the axis along which to compute the log-likelihood loss. If axis is ``-1``, + the log-likelihood loss will be computed along the last dimension. + Default: ``-1``. + reduction + ``'mean'``: The output will be averaged. + ``'sum'``: The output will be summed. + ``'none'``: No reduction will be applied to the output. Default: ``'none'``. + key_chains + The key-chains to apply or not apply the method to. Default is ``None``. + to_apply + If input, the method will be applied to key_chains, otherwise key_chains + will be skipped. Default is ``input``. + prune_unapplied + Whether to prune key_chains for which the function was not applied. + Default is ``False``. + map_sequences + Whether to also map method to sequences (lists, tuples). + Default is ``False``. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. + + Returns + ------- + ret + The L1 loss between the input array and the targeted values. + + Examples + -------- + With :class:`ivy.Container` inputs: + + >>> x = ivy.Container(a=ivy.array([1, 2, 3]), b=ivy.array([4, 5, 6])) + >>> y = ivy.Container(a=ivy.array([2, 2, 2]), b=ivy.array([5, 5, 5])) + >>> z = ivy.Container.static_log_poisson_loss(x, y, reduction='mean') + >>> print(z) + { + a: ivy.array(1.), + b: ivy.array(0.) + } + + With a mix of :class:`ivy.Array` and :class:`ivy.Container` inputs: + + >>> x = ivy.array([1, 2, 3]) + >>> y = ivy.Container(a=ivy.array([2, 2, 2]), b=ivy.array([5, 5, 5])) + >>> z = ivy.Container.static_log_poisson_loss(x, y, reduction='mean') + >>> print(z) + { + a: ivy.array(1.), + b: ivy.array(4.) + } + """ + return ContainerBase.cont_multi_map_in_function( + "log_poisson_loss", + input, + target, + compute_full_loss=compute_full_loss, + axis=axis, + reduction=reduction, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) + + def log_poisson_loss( + self: ivy.Container, + target: Union[ivy.Container, ivy.Array, ivy.NativeArray], + /, + *, + compute_full_loss: bool = False, + axis: int = -1, + reduction: Optional[Union[str, ivy.Container]] = "mean", + key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, + to_apply: Union[bool, ivy.Container] = True, + prune_unapplied: Union[bool, ivy.Container] = False, + map_sequences: Union[bool, ivy.Container] = False, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ + ivy.Container instance method variant of ivy.log_poisson_loss. This method + simply wraps the function, and so the docstring for ivy.log_poisson_loss also + applies to this method with minimal changes. + + Parameters + ---------- + self + input container. + target + input array or container containing the targeticted values. + compute_full_loss + whether to compute the full loss. If false, a constant term is dropped + in favor of more efficient optimization. Default: ``False``. + axis + the axis along which to compute the log-likelihood loss. If axis is ``-1``, + the log-likelihood loss will be computed along the last dimension. + Default: ``-1``. + reduction + ``'mean'``: The output will be averaged. + ``'sum'``: The output will be summed. + ``'none'``: No reduction will be applied to the output. Default: ``'none'``. + key_chains + The key-chains to apply or not apply the method to. Default is ``None``. + to_apply + If input, the method will be applied to key_chains, otherwise key_chains + will be skipped. Default is ``input``. + prune_unapplied + Whether to prune key_chains for which the function was not applied. + Default is ``False``. + map_sequences + Whether to also map method to sequences (lists, tuples). + Default is ``False``. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. + + Returns + ------- + ret + The L1 loss between the input array and the targeticted values. + + Examples + -------- + >>> x = ivy.Container(a=ivy.array([1, 2, 3]), b=ivy.array([4, 5, 6])) + >>> y = ivy.Container(a=ivy.array([2, 2, 2]), b=ivy.array([5, 5, 5])) + >>> z = x.log_poisson_loss(y) + >>> print(z) + { + a: ivy.array(1.), + b: ivy.array(0.) + } + """ + return self._static_log_poisson_loss( + self, + target, + compute_full_loss=compute_full_loss, + axis=axis, + reduction=reduction, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) + @staticmethod def _static_smooth_l1_loss( input: Union[ivy.Container, ivy.Array, ivy.NativeArray], @@ -612,3 +787,117 @@ def soft_margin_loss( map_sequences=map_sequences, out=out, ) + + @staticmethod + def _static_kl_div( + input: Union[ivy.Container, ivy.Array, ivy.NativeArray], + target: Union[ivy.Container, ivy.Array, ivy.NativeArray], + /, + *, + reduction: Optional[Union[str, ivy.Container]] = "mean", + key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, + to_apply: Union[bool, ivy.Container] = True, + prune_unapplied: Union[bool, ivy.Container] = False, + map_sequences: Union[bool, ivy.Container] = False, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ + ivy.Container static method variant of ivy.kl_div. This method simply wraps the + function, and so the docstring for ivy.kl_div also applies to this method with + minimal changes. + + Parameters + ---------- + input + input array or container containing input distribution. + target + input array or container containing target distribution. + reduction + the reduction method. Default: "mean". + key_chains + The key-chains to apply or not apply the method to. Default is None. + to_apply + If input, the method will be applied to key_chains, otherwise key_chains + will be skipped. Default is input. + prune_unapplied + Whether to prune key_chains for which the function was not applied. + Default is False. + map_sequences + Whether to also map method to sequences (lists, tuples). + Default is False. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. + + Returns + ------- + ret + The Kullback-Leibler divergence loss between the given distributions. + """ + return ContainerBase.cont_multi_map_in_function( + "kl_div", + input, + target, + reduction=reduction, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) + + def kl_div( + self: ivy.Container, + target: Union[ivy.Container, ivy.Array, ivy.NativeArray], + /, + *, + reduction: Optional[Union[str, ivy.Container]] = "mean", + key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, + to_apply: Union[bool, ivy.Container] = True, + prune_unapplied: Union[bool, ivy.Container] = False, + map_sequences: Union[bool, ivy.Container] = False, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ + ivy.Container instance method variant of ivy.kl_div. This method simply wraps + the function, and so the docstring for ivy.kl_div also applies to this method + with minimal changes. + + Parameters + ---------- + self + input container containing input distribution. + target + input array or container containing target distribution. + reduction + the reduction method. Default: "mean". + key_chains + The key-chains to apply or not apply the method to. Default is None. + to_apply + If input, the method will be applied to key_chains, otherwise key_chains + will be skipped. Default is input. + prune_unapplied + Whether to prune key_chains for which the function was not applied. + Default is False. + map_sequences + Whether to also map method to sequences (lists, tuples). + Default is False. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. + + Returns + ------- + ret + The Kullback-Leibler divergence loss between the given distributions. + """ + return self._static_kl_div( + self, + target, + reduction=reduction, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) diff --git a/ivy/data_classes/container/experimental/manipulation.py b/ivy/data_classes/container/experimental/manipulation.py index 25ba79bec9fd4..a8a43a3f75322 100644 --- a/ivy/data_classes/container/experimental/manipulation.py +++ b/ivy/data_classes/container/experimental/manipulation.py @@ -554,7 +554,7 @@ def static_rot90( ------- ret Container with a rotated view of m. - + Examples -------- >>> m = ivy.Container(a=ivy.array([[1,2], [3,4]]),\ @@ -2057,7 +2057,7 @@ def atleast_3d( container with array inputs. arys one or more container with array inputs. - + key_chains The keychains to apply or not apply the method to. Default is ``None``. to_apply @@ -3678,3 +3678,110 @@ def soft_thresholding( thresholded tensor on which the operator has been applied """ return self.static_soft_thresholding(self, threshold, out=out) + + @staticmethod + def static_column_stack( + xs: Sequence[Union[ivy.Array, ivy.NativeArray, ivy.Container]], + /, + *, + key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, + to_apply: Union[bool, ivy.Container] = True, + prune_unapplied: Union[bool, ivy.Container] = False, + map_sequences: Union[bool, ivy.Container] = False, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ + ivy.Container static method variant of ivy.column_stack. + + This method simply wraps the function, and so the docstring for + ivy.column_stack also applies to this method with minimal + changes. + + Parameters + ---------- + xs + Container with leaves to stack. + + key_chains + The key-chains to apply or not apply the method to. Default is ``None``. + to_apply + If True, the method will be applied to key_chains, otherwise key_chains + will be skipped. Default is ``True``. + prune_unapplied + Whether to prune key_chains for which the function was not applied. + Default is ``False``. + map_sequences + Whether to also map method to sequences (lists, tuples). + Default is ``False``. + out + Optional output array, for writing the result to. + + Returns + ------- + ret + An output container with the results. + """ + return ContainerBase.cont_multi_map_in_function( + "column_stack", + xs, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) + + def column_stack( + self: ivy.Container, + /, + xs: Sequence[Union[ivy.Array, ivy.NativeArray, ivy.Container]], + *, + key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, + to_apply: Union[bool, ivy.Container] = True, + prune_unapplied: Union[bool, ivy.Container] = False, + map_sequences: Union[bool, ivy.Container] = False, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ + ivy.Container instance method variant of ivy.column_stack. + + This method simply wraps the function, and so the docstring for + ivy.column_stack also applies to this method with minimal + changes. + + Parameters + ---------- + self + Container with leaves to stack with leaves of other arrays/containers. + xs + Container with other leaves to join. + + key_chains + The key-chains to apply or not apply the method to. Default is ``None``. + to_apply + If True, the method will be applied to key_chains, otherwise key_chains + will be skipped. Default is ``True``. + prune_unapplied + Whether to prune key_chains for which the function was not applied. + Default is ``False``. + map_sequences + Whether to also map method to sequences (lists, tuples). + Default is ``False``. + out + Optional output array, for writing the result to. + + Returns + ------- + ret + An output container with the results. + """ + new_xs = xs.cont_copy() if ivy.is_ivy_container(xs) else list(xs).copy() + new_xs.insert(0, self.cont_copy()) + return self.static_column_stack( + new_xs, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) diff --git a/ivy/data_classes/container/experimental/random.py b/ivy/data_classes/container/experimental/random.py index 9b60d9742c94e..7544f92d3fafa 100644 --- a/ivy/data_classes/container/experimental/random.py +++ b/ivy/data_classes/container/experimental/random.py @@ -30,10 +30,10 @@ def static_dirichlet( Parameters ---------- alpha - Sequence of floats of length k + Sequence of floats of length k size - optional container including ints or tuple of ints, - Output shape for the arrays in the input container. + optional container including ints or tuple of ints, + Output shape for the arrays in the input container. dtype output container array data type. If ``dtype`` is ``None``, the output data type will be the default floating-point data type. Default ``None`` @@ -97,10 +97,10 @@ def dirichlet( Parameters ---------- self - Sequence of floats of length k + Sequence of floats of length k size - optional container including ints or tuple of ints, - Output shape for the arrays in the input container. + optional container including ints or tuple of ints, + Output shape for the arrays in the input container. dtype output container array data type. If ``dtype`` is ``None``, the output data type will be the default floating-point data type. Default ``None`` diff --git a/ivy/data_classes/container/experimental/statistical.py b/ivy/data_classes/container/experimental/statistical.py index f4f61e636ea94..edf20317bbdc9 100644 --- a/ivy/data_classes/container/experimental/statistical.py +++ b/ivy/data_classes/container/experimental/statistical.py @@ -356,8 +356,8 @@ def static_nanmean( If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original a. If the value is anything but the default, - then keepdims will be passed through to the mean or sum methods of - sub-classes of ndarray. If the sub-classes methods does not implement + then keepdims will be passed through to the mean or sum methods of + sub-classes of ndarray. If the sub-classes methods does not implement keepdims any exceptions will be raised. dtype The desired data type of returned tensor. Default is None. @@ -417,8 +417,8 @@ def nanmean( If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original a. If the value is anything but the default, - then keepdims will be passed through to the mean or sum methods of - sub-classes of ndarray. If the sub-classes methods does not implement + then keepdims will be passed through to the mean or sum methods of + sub-classes of ndarray. If the sub-classes methods does not implement keepdims any exceptions will be raised. dtype The desired data type of returned tensor. Default is None. @@ -444,6 +444,140 @@ def nanmean( self, axis=axis, keepdims=keepdims, dtype=dtype, out=out ) + @staticmethod + def static_nanprod( + input: ivy.Container, + /, + *, + axis: Optional[Union[Tuple[int], int, ivy.Container]] = None, + dtype: Optional[Union[ivy.Dtype, ivy.NativeDtype, ivy.Container]] = None, + keepdims: Optional[Union[bool, ivy.Container]] = False, + key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, + to_apply: Union[bool, ivy.Container] = True, + prune_unapplied: Union[bool, ivy.Container] = False, + map_sequences: Union[bool, ivy.Container] = False, + out: Optional[Union[ivy.Array, ivy.Container]] = None, + initial: Optional[Union[int, float, complex, ivy.Container]] = 1, + where: Optional[Union[ivy.Array, ivy.Container]] = None, + ) -> ivy.Container: + """ + ivy.Container static method variant of ivy.nanprod. This method simply wraps the + function, and so the docstring for ivy.nanprod also applies to this method with + minimal changes. + + Parameters + ---------- + input + Input container including arrays. + axis + Axis or axes along which the product is computed. + The default is to compute the product of the flattened array. + dtype + The desired data type of returned array. Default is None. + out + optional output array, for writing the result to. + keepdims + If this is set to True, the axes which are reduced are left in the result + as dimensions with size one. With this option, the result will broadcast + correctly against the original a. + initial + The starting value for this product. + where + Elements to include in the product + + Returns + ------- + ret + The product of array elements over a given axis treating + Not a Numbers (NaNs) as ones + + Examples + -------- + >>> a = ivy.Container(x=ivy.array([[1, 2], [3, ivy.nan]]),\ + y=ivy.array([[ivy.nan, 1, 2], [1, 2, 3]]) + >>> ivy.Container.static_nanprod(a) + { + x: 12.0 + y: 12.0 + } + """ + return ContainerBase.cont_multi_map_in_function( + "nanprod", + input, + axis=axis, + keepdims=keepdims, + dtype=dtype, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + initial=initial, + where=where, + ) + + def nanprod( + self: ivy.Container, + /, + *, + axis: Optional[Union[Tuple[int], int, ivy.Container]] = None, + dtype: Optional[Union[ivy.Dtype, ivy.NativeDtype, ivy.Container]] = None, + keepdims: Optional[Union[bool, ivy.Container]] = False, + out: Optional[ivy.Container] = None, + initial: Optional[Union[int, float, complex, ivy.Container]] = None, + where: Optional[Union[ivy.Array, ivy.Container]] = None, + ) -> ivy.Container: + """ + ivy.Container instance method variant of ivy.nanprod. This method simply wraps + the function, and so the docstring for ivy.nanprod also applies to this method + with minimal changes. + + Parameters + ---------- + self + Input container including arrays. + axis + Axis or axes along which the product is computed. + The default is to compute the product of the flattened array. + dtype + The desired data type of returned array. Default is None. + out + optional output array, for writing the result to. + keepdims + If this is set to True, the axes which are reduced are left in the result + as dimensions with size one. With this option, the result will broadcast + correctly against the original a. + initial + The starting value for this product. + where + Elements to include in the product + + Returns + ------- + ret + The product of array elements over a given axis treating + Not a Numbers (NaNs) as ones + + Examples + -------- + >>> a = ivy.Container(x=ivy.array([[1, 2], [3, ivy.nan]]),\ + y=ivy.array([[ivy.nan, 1, 2], [1, 2, 3]]) + >>> a.nanprod() + { + x: 12.0 + y: 12.0 + } + """ + return self.static_nanprod( + self, + axis=axis, + keepdims=keepdims, + dtype=dtype, + out=out, + initial=initial, + where=where, + ) + @staticmethod def static_quantile( a: Union[ivy.Container, ivy.Array, ivy.NativeArray], @@ -735,9 +869,9 @@ def static_corrcoef( z=ivy.array([[0., 1., 2.], [2., 1., 0.]])) >>> ivy.Container.corrcoef(a) { - w: ivy.array([[1., 1.], + w: ivy.array([[1., 1.], [1., 1.]]), - z: ivy.array([[1., -1.], + z: ivy.array([[1., -1.], [-1., 1.]]) } """ @@ -788,9 +922,9 @@ def corrcoef( z=ivy.array([[0., 1., 2.], [2., 1., 0.]])) >>> ivy.Container.corrcoef(a) { - w: ivy.array([[1., 1.], + w: ivy.array([[1., 1.], [1., 1.]]), - z: ivy.array([[1., -1.], + z: ivy.array([[1., -1.], [-1., 1.]]) } """ diff --git a/ivy/data_classes/container/general.py b/ivy/data_classes/container/general.py index 8f094f546ade0..32c021edc72a4 100644 --- a/ivy/data_classes/container/general.py +++ b/ivy/data_classes/container/general.py @@ -4031,14 +4031,25 @@ def array_equal( >>> b = ivy.array([[-2., 1.], [1. ,2.]]) >>> c = ivy.array([[0., 1.], [1. ,0.]]) >>> d = ivy.array([[2., 1.], [1. ,2.]]) - >>> a0 = ivy.Container(a = a, b = b) - >>> a1 = ivy.Container(a = c, b = d) - >>> y = a0.array_equal(a1) + >>> a1 = ivy.Container(a = a, b = b) + >>> a2 = ivy.Container(a = c, b = d) + >>> y = a1.array_equal(a2) >>> print(y) { a: True, b: False } + + >>> x1 = ivy.Container(a=ivy.native_array([1, 0, 0]), + b=ivy.array([1, 2, 3])) + >>> x2 = ivy.Container(a=ivy.native_array([1, 0, 1]), + b=ivy.array([1, 2, 3])) + >>> y = x1.array_equal(x2) + >>> print(y) + { + a: False, + b: True + } """ return _ContainerWithGeneral._static_array_equal( self, @@ -4245,3 +4256,127 @@ def strides( A tuple containing the strides. """ return self.static_strides(self) + + @staticmethod + def _static_exists( + x: ivy.Container, + /, + *, + key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, + to_apply: Union[bool, ivy.Container] = True, + prune_unapplied: Union[bool, ivy.Container] = False, + map_sequences: Union[bool, ivy.Container] = False, + ) -> ivy.Container: + """ + ivy.Container instance method variant of ivy.exists. This method simply wraps + the function, and so the docstring for ivy.exists also applies to this method + with minimal changes. + + Parameters + ---------- + x + The input container. + key_chains + The key-chains to apply or not apply the method to. Default is ``None``. + to_apply + If True, the method will be applied to key_chains, otherwise key_chains + will be skipped. Default is ``True``. + prune_unapplied + Whether to prune key_chains for which the function was not applied. + Default is ``False``. + map_sequences + Whether to also map method to sequences (lists, tuples). + Default is ``False``. + + Returns + ------- + ret + A boolean container detaling if any of the leaf nodes are None. + True if not None, False if None. + + Examples + -------- + >>> x = ivy.Container(a=ivy.array([0,4,5]), b=ivy.array([2,2,0])) + >>> y = x._static_exists(x) + >>> print(y) + { a: True, b: True } + + >>> x = ivy.Container(a=[1,2], b=None) + >>> y = x._static_exists(x) + >>> print(y) + { a: True, b: False } + + >>> x = ivy.Container(a={"d": 1, "c": 3}, b={"d": 20, "c": None}) + >>> y = x._static_exists(x) + >>> print(y) + { a: { c: True, d: True }, b: { c: False, d: True } } + """ + return ContainerBase.cont_multi_map_in_function( + "exists", + x, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + ) + + def exists( + self: ivy.Container, + /, + *, + key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, + to_apply: Union[bool, ivy.Container] = True, + prune_unapplied: Union[bool, ivy.Container] = False, + map_sequences: Union[bool, ivy.Container] = False, + ) -> ivy.Container: + """ + ivy.Container instance method variant of ivy.exists. This method simply wraps + the function, and so the docstring for ivy.exists also applies to this method + with minimal changes. + + Parameters + ---------- + self + The input container. + key_chains + The key-chains to apply or not apply the method to. Default is ``None``. + to_apply + If True, the method will be applied to key_chains, otherwise key_chains + will be skipped. Default is ``True``. + prune_unapplied + Whether to prune key_chains for which the function was not applied. + Default is ``False``. + map_sequences + Whether to also map method to sequences (lists, tuples). + Default is ``False``. + + Returns + ------- + ret + A boolean container detaling if any of the leaf nodes are None. + True if not None, False if None. + + Examples + -------- + >>> x = ivy.Container(a=[1,2,3,4], b=[]) + >>> y = x.exists() + >>> print(y) + { a: True, b: True } + + >>> x = ivy.Container(a=None, b=[1,2]) + >>> y = x.exists() + >>> print(y) + { a: False, b: True } + + >>> x = ivy.Container(a={"d": 1, "c": 3}, b=None) + >>> y = x.exists() + >>> print(y) + { a: { c: True, d: True }, b: False } + """ + return self._static_exists( + self, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + ) diff --git a/ivy/data_classes/container/layers.py b/ivy/data_classes/container/layers.py index d094de5ac4a64..8e28923ad204a 100644 --- a/ivy/data_classes/container/layers.py +++ b/ivy/data_classes/container/layers.py @@ -1026,11 +1026,11 @@ def scaled_dot_product_attention( @staticmethod def _static_multi_head_attention( query: Union[ivy.Array, ivy.NativeArray, ivy.Container], - key: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None, - value: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None, /, *, - num_heads: Optional[Union[int, ivy.Container]] = 8, + key: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None, + value: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None, + num_heads: Union[int, ivy.Container] = 8, scale: Optional[Union[float, ivy.Container]] = None, attention_mask: Optional[ Union[ivy.Array, ivy.NativeArray, ivy.Container] @@ -1054,22 +1054,22 @@ def _static_multi_head_attention( out_proj_bias: Optional[ Union[ivy.Array, ivy.NativeArray, ivy.Container] ] = None, - is_causal: Optional[Union[bool, ivy.Container]] = False, - return_attention_weights: Optional[Union[bool, ivy.Container]] = False, - average_attention_weights: Optional[Union[bool, ivy.Container]] = True, - dropout: Optional[Union[float, ivy.Container]] = 0.0, - training: Optional[Union[bool, ivy.Container]] = False, + is_causal: Union[bool, ivy.Container] = False, + return_attention_weights: Union[bool, ivy.Container] = False, + average_attention_weights: Union[bool, ivy.Container] = True, + dropout: Union[float, ivy.Container] = 0.0, + training: Union[bool, ivy.Container] = False, key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, out: Optional[Union[ivy.Array, ivy.Container]] = None, - ) -> Union[ivy.Array, ivy.NativeArray, ivy.Container]: + ) -> ivy.Container: return ContainerBase.cont_multi_map_in_function( "multi_head_attention", query, - key, - value, + key=key, + value=value, num_heads=num_heads, scale=scale, attention_mask=attention_mask, @@ -1098,7 +1098,7 @@ def multi_head_attention( value: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None, /, *, - num_heads: Optional[Union[int, ivy.Container]] = 8, + num_heads: Union[int, ivy.Container] = 8, scale: Optional[Union[float, ivy.Container]] = None, attention_mask: Optional[ Union[ivy.Array, ivy.NativeArray, ivy.Container] @@ -1122,21 +1122,21 @@ def multi_head_attention( out_proj_bias: Optional[ Union[ivy.Array, ivy.NativeArray, ivy.Container] ] = None, - is_causal: Optional[Union[bool, ivy.Container]] = False, - return_attention_weights: Optional[Union[bool, ivy.Container]] = False, - average_attention_weights: Optional[Union[bool, ivy.Container]] = True, - dropout: Optional[Union[float, ivy.Container]] = 0.0, - training: Optional[Union[bool, ivy.Container]] = False, + is_causal: Union[bool, ivy.Container] = False, + return_attention_weights: Union[bool, ivy.Container] = False, + average_attention_weights: Union[bool, ivy.Container] = True, + dropout: Union[float, ivy.Container] = 0.0, + training: Union[bool, ivy.Container] = False, key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, out: Optional[Union[ivy.Array, ivy.Container]] = None, - ) -> Union[ivy.Array, ivy.NativeArray, ivy.Container]: + ) -> ivy.Container: return self._static_multi_head_attention( self, - key, - value, + key=key, + value=value, num_heads=num_heads, scale=scale, attention_mask=attention_mask, @@ -2105,7 +2105,7 @@ def _static_conv3d( filter_format Either "channel_first" or "channel_last". Defaults to "channel_last". x_dilations - The dilation factor for each dimension of input. (Default value = 1) + The dilation factor for each dimension of input. (Default value = 1) dilations The dilation factor for each dimension of input. (Default value = 1) bias @@ -2190,7 +2190,7 @@ def conv3d( filter_format Either "channel_first" or "channel_last". Defaults to "channel_last". x_dilations - The dilation factor for each dimension of input. (Default value = 1) + The dilation factor for each dimension of input. (Default value = 1) dilations The dilation factor for each dimension of input. (Default value = 1) bias diff --git a/ivy/data_classes/container/linear_algebra.py b/ivy/data_classes/container/linear_algebra.py index e68d1e92d86b6..d9b7fa142b492 100644 --- a/ivy/data_classes/container/linear_algebra.py +++ b/ivy/data_classes/container/linear_algebra.py @@ -1416,7 +1416,7 @@ def _static_matrix_norm( Parameters ---------- x - Input array having shape (..., M, N) and whose innermost two deimensions + Input array having shape (..., M, N) and whose innermost two deimensions form MxN matrices. Should have a floating-point data type. ord Order of the norm. Default is "fro". @@ -1445,7 +1445,7 @@ def _static_matrix_norm( ------- ret Matrix norm of the array at specified axes. - + Examples -------- >>> x = ivy.Container(a=ivy.array([[1.1, 2.2], [1., 2.]]), \ @@ -1466,7 +1466,7 @@ def _static_matrix_norm( >>> print(y) { a: ivy.array([4.24, 11.4, 19.2]), - b: ivy.array([[[3.7]], + b: ivy.array([[[3.7]], [[11.2]]]) } """ @@ -1504,7 +1504,7 @@ def matrix_norm( Parameters ---------- self - Container having shape (..., M, N) and whose innermost two dimensions + Container having shape (..., M, N) and whose innermost two dimensions form MxN matrices. Should have a floating-point data type. ord Order of the norm. Default is "fro". @@ -1553,8 +1553,8 @@ def matrix_norm( >>> y = x.matrix_norm(ord=ord, axis=axis, keepdims=k) >>> print(y) { - a: ivy.array([[[4.24]], - [[11.4]], + a: ivy.array([[[4.24]], + [[11.4]], [[19.2]]]), b: ivy.array([4., 12.]) } @@ -3313,3 +3313,106 @@ def vander( increasing=increasing, out=out, ) + + @staticmethod + def static_general_inner_product( + x1: Union[ivy.Container, ivy.Array, ivy.NativeArray], + x2: Union[ivy.Container, ivy.Array, ivy.NativeArray], + n_modes: Optional[Union[int, ivy.Container]] = None, + /, + *, + key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, + to_apply: Union[bool, ivy.Container] = True, + prune_unapplied: Union[bool, ivy.Container] = False, + map_sequences: Union[bool, ivy.Container] = False, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ + ivy.Container static method variant of ivy.general_inner_product. This method + simply wraps the function, and so the docstring for ivy.general_inner_product + also applies to this method with minimal changes. + + Parameters + ---------- + x1 + First input container containing input array. + x2 + First input container containing input array. + n_modes + int, default is None. If None, the traditional inner product is returned + (i.e. a float) otherwise, the product between the `n_modes` last modes of + `x1` and the `n_modes` first modes of `x2` is returned. The resulting + tensor's order is `len(x1) - n_modes`. + key_chains + The key-chains to apply or not apply the method to. Default is ``None``. + to_apply + If True, the method will be applied to key_chains, otherwise key_chains + will be skipped. Default is ``True``. + prune_unapplied + Whether to prune key_chains for which the function was not applied. + Default is ``False``. + map_sequences + Whether to also map method to sequences (lists, tuples). + Default is ``False``. + out + Alternate output container in which to place the result. + The default is None. + + Returns + ------- + ret + Container including the inner product tensor. + + Examples + -------- + >>> x = ivy.Container( + a=ivy.reshape(ivy.arange(4), (2, 2)), + b=ivy.reshape(ivy.arange(8), (2, 4)), + ) + >>> ivy.Container.general_inner_product(x, 1) + { + a: ivy.array(6), + b: ivy.array(28) + } + """ + return ContainerBase.cont_multi_map_in_function( + "general_inner_product", + x1, + x2, + n_modes, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) + + def general_inner_product( + self: Union[ivy.Container, ivy.Array, ivy.NativeArray], + x2: Union[ivy.Container, ivy.Array, ivy.NativeArray], + n_modes: Optional[Union[int, ivy.Container]] = None, + /, + *, + key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, + to_apply: Union[bool, ivy.Container] = True, + prune_unapplied: Union[bool, ivy.Container] = False, + map_sequences: Union[bool, ivy.Container] = False, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ + ivy.Container instance method variant of ivy.general_inner_product. + + This method simply wraps the function, and so the docstring for + ivy.general_inner_product also applies to this method with + minimal changes. + """ + return self.static_general_inner_product( + self, + x2, + n_modes, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) diff --git a/ivy/engines/XLA/__init__.py b/ivy/engines/XLA/__init__.py new file mode 100644 index 0000000000000..45a9584091f50 --- /dev/null +++ b/ivy/engines/XLA/__init__.py @@ -0,0 +1,14 @@ +from .rust_api.python_frontend.xla_core import * +from .rust_api.python_frontend.layers import * +from .rust_api.python_frontend.manipulation import * +from .rust_api.python_frontend.activations import * +from .rust_api.python_frontend.norms import * +from .rust_api.python_frontend.stateful_layers import * + +# from .rust_api.python_frontend.sequential_handler import * +from .rust_api.python_frontend.general import * +from .rust_api.python_frontend.manipulation import * +from .rust_api.python_frontend.creation import * +from .rust_api.python_frontend.linear_algebra import * +from .rust_api.python_frontend.elementwise import * +from .rust_api.python_frontend.statistical import * diff --git a/ivy/engines/XLA/rust_api/Cargo.toml b/ivy/engines/XLA/rust_api/Cargo.toml new file mode 100644 index 0000000000000..1d77407f33527 --- /dev/null +++ b/ivy/engines/XLA/rust_api/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "xlar" +version = "0.1.0" +edition = "2021" + +[lib] +name = "xlar" +crate-type = ["cdylib"] + +[dependencies] +thiserror = "1" +libc = "0.2" +num-traits = "0.2" +num-derive = "0.3" +zip = "0.6.4" +pyo3 = { version = "0.19.1", features = ["extension-module"] } +ndarray = "0.15.6" +numpy = "0.19.0" +half = "2.3.1" + +[build-dependencies] +bindgen = "0.64" +cc = "1.0" + +[dev-dependencies] +anyhow = "1.0" +clap = { version = "4.2.4", features = ["derive"] } +fancy-regex = "0.11.0" +rand = "0.8.5" +serde_json = "1.0.96" \ No newline at end of file diff --git a/ivy/engines/XLA/rust_api/build.rs b/ivy/engines/XLA/rust_api/build.rs new file mode 100644 index 0000000000000..db625ed20570d --- /dev/null +++ b/ivy/engines/XLA/rust_api/build.rs @@ -0,0 +1,69 @@ +extern crate bindgen; + +use std::env; +use std::path::{Path, PathBuf}; + +fn make_shared_lib>(xla_dir: P) { + let os = env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS"); + println!("cargo:rerun-if-changed=xla_rs/xla_rs.cc"); + println!("cargo:rerun-if-changed=xla_rs/xla_rs.h"); + match os.as_str() { + "linux" | "macos" => { + cc::Build::new() + .cpp(true) + .pic(true) + .warnings(false) + .include(xla_dir.as_ref().join("include")) + .flag("-std=c++17") + .flag("-Wno-deprecated-declarations") + .flag("-DLLVM_ON_UNIX=1") + .file("xla_rs/xla_rs.cc") + .compile("xla_rs"); + } + "windows" => { + cc::Build::new() + .cpp(true) + .pic(true) + .warnings(false) + .include(xla_dir.as_ref().join("include")) + .file("xla_rs/xla_rs.cc") + .compile("xla_rs"); + } + _ => panic!("Unsupported OS"), + }; +} + +fn env_var_rerun(name: &str) -> Option { + println!("cargo:rerun-if-env-changed={name}"); + env::var(name).ok() +} + +fn main() { + let xla_dir = env_var_rerun("XLA_EXTENSION_DIR") + .map_or_else(|| env::current_dir().unwrap().join("xla_extension"), PathBuf::from); + + println!("cargo:rerun-if-changed=xla_rs/xla_rs.h"); + println!("cargo:rerun-if-changed=xla_rs/xla_rs.cc"); + let bindings = bindgen::Builder::default() + .header("xla_rs/xla_rs.h") + .parse_callbacks(Box::new(bindgen::CargoCallbacks)) + .generate() + .expect("Unable to generate bindings"); + let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); + bindings.write_to_file(out_path.join("c_xla.rs")).expect("Couldn't write bindings!"); + + // Exit early on docs.rs as the C++ library would not be available. + if std::env::var("DOCS_RS").is_ok() { + return; + } + make_shared_lib(&xla_dir); + // The --copy-dt-needed-entries -lstdc++ are helpful to get around some + // "DSO missing from command line" error + // undefined reference to symbol '_ZStlsIcSt11char_traitsIcESaIcEERSt13basic_ostreamIT_T0_ES7_RKNSt7__cxx1112basic_stringIS4_S5_T1_EE@@GLIBCXX_3.4.21' + println!("cargo:rustc-link-arg=-Wl,--copy-dt-needed-entries"); + println!("cargo:rustc-link-arg=-Wl,-lstdc++"); + println!("cargo:rustc-link-search=native={}", xla_dir.join("lib").display()); + println!("cargo:rustc-link-lib=static=xla_rs"); + println!("cargo:rustc-link-arg=-Wl,-rpath={}", xla_dir.join("lib").display()); + println!("cargo:rustc-link-lib=xla_extension"); +} \ No newline at end of file diff --git a/ivy/engines/XLA/rust_api/pyproject.toml b/ivy/engines/XLA/rust_api/pyproject.toml new file mode 100644 index 0000000000000..e0dd83db89328 --- /dev/null +++ b/ivy/engines/XLA/rust_api/pyproject.toml @@ -0,0 +1,12 @@ +[build-system] +requires = ["maturin>=1,<2"] +build-backend = "maturin" + +[project] +name = "pyo3_example" +requires-python = ">=3.7" +classifiers = [ + "Programming Language :: Rust", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] \ No newline at end of file diff --git a/ivy/engines/XLA/rust_api/python_frontend/activations.cpython-310-x86_64-linux-gnu.so b/ivy/engines/XLA/rust_api/python_frontend/activations.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000..d8a48848a7c25 Binary files /dev/null and b/ivy/engines/XLA/rust_api/python_frontend/activations.cpython-310-x86_64-linux-gnu.so differ diff --git a/ivy/engines/XLA/rust_api/python_frontend/creation.cpython-310-x86_64-linux-gnu.so b/ivy/engines/XLA/rust_api/python_frontend/creation.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000..6da24a886f08c Binary files /dev/null and b/ivy/engines/XLA/rust_api/python_frontend/creation.cpython-310-x86_64-linux-gnu.so differ diff --git a/ivy/engines/XLA/rust_api/python_frontend/elementwise.cpython-310-x86_64-linux-gnu.so b/ivy/engines/XLA/rust_api/python_frontend/elementwise.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000..20b1a80e9dddf Binary files /dev/null and b/ivy/engines/XLA/rust_api/python_frontend/elementwise.cpython-310-x86_64-linux-gnu.so differ diff --git a/ivy/engines/XLA/rust_api/python_frontend/general.cpython-310-x86_64-linux-gnu.so b/ivy/engines/XLA/rust_api/python_frontend/general.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000..5e29f2149416f Binary files /dev/null and b/ivy/engines/XLA/rust_api/python_frontend/general.cpython-310-x86_64-linux-gnu.so differ diff --git a/ivy/engines/XLA/rust_api/python_frontend/layers.cpython-310-x86_64-linux-gnu.so b/ivy/engines/XLA/rust_api/python_frontend/layers.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000..5ec7994989713 Binary files /dev/null and b/ivy/engines/XLA/rust_api/python_frontend/layers.cpython-310-x86_64-linux-gnu.so differ diff --git a/ivy/engines/XLA/rust_api/python_frontend/linear_algebra.cpython-310-x86_64-linux-gnu.so b/ivy/engines/XLA/rust_api/python_frontend/linear_algebra.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000..14e20168c7c81 Binary files /dev/null and b/ivy/engines/XLA/rust_api/python_frontend/linear_algebra.cpython-310-x86_64-linux-gnu.so differ diff --git a/ivy/engines/XLA/rust_api/python_frontend/manipulation.cpython-310-x86_64-linux-gnu.so b/ivy/engines/XLA/rust_api/python_frontend/manipulation.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000..2fd2e6219b492 Binary files /dev/null and b/ivy/engines/XLA/rust_api/python_frontend/manipulation.cpython-310-x86_64-linux-gnu.so differ diff --git a/ivy/engines/XLA/rust_api/python_frontend/norms.cpython-310-x86_64-linux-gnu.so b/ivy/engines/XLA/rust_api/python_frontend/norms.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000..7464aa005cede Binary files /dev/null and b/ivy/engines/XLA/rust_api/python_frontend/norms.cpython-310-x86_64-linux-gnu.so differ diff --git a/ivy/engines/XLA/rust_api/python_frontend/stateful_activations.cpython-310-x86_64-linux-gnu.so b/ivy/engines/XLA/rust_api/python_frontend/stateful_activations.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000..60cf781e0bcd5 Binary files /dev/null and b/ivy/engines/XLA/rust_api/python_frontend/stateful_activations.cpython-310-x86_64-linux-gnu.so differ diff --git a/ivy/engines/XLA/rust_api/python_frontend/stateful_layers.cpython-310-x86_64-linux-gnu.so b/ivy/engines/XLA/rust_api/python_frontend/stateful_layers.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000..1ed2a0484ebc9 Binary files /dev/null and b/ivy/engines/XLA/rust_api/python_frontend/stateful_layers.cpython-310-x86_64-linux-gnu.so differ diff --git a/ivy/engines/XLA/rust_api/python_frontend/stateful_norms.cpython-310-x86_64-linux-gnu.so b/ivy/engines/XLA/rust_api/python_frontend/stateful_norms.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000..7685c5504ce4c Binary files /dev/null and b/ivy/engines/XLA/rust_api/python_frontend/stateful_norms.cpython-310-x86_64-linux-gnu.so differ diff --git a/ivy/engines/XLA/rust_api/python_frontend/statistical.cpython-310-x86_64-linux-gnu.so b/ivy/engines/XLA/rust_api/python_frontend/statistical.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000..b17c5c0d30b01 Binary files /dev/null and b/ivy/engines/XLA/rust_api/python_frontend/statistical.cpython-310-x86_64-linux-gnu.so differ diff --git a/ivy/engines/XLA/rust_api/python_frontend/xla_core.cpython-310-x86_64-linux-gnu.so b/ivy/engines/XLA/rust_api/python_frontend/xla_core.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000..a60b735859455 Binary files /dev/null and b/ivy/engines/XLA/rust_api/python_frontend/xla_core.cpython-310-x86_64-linux-gnu.so differ diff --git a/ivy/engines/XLA/rust_api/src/c_lib.rs b/ivy/engines/XLA/rust_api/src/c_lib.rs new file mode 100644 index 0000000000000..556531efa3533 --- /dev/null +++ b/ivy/engines/XLA/rust_api/src/c_lib.rs @@ -0,0 +1,6 @@ +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] +#![allow(dead_code)] + +include!(concat!(env!("OUT_DIR"), "/c_xla.rs")); diff --git a/ivy/engines/XLA/rust_api/src/error.rs b/ivy/engines/XLA/rust_api/src/error.rs new file mode 100644 index 0000000000000..08c53368034f7 --- /dev/null +++ b/ivy/engines/XLA/rust_api/src/error.rs @@ -0,0 +1,89 @@ +use pyo3::prelude::*; +use pyo3::exceptions::{PyOSError}; +use std::str::Utf8Error; + +/// Main library error type. +#[derive(thiserror::Error, Debug)] +pub enum Error { + /// Incorrect number of elements. + #[error("wrong element count {element_count} for dims {dims:?}")] + WrongElementCount { dims: Vec, element_count: usize }, + + /// Error from the xla C++ library. + #[error("xla error {msg}\n{backtrace}")] + XlaError { msg: String, backtrace: String }, + + #[error("unexpected element type {0}")] + UnexpectedElementType(i32), + + #[error("unexpected number of dimensions, expected: {expected}, got: {got} ({dims:?})")] + UnexpectedNumberOfDims { expected: usize, got: usize, dims: Vec }, + + #[error("not an element type, got: {got:?}")] + NotAnElementType { got: crate::PrimitiveType }, + + #[error("not an array, expected: {expected:?}, got: {got:?}")] + NotAnArray { expected: Option, got: crate::Shape }, + + #[error("cannot handle unsupported shapes {shape:?}")] + UnsupportedShape { shape: crate::Shape }, + + #[error("unexpected number of tuple elements, expected: {expected}, got: {got}")] + UnexpectedNumberOfElemsInTuple { expected: usize, got: usize }, + + #[error("element type mismatch, on-device: {on_device:?}, on-host: {on_host:?}")] + ElementTypeMismatch { on_device: crate::ElementType, on_host: crate::ElementType }, + + #[error("unsupported element type for {op}: {ty:?}")] + UnsupportedElementType { ty: crate::PrimitiveType, op: &'static str }, + + #[error( + "target buffer is too large, offset {offset}, shape {shape:?}, buffer_len: {buffer_len}" + )] + TargetBufferIsTooLarge { offset: usize, shape: crate::ArrayShape, buffer_len: usize }, + + #[error("binary buffer is too large, element count {element_count}, buffer_len: {buffer_len}")] + BinaryBufferIsTooLarge { element_count: usize, buffer_len: usize }, + + #[error("empty literal")] + EmptyLiteral, + + #[error("index out of bounds {index}, rank {rank}")] + IndexOutOfBounds { index: i64, rank: usize }, + + #[error("npy/npz error {0}")] + Npy(String), + + /// I/O error. + #[error(transparent)] + Io(#[from] std::io::Error), + + /// Zip file format error. + #[error(transparent)] + Zip(#[from] zip::result::ZipError), + + /// Integer parse error. + #[error(transparent)] + ParseInt(#[from] std::num::ParseIntError), + + #[error("cannot create literal with shape {ty:?} {dims:?} from bytes data with len {data_len_in_bytes}")] + CannotCreateLiteralWithData { + data_len_in_bytes: usize, + ty: crate::PrimitiveType, + dims: Vec, + }, + + #[error("invalid dimensions in matmul, lhs: {lhs_dims:?}, rhs: {rhs_dims:?}, {msg}")] + MatMulIncorrectDims { lhs_dims: Vec, rhs_dims: Vec, msg: &'static str }, + + #[error("Invalid UTF-8 data: {0}")] + Utf8Error(#[from] Utf8Error), +} + +impl From for PyErr { + fn from(err: Error) -> PyErr { + PyOSError::new_err(err.to_string()) + } +} + +pub type Result = std::result::Result; diff --git a/ivy/engines/XLA/rust_api/src/lib.rs b/ivy/engines/XLA/rust_api/src/lib.rs new file mode 100644 index 0000000000000..06df9110db4ff --- /dev/null +++ b/ivy/engines/XLA/rust_api/src/lib.rs @@ -0,0 +1,1821 @@ +mod c_lib; +mod error; +mod wrappers; + +use std::rc::Rc; +pub use error::{Error, Result}; +pub use wrappers::*; +use pyo3::prelude::*; +use ndarray::{ArrayD}; +use numpy::{PyArrayDyn, ToPyArray}; +use half::{f16, bf16}; +use pyo3::exceptions::PyTypeError; +use pyo3::{exceptions, wrap_pyfunction}; + + +#[derive(Debug, Copy, Clone)] +pub enum TfLogLevel { + Info, + Warning, + Error, + Fatal, +} + +impl TfLogLevel { + fn as_env_variable_str(&self) -> &'static str { + match self { + Self::Info => "0", + Self::Warning => "1", + Self::Error => "2", + Self::Fatal => "3", + } + } +} + +pub fn set_tf_min_log_level(log_level: TfLogLevel) { + std::env::set_var("TF_CPP_MIN_LOG_LEVEL", log_level.as_env_variable_str()) +} + + +#[derive(Debug)] +enum ArrayDyn { + Pred(ArrayD), + I8(ArrayD), + I16(ArrayD), + I32(ArrayD), + I64(ArrayD), + U8(ArrayD), + U16(ArrayD), + U32(ArrayD), + U64(ArrayD), + Bf16(ArrayD), + F16(ArrayD), + F32(ArrayD), + F64(ArrayD), +} + +#[derive(Debug)] +#[pyclass(unsendable)] +pub struct Tensor { + x: ArrayDyn +} + +impl From> for Tensor { + fn from(x: ArrayD) -> Self { + Tensor { + x: ArrayDyn::Pred(x), + } + } +} + +impl From> for Tensor { + fn from(x: ArrayD) -> Self { + Tensor { + x: ArrayDyn::I8(x), + } + } +} + +impl From> for Tensor { + fn from(x: ArrayD) -> Self { + Tensor { + x: ArrayDyn::I16(x), + } + } +} + +impl From> for Tensor { + fn from(x: ArrayD) -> Self { + Tensor { + x: ArrayDyn::I32(x), + } + } +} + +impl From> for Tensor { + fn from(x: ArrayD) -> Self { + Tensor { + x: ArrayDyn::I64(x), + } + } +} + +impl From> for Tensor { + fn from(x: ArrayD) -> Self { + Tensor { + x: ArrayDyn::U8(x), + } + } +} + +impl From> for Tensor { + fn from(x: ArrayD) -> Self { + Tensor { + x: ArrayDyn::U16(x), + } + } +} + +impl From> for Tensor { + fn from(x: ArrayD) -> Self { + Tensor { + x: ArrayDyn::U32(x), + } + } +} + +impl From> for Tensor { + fn from(x: ArrayD) -> Self { + Tensor { + x: ArrayDyn::U64(x), + } + } +} + +impl From> for Tensor { + fn from(x: ArrayD) -> Self { + Tensor { + x: ArrayDyn::Bf16(x), + } + } +} + +impl From> for Tensor { + fn from(x: ArrayD) -> Self { + Tensor { + x: ArrayDyn::F16(x), + } + } +} + +impl From> for Tensor { + fn from(x: ArrayD) -> Self { + Tensor { + x: ArrayDyn::F32(x), + } + } +} + +impl From> for Tensor { + fn from(x: ArrayD) -> Self { + Tensor { + x: ArrayDyn::F64(x), + } + } +} + + +#[pymethods] +impl Tensor { + fn __repr__(&self) -> PyResult { + let desc = match &self.x { + ArrayDyn::Pred(array) => format!("{:?}", array), + ArrayDyn::I8(array) => format!("{:?}", array), + ArrayDyn::I16(array) => format!("{:?}", array), + ArrayDyn::I32(array) => format!("{:?}", array), + ArrayDyn::I64(array) => format!("{:?}", array), + ArrayDyn::U8(array) => format!("{:?}", array), + ArrayDyn::U16(array) => format!("{:?}", array), + ArrayDyn::U32(array) => format!("{:?}", array), + ArrayDyn::U64(array) => format!("{:?}", array), + ArrayDyn::Bf16(array) => format!("{:?}", array), + ArrayDyn::F16(array) => format!("{:?}", array), + ArrayDyn::F32(array) => format!("{:?}", array), + ArrayDyn::F64(array) => format!("{:?}", array), + }; + Ok(format!("Tensor({})", desc)) + } +} + +#[derive(Clone, Debug)] +#[pyclass(unsendable)] +struct Bf16Array { + x: Py> +} +impl From>> for Bf16Array { + fn from(x: Py>) -> Self { + Bf16Array { + x + } + } +} + +#[derive(Clone, Debug)] +#[pyclass(unsendable)] +struct F16Array { + x: Py> +} +impl From>> for F16Array { + fn from(x: Py>) -> Self { + F16Array { + x + } + } +} + +#[pyfunction] +fn create_bf16_array(x: Py>) -> PyResult { + let x = Bf16Array{x}; + Ok(x) +} + +#[pyfunction] +fn create_f16_array(x: Py>) -> PyResult { + let x = F16Array{x}; + Ok(x) +} + +#[derive(Debug)] +enum DynamicPyArray { + Pred(Py>), + I8(Py>), + I16(Py>), + I32(Py>), + I64(Py>), + U8(Py>), + U16(Py>), + U32(Py>), + U64(Py>), + Bf16(Bf16Array), + F16(F16Array), + F32(Py>), + F64(Py>), +} + +impl<'source> FromPyObject<'source> for DynamicPyArray { + fn extract(obj: &'source PyAny) -> PyResult { + if let Ok(arr) = obj.extract::>>() { + Ok(DynamicPyArray::Pred(arr)) + } + else if let Ok(arr) = obj.extract::>>() { + Ok(DynamicPyArray::I8(arr)) + } + else if let Ok(arr) = obj.extract::>>() { + Ok(DynamicPyArray::I16(arr)) + } + else if let Ok(arr) = obj.extract::>>() { + Ok(DynamicPyArray::I32(arr)) + } + else if let Ok(arr) = obj.extract::>>() { + Ok(DynamicPyArray::I64(arr)) + } + else if let Ok(arr) = obj.extract::>>() { + Ok(DynamicPyArray::U8(arr)) + } + else if let Ok(arr) = obj.extract::>>() { + Ok(DynamicPyArray::U16(arr)) + } + else if let Ok(arr) = obj.extract::>>() { + Ok(DynamicPyArray::U32(arr)) + } + else if let Ok(arr) = obj.extract::>>() { + Ok(DynamicPyArray::U64(arr)) + } + else if let Ok(arr) = obj.extract::() { + Ok(DynamicPyArray::Bf16(arr)) + } + else if let Ok(arr) = obj.extract::() { + Ok(DynamicPyArray::F16(arr)) + } + else if let Ok(arr) = obj.extract::>>() { + Ok(DynamicPyArray::F32(arr)) + } + else if let Ok(arr) = obj.extract::>>() { + Ok(DynamicPyArray::F64(arr)) + } + else { + Err(PyErr::from(PyTypeError::new_err( + "Expected a numpy array of one of the valid types", + ))) + } + } +} + +#[pyfunction] +fn constant_array(py: Python, array: DynamicPyArray, builder: XlaBuilder) -> PyResult { + match array { + DynamicPyArray::Pred(py_array) => { + let x = Literal::vec1(unsafe { py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice() }); + let x = builder.constant_literal(&x)?; + Ok(x) + }, + DynamicPyArray::I8(py_array) => { + let x = Literal::vec1(unsafe { py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice() }); + let x = builder.constant_literal(&x)?; + Ok(x) + }, + DynamicPyArray::I16(py_array) => { + let x = Literal::vec1(unsafe { py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice() }); + let x = builder.constant_literal(&x)?; + Ok(x) + }, + DynamicPyArray::I32(py_array) => { + let x = Literal::vec1(unsafe { py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice() }); + let x = builder.constant_literal(&x)?; + Ok(x) + }, + DynamicPyArray::I64(py_array) => { + let x = Literal::vec1(unsafe { py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice() }); + let x = builder.constant_literal(&x)?; + Ok(x) + }, + DynamicPyArray::U8(py_array) => { + let x = Literal::vec1(unsafe { py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice() }); + let x = builder.constant_literal(&x)?; + Ok(x) + }, + DynamicPyArray::U16(py_array) => { + let x = Literal::vec1(unsafe { py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice() }); + let x = builder.constant_literal(&x)?; + Ok(x) + }, + DynamicPyArray::U32(py_array) => { + let x = Literal::vec1(unsafe { py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice() }); + let x = builder.constant_literal(&x)?; + Ok(x) + }, + DynamicPyArray::U64(py_array) => { + let x = Literal::vec1(unsafe {py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}); + let x = builder.constant_literal(&x)?; + Ok(x) + }, + DynamicPyArray::Bf16(py_array) => { + let x = Literal::vec1(unsafe {py_array.x.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}).convert(PrimitiveType::Bf16)?; + let x = builder.constant_literal(&x)?; + Ok(x) + }, + DynamicPyArray::F16(py_array) => { + let x = Literal::vec1(unsafe {py_array.x.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}).convert(PrimitiveType::F16)?; + let x = builder.constant_literal(&x)?; + Ok(x) + }, + DynamicPyArray::F32(py_array) => { + let x = Literal::vec1(unsafe {py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}); + let x = builder.constant_literal(&x)?; + Ok(x) + }, + DynamicPyArray::F64(py_array) => { + let x = Literal::vec1(unsafe {py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}); + let x = builder.constant_literal(&x)?; + Ok(x) + }, + } +} + + +#[pyfunction] +fn gather_params(py: Python, arrays: Vec) -> PyResult> { + let mut literals = Vec::with_capacity(arrays.len()); + for array in arrays { + match array { + DynamicPyArray::Pred(py_array) => { + let x = Literal::vec1(unsafe {py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}); + literals.push(x); + }, + DynamicPyArray::I8(py_array) => { + let x = Literal::vec1(unsafe {py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}); + literals.push(x); + }, + DynamicPyArray::I16(py_array) => { + let x = Literal::vec1(unsafe {py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}); + literals.push(x); + }, + DynamicPyArray::I32(py_array) => { + let x = Literal::vec1(unsafe {py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}); + literals.push(x); + }, + DynamicPyArray::I64(py_array) => { + let x = Literal::vec1(unsafe {py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}); + literals.push(x); + }, + DynamicPyArray::U8(py_array) => { + let x = Literal::vec1(unsafe {py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}); + literals.push(x); + }, + DynamicPyArray::U16(py_array) => { + let x = Literal::vec1(unsafe {py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}); + literals.push(x); + }, + DynamicPyArray::U32(py_array) => { + let x = Literal::vec1(unsafe {py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}); + literals.push(x); + }, + DynamicPyArray::U64(py_array) => { + let x = Literal::vec1(unsafe {py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}); + literals.push(x); + }, + DynamicPyArray::Bf16(py_array) => { + let x = Literal::vec1(unsafe {py_array.x.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}).convert(PrimitiveType::Bf16)?; + literals.push(x); + }, + DynamicPyArray::F16(py_array) => { + let x = Literal::vec1(unsafe {py_array.x.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}).convert(PrimitiveType::F16)?; + literals.push(x); + }, + DynamicPyArray::F32(py_array) => { + let x = Literal::vec1(unsafe {py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}); + literals.push(x); + }, + DynamicPyArray::F64(py_array) => { + let x = Literal::vec1(unsafe {py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}); + literals.push(x); + }, + } + } + Ok(literals) +} + +#[pyfunction] +fn new_input(py: Python, input: DynamicPyArray) -> PyResult { + match input { + DynamicPyArray::Pred(py_array) => { + let x = Literal::vec1(unsafe {py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}); + Ok(x) + }, + DynamicPyArray::I8(py_array) => { + let x = Literal::vec1(unsafe {py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}); + Ok(x) + }, + DynamicPyArray::I16(py_array) => { + let x = Literal::vec1(unsafe {py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}); + Ok(x) + }, + DynamicPyArray::I32(py_array) => { + let x = Literal::vec1(unsafe {py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}); + Ok(x) + }, + DynamicPyArray::I64(py_array) => { + let x = Literal::vec1(unsafe {py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}); + Ok(x) + }, + DynamicPyArray::U8(py_array) => { + let x = Literal::vec1(unsafe {py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}); + Ok(x) + }, + DynamicPyArray::U16(py_array) => { + let x = Literal::vec1(unsafe {py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}); + Ok(x) + }, + DynamicPyArray::U32(py_array) => { + let x = Literal::vec1(unsafe {py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}); + Ok(x) + }, + DynamicPyArray::U64(py_array) => { + let x = Literal::vec1(unsafe {py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}); + Ok(x) + }, + DynamicPyArray::Bf16(py_array) => { + let x = Literal::vec1(unsafe {py_array.x.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}).convert(PrimitiveType::Bf16)?; + Ok(x) + }, + DynamicPyArray::F16(py_array) => { + let x = Literal::vec1(unsafe {py_array.x.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}).convert(PrimitiveType::F16)?; + Ok(x) + }, + DynamicPyArray::F32(py_array) => { + let x = Literal::vec1(unsafe {py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}); + Ok(x) + }, + DynamicPyArray::F64(py_array) => { + let x = Literal::vec1(unsafe {py_array.as_ref(py).as_array().to_owned().into_raw_vec().as_slice()}); + Ok(x) + }, + } +} + +#[pyfunction] +fn swap_param(x: Literal, mut params: Vec) -> PyResult> { + params[0] = x; + Ok(params) +} + +#[pyfunction] +fn to_tensor(literal: Literal) -> PyResult { + let shape = literal.shape().unwrap(); + let shape = ArrayShape::try_from(&shape).unwrap(); + let shape: Vec = shape.dims().iter().map(|&x| x as usize).collect(); + + match literal.ty().unwrap() { + ElementType::Pred => { + let data: Vec = literal.to_vec().unwrap(); + let array = ArrayD::from_shape_vec(shape, data).unwrap(); + Ok(Tensor::from(array)) + } + ElementType::S8 => { + let data: Vec = literal.to_vec().unwrap(); + let array = ArrayD::from_shape_vec(shape, data).unwrap(); + Ok(Tensor::from(array)) + } + ElementType::S16 => { + let data: Vec = literal.to_vec().unwrap(); + let array = ArrayD::from_shape_vec(shape, data).unwrap(); + Ok(Tensor::from(array)) + } + ElementType::S32 => { + let data: Vec = literal.to_vec().unwrap(); + let array = ArrayD::from_shape_vec(shape, data).unwrap(); + Ok(Tensor::from(array)) + } + ElementType::S64 => { + let data: Vec = literal.to_vec().unwrap(); + let array = ArrayD::from_shape_vec(shape, data).unwrap(); + Ok(Tensor::from(array)) + } + ElementType::U8 => { + let data: Vec = literal.to_vec().unwrap(); + let array = ArrayD::from_shape_vec(shape, data).unwrap(); + Ok(Tensor::from(array)) + } + ElementType::U16 => { + let data: Vec = literal.to_vec().unwrap(); + let array = ArrayD::from_shape_vec(shape, data).unwrap(); + Ok(Tensor::from(array)) + } + ElementType::U32 => { + let data: Vec = literal.to_vec().unwrap(); + let array = ArrayD::from_shape_vec(shape, data).unwrap(); + Ok(Tensor::from(array)) + } + ElementType::U64 => { + let data: Vec = literal.to_vec().unwrap(); + let array = ArrayD::from_shape_vec(shape, data).unwrap(); + Ok(Tensor::from(array)) + } + ElementType::Bf16 => { + let data: Vec = literal.to_vec().unwrap(); + let array = ArrayD::from_shape_vec(shape, data).unwrap(); + Ok(Tensor::from(array)) + } + ElementType::F16 => { + let data: Vec = literal.to_vec().unwrap(); + let array = ArrayD::from_shape_vec(shape, data).unwrap(); + Ok(Tensor::from(array)) + } + ElementType::F32 => { + let data: Vec = literal.to_vec().unwrap(); + let array = ArrayD::from_shape_vec(shape, data).unwrap(); + Ok(Tensor::from(array)) + } + ElementType::F64 => { + let data: Vec = literal.to_vec().unwrap(); + let array = ArrayD::from_shape_vec(shape, data).unwrap(); + Ok(Tensor::from(array)) + } + _ => Err(PyErr::from(PyTypeError::new_err( + "Unsupported date type", + ))) + + } +} + +#[pyfunction] +fn to_numpy(py: Python, literal: Literal) -> PyResult { + let shape = literal.shape().unwrap(); + let shape = ArrayShape::try_from(&shape).unwrap(); + let shape: Vec = shape.dims().iter().map(|&x| x as usize).collect(); + + match literal.ty().unwrap() { + ElementType::Pred => { + let data: Vec = literal.to_vec()?; + let array = ArrayD::from_shape_vec(shape, data).unwrap().to_pyarray(py); + Ok(array.to_object(py)) + } + ElementType::S8 => { + let data: Vec = literal.to_vec()?; + let array = ArrayD::from_shape_vec(shape, data).unwrap().to_pyarray(py); + Ok(array.to_object(py)) + } + ElementType::S16 => { + let data: Vec = literal.to_vec()?; + let array = ArrayD::from_shape_vec(shape, data).unwrap().to_pyarray(py); + Ok(array.to_object(py)) + } + ElementType::S32 => { + let data: Vec = literal.to_vec()?; + let array = ArrayD::from_shape_vec(shape, data).unwrap().to_pyarray(py); + Ok(array.to_object(py)) + } + ElementType::S64 => { + let data: Vec = literal.to_vec()?; + let array = ArrayD::from_shape_vec(shape, data).unwrap().to_pyarray(py); + Ok(array.to_object(py)) + } + ElementType::U8 => { + let data: Vec = literal.to_vec()?; + let array = ArrayD::from_shape_vec(shape, data).unwrap().to_pyarray(py); + Ok(array.to_object(py)) + } + ElementType::U16 => { + let data: Vec = literal.to_vec()?; + let array = ArrayD::from_shape_vec(shape, data).unwrap().to_pyarray(py); + Ok(array.to_object(py)) + } + ElementType::U32 => { + let data: Vec = literal.to_vec()?; + let array = ArrayD::from_shape_vec(shape, data).unwrap().to_pyarray(py); + Ok(array.to_object(py)) + } + ElementType::U64 => { + let data: Vec = literal.to_vec()?; + let array = ArrayD::from_shape_vec(shape, data).unwrap().to_pyarray(py); + Ok(array.to_object(py)) + } + ElementType::Bf16 | ElementType::F16 => { + let literal = literal.convert(PrimitiveType::F32)?; + let data: Vec = literal.to_vec()?; + let array = ArrayD::from_shape_vec(shape, data).unwrap().to_pyarray(py); + Ok(array.to_object(py)) + } + ElementType::F32 => { + let data: Vec = literal.to_vec()?; + let array = ArrayD::from_shape_vec(shape, data).unwrap().to_pyarray(py); + Ok(array.to_object(py)) + } + ElementType::F64 => { + let data: Vec = literal.to_vec()?; + let array = ArrayD::from_shape_vec(shape, data).unwrap().to_pyarray(py); + Ok(array.to_object(py)) + } + _ => Err(PyErr::from(PyTypeError::new_err( + "Unsupported data type", + ))) + } +} + +#[pyfunction] +fn to_tuple(literal: Literal) -> PyResult> { + let y = literal.to_tuple()?; + Ok(y) +} + + +macro_rules! param_gen { + ($name:ident, $type:ty) => { + #[pyfunction] + fn $name(builder: XlaBuilder, param_number: i64, dims: Vec, name: &str) -> PyResult { + let shape = &Shape::array::<$type>(dims); + let param = builder.parameter_s(param_number, shape, name)?; + Ok(param) + } + } +} + +param_gen!(param_pred, bool); +param_gen!(param_i8, i8); +param_gen!(param_i16, i16); +param_gen!(param_i32, i32); +param_gen!(param_i64, i64); +param_gen!(param_u8, u8); +param_gen!(param_u16, u16); +param_gen!(param_u32, u32); +param_gen!(param_u64, u64); +param_gen!(param_bf16, Bf16); +param_gen!(param_f16, F16); +param_gen!(param_f32, f32); +param_gen!(param_f64, f64); + + +macro_rules! constant { + ($name:ident, $type:ty) => { + #[pyfunction] + fn $name(b: XlaBuilder, v: $type) -> PyResult { + let c = b.c0(v)?; + Ok(c) + } + }; +} + +constant!(constant_bool, bool); +constant!(constant_i8, i8); +constant!(constant_i16, i16); +constant!(constant_i32, i32); +constant!(constant_i64, i64); +constant!(constant_u8, u8); +constant!(constant_u16, u16); +constant!(constant_u32, u32); +constant!(constant_u64, u64); +constant!(constant_f32, f32); +constant!(constant_f64, f64); + + +macro_rules! astype { + ($name:ident, $primitive:ident) => { + #[pyfunction] + fn $name(x: XlaOp) -> PyResult { + let y = x.astype(PrimitiveType::$primitive)?; + Ok(y) + } + }; +} + +astype!(astype_bool, Pred); +astype!(astype_i8, S8); +astype!(astype_i16, S16); +astype!(astype_i32, S32); +astype!(astype_i64, S64); +astype!(astype_u8, U8); +astype!(astype_u16, U16); +astype!(astype_u32, U32); +astype!(astype_u64, U64); +astype!(astype_bf16, Bf16); +astype!(astype_f16, F16); +astype!(astype_f32, F32); +astype!(astype_f64, F64); + + +#[pyfunction] +fn cpu_client() -> PyResult { + let client = PjRtClient::cpu()?; + Ok(client) +} + +#[pyfunction] +fn gpu_client(memory_fraction: f64, preallocate: bool) -> PyResult { + let client = PjRtClient::gpu(memory_fraction, preallocate)?; + Ok(client) +} + +#[pyfunction] +fn xla_builder(name: &str) -> PyResult { + let builder = XlaBuilder::new(name); + Ok(builder) +} + +#[pyfunction] +fn build(op: XlaOp) -> PyResult { + let computation = op.build()?; + Ok(computation) +} + +#[pyfunction] +fn get_hlo_proto(comp: &XlaComputation) -> PyResult { + let hlo_proto = comp.proto(); + Ok(hlo_proto) +} + +#[pyfunction] +fn hlo_module_from_proto(proto: &HloModuleProto) -> PyResult { + let hlo_module = HloModule::from_proto(proto)?; + Ok(hlo_module) +} + +#[pyfunction] +fn hlo_module_to_string(module: &HloModule) -> PyResult { + let module_str = module.to_string()?; + Ok(module_str) +} + +#[pyfunction] +fn get_hlo_module_entry_computation(module: &HloModule) -> PyResult { + let hlo_comp = module.get_entry_computation()?; + Ok(hlo_comp) +} + +#[pyfunction] +fn computation_count(module: &HloModule) -> PyResult { + let comp_count = module.computation_count()?; + Ok(comp_count) +} + +#[pyfunction] +fn instruction_count(module: &HloModule) -> PyResult { + let instruct_count = module.instruction_count()?; + Ok(instruct_count) +} + +#[pyfunction] +fn compile(client: PjRtClient, computation: &XlaComputation) -> PyResult { + let executable = client.compile(computation)?; + Ok(executable) +} + +#[pyfunction] +fn execute(executable: &PjRtLoadedExecutable, args: Vec) -> PyResult { + let buffer = executable.execute::(args.as_slice())?[0].remove(0); + Ok(buffer) +} + +#[pyfunction] +fn to_literal(buffer: &PjRtBuffer) -> PyResult { + let literal = buffer.to_literal_sync()?; + Ok(literal) +} + +#[pyfunction] +fn add(lhs: XlaOp, rhs: &XlaOp) -> PyResult { + let y = lhs.add_(rhs)?; + Ok(y) +} + +#[pyfunction] +fn sub(lhs: XlaOp, rhs: &XlaOp) -> PyResult { + let y = lhs.sub_(rhs)?; + Ok(y) +} + +#[pyfunction] +fn mul(lhs: XlaOp, rhs: &XlaOp) -> PyResult { + let y = lhs.mul_(rhs)?; + Ok(y) +} + +#[pyfunction] +fn div(lhs: XlaOp, rhs: &XlaOp) -> PyResult { + let y = lhs.div_(rhs)?; + Ok(y) +} + +#[pyfunction] +fn rem(lhs: XlaOp, rhs: &XlaOp) -> PyResult { + let y = lhs.rem_(rhs)?; + Ok(y) +} + +#[pyfunction] +fn pow(lhs: XlaOp, rhs: &XlaOp) -> PyResult { + let y = lhs.pow(rhs)?; + Ok(y) +} + +#[pyfunction] +fn max(lhs: XlaOp, rhs: &XlaOp) -> PyResult { + let y = lhs.max(rhs)?; + Ok(y) +} + +#[pyfunction] +fn min(lhs: XlaOp, rhs: &XlaOp) -> PyResult { + let y = lhs.min(rhs)?; + Ok(y) +} + +#[pyfunction] +fn _and(lhs: XlaOp, rhs: &XlaOp) -> PyResult { + let y = lhs.and(rhs)?; + Ok(y) +} + +#[pyfunction] +fn _or(lhs: XlaOp, rhs: &XlaOp) -> PyResult { + let y = lhs.or(rhs)?; + Ok(y) +} + +#[pyfunction] +fn xor(lhs: XlaOp, rhs: &XlaOp) -> PyResult { + let y = lhs.xor(rhs)?; + Ok(y) +} + +#[pyfunction] +fn eq(lhs: XlaOp, rhs: &XlaOp) -> PyResult { + let y = lhs.eq(rhs)?; + Ok(y) +} + +#[pyfunction] +fn ne(lhs: XlaOp, rhs: &XlaOp) -> PyResult { + let y = lhs.ne(rhs)?; + Ok(y) +} + +#[pyfunction] +fn ge(lhs: XlaOp, rhs: &XlaOp) -> PyResult { + let y = lhs.ge(rhs)?; + Ok(y) +} + +#[pyfunction] +fn gt(lhs: XlaOp, rhs: &XlaOp) -> PyResult { + let y = lhs.gt(rhs)?; + Ok(y) +} + +#[pyfunction] +fn le(lhs: XlaOp, rhs: &XlaOp) -> PyResult { + let y = lhs.le(rhs)?; + Ok(y) +} + +#[pyfunction] +fn lt(lhs: XlaOp, rhs: &XlaOp) -> PyResult { + let y = lhs.lt(rhs)?; + Ok(y) +} + +#[pyfunction] +fn lshift(lhs: XlaOp, rhs: &XlaOp) -> PyResult { + let y = lhs.lshift(rhs)?; + Ok(y) +} + +#[pyfunction] +fn rshift(lhs: XlaOp, rhs: &XlaOp) -> PyResult { + let y = lhs.rshift_arith(rhs)?; + Ok(y) +} + +#[pyfunction] +fn atan2(lhs: XlaOp, rhs: &XlaOp) -> PyResult { + let y = lhs.atan2(rhs)?; + Ok(y) +} + +#[pyfunction] +fn dot(lhs: XlaOp, rhs: &XlaOp) -> PyResult { + let y = lhs.dot(rhs)?; + Ok(y) +} + +#[pyfunction] +fn matmul(lhs: XlaOp, rhs: &XlaOp) -> PyResult { + let y = lhs.matmul(rhs)?; + Ok(y) +} + +#[pyfunction] +fn population_count(x: XlaOp) -> PyResult { + let y = x.population_count()?; + Ok(y) +} + +#[pyfunction] +fn _not(x: XlaOp) -> PyResult { + let y = x.not()?; + Ok(y) +} + +#[pyfunction] +fn neg(x: XlaOp) -> PyResult { + let y = x.neg()?; + Ok(y) +} + +#[pyfunction] +fn abs(x: XlaOp) -> PyResult { + let y = x.abs()?; + Ok(y) +} + +#[pyfunction] +fn floor(x: XlaOp) -> PyResult { + let y = x.floor()?; + Ok(y) +} + +#[pyfunction] +fn ceil(x: XlaOp) -> PyResult { + let y = x.ceil()?; + Ok(y) +} + +#[pyfunction] +fn round(x: XlaOp) -> PyResult { + let y = x.round()?; + Ok(y) +} + +#[pyfunction] +fn round_nearest_even(x: XlaOp) -> PyResult { + let y = x.round_nearest_even()?; + Ok(y) +} + +#[pyfunction] +fn exp(x: XlaOp) -> PyResult { + let y = x.exp()?; + Ok(y) +} + +#[pyfunction] +fn expm1(x: XlaOp) -> PyResult { + let y = x.expm1()?; + Ok(y) +} + +#[pyfunction] +fn log(x: XlaOp) -> PyResult { + let y = x.log()?; + Ok(y) +} + +#[pyfunction] +fn log1p(x: XlaOp) -> PyResult { + let y = x.log1p()?; + Ok(y) +} + +#[pyfunction] +fn logistic(x: XlaOp) -> PyResult { + let y = x.logistic()?; + Ok(y) +} + +#[pyfunction] +fn sign(x: XlaOp) -> PyResult { + let y = x.sign()?; + Ok(y) +} + +#[pyfunction] +fn clz(x: XlaOp) -> PyResult { + let y = x.clz()?; + Ok(y) +} + +#[pyfunction] +fn sin(x: XlaOp) -> PyResult { + let y = x.sin()?; + Ok(y) +} + +#[pyfunction] +fn cos(x: XlaOp) -> PyResult { + let y = x.cos()?; + Ok(y) +} + +#[pyfunction] +fn tanh(x: XlaOp) -> PyResult { + let y = x.tanh()?; + Ok(y) +} + +#[pyfunction] +fn real(x: XlaOp) -> PyResult { + let y = x.real()?; + Ok(y) +} + +#[pyfunction] +fn imag(x: XlaOp) -> PyResult { + let y = x.imag()?; + Ok(y) +} + +#[pyfunction] +fn conj(x: XlaOp) -> PyResult { + let y = x.conj()?; + Ok(y) +} + +#[pyfunction] +fn square(x: XlaOp) -> PyResult { + let y = x.square()?; + Ok(y) +} + +#[pyfunction] +fn sqrt(x: XlaOp) -> PyResult { + let y = x.sqrt()?; + Ok(y) +} + +#[pyfunction] +fn rsqrt(x: XlaOp) -> PyResult { + let y = x.rsqrt()?; + Ok(y) +} + +#[pyfunction] +fn cbrt(x: XlaOp) -> PyResult { + let y = x.cbrt()?; + Ok(y) +} + +#[pyfunction] +fn upper_triangle(x: XlaOp) -> PyResult { + let y = x.upper_triangle()?; + Ok(y) +} + +#[pyfunction] +fn lower_triangle(x: XlaOp) -> PyResult { + let y = x.lower_triangle()?; + Ok(y) +} + +#[pyfunction] +fn erf(x: XlaOp) -> PyResult { + let y = x.erf()?; + Ok(y) +} + +#[pyfunction] +fn is_finite(x: XlaOp) -> PyResult { + let y = x.is_finite()?; + Ok(y) +} + +#[pyfunction] +fn zeros_like(x: XlaOp) -> PyResult { + let y = x.zeros_like()?; + Ok(y) +} + +#[pyfunction] +fn copy(x: XlaOp) -> PyResult { + let y = x.copy()?; + Ok(y) +} + +#[pyfunction] +fn sigmoid(x: XlaOp) -> PyResult { + let y = x.sigmoid()?; + Ok(y) +} + +#[pyfunction] +fn silu(x: XlaOp) -> PyResult { + let y = x.silu()?; + Ok(y) +} + +#[pyfunction] +fn relu(x: XlaOp) -> PyResult { + let y = x.relu()?; + Ok(y) +} + +#[pyfunction] +fn gelu(x: XlaOp) -> PyResult { + let y = x.gelu()?; + Ok(y) +} + +#[pyfunction] +fn gelu_approx(x: XlaOp) -> PyResult { + let y = x.gelu_approx()?; + Ok(y) +} + +#[pyfunction] +fn einsum1(x: XlaOp, config: &str) -> PyResult { + let y = x.einsum1(config)?; + Ok(y) +} + +#[pyfunction] +fn einsum2(x: XlaOp, rhs: &XlaOp, config: &str) -> PyResult { + let y = x.einsum2(rhs, config)?; + Ok(y) +} + +#[pyfunction] +fn reshape(x: XlaOp, dims: Vec) -> PyResult { + let dims = dims.as_slice(); + let y = x.reshape(dims)?; + Ok(y) +} + +#[pyfunction] +fn dynamic_reshape( + x: XlaOp, + dim_sizes: Vec, + new_size_bounds: Vec, + dims_are_dynamic: Vec +) -> PyResult { + let dim_sizes = dim_sizes.as_slice(); + let new_size_bounds = new_size_bounds.as_slice(); + let y = x.dynamic_reshape(dim_sizes, new_size_bounds, dims_are_dynamic)?; + Ok(y) +} + +#[pyfunction] +fn broadcast(x: XlaOp, dims: Vec) -> PyResult { + let dims = dims.as_slice(); + let y = x.broadcast(dims)?; + Ok(y) +} + +#[pyfunction] +fn broadcast_in_dim(x: XlaOp, out_dims: Vec, broadcast_dims: Vec) -> PyResult { + let out_dims = out_dims.as_slice(); + let broadcast_dims = broadcast_dims.as_slice(); + let y = x.broadcast_in_dim(out_dims, broadcast_dims)?; + Ok(y) +} + +#[pyfunction] +fn collapse(x: XlaOp, dims: Vec) -> PyResult { + let dims = dims.as_slice(); + let y = x.collapse(dims)?; + Ok(y) +} + +#[pyfunction] +fn transpose(x: XlaOp, index_perm: Vec) -> PyResult { + let index_perm = index_perm.as_slice(); + let y = x.transpose(index_perm)?; + Ok(y) +} + +#[pyfunction] +fn swap_dims(x: XlaOp, index1: i64, index2: i64) -> PyResult { + let y = x.swap_dims(index1, index2)?; + Ok(y) +} + +#[pyfunction] +fn pad(x: XlaOp, padding_value: &XlaOp, padding_config:Vec<(i64, i64, i64)> ) -> PyResult { + let y = x.pad(padding_value, padding_config)?; + Ok(y) +} + +#[pyfunction] +fn pad_in_dim(x: XlaOp, padding_value: &XlaOp, dinmo: i64, pad_low: i64, pad_high: i64) -> PyResult { + let y = x.pad_in_dim(padding_value, dinmo, pad_low, pad_high)?; + Ok(y) +} + +#[pyfunction] +fn slice(x: XlaOp, start_indices: Vec, limit_indices: Vec, strides: Vec) -> PyResult { + let start_indices = start_indices.as_slice(); + let limit_indices = limit_indices.as_slice(); + let strides = strides.as_slice(); + let y = x.slice(start_indices, limit_indices, strides)?; + Ok(y) +} + +#[pyfunction] +fn slice_in_dim(x: XlaOp, start_index: i64, stop_index: i64, stride: i64, dim: i64) -> PyResult { + let y = x.slice_in_dim(start_index, stop_index, stride, dim)?; + Ok(y) +} + +#[pyfunction] +fn dynamic_slice(x: XlaOp, start_indices: Vec, slice_indices: Vec) -> PyResult { + let start_indices = start_indices.as_slice(); + let slice_indices = slice_indices.as_slice(); + let y = x.dynamic_slice(start_indices, slice_indices)?; + Ok(y) +} + +#[pyfunction] +fn dynamic_update_slice(x: XlaOp, update: &XlaOp, start_indices: Vec) -> PyResult { + let start_indices = start_indices.as_slice(); + let y = x.dynamic_update_slice(update, start_indices)?; + Ok(y) +} + +#[pyfunction] +fn at(x: XlaOp, index_in_dim: i64, dim_index: i64) -> PyResult { + let y = x.at(index_in_dim, dim_index)?; + Ok(y) +} + +#[pyfunction] +fn squeeze(x: XlaOp, index: i64) -> PyResult { + let y = x.squeeze(index)?; + Ok(y) +} + +#[pyfunction] +fn clamp(x: XlaOp, min: &XlaOp, max: &XlaOp) -> PyResult { + let y = x.clamp(min, max)?; + Ok(y) +} + +#[pyfunction] +fn concat(x: XlaOp, args: Vec, dim: i64) -> PyResult { + let args = args.as_slice(); + let y = x.concat_in_dim(args, dim)?; + Ok(y) +} + +#[pyfunction] +fn get_tuple_element(x: XlaOp, index: i64) -> PyResult { + let y = x.get_tuple_element(index)?; + Ok(y) +} + +#[pyfunction] +fn rng_uniform(min: &XlaOp, max: &XlaOp, shape: &ArrayShape) -> PyResult { + let y = XlaOp::rng_uniform(min, max, shape)?; + Ok(y) +} + +#[pyfunction] +fn rng_normal(mu: &XlaOp, sigma: &XlaOp, shape: &ArrayShape) -> PyResult { + let y = XlaOp::rng_normal(mu, sigma, shape)?; + Ok(y) +} + +#[pyfunction] +fn astype(x: XlaOp, ty: PrimitiveType) -> PyResult { + let y = x.astype(ty)?; + Ok(y) +} + +#[pyfunction] +fn dimension_size(x: XlaOp, index: i64) -> PyResult { + let y = x.dimensions_size(index)?; + Ok(y) +} + +#[pyfunction] +fn reduce( + x: XlaOp, + init_value: XlaOp, + comp: &XlaComputation, + dims: Vec, + keep_dims: bool, +) -> PyResult { + let dims = dims.as_slice(); + let y = x.reduce(init_value, comp, dims, keep_dims)?; + Ok(y) +} + +#[pyfunction] +fn call(builder: XlaBuilder, computation: &XlaComputation, operands: Vec) -> PyResult { + let operands = operands.as_slice(); + let y = builder.call(computation, operands)?; + Ok(y) +} + +#[pyfunction] +fn map(builder: XlaBuilder, + operands: Vec, + computation: &XlaComputation, + dims: Vec, + static_operands: Vec +) -> PyResult { + let operands = operands.as_slice(); + let dims = dims.as_slice(); + let static_operands = static_operands.as_slice(); + let y = builder.map(operands, computation, dims, static_operands)?; + Ok(y) +} + +#[pyfunction] +fn select(x: XlaOp, on_true: &XlaOp, on_false: &XlaOp) -> PyResult { + let y = x.select(on_true, on_false)?; + Ok(y) +} + +#[pyfunction] +fn while_loop(cond: &XlaComputation, body: &XlaComputation, init: XlaOp) -> PyResult { + let y = XlaOp::while_(cond, body, init)?; + Ok(y) +} + +#[pyfunction] +fn conditional( + x: XlaOp, + true_op: XlaOp, + true_comp: &XlaComputation, + false_op: XlaOp, + false_comp: &XlaComputation, +) -> PyResult { + let y = x.conditional(true_op, true_comp,false_op, false_comp)?; + Ok(y) +} + +#[pyfunction] +fn conv( + x: XlaOp, + rhs: &XlaOp, + window_strides: Vec, + padding: &str, + feature_group_count: i64, + batch_group_count: i64, +) -> PyResult { + let window_strides = window_strides.as_slice(); + let y = x.conv(rhs, window_strides, padding, feature_group_count, batch_group_count)?; + Ok(y) +} + +#[pyfunction] +fn conv_general_dilated( + x: XlaOp, + rhs: &XlaOp, + window_strides: Vec, + padding: Vec<(i64, i64)>, + lhs_dilations: Vec, + rhs_dilations: Vec, + input_batch_dim: i64, + input_feature_dim: i64, + input_spatial_dims: Vec, + output_batch_dim: i64, + output_feature_dim: i64, + output_spatial_dims: Vec, + kernel_input_feature_dim: i64, + kernel_output_feature_dim: i64, + kernel_spatial_dims: Vec, + feature_group_count: i64, + batch_group_count: i64 +) -> PyResult { + let window_strides = window_strides.as_slice(); + let padding = padding.as_slice(); + let lhs_dilations = lhs_dilations.as_slice(); + let rhs_dilations = rhs_dilations.as_slice(); + let input_spatial_dims = input_spatial_dims.as_slice(); + let output_spatial_dims = output_spatial_dims.as_slice(); + let kernel_spatial_dims = kernel_spatial_dims.as_slice(); + let y = x.conv_general_dilated( + rhs, + window_strides, + padding, + lhs_dilations, + rhs_dilations, + &input_batch_dim, + &input_feature_dim, + input_spatial_dims, + &output_batch_dim, + &output_feature_dim, + output_spatial_dims, + &kernel_input_feature_dim, + &kernel_output_feature_dim, + kernel_spatial_dims, + feature_group_count, + batch_group_count, + )?; + Ok(y) +} + +#[pyfunction] +fn batch_norm_inference( + x: XlaOp, + scale: &XlaOp, + offset: &XlaOp, + mean: &XlaOp, + variance: &XlaOp, + epsilon: f32, + feature_index: i64, +) -> PyResult { + let y = x.batch_norm_inference( + scale, offset, mean, variance, epsilon, feature_index + )?; + Ok(y) +} + +#[pyfunction] +fn dot_general( + x: XlaOp, + rhs: &XlaOp, + lhs_contracting_dims: Vec, + rhs_contracting_dims: Vec, + lhs_batch_dims: Vec, + rhs_batch_dims: Vec, +) -> PyResult { + let lhs_contracting_dims = lhs_contracting_dims.as_slice(); + let rhs_contracting_dims = rhs_contracting_dims.as_slice(); + let lhs_batch_dims = lhs_batch_dims.as_slice(); + let rhs_batch_dims = rhs_batch_dims.as_slice(); + let y = x.dot_general( + rhs, + lhs_contracting_dims, + rhs_contracting_dims, + lhs_batch_dims, + rhs_batch_dims + )?; + Ok(y) +} + +#[pyfunction] +fn gather( + x: XlaOp, + start_indices: &XlaOp, + offset_dims: Vec, + collapsed_slice_dims: Vec, + start_index_map: Vec, + slice_sizes: Vec, + set_index_vector_dim: Option, +) -> PyResult { + let offset_dims = offset_dims.as_slice(); + let collapsed_slice_dims = collapsed_slice_dims.as_slice(); + let start_index_map = start_index_map.as_slice(); + let slice_sizes = slice_sizes.as_slice(); + let y = x.gather( + start_indices, + offset_dims, + collapsed_slice_dims, + start_index_map, + set_index_vector_dim, + slice_sizes, + )?; + Ok(y) +} + +#[pyfunction] +fn scatter( + operands: Vec, + scatter_indices: &XlaOp, + updates: Vec, + update_computation: &XlaComputation, + update_window_dims: Vec, + inserted_window_dims: Vec, + scatter_dims_to_operand_dims: Vec, + index_vector_dim: i64 +) -> PyResult { + let operands = operands.as_slice(); + let updates = updates.as_slice(); + let update_window_dims = update_window_dims.as_slice(); + let inserted_window_dims = inserted_window_dims.as_slice(); + let scatter_dims_to_operand_dims = scatter_dims_to_operand_dims.as_slice(); + let y = XlaOp::scatter( + operands, + scatter_indices, + updates, + update_computation, + update_window_dims, + inserted_window_dims, + scatter_dims_to_operand_dims, + index_vector_dim + )?; + Ok(y) +} + +#[pyfunction] +fn take(x: XlaOp, indices: &XlaOp, axis: i64) -> PyResult { + let y = x.take(indices, axis)?; + Ok(y) +} + +#[pyfunction] +fn reduce_sum(x: XlaOp, dims: Vec, keep_dims: bool) -> PyResult { + let dims = dims.as_slice(); + let y = x.reduce_sum(dims, keep_dims)?; + Ok(y) +} + +#[pyfunction] +fn reduce_mean(x: XlaOp, dims: Vec, keep_dims: bool) -> PyResult { + let dims = dims.as_slice(); + let y = x.reduce_mean(dims, keep_dims)?; + Ok(y) +} + +#[pyfunction] +fn reduce_max(x: XlaOp, dims: Vec, keep_dims: bool) -> PyResult { + let dims = dims.as_slice(); + let y = x.reduce_max(dims, keep_dims)?; + Ok(y) +} + +#[pyfunction] +fn reduce_min(x: XlaOp, dims: Vec, keep_dims: bool) -> PyResult { + let dims = dims.as_slice(); + let y = x.reduce_min(dims, keep_dims)?; + Ok(y) +} + +#[pyfunction] +fn softmax(x: XlaOp, axis: i64) -> PyResult { + let y = x.softmax(axis)?; + Ok(y) +} + +#[pyfunction] +fn layer_norm(x: XlaOp, dims: Vec, scale: &XlaOp, bias: &XlaOp, eps: f64) -> PyResult { + let dims = dims.as_slice(); + let y = x.layer_norm(dims, scale, bias, eps)?; + Ok(y) +} + +#[pyfunction] +fn primitive_type(x: XlaOp) -> PyResult { + let prim_type = x.primitive_type()?; + Ok(prim_type) +} + +#[pyfunction] +fn element_type(x: XlaOp) -> PyResult { + let elem_type = PrimitiveType::element_type(x.ty()?)?; + Ok(elem_type) +} + +#[pyfunction] +fn dims(x: XlaOp) -> PyResult> { + let dims = x.dims()?; + Ok(dims) +} + +#[pyfunction] +fn rank(x: XlaOp) -> PyResult { + let rank = x.rank()?; + Ok(rank) +} + +#[pyfunction] +fn shape(x: XlaOp) -> PyResult> { + let shape = x.shape()?; + let shape = ArrayShape::try_from(&shape)?; + let shape: Vec = shape.dims().iter().map(|&x| x as usize).collect(); + Ok(shape) +} + +#[pyfunction] +fn array_shape(x: XlaOp) -> PyResult { + let shape = x.array_shape()?; + Ok(shape) +} + +#[pyfunction] +fn create_array_shape(ty: ElementType, dims: Vec) -> PyResult { + let shape = ArrayShape::new_with_type(ty, dims); + Ok(shape) +} + +#[pyfunction] +fn last_dim(x: XlaOp) -> PyResult { + let shape = x.shape()?; + let shape = ArrayShape::try_from(&shape)?; + let last_dim = shape.last_dim().ok_or_else(|| PyErr::new::("Shape has no dimensions"))?; + Ok(last_dim) +} + +#[pyfunction] +fn tuple(builder: XlaBuilder, args: Vec) -> PyResult { + let y = builder.tuple(&args)?; + Ok(y) +} + +#[pyfunction] +fn get_builder(x: XlaOp) -> PyResult { + let b = Rc::new(x.builder().clone()); + match Rc::try_unwrap(b) { + Ok(builder) => Ok(builder), + Err(_) => Err(PyErr::new::("Could not unwrap XlaBuilder")), + } +} + + +#[pymodule] +#[pyo3(name="xlar")] +fn module(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_function(wrap_pyfunction!(xla_builder, m)?)?; + m.add_function(wrap_pyfunction!(constant_array, m)?)?; + m.add_function(wrap_pyfunction!(gather_params, m)?)?; + m.add_function(wrap_pyfunction!(swap_param, m)?)?; + m.add_function(wrap_pyfunction!(new_input, m)?)?; + m.add_function(wrap_pyfunction!(create_bf16_array, m)?)?; + m.add_function(wrap_pyfunction!(create_f16_array, m)?)?; + m.add_function(wrap_pyfunction!(to_tensor, m)?)?; + m.add_function(wrap_pyfunction!(to_numpy, m)?)?; + m.add_function(wrap_pyfunction!(to_tuple, m)?)?; + m.add_function(wrap_pyfunction!(param_pred, m)?)?; + m.add_function(wrap_pyfunction!(param_i8, m)?)?; + m.add_function(wrap_pyfunction!(param_i16, m)?)?; + m.add_function(wrap_pyfunction!(param_i32, m)?)?; + m.add_function(wrap_pyfunction!(param_i64, m)?)?; + m.add_function(wrap_pyfunction!(param_u8, m)?)?; + m.add_function(wrap_pyfunction!(param_u16, m)?)?; + m.add_function(wrap_pyfunction!(param_u32, m)?)?; + m.add_function(wrap_pyfunction!(param_u64, m)?)?; + m.add_function(wrap_pyfunction!(param_bf16, m)?)?; + m.add_function(wrap_pyfunction!(param_f16, m)?)?; + m.add_function(wrap_pyfunction!(param_f32, m)?)?; + m.add_function(wrap_pyfunction!(param_f64, m)?)?; + m.add_function(wrap_pyfunction!(cpu_client, m)?)?; + m.add_function(wrap_pyfunction!(gpu_client, m)?)?; + m.add_function(wrap_pyfunction!(build, m)?)?; + m.add_function(wrap_pyfunction!(get_hlo_proto, m)?)?; + m.add_function(wrap_pyfunction!(hlo_module_from_proto, m)?)?; + m.add_function(wrap_pyfunction!(hlo_module_to_string, m)?)?; + m.add_function(wrap_pyfunction!(get_hlo_module_entry_computation, m)?)?; + m.add_function(wrap_pyfunction!(computation_count, m)?)?; + m.add_function(wrap_pyfunction!(instruction_count, m)?)?; + m.add_function(wrap_pyfunction!(compile, m)?)?; + m.add_function(wrap_pyfunction!(execute, m)?)?; + m.add_function(wrap_pyfunction!(to_literal, m)?)?; + m.add_function(wrap_pyfunction!(add, m)?)?; + m.add_function(wrap_pyfunction!(sub, m)?)?; + m.add_function(wrap_pyfunction!(mul, m)?)?; + m.add_function(wrap_pyfunction!(div, m)?)?; + m.add_function(wrap_pyfunction!(rem, m)?)?; + m.add_function(wrap_pyfunction!(pow, m)?)?; + m.add_function(wrap_pyfunction!(max, m)?)?; + m.add_function(wrap_pyfunction!(min, m)?)?; + m.add_function(wrap_pyfunction!(_and, m)?)?; + m.add_function(wrap_pyfunction!(_or, m)?)?; + m.add_function(wrap_pyfunction!(xor, m)?)?; + m.add_function(wrap_pyfunction!(eq, m)?)?; + m.add_function(wrap_pyfunction!(ne, m)?)?; + m.add_function(wrap_pyfunction!(ge, m)?)?; + m.add_function(wrap_pyfunction!(gt, m)?)?; + m.add_function(wrap_pyfunction!(le, m)?)?; + m.add_function(wrap_pyfunction!(lt, m)?)?; + m.add_function(wrap_pyfunction!(lshift, m)?)?; + m.add_function(wrap_pyfunction!(rshift, m)?)?; + m.add_function(wrap_pyfunction!(atan2, m)?)?; + m.add_function(wrap_pyfunction!(dot, m)?)?; + m.add_function(wrap_pyfunction!(matmul, m)?)?; + m.add_function(wrap_pyfunction!(population_count, m)?)?; + m.add_function(wrap_pyfunction!(_not, m)?)?; + m.add_function(wrap_pyfunction!(neg, m)?)?; + m.add_function(wrap_pyfunction!(abs, m)?)?; + m.add_function(wrap_pyfunction!(floor, m)?)?; + m.add_function(wrap_pyfunction!(ceil, m)?)?; + m.add_function(wrap_pyfunction!(round, m)?)?; + m.add_function(wrap_pyfunction!(round_nearest_even, m)?)?; + m.add_function(wrap_pyfunction!(exp, m)?)?; + m.add_function(wrap_pyfunction!(expm1, m)?)?; + m.add_function(wrap_pyfunction!(log, m)?)?; + m.add_function(wrap_pyfunction!(log1p, m)?)?; + m.add_function(wrap_pyfunction!(logistic, m)?)?; + m.add_function(wrap_pyfunction!(sign, m)?)?; + m.add_function(wrap_pyfunction!(clz, m)?)?; + m.add_function(wrap_pyfunction!(sin, m)?)?; + m.add_function(wrap_pyfunction!(cos, m)?)?; + m.add_function(wrap_pyfunction!(tanh, m)?)?; + m.add_function(wrap_pyfunction!(real, m)?)?; + m.add_function(wrap_pyfunction!(imag, m)?)?; + m.add_function(wrap_pyfunction!(conj, m)?)?; + m.add_function(wrap_pyfunction!(square, m)?)?; + m.add_function(wrap_pyfunction!(sqrt, m)?)?; + m.add_function(wrap_pyfunction!(rsqrt, m)?)?; + m.add_function(wrap_pyfunction!(cbrt, m)?)?; + m.add_function(wrap_pyfunction!(upper_triangle, m)?)?; + m.add_function(wrap_pyfunction!(lower_triangle, m)?)?; + m.add_function(wrap_pyfunction!(erf, m)?)?; + m.add_function(wrap_pyfunction!(is_finite, m)?)?; + m.add_function(wrap_pyfunction!(zeros_like, m)?)?; + m.add_function(wrap_pyfunction!(copy, m)?)?; + m.add_function(wrap_pyfunction!(sigmoid, m)?)?; + m.add_function(wrap_pyfunction!(silu, m)?)?; + m.add_function(wrap_pyfunction!(relu, m)?)?; + m.add_function(wrap_pyfunction!(gelu, m)?)?; + m.add_function(wrap_pyfunction!(gelu_approx, m)?)?; + m.add_function(wrap_pyfunction!(einsum1, m)?)?; + m.add_function(wrap_pyfunction!(einsum2, m)?)?; + m.add_function(wrap_pyfunction!(reshape, m)?)?; + m.add_function(wrap_pyfunction!(dynamic_reshape, m)?)?; + m.add_function(wrap_pyfunction!(broadcast, m)?)?; + m.add_function(wrap_pyfunction!(broadcast_in_dim, m)?)?; + m.add_function(wrap_pyfunction!(collapse, m)?)?; + m.add_function(wrap_pyfunction!(transpose, m)?)?; + m.add_function(wrap_pyfunction!(swap_dims, m)?)?; + m.add_function(wrap_pyfunction!(pad, m)?)?; + m.add_function(wrap_pyfunction!(pad_in_dim, m)?)?; + m.add_function(wrap_pyfunction!(slice, m)?)?; + m.add_function(wrap_pyfunction!(slice_in_dim, m)?)?; + m.add_function(wrap_pyfunction!(dynamic_slice, m)?)?; + m.add_function(wrap_pyfunction!(dynamic_update_slice, m)?)?; + m.add_function(wrap_pyfunction!(at, m)?)?; + m.add_function(wrap_pyfunction!(squeeze, m)?)?; + m.add_function(wrap_pyfunction!(clamp, m)?)?; + m.add_function(wrap_pyfunction!(concat, m)?)?; + m.add_function(wrap_pyfunction!(get_tuple_element, m)?)?; + m.add_function(wrap_pyfunction!(rng_uniform, m)?)?; + m.add_function(wrap_pyfunction!(rng_normal, m)?)?; + m.add_function(wrap_pyfunction!(astype, m)?)?; + m.add_function(wrap_pyfunction!(dimension_size, m)?)?; + m.add_function(wrap_pyfunction!(reduce, m)?)?; + m.add_function(wrap_pyfunction!(call, m)?)?; + m.add_function(wrap_pyfunction!(map, m)?)?; + m.add_function(wrap_pyfunction!(select, m)?)?; + m.add_function(wrap_pyfunction!(while_loop, m)?)?; + m.add_function(wrap_pyfunction!(conditional, m)?)?; + m.add_function(wrap_pyfunction!(conv, m)?)?; + m.add_function(wrap_pyfunction!(conv_general_dilated, m)?)?; + m.add_function(wrap_pyfunction!(batch_norm_inference, m)?)?; + m.add_function(wrap_pyfunction!(dot_general, m)?)?; + m.add_function(wrap_pyfunction!(gather, m)?)?; + m.add_function(wrap_pyfunction!(scatter, m)?)?; + m.add_function(wrap_pyfunction!(take, m)?)?; + m.add_function(wrap_pyfunction!(reduce_sum, m)?)?; + m.add_function(wrap_pyfunction!(reduce_mean, m)?)?; + m.add_function(wrap_pyfunction!(reduce_max, m)?)?; + m.add_function(wrap_pyfunction!(reduce_min, m)?)?; + m.add_function(wrap_pyfunction!(softmax, m)?)?; + m.add_function(wrap_pyfunction!(layer_norm, m)?)?; + m.add_function(wrap_pyfunction!(primitive_type, m)?)?; + m.add_function(wrap_pyfunction!(element_type, m)?)?; + m.add_function(wrap_pyfunction!(rank, m)?)?; + m.add_function(wrap_pyfunction!(shape, m)?)?; + m.add_function(wrap_pyfunction!(array_shape, m)?)?; + m.add_function(wrap_pyfunction!(dims, m)?)?; + m.add_function(wrap_pyfunction!(last_dim, m)?)?; + m.add_function(wrap_pyfunction!(tuple, m)?)?; + m.add_function(wrap_pyfunction!(get_builder, m)?)?; + m.add_function(wrap_pyfunction!(constant_array, m)?)?; + m.add_function(wrap_pyfunction!(create_array_shape, m)?)?; + m.add_function(wrap_pyfunction!(constant_i32, m)?)?; + m.add_function(wrap_pyfunction!(constant_bool, m)?)?; + m.add_function(wrap_pyfunction!(constant_i8, m)?)?; + m.add_function(wrap_pyfunction!(constant_i16, m)?)?; + m.add_function(wrap_pyfunction!(constant_i32, m)?)?; + m.add_function(wrap_pyfunction!(constant_i64, m)?)?; + m.add_function(wrap_pyfunction!(constant_u8, m)?)?; + m.add_function(wrap_pyfunction!(constant_u16, m)?)?; + m.add_function(wrap_pyfunction!(constant_u32, m)?)?; + m.add_function(wrap_pyfunction!(constant_u64, m)?)?; + m.add_function(wrap_pyfunction!(constant_f32, m)?)?; + m.add_function(wrap_pyfunction!(constant_f64, m)?)?; + m.add_function(wrap_pyfunction!(astype_bool, m)?)?; + m.add_function(wrap_pyfunction!(astype_i8, m)?)?; + m.add_function(wrap_pyfunction!(astype_i16, m)?)?; + m.add_function(wrap_pyfunction!(astype_i32, m)?)?; + m.add_function(wrap_pyfunction!(astype_i64, m)?)?; + m.add_function(wrap_pyfunction!(astype_u8, m)?)?; + m.add_function(wrap_pyfunction!(astype_u16, m)?)?; + m.add_function(wrap_pyfunction!(astype_u32, m)?)?; + m.add_function(wrap_pyfunction!(astype_u64, m)?)?; + m.add_function(wrap_pyfunction!(astype_bf16, m)?)?; + m.add_function(wrap_pyfunction!(astype_f16, m)?)?; + m.add_function(wrap_pyfunction!(astype_f32, m)?)?; + m.add_function(wrap_pyfunction!(astype_f64, m)?)?; + Ok(()) +} + diff --git a/ivy/engines/XLA/rust_api/src/wrappers/literal.rs b/ivy/engines/XLA/rust_api/src/wrappers/literal.rs new file mode 100644 index 0000000000000..cc5dd8466ff77 --- /dev/null +++ b/ivy/engines/XLA/rust_api/src/wrappers/literal.rs @@ -0,0 +1,287 @@ +use super::{ + ArrayElement, ArrayShape, ElementType, FromPrimitive, NativeType, PrimitiveType, Shape, +}; +use crate::{c_lib, Error, Result}; +use pyo3::prelude::*; + +/// A literal represent a value, typically a multi-dimensional array, stored on the host device. +#[derive(Debug)] +#[pyclass(unsendable)] +pub struct Literal(pub(super) c_lib::literal); + +impl Clone for Literal { + fn clone(&self) -> Self { + let v = unsafe { c_lib::literal_clone(self.0) }; + Self(v) + } +} + +impl Literal { + /// Create an uninitialized literal based on some primitive type and some dimensions. + pub fn create_from_shape(ty: PrimitiveType, dims: &[usize]) -> Self { + let dims: Vec<_> = dims.iter().map(|x| *x as i64).collect(); + let v = unsafe { c_lib::literal_create_from_shape(ty as i32, dims.as_ptr(), dims.len()) }; + Self(v) + } + + /// Create an uninitialized literal based on some primitive type, some dimensions, and some data. + /// The data is untyped, i.e. it is a sequence of bytes represented as a slice of `u8` even if + /// the primitive type is not `U8`. + pub fn create_from_shape_and_untyped_data( + ty: ElementType, + dims: &[usize], + untyped_data: &[u8], + ) -> Result { + let dims64: Vec<_> = dims.iter().map(|x| *x as i64).collect(); + let ty = ty.primitive_type(); + let v = unsafe { + c_lib::literal_create_from_shape_and_data( + ty as i32, + dims64.as_ptr(), + dims64.len(), + untyped_data.as_ptr() as *const libc::c_void, + untyped_data.len(), + ) + }; + if v.is_null() { + return Err(Error::CannotCreateLiteralWithData { + data_len_in_bytes: untyped_data.len(), + ty, + dims: dims.to_vec(), + }); + } + Ok(Self(v)) + } + + /// Get the first element from a literal. This returns an error if type `T` is not the + /// primitive type that the literal uses. + pub fn get_first_element(&self) -> Result { + let ty = self.ty()?; + if ty != T::TY { + Err(Error::ElementTypeMismatch { on_device: ty, on_host: T::TY })? + } + if self.element_count() == 0 { + Err(Error::EmptyLiteral)? + } + let v = unsafe { T::literal_get_first_element(self.0) }; + Ok(v) + } + + /// The number of elements stored in the literal. + pub fn element_count(&self) -> usize { + unsafe { c_lib::literal_element_count(self.0) as usize } + } + + /// The primitive type used by element stored in this literal. + pub fn primitive_type(&self) -> Result { + let ty = unsafe { c_lib::literal_element_type(self.0) }; + match FromPrimitive::from_i32(ty) { + None => Err(Error::UnexpectedElementType(ty)), + Some(ty) => Ok(ty), + } + } + + /// The element type used by element stored in this literal. + pub fn element_type(&self) -> Result { + self.primitive_type()?.element_type() + } + + /// The element type used by element stored in this literal, shortcut for `element_type`. + pub fn ty(&self) -> Result { + self.element_type() + } + + /// The literal size in bytes, this is the same as `element_count` multiplied by + /// `element_size_in_bytes`. + pub fn size_bytes(&self) -> usize { + unsafe { c_lib::literal_size_bytes(self.0) as usize } + } + + /// The [`Shape`] of the literal, this contains information about the dimensions of the + /// underlying array, as well as the primitive type of the array's elements. + pub fn shape(&self) -> Result { + let mut out: c_lib::shape = std::ptr::null_mut(); + unsafe { c_lib::literal_shape(self.0, &mut out) }; + let c_shape = super::shape::CShape::from_ptr(out); + c_shape.shape() + } + + pub fn array_shape(&self) -> Result { + ArrayShape::try_from(&self.shape()?) + } + + /// Copy the literal data to a slice. This returns an error if the primitive type used by the + /// literal is not `T` or if the number of elements in the slice and literal are different. + pub fn copy_raw_to(&self, dst: &mut [T]) -> Result<()> { + let ty = self.ty()?; + let element_count = self.element_count(); + if ty != T::TY { + Err(Error::ElementTypeMismatch { on_device: ty, on_host: T::TY })? + } + if dst.len() > element_count { + Err(Error::BinaryBufferIsTooLarge { element_count, buffer_len: dst.len() })? + } + unsafe { + c_lib::literal_copy_to( + self.0, + dst.as_mut_ptr() as *mut libc::c_void, + element_count * T::ELEMENT_SIZE_IN_BYTES, + ) + }; + Ok(()) + } + + /// Copy data from a slice to the literal. This returns an error if the primitive type used + /// by the literal is not `T` or if number of elements in the slice and the literal are + /// different. + pub fn copy_raw_from(&mut self, src: &[T]) -> Result<()> { + let ty = self.ty()?; + let element_count = self.element_count(); + if ty != T::TY { + Err(Error::ElementTypeMismatch { on_device: ty, on_host: T::TY })? + } + if src.len() > element_count { + Err(Error::BinaryBufferIsTooLarge { element_count, buffer_len: src.len() })? + } + unsafe { + c_lib::literal_copy_from( + self.0, + src.as_ptr() as *const libc::c_void, + element_count * T::ELEMENT_SIZE_IN_BYTES, + ) + }; + Ok(()) + } + + /// Copy the values stored in the literal in a newly created vector. The data is flattened out + /// for literals with more than one dimension. + pub fn to_vec(&self) -> Result> { + let element_count = self.element_count(); + // Maybe we should use an uninitialized vec instead? + let mut data = vec![T::ZERO; element_count]; + self.copy_raw_to(&mut data)?; + Ok(data) + } + + /// Create a literal from a scalar value, the resulting literal has zero dimensions and stores + /// a single element. + pub fn scalar(t: T) -> Self { + let ptr = unsafe { T::create_r0(t) }; + Literal(ptr) + } + + /// Create a literal from a slice of data, the resulting literal has one dimension which size + /// is the same as the slice passed as argument. + pub fn vec1(f: &[T]) -> Self { + let ptr = unsafe { T::create_r1(f.as_ptr(), f.len()) }; + Literal(ptr) + } + + /// Create a new literal containing the same data but using a different shape. This returns an + /// error if the number of elements in the literal is different from the product of the target + /// dimension sizes. + pub fn reshape(&self, dims: &[i64]) -> Result { + let mut result: c_lib::literal = std::ptr::null_mut(); + let status = + unsafe { c_lib::literal_reshape(self.0, dims.as_ptr(), dims.len(), &mut result) }; + super::handle_status(status)?; + Ok(Literal(result)) + } + + /// Create a new literal containing the data from the original literal casted to a new + /// primitive type. The dimensions of the resulting literal are the same as the dimensions of + /// the original literal. + pub fn convert(&self, ty: PrimitiveType) -> Result { + let mut result: c_lib::literal = std::ptr::null_mut(); + let status = unsafe { c_lib::literal_convert(self.0, ty as i32, &mut result) }; + super::handle_status(status)?; + Ok(Literal(result)) + } + + /// When the input is a tuple, return a vector of its elements. This replaces the original + /// value by an empty tuple, no copy is performed. + pub fn decompose_tuple(&mut self) -> Result> { + match self.shape()? { + Shape::Array(_) | Shape::Unsupported(_) => Ok(vec![]), + Shape::Tuple(shapes) => { + let tuple_len = shapes.len(); + let mut outputs = vec![std::ptr::null_mut::(); tuple_len]; + unsafe { c_lib::literal_decompose_tuple(self.0, outputs.as_mut_ptr(), tuple_len) }; + Ok(outputs.into_iter().map(Literal).collect()) + } + } + } + + pub fn to_tuple(mut self) -> Result> { + self.decompose_tuple() + } + + pub fn to_tuple1(mut self) -> Result { + let mut tuple = self.decompose_tuple()?; + if tuple.len() != 1 { + Err(Error::UnexpectedNumberOfElemsInTuple { expected: 1, got: tuple.len() })? + } + let v1 = tuple.pop().unwrap(); + Ok(v1) + } + + pub fn to_tuple2(mut self) -> Result<(Self, Self)> { + let mut tuple = self.decompose_tuple()?; + if tuple.len() != 2 { + Err(Error::UnexpectedNumberOfElemsInTuple { expected: 2, got: tuple.len() })? + } + let v2 = tuple.pop().unwrap(); + let v1 = tuple.pop().unwrap(); + Ok((v1, v2)) + } + + pub fn to_tuple3(mut self) -> Result<(Self, Self, Self)> { + let mut tuple = self.decompose_tuple()?; + if tuple.len() != 3 { + Err(Error::UnexpectedNumberOfElemsInTuple { expected: 3, got: tuple.len() })? + } + let v3 = tuple.pop().unwrap(); + let v2 = tuple.pop().unwrap(); + let v1 = tuple.pop().unwrap(); + Ok((v1, v2, v3)) + } + + pub fn to_tuple4(mut self) -> Result<(Self, Self, Self, Self)> { + let mut tuple = self.decompose_tuple()?; + if tuple.len() != 4 { + Err(Error::UnexpectedNumberOfElemsInTuple { expected: 4, got: tuple.len() })? + } + let v4 = tuple.pop().unwrap(); + let v3 = tuple.pop().unwrap(); + let v2 = tuple.pop().unwrap(); + let v1 = tuple.pop().unwrap(); + Ok((v1, v2, v3, v4)) + } + + pub fn tuple(elems: Vec) -> Self { + let elem_ptrs: Vec<_> = elems.iter().map(|e| e.0).collect(); + let literal = + unsafe { c_lib::literal_make_tuple_owned(elem_ptrs.as_ptr(), elem_ptrs.len()) }; + // Ensure that elems are only dropped after the pointers have been used. + drop(elems); + Self(literal) + } +} + +impl From for Literal { + fn from(f: T) -> Self { + Literal::scalar(f) + } +} + +impl From<&[T]> for Literal { + fn from(f: &[T]) -> Self { + Literal::vec1(f) + } +} + +impl Drop for Literal { + fn drop(&mut self) { + unsafe { c_lib::literal_free(self.0) } + } +} diff --git a/ivy/engines/XLA/rust_api/src/wrappers/mod.rs b/ivy/engines/XLA/rust_api/src/wrappers/mod.rs new file mode 100644 index 0000000000000..cff2017e46d19 --- /dev/null +++ b/ivy/engines/XLA/rust_api/src/wrappers/mod.rs @@ -0,0 +1,508 @@ +mod literal; +mod pjrt_buffer; +mod pjrt_client; +mod pjrt_device; +mod pjrt_loaded_executable; +mod shape; +mod xla_builder; +mod xla_op; + +use crate::c_lib; +use crate::error::{Error, Result}; +use num_derive::FromPrimitive; +use num_traits::FromPrimitive; + +pub use literal::Literal; +pub use pjrt_buffer::PjRtBuffer; +pub use pjrt_client::PjRtClient; +pub use pjrt_device::PjRtDevice; +pub use pjrt_loaded_executable::PjRtLoadedExecutable; +pub use shape::{ArrayShape, Shape}; +pub use xla_builder::XlaBuilder; +pub use xla_op::XlaOp; +use pyo3::prelude::*; + +pub(self) unsafe fn c_ptr_to_string(ptr: *const std::ffi::c_char) -> String { + let str = std::ffi::CStr::from_ptr(ptr).to_string_lossy().into_owned(); + libc::free(ptr as *mut libc::c_void); + str +} + +/// The primitive types supported by XLA. `S8` is a signed 1 byte integer, +/// `U32` is an unsigned 4 bytes integer, etc. +#[derive(Clone, Copy, PartialEq, Eq, Debug, FromPrimitive)] +#[pyclass(unsendable)] +pub enum PrimitiveType { + Invalid = 0, + Pred = 1, + S8 = 2, + S16 = 3, + S32 = 4, + S64 = 5, + U8 = 6, + U16 = 7, + U32 = 8, + U64 = 9, + F16 = 10, + F32 = 11, + Bf16 = 16, + F64 = 12, + C64 = 15, + C128 = 18, + Tuple = 13, + OpaqueType = 14, + Token = 17, +} + +impl PrimitiveType { + pub fn element_type(self) -> Result { + match self { + Self::Pred => Ok(ElementType::Pred), + Self::S8 => Ok(ElementType::S8), + Self::S16 => Ok(ElementType::S16), + Self::S32 => Ok(ElementType::S32), + Self::S64 => Ok(ElementType::S64), + Self::U8 => Ok(ElementType::U8), + Self::U16 => Ok(ElementType::U16), + Self::U32 => Ok(ElementType::U32), + Self::U64 => Ok(ElementType::U64), + Self::F16 => Ok(ElementType::F16), + Self::F32 => Ok(ElementType::F32), + Self::Bf16 => Ok(ElementType::Bf16), + Self::F64 => Ok(ElementType::F64), + Self::C64 => Ok(ElementType::C64), + Self::C128 => Ok(ElementType::C128), + Self::Invalid | Self::Tuple | Self::OpaqueType | Self::Token => { + Err(Error::NotAnElementType { got: self }) + } + } + } +} + +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +#[pyclass(unsendable)] +pub enum ElementType { + Pred, + S8, + S16, + S32, + S64, + U8, + U16, + U32, + U64, + F16, + F32, + Bf16, + F64, + C64, + C128, +} + +impl ElementType { + /// The size for this element type in bytes. + pub fn element_size_in_bytes(&self) -> usize { + match self { + Self::Pred => 1, + Self::S8 => 1, + Self::S16 => 2, + Self::S32 => 4, + Self::S64 => 8, + Self::U8 => 1, + Self::U16 => 2, + Self::U32 => 4, + Self::U64 => 8, + Self::F16 => 2, + Self::F32 => 4, + Self::Bf16 => 2, + Self::F64 => 8, + Self::C64 => 8, + Self::C128 => 16, + } + } + + pub fn primitive_type(&self) -> PrimitiveType { + match self { + Self::Pred => PrimitiveType::Pred, + Self::S8 => PrimitiveType::S8, + Self::S16 => PrimitiveType::S16, + Self::S32 => PrimitiveType::S32, + Self::S64 => PrimitiveType::S64, + Self::U8 => PrimitiveType::U8, + Self::U16 => PrimitiveType::U16, + Self::U32 => PrimitiveType::U32, + Self::U64 => PrimitiveType::U64, + Self::F16 => PrimitiveType::F16, + Self::F32 => PrimitiveType::F32, + Self::Bf16 => PrimitiveType::Bf16, + Self::F64 => PrimitiveType::F64, + Self::C64 => PrimitiveType::C64, + Self::C128 => PrimitiveType::C128, + } + } +} + +pub trait ArrayElement: Copy { + const TY: ElementType; + const ELEMENT_SIZE_IN_BYTES: usize; + const ZERO: Self; +} + +#[allow(clippy::missing_safety_doc)] +/// A type implementing the `NativeType` trait can be directly converted to constant ops or +/// literals. +pub trait NativeType: Copy { + unsafe fn constant_r0(b: c_lib::xla_builder, v: Self) -> c_lib::xla_op; + unsafe fn constant_r1(b: c_lib::xla_builder, v: *const Self, l: usize) -> c_lib::xla_op; + unsafe fn constant_r1c(b: c_lib::xla_builder, v: Self, l: usize) -> c_lib::xla_op; + unsafe fn create_r0(v: Self) -> c_lib::literal; + unsafe fn create_r1(v: *const Self, l: usize) -> c_lib::literal; + unsafe fn literal_get_first_element(l: c_lib::literal) -> Self; +} + +macro_rules! native_type { + ($ty:ty, $cst0:ident, $cst1:ident, $cst1c:ident, $cre0:ident, $cre1:ident, $gf:ident) => { + impl NativeType for $ty { + unsafe fn constant_r0(b: c_lib::xla_builder, v: Self) -> c_lib::xla_op { + c_lib::$cst0(b, v) + } + unsafe fn constant_r1( + b: c_lib::xla_builder, + v: *const Self, + l: usize, + ) -> c_lib::xla_op { + c_lib::$cst1(b, v, l) + } + unsafe fn constant_r1c(b: c_lib::xla_builder, v: Self, l: usize) -> c_lib::xla_op { + c_lib::$cst1c(b, v, l) + } + unsafe fn create_r0(v: Self) -> c_lib::literal { + c_lib::$cre0(v) + } + unsafe fn create_r1(v: *const Self, l: usize) -> c_lib::literal { + c_lib::$cre1(v, l) + } + unsafe fn literal_get_first_element(l: c_lib::literal) -> Self { + c_lib::$gf(l) + } + } + }; +} + +native_type!( + bool, + constant_r0_bool, + constant_r1_bool, + constant_r1c_bool, + create_r0_bool, + create_r1_bool, + literal_get_first_element_bool +); + +native_type!( + i8, + constant_r0_int8_t, + constant_r1_int8_t, + constant_r1c_int8_t, + create_r0_int8_t, + create_r1_int8_t, + literal_get_first_element_int8_t +); + +native_type!( + i16, + constant_r0_int16_t, + constant_r1_int16_t, + constant_r1c_int16_t, + create_r0_int16_t, + create_r1_int16_t, + literal_get_first_element_int16_t +); + +native_type!( + i32, + constant_r0_int32_t, + constant_r1_int32_t, + constant_r1c_int32_t, + create_r0_int32_t, + create_r1_int32_t, + literal_get_first_element_int32_t +); + +native_type!( + i64, + constant_r0_int64_t, + constant_r1_int64_t, + constant_r1c_int64_t, + create_r0_int64_t, + create_r1_int64_t, + literal_get_first_element_int64_t +); + +native_type!( + u8, + constant_r0_uint8_t, + constant_r1_uint8_t, + constant_r1c_uint8_t, + create_r0_uint8_t, + create_r1_uint8_t, + literal_get_first_element_uint8_t +); + +native_type!( + u16, + constant_r0_uint16_t, + constant_r1_uint16_t, + constant_r1c_uint16_t, + create_r0_uint16_t, + create_r1_uint16_t, + literal_get_first_element_uint16_t +); + +native_type!( + u32, + constant_r0_uint32_t, + constant_r1_uint32_t, + constant_r1c_uint32_t, + create_r0_uint32_t, + create_r1_uint32_t, + literal_get_first_element_uint32_t +); + +native_type!( + u64, + constant_r0_uint64_t, + constant_r1_uint64_t, + constant_r1c_uint64_t, + create_r0_uint64_t, + create_r1_uint64_t, + literal_get_first_element_uint64_t +); + +native_type!( + f32, + constant_r0_float, + constant_r1_float, + constant_r1c_float, + create_r0_float, + create_r1_float, + literal_get_first_element_float +); + +native_type!( + f64, + constant_r0_double, + constant_r1_double, + constant_r1c_double, + create_r0_double, + create_r1_double, + literal_get_first_element_double +); + +macro_rules! element_type { + ($ty:ty, $v:ident, $sz:tt, $zero:expr) => { + impl ArrayElement for $ty { + const TY: ElementType = ElementType::$v; + const ELEMENT_SIZE_IN_BYTES: usize = $sz; + const ZERO: Self = $zero; + } + }; +} + +// Dummy F16 type. +#[derive(Copy, Clone, Debug)] +pub struct F16; + +impl ArrayElement for F16 { + const TY: ElementType = ElementType::F16; + const ELEMENT_SIZE_IN_BYTES: usize = 2; + const ZERO: Self = Self; +} + +// Dummy BF16 type. +#[derive(Copy, Clone, Debug)] +pub struct Bf16; + +impl ArrayElement for Bf16 { + const TY: ElementType = ElementType::Bf16; + const ELEMENT_SIZE_IN_BYTES: usize = 2; + const ZERO: Self = Self; +} + +element_type!(bool, Pred, 1, false); +element_type!(u8, U8, 1, 0); +element_type!(u16, U16, 2, 0); +element_type!(u32, U32, 4, 0); +element_type!(u64, U64, 8, 0); +element_type!(i8, S8, 1, 0); +element_type!(i16, S16, 2, 0); +element_type!(i32, S32, 4, 0); +element_type!(i64, S64, 8, 0); +element_type!(f32, F32, 4, 0.0f32); +element_type!(f64, F64, 8, 0.0f64); + + +/// A computation is built from a root [`XlaOp`]. Computations are device independent and can be +/// specialized to a given device through a compilation step. +#[derive(Clone)] +#[pyclass(unsendable)] +pub struct XlaComputation(c_lib::xla_computation); + +pub(self) fn handle_status(status: c_lib::status) -> Result<()> { + if status.is_null() { + Ok(()) + } else { + let msg = unsafe { + let error_message_ptr = c_lib::status_error_message(status); + let error_message = c_ptr_to_string(error_message_ptr); + c_lib::status_free(status); + error_message + }; + let backtrace = std::backtrace::Backtrace::capture().to_string(); + Err(Error::XlaError { msg, backtrace }) + } +} + +impl XlaComputation { + pub fn from_proto(proto: &HloModuleProto) -> Self { + let ptr = unsafe { c_lib::xla_computation_from_hlo_module_proto(proto.0) }; + Self(ptr) + } + + /// The computation name. + pub fn name(&self) -> String { + unsafe { + let ptr = c_lib::xla_computation_name(self.0); + c_ptr_to_string(ptr) + } + } + + /// Compile this computation for the specified client. + pub fn compile(&self, client: &PjRtClient) -> Result { + client.compile(self) + } + + /// Get the HloModuleProto for the computation. + pub fn proto(&self) -> HloModuleProto { + let ptr = unsafe { c_lib::xla_computation_proto(self.0) }; + HloModuleProto(ptr) + } + +} + +impl Drop for XlaComputation { + fn drop(&mut self) { + unsafe { c_lib::xla_computation_free(self.0) } + } +} + +#[pyclass(unsendable)] +pub struct HloModuleProto(c_lib::hlo_module_proto); + +impl HloModuleProto { + /// Read a HLO module from a text file. + pub fn from_text_file>(path: P) -> Result { + use std::io::Read; + let mut file = std::fs::File::open(path.as_ref())?; + let mut content = Vec::new(); + file.read_to_end(&mut content)?; + Self::parse_and_return_unverified_module(&content) + } + + /// Read a HLO module from a proto file, either in binary or pbtxt format. + pub fn from_proto_file>(path: P, binary: bool) -> Result { + use std::io::Read; + let mut file = std::fs::File::open(path.as_ref())?; + let mut content = Vec::new(); + file.read_to_end(&mut content)?; + Self::parse_proto(&content, binary) + } + + pub fn parse_and_return_unverified_module(data: &[u8]) -> Result { + let mut ptr: c_lib::hlo_module_proto = std::ptr::null_mut(); + let status = unsafe { + c_lib::hlo_module_proto_parse_and_return_unverified_module( + data.as_ptr() as *const libc::c_char, + data.len(), + &mut ptr, + ) + }; + handle_status(status)?; + Ok(Self(ptr)) + } + + pub fn parse_proto(data: &[u8], binary: bool) -> Result { + let mut ptr: c_lib::hlo_module_proto = std::ptr::null_mut(); + let status = unsafe { + c_lib::hlo_module_proto_parse_proto( + data.as_ptr() as *const libc::c_char, + data.len(), + binary, + &mut ptr, + ) + }; + handle_status(status)?; + Ok(Self(ptr)) + } +} + +impl Drop for HloModuleProto { + fn drop(&mut self) { + unsafe { c_lib::hlo_module_proto_free(self.0) } + } +} + +#[pyclass(unsendable)] +pub struct HloModule(c_lib::hlo_module); + +impl HloModule { + pub fn from_proto(proto: &HloModuleProto) -> Result { + let mut ptr = std::ptr::null_mut(); + let status = unsafe { + c_lib::hlo_module_from_proto( + proto.0, + &mut ptr, + )}; + handle_status(status)?; + Ok(Self(ptr)) + } + + pub fn to_string(&self) -> Result { + let str_ptr = unsafe { + c_lib::hlo_module_to_string(self.0) + }; + let module_str = unsafe { + let c_str = std::ffi::CStr::from_ptr(str_ptr); + let result = c_str.to_str()?.to_string(); + libc::free(str_ptr as *mut _); + Ok(result) + }; + module_str + } + + + pub fn get_entry_computation(&self) -> Result { + let entry_comp = unsafe { + c_lib::hlo_module_entry_computation(self.0) + }; + Ok(HloComputation(entry_comp)) + } + + pub fn computation_count(&self) -> Result { + let comp_count = unsafe { + c_lib::hlo_module_computation_count(self.0) + }; + Ok(comp_count) + } + + pub fn instruction_count(&self) -> Result { + let instruct_count = unsafe { + c_lib::hlo_module_instruction_count(self.0) + }; + Ok(instruct_count) + } +} + + +#[pyclass(unsendable)] +pub struct HloComputation(c_lib::hlo_computation); diff --git a/ivy/engines/XLA/rust_api/src/wrappers/pjrt_buffer.rs b/ivy/engines/XLA/rust_api/src/wrappers/pjrt_buffer.rs new file mode 100644 index 0000000000000..460725e6df5be --- /dev/null +++ b/ivy/engines/XLA/rust_api/src/wrappers/pjrt_buffer.rs @@ -0,0 +1,76 @@ +//! A view on a memory slice hosted on a device. +use super::{ArrayElement, ArrayShape, Literal, PjRtDevice, Shape}; +use crate::{c_lib, Error, Result}; +use pyo3::prelude::*; + +/// A buffer represents a view on a memory slice hosted on a device. +#[derive(Clone)] +#[pyclass(unsendable)] +pub struct PjRtBuffer { + pub(super) buffer: c_lib::pjrt_buffer, + pub(super) client: super::PjRtClient, +} + +impl PjRtBuffer { + /// The client that owns this buffer. + pub fn client(&self) -> &super::PjRtClient { + &self.client + } + + /// Copy the buffer to a different device. + pub fn copy_to_device(&self, device: PjRtDevice) -> Result { + let mut buffer: c_lib::pjrt_buffer = std::ptr::null_mut(); + let status = + unsafe { c_lib::pjrt_buffer_copy_to_device(self.buffer, device.device, &mut buffer) }; + super::handle_status(status)?; + Ok(Self { buffer, client: self.client.clone() }) + } + + /// Copy the buffer back to the host as a literal. + pub fn to_literal_sync(&self) -> Result { + let mut result: c_lib::literal = std::ptr::null_mut(); + let status = unsafe { c_lib::pjrt_buffer_to_literal_sync(self.buffer, &mut result) }; + super::handle_status(status)?; + Ok(Literal(result)) + } + + /// Retrieve the shape used by this buffer. + pub fn on_device_shape(&self) -> Result { + let shape = unsafe { c_lib::pjrt_buffer_on_device_shape(self.buffer) }; + let c_shape = super::shape::CShape::from_ptr(shape); + c_shape.shape() + } + + /// Copy the data stored in a buffer to host memory in a blocking way. + pub fn copy_raw_to_host_sync( + &self, + dst: &mut [T], + offset: usize, + ) -> Result<()> { + let shape = ArrayShape::try_from(&self.on_device_shape()?)?; + let on_host = T::TY; + let on_device = shape.primitive_type().element_type()?; + if on_device != on_host { + Err(Error::ElementTypeMismatch { on_device, on_host })? + } + if offset + dst.len() > shape.element_count() { + Err(Error::TargetBufferIsTooLarge { offset, shape, buffer_len: dst.len() })? + } + let status = unsafe { + c_lib::pjrt_buffer_copy_raw_to_host_sync( + self.buffer, + dst.as_mut_ptr() as *mut libc::c_void, + offset, + dst.len() * T::ELEMENT_SIZE_IN_BYTES, + ) + }; + super::handle_status(status)?; + Ok(()) + } +} + +impl Drop for PjRtBuffer { + fn drop(&mut self) { + unsafe { c_lib::pjrt_buffer_free(self.buffer) } + } +} diff --git a/ivy/engines/XLA/rust_api/src/wrappers/pjrt_client.rs b/ivy/engines/XLA/rust_api/src/wrappers/pjrt_client.rs new file mode 100644 index 0000000000000..75636cc19915e --- /dev/null +++ b/ivy/engines/XLA/rust_api/src/wrappers/pjrt_client.rs @@ -0,0 +1,189 @@ +//! A device (CPUs, GPUs, TPUs) where computations can be run. +use super::{ArrayElement, Literal, PjRtBuffer, PjRtDevice, PjRtLoadedExecutable, XlaComputation}; +use crate::{c_lib, Error, Result}; +use std::marker::PhantomData; +use std::rc::Rc; +use pyo3::prelude::*; + +pub(super) struct PjRtClientInternal(pub(self) c_lib::pjrt_client); + +/// A client represents a device that can be used to run some computations. A computation graph is +/// compiled in a way that is specific to a device before it can be run. +#[derive(Clone)] +#[pyclass(unsendable)] +pub struct PjRtClient(Rc); + +impl PjRtClient { + /// A CPU client, this can run computations on multiple CPUs at the same time. + pub fn cpu() -> Result { + let mut ptr: c_lib::pjrt_client = std::ptr::null_mut(); + let status = unsafe { c_lib::pjrt_cpu_client_create(&mut ptr) }; + super::handle_status(status)?; + Ok(Self(Rc::new(PjRtClientInternal(ptr)))) + } + + /// A GPU client, the memory requirements are limited by the specified `memory_fraction` and + /// this memory can either be allocated dynamically or pre-allocated depending on + /// `preallocate`. + pub fn gpu(memory_fraction: f64, preallocate: bool) -> Result { + let mut ptr: c_lib::pjrt_client = std::ptr::null_mut(); + let status = + unsafe { c_lib::pjrt_gpu_client_create(&mut ptr, memory_fraction, preallocate) }; + super::handle_status(status)?; + Ok(Self(Rc::new(PjRtClientInternal(ptr)))) + } + + /// A TPU client. + pub fn tpu(max_inflight_computations: usize) -> Result { + let mut ptr: c_lib::pjrt_client = std::ptr::null_mut(); + let status = + unsafe { c_lib::pjrt_tpu_client_create(&mut ptr, max_inflight_computations as i32) }; + super::handle_status(status)?; + Ok(Self(Rc::new(PjRtClientInternal(ptr)))) + } + + fn ptr(&self) -> c_lib::pjrt_client { + self.0 .0 + } + + /// Compile a computation for this device, and return the executable. + pub fn compile(&self, c: &XlaComputation) -> Result { + let mut exe: c_lib::pjrt_loaded_executable = std::ptr::null_mut(); + let status = unsafe { c_lib::compile(self.ptr(), c.0, &mut exe) }; + super::handle_status(status)?; + Ok(PjRtLoadedExecutable { exe, client: self.clone() }) + } + + /// The number of devices that this client has detected, e.g. the number of GPUs. + pub fn device_count(&self) -> usize { + unsafe { c_lib::pjrt_client_device_count(self.ptr()) as usize } + } + + /// The number of devices that this client can use. + pub fn addressable_device_count(&self) -> usize { + unsafe { c_lib::pjrt_client_addressable_device_count(self.ptr()) as usize } + } + + /// The name of the platform. + pub fn platform_name(&self) -> String { + unsafe { + let ptr = c_lib::pjrt_client_platform_name(self.ptr()); + super::c_ptr_to_string(ptr) + } + } + + /// The version of the platform. + pub fn platform_version(&self) -> String { + unsafe { + let ptr = c_lib::pjrt_client_platform_version(self.ptr()); + super::c_ptr_to_string(ptr) + } + } + + /// A list of devices attached to this client. + pub fn devices(&self) -> Vec { + let device_count = self.device_count(); + let mut device_ptrs = vec![std::ptr::null_mut(); device_count]; + unsafe { c_lib::pjrt_client_devices(self.ptr(), device_ptrs.as_mut_ptr()) }; + device_ptrs.into_iter().map(|device| PjRtDevice { device, marker: PhantomData }).collect() + } + + /// A list of devices that can be used by this client. + pub fn addressable_devices(&self) -> Vec { + let device_count = self.addressable_device_count(); + let mut device_ptrs = vec![std::ptr::null_mut(); device_count]; + unsafe { c_lib::pjrt_client_addressable_devices(self.ptr(), device_ptrs.as_mut_ptr()) }; + device_ptrs.into_iter().map(|device| PjRtDevice { device, marker: PhantomData }).collect() + } + + /// Transfer some data from the host to a `PjRtBuffer` stored on the target device. If the + /// device is not specified, the default device is used. + /// The source data is passed as a slice of the specified primitive type, as well as the + /// dimensions. The dimensions have to match the number of elements in the source data, + /// otherwise an error is returned. + pub fn buffer_from_host_buffer( + &self, + data: &[T], + dims: &[usize], + device: Option<&PjRtDevice>, + ) -> Result { + let mut buffer: c_lib::pjrt_buffer = std::ptr::null_mut(); + let element_count: usize = dims.iter().product(); + if element_count != data.len() { + Err(Error::WrongElementCount { dims: dims.to_vec(), element_count })? + } + let device = device.map_or(std::ptr::null_mut(), |d| d.device); + let dims: Vec<_> = dims.iter().map(|d| *d as i64).collect(); + let status = unsafe { + c_lib::pjrt_buffer_from_host_buffer( + self.ptr(), + device, + data.as_ptr() as *const libc::c_void, + T::TY.primitive_type() as i32, + dims.len() as i32, + dims.as_ptr(), + &mut buffer, + ) + }; + super::handle_status(status)?; + Ok(PjRtBuffer { buffer, client: self.clone() }) + } + + /// Transfer some data from the host to a `PjRtBuffer` stored on the target device. If the + /// device is not specified, the default device is used. + /// The source data is passed as a slice of raw bytes, as well as the dimensions. The + /// dimensions have to match the number of bytes in the source data, otherwise an error + /// is returned. + pub fn buffer_from_host_raw_bytes( + &self, + ty: super::ElementType, + data: &[u8], + dims: &[usize], + device: Option<&PjRtDevice>, + ) -> Result { + let mut buffer: c_lib::pjrt_buffer = std::ptr::null_mut(); + let element_count: usize = dims.iter().product(); + let element_size_in_bytes = ty.element_size_in_bytes(); + if element_count * element_size_in_bytes != data.len() { + Err(Error::WrongElementCount { dims: dims.to_vec(), element_count })? + } + let device = device.map_or(std::ptr::null_mut(), |d| d.device); + let dims: Vec<_> = dims.iter().map(|d| *d as i64).collect(); + let status = unsafe { + c_lib::pjrt_buffer_from_host_buffer( + self.ptr(), + device, + data.as_ptr() as *const libc::c_void, + ty as i32, + dims.len() as i32, + dims.as_ptr(), + &mut buffer, + ) + }; + super::handle_status(status)?; + Ok(PjRtBuffer { buffer, client: self.clone() }) + } + + /// Transfer some data from the host to a `PjRtBuffer` stored on the target device. If the + /// device is not specified, the default device is used. + /// The source data is passed as a literal. + pub fn buffer_from_host_literal( + &self, + device: Option<&PjRtDevice>, + literal: &Literal, + ) -> Result { + let mut buffer: c_lib::pjrt_buffer = std::ptr::null_mut(); + let device = device.map_or(std::ptr::null_mut(), |d| d.device); + let status = unsafe { + c_lib::pjrt_buffer_from_host_literal(self.ptr(), device, literal.0, &mut buffer) + }; + super::handle_status(status)?; + Ok(PjRtBuffer { buffer, client: self.clone() }) + } +} + +impl Drop for PjRtClientInternal { + fn drop(&mut self) { + unsafe { c_lib::pjrt_client_free(self.0) } + } +} diff --git a/ivy/engines/XLA/rust_api/src/wrappers/pjrt_device.rs b/ivy/engines/XLA/rust_api/src/wrappers/pjrt_device.rs new file mode 100644 index 0000000000000..2a24bfd8a9c06 --- /dev/null +++ b/ivy/engines/XLA/rust_api/src/wrappers/pjrt_device.rs @@ -0,0 +1,58 @@ +use crate::{c_lib, Result}; +use std::marker::PhantomData; + +/// A device attached to a [`super::PjRtClient`]. +pub struct PjRtDevice<'a> { + pub(super) device: c_lib::pjrt_device, + pub(super) marker: PhantomData<&'a super::PjRtClient>, +} + +impl PjRtDevice<'_> { + /// The device unique identifier. + pub fn id(&self) -> usize { + (unsafe { c_lib::pjrt_device_id(self.device) }) as usize + } + + pub fn process_index(&self) -> usize { + (unsafe { c_lib::pjrt_device_process_index(self.device) }) as usize + } + + pub fn local_hardware_id(&self) -> usize { + (unsafe { c_lib::pjrt_device_local_hardware_id(self.device) }) as usize + } + + #[allow(clippy::inherent_to_string)] + pub fn to_string(&self) -> String { + unsafe { + let ptr = c_lib::pjrt_device_to_string(self.device); + super::c_ptr_to_string(ptr) + } + } + + pub fn kind(&self) -> String { + unsafe { + let ptr = c_lib::pjrt_device_kind(self.device); + super::c_ptr_to_string(ptr) + } + } + + pub fn debug_string(&self) -> String { + unsafe { + let ptr = c_lib::pjrt_device_debug_string(self.device); + super::c_ptr_to_string(ptr) + } + } + + pub fn transfer_to_infeed(&self, src: &super::Literal) -> Result<()> { + let status = unsafe { c_lib::pjrt_device_transfer_to_infeed(self.device, src.0) }; + super::handle_status(status)?; + Ok(()) + } + + /// Transfer and return a value for the given shape from the outfeed queue. + pub fn transfer_from_outfeed(&self, dst: &mut super::Literal) -> Result<()> { + let status = unsafe { c_lib::pjrt_device_transfer_from_outfeed(self.device, dst.0) }; + super::handle_status(status)?; + Ok(()) + } +} diff --git a/ivy/engines/XLA/rust_api/src/wrappers/pjrt_loaded_executable.rs b/ivy/engines/XLA/rust_api/src/wrappers/pjrt_loaded_executable.rs new file mode 100644 index 0000000000000..bb7b500c1e178 --- /dev/null +++ b/ivy/engines/XLA/rust_api/src/wrappers/pjrt_loaded_executable.rs @@ -0,0 +1,74 @@ +use super::{Literal, PjRtBuffer}; +use crate::{c_lib, Result}; +use pyo3::prelude::*; + +#[derive(Clone)] +#[pyclass(unsendable)] +pub struct PjRtLoadedExecutable { + pub(super) exe: c_lib::pjrt_loaded_executable, + pub(super) client: super::PjRtClient, +} + +impl PjRtLoadedExecutable { + /// The client that owns this executable. + pub fn client(&self) -> &super::PjRtClient { + &self.client + } + + fn process_execute_outputs( + &self, + outputs: *mut *mut c_lib::pjrt_buffer, + ) -> Vec> { + unsafe { + let mut vec = vec![]; + loop { + let outputs = *outputs.add(vec.len()); + if outputs.is_null() { + break; + } + let mut replica_vec = vec![]; + loop { + let buffer = *outputs.add(replica_vec.len()); + if buffer.is_null() { + break; + } + replica_vec.push(PjRtBuffer { buffer, client: self.client.clone() }); + } + libc::free(outputs as *mut libc::c_void); + vec.push(replica_vec); + } + libc::free(outputs as *mut libc::c_void); + vec + } + } + + pub fn execute>( + &self, + args: &[L], + ) -> Result>> { + let mut outputs = std::ptr::null_mut(); + let args: Vec<_> = args.iter().map(|x| x.borrow().0).collect(); + let status = + unsafe { c_lib::execute(self.exe, args.as_ptr(), args.len() as i32, &mut outputs) }; + super::handle_status(status)?; + Ok(self.process_execute_outputs(outputs)) + } + + pub fn execute_b>( + &self, + args: &[L], + ) -> Result>> { + let mut outputs = std::ptr::null_mut(); + let args: Vec<_> = args.iter().map(|x| x.borrow().buffer).collect(); + let status = + unsafe { c_lib::execute_b(self.exe, args.as_ptr(), args.len() as i32, &mut outputs) }; + super::handle_status(status)?; + Ok(self.process_execute_outputs(outputs)) + } +} + +impl Drop for PjRtLoadedExecutable { + fn drop(&mut self) { + unsafe { c_lib::pjrt_loaded_executable_free(self.exe) } + } +} diff --git a/ivy/engines/XLA/rust_api/src/wrappers/shape.rs b/ivy/engines/XLA/rust_api/src/wrappers/shape.rs new file mode 100644 index 0000000000000..de45c6bc76b50 --- /dev/null +++ b/ivy/engines/XLA/rust_api/src/wrappers/shape.rs @@ -0,0 +1,218 @@ +use super::{ArrayElement, ElementType, PrimitiveType}; +use crate::{c_lib, Error, Result}; +use pyo3::prelude::*; + +#[derive(Clone, PartialEq, Eq, Debug)] +#[pyclass(unsendable)] +pub struct ArrayShape { + ty: ElementType, + dims: Vec, +} + +impl ArrayShape { + /// Create a new array shape. + pub fn new(dims: Vec) -> Self { + Self { ty: E::TY, dims } + } + + /// Create a new array shape. + pub fn new_with_type(ty: ElementType, dims: Vec) -> Self { + Self { ty, dims } + } + + pub fn element_type(&self) -> ElementType { + self.ty + } + + pub fn ty(&self) -> ElementType { + self.ty + } + + /// The stored primitive type. + pub fn primitive_type(&self) -> PrimitiveType { + self.ty.primitive_type() + } + + /// The number of elements stored in arrays that use this shape, this is the product of sizes + /// across each dimension. + pub fn element_count(&self) -> usize { + self.dims.iter().map(|d| *d as usize).product::() + } + + pub fn dims(&self) -> &[i64] { + &self.dims + } + + pub fn first_dim(&self) -> Option { + self.dims.first().copied() + } + + pub fn last_dim(&self) -> Option { + self.dims.last().copied() + } +} + +/// A shape specifies a primitive type as well as some array dimensions. +#[derive(Clone, PartialEq, Eq, Debug)] +pub enum Shape { + Tuple(Vec), + Array(ArrayShape), + Unsupported(PrimitiveType), +} + +impl Shape { + /// Create a new array shape. + pub fn array(dims: Vec) -> Self { + Self::Array(ArrayShape { ty: E::TY, dims }) + } + + /// Create a new array shape. + pub fn array_with_type(ty: ElementType, dims: Vec) -> Self { + Self::Array(ArrayShape { ty, dims }) + } + + /// Create a new tuple shape. + pub fn tuple(shapes: Vec) -> Self { + Self::Tuple(shapes) + } + + /// The stored primitive type. + pub fn primitive_type(&self) -> PrimitiveType { + match self { + Self::Tuple(_) => PrimitiveType::Tuple, + Self::Array(a) => a.ty.primitive_type(), + Self::Unsupported(ty) => *ty, + } + } + + pub fn is_tuple(&self) -> bool { + match self { + Self::Tuple(_) => true, + Self::Array { .. } | Self::Unsupported(_) => false, + } + } + + pub fn tuple_size(&self) -> Option { + match self { + Self::Tuple(shapes) => Some(shapes.len()), + Self::Array { .. } | Self::Unsupported(_) => None, + } + } + + #[allow(dead_code)] + pub(crate) fn c_shape(&self) -> Result { + match self { + Self::Tuple(shapes) => { + let shapes = shapes.iter().map(|s| s.c_shape()).collect::>>()?; + let ptrs: Vec<_> = shapes.iter().map(|s| s.0).collect(); + let c_shape = CShape(unsafe { c_lib::make_shape_tuple(ptrs.len(), ptrs.as_ptr()) }); + drop(shapes); + Ok(c_shape) + } + Self::Array(a) => { + let dims = a.dims(); + Ok(CShape(unsafe { + c_lib::make_shape_array(a.primitive_type() as i32, dims.len(), dims.as_ptr()) + })) + } + Self::Unsupported(_) => Err(Error::UnsupportedShape { shape: self.clone() }), + } + } +} + +impl TryFrom<&Shape> for ArrayShape { + type Error = Error; + + fn try_from(value: &Shape) -> Result { + match value { + Shape::Tuple(_) | Shape::Unsupported(_) => { + Err(Error::NotAnArray { expected: None, got: value.clone() }) + } + Shape::Array(a) => Ok(a.clone()), + } + } +} + +macro_rules! extract_dims { + ($cnt:tt, $dims:expr, $out_type:ty) => { + impl TryFrom<&ArrayShape> for $out_type { + type Error = Error; + + fn try_from(value: &ArrayShape) -> Result { + if value.dims.len() != $cnt { + Err(Error::UnexpectedNumberOfDims { + expected: $cnt, + got: value.dims.len(), + dims: value.dims.clone(), + }) + } else { + Ok($dims(&value.dims)) + } + } + } + + impl TryFrom<&Shape> for $out_type { + type Error = Error; + + fn try_from(value: &Shape) -> Result { + match value { + Shape::Tuple(_) | Shape::Unsupported(_) => { + Err(Error::NotAnArray { expected: Some($cnt), got: value.clone() }) + } + Shape::Array(a) => Self::try_from(a), + } + } + } + }; +} + +extract_dims!(1, |d: &Vec| d[0], i64); +extract_dims!(2, |d: &Vec| (d[0], d[1]), (i64, i64)); +extract_dims!(3, |d: &Vec| (d[0], d[1], d[2]), (i64, i64, i64)); +extract_dims!(4, |d: &Vec| (d[0], d[1], d[2], d[3]), (i64, i64, i64, i64)); +extract_dims!(5, |d: &Vec| (d[0], d[1], d[2], d[3], d[4]), (i64, i64, i64, i64, i64)); + +pub(crate) struct CShape(c_lib::shape); + +impl CShape { + pub(crate) fn from_ptr(ptr: c_lib::shape) -> Self { + Self(ptr) + } + + pub(crate) fn shape(&self) -> Result { + fn from_ptr_rec(ptr: c_lib::shape) -> Result { + let ty = unsafe { c_lib::shape_element_type(ptr) }; + let ty = super::FromPrimitive::from_i32(ty) + .ok_or_else(|| Error::UnexpectedElementType(ty))?; + match ty { + PrimitiveType::Tuple => { + let elem_cnt = unsafe { c_lib::shape_tuple_shapes_size(ptr) }; + let shapes: Result> = (0..elem_cnt) + .map(|i| from_ptr_rec(unsafe { c_lib::shape_tuple_shapes(ptr, i as i32) })) + .collect(); + Ok(Shape::Tuple(shapes?)) + } + ty => match ty.element_type() { + Ok(ty) => { + let rank = unsafe { c_lib::shape_dimensions_size(ptr) }; + let dims: Vec<_> = + (0..rank).map(|i| unsafe { c_lib::shape_dimensions(ptr, i) }).collect(); + Ok(Shape::Array(ArrayShape { ty, dims })) + } + Err(_) => Ok(Shape::Unsupported(ty)), + }, + } + } + from_ptr_rec(self.0) + } + + pub(crate) fn as_ptr(&self) -> c_lib::shape { + self.0 + } +} + +impl Drop for CShape { + fn drop(&mut self) { + unsafe { c_lib::shape_free(self.0) }; + } +} diff --git a/ivy/engines/XLA/rust_api/src/wrappers/xla_builder.rs b/ivy/engines/XLA/rust_api/src/wrappers/xla_builder.rs new file mode 100644 index 0000000000000..dcfdb48757967 --- /dev/null +++ b/ivy/engines/XLA/rust_api/src/wrappers/xla_builder.rs @@ -0,0 +1,288 @@ +use super::{ + handle_status, FromPrimitive, Literal, NativeType, PrimitiveType, Shape, XlaComputation, XlaOp, +}; +use crate::{c_lib, Error, Result}; +use std::rc::Rc; +use pyo3::prelude::*; + +/// A builder is used to keep track of a computation graph while it's being built. +pub(super) struct XlaBuilderInternal(c_lib::xla_builder); + +#[derive(Clone)] +#[pyclass(unsendable)] +pub struct XlaBuilder(Rc); + +impl XlaBuilder { + /// Create a new builder with the associated name, the name is only used for debugging + /// purposes. + pub fn new(name: &str) -> XlaBuilder { + let name = std::ffi::CString::new(name).unwrap(); + let xla_builder = unsafe { c_lib::xla_builder_create(name.as_ptr()) }; + XlaBuilder(Rc::new(XlaBuilderInternal(xla_builder))) + } + + fn ptr(&self) -> c_lib::xla_builder { + self.0 .0 + } + + /// Build a computation from the specified root node. This can only be called once. + pub fn build(&self, op: &XlaOp) -> Result { + let mut result: c_lib::xla_computation = std::ptr::null_mut(); + let status = unsafe { c_lib::build(self.ptr(), op.op, &mut result) }; + handle_status(status)?; + Ok(XlaComputation(result)) + } + + /// This returns `Ok(())` if the graph creation has not generated any error so far. Otherwise + /// the first error is returned. + pub fn first_error(&self) -> Result<()> { + let status = unsafe { c_lib::first_error(self.ptr()) }; + handle_status(status)?; + Ok(()) + } + + /// This returns `Ok(())` if the graph creation has not generated any error so far. Otherwise + /// the current status is returned. + pub fn get_current_status(&self) -> Result<()> { + let status = unsafe { c_lib::get_current_status(self.ptr()) }; + handle_status(status)?; + Ok(()) + } + + /// Create a node with a constant value defined by the specified literal. + pub fn constant_literal(&self, literal: &Literal) -> Result { + let op = unsafe { c_lib::constant_literal(self.ptr(), literal.0) }; + self.wrap(op) + } + + /// Create a node with a constant scalar value using the type of the element that is passed as + /// argument. + pub fn constant_r0(&self, f: T) -> Result { + let op = unsafe { T::constant_r0(self.ptr(), f) }; + self.wrap(op) + } + + /// A shorter notation for `constant_r0`. + pub fn c0(&self, f: T) -> Result { + self.constant_r0(f) + } + + pub fn wrap(&self, op: c_lib::xla_op) -> Result { + self.get_current_status()?; + Ok(XlaOp { op, builder: self.clone() }) + } + + /// Create an input node with the specified type and dimensions. A literal has to be passed for + /// each of the parameter in the graph when calling the `execute` function, the parameter + /// number are specified as incrementing values from 0 and represent the index of the + /// associated literal in the slice passed to `execute`. + pub fn parameter( + &self, + parameter_number: i64, + ty: super::ElementType, + dims: &[i64], + name: &str, + ) -> Result { + let name = std::ffi::CString::new(name).unwrap(); + let op = unsafe { + c_lib::parameter( + self.ptr(), + parameter_number, + ty.primitive_type() as i32, + dims.len() as i32, + dims.as_ptr(), + name.as_ptr(), + ) + }; + self.wrap(op) + } + + /// Read a single value from the implicit streaming interface of the device. + pub fn infeed(&self, ty: PrimitiveType, dims: &[i64], config: &str) -> Result { + let config = std::ffi::CString::new(config).unwrap(); + let op = unsafe { + c_lib::infeed(self.ptr(), ty as i32, dims.len() as i32, dims.as_ptr(), config.as_ptr()) + }; + self.wrap(op) + } + + pub fn parameter_s(&self, parameter_number: i64, shape: &Shape, name: &str) -> Result { + let c_shape = shape.c_shape()?; + let name = std::ffi::CString::new(name).unwrap(); + let op = unsafe { + c_lib::parameter_s(self.ptr(), parameter_number, c_shape.as_ptr(), name.as_ptr()) + }; + drop(c_shape); + self.wrap(op) + } + + pub fn constant_r1c(&self, f: T, len: usize) -> Result { + let op = unsafe { T::constant_r1c(self.ptr(), f, len) }; + self.wrap(op) + } + + /// A one dimension constant node based on some slice stored on the host. + pub fn constant_r1(&self, f: &[T]) -> Result { + let op = unsafe { T::constant_r1(self.ptr(), f.as_ptr(), f.len()) }; + self.wrap(op) + } + + /// Shorthand function for `constant_r1`. + pub fn c1(&self, f: &[T]) -> Result { + self.constant_r1(f) + } + + /// A scalar node with the zero value for the associated type. + pub fn zero(&self, ty: super::ElementType) -> Result { + let op = unsafe { c_lib::op_zero(self.ptr(), ty.primitive_type() as i32) }; + self.wrap(op) + } + + /// A scalar node with the one value for the associated type. + pub fn one(&self, ty: super::ElementType) -> Result { + let op = unsafe { c_lib::op_one(self.ptr(), ty.primitive_type() as i32) }; + self.wrap(op) + } + + /// A scalar node with the minimum value for the associated type. + pub fn min_value(&self, ty: super::ElementType) -> Result { + let op = unsafe { c_lib::op_min_value(self.ptr(), ty.primitive_type() as i32) }; + self.wrap(op) + } + + /// A scalar node with the maximum value for the associated type. + pub fn max_value(&self, ty: super::ElementType) -> Result { + let op = unsafe { c_lib::op_max_value(self.ptr(), ty.primitive_type() as i32) }; + self.wrap(op) + } + + /// A constant node with the specified shape that holds increasing values starting from 0 along + /// the iota dimension. + pub fn iota(&self, ty: super::ElementType, dims: &[i64], iota_dimension: i64) -> Result { + let op = unsafe { + c_lib::op_iota( + self.ptr(), + ty.primitive_type() as i32, + dims.len(), + dims.as_ptr(), + iota_dimension, + ) + }; + self.wrap(op) + } + + /// A constant node for a unidimensional array of increasing values starting from 0. + pub fn iota1(&self, ty: super::ElementType, size: usize) -> Result { + let op = unsafe { c_lib::op_iota1(self.ptr(), ty.primitive_type() as i32, size) }; + self.wrap(op) + } + + pub fn call(&self, computation: &XlaComputation, operands: &[XlaOp]) -> Result { + let operands: Vec<_> = operands.iter().map(|a| a.op).collect(); + let op = unsafe { + c_lib::op_call(self.ptr(), computation.0, operands.len(), operands.as_ptr()) + }; + self.wrap(op) + } + + pub fn map( + &self, + operands: &[XlaOp], + computation: &XlaComputation, + dims: &[i64], + static_operands: &[XlaOp] + ) -> Result { + let operands: Vec<_> = operands.iter().map(|a| a.op).collect(); + let static_operands: Vec<_> = static_operands.iter().map(|a| a.op).collect(); + let op = unsafe { + c_lib::op_map( + self.ptr(), + operands.len(), + operands.as_ptr(), + computation.0, + dims.len(), + dims.as_ptr(), + static_operands.len(), + static_operands.as_ptr(), + ) + }; + self.wrap(op) + } + + /// An error node, using the 'internal error' error type. + pub fn internal_error(&self, msg: &str) -> XlaOp { + let msg = std::ffi::CString::new(msg).unwrap(); + let op = unsafe { c_lib::op_internal_error(self.ptr(), msg.as_ptr()) }; + XlaOp { op, builder: self.clone() } + } + + /// An error node, using the 'unknown error' error type. + pub fn unknown_error(&self, msg: &str) -> XlaOp { + let msg = std::ffi::CString::new(msg).unwrap(); + let op = unsafe { c_lib::op_unknown_error(self.ptr(), msg.as_ptr()) }; + XlaOp { op, builder: self.clone() } + } + + /// An error node, using the 'invalid argument error' error type. + pub fn invalid_argument_error(&self, msg: &str) -> XlaOp { + let msg = std::ffi::CString::new(msg).unwrap(); + let op = unsafe { c_lib::op_invalid_argument_error(self.ptr(), msg.as_ptr()) }; + XlaOp { op, builder: self.clone() } + } + + /// Wrap a potential error in an error node. If the argument is `Ok(op)` then `op` is passed + /// back as the result. + pub fn wrap_error(&self, op: Result) -> XlaOp { + match op { + Ok(op) => op, + Err(err) => self.internal_error(&err.to_string()), + } + } + + /// The shape associated with this op. + pub fn get_shape(&self, op: &XlaOp) -> Result { + let mut out: c_lib::shape = std::ptr::null_mut(); + let status = unsafe { c_lib::get_shape(self.ptr(), op.op, &mut out) }; + handle_status(status)?; + let c_shape = super::shape::CShape::from_ptr(out); + c_shape.shape() + } + + /// The dimension sizes associated with this op. + pub fn get_dims(&self, op: &XlaOp) -> Result> { + let rank = self.get_dimensions_size(op)?; + let mut dims = vec![0; rank]; + let status = unsafe { c_lib::get_dimensions(self.ptr(), op.op, dims.as_mut_ptr()) }; + handle_status(status)?; + Ok(dims) + } + + /// The element type associated with this op. + pub fn get_primitive_type(&self, op: &XlaOp) -> Result { + let mut ty = 0i32; + let status = unsafe { c_lib::get_element_type(self.ptr(), op.op, &mut ty) }; + handle_status(status)?; + FromPrimitive::from_i32(ty).ok_or(Error::UnexpectedElementType(ty)) + } + + /// The number of dimensions (a.k.a the rank) associated with this op. + pub fn get_dimensions_size(&self, op: &XlaOp) -> Result { + let mut dsize = 0i32; + let status = unsafe { c_lib::get_dimensions_size(self.ptr(), op.op, &mut dsize) }; + handle_status(status)?; + Ok(dsize as usize) + } + + /// Build a tuple from multiple operands. + pub fn tuple>(&self, args: &[B]) -> Result { + let args: Vec<_> = args.iter().map(|a| a.borrow().op).collect(); + let op = unsafe { c_lib::op_tuple(self.ptr(), args.as_ptr(), args.len()) }; + self.wrap(op) + } +} + +impl Drop for XlaBuilderInternal { + fn drop(&mut self) { + unsafe { c_lib::xla_builder_free(self.0) } + } +} diff --git a/ivy/engines/XLA/rust_api/src/wrappers/xla_op.rs b/ivy/engines/XLA/rust_api/src/wrappers/xla_op.rs new file mode 100644 index 0000000000000..5a4094b6e497e --- /dev/null +++ b/ivy/engines/XLA/rust_api/src/wrappers/xla_op.rs @@ -0,0 +1,922 @@ +//! Nodes from the computation graph. +//! +//! An `XlaOp` value represents a node/operand in the computation graph, e.g. it can be the sum of two +//! other nodes, a constant value, an input parameter, etc. +//! +//! For details on the semantics, see +//! [operation_semantics](https://www.tensorflow.org/xla/operation_semantics). +use super::{ArrayShape, PrimitiveType, Shape, XlaBuilder, XlaComputation}; +use crate::{c_lib, Error, Result}; +use pyo3::prelude::*; + +#[pyclass(unsendable)] +pub struct XlaOp { + pub(super) op: c_lib::xla_op, + pub(super) builder: XlaBuilder, +} + +macro_rules! extract_dims { + ($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => { + pub fn $fn_name(&self) -> Result<$out_type> { + let dims = self.builder.get_dims(self)?; + if dims.len() != $cnt { + let dims: Vec<_> = dims.iter().map(|d| *d as i64).collect(); + Err(Error::UnexpectedNumberOfDims { expected: $cnt, got: dims.len(), dims }) + } else { + let dims = $dims(dims); + Ok(dims) + } + } + }; +} + +macro_rules! binary_op { + ($func_name:ident, $expression:expr) => { + pub fn $func_name(&self, op: &XlaOp) -> Result { + let op = unsafe { $expression(self.op, op.op) }; + self.wrap(op) + } + }; +} + +macro_rules! unary_op { + ($func_name:ident, $expression:expr) => { + pub fn $func_name(&self) -> Result { + let op = unsafe { $expression(self.op) }; + self.wrap(op) + } + }; +} + +impl Clone for XlaOp { + fn clone(&self) -> Self { + let op = unsafe { c_lib::op_clone(self.op) }; + Self { op, builder: self.builder.clone() } + } +} +impl XlaOp { + pub(super) fn wrap(&self, op: c_lib::xla_op) -> Result { + self.builder.get_current_status()?; + Ok(XlaOp { op, builder: self.builder.clone() }) + } + + pub fn builder(&self) -> &XlaBuilder { + &self.builder + } + + binary_op!(add_, c_lib::op_add); + binary_op!(sub_, c_lib::op_sub); + binary_op!(mul_, c_lib::op_mul); + binary_op!(div_, c_lib::op_div); + binary_op!(rem_, c_lib::op_rem); + binary_op!(max, c_lib::op_max); + binary_op!(min, c_lib::op_min); + binary_op!(and, c_lib::op_and); + binary_op!(or, c_lib::op_or); + binary_op!(xor, c_lib::op_xor); + binary_op!(atan2, c_lib::op_atan2); + binary_op!(pow, c_lib::op_pow); + binary_op!(dot, c_lib::op_dot); + binary_op!(eq, c_lib::op_eq); + binary_op!(ne, c_lib::op_ne); + binary_op!(ge, c_lib::op_ge); + binary_op!(gt, c_lib::op_gt); + binary_op!(le, c_lib::op_le); + binary_op!(lt, c_lib::op_lt); + binary_op!(lshift, c_lib::op_shift_left); + binary_op!(rshift_arith, c_lib::op_shift_right_arith); + binary_op!(rshift_logic, c_lib::op_shift_right_logic); + + unary_op!(population_count, c_lib::op_population_count); + unary_op!(not, c_lib::op_not); + unary_op!(abs, c_lib::op_abs); + unary_op!(exp, c_lib::op_exp); + unary_op!(expm1, c_lib::op_expm1); + unary_op!(floor, c_lib::op_floor); + unary_op!(ceil, c_lib::op_ceil); + unary_op!(round, c_lib::op_round); + unary_op!(round_nearest_even, c_lib::op_round_nearest_even); + unary_op!(log, c_lib::op_log); + unary_op!(log1p, c_lib::op_log1p); + unary_op!(logistic, c_lib::op_logistic); + unary_op!(sign, c_lib::op_sign); + unary_op!(clz, c_lib::op_clz); + unary_op!(cos, c_lib::op_cos); + unary_op!(sin, c_lib::op_sin); + unary_op!(tanh, c_lib::op_tanh); + unary_op!(real, c_lib::op_real); + unary_op!(imag, c_lib::op_imag); + unary_op!(conj, c_lib::op_conj); + unary_op!(square, c_lib::op_square); + unary_op!(sqrt, c_lib::op_sqrt); + unary_op!(rsqrt, c_lib::op_rsqrt); + unary_op!(cbrt, c_lib::op_cbrt); + unary_op!(is_finite, c_lib::op_is_finite); + unary_op!(neg, c_lib::op_neg); + unary_op!(lower_triangle, c_lib::op_lower_triangle); + unary_op!(upper_triangle, c_lib::op_upper_triangle); + unary_op!(erf, c_lib::op_erf); + unary_op!(copy, c_lib::op_copy); + unary_op!(zeros_like, c_lib::op_zeros_like); + + /// Sigmoid activation function. + /// + /// This computes the element-wise sigmoid. + pub fn sigmoid(&self) -> Result { + self.logistic() + } + + /// SiLU activation function. + /// + /// This computes the element-wise SiLU activation, x.sigmoid(x). + pub fn silu(&self) -> Result { + self * self.logistic() + } + + pub fn relu(&self) -> Result { + self.max(&self.zeros_like()?) + } + + pub fn gelu(&self) -> Result { + let prim_type = self.primitive_type()?; + let elem_type = prim_type.element_type()?; + let b = self.builder(); + let sqrt_two = b.c0(2)?.astype(prim_type)?.sqrt()?; + let one_half = b.c0(0.5)?.astype(prim_type)?; + let gauss_cdf = self.div_(&sqrt_two)?.erf()?.add_(&b.one(elem_type)?)?.mul_(&one_half)?; + self.mul_(&gauss_cdf) + } + + pub fn gelu_approx(&self) -> Result { + let prim_type = self.primitive_type()?; + let b = self.builder(); + let sqrt_two_over_pi = b.c0(2f32 / std::f32::consts::PI)?.astype(prim_type)?.sqrt()?; + let v = (sqrt_two_over_pi * ((b.c0(0.044715)?.astype(prim_type)? * self.pow(&b.c0(3f32)?)?)? + self)?)?; + (b.c0(0.5)?.astype(prim_type)? * self)? * (v.tanh()? + b.c0(1)?.astype(prim_type)?)? + } + + /// A node that applies the specified Einstein summation formula to this node. + pub fn einsum1(&self, config: &str) -> Result { + let config = std::ffi::CString::new(config).unwrap(); + let op = unsafe { c_lib::op_einsum1(self.op, config.as_ptr()) }; + self.wrap(op) + } + + /// A node that applies the specified Einstein summation formula to this node and the other + /// argument node. + pub fn einsum2(&self, rhs: &XlaOp, config: &str) -> Result { + let config = std::ffi::CString::new(config).unwrap(); + let op = unsafe { c_lib::op_einsum2(self.op, rhs.op, config.as_ptr()) }; + self.wrap(op) + } + + /// Reshape this node to a different set of dimension sizes, the number of element between the + /// two different shapes has to match. + pub fn reshape(&self, dims: &[i64]) -> Result { + let op = unsafe { c_lib::op_reshape(self.op, dims.len(), dims.as_ptr()) }; + self.wrap(op) + } + + pub fn dynamic_reshape( + &self, + dim_sizes: &[XlaOp], + new_size_bounds: &[i64], + dims_are_dynamic: Vec + ) -> Result { + let dim_sizes: Vec<_> = dim_sizes.iter().map(|a| a.op).collect(); + let op = unsafe {c_lib::op_dynamic_reshape( + self.op, dim_sizes.len(), dim_sizes.as_ptr(), + new_size_bounds.len(), new_size_bounds.as_ptr(), + dims_are_dynamic.as_ptr()) + }; + self.wrap(op) + } + + /// Add some broadcasting dimensions at the beginning of the current node shape. + pub fn broadcast(&self, dims: &[i64]) -> Result { + let op = unsafe { c_lib::op_broadcast(self.op, dims.len(), dims.as_ptr()) }; + self.wrap(op) + } + + /// Add some broadcasting dimensions at arbitrary positions. + /// + /// See the [semantics](https://www.tensorflow.org/xla/operation_semantics#broadcastindim). + pub fn broadcast_in_dim(&self, out_dims: &[i64], broadcast_dims: &[i64]) -> Result { + let op = unsafe { + c_lib::op_broadcast_in_dim( + self.op, + out_dims.len(), + out_dims.as_ptr(), + broadcast_dims.len(), + broadcast_dims.as_ptr(), + ) + }; + self.wrap(op) + } + + /// Collapse the dimensions of this node into a single dimension, [xla + /// documentation](https://www.tensorflow.org/xla/operation_semantics#collapse). + pub fn collapse(&self, dims: &[i64]) -> Result { + let op = unsafe { c_lib::op_collapse(self.op, dims.len(), dims.as_ptr()) }; + self.wrap(op) + } + + /// Permute the dimension with the specified indexes. + pub fn transpose(&self, index_perm: &[i64]) -> Result { + let op = unsafe { c_lib::op_transpose(self.op, index_perm.len(), index_perm.as_ptr()) }; + self.wrap(op) + } + + /// Permute two dimensions, this is a specialized version of `transpose`. + pub fn swap_dims(&self, index1: i64, index2: i64) -> Result { + let index1 = self.normalize_index(index1)?; + let index2 = self.normalize_index(index2)?; + let rank = self.rank()?; + let mut index_perm: Vec<_> = (0..rank as i64).collect(); + index_perm[index1 as usize] = index2; + index_perm[index2 as usize] = index1; + self.transpose(&index_perm) + } + + + pub fn pad(&self, padding_value: &XlaOp, padding_config: Vec<(i64, i64, i64)>) -> Result { + let lows: Vec<_> = padding_config.iter().map(|x| x.0).collect(); + let highs: Vec<_> = padding_config.iter().map(|x| x.1).collect(); + let interiors: Vec<_> = padding_config.iter().map(|x| x.2).collect(); + let op = unsafe {c_lib::op_pad( + self.op, padding_value.op, padding_config.len(), + lows.as_ptr(), highs.as_ptr(), interiors.as_ptr()) + }; + self.wrap(op) + } + + pub fn pad_in_dim(&self, padding_value: &XlaOp, dinmo: i64, pad_low: i64, pad_high: i64) -> Result { + let op = unsafe {c_lib::op_pad_in_dim(self.op, padding_value.op, dinmo, pad_low, pad_high)}; + self.wrap(op) + } + + pub fn slice(&self, start_indices: &[i64], limit_indices: &[i64], strides: &[i64]) -> Result { + let op = unsafe {c_lib::op_slice( + self.op,start_indices.len(),start_indices.as_ptr(), + limit_indices.len(),limit_indices.as_ptr(), + strides.len(),strides.as_ptr()) + }; + self.wrap(op) + } + + /// Create a node that has a partial view on the data of the original node. Indexes on the + /// target dimension `dim` are restricted to the values between `start_index` (inclusive) and + /// `stop_index` (exclusive), using the associated `stride` as a step between two values. + pub fn slice_in_dim( + &self, + start_index: i64, + stop_index: i64, + stride: i64, + dim: i64, + ) -> Result { + let dim = self.normalize_index(dim)?; + let op = unsafe { c_lib::op_slice_in_dim(self.op, start_index, stop_index, stride, dim) }; + self.wrap(op) + } + + /// A specialized version of `slice_in_dim` using a stride of one, so with all values with an + /// index between `start_index` (inclusive) and `stop_index` (exclusive). + pub fn slice_in_dim1(&self, start_index: i64, stop_index: i64, dim: i64) -> Result { + self.slice_in_dim(start_index, stop_index, 1, dim) + } + + pub fn dynamic_slice( + &self, + start_indices: &[XlaOp], + slice_indices: &[i64], + ) -> Result { + let start_indices: Vec<_> = start_indices.iter().map(|a| a.op).collect(); + let op = unsafe { c_lib::op_dynamic_slice( + self.op, start_indices.len(), start_indices.as_ptr(), + slice_indices.len(), slice_indices.as_ptr()) + }; + self.wrap(op) + } + + pub fn dynamic_update_slice( + &self, + update: &XlaOp, + start_indices: &[XlaOp], + ) -> Result { + let start_indices: Vec<_> = start_indices.iter().map(|a| a.op).collect(); + let op = unsafe { c_lib::op_dynamic_update_slice( + self.op, update.op, start_indices.len(), start_indices.as_ptr()) + }; + self.wrap(op) + } + + /// A new node containing only values for index `index_in_dim` on the dimension `dim_index`. + /// The target dimension is squeezed so the resulting node has one less dimension than the + /// original node. + pub fn at(&self, index_in_dim: i64, dim_index: i64) -> Result { + let slice = self.slice_in_dim(index_in_dim, index_in_dim + 1, 1, dim_index)?; + slice.squeeze(dim_index) + } + + /// Squeeze the dimension as the target index, i.e. if this dimension has size one remove it + /// for the generated node. The target dimension index can be specified as a negative value, + /// e.g. -1 for the last dimension. + pub fn squeeze(&self, index: i64) -> Result { + let index = self.normalize_index(index)?; + let dims = self.dims()?; + let mut new_dims = vec![]; + for (i, d) in dims.iter().enumerate() { + if i as i64 != index || *d != 1 { + new_dims.push(*d as i64) + } + } + self.reshape(&new_dims) + } + + /// Concat multiple nodes (together with the `self` node) along the target dimension. + pub fn concat_in_dim( + &self, + args: &[XlaOp], + dim: i64, + ) -> Result { + let dim = self.normalize_index(dim)?; + let args: Vec<_> = args.iter().map(|a| a.op).collect(); + let op = unsafe { c_lib::op_concat_in_dim(self.op, args.as_ptr(), args.len(), dim) }; + self.wrap(op) + } + + /// Index into tuples. + pub fn get_tuple_element(&self, index: i64) -> Result { + let op = unsafe { c_lib::op_get_tuple_element(self.op, index) }; + self.wrap(op) + } + + /// Clamp the values in the original node to be between `min` and `max`. + pub fn clamp(&self, min: &Self, max: &Self) -> Result { + let op = unsafe { c_lib::op_clamp(min.op, self.op, max.op) }; + self.wrap(op) + } + + /// Select values from the original tensor to be values from `on_true` if the associated + /// value in `self` is true, and the values from `on_false` otherwise. + pub fn select(&self, on_true: &Self, on_false: &Self) -> Result { + let op = unsafe { c_lib::op_select(self.op, on_true.op, on_false.op) }; + self.wrap(op) + } + + /// A node that when executed generates values using a random uniform distribution. + pub fn rng_uniform(min: &Self, max: &Self, shape: &ArrayShape) -> Result { + let dims = shape.dims(); + let op = unsafe { + c_lib::op_rng_uniform( + min.op, + max.op, + shape.primitive_type() as i32, + dims.len() as i32, + dims.as_ptr(), + ) + }; + min.wrap(op) + } + + /// A node that when executed generates values using a random normal distribution. + pub fn rng_normal(mu: &Self, sigma: &Self, shape: &ArrayShape) -> Result { + let dims = shape.dims(); + let op = unsafe { + c_lib::op_rng_normal( + mu.op, + sigma.op, + shape.primitive_type() as i32, + dims.len() as i32, + dims.as_ptr(), + ) + }; + mu.wrap(op) + } + + /// Create a new node by casting the elements of the original node to a new primitive type. + pub fn astype(&self, ty: PrimitiveType) -> Result { + let op = unsafe { c_lib::op_convert_element_type(self.op, ty as i32) }; + self.wrap(op) + } + + fn normalize_indexes(&self, indexes: &[i64]) -> Result> { + let rank = self.rank()?; + indexes + .iter() + .map(|&index| { + if index >= rank as i64 { + Err(Error::IndexOutOfBounds { index, rank }) + } else if index >= 0 { + Ok(index) + } else if index + rank as i64 >= 0 { + Ok(index + rank as i64) + } else { + Err(Error::IndexOutOfBounds { index, rank }) + } + }) + .collect() + } + + fn normalize_index(&self, index: i64) -> Result { + let rank = self.rank()?; + if index >= rank as i64 { + Err(Error::IndexOutOfBounds { index, rank }) + } else if index >= 0 { + Ok(index) + } else if index + rank as i64 >= 0 { + Ok(index + rank as i64) + } else { + Err(Error::IndexOutOfBounds { index, rank }) + } + } + + /// A node that contains the size of the dimension with the target index as a `S32` scalar + /// value. + pub fn dimensions_size(&self, index: i64) -> Result { + let index = self.normalize_index(index)?; + let op = unsafe { c_lib::op_dimensions_size(self.op, index) }; + self.wrap(op) + } + + /// Create a node by folding a computation across some target dimensions. If `keep_dims` is + /// `true`, the resulting node has a dimension of size one for the target dimensions, when + /// using `false` these dimensions are squeezed so the resulting node has a rank that is the + /// original node rank minus the number of elements in `dims`. + pub fn reduce( + &self, + init_value: Self, + comp: &XlaComputation, + dims: &[i64], + keep_dims: bool, + ) -> Result { + let dims = self.normalize_indexes(dims)?; + let op = + unsafe { c_lib::op_reduce(self.op, init_value.op, comp.0, dims.as_ptr(), dims.len()) }; + let op = self.wrap(op)?; + self.maybe_keep_dims(op, &dims, keep_dims) + } + + /// Sequentially execute `body` until `cond` fails. + /// + /// - `init` argument has a type `T`. + /// - `cond` is a computation with a single argument of type `T` producing a value of type + /// `PRED`. + /// - `body` is a computation with a single argument of type `T` producing a value of type + /// `T`. + pub fn while_(cond: &XlaComputation, body: &XlaComputation, init: Self) -> Result { + let op = unsafe { c_lib::op_while(cond.0, body.0, init.op) }; + init.wrap(op) + } + + /// Execute `true_comp` if `self` is true, `false_comp` if `self` is false, and return the result. + /// `self` has to be a scalar of type `PRED`. + /// `true_op` is used as the single argument to `true_comp` and `false_op` as the single + /// argument to `false_comp`. + pub fn conditional( + &self, + true_op: Self, + true_comp: &XlaComputation, + false_op: Self, + false_comp: &XlaComputation, + ) -> Result { + let op = unsafe { + c_lib::op_conditional(self.op, true_op.op, true_comp.0, false_op.op, false_comp.0) + }; + self.wrap(op) + } + + pub fn conv( + &self, + rhs: &XlaOp, + window_strides: &[i64], + padding: &str, + feature_group_count: i64, + batch_group_count: i64 + ) -> Result { + let padding_config = std::ffi::CString::new(padding).unwrap(); + let op = unsafe { + c_lib::op_conv( + self.op, rhs.op, + window_strides.len(), + window_strides.as_ptr(), + padding_config.as_ptr(), + feature_group_count, + batch_group_count + ) + }; + self.wrap(op) + } + + pub fn conv_general_dilated( + &self, + rhs: &XlaOp, + window_strides: &[i64], + padding: &[(i64, i64)], + lhs_dilations: &[i64], + rhs_dilations: &[i64], + input_batch_dim: &i64, + input_feature_dim: &i64, + input_spatial_dims: &[i64], + output_batch_dim: &i64, + output_feature_dim: &i64, + output_spatial_dims: &[i64], + kernel_input_feature_dim: &i64, + kernel_output_feature_dim: &i64, + kernel_spatial_dims: &[i64], + feature_group_count: i64, + batch_group_count: i64 + ) -> Result { + let padding: Vec = padding.iter().flat_map(|(a, b)| vec![*a, *b]).collect(); + let op = unsafe { + c_lib::op_conv_general_dilated( + self.op, + rhs.op, + window_strides.len(), + window_strides.as_ptr(), + padding.len() / 2, + padding.as_ptr(), + lhs_dilations.len(), + lhs_dilations.as_ptr(), + rhs_dilations.len(), + rhs_dilations.as_ptr(), + input_batch_dim, + input_feature_dim, + input_spatial_dims.len(), + input_spatial_dims.as_ptr(), + output_batch_dim, + output_feature_dim, + output_spatial_dims.len(), + output_spatial_dims.as_ptr(), + kernel_input_feature_dim, + kernel_output_feature_dim, + kernel_spatial_dims.len(), + kernel_spatial_dims.as_ptr(), + feature_group_count, + batch_group_count, + ) + }; + self.wrap(op) + } + + pub fn batch_norm_inference( + &self, + scale: &XlaOp, + offset: &XlaOp, + mean: &XlaOp, + variance: &XlaOp, + epsilon: f32, + feature_index: i64, + ) -> Result { + let op = unsafe { + c_lib::op_batch_norm_inference( + self.op, + scale.op, + offset.op, + mean.op, + variance.op, + epsilon, + feature_index, + ) + }; + self.wrap(op) + } + + pub fn outfeed(&self, ty: PrimitiveType, dims: &[i64], config: &str) { + let config = std::ffi::CString::new(config).unwrap(); + unsafe { + c_lib::outfeed(self.op, ty as i32, dims.len() as i32, dims.as_ptr(), config.as_ptr()) + } + } + + /// The kind of elements that are computed by this operand. + pub fn primitive_type(&self) -> Result { + self.builder.get_primitive_type(self) + } + + /// The kind of elements that are computed by this operand, shortcut for `primitive_type`. + pub fn ty(&self) -> Result { + self.primitive_type() + } + + /// The number of dimensions for this node. + pub fn rank(&self) -> Result { + self.builder.get_dimensions_size(self) + } + + pub fn shape(&self) -> Result { + self.builder.get_shape(self) + } + + pub fn array_shape(&self) -> Result { + ArrayShape::try_from(&self.builder.get_shape(self)?) + } + + pub fn dims(&self) -> Result> { + self.builder.get_dims(self) + } + + extract_dims!(dim1, 1, |d: Vec| d[0], usize); + extract_dims!(dim2, 2, |d: Vec| (d[0], d[1]), (usize, usize)); + extract_dims!(dim3, 3, |d: Vec| (d[0], d[1], d[2]), (usize, usize, usize)); + extract_dims!(dim4, 4, |d: Vec| (d[0], d[1], d[2], d[3]), (usize, usize, usize, usize)); + extract_dims!( + dim5, + 5, + |d: Vec| (d[0], d[1], d[2], d[3], d[4]), + (usize, usize, usize, usize, usize) + ); + + /// General dot multiplication between two nodes, specifying the dimensions that get contracted + /// as well as the batch dimensions. + pub fn dot_general( + &self, + rhs: &XlaOp, + lhs_contracting_dims: &[i64], + rhs_contracting_dims: &[i64], + lhs_batch_dims: &[i64], + rhs_batch_dims: &[i64], + ) -> Result { + let op = unsafe { + c_lib::op_dot_general( + self.op, + rhs.op, + lhs_contracting_dims.as_ptr(), + lhs_contracting_dims.len(), + rhs_contracting_dims.as_ptr(), + rhs_contracting_dims.len(), + lhs_batch_dims.as_ptr(), + lhs_batch_dims.len(), + rhs_batch_dims.as_ptr(), + rhs_batch_dims.len(), + ) + }; + self.wrap(op) + } + + pub fn gather( + &self, + start_indices: &XlaOp, + offset_dims: &[i64], + collapsed_slice_dims: &[i64], + start_index_map: &[i64], + set_index_vector_dim: Option, + slice_sizes: &[i64], + ) -> Result { + let set_index_vector_dim_ptr = + set_index_vector_dim.as_ref().map(|p| p as *const _).unwrap_or(std::ptr::null()); + let op = unsafe { + c_lib::op_gather( + self.op, + start_indices.op, + offset_dims.as_ptr(), + offset_dims.len(), + collapsed_slice_dims.as_ptr(), + collapsed_slice_dims.len(), + start_index_map.as_ptr(), + start_index_map.len(), + set_index_vector_dim_ptr, + slice_sizes.as_ptr(), + slice_sizes.len(), + ) + }; + self.wrap(op) + } + + pub fn scatter( + operands: &[XlaOp], + scatter_indices: &XlaOp, + updates: &[XlaOp], + update_computation: &XlaComputation, + update_window_dims: &[i64], + inserted_window_dims: &[i64], + scatter_dims_to_operand_dims: &[i64], + index_vector_dim: i64 + ) -> Result { + let operands: Vec<_> = operands.iter().map(|a| a.op).collect(); + let updates: Vec<_> = updates.iter().map(|a| a.op).collect(); + let op = unsafe { + c_lib::op_scatter( + operands.len(), + operands.as_ptr(), + scatter_indices.op, + updates.len(), + updates.as_ptr(), + update_computation.0, + update_window_dims.len(), + update_window_dims.as_ptr(), + inserted_window_dims.len(), + inserted_window_dims.as_ptr(), + scatter_dims_to_operand_dims.len(), + scatter_dims_to_operand_dims.as_ptr(), + index_vector_dim + ) + }; + scatter_indices.wrap(op) + } + + pub fn take(&self, indices: &XlaOp, axis: i64) -> Result { + let axis = self.normalize_index(axis)?; + let shape = self.array_shape()?; + let indices_shape = indices.array_shape()?; + let index_dims = indices_shape.dims(); + let dims = shape.dims(); + let offset_dims: Vec<_> = (0..((dims.len() + index_dims.len()) as i64 - 1)) + .filter(|x| *x < axis || *x >= axis + index_dims.len() as i64) + .collect(); + let mut slice_sizes: Vec<_> = dims.to_vec(); + slice_sizes[axis as usize] = 1; + let mut index_dims_plus_1 = index_dims.to_vec(); + index_dims_plus_1.push(1); + let indices = indices.reshape(&index_dims_plus_1)?; + // Same as in Jax: always use the last dimension for index_vector_dim. + let index_vector_dim = Some(index_dims.len() as i64); + self.gather(&indices, &offset_dims, &[axis], &[axis], index_vector_dim, &slice_sizes) + } + + fn maybe_keep_dims(&self, res: XlaOp, dims_to_keep: &[i64], keep_dims: bool) -> Result { + if keep_dims && !dims_to_keep.is_empty() { + let shape = self.array_shape()?; + let mut dims = shape.dims().to_vec(); + for d in dims_to_keep.iter() { + dims[*d as usize] = 1; + } + res.reshape(&dims) + } else { + Ok(res) + } + } + + /// A node that computes the sum across the specified dimensions, e.g. if all the dimensions + /// are passed as an argument the result is a scalar with the sum of all the elements in the + /// original node. + pub fn reduce_sum(&self, dims: &[i64], keep_dims: bool) -> Result { + let builder = XlaBuilder::new("Sum"); + let ty = self.primitive_type()?.element_type()?; + let x = builder.parameter(0, ty, &[], "x")?; + let y = builder.parameter(1, ty, &[], "y")?; + let sum = x.add_(&y)?.build()?; + let init_value = self.builder.zero(ty)?; + self.reduce(init_value, &sum, dims, keep_dims) + } + + /// A node that computes the average value across the specified dimensions. + pub fn reduce_mean(&self, dims: &[i64], keep_dims: bool) -> Result { + let b = &self.builder(); + let ty = self.primitive_type()?; + let mut scale = b.one(crate::ElementType::S32)?; + for d in dims.iter() { + scale = (scale * self.dimensions_size(*d)?)?; + } + let sum = self.reduce_sum(dims, keep_dims)?; + sum / scale.astype(ty)? + } + + /// A node that computes the maximum value across the specified dimensions. + pub fn reduce_max(&self, dims: &[i64], keep_dims: bool) -> Result { + let builder = XlaBuilder::new("Max"); + let ty = self.primitive_type()?.element_type()?; + let x = builder.parameter(0, ty, &[], "x")?; + let y = builder.parameter(1, ty, &[], "y")?; + let sum = x.max(&y)?.build()?; + let init_value = self.builder.min_value(ty)?; + self.reduce(init_value, &sum, dims, keep_dims) + } + + /// A node that computes the minimum value across the specified dimensions. + pub fn reduce_min(&self, dims: &[i64], keep_dims: bool) -> Result { + let builder = XlaBuilder::new("Min"); + let ty = self.primitive_type()?.element_type()?; + let x = builder.parameter(0, ty, &[], "x")?; + let y = builder.parameter(1, ty, &[], "y")?; + let sum = x.min(&y)?.build()?; + let init_value = self.builder.max_value(ty)?; + self.reduce(init_value, &sum, dims, keep_dims) + } + + pub fn softmax(&self, dim: i64) -> Result { + let max = self.reduce_max(&[dim], true)?; + let unnormalized = (self - max)?.exp()?; + let sum = unnormalized.reduce_sum(&[dim], true)?; + unnormalized / sum + } + + /// Layer normalization, this normalizes values on the target dimension to be of zero mean and + /// standard deviation one, and then scales the result by `scale` and adds `bias`. + pub fn layer_norm(&self, dims: &[i64], scale: &XlaOp, bias: &XlaOp, eps: f64) -> Result { + let ty = self.primitive_type().unwrap_or(PrimitiveType::F32); + let eps = self.builder().c0(eps)?.astype(ty)?; + let mean = self.reduce_mean(&dims, true)?; + let mean2 = (self * self)?.reduce_mean(&dims, true)?; + let var = (mean2 - (&mean * &mean)?)?; + let mul = (var + eps)?.rsqrt()?; + bias + ((self - mean)? * mul)? * scale + } + + /// Matrix multiplication, this is a specialized version of `dot_general` to be used for + /// matrix-matrix or matrix-vector multiplications. + pub fn matmul(&self, rhs: &Self) -> Result { + // Similar to the jax implementation but without the squeezing. + // https://github.com/google/jax/blob/849e47f79ac64ccba1a762804217c00a9905025b/jax/_src/numpy/lax_numpy.py#L3028 + let lhs_shape = self.array_shape()?; + let rhs_shape = self.array_shape()?; + let lhs_dims = lhs_shape.dims(); + let rhs_dims = rhs_shape.dims(); + let lhs_ndims = lhs_dims.len(); + let rhs_ndims = rhs_dims.len(); + if lhs_ndims < 1 || rhs_ndims < 1 { + Err(Error::MatMulIncorrectDims { + lhs_dims: lhs_dims.to_vec(), + rhs_dims: rhs_dims.to_vec(), + msg: "empty dimension", + })? + } + + let rhs_is_mat = rhs_ndims > 1; + let lhs_batch_ndims = lhs_ndims.saturating_sub(2); + let rhs_batch_ndims = rhs_ndims.saturating_sub(2); + let max_ndims = usize::max(lhs_batch_ndims, rhs_batch_ndims); + let mut lhs_batch_dims = vec![]; + let mut rhs_batch_dims = vec![]; + for idx in 0..max_ndims { + let lhs_idx = (idx + lhs_batch_ndims) as i64 - max_ndims as i64; + let rhs_idx = (idx + rhs_batch_ndims) as i64 - max_ndims as i64; + // Only one of lhs_idx and rhs_idx can be negative. + if lhs_idx < 0 && rhs_idx < 0 { + panic!("internal error: negative dim idxs {lhs_dims:?} {rhs_dims:?}") + } else if lhs_idx < 0 && rhs_idx >= 0 { + rhs_batch_dims.push(rhs_idx) + } else if lhs_idx >= 0 && rhs_idx < 0 { + lhs_batch_dims.push(lhs_idx) + } else if lhs_dims[lhs_idx as usize] == rhs_dims[rhs_idx as usize] { + lhs_batch_dims.push(lhs_idx); + rhs_batch_dims.push(rhs_idx); + } else { + Err(Error::MatMulIncorrectDims { + lhs_dims: lhs_dims.to_vec(), + rhs_dims: rhs_dims.to_vec(), + msg: "incompatible batch dimensions", + })? + } + } + self.dot_general( + rhs, + &[lhs_ndims as i64 - 1], + &[rhs_ndims as i64 - 1 - i64::from(rhs_is_mat)], + &lhs_batch_dims, + &rhs_batch_dims, + ) + } + + /// Generate a computation which root value is this node. + pub fn build(&self) -> Result { + self.builder.build(self) + } +} + +impl Drop for XlaOp { + fn drop(&mut self) { + unsafe { c_lib::xla_op_free(self.op) } + } +} + +macro_rules! bin_trait { + ($trait:ident, $fn1:ident, $fn2:ident) => { + impl> std::ops::$trait for XlaOp { + type Output = Result; + + fn $fn1(self, rhs: B) -> Self::Output { + (&self).$fn1(rhs) + } + } + + impl> std::ops::$trait for &XlaOp { + type Output = Result; + + fn $fn1(self, rhs: B) -> Self::Output { + self.$fn2(rhs.borrow()) + } + } + + impl> std::ops::$trait> for XlaOp { + type Output = Result; + + fn $fn1(self, rhs: Result) -> Self::Output { + (&self).$fn1(rhs) + } + } + + impl> std::ops::$trait> for &XlaOp { + type Output = Result; + + fn $fn1(self, rhs: Result) -> Self::Output { + self.$fn2(rhs?.borrow()) + } + } + }; +} + +bin_trait!(Add, add, add_); +bin_trait!(Sub, sub, sub_); +bin_trait!(Mul, mul, mul_); +bin_trait!(Div, div, div_); diff --git a/ivy/engines/XLA/rust_api/xla_rs/xla_rs.cc b/ivy/engines/XLA/rust_api/xla_rs/xla_rs.cc new file mode 100644 index 0000000000000..8eabce7517d9f --- /dev/null +++ b/ivy/engines/XLA/rust_api/xla_rs/xla_rs.cc @@ -0,0 +1,1402 @@ +#include "xla_rs.h" + +#define ASSIGN_OR_RETURN_STATUS(lhs, rexpr) \ + ASSIGN_OR_RETURN_STATUS_IMPL( \ + TF_STATUS_MACROS_CONCAT_NAME(_statusor, __COUNTER__), lhs, rexpr) + +#define ASSIGN_OR_RETURN_STATUS_IMPL(statusor, lhs, rexpr) \ + auto statusor = (rexpr); \ + if (!statusor.ok()) \ + return new Status(statusor.status()); \ + auto lhs = std::move(statusor.value()); + +#define MAYBE_RETURN_STATUS(rexpr) \ + MAYBE_RETURN_STATUS_IMPL(TF_STATUS_MACROS_CONCAT_NAME(_status, __COUNTER__), \ + rexpr) + +#define MAYBE_RETURN_STATUS_IMPL(statusor, rexpr) \ + auto statusor = (rexpr); \ + if (!statusor.ok()) \ + return new Status(statusor); + +#define BEGIN_PROTECT_OP try { +#define END_PROTECT_OP_B(builder) \ + } \ + catch (std::exception e) { \ + return new XlaOp(builder->ReportError(tsl::errors::Internal(e.what()))); \ + } +#define END_PROTECT_OP(arg) \ + } \ + catch (std::exception e) { \ + return new XlaOp( \ + arg->builder()->ReportError(tsl::errors::Internal(e.what()))); \ + } + +status pjrt_cpu_client_create(pjrt_client *output) { + ASSIGN_OR_RETURN_STATUS(client, xla::GetTfrtCpuClient(false)); + *output = new std::shared_ptr(std::move(client)); + return nullptr; +} + +status pjrt_gpu_client_create(pjrt_client *output, double memory_fraction, + bool preallocate) { + xla::GpuAllocatorConfig allocator = {.memory_fraction = memory_fraction, + .preallocate = preallocate}; + ASSIGN_OR_RETURN_STATUS( + client, xla::GetStreamExecutorGpuClient(false, allocator, nullptr, 0)); + *output = new std::shared_ptr(std::move(client)); + return nullptr; +} + +status pjrt_tpu_client_create(pjrt_client *output, + int max_inflight_computations) { + ASSIGN_OR_RETURN_STATUS(client, xla::GetTpuClient(max_inflight_computations)); + *output = new std::shared_ptr(std::move(client)); + return nullptr; +} + +int pjrt_client_device_count(pjrt_client c) { return (*c)->device_count(); } + +int pjrt_client_addressable_device_count(pjrt_client c) { + return (*c)->addressable_device_count(); +} + +void pjrt_client_devices(pjrt_client c, pjrt_device *outputs) { + size_t index = 0; + for (auto device : (*c)->devices()) { + outputs[index++] = device; + } +} + +void pjrt_client_addressable_devices(pjrt_client c, pjrt_device *outputs) { + size_t index = 0; + for (auto device : (*c)->addressable_devices()) { + outputs[index++] = device; + } +} + +char *pjrt_client_platform_name(pjrt_client c) { + // TODO: Avoid the double allocation when converting string views. + return strdup(std::string((*c)->platform_name()).c_str()); +} + +char *pjrt_client_platform_version(pjrt_client c) { + return strdup(std::string((*c)->platform_version()).c_str()); +} + +void pjrt_client_free(pjrt_client b) { delete b; } + +void pjrt_loaded_executable_free(pjrt_loaded_executable b) { delete b; } + +status pjrt_buffer_from_host_buffer(const pjrt_client client, + const pjrt_device device, const void *d, + int pr_type, int dsize, const int64_t *ds, + pjrt_buffer *output) { + PjRtDevice *device_ = device == nullptr ? (*client)->devices()[0] : device; + ASSIGN_OR_RETURN_STATUS( + buffer, + (*client)->BufferFromHostBuffer( + d, (PrimitiveType)pr_type, absl::Span(ds, dsize), {}, + PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, []() {}, + device_)); + *output = buffer.release(); + return nullptr; +} + +status pjrt_buffer_from_host_literal(const pjrt_client client, + const pjrt_device device, const literal l, + pjrt_buffer *output) { + PjRtDevice *d = device == nullptr ? (*client)->devices()[0] : device; + ASSIGN_OR_RETURN_STATUS(buffer, (*client)->BufferFromHostLiteral(*l, d)); + *output = buffer.release(); + return nullptr; +} + +status pjrt_buffer_to_literal_sync(pjrt_buffer b, literal *output) { + ASSIGN_OR_RETURN_STATUS(literal, b->ToLiteralSync()); + *output = new Literal(); + **output = std::move(*literal); + return nullptr; +} + +shape pjrt_buffer_on_device_shape(pjrt_buffer b) { + return new Shape(b->on_device_shape()); +} + +status pjrt_buffer_copy_to_device(pjrt_buffer b, pjrt_device device, + pjrt_buffer *output) { + ASSIGN_OR_RETURN_STATUS(copied_b, b->CopyToDevice(device)); + *output = copied_b.release(); + return nullptr; +} + +status pjrt_buffer_copy_raw_to_host_sync(pjrt_buffer b, void *dst, + size_t offset, size_t transfer_size) { + MAYBE_RETURN_STATUS(b->CopyRawToHost(dst, offset, transfer_size).Await()); + return nullptr; +} + +void pjrt_buffer_free(pjrt_buffer b) { delete b; } + +int pjrt_device_id(pjrt_device d) { return d->id(); } + +int pjrt_device_process_index(pjrt_device d) { return d->process_index(); } + +int pjrt_device_local_hardware_id(pjrt_device d) { + return d->local_hardware_id(); +} + +status pjrt_device_transfer_to_infeed(pjrt_device d, const literal l) { + MAYBE_RETURN_STATUS(d->TransferToInfeed(*l)); + return nullptr; +} + +status pjrt_device_transfer_from_outfeed(pjrt_device d, literal l) { + MAYBE_RETURN_STATUS(d->TransferFromOutfeed(l)); + return nullptr; +} + +char *pjrt_device_kind(pjrt_device d) { + return strdup(std::string(d->device_kind()).c_str()); +} + +char *pjrt_device_debug_string(pjrt_device d) { + return strdup(std::string(d->DebugString()).c_str()); +} + +char *pjrt_device_to_string(pjrt_device d) { + return strdup(std::string(d->ToString()).c_str()); +} + +xla_builder xla_builder_create(const char *name) { + return new XlaBuilder(name); +} + +void xla_builder_free(xla_builder b) { delete b; } + +xla_op constant_literal(const xla_builder b, const literal l) { + BEGIN_PROTECT_OP + return new XlaOp(ConstantLiteral(b, *l)); + END_PROTECT_OP_B(b) +} + +#define CONST_OP_R01(native_type, primitive_type) \ + xla_op constant_r0_##native_type(const xla_builder b, native_type f) { \ + return new XlaOp(ConstantR0(b, f)); \ + } \ + xla_op constant_r1c_##native_type(const xla_builder b, native_type f, \ + size_t len) { \ + return new XlaOp(ConstantR1(b, len, f)); \ + } \ + xla_op constant_r1_##native_type(const xla_builder b, const native_type *f, \ + size_t len) { \ + return new XlaOp( \ + ConstantR1(b, absl::Span(f, len))); \ + } \ + literal create_r0_##native_type(native_type f) { \ + return new Literal(LiteralUtil::CreateR0(f)); \ + } \ + literal create_r1_##native_type(const native_type *f, size_t nel) { \ + return new Literal(LiteralUtil::CreateR1( \ + absl::Span(f, nel))); \ + } \ + native_type literal_get_first_element_##native_type(const literal l) { \ + return l->GetFirstElement(); \ + } + +FOR_EACH_NATIVE_TYPE(CONST_OP_R01) +#undef CONST_OP_R01 + +Shape make_shape_internal(int pr_type, int dsize, const int64_t *ds) { + bool has_negative_dim = false; + for (int i = 0; i < dsize; ++i) { + if (ds[i] < 0) { + has_negative_dim = true; + break; + } + } + Shape shape; + if (has_negative_dim) { + std::vector dynamic; + std::vector bounds; + for (int i = 0; i < dsize; ++i) { + if (ds[i] < 0) { + bounds.push_back(-ds[i]); + dynamic.push_back(true); + } else { + bounds.push_back(ds[i]); + dynamic.push_back(false); + } + } + shape = ShapeUtil::MakeShape( + (PrimitiveType)pr_type, + absl::Span(bounds.data(), bounds.size()), dynamic); + } else { + shape = ShapeUtil::MakeShape((PrimitiveType)pr_type, + absl::Span(ds, dsize)); + } + return shape; +} + +shape make_shape_array(int pr_type, size_t dsize, const int64_t *ds) { + return new Shape(make_shape_internal(pr_type, dsize, ds)); +} + +shape make_shape_tuple(size_t dsize, const shape *ds) { + std::vector elts; + for (size_t i = 0; i < dsize; ++i) { + elts.push_back(*ds[i]); + } + return new Shape(ShapeUtil::MakeTupleShape(elts)); +} + +xla_op parameter(const xla_builder b, int64_t id, int pr_type, int dsize, + const int64_t *ds, const char *name) { + BEGIN_PROTECT_OP + Shape shape = make_shape_internal(pr_type, dsize, ds); + return new XlaOp(Parameter(b, id, shape, std::string(name))); + END_PROTECT_OP_B(b) +} + +xla_op parameter_s(const xla_builder b, int64_t id, const shape s, + const char *name) { + BEGIN_PROTECT_OP + return new XlaOp(Parameter(b, id, *s, std::string(name))); + END_PROTECT_OP_B(b) +} + +xla_op infeed(const xla_builder b, int pr_type, int dsize, const int64_t *ds, + const char *config) { + BEGIN_PROTECT_OP + Shape shape = make_shape_internal(pr_type, dsize, ds); + return new XlaOp(Infeed(b, shape, std::string(config))); + END_PROTECT_OP_B(b) +} + +void outfeed(const xla_op op, int pr_type, int dsize, const int64_t *ds, + const char *outfeed_config) { + Shape shape = make_shape_internal(pr_type, dsize, ds); + Outfeed(*op, shape, std::string(outfeed_config)); +} + +xla_op op_add(const xla_op lhs, const xla_op rhs) { + BEGIN_PROTECT_OP + return new XlaOp(Add(*lhs, *rhs)); + END_PROTECT_OP(lhs) +} + +xla_op op_sub(const xla_op lhs, const xla_op rhs) { + BEGIN_PROTECT_OP + return new XlaOp(Sub(*lhs, *rhs)); + END_PROTECT_OP(lhs) +} + +xla_op op_mul(const xla_op lhs, const xla_op rhs) { + BEGIN_PROTECT_OP + return new XlaOp(Mul(*lhs, *rhs)); + END_PROTECT_OP(lhs) +} + +xla_op op_div(const xla_op lhs, const xla_op rhs) { + BEGIN_PROTECT_OP + return new XlaOp(Div(*lhs, *rhs)); + END_PROTECT_OP(lhs) +} + +xla_op op_rem(const xla_op lhs, const xla_op rhs) { + BEGIN_PROTECT_OP + return new XlaOp(Rem(*lhs, *rhs)); + END_PROTECT_OP(lhs) +} + +xla_op op_max(const xla_op lhs, const xla_op rhs) { + BEGIN_PROTECT_OP + return new XlaOp(Max(*lhs, *rhs)); + END_PROTECT_OP(lhs) +} + +xla_op op_min(const xla_op lhs, const xla_op rhs) { + BEGIN_PROTECT_OP + return new XlaOp(Min(*lhs, *rhs)); + END_PROTECT_OP(lhs) +} + +xla_op op_and(const xla_op lhs, const xla_op rhs) { + BEGIN_PROTECT_OP + return new XlaOp(And(*lhs, *rhs)); + END_PROTECT_OP(lhs) +} + +xla_op op_or(const xla_op lhs, const xla_op rhs) { + BEGIN_PROTECT_OP + return new XlaOp(Or(*lhs, *rhs)); + END_PROTECT_OP(lhs) +} + +xla_op op_xor(const xla_op lhs, const xla_op rhs) { + BEGIN_PROTECT_OP + return new XlaOp(Xor(*lhs, *rhs)); + END_PROTECT_OP(lhs) +} + +xla_op op_atan2(const xla_op lhs, const xla_op rhs) { + BEGIN_PROTECT_OP + return new XlaOp(Atan2(*lhs, *rhs)); + END_PROTECT_OP(lhs) +} + +xla_op op_pow(const xla_op lhs, const xla_op rhs) { + BEGIN_PROTECT_OP + return new XlaOp(Pow(*lhs, *rhs)); + END_PROTECT_OP(lhs) +} + +xla_op op_dot(const xla_op lhs, const xla_op rhs) { + BEGIN_PROTECT_OP + return new XlaOp(Dot(*lhs, *rhs)); + END_PROTECT_OP(lhs) +} + +xla_op op_dot_general(const xla_op lhs, const xla_op rhs, const int64_t *lhs_c, + size_t nlhs_c, const int64_t *rhs_c, size_t nrhs_c, + const int64_t *lhs_b, size_t nlhs_b, const int64_t *rhs_b, + size_t nrhs_b) { + BEGIN_PROTECT_OP + DotDimensionNumbers dnums; + for (size_t i = 0; i < nlhs_c; ++i) + dnums.add_lhs_contracting_dimensions(lhs_c[i]); + for (size_t i = 0; i < nrhs_c; ++i) + dnums.add_rhs_contracting_dimensions(rhs_c[i]); + for (size_t i = 0; i < nlhs_b; ++i) + dnums.add_lhs_batch_dimensions(lhs_b[i]); + for (size_t i = 0; i < nrhs_b; ++i) + dnums.add_rhs_batch_dimensions(rhs_b[i]); + return new XlaOp(DotGeneral(*lhs, *rhs, dnums)); + END_PROTECT_OP(lhs) +} + +xla_op op_eq(const xla_op lhs, const xla_op rhs) { + BEGIN_PROTECT_OP + return new XlaOp(Eq(*lhs, *rhs)); + END_PROTECT_OP(lhs) +} + +xla_op op_ne(const xla_op lhs, const xla_op rhs) { + BEGIN_PROTECT_OP + return new XlaOp(Ne(*lhs, *rhs)); + END_PROTECT_OP(lhs) +} + +xla_op op_ge(const xla_op lhs, const xla_op rhs) { + BEGIN_PROTECT_OP + return new XlaOp(Ge(*lhs, *rhs)); + END_PROTECT_OP(lhs) +} + +xla_op op_gt(const xla_op lhs, const xla_op rhs) { + BEGIN_PROTECT_OP + return new XlaOp(Gt(*lhs, *rhs)); + END_PROTECT_OP(lhs) +} + +xla_op op_le(const xla_op lhs, const xla_op rhs) { + BEGIN_PROTECT_OP + return new XlaOp(Le(*lhs, *rhs)); + END_PROTECT_OP(lhs) +} + +xla_op op_lt(const xla_op lhs, const xla_op rhs) { + BEGIN_PROTECT_OP + return new XlaOp(Lt(*lhs, *rhs)); + END_PROTECT_OP(lhs) +} + +xla_op op_shift_left(const xla_op lhs, const xla_op rhs) { + BEGIN_PROTECT_OP + return new XlaOp(ShiftLeft(*lhs, *rhs)); + END_PROTECT_OP(lhs) +} + +xla_op op_shift_right_arith(const xla_op lhs, const xla_op rhs) { + BEGIN_PROTECT_OP + return new XlaOp(ShiftRightArithmetic(*lhs, *rhs)); + END_PROTECT_OP(lhs) +} + +xla_op op_shift_right_logic(const xla_op lhs, const xla_op rhs) { + BEGIN_PROTECT_OP + return new XlaOp(ShiftRightLogical(*lhs, *rhs)); + END_PROTECT_OP(lhs) +} + +xla_op op_population_count(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(PopulationCount(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_not(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(Not(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_abs(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(Abs(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_exp(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(Exp(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_expm1(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(Expm1(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_floor(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(Floor(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_ceil(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(Ceil(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_round(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(Round(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_round_nearest_even(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(RoundNearestEven(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_log(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(Log(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_log1p(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(Log1p(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_logistic(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(Logistic(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_sign(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(Sign(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_clz(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(Clz(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_cos(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(Cos(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_sin(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(Sin(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_tanh(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(Tanh(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_real(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(Real(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_imag(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(Imag(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_conj(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(Conj(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_square(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(Square(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_sqrt(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(Sqrt(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_rsqrt(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(Rsqrt(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_cbrt(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(Cbrt(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_is_finite(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(IsFinite(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_neg(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(Neg(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_lower_triangle(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(LowerTriangle(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_upper_triangle(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(UpperTriangle(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_erf(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(Erf(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_einsum1(const xla_op arg, const char *config) { + BEGIN_PROTECT_OP + return new XlaOp(Einsum(*arg, config)); + END_PROTECT_OP(arg) +} + +xla_op op_einsum2(const xla_op arg1, const xla_op arg2, const char *config) { + BEGIN_PROTECT_OP + return new XlaOp(Einsum(*arg1, *arg2, config)); + END_PROTECT_OP(arg1) +} + +xla_op op_copy(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(Copy(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_clone(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(*arg); + END_PROTECT_OP(arg) +} + +xla_op op_zeros_like(const xla_op arg) { + BEGIN_PROTECT_OP + return new XlaOp(ZerosLike(*arg)); + END_PROTECT_OP(arg) +} + +xla_op op_zero_like(const xla_op arg) { + BEGIN_PROTECT_OP + const Shape *shape = arg->builder()->GetShapePtr(*arg).value(); + return new XlaOp(Zero(arg->builder(), shape->element_type())); + END_PROTECT_OP(arg) +} + +xla_op op_reshape(const xla_op arg, size_t dsize, const int64_t *ds) { + BEGIN_PROTECT_OP + return new XlaOp(Reshape(*arg, absl::Span(ds, dsize))); + END_PROTECT_OP(arg) +} + +xla_op op_dynamic_reshape(const xla_op arg, size_t n_ops, const xla_op *ds, + size_t n_new_size_bounds, const int64_t *new_size_bounds, + const bool *dims_are_dynamic) { + BEGIN_PROTECT_OP + std::vector vec_dim_sizes; + for (size_t i = 0; i < n_ops; ++i) { + vec_dim_sizes.push_back(*ds[i]); + } + + std::vector vec_dims_are_dynamic; + for (size_t i = 0; i < n_ops; ++i) { + vec_dims_are_dynamic.push_back(dims_are_dynamic[i]); + } + + return new XlaOp( + DynamicReshape(*arg, + absl::Span(vec_dim_sizes), + absl::Span(new_size_bounds, n_new_size_bounds), + vec_dims_are_dynamic)); + END_PROTECT_OP(arg) +} + + + +xla_op op_broadcast(const xla_op arg, size_t dsize, const int64_t *ds) { + BEGIN_PROTECT_OP + return new XlaOp(Broadcast(*arg, absl::Span(ds, dsize))); + END_PROTECT_OP(arg) +} + +xla_op op_broadcast_in_dim(const xla_op arg, size_t out_dsize, + const int64_t *out_ds, size_t broadcast_dsize, + const int64_t *broadcast_ds) { + BEGIN_PROTECT_OP + return new XlaOp( + BroadcastInDim(*arg, absl::Span(out_ds, out_dsize), + absl::Span(broadcast_ds, broadcast_dsize))); + END_PROTECT_OP(arg) +} + +xla_op op_collapse(const xla_op arg, size_t dsize, const int64_t *ds) { + BEGIN_PROTECT_OP + return new XlaOp(Collapse(*arg, absl::Span(ds, dsize))); + END_PROTECT_OP(arg) +} + +xla_op op_transpose(const xla_op arg, size_t dsize, const int64_t *ds) { + BEGIN_PROTECT_OP + return new XlaOp(Transpose(*arg, absl::Span(ds, dsize))); + END_PROTECT_OP(arg) +} + +xla_op op_clamp(const xla_op arg1, const xla_op arg2, const xla_op arg3) { + BEGIN_PROTECT_OP + return new XlaOp(Clamp(*arg1, *arg2, *arg3)); + END_PROTECT_OP(arg1) +} + +xla_op op_select(const xla_op arg1, const xla_op arg2, const xla_op arg3) { + BEGIN_PROTECT_OP + return new XlaOp(Select(*arg1, *arg2, *arg3)); + END_PROTECT_OP(arg1) +} + +xla_op op_call(const xla_builder b, const xla_computation f, size_t n_ops, const xla_op *args) { + BEGIN_PROTECT_OP + std::vector args_; + for (size_t i = 0; i < n_ops; ++i) { + args_.push_back(*args[i]); + } + return new XlaOp(Call(b, *f, absl::Span(args_))); + END_PROTECT_OP_B(b) +} + +xla_op op_map(const xla_builder b, size_t n_ops, const xla_op *ops, const xla_computation f, + size_t n_dims, const int64_t *dims, size_t n_static_ops, const xla_op *static_ops) { + BEGIN_PROTECT_OP + std::vector ops_; + for (size_t i = 0; i < n_ops; ++i) { + ops_.push_back(*ops[i]); + } + std::vector static_ops_; + for (size_t i = 0; i < n_static_ops; ++i) { + static_ops_.push_back(*static_ops[i]); + } + return new XlaOp(Map(b, absl::Span(ops_), *f, + absl::Span(dims, n_dims), + absl::Span(static_ops_))); + END_PROTECT_OP_B(b) +} + +xla_op op_rng_uniform(const xla_op arg1, const xla_op arg2, int pr_type, + int dsize, const int64_t *ds) { + BEGIN_PROTECT_OP + auto shape = ShapeUtil::MakeShape((PrimitiveType)pr_type, + absl::Span(ds, dsize)); + return new XlaOp(RngUniform(*arg1, *arg2, shape)); + END_PROTECT_OP(arg1) +} + +xla_op op_rng_normal(const xla_op arg1, const xla_op arg2, int pr_type, + int dsize, const int64_t *ds) { + BEGIN_PROTECT_OP + auto shape = ShapeUtil::MakeShape((PrimitiveType)pr_type, + absl::Span(ds, dsize)); + return new XlaOp(RngNormal(*arg1, *arg2, shape)); + END_PROTECT_OP(arg1) +} + +xla_op op_pad(const xla_op arg, + const xla_op padding_value, + size_t n_dims, + const int64_t *edge_low, + const int64_t *edge_high, + const int64_t *interior) { + BEGIN_PROTECT_OP + PaddingConfig config; + for (size_t i = 0; i < n_dims; ++i) { + auto dim = config.add_dimensions(); + dim->set_edge_padding_low(edge_low[i]); + dim->set_edge_padding_high(edge_high[i]); + dim->set_interior_padding(interior[i]); + } + return new XlaOp(Pad(*arg, *padding_value, config)); + END_PROTECT_OP(arg) +} + +xla_op op_pad_in_dim(const xla_op arg, const xla_op padding_value, + int64_t dinmo, int64_t pad_lo, int64_t pad_hi) { + BEGIN_PROTECT_OP + return new XlaOp(PadInDim(*arg, *padding_value, dinmo, pad_lo, pad_hi)); + END_PROTECT_OP(arg) +} + + +xla_op op_slice(const xla_op arg, size_t start_dsize, const int64_t *start_ds, + size_t limit_dsize, const int64_t *limit_ds, + size_t stride_dsize, const int64_t *stride_ds) { + BEGIN_PROTECT_OP + return new XlaOp(Slice(*arg, absl::Span(start_ds, start_dsize), + absl::Span(limit_ds, limit_dsize), + absl::Span(stride_ds, stride_dsize))); + END_PROTECT_OP(arg) +} + +xla_op op_slice_in_dim(const xla_op arg, int64_t start, int64_t stop, + int64_t stride, int64_t dim) { + BEGIN_PROTECT_OP + return new XlaOp(SliceInDim(*arg, start, stop, stride, dim)); + END_PROTECT_OP(arg) +} + +xla_op op_dynamic_slice(const xla_op arg, size_t n_ops, + const xla_op *start_indices, + size_t slice_dsize, const int64_t *slice_ds) { + BEGIN_PROTECT_OP + std::vector indices; + for (size_t i = 0; i < n_ops; ++i) { + indices.push_back(*start_indices[i]); + } + return new XlaOp( + DynamicSlice(*arg, absl::Span(indices), + absl::Span(slice_ds, slice_dsize))); + END_PROTECT_OP(arg) +} + +xla_op op_dynamic_update_slice(const xla_op arg, const xla_op update, + size_t n_ops, const xla_op *start_indices) { + BEGIN_PROTECT_OP + std::vector indices; + for (size_t i = 0; i < n_ops; ++i) { + indices.push_back(*start_indices[i]); + } + return new XlaOp(DynamicUpdateSlice(*arg, *update, absl::Span(indices))); + END_PROTECT_OP(arg) +} + +xla_op op_concat_in_dim(const xla_op arg, const xla_op *args, size_t nargs, + int64_t dim) { + BEGIN_PROTECT_OP + std::vector args_ = {*arg}; + for (size_t i = 0; i < nargs; ++i) { + args_.push_back(*args[i]); + } + return new XlaOp( + ConcatInDim(arg->builder(), absl::Span(args_), dim)); + END_PROTECT_OP(arg) +} + +xla_op op_tuple(const xla_builder b, const xla_op *args, size_t nargs) { + BEGIN_PROTECT_OP + std::vector args_; + for (size_t i = 0; i < nargs; ++i) { + args_.push_back(*args[i]); + } + return new XlaOp(Tuple(b, absl::Span(args_))); + END_PROTECT_OP_B(b) +} + +xla_op op_get_tuple_element(const xla_op arg, int64_t index) { + BEGIN_PROTECT_OP + return new XlaOp(GetTupleElement(*arg, index)); + END_PROTECT_OP(arg) +} + +xla_op op_gather(const xla_op arg1, const xla_op arg2, + const int64_t *offset_dims, size_t noffset_dims, + const int64_t *collapsed_slice_dims, + size_t ncollapsed_slice_dims, const int64_t *start_index_map, + size_t nstart_index_map, const int64_t *set_index_vector_dim, + const int64_t *slice_sizes, size_t nslice_sizes) { + BEGIN_PROTECT_OP + GatherDimensionNumbers dnums; + for (size_t i = 0; i < noffset_dims; ++i) { + dnums.add_offset_dims(offset_dims[i]); + } + for (size_t i = 0; i < ncollapsed_slice_dims; ++i) { + dnums.add_collapsed_slice_dims(collapsed_slice_dims[i]); + } + for (size_t i = 0; i < nstart_index_map; ++i) { + dnums.add_start_index_map(start_index_map[i]); + } + if (set_index_vector_dim) { + dnums.set_index_vector_dim(*set_index_vector_dim); + } + auto ss = absl::Span(slice_sizes, nslice_sizes); + return new XlaOp(Gather(*arg1, *arg2, dnums, ss)); + END_PROTECT_OP(arg1) +} + +xla_op op_scatter(size_t n_ops, + const xla_op *operands, + const xla_op scatter_indices, + size_t n_updates, + const xla_op *updates, + const xla_computation comp, + size_t n_update_window_dims, + const int64_t *update_window_dims, + size_t n_inserted_window_dims, + const int64_t *inserted_window_dims, + size_t n_scatter_dims_to_operand_dims, + const int64_t *scatter_dims_to_operand_dims, + int64_t index_vector_dim + ) { + BEGIN_PROTECT_OP + std::vector operands_; + for (size_t i = 0; i < n_ops; ++i) { + operands_.push_back(*operands[i]); + } + std::vector updates_; + for (size_t i = 0; i < n_updates; ++i) { + updates_.push_back(*updates[i]); + } + ScatterDimensionNumbers dnums; + for (size_t i = 0; i < n_update_window_dims; ++i) { + dnums.add_update_window_dims(update_window_dims[i]); + } + for (size_t i = 0; i < n_inserted_window_dims; ++i) { + dnums.add_inserted_window_dims(inserted_window_dims[i]); + } + for (size_t i = 0; i < n_scatter_dims_to_operand_dims; ++i) { + dnums.add_scatter_dims_to_operand_dims(scatter_dims_to_operand_dims[i]); + } + dnums.set_index_vector_dim(index_vector_dim); + return new XlaOp(Scatter(operands_, *scatter_indices, updates_, *comp, dnums)); + END_PROTECT_OP(scatter_indices) +} + +xla_op op_convert_element_type(const xla_op arg, int pr_type) { + BEGIN_PROTECT_OP + return new XlaOp(ConvertElementType(*arg, (PrimitiveType)pr_type)); + END_PROTECT_OP(arg) +} + +xla_op op_dimensions_size(const xla_op arg, int64_t dim) { + BEGIN_PROTECT_OP + return new XlaOp(GetDimensionSize(*arg, dim)); + END_PROTECT_OP(arg) +} + +xla_op op_reduce(const xla_op arg, const xla_op init, + const xla_computation comp, const int64_t *dims, + size_t ndims) { + BEGIN_PROTECT_OP + return new XlaOp( + Reduce(*arg, *init, *comp, absl::Span(dims, ndims))); + END_PROTECT_OP(arg) +} + +xla_op op_internal_error(const xla_builder b, const char *error) { + BEGIN_PROTECT_OP + return new XlaOp(b->ReportError(tsl::errors::Internal(error))); + END_PROTECT_OP_B(b) +} + +xla_op op_unknown_error(const xla_builder b, const char *error) { + BEGIN_PROTECT_OP + return new XlaOp(b->ReportError(tsl::errors::Unknown(error))); + END_PROTECT_OP_B(b) +} + +xla_op op_invalid_argument_error(const xla_builder b, const char *error) { + BEGIN_PROTECT_OP + return new XlaOp(b->ReportError(tsl::errors::InvalidArgument(error))); + END_PROTECT_OP_B(b) +} + +xla_op op_zero(const xla_builder b, int pr_type) { + BEGIN_PROTECT_OP + return new XlaOp(Zero(b, (PrimitiveType)pr_type)); + END_PROTECT_OP_B(b) +} + +xla_op op_one(const xla_builder b, int pr_type) { + BEGIN_PROTECT_OP + return new XlaOp(One(b, (PrimitiveType)pr_type)); + END_PROTECT_OP_B(b) +} + +xla_op op_min_value(const xla_builder b, int pr_type) { + BEGIN_PROTECT_OP + return new XlaOp(MinValue(b, (PrimitiveType)pr_type)); + END_PROTECT_OP_B(b) +} + +xla_op op_max_value(const xla_builder b, int pr_type) { + BEGIN_PROTECT_OP + return new XlaOp(MaxValue(b, (PrimitiveType)pr_type)); + END_PROTECT_OP_B(b) +} + +xla_op op_iota1(const xla_builder b, int pr_type, size_t sz) { + BEGIN_PROTECT_OP + return new XlaOp(Iota(b, (PrimitiveType)pr_type, (int64_t)sz)); + END_PROTECT_OP_B(b) +} + +xla_op op_iota(const xla_builder b, int pr_type, size_t dsize, + const int64_t *ds, int64_t increasing_dim) { + BEGIN_PROTECT_OP + auto shape = ShapeUtil::MakeShape((PrimitiveType)pr_type, + absl::Span(ds, dsize)); + return new XlaOp(Iota(b, shape, increasing_dim)); + END_PROTECT_OP_B(b) +} + +xla_op op_while(const xla_computation cond, const xla_computation body, + const xla_op init) { + BEGIN_PROTECT_OP + return new XlaOp(While(*cond, *body, *init)); + END_PROTECT_OP(init) +} + +xla_op op_conditional(const xla_op pred, const xla_op true_op, + const xla_computation true_comp, const xla_op false_op, + const xla_computation false_comp) { + BEGIN_PROTECT_OP + return new XlaOp( + Conditional(*pred, *true_op, *true_comp, *false_op, *false_comp)); + END_PROTECT_OP(pred) +} + +Padding ParsePadding(const char* padding_config) { + if (std::string(padding_config) == "same") { + return Padding::kSame; + } + if (std::string(padding_config) == "valid") { + return Padding::kValid; + } + throw std::runtime_error("Invalid padding config: " + std::string(padding_config)); +} + +xla_op op_conv(const xla_op lhs, + const xla_op rhs, + size_t n_strides, + const int64_t *window_strides, + const char *padding_config, + int64_t feature_group_count, + int64_t batch_group_count) { + BEGIN_PROTECT_OP + Padding padding = ParsePadding(padding_config); + return new XlaOp( + Conv(*lhs, *rhs, absl::Span(window_strides, n_strides), padding, feature_group_count, batch_group_count)); + END_PROTECT_OP(lhs) +} + +xla_op op_conv_general_dilated(const xla_op lhs, + const xla_op rhs, + size_t n_strides, + const int64_t *window_strides, + size_t n_padding_pairs, + const int64_t *padding_values, + size_t n_lhs_dilations, + const int64_t *lhs_dilations, + size_t n_rhs_dilations, + const int64_t *rhs_dilations, + const int64_t *ibdim, + const int64_t *ifdim, + size_t n_isdims, + const int64_t *isdims, + const int64_t *obdim, + const int64_t *ofdim, + size_t n_osdims, + const int64_t *osdims, + const int64_t *kifdim, + const int64_t *kofdim, + size_t n_ksdims, + const int64_t *ksdims, + int64_t feature_group_count, + int64_t batch_group_count) { + BEGIN_PROTECT_OP + std::vector> padding_pairs; + for (size_t i = 0; i < 2 * n_padding_pairs; i += 2) { + padding_pairs.emplace_back(padding_values[i], padding_values[i + 1]); + } + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(*ibdim); + dnums.set_input_feature_dimension(*ifdim); + for (size_t i = 0; i < n_isdims; ++i) { + dnums.add_input_spatial_dimensions(isdims[i]); + } + dnums.set_output_batch_dimension(*obdim); + dnums.set_output_feature_dimension(*ofdim); + for (size_t i = 0; i < n_osdims; ++i) { + dnums.add_output_spatial_dimensions(osdims[i]); + } + dnums.set_kernel_input_feature_dimension(*kifdim); + dnums.set_kernel_output_feature_dimension(*kofdim); + for (size_t i = 0; i < n_ksdims; ++i) { + dnums.add_kernel_spatial_dimensions(ksdims[i]); + } + return new XlaOp( + ConvGeneralDilated(*lhs, + *rhs, + absl::Span(window_strides, n_strides), + absl::Span>(padding_pairs), + absl::Span(lhs_dilations, n_lhs_dilations), + absl::Span(rhs_dilations, n_rhs_dilations), + dnums, + feature_group_count, + batch_group_count)); + END_PROTECT_OP(lhs) +} + + +xla_op op_batch_norm_inference(const xla_op operand, + const xla_op scale, + const xla_op offset, + const xla_op mean, + const xla_op variance, + float epsilon, + int64_t feature_index) { + BEGIN_PROTECT_OP + return new XlaOp(BatchNormInference(*operand, *scale, *offset, *mean, *variance, epsilon, feature_index)); + END_PROTECT_OP(operand) +} + +xla_builder op_builder(const xla_op arg) { return arg->builder(); } + +int xla_op_valid(const xla_op op) { return op->valid(); } + +void xla_op_free(xla_op o) { delete o; } + +size_t shape_tuple_shapes_size(const shape s) { return s->tuple_shapes_size(); } + +shape shape_tuple_shapes(const shape s, int i) { + return (shape)&s->tuple_shapes(i); +} + +int shape_dimensions_size(const shape s) { return s->dimensions_size(); } + +int shape_element_type(const shape s) { return s->element_type(); } + +int64_t shape_dimensions(const shape s, int i) { return s->dimensions(i); } + +void shape_free(shape s) { delete s; } + +status get_shape(const xla_builder b, const xla_op o, shape *out_shape) { + ASSIGN_OR_RETURN_STATUS(shape, b->GetShape(*o)); + *out_shape = new Shape(shape); + return nullptr; +} + +status get_element_type(const xla_builder b, const xla_op o, + int *out_element_type) { + ASSIGN_OR_RETURN_STATUS(shape, b->GetShapePtr(*o)); + *out_element_type = shape->element_type(); + return nullptr; +} + +status get_dimensions_size(const xla_builder b, const xla_op o, int *out_rank) { + ASSIGN_OR_RETURN_STATUS(shape, b->GetShapePtr(*o)); + *out_rank = shape->dimensions_size(); + return nullptr; +} + +status get_dimensions(const xla_builder b, const xla_op o, size_t *out_dims) { + ASSIGN_OR_RETURN_STATUS(shape, b->GetShapePtr(*o)); + size_t dim_size = shape->dimensions_size(); + for (size_t i = 0; i < dim_size; ++i) { + out_dims[i] = shape->dimensions(i); + } + return nullptr; +} + +status build(const xla_builder b, const xla_op o, xla_computation *output) { + ASSIGN_OR_RETURN_STATUS(computation, b->Build(o)); + *output = new XlaComputation(); + **output = std::move(computation); + return nullptr; +} + +status compile(const pjrt_client client, const xla_computation computation, + pjrt_loaded_executable *output) { + CompileOptions options; + ASSIGN_OR_RETURN_STATUS(executable, + (*client)->Compile(*computation, options)); + *output = executable.release(); + return nullptr; +} + +status first_error(const xla_builder b) { + MAYBE_RETURN_STATUS(b->first_error()); + return nullptr; +} + +status get_current_status(const xla_builder b) { + MAYBE_RETURN_STATUS(b->GetCurrentStatus()); + return nullptr; +} + +status execute(const pjrt_loaded_executable exe, const literal *inputs, + int ninputs, pjrt_buffer ***outputs) { + auto client = exe->client(); + ExecuteOptions options; + options.strict_shape_checking = false; + std::vector input_buffer_ptrs; + PjRtDevice *device = client->devices()[0]; + for (int i = 0; i < ninputs; ++i) { + ASSIGN_OR_RETURN_STATUS(buffer, + client->BufferFromHostLiteral(*inputs[i], device)); + // Wait for the transfer to have completed to avoid the literal potentially + // getting out of scope before it has been transfered. + MAYBE_RETURN_STATUS(buffer->GetReadyFuture().Await()); + input_buffer_ptrs.push_back(buffer.release()); + } + ASSIGN_OR_RETURN_STATUS(results, exe->Execute({input_buffer_ptrs}, options)); + pjrt_buffer **out = + (pjrt_buffer **)malloc((results.size() + 1) * sizeof(pjrt_buffer *)); + for (size_t i = 0; i < results.size(); ++i) { + auto &replica_results = results[i]; + pjrt_buffer *per_replica_outputs = (pjrt_buffer *)malloc( + (replica_results.size() + 1) * sizeof(pjrt_buffer)); + for (size_t j = 0; j < replica_results.size(); ++j) { + per_replica_outputs[j] = replica_results[j].release(); + } + per_replica_outputs[replica_results.size()] = nullptr; + out[i] = per_replica_outputs; + } + out[results.size()] = nullptr; + *outputs = out; + return nullptr; +} + +status execute_b(const pjrt_loaded_executable exe, const pjrt_buffer *inputs, + int ninputs, pjrt_buffer ***outputs) { + auto client = exe->client(); + ExecuteOptions options; + options.strict_shape_checking = false; + std::vector input_buffer_ptrs(inputs, inputs + ninputs); + ASSIGN_OR_RETURN_STATUS(results, exe->Execute({input_buffer_ptrs}, options)); + pjrt_buffer **out = + (pjrt_buffer **)malloc((results.size() + 1) * sizeof(pjrt_buffer *)); + for (size_t i = 0; i < results.size(); ++i) { + auto &replica_results = results[i]; + pjrt_buffer *per_replica_outputs = (pjrt_buffer *)malloc( + (replica_results.size() + 1) * sizeof(pjrt_buffer)); + for (size_t j = 0; j < replica_results.size(); ++j) { + per_replica_outputs[j] = replica_results[j].release(); + } + per_replica_outputs[replica_results.size()] = nullptr; + out[i] = per_replica_outputs; + } + out[results.size()] = nullptr; + *outputs = out; + return nullptr; +} + +literal literal_create_from_shape(int pr_type, const int64_t *dims, + size_t ndims) { + auto shape = ShapeUtil::MakeShape((PrimitiveType)pr_type, + absl::Span(dims, ndims)); + Literal l = Literal::CreateFromShape(shape); + return new Literal(std::move(l)); +} + +literal literal_create_from_shape_and_data(int pr_type, const int64_t *dims, + size_t ndims, const void *data, + size_t data_len) { + auto shape = ShapeUtil::MakeShape((PrimitiveType)pr_type, + absl::Span(dims, ndims)); + Literal l = Literal::CreateFromShape(shape); + if (l.size_bytes() != data_len) { + return nullptr; + } + memcpy(l.untyped_data(), data, data_len); + return new Literal(std::move(l)); +} + +literal literal_clone(const literal l) { + return new Literal(std::move(l->Clone())); +} + +status literal_reshape(const literal l, const int64_t *dims, size_t ndims, + literal *output) { + ASSIGN_OR_RETURN_STATUS(literal, + l->Reshape(absl::Span(dims, ndims))); + *output = new Literal(std::move(literal)); + return nullptr; +} + +status literal_convert(const literal l, int pr_type, literal *output) { + ASSIGN_OR_RETURN_STATUS(literal, l->Convert((PrimitiveType)pr_type)); + *output = new Literal(std::move(literal)); + return nullptr; +} + +int64_t literal_element_count(const literal l) { return l->element_count(); } + +int64_t literal_size_bytes(const literal l) { return l->size_bytes(); } + +void literal_shape(const literal l, shape *out_shape) { + *out_shape = new Shape(l->shape()); +} + +void literal_decompose_tuple(literal l, literal *outputs, size_t noutputs) { + auto tuple = l->DecomposeTuple(); + for (int i = 0; i < std::min(noutputs, tuple.size()); ++i) { + outputs[i] = new Literal(std::move(tuple[i])); + } +} + +int literal_element_type(const literal l) { return l->shape().element_type(); } + +void literal_copy_to(const literal l, void *dst, size_t size_in_bytes) { + std::memcpy(dst, l->untyped_data(), size_in_bytes); +} + +void literal_copy_from(literal l, const void *src, size_t size_in_bytes) { + std::memcpy(l->untyped_data(), src, size_in_bytes); +} + +literal literal_make_tuple(const literal *l, size_t n) { + Literal out = LiteralUtil::MakeTuple(absl::MakeSpan(l, n)); + return new Literal(std::move(out)); +} + +literal literal_make_tuple_owned(const literal *l, size_t n) { + std::vector elems; + for (size_t i = 0; i < n; ++i) { + elems.push_back(std::move(*(l[i]))); + } + Literal out = LiteralUtil::MakeTupleOwned(std::move(elems)); + return new Literal(std::move(out)); +} + +void literal_free(literal l) { delete l; } + +void status_free(status s) { delete s; } + +char *xla_computation_name(xla_computation c) { + return strdup(std::string(c->name()).c_str()); +} + +void xla_computation_free(xla_computation c) { delete c; } + +char *status_error_message(status s) { + return strdup(s->error_message().c_str()); +} + +status hlo_module_proto_parse_and_return_unverified_module( + const char *data, size_t len, hlo_module_proto *output) { + ASSIGN_OR_RETURN_STATUS( + hmp, ParseAndReturnUnverifiedModule(std::string(data, len))); + *output = new HloModuleProto(hmp->ToProto()); + return nullptr; +} + +status hlo_module_proto_parse_proto(const char *d, size_t len, bool binary, + hlo_module_proto *output) { + std::string data(d, len); + HloSnapshot proto; + if (binary) { + if (!proto.ParseFromString(data) && + !proto.mutable_hlo()->ParseFromString(data) && + !proto.mutable_hlo()->mutable_hlo_module()->ParseFromString(data)) { + return new Status( + InvalidArgument("Failed to parse input as HLO protobuf binary")); + } + } else { + if (!tsl::protobuf::TextFormat::ParseFromString(data, &proto) && + !tsl::protobuf::TextFormat::ParseFromString(data, + proto.mutable_hlo()) && + !tsl::protobuf::TextFormat::ParseFromString( + data, proto.mutable_hlo()->mutable_hlo_module())) { + return new Status( + InvalidArgument("Failed to parse input as HLO protobuf text")); + } + } + ASSIGN_OR_RETURN_STATUS(config, HloModule::CreateModuleConfigFromProto( + proto.hlo().hlo_module(), {})); + ASSIGN_OR_RETURN_STATUS( + hmp, HloModule::CreateFromProto(proto.hlo().hlo_module(), config)); + *output = new HloModuleProto(hmp->ToProto()); + return nullptr; +} + +status hlo_module_from_proto(const hlo_module_proto input_proto, hlo_module *output) { + ASSIGN_OR_RETURN_STATUS( + config, HloModule::CreateModuleConfigFromProto(*input_proto, {})); + ASSIGN_OR_RETURN_STATUS( + hm, HloModule::CreateFromProto(*input_proto, config)); + *output = hm.release(); + return nullptr; +} + +hlo_computation hlo_module_entry_computation(const hlo_module module) { + return module->entry_computation(); +} + +int64_t hlo_module_computation_count(const hlo_module module) { + return module->computation_count(); +} + +int64_t hlo_module_instruction_count(const hlo_module module) { + return module->instruction_count(); +} + +char* hlo_module_to_string(const hlo_module module) { + std::string result = module->ToString(); + char* output = new char[result.length() + 1]; + std::strcpy(output, result.c_str()); + return output; +} + +xla_computation xla_computation_from_hlo_module_proto(const hlo_module_proto p) { + return new XlaComputation(*p); +} + +void hlo_module_proto_free(hlo_module_proto p) { delete p; } + +hlo_module_proto xla_computation_proto(const xla_computation c) { + return new HloModuleProto(c->proto()); +} \ No newline at end of file diff --git a/ivy/engines/XLA/rust_api/xla_rs/xla_rs.h b/ivy/engines/XLA/rust_api/xla_rs/xla_rs.h new file mode 100644 index 0000000000000..e414657e701e9 --- /dev/null +++ b/ivy/engines/XLA/rust_api/xla_rs/xla_rs.h @@ -0,0 +1,324 @@ +#include +#include +#include +#ifdef __cplusplus +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wuninitialized" +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#pragma GCC diagnostic ignored "-Winvalid-offsetof" +#pragma GCC diagnostic ignored "-Wreturn-type" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/pjrt/gpu/gpu_helpers.h" +#include "tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h" +#include "tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h" +#include "tensorflow/compiler/xla/pjrt/tpu_client.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#pragma GCC diagnostic pop +using namespace xla; + +extern "C" { +typedef std::shared_ptr *pjrt_client; +typedef PjRtLoadedExecutable *pjrt_loaded_executable; +typedef PjRtDevice *pjrt_device; +typedef PjRtBuffer *pjrt_buffer; +typedef XlaBuilder *xla_builder; +typedef XlaOp *xla_op; +typedef Status *status; +typedef Shape *shape; +typedef Literal *literal; +typedef XlaComputation *xla_computation; +typedef HloModule *hlo_module; +typedef HloModuleProto *hlo_module_proto; +typedef HloComputation *hlo_computation; +#else +typedef struct _pjrt_client *pjrt_client; +typedef struct _pjrt_loaded_executable *pjrt_loaded_executable; +typedef struct _pjrt_device *pjrt_device; +typedef struct _pjrt_buffer *pjrt_buffer; +typedef struct _xla_builder *xla_builder; +typedef struct _xla_op *xla_op; +typedef struct _status *status; +typedef struct _shape *shape; +typedef struct _literal *literal; +typedef struct _xla_computation *xla_computation; +typedef struct _hlo_module *hlo_module; +typedef struct _hlo_module_proto *hlo_module_proto; +typedef struct _hlo_computation *hlo_computation; +#endif + +status pjrt_cpu_client_create(pjrt_client *); +status pjrt_gpu_client_create(pjrt_client *, double, bool); +status pjrt_tpu_client_create(pjrt_client *, int); +void pjrt_client_free(pjrt_client); +int pjrt_client_device_count(pjrt_client); +int pjrt_client_addressable_device_count(pjrt_client); +void pjrt_client_devices(pjrt_client, pjrt_device *); +void pjrt_client_addressable_devices(pjrt_client, pjrt_device *); +char *pjrt_client_platform_name(pjrt_client); +char *pjrt_client_platform_version(pjrt_client); + +void pjrt_loaded_executable_free(pjrt_loaded_executable); + +int pjrt_device_id(pjrt_device); +int pjrt_device_process_index(pjrt_device); +int pjrt_device_local_hardware_id(pjrt_device); +status pjrt_device_transfer_to_infeed(pjrt_device, const literal); +status pjrt_device_transfer_from_outfeed(pjrt_device, literal); +char *pjrt_device_kind(pjrt_device); +char *pjrt_device_debug_string(pjrt_device); +char *pjrt_device_to_string(pjrt_device); + +status pjrt_buffer_from_host_literal(const pjrt_client, const pjrt_device, + const literal, pjrt_buffer *); +status pjrt_buffer_from_host_buffer(const pjrt_client, const pjrt_device, + const void *, int, int, const int64_t *, + pjrt_buffer *); +status pjrt_buffer_to_literal_sync(pjrt_buffer, literal *); +status pjrt_buffer_copy_raw_to_host_sync(pjrt_buffer, void *, size_t, size_t); +shape pjrt_buffer_on_device_shape(pjrt_buffer); +status pjrt_buffer_copy_to_device(pjrt_buffer, pjrt_device, pjrt_buffer *); +void pjrt_buffer_free(pjrt_buffer); + +xla_builder xla_builder_create(const char *); +void xla_builder_free(xla_builder); + +xla_op constant_literal(const xla_builder, const literal); +xla_op parameter(const xla_builder, int64_t, int, int, const int64_t *, + const char *); +xla_op parameter_s(const xla_builder, int64_t, const shape, const char *); +xla_op infeed(const xla_builder, int, int, const int64_t *, const char *); +void outfeed(const xla_op, int, int, const int64_t *, const char *); + +// Ops +xla_op op_add(const xla_op, const xla_op); +xla_op op_sub(const xla_op, const xla_op); +xla_op op_mul(const xla_op, const xla_op); +xla_op op_div(const xla_op, const xla_op); +xla_op op_rem(const xla_op, const xla_op); +xla_op op_max(const xla_op, const xla_op); +xla_op op_min(const xla_op, const xla_op); +xla_op op_and(const xla_op, const xla_op); +xla_op op_or(const xla_op, const xla_op); +xla_op op_xor(const xla_op, const xla_op); +xla_op op_atan2(const xla_op, const xla_op); +xla_op op_pow(const xla_op, const xla_op); +xla_op op_dot(const xla_op, const xla_op); +xla_op op_dot_general(const xla_op, const xla_op, const int64_t *, size_t, + const int64_t *, size_t, const int64_t *, size_t, + const int64_t *, size_t); +xla_op op_eq(const xla_op, const xla_op); +xla_op op_ne(const xla_op, const xla_op); +xla_op op_ge(const xla_op, const xla_op); +xla_op op_gt(const xla_op, const xla_op); +xla_op op_le(const xla_op, const xla_op); +xla_op op_lt(const xla_op, const xla_op); +xla_op op_shift_left(const xla_op, const xla_op); +xla_op op_shift_right_arith(const xla_op, const xla_op); +xla_op op_shift_right_logic(const xla_op, const xla_op); +xla_op op_population_count(const xla_op); +xla_op op_not(const xla_op); +xla_op op_abs(const xla_op); +xla_op op_exp(const xla_op); +xla_op op_expm1(const xla_op); +xla_op op_floor(const xla_op); +xla_op op_ceil(const xla_op); +xla_op op_round(const xla_op); +xla_op op_round_nearest_even(const xla_op); +xla_op op_log(const xla_op); +xla_op op_log1p(const xla_op); +xla_op op_logistic(const xla_op); +xla_op op_sign(const xla_op); +xla_op op_clz(const xla_op); +xla_op op_cos(const xla_op); +xla_op op_sin(const xla_op); +xla_op op_tanh(const xla_op); +xla_op op_real(const xla_op); +xla_op op_imag(const xla_op); +xla_op op_conj(const xla_op); +xla_op op_square(const xla_op); +xla_op op_sqrt(const xla_op); +xla_op op_rsqrt(const xla_op); +xla_op op_cbrt(const xla_op); +xla_op op_is_finite(const xla_op); +xla_op op_neg(const xla_op); +xla_op op_lower_triangle(const xla_op); +xla_op op_upper_triangle(const xla_op); +xla_op op_erf(const xla_op); +xla_op op_einsum1(const xla_op, const char *); +xla_op op_einsum2(const xla_op, const xla_op, const char *); +xla_op op_copy(const xla_op); +xla_op op_clone(const xla_op); +xla_op op_zeros_like(const xla_op); +xla_op op_zero_like(const xla_op); +xla_op op_zero(const xla_builder, int); +xla_op op_one(const xla_builder, int); +xla_op op_min_value(const xla_builder, int); +xla_op op_max_value(const xla_builder, int); +xla_op op_reshape(const xla_op, size_t, const int64_t *); +xla_op op_dynamic_reshape(const xla_op, size_t, const xla_op *, size_t, const int64_t *, const bool *); +xla_op op_broadcast(const xla_op, size_t, const int64_t *); +xla_op op_broadcast_in_dim(const xla_op, size_t, const int64_t *, size_t, + const int64_t *); +xla_op op_collapse(const xla_op, size_t, const int64_t *); +xla_op op_transpose(const xla_op, size_t, const int64_t *); +xla_op op_clamp(const xla_op, const xla_op, const xla_op); +xla_op op_select(const xla_op, const xla_op, const xla_op); +xla_op op_call(const xla_builder, const xla_computation, size_t, const xla_op *); +xla_op op_map(const xla_builder, size_t, const xla_op *, const xla_computation, size_t, const int64_t *, size_t, const xla_op *); +xla_op op_rng_uniform(const xla_op, const xla_op, int, int, const int64_t *); +xla_op op_rng_normal(const xla_op, const xla_op, int, int, const int64_t *); +xla_op op_pad(const xla_op, const xla_op, size_t, const int64_t *, const int64_t *, const int64_t *); +xla_op op_pad_in_dim(const xla_op, const xla_op, int64_t, int64_t, int64_t); +xla_op op_slice(const xla_op, size_t, const int64_t *, size_t, const int64_t *, size_t, const int64_t *); +xla_op op_slice_in_dim(const xla_op, int64_t, int64_t, int64_t, int64_t); +xla_op op_dynamic_slice(const xla_op, size_t, const xla_op *, size_t, const int64_t *); +xla_op op_dynamic_update_slice(const xla_op, const xla_op, size_t, const xla_op *); +xla_op op_concat_in_dim(const xla_op, const xla_op *, size_t, int64_t); +xla_op op_tuple(const xla_builder, const xla_op *, size_t); +xla_op op_get_tuple_element(const xla_op, int64_t); +xla_op op_gather(const xla_op, const xla_op, const int64_t *, size_t, + const int64_t *, size_t, const int64_t *, size_t, + const int64_t *, const int64_t *, size_t); +xla_op op_scatter(size_t, const xla_op *, const xla_op, size_t, const xla_op *, const xla_computation, + size_t, const int64_t *, size_t, const int64_t *, size_t, const int64_t *, int64_t); +xla_op op_convert_element_type(const xla_op, int); +xla_op op_dimensions_size(const xla_op, int64_t); +xla_op op_reduce(const xla_op, const xla_op, const xla_computation, + const int64_t *, size_t); +xla_op op_internal_error(const xla_builder, const char *); +xla_op op_unknown_error(const xla_builder, const char *); +xla_op op_invalid_argument_error(const xla_builder, const char *); +xla_op op_iota1(const xla_builder, int, size_t); +xla_op op_iota(const xla_builder, int, size_t, const int64_t *, int64_t); +xla_op op_while(const xla_computation, const xla_computation, const xla_op); +xla_op op_conditional(const xla_op, const xla_op, const xla_computation, + const xla_op, const xla_computation); +xla_op op_conv(const xla_op, const xla_op, size_t, const int64_t *, const char*, int64_t, int64_t); +xla_op op_conv_general_dilated(const xla_op, const xla_op, + size_t, const int64_t *, + size_t, const int64_t *, + size_t, const int64_t *, + size_t, const int64_t *, + const int64_t *, + const int64_t *, + size_t, const int64_t *, + const int64_t *, + const int64_t *, + size_t, const int64_t *, + const int64_t *, + const int64_t *, + size_t, const int64_t *, + int64_t, int64_t); +xla_op op_batch_norm_inference(const xla_op, + const xla_op, + const xla_op, + const xla_op, + const xla_op, + float, + int64_t); + +xla_builder op_builder(const xla_op); + +int xla_op_valid(const xla_op); +void xla_op_free(xla_op); + +int shape_dimensions_size(const shape); +size_t shape_tuple_shapes_size(const shape); +shape shape_tuple_shapes(const shape, int); +int shape_element_type(const shape); +int64_t shape_dimensions(const shape, int); +void shape_free(shape); +shape make_shape_array(int, size_t, const int64_t *); +shape make_shape_tuple(size_t, const shape *); + +status get_shape(const xla_builder, const xla_op, shape *); +status get_element_type(const xla_builder, const xla_op, int *); +status get_dimensions_size(const xla_builder, const xla_op, int *); +status get_dimensions(const xla_builder, const xla_op, size_t *); + +status build(const xla_builder, const xla_op, xla_computation *); +status compile(const pjrt_client, const xla_computation, + pjrt_loaded_executable *); +status execute(const pjrt_loaded_executable, const literal *, int, + pjrt_buffer ***); +status execute_b(const pjrt_loaded_executable, const pjrt_buffer *, int, + pjrt_buffer ***); +status first_error(const xla_builder); +status get_current_status(const xla_builder); + +literal literal_create_from_shape(int, const int64_t *, size_t); +literal literal_create_from_shape_and_data(int, const int64_t *, size_t, + const void *, size_t); +literal literal_clone(const literal); +status literal_reshape(const literal, const int64_t *, size_t, literal *); +status literal_convert(const literal, int, literal *); +int64_t literal_element_count(const literal); +int literal_element_type(const literal); +void literal_shape(const literal, shape *); +void literal_decompose_tuple(literal, literal *, size_t); +int64_t literal_size_bytes(const literal); +void literal_copy_to(const literal, void *, size_t); +void literal_copy_from(literal, const void *, size_t); +literal literal_make_tuple(const literal *, size_t); +literal literal_make_tuple_owned(const literal *, size_t); +void literal_free(literal); + +status hlo_module_proto_parse_and_return_unverified_module(const char *, size_t, + hlo_module_proto *); +status hlo_module_proto_parse_proto(const char *, size_t, bool, + hlo_module_proto *); +status hlo_module_from_proto(const hlo_module_proto, hlo_module *); + +hlo_computation hlo_module_entry_computation(const hlo_module); +int64_t hlo_module_computation_count(const hlo_module); +int64_t hlo_module_instruction_count(const hlo_module); +char *hlo_module_to_string(const hlo_module); + +xla_computation xla_computation_from_hlo_module_proto(const hlo_module_proto); +void hlo_module_proto_free(hlo_module_proto); + +char *xla_computation_name(xla_computation); +hlo_module_proto xla_computation_proto(const xla_computation); +void xla_computation_free(xla_computation); + +void status_free(status); +char *status_error_message(status); + +#define FOR_EACH_NATIVE_TYPE(_) \ + _(bool, PRED) \ + _(int8_t, S8) \ + _(int16_t, S16) \ + _(int32_t, S32) \ + _(int64_t, S64) \ + _(uint8_t, U8) \ + _(uint16_t, U16) \ + _(uint32_t, U32) \ + _(uint64_t, U64) \ + _(float, F32) \ + _(double, F64) + +#define CONST_OP_R01(native_type, primitive_type) \ + xla_op constant_r0_##native_type(const xla_builder, native_type); \ + xla_op constant_r1c_##native_type(const xla_builder, native_type, size_t); \ + xla_op constant_r1_##native_type(const xla_builder, const native_type *, \ + size_t); \ + literal create_r0_##native_type(native_type); \ + literal create_r1_##native_type(const native_type *, size_t); \ + native_type literal_get_first_element_##native_type(const literal); + +FOR_EACH_NATIVE_TYPE(CONST_OP_R01) +#undef CONST_OP_R01 + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/ivy/engines/__init__.py b/ivy/engines/__init__.py new file mode 100644 index 0000000000000..5ca851e646876 --- /dev/null +++ b/ivy/engines/__init__.py @@ -0,0 +1,2 @@ +from . import XLA +from .XLA import * diff --git a/ivy/engines/ivy2xla.cpython-310-x86_64-linux-gnu.so b/ivy/engines/ivy2xla.cpython-310-x86_64-linux-gnu.so new file mode 100755 index 0000000000000..39de8148f389e Binary files /dev/null and b/ivy/engines/ivy2xla.cpython-310-x86_64-linux-gnu.so differ diff --git a/ivy/engines/setup_xla.sh b/ivy/engines/setup_xla.sh new file mode 100644 index 0000000000000..26429770b8dc2 --- /dev/null +++ b/ivy/engines/setup_xla.sh @@ -0,0 +1,14 @@ +#!/bin/bash +#pip install virtualenv +cd XLA/rust_api/ +#mkdir xla_build && virtualenv xla_build +#source xla_build/bin/activate +wget https://github.com/elixir-nx/xla/releases/download/v0.4.4/xla_extension-x86_64-linux-gnu-cuda111.tar.gz +tar -xzvf xla_extension-x86_64-linux-gnu-cuda111.tar.gz +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y +source "$HOME/.cargo/env" +pip install maturin +apt-get update +apt install llvm-dev libclang-dev clang +export LIBCLANG_PATH=/usr/local/lib +# maturin develop diff --git a/ivy/func_wrapper.py b/ivy/func_wrapper.py index 6ff998368f2e6..e7fa2ca9acce7 100644 --- a/ivy/func_wrapper.py +++ b/ivy/func_wrapper.py @@ -494,7 +494,7 @@ def _inputs_to_ivy_arrays(*args, **kwargs): has_out = True # convert all arrays in the inputs to ivy.Array instances ivy_args, ivy_kwargs = ivy.args_to_ivy( - *args, **kwargs, include_derived={tuple: True} + *args, **kwargs, include_derived={"tuple": True} ) if has_out: ivy_kwargs["out"] = out @@ -564,7 +564,7 @@ def _outputs_to_ivy_arrays(*args, **kwargs): ret = fn(*args, **kwargs) # convert all arrays in the return to `ivy.Array` instances return ( - ivy.to_ivy(ret, nested=True, include_derived={tuple: True}) + ivy.to_ivy(ret, nested=True, include_derived={"tuple": True}) if ivy.array_mode else ret ) @@ -594,7 +594,7 @@ def output_to_native_arrays(fn: Callable) -> Callable: @functools.wraps(fn) def _output_to_native_arrays(*args, **kwargs): ret = fn(*args, **kwargs) - return ivy.to_native(ret, nested=True, include_derived={tuple: True}) + return ivy.to_native(ret, nested=True, include_derived={"tuple": True}) _output_to_native_arrays.outputs_to_native_arrays = True return _output_to_native_arrays diff --git a/ivy/functional/backends/jax/__init__.py b/ivy/functional/backends/jax/__init__.py index 036fe96894758..9ad6e9403b338 100644 --- a/ivy/functional/backends/jax/__init__.py +++ b/ivy/functional/backends/jax/__init__.py @@ -23,6 +23,19 @@ ) +# make ivy.Array compatible with jax pytree traversal +def _array_flatten(tree): + return ((tree.data,), None) + + +def _array_unflatten(aux_data, children): + if type(*children) == object: + return children + return ivy.Array(*children) + + +register_pytree_node(ivy.Array, _array_flatten, _array_unflatten) + # noinspection PyUnresolvedReferences if not ivy.is_local(): _module_in_memory = sys.modules[__name__] diff --git a/ivy/functional/backends/jax/creation.py b/ivy/functional/backends/jax/creation.py index 4fbc39f493756..06240484f03b9 100644 --- a/ivy/functional/backends/jax/creation.py +++ b/ivy/functional/backends/jax/creation.py @@ -13,13 +13,13 @@ from ivy import as_native_dtype from ivy.functional.backends.jax import JaxArray from ivy.functional.ivy.creation import ( - asarray_to_native_arrays_and_back, - asarray_infer_device, - asarray_infer_dtype, - asarray_handle_nestable, + _asarray_to_native_arrays_and_back, + _asarray_infer_device, + _asarray_infer_dtype, + _asarray_handle_nestable, NestedSequence, SupportsBufferProtocol, - asarray_inputs_to_native_shapes, + _asarray_inputs_to_native_shapes, ) @@ -49,11 +49,11 @@ def arange( return res -@asarray_to_native_arrays_and_back -@asarray_infer_device -@asarray_handle_nestable -@asarray_inputs_to_native_shapes -@asarray_infer_dtype +@_asarray_to_native_arrays_and_back +@_asarray_infer_device +@_asarray_handle_nestable +@_asarray_inputs_to_native_shapes +@_asarray_infer_dtype def asarray( obj: Union[ JaxArray, diff --git a/ivy/functional/backends/jax/elementwise.py b/ivy/functional/backends/jax/elementwise.py index 6de81f65a16eb..84f0f7d0b6518 100644 --- a/ivy/functional/backends/jax/elementwise.py +++ b/ivy/functional/backends/jax/elementwise.py @@ -14,6 +14,7 @@ from ivy.functional.backends.jax import JaxArray from ivy.func_wrapper import with_unsupported_dtypes from . import backend_version +from ...ivy.elementwise import _complex_to_inf def abs( @@ -401,13 +402,20 @@ def positive( def pow( - x1: Union[float, JaxArray], - x2: Union[float, JaxArray], + x1: JaxArray, + x2: Union[int, float, JaxArray], /, *, out: Optional[JaxArray] = None, ) -> JaxArray: x1, x2 = ivy.promote_types_of_inputs(x1, x2) + if ivy.is_complex_dtype(x1) and ivy.any(ivy.isinf(x2)): + inf_indices = jnp.nonzero(jnp.isinf(x2)) + ret = jnp.power(x1, x2) + ret[inf_indices] = _complex_to_inf(ret[inf_indices]) + return ret + if ivy.is_int_dtype(x1) and ivy.any(x2 < 0): + return jnp.float_power(x1, x2).astype(x1.dtype) return jnp.power(x1, x2) diff --git a/ivy/functional/backends/jax/experimental/activations.py b/ivy/functional/backends/jax/experimental/activations.py index 788126f4b44cf..ccb239dcbc430 100644 --- a/ivy/functional/backends/jax/experimental/activations.py +++ b/ivy/functional/backends/jax/experimental/activations.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, Union, Literal # global import jax @@ -13,6 +13,7 @@ def logit( /, *, eps: Optional[float] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[JaxArray] = None, ): if eps is None: @@ -48,7 +49,9 @@ def thresholded_relu( return jnp.where(x > threshold, x, 0).astype(x.dtype) -def logsigmoid(input: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: +def logsigmoid( + input: JaxArray, /, *, complex_mode="jax", out: Optional[JaxArray] = None +) -> JaxArray: return jax.nn.log_sigmoid(input) diff --git a/ivy/functional/backends/jax/experimental/creation.py b/ivy/functional/backends/jax/experimental/creation.py index 2bbc0b1b895e6..341d4f378c8a1 100644 --- a/ivy/functional/backends/jax/experimental/creation.py +++ b/ivy/functional/backends/jax/experimental/creation.py @@ -129,3 +129,33 @@ def trilu( if upper: return jnp.triu(x, k) return jnp.tril(x, k) + + +def mel_weight_matrix( + num_mel_bins: int, + dft_length: int, + sample_rate: int, + lower_edge_hertz: float = 0.0, + upper_edge_hertz: float = 3000.0, +): + lower_edge_hertz = jnp.array(lower_edge_hertz) + upper_edge_hertz = jnp.array(upper_edge_hertz) + zero = jnp.array(0.0) + hz_to_mel = lambda f: 2595 * jnp.log10(1 + f / 700) + nyquist_hz = sample_rate / 2 + linear_freqs = jnp.linspace(0, nyquist_hz, dft_length, dtype=jnp.float32)[1:] + spec_bin_mels = hz_to_mel(linear_freqs)[..., None] + mel_edges = jnp.linspace( + hz_to_mel(lower_edge_hertz), + hz_to_mel(upper_edge_hertz), + num_mel_bins + 2, + dtype=jnp.float32, + ) + mel_edges = jnp.stack([mel_edges[i : i + 3] for i in range(num_mel_bins)]) + lower_edge_mel, center_mel, upper_edge_mel = [ + t.reshape((1, num_mel_bins)) for t in jnp.split(mel_edges, 3, axis=1) + ] + lower_slopes = (spec_bin_mels - lower_edge_mel) / (center_mel - lower_edge_mel) + upper_slopes = (upper_edge_mel - spec_bin_mels) / (upper_edge_mel - center_mel) + mel_weights = jnp.maximum(zero, jnp.minimum(lower_slopes, upper_slopes)) + return jnp.pad(mel_weights, [[1, 0], [0, 0]]) diff --git a/ivy/functional/backends/jax/experimental/losses.py b/ivy/functional/backends/jax/experimental/losses.py index a5470e0afbc39..3fa778c0dbb9a 100644 --- a/ivy/functional/backends/jax/experimental/losses.py +++ b/ivy/functional/backends/jax/experimental/losses.py @@ -56,3 +56,23 @@ def soft_margin_loss( return jnp.sum(loss) else: return loss + + +def kl_div( + input: JaxArray, + target: JaxArray, + /, + *, + reduction: Optional[str] = "mean", +) -> JaxArray: + size = jnp.shape(input) + loss = jnp.sum(input * jnp.log(input / target), axis=-1) + + if reduction == "mean": + loss = jnp.mean(loss) + elif reduction == "sum": + loss = jnp.sum(loss) + elif reduction == "batchmean": + loss = jnp.divide(jnp.sum(loss), size[0]) + + return loss diff --git a/ivy/functional/backends/jax/experimental/manipulation.py b/ivy/functional/backends/jax/experimental/manipulation.py index 7258a7ae7ebe7..bd7575f2532f4 100644 --- a/ivy/functional/backends/jax/experimental/manipulation.py +++ b/ivy/functional/backends/jax/experimental/manipulation.py @@ -412,3 +412,9 @@ def fill_diagonal( a = a.at[:end:step].set(jnp.array(v).astype(a.dtype)) a = jnp.reshape(a, shape) return a + + +def column_stack( + arrays: Sequence[JaxArray], /, *, out: Optional[JaxArray] = None +) -> JaxArray: + return jnp.column_stack(arrays) diff --git a/ivy/functional/backends/jax/experimental/norms.py b/ivy/functional/backends/jax/experimental/norms.py index 68a0060cc5802..ad616122fec3d 100644 --- a/ivy/functional/backends/jax/experimental/norms.py +++ b/ivy/functional/backends/jax/experimental/norms.py @@ -1,11 +1,8 @@ import jax.numpy as jnp from typing import Optional from ivy.functional.backends.jax import JaxArray -from ivy.func_wrapper import with_unsupported_dtypes -from . import backend_version -@with_unsupported_dtypes({"0.4.14 and below": "uint8"}, backend_version) def l1_normalize( x: JaxArray, /, diff --git a/ivy/functional/backends/jax/experimental/statistical.py b/ivy/functional/backends/jax/experimental/statistical.py index 64f1bfc031973..9d35e51ae6017 100644 --- a/ivy/functional/backends/jax/experimental/statistical.py +++ b/ivy/functional/backends/jax/experimental/statistical.py @@ -162,6 +162,26 @@ def nanmean( return jnp.nanmean(a, axis=axis, keepdims=keepdims, dtype=dtype, out=out) +def nanprod( + a: JaxArray, + /, + *, + axis: Optional[Union[int, Sequence[int]]] = None, + dtype: Optional[jnp.dtype] = None, + keepdims: Optional[bool] = False, + out: Optional[JaxArray] = None, + initial: Optional[Union[int, float, complex]] = None, + where: Optional[JaxArray] = None, +) -> JaxArray: + dtype = ivy.as_native_dtype(dtype) + if dtype is None: + dtype = _infer_dtype(a.dtype) + axis = tuple(axis) if isinstance(axis, list) else axis + return jnp.nanprod( + a, axis=axis, keepdims=keepdims, dtype=dtype, out=out, initial=initial + ) + + def quantile( a: JaxArray, q: Union[float, JaxArray], diff --git a/ivy/functional/backends/mxnet/creation.py b/ivy/functional/backends/mxnet/creation.py index 632b1daca8f9e..1575fbb463349 100644 --- a/ivy/functional/backends/mxnet/creation.py +++ b/ivy/functional/backends/mxnet/creation.py @@ -8,12 +8,12 @@ import ivy from ivy.utils.exceptions import IvyNotImplementedException from ivy.functional.ivy.creation import ( - asarray_to_native_arrays_and_back, - asarray_infer_device, - asarray_handle_nestable, + _asarray_to_native_arrays_and_back, + _asarray_infer_device, + _asarray_handle_nestable, NestedSequence, SupportsBufferProtocol, - asarray_inputs_to_native_shapes, + _asarray_inputs_to_native_shapes, ) @@ -30,10 +30,10 @@ def arange( raise IvyNotImplementedException() -@asarray_to_native_arrays_and_back -@asarray_infer_device -@asarray_handle_nestable -@asarray_inputs_to_native_shapes +@_asarray_to_native_arrays_and_back +@_asarray_infer_device +@_asarray_handle_nestable +@_asarray_inputs_to_native_shapes def asarray( obj: Union[ ( diff --git a/ivy/functional/backends/mxnet/elementwise.py b/ivy/functional/backends/mxnet/elementwise.py index 42a8a2788b605..5be5e5d59cf5c 100644 --- a/ivy/functional/backends/mxnet/elementwise.py +++ b/ivy/functional/backends/mxnet/elementwise.py @@ -463,8 +463,8 @@ def positive( def pow( - x1: Union[(float, None, mx.ndarray.NDArray)], - x2: Union[(float, None, mx.ndarray.NDArray)], + x1: Union[(None, mx.ndarray.NDArray)], + x2: Union[(int, float, None, mx.ndarray.NDArray)], /, *, out: Optional[Union[(None, mx.ndarray.NDArray)]] = None, diff --git a/ivy/functional/backends/numpy/creation.py b/ivy/functional/backends/numpy/creation.py index 7600f3cf9dba0..3835ce807cae1 100644 --- a/ivy/functional/backends/numpy/creation.py +++ b/ivy/functional/backends/numpy/creation.py @@ -8,13 +8,13 @@ import ivy from ivy.functional.backends.numpy.device import _to_device from ivy.functional.ivy.creation import ( - asarray_to_native_arrays_and_back, - asarray_infer_device, - asarray_infer_dtype, - asarray_handle_nestable, + _asarray_to_native_arrays_and_back, + _asarray_infer_device, + _asarray_infer_dtype, + _asarray_handle_nestable, NestedSequence, SupportsBufferProtocol, - asarray_inputs_to_native_shapes, + _asarray_inputs_to_native_shapes, ) from .data_type import as_native_dtype @@ -44,11 +44,11 @@ def arange( return res -@asarray_to_native_arrays_and_back -@asarray_infer_device -@asarray_handle_nestable -@asarray_inputs_to_native_shapes -@asarray_infer_dtype +@_asarray_to_native_arrays_and_back +@_asarray_infer_device +@_asarray_handle_nestable +@_asarray_inputs_to_native_shapes +@_asarray_infer_dtype def asarray( obj: Union[ np.ndarray, bool, int, float, tuple, NestedSequence, SupportsBufferProtocol diff --git a/ivy/functional/backends/numpy/elementwise.py b/ivy/functional/backends/numpy/elementwise.py index 6c4459bbaaf4f..72060aae2fed1 100644 --- a/ivy/functional/backends/numpy/elementwise.py +++ b/ivy/functional/backends/numpy/elementwise.py @@ -604,14 +604,16 @@ def positive( @_scalar_output_to_0d_array def pow( - x1: Union[float, np.ndarray], - x2: Union[float, np.ndarray], + x1: np.ndarray, + x2: Union[int, float, np.ndarray], /, *, out: Optional[np.ndarray] = None, ) -> np.ndarray: x1, x2 = ivy.promote_types_of_inputs(x1, x2) - return np.power(x1, x2, out=out) + if ivy.is_int_dtype(x1) and ivy.any(x2 < 0): + return np.float_power(x1, x2, casting='unsafe').astype(x1.dtype) + return np.power(x1, x2) pow.support_native_out = True diff --git a/ivy/functional/backends/numpy/experimental/activations.py b/ivy/functional/backends/numpy/experimental/activations.py index d3f8282a22de2..707ecd8a4f4df 100644 --- a/ivy/functional/backends/numpy/experimental/activations.py +++ b/ivy/functional/backends/numpy/experimental/activations.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, Union, Literal # global import numpy as np @@ -15,6 +15,7 @@ def logit( /, *, eps: Optional[float] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[np.ndarray] = None, ): x_dtype = x.dtype @@ -52,7 +53,9 @@ def relu6(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: @with_unsupported_dtypes({"1.25.2 and below": ("bool",)}, backend_version) @_scalar_output_to_0d_array -def logsigmoid(input: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: +def logsigmoid( + input: np.ndarray, /, *, complex_mode="jax", out: Optional[np.ndarray] = None +) -> np.ndarray: return -(np.log1p(np.exp(-(input)))) diff --git a/ivy/functional/backends/numpy/experimental/creation.py b/ivy/functional/backends/numpy/experimental/creation.py index c2e01f77c55e8..6b111ca513bcb 100644 --- a/ivy/functional/backends/numpy/experimental/creation.py +++ b/ivy/functional/backends/numpy/experimental/creation.py @@ -171,3 +171,33 @@ def trilu( if upper: return np.triu(x, k) return np.tril(x, k) + + +def mel_weight_matrix( + num_mel_bins: int, + dft_length: int, + sample_rate: int, + lower_edge_hertz: float = 125.0, + upper_edge_hertz: float = 3000.0, +): + lower_edge_hertz = np.array(lower_edge_hertz) + upper_edge_hertz = np.array(upper_edge_hertz) + zero = np.array(0.0) + hz_to_mel = lambda f: 2595 * np.log10(1 + f / 700) + nyquist_hz = sample_rate / 2 + linear_freqs = np.linspace(0, nyquist_hz, dft_length, dtype=np.float32)[1:] + spec_bin_mels = hz_to_mel(linear_freqs)[..., None] + mel_edges = np.linspace( + hz_to_mel(lower_edge_hertz), + hz_to_mel(upper_edge_hertz), + num_mel_bins + 2, + dtype=np.float32, + ) + mel_edges = np.stack([mel_edges[i : i + 3] for i in range(num_mel_bins)]) + lower_edge_mel, center_mel, upper_edge_mel = [ + t.reshape((1, num_mel_bins)) for t in np.split(mel_edges, 3, axis=1) + ] + lower_slopes = (spec_bin_mels - lower_edge_mel) / (center_mel - lower_edge_mel) + upper_slopes = (upper_edge_mel - spec_bin_mels) / (upper_edge_mel - center_mel) + mel_weights = np.maximum(zero, np.minimum(lower_slopes, upper_slopes)) + return np.pad(mel_weights, [[1, 0], [0, 0]]) diff --git a/ivy/functional/backends/numpy/experimental/losses.py b/ivy/functional/backends/numpy/experimental/losses.py index 0a33b6bcf1dad..76a266109e785 100644 --- a/ivy/functional/backends/numpy/experimental/losses.py +++ b/ivy/functional/backends/numpy/experimental/losses.py @@ -70,3 +70,26 @@ def soft_margin_loss( return np.sum(loss) else: return loss + + +@with_unsupported_dtypes({"1.25.2 and below": ("bool", "bfloat16")}, backend_version) +@_scalar_output_to_0d_array +def kl_div( + input: np.ndarray, + target: np.ndarray, + /, + *, + reduction: Optional[str] = "mean", +) -> np.ndarray: + size = np.shape(input) + + loss = np.sum(input * np.log(input / target), axis=-1) + + if reduction == "mean": + loss = np.mean(loss) + elif reduction == "sum": + loss = np.sum(loss) + elif reduction == "batchmean": + loss = np.divide(np.sum(loss), size[0]) + + return loss diff --git a/ivy/functional/backends/numpy/experimental/manipulation.py b/ivy/functional/backends/numpy/experimental/manipulation.py index 53436987e2187..fe6971128e20c 100644 --- a/ivy/functional/backends/numpy/experimental/manipulation.py +++ b/ivy/functional/backends/numpy/experimental/manipulation.py @@ -477,3 +477,9 @@ def fill_diagonal( ) -> np.ndarray: np.fill_diagonal(a, v, wrap=wrap) return a + + +def column_stack( + arrays: Sequence[np.ndarray], /, *, out: Optional[np.ndarray] = None +) -> np.ndarray: + return np.column_stack(arrays) diff --git a/ivy/functional/backends/numpy/experimental/statistical.py b/ivy/functional/backends/numpy/experimental/statistical.py index 12c3565148ec7..667c6106c61a3 100644 --- a/ivy/functional/backends/numpy/experimental/statistical.py +++ b/ivy/functional/backends/numpy/experimental/statistical.py @@ -167,6 +167,31 @@ def nanmean( nanmean.support_native_out = True +def nanprod( + a: np.ndarray, + /, + *, + axis: Optional[Union[int, Sequence[int]]] = None, + dtype: Optional[np.dtype] = None, + keepdims: Optional[bool] = False, + out: Optional[np.ndarray] = None, + initial: Optional[Union[int, float, complex]] = None, + where: Optional[np.ndarray] = None, +) -> np.ndarray: + dtype = ivy.as_native_dtype(dtype) + if dtype is None: + dtype = _infer_dtype(a.dtype) + axis = tuple(axis) if isinstance(axis, list) else axis + return np.asarray( + np.nanprod( + a=a, axis=axis, dtype=dtype, keepdims=keepdims, out=out, initial=initial + ) + ) + + +nanprod.support_native_out = True + + def _validate_quantile(q): if isinstance(q, float): q = np.asarray(q) diff --git a/ivy/functional/backends/paddle/activations.py b/ivy/functional/backends/paddle/activations.py index dabf30aec9d22..42fa990c6177b 100644 --- a/ivy/functional/backends/paddle/activations.py +++ b/ivy/functional/backends/paddle/activations.py @@ -97,7 +97,7 @@ def sigmoid( @with_unsupported_device_and_dtypes( - {"2.5.1 and below": {"cpu": ("float16",)}}, backend_version + {"2.5.1 and below": {"cpu": ("float16", "bfloat16")}}, backend_version ) def softmax( x: paddle.Tensor, @@ -108,12 +108,13 @@ def softmax( ) -> paddle.Tensor: if axis is None: axis = -1 - exp_x = paddle_backend.exp( - paddle_backend.subtract(x, paddle_backend.max(x, axis=axis, keepdims=True)) - ) - return paddle_backend.divide( - exp_x, paddle_backend.sum(exp_x, axis=axis, keepdims=True) - ) + + if paddle.is_complex(x): + amax = paddle_backend.max(x, axis=axis, keepdims=True) + else: + amax = paddle.max(x, axis, keepdim=True) + exp_x = paddle_backend.exp(paddle.subtract(x, amax)) + return paddle.divide(exp_x, paddle.sum(exp_x, axis=axis, keepdim=True)) def softplus( diff --git a/ivy/functional/backends/paddle/creation.py b/ivy/functional/backends/paddle/creation.py index 7b01cae902d53..ffbb017068a8f 100644 --- a/ivy/functional/backends/paddle/creation.py +++ b/ivy/functional/backends/paddle/creation.py @@ -13,13 +13,13 @@ with_unsupported_device_and_dtypes, ) from ivy.functional.ivy.creation import ( - asarray_to_native_arrays_and_back, - asarray_infer_device, - asarray_handle_nestable, - asarray_infer_dtype, + _asarray_to_native_arrays_and_back, + _asarray_infer_device, + _asarray_handle_nestable, + _asarray_infer_dtype, NestedSequence, SupportsBufferProtocol, - asarray_inputs_to_native_shapes, + _asarray_inputs_to_native_shapes, _remove_np_bfloat16, ) from . import backend_version @@ -64,11 +64,11 @@ def arange( return paddle.arange(start, stop, step).cast(dtype) -@asarray_to_native_arrays_and_back -@asarray_infer_device -@asarray_handle_nestable -@asarray_inputs_to_native_shapes -@asarray_infer_dtype +@_asarray_to_native_arrays_and_back +@_asarray_infer_device +@_asarray_handle_nestable +@_asarray_inputs_to_native_shapes +@_asarray_infer_dtype def asarray( obj: Union[ paddle.Tensor, diff --git a/ivy/functional/backends/paddle/elementwise.py b/ivy/functional/backends/paddle/elementwise.py index d5dda72b4a901..09210b97a29f0 100644 --- a/ivy/functional/backends/paddle/elementwise.py +++ b/ivy/functional/backends/paddle/elementwise.py @@ -10,6 +10,7 @@ # local from . import backend_version +from ...ivy.elementwise import _complex_to_inf def _elementwise_helper(x1, x2): @@ -800,13 +801,18 @@ def square( {"2.5.1 and below": {"cpu": ("bfloat16",)}}, backend_version ) def pow( - x1: Union[float, paddle.Tensor], - x2: Union[float, paddle.Tensor], + x1: paddle.Tensor, + x2: Union[int, float, paddle.Tensor], /, *, out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: x1, x2, ret_dtype = _elementwise_helper(x1, x2) + if ivy.is_complex_dtype(x1) and ivy.any(ivy.isinf(x2)): + inf_indices = paddle.nonzero(paddle.isinf(x2)) + ret = paddle.pow(x1, x2) + ret[inf_indices] = _complex_to_inf(ret[inf_indices]) + return ret if x1.dtype in [ paddle.int8, paddle.int16, diff --git a/ivy/functional/backends/paddle/experimental/activations.py b/ivy/functional/backends/paddle/experimental/activations.py index 81bc6bdf25cb7..33f1d80bbf8ef 100644 --- a/ivy/functional/backends/paddle/experimental/activations.py +++ b/ivy/functional/backends/paddle/experimental/activations.py @@ -1,5 +1,5 @@ # global -from typing import Optional, Union +from typing import Optional, Union, Literal import paddle import paddle.nn.functional as F @@ -10,9 +10,16 @@ @with_unsupported_device_and_dtypes( - {"2.5.1 and below": {"cpu": ("float16",)}}, backend_version + {"2.5.1 and below": {"cpu": ("float16", "bfloat16")}}, backend_version ) -def logit(x: paddle.Tensor, /, *, eps: Optional[float] = None, out=None): +def logit( + x: paddle.Tensor, + /, + *, + eps: Optional[float] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out=None, +): if x.dtype in [paddle.float32, paddle.float64]: return paddle.logit(x, eps) if eps is None: @@ -55,15 +62,18 @@ def relu6(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle return F.relu6(x.cast("float32")).cast(x.dtype) +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("bfloat16",)}}, backend_version +) def logsigmoid( - input: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None + input: paddle.Tensor, /, *, complex_mode="jax", out: Optional[paddle.Tensor] = None ) -> paddle.Tensor: if input.dtype in [paddle.float32, paddle.float64]: return F.log_sigmoid(input) if paddle.is_complex(input): return paddle_backend.log( paddle_backend.divide( - 1.0, (paddle_backend.add(1.0, paddle_backend.exp(input))) + 1.0, (paddle_backend.add(1.0, paddle_backend.exp(-input))) ) ) return F.log_sigmoid(input.cast("float32")).cast(input.dtype) diff --git a/ivy/functional/backends/paddle/experimental/creation.py b/ivy/functional/backends/paddle/experimental/creation.py index 04d15e026a29d..1de567c08e0bc 100644 --- a/ivy/functional/backends/paddle/experimental/creation.py +++ b/ivy/functional/backends/paddle/experimental/creation.py @@ -212,3 +212,21 @@ def trilu( if upper: return paddle.triu(x=x, diagonal=k) return paddle.tril(x=x, diagonal=k) + + +def mel_weight_matrix( + num_mel_bins: int, + dft_length: int, + sample_rate: int, + lower_edge_hertz: float = 0.0, + upper_edge_hertz: float = 3000.0, +): + n_fft = (dft_length - 1) * 2 + mel_mat = paddle.audio.functional.compute_fbank_matrix( + sample_rate, + n_fft, + num_mel_bins, + lower_edge_hertz, + upper_edge_hertz, + ) + return paddle.transpose(mel_mat, (1, 0)) diff --git a/ivy/functional/backends/paddle/experimental/losses.py b/ivy/functional/backends/paddle/experimental/losses.py index b0be8bda8d502..a582b43e15fb2 100644 --- a/ivy/functional/backends/paddle/experimental/losses.py +++ b/ivy/functional/backends/paddle/experimental/losses.py @@ -120,3 +120,29 @@ def soft_margin_loss( reduction: Optional[str] = "mean", ) -> paddle.Tensor: return paddle.nn.functional.soft_margin_loss(input, label, reduction=reduction) + + +@with_unsupported_device_and_dtypes( + { + "2.5.1 and below": { + "cpu": ( + "bfloat16", + "float16", + "int8", + "int16", + "int32", + "int64", + "uint8", + "complex64", + "complex128", + "bool", + ) + } + }, + backend_version, +) +def kl_div( + input: paddle.Tensor, target: paddle.Tensor, /, *, reduction: Optional[str] = "mean" +) -> paddle.Tensor: + loss = F.kl_div(input, target, reduction=reduction) + return loss diff --git a/ivy/functional/backends/paddle/experimental/manipulation.py b/ivy/functional/backends/paddle/experimental/manipulation.py index 8b2a81dcb4d01..88099753ec83b 100644 --- a/ivy/functional/backends/paddle/experimental/manipulation.py +++ b/ivy/functional/backends/paddle/experimental/manipulation.py @@ -215,6 +215,10 @@ def vstack( return ivy.stack(arrays, axis=0) +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("int16", "bfloat16")}}, + backend_version, +) def hstack( arrays: Sequence[paddle.Tensor], /, diff --git a/ivy/functional/backends/paddle/experimental/statistical.py b/ivy/functional/backends/paddle/experimental/statistical.py index 0cc4e40739733..336fe507f94bc 100644 --- a/ivy/functional/backends/paddle/experimental/statistical.py +++ b/ivy/functional/backends/paddle/experimental/statistical.py @@ -6,7 +6,7 @@ from copy import deepcopy # local -from ivy.func_wrapper import with_unsupported_device_and_dtypes +from ivy.func_wrapper import with_unsupported_device_and_dtypes, with_supported_dtypes from ivy.utils.exceptions import IvyNotImplementedException from . import backend_version @@ -96,6 +96,13 @@ def nanmean( return ret.astype(ret_dtype) +def _infer_dtype(dtype: paddle.dtype): + default_dtype = ivy.infer_default_dtype(dtype) + if ivy.dtype_bits(dtype) < ivy.dtype_bits(default_dtype): + return default_dtype + return dtype + + def _validate_quantile(q): if isinstance(q, float): q = paddle.to_tensor(q) @@ -109,6 +116,42 @@ def _validate_quantile(q): return True +@with_supported_dtypes( + {"2.5.1 and below": ("float64", "float32")}, + backend_version, +) +def nanprod( + a: paddle.Tensor, + /, + *, + axis: Optional[Union[int, Tuple[int]]] = None, + keepdims: Optional[bool] = False, + dtype: Optional[paddle.dtype] = None, + out: Optional[paddle.Tensor] = None, + initial: Optional[Union[int, float, complex]] = None, + where: Optional[paddle.Tensor] = None, +) -> paddle.Tensor: + dtype = ivy.as_native_dtype(dtype) + if dtype is None: + dtype = _infer_dtype(a.dtype) + a = a.cast(dtype) + if initial is None: + initial = 1 + if a.dtype not in [paddle.int32, paddle.int64, paddle.float32, paddle.float64]: + a = paddle.nan_to_num(a.cast("float64"), nan=1.0) + ret = paddle.prod(a, axis=axis, keepdim=keepdims) * initial + else: + a = paddle.nan_to_num(a, nan=1.0) + ret = paddle.prod(a, axis=axis, keepdim=keepdims) * initial + + if isinstance(axis, Sequence): + if len(axis) == a.ndim: + axis = None + if (a.ndim == 1 or axis is None) and not keepdims: + ret = ret.squeeze() + return ret.cast(dtype) + + def _to_positive_axis(axis, ndim): if not isinstance(axis, (list, tuple)): axis = [axis] diff --git a/ivy/functional/backends/paddle/statistical.py b/ivy/functional/backends/paddle/statistical.py index 4d0297b1b7aac..0f294cde0e285 100644 --- a/ivy/functional/backends/paddle/statistical.py +++ b/ivy/functional/backends/paddle/statistical.py @@ -77,9 +77,17 @@ def max( paddle.bool, ]: if paddle.is_complex(x): - real_part = paddle.amax(x.real(), axis=axis, keepdim=keepdims) - imag_part = paddle.amax(x.imag(), axis=axis, keepdim=keepdims) - ret = paddle.complex(real_part, imag_part) + const = paddle.to_tensor(1j, dtype=x.dtype) + real_max = paddle.max(x.real(), axis=axis, keepdim=keepdims) + imag = paddle.where( + x.real() == real_max, x.imag(), paddle.full_like(x.imag(), -1e10) + ) + # we consider the number with the biggest real and imag part + img_max = paddle.max(imag, axis=axis, keepdim=keepdims) + img_max = paddle.cast(img_max, x.dtype) + return paddle.add( + paddle.cast(real_max, x.dtype), paddle.multiply(img_max, const) + ) else: ret = paddle.amax(x.cast("float32"), axis=axis, keepdim=keepdims) else: diff --git a/ivy/functional/backends/tensorflow/activations.py b/ivy/functional/backends/tensorflow/activations.py index 9fc9325977165..6fb69cc2c7fb0 100644 --- a/ivy/functional/backends/tensorflow/activations.py +++ b/ivy/functional/backends/tensorflow/activations.py @@ -15,6 +15,7 @@ import ivy from ivy.func_wrapper import with_unsupported_dtypes, with_supported_dtypes from . import backend_version +import ivy.functional.backends.tensorflow as tf_backend def gelu( @@ -51,10 +52,18 @@ def sigmoid(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor: return tf.nn.sigmoid(x) -@with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version) def softmax( x: Tensor, /, *, axis: Optional[int] = None, out: Optional[Tensor] = None ) -> Tensor: + if axis is None: + axis = -1 + dtype = x.dtype + if "complex" in str(dtype): + amax = tf_backend.max(x, axis=axis, keepdims=True) + normalized = tf.exp(tf.subtract(x, amax)) + return tf.divide( + normalized, tf.reduce_sum(normalized, axis=axis, keepdims=True) + ) return tf.nn.softmax(x, axis) diff --git a/ivy/functional/backends/tensorflow/creation.py b/ivy/functional/backends/tensorflow/creation.py index 25ce085eb766d..05f3487191575 100644 --- a/ivy/functional/backends/tensorflow/creation.py +++ b/ivy/functional/backends/tensorflow/creation.py @@ -9,13 +9,13 @@ import ivy from ivy.func_wrapper import with_unsupported_dtypes from ivy.functional.ivy.creation import ( - asarray_to_native_arrays_and_back, - asarray_infer_device, - asarray_infer_dtype, - asarray_handle_nestable, + _asarray_to_native_arrays_and_back, + _asarray_infer_device, + _asarray_infer_dtype, + _asarray_handle_nestable, NestedSequence, SupportsBufferProtocol, - asarray_inputs_to_native_shapes, + _asarray_inputs_to_native_shapes, ) from . import backend_version @@ -65,11 +65,11 @@ def arange( return tf.range(start, stop, delta=step, dtype=dtype) -@asarray_to_native_arrays_and_back -@asarray_infer_device -@asarray_handle_nestable -@asarray_inputs_to_native_shapes -@asarray_infer_dtype +@_asarray_to_native_arrays_and_back +@_asarray_infer_device +@_asarray_handle_nestable +@_asarray_inputs_to_native_shapes +@_asarray_infer_dtype def asarray( obj: Union[ tf.Tensor, @@ -93,6 +93,9 @@ def asarray( try: ret = tf.convert_to_tensor(obj, dtype) except (TypeError, ValueError): + obj = ( + obj if isinstance(obj, tf.Tensor) else tf.convert_to_tensor(obj, tf.float64) + ) ret = tf.cast(obj, dtype) return tf.identity(ret) if copy else ret diff --git a/ivy/functional/backends/tensorflow/elementwise.py b/ivy/functional/backends/tensorflow/elementwise.py index 1e060e8e8cae3..751a42d4a24e6 100644 --- a/ivy/functional/backends/tensorflow/elementwise.py +++ b/ivy/functional/backends/tensorflow/elementwise.py @@ -605,8 +605,8 @@ def positive( backend_version, ) def pow( - x1: Union[float, tf.Tensor, tf.Variable], - x2: Union[float, tf.Tensor, tf.Variable], + x1: Union[tf.Tensor, tf.Variable], + x2: Union[int, float, tf.Tensor, tf.Variable], /, *, out: Optional[Union[tf.Tensor, tf.Variable]] = None, @@ -620,7 +620,14 @@ def pow( if x2.dtype.is_unsigned: x2 = tf.cast(x2, tf.float64) return tf.cast(tf.experimental.numpy.power(x1, x2), promoted_type) - return tf.experimental.numpy.power(x1, x2) + orig_x1_dtype = None + if ivy.is_int_dtype(x1) and ivy.any(x2 < 0): + orig_x1_dtype = x1.dtype + x1 = tf.cast(x1, tf.float32) + ret = tf.experimental.numpy.power(x1, x2) + if orig_x1_dtype is not None: + return tf.cast(ret, orig_x1_dtype) + return ret @with_unsupported_dtypes({"2.13.0 and below": ("bfloat16", "complex")}, backend_version) diff --git a/ivy/functional/backends/tensorflow/experimental/activations.py b/ivy/functional/backends/tensorflow/experimental/activations.py index f798d7b907d98..709ae09314cd1 100644 --- a/ivy/functional/backends/tensorflow/experimental/activations.py +++ b/ivy/functional/backends/tensorflow/experimental/activations.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, Union, Literal # global import tensorflow as tf @@ -15,6 +15,7 @@ def logit( /, *, eps: Optional[float] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[Tensor] = None, ) -> Tensor: x_dtype = x.dtype @@ -43,7 +44,11 @@ def relu6(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor: @with_supported_dtypes({"2.13.0 and below": ("float",)}, backend_version) -def logsigmoid(input: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor: +def logsigmoid( + input: Tensor, /, *, complex_mode="jax", out: Optional[Tensor] = None +) -> Tensor: + if input.dtype in [tf.complex64, tf.complex128]: + return tf.math.log(tf.nn.sigmoid(input)) return tf.math.log_sigmoid(input) diff --git a/ivy/functional/backends/tensorflow/experimental/creation.py b/ivy/functional/backends/tensorflow/experimental/creation.py index 96cd0c15b84be..6546675630211 100644 --- a/ivy/functional/backends/tensorflow/experimental/creation.py +++ b/ivy/functional/backends/tensorflow/experimental/creation.py @@ -138,3 +138,19 @@ def trilu( if upper: return tf.experimental.numpy.triu(x, k) return tf.experimental.numpy.tril(x, k) + + +def mel_weight_matrix( + num_mel_bins: int, + dft_length: int, + sample_rate: int, + lower_edge_hertz: float = 125.0, + upper_edge_hertz: float = 3000.0, +): + return tf.signal.linear_to_mel_weight_matrix( + num_mel_bins, + dft_length, + sample_rate, + lower_edge_hertz=lower_edge_hertz, + upper_edge_hertz=upper_edge_hertz, + ) diff --git a/ivy/functional/backends/tensorflow/experimental/losses.py b/ivy/functional/backends/tensorflow/experimental/losses.py index fdd493e40b8e8..e0c2da76b9958 100644 --- a/ivy/functional/backends/tensorflow/experimental/losses.py +++ b/ivy/functional/backends/tensorflow/experimental/losses.py @@ -62,3 +62,25 @@ def soft_margin_loss( return tf.reduce_mean(loss) else: return loss + + +@with_unsupported_dtypes({"2.13.0 and below": ("bool", "bfloat16")}, backend_version) +def kl_div( + input: tf.Tensor, + target: tf.Tensor, + /, + *, + reduction: Optional[str] = "mean", +) -> tf.Tensor: + size = tf.shape(input) + + loss = tf.reduce_sum(input * tf.math.log(input / target), axis=-1) + + if reduction == "mean": + loss = tf.math.reduce_mean(loss) + elif reduction == "sum": + loss = tf.math.reduce_sum(loss) + elif reduction == "batchmean": + loss = tf.math.reduce_sum(loss) / tf.cast(size[0], dtype=tf.float32) + + return loss diff --git a/ivy/functional/backends/tensorflow/experimental/norms.py b/ivy/functional/backends/tensorflow/experimental/norms.py index 5fc38b60283af..bd3a1abb624fa 100644 --- a/ivy/functional/backends/tensorflow/experimental/norms.py +++ b/ivy/functional/backends/tensorflow/experimental/norms.py @@ -4,6 +4,7 @@ from . import backend_version +@with_unsupported_dtypes({"2.13.0 and below": "uint8"}, backend_version) def l1_normalize( x: Union[tf.Tensor, tf.Variable], /, diff --git a/ivy/functional/backends/tensorflow/experimental/statistical.py b/ivy/functional/backends/tensorflow/experimental/statistical.py index 9cd91437a5f4a..7c7d45204ff3f 100644 --- a/ivy/functional/backends/tensorflow/experimental/statistical.py +++ b/ivy/functional/backends/tensorflow/experimental/statistical.py @@ -115,6 +115,37 @@ def nanmean( return tf.experimental.numpy.nanmean(a, axis=axis, keepdims=keepdims, dtype=dtype) +def _infer_dtype(dtype: tf.DType): + default_dtype = ivy.infer_default_dtype(dtype) + if ivy.dtype_bits(dtype) < ivy.dtype_bits(default_dtype): + return default_dtype + return dtype + + +def nanprod( + a: Union[tf.Tensor, tf.Variable], + /, + *, + axis: Optional[Union[int, Sequence[int]]] = None, + dtype: Optional[tf.DType] = None, + keepdims: Optional[bool] = False, + out: Optional[Union[tf.Tensor, tf.Variable]] = None, + initial: Optional[Union[int, float, complex]] = None, + where: Optional[Union[tf.Tensor, tf.Variable]] = None, +) -> Union[tf.Tensor, tf.Variable]: + np_math_ops.enable_numpy_methods_on_tensor() + dtype = ivy.as_native_dtype(dtype) + if dtype is None: + dtype = _infer_dtype(a.dtype) + if initial is None: + initial = 1 + axis = tuple(axis) if isinstance(axis, list) else axis + return ( + tf.experimental.numpy.nanprod(a, axis=axis, keepdims=keepdims, dtype=dtype) + * initial + ) + + def _validate_quantile(q): if tf.experimental.numpy.ndim(q) == 1 and tf.size(q) < 10: for i in range(tf.size(q)): diff --git a/ivy/functional/backends/tensorflow/general.py b/ivy/functional/backends/tensorflow/general.py index ba239549e3fdd..cb24872314909 100644 --- a/ivy/functional/backends/tensorflow/general.py +++ b/ivy/functional/backends/tensorflow/general.py @@ -37,7 +37,7 @@ def array_equal( /, ) -> bool: x0, x1 = ivy.promote_types_of_inputs(x0, x1) - return bool((tf.experimental.numpy.array_equal(x0, x1))) + return bool(tf.experimental.numpy.array_equal(x0, x1)) def container_types(): diff --git a/ivy/functional/backends/tensorflow/statistical.py b/ivy/functional/backends/tensorflow/statistical.py index 53fc8b5bb57e6..874be24f1dad1 100644 --- a/ivy/functional/backends/tensorflow/statistical.py +++ b/ivy/functional/backends/tensorflow/statistical.py @@ -26,7 +26,6 @@ def min( return tf.math.reduce_min(x, axis=axis, keepdims=keepdims) -@with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version) def max( x: Union[tf.Tensor, tf.Variable], /, @@ -35,6 +34,20 @@ def max( keepdims: bool = False, out: Optional[Union[tf.Tensor, tf.Variable]] = None, ) -> Union[tf.Tensor, tf.Variable]: + if "complex" in str(x.dtype): + real = tf.math.real(x) + img = tf.math.imag(x) + const = tf.constant(1j, dtype=x.dtype) + real_max = tf.reduce_max(real, axis=axis, keepdims=keepdims) + imag = tf.where( + real == real_max, + img, + tf.experimental.numpy.finfo(img.dtype).min, + ) + # we consider the number with the biggest real and imag part + img_max = tf.reduce_max(imag, axis=axis, keepdims=keepdims) + img_max = tf.cast(img_max, x.dtype) + return tf.add(tf.cast(real_max, x.dtype), tf.multiply(img_max, const)) axis = tuple(axis) if isinstance(axis, list) else axis return tf.math.reduce_max(x, axis=axis, keepdims=keepdims) diff --git a/ivy/functional/backends/torch/__init__.py b/ivy/functional/backends/torch/__init__.py index 11459cc8a797b..1dd6577901944 100644 --- a/ivy/functional/backends/torch/__init__.py +++ b/ivy/functional/backends/torch/__init__.py @@ -8,6 +8,10 @@ backend_version = {"version": torch.__version__.split("+")[0]} +# Registering ivy.Array as trackable submodule +if hasattr(torch, "_dynamo"): + torch._dynamo.config.traceable_tensor_subclasses = (ivy.Array,) + # noinspection PyUnresolvedReferences if not ivy.is_local(): _module_in_memory = sys.modules[__name__] diff --git a/ivy/functional/backends/torch/activations.py b/ivy/functional/backends/torch/activations.py index 53aed823a1593..0e01beb0f7131 100644 --- a/ivy/functional/backends/torch/activations.py +++ b/ivy/functional/backends/torch/activations.py @@ -15,6 +15,7 @@ import ivy from ivy.func_wrapper import with_unsupported_dtypes from . import backend_version +import ivy.functional.backends.torch as torch_backend @with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) @@ -62,7 +63,7 @@ def sigmoid(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch. sigmoid.support_native_out = True -@with_unsupported_dtypes({"2.0.1 and below": ("complex", "float16")}, backend_version) +@with_unsupported_dtypes({"2.0.1 and below": ("bfloat16", "float16")}, backend_version) def softmax( x: torch.Tensor, /, @@ -72,6 +73,10 @@ def softmax( ) -> torch.Tensor: if axis is None: axis = -1 + if torch.is_complex(x): + amax = torch_backend.max(x, axis=axis, keepdims=True) + exp_x = torch.exp(torch.subtract(x, amax)) + return torch.divide(exp_x, torch.sum(exp_x, dim=axis, keepdim=True)) return torch.nn.functional.softmax(x, axis) diff --git a/ivy/functional/backends/torch/creation.py b/ivy/functional/backends/torch/creation.py index 7c5a41482e44c..80a063961e0e1 100644 --- a/ivy/functional/backends/torch/creation.py +++ b/ivy/functional/backends/torch/creation.py @@ -13,13 +13,13 @@ with_unsupported_device_and_dtypes, ) from ivy.functional.ivy.creation import ( - asarray_to_native_arrays_and_back, - asarray_infer_device, - asarray_infer_dtype, - asarray_handle_nestable, + _asarray_to_native_arrays_and_back, + _asarray_infer_device, + _asarray_infer_dtype, + _asarray_handle_nestable, NestedSequence, SupportsBufferProtocol, - asarray_inputs_to_native_shapes, + _asarray_inputs_to_native_shapes, _remove_np_bfloat16, ) from . import backend_version @@ -96,11 +96,11 @@ def _stack_tensors(x, dtype): @with_unsupported_dtypes({"2.0.1 and below": ("bfloat16",)}, backend_version) -@asarray_to_native_arrays_and_back -@asarray_infer_device -@asarray_handle_nestable -@asarray_inputs_to_native_shapes -@asarray_infer_dtype +@_asarray_to_native_arrays_and_back +@_asarray_infer_device +@_asarray_handle_nestable +@_asarray_inputs_to_native_shapes +@_asarray_infer_dtype def asarray( obj: Union[ torch.Tensor, diff --git a/ivy/functional/backends/torch/device.py b/ivy/functional/backends/torch/device.py index 0a0832e371615..ebf677c06d728 100644 --- a/ivy/functional/backends/torch/device.py +++ b/ivy/functional/backends/torch/device.py @@ -74,7 +74,7 @@ def as_native_dev( ) -> Optional[torch.device]: if not isinstance(device, str): return device - if device == "mps": + if torch.backends.mps.is_available(): return torch.device(ivy.Device(device).replace("gpu", "mps")) return torch.device(ivy.Device(device).replace("gpu", "cuda")) diff --git a/ivy/functional/backends/torch/elementwise.py b/ivy/functional/backends/torch/elementwise.py index a0a3f5c11e053..4d7c767fd5b8c 100644 --- a/ivy/functional/backends/torch/elementwise.py +++ b/ivy/functional/backends/torch/elementwise.py @@ -583,8 +583,8 @@ def square(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.T @handle_numpy_arrays_in_specific_backend def pow( - x1: Union[float, torch.Tensor], - x2: Union[float, torch.Tensor], + x1: torch.Tensor, + x2: Union[int, float, torch.Tensor], /, *, out: Optional[torch.Tensor] = None, diff --git a/ivy/functional/backends/torch/experimental/activations.py b/ivy/functional/backends/torch/experimental/activations.py index 98d72ac526c45..5969804298c59 100644 --- a/ivy/functional/backends/torch/experimental/activations.py +++ b/ivy/functional/backends/torch/experimental/activations.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, Union, Literal # global import torch @@ -16,6 +16,7 @@ def logit( /, *, eps: Optional[float] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[torch.Tensor] = None, ) -> torch.Tensor: return torch.logit(x, eps=eps, out=out) @@ -39,8 +40,10 @@ def relu6(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Te @with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) def logsigmoid( - input: torch.Tensor, /, *, out: Optional[torch.Tensor] = None + input: torch.Tensor, /, *, complex_mode="jax", out: Optional[torch.Tensor] = None ) -> torch.Tensor: + if torch.is_complex(input): + return torch.log(torch.sigmoid(input)) return torch.nn.functional.logsigmoid(input) diff --git a/ivy/functional/backends/torch/experimental/creation.py b/ivy/functional/backends/torch/experimental/creation.py index a1b94a8a894cf..ff46646abc6b5 100644 --- a/ivy/functional/backends/torch/experimental/creation.py +++ b/ivy/functional/backends/torch/experimental/creation.py @@ -210,3 +210,35 @@ def trilu( trilu.support_native_out = True + + +def mel_weight_matrix( + num_mel_bins: int, + dft_length: int, + sample_rate: int, + lower_edge_hertz: float = 125.0, + upper_edge_hertz: float = 3000.0, +): + # transform the inputs to tensors + lower_edge_hertz = torch.tensor(lower_edge_hertz) + upper_edge_hertz = torch.tensor(upper_edge_hertz) + zero = torch.tensor(0.0) + # mel transform lambda function + hz_to_mel = lambda f: 2595 * torch.log10(1 + f / 700) + nyquist_hz = sample_rate / 2 + # define a range of frequencies in HZ + linear_freqs = torch.linspace(0, nyquist_hz, dft_length)[1:] + # transform the frequencies from HZ to mels + spec_bin_mels = hz_to_mel(linear_freqs).unsqueeze(1) + mel_edges = torch.linspace( + hz_to_mel(lower_edge_hertz), hz_to_mel(upper_edge_hertz), num_mel_bins + 2 + ) + # create overlapping frames of size 3 + mel_edges = mel_edges.unfold(0, size=3, step=1) + lower_edge_mel, center_mel, upper_edge_mel = [ + t.reshape((1, num_mel_bins)) for t in mel_edges.split(1, dim=1) + ] + lower_slopes = (spec_bin_mels - lower_edge_mel) / (center_mel - lower_edge_mel) + upper_slopes = (upper_edge_mel - spec_bin_mels) / (upper_edge_mel - center_mel) + mel_weights = torch.maximum(zero, torch.minimum(lower_slopes, upper_slopes)) + return torch.nn.functional.pad(mel_weights, (0, 0, 1, 0)) diff --git a/ivy/functional/backends/torch/experimental/losses.py b/ivy/functional/backends/torch/experimental/losses.py index 2c24c6afd01e2..a37365922a1b6 100644 --- a/ivy/functional/backends/torch/experimental/losses.py +++ b/ivy/functional/backends/torch/experimental/losses.py @@ -97,3 +97,32 @@ def soft_margin_loss( target, reduction=reduction, ) + + +@with_unsupported_dtypes( + { + "2.0.1 and below": ( + "float16", + "uint8", + "int8", + "int16", + "int32", + "int64", + "bool", + ) + }, + backend_version, +) +def kl_div( + input: torch.Tensor, + target: torch.Tensor, + /, + *, + reduction: Optional[str] = "mean", +) -> torch.Tensor: + loss = torch.nn.functional.kl_div( + input, + target, + reduction=reduction, + ) + return loss diff --git a/ivy/functional/backends/torch/experimental/manipulation.py b/ivy/functional/backends/torch/experimental/manipulation.py index 16c7cc3e5df76..2b6d0aa525429 100644 --- a/ivy/functional/backends/torch/experimental/manipulation.py +++ b/ivy/functional/backends/torch/experimental/manipulation.py @@ -462,3 +462,9 @@ def fill_diagonal( a = torch.where(w, v, a) a = torch.reshape(a, shape) return a + + +def column_stack( + arrays: Sequence[torch.Tensor], /, *, out: Optional[torch.Tensor] = None +) -> torch.Tensor: + return torch.column_stack(arrays) diff --git a/ivy/functional/backends/torch/experimental/statistical.py b/ivy/functional/backends/torch/experimental/statistical.py index 2cdc576eec124..b2aee02bcd414 100644 --- a/ivy/functional/backends/torch/experimental/statistical.py +++ b/ivy/functional/backends/torch/experimental/statistical.py @@ -185,6 +185,42 @@ def nanmean( nanmean.support_native_out = True +def nanprod( + a: torch.Tensor, + /, + *, + axis: Optional[Union[int, Sequence[int]]] = None, + dtype: Optional[torch.dtype] = None, + keepdims: Optional[bool] = False, + out: Optional[torch.Tensor] = None, + initial: Optional[Union[int, float, complex, ivy.Container]] = None, + where: Optional[torch.Tensor] = None, +) -> torch.Tensor: + dtype = ivy.as_native_dtype(dtype) + if dtype is None: + dtype = _infer_dtype(a.dtype) + if initial is None: + initial = 1 + a = a.type(dtype) + a = torch.nan_to_num(a, nan=1.0) + if a.dtype == torch.float16: + a = a.type(torch.float32) + if axis == (): + return a.type(dtype) + if axis is None: + return torch.prod(input=a, out=out).type(dtype) * initial + if isinstance(axis, tuple) or isinstance(axis, list): + for i in axis: + a = torch.prod(a, dim=i, keepdim=keepdims, out=out).type(dtype) + if a.dtype == torch.float16: + a = a.type(torch.float32) + return a.type(dtype) * initial + return torch.prod(a, dim=axis, keepdim=keepdims, out=out).type(dtype) * initial + + +nanprod.support_native_out = True + + def _validate_quantile(q): if isinstance(q, float): q = torch.as_tensor(q) diff --git a/ivy/functional/backends/torch/statistical.py b/ivy/functional/backends/torch/statistical.py index 04694a0bdca3f..1e31b8854cbb2 100644 --- a/ivy/functional/backends/torch/statistical.py +++ b/ivy/functional/backends/torch/statistical.py @@ -36,7 +36,6 @@ def min( min.support_native_out = True -@with_unsupported_dtypes({"2.0.1 and below": ("complex",)}, backend_version) def max( x: torch.Tensor, /, @@ -45,6 +44,15 @@ def max( keepdims: bool = False, out: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if torch.is_complex(x): + const = torch.tensor(1j, device=x.device, dtype=x.dtype) + real_max = torch.max(x.real, dim=axis, keepdim=keepdims).values + min_val = torch.finfo(x.real.dtype).min + imag = torch.where(x.real == real_max, x.imag, min_val) + # we consider the number with the biggest real and imag part + img_max = torch.max(imag, dim=axis, keepdim=keepdims).values + img_max = img_max.to(x.dtype) + return torch.add(real_max.to(x.dtype), torch.multiply(img_max, const)) if axis == (): if ivy.exists(out): return ivy.inplace_update(out, x) diff --git a/ivy/functional/frontends/jax/func_wrapper.py b/ivy/functional/frontends/jax/func_wrapper.py index cace90dd4fe33..ba89504f7243c 100644 --- a/ivy/functional/frontends/jax/func_wrapper.py +++ b/ivy/functional/frontends/jax/func_wrapper.py @@ -111,10 +111,10 @@ def _inputs_to_ivy_arrays_jax(*args, **kwargs): has_out = True # convert all arrays in the inputs to ivy.Array instances new_args = ivy.nested_map( - args, _to_ivy_array, include_derived={tuple: True}, shallow=False + args, _to_ivy_array, include_derived={"tuple": True}, shallow=False ) new_kwargs = ivy.nested_map( - kwargs, _to_ivy_array, include_derived={tuple: True}, shallow=False + kwargs, _to_ivy_array, include_derived={"tuple": True}, shallow=False ) # add the original out argument back to the keyword arguments if has_out: @@ -153,10 +153,10 @@ def _outputs_to_frontend_arrays_jax(*args, **kwargs): return _from_ivy_array_to_jax_frontend_array_weak_type( ret, nested=True, - include_derived={tuple: True}, + include_derived={"tuple": True}, ) return _from_ivy_array_to_jax_frontend_array( - ret, nested=True, include_derived={tuple: True} + ret, nested=True, include_derived={"tuple": True} ) return _outputs_to_frontend_arrays_jax diff --git a/ivy/functional/frontends/jax/nn/non_linear_activations.py b/ivy/functional/frontends/jax/nn/non_linear_activations.py index 2f6ad3a036c69..8d10397702cec 100644 --- a/ivy/functional/frontends/jax/nn/non_linear_activations.py +++ b/ivy/functional/frontends/jax/nn/non_linear_activations.py @@ -186,7 +186,7 @@ def leaky_relu(x, negative_slope=0.01): @to_ivy_arrays_and_back def log_sigmoid(x): x = _type_conversion(x) - return ivy.negative(ivy.softplus(ivy.negative(x))).astype(x.dtype) + return ivy.logsigmoid(x, complex_mode="jax").astype(x.dtype) @to_ivy_arrays_and_back diff --git a/ivy/functional/frontends/jax/numpy/fft.py b/ivy/functional/frontends/jax/numpy/fft.py index 7a62b524d67e8..0ba0fa31f9c94 100644 --- a/ivy/functional/frontends/jax/numpy/fft.py +++ b/ivy/functional/frontends/jax/numpy/fft.py @@ -11,6 +11,13 @@ def fft(a, n=None, axis=-1, norm=None): return ivy.fft(a, axis, norm=norm, n=n) +@to_ivy_arrays_and_back +def fft2(a, s=None, axes=(-2, -1), norm=None): + if norm is None: + norm = "backward" + return ivy.array(ivy.fft2(a, s=s, dim=axes, norm=norm), dtype=ivy.dtype(a)) + + @to_ivy_arrays_and_back @with_unsupported_dtypes({"2.4.2 and below": ("float16", "bfloat16")}, "paddle") def fftshift(x, axes=None, name=None): @@ -27,3 +34,10 @@ def fftshift(x, axes=None, name=None): roll = ivy.roll(x, shifts, axis=axes) return roll + + +@to_ivy_arrays_and_back +def ifft(a, n=None, axis=-1, norm=None): + if norm is None: + norm = "backward" + return ivy.ifft(a, axis, norm=norm, n=n) diff --git a/ivy/functional/frontends/jax/random.py b/ivy/functional/frontends/jax/random.py index 449b6f35a5cc8..694b5f1004ff3 100644 --- a/ivy/functional/frontends/jax/random.py +++ b/ivy/functional/frontends/jax/random.py @@ -145,6 +145,24 @@ def dirichlet(key, alpha, shape=None, dtype="float32"): return ivy.dirichlet(alpha, size=shape, dtype=dtype, seed=seed) +@handle_jax_dtype +@to_ivy_arrays_and_back +@with_unsupported_dtypes( + {"0.4.14 and below": "uint32"}, + "jax", +) +def double_sided_maxwell(key, loc, scale, shape=(), dtype="float64"): + params_shapes = ivy.broadcast_shapes(ivy.shape(loc), ivy.shape(scale)) + if not shape: + shape = params_shapes + + shape = shape + params_shapes + maxwell_rvs = maxwell(key, shape=shape, dtype=dtype) + random_sign = rademacher(key, shape=shape, dtype=dtype) + + return random_sign * maxwell_rvs * scale + loc + + @handle_jax_dtype @to_ivy_arrays_and_back @with_unsupported_dtypes( @@ -247,6 +265,18 @@ def loggamma(key, a, shape=None, dtype="float64"): return ivy.log(ivy.gamma(a, 1.0, shape=shape, dtype=dtype, seed=seed)) +@handle_jax_dtype +@to_ivy_arrays_and_back +@with_unsupported_dtypes( + {"0.4.14 and below": ("float16", "bfloat16")}, + "jax", +) +def logistic(key, shape=(), dtype="float64"): + seed = _get_seed(key) + uniform_x = ivy.random_uniform(seed=seed, shape=shape, dtype=dtype) + return ivy.log(ivy.divide(uniform_x, ivy.subtract(1.0, uniform_x))) + + @handle_jax_dtype @to_ivy_arrays_and_back @with_unsupported_dtypes( @@ -258,13 +288,11 @@ def loggamma(key, a, shape=None, dtype="float64"): }, "jax", ) -def maxwell(key, shape=None, dtype="float64"): +def maxwell(key, shape, dtype="float64"): seed = _get_seed(key) - # generate uniform random numbers between 0 and 1 - z = ivy.random_uniform(seed=seed, shape=shape, dtype=dtype) - # applying inverse transform sampling - x = (z**2) * ivy.exp(-(z**2) / 2) - return x + shape = shape + (3,) + random_normal = ivy.random_normal(seed=seed, shape=shape, dtype=dtype) + return ivy.vector_norm(random_normal, axis=-1) @handle_jax_dtype @@ -377,7 +405,8 @@ def poisson(key, lam, shape=None, dtype=None): ) def rademacher(key, shape, dtype="int64"): seed = _get_seed(key) - b = ivy.bernoulli(ivy.array([0.5]), shape=shape, dtype="float32", seed=seed) + prob = ivy.full(shape, 0.5, dtype="float32") + b = ivy.bernoulli(prob, shape=shape, dtype="float32", seed=seed) b = ivy.astype(b, dtype) return 2 * b - 1 diff --git a/ivy/functional/frontends/mxnet/func_wrapper.py b/ivy/functional/frontends/mxnet/func_wrapper.py index 20b358488ea89..9beab9f42365d 100644 --- a/ivy/functional/frontends/mxnet/func_wrapper.py +++ b/ivy/functional/frontends/mxnet/func_wrapper.py @@ -81,10 +81,10 @@ def _inputs_to_ivy_arrays_mxnet(*args, **kwargs): """ # convert all arrays in the inputs to ivy.Array instances new_args = ivy.nested_map( - args, _to_ivy_array, include_derived={tuple: True}, shallow=False + args, _to_ivy_array, include_derived={"tuple": True}, shallow=False ) new_kwargs = ivy.nested_map( - kwargs, _to_ivy_array, include_derived={tuple: True}, shallow=False + kwargs, _to_ivy_array, include_derived={"tuple": True}, shallow=False ) return fn(*new_args, **new_kwargs) @@ -105,7 +105,7 @@ def _outputs_to_frontend_arrays_mxnet(*args, **kwargs): ret = fn(*args, **kwargs) # convert all arrays in the return to `frontend.Tensorflow.tensor` instances - return ivy.nested_map(ret, _ivy_array_to_mxnet, include_derived={tuple: True}) + return ivy.nested_map(ret, _ivy_array_to_mxnet, include_derived={"tuple": True}) _outputs_to_frontend_arrays_mxnet.outputs_to_frontend_arrays = True return _outputs_to_frontend_arrays_mxnet diff --git a/ivy/functional/frontends/numpy/func_wrapper.py b/ivy/functional/frontends/numpy/func_wrapper.py index 495820f649e0f..a001f245ec3d5 100644 --- a/ivy/functional/frontends/numpy/func_wrapper.py +++ b/ivy/functional/frontends/numpy/func_wrapper.py @@ -194,7 +194,7 @@ def _set_order(args, order): ) if order in ["K", "A", None]: check_order = ivy.nested_map( - args, _check_C_order, include_derived={tuple: True}, shallow=False + args, _check_C_order, include_derived={"tuple": True}, shallow=False ) if all(v is None for v in check_order) or any( ivy.multi_index_nest(check_order, ivy.all_nested_indices(check_order)) @@ -447,9 +447,9 @@ def _inputs_to_ivy_arrays_np(*args, **kwargs): The return of the function, with ivy arrays passed in the arguments. """ # convert all arrays in the inputs to ivy.Array instances - ivy_args = ivy.nested_map(args, _to_ivy_array, include_derived={tuple: True}) + ivy_args = ivy.nested_map(args, _to_ivy_array, include_derived={"tuple": True}) ivy_kwargs = ivy.nested_map( - kwargs, _to_ivy_array, include_derived={tuple: True} + kwargs, _to_ivy_array, include_derived={"tuple": True} ) return fn(*ivy_args, **ivy_kwargs) @@ -509,10 +509,10 @@ def _outputs_to_frontend_arrays(*args, order="K", **kwargs): # convert all returned arrays to `ndarray` instances if order == "F": return ivy.nested_map( - ret, _ivy_to_numpy_order_F, include_derived={tuple: True} + ret, _ivy_to_numpy_order_F, include_derived={"tuple": True} ) else: - return ivy.nested_map(ret, _ivy_to_numpy, include_derived={tuple: True}) + return ivy.nested_map(ret, _ivy_to_numpy, include_derived={"tuple": True}) if "order" in list(inspect.signature(fn).parameters.keys()): contains_order = True diff --git a/ivy/functional/frontends/numpy/linalg/matrix_eigenvalues.py b/ivy/functional/frontends/numpy/linalg/matrix_eigenvalues.py index d0b98b00264c7..6c3a6e61f1265 100644 --- a/ivy/functional/frontends/numpy/linalg/matrix_eigenvalues.py +++ b/ivy/functional/frontends/numpy/linalg/matrix_eigenvalues.py @@ -17,6 +17,11 @@ def eigh(a, /, UPLO="L"): return ivy.eigh(a, UPLO=UPLO) +@to_ivy_arrays_and_back +def eigvals(a): + return ivy.eig(a)[0] + + @to_ivy_arrays_and_back @from_zero_dim_arrays_to_scalar def eigvalsh(a, /, UPLO="L"): diff --git a/ivy/functional/frontends/numpy/linalg/norms_and_other_numbers.py b/ivy/functional/frontends/numpy/linalg/norms_and_other_numbers.py index 505fee51f9ee6..b02b6fe921046 100644 --- a/ivy/functional/frontends/numpy/linalg/norms_and_other_numbers.py +++ b/ivy/functional/frontends/numpy/linalg/norms_and_other_numbers.py @@ -27,7 +27,7 @@ def matrix_rank(A, tol=None, hermitian=False): @to_ivy_arrays_and_back @from_zero_dim_arrays_to_scalar def norm(x, ord=None, axis=None, keepdims=False): - if axis is None and not (ord is None): + if axis is None and (ord is not None): if x.ndim not in (1, 2): raise ValueError("Improper number of dimensions to norm.") else: diff --git a/ivy/functional/frontends/numpy/mathematical_functions/other_special_functions.py b/ivy/functional/frontends/numpy/mathematical_functions/other_special_functions.py index 6d2c7685519fb..1a13fe4964081 100644 --- a/ivy/functional/frontends/numpy/mathematical_functions/other_special_functions.py +++ b/ivy/functional/frontends/numpy/mathematical_functions/other_special_functions.py @@ -8,6 +8,13 @@ ) +@to_ivy_arrays_and_back +def sinc(x): + if ivy.get_num_dims(x) == 0: + x = ivy.astype(x, ivy.float64) + return ivy.sinc(x) + + @to_ivy_arrays_and_back @from_zero_dim_arrays_to_scalar def unwrap(p, discont=None, axis=-1, *, period=2 * ivy.pi): @@ -35,10 +42,3 @@ def unwrap(p, discont=None, axis=-1, *, period=2 * ivy.pi): up = ivy.array(p, copy=True, dtype=dtype) up[slice1] = p[slice1] + ph_correct.cumsum(axis) return up - - -@to_ivy_arrays_and_back -def sinc(x): - if ivy.get_num_dims(x) == 0: - x = ivy.astype(x, ivy.float64) - return ivy.sinc(x) diff --git a/ivy/functional/frontends/numpy/mathematical_functions/sums_products_differences.py b/ivy/functional/frontends/numpy/mathematical_functions/sums_products_differences.py index 361a55c5c0dd0..a863284396fd9 100644 --- a/ivy/functional/frontends/numpy/mathematical_functions/sums_products_differences.py +++ b/ivy/functional/frontends/numpy/mathematical_functions/sums_products_differences.py @@ -8,6 +8,7 @@ from_zero_dim_arrays_to_scalar, handle_numpy_out, ) +import ivy.functional.frontends.numpy as np_frontend @handle_numpy_out @@ -114,9 +115,10 @@ def prod( initial=None, where=True, ): - if ivy.is_array(where): + if where is not True: x = ivy.where(where, x, ivy.default(out, ivy.ones_like(x)), out=out) if initial is not None: + initial = np_frontend.array(initial, dtype=dtype).tolist() if axis is not None: s = ivy.to_list(ivy.shape(x, as_array=True)) s[axis] = 1 diff --git a/ivy/functional/frontends/onnx/func_wrapper.py b/ivy/functional/frontends/onnx/func_wrapper.py index ced490f7f96a1..d1e784c7a197e 100644 --- a/ivy/functional/frontends/onnx/func_wrapper.py +++ b/ivy/functional/frontends/onnx/func_wrapper.py @@ -58,10 +58,10 @@ def _inputs_to_ivy_arrays_onnx(*args, **kwargs): """ # convert all arrays in the inputs to ivy.Array instances new_args = ivy.nested_map( - args, _to_ivy_array, include_derived={tuple: True}, shallow=False + args, _to_ivy_array, include_derived={"tuple": True}, shallow=False ) new_kwargs = ivy.nested_map( - kwargs, _to_ivy_array, include_derived={tuple: True}, shallow=False + kwargs, _to_ivy_array, include_derived={"tuple": True}, shallow=False ) return fn(*new_args, **new_kwargs) @@ -82,7 +82,7 @@ def _outputs_to_frontend_arrays_onnx(*args, **kwargs): # convert all arrays in the return to `frontend.onnx.Tensor` instances return _from_ivy_array_to_onnx_frontend_tensor( - ret, nested=True, include_derived={tuple: True} + ret, nested=True, include_derived={"tuple": True} ) return _outputs_to_frontend_arrays_onnx diff --git a/ivy/functional/frontends/paddle/__init__.py b/ivy/functional/frontends/paddle/__init__.py index 21853bfe07ff2..8b6e9382abe19 100644 --- a/ivy/functional/frontends/paddle/__init__.py +++ b/ivy/functional/frontends/paddle/__init__.py @@ -215,25 +215,20 @@ def promote_types_of_paddle_inputs( return x1, x2 -from . import vision from . import nn -from .nn.functional.activation import tanh -from . import linalg -from . import fft -from . import signal - -from .tensor.attribute import * -from .tensor.creation import * -from .tensor.linalg import * -from .tensor.logic import * -from .tensor.manipulation import * -from .tensor.math import * -from .tensor.random import * -from .tensor.search import * -from .tensor.einsum import * -from .tensor.stat import * - +from . import tensor from .tensor.tensor import Tensor +from . import vision +from .attribute import * +from .creation import * +from .fft import * +from .linalg import * +from .logic import * +from .manipulation import * +from .math import * +from .random import * +from .search import * +from .stat import * _frontend_array = Tensor diff --git a/ivy/functional/frontends/paddle/attribute.py b/ivy/functional/frontends/paddle/attribute.py new file mode 100644 index 0000000000000..a94192737bdcb --- /dev/null +++ b/ivy/functional/frontends/paddle/attribute.py @@ -0,0 +1,35 @@ +# global +import ivy +from ivy.functional.frontends.paddle.func_wrapper import ( + to_ivy_arrays_and_back, +) + + +@to_ivy_arrays_and_back +def imag(x): + return ivy.imag(x) + + +@to_ivy_arrays_and_back +def is_complex(x): + return ivy.is_complex_dtype(x) + + +@to_ivy_arrays_and_back +def is_floating_point(x): + return ivy.is_float_dtype(x) + + +@to_ivy_arrays_and_back +def is_integer(x): + return ivy.is_int_dtype(x) + + +@to_ivy_arrays_and_back +def rank(input): + return ivy.get_num_dims(input) + + +@to_ivy_arrays_and_back +def real(x): + return ivy.real(x) diff --git a/ivy/functional/frontends/paddle/creation.py b/ivy/functional/frontends/paddle/creation.py new file mode 100644 index 0000000000000..5a110fb73d326 --- /dev/null +++ b/ivy/functional/frontends/paddle/creation.py @@ -0,0 +1,222 @@ +# global +import ivy +from ivy.func_wrapper import with_unsupported_dtypes, with_supported_dtypes +import ivy.functional.frontends.paddle as paddle_frontend +from ivy.functional.frontends.paddle.func_wrapper import ( + to_ivy_arrays_and_back, +) + + +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def arange(start, end=None, step=1, dtype=None, name=None): + return ivy.arange(start, end, step=step, dtype=dtype) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64", "bool")}, + "paddle", +) +@to_ivy_arrays_and_back +def assign(x, output=None): + if len(ivy.shape(x)) == 0: + x = ivy.reshape(ivy.Array(x), (1,)) + if ivy.exists(output): + output = ivy.reshape(ivy.Array(output), (1,)) + else: + x = ivy.reshape(x, ivy.shape(x)) + ret = ivy.copy_array(x, to_ivy_array=False, out=output) + return ret + + +@with_unsupported_dtypes( + {"2.5.1 and below": ("bfloat16", "uint16", "uint32", "uint64")}, "paddle" +) +@to_ivy_arrays_and_back +def clone(x): + return ivy.copy_array(x) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64")}, + "paddle", +) +@to_ivy_arrays_and_back +def complex(real, imag, name=None): + assert real.dtype == imag.dtype, ( + "(InvalidArgument) The type of data we are trying to retrieve does not match" + " the type of data currently contained in the container." + ) + complex_dtype = "complex64" if real.dtype == "float32" else "complex128" + imag_cmplx = ivy.astype(imag, complex_dtype) * 1j + complex_array = real + imag_cmplx + return complex_array + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" +) +@to_ivy_arrays_and_back +def diag(x, offset=0, padding_value=0, name=None): + if len(x.shape) == 1: + padding_value = ivy.astype(padding_value, ivy.dtype(x)) + ret = ivy.diagflat(x, offset=offset, padding_value=padding_value) + if len(ret.shape) != 2: + ret = ivy.reshape(ret, (1, 1)) + else: + ret = ivy.diag(x, k=offset) + return ret + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" +) +@to_ivy_arrays_and_back +def diagflat(x, offset=0, name=None): + arr = ivy.diagflat(x, offset=offset) + return arr + + +@to_ivy_arrays_and_back +def empty(shape, dtype=None): + return ivy.empty(shape=shape, dtype=dtype) + + +@to_ivy_arrays_and_back +def empty_like(x, dtype=None, name=None): + return ivy.empty_like(x, dtype=dtype) + + +@to_ivy_arrays_and_back +def eye(num_rows, num_columns=None, dtype=None, name=None): + return ivy.eye(num_rows, num_columns, dtype=dtype) + + +@to_ivy_arrays_and_back +def full(shape, fill_value, /, *, dtype=None, name=None): + dtype = "float32" if dtype is None else dtype + return ivy.full(shape, fill_value, dtype=dtype) + + +@to_ivy_arrays_and_back +def full_like(x, fill_value, /, *, dtype=None, name=None): + dtype = x.dtype if dtype is None else dtype + return ivy.full_like(x, fill_value, dtype=dtype) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" +) +@to_ivy_arrays_and_back +def linspace(start, stop, num, dtype=None, name=None): + return ivy.linspace(start, stop, num=num, dtype=dtype) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" +) +@to_ivy_arrays_and_back +def logspace(start, stop, num, base=10.0, dtype=None, name=None): + return ivy.logspace(start, stop, num=num, base=base, dtype=dtype) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" +) +@to_ivy_arrays_and_back +def meshgrid(*args, **kwargs): + return ivy.meshgrid(*args, indexing="ij") + + +@with_unsupported_dtypes({"2.5.1 and below": "int8"}, "paddle") +@to_ivy_arrays_and_back +def ones(shape, /, *, dtype=None, name=None): + dtype = "float32" if dtype is None else dtype + return ivy.ones(shape, dtype=dtype) + + +@with_unsupported_dtypes( + {"2.5.1 and below": ("uint8", "int8", "complex64", "complex128")}, "paddle" +) +@to_ivy_arrays_and_back +def ones_like(x, /, *, dtype=None, name=None): + dtype = x.dtype if dtype is None else dtype + return ivy.ones_like(x, dtype=dtype) + + +@to_ivy_arrays_and_back +def to_tensor(data, /, *, dtype=None, place=None, stop_gradient=True): + array = ivy.array(data, dtype=dtype, device=place) + return paddle_frontend.Tensor(array, dtype=dtype, place=place) + + +@with_unsupported_dtypes( + { + "2.5.1 and below": ( + "uint8", + "int8", + "int16", + "float16", + "complex64", + "complex128", + "bool", + ) + }, + "paddle", +) +@to_ivy_arrays_and_back +def tril(x, diagonal=0, name=None): + return ivy.tril(x, k=diagonal) + + +@with_supported_dtypes({"2.5.1 and below": ("int32", "int64")}, "paddle") +@to_ivy_arrays_and_back +def tril_indices(row, col, offset=0, dtype="int64"): + arr = ivy.tril_indices(row, col, offset) + arr = ivy.astype(arr, dtype) + return arr + + +@with_unsupported_dtypes( + { + "2.5.1 and below": ( + "uint8", + "int8", + "int16", + "float16", + "complex64", + "complex128", + "bool", + ) + }, + "paddle", +) +@to_ivy_arrays_and_back +def triu(x, diagonal=0, name=None): + return ivy.triu(x, k=diagonal) + + +@with_supported_dtypes({"2.5.1 and below": ("int32", "int64")}, "paddle") +@to_ivy_arrays_and_back +def triu_indices(row, col=None, offset=0, dtype="int64"): + arr = ivy.triu_indices(row, col, offset) + if not ivy.to_scalar(ivy.shape(arr[0], as_array=True)): + return arr + arr = ivy.astype(arr, dtype) + return arr + + +@with_unsupported_dtypes({"2.5.1 and below": "int8"}, "paddle") +@to_ivy_arrays_and_back +def zeros(shape, /, *, dtype=None, name=None): + dtype = "float32" if dtype is None else dtype + return ivy.zeros(shape, dtype=dtype) + + +@with_unsupported_dtypes( + {"2.5.1 and below": ("uint8", "int8", "complex64", "complex128")}, "paddle" +) +@to_ivy_arrays_and_back +def zeros_like(x, /, *, dtype=None, name=None): + dtype = x.dtype if dtype is None else dtype + return ivy.zeros_like(x, dtype=dtype) diff --git a/ivy/functional/frontends/paddle/fft.py b/ivy/functional/frontends/paddle/fft.py index bd40f59060eb0..4355c8d954eef 100644 --- a/ivy/functional/frontends/paddle/fft.py +++ b/ivy/functional/frontends/paddle/fft.py @@ -122,3 +122,12 @@ def irfft(x, n=None, axis=-1.0, norm="backward", name=None): if ivy.isreal(x): time_domain = ivy.real(time_domain) return time_domain + + +@to_ivy_arrays_and_back +def rfftfreq(n, d=1.0, dtype=None, name=None): + dtype = ivy.default_dtype() + val = 1.0 / (n * d) + pos_max = n // 2 + 1 + indices = ivy.arange(0, pos_max, dtype=dtype) + return indices * val diff --git a/ivy/functional/frontends/paddle/func_wrapper.py b/ivy/functional/frontends/paddle/func_wrapper.py index 1af63572cc413..b0ddb992e05e9 100644 --- a/ivy/functional/frontends/paddle/func_wrapper.py +++ b/ivy/functional/frontends/paddle/func_wrapper.py @@ -50,10 +50,10 @@ def new_fn(*args, **kwargs): """ # convert all input arrays to ivy.Array instances new_args = ivy.nested_map( - args, _to_ivy_array, include_derived={tuple: True}, shallow=False + args, _to_ivy_array, include_derived={"tuple": True}, shallow=False ) new_kwargs = ivy.nested_map( - kwargs, _to_ivy_array, include_derived={tuple: True}, shallow=False + kwargs, _to_ivy_array, include_derived={"tuple": True}, shallow=False ) return fn(*new_args, **new_kwargs) @@ -82,7 +82,7 @@ def new_fn(*args, **kwargs): ivy.unset_default_float_dtype() # convert all arrays in the return to `paddle_frontend.Tensor` instances return _from_ivy_array_to_paddle_frontend_tensor( - ret, nested=True, include_derived={tuple: True} + ret, nested=True, include_derived={"tuple": True} ) return new_fn diff --git a/ivy/functional/frontends/paddle/linalg.py b/ivy/functional/frontends/paddle/linalg.py index a13f5b7eb9c71..34eff9474cb1b 100644 --- a/ivy/functional/frontends/paddle/linalg.py +++ b/ivy/functional/frontends/paddle/linalg.py @@ -1,5 +1,196 @@ -# Note: All functions are supposed to be added -# to `paddle.tensor.linalg` namespace and will -# be imported here for the sake of namespace consistency. -from . import tensor # noqa: F401 -from .tensor.linalg import * # noqa: F401, F403 +# global +import ivy +from ivy.func_wrapper import with_unsupported_dtypes, with_supported_dtypes +from ivy.functional.frontends.paddle import promote_types_of_paddle_inputs +from ivy.functional.frontends.paddle.func_wrapper import ( + to_ivy_arrays_and_back, +) + + +@with_supported_dtypes({"2.4.1 and above": ("int64",)}, "paddle") +@to_ivy_arrays_and_back +def bincount(x, weights=None, minlength=0, name=None): + return ivy.bincount(x, weights=weights, minlength=minlength) + + +# bmm +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def bmm(x, y, transpose_x=False, transpose_y=False, name=None): + if len(ivy.shape(x)) != 3 or len(ivy.shape(y)) != 3: + raise RuntimeError("input must be 3D matrices") + x, y = promote_types_of_paddle_inputs(x, y) + return ivy.matmul(x, y, transpose_a=transpose_x, transpose_b=transpose_y) + + +# cholesky +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def cholesky(x, /, *, upper=False, name=None): + return ivy.cholesky(x, upper=upper) + + +# cholesky_solve +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def cholesky_solve(x, y, /, *, upper=False, name=None): + if upper: + y = ivy.matrix_transpose(y) + Y = ivy.solve(y, x) + return ivy.solve(ivy.matrix_transpose(y), Y) + + +# cond +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def cond(x, p=None, name=None): + ret = ivy.cond(x, p=p, out=name) + if ret.shape == (): + ret = ret.reshape((1,)) + return ret + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" +) +@to_ivy_arrays_and_back +def cross(x, y, /, *, axis=9, name=None): + x, y = promote_types_of_paddle_inputs(x, y) + return ivy.cross(x, y, axis=axis) + + +@with_supported_dtypes({"2.4.1 and above": ("float64", "float32")}, "paddle") +@to_ivy_arrays_and_back +def dist(x, y, p=2): + ret = ivy.vector_norm(ivy.subtract(x, y), ord=p) + return ivy.reshape(ret, (1,)) + + +# dot +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def dot(x, y, name=None): + x, y = promote_types_of_paddle_inputs(x, y) + out = ivy.multiply(x, y) + return ivy.sum(out, axis=ivy.get_num_dims(x) - 1, keepdims=False) + + +# eig +@to_ivy_arrays_and_back +def eig(x, name=None): + return ivy.eig(x) + + +# eigh +@to_ivy_arrays_and_back +def eigh(x, UPLO="L", name=None): + return ivy.eigh(x, UPLO=UPLO) + + +# eigvals +@to_ivy_arrays_and_back +def eigvals(x, name=None): + return ivy.eigvals(x) + + +# eigvalsh +@to_ivy_arrays_and_back +def eigvalsh(x, UPLO="L", name=None): + return ivy.eigvalsh(x, UPLO=UPLO) + + +# matmul +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def matmul(x, y, transpose_x=False, transpose_y=False, name=None): + x, y = promote_types_of_paddle_inputs(x, y) + return ivy.matmul(x, y, transpose_a=transpose_x, transpose_b=transpose_y) + + +# matrix_power +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def matrix_power(x, n, name=None): + return ivy.matrix_power(x, n) + + +# norm +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def norm(x, p="fro", axis=None, keepdim=False, name=None): + if axis is None and p is not None: + if p == "fro": + p = 2 + ret = ivy.vector_norm(x.flatten(), ord=p, axis=-1) + if keepdim: + ret = ret.reshape([1] * len(x.shape)) + return ret + + if isinstance(axis, tuple): + axis = list(axis) + if isinstance(axis, list) and len(axis) == 1: + axis = axis[0] + + if isinstance(axis, int): + if p == "fro": + p = 2 + if p in [0, 1, 2, ivy.inf, -ivy.inf]: + ret = ivy.vector_norm(x, ord=p, axis=axis, keepdims=keepdim) + elif isinstance(p, (int, float)): + ret = ivy.pow( + ivy.sum(ivy.pow(ivy.abs(x), p), axis=axis, keepdims=keepdim), + float(1.0 / p), + ) + + elif isinstance(axis, list) and len(axis) == 2: + if p == 0: + raise ValueError + elif p == 1: + ret = ivy.sum(ivy.abs(x), axis=axis, keepdims=keepdim) + elif p == 2 or p == "fro": + ret = ivy.matrix_norm(x, ord="fro", axis=axis, keepdims=keepdim) + elif p == ivy.inf: + ret = ivy.max(ivy.abs(x), axis=axis, keepdims=keepdim) + elif p == -ivy.inf: + ret = ivy.min(ivy.abs(x), axis=axis, keepdims=keepdim) + elif isinstance(p, (int, float)) and p > 0: + ret = ivy.pow( + ivy.sum(ivy.pow(ivy.abs(x), p), axis=axis, keepdims=keepdim), + float(1.0 / p), + ) + else: + raise ValueError + + else: + raise ValueError + + return ret + + +# pinv +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def pinv(x, rcond=1e-15, hermitian=False, name=None): + # TODO: Add hermitian functionality + return ivy.pinv(x, rtol=rcond) + + +# qr +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def qr(x, mode="reduced", name=None): + return ivy.qr(x, mode=mode) + + +# solve +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def solve(x, y, name=None): + return ivy.solve(x, y) + + +# transpose +@with_unsupported_dtypes({"2.5.1 and below": ("uint8", "int8", "int16")}, "paddle") +@to_ivy_arrays_and_back +def transpose(x, perm, name=None): + return ivy.permute_dims(x, axes=perm) diff --git a/ivy/functional/frontends/paddle/logic.py b/ivy/functional/frontends/paddle/logic.py new file mode 100644 index 0000000000000..1b2ea90fa4b06 --- /dev/null +++ b/ivy/functional/frontends/paddle/logic.py @@ -0,0 +1,286 @@ +# global +import ivy +import ivy.functional.frontends.paddle as paddle +from ivy.func_wrapper import ( + with_unsupported_dtypes, + handle_out_argument, + with_supported_dtypes, +) +from ivy.functional.frontends.paddle.func_wrapper import ( + to_ivy_arrays_and_back, +) + + +@with_supported_dtypes( + { + "2.5.1 and below": ( + "float32", + "float64", + "bool", + "uint8", + "int8", + "int16", + "int32", + "int64", + ) + }, + "paddle", +) +@to_ivy_arrays_and_back +@handle_out_argument +def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): + ret = ivy.allclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) + return paddle.to_tensor([ret]) + + +@with_supported_dtypes( + { + "2.5.1 and below": ( + "bool", + "uint8", + "int8", + "int16", + "int32", + "int64", + ) + }, + "paddle", +) +@to_ivy_arrays_and_back +@handle_out_argument +def bitwise_and(x, y, /, *, name=None, out=None): + return ivy.bitwise_and(x, y, out=out) + + +@with_supported_dtypes( + { + "2.5.1 and below": ( + "bool", + "uint8", + "int8", + "int16", + "int32", + "int64", + ) + }, + "paddle", +) +@to_ivy_arrays_and_back +@handle_out_argument +def bitwise_not(x, out=None, name=None): + return ivy.bitwise_invert(x, out=out) + + +@with_supported_dtypes( + { + "2.5.1 and below": ( + "bool", + "uint8", + "int8", + "int16", + "int32", + "int64", + ) + }, + "paddle", +) +@to_ivy_arrays_and_back +@handle_out_argument +def bitwise_or(x, y, name=None, out=None): + return ivy.bitwise_or(x, y, out=out) + + +@with_supported_dtypes( + { + "2.5.1 and below": ( + "bool", + "uint8", + "int8", + "int16", + "int32", + "int64", + ) + }, + "paddle", +) +@to_ivy_arrays_and_back +@handle_out_argument +def bitwise_xor(x, y, /, *, name=None, out=None): + return ivy.bitwise_xor(x, y, out=out) + + +@with_unsupported_dtypes( + {"2.5.1 and below": ("uint8", "int8", "int16", "complex64", "complex128")}, "paddle" +) +@to_ivy_arrays_and_back +def equal(x, y, /, *, name=None): + return ivy.equal(x, y) + + +@with_unsupported_dtypes( + { + "2.5.1 and below": ( + "uint8", + "int8", + "int16", + "float16", + "complex64", + "complex128", + ) + }, + "paddle", +) +@to_ivy_arrays_and_back +def equal_all(x, y, /, *, name=None): + return paddle.to_tensor([ivy.array_equal(x, y)]) + + +@with_unsupported_dtypes( + {"2.5.1 and below": ("bool", "uint8", "int8", "int16", "complex64", "complex128")}, + "paddle", +) +@to_ivy_arrays_and_back +def greater_equal(x, y, /, *, name=None): + return ivy.greater_equal(x, y) + + +@with_unsupported_dtypes( + {"2.5.1 and below": ("bool", "uint8", "int8", "int16", "complex64", "complex128")}, + "paddle", +) +@to_ivy_arrays_and_back +def greater_than(x, y, /, *, name=None): + return ivy.greater(x, y) + + +@with_unsupported_dtypes( + {"2.5.1 and below": ("uint8", "int8", "int16", "complex64", "complex128")}, "paddle" +) +@to_ivy_arrays_and_back +def is_empty(x, name=None): + return ivy.is_empty(x) + + +@to_ivy_arrays_and_back +def is_tensor(x): + return ivy.is_array(x) + + +@with_supported_dtypes( + { + "2.5.1 and below": ( + "float32", + "float64", + ) + }, + "paddle", +) +@to_ivy_arrays_and_back +def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): + return ivy.isclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) + + +@with_unsupported_dtypes( + {"2.5.1 and below": ("bool", "uint8", "int8", "int16", "complex64", "complex128")}, + "paddle", +) +@to_ivy_arrays_and_back +def less_equal(x, y, /, *, name=None): + return ivy.less_equal(x, y) + + +@with_unsupported_dtypes( + {"2.5.1 and below": ("bool", "uint8", "int8", "int16", "complex64", "complex128")}, + "paddle", +) +@to_ivy_arrays_and_back +def less_than(x, y, /, *, name=None): + return ivy.less(x, y) + + +@with_supported_dtypes( + { + "2.5.1 and below": ( + "bool", + "int8", + "int16", + "int32", + "int64", + "float32", + "float64", + ) + }, + "paddle", +) +@to_ivy_arrays_and_back +@handle_out_argument +def logical_and(x, y, /, *, name=None, out=None): + return ivy.logical_and(x, y, out=out) + + +@with_supported_dtypes( + { + "2.5.1 and below": ( + "bool", + "int8", + "int16", + "int32", + "int64", + "float32", + "float64", + ) + }, + "paddle", +) +@to_ivy_arrays_and_back +@handle_out_argument +def logical_not(x, /, *, name=None, out=None): + return ivy.logical_not(x, out=out) + + +@with_supported_dtypes( + { + "2.5.1 and below": ( + "bool", + "int8", + "int16", + "int32", + "int64", + "float32", + "float64", + ) + }, + "paddle", +) +@to_ivy_arrays_and_back +@handle_out_argument +def logical_or(x, y, /, *, name=None, out=None): + return ivy.logical_or(x, y, out=out) + + +@with_supported_dtypes( + { + "2.5.1 and below": ( + "bool", + "int8", + "int16", + "int32", + "int64", + "float32", + "float64", + ) + }, + "paddle", +) +@to_ivy_arrays_and_back +@handle_out_argument +def logical_xor(x, y, /, *, name=None, out=None): + return ivy.logical_xor(x, y, out=out) + + +@with_unsupported_dtypes( + {"2.5.1 and below": ("uint8", "int8", "int16", "complex64", "complex128")}, "paddle" +) +@to_ivy_arrays_and_back +def not_equal(x, y, /, *, name=None): + return ivy.not_equal(x, y) diff --git a/ivy/functional/frontends/paddle/manipulation.py b/ivy/functional/frontends/paddle/manipulation.py new file mode 100644 index 0000000000000..c62dda831e633 --- /dev/null +++ b/ivy/functional/frontends/paddle/manipulation.py @@ -0,0 +1,175 @@ +# global +import ivy +from ivy.functional.frontends.paddle.func_wrapper import ( + to_ivy_arrays_and_back, +) +from ivy.func_wrapper import ( + with_unsupported_dtypes, + with_supported_dtypes, + with_supported_device_and_dtypes, +) + + +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def abs(x, name=None): + return ivy.abs(x) + + +@with_supported_dtypes( + {"2.5.1 and below": ("bool", "float32", "float64", "int32", "int64")}, + "paddle", +) +@to_ivy_arrays_and_back +def broadcast_to(x, shape, name=None): + return ivy.broadcast_to(x, shape) + + +@with_supported_dtypes( + { + "2.5.1 and below": ( + "bool", + "float16", + "float32", + "float64", + "int32", + "int64", + "uint8", + ) + }, + "paddle", +) +@to_ivy_arrays_and_back +def cast(x, dtype): + return ivy.astype(x, dtype) + + +@with_unsupported_dtypes({"2.5.1 and below": ("int8", "int16")}, "paddle") +@to_ivy_arrays_and_back +def concat(x, axis, name=None): + return ivy.concat(x, axis=axis) + + +@with_supported_dtypes( + {"2.5.1 and below": ("bool", "float32", "float64", "int32", "int64")}, + "paddle", +) +@to_ivy_arrays_and_back +def expand(x, shape, name=None): + return ivy.expand(x, shape) + + +@with_unsupported_dtypes( + {"2.5.1 and below": ("int8", "uint8", "int16", "float16")}, + "paddle", +) +@to_ivy_arrays_and_back +def flip(x, axis, name=None): + return ivy.flip(x, axis=axis) + + +@with_supported_dtypes( + {"2.5.1 and below": ("bool", "float32", "float64", "int32", "int64")}, + "paddle", +) +@to_ivy_arrays_and_back +def gather(params, indices, axis=-1, batch_dims=0, name=None): + return ivy.gather(params, indices, axis=axis, batch_dims=batch_dims) + + +@to_ivy_arrays_and_back +def reshape(x, shape): + return ivy.reshape(x, shape) + + +@with_supported_dtypes( + { + "2.5.0 and below": ( + "float32", + "float64", + "int32", + "int64", + "complex64", + "complex128", + ) + }, + "paddle", +) +@to_ivy_arrays_and_back +def roll(x, shifts, axis=None, name=None): + return ivy.roll(x, shifts, axis=axis) + + +@with_supported_device_and_dtypes( + { + "2.5.1 and above": { + "cpu": ( + "bool", + "int32", + "int64", + "float32", + "float64", + ), + "gpu": ("float16",), + }, + }, + "paddle", +) +@to_ivy_arrays_and_back +def rot90(x, k=1, axes=(0, 1), name=None): + return ivy.rot90(x, k=k, axes=axes) + + +@with_unsupported_dtypes( + {"2.5.1 and below": ("int16", "complex64", "complex128")}, + "paddle", +) +@to_ivy_arrays_and_back +def split(x, num_or_sections, axis=0, name=None): + return ivy.split(x, num_or_size_splits=num_or_sections, axis=axis) + + +@with_unsupported_dtypes( + {"2.5.1 and below": ("float16", "bfloat16", "int8", "int16")}, + "paddle", +) +@to_ivy_arrays_and_back +def squeeze(x, axis=None, name=None): + return ivy.squeeze(x, axis=axis) + + +@to_ivy_arrays_and_back +def stack(x, axis=0, name=None): + return ivy.stack(x, axis=axis) + + +def take_along_axis(arr, indices, axis): + return ivy.take_along_axis(arr, indices, axis) + + +@with_unsupported_dtypes( + {"2.5.1 and below": ("int8", "uint8", "int16", "float16")}, + "paddle", +) +@to_ivy_arrays_and_back +def tile(x, repeat_times, name=None): + return ivy.tile(x, repeats=repeat_times) + + +@with_supported_dtypes( + { + "2.5.1 and below": ( + "float32", + "float64", + "int32", + "int64", + ) + }, + "paddle", +) +@to_ivy_arrays_and_back +def unstack(x, axis=0, name=None): + return ivy.unstack(x, axis=axis) + + +absolute = abs diff --git a/ivy/functional/frontends/paddle/math.py b/ivy/functional/frontends/paddle/math.py new file mode 100644 index 0000000000000..c4d3a0130f79b --- /dev/null +++ b/ivy/functional/frontends/paddle/math.py @@ -0,0 +1,547 @@ +# global +import ivy +from ivy.func_wrapper import with_unsupported_dtypes, with_supported_dtypes +from ivy.functional.frontends.paddle.func_wrapper import to_ivy_arrays_and_back + + +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def abs(x, name=None): + return ivy.abs(x) + + +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def acos(x, name=None): + return ivy.acos(x) + + +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def acosh(x, name=None): + return ivy.acosh(x) + + +@with_unsupported_dtypes( + {"2.5.1 and below": ("bool", "unsigned", "int8", "float16", "bfloat16")}, "paddle" +) +@to_ivy_arrays_and_back +def add(x, y, name=None): + return ivy.add(x, y) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" +) +@to_ivy_arrays_and_back +def addmm(input, x, y, beta=1.0, alpha=1.0, name=None): + value = alpha * ivy.matmul(x, y) + (beta * input) + return value + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" +) +@to_ivy_arrays_and_back +def amax(x, axis=None, keepdims=False): + if axis is None: + return ivy.max(x) + if isinstance(axis, int): + axis = [axis] + for i in range(len(axis)): + if axis[i] < 0: + axis[i] += x.ndim + for i in axis: + if i < 0 or i >= x.ndim: + raise ValueError("axis {} is out of range [-{}:{}]".format(i, 0, x.ndim)) + return ivy.max(x, axis=axis, keepdims=keepdims) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" +) +@to_ivy_arrays_and_back +def amin(x, axis=None, keepdim=False, name=None): + return ivy.min(x, axis=axis, keepdims=keepdim) + + +@with_supported_dtypes( + {"2.5.1 and below": ("complex64", "complex128", "float32", "float64")}, + "paddle", +) +@to_ivy_arrays_and_back +def angle(x, name=None): + return ivy.angle(x) + + +@with_supported_dtypes({"2.5.0 and below": "bool"}, "paddle") +@to_ivy_arrays_and_back +def any(x, axis=None, keepdim=False, name=None): + return ivy.any(x, axis=axis, keepdims=keepdim) + + +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def asin(x, name=None): + return ivy.asin(x) + + +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def asinh(x, name=None): + return ivy.asinh(x) + + +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def atan(x, name=None): + return ivy.atan(x) + + +@with_unsupported_dtypes({"2.4.2 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def atan2(x, y, name=None): + return ivy.atan2(x, y) + + +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def atanh(x, name=None): + return ivy.atanh(x) + + +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def ceil(x, name=None): + return ivy.ceil(x) + + +@with_unsupported_dtypes({"2.4.2 and below": ("int16", "float16")}, "paddle") +@to_ivy_arrays_and_back +def conj(x, name=None): + return ivy.conj(x) + + +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def cos(x, name=None): + return ivy.cos(x) + + +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def cosh(x, name=None): + return ivy.cosh(x) + + +@with_supported_dtypes( + { + "2.5.1 and below": ( + "int32", + "int64", + "float32", + "float64", + "complex64", + "complex128", + ) + }, + "paddle", +) +@to_ivy_arrays_and_back +def cumprod(x, dim=None, dtype=None, name=None): + return ivy.cumprod(x, axis=dim, dtype=dtype) + + +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def deg2rad(x, name=None): + return ivy.deg2rad(x) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" +) +@to_ivy_arrays_and_back +def diff(x, n=1, axis=-1, prepend=None, append=None, name=None): + return ivy.diff(x, n=n, axis=axis, prepend=prepend, append=append) + + +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def digamma(x, name=None): + digamma_fun = ivy.digamma + return ivy.array(digamma_fun(x), dtype=x.dtype) + + +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def divide(x, y, name=None): + return ivy.divide(x, y) + + +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def erf(x, name=None): + return ivy.erf(x) + + +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def exp(x, name=None): + return ivy.exp(x) + + +@with_supported_dtypes({"2.5.1 and below": ("float16", "float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def expm1(x, name=None): + return ivy.expm1(x) + + +@with_supported_dtypes( + {"2.5.1 and below": ("bfloat16", "float32", "float64")}, "paddle" +) +@to_ivy_arrays_and_back +def floor(x, name=None): + return ivy.floor(x) + + +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def floor_divide(x, y, name=None): + return ivy.floor_divide(x, y) + + +@with_unsupported_dtypes({"2.5.1 and below": "bfloat16"}, "paddle") +@to_ivy_arrays_and_back +def fmax(x, y, name=None): + return ivy.fmax(x, y) + + +@with_unsupported_dtypes({"2.5.1 and below": "bfloat16"}, "paddle") +@to_ivy_arrays_and_back +def fmin(x, y, name=None): + return ivy.fmin(x, y) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" +) +@to_ivy_arrays_and_back +def frac(x, name=None): + y = ivy.trunc(x) + return ivy.subtract(x, y) + + +@with_supported_dtypes({"2.5.1 and below": ("int32", "int64")}, "paddle") +@to_ivy_arrays_and_back +def gcd(x, y, name=None): + return ivy.gcd(x, y) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64")}, "paddle" +) +@to_ivy_arrays_and_back +def heaviside(x, y, name=None): + return ivy.heaviside(x, y) + + +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def inner(x, y, name=None): + result = ivy.inner(x, y) + if (x.shape == () and y.shape == (1,)) or (x.shape == (1,) and y.shape == ()): + result = result.reshape((1,)) + elif x.shape == (1,) and y.shape == (1,): + result = result.reshape((1,)) + return result + + +@with_supported_dtypes( + {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64")}, "paddle" +) +@to_ivy_arrays_and_back +def isfinite(x, name=None): + return ivy.isfinite(x) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64")}, "paddle" +) +@to_ivy_arrays_and_back +def isinf(x, name=None): + return ivy.isinf(x) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64")}, "paddle" +) +@to_ivy_arrays_and_back +def isnan(x, name=None): + return ivy.isnan(x) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64")}, "paddle" +) +@to_ivy_arrays_and_back +def kron(x, y, name=None): + return ivy.kron(x, y) + + +@with_supported_dtypes({"2.5.1 and below": ("int32", "int64")}, "paddle") +@to_ivy_arrays_and_back +def lcm(x, y, name=None): + return ivy.lcm(x, y) + + +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def lerp(x, y, weight, name=None): + return ivy.lerp(x, y, weight) + + +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def lgamma(x, name=None): + return ivy.lgamma(x) + + +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def log(x, name=None): + return ivy.log(x) + + +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def log1p(x, name=None): + return ivy.log1p(x) + + +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def log2(x, name=None): + return ivy.log2(x) + + +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def logit(x, eps=None, name=None): + return ivy.logit(x, eps=eps) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" +) +@to_ivy_arrays_and_back +def max(x, axis=None, keepdim=False, name=None): + return ivy.max(x, axis=axis, keepdims=keepdim) + + +# maximum +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def maximum(x, y, name=None): + return ivy.maximum(x, y) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" +) +@to_ivy_arrays_and_back +def min(x, axis=None, keepdim=False, name=None): + return ivy.min(x, axis=axis, keepdims=keepdim) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" +) +@to_ivy_arrays_and_back +def minimum(x, y, name=None): + return ivy.minimum(x, y) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" +) +@to_ivy_arrays_and_back +def mm(input, mat2, name=None): + return ivy.matmul(input, mat2) + + +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def multiply(x, y, name=None): + return ivy.multiply(x, y) + + +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def nanmean(x, axis=None, keepdims=False): + return ivy.nanmean(x, axis=axis, keepdims=keepdims) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" +) +@to_ivy_arrays_and_back +def nansum(x, axis=None, dtype=None, name=None): + return ivy.nansum(x, axis=axis, dtype=dtype) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int8", "int16", "int32", "int64")}, + "paddle", +) +@to_ivy_arrays_and_back +def neg(x, name=None): + return ivy.negative(x) + + +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def outer(x, y, name=None): + return ivy.outer(x, y) + + +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def pow(x, y, name=None): + return ivy.pow(x, y) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" +) +@to_ivy_arrays_and_back +def prod(x, axis=None, keepdim=False, dtype=None, name=None): + return ivy.prod(x, axis=axis, keepdims=keepdim, dtype=dtype) + + +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def rad2deg(x, name=None): + return ivy.rad2deg(x) + + +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def reciprocal(x, name=None): + return ivy.reciprocal(x) + + +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def remainder(x, y, name=None): + return ivy.remainder(x, y) + + +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def round(x, name=None): + return ivy.round(x) + + +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def rsqrt(x, name=None): + return 1 / ivy.sqrt(x) + + +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def sgn(x, name=None): + return ivy.sign(x, np_variant=True) + + +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def sign(x, name=None): + return ivy.sign(x, np_variant=False) + + +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def sin(x, name=None): + return ivy.sin(x) + + +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def sinh(x, name=None): + return ivy.sinh(x) + + +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def sqrt(x, name=None): + return ivy.sqrt(x) + + +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def square(x, name=None): + return ivy.square(x) + + +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def stanh(x, scale_a=0.67, scale_b=1.7159, name=None): + # TODO this function will be simplified as soon as the ivy.stanh(x,a,b) is added + exp_ax = ivy.exp(ivy.multiply(scale_a, x)) + exp_minus_ax = ivy.exp(ivy.multiply(-scale_a, x)) + numerator = ivy.subtract(exp_ax, exp_minus_ax) + denominator = ivy.add(exp_ax, exp_minus_ax) + ret = ivy.multiply(scale_b, ivy.divide(numerator, denominator)) + return ret + + +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def subtract(x, y, name=None): + return ivy.subtract(x, y) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int6")}, "paddle" +) +@to_ivy_arrays_and_back +def take( + x, + index, + mode="raise", + name=None, +): + if mode not in ["raise", "wrap", "clip"]: + raise ValueError( + "'mode' in 'take' should be 'raise', 'wrap', 'clip', but received {}." + .format(mode) + ) + x = ivy.reshape(x, (-1,)) + if mode == "clip": + index = ivy.clip(index, 0, x.shape[-1] - 1) + elif mode == "wrap": + index = ivy.where(index < 0, index % x.shape[-1], index) + index = ivy.where(index >= x.shape[-1], index % x.shape[-1], index) + return ivy.gather(x, index, axis=0) + + +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def tan(x, name=None): + return ivy.tan(x) + + +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def tanh(x, name=None): + return ivy.tanh(x) + + +@with_supported_dtypes( + {"2.4.2 and below": ("float32", "float64", "int32", "int64")}, "paddle" +) +@to_ivy_arrays_and_back +def trunc(x, name=None): + return ivy.trunc(x) diff --git a/ivy/functional/frontends/paddle/nn/functional/activation.py b/ivy/functional/frontends/paddle/nn/functional/activation.py index cbff12b6b5197..9e774211febc6 100644 --- a/ivy/functional/frontends/paddle/nn/functional/activation.py +++ b/ivy/functional/frontends/paddle/nn/functional/activation.py @@ -245,6 +245,12 @@ def softsign( return ivy.divide(x, ivy.add(1, ivy.abs(x))) +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def swish(x, name=None): + return ivy.multiply(x, ivy.sigmoid(x)) + + @with_supported_dtypes({"2.4.2 and below": ("float32", "float64")}, "paddle") @to_ivy_arrays_and_back def tanh_(x, name=None): diff --git a/ivy/functional/frontends/paddle/nn/functional/loss.py b/ivy/functional/frontends/paddle/nn/functional/loss.py index 774e3d8e43cf4..04377a35403df 100644 --- a/ivy/functional/frontends/paddle/nn/functional/loss.py +++ b/ivy/functional/frontends/paddle/nn/functional/loss.py @@ -472,3 +472,21 @@ def triplet_margin_loss( loss = reduction(loss).astype(input.dtype) return loss + + +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def multi_label_soft_margin_loss( + input, label, weight=None, reduction="mean", name=None +): + reduction = _get_reduction_func(reduction) + loss = -( + label * ivy.log(ivy.sigmoid(input)) + + (1 - label) * ivy.log(1 - ivy.sigmoid(input)) + ) + + if weight is not None: + loss = ivy.multiply(weight, loss) + loss = ivy.mean(loss, axis=-1) + ret = reduction(loss).astype(input.dtype) + return ret diff --git a/ivy/functional/frontends/paddle/nn/functional/norm.py b/ivy/functional/frontends/paddle/nn/functional/norm.py index 6c6fad6a1f439..76bc210cabbb6 100644 --- a/ivy/functional/frontends/paddle/nn/functional/norm.py +++ b/ivy/functional/frontends/paddle/nn/functional/norm.py @@ -8,3 +8,11 @@ @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") def layer_norm(x, normalized_shape, weight=None, bias=None, epsilon=1e-05, name=None): return ivy.layer_norm(x, normalized_shape, weight, bias, epsilon) + + +@to_ivy_arrays_and_back +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +def normalize(x, p=2, axis=1, epsilon=1e-12, name=None): + if axis < 0: + axis = ivy.get_num_dims(x) + axis + return ivy.lp_normalize(x, p=p, axis=axis) diff --git a/ivy/functional/frontends/paddle/nn/functional/vision.py b/ivy/functional/frontends/paddle/nn/functional/vision.py index 653ed76189da5..4e98183b4ce4d 100644 --- a/ivy/functional/frontends/paddle/nn/functional/vision.py +++ b/ivy/functional/frontends/paddle/nn/functional/vision.py @@ -170,3 +170,45 @@ def pixel_shuffle(x, upscale_factor, data_format="NCHW"): return ivy.reshape( ivy.permute_dims(input_reshaped, (0, 1, 4, 2, 5, 3)), (b, oh, ow, oc) ) + + +@to_ivy_arrays_and_back +def pixel_unshuffle(x, downscale_factor, data_format="NCHW"): + if len(ivy.shape(x)) != 4: + raise ValueError( + "Input x should be 4D tensor, but received x with the shape of {}".format( + ivy.shape(x) + ) + ) + + if not isinstance(downscale_factor, int): + raise ValueError("Downscale factor must be int type") + + if downscale_factor <= 0: + raise ValueError("Downscale factor must be positive") + + if data_format not in ["NCHW", "NHWC"]: + raise ValueError( + "Attr(data_format) should be 'NCHW' or 'NHWC'." + "But recevie Attr(data_format): {} ".format(data_format) + ) + + if data_format == "NCHW": + b, c, h, w = ivy.shape(x) + oc = c * downscale_factor**2 + oh = h // downscale_factor + ow = w // downscale_factor + + x = ivy.reshape(x, (b, c, oh, downscale_factor, ow, downscale_factor)) + x = ivy.permute_dims(x, (0, 1, 3, 5, 2, 4)) + x = ivy.reshape(x, (b, oc, oh, ow)) + else: + b, h, w, c = ivy.shape(x) + oc = c * downscale_factor**2 + oh = h // downscale_factor + ow = w // downscale_factor + + x = ivy.reshape(x, (b, downscale_factor, oh, downscale_factor, ow, c)) + x = ivy.permute_dims(x, (0, 1, 3, 5, 2, 4)) + x = ivy.reshape(x, (b, oh, ow, oc)) + return x diff --git a/ivy/functional/frontends/paddle/random.py b/ivy/functional/frontends/paddle/random.py new file mode 100644 index 0000000000000..f5e84b9c43c65 --- /dev/null +++ b/ivy/functional/frontends/paddle/random.py @@ -0,0 +1,106 @@ +# global +import ivy +from ivy.func_wrapper import with_supported_dtypes +from ivy.func_wrapper import with_supported_device_and_dtypes, with_unsupported_dtypes +from ivy.functional.frontends.paddle.func_wrapper import ( + to_ivy_arrays_and_back, +) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64")}, + "paddle", +) +@to_ivy_arrays_and_back +def normal(mean=0.0, std=1.0, shape=None, name=None): + return ivy.random_normal(mean=mean, std=std, shape=shape) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64")}, + "paddle", +) +@to_ivy_arrays_and_back +def poisson(x, name=None): + return ivy.poisson(x, shape=None, device=None, dtype=None, seed=None, out=None) + + +@with_supported_device_and_dtypes( + { + "2.5.1 and above": { + "cpu": ( + "bfloat16", + "float32", + "float64", + ), + "gpu": ( + "bfloat16", + "float16", + "float32", + "float64", + ), + }, + "2.4.2 and below": { + "cpu": ( + "float32", + "float64", + ), + "gpu": ( + "float16", + "float32", + "float64", + ), + }, + }, + "paddle", +) +@to_ivy_arrays_and_back +def rand(shape, dtype=None, name=None): + return ivy.random_uniform(low=0.0, high=1.0, shape=shape, dtype=dtype, seed=None) + + +@to_ivy_arrays_and_back +def randint(low=0, high=None, shape=[1], dtype=None, name=None): + return ivy.randint(low, high, shape=shape, dtype=dtype) + + +@with_unsupported_dtypes( + {"2.5.1 and below": ("int16", "float16", "bfloat16", "uint8")}, + "paddle", +) +@to_ivy_arrays_and_back +def randint_like(x, low=0, high=None, dtype=None, name=None): + if high is None: + high = low + low = 0 + if high <= 0: + raise ivy.exceptions.IvyError( + "If high is None, low must be greater than 0, but received low = 0." + ) + return ivy.randint(low, high, shape=x.shape, dtype=dtype, seed=None) + + +def randn(shape, dtype=None, name=None): + if dtype not in ["float32", "float64"]: + raise ivy.exceptions.IvyError( + "Unsupported dtype for randn, only float32 and float64 are supported, " + ) + return ivy.random_normal(shape=shape, dtype=dtype, seed=None) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64")}, + "paddle", +) +@to_ivy_arrays_and_back +def standard_normal(shape, dtype=None, name=None): + return ivy.random_normal(mean=0, std=1, shape=shape, dtype=dtype) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64")}, + "paddle", +) +@to_ivy_arrays_and_back +def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None): + return ivy.random_uniform(low=min, high=max, shape=shape, dtype=dtype, seed=seed) diff --git a/ivy/functional/frontends/paddle/search.py b/ivy/functional/frontends/paddle/search.py new file mode 100644 index 0000000000000..205f2b6d1e5b7 --- /dev/null +++ b/ivy/functional/frontends/paddle/search.py @@ -0,0 +1,98 @@ +# global +import ivy +from ivy.func_wrapper import with_supported_dtypes +from ivy.functional.frontends.paddle.func_wrapper import ( + to_ivy_arrays_and_back, +) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int16", "int32", "int64", "uint8")}, + "paddle", +) +@to_ivy_arrays_and_back +def argmax(x, /, *, axis=None, keepdim=False, dtype="int64", name=None): + return ivy.argmax(x, axis=axis, keepdims=keepdim, dtype=dtype) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int16", "int32", "int64", "uint8")}, + "paddle", +) +@to_ivy_arrays_and_back +def argmin(x, /, *, axis=None, keepdim=False, dtype="int64", name=None): + return ivy.argmin(x, axis=axis, keepdims=keepdim, dtype=dtype) + + +@with_supported_dtypes( + {"2.4.2 and below": ("float32", "float64", "int16", "int32", "int64", "uint8")}, + "paddle", +) +@to_ivy_arrays_and_back +def argsort(x, /, *, axis=-1, descending=False, name=None): + return ivy.argsort(x, axis=axis, descending=descending) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, + "paddle", +) +@to_ivy_arrays_and_back +def masked_select(x, mask, name=None): + return ivy.flatten(x[mask]) + + +@with_supported_dtypes( + {"2.4.2 and below": ("float32", "float64", "int16", "int32", "int64", "uint8")}, + "paddle", +) +@to_ivy_arrays_and_back +def nonzero(input, /, *, as_tuple=False): + ret = ivy.nonzero(input) + if as_tuple is False: + ret = ivy.matrix_transpose(ivy.stack(ret)) + return ret + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, + "paddle", +) +@to_ivy_arrays_and_back +def searchsorted(sorted_sequence, values, out_int32=False, right=False, name=None): + if right: + side = "right" + else: + side = "left" + ret = ivy.searchsorted(sorted_sequence, values, side=side) + if out_int32: + ret = ivy.astype(ret, "int32") + return ret + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, + "paddle", +) +@to_ivy_arrays_and_back +def sort(x, /, *, axis=-1, descending=False, name=None): + return ivy.sort(x, axis=axis, descending=descending) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, + "paddle", +) +@to_ivy_arrays_and_back +def topk(x, k, axis=None, largest=True, sorted=True, name=None): + return ivy.top_k(x, k, axis=axis, largest=largest, sorted=sorted) + + +# where +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, + "paddle", +) +@to_ivy_arrays_and_back +def where(condition, x, y, name=None): + return ivy.where(condition, x, y) diff --git a/ivy/functional/frontends/paddle/signal.py b/ivy/functional/frontends/paddle/signal.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/ivy/functional/frontends/paddle/stat.py b/ivy/functional/frontends/paddle/stat.py new file mode 100644 index 0000000000000..91b69d7441091 --- /dev/null +++ b/ivy/functional/frontends/paddle/stat.py @@ -0,0 +1,76 @@ +# global +import ivy +from ivy.func_wrapper import with_unsupported_dtypes, with_supported_dtypes +from ivy.functional.frontends.paddle.func_wrapper import ( + to_ivy_arrays_and_back, +) + + +@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@to_ivy_arrays_and_back +def mean(input, axis=None, keepdim=False, out=None): + ret = ivy.mean(input, axis=axis, keepdims=keepdim, out=out) + return ret + + +@with_supported_dtypes( + {"2.5.1 and below": ("bool", "float16", "float32", "float64", "int32", "int64")}, + "paddle", +) +@to_ivy_arrays_and_back +def median(x, axis=None, keepdim=False, name=None): + x = ( + ivy.astype(x, ivy.float64) + if ivy.dtype(x) == "float64" + else ivy.astype(x, ivy.float32) + ) + return ivy.median(x, axis=axis, keepdims=keepdim) + + +@with_supported_dtypes( + {"2.5.0 and below": ("float16", "float32", "float64", "uint16")}, + "paddle", +) +@to_ivy_arrays_and_back +def nanmedian(x, axis=None, keepdim=True, name=None): + x = ( + ivy.astype(x, ivy.float64) + if ivy.dtype(x) == "float64" + else ivy.astype(x, ivy.float32) + ) + return ivy.median(x, axis=axis, keepdims=keepdim) + + +@with_unsupported_dtypes({"2.5.1 and below": ("complex", "int8")}, "paddle") +@to_ivy_arrays_and_back +def numel(x, name=None): + prod = ivy.prod(x.size, dtype=ivy.int64) + try: + length = len(x) + except (ValueError, TypeError): + length = 1 # if 0 dimensional tensor with 1 element + return ivy.array(prod if prod > 0 else ivy.array(length, dtype=ivy.int64)) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "uint16")}, + "paddle", +) +@to_ivy_arrays_and_back +def std(x, axis=None, unbiased=True, keepdim=False, name=None): + x = ( + ivy.astype(x, ivy.float64) + if ivy.dtype(x) == "float64" + else ivy.astype(x, ivy.float32) + ) + return ivy.std(x, axis=axis, correction=int(unbiased), keepdims=keepdim) + + +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def var(x, axis=None, unbiased=True, keepdim=False, name=None): + if unbiased: + correction = 1 + else: + correction = 0 + return ivy.var(x, axis=axis, correction=correction, keepdims=keepdim) diff --git a/ivy/functional/frontends/paddle/tensor/__init__.py b/ivy/functional/frontends/paddle/tensor/__init__.py index cefb9c41204cb..3b1f01ba23fee 100644 --- a/ivy/functional/frontends/paddle/tensor/__init__.py +++ b/ivy/functional/frontends/paddle/tensor/__init__.py @@ -2,8 +2,6 @@ from .attribute import * from . import creation from .creation import * -from . import einsum -from .einsum import * from . import linalg from .linalg import * from . import logic diff --git a/ivy/functional/frontends/paddle/tensor/attribute.py b/ivy/functional/frontends/paddle/tensor/attribute.py index a94192737bdcb..605913c9f2c85 100644 --- a/ivy/functional/frontends/paddle/tensor/attribute.py +++ b/ivy/functional/frontends/paddle/tensor/attribute.py @@ -1,35 +1,2 @@ -# global -import ivy -from ivy.functional.frontends.paddle.func_wrapper import ( - to_ivy_arrays_and_back, -) - - -@to_ivy_arrays_and_back -def imag(x): - return ivy.imag(x) - - -@to_ivy_arrays_and_back -def is_complex(x): - return ivy.is_complex_dtype(x) - - -@to_ivy_arrays_and_back -def is_floating_point(x): - return ivy.is_float_dtype(x) - - -@to_ivy_arrays_and_back -def is_integer(x): - return ivy.is_int_dtype(x) - - -@to_ivy_arrays_and_back -def rank(input): - return ivy.get_num_dims(input) - - -@to_ivy_arrays_and_back -def real(x): - return ivy.real(x) +# local +from ..attribute import * # noqa: F401 diff --git a/ivy/functional/frontends/paddle/tensor/creation.py b/ivy/functional/frontends/paddle/tensor/creation.py index 7ed0e10847fac..fa08078aa84b2 100644 --- a/ivy/functional/frontends/paddle/tensor/creation.py +++ b/ivy/functional/frontends/paddle/tensor/creation.py @@ -1,222 +1,2 @@ -# global -import ivy -from ivy.func_wrapper import with_unsupported_dtypes, with_supported_dtypes -from .tensor import Tensor -from ivy.functional.frontends.paddle.func_wrapper import ( - to_ivy_arrays_and_back, -) - - -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def arange(start, end=None, step=1, dtype=None, name=None): - return ivy.arange(start, end, step=step, dtype=dtype) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64", "bool")}, - "paddle", -) -@to_ivy_arrays_and_back -def assign(x, output=None): - if len(ivy.shape(x)) == 0: - x = ivy.reshape(ivy.Array(x), (1,)) - if ivy.exists(output): - output = ivy.reshape(ivy.Array(output), (1,)) - else: - x = ivy.reshape(x, ivy.shape(x)) - ret = ivy.copy_array(x, to_ivy_array=False, out=output) - return ret - - -@with_unsupported_dtypes( - {"2.5.1 and below": ("bfloat16", "uint16", "uint32", "uint64")}, "paddle" -) -@to_ivy_arrays_and_back -def clone(x): - return ivy.copy_array(x) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64")}, - "paddle", -) -@to_ivy_arrays_and_back -def complex(real, imag, name=None): - assert real.dtype == imag.dtype, ( - "(InvalidArgument) The type of data we are trying to retrieve does not match" - " the type of data currently contained in the container." - ) - complex_dtype = "complex64" if real.dtype == "float32" else "complex128" - imag_cmplx = ivy.astype(imag, complex_dtype) * 1j - complex_array = real + imag_cmplx - return complex_array - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" -) -@to_ivy_arrays_and_back -def diag(x, offset=0, padding_value=0, name=None): - if len(x.shape) == 1: - padding_value = ivy.astype(padding_value, ivy.dtype(x)) - ret = ivy.diagflat(x, offset=offset, padding_value=padding_value) - if len(ret.shape) != 2: - ret = ivy.reshape(ret, (1, 1)) - else: - ret = ivy.diag(x, k=offset) - return ret - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" -) -@to_ivy_arrays_and_back -def diagflat(x, offset=0, name=None): - arr = ivy.diagflat(x, offset=offset) - return arr - - -@to_ivy_arrays_and_back -def empty(shape, dtype=None): - return ivy.empty(shape=shape, dtype=dtype) - - -@to_ivy_arrays_and_back -def empty_like(x, dtype=None, name=None): - return ivy.empty_like(x, dtype=dtype) - - -@to_ivy_arrays_and_back -def eye(num_rows, num_columns=None, dtype=None, name=None): - return ivy.eye(num_rows, num_columns, dtype=dtype) - - -@to_ivy_arrays_and_back -def full(shape, fill_value, /, *, dtype=None, name=None): - dtype = "float32" if dtype is None else dtype - return ivy.full(shape, fill_value, dtype=dtype) - - -@to_ivy_arrays_and_back -def full_like(x, fill_value, /, *, dtype=None, name=None): - dtype = x.dtype if dtype is None else dtype - return ivy.full_like(x, fill_value, dtype=dtype) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" -) -@to_ivy_arrays_and_back -def linspace(start, stop, num, dtype=None, name=None): - return ivy.linspace(start, stop, num=num, dtype=dtype) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" -) -@to_ivy_arrays_and_back -def logspace(start, stop, num, base=10.0, dtype=None, name=None): - return ivy.logspace(start, stop, num=num, base=base, dtype=dtype) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" -) -@to_ivy_arrays_and_back -def meshgrid(*args, **kwargs): - return ivy.meshgrid(*args, indexing="ij") - - -@with_unsupported_dtypes({"2.5.1 and below": "int8"}, "paddle") -@to_ivy_arrays_and_back -def ones(shape, /, *, dtype=None, name=None): - dtype = "float32" if dtype is None else dtype - return ivy.ones(shape, dtype=dtype) - - -@with_unsupported_dtypes( - {"2.5.1 and below": ("uint8", "int8", "complex64", "complex128")}, "paddle" -) -@to_ivy_arrays_and_back -def ones_like(x, /, *, dtype=None, name=None): - dtype = x.dtype if dtype is None else dtype - return ivy.ones_like(x, dtype=dtype) - - -@to_ivy_arrays_and_back -def to_tensor(data, /, *, dtype=None, place=None, stop_gradient=True): - array = ivy.array(data, dtype=dtype, device=place) - return Tensor(array, dtype=dtype, place=place) - - -@with_unsupported_dtypes( - { - "2.5.1 and below": ( - "uint8", - "int8", - "int16", - "float16", - "complex64", - "complex128", - "bool", - ) - }, - "paddle", -) -@to_ivy_arrays_and_back -def tril(x, diagonal=0, name=None): - return ivy.tril(x, k=diagonal) - - -@with_supported_dtypes({"2.5.1 and below": ("int32", "int64")}, "paddle") -@to_ivy_arrays_and_back -def tril_indices(row, col, offset=0, dtype="int64"): - arr = ivy.tril_indices(row, col, offset) - arr = ivy.astype(arr, dtype) - return arr - - -@with_unsupported_dtypes( - { - "2.5.1 and below": ( - "uint8", - "int8", - "int16", - "float16", - "complex64", - "complex128", - "bool", - ) - }, - "paddle", -) -@to_ivy_arrays_and_back -def triu(x, diagonal=0, name=None): - return ivy.triu(x, k=diagonal) - - -@with_supported_dtypes({"2.5.1 and below": ("int32", "int64")}, "paddle") -@to_ivy_arrays_and_back -def triu_indices(row, col=None, offset=0, dtype="int64"): - arr = ivy.triu_indices(row, col, offset) - if not ivy.to_scalar(ivy.shape(arr[0], as_array=True)): - return arr - arr = ivy.astype(arr, dtype) - return arr - - -@with_unsupported_dtypes({"2.5.1 and below": "int8"}, "paddle") -@to_ivy_arrays_and_back -def zeros(shape, /, *, dtype=None, name=None): - dtype = "float32" if dtype is None else dtype - return ivy.zeros(shape, dtype=dtype) - - -@with_unsupported_dtypes( - {"2.5.1 and below": ("uint8", "int8", "complex64", "complex128")}, "paddle" -) -@to_ivy_arrays_and_back -def zeros_like(x, /, *, dtype=None, name=None): - dtype = x.dtype if dtype is None else dtype - return ivy.zeros_like(x, dtype=dtype) +# local +from ..creation import * # noqa: F401 diff --git a/ivy/functional/frontends/paddle/tensor/einsum.py b/ivy/functional/frontends/paddle/tensor/einsum.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/ivy/functional/frontends/paddle/tensor/linalg.py b/ivy/functional/frontends/paddle/tensor/linalg.py index 4ae10e9824324..4859ac7195588 100644 --- a/ivy/functional/frontends/paddle/tensor/linalg.py +++ b/ivy/functional/frontends/paddle/tensor/linalg.py @@ -1,196 +1,2 @@ -# global -import ivy -from ivy.func_wrapper import with_unsupported_dtypes, with_supported_dtypes -from ivy.functional.frontends.paddle import promote_types_of_paddle_inputs -from ivy.functional.frontends.paddle.func_wrapper import ( - to_ivy_arrays_and_back, -) - - -@with_supported_dtypes({"2.4.1 and above": ("int64",)}, "paddle") -@to_ivy_arrays_and_back -def bincount(x, weights=None, minlength=0, name=None): - return ivy.bincount(x, weights=weights, minlength=minlength) - - -# bmm -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def bmm(x, y, transpose_x=False, transpose_y=False, name=None): - if len(ivy.shape(x)) != 3 or len(ivy.shape(y)) != 3: - raise RuntimeError("input must be 3D matrices") - x, y = promote_types_of_paddle_inputs(x, y) - return ivy.matmul(x, y, transpose_a=transpose_x, transpose_b=transpose_y) - - -# cholesky -@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") -@to_ivy_arrays_and_back -def cholesky(x, /, *, upper=False, name=None): - return ivy.cholesky(x, upper=upper) - - -# cholesky_solve -@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") -@to_ivy_arrays_and_back -def cholesky_solve(x, y, /, *, upper=False, name=None): - if upper: - y = ivy.matrix_transpose(y) - Y = ivy.solve(y, x) - return ivy.solve(ivy.matrix_transpose(y), Y) - - -# cond -@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") -@to_ivy_arrays_and_back -def cond(x, p=None, name=None): - ret = ivy.cond(x, p=p, out=name) - if ret.shape == (): - ret = ret.reshape((1,)) - return ret - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" -) -@to_ivy_arrays_and_back -def cross(x, y, /, *, axis=9, name=None): - x, y = promote_types_of_paddle_inputs(x, y) - return ivy.cross(x, y, axis=axis) - - -@with_supported_dtypes({"2.4.1 and above": ("float64", "float32")}, "paddle") -@to_ivy_arrays_and_back -def dist(x, y, p=2): - ret = ivy.vector_norm(ivy.subtract(x, y), ord=p) - return ivy.reshape(ret, (1,)) - - -# dot -@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") -@to_ivy_arrays_and_back -def dot(x, y, name=None): - x, y = promote_types_of_paddle_inputs(x, y) - out = ivy.multiply(x, y) - return ivy.sum(out, axis=ivy.get_num_dims(x) - 1, keepdims=False) - - -# eig -@to_ivy_arrays_and_back -def eig(x, name=None): - return ivy.eig(x) - - -# eigh -@to_ivy_arrays_and_back -def eigh(x, UPLO="L", name=None): - return ivy.eigh(x, UPLO=UPLO) - - -# eigvals -@to_ivy_arrays_and_back -def eigvals(x, name=None): - return ivy.eigvals(x) - - -# eigvalsh -@to_ivy_arrays_and_back -def eigvalsh(x, UPLO="L", name=None): - return ivy.eigvalsh(x, UPLO=UPLO) - - -# matmul -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def matmul(x, y, transpose_x=False, transpose_y=False, name=None): - x, y = promote_types_of_paddle_inputs(x, y) - return ivy.matmul(x, y, transpose_a=transpose_x, transpose_b=transpose_y) - - -# matrix_power -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def matrix_power(x, n, name=None): - return ivy.matrix_power(x, n) - - -# norm -@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") -@to_ivy_arrays_and_back -def norm(x, p="fro", axis=None, keepdim=False, name=None): - if axis is None and p is not None: - if p == "fro": - p = 2 - ret = ivy.vector_norm(x.flatten(), ord=p, axis=-1) - if keepdim: - ret = ret.reshape([1] * len(x.shape)) - return ret - - if isinstance(axis, tuple): - axis = list(axis) - if isinstance(axis, list) and len(axis) == 1: - axis = axis[0] - - if isinstance(axis, int): - if p == "fro": - p = 2 - if p in [0, 1, 2, ivy.inf, -ivy.inf]: - ret = ivy.vector_norm(x, ord=p, axis=axis, keepdims=keepdim) - elif isinstance(p, (int, float)): - ret = ivy.pow( - ivy.sum(ivy.pow(ivy.abs(x), p), axis=axis, keepdims=keepdim), - float(1.0 / p), - ) - - elif isinstance(axis, list) and len(axis) == 2: - if p == 0: - raise ValueError - elif p == 1: - ret = ivy.sum(ivy.abs(x), axis=axis, keepdims=keepdim) - elif p == 2 or p == "fro": - ret = ivy.matrix_norm(x, ord="fro", axis=axis, keepdims=keepdim) - elif p == ivy.inf: - ret = ivy.max(ivy.abs(x), axis=axis, keepdims=keepdim) - elif p == -ivy.inf: - ret = ivy.min(ivy.abs(x), axis=axis, keepdims=keepdim) - elif isinstance(p, (int, float)) and p > 0: - ret = ivy.pow( - ivy.sum(ivy.pow(ivy.abs(x), p), axis=axis, keepdims=keepdim), - float(1.0 / p), - ) - else: - raise ValueError - - else: - raise ValueError - - return ret - - -# pinv -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def pinv(x, rcond=1e-15, hermitian=False, name=None): - # TODO: Add hermitian functionality - return ivy.pinv(x, rtol=rcond) - - -# qr -@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") -@to_ivy_arrays_and_back -def qr(x, mode="reduced", name=None): - return ivy.qr(x, mode=mode) - - -# solve -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def solve(x1, x2, name=None): - return ivy.solve(x1, x2) - - -# transpose -@with_unsupported_dtypes({"2.5.1 and below": ("uint8", "int8", "int16")}, "paddle") -@to_ivy_arrays_and_back -def transpose(x, perm, name=None): - return ivy.permute_dims(x, axes=perm) +# local +from ..linalg import * # noqa: F401 diff --git a/ivy/functional/frontends/paddle/tensor/logic.py b/ivy/functional/frontends/paddle/tensor/logic.py index 1b2ea90fa4b06..2b61d7f97cd96 100644 --- a/ivy/functional/frontends/paddle/tensor/logic.py +++ b/ivy/functional/frontends/paddle/tensor/logic.py @@ -1,286 +1,2 @@ -# global -import ivy -import ivy.functional.frontends.paddle as paddle -from ivy.func_wrapper import ( - with_unsupported_dtypes, - handle_out_argument, - with_supported_dtypes, -) -from ivy.functional.frontends.paddle.func_wrapper import ( - to_ivy_arrays_and_back, -) - - -@with_supported_dtypes( - { - "2.5.1 and below": ( - "float32", - "float64", - "bool", - "uint8", - "int8", - "int16", - "int32", - "int64", - ) - }, - "paddle", -) -@to_ivy_arrays_and_back -@handle_out_argument -def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): - ret = ivy.allclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) - return paddle.to_tensor([ret]) - - -@with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "uint8", - "int8", - "int16", - "int32", - "int64", - ) - }, - "paddle", -) -@to_ivy_arrays_and_back -@handle_out_argument -def bitwise_and(x, y, /, *, name=None, out=None): - return ivy.bitwise_and(x, y, out=out) - - -@with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "uint8", - "int8", - "int16", - "int32", - "int64", - ) - }, - "paddle", -) -@to_ivy_arrays_and_back -@handle_out_argument -def bitwise_not(x, out=None, name=None): - return ivy.bitwise_invert(x, out=out) - - -@with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "uint8", - "int8", - "int16", - "int32", - "int64", - ) - }, - "paddle", -) -@to_ivy_arrays_and_back -@handle_out_argument -def bitwise_or(x, y, name=None, out=None): - return ivy.bitwise_or(x, y, out=out) - - -@with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "uint8", - "int8", - "int16", - "int32", - "int64", - ) - }, - "paddle", -) -@to_ivy_arrays_and_back -@handle_out_argument -def bitwise_xor(x, y, /, *, name=None, out=None): - return ivy.bitwise_xor(x, y, out=out) - - -@with_unsupported_dtypes( - {"2.5.1 and below": ("uint8", "int8", "int16", "complex64", "complex128")}, "paddle" -) -@to_ivy_arrays_and_back -def equal(x, y, /, *, name=None): - return ivy.equal(x, y) - - -@with_unsupported_dtypes( - { - "2.5.1 and below": ( - "uint8", - "int8", - "int16", - "float16", - "complex64", - "complex128", - ) - }, - "paddle", -) -@to_ivy_arrays_and_back -def equal_all(x, y, /, *, name=None): - return paddle.to_tensor([ivy.array_equal(x, y)]) - - -@with_unsupported_dtypes( - {"2.5.1 and below": ("bool", "uint8", "int8", "int16", "complex64", "complex128")}, - "paddle", -) -@to_ivy_arrays_and_back -def greater_equal(x, y, /, *, name=None): - return ivy.greater_equal(x, y) - - -@with_unsupported_dtypes( - {"2.5.1 and below": ("bool", "uint8", "int8", "int16", "complex64", "complex128")}, - "paddle", -) -@to_ivy_arrays_and_back -def greater_than(x, y, /, *, name=None): - return ivy.greater(x, y) - - -@with_unsupported_dtypes( - {"2.5.1 and below": ("uint8", "int8", "int16", "complex64", "complex128")}, "paddle" -) -@to_ivy_arrays_and_back -def is_empty(x, name=None): - return ivy.is_empty(x) - - -@to_ivy_arrays_and_back -def is_tensor(x): - return ivy.is_array(x) - - -@with_supported_dtypes( - { - "2.5.1 and below": ( - "float32", - "float64", - ) - }, - "paddle", -) -@to_ivy_arrays_and_back -def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): - return ivy.isclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) - - -@with_unsupported_dtypes( - {"2.5.1 and below": ("bool", "uint8", "int8", "int16", "complex64", "complex128")}, - "paddle", -) -@to_ivy_arrays_and_back -def less_equal(x, y, /, *, name=None): - return ivy.less_equal(x, y) - - -@with_unsupported_dtypes( - {"2.5.1 and below": ("bool", "uint8", "int8", "int16", "complex64", "complex128")}, - "paddle", -) -@to_ivy_arrays_and_back -def less_than(x, y, /, *, name=None): - return ivy.less(x, y) - - -@with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "int8", - "int16", - "int32", - "int64", - "float32", - "float64", - ) - }, - "paddle", -) -@to_ivy_arrays_and_back -@handle_out_argument -def logical_and(x, y, /, *, name=None, out=None): - return ivy.logical_and(x, y, out=out) - - -@with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "int8", - "int16", - "int32", - "int64", - "float32", - "float64", - ) - }, - "paddle", -) -@to_ivy_arrays_and_back -@handle_out_argument -def logical_not(x, /, *, name=None, out=None): - return ivy.logical_not(x, out=out) - - -@with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "int8", - "int16", - "int32", - "int64", - "float32", - "float64", - ) - }, - "paddle", -) -@to_ivy_arrays_and_back -@handle_out_argument -def logical_or(x, y, /, *, name=None, out=None): - return ivy.logical_or(x, y, out=out) - - -@with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "int8", - "int16", - "int32", - "int64", - "float32", - "float64", - ) - }, - "paddle", -) -@to_ivy_arrays_and_back -@handle_out_argument -def logical_xor(x, y, /, *, name=None, out=None): - return ivy.logical_xor(x, y, out=out) - - -@with_unsupported_dtypes( - {"2.5.1 and below": ("uint8", "int8", "int16", "complex64", "complex128")}, "paddle" -) -@to_ivy_arrays_and_back -def not_equal(x, y, /, *, name=None): - return ivy.not_equal(x, y) +# local +from ..logic import * # noqa: F401 diff --git a/ivy/functional/frontends/paddle/tensor/manipulation.py b/ivy/functional/frontends/paddle/tensor/manipulation.py index ae2259b53edac..35ce51915f140 100644 --- a/ivy/functional/frontends/paddle/tensor/manipulation.py +++ b/ivy/functional/frontends/paddle/tensor/manipulation.py @@ -1,85 +1,14 @@ -# global +# local +from ..manipulation import * # noqa: F401 import ivy from ivy.functional.frontends.paddle.func_wrapper import ( to_ivy_arrays_and_back, ) -from ivy.func_wrapper import ( - with_unsupported_dtypes, - with_supported_dtypes, - with_supported_device_and_dtypes, -) - - -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def abs(x, name=None): - return ivy.abs(x) - - -@with_supported_dtypes( - {"2.5.1 and below": ("bool", "float32", "float64", "int32", "int64")}, - "paddle", -) -@to_ivy_arrays_and_back -def broadcast_to(x, shape, name=None): - return ivy.broadcast_to(x, shape) - - -@with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "float16", - "float32", - "float64", - "int32", - "int64", - "uint8", - ) - }, - "paddle", -) -@to_ivy_arrays_and_back -def cast(x, dtype): - return ivy.astype(x, dtype) - - -@with_unsupported_dtypes({"2.5.1 and below": ("int8", "int16")}, "paddle") -@to_ivy_arrays_and_back -def concat(x, axis, name=None): - return ivy.concat(x, axis=axis) - - -@with_supported_dtypes( - {"2.5.1 and below": ("bool", "float32", "float64", "int32", "int64")}, - "paddle", -) -@to_ivy_arrays_and_back -def expand(x, shape, name=None): - return ivy.expand(x, shape) +from ivy.func_wrapper import with_unsupported_dtypes - -@with_unsupported_dtypes( - {"2.5.1 and below": ("int8", "uint8", "int16", "float16")}, - "paddle", -) -@to_ivy_arrays_and_back -def flip(x, axis, name=None): - return ivy.flip(x, axis=axis) - - -@with_supported_dtypes( - {"2.5.1 and below": ("bool", "float32", "float64", "int32", "int64")}, - "paddle", -) -@to_ivy_arrays_and_back -def gather(params, indices, axis=-1, batch_dims=0, name=None): - return ivy.gather(params, indices, axis=axis, batch_dims=batch_dims) - - -@to_ivy_arrays_and_back -def reshape(x, shape): - return ivy.reshape(x, shape) +# NOTE: +# Only inplace functions are to be added in this file. +# Please add non-inplace counterparts to `/frontends/paddle/manipulation.py`. @with_unsupported_dtypes( @@ -91,96 +20,3 @@ def reshape_(x, shape): ret = ivy.reshape(x, shape) ivy.inplace_update(x, ret) return x - - -@with_supported_dtypes( - { - "2.5.0 and below": ( - "float32", - "float64", - "int32", - "int64", - "complex64", - "complex128", - ) - }, - "paddle", -) -@to_ivy_arrays_and_back -def roll(x, shifts, axis=None, name=None): - return ivy.roll(x, shifts, axis=axis) - - -@with_supported_device_and_dtypes( - { - "2.5.1 and above": { - "cpu": ( - "bool", - "int32", - "int64", - "float32", - "float64", - ), - "gpu": ("float16",), - }, - }, - "paddle", -) -@to_ivy_arrays_and_back -def rot90(x, k=1, axes=(0, 1), name=None): - return ivy.rot90(x, k=k, axes=axes) - - -@with_unsupported_dtypes( - {"2.5.1 and below": ("int16", "complex64", "complex128")}, - "paddle", -) -@to_ivy_arrays_and_back -def split(x, num_or_sections, axis=0, name=None): - return ivy.split(x, num_or_size_splits=num_or_sections, axis=axis) - - -@with_unsupported_dtypes( - {"2.5.1 and below": ("float16", "bfloat16", "int8", "int16")}, - "paddle", -) -@to_ivy_arrays_and_back -def squeeze(x, axis=None, name=None): - return ivy.squeeze(x, axis=axis) - - -@to_ivy_arrays_and_back -def stack(x, axis=0, name=None): - return ivy.stack(x, axis=axis) - - -def take_along_axis(arr, indices, axis): - return ivy.take_along_axis(arr, indices, axis) - - -@with_unsupported_dtypes( - {"2.5.1 and below": ("int8", "uint8", "int16", "float16")}, - "paddle", -) -@to_ivy_arrays_and_back -def tile(x, repeat_times, name=None): - return ivy.tile(x, repeats=repeat_times) - - -@with_supported_dtypes( - { - "2.5.1 and below": ( - "float32", - "float64", - "int32", - "int64", - ) - }, - "paddle", -) -@to_ivy_arrays_and_back -def unstack(x, axis=0, name=None): - return ivy.unstack(x, axis=axis) - - -absolute = abs diff --git a/ivy/functional/frontends/paddle/tensor/math.py b/ivy/functional/frontends/paddle/tensor/math.py index f28334cfcefe8..0f4695a7c8e14 100644 --- a/ivy/functional/frontends/paddle/tensor/math.py +++ b/ivy/functional/frontends/paddle/tensor/math.py @@ -1,119 +1,12 @@ -# global +# local +from ..math import * # noqa: F401 import ivy from ivy.func_wrapper import with_unsupported_dtypes, with_supported_dtypes from ivy.functional.frontends.paddle.func_wrapper import to_ivy_arrays_and_back - -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def abs(x, name=None): - return ivy.abs(x) - - -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def acos(x, name=None): - return ivy.acos(x) - - -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def acosh(x, name=None): - return ivy.acosh(x) - - -@with_unsupported_dtypes( - {"2.5.1 and below": ("bool", "unsigned", "int8", "float16", "bfloat16")}, "paddle" -) -@to_ivy_arrays_and_back -def add(x, y, name=None): - return ivy.add(x, y) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" -) -@to_ivy_arrays_and_back -def addmm(input, x, y, beta=1.0, alpha=1.0, name=None): - value = alpha * ivy.matmul(x, y) + (beta * input) - return value - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" -) -@to_ivy_arrays_and_back -def amax(x, axis=None, keepdims=False): - if axis is None: - return ivy.max(x) - if isinstance(axis, int): - axis = [axis] - for i in range(len(axis)): - if axis[i] < 0: - axis[i] += x.ndim - for i in axis: - if i < 0 or i >= x.ndim: - raise ValueError("axis {} is out of range [-{}:{}]".format(i, 0, x.ndim)) - return ivy.max(x, axis=axis, keepdims=keepdims) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" -) -@to_ivy_arrays_and_back -def amin(x, axis=None, keepdim=False, name=None): - return ivy.min(x, axis=axis, keepdims=keepdim) - - -@with_supported_dtypes( - {"2.5.1 and below": ("complex64", "complex128", "float32", "float64")}, - "paddle", -) -@to_ivy_arrays_and_back -def angle(x, name=None): - return ivy.angle(x) - - -@with_supported_dtypes({"2.5.0 and below": "bool"}, "paddle") -@to_ivy_arrays_and_back -def any(x, axis=None, keepdim=False, name=None): - return ivy.any(x, axis=axis, keepdims=keepdim) - - -@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") -@to_ivy_arrays_and_back -def asin(x, name=None): - return ivy.asin(x) - - -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def asinh(x, name=None): - return ivy.asinh(x) - - -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def atan(x, name=None): - return ivy.atan(x) - - -@with_unsupported_dtypes({"2.4.2 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def atan2(x, y, name=None): - return ivy.atan2(x, y) - - -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def atanh(x, name=None): - return ivy.atanh(x) - - -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def ceil(x, name=None): - return ivy.ceil(x) +# NOTE: +# Only inplace functions are to be added in this file. +# Please add non-inplace counterparts to `/frontends/paddle/math.py`. @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") @@ -122,313 +15,16 @@ def ceil_(x, name=None): return ivy.ceil(x, out=x) -@with_unsupported_dtypes({"2.4.2 and below": ("int16", "float16")}, "paddle") -@to_ivy_arrays_and_back -def conj(x, name=None): - return ivy.conj(x) - - -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def cos(x, name=None): - return ivy.cos(x) - - -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def cosh(x, name=None): - return ivy.cosh(x) - - -@with_supported_dtypes( - { - "2.5.1 and below": ( - "int32", - "int64", - "float32", - "float64", - "complex64", - "complex128", - ) - }, - "paddle", -) -@to_ivy_arrays_and_back -def cumprod(x, dim=None, dtype=None, name=None): - return ivy.cumprod(x, axis=dim, dtype=dtype) - - -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def deg2rad(x, name=None): - return ivy.deg2rad(x) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" -) -@to_ivy_arrays_and_back -def diff(x, n=1, axis=-1, prepend=None, append=None, name=None): - return ivy.diff(x, n=n, axis=axis, prepend=prepend, append=append) - - -@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") -@to_ivy_arrays_and_back -def digamma(x, name=None): - digamma_fun = ivy.digamma - return ivy.array(digamma_fun(x), dtype=x.dtype) - - -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def divide(x, y, name=None): - return ivy.divide(x, y) - - -@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") -@to_ivy_arrays_and_back -def erf(x, name=None): - return ivy.erf(x) - - -@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") -@to_ivy_arrays_and_back -def exp(x, name=None): - return ivy.exp(x) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") @to_ivy_arrays_and_back def exp_(x, name=None): return ivy.inplace_update(x, exp(x)) -@with_supported_dtypes({"2.5.1 and below": ("float16", "float32", "float64")}, "paddle") -@to_ivy_arrays_and_back -def expm1(x, name=None): - return ivy.expm1(x) - - -@with_supported_dtypes( - {"2.5.1 and below": ("bfloat16", "float32", "float64")}, "paddle" -) -@to_ivy_arrays_and_back -def floor(x, name=None): - return ivy.floor(x) - - -@with_unsupported_dtypes({"2.5.1 and below": "bfloat16"}, "paddle") -@to_ivy_arrays_and_back -def fmax(x, y, name=None): - return ivy.fmax(x, y) - - -@with_unsupported_dtypes({"2.5.1 and below": "bfloat16"}, "paddle") -@to_ivy_arrays_and_back -def fmin(x, y, name=None): - return ivy.fmin(x, y) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" -) -@to_ivy_arrays_and_back -def frac(x, name=None): - y = ivy.trunc(x) - return ivy.subtract(x, y) - - -@with_supported_dtypes({"2.5.1 and below": ("int32", "int64")}, "paddle") -@to_ivy_arrays_and_back -def gcd(x, y, name=None): - return ivy.gcd(x, y) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64")}, "paddle" -) -@to_ivy_arrays_and_back -def heaviside(x, y, name=None): - return ivy.heaviside(x, y) - - -@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") -@to_ivy_arrays_and_back -def inner(x, y, name=None): - result = ivy.inner(x, y) - if (x.shape == () and y.shape == (1,)) or (x.shape == (1,) and y.shape == ()): - result = result.reshape((1,)) - elif x.shape == (1,) and y.shape == (1,): - result = result.reshape((1,)) - return result - - -@with_supported_dtypes( - {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64")}, "paddle" -) -@to_ivy_arrays_and_back -def isfinite(x, name=None): - return ivy.isfinite(x) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64")}, "paddle" -) -@to_ivy_arrays_and_back -def isinf(x, name=None): - return ivy.isinf(x) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64")}, "paddle" -) -@to_ivy_arrays_and_back -def isnan(x, name=None): - return ivy.isnan(x) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64")}, "paddle" -) -@to_ivy_arrays_and_back -def kron(x, y, name=None): - return ivy.kron(x, y) - - -@with_supported_dtypes({"2.5.1 and below": ("int32", "int64")}, "paddle") -@to_ivy_arrays_and_back -def lcm(x, y, name=None): - return ivy.lcm(x, y) - - -@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") -@to_ivy_arrays_and_back -def lerp(x, y, weight, name=None): - return ivy.lerp(x, y, weight) - - -@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") -@to_ivy_arrays_and_back -def lgamma(x, name=None): - return ivy.lgamma(x) - - -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def log(x, name=None): - return ivy.log(x) - - -@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") -@to_ivy_arrays_and_back -def log1p(x, name=None): - return ivy.log1p(x) - - -@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") -@to_ivy_arrays_and_back -def log2(x, name=None): - return ivy.log2(x) - - -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def logit(x, eps=None, name=None): - return ivy.logit(x, eps=eps) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" -) -@to_ivy_arrays_and_back -def max(x, axis=None, keepdim=False, name=None): - return ivy.max(x, axis=axis, keepdims=keepdim) - - -# maximum -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def maximum(x, y, name=None): - return ivy.maximum(x, y) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" -) -@to_ivy_arrays_and_back -def min(x, axis=None, keepdim=False, name=None): - return ivy.min(x, axis=axis, keepdims=keepdim) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" -) -@to_ivy_arrays_and_back -def minimum(x, y, name=None): - return ivy.minimum(x, y) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" -) -@to_ivy_arrays_and_back -def mm(input, mat2, name=None): - return ivy.matmul(input, mat2) - - -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def multiply(x, y, name=None): - return ivy.multiply(x, y) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" -) -@to_ivy_arrays_and_back -def nansum(x, axis=None, dtype=None, name=None): - return ivy.nansum(x, axis=axis, dtype=dtype) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int8", "int16", "int32", "int64")}, - "paddle", -) -@to_ivy_arrays_and_back -def neg(x, name=None): - return ivy.negative(x) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") @to_ivy_arrays_and_back -def outer(x, y, name=None): - return ivy.outer(x, y) - - -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def pow(x, y, name=None): - return ivy.pow(x, y) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" -) -@to_ivy_arrays_and_back -def prod(x, axis=None, keepdim=False, dtype=None, name=None): - return ivy.prod(x, axis=axis, keepdims=keepdim, dtype=dtype) - - -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def rad2deg(x, name=None): - return ivy.rad2deg(x) - - -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def reciprocal(x, name=None): - return ivy.reciprocal(x) +def lerp_(x, y, weight, name=None): + return ivy.inplace_update(x, lerp(x, y, weight)) @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") @@ -437,140 +33,18 @@ def reciprocal_(x, name=None): return ivy.inplace_update(x, reciprocal(x)) -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def remainder(x, y, name=None): - return ivy.remainder(x, y) - - -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def round(x, name=None): - return ivy.round(x) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") @to_ivy_arrays_and_back def round_(x, name=None): return ivy.inplace_update(x, round(x)) -@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") -@to_ivy_arrays_and_back -def rsqrt(x, name=None): - return 1 / ivy.sqrt(x) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") def rsqrt_(x, name=None): return ivy.inplace_update(x, reciprocal(sqrt(x))) -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def sgn(x, name=None): - return ivy.sign(x, np_variant=True) - - -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def sign(x, name=None): - return ivy.sign(x, np_variant=False) - - -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def sin(x, name=None): - return ivy.sin(x) - - -@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") -@to_ivy_arrays_and_back -def sinh(x, name=None): - return ivy.sinh(x) - - -@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") -@to_ivy_arrays_and_back -def sqrt(x, name=None): - return ivy.sqrt(x) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") @to_ivy_arrays_and_back def sqrt_(x, name=None): return ivy.inplace_update(x, sqrt(x)) - - -@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") -@to_ivy_arrays_and_back -def square(x, name=None): - return ivy.square(x) - - -@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") -@to_ivy_arrays_and_back -def floor_divide(x, y, name=None): - return ivy.floor_divide(x, y) - - -@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") -@to_ivy_arrays_and_back -def stanh(x, scale_a=0.67, scale_b=1.7159, name=None): - # TODO this function will be simplified as soon as the ivy.stanh(x,a,b) is added - exp_ax = ivy.exp(ivy.multiply(scale_a, x)) - exp_minus_ax = ivy.exp(ivy.multiply(-scale_a, x)) - numerator = ivy.subtract(exp_ax, exp_minus_ax) - denominator = ivy.add(exp_ax, exp_minus_ax) - ret = ivy.multiply(scale_b, ivy.divide(numerator, denominator)) - return ret - - -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def subtract(x, y, name=None): - return ivy.subtract(x, y) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int6")}, "paddle" -) -@to_ivy_arrays_and_back -def take( - x, - index, - mode="raise", - name=None, -): - if mode not in ["raise", "wrap", "clip"]: - raise ValueError( - "'mode' in 'take' should be 'raise', 'wrap', 'clip', but received {}." - .format(mode) - ) - x = ivy.reshape(x, (-1,)) - if mode == "clip": - index = ivy.clip(index, 0, x.shape[-1] - 1) - elif mode == "wrap": - index = ivy.where(index < 0, index % x.shape[-1], index) - index = ivy.where(index >= x.shape[-1], index % x.shape[-1], index) - return ivy.gather(x, index, axis=0) - - -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def tan(x, name=None): - return ivy.tan(x) - - -@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") -@to_ivy_arrays_and_back -def tanh(x, name=None): - return ivy.tanh(x) - - -@with_supported_dtypes( - {"2.4.2 and below": ("float32", "float64", "int32", "int64")}, "paddle" -) -@to_ivy_arrays_and_back -def trunc(x, name=None): - return ivy.trunc(x) diff --git a/ivy/functional/frontends/paddle/tensor/random.py b/ivy/functional/frontends/paddle/tensor/random.py index b69d73d8a5e25..ea6bd38157195 100644 --- a/ivy/functional/frontends/paddle/tensor/random.py +++ b/ivy/functional/frontends/paddle/tensor/random.py @@ -1,100 +1,14 @@ # global +from ..random import * # noqa: F401 import ivy from ivy.func_wrapper import with_supported_dtypes -from ivy.func_wrapper import with_supported_device_and_dtypes, with_unsupported_dtypes from ivy.functional.frontends.paddle.func_wrapper import ( to_ivy_arrays_and_back, ) - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64")}, - "paddle", -) -@to_ivy_arrays_and_back -def normal(mean=0.0, std=1.0, shape=None, name=None): - return ivy.random_normal(mean=mean, std=std, shape=shape) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64")}, - "paddle", -) -@to_ivy_arrays_and_back -def poisson(x, name=None): - return ivy.poisson(x, shape=None, device=None, dtype=None, seed=None, out=None) - - -@with_supported_device_and_dtypes( - { - "2.5.1 and above": { - "cpu": ( - "bfloat16", - "float32", - "float64", - ), - "gpu": ( - "bfloat16", - "float16", - "float32", - "float64", - ), - }, - "2.4.2 and below": { - "cpu": ( - "float32", - "float64", - ), - "gpu": ( - "float16", - "float32", - "float64", - ), - }, - }, - "paddle", -) -@to_ivy_arrays_and_back -def rand(shape, dtype=None, name=None): - return ivy.random_uniform(low=0.0, high=1.0, shape=shape, dtype=dtype, seed=None) - - -@to_ivy_arrays_and_back -def randint(low=0, high=None, shape=[1], dtype=None, name=None): - return ivy.randint(low, high, shape=shape, dtype=dtype) - - -@with_unsupported_dtypes( - {"2.5.1 and below": ("int16", "float16", "bfloat16", "uint8")}, - "paddle", -) -@to_ivy_arrays_and_back -def randint_like(x, low=0, high=None, dtype=None, name=None): - if high is None: - high = low - low = 0 - if high <= 0: - raise ivy.exceptions.IvyError( - "If high is None, low must be greater than 0, but received low = 0." - ) - return ivy.randint(low, high, shape=x.shape, dtype=dtype, seed=None) - - -def randn(shape, dtype=None, name=None): - if dtype not in ["float32", "float64"]: - raise ivy.exceptions.IvyError( - "Unsupported dtype for randn, only float32 and float64 are supported, " - ) - return ivy.random_normal(shape=shape, dtype=dtype, seed=None) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64")}, - "paddle", -) -@to_ivy_arrays_and_back -def standard_normal(shape, dtype=None, name=None): - return ivy.random_normal(mean=0, std=1, shape=shape, dtype=dtype) +# NOTE: +# Only inplace functions are to be added in this file. +# Please add non-inplace counterparts to `/frontends/paddle/random.py`. @with_supported_dtypes( @@ -102,8 +16,8 @@ def standard_normal(shape, dtype=None, name=None): "paddle", ) @to_ivy_arrays_and_back -def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None): - return ivy.random_uniform(low=min, high=max, shape=shape, dtype=dtype, seed=seed) +def exponential_(x, lam=1.0, name=None): + return ivy.multiply(lam, ivy.exp(ivy.multiply(-lam, x))) @with_supported_dtypes( diff --git a/ivy/functional/frontends/paddle/tensor/search.py b/ivy/functional/frontends/paddle/tensor/search.py index 205f2b6d1e5b7..9fde09978f216 100644 --- a/ivy/functional/frontends/paddle/tensor/search.py +++ b/ivy/functional/frontends/paddle/tensor/search.py @@ -1,98 +1,2 @@ # global -import ivy -from ivy.func_wrapper import with_supported_dtypes -from ivy.functional.frontends.paddle.func_wrapper import ( - to_ivy_arrays_and_back, -) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int16", "int32", "int64", "uint8")}, - "paddle", -) -@to_ivy_arrays_and_back -def argmax(x, /, *, axis=None, keepdim=False, dtype="int64", name=None): - return ivy.argmax(x, axis=axis, keepdims=keepdim, dtype=dtype) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int16", "int32", "int64", "uint8")}, - "paddle", -) -@to_ivy_arrays_and_back -def argmin(x, /, *, axis=None, keepdim=False, dtype="int64", name=None): - return ivy.argmin(x, axis=axis, keepdims=keepdim, dtype=dtype) - - -@with_supported_dtypes( - {"2.4.2 and below": ("float32", "float64", "int16", "int32", "int64", "uint8")}, - "paddle", -) -@to_ivy_arrays_and_back -def argsort(x, /, *, axis=-1, descending=False, name=None): - return ivy.argsort(x, axis=axis, descending=descending) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, - "paddle", -) -@to_ivy_arrays_and_back -def masked_select(x, mask, name=None): - return ivy.flatten(x[mask]) - - -@with_supported_dtypes( - {"2.4.2 and below": ("float32", "float64", "int16", "int32", "int64", "uint8")}, - "paddle", -) -@to_ivy_arrays_and_back -def nonzero(input, /, *, as_tuple=False): - ret = ivy.nonzero(input) - if as_tuple is False: - ret = ivy.matrix_transpose(ivy.stack(ret)) - return ret - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, - "paddle", -) -@to_ivy_arrays_and_back -def searchsorted(sorted_sequence, values, out_int32=False, right=False, name=None): - if right: - side = "right" - else: - side = "left" - ret = ivy.searchsorted(sorted_sequence, values, side=side) - if out_int32: - ret = ivy.astype(ret, "int32") - return ret - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, - "paddle", -) -@to_ivy_arrays_and_back -def sort(x, /, *, axis=-1, descending=False, name=None): - return ivy.sort(x, axis=axis, descending=descending) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, - "paddle", -) -@to_ivy_arrays_and_back -def topk(x, k, axis=None, largest=True, sorted=True, name=None): - return ivy.top_k(x, k, axis=axis, largest=largest, sorted=sorted) - - -# where -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, - "paddle", -) -@to_ivy_arrays_and_back -def where(condition, x, y, name=None): - return ivy.where(condition, x, y) +from ..search import * # noqa: F401 diff --git a/ivy/functional/frontends/paddle/tensor/stat.py b/ivy/functional/frontends/paddle/tensor/stat.py index bfee9e9d65d2c..4e0771fbae0b7 100644 --- a/ivy/functional/frontends/paddle/tensor/stat.py +++ b/ivy/functional/frontends/paddle/tensor/stat.py @@ -1,77 +1,2 @@ # global -import ivy -from ivy.func_wrapper import with_unsupported_dtypes, with_supported_dtypes -from ivy.functional.frontends.paddle.func_wrapper import ( - to_ivy_arrays_and_back, -) - - -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") -@to_ivy_arrays_and_back -def mean(input, axis=None, keepdim=False, out=None): - ret = ivy.mean(input, axis=axis, keepdims=keepdim, out=out) - ret = ivy.expand_dims(ret, axis=-1) if ret.ndim == 0 else ret - return ret - - -@with_supported_dtypes( - {"2.5.1 and below": ("bool", "float16", "float32", "float64", "int32", "int64")}, - "paddle", -) -@to_ivy_arrays_and_back -def median(x, axis=None, keepdim=False, name=None): - x = ( - ivy.astype(x, ivy.float64) - if ivy.dtype(x) == "float64" - else ivy.astype(x, ivy.float32) - ) - return ivy.median(x, axis=axis, keepdims=keepdim) - - -@with_supported_dtypes( - {"2.5.0 and below": ("float16", "float32", "float64", "uint16")}, - "paddle", -) -@to_ivy_arrays_and_back -def nanmedian(x, axis=None, keepdim=True, name=None): - x = ( - ivy.astype(x, ivy.float64) - if ivy.dtype(x) == "float64" - else ivy.astype(x, ivy.float32) - ) - return ivy.median(x, axis=axis, keepdims=keepdim) - - -@with_unsupported_dtypes({"2.5.1 and below": ("complex", "int8")}, "paddle") -@to_ivy_arrays_and_back -def numel(x, name=None): - prod = ivy.prod(x.size, dtype=ivy.int64) - try: - length = len(x) - except (ValueError, TypeError): - length = 1 # if 0 dimensional tensor with 1 element - return ivy.array(prod if prod > 0 else ivy.array(length, dtype=ivy.int64)) - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "uint16")}, - "paddle", -) -@to_ivy_arrays_and_back -def std(x, axis=None, unbiased=True, keepdim=False, name=None): - x = ( - ivy.astype(x, ivy.float64) - if ivy.dtype(x) == "float64" - else ivy.astype(x, ivy.float32) - ) - return ivy.std(x, axis=axis, correction=int(unbiased), keepdims=keepdim) - - -@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") -@to_ivy_arrays_and_back -def var(x, axis=None, unbiased=True, keepdim=False, name=None): - if unbiased: - correction = 1 - else: - correction = 0 - return ivy.var(x, axis=axis, correction=correction, keepdims=keepdim) +from ..stat import * # noqa: F401 diff --git a/ivy/functional/frontends/paddle/tensor/tensor.py b/ivy/functional/frontends/paddle/tensor/tensor.py index 1c3e399aef9ed..ad68778a1c0fe 100644 --- a/ivy/functional/frontends/paddle/tensor/tensor.py +++ b/ivy/functional/frontends/paddle/tensor/tensor.py @@ -141,6 +141,10 @@ def sin(self, name=None): def sinh(self, name=None): return paddle_frontend.Tensor(ivy.sinh(self._ivy_array)) + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def lerp(self, y, weight, name=None): + return paddle_frontend.lerp(self, y, weight) + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") def lerp_(self, y, weight, name=None): self.ivy_array = paddle_frontend.lerp(self, y, weight).ivy_array @@ -694,6 +698,12 @@ def tolist(self): def min(self, axis=None, keepdim=False, name=None): return ivy.min(self._ivy_array, axis=axis, keepdims=keepdim) + @with_supported_dtypes( + {"2.5.1 and below": ("int32", "int64", "float32", "float64")}, "paddle" + ) + def pow(self, y, name=None): + return paddle_frontend.Tensor(ivy.pow(self._ivy_array, _to_ivy_array(y))) + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") def atan(self, name=None): return ivy.atan(self._ivy_array) @@ -714,6 +724,18 @@ def std(self, axis=None, unbiased=True, keepdim=False, name=None): def trunc(self, name=None): return paddle_frontend.Tensor(ivy.trunc(self._ivy_array)) + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def stanh(self, scale_a=0.67, scale_b=1.7159, name=None): + return paddle_frontend.stanh(self, scale_a=scale_a, scale_b=scale_b) + + @with_supported_dtypes( + {"2.5.1 and below": ("int32", "int64", "float32", "float64")}, "paddle" + ) + def trace(self, offset=0, axis1=0, axis2=1, name=None): + return paddle_frontend.Tensor( + ivy.trace(self._ivy_array, offset=offset, axis1=axis1, axis2=axis2) + ) + @with_supported_dtypes( {"2.5.1 and below": ("float32", "float64", "int16", "int32", "int64", "uint8")}, "paddle", diff --git a/ivy/functional/frontends/sklearn/metrics/_classification.py b/ivy/functional/frontends/sklearn/metrics/_classification.py index e6679505631b5..b5fece615691b 100644 --- a/ivy/functional/frontends/sklearn/metrics/_classification.py +++ b/ivy/functional/frontends/sklearn/metrics/_classification.py @@ -1,6 +1,7 @@ import ivy from ivy.functional.frontends.numpy.func_wrapper import to_ivy_arrays_and_back from sklearn.utils.multiclass import type_of_target +from ivy.utils.exceptions import IvyValueError @to_ivy_arrays_and_back @@ -17,3 +18,28 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None): ret = ret / y_true.shape[0] ret = ret.astype("float64") return ret + + +@to_ivy_arrays_and_back +def precision_score(y_true, y_pred, *, average="binary", sample_weight=None): + # TODO: implement sample_weight + y_type = type_of_target(y_true) + if y_type.startswith("multilabel"): + true_positives = ivy.count_nonzero( + ivy.equal(y_true, y_pred).astype("int64"), axis=0 + ) + all_positives = ivy.count_nonzero(y_pred, axis=0) + else: + true_positives = ivy.count_nonzero( + ivy.equal(y_true, y_pred).astype("int64"), axis=1 + ) + all_positives = ivy.count_nonzero(y_pred) + if average == "binary": + precision = true_positives / all_positives + elif average == "micro": + precision = ivy.sum(true_positives) / ivy.sum(all_positives) + elif average == "macro": + precision = ivy.mean(true_positives / all_positives) + else: + raise IvyValueError("Invalid value for 'average'.") + return precision diff --git a/ivy/functional/frontends/tensorflow/func_wrapper.py b/ivy/functional/frontends/tensorflow/func_wrapper.py index ca14f5fc1845a..06fee643e81e0 100644 --- a/ivy/functional/frontends/tensorflow/func_wrapper.py +++ b/ivy/functional/frontends/tensorflow/func_wrapper.py @@ -217,7 +217,7 @@ def _outputs_to_frontend_arrays_tf(*args, **kwargs): # convert all arrays in the return to `frontend.Tensorflow.tensor` instances return ivy.nested_map( - ret, _ivy_array_to_tensorflow, include_derived={tuple: True} + ret, _ivy_array_to_tensorflow, include_derived={"tuple": True} ) _outputs_to_frontend_arrays_tf.outputs_to_frontend_arrays = True diff --git a/ivy/functional/frontends/tensorflow/general_functions.py b/ivy/functional/frontends/tensorflow/general_functions.py index 11fa0a9437816..8cc476dfe4a28 100644 --- a/ivy/functional/frontends/tensorflow/general_functions.py +++ b/ivy/functional/frontends/tensorflow/general_functions.py @@ -285,6 +285,16 @@ def linspace(start, stop, num, name=None, axis=0): return ivy.linspace(start, stop, num, axis=axis) +@to_ivy_arrays_and_back +def meshgrid(*args, **kwargs): + sparse = False + indexing = "xy" + if "indexing" in kwargs: + indexing = kwargs["indexing"] + + return ivy.meshgrid(*args, sparse=sparse, indexing=indexing) + + @to_ivy_arrays_and_back def no_op(name=None): return diff --git a/ivy/functional/frontends/tensorflow/math.py b/ivy/functional/frontends/tensorflow/math.py index 4283c2c151633..2779220159faf 100644 --- a/ivy/functional/frontends/tensorflow/math.py +++ b/ivy/functional/frontends/tensorflow/math.py @@ -226,6 +226,11 @@ def cumsum(x, axis, exclusive=False, reverse=False, name=None): ) +@to_ivy_arrays_and_back +def digamma(x, name=None): + return ivy.digamma(x) + + @to_ivy_arrays_and_back def divide(x, y, name=None): x, y = check_tensorflow_casting(x, y) diff --git a/ivy/functional/frontends/tensorflow/raw_ops.py b/ivy/functional/frontends/tensorflow/raw_ops.py index 3120d8395cab5..354afa54bdb71 100644 --- a/ivy/functional/frontends/tensorflow/raw_ops.py +++ b/ivy/functional/frontends/tensorflow/raw_ops.py @@ -49,6 +49,7 @@ Cosh = to_ivy_arrays_and_back(map_raw_ops_alias(tf_frontend.math.cosh)) Cumprod = to_ivy_arrays_and_back(map_raw_ops_alias(tf_frontend.math.cumprod)) Cumsum = to_ivy_arrays_and_back(map_raw_ops_alias(tf_frontend.math.cumsum)) +Digamma = to_ivy_arrays_and_back(map_raw_ops_alias(tf_frontend.math.digamma)) Div = to_ivy_arrays_and_back(map_raw_ops_alias(tf_frontend.math.divide)) Einsum = to_ivy_arrays_and_back( with_supported_dtypes( diff --git a/ivy/functional/frontends/torch/creation_ops.py b/ivy/functional/frontends/torch/creation_ops.py index ec2f79b4eb476..b275cfb9d5db8 100644 --- a/ivy/functional/frontends/torch/creation_ops.py +++ b/ivy/functional/frontends/torch/creation_ops.py @@ -5,6 +5,7 @@ to_ivy_shape, ) from ivy.func_wrapper import with_unsupported_dtypes +import ivy.functional.frontends.torch as torch_frontend @to_ivy_arrays_and_back @@ -46,6 +47,16 @@ def as_tensor( dtype=None, device=None, ): + if dtype is None: + if isinstance(data, int): + dtype = ivy.int64 + elif isinstance(data, float): + dtype = torch_frontend.get_default_dtype() + elif isinstance(data, (list, tuple)): + if all(isinstance(d, int) for d in data): + dtype = ivy.int64 + else: + dtype = torch_frontend.get_default_dtype() return ivy.asarray(data, dtype=dtype, device=device) diff --git a/ivy/functional/frontends/torch/func_wrapper.py b/ivy/functional/frontends/torch/func_wrapper.py index 84e16bb9d797d..da2fad9b18517 100644 --- a/ivy/functional/frontends/torch/func_wrapper.py +++ b/ivy/functional/frontends/torch/func_wrapper.py @@ -145,10 +145,10 @@ def _inputs_to_ivy_arrays_torch(*args, **kwargs): ) # convert all input arrays to ivy.Array instances new_args = ivy.nested_map( - args, _to_ivy_array, include_derived={tuple: True}, shallow=False + args, _to_ivy_array, include_derived={"tuple": True}, shallow=False ) new_kwargs = ivy.nested_map( - kwargs, _to_ivy_array, include_derived={tuple: True}, shallow=False + kwargs, _to_ivy_array, include_derived={"tuple": True}, shallow=False ) return fn(*new_args, **new_kwargs) @@ -202,7 +202,7 @@ def outputs_to_frontend_arrays_torch(*args, **kwargs): ret = _from_ivy_array_to_torch_frontend_tensor( ret, nested=True, - include_derived={tuple: True}, + include_derived={"tuple": True}, requires_grad=kwargs.get( "requires_grad", any( diff --git a/ivy/functional/frontends/torch/linalg.py b/ivy/functional/frontends/torch/linalg.py index 46b8609b82012..1055938e53a4d 100644 --- a/ivy/functional/frontends/torch/linalg.py +++ b/ivy/functional/frontends/torch/linalg.py @@ -76,6 +76,14 @@ def eigh(a, /, UPLO="L", out=None): return ivy.eigh(a, UPLO=UPLO, out=out) +@with_supported_dtypes( + {"2.0.1 and below": ("float32", "float64", "complex32", "complex64", "complex128")}, + "torch", +) +def eigh(A, UPLO="L", *, out=None): + return ivy.eigh(A, UPLO=UPLO, out=out) + + @to_ivy_arrays_and_back @with_supported_dtypes( {"2.0.1 and below": ("float32", "float64", "complex32", "complex64")}, "torch" @@ -185,6 +193,26 @@ def multi_dot(tensors, *, out=None): return ivy.multi_dot(tensors, out=out) +@to_ivy_arrays_and_back +@with_supported_dtypes( + {"2.0.1 and below": ("float32", "float64", "complex64", "complex128")}, "torch" +) +def norm(input, ord=None, dim=None, keepdim=False, *, dtype=None, out=None): + if dim is None and (ord is not None): + if input.ndim == 1: + ret = ivy.vector_norm(input, axis=dim, keepdims=keepdim, ord=ord) + else: + ret = ivy.matrix_norm(input, keepdims=keepdim, ord=ord) + elif dim is None and ord is None: + input = ivy.flatten(input) + ret = ivy.vector_norm(input, axis=0, keepdims=keepdim, ord=2) + if isinstance(dim, int): + ret = ivy.vector_norm(input, axis=dim, keepdims=keepdim, ord=ord) + elif isinstance(dim, tuple): + ret = ivy.matrix_norm(input, axis=dim, keepdims=keepdim, ord=ord) + return ret + + @to_ivy_arrays_and_back @with_supported_dtypes( {"2.0.1 and below": ("float32", "float64", "complex32", "complex64")}, "torch" diff --git a/ivy/functional/frontends/torch/nn/functional/vision_functions.py b/ivy/functional/frontends/torch/nn/functional/vision_functions.py index 80f93ebbc52de..9d0aa88c30d1a 100644 --- a/ivy/functional/frontends/torch/nn/functional/vision_functions.py +++ b/ivy/functional/frontends/torch/nn/functional/vision_functions.py @@ -3,11 +3,15 @@ # local import ivy -from ivy import with_unsupported_dtypes +from ivy import with_unsupported_dtypes, with_supported_dtypes from ivy.functional.frontends.torch.func_wrapper import to_ivy_arrays_and_back from ivy.utils.exceptions import IvyNotImplementedException +cubic_conv1 = lambda A, x: ((A + 2) * x - (A + 3)) * x * x + 1 +cubic_conv2 = lambda A, x: (((A * x) - (5 * A)) * x + (8 * A)) * x - (4 * A) + + # --- Helpers --- # # --------------- # @@ -78,6 +82,269 @@ def affine_grid(theta, size, align_corners=False): return grid.view((N, D, H, W, 3)) +def bicubic_interp(x, t, alpha=-0.75): + n, h, w = t.shape + coeffs = [] + coeffs.append(ivy.reshape(cubic_conv2(alpha, t + 1), (n, 1, h, w))) + coeffs.append(ivy.reshape(cubic_conv1(alpha, t), (n, 1, h, w))) + coeffs.append(ivy.reshape(cubic_conv1(alpha, 1 - t), (n, 1, h, w))) + coeffs.append(ivy.reshape(cubic_conv2(alpha, 2 - t), (n, 1, h, w))) + return x[0] * coeffs[0] + x[1] * coeffs[1] + x[2] * coeffs[2] + x[3] * coeffs[3] + + +@with_supported_dtypes({"2.0.1 and below": ("float32", "float64")}, "torch") +@to_ivy_arrays_and_back +def grid_sample( + input, grid, mode="bilinear", padding_mode="zeros", align_corners=False +): + input_clone = ivy.copy_array(input) + grid_clone = ivy.copy_array(grid) + + if ivy.get_num_dims(input_clone) == 4: # sample from 2D images + n, c, h, w = input_clone.shape + n, to_h, to_w, gc = grid_clone.shape + + # Un-normalize 2D grid + if align_corners: # to range[0, size - 1] + grid_clone[..., 0] = ((grid_clone[..., 0] + 1) / 2) * (w - 1) + grid_clone[..., 1] = ((grid_clone[..., 1] + 1) / 2) * (h - 1) + + elif not align_corners: # to range[0.5, size - 0.5] + grid_clone[..., 0] = ((grid_clone[..., 0] + 1) * w - 1) / 2 + grid_clone[..., 1] = ((grid_clone[..., 1] + 1) * h - 1) / 2 + + batch_coor = ivy.reshape(ivy.arange(n), (-1, 1)) + batch_coor = ivy.repeat(batch_coor, to_h * to_w, axis=1) + batch_coor = ivy.reshape(batch_coor, (n, to_h, to_w)) + padding = [(0, 0) for _ in range(2)] + [(4, 4) for _ in range(2)] + input_clone = ivy.pad(input_clone, padding, mode="constant", constant_values=0) + + if mode == "bicubic": + grid_floor = ivy.floor(grid_clone) + distance = grid_clone - grid_floor + + tx, ty = distance[..., 0], distance[..., 1] + + grid_floor -= 1 + grid_floor = [ + grid_sample_padding( + grid_floor + i, padding_mode, align_corners, borders=[w, h] + ) + for i in range(4) + ] + + w_cubic = [ + ivy.astype(grid_floor[i][..., 0] + 4, ivy.int64) for i in range(4) + ] + h_cubic = [ + ivy.astype(grid_floor[i][..., 1] + 4, ivy.int64) for i in range(4) + ] + + coeffs = [ + bicubic_interp( + [ + ivy.permute_dims( + input_clone[batch_coor, :, h_cubic[i], w_cubic[0]], + (0, 3, 1, 2), + ), + ivy.permute_dims( + input_clone[batch_coor, :, h_cubic[i], w_cubic[1]], + (0, 3, 1, 2), + ), + ivy.permute_dims( + input_clone[batch_coor, :, h_cubic[i], w_cubic[2]], + (0, 3, 1, 2), + ), + ivy.permute_dims( + input_clone[batch_coor, :, h_cubic[i], w_cubic[3]], + (0, 3, 1, 2), + ), + ], + tx, + ) + for i in range(4) + ] + return bicubic_interp(coeffs, ty) + + else: + grid_clone = grid_sample_padding( + grid_clone, padding_mode, align_corners, borders=[w, h] + ) + + if mode == "bilinear": + grid_clone += 4 + w_coor = ivy.reshape(grid_clone[..., 0], (n, to_h, to_w)) + h_coor = ivy.reshape(grid_clone[..., 1], (n, to_h, to_w)) + + w0 = ivy.astype(ivy.floor(w_coor), ivy.int64) + h0 = ivy.astype(ivy.floor(h_coor), ivy.int64) + w1 = w0 + 1 + h1 = h0 + 1 + + v00 = ivy.permute_dims(input_clone[batch_coor, :, h0, w0], (0, 3, 1, 2)) + v01 = ivy.permute_dims(input_clone[batch_coor, :, h0, w1], (0, 3, 1, 2)) + v10 = ivy.permute_dims(input_clone[batch_coor, :, h1, w0], (0, 3, 1, 2)) + v11 = ivy.permute_dims(input_clone[batch_coor, :, h1, w1], (0, 3, 1, 2)) + + alpha = ivy.reshape(w_coor - w0, (n, 1, to_h, to_w)) + beta = ivy.reshape(h_coor - h0, (n, 1, to_h, to_w)) + + alpha = ivy.astype(alpha, ivy.float32) + beta = ivy.astype(beta, ivy.float32) + + v0 = v00 * (1 - alpha) + v01 * alpha + v1 = v10 * (1 - alpha) + v11 * alpha + + return v0 * (1 - beta) + v1 * beta + + elif mode == "nearest": + w_coor = ivy.reshape(grid_clone[..., 0], (n, to_h, to_w)) + h_coor = ivy.reshape(grid_clone[..., 1], (n, to_h, to_w)) + + w_coor = ivy.astype(ivy.round(w_coor), ivy.int64) + 4 + h_coor = ivy.astype(ivy.round(h_coor), ivy.int64) + 4 + return ivy.permute_dims( + input_clone[batch_coor, :, h_coor, w_coor], (0, 3, 1, 2) + ) + + else: + raise ivy.exceptions.IvyError(f"Not supported mode {mode}") + + elif ivy.get_num_dims(input_clone) == 5: # sample from 3D images + n, c, d, h, w = input_clone.shape + n, to_d, to_h, to_w, gc = grid_clone.shape + + # Un-normalize 3D grid + if align_corners: # to range[0, size - 1] + grid_clone[..., 0] = ((grid_clone[..., 0] + 1) / 2) * (w - 1) + grid_clone[..., 1] = ((grid_clone[..., 1] + 1) / 2) * (h - 1) + grid_clone[..., 2] = ((grid_clone[..., 2] + 1) / 2) * (d - 1) + elif not align_corners: # to range[0.5, size - 0.5] + grid_clone[..., 0] = ((grid_clone[..., 0] + 1) * w - 1) / 2 + grid_clone[..., 1] = ((grid_clone[..., 1] + 1) * h - 1) / 2 + grid_clone[..., 2] = ((grid_clone[..., 2] + 1) * d - 1) / 2 + + batch_coor = ivy.reshape(ivy.arange(n), (-1, 1)) + batch_coor = ivy.repeat(batch_coor, to_d * to_h * to_w, axis=1) + batch_coor = ivy.reshape(batch_coor, (n, to_d, to_h, to_w)) + padding = [(0, 0) for _ in range(2)] + [(3, 3) for _ in range(3)] + input_clone = ivy.pad(input_clone, padding, mode="constant", constant_values=0) + + grid_clone = grid_sample_padding( + grid_clone, padding_mode, align_corners, borders=[w, h, d] + ) + + if mode == "bilinear": + grid_clone += 3 + w_coor = ivy.reshape(grid_clone[..., 0], (n, to_d, to_h, to_w)) + h_coor = ivy.reshape(grid_clone[..., 1], (n, to_d, to_h, to_w)) + d_coor = ivy.reshape(grid_clone[..., 2], (n, to_d, to_h, to_w)) + + w0 = ivy.astype(ivy.floor(w_coor), ivy.int64) + h0 = ivy.astype(ivy.floor(h_coor), ivy.int64) + d0 = ivy.astype(ivy.floor(d_coor), ivy.int64) + w1 = w0 + 1 + h1 = h0 + 1 + d1 = d0 + 1 + + v000 = ivy.permute_dims( + input_clone[batch_coor, :, d0, h0, w0], (0, 4, 1, 2, 3) + ) # tnw + v001 = ivy.permute_dims( + input_clone[batch_coor, :, d0, h0, w1], (0, 4, 1, 2, 3) + ) # tne + v010 = ivy.permute_dims( + input_clone[batch_coor, :, d0, h1, w0], (0, 4, 1, 2, 3) + ) # tsw + v011 = ivy.permute_dims( + input_clone[batch_coor, :, d0, h1, w1], (0, 4, 1, 2, 3) + ) # tse + v100 = ivy.permute_dims( + input_clone[batch_coor, :, d1, h0, w0], (0, 4, 1, 2, 3) + ) # bnw + v101 = ivy.permute_dims( + input_clone[batch_coor, :, d1, h0, w1], (0, 4, 1, 2, 3) + ) # bne + v110 = ivy.permute_dims( + input_clone[batch_coor, :, d1, h1, w0], (0, 4, 1, 2, 3) + ) # bsw + v111 = ivy.permute_dims( + input_clone[batch_coor, :, d1, h1, w1], (0, 4, 1, 2, 3) + ) # bse + + alpha = ivy.reshape(w_coor - w0, (n, 1, to_d, to_h, to_w)) + beta = ivy.reshape(h_coor - h0, (n, 1, to_d, to_h, to_w)) + gamma = ivy.reshape(d_coor - d0, (n, 1, to_d, to_h, to_w)) + + alpha = ivy.astype(alpha, ivy.float32) + beta = ivy.astype(beta, ivy.float32) + gamma = ivy.astype(gamma, ivy.float32) + + v = (alpha * beta * gamma) * v111 + v += ((1 - alpha) * beta * gamma) * v110 + v += (alpha * (1 - beta) * gamma) * v101 + v += ((1 - alpha) * (1 - beta) * gamma) * v100 + + v += (alpha * beta * (1 - gamma)) * v011 + v += ((1 - alpha) * beta * (1 - gamma)) * v010 + v += (alpha * (1 - beta) * (1 - gamma)) * v001 + v += ((1 - alpha) * (1 - beta) * (1 - gamma)) * v000 + return v + + elif mode == "nearest": + ceil_mask = grid_clone % 1 == 0.5 + grid_clone[ceil_mask] = ivy.astype( + ivy.ceil(grid_clone[ceil_mask]), ivy.int64 + ) + + w_coor = ivy.reshape(grid_clone[..., 0], (n, to_d, to_h, to_w)) + h_coor = ivy.reshape(grid_clone[..., 1], (n, to_d, to_h, to_w)) + d_coor = ivy.reshape(grid_clone[..., 2], (n, to_d, to_h, to_w)) + + w_coor = ivy.astype(ivy.round(w_coor), ivy.int64) + 3 + h_coor = ivy.astype(ivy.round(h_coor), ivy.int64) + 3 + d_coor = ivy.astype(ivy.round(d_coor), ivy.int64) + 3 + return ivy.permute_dims( + input_clone[batch_coor, :, d_coor, h_coor, w_coor], (0, 4, 1, 2, 3) + ) + + elif mode == "bicubic": + raise ivy.exceptions.IvyError("Bicubic is not support in 3D grid sampling") + + else: + raise ivy.exceptions.IvyError(f"Not supported input shape {input_clone.shape}") + + +def grid_sample_padding(grid, padding_mode, align_corners, borders=None): + if padding_mode == "reflection": + if align_corners: + for idx, border in enumerate(borders): + grid[..., idx] = reflect(grid[..., idx], 0, 2 * (border - 1)) + grid[..., idx] = ivy.clip(grid[..., idx], 0, border - 1) + + else: + for idx, border in enumerate(borders): + grid[..., idx] = reflect(grid[..., idx], -1, 2 * border - 1) + grid[..., idx] = ivy.clip(grid[..., idx], 0, border - 1) + + elif padding_mode == "border": + for idx, border in enumerate(borders): + grid[..., idx] = ivy.clip(grid[..., idx], 0, border - 1) + + masks = [] + for idx, border in enumerate(borders): + masks.append(ivy.bitwise_or(grid[..., idx] < -4, grid[..., idx] > border + 2)) + borders[idx] += 1 + + zeros_mask = masks[0] + for i in range(1, len(borders)): + zeros_mask = ivy.bitwise_or(zeros_mask, masks[i]) + + if grid[zeros_mask].shape[0] > 0: + grid[zeros_mask] = ivy.array(borders) + return grid + + @with_unsupported_dtypes( { "2.0.1 and below": ( @@ -184,7 +451,7 @@ def interpolate( if ( bool(antialias) - and not (mode in ["bilinear", "bicubic"]) + and (mode not in ["bilinear", "bicubic"]) and ivy.get_num_dims(input) == 4 ): raise ivy.utils.exceptions.IvyException( @@ -325,6 +592,18 @@ def pixel_unshuffle(input, downscale_factor): ) +def reflect(x, low2, high2): + min = low2 / 2 + span = (high2 - low2) / 2 + x = ivy.abs(x - min) + frac_in = ivy.abs(x / span) + extra = (frac_in - ivy.floor(frac_in)) * ivy.abs(span) + flips = ivy.floor(x / span) + x[flips % 2 == 0] = (extra + min)[flips % 2 == 0] + x[flips % 2 != 0] = (span - extra + min)[flips % 2 != 0] + return x + + @with_unsupported_dtypes({"2.0.1 and below": ("float16", "bfloat16")}, "torch") @to_ivy_arrays_and_back def upsample( diff --git a/ivy/functional/frontends/torch/pointwise_ops.py b/ivy/functional/frontends/torch/pointwise_ops.py index 2cc35bf32960a..e04dddf76454c 100644 --- a/ivy/functional/frontends/torch/pointwise_ops.py +++ b/ivy/functional/frontends/torch/pointwise_ops.py @@ -384,6 +384,22 @@ def mul(input, other, *, out=None): return ivy.multiply(input, other, out=out) +@to_ivy_arrays_and_back +@with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, "torch") +def mvlgamma(input, p, *, out=None): + ivy.assertions.check_greater( + p, 1, allow_equal=True, message="p has to be greater than or equal to 1" + ) + c = 0.25 * p * (p - 1) * ivy.log(ivy.pi, out=out) + b = 0.5 * ivy.arange((1 - p), 1, 1, dtype=input.dtype, device=input.device, out=out) + return ( + ivy.sum( + ivy.lgamma(ivy.expand_dims(input, axis=-1) + b, out=out), axis=-1, out=out + ) + + c + ) + + @with_unsupported_dtypes({"2.0.1 and below": ("bfloat16",)}, "tensorflow") @to_ivy_arrays_and_back def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index ad750eb9aedb1..20c0c5326a8df 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -1833,6 +1833,27 @@ def char(self): def lcm(self, other, *, out=None): return torch_frontend.lcm(self, other, out=out) + @with_unsupported_dtypes( + { + "2.0.1 and below": ( + "float16", + "bfloat16", + "float32", + "float64", + "complex", + "uint8", + "uint16", + "uint32", + "uint64", + "int8", + ) + }, + "torch", + ) + def lcm_(self, other, *, out=None): + self.ivy_array = self.lcm(other, out=out).ivy_array + return self + @with_unsupported_dtypes( { "2.0.1 and below": ( @@ -1965,6 +1986,25 @@ def unique_consecutive(self, return_inverse, return_counts, dim): def cummax(self, dim): return torch_frontend.cummax(self, dim) + @with_unsupported_dtypes( + { + "2.0.1 and below": ( + "bfloat16", + "int8", + "uint8", + "uint32", + "uint16", + "uint64", + "int16", + "complex128", + "complex64", + ) + }, + "torch", + ) + def triu(self, diagonal=0): + return torch_frontend.triu(self, diagonal) + @with_unsupported_dtypes( {"2.0.1 and below": ("bfloat16",)}, "torch", diff --git a/ivy/functional/ivy/activations.py b/ivy/functional/ivy/activations.py index 8ec46b3f94ba6..1010b947a19b9 100644 --- a/ivy/functional/ivy/activations.py +++ b/ivy/functional/ivy/activations.py @@ -452,11 +452,13 @@ def sigmoid( @to_native_arrays_and_back @handle_array_function @handle_device_shifting +@handle_complex_input def softmax( x: Union[ivy.Array, ivy.NativeArray], /, *, axis: Optional[int] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Array] = None, ) -> ivy.Array: """ @@ -468,6 +470,9 @@ def softmax( Input array. axis The dimension softmax would be performed on. The default is ``None``. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. diff --git a/ivy/functional/ivy/creation.py b/ivy/functional/ivy/creation.py index 23ac94a9d5e3d..eeb8758f43b70 100644 --- a/ivy/functional/ivy/creation.py +++ b/ivy/functional/ivy/creation.py @@ -39,11 +39,11 @@ # --------# -def asarray_handle_nestable(fn: Callable) -> Callable: +def _asarray_handle_nestable(fn: Callable) -> Callable: fn_name = fn.__name__ @functools.wraps(fn) - def _asarray_handle_nestable(*args, **kwargs): + def _asarray_handle_nestable_wrapper(*args, **kwargs): """ Call `fn` with the *nestable* property of the function correctly handled. This means mapping the function to the container leaves if any containers are passed @@ -71,8 +71,8 @@ def _asarray_handle_nestable(*args, **kwargs): # the passed arguments, returning an ivy or a native array. return fn(*args, **kwargs) - _asarray_handle_nestable.handle_nestable = True - return _asarray_handle_nestable + _asarray_handle_nestable_wrapper.handle_nestable = True + return _asarray_handle_nestable_wrapper def _ivy_to_native(x): @@ -126,9 +126,9 @@ def _remove_np_bfloat16(obj): return obj -def asarray_to_native_arrays_and_back(fn: Callable) -> Callable: +def _asarray_to_native_arrays_and_back(fn: Callable) -> Callable: @functools.wraps(fn) - def _asarray_to_native_arrays_and_back(*args, dtype=None, **kwargs): + def _asarray_to_native_arrays_and_back_wrapper(*args, dtype=None, **kwargs): """ Wrap `fn` so that input arrays are all converted to `ivy.NativeArray` instances and return arrays are all converted to `ivy.Array` instances. @@ -147,12 +147,12 @@ def _asarray_to_native_arrays_and_back(*args, dtype=None, **kwargs): dtype = ivy.default_dtype(dtype=dtype, as_native=True) return to_ivy(fn(*new_args, dtype=dtype, **kwargs)) - return _asarray_to_native_arrays_and_back + return _asarray_to_native_arrays_and_back_wrapper -def asarray_infer_dtype(fn: Callable) -> Callable: +def _asarray_infer_dtype(fn: Callable) -> Callable: @functools.wraps(fn) - def _asarray_infer_dtype(*args, dtype=None, **kwargs): + def _asarray_infer_dtype_wrapper(*args, dtype=None, **kwargs): """ Determine the correct `dtype`, and then calls the function with the `dtype` passed explicitly. This wrapper is specifically for the backend implementations @@ -204,13 +204,13 @@ def _infer_dtype(obj): # call the function with dtype provided explicitly return fn(*args, dtype=dtype, **kwargs) - _asarray_infer_dtype.infer_dtype = True - return _asarray_infer_dtype + _asarray_infer_dtype_wrapper.infer_dtype = True + return _asarray_infer_dtype_wrapper -def asarray_infer_device(fn: Callable) -> Callable: +def _asarray_infer_device(fn: Callable) -> Callable: @functools.wraps(fn) - def _asarray_infer_device(*args, device=None, **kwargs): + def _asarray_infer_device_wrapper(*args, device=None, **kwargs): """ Determine the correct `device`, and then calls the function with the `device` passed explicitly. This wrapper is specifically for the backend implementations @@ -243,11 +243,11 @@ def _asarray_infer_device(*args, device=None, **kwargs): # call the function with device provided explicitly return fn(*args, device=device, **kwargs) - _asarray_infer_device.infer_device = True - return _asarray_infer_device + _asarray_infer_device_wrapper.infer_device = True + return _asarray_infer_device_wrapper -def asarray_inputs_to_native_shapes(fn: Callable) -> Callable: +def _asarray_inputs_to_native_shapes(fn: Callable) -> Callable: @functools.wraps(fn) def _inputs_to_native_shapes(*args, **kwargs): new_arg = _shape_to_native(args[0]) @@ -1961,7 +1961,7 @@ def one_hot( ret Tensor of zeros with the same shape and type as a, unless dtype provided which overrides. - + Examples -------- With :class:`ivy.Array` inputs: @@ -1993,11 +1993,11 @@ def one_hot( >>> z = x.one_hot(y) >>> print(z) { - a: ivy.array([[0., 1., 0., 0., 0.], + a: ivy.array([[0., 1., 0., 0., 0.], [0., 0., 1., 0., 0.]]), - b: ivy.array([[0., 0., 0., 1., 0.], + b: ivy.array([[0., 0., 0., 1., 0.], [0., 1., 0., 0., 0.]]), - c: ivy.array([[0., 0., 1., 0., 0.], + c: ivy.array([[0., 0., 1., 0., 0.], [0., 0., 0., 1., 0.]]) } diff --git a/ivy/functional/ivy/device.py b/ivy/functional/ivy/device.py index df875906cd6cd..dd35ffa45aa82 100644 --- a/ivy/functional/ivy/device.py +++ b/ivy/functional/ivy/device.py @@ -831,7 +831,7 @@ def default_device( return ivy.dev(item, as_native=as_native) global default_device_stack if not default_device_stack: - ret = "gpu:0" if ivy.gpu_is_available() else "cpu" + ret = "cpu" else: ret = default_device_stack[-1] if as_native: diff --git a/ivy/functional/ivy/elementwise.py b/ivy/functional/ivy/elementwise.py index 3d9eebc0aa71e..461b1adacc179 100644 --- a/ivy/functional/ivy/elementwise.py +++ b/ivy/functional/ivy/elementwise.py @@ -5189,8 +5189,8 @@ def positive( @handle_array_function @handle_device_shifting def pow( - x1: Union[float, ivy.Array, ivy.NativeArray], - x2: Union[float, ivy.Array, ivy.NativeArray], + x1: Union[ivy.Array, ivy.NativeArray], + x2: Union[int, float, ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None, @@ -5201,13 +5201,6 @@ def pow( (the exponent), where ``x2_i`` is the corresponding element of the input array ``x2``. - .. note:: - If both ``x1`` and ``x2`` have integer data types, the result of ``pow`` when - ``x2_i`` is negative (i.e., less than zero) is unspecified and thus - implementation-dependent. If ``x1`` has an integer data type and ``x2`` has a - floating-point data type, behavior is implementation-dependent (type promotion - between data type "kinds" (integer versus floating-point) is unspecified). - **Special cases** For floating-point operands, @@ -5328,6 +5321,13 @@ def pow( pow.unsupported_gradients = {"torch": ["float16"]} +def _complex_to_inf(exponent): + if exponent < 0: + return float('inf') + ivy.nan * 1j + else: + return -0 * 1j + + @handle_exceptions @handle_backend_invalid @handle_nestable diff --git a/ivy/functional/ivy/experimental/activations.py b/ivy/functional/ivy/experimental/activations.py index 34203cffbaf6d..d4f701cf45b0d 100644 --- a/ivy/functional/ivy/experimental/activations.py +++ b/ivy/functional/ivy/experimental/activations.py @@ -1,5 +1,5 @@ # global -from typing import Union, Optional +from typing import Union, Optional, Callable, Literal # local import ivy @@ -14,9 +14,29 @@ inputs_to_ivy_arrays, handle_device_shifting, handle_backend_invalid, + handle_complex_input, ) +def _logit_jax_like( + x: Union[float, int, ivy.Array], + /, + *, + fn_original: Optional[Callable] = None, + eps: Optional[float] = None, + out: Optional[ivy.Array] = None, +): + real = ivy.real(x) + imag = ivy.imag(x) + if eps is None: + real = ivy.where(ivy.logical_or(real > 1, real < 0), ivy.nan, real) + else: + real = ivy.clip(real, eps, 1 - eps) + z = ivy.add(real, ivy.multiply(ivy.array(1j, dtype=x.dtype), imag)) + z = ivy.log(z / (1 - z)) + return z + + @handle_exceptions @handle_backend_invalid @handle_nestable @@ -24,11 +44,13 @@ @handle_out_argument @to_native_arrays_and_back @handle_device_shifting +@handle_complex_input def logit( x: Union[float, int, ivy.Array], /, *, eps: Optional[float] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Array] = None, ) -> ivy.Array: """ @@ -44,6 +66,9 @@ def logit( When eps is None the function outpus NaN where x < 0 or x > 1. and inf or -inf where x = 1 or x = 0, respectively. Otherwise if eps is defined, x is clamped to [eps, 1 - eps] + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out Optional output array. @@ -67,6 +92,9 @@ def logit( return current_backend(x).logit(x, eps=eps, out=out) +logit.jax_like = _logit_jax_like + + @handle_exceptions @handle_nestable @handle_array_like_without_promotion @@ -239,8 +267,13 @@ def relu6( @handle_out_argument @to_native_arrays_and_back @handle_device_shifting +@handle_complex_input def logsigmoid( - input: Union[ivy.NativeArray, ivy.Array], /, *, out: Optional[ivy.Array] = None + input: Union[ivy.NativeArray, ivy.Array], + /, + *, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ Apply element-wise Log-sigmoid of x. @@ -251,6 +284,9 @@ def logsigmoid( ---------- input Input array. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -452,37 +488,3 @@ def elu( } """ return current_backend(x).elu(x, alpha=alpha, out=out) - - -def sequence_length( - x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None -) -> ivy.int64: - """ - Produce a scalar (tensor of empty shape) containing the number of tensors in the ivy - array input. - - Parameters - ---------- - x - Can be a sequence of any tensor type: bool, complex128, - complex64, double, float, float16, int16, int32, int64, - int8, string, uint16, uint32, uint64, uint8 - - Returns - ------- - length - Length of the input sequence, as a scalar (empty shape tensor). - - Examples - -------- - >>> x = ivy.array([True, False, True]) - >>> y = ivy.sequence_length(x) - >>> print(y) - 3 - - >>> x = [1.0, 2.5, -3.4, 5.6, -85.3] - >>> y = ivy.sequence_length(x) - >>> print(y) - 5 - """ - return current_backend(x).sequence_length(x, out=out) diff --git a/ivy/functional/ivy/experimental/creation.py b/ivy/functional/ivy/experimental/creation.py index e4527da1b6e58..fafcef98e44ff 100644 --- a/ivy/functional/ivy/experimental/creation.py +++ b/ivy/functional/ivy/experimental/creation.py @@ -970,3 +970,55 @@ def trilu( instances in place of any of the arguments. """ return current_backend(x).trilu(x, k=k, upper=upper, out=out) + + +@handle_exceptions +@handle_nestable +@to_native_arrays_and_back +def mel_weight_matrix( + num_mel_bins: int, + dft_length: int, + sample_rate: int, + lower_edge_hertz: float = 0.0, + upper_edge_hertz: float = 3000.0, +): + """ + Generate a MelWeightMatrix that can be used to re-weight a Tensor containing a + linearly sampled frequency spectra (from DFT or STFT) into num_mel_bins frequency + information based on the [lower_edge_hertz, upper_edge_hertz] + + range on the mel scale. This function defines the mel scale in terms of a frequency + in hertz according to the following formula: mel(f) = 2595 * log10(1 + f/700) + + Parameters + ---------- + num_mel_bins + The number of bands in the mel spectrum. + dft_length + The size of the original DFT obtained from (n_fft / 2 + 1). + sample_rate + Samples per second of the input signal. + lower_edge_hertz + Lower bound on the frequencies to be included in the mel spectrum. + upper_edge_hertz + The desired top edge of the highest frequency band. + + Returns + ------- + ret + MelWeightMatrix of shape: [frames, num_mel_bins]. + + Examples + -------- + >>> ivy.mel_weight_matrix(3,3,8000) + ivy.array([[0. ,0. , 0.], + [0. ,0. , 0.75694758], + [0. ,0. , 0. ]]) + """ + return ivy.current_backend().mel_weight_matrix( + num_mel_bins, + dft_length, + sample_rate, + lower_edge_hertz, + upper_edge_hertz, + ) diff --git a/ivy/functional/ivy/experimental/layers.py b/ivy/functional/ivy/experimental/layers.py index 120ebe0279784..5003ef0472ddf 100644 --- a/ivy/functional/ivy/experimental/layers.py +++ b/ivy/functional/ivy/experimental/layers.py @@ -114,77 +114,6 @@ def max_pool1d( ) -@handle_backend_invalid -@handle_nestable -@handle_out_argument -@to_native_arrays_and_back -@handle_device_shifting -def max_unpool1d( - x: ivy.Union[ivy.Array, ivy.NativeArray], - indices: Union[ivy.Array, ivy.NativeArray], - kernel: Union[int, Tuple[int]], - strides: Union[int, Tuple[int]], - padding: str, - /, - *, - data_format: str = "NWC", - out: Optional[ivy.Array] = None, -) -> ivy.Array: - """ - Compute a 1-D max unpooling given the 1-D pooled input x and its indices. - - Parameters - ---------- - x - Pooled input image *[batch_size, w, d_in]*. - indices - Indices obtained from the corresponding max pooling operation. - kernel - Size of the kernel i.e., the sliding window for each - dimension of input. *[w]*. - strides - The stride of the sliding window for each dimension of input. - padding - SAME" or "VALID" indicating the algorithm, or list - indicating the per-dimension paddings. - data_format - NWC" or "NCW". Defaults to "NWC". - out - optional output array, for writing the result to. - - Returns - ------- - ret - The result of the unpooling operation. - - Both the description and the type hints above assume an array input - for simplicity, but this function is *nestable*, and therefore - also accepts :class:`ivy.Container` instances in place of any of - the arguments. - - Examples - -------- - >>> x = ivy.arange(0, 24.).reshape((2, 3, 4)) - >>> pool_result = ivy.max_pool1d(x, 2, 2, 'SAME') - >>> print(pool_result) - ivy.array([[[ 4., 5., 6., 7.], - [ 8., 9., 10., 11.]], - - [[16., 17., 18., 19.], - [20., 21., 22., 23.]]]) - >>> unpool_result = ivy.max_unpool1d(pool_result, indices, 2, 2, 'SAME') - >>> print(unpool_result) - ivy.array([[[ 0., 4., 0., 5., 0., 6., 0., 7., 0., 0., 0., 0.], - [ 0., 0., 0., 0., 8., 0., 9., 0., 10., 0., 11., 0.]], - - [[ 0., 0., 0., 0., 0., 0., 0., 0., 16., 0., 17., 0.], - [ 0., 18., 0., 19., 0., 0., 0., 0., 20., 0., 21., 0.]]]) - """ - return ivy.current_backend(x).max_unpool1d( - x, indices, kernel, strides, padding, data_format=data_format, out=out - ) - - @handle_backend_invalid @handle_nestable @handle_out_argument diff --git a/ivy/functional/ivy/experimental/linear_algebra.py b/ivy/functional/ivy/experimental/linear_algebra.py index 405ad8bc46e3c..d4a511c30353a 100644 --- a/ivy/functional/ivy/experimental/linear_algebra.py +++ b/ivy/functional/ivy/experimental/linear_algebra.py @@ -1666,3 +1666,90 @@ def dot( ivy.array([[-15.28]]) """ return current_backend(a, b).dot(a, b, out=out) + + +@handle_exceptions +@handle_nestable +@handle_array_like_without_promotion +@inputs_to_ivy_arrays +@handle_array_function +@handle_device_shifting +def general_inner_product( + a: Union[ivy.Array, ivy.NativeArray], + b: Union[ivy.Array, ivy.NativeArray], + n_modes: Optional[int] = None, + /, + *, + out: Optional[ivy.Array] = None, +) -> ivy.Array: + """ + Generalised inner products between tensors. + + Takes the inner product between the last (respectively first) + `n_modes` of `a` (respectively `b`) + + Parameters + ---------- + a + first input tensor. + b + second input tensor. + n_modes + int, default is None. If None, the traditional inner product is returned + (i.e. a float) otherwise, the product between the `n_modes` last modes of + `a` and the `n_modes` first modes of `b` is returned. The resulting tensor's + order is `len(a) - n_modes`. + out + Optional output array. If provided, the output array to store the result. + + Returns + ------- + The inner product of the input arrays. + + Examples + -------- + With :class:`ivy.Array` inputs: + + >>> a = ivy.array([1, 2, 3]) + >>> b = ivy.array([4, 5, 6]) + >>> result = ivy.general_inner_product(a, b, n_modes=1) + >>> print(result) + ivy.array(32) + + >>> a = ivy.array([1, 2]) + >>> b = ivy.array([4, 5]) + >>> result = ivy.general_inner_product(a, b) + >>> print(result) + ivy.array(14) + + >>> a = ivy.array([[1, 1], [1, 1]]) + >>> b = ivy.array([[1, 2, 3, 4],[1, 1, 1, 1]]) + >>> result = ivy.general_inner_product(a, b, n_modes=1) + >>> print(result) + ivy.array([[2, 3, 4, 5], + [2, 3, 4, 5]]) + """ + shape_a = a.shape + shape_b = b.shape + if n_modes is None: + if shape_a != shape_b: + raise ValueError( + "Taking a generalised product between two tensors without specifying" + " common modes is equivalent to taking inner product.This requires" + f" a.shape == b.shape.However, got shapes {a.shape} and {b.shape}" + ) + return ivy.sum(ivy.multiply(a, b), out=out) + + common_modes = shape_a[len(shape_a) - n_modes :] + if common_modes != shape_b[:n_modes]: + raise ValueError( + f"Incorrect shapes for inner product along {n_modes} common modes." + f"Shapes {shape_a.shape} and {shape_b.shape}" + ) + + common_size = int(ivy.prod(common_modes)) if len(common_modes) != 0 else 0 + output_shape = shape_a[:-n_modes] + shape_b[n_modes:] + inner_product = ivy.dot( + ivy.reshape(a, (-1, common_size)), ivy.reshape(b, (common_size, -1)) + ) + return ivy.reshape(inner_product, output_shape, out=out) diff --git a/ivy/functional/ivy/experimental/losses.py b/ivy/functional/ivy/experimental/losses.py index 6070101b45ae4..f2950d682bb0d 100644 --- a/ivy/functional/ivy/experimental/losses.py +++ b/ivy/functional/ivy/experimental/losses.py @@ -66,7 +66,7 @@ def log_poisson_loss( -------- >>> x = ivy.array([0, 0, 1, 0]) >>> y = ivy.array([0.25, 0.25, 0.25, 0.25]) - >>> print(ivy.log_poisson_loss(x, z)) + >>> print(ivy.log_poisson_loss(x, y)) ivy.array([1.28402555, 1.28402555, 1.03402555, 1.28402555]) >>> z = ivy.array([0.1, 0.1, 0.7, 0.1]) @@ -408,3 +408,72 @@ def soft_margin_loss( return ivy.mean(loss, out=out) else: return ivy.inplace_update(out, loss) if out is not None else loss + + +@handle_exceptions +@handle_nestable +@inputs_to_ivy_arrays +@handle_array_function +def kl_div( + input: Union[ivy.Array, ivy.NativeArray], + target: Union[ivy.Array, ivy.NativeArray], + /, + *, + reduction: Optional[str] = "mean", + out: Optional[ivy.Array] = None, +) -> ivy.Array: + """ + Compute the Kullback-Leibler divergence loss between two input tensors + (conventionally, probability distributions). + + Parameters + ---------- + input : array_like + Input probability distribution (first tensor). + target : array_like + Target probability distribution (second tensor). + reduction : {'mean', 'sum', 'batchmean', 'none'}, optional + Type of reduction to apply to the output. Default is 'mean'. + out : array_like, optional + Optional output array, for writing the result to. + It must have a shape that the inputs broadcast to. + + Returns + ------- + ret : array + The Kullback-Leibler divergence loss between the two input tensors. + + Examples + -------- + >>> input = ivy.array([0.2, 0.8], [0.5, 0.5]) + >>> target = ivy.array([0.6, 0.4], [0.3, 0.7]) + >>> ivy.kl_div(input, target) + ivy.array(0.0916) + + >>> input = ivy.array([0.2, 0.8], [0.5, 0.5]) + >>> target = ivy.array([0.6, 0.4], [0.3, 0.7]) + >>> ivy.kl_div(input, target, reduction='sum') + ivy.array(0.1832) + + >>> input = ivy.array([0.2, 0.8], [0.5, 0.5]) + >>> target = ivy.array([0.6, 0.4], [0.3, 0.7]) + >>> ivy.kl_div(input, target, reduction='batchmean') + ivy.array(0.0916) + + >>> input = ivy.array([0.2, 0.8], [0.5, 0.5]) + >>> target = ivy.array([0.6, 0.4], [0.3, 0.7]) + >>> ivy.kl_div(input, target, reduction='none') + ivy.array([0.0378], [0.1453]) + """ + size = ivy.shape(input) + + loss = ivy.sum(input * ivy.log(input / target), axis=-1) + + if reduction == "sum": + loss = ivy.sum(loss, out=out) + elif reduction == "mean": + loss = ivy.mean(loss, out=out) + elif reduction == "batchmean": + loss = ivy.sum(loss, out=out) / size[0] + + return ivy.inplace_update(out, loss) if out is not None else loss diff --git a/ivy/functional/ivy/experimental/manipulation.py b/ivy/functional/ivy/experimental/manipulation.py index f8d4e90f73388..438b90ad4c195 100644 --- a/ivy/functional/ivy/experimental/manipulation.py +++ b/ivy/functional/ivy/experimental/manipulation.py @@ -496,8 +496,8 @@ def rot90( m Input array of two or more dimensions. copy - boolean indicating whether or not to copy the input array. - If True, the function must always copy. + boolean indicating whether or not to copy the input array. + If True, the function must always copy. If False, the function must never copy. In case copy is False we avoid copying by returning a view of the input array. k @@ -2632,3 +2632,70 @@ def choose( ivy.array([20, 1, 12, 3]) """ return ivy.current_backend(arr).choose(arr, choices, out=out, mode=mode) + + +@handle_array_function +@inputs_to_ivy_arrays +@handle_nestable +@handle_exceptions +@handle_device_shifting +def column_stack( + arrays: Sequence[Union[ivy.Array, ivy.NativeArray]], + /, + *, + out: Optional[ivy.Array] = None, +) -> ivy.Array: + """ + Create a new array by horizontally stacking the arrays in arrays. + + Equivalent to `ivy.hstack(arrays)`, except each zero or one dimensional + array `x` in arrays is first reshaped into a `(x.size(), 1)` column + before being stacked horizontally. + + Parameters + ---------- + arrays + Arrays to be stacked. + out + Output array. + + Returns + ------- + ret + Stacked input. + + Examples + -------- + Arrays of different dtypes up to dimension 2. + >>> a0 = ivy.array(True) + >>> a1 = ivy.array([7]) + >>> a2 = ivy.array([[11.3, 13.7]]) + >>> ivy.column_stack((a0, a1, a2)) + ivy.array([[ 1. , 7. , 11.30000019, 13.69999981]]) + + Arrays of dimension 3. + >>> a = ivy.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + >>> b = ivy.array([[[11, 12]], [[13, 14]]]) + >>> ivy.column_stack((a, b)) + ivy.array([[[ 1, 2], + [ 3, 4], + [11, 12]], + + [[ 5, 6], + [ 7, 8], + [13, 14]]]) + """ + arrays = [ivy.reshape(x, shape=(-1, 1)) if x.ndim < 2 else x for x in arrays] + + return ivy.hstack(arrays, out=out) + + +column_stack.mixed_backend_wrappers = { + "to_add": ( + "handle_backend_invalid", + "inputs_to_native_arrays", + "outputs_to_ivy_arrays", + "handle_out_argument", + ), + "to_skip": ("inputs_to_ivy_arrays",), +} diff --git a/ivy/functional/ivy/experimental/sparse_array.py b/ivy/functional/ivy/experimental/sparse_array.py index 5c67cc823e7c1..49a13248921c8 100644 --- a/ivy/functional/ivy/experimental/sparse_array.py +++ b/ivy/functional/ivy/experimental/sparse_array.py @@ -316,7 +316,7 @@ def _is_valid_format( ): valid_formats = ["coo", "csr", "csc", "csc", "bsc", "bsr"] - if not isinstance(format, str) or not format.lower() in valid_formats: + if not isinstance(format, str) or format.lower() not in valid_formats: return False if format.endswith("o"): diff --git a/ivy/functional/ivy/experimental/statistical.py b/ivy/functional/ivy/experimental/statistical.py index d76cfd1a334e4..3c998de8656e8 100644 --- a/ivy/functional/ivy/experimental/statistical.py +++ b/ivy/functional/ivy/experimental/statistical.py @@ -258,6 +258,74 @@ def nanmean( @handle_nestable @handle_out_argument @to_native_arrays_and_back +@infer_dtype +@handle_device_shifting +def nanprod( + a: ivy.Array, + /, + *, + axis: Optional[Union[Tuple[int], int]] = None, + keepdims: Optional[bool] = False, + dtype: Optional[Union[ivy.Dtype, ivy.NativeDtype]] = None, + out: Optional[ivy.Array] = None, + initial: Optional[Union[int, float, complex]] = None, + where: Optional[ivy.Array] = None, +) -> ivy.Array: + """ + Compute the product of array elements over a given axis treating Not a Numbers + (NaNs) as ones. + + Parameters + ---------- + a + Input array. + axis + Axis or axes along which the product is computed. + The default is to compute the product of the flattened array. + dtype + The desired data type of returned array. Default is None. + out + optional output array, for writing the result to. + keepdims + If this is set to True, the axes which are reduced are left in the result + as dimensions with size one. With this option, the result will broadcast + correctly against the original a. + initial + The starting value for this product. + where + Elements to include in the product + + Returns + ------- + ret + The product of array elements over a given axis treating + Not a Numbers (NaNs) as ones + + Functional Examples + ------------------- + >>> a = ivy.array([[1, ivy.nan], [3, 4]]) + >>> ivy.nanprod(a) + 12.0 + >>> ivy.nanprod(a, axis=0) + [3. 4.] + >>> ivy.nanprod(a, axis=0, keepdims=True) + [[3. 4.]] + """ + return ivy.current_backend(a).nanprod( + a, + axis=axis, + keepdims=keepdims, + dtype=dtype, + out=out, + initial=initial, + where=where, + ) + + +@handle_exceptions +@handle_nestable +@handle_out_argument +@to_native_arrays_and_back @handle_device_shifting def quantile( a: ivy.Array, diff --git a/ivy/functional/ivy/general.py b/ivy/functional/ivy/general.py index f80a3fae31133..5b63f9e020d66 100644 --- a/ivy/functional/ivy/general.py +++ b/ivy/functional/ivy/general.py @@ -73,7 +73,7 @@ class PreciseMode: """Precise Mode Context Manager.""" # noinspection PyShadowingNames - def __init__(self, precise_mode): + def __init__(self, precise_mode: bool): self._precise_mode = precise_mode def __enter__(self): @@ -180,23 +180,49 @@ def __exit__(self, exc_type, exc_val, exc_tb): def get_referrers_recursive( - item, depth=0, max_depth=None, seen_set=None, local_set=None -): + item: object, + depth: int = 0, + max_depth: int = None, + seen_set: set = None, + local_set: set = None, +) -> ivy.Container: """ - Summary. + Recursively retrieve referrers for an object. + + This function recursively fetches referrers for the specified `item` up to a given `max_depth`. Parameters ---------- - item + item : object + The object for which referrers should be retrieved. + depth : int, optional + Current depth in the recursion. (default is 0) + max_depth : int, optional + Maximum depth of recursion. If `None`, there's no depth limit. (default is None) + seen_set : set, optional + Set of seen referrer IDs to prevent duplicates. (default is None) + local_set : set, optional + Set of local referrer IDs to avoid redundancy. (default is None) + + Returns + ------- + ivy.Container + A container representing referrers and their sub-referrers, respecting the `max_depth`. - depth - (Default value = 0) - max_depth - (Default value = None) - seen_set - (Default value = None) - local_set - (Default value = None`) + Examples + -------- + >>> import gc + >>> def example_function(): + ... obj = [1, 2, 3] + ... return get_referrers_recursive(obj, max_depth=2) + >>> result = example_function() + >>> print(result) + Container( + 'ref_id_1': Container( + 'ref_id_2': 'tracked', + 'ref_id_3': 'tracked' + ) + ) """ seen_set = ivy.default(seen_set, set()) local_set = ivy.default(local_set, set()) @@ -205,6 +231,7 @@ def get_referrers_recursive( alphabetical_keys=False, keyword_color_dict={"repr": "magenta"}, ) + referrers = [ ref for ref in gc.get_referrers(item) @@ -213,6 +240,7 @@ def get_referrers_recursive( and min([k in ref for k in ["depth", "max_depth", "seen_set", "local_set"]]) ) ] + local_set.add(str(id(referrers))) for ref in referrers: ref_id = str(id(ref)) @@ -220,22 +248,28 @@ def get_referrers_recursive( continue seen = ref_id in seen_set seen_set.add(ref_id) - refs_rec = lambda: get_referrers_recursive( - ref, depth + 1, max_depth, seen_set, local_set - ) + + def get_referrers_recursive_inner(): + return get_referrers_recursive( + ref, depth + 1, max_depth, seen_set, local_set + ) + this_repr = "tracked" if seen else str(ref).replace(" ", "") + if not seen and (not max_depth or depth < max_depth): val = ivy.Container( repr=this_repr, alphabetical_keys=False, keyword_color_dict={"repr": "magenta"}, ) - refs = refs_rec() + + refs = get_referrers_recursive_inner() for k, v in refs.items(): val[k] = v else: val = this_repr ret_cont[str(ref_id)] = val + return ret_cont @@ -1351,7 +1385,7 @@ def has_nans( @handle_exceptions -def exists(x: Any) -> bool: +def exists(x: Any, /) -> bool: """ Check as to whether the input is None or not. @@ -2940,7 +2974,7 @@ def _broadcast_to(input, target_shape): @handle_nestable @inputs_to_ivy_arrays @handle_array_function -# @handle_device_shifting +@handle_device_shifting def inplace_update( x: Union[ivy.Array, ivy.NativeArray], val: Union[ivy.Array, ivy.NativeArray], diff --git a/ivy/functional/ivy/layers.py b/ivy/functional/ivy/layers.py index c8096d4d1fe26..a378c626db5da 100644 --- a/ivy/functional/ivy/layers.py +++ b/ivy/functional/ivy/layers.py @@ -714,7 +714,7 @@ def multi_head_attention( value: Optional[Union[ivy.Array, ivy.NativeArray]] = None, /, *, - num_heads: Optional[int] = 8, + num_heads: int = 8, scale: Optional[float] = None, attention_mask: Optional[Union[ivy.Array, ivy.NativeArray]] = None, in_proj_weights: Optional[Union[ivy.Array, ivy.NativeArray]] = None, @@ -724,12 +724,12 @@ def multi_head_attention( out_proj_weights: Optional[Union[ivy.Array, ivy.NativeArray]] = None, in_proj_bias: Optional[Union[ivy.Array, ivy.NativeArray]] = None, out_proj_bias: Optional[Union[ivy.Array, ivy.NativeArray]] = None, - is_causal: Optional[bool] = False, - return_attention_weights: Optional[bool] = False, - average_attention_weights: Optional[bool] = True, - dropout: Optional[float] = 0.0, - training: Optional[bool] = False, - out: Optional[Union[ivy.Array, ivy.NativeArray]] = None, + is_causal: bool = False, + return_attention_weights: bool = False, + average_attention_weights: bool = True, + dropout: float = 0.0, + training: bool = False, + out: Optional[ivy.Array] = None, ) -> Union[ivy.Array, ivy.NativeArray]: """ Apply multi-head attention to inputs x. This is an implementation of multi-headed diff --git a/ivy/functional/ivy/linear_algebra.py b/ivy/functional/ivy/linear_algebra.py index 3708af0b00732..b7c3f450047ea 100644 --- a/ivy/functional/ivy/linear_algebra.py +++ b/ivy/functional/ivy/linear_algebra.py @@ -1397,11 +1397,11 @@ def matrix_rank( where ``eps`` must be the machine epsilon associated with the floating-point data type determined by :ref:`type-promotion` (as applied to ``x``). Default: ``None``. - + hermitian indicates whether ``x`` is Hermitian. When ``hermitian=True``, ``x`` is assumed to be Hermitian, enabling a more efficient method for finding - eigenvalues, but x is not checked inside the function. + eigenvalues, but x is not checked inside the function. Instead, We just use the lower triangular of the matrix to compute. Default: ``False``. out diff --git a/ivy/functional/ivy/nest.py b/ivy/functional/ivy/nest.py index 5965c48416e08..c0826e5e8d653 100644 --- a/ivy/functional/ivy/nest.py +++ b/ivy/functional/ivy/nest.py @@ -1036,7 +1036,7 @@ def nested_map( x: Union[ivy.Array, ivy.NativeArray, Iterable], /, fn: Callable, - include_derived: Optional[Union[Dict[type, bool], bool]] = None, + include_derived: Optional[Union[Dict[str, bool], bool]] = None, to_ignore: Optional[Union[type, Tuple[type]]] = None, to_mutable: bool = False, max_depth: Optional[int] = None, @@ -1160,12 +1160,13 @@ def nested_map( to_ignore = ivy.default(to_ignore, ()) extra_nest_types = ivy.default(extra_nest_types, ()) if include_derived is True: - include_derived = {tuple: True, list: True, dict: True} + include_derived = {"tuple": True, "list": True, "dict": True} elif not include_derived: include_derived = {} - for t in (tuple, list, dict): + for t in ("tuple", "list", "dict"): if t not in include_derived: include_derived[t] = False + # to ensure all keys are strings if ivy.exists(max_depth) and _depth > max_depth: return x class_instance = type(x) @@ -1182,7 +1183,7 @@ def nested_map( _tuple_check_fn, ( (lambda x_, t_: isinstance(x_, t_)) - if include_derived[tuple] + if include_derived["tuple"] else (lambda x_, t_: type(x_) is t_) ), ) @@ -1190,7 +1191,7 @@ def nested_map( _list_check_fn, ( (lambda x_, t_: isinstance(x_, t_)) - if include_derived[list] + if include_derived["list"] else (lambda x_, t_: type(x_) is t_) ), ) @@ -1198,11 +1199,10 @@ def nested_map( _dict_check_fn, ( (lambda x_, t_: isinstance(x_, t_)) - if include_derived[dict] + if include_derived["dict"] else (lambda x_, t_: type(x_) is t_) ), ) - if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ nested_map( diff --git a/ivy/stateful/activations.py b/ivy/stateful/activations.py index b9977f6f17b1f..7b9c67035e117 100644 --- a/ivy/stateful/activations.py +++ b/ivy/stateful/activations.py @@ -170,10 +170,25 @@ def _forward(self, x): class Softmax(Module): - def __init__(self, axis: int = -1): - """Apply the SOFTMAX activation function.""" + def __init__( + self, + axis: int = -1, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + ): + """ + Apply the SOFTMAX activation function. + + Parameters + ---------- + axis + The axis which we apply softmax op on. + complex_mode + Specifies how to handle complex input. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. + """ Module.__init__(self) self._axis = axis + self._complex_mode = complex_mode def _forward(self, x): """ @@ -191,7 +206,7 @@ def _forward(self, x): The outputs following the SOFTMAX activation *[batch_shape, d]* """ - return ivy.softmax(x, axis=self._axis) + return ivy.softmax(x, axis=self._axis, complex_mode=self._complex_mode) class Softplus(Module): @@ -359,10 +374,25 @@ def _forward(self, x): class Logit(Module): - def __init__(self, eps=None): - """Apply the LOGIT activation function.""" + def __init__( + self, + eps=None, + complex_mode="jax", + ): + """ + Apply the LOGIT activation function. + + Parameters + ---------- + eps + The epsilon value for the logit formation. Default: ``None``. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. + """ Module.__init__(self) self._eps = eps + self._complex_mode = complex_mode def _forward(self, x): """ @@ -371,15 +401,17 @@ def _forward(self, x): ---------- x Inputs to process *[batch_shape, d]*. - eps - The epsilon value for the logit formation. Default: ``None``. Returns ------- ret The outputs following the LOGIT activation *[batch_shape, d]* """ - return ivy.logit(x, eps=self._eps) + return ivy.logit( + x, + eps=self._eps, + complex_mode=self._complex_mode, + ) class PReLU(Module): @@ -450,8 +482,17 @@ def _forward(self, x): class LogSigmoid(Module): - def __init__(self): - """Apply the LogSigmoid activation function.""" + def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): + """ + Apply the LogSigmoid activation function. + + Parameter + ---------- + complex_mode + Specifies how to handle complex input. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. + """ + self._complex_mode = complex_mode Module.__init__(self) def _forward(self, x): @@ -467,4 +508,4 @@ def _forward(self, x): ret The outputs following the LogSigmoid activation *[batch_shape, d]* """ - return ivy.logsigmoid(x) + return ivy.logsigmoid(x, complex_mode=self._complex_mode) diff --git a/ivy/utils/_importlib.py b/ivy/utils/_importlib.py index f6cb3ad3a25c3..d65e7caecc270 100644 --- a/ivy/utils/_importlib.py +++ b/ivy/utils/_importlib.py @@ -11,7 +11,7 @@ # If they do, the behavior of ivy.with_backend is undefined and may not function as # expected. Import these modules along with Ivy initialization, as the import logic # assumes they exist in sys.modules. -MODULES_TO_SKIP = ["ivy.compiler"] +MODULES_TO_SKIP = ["ivy.compiler", "ivy.engines"] IS_COMPILING_WITH_BACKEND = False diff --git a/ivy/utils/backend/handler.py b/ivy/utils/backend/handler.py index 1442205401e99..9693cb035fb17 100644 --- a/ivy/utils/backend/handler.py +++ b/ivy/utils/backend/handler.py @@ -626,8 +626,6 @@ def with_backend(backend: str, cached: bool = True): # Use already compiled object if cached and backend in compiled_backends.keys(): cached_backend = compiled_backends[backend][-1] - if not cached_backend.native_inplace_support: - _handle_inplace_mode() return cached_backend with _importlib.LocalIvyImporter(): ivy_pack = _importlib._import_module("ivy") @@ -657,5 +655,7 @@ def with_backend(backend: str, cached: bool = True): compiled_backends[backend].append(ivy_pack) except KeyError: compiled_backends[backend] = [ivy_pack] - _handle_inplace_mode() + if ivy.backend != backend: + # to avoid warning users when not using set_backend with ivy.Array.__repr__ + _handle_inplace_mode(ivy_pack=ivy_pack) return ivy_pack diff --git a/ivy/utils/exceptions.py b/ivy/utils/exceptions.py index 2d77f18ae6de9..886c41400d36d 100644 --- a/ivy/utils/exceptions.py +++ b/ivy/utils/exceptions.py @@ -391,14 +391,14 @@ def _handle_exceptions(*args, **kwargs): except InvalidBackendException as e: _configure_stack_trace(e.__traceback__) raise e - except (Exception, IvyBackendException) as e: + except InplaceUpdateException as e: _configure_stack_trace(e.__traceback__) - raise ivy.utils.exceptions.IvyBackendException( + raise ivy.utils.exceptions.InplaceUpdateException( fn.__name__, str(e), include_backend=True ) - except InplaceUpdateException as e: + except (Exception, IvyBackendException) as e: _configure_stack_trace(e.__traceback__) - raise ivy.utils.exceptions.InplaceUpdateException( + raise ivy.utils.exceptions.IvyBackendException( fn.__name__, str(e), include_backend=True ) @@ -409,9 +409,11 @@ def _handle_exceptions(*args, **kwargs): # Inplace Update -def _handle_inplace_mode(): - current_backend = ivy.current_backend_str() - if not ivy.native_inplace_support and ivy.inplace_mode == "lenient": +def _handle_inplace_mode(ivy_pack=None): + if not ivy_pack: + ivy_pack = ivy + current_backend = ivy_pack.current_backend_str() + if not ivy_pack.native_inplace_support and ivy_pack.inplace_mode == "lenient": warnings.warn( f"The current backend: '{current_backend}' does not support " "inplace updates natively. Ivy would quietly create new arrays when " diff --git a/ivy_tests/array_api_testing/write_array_api_tests_k_flag.py b/ivy_tests/array_api_testing/write_array_api_tests_k_flag.py index 61b678fda7f08..a41b8243ce608 100644 --- a/ivy_tests/array_api_testing/write_array_api_tests_k_flag.py +++ b/ivy_tests/array_api_testing/write_array_api_tests_k_flag.py @@ -40,7 +40,7 @@ continue if ("#" not in s) or ( "#" in s - and not (framework in s.lower()) + and (framework not in s.lower()) and any(f in s.lower() for f in framework_tests_to_run) ): tests_to_run += ( diff --git a/ivy_tests/test_ivy/helpers/function_testing.py b/ivy_tests/test_ivy/helpers/function_testing.py index 19a3044fcb961..072aa0449e6be 100644 --- a/ivy_tests/test_ivy/helpers/function_testing.py +++ b/ivy_tests/test_ivy/helpers/function_testing.py @@ -915,7 +915,7 @@ def test_frontend_function( if not test_values: ret = ivy_backend.nested_map( - ret, _frontend_array_to_ivy, include_derived={tuple: True} + ret, _frontend_array_to_ivy, include_derived={"tuple": True} ) def arrays_to_numpy(x): @@ -965,7 +965,7 @@ def arrays_to_numpy(x): # change ivy device to native devices if "device" in kwargs_frontend: - kwargs_frontend["device"] = frontend_config.as_native_dev( + kwargs_frontend["device"] = frontend_config.as_native_device( kwargs_frontend["device"] ) @@ -2092,7 +2092,7 @@ def test_frontend_method( frontend_config.as_native_dtype(x) if isinstance(x, frontend_config.Dtype) else ( - frontend_config.as_native_dev(x) + frontend_config.as_native_device(x) if isinstance(x, frontend_config.Device) else x ) @@ -2114,7 +2114,7 @@ def test_frontend_method( # change ivy device to native devices if "device" in kwargs_method_frontend: - kwargs_method_frontend["device"] = frontend_config.as_native_dev( + kwargs_method_frontend["device"] = frontend_config.as_native_device( kwargs_method_frontend["device"] ) frontend_creation_fn = getattr( @@ -2340,7 +2340,8 @@ def flatten_and_to_np(*, backend: str, ret): # flatten the return ret_flat = flatten(backend=backend, ret=ret) with BackendHandler.update_backend(backend) as ivy_backend: - return [ivy_backend.to_numpy(x) for x in ret_flat] + ret = [ivy_backend.to_numpy(x) for x in ret_flat] + return ret def flatten_frontend_to_np(*, backend: str, ret, frontend_array_fn=None): @@ -2376,7 +2377,7 @@ def map_fn(x): return ivy_backend.to_ivy(x) return x - ret = ivy_backend.nested_map(ret, map_fn, include_derived={tuple: True}) + ret = ivy_backend.nested_map(ret, map_fn, include_derived={"tuple": True}) return ret, flatten_and_to_np(backend=backend_to_test, ret=ret) @@ -2396,24 +2397,24 @@ def get_frontend_ret( with BackendHandler.update_backend(backend) as ivy_backend: if not as_ivy_arrays and test_compile: args, kwargs = ivy_backend.nested_map( - (args, kwargs), _frontend_array_to_ivy, include_derived={tuple: True} + (args, kwargs), _frontend_array_to_ivy, include_derived={"tuple": True} ) with ivy_backend.PreciseMode(precision_mode): ret = frontend_fn(*args, **kwargs) if test_compile and frontend_array_function is not None: if as_ivy_arrays: ret = ivy_backend.nested_map( - ret, ivy_backend.asarray, include_derived={tuple: True} + ret, ivy_backend.asarray, include_derived={"tuple": True} ) else: ret = ivy_backend.nested_map( ret, arrays_to_frontend(backend, frontend_array_function), - include_derived={tuple: True}, + include_derived={"tuple": True}, ) elif as_ivy_arrays: ret = ivy_backend.nested_map( - ret, _frontend_array_to_ivy, include_derived={tuple: True} + ret, _frontend_array_to_ivy, include_derived={"tuple": True} ) return ret diff --git a/ivy_tests/test_ivy/helpers/hypothesis_helpers/array_helpers.py b/ivy_tests/test_ivy/helpers/hypothesis_helpers/array_helpers.py index ce163e7d8a072..ecab2b63b01c0 100644 --- a/ivy_tests/test_ivy/helpers/hypothesis_helpers/array_helpers.py +++ b/ivy_tests/test_ivy/helpers/hypothesis_helpers/array_helpers.py @@ -2162,3 +2162,78 @@ def einsum_helper(draw): eq = "".join(eq_1) + "," + "".join(eq_2) + "->" + output_eq return eq, (value_1[0], value_2[0]), [dtype_1[0], dtype_2[0]] + + +@st.composite +def create_concatenable_arrays_dtypes( + draw, + min_num_dims, + max_num_dims, + min_num_arrays, + max_num_arrays, + concat_dim, + dtypes, + common_shape=None, +): + """ + Draws a random number of arrays with concatenable or stackable dimensions. Arrays + have same number of dimensions, but their shape can differ along a specified + dimension (concat_dim). If concat_dim is None, arrays have the same shape. Dtypes of + arrays can differ. + + Parameters + ---------- + min_num_dims + minimum number of dimensions + max_num_dims + maximum number of dimensions + min_num_arrays + minimum number of arrays + max_num_arrays + maximum number of arrays + concat_dim + dimension along which the shape of arrays can differ, + if None all the arrays will have the same shape + dtypes + list of dtypes from which array dtypes will be draws, + each array can have different dtype + given_common_shape + if not None, specifies the shape of the arrays + (dimension concat_dim can still be modified) + """ + num_arrays = draw(helpers.ints(min_value=min_num_arrays, max_value=max_num_arrays)) + if common_shape is None: + num_dims = draw(helpers.ints(min_value=min_num_dims, max_value=max_num_dims)) + common_shape = draw( + helpers.list_of_size( + x=helpers.ints(min_value=1, max_value=5), + size=num_dims, + ) + ) + else: + num_dims = len(common_shape) + input_dtypes = draw( + helpers.array_dtypes(num_arrays=num_arrays, available_dtypes=dtypes) + ) + array_shapes = [common_shape.copy() for i in range(num_arrays)] + if num_dims > 0 and concat_dim is not None: + unique_dims = draw( + helpers.list_of_size( + x=helpers.ints(min_value=1, max_value=5), + size=num_arrays, + ) + ) + for i in range(num_arrays): + array_shapes[i][concat_dim] = unique_dims[i] + + xs = list() + + for sh, dt in zip(array_shapes, input_dtypes): + x = draw( + helpers.array_values( + shape=sh, + dtype=dt, + ) + ) + xs.append(x) + return xs, input_dtypes diff --git a/ivy_tests/test_ivy/helpers/hypothesis_helpers/dtype_helpers.py b/ivy_tests/test_ivy/helpers/hypothesis_helpers/dtype_helpers.py index fe2d37e739c75..e725ddb7e3fe7 100644 --- a/ivy_tests/test_ivy/helpers/hypothesis_helpers/dtype_helpers.py +++ b/ivy_tests/test_ivy/helpers/hypothesis_helpers/dtype_helpers.py @@ -432,4 +432,9 @@ def cast_filter_helper(d, dtype, x, current_backend): max_x <= max_val and min_x >= min_val and bound_dtype_bits(d) >= bound_dtype_bits(dtype) + and ( + ivy_backend.is_complex_dtype(d) + or not ivy_backend.is_complex_dtype(dtype) + ) + and (min_x > 0 or not ivy_backend.is_uint_dtype(dtype)) ) diff --git a/ivy_tests/test_ivy/helpers/test_parameter_flags.py b/ivy_tests/test_ivy/helpers/test_parameter_flags.py index 2ab790f34b83f..32dfbdb8aacf5 100644 --- a/ivy_tests/test_ivy/helpers/test_parameter_flags.py +++ b/ivy_tests/test_ivy/helpers/test_parameter_flags.py @@ -3,6 +3,14 @@ from . import globals as test_globals from .pipeline_helper import BackendHandler +from dataclasses import dataclass +from hypothesis.strategies import SearchStrategy + + +@dataclass +class DynamicFlag: + strategy: SearchStrategy + @st.composite def _gradient_strategy(draw): @@ -27,17 +35,17 @@ def _as_varaible_strategy(draw): return draw(st.lists(st.booleans(), min_size=1, max_size=1)) -BuiltNativeArrayStrategy = st.lists(st.booleans(), min_size=1, max_size=1) -BuiltAsVariableStrategy = _as_varaible_strategy() -BuiltContainerStrategy = st.lists(st.booleans(), min_size=1, max_size=1) -BuiltInstanceStrategy = st.booleans() -BuiltInplaceStrategy = st.just(False) -BuiltGradientStrategy = _gradient_strategy() -BuiltWithOutStrategy = st.booleans() -BuiltCompileStrategy = st.just(False) -BuiltFrontendArrayStrategy = st.booleans() -BuiltTranspileStrategy = st.just(False) -BuiltPrecisionModeStrategy = st.booleans() +BuiltNativeArrayStrategy = DynamicFlag(st.lists(st.booleans(), min_size=1, max_size=1)) +BuiltAsVariableStrategy = DynamicFlag(_as_varaible_strategy()) +BuiltContainerStrategy = DynamicFlag(st.lists(st.booleans(), min_size=1, max_size=1)) +BuiltInstanceStrategy = DynamicFlag(st.booleans()) +BuiltInplaceStrategy = DynamicFlag(st.just(False)) +BuiltGradientStrategy = DynamicFlag(_gradient_strategy()) +BuiltWithOutStrategy = DynamicFlag(st.booleans()) +BuiltCompileStrategy = DynamicFlag(st.just(False)) +BuiltFrontendArrayStrategy = DynamicFlag(st.booleans()) +BuiltTranspileStrategy = DynamicFlag(st.just(False)) +BuiltPrecisionModeStrategy = DynamicFlag(st.booleans()) flags_mapping = { @@ -61,7 +69,7 @@ def build_flag(key: str, value: bool): assert ( flags_mapping[key] in globals().keys() ), f"{flags_mapping[key]} is not a valid flag variable." - globals()[flags_mapping[key]] = value + globals()[flags_mapping[key]].strategy = value # Strategy Helpers # diff --git a/ivy_tests/test_ivy/helpers/testing_helpers.py b/ivy_tests/test_ivy/helpers/testing_helpers.py index 43de1243a3e8f..0bbd9d3ab21e0 100644 --- a/ivy_tests/test_ivy/helpers/testing_helpers.py +++ b/ivy_tests/test_ivy/helpers/testing_helpers.py @@ -17,6 +17,7 @@ from . import test_globals as t_globals from .pipeline_helper import BackendHandler from ivy_tests.test_ivy.helpers.test_parameter_flags import ( + DynamicFlag, BuiltInstanceStrategy, BuiltAsVariableStrategy, BuiltNativeArrayStrategy, @@ -51,6 +52,10 @@ ) +def _get_runtime_flag_value(flag): + return flag.strategy if isinstance(flag, DynamicFlag) else flag + + @st.composite def num_positional_args_method(draw, *, method): """ @@ -395,14 +400,14 @@ def handle_test( possible_arguments["test_flags"] = pf.function_flags( ground_truth_backend=st.just(ground_truth_backend), num_positional_args=number_positional_args, - instance_method=test_instance_method, - with_out=test_with_out, - test_gradients=test_gradients, - test_compile=test_compile, - as_variable=as_variable_flags, - native_arrays=native_array_flags, - container_flags=container_flags, - precision_mode=precision_mode, + instance_method=_get_runtime_flag_value(test_instance_method), + with_out=_get_runtime_flag_value(test_with_out), + test_gradients=_get_runtime_flag_value(test_gradients), + test_compile=_get_runtime_flag_value(test_compile), + as_variable=_get_runtime_flag_value(as_variable_flags), + native_arrays=_get_runtime_flag_value(native_array_flags), + container_flags=_get_runtime_flag_value(container_flags), + precision_mode=_get_runtime_flag_value(precision_mode), ) def test_wrapper(test_fn): @@ -526,14 +531,14 @@ def handle_frontend_test( # Generate the test flags strategy test_flags = pf.frontend_function_flags( num_positional_args=number_positional_args, - with_out=test_with_out, - inplace=test_inplace, - as_variable=as_variable_flags, - native_arrays=native_array_flags, - test_compile=test_compile, - generate_frontend_arrays=generate_frontend_arrays, - transpile=transpile, - precision_mode=precision_mode, + with_out=_get_runtime_flag_value(test_with_out), + inplace=_get_runtime_flag_value(test_inplace), + as_variable=_get_runtime_flag_value(as_variable_flags), + native_arrays=_get_runtime_flag_value(native_array_flags), + test_compile=_get_runtime_flag_value(test_compile), + generate_frontend_arrays=_get_runtime_flag_value(generate_frontend_arrays), + transpile=_get_runtime_flag_value(transpile), + precision_mode=_get_runtime_flag_value(precision_mode), ) def test_wrapper(test_fn): @@ -635,9 +640,9 @@ def handle_method( is_hypothesis_test = len(_given_kwargs) != 0 possible_arguments = { "ground_truth_backend": st.just(ground_truth_backend), - "test_gradients": test_gradients, - "test_compile": test_compile, - "precision_mode": precision_mode, + "test_gradients": _get_runtime_flag_value(test_gradients), + "test_compile": _get_runtime_flag_value(test_compile), + "precision_mode": _get_runtime_flag_value(precision_mode), } if is_hypothesis_test and is_method_tree_provided: @@ -650,9 +655,9 @@ def handle_method( possible_arguments["init_flags"] = pf.init_method_flags( num_positional_args=init_num_positional_args, - as_variable=init_as_variable_flags, - native_arrays=init_native_arrays, - precision_mode=precision_mode, + as_variable=_get_runtime_flag_value(init_as_variable_flags), + native_arrays=_get_runtime_flag_value(init_native_arrays), + precision_mode=_get_runtime_flag_value(precision_mode), ) if method_num_positional_args is None: @@ -662,10 +667,10 @@ def handle_method( possible_arguments["method_flags"] = pf.method_flags( num_positional_args=method_num_positional_args, - as_variable=method_as_variable_flags, - native_arrays=method_native_arrays, - container_flags=method_container_flags, - precision_mode=precision_mode, + as_variable=_get_runtime_flag_value(method_as_variable_flags), + native_arrays=_get_runtime_flag_value(method_native_arrays), + container_flags=_get_runtime_flag_value(method_container_flags), + precision_mode=_get_runtime_flag_value(precision_mode), ) def test_wrapper(test_fn): @@ -783,18 +788,18 @@ def test_wrapper(test_fn): param_names = inspect.signature(test_fn).parameters.keys() init_flags = pf.frontend_method_flags( num_positional_args=init_num_positional_args, - as_variable=init_as_variable_flags, - native_arrays=init_native_arrays, - test_compile=test_compile, - precision_mode=precision_mode, + as_variable=_get_runtime_flag_value(init_as_variable_flags), + native_arrays=_get_runtime_flag_value(init_native_arrays), + test_compile=_get_runtime_flag_value(test_compile), + precision_mode=_get_runtime_flag_value(precision_mode), ) method_flags = pf.frontend_method_flags( num_positional_args=method_num_positional_args, - as_variable=method_as_variable_flags, - native_arrays=method_native_arrays, - test_compile=test_compile, - precision_mode=precision_mode, + as_variable=_get_runtime_flag_value(method_as_variable_flags), + native_arrays=_get_runtime_flag_value(method_native_arrays), + test_compile=_get_runtime_flag_value(test_compile), + precision_mode=_get_runtime_flag_value(precision_mode), ) ivy_init_modules = str(ivy_init_module) framework_init_modules = str(framework_init_module) diff --git a/ivy_tests/test_ivy/test_frontends/config/base.py b/ivy_tests/test_ivy/test_frontends/config/base.py index 815b1a74812e9..30c279418d94b 100644 --- a/ivy_tests/test_ivy/test_frontends/config/base.py +++ b/ivy_tests/test_ivy/test_frontends/config/base.py @@ -128,7 +128,7 @@ def as_native_dtype(self, dtype: str): return self.backend.as_native_dtype(dtype) def as_native_device(self, device: str): - return self.backend_as_native_dev(device) + return self.backend.as_native_dev(device) def isscalar(self, x): return self.backend.isscalar(x) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index 19554c3ccc1e3..98660c17328ed 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -348,7 +348,7 @@ def test_jax_leaky_relu( @handle_frontend_test( fn_tree="jax.nn.log_sigmoid", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), min_value=-100, max_value=100, large_abs_safety_factor=8, @@ -735,9 +735,9 @@ def test_jax_soft_sign( @handle_frontend_test( fn_tree="jax.nn.softmax", dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), min_num_dims=2, - max_axes_size=1, + max_axes_size=2, force_int_axis=True, valid_axis=True, ), diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py index ed316ca5a544e..f4fc0911ae534 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py @@ -48,6 +48,57 @@ def test_jax_numpy_fft( ) +# fft2 +@handle_frontend_test( + fn_tree="jax.numpy.fft.fft2", + dtype_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("complex"), + num_arrays=1, + min_value=-1e5, + max_value=1e5, + min_num_dims=2, + max_num_dims=5, + min_dim_size=2, + max_dim_size=5, + allow_inf=False, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + ), + axes=st.sampled_from([(0, 1), (-1, -2), (1, 0)]), + s=st.tuples( + st.integers(min_value=2, max_value=256), st.integers(min_value=2, max_value=256) + ), + norm=st.sampled_from(["backward", "ortho", "forward", None]), +) +def test_jax_numpy_fft2( + dtype_values, + s, + axes, + norm, + frontend, + backend_fw, + test_flags, + fn_tree, + on_device, +): + dtype, values = dtype_values + helpers.test_frontend_function( + input_dtypes=dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=values[0], + s=s, + axes=axes, + norm=norm, + atol=1e-02, + rtol=1e-02, + ) + + # fftshift @handle_frontend_test( fn_tree="jax.numpy.fft.fftshift", @@ -70,3 +121,45 @@ def test_jax_numpy_fftshift( x=arr[0], axes=None, ) + + +# ifft +@handle_frontend_test( + fn_tree="jax.numpy.fft.ifft", + dtype_values_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("complex"), + num_arrays=1, + min_value=-1e5, + max_value=1e5, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=5, + allow_inf=False, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + valid_axis=True, + force_int_axis=True, + ), + n=st.integers(min_value=2, max_value=10), + norm=st.sampled_from(["backward", "ortho", "forward", None]), +) +def test_jax_numpy_ifft( + dtype_values_axis, n, norm, frontend, backend_fw, test_flags, fn_tree, on_device +): + dtype, values, axis = dtype_values_axis + helpers.test_frontend_function( + input_dtypes=dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=values[0], + n=n, + axis=axis, + norm=norm, + atol=1e-02, + rtol=1e-02, + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_random.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_random.py index 10ddbd13320db..7f9778edacea6 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_random.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_random.py @@ -478,6 +478,68 @@ def call(): assert u.shape == v.shape +@pytest.mark.xfail +@handle_frontend_test( + fn_tree="jax.random.double_sided_maxwell", + dtype_key=helpers.dtype_and_values( + available_dtypes=["uint32"], + min_value=1, + max_value=2000, + min_num_dims=1, + max_num_dims=1, + min_dim_size=2, + max_dim_size=2, + ), + shape=helpers.get_shape(), + dtype=helpers.get_dtypes("float", full=False), + loc=st.integers(min_value=10, max_value=100), + scale=st.floats(min_value=0, max_value=100, exclude_min=True), + test_with_out=st.just(False), +) +def test_jax_double_sided_maxwell( + *, + dtype_key, + loc, + scale, + shape, + dtype, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, key = dtype_key + + def call(): + return helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + test_values=False, + backend_to_test=backend_fw, + key=key[0], + loc=loc, + scale=scale, + shape=shape, + dtype=dtype[0], + ) + + ret = call() + + if not ivy.exists(ret): + return + + ret_np, ret_from_np = ret + ret_np = helpers.flatten_and_to_np(backend=backend_fw, ret=ret_np) + ret_from_np = helpers.flatten_and_to_np(backend=backend_fw, ret=ret_from_np) + for u, v in zip(ret_np, ret_from_np): + assert u.dtype == v.dtype + assert u.shape == v.shape + + @pytest.mark.xfail @handle_frontend_test( fn_tree="jax.random.exponential", @@ -820,6 +882,62 @@ def call(): assert u.shape == v.shape +@pytest.mark.xfail +@handle_frontend_test( + fn_tree="jax.random.logistic", + dtype_key=helpers.dtype_and_values( + available_dtypes=["uint32"], + min_value=0, + max_value=2000, + min_num_dims=1, + max_num_dims=1, + min_dim_size=2, + max_dim_size=2, + ), + shape=helpers.get_shape(allow_none=False, min_num_dims=1, min_dim_size=1), + dtype=helpers.get_dtypes("float", full=False), + test_with_out=st.just(False), +) +def test_jax_logistic( + *, + dtype_key, + shape, + dtype, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, key = dtype_key + + def call(): + return helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + key=key[0], + shape=shape, + dtype=dtype[0], + test_values=False, + ) + + ret = call() + + if not ivy.exists(ret): + return + + ret_np, ret_from_np = ret + ret_np = helpers.flatten_and_to_np(ret=ret_np, backend=backend_fw) + ret_from_np = helpers.flatten_and_to_np(ret=ret_from_np, backend=backend_fw) + for u, v in zip(ret_np, ret_from_np): + assert u.dtype == v.dtype + assert u.shape == v.shape + + @pytest.mark.xfail @handle_frontend_test( fn_tree="jax.random.maxwell", @@ -1466,7 +1584,7 @@ def call(): @pytest.mark.xfail @handle_frontend_test( - fn_tree="jax.random.uniform", + fn_tree="jax.random.ball", dtype_key=helpers.dtype_and_values( available_dtypes=["uint32"], min_value=0, diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_linalg/test_matrix_eigenvalues.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_linalg/test_matrix_eigenvalues.py index d240cef1dc48b..b76655cf0897d 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_linalg/test_matrix_eigenvalues.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_linalg/test_matrix_eigenvalues.py @@ -130,6 +130,55 @@ def test_numpy_eigh( ) +@handle_frontend_test( + fn_tree="numpy.linalg.eigvals", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=0, + max_value=10, + shape=helpers.ints(min_value=2, max_value=4).map(lambda x: tuple([x, x])), + ).filter( + lambda x: "float16" not in x[0] + and "bfloat16" not in x[0] + and np.linalg.cond(x[1][0]) < 1 / sys.float_info.epsilon + and np.linalg.det(np.asarray(x[1][0])) != 0 + ), + test_with_out=st.just(False), +) +def test_numpy_eigvals( + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtype, x = dtype_and_x + ret, frontend_ret = helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + test_values=False, + a=x, + ) + with BackendHandler.update_backend(backend_fw) as ivy_backend: + ret = np.sort( + np.array([ivy_backend.to_numpy(x).astype(np.float128) for x in ret]) + ) + frontend_ret = np.sort(np.array([x.astype(np.float128) for x in frontend_ret])) + assert_all_close( + ret_np=ret, + ret_from_gt_np=frontend_ret, + backend=backend_fw, + ground_truth_backend=frontend, + atol=1e-2, + rtol=1e-2, + ) + + # eigvalsh @handle_frontend_test( fn_tree="numpy.linalg.eigvalsh", diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_attribute.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_attribute.py similarity index 92% rename from ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_attribute.py rename to ivy_tests/test_ivy/test_frontends/test_paddle/test_attribute.py index f0a471073ab0e..c23662391620c 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_attribute.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_attribute.py @@ -6,7 +6,7 @@ @handle_frontend_test( - fn_tree="paddle.tensor.attribute.imag", + fn_tree="paddle.imag", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), ), @@ -33,7 +33,7 @@ def test_paddle_imag( @handle_frontend_test( - fn_tree="paddle.tensor.attribute.is_complex", + fn_tree="paddle.is_complex", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), ), @@ -60,7 +60,7 @@ def test_paddle_is_complex( @handle_frontend_test( - fn_tree="paddle.tensor.attribute.is_floating_point", + fn_tree="paddle.is_floating_point", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), ), @@ -87,7 +87,7 @@ def test_paddle_is_floating_point( @handle_frontend_test( - fn_tree="paddle.tensor.attribute.is_integer", + fn_tree="paddle.is_integer", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), ), @@ -114,7 +114,7 @@ def test_paddle_is_integer( @handle_frontend_test( - fn_tree="paddle.tensor.attribute.rank", + fn_tree="paddle.rank", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), ), @@ -141,7 +141,7 @@ def test_paddle_rank( @handle_frontend_test( - fn_tree="paddle.tensor.attribute.real", + fn_tree="paddle.real", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), ), diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_creation.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_creation.py similarity index 100% rename from ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_creation.py rename to ivy_tests/test_ivy/test_frontends/test_paddle/test_creation.py diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_fft.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_fft.py index e136ea66d7209..a03816f2f7151 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_fft.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_fft.py @@ -220,3 +220,25 @@ def test_paddle_irfft( valid_axis=True, force_int_axis=True, ) + + +@handle_frontend_test( + fn_tree="paddle.fft.rfftfreq", + n=st.integers(min_value=1, max_value=1000), + sample_rate=st.integers(min_value=1, max_value=20), +) +def test_paddle_rfftfreq( + n, sample_rate, backend_fw, frontend, test_flags, fn_tree, on_device +): + d = 1 / sample_rate + helpers.test_frontend_function( + input_dtypes=[int], + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + test_values=True, + n=n, + d=d, + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_linalg.py similarity index 95% rename from ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_linalg.py rename to ivy_tests/test_ivy/test_frontends/test_paddle/test_linalg.py index 1290c944bef8d..ceff4a4f94323 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_linalg.py @@ -12,6 +12,7 @@ ) from ivy_tests.test_ivy.test_frontends.test_tensorflow.test_linalg import ( + _get_first_matrix, _get_second_matrix, _get_cholesky_matrix, ) @@ -287,7 +288,6 @@ def test_paddle_bincount( min_value=-10, max_value=10, ), - aliases=["paddle.tensor.linalg.bmm"], test_with_out=st.just(False), ) def test_paddle_bmm( @@ -314,7 +314,7 @@ def test_paddle_bmm( # cholesky @handle_frontend_test( - fn_tree="paddle.tensor.linalg.cholesky", + fn_tree="paddle.cholesky", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), min_value=0, @@ -349,7 +349,7 @@ def test_paddle_cholesky( @handle_frontend_test( - fn_tree="paddle.tensor.linalg.cholesky_solve", + fn_tree="paddle.cholesky_solve", x=_get_second_matrix(), y=_get_paddle_cholesky_matrix(), test_with_out=st.just(False), @@ -382,7 +382,7 @@ def test_paddle_cholesky_solve( @handle_frontend_test( - fn_tree="paddle.tensor.linalg.cond", + fn_tree="paddle.cond", dtype_and_x=_get_dtype_and_matrix_non_singular(dtypes=["float32", "float64"]), p=st.sampled_from([None, "fro", "nuc", np.inf, -np.inf, 1, -1, 2, -2]), test_with_out=st.just(False), @@ -415,7 +415,7 @@ def test_paddle_cond( # cross @handle_frontend_test( - fn_tree="paddle.tensor.linalg.cross", + fn_tree="paddle.cross", dtype_x_y_axis=dtype_value1_value2_axis( available_dtypes=helpers.get_dtypes("valid"), min_num_dims=1, @@ -489,7 +489,7 @@ def test_paddle_dist( # dot @handle_frontend_test( - fn_tree="paddle.tensor.linalg.dot", + fn_tree="paddle.dot", dtype_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, @@ -523,7 +523,7 @@ def test_paddle_dot( # eig @handle_frontend_test( - fn_tree="paddle.tensor.linalg.eig", + fn_tree="paddle.eig", dtype_and_input=_get_dtype_and_square_matrix(real_and_complex_only=True), test_with_out=st.just(False), ) @@ -570,7 +570,7 @@ def test_paddle_eig( # eigh @handle_frontend_test( - fn_tree="paddle.tensor.linalg.eigh", + fn_tree="paddle.eigh", dtype_and_input=_get_dtype_and_square_matrix(real_and_complex_only=True), UPLO=st.sampled_from(("L", "U")), test_with_out=st.just(False), @@ -620,7 +620,7 @@ def test_paddle_eigh( # eigvals @handle_frontend_test( - fn_tree="paddle.tensor.linalg.eigvals", + fn_tree="paddle.eigvals", dtype_x=_get_dtype_and_square_matrix(real_and_complex_only=True), test_with_out=st.just(False), ) @@ -652,7 +652,7 @@ def test_paddle_eigvals( # eigvalsh @handle_frontend_test( - fn_tree="paddle.tensor.linalg.eigvalsh", + fn_tree="paddle.eigvalsh", dtype_x=_get_dtype_and_square_matrix(real_and_complex_only=True), UPLO=st.sampled_from(("L", "U")), test_with_out=st.just(False), @@ -696,7 +696,6 @@ def test_paddle_eigvalsh( min_value=-10, max_value=10, ), - aliases=["paddle.tensor.linalg.matmul"], transpose_x=st.booleans(), transpose_y=st.booleans(), test_with_out=st.just(False), @@ -729,7 +728,7 @@ def test_paddle_matmul( # matrix_power @handle_frontend_test( - fn_tree="paddle.tensor.linalg.matrix_power", + fn_tree="paddle.matrix_power", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), min_value=0, @@ -763,7 +762,7 @@ def test_paddle_matrix_power( # norm @handle_frontend_test( - fn_tree="paddle.tensor.linalg.norm", + fn_tree="paddle.norm", dtype_values_axis=_dtype_values_axis(), keepdims=st.booleans(), test_with_out=st.just(False), @@ -796,7 +795,7 @@ def test_paddle_norm( # pinv @handle_frontend_test( - fn_tree="paddle.tensor.linalg.pinv", + fn_tree="paddle.pinv", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), min_num_dims=2, @@ -841,7 +840,7 @@ def test_paddle_pinv( # qr @handle_frontend_test( - fn_tree="paddle.tensor.linalg.qr", + fn_tree="paddle.qr", dtype_and_x=_get_dtype_and_matrix(), mode=st.sampled_from(("reduced", "complete")), test_with_out=st.just(False), @@ -873,35 +872,33 @@ def test_paddle_qr( # solve @handle_frontend_test( fn_tree="paddle.solve", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, - min_value=-10, - max_value=10, - ), - aliases=["paddle.tensor.linalg.solve"], + x=_get_first_matrix(), + y=_get_second_matrix(), test_with_out=st.just(False), ) def test_paddle_solve( *, - dtype_x, + x, + y, frontend, - test_flags, backend_fw, + test_flags, fn_tree, on_device, ): - input_dtype, x = dtype_x + input_dtype1, x1 = x + input_dtype2, x2 = y helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, + input_dtypes=[input_dtype1, input_dtype2], backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - y=x[1], + rtol=1e-3, + atol=1e-3, + x=x1, + y=x2, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_logic.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_logic.py similarity index 100% rename from ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_logic.py rename to ivy_tests/test_ivy/test_frontends/test_paddle/test_logic.py diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_manipulation.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_manipulation.py new file mode 100644 index 0000000000000..b42f5286b70fc --- /dev/null +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_manipulation.py @@ -0,0 +1,758 @@ +# global +from hypothesis import strategies as st +import math + +# local +import ivy_tests.test_ivy.helpers as helpers +from ivy_tests.test_ivy.helpers import handle_frontend_test +from ivy_tests.test_ivy.test_functional.test_experimental.test_core.test_manipulation import ( # noqa + _get_dtype_values_k_axes_for_rot90, +) +from ivy_tests.test_ivy.test_frontends.test_torch.test_miscellaneous_ops import ( + _get_repeat_interleaves_args, +) + + +# --- Helpers --- # +# --------------- # + + +# stack +@st.composite +def _arrays_axis_n_dtypes(draw): + num_dims = draw(st.shared(helpers.ints(min_value=2, max_value=5), key="num_dims")) + num_arrays = draw( + st.shared(helpers.ints(min_value=2, max_value=4), key="num_arrays") + ) + common_shape = draw( + helpers.list_of_size( + x=helpers.ints(min_value=2, max_value=3), + size=num_dims - 1, + ) + ) + axis = draw(st.sampled_from(list(range(num_dims)))) + xs = [] + input_dtypes = draw( + helpers.array_dtypes(available_dtypes=draw(helpers.get_dtypes("numeric"))) + ) + dtype = draw(st.sampled_from(input_dtypes)) + for _ in range(num_arrays): + x = draw( + helpers.array_values( + shape=common_shape, + dtype=dtype, + ) + ) + xs.append(x) + input_dtypes = [dtype] * len(input_dtypes) + return xs, input_dtypes, axis + + +# concat +@st.composite +def _arrays_idx_n_dtypes(draw): + num_dims = draw(st.shared(helpers.ints(min_value=1, max_value=4), key="num_dims")) + num_arrays = draw( + st.shared(helpers.ints(min_value=2, max_value=4), key="num_arrays") + ) + common_shape = draw( + helpers.list_of_size( + x=helpers.ints(min_value=2, max_value=3), + size=num_dims - 1, + ) + ) + unique_idx = draw(helpers.ints(min_value=0, max_value=num_dims - 1)) + unique_dims = draw( + helpers.list_of_size( + x=helpers.ints(min_value=2, max_value=3), + size=num_arrays, + ) + ) + xs = [] + input_dtypes = draw( + helpers.array_dtypes(available_dtypes=draw(helpers.get_dtypes("valid"))) + ) + dtype = draw(st.sampled_from(input_dtypes)) + for ud in unique_dims: + x = draw( + helpers.array_values( + shape=common_shape[:unique_idx] + [ud] + common_shape[unique_idx:], + dtype=dtype, + ) + ) + xs.append(x) + input_dtypes = [dtype] * len(input_dtypes) + return xs, input_dtypes, unique_idx + + +@st.composite +def _broadcast_to_helper(draw): + dtype_and_x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=6, + ) + ) + + dtype, x = dtype_and_x + input_shape = x[0].shape + + max_num_dims = 6 - len(input_shape) + shape = draw(helpers.get_shape(max_num_dims=max_num_dims)) + input_shape + + return dtype, x, shape + + +# flip +@st.composite +def _dtype_x_axis(draw, **kwargs): + dtype, x, shape = draw(helpers.dtype_and_values(**kwargs, ret_shape=True)) + axis = draw( + st.lists( + helpers.ints(min_value=0, max_value=len(shape) - 1), + min_size=len(shape), + max_size=len(shape), + unique=True, + ) + ) + return dtype, x, axis + + +# expand +@st.composite +def _expand_helper(draw): + dtype_and_x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=6, + ) + ) + + dtype, x = dtype_and_x + input_shape = x[0].shape + + max_num_dims = 6 - len(input_shape) + shape = draw(helpers.get_shape(max_num_dims=max_num_dims)) + input_shape + + return dtype, x, shape + + +@st.composite +def _gather_helper(draw): + dtype_and_param = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=6, + ) + ) + + dtype_and_indices = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=6, + ) + ) + dtype, param = dtype_and_param + dtype, indices = dtype_and_indices + return dtype, param, indices + + +# split +@st.composite +def _split_helper(draw): + dtypes, values, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=2, + max_num_dims=4, + min_dim_size=2, + max_dim_size=4, + ret_shape=True, + ) + ) + axis = draw(st.sampled_from(range(len(shape)))) + num_eles = shape[axis] + splits = [i for i in range(1, num_eles + 1) if num_eles % i == 0] + num_splits = draw(st.sampled_from(splits)) + return dtypes, values, num_splits, axis + + +# squeeze +@st.composite +def _squeeze_helper(draw): + shape = draw(st.shared(helpers.get_shape(), key="value_shape")) + valid_axes = [] + for index, axis in enumerate(shape): + if axis == 1: + valid_axes.append(index) + valid_axes.insert(0, None) + + return draw(st.sampled_from(valid_axes)) + + +# tile +@st.composite +def _tile_helper(draw): + dtype, x, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=4, + min_dim_size=2, + max_dim_size=3, + ret_shape=True, + ) + ) + repeats = draw( + helpers.list_of_size( + x=helpers.ints(min_value=1, max_value=3), + size=len(shape), + ) + ) + return dtype, x, repeats + + +# Helpers # +# ------ # + + +@st.composite +def dtypes_x_reshape(draw): + shape = draw(helpers.get_shape(min_num_dims=1)) + dtypes, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + shape=shape, + ) + ) + shape = draw( + helpers.get_shape(min_num_dims=1).filter( + lambda s: math.prod(s) == math.prod(shape) + ) + ) + return dtypes, x, shape + + +# --- Main --- # +# ------------ # + + +# abs +@handle_frontend_test( + fn_tree="paddle.abs", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_paddle_abs( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +@handle_frontend_test( + fn_tree="paddle.broadcast_to", + dtype_x_and_shape=_broadcast_to_helper(), +) +def test_paddle_broadcast_to( + *, + dtype_x_and_shape, + on_device, + fn_tree, + backend_fw, + frontend, + test_flags, +): + input_dtype, x, shape = dtype_x_and_shape + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + shape=shape, + ) + + +# cast +@handle_frontend_test( + fn_tree="paddle.cast", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), + dtype=helpers.get_dtypes("valid", full=False), +) +def test_paddle_cast( + *, + dtype_and_x, + dtype, + on_device, + backend_fw, + fn_tree, + frontend, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + dtype=dtype[0], + ) + + +@handle_frontend_test( + fn_tree="paddle.concat", + xs_n_input_dtypes_n_unique_idx=_arrays_idx_n_dtypes(), + test_with_out=st.just(False), +) +def test_paddle_concat( + *, + xs_n_input_dtypes_n_unique_idx, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + xs, input_dtypes, unique_idx = xs_n_input_dtypes_n_unique_idx + helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=xs, + axis=unique_idx, + ) + + +@handle_frontend_test( + fn_tree="paddle.expand", + dtype_x_and_shape=_expand_helper(), +) +def test_paddle_expand( + *, + dtype_x_and_shape, + on_device, + fn_tree, + backend_fw, + frontend, + test_flags, +): + input_dtype, x, shape = dtype_x_and_shape + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + shape=shape, + ) + + +@handle_frontend_test( + fn_tree="paddle.flip", + dtype_x_axis=_dtype_x_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + min_dim_size=1, + ), + test_with_out=st.just(False), +) +def test_paddle_flip( + *, + dtype_x_axis, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, axis = dtype_x_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + axis=axis, + ) + + +@handle_frontend_test( + fn_tree="paddle.gather", + dtype_param_and_indices=_gather_helper(), +) +def test_paddle_gather( + *, + dtype_param_and_indices, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, param, indices = dtype_param_and_indices + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + param=param[0], + indices=indices[0], + ) + + +# repeat_interleave +@handle_frontend_test( + fn_tree="paddle.repeat_interleave", + dtype_values_repeats_axis_output_size=_get_repeat_interleaves_args( + available_dtypes=helpers.get_dtypes("numeric"), + valid_axis=True, + max_num_dims=4, + max_dim_size=4, + ), +) +def test_paddle_repeat_interleave( + *, + dtype_values_repeats_axis_output_size, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtype, values, repeats, axis, _ = dtype_values_repeats_axis_output_size + + helpers.test_frontend_function( + input_dtypes=[dtype[0][0], dtype[1][0]], + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=values[0], + repeats=repeats[0], + axis=axis, + ) + + +# Tests # +# ----- # + + +# reshape +@handle_frontend_test( + fn_tree="paddle.reshape", + dtypes_x_reshape=dtypes_x_reshape(), +) +def test_paddle_reshape( + *, + dtypes_x_reshape, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, shape = dtypes_x_reshape + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + shape=shape, + ) + + +# roll +@handle_frontend_test( + fn_tree="paddle.roll", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=2, + min_dim_size=2, + ), + shift=helpers.ints(min_value=1, max_value=10), + axis=helpers.ints(min_value=-1, max_value=1), + test_with_out=st.just(False), +) +def test_paddle_roll( + *, + dtype_and_x, + shift, + axis, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + shifts=shift, + axis=axis, + ) + + +# rot90 +@handle_frontend_test( + fn_tree="paddle.rot90", + dtype_m_k_axes=_get_dtype_values_k_axes_for_rot90( + available_dtypes=helpers.get_dtypes(kind="valid"), + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + ), +) +def test_paddle_rot90( + *, + dtype_m_k_axes, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, m, k, axes = dtype_m_k_axes + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=m, + k=k, + axes=tuple(axes), + ) + + +@handle_frontend_test( + fn_tree="paddle.split", + dt_x_num_splits_axis=_split_helper(), + test_with_out=st.just(False), +) +def test_paddle_split( + *, + dt_x_num_splits_axis, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtypes, x, num_splits, axis = dt_x_num_splits_axis + helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + num_or_sections=num_splits, + axis=axis, + ) + + +@handle_frontend_test( + fn_tree="paddle.squeeze", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(), key="value_shape"), + ), + axis=_squeeze_helper(), +) +def test_paddle_squeeze( + *, + dtype_and_x, + axis, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + axis=axis, + ) + + +@handle_frontend_test( + fn_tree="paddle.stack", + _arrays_n_dtypes_axis=_arrays_axis_n_dtypes(), + test_with_out=st.just(False), +) +def test_paddle_stack( + *, + _arrays_n_dtypes_axis, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + xs, input_dtypes, axis = _arrays_n_dtypes_axis + helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=xs, + axis=axis, + ) + + +# take_along_axis +@handle_frontend_test( + fn_tree="paddle.take_along_axis", + dtype_indices_axis=helpers.array_indices_axis( + array_dtypes=helpers.get_dtypes(kind="valid"), + indices_dtypes=["int64"], + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + indices_same_dims=True, + ), +) +def test_paddle_take_along_axis( + *, + dtype_indices_axis, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtypes, value, indices, axis, _ = dtype_indices_axis + helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + arr=value, + indices=indices, + axis=axis, + ) + + +@handle_frontend_test( + fn_tree="paddle.tile", + dt_x_repeats=_tile_helper(), + test_with_out=st.just(False), +) +def test_paddle_tile( + *, + dt_x_repeats, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtypes, x, repeats = dt_x_repeats + helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + repeat_times=repeats, + ) + + +# unstack +@handle_frontend_test( + fn_tree="paddle.unstack", + dtypes_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=2, + max_num_dims=2, + max_dim_size=1, + ), + number_positional_args=st.just(1), + axis=st.integers(-1, 0), + test_with_out=st.just(False), +) +def test_paddle_unstack( + *, + dtypes_values, + axis, + on_device, + fn_tree, + backend_fw, + frontend, + test_flags, +): + x_dtype, x = dtypes_values + axis = axis + helpers.test_frontend_function( + input_dtypes=x_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + axis=axis, + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_math.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_math.py new file mode 100644 index 0000000000000..926836080a9b3 --- /dev/null +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_math.py @@ -0,0 +1,2323 @@ +# global +from hypothesis import strategies as st + +# local +import ivy_tests.test_ivy.helpers as helpers +from ivy_tests.test_ivy.helpers import handle_frontend_test +from ivy_tests.test_ivy.test_frontends.test_torch.test_blas_and_lapack_ops import ( + _get_dtype_input_and_matrices, + _get_dtype_and_3dbatch_matrices, +) + + +# --- Helpers --- # +# --------------- # + + +@st.composite +def _test_paddle_take_helper(draw): + mode = draw(st.sampled_from(["raise", "clip", "wrap"])) + + safe_bounds = mode == "raise" + + dtypes, xs, indices, _, _ = draw( + helpers.array_indices_axis( + array_dtypes=helpers.get_dtypes("float_and_integer"), + indices_dtypes=["int32", "int64"], + valid_bounds=safe_bounds, + ) + ) + + return dtypes, xs, indices, mode + + +# --- Main --- # +# ------------ # + + +# abs +@handle_frontend_test( + fn_tree="paddle.abs", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), +) +def test_paddle_abs( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# acos +@handle_frontend_test( + fn_tree="paddle.acos", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_paddle_acos( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + atol=1e-2, + x=x[0], + ) + + +# acosh +@handle_frontend_test( + fn_tree="paddle.acosh", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_paddle_acosh( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + atol=1e-2, + x=x[0], + ) + + +# add +@handle_frontend_test( + fn_tree="paddle.add", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + allow_inf=False, + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="log", + shared_dtype=True, + ), +) +def test_paddle_add( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + fn_tree=fn_tree, + test_flags=test_flags, + on_device=on_device, + x=x[0], + y=x[1], + ) + + +# addmm +@handle_frontend_test( + fn_tree="paddle.addmm", + dtype_input_xy=_get_dtype_and_3dbatch_matrices(with_input=True, input_3d=True), + beta=st.floats( + min_value=-5, + max_value=5, + allow_nan=False, + allow_subnormal=False, + allow_infinity=False, + ), + alpha=st.floats( + min_value=-5, + max_value=5, + allow_nan=False, + allow_subnormal=False, + allow_infinity=False, + ), +) +def test_paddle_addmm( + *, + dtype_input_xy, + beta, + alpha, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, input, x, y = dtype_input_xy + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=input[0], + x=x[0], + y=y[0], + beta=beta, + alpha=alpha, + ) + + +# amax +@handle_frontend_test( + fn_tree="paddle.amax", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + allow_inf=False, + shared_dtype=True, + ), +) +def test_paddle_amax( + *, + dtype_and_x, + on_device, + fn_tree, + backend_fw, + frontend, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + fn_tree=fn_tree, + test_flags=test_flags, + on_device=on_device, + x=x[0], + ) + + +# amin +@handle_frontend_test( + fn_tree="paddle.amin", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + valid_axis=True, + ), + keepdim=st.booleans(), +) +def test_paddle_amin( + *, + dtype_and_x, + keepdim, + on_device, + fn_tree, + backend_fw, + frontend, + test_flags, +): + input_dtype, x, axis = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + fn_tree=fn_tree, + test_flags=test_flags, + on_device=on_device, + x=x[0], + axis=axis, + keepdim=keepdim, + ) + + +@handle_frontend_test( + fn_tree="paddle.angle", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=["float64", "complex64", "complex128"], + ), +) +def test_paddle_angle( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# any +@handle_frontend_test( + fn_tree="paddle.any", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=["bool"], + valid_axis=True, + allow_neg_axes=True, + force_int_axis=True, + min_num_dims=1, + ), +) +def test_paddle_any( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, axis = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + backend_to_test=backend_fw, + x=x[0], + axis=axis, + keepdim=False, + ) + + +# asin +@handle_frontend_test( + fn_tree="paddle.asin", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), +) +def test_paddle_asin( + *, + dtype_and_x, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# asinh +@handle_frontend_test( + fn_tree="paddle.asinh", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_paddle_asinh( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + atol=1e-2, + x=x[0], + ) + + +# atan +@handle_frontend_test( + fn_tree="paddle.atan", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_paddle_atan( + *, + dtype_and_x, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# atan2 +@handle_frontend_test( + fn_tree="paddle.atan2", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + allow_inf=False, + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="log", + shared_dtype=True, + ), +) +def test_paddle_atan2( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + fn_tree=fn_tree, + test_flags=test_flags, + on_device=on_device, + x=x[0], + y=x[1], + ) + + +# atanh +@handle_frontend_test( + fn_tree="paddle.atanh", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_paddle_atanh( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# ceil +@handle_frontend_test( + fn_tree="paddle.ceil", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_paddle_ceil( + *, + dtype_and_x, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# conj +@handle_frontend_test( + fn_tree="paddle.conj", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + ), +) +def test_paddle_conj( + *, + dtype_and_input, + frontend, + backend_fw, + test_flags, + fn_tree, + on_device, +): + input_dtype, x = dtype_and_input + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# cos +@handle_frontend_test( + fn_tree="paddle.cos", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_paddle_cos( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# cosh +@handle_frontend_test( + fn_tree="paddle.cosh", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_paddle_cosh( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + atol=1e-2, + x=x[0], + ) + + +# cumprod +@handle_frontend_test( + fn_tree="paddle.cumprod", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + valid_axis=True, + force_int_axis=True, + min_num_dims=1, + min_value=-5, + max_value=5, + ), +) +def test_paddle_cumprod( + *, + dtype_x_axis, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x, axis = dtype_x_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + dim=axis, + ) + + +# deg2rad +@handle_frontend_test( + fn_tree="paddle.deg2rad", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_paddle_deg2rad( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# diff +@handle_frontend_test( + fn_tree="paddle.diff", + dtype_n_x_n_axis=helpers.dtype_values_axis( + available_dtypes=st.shared(helpers.get_dtypes("valid"), key="dtype"), + min_num_dims=1, + valid_axis=True, + force_int_axis=True, + ), + n=st.integers(min_value=1, max_value=1), + dtype_prepend=helpers.dtype_and_values( + available_dtypes=st.shared(helpers.get_dtypes("valid"), key="dtype"), + min_num_dims=1, + max_num_dims=1, + ), + dtype_append=helpers.dtype_and_values( + available_dtypes=st.shared(helpers.get_dtypes("valid"), key="dtype"), + min_num_dims=1, + max_num_dims=1, + ), +) +def test_paddle_diff( + *, + dtype_n_x_n_axis, + n, + dtype_prepend, + dtype_append, + test_flags, + frontend, + backend_fw, + fn_tree, + on_device, +): + input_dtype, x, axis = dtype_n_x_n_axis + _, prepend = dtype_prepend + _, append = dtype_append + helpers.test_frontend_function( + input_dtypes=input_dtype, + test_flags=test_flags, + frontend=frontend, + backend_to_test=backend_fw, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + n=n, + axis=axis, + prepend=prepend[0], + append=append[0], + ) + + +# digamma +@handle_frontend_test( + fn_tree="paddle.digamma", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + safety_factor_scale="log", + ), +) +def test_paddle_digamma( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + atol=1e-4, + x=x[0], + ) + + +# divide +@handle_frontend_test( + fn_tree="paddle.divide", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + allow_inf=False, + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="log", + shared_dtype=True, + ), +) +def test_paddle_divide( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + y=x[1], + ) + + +# erf +@handle_frontend_test( + fn_tree="paddle.erf", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), +) +def test_paddle_erf( + *, + dtype_and_x, + frontend, + backend_fw, + test_flags, + fn_tree, + on_device, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# exp +@handle_frontend_test( + fn_tree="paddle.exp", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), +) +def test_paddle_exp( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# expm1 +@handle_frontend_test( + fn_tree="paddle.expm1", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), +) +def test_paddle_expm1( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# floor +@handle_frontend_test( + fn_tree="paddle.floor", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), +) +def test_paddle_floor( + *, + dtype_and_x, + frontend, + backend_fw, + test_flags, + fn_tree, + on_device, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# floor_divide +@handle_frontend_test( + fn_tree="paddle.floor_divide", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_value=-10, + max_value=10, + num_arrays=2, + allow_inf=False, + shared_dtype=True, + ), +) +def test_paddle_floor_divide( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + fn_tree=fn_tree, + test_flags=test_flags, + on_device=on_device, + x=x[0], + y=x[1], + atol=1e-5, + ) + + +@handle_frontend_test( + fn_tree="paddle.fmax", + dtypes_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True + ), +) +def test_paddle_fmax( + *, + dtypes_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtypes_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + y=x[1], + ) + + +@handle_frontend_test( + fn_tree="paddle.fmin", + dtypes_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True + ), +) +def test_paddle_fmin( + *, + dtypes_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtypes_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + y=x[1], + ) + + +# frac +@handle_frontend_test( + fn_tree="paddle.frac", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=1, + max_value=1e6, + min_value=-1e6, + ), +) +def test_paddle_frac( + *, + dtype_and_x, + frontend, + backend_fw, + test_flags, + fn_tree, + on_device, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# gcd +@handle_frontend_test( + fn_tree="paddle.gcd", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_value=-100, + max_value=100, + min_num_dims=1, + min_dim_size=1, + num_arrays=2, + shared_dtype=True, + ), +) +def test_paddle_gcd( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + fn_tree=fn_tree, + test_flags=test_flags, + on_device=on_device, + x=x[0], + y=x[1], + ) + + +# heaviside +@handle_frontend_test( + fn_tree="paddle.heaviside", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + allow_inf=False, + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="log", + shared_dtype=True, + ), +) +def test_paddle_heaviside( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + fn_tree=fn_tree, + test_flags=test_flags, + on_device=on_device, + x=x[0], + y=x[1], + ) + + +# inner +@handle_frontend_test( + fn_tree="paddle.inner", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_value=-10, + max_value=10, + num_arrays=2, + shared_dtype=True, + ), +) +def test_paddle_inner( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + fn_tree=fn_tree, + test_flags=test_flags, + on_device=on_device, + x=x[0], + y=x[1], + ) + + +# isfinite +@handle_frontend_test( + fn_tree="paddle.isfinite", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), +) +def test_paddle_isfinite( + *, + dtype_and_x, + frontend, + backend_fw, + test_flags, + fn_tree, + on_device, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# isinf +@handle_frontend_test( + fn_tree="paddle.isinf", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), +) +def test_paddle_isinf( + *, + dtype_and_x, + frontend, + backend_fw, + test_flags, + fn_tree, + on_device, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# isnan +@handle_frontend_test( + fn_tree="paddle.isnan", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), +) +def test_paddle_isnan( + *, + dtype_and_x, + frontend, + backend_fw, + test_flags, + fn_tree, + on_device, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# kron +@handle_frontend_test( + fn_tree="paddle.kron", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + allow_inf=False, + shared_dtype=True, + ), +) +def test_paddle_kron( + *, + dtype_and_x, + on_device, + fn_tree, + backend_fw, + frontend, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + fn_tree=fn_tree, + test_flags=test_flags, + on_device=on_device, + x=x[0], + y=x[1], + ) + + +# lcm +@handle_frontend_test( + fn_tree="paddle.lcm", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + min_num_dims=1, + safety_factor_scale="log", + large_abs_safety_factor=2, + shared_dtype=True, + ), +) +def test_paddle_lcm( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + fn_tree=fn_tree, + test_flags=test_flags, + on_device=on_device, + x=x[0], + y=x[1], + ) + + +# lerp +@handle_frontend_test( + fn_tree="paddle.lerp", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=3, + allow_inf=False, + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="log", + shared_dtype=True, + ), +) +def test_paddle_lerp( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + fn_tree=fn_tree, + test_flags=test_flags, + on_device=on_device, + x=x[0], + y=x[1], + weight=x[2], + ) + + +# lgamma +@handle_frontend_test( + fn_tree="paddle.lgamma", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + safety_factor_scale="log", + ), +) +def test_paddle_lgamma( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + atol=1e-4, + x=x[0], + ) + + +# log +@handle_frontend_test( + fn_tree="paddle.log", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_paddle_log( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# log1p +@handle_frontend_test( + fn_tree="paddle.log1p", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + max_value=1e5, + ), +) +def test_paddle_log1p( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# log2 +@handle_frontend_test( + fn_tree="paddle.log2", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), +) +def test_paddle_log2( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# logit +@handle_frontend_test( + fn_tree="paddle.logit", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_paddle_logit( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + eps=1e-2, + ) + + +# max +@handle_frontend_test( + fn_tree="paddle.max", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_axis=-1, + max_axis=0, + min_num_dims=1, + force_int_axis=False, + ), +) +def test_paddle_max( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x, axis = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + axis=axis, + keepdim=False, + ) + + +# maximum +@handle_frontend_test( + fn_tree="paddle.maximum", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, + ), +) +def test_paddle_maximum( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + y=x[1], + ) + + +# min +@handle_frontend_test( + fn_tree="paddle.min", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_axis=-1, + max_axis=0, + min_num_dims=1, + force_int_axis=False, + ), +) +def test_paddle_min( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x, axis = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + axis=axis, + keepdim=False, + ) + + +@handle_frontend_test( + fn_tree="paddle.minimum", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, + ), +) +def test_paddle_minimum( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + fn_tree=fn_tree, + test_flags=test_flags, + on_device=on_device, + x=x[0], + y=x[1], + ) + + +# mm +@handle_frontend_test( + fn_tree="paddle.mm", + dtype_xy=_get_dtype_input_and_matrices(), +) +def test_paddle_mm( + *, + dtype_xy, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, y = dtype_xy + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x, + mat2=y, + ) + + +# multiply +@handle_frontend_test( + fn_tree="paddle.multiply", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + allow_inf=False, + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="log", + shared_dtype=True, + ), +) +def test_paddle_multiply( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + y=x[1], + ) + + +@handle_frontend_test( + fn_tree="paddle.nanmean", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=1, + allow_nan=True, + ), +) +def test_paddle_nanmean( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + fn_tree=fn_tree, + test_flags=test_flags, + on_device=on_device, + x=x[0], + rtol=1e-04, + atol=1e-04, + ) + + +# nansum +@handle_frontend_test( + fn_tree="paddle.nansum", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + valid_axis=True, + force_int_axis=True, + min_num_dims=1, + allow_nan=True, + ), +) +def test_paddle_nansum( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x, axis = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + axis=axis, + rtol=1e-04, + atol=1e-04, + ) + + +# neg +@handle_frontend_test( + fn_tree="paddle.neg", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=["float32", "float64", "int8", "int16", "int32", "int64"], + ), +) +def test_paddle_neg( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# outer +@handle_frontend_test( + fn_tree="paddle.outer", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + min_num_dims=1, + max_num_dims=1, + shared_dtype=True, + ), +) +def test_paddle_outer( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + fn_tree=fn_tree, + test_flags=test_flags, + on_device=on_device, + x=x[0], + y=x[1], + ) + + +# pow +@handle_frontend_test( + fn_tree="paddle.pow", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + allow_inf=False, + shared_dtype=True, + ), +) +def test_paddle_pow( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + y=x[1], + ) + + +# prod +@handle_frontend_test( + fn_tree="paddle.prod", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_axis=-1, + max_axis=0, + min_num_dims=1, + min_value=-10, + max_value=10, + force_int_axis=False, + allow_nan=False, + ), +) +def test_paddle_prod( + *, + dtype_and_x, + on_device, + backend_fw, + fn_tree, + frontend, + test_flags, +): + input_dtype, x, axis = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + axis=axis, + keepdim=False, + backend_to_test=backend_fw, + ) + + +# rad2deg +@handle_frontend_test( + fn_tree="paddle.rad2deg", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_paddle_rad2deg( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# reciprocal +@handle_frontend_test( + fn_tree="paddle.reciprocal", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_paddle_reciprocal( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# remainder +@handle_frontend_test( + fn_tree="paddle.remainder", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + allow_inf=False, + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="log", + shared_dtype=True, + ), +) +def test_paddle_remainder( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + y=x[1], + ) + + +# round +@handle_frontend_test( + fn_tree="paddle.round", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=1, + ), +) +def test_paddle_round( + *, + dtype_and_x, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# rsqrt +@handle_frontend_test( + fn_tree="paddle.rsqrt", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), +) +def test_paddle_rsqrt( + *, + dtype_and_x, + frontend, + test_flags, + fn_tree, + on_device, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +@handle_frontend_test( + fn_tree="paddle.sgn", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float_and_complex"), + min_num_dims=1, + max_num_dims=1, + min_dim_size=1, + max_dim_size=1, + abs_smallest_val=1e-10, + min_value=-10, + max_value=10, + ), +) +def test_paddle_sgn( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# sign +@handle_frontend_test( + fn_tree="paddle.sign", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_paddle_sign( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# sin +@handle_frontend_test( + fn_tree="paddle.sin", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_paddle_sin( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# diff +# sinh +@handle_frontend_test( + fn_tree="paddle.sinh", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_paddle_sinh( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# sqrt +@handle_frontend_test( + fn_tree="paddle.sqrt", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), +) +def test_paddle_sqrt( + *, + dtype_and_x, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# square +@handle_frontend_test( + fn_tree="paddle.square", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_paddle_square( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# stanh +@handle_frontend_test( + fn_tree="paddle.stanh", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), + scale_a=st.floats(1e-5, 1e5), + scale_b=st.floats(1e-5, 1e5), +) +def test_paddle_stanh( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, + scale_a, + scale_b, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + scale_a=scale_a, + scale_b=scale_b, + ) + + +# subtract +@handle_frontend_test( + fn_tree="paddle.subtract", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + allow_inf=False, + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="log", + shared_dtype=True, + ), +) +def test_paddle_subtract( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + fn_tree=fn_tree, + test_flags=test_flags, + on_device=on_device, + x=x[0], + y=x[1], + ) + + +# take +@handle_frontend_test( + fn_tree="paddle.take", dtype_and_values=_test_paddle_take_helper() +) +def test_paddle_take( + *, + dtype_and_values, + on_device, + fn_tree, + backend_fw, + frontend, + test_flags, +): + dtypes, xs, indices, modes = dtype_and_values + helpers.test_frontend_function( + input_dtypes=dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=xs, + index=indices, + mode=modes, + ) + + +# tan +@handle_frontend_test( + fn_tree="paddle.tan", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_paddle_tan( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + atol=1e-2, + x=x[0], + ) + + +# tanh +@handle_frontend_test( + fn_tree="paddle.tanh", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), +) +def test_paddle_tanh( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + atol=1e-2, + x=x[0], + ) + + +# trunc +@handle_frontend_test( + fn_tree="paddle.trunc", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float", "int"), + ), +) +def test_paddle_trunc( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_activation.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_activation.py index a1b17a2d3dbf2..799755d5378e9 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_activation.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_activation.py @@ -760,6 +760,34 @@ def test_paddle_softsign( ) +# swish +@handle_frontend_test( + fn_tree="paddle.nn.functional.swish", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_paddle_swish( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + # tanh_ @handle_frontend_test( fn_tree="paddle.nn.functional.tanh_", diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_loss.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_loss.py index c6792a06158c3..2a8a9486576d2 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_loss.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_loss.py @@ -720,3 +720,55 @@ def test_paddle_triplet_margin_loss( swap=swap, reduction=reduction, ) + + +@handle_frontend_test( + fn_tree="paddle.nn.functional.multi_label_soft_margin_loss", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + min_value=-2, + max_value=2, + shared_dtype=True, + allow_inf=False, + min_num_dims=2, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + ), + dtype_and_weight=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=2, + min_value=-2, + max_value=2, + ), + reduction=st.sampled_from(["mean", "none", "sum"]), +) +def test_paddle_multi_label_soft_margin_loss( + dtype_and_x, + dtype_and_weight, + reduction, + on_device, + fn_tree, + backend_fw, + frontend, + test_flags, +): + x_dtype, x = dtype_and_x + weight_dtype, weight = dtype_and_weight + helpers.test_frontend_function( + input_dtypes=[ + x_dtype[0], + x_dtype[1], + weight_dtype[0], + ], + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + label=x[1], + weight=weight[0], + reduction=reduction, + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_norm.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_norm.py index aae30f98f79c5..9a311dd8e526d 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_norm.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_norm.py @@ -43,3 +43,44 @@ def test_paddle_layer_norm( bias=offset[0], epsilon=eps, ) + + +# normalize +@handle_frontend_test( + fn_tree="paddle.nn.functional.normalize", + dtype_and_x_and_axis=helpers.arrays_and_axes( + available_dtypes=helpers.get_dtypes(kind="valid"), + num=1, + return_dtype=True, + force_int_axis=True, + ), + p=st.floats(min_value=0.1, max_value=2), + negative_axis=st.booleans(), +) +def test_paddle_normalize( + *, + dtype_and_x_and_axis, + p, + negative_axis, + test_flags, + frontend, + backend_fw, + on_device, + fn_tree, +): + dtype, x, axis = dtype_and_x_and_axis + if axis: + axis = -axis if negative_axis else axis + else: + axis = 0 + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + on_device=on_device, + fn_tree=fn_tree, + x=x[0], + p=p, + axis=axis, + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_vision.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_vision.py index 0659e33a95bf2..a0841ee2e9a82 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_vision.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_vision.py @@ -174,3 +174,47 @@ def test_paddle_pixel_shuffle( data_format=data_format, backend_to_test=backend_fw, ) + + +# pixel_unshuffle +@handle_frontend_test( + fn_tree="paddle.nn.functional.pixel_unshuffle", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=["float32", "float64"], + min_value=0, + min_num_dims=4, + max_num_dims=4, + min_dim_size=3, + ), + factor=helpers.ints(min_value=1), + data_format=st.sampled_from(["NCHW", "NHWC"]), +) +def test_paddle_pixel_unshuffle( + *, + dtype_and_x, + factor, + data_format, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + if data_format == "NCHW": + assume(ivy.shape(x[0])[2] % factor == 0) + assume(ivy.shape(x[0])[3] % factor == 0) + else: + assume(ivy.shape(x[0])[1] % factor == 0) + assume(ivy.shape(x[0])[2] % factor == 0) + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + downscale_factor=factor, + data_format=data_format, + backend_to_test=backend_fw, + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_random.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_random.py new file mode 100644 index 0000000000000..a9c2e27164924 --- /dev/null +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_random.py @@ -0,0 +1,273 @@ +# global +from hypothesis import strategies as st + +# local + +import ivy_tests.test_ivy.helpers as helpers +from ivy_tests.test_ivy.helpers import handle_frontend_test + + +@handle_frontend_test( + fn_tree="paddle.normal", + input_dtypes=st.sampled_from([["float32"], ["float64"]]), + shape=helpers.get_shape( + min_num_dims=1, + min_dim_size=1, + ), + mean=st.floats( + min_value=-10, + max_value=10, + ), + std=st.floats( + min_value=0, + max_value=10, + ), +) +def test_paddle_normal( + input_dtypes, + shape, + mean, + std, + frontend, + backend_fw, + test_flags, + fn_tree, +): + helpers.test_frontend_function( + input_dtypes=input_dtypes, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + test_values=False, + mean=mean, + std=std, + shape=shape, + ) + + +@handle_frontend_test( + fn_tree="paddle.poisson", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=0, + max_value=1000, + min_num_dims=1, + max_num_dims=1, + min_dim_size=2, + max_dim_size=2, + ), +) +def test_paddle_poisson(dtype_and_x, backend_fw, frontend, test_flags, fn_tree): + dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + test_values=False, + x=x[0], + ) + + +@handle_frontend_test( + fn_tree="paddle.rand", + input_dtypes=st.sampled_from(["int32", "int64"]), + shape=helpers.get_shape( + allow_none=False, + min_num_dims=0, + min_dim_size=1, + ), + dtype=helpers.get_dtypes("valid", full=False), +) +def test_paddle_rand( + *, + input_dtypes, + shape, + dtype, + frontend, + backend_fw, + test_flags, + fn_tree, +): + helpers.test_frontend_function( + input_dtypes=[input_dtypes], + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + test_values=False, + shape=shape, + dtype=dtype[0], + ) + + +# randint +@handle_frontend_test( + fn_tree="paddle.randint", + low=helpers.ints(min_value=0, max_value=10), + high=helpers.ints(min_value=11, max_value=20), + dtype=helpers.get_dtypes("integer"), + shape=helpers.get_shape( + allow_none=False, min_num_dims=2, max_num_dims=7, min_dim_size=2 + ), +) +def test_paddle_randint( + low, + high, + dtype, + backend_fw, + frontend, + test_flags, + shape, + fn_tree, +): + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_values=False, + fn_tree=fn_tree, + test_flags=test_flags, + low=low, + high=high, + shape=shape, + ) + + +@handle_frontend_test( + fn_tree="paddle.randint_like", + input_dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=helpers.get_shape( + allow_none=False, min_num_dims=2, max_num_dims=7, min_dim_size=2 + ), + ), + low=st.integers(min_value=0, max_value=10), + high=st.integers(min_value=11, max_value=20), + dtype=helpers.get_dtypes("integer"), +) +def test_paddle_randint_like( + input_dtype_and_x, + low, + high, + dtype, + frontend, + backend_fw, + test_flags, + fn_tree, + on_device, +): + input_dtype, x = input_dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + test_values=False, + x=x[0], + low=low, + high=high, + dtype=dtype[0], + ) + + +@handle_frontend_test( + fn_tree="paddle.randn", + input_dtypes=st.sampled_from(["int32", "int64"]), + shape=helpers.get_shape( + allow_none=False, min_num_dims=1, max_num_dims=1, min_dim_size=2 + ), + dtype=st.sampled_from(["float32", "float64"]), +) +def test_paddle_randn( + *, + input_dtypes, + shape, + dtype, + frontend, + backend_fw, + test_flags, + fn_tree, +): + helpers.test_frontend_function( + input_dtypes=[input_dtypes], + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + test_values=False, + shape=shape, + dtype=dtype, + ) + + +@handle_frontend_test( + fn_tree="paddle.standard_normal", + input_dtypes=st.sampled_from([["int32"], ["int64"]]), + shape=helpers.get_shape( + min_num_dims=1, + min_dim_size=1, + ), + dtype=helpers.get_dtypes("valid", full=False), +) +def test_paddle_standard_normal( + input_dtypes, + shape, + dtype, + frontend, + backend_fw, + test_flags, + fn_tree, +): + helpers.test_frontend_function( + input_dtypes=input_dtypes, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + test_values=False, + shape=shape, + dtype=dtype[0], + ) + + +@handle_frontend_test( + fn_tree="paddle.uniform", + input_dtypes=helpers.get_dtypes("float"), + shape=st.tuples( + st.integers(min_value=2, max_value=5), st.integers(min_value=2, max_value=5) + ), + dtype=helpers.get_dtypes("valid", full=False), + min=st.floats(allow_nan=False, allow_infinity=False, width=32), + max=st.floats(allow_nan=False, allow_infinity=False, width=32), + seed=st.integers(min_value=2, max_value=5), +) +def test_paddle_uniform( + input_dtypes, + shape, + dtype, + min, + max, + seed, + frontend, + backend_fw, + test_flags, + fn_tree, +): + helpers.test_frontend_function( + input_dtypes=input_dtypes, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + test_values=False, + shape=shape, + dtype=dtype[0], + min=min, + max=max, + seed=seed, + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_search.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_search.py similarity index 100% rename from ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_search.py rename to ivy_tests/test_ivy/test_frontends/test_paddle/test_search.py diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_stat.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_stat.py similarity index 100% rename from ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_stat.py rename to ivy_tests/test_ivy/test_frontends/test_paddle/test_stat.py diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_einsum.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_einsum.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py index be36b2a116f1c..c1265c500864e 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py @@ -1,6 +1,5 @@ # global from hypothesis import strategies as st -import math # local import ivy_tests.test_ivy.helpers as helpers @@ -14,226 +13,6 @@ # --------------- # -# stack -@st.composite -def _arrays_axis_n_dtypes(draw): - num_dims = draw(st.shared(helpers.ints(min_value=2, max_value=5), key="num_dims")) - num_arrays = draw( - st.shared(helpers.ints(min_value=2, max_value=4), key="num_arrays") - ) - common_shape = draw( - helpers.list_of_size( - x=helpers.ints(min_value=2, max_value=3), - size=num_dims - 1, - ) - ) - axis = draw(st.sampled_from(list(range(num_dims)))) - xs = [] - input_dtypes = draw( - helpers.array_dtypes(available_dtypes=draw(helpers.get_dtypes("numeric"))) - ) - dtype = draw(st.sampled_from(input_dtypes)) - for _ in range(num_arrays): - x = draw( - helpers.array_values( - shape=common_shape, - dtype=dtype, - ) - ) - xs.append(x) - input_dtypes = [dtype] * len(input_dtypes) - return xs, input_dtypes, axis - - -# concat -@st.composite -def _arrays_idx_n_dtypes(draw): - num_dims = draw(st.shared(helpers.ints(min_value=1, max_value=4), key="num_dims")) - num_arrays = draw( - st.shared(helpers.ints(min_value=2, max_value=4), key="num_arrays") - ) - common_shape = draw( - helpers.list_of_size( - x=helpers.ints(min_value=2, max_value=3), - size=num_dims - 1, - ) - ) - unique_idx = draw(helpers.ints(min_value=0, max_value=num_dims - 1)) - unique_dims = draw( - helpers.list_of_size( - x=helpers.ints(min_value=2, max_value=3), - size=num_arrays, - ) - ) - xs = [] - input_dtypes = draw( - helpers.array_dtypes(available_dtypes=draw(helpers.get_dtypes("valid"))) - ) - dtype = draw(st.sampled_from(input_dtypes)) - for ud in unique_dims: - x = draw( - helpers.array_values( - shape=common_shape[:unique_idx] + [ud] + common_shape[unique_idx:], - dtype=dtype, - ) - ) - xs.append(x) - input_dtypes = [dtype] * len(input_dtypes) - return xs, input_dtypes, unique_idx - - -@st.composite -def _broadcast_to_helper(draw): - dtype_and_x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=6, - ) - ) - - dtype, x = dtype_and_x - input_shape = x[0].shape - - max_num_dims = 6 - len(input_shape) - shape = draw(helpers.get_shape(max_num_dims=max_num_dims)) + input_shape - - return dtype, x, shape - - -# flip -@st.composite -def _dtype_x_axis(draw, **kwargs): - dtype, x, shape = draw(helpers.dtype_and_values(**kwargs, ret_shape=True)) - axis = draw( - st.lists( - helpers.ints(min_value=0, max_value=len(shape) - 1), - min_size=len(shape), - max_size=len(shape), - unique=True, - ) - ) - return dtype, x, axis - - -# expand -@st.composite -def _expand_helper(draw): - dtype_and_x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=6, - ) - ) - - dtype, x = dtype_and_x - input_shape = x[0].shape - - max_num_dims = 6 - len(input_shape) - shape = draw(helpers.get_shape(max_num_dims=max_num_dims)) + input_shape - - return dtype, x, shape - - -@st.composite -def _gather_helper(draw): - dtype_and_param = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=6, - ) - ) - - dtype_and_indices = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=6, - ) - ) - dtype, param = dtype_and_param - dtype, indices = dtype_and_indices - return dtype, param, indices - - -# split -@st.composite -def _split_helper(draw): - dtypes, values, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=2, - max_num_dims=4, - min_dim_size=2, - max_dim_size=4, - ret_shape=True, - ) - ) - axis = draw(st.sampled_from(range(len(shape)))) - num_eles = shape[axis] - splits = [i for i in range(1, num_eles + 1) if num_eles % i == 0] - num_splits = draw(st.sampled_from(splits)) - return dtypes, values, num_splits, axis - - -# squeeze -@st.composite -def _squeeze_helper(draw): - shape = draw(st.shared(helpers.get_shape(), key="value_shape")) - valid_axes = [] - for index, axis in enumerate(shape): - if axis == 1: - valid_axes.append(index) - valid_axes.insert(0, None) - - return draw(st.sampled_from(valid_axes)) - - -# tile -@st.composite -def _tile_helper(draw): - dtype, x, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=4, - min_dim_size=2, - max_dim_size=3, - ret_shape=True, - ) - ) - repeats = draw( - helpers.list_of_size( - x=helpers.ints(min_value=1, max_value=3), - size=len(shape), - ) - ) - return dtype, x, repeats - - -# Helpers # -# ------ # - - -@st.composite -def dtypes_x_reshape(draw): - shape = draw(helpers.get_shape(min_num_dims=1)) - dtypes, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - shape=shape, - ) - ) - shape = draw( - helpers.get_shape(min_num_dims=1).filter( - lambda s: math.prod(s) == math.prod(shape) - ) - ) - return dtypes, x, shape - - @st.composite def dtypes_x_reshape_(draw): shape = draw(helpers.get_shape(min_num_dims=1)) @@ -246,239 +25,9 @@ def dtypes_x_reshape_(draw): return dtypes, x, shape -# --- Main --- # -# ------------ # - - -# abs -@handle_frontend_test( - fn_tree="paddle.abs", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), -) -def test_paddle_abs( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -@handle_frontend_test( - fn_tree="paddle.broadcast_to", - dtype_x_and_shape=_broadcast_to_helper(), -) -def test_paddle_broadcast_to( - *, - dtype_x_and_shape, - on_device, - fn_tree, - backend_fw, - frontend, - test_flags, -): - input_dtype, x, shape = dtype_x_and_shape - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - shape=shape, - ) - - -# cast -@handle_frontend_test( - fn_tree="paddle.cast", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), - dtype=helpers.get_dtypes("valid", full=False), -) -def test_paddle_cast( - *, - dtype_and_x, - dtype, - on_device, - backend_fw, - fn_tree, - frontend, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - dtype=dtype[0], - ) - - -@handle_frontend_test( - fn_tree="paddle.concat", - xs_n_input_dtypes_n_unique_idx=_arrays_idx_n_dtypes(), - test_with_out=st.just(False), -) -def test_paddle_concat( - *, - xs_n_input_dtypes_n_unique_idx, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - xs, input_dtypes, unique_idx = xs_n_input_dtypes_n_unique_idx - helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=xs, - axis=unique_idx, - ) - - -@handle_frontend_test( - fn_tree="paddle.expand", - dtype_x_and_shape=_expand_helper(), -) -def test_paddle_expand( - *, - dtype_x_and_shape, - on_device, - fn_tree, - backend_fw, - frontend, - test_flags, -): - input_dtype, x, shape = dtype_x_and_shape - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - shape=shape, - ) - - -@handle_frontend_test( - fn_tree="paddle.flip", - dtype_x_axis=_dtype_x_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - min_dim_size=1, - ), - test_with_out=st.just(False), -) -def test_paddle_flip( - *, - dtype_x_axis, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, axis = dtype_x_axis - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - axis=axis, - ) - - -@handle_frontend_test( - fn_tree="paddle.gather", - dtype_param_and_indices=_gather_helper(), -) -def test_paddle_gather( - *, - dtype_param_and_indices, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, param, indices = dtype_param_and_indices - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - param=param[0], - indices=indices[0], - ) - - -# Tests # -# ----- # - - -# reshape -@handle_frontend_test( - fn_tree="paddle.reshape", - dtypes_x_reshape=dtypes_x_reshape(), -) -def test_paddle_reshape( - *, - dtypes_x_reshape, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, shape = dtypes_x_reshape - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - shape=shape, - ) - - # reshape_ @handle_frontend_test( - fn_tree="paddle.reshape_", + fn_tree="paddle.tensor.manipulation.reshape_", dtypes_x_reshape=dtypes_x_reshape_(), ) def test_paddle_reshape_( @@ -501,260 +50,3 @@ def test_paddle_reshape_( x=x[0], shape=shape, ) - - -# roll -@handle_frontend_test( - fn_tree="paddle.roll", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=2, - min_dim_size=2, - ), - shift=helpers.ints(min_value=1, max_value=10), - axis=helpers.ints(min_value=-1, max_value=1), - test_with_out=st.just(False), -) -def test_paddle_roll( - *, - dtype_and_x, - shift, - axis, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - shifts=shift, - axis=axis, - ) - - -# rot90 -@handle_frontend_test( - fn_tree="paddle.rot90", - dtype_m_k_axes=_get_dtype_values_k_axes_for_rot90( - available_dtypes=helpers.get_dtypes(kind="valid"), - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, - ), -) -def test_paddle_rot90( - *, - dtype_m_k_axes, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, m, k, axes = dtype_m_k_axes - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=m, - k=k, - axes=tuple(axes), - ) - - -@handle_frontend_test( - fn_tree="paddle.split", - dt_x_num_splits_axis=_split_helper(), - test_with_out=st.just(False), -) -def test_paddle_split( - *, - dt_x_num_splits_axis, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtypes, x, num_splits, axis = dt_x_num_splits_axis - helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - num_or_sections=num_splits, - axis=axis, - ) - - -@handle_frontend_test( - fn_tree="paddle.squeeze", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(), key="value_shape"), - ), - axis=_squeeze_helper(), -) -def test_paddle_squeeze( - *, - dtype_and_x, - axis, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - axis=axis, - ) - - -@handle_frontend_test( - fn_tree="paddle.stack", - _arrays_n_dtypes_axis=_arrays_axis_n_dtypes(), - test_with_out=st.just(False), -) -def test_paddle_stack( - *, - _arrays_n_dtypes_axis, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - xs, input_dtypes, axis = _arrays_n_dtypes_axis - helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=xs, - axis=axis, - ) - - -# take_along_axis -@handle_frontend_test( - fn_tree="paddle.take_along_axis", - dtype_indices_axis=helpers.array_indices_axis( - array_dtypes=helpers.get_dtypes(kind="valid"), - indices_dtypes=["int64"], - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, - indices_same_dims=True, - ), -) -def test_paddle_take_along_axis( - *, - dtype_indices_axis, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtypes, value, indices, axis, _ = dtype_indices_axis - helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - arr=value, - indices=indices, - axis=axis, - ) - - -@handle_frontend_test( - fn_tree="paddle.tile", - dt_x_repeats=_tile_helper(), - test_with_out=st.just(False), -) -def test_paddle_tile( - *, - dt_x_repeats, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtypes, x, repeats = dt_x_repeats - helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - repeat_times=repeats, - ) - - -# unstack -@handle_frontend_test( - fn_tree="paddle.unstack", - dtypes_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=2, - max_num_dims=2, - max_dim_size=1, - ), - number_positional_args=st.just(1), - axis=st.integers(-1, 0), - test_with_out=st.just(False), -) -def test_paddle_unstack( - *, - dtypes_values, - axis, - on_device, - fn_tree, - backend_fw, - frontend, - test_flags, -): - x_dtype, x = dtypes_values - axis = axis - helpers.test_frontend_function( - input_dtypes=x_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - axis=axis, - ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py index 387200e0c8590..692fec8fa6e5d 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py @@ -1,110 +1,23 @@ -# global -from hypothesis import strategies as st - # local import ivy_tests.test_ivy.helpers as helpers from ivy_tests.test_ivy.helpers import handle_frontend_test -from ivy_tests.test_ivy.test_frontends.test_torch.test_blas_and_lapack_ops import ( - _get_dtype_input_and_matrices, - _get_dtype_and_3dbatch_matrices, -) - - -# --- Helpers --- # -# --------------- # - - -@st.composite -def _test_paddle_take_helper(draw): - mode = draw(st.sampled_from(["raise", "clip", "wrap"])) - - safe_bounds = mode == "raise" - - dtypes, xs, indices, _, _ = draw( - helpers.array_indices_axis( - array_dtypes=helpers.get_dtypes("float_and_integer"), - indices_dtypes=["int32", "int64"], - valid_bounds=safe_bounds, - ) - ) - - return dtypes, xs, indices, mode - - -# --- Main --- # -# ------------ # - - -# abs -@handle_frontend_test( - fn_tree="paddle.tensor.math.abs", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), -) -def test_paddle_abs( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) -# acos +# ceil_ @handle_frontend_test( - fn_tree="paddle.acos", + fn_tree="paddle.tensor.math.ceil_", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_acos( +def test_paddle_ceil_( *, dtype_and_x, - on_device, - fn_tree, frontend, test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - atol=1e-2, - x=x[0], - ) - - -# acosh -@handle_frontend_test( - fn_tree="paddle.tensor.math.acosh", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), -) -def test_paddle_acosh( - *, - dtype_and_x, - on_device, fn_tree, - frontend, - test_flags, backend_fw, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -114,163 +27,52 @@ def test_paddle_acosh( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - atol=1e-2, x=x[0], ) -# add +# exp_ @handle_frontend_test( - fn_tree="paddle.add", + fn_tree="paddle.tensor.math.exp_", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - allow_inf=False, - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="log", - shared_dtype=True, ), ) -def test_paddle_add( +def test_paddle_exp_( *, dtype_and_x, on_device, fn_tree, frontend, - test_flags, backend_fw, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, - fn_tree=fn_tree, - test_flags=test_flags, - on_device=on_device, - x=x[0], - y=x[1], - ) - - -# addmm -@handle_frontend_test( - fn_tree="paddle.tensor.math.addmm", - dtype_input_xy=_get_dtype_and_3dbatch_matrices(with_input=True, input_3d=True), - beta=st.floats( - min_value=-5, - max_value=5, - allow_nan=False, - allow_subnormal=False, - allow_infinity=False, - ), - alpha=st.floats( - min_value=-5, - max_value=5, - allow_nan=False, - allow_subnormal=False, - allow_infinity=False, - ), -) -def test_paddle_addmm( - *, - dtype_input_xy, - beta, - alpha, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, input, x, y = dtype_input_xy - helpers.test_frontend_function( - input_dtypes=input_dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input[0], x=x[0], - y=y[0], - beta=beta, - alpha=alpha, ) -# amax +# lerp_ @handle_frontend_test( - fn_tree="paddle.tensor.math.amax", + fn_tree="paddle.tensor.math.lerp_", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, + available_dtypes=helpers.get_dtypes("float"), + num_arrays=3, allow_inf=False, + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="log", shared_dtype=True, ), ) -def test_paddle_amax( - *, - dtype_and_x, - on_device, - fn_tree, - backend_fw, - frontend, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - fn_tree=fn_tree, - test_flags=test_flags, - on_device=on_device, - x=x[0], - ) - - -# amin -@handle_frontend_test( - fn_tree="paddle.tensor.math.amin", - dtype_and_x=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - valid_axis=True, - ), - keepdim=st.booleans(), -) -def test_paddle_amin( - *, - dtype_and_x, - keepdim, - on_device, - fn_tree, - backend_fw, - frontend, - test_flags, -): - input_dtype, x, axis = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - fn_tree=fn_tree, - test_flags=test_flags, - on_device=on_device, - x=x[0], - axis=axis, - keepdim=keepdim, - ) - - -@handle_frontend_test( - fn_tree="paddle.tensor.math.angle", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=["float64", "complex64", "complex128"], - ), -) -def test_paddle_angle( +def test_paddle_lerp_( *, dtype_and_x, on_device, @@ -284,83 +86,23 @@ def test_paddle_angle( input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# any -@handle_frontend_test( - fn_tree="paddle.tensor.math.any", - dtype_and_x=helpers.dtype_values_axis( - available_dtypes=["bool"], - valid_axis=True, - allow_neg_axes=True, - force_int_axis=True, - min_num_dims=1, - ), -) -def test_paddle_any( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, axis = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - backend_to_test=backend_fw, - x=x[0], - axis=axis, - keepdim=False, - ) - - -# asin -@handle_frontend_test( - fn_tree="paddle.tensor.math.asin", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), -) -def test_paddle_asin( - *, - dtype_and_x, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, - fn_tree=fn_tree, on_device=on_device, x=x[0], + y=x[1], + weight=x[2], ) -# asinh +# reciprocal_ @handle_frontend_test( - fn_tree="paddle.tensor.math.asinh", + fn_tree="paddle.tensor.math.reciprocal_", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_asinh( +def test_paddle_reciprocal_( *, dtype_and_x, on_device, @@ -377,19 +119,19 @@ def test_paddle_asinh( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - atol=1e-2, x=x[0], ) -# atan +# round_ @handle_frontend_test( - fn_tree="paddle.tensor.math.atan", + fn_tree="paddle.tensor.math.round_", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + min_value=1, ), ) -def test_paddle_atan( +def test_paddle_round_( *, dtype_and_x, frontend, @@ -410,55 +152,20 @@ def test_paddle_atan( ) -# atan2 +# rsqrt_ @handle_frontend_test( - fn_tree="paddle.atan2", + fn_tree="paddle.tensor.math.rsqrt_", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - allow_inf=False, - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="log", - shared_dtype=True, + available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_paddle_atan2( +def test_paddle_rsqrt_( *, dtype_and_x, - on_device, - fn_tree, frontend, - backend_fw, test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - fn_tree=fn_tree, - test_flags=test_flags, - on_device=on_device, - x=x[0], - y=x[1], - ) - - -# atanh -@handle_frontend_test( - fn_tree="paddle.tensor.math.atanh", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), -) -def test_paddle_atanh( - *, - dtype_and_x, - on_device, fn_tree, - frontend, - test_flags, + on_device, backend_fw, ): input_dtype, x = dtype_and_x @@ -473,19 +180,19 @@ def test_paddle_atanh( ) -# ceil +# sqrt_ @handle_frontend_test( - fn_tree="paddle.tensor.math.ceil", + fn_tree="paddle.tensor.math.sqrt_", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_paddle_ceil( +def test_paddle_sqrt_( *, dtype_and_x, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, on_device, ): @@ -499,1964 +206,3 @@ def test_paddle_ceil( on_device=on_device, x=x[0], ) - - -# ceil_ -@handle_frontend_test( - fn_tree="paddle.tensor.math.ceil_", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), -) -def test_paddle_ceil_( - *, - dtype_and_x, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# conj -@handle_frontend_test( - fn_tree="paddle.tensor.math.conj", - dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - ), -) -def test_paddle_conj( - *, - dtype_and_input, - frontend, - backend_fw, - test_flags, - fn_tree, - on_device, -): - input_dtype, x = dtype_and_input - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# cos -@handle_frontend_test( - fn_tree="paddle.cos", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), -) -def test_paddle_cos( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# cosh -@handle_frontend_test( - fn_tree="paddle.tensor.math.cosh", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), -) -def test_paddle_cosh( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - atol=1e-2, - x=x[0], - ) - - -# cumprod -@handle_frontend_test( - fn_tree="paddle.tensor.math.cumprod", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - valid_axis=True, - force_int_axis=True, - min_num_dims=1, - min_value=-5, - max_value=5, - ), -) -def test_paddle_cumprod( - *, - dtype_x_axis, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x, axis = dtype_x_axis - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - dim=axis, - ) - - -# deg2rad -@handle_frontend_test( - fn_tree="paddle.deg2rad", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), -) -def test_paddle_deg2rad( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# diff -@handle_frontend_test( - fn_tree="paddle.tensor.math.diff", - dtype_n_x_n_axis=helpers.dtype_values_axis( - available_dtypes=st.shared(helpers.get_dtypes("valid"), key="dtype"), - min_num_dims=1, - valid_axis=True, - force_int_axis=True, - ), - n=st.integers(min_value=1, max_value=1), - dtype_prepend=helpers.dtype_and_values( - available_dtypes=st.shared(helpers.get_dtypes("valid"), key="dtype"), - min_num_dims=1, - max_num_dims=1, - ), - dtype_append=helpers.dtype_and_values( - available_dtypes=st.shared(helpers.get_dtypes("valid"), key="dtype"), - min_num_dims=1, - max_num_dims=1, - ), -) -def test_paddle_diff( - *, - dtype_n_x_n_axis, - n, - dtype_prepend, - dtype_append, - test_flags, - frontend, - backend_fw, - fn_tree, - on_device, -): - input_dtype, x, axis = dtype_n_x_n_axis - _, prepend = dtype_prepend - _, append = dtype_append - helpers.test_frontend_function( - input_dtypes=input_dtype, - test_flags=test_flags, - frontend=frontend, - backend_to_test=backend_fw, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - n=n, - axis=axis, - prepend=prepend[0], - append=append[0], - ) - - -# digamma -@handle_frontend_test( - fn_tree="paddle.tensor.math.digamma", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - safety_factor_scale="log", - ), -) -def test_paddle_digamma( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - atol=1e-4, - x=x[0], - ) - - -# divide -@handle_frontend_test( - fn_tree="paddle.divide", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - allow_inf=False, - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="log", - shared_dtype=True, - ), -) -def test_paddle_divide( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - y=x[1], - ) - - -# erf -@handle_frontend_test( - fn_tree="paddle.tensor.math.erf", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), -) -def test_paddle_erf( - *, - dtype_and_x, - frontend, - backend_fw, - test_flags, - fn_tree, - on_device, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# exp -@handle_frontend_test( - fn_tree="paddle.exp", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), -) -def test_paddle_exp( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# exp_ -@handle_frontend_test( - fn_tree="paddle.tensor.math.exp_", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), -) -def test_paddle_exp_( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# expm1 -@handle_frontend_test( - fn_tree="paddle.expm1", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), -) -def test_paddle_expm1( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# floor -@handle_frontend_test( - fn_tree="paddle.tensor.math.floor", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), -) -def test_paddle_floor( - *, - dtype_and_x, - frontend, - backend_fw, - test_flags, - fn_tree, - on_device, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -@handle_frontend_test( - fn_tree="paddle.fmax", - dtypes_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True - ), -) -def test_paddle_fmax( - *, - dtypes_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtypes_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - y=x[1], - ) - - -@handle_frontend_test( - fn_tree="paddle.fmin", - dtypes_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True - ), -) -def test_paddle_fmin( - *, - dtypes_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtypes_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - y=x[1], - ) - - -# frac -@handle_frontend_test( - fn_tree="paddle.tensor.math.frac", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=1, - max_value=1e6, - min_value=-1e6, - ), -) -def test_paddle_frac( - *, - dtype_and_x, - frontend, - backend_fw, - test_flags, - fn_tree, - on_device, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# gcd -@handle_frontend_test( - fn_tree="paddle.gcd", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_value=-100, - max_value=100, - min_num_dims=1, - min_dim_size=1, - num_arrays=2, - shared_dtype=True, - ), -) -def test_paddle_gcd( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - fn_tree=fn_tree, - test_flags=test_flags, - on_device=on_device, - x=x[0], - y=x[1], - ) - - -# heaviside -@handle_frontend_test( - fn_tree="paddle.tensor.math.heaviside", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - allow_inf=False, - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="log", - shared_dtype=True, - ), -) -def test_paddle_heaviside( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - fn_tree=fn_tree, - test_flags=test_flags, - on_device=on_device, - x=x[0], - y=x[1], - ) - - -# inner -@handle_frontend_test( - fn_tree="paddle.tensor.math.inner", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_value=-10, - max_value=10, - num_arrays=2, - shared_dtype=True, - ), -) -def test_paddle_inner( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - fn_tree=fn_tree, - test_flags=test_flags, - on_device=on_device, - x=x[0], - y=x[1], - ) - - -# isfinite -@handle_frontend_test( - fn_tree="paddle.isfinite", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), -) -def test_paddle_isfinite( - *, - dtype_and_x, - frontend, - backend_fw, - test_flags, - fn_tree, - on_device, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# isinf -@handle_frontend_test( - fn_tree="paddle.isinf", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), -) -def test_paddle_isinf( - *, - dtype_and_x, - frontend, - backend_fw, - test_flags, - fn_tree, - on_device, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# isnan -@handle_frontend_test( - fn_tree="paddle.isnan", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), -) -def test_paddle_isnan( - *, - dtype_and_x, - frontend, - backend_fw, - test_flags, - fn_tree, - on_device, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# kron -@handle_frontend_test( - fn_tree="paddle.tensor.math.kron", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - allow_inf=False, - shared_dtype=True, - ), -) -def test_paddle_kron( - *, - dtype_and_x, - on_device, - fn_tree, - backend_fw, - frontend, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - fn_tree=fn_tree, - test_flags=test_flags, - on_device=on_device, - x=x[0], - y=x[1], - ) - - -# lcm -@handle_frontend_test( - fn_tree="paddle.lcm", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - min_num_dims=1, - safety_factor_scale="log", - large_abs_safety_factor=2, - shared_dtype=True, - ), -) -def test_paddle_lcm( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - fn_tree=fn_tree, - test_flags=test_flags, - on_device=on_device, - x=x[0], - y=x[1], - ) - - -# lerp -@handle_frontend_test( - fn_tree="paddle.tensor.math.lerp", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=3, - allow_inf=False, - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="log", - shared_dtype=True, - ), -) -def test_paddle_lerp( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - fn_tree=fn_tree, - test_flags=test_flags, - on_device=on_device, - x=x[0], - y=x[1], - weight=x[2], - ) - - -# lgamma -@handle_frontend_test( - fn_tree="paddle.tensor.math.lgamma", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - safety_factor_scale="log", - ), -) -def test_paddle_lgamma( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - atol=1e-4, - x=x[0], - ) - - -# log -@handle_frontend_test( - fn_tree="paddle.log", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), -) -def test_paddle_log( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# log1p -@handle_frontend_test( - fn_tree="paddle.log1p", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - max_value=1e5, - ), -) -def test_paddle_log1p( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# log2 -@handle_frontend_test( - fn_tree="paddle.log2", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), -) -def test_paddle_log2( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# logit -@handle_frontend_test( - fn_tree="paddle.logit", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), -) -def test_paddle_logit( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - eps=1e-2, - ) - - -# max -@handle_frontend_test( - fn_tree="paddle.tensor.math.max", - dtype_and_x=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - min_axis=-1, - max_axis=0, - min_num_dims=1, - force_int_axis=False, - ), -) -def test_paddle_max( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x, axis = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - axis=axis, - keepdim=False, - ) - - -# maximum -@handle_frontend_test( - fn_tree="paddle.maximum", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, - ), -) -def test_paddle_maximum( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - y=x[1], - ) - - -# min -@handle_frontend_test( - fn_tree="paddle.tensor.math.min", - dtype_and_x=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - min_axis=-1, - max_axis=0, - min_num_dims=1, - force_int_axis=False, - ), -) -def test_paddle_min( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x, axis = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - axis=axis, - keepdim=False, - ) - - -@handle_frontend_test( - fn_tree="paddle.minimum", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, - ), -) -def test_paddle_minimum( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - fn_tree=fn_tree, - test_flags=test_flags, - on_device=on_device, - x=x[0], - y=x[1], - ) - - -# mm -@handle_frontend_test( - fn_tree="paddle.tensor.math.mm", - dtype_xy=_get_dtype_input_and_matrices(), -) -def test_paddle_mm( - *, - dtype_xy, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, y = dtype_xy - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x, - mat2=y, - ) - - -# multiply -@handle_frontend_test( - fn_tree="paddle.multiply", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - allow_inf=False, - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="log", - shared_dtype=True, - ), -) -def test_paddle_multiply( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - y=x[1], - ) - - -# nansum -@handle_frontend_test( - fn_tree="paddle.tensor.math.nansum", - dtype_and_x=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - valid_axis=True, - force_int_axis=True, - min_num_dims=1, - allow_nan=True, - ), -) -def test_paddle_nansum( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x, axis = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - axis=axis, - rtol=1e-04, - atol=1e-04, - ) - - -# neg -@handle_frontend_test( - fn_tree="paddle.neg", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=["float32", "float64", "int8", "int16", "int32", "int64"], - ), -) -def test_paddle_neg( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# outer -@handle_frontend_test( - fn_tree="paddle.tensor.math.outer", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - min_num_dims=1, - max_num_dims=1, - shared_dtype=True, - ), -) -def test_paddle_outer( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - fn_tree=fn_tree, - test_flags=test_flags, - on_device=on_device, - x=x[0], - y=x[1], - ) - - -# pow -@handle_frontend_test( - fn_tree="paddle.pow", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - allow_inf=False, - shared_dtype=True, - ), -) -def test_paddle_pow( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - y=x[1], - ) - - -# prod -@handle_frontend_test( - fn_tree="paddle.tensor.math.prod", - dtype_and_x=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_axis=-1, - max_axis=0, - min_num_dims=1, - min_value=-10, - max_value=10, - force_int_axis=False, - allow_nan=False, - ), -) -def test_paddle_prod( - *, - dtype_and_x, - on_device, - backend_fw, - fn_tree, - frontend, - test_flags, -): - input_dtype, x, axis = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - axis=axis, - keepdim=False, - backend_to_test=backend_fw, - ) - - -# rad2deg -@handle_frontend_test( - fn_tree="paddle.rad2deg", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), -) -def test_paddle_rad2deg( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# reciprocal -@handle_frontend_test( - fn_tree="paddle.reciprocal", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), -) -def test_paddle_reciprocal( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# reciprocal_ -@handle_frontend_test( - fn_tree="paddle.tensor.math.reciprocal_", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), -) -def test_paddle_reciprocal_( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# remainder -@handle_frontend_test( - fn_tree="paddle.remainder", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - allow_inf=False, - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="log", - shared_dtype=True, - ), -) -def test_paddle_remainder( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - y=x[1], - ) - - -# round -@handle_frontend_test( - fn_tree="paddle.tensor.math.round", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=1, - ), -) -def test_paddle_round( - *, - dtype_and_x, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# round_ -@handle_frontend_test( - fn_tree="paddle.tensor.math.round_", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=1, - ), -) -def test_paddle_round_( - *, - dtype_and_x, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# rsqrt -@handle_frontend_test( - fn_tree="paddle.tensor.math.rsqrt", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), -) -def test_paddle_rsqrt( - *, - dtype_and_x, - frontend, - test_flags, - fn_tree, - on_device, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# rsqrt_ -@handle_frontend_test( - fn_tree="paddle.tensor.math.rsqrt_", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), -) -def test_paddle_rsqrt_( - *, - dtype_and_x, - frontend, - test_flags, - fn_tree, - on_device, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -@handle_frontend_test( - fn_tree="paddle.tensor.math.sgn", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex"), - min_num_dims=1, - max_num_dims=1, - min_dim_size=1, - max_dim_size=1, - abs_smallest_val=1e-10, - min_value=-10, - max_value=10, - ), -) -def test_paddle_sgn( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# sign -@handle_frontend_test( - fn_tree="paddle.tensor.math.sign", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), -) -def test_paddle_sign( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# sin -@handle_frontend_test( - fn_tree="paddle.sin", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), -) -def test_paddle_sin( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# floor_divide -@handle_frontend_test( - fn_tree="paddle.tensor.math.floor_divide", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_value=-10, - max_value=10, - num_arrays=2, - allow_inf=False, - shared_dtype=True, - ), -) -def test_paddle_floor_divide( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - fn_tree=fn_tree, - test_flags=test_flags, - on_device=on_device, - x=x[0], - y=x[1], - atol=1e-5, - ) - - -# diff -# sinh -@handle_frontend_test( - fn_tree="paddle.sinh", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), -) -def test_paddle_sinh( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# sqrt -@handle_frontend_test( - fn_tree="paddle.tensor.math.sqrt", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), -) -def test_paddle_sqrt( - *, - dtype_and_x, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# sqrt_ -@handle_frontend_test( - fn_tree="paddle.tensor.math.sqrt_", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), -) -def test_paddle_sqrt_( - *, - dtype_and_x, - fn_tree, - frontend, - test_flags, - backend_fw, - on_device, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# square -@handle_frontend_test( - fn_tree="paddle.tensor.math.square", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), -) -def test_paddle_square( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# stanh -@handle_frontend_test( - fn_tree="paddle.tensor.math.stanh", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), - scale_a=st.floats(1e-5, 1e5), - scale_b=st.floats(1e-5, 1e5), -) -def test_paddle_stanh( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, - scale_a, - scale_b, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - scale_a=scale_a, - scale_b=scale_b, - ) - - -# subtract -@handle_frontend_test( - fn_tree="paddle.subtract", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - allow_inf=False, - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="log", - shared_dtype=True, - ), -) -def test_paddle_subtract( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - fn_tree=fn_tree, - test_flags=test_flags, - on_device=on_device, - x=x[0], - y=x[1], - ) - - -# take -@handle_frontend_test( - fn_tree="paddle.take", dtype_and_values=_test_paddle_take_helper() -) -def test_paddle_take( - *, - dtype_and_values, - on_device, - fn_tree, - backend_fw, - frontend, - test_flags, -): - dtypes, xs, indices, modes = dtype_and_values - helpers.test_frontend_function( - input_dtypes=dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=xs, - index=indices, - mode=modes, - ) - - -# tan -@handle_frontend_test( - fn_tree="paddle.tensor.math.tan", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), -) -def test_paddle_tan( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - atol=1e-2, - x=x[0], - ) - - -# tanh -@handle_frontend_test( - fn_tree="paddle.tensor.math.tanh", - aliases=["paddle.tanh", "paddle.nn.functional.tanh"], - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), -) -def test_paddle_tanh( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - atol=1e-2, - x=x[0], - ) - - -# trunc -@handle_frontend_test( - fn_tree="paddle.trunc", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float", "int"), - ), -) -def test_paddle_trunc( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_random.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_random.py index 93fdb86789953..8880032bf12b2 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_random.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_random.py @@ -2,274 +2,38 @@ from hypothesis import strategies as st # local - import ivy_tests.test_ivy.helpers as helpers from ivy_tests.test_ivy.helpers import handle_frontend_test @handle_frontend_test( - fn_tree="paddle.normal", - input_dtypes=st.sampled_from([["float32"], ["float64"]]), - shape=helpers.get_shape( - min_num_dims=1, - min_dim_size=1, - ), - mean=st.floats( - min_value=-10, - max_value=10, - ), - std=st.floats( - min_value=0, - max_value=10, - ), -) -def test_paddle_normal( - input_dtypes, - shape, - mean, - std, - frontend, - backend_fw, - test_flags, - fn_tree, -): - helpers.test_frontend_function( - input_dtypes=input_dtypes, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - test_values=False, - mean=mean, - std=std, - shape=shape, - ) - - -@handle_frontend_test( - fn_tree="paddle.poisson", + fn_tree="paddle.tensor.random.exponential_", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), min_value=0, max_value=1000, min_num_dims=1, - max_num_dims=1, + max_num_dims=10, min_dim_size=2, - max_dim_size=2, - ), -) -def test_paddle_poisson(dtype_and_x, backend_fw, frontend, test_flags, fn_tree): - dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - test_values=False, - x=x[0], - ) - - -@handle_frontend_test( - fn_tree="paddle.rand", - input_dtypes=st.sampled_from(["int32", "int64"]), - shape=helpers.get_shape( - allow_none=False, - min_num_dims=0, - min_dim_size=1, - ), - dtype=helpers.get_dtypes("valid", full=False), -) -def test_paddle_rand( - *, - input_dtypes, - shape, - dtype, - frontend, - backend_fw, - test_flags, - fn_tree, -): - helpers.test_frontend_function( - input_dtypes=[input_dtypes], - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - test_values=False, - shape=shape, - dtype=dtype[0], - ) - - -# randint -@handle_frontend_test( - fn_tree="paddle.randint", - low=helpers.ints(min_value=0, max_value=10), - high=helpers.ints(min_value=11, max_value=20), - dtype=helpers.get_dtypes("integer"), - shape=helpers.get_shape( - allow_none=False, min_num_dims=2, max_num_dims=7, min_dim_size=2 + max_dim_size=10, ), ) -def test_paddle_randint( - low, - high, - dtype, - backend_fw, - frontend, - test_flags, - shape, +def test_paddle_exponential_( fn_tree, -): - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_values=False, - fn_tree=fn_tree, - test_flags=test_flags, - low=low, - high=high, - shape=shape, - ) - - -@handle_frontend_test( - fn_tree="paddle.randint_like", - input_dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=helpers.get_shape( - allow_none=False, min_num_dims=2, max_num_dims=7, min_dim_size=2 - ), - ), - low=st.integers(min_value=0, max_value=10), - high=st.integers(min_value=11, max_value=20), - dtype=helpers.get_dtypes("integer"), -) -def test_paddle_randint_like( - input_dtype_and_x, - low, - high, - dtype, + dtype_and_x, frontend, backend_fw, test_flags, - fn_tree, - on_device, ): - input_dtype, x = input_dtype_and_x + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, test_values=False, x=x[0], - low=low, - high=high, - dtype=dtype[0], - ) - - -@handle_frontend_test( - fn_tree="paddle.randn", - input_dtypes=st.sampled_from(["int32", "int64"]), - shape=helpers.get_shape( - allow_none=False, min_num_dims=1, max_num_dims=1, min_dim_size=2 - ), - dtype=st.sampled_from(["float32", "float64"]), -) -def test_paddle_randn( - *, - input_dtypes, - shape, - dtype, - frontend, - backend_fw, - test_flags, - fn_tree, -): - helpers.test_frontend_function( - input_dtypes=[input_dtypes], - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - test_values=False, - shape=shape, - dtype=dtype, - ) - - -@handle_frontend_test( - fn_tree="paddle.standard_normal", - input_dtypes=st.sampled_from([["int32"], ["int64"]]), - shape=helpers.get_shape( - min_num_dims=1, - min_dim_size=1, - ), - dtype=helpers.get_dtypes("valid", full=False), -) -def test_paddle_standard_normal( - input_dtypes, - shape, - dtype, - frontend, - backend_fw, - test_flags, - fn_tree, -): - helpers.test_frontend_function( - input_dtypes=input_dtypes, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - test_values=False, - shape=shape, - dtype=dtype[0], - ) - - -@handle_frontend_test( - fn_tree="paddle.uniform", - input_dtypes=helpers.get_dtypes("float"), - shape=st.tuples( - st.integers(min_value=2, max_value=5), st.integers(min_value=2, max_value=5) - ), - dtype=helpers.get_dtypes("valid", full=False), - min=st.floats(allow_nan=False, allow_infinity=False, width=32), - max=st.floats(allow_nan=False, allow_infinity=False, width=32), - seed=st.integers(min_value=2, max_value=5), -) -def test_paddle_uniform( - input_dtypes, - shape, - dtype, - min, - max, - seed, - frontend, - backend_fw, - test_flags, - fn_tree, -): - helpers.test_frontend_function( - input_dtypes=input_dtypes, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - test_values=False, - shape=shape, - dtype=dtype[0], - min=min, - max=max, - seed=seed, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py index 56f4cbe2b990f..7ceeba34f4bff 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py @@ -127,6 +127,30 @@ def _get_dtype_and_square_matrix(draw): return dtype, mat +@st.composite +def _get_dtype_and_values_for_lerp(draw): + is_tensor = draw(st.booleans()) + if is_tensor: + input_dtype, x = draw( + helpers.dtype_and_values( + num_arrays=3, + available_dtypes=helpers.get_dtypes("valid"), + shared_dtype=True, + ) + ) + return input_dtype, x[0], x[1], x[2] + else: + input_dtype, x = draw( + helpers.dtype_and_values( + num_arrays=2, + available_dtypes=helpers.get_dtypes("valid"), + shared_dtype=True, + ) + ) + weight = draw(st.floats()) + return input_dtype, x[0], x[1], weight + + @st.composite def _reshape_helper(draw): # generate a shape s.t len(shape) > 0 @@ -2154,6 +2178,40 @@ def test_paddle_tensor_isnan( ) +# lerp +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="paddle.to_tensor", + method_name="lerp", + dtypes_and_x=_get_dtype_and_values_for_lerp(), +) +def test_paddle_tensor_lerp( + dtypes_and_x, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, + backend_fw, +): + input_dtype, x, y, weight = dtypes_and_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={"data": x}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "y": y, + "weight": weight, + }, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + + # lerp_ @handle_frontend_method( class_tree=CLASS_TREE, @@ -2629,6 +2687,42 @@ def test_paddle_tensor_numel( ) +# pow +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="paddle.to_tensor", + method_name="pow", + dtypes_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + allow_inf=False, + shared_dtype=True, + ), +) +def test_paddle_tensor_pow( + dtypes_and_x, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, + backend_fw, +): + input_dtype, x = dtypes_and_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={"data": x[0]}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={"y": x[1]}, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + + # rad2deg @handle_frontend_method( class_tree=CLASS_TREE, @@ -3199,6 +3293,48 @@ def test_paddle_tensor_squeeze_( ) +# stanh +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="paddle.to_tensor", + method_name="stanh", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), + scale_a=st.floats(1e-5, 1e5), + scale_b=st.floats(1e-5, 1e5), +) +def test_paddle_tensor_stanh( + dtype_and_x, + frontend_method_data, + scale_a, + scale_b, + init_flags, + method_flags, + frontend, + backend_fw, + on_device, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + init_all_as_kwargs_np={ + "data": x[0], + }, + method_input_dtypes=input_dtype, + backend_to_test=backend_fw, + method_all_as_kwargs_np={ + "scale_a": scale_a, + "scale_b": scale_b, + }, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + + # subtract @handle_frontend_method( class_tree=CLASS_TREE, @@ -3300,6 +3436,56 @@ def test_paddle_tensor_tanh( ) +# trace +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="paddle.to_tensor", + method_name="trace", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=1, + min_num_dims=2, + min_value=-1e04, + max_value=1e04, + allow_inf=False, + ), + offset=st.integers(min_value=-1e04, max_value=1e04), + axis1=st.integers(min_value=0, max_value=0), + axis2=st.integers(min_value=1, max_value=1), +) +def test_paddle_tensor_trace( + dtype_and_x, + offset, + axis1, + axis2, + frontend, + backend_fw, + frontend_method_data, + init_flags, + method_flags, + on_device, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + init_all_as_kwargs_np={ + "value": x[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "offset": offset, + "axis1": axis1, + "axis2": axis2, + }, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + backend_to_test=backend_fw, + on_device=on_device, + ) + + @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", diff --git a/ivy_tests/test_ivy/test_frontends/test_scipy/test_linalg/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_scipy/test_linalg/test_linalg.py index bb2f61a1c45b0..d73718550615a 100644 --- a/ivy_tests/test_ivy/test_frontends/test_scipy/test_linalg/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_scipy/test_linalg/test_linalg.py @@ -362,31 +362,36 @@ def test_scipy_svd( # svdvals @handle_frontend_test( fn_tree="scipy.linalg.svdvals", - dtype_x=helpers.dtype_and_values( + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_value=0, + min_value=0.1, max_value=50, - min_num_dims=2, + shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), ), + check_finite=st.booleans(), test_with_out=st.just(False), ) def test_scipy_svdvals( - dtype_x, + dtype_and_x, + check_finite, frontend, test_flags, fn_tree, - on_device, backend_fw, + on_device, ): - dtype, x = dtype_x + dtype, x = dtype_and_x + x = x[0] helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, + test_values=False, fn_tree=fn_tree, on_device=on_device, - a=x[0], + a=x, + check_finite=check_finite, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_sklearn/test_metrics/test_classification.py b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_metrics/test_classification.py index 7977c77df3d03..d94f489a169b0 100644 --- a/ivy_tests/test_ivy/test_frontends/test_sklearn/test_metrics/test_classification.py +++ b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_metrics/test_classification.py @@ -43,3 +43,44 @@ def test_sklearn_accuracy_score( normalize=normalize, sample_weight=None, ) + + +@handle_frontend_test( + fn_tree="sklearn.metrics.precision_score", + arrays_and_dtypes=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float_and_integer"), + num_arrays=2, + min_value=0, + max_value=1, # Precision score typically works with binary classification + shared_dtype=True, + shape=(helpers.ints(min_value=2, max_value=5)), + ), + average=st.sampled_from(["micro", "macro", "weighted"]), +) +def test_sklearn_precision_score( + arrays_and_dtypes, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, + average, +): + dtypes, values = arrays_and_dtypes + + # Ensure binary classification labels (0 and 1) + for i in range(2): + values[i] = np.round(values[i]) + + helpers.test_frontend_function( + input_dtypes=dtypes, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + frontend=frontend, + on_device=on_device, + y_true=values[0], + y_pred=values[1], + average=average, + sample_weight=None, + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_general_functions.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_general_functions.py index 9e114650dcc4f..b703153a9896f 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_general_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_general_functions.py @@ -1151,6 +1151,51 @@ def test_tensorflow_linspace( ) +# meshgrid +@handle_frontend_test( + fn_tree="tensorflow.meshgrid", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + max_num_dims=2, + min_num_dims=2, + min_dim_size=2, + max_dim_size=5, + ), + indexing=st.sampled_from(["xy", "ij"]), + test_with_out=st.just(False), +) +def test_tensorflow_meshgrid( + *, + dtype_and_values, + indexing, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + dtype, arrays = dtype_and_values + arrays = arrays[0] + kwargs = {} + + for i, array in enumerate(arrays): + kwargs[f"a{i}"] = array + + kwargs["indexing"] = indexing + + test_flags.num_positional_args = len(arrays) + test_flags.generate_frontend_arrays = False + helpers.test_frontend_function( + input_dtypes=dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + **kwargs, + ) + + # no_op @handle_frontend_test( fn_tree="tensorflow.no_op", diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py index db0e0a9de988f..48935cb1b6021 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py @@ -761,6 +761,36 @@ def test_tensorflow_cumsum( # NOQA ) +# digamma +@handle_frontend_test( + fn_tree="tensorflow.math.digamma", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=1, + ), + test_with_out=st.just(False), +) +def test_tensorflow_digamma( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + # divide @handle_frontend_test( fn_tree="tensorflow.math.divide", diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_raw_ops.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_raw_ops.py index 0ef72b5d8634c..e734dfa5406f7 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_raw_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_raw_ops.py @@ -236,7 +236,7 @@ def _pow_helper_shared_dtype(draw): dtype1, dtype2 = dtype x1, x2 = x if "int" in dtype2: - x2 = ivy.nested_map(x2, lambda x: abs(x), include_derived={list: True}) + x2 = ivy.nested_map(x2, lambda x: abs(x), include_derived={"list": True}) if ivy.is_int_dtype(dtype2): max_val = ivy.iinfo(dtype2).max diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_creation_ops.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_creation_ops.py index fcbd20246b6b3..ebcbee64b678d 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_creation_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_creation_ops.py @@ -41,6 +41,39 @@ def _as_strided_helper(draw): return x_dtype, x, size, stride, offset +@st.composite +def _as_tensor_helper(draw): + dtype_and_x = draw( + st.one_of( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), + st.floats(), + st.integers(), + st.lists(st.one_of(st.floats(), st.integers()), min_size=1), + ) + ) + if isinstance(dtype_and_x, tuple): + input_dtype = dtype_and_x[0] + x = dtype_and_x[1][0] + else: + input_dtype = [] + x = dtype_and_x + dtype = draw( + st.one_of( + helpers.get_castable_dtype( + draw(helpers.get_dtypes("valid")), + dtype=draw(helpers.get_dtypes("valid", full=False))[0], + x=x, + ), + st.none(), + ) + ) + if isinstance(dtype, tuple): + dtype = dtype[0] + return input_dtype, x, dtype + + # Helper functions @@ -188,31 +221,39 @@ def test_torch_as_strided( # as_tensor @handle_frontend_test( fn_tree="torch.as_tensor", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), - dtype=helpers.get_dtypes("valid", full=False), + dtype_x_dtype=_as_tensor_helper(), ) def test_torch_as_tensor( *, - dtype_and_x, - dtype, + dtype_x_dtype, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, input = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - data=input[0], - dtype=dtype[0], - device=on_device, - ) + input_dtype, x, dtype = dtype_x_dtype + # ToDo: fix get_castable_dtype to avoid the exceptions + try: + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + data=x, + dtype=dtype, + device=on_device, + ) + except Exception as e: + if any( + error_string in str(e) + for error_string in ["overflow", "too large to convert to"] + ): + assume(False) + else: + raise # asarray diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py index 577967c5f3e1d..7d11bfa7b5531 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py @@ -60,6 +60,39 @@ def _generate_multi_dot_dtype_and_arrays(draw): return input_dtype, [matrix_1[1][0], matrix_2[1][0], matrix_3[1][0]] +@st.composite +def _get_axis_and_p(draw): + p = draw(st.sampled_from(["fro", "nuc", 1, 2, -1, -2, float("inf"), -float("inf")])) + if p == "fro" or p == "nuc": + max_axes_size = 2 + min_axes_size = 2 + else: + min_axes_size = 1 + max_axes_size = 5 + x_dtype, values, axis = draw( + helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=2, + valid_axis=True, + min_value=-1e04, + max_value=1e04, + min_axes_size=min_axes_size, + max_axes_size=max_axes_size, + large_abs_safety_factor=2, + safety_factor_scale="log", + ) + ) + axis = axis[0] if isinstance(axis, tuple) and len(axis) == 1 else axis + # ToDo: fix the castable dtype helper. Right now using `dtype` causes errors + # dtype should be real for real inputs, but got ComplexDouble + x_dtype, values, dtype = draw( + helpers.get_castable_dtype( + draw(helpers.get_dtypes("valid")), x_dtype[0], values[0] + ) + ) + return p, x_dtype, values, axis, x_dtype + + # helpers @st.composite def _get_dtype_and_matrix( @@ -483,6 +516,50 @@ def test_torch_eigh( ) +@handle_frontend_test( + fn_tree="torch.linalg.eigh", + dtype_and_x=_get_dtype_and_matrix(dtype="valid", square=True, invertible=True), + UPLO=st.sampled_from(("L", "U")), +) +def test_torch_eigh( + *, + dtype_and_x, + UPLO, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtype, x = dtype_and_x + x = np.array(x[0], dtype=dtype[0]) + # make symmetric positive-definite beforehand + x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 + + ret, frontend_ret = helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + test_values=False, + a=x, + UPLO=UPLO, + ) + ret = [ivy.to_numpy(x) for x in ret] + frontend_ret = [np.asarray(x) for x in frontend_ret] + + L, Q = ret + frontend_L, frontend_Q = frontend_ret + + assert_all_close( + ret_np=Q @ np.diag(L) @ Q.T, + ret_from_gt_np=frontend_Q @ np.diag(frontend_L) @ frontend_Q.T, + atol=1e-02, + ) + + # eigvals @handle_frontend_test( fn_tree="torch.linalg.eigvals", @@ -868,6 +945,40 @@ def test_torch_multi_dot( ) +@handle_frontend_test( + fn_tree="torch.linalg.norm", + args=_get_axis_and_p(), + keepdim=st.booleans(), + test_with_out=st.just(False), +) +def test_torch_norm( + *, + args, + keepdim, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + p, x_dtype, x, axis, dtype = args + helpers.test_frontend_function( + input_dtypes=[x_dtype], + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + rtol=1e-01, + atol=1e-08, + input=x, + ord=p, + dim=axis, + keepdim=keepdim, + dtype=dtype, + ) + + # pinv # TODO: add testing for hermitian @handle_frontend_test( diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_vision_functions.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_vision_functions.py index 5dbf4104d5d95..9c34752b567e5 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_vision_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_vision_functions.py @@ -112,6 +112,63 @@ def _pad_helper(draw): return dtype, input[0], padding, value, mode +@st.composite +def grid_sample_helper(draw, dtype, mode, mode_3d, padding_mode): + dtype = draw(dtype) + align_corners = draw(st.booleans()) + dims = draw(st.integers(4, 5)) + height = draw(helpers.ints(min_value=5, max_value=10)) + width = draw(helpers.ints(min_value=5, max_value=10)) + channels = draw(helpers.ints(min_value=1, max_value=3)) + + grid_h = draw(helpers.ints(min_value=2, max_value=4)) + grid_w = draw(helpers.ints(min_value=2, max_value=4)) + batch = draw(helpers.ints(min_value=1, max_value=5)) + + padding_mode = draw(st.sampled_from(padding_mode)) + if dims == 4: + mode = draw(st.sampled_from(mode)) + x = draw( + helpers.array_values( + dtype=dtype[0], + shape=[batch, channels, height, width], + min_value=-1, + max_value=1, + ) + ) + + grid = draw( + helpers.array_values( + dtype=dtype[0], + shape=[batch, grid_h, grid_w, 2], + min_value=-1, + max_value=1, + ) + ) + elif dims == 5: + mode = draw(st.sampled_from(mode_3d)) + depth = draw(helpers.ints(min_value=10, max_value=15)) + grid_d = draw(helpers.ints(min_value=5, max_value=10)) + x = draw( + helpers.array_values( + dtype=dtype[0], + shape=[batch, channels, depth, height, width], + min_value=-1, + max_value=1, + ) + ) + + grid = draw( + helpers.array_values( + dtype=dtype[0], + shape=[batch, grid_d, grid_h, grid_w, 3], + min_value=-1, + max_value=1, + ) + ) + return dtype, x, grid, mode, padding_mode, align_corners + + # --- Main --- # # ------------ # @@ -144,6 +201,40 @@ def test_torch_affine_grid( ) +@handle_frontend_test( + fn_tree="torch.nn.functional.grid_sample", + dtype_x_grid_modes=grid_sample_helper( + dtype=helpers.get_dtypes("valid", full=False), + mode=["nearest", "bilinear", "bicubic"], + mode_3d=["nearest", "bilinear"], + padding_mode=["border", "zeros", "reflection"], + ), +) +def test_torch_grid_sample( + *, + dtype_x_grid_modes, + on_device, + backend_fw, + fn_tree, + frontend, + test_flags, +): + dtype, x, grid, mode, padding_mode, align_corners = dtype_x_grid_modes + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x, + grid=grid, + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + ) + + @handle_frontend_test( fn_tree="torch.nn.functional.interpolate", dtype_and_input_and_other=_interp_args( diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_pointwise_ops.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_pointwise_ops.py index a502749812473..0577ad06d862b 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_pointwise_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_pointwise_ops.py @@ -2153,6 +2153,30 @@ def test_torch_mul( ) +# mvlgamma +@handle_frontend_test( + fn_tree="torch.mvlgamma", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float") + ), + p=helpers.ints(min_value=1, max_value=11), +) +def test_torch_mvlgamma( + *, dtype_and_input, frontend, test_flags, fn_tree, backend_fw, on_device, p +): + input_dtype, input = dtype_and_input + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=input[0], + p=p, + ) + + @handle_frontend_test( fn_tree="torch.nan_to_num", dtype_and_x=helpers.dtype_and_values( diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index bbca3cff7ec10..cdd54b5cf2ccd 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -7851,6 +7851,51 @@ def test_torch_tensor_lcm( ) +# lcm_ +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="lcm_", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + num_arrays=2, + min_value=-100, + max_value=100, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + shared_dtype=True, + ), +) +def test_torch_tensor_lcm_( + dtype_and_x, + frontend, + frontend_method_data, + init_flags, + method_flags, + on_device, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "other": x[1], + }, + frontend=frontend, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + on_device=on_device, + ) + + # less @handle_frontend_method( class_tree=CLASS_TREE, @@ -12774,6 +12819,48 @@ def test_torch_tensor_zero_( ) +# triu +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="triu", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=2, + max_num_dims=5, + min_dim_size=1, + max_dim_size=5, + ), + diagonal=st.integers( + min_value=-4, + max_value=4, + ), +) +def test_torch_triu( + dtype_x, + diagonal, + frontend, + frontend_method_data, + init_flags, + method_flags, + on_device, + backend_fw, +): + input_dtype, x = dtype_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={"data": x[0]}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={"diagonal": diagonal}, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + + # triu_ @handle_frontend_method( class_tree=CLASS_TREE, diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_device.py b/ivy_tests/test_ivy/test_functional/test_core/test_device.py index b4bf40ba7fe16..ec6c8b6a84351 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_device.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_device.py @@ -691,7 +691,7 @@ def test_to_device( # check if native arrays are the same # these backends do not support native inplace updates - assume(not (backend_fw in ["tensorflow", "jax"])) + assume(backend_fw not in ["tensorflow", "jax"]) assert x_on_dev.data is out.data diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py index 22cc88c61f6be..5fd064506ed80 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py @@ -126,21 +126,26 @@ def cast_filter(dtype1_x1_dtype2): max_value = int(math.log(max_val) / math.log(max_x1)) if abs(max_value) > abs(max_val) / 40 or max_value < 0: max_value = None - dtype2, x2 = draw( - helpers.dtype_and_values( - small_abs_safety_factor=16, - large_abs_safety_factor=16, - safety_factor_scale="log", - max_value=max_value, - dtype=[dtype2], + dtype_and_x2 = draw( + st.one_of( + helpers.dtype_and_values( + small_abs_safety_factor=16, + large_abs_safety_factor=16, + safety_factor_scale="log", + max_value=max_value, + dtype=[dtype2], + ), + st.floats(max_value=max_value), + st.integers(max_value=max_value), ) ) - dtype2 = dtype2[0] - if "int" in dtype2: - x2 = ivy.nested_map( - x2[0], lambda x: abs(x), include_derived={list: True}, shallow=False - ) - return [dtype1, dtype2], [x1, x2] + input_dtypes = [dtype1] + if isinstance(dtype_and_x2, tuple): + input_dtypes += dtype_and_x2[0] + x2 = dtype_and_x2[1][0] + else: + x2 = dtype_and_x2 + return input_dtypes, [x1[0], x2] # --- Main --- # @@ -148,8 +153,7 @@ def cast_filter(dtype1_x1_dtype2): def not_too_close_to_zero(x): - f = np.vectorize(lambda item: item + (_one if np.isclose(item, 0) else _zero)) - return f(x) + return np.where(np.isclose(x, 0), x + 1, x) # abs @@ -936,7 +940,7 @@ def test_gcd(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): def test_greater(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x # bfloat16 is not supported - assume(not ("bfloat16" in input_dtype)) + assume("bfloat16" not in input_dtype) helpers.test_function( input_dtypes=input_dtype, test_flags=test_flags, @@ -959,7 +963,7 @@ def test_greater(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): def test_greater_equal(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x # bfloat16 is not supported by numpy - assume(not ("bfloat16" in input_dtype)) + assume("bfloat16" not in input_dtype) # make sure they're not too close together assume(not (np.any(np.isclose(x[0], x[1])) or np.any(np.isclose(x[1], x[0])))) helpers.test_function( @@ -1142,7 +1146,7 @@ def test_lcm(dtype_and_x, test_flags, backend_fw, fn_name, on_device): def test_less(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x # bfloat16 is not supported by numpy - assume(not ("bfloat16" in input_dtype)) + assume("bfloat16" not in input_dtype) # make sure they're not too close together assume(not (np.any(np.isclose(x[0], x[1])) or np.any(np.isclose(x[1], x[0])))) helpers.test_function( @@ -1167,7 +1171,7 @@ def test_less(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): def test_less_equal(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x # bfloat16 is not supported by numpy - assume(not ("bfloat16" in input_dtype)) + assume("bfloat16" not in input_dtype) # make sure they're not too close together assume(not (np.any(np.isclose(x[0], x[1])) or np.any(np.isclose(x[1], x[0])))) helpers.test_function( @@ -1590,22 +1594,12 @@ def test_positive(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="functional.ivy.pow", dtype_and_x=pow_helper(), + test_gradients=st.just(False), ) def test_pow(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x - # bfloat16 is not supported by numpy assume(not ("bfloat16" in input_dtype)) - - # Make sure x2 isn't a float when x1 is integer - assume( - not (ivy.is_int_dtype(input_dtype[0] and ivy.is_float_dtype(input_dtype[1]))) - ) - - # Make sure x2 is non-negative when both is integer - if ivy.is_int_dtype(input_dtype[1]) and ivy.is_int_dtype(input_dtype[0]): - x[1] = np.abs(x[1]) - x[0] = not_too_close_to_zero(x[0]) x[1] = not_too_close_to_zero(x[1]) helpers.test_function( diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_general.py b/ivy_tests/test_ivy/test_functional/test_core/test_general.py index 7f22a60bf9d39..65a1114e4c653 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_general.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_general.py @@ -1621,12 +1621,6 @@ def test_scatter_nd(x, reduction, test_flags, backend_fw, fn_name, on_device): # ------# -@given(fw_str=st.sampled_from(["numpy", "jax", "torch", "tensorflow"])) -def test_set_framework(fw_str): - ivy.set_backend(fw_str) - ivy.previous_backend() - - @pytest.mark.parametrize("mode", ["lenient", "strict"]) def test_set_inplace_mode(mode): ivy.set_inplace_mode(mode) diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_creation.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_creation.py index c3d5686656b19..a982a667e5c85 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_creation.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_creation.py @@ -308,6 +308,46 @@ def test_kaiser_window( ) +# mel_weight_matrix +@handle_test( + fn_tree="functional.ivy.experimental.mel_weight_matrix", + num_mel_bins=helpers.ints(min_value=5, max_value=10), + dft_length=helpers.ints(min_value=5, max_value=10), + sample_rate=helpers.ints(min_value=1000, max_value=2000), + lower_edge_hertz=helpers.floats(min_value=0.0, max_value=5.0), + upper_edge_hertz=helpers.floats(min_value=5.0, max_value=10.0), + test_with_out=st.just(False), + test_gradients=st.just(False), + test_instance_method=st.just(False), +) +def test_mel_weight_matrix( + *, + num_mel_bins, + dft_length, + sample_rate, + lower_edge_hertz, + upper_edge_hertz, + test_flags, + backend_fw, + fn_name, + on_device, +): + helpers.test_function( + input_dtypes=[], + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + rtol_=0.05, + atol_=0.05, + fn_name=fn_name, + num_mel_bins=num_mel_bins, + dft_length=dft_length, + sample_rate=sample_rate, + lower_edge_hertz=lower_edge_hertz, + upper_edge_hertz=upper_edge_hertz, + ) + + # ndenumerate @handle_test( fn_tree="functional.ivy.experimental.ndenumerate", diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_linalg.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_linalg.py index 80e0317124486..b1b01c92b0a53 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_linalg.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_linalg.py @@ -233,6 +233,25 @@ def _generate_eigh_tridiagonal_args(draw): return dtype, alpha, beta, eigvals_only, select, select_range, tol +@st.composite +def _generate_general_inner_product_args(draw): + dim = draw(st.integers(min_value=1, max_value=3)) + x_dtype, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=(dim, dim), + min_value=1, + max_value=10.0, + num_arrays=2, + shared_dtype=True, + allow_nan=False, + ) + ) + max_value = dim - 1 if dim > 1 else dim + n_modes = draw(st.integers(min_value=1, max_value=max_value) | st.just(None)) + return x_dtype, x, n_modes + + # multi_dot @st.composite def _generate_multi_dot_dtype_and_arrays(draw): @@ -995,6 +1014,26 @@ def test_eigvals(dtype_x, test_flags, backend_fw, fn_name): ) +@handle_test( + fn_tree="functional.ivy.experimental.general_inner_product", + data=_generate_general_inner_product_args(), +) +def test_general_inner_product(*, data, test_flags, backend_fw, fn_name, on_device): + input_dtypes, x, n_modes = data + helpers.test_function( + backend_to_test=backend_fw, + test_flags=test_flags, + fn_name=fn_name, + on_device=on_device, + rtol_=1e-1, + atol_=1e-1, + input_dtypes=input_dtypes, + a=x[0], + b=x[1], + n_modes=n_modes, + ) + + @handle_test( fn_tree="functional.ivy.experimental.initialize_tucker", data=_initialize_tucker_data(), diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_manipulation.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_manipulation.py index 234c75451f748..2b06ac879c412 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_manipulation.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_manipulation.py @@ -6,7 +6,7 @@ # local import ivy import ivy_tests.test_ivy.helpers as helpers -from ivy_tests.test_ivy.helpers import handle_test +from ivy_tests.test_ivy.helpers import handle_test, create_concatenable_arrays_dtypes from ivy.functional.ivy.experimental.manipulation import _check_bounds from ivy_tests.test_ivy.test_functional.test_core.test_manipulation import _get_splits @@ -375,6 +375,60 @@ def _soft_thresholding_data(draw): return x_dtype + t_dtype, x, threshold +@st.composite +def _st_col_row_stack_arrays(draw, stack_dim): + ndim = draw(st.integers(min_value=2, max_value=5)) + dtype = draw(st.sampled_from(draw(helpers.get_dtypes("valid")))) + arrays, dtypes = draw( + create_concatenable_arrays_dtypes( + min_num_dims=ndim, + max_num_dims=ndim, + min_num_arrays=1, + max_num_arrays=3, + concat_dim=stack_dim, + dtypes=[dtype], + ) + ) + if ndim == 2: + non_stack_dim_len = arrays[0].shape[1 - stack_dim] + add_1D = draw(st.booleans()) + if add_1D: + arrays_1D, dtypes_1D = draw( + create_concatenable_arrays_dtypes( + min_num_dims=None, + max_num_dims=None, + min_num_arrays=1, + max_num_arrays=2, + concat_dim=None, + dtypes=[dtype], + common_shape=[non_stack_dim_len], + ) + ) + arrays += arrays_1D + dtypes += dtypes_1D + + if non_stack_dim_len == 1: + add_0D = draw(st.booleans()) + if add_0D: + arrays_0D, dtypes_0D = draw( + create_concatenable_arrays_dtypes( + min_num_dims=0, + max_num_dims=0, + min_num_arrays=1, + max_num_arrays=2, + concat_dim=None, + dtypes=[dtype], + ) + ) + arrays += arrays_0D + dtypes += dtypes_0D + + arrays_dtypes = draw(st.permutations(list(zip(arrays, dtypes)))) + arrays, dtypes = list(zip(*arrays_dtypes)) + + return list(arrays), list(dtypes) + + def _st_tuples_or_int(n_pairs, min_val=0): return st.one_of( st_tuples( @@ -552,6 +606,24 @@ def test_broadcast_shapes(*, shapes, test_flags, backend_fw, fn_name, on_device) ) +# column_stack +@handle_test( + fn_tree="functional.ivy.experimental.column_stack", + arrays_dtypes=_st_col_row_stack_arrays(stack_dim=1), + test_gradients=st.just(False), +) +def test_column_stack(*, arrays_dtypes, test_flags, backend_fw, fn_name, on_device): + arrays, dtypes = arrays_dtypes + helpers.test_function( + input_dtypes=dtypes, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + arrays=arrays, + ) + + # concat_from_sequence @handle_test( fn_tree="functional.ivy.experimental.concat_from_sequence", diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_statistical.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_statistical.py index 1b3f7b173043e..716ff0ae7353c 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_statistical.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_statistical.py @@ -15,6 +15,30 @@ # --------------- # +@st.composite +def _get_castable_float_dtype_nan(draw, min_value=None, max_value=None): + available_dtypes = helpers.get_dtypes("float") + shape = draw(helpers.get_shape(min_num_dims=1, max_num_dims=4, max_dim_size=6)) + dtype, values = draw( + helpers.dtype_and_values( + available_dtypes=available_dtypes, + num_arrays=1, + large_abs_safety_factor=6, + small_abs_safety_factor=24, + safety_factor_scale="log", + shape=shape, + min_value=min_value, + max_value=max_value, + allow_nan=True, + ) + ) + axis = draw(helpers.get_axis(shape=shape, force_int=True)) + dtype1, values, dtype2 = draw( + helpers.get_castable_dtype(draw(available_dtypes), dtype[0], values[0]) + ) + return dtype1, [values], axis, dtype2 + + @st.composite def _get_dtype_value1_value2_cov( draw, @@ -622,6 +646,41 @@ def test_nanmedian( ) +@handle_test( + fn_tree="functional.ivy.experimental.nanprod", + dtype_x_axis_castable=_get_castable_float_dtype_nan(), + keep_dims=st.booleans(), + test_gradients=st.just(False), + initial=st.integers(min_value=-5, max_value=5), +) +def test_nanprod( + *, + dtype_x_axis_castable, + keep_dims, + test_flags, + initial, + backend_fw, + fn_name, + on_device, +): + input_dtype, x, axis, castable_dtype = dtype_x_axis_castable + x = x[0] + helpers.test_function( + input_dtypes=[input_dtype], + test_flags=test_flags, + rtol_=1e-1, + atol_=1e-1, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + a=x, + axis=axis, + keepdims=keep_dims, + dtype=castable_dtype, + initial=initial, + ) + + # quantile @handle_test( fn_tree="functional.ivy.experimental.quantile", diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py index f940f291a7714..7c213994ed770 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py @@ -44,7 +44,7 @@ def test_elu( @handle_test( fn_tree="functional.ivy.experimental.logit", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", @@ -89,7 +89,7 @@ def test_logsigmoid(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="prelu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), shape=st.shared(helpers.get_shape(), key="prelu"), large_abs_safety_factor=8, small_abs_safety_factor=8, @@ -117,7 +117,7 @@ def test_prelu(*, dtype_and_x, slope, test_flags, backend_fw, fn_name, on_device @handle_test( fn_tree="functional.ivy.experimental.relu6", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("valid"), large_abs_safety_factor=2, small_abs_safety_factor=2, safety_factor_scale="log", diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py index 4c963d7bdd711..2e131e2c7ff48 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py @@ -1190,39 +1190,6 @@ def test_max_pool3d( ) -@handle_test( - fn_tree="functional.ivy.experimental.layers.max_unpool1d", - x_k_s_p=helpers.arrays_for_pooling(min_dims=3, max_dims=3, min_side=1, max_side=4), - indices=st.lists(st.integers(0, 1), min_size=1, max_size=4), - ground_truth_backend="jax", - test_gradients=st.just(False), -) -def test_max_unpool1d( - *, - x_k_s_p, - indices, - test_flags, - backend_fw, - fn_name, - on_device, -): - dtype, x, kernel, stride, pad = x_k_s_p - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - on_device=on_device, - fn_name=fn_name, - rtol_=1e-2, - atol_=1e-2, - x=x[0], - kernel=kernel, - strides=stride, - padding=pad, - indices=indices, - ) - - @handle_test( fn_tree="functional.ivy.experimental.reduce_window", all_args=_reduce_window_helper(_get_reduce_func), diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_losses.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_losses.py index c41b3f0d2d63c..cbddd092c511e 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_losses.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_losses.py @@ -55,6 +55,55 @@ def test_huber_loss( ) +# kl_div +@handle_test( + fn_tree="functional.ivy.experimental.kl_div", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=1e-04, + max_value=1, + allow_inf=False, + min_num_dims=1, + max_num_dims=3, + min_dim_size=3, + ), + dtype_and_target=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=1e-04, + max_value=1, + allow_inf=False, + min_num_dims=1, + max_num_dims=3, + min_dim_size=3, + ), + reduction=st.sampled_from(["none", "sum", "batchmean", "mean"]), + test_with_out=st.just(False), +) +def test_kl_div( + dtype_and_input, + dtype_and_target, + reduction, + test_flags, + backend_fw, + fn_name, + on_device, +): + input_dtype, input = dtype_and_input + target_dtype, target = dtype_and_target + + helpers.test_function( + input_dtypes=input_dtype + target_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + atol_=1e-02, + input=input[0], + target=target[0], + reduction=reduction, + ) + + @handle_test( fn_tree="functional.ivy.experimental.l1_loss", dtype_input=helpers.dtype_and_values( diff --git a/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py index 2399daf623170..ac23f384b6398 100644 --- a/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py @@ -208,10 +208,10 @@ def test_sigmoid(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="functional.ivy.softmax", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), min_num_dims=1, large_abs_safety_factor=8, - small_abs_safety_factor=8, + small_abs_safety_factor=4, safety_factor_scale="log", ), axis=st.one_of( diff --git a/ivy_tests/test_ivy/test_misc/test_array.py b/ivy_tests/test_ivy/test_misc/test_array.py index 93b15c77a456e..853bb85708e63 100644 --- a/ivy_tests/test_ivy/test_misc/test_array.py +++ b/ivy_tests/test_ivy/test_misc/test_array.py @@ -873,7 +873,7 @@ def test_array__ipow__( input_dtype, x = dtype_and_x # bfloat16 is not supported by numpy - assume(not ("bfloat16" in input_dtype)) + assume("bfloat16" not in input_dtype) # Make sure x2 isn't a float when x1 is integer assume( @@ -1502,7 +1502,7 @@ def test_array__pow__( input_dtype, x = dtype_and_x # bfloat16 is not supported by numpy - assume(not ("bfloat16" in input_dtype)) + assume("bfloat16" not in input_dtype) # Make sure x2 isn't a float when x1 is integer assume( @@ -1881,7 +1881,7 @@ def test_array__rpow__( input_dtype, x = dtype_and_x # bfloat16 is not supported by numpy - assume(not ("bfloat16" in input_dtype)) + assume("bfloat16" not in input_dtype) # Make sure x2 isn't a float when x1 is integer assume( diff --git a/ivy_tests/test_ivy/test_misc/test_shape.py b/ivy_tests/test_ivy/test_misc/test_shape.py index 266292f7e2500..ec84ce3ac3bfd 100644 --- a/ivy_tests/test_ivy/test_misc/test_shape.py +++ b/ivy_tests/test_ivy/test_misc/test_shape.py @@ -545,7 +545,7 @@ def test_shape__pow__( input_dtype, x = dtype_and_x # bfloat16 is not supported by numpy - assume(not ("bfloat16" in input_dtype)) + assume("bfloat16" not in input_dtype) # Make sure x2 isn't a float when x1 is integer with BackendHandler.update_backend(backend_fw) as ivy_backend: diff --git a/ivy_tests/test_ivy/test_stateful/test_activations.py b/ivy_tests/test_ivy/test_stateful/test_activations.py index f9bf437ae759b..3630421a1b7cc 100644 --- a/ivy_tests/test_ivy/test_stateful/test_activations.py +++ b/ivy_tests/test_ivy/test_stateful/test_activations.py @@ -288,7 +288,7 @@ def test_log_softmax( @handle_method( method_tree="stateful.activations.Logit.__call__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", @@ -671,10 +671,10 @@ def test_silu( @handle_method( method_tree="stateful.activations.Softmax.__call__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), min_num_dims=1, - large_abs_safety_factor=8, - small_abs_safety_factor=8, + large_abs_safety_factor=10, + small_abs_safety_factor=10, safety_factor_scale="log", ), axis=helpers.ints(min_value=-1, max_value=0), diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 818e9c52d7a2d..e14c24869e34d 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -10,3 +10,4 @@ google-auth # mod_name=google.auth requests pyvis dill +astunparse diff --git a/run_tests.py b/run_tests.py index c67e3dbdab35c..0832fd9cdbfe2 100644 --- a/run_tests.py +++ b/run_tests.py @@ -14,6 +14,11 @@ "test_stateful", "test_misc", "test_scipy", + "test_pandas", + "test_mindspore", + "test_onnx", + "test_sklearn", + "test_xgboost", ) db_dict = { "test_functional/test_core": ["core", 10], @@ -28,6 +33,11 @@ "test_misc": ["misc", 19], "test_paddle": ["paddle", 20], "test_scipy": ["scipy", 21], + "test_pandas": ["pandas", 22], + "test_mindspore": ["mindspore", 23], + "test_onnx": ["onnx", 24], + "test_sklearn": ["sklearn", 25], + "test_xgboost": ["xgboost", 26], } result_config = { "success": "https://img.shields.io/badge/-success-success", @@ -36,9 +46,10 @@ def make_clickable(url, name): - return ''.format(name) + return ( + f'' + ) def get_submodule(test_path): @@ -46,9 +57,10 @@ def get_submodule(test_path): for name in submodules: if name in test_path: if name == "test_functional": - coll = db_dict["test_functional/" + test_path[-2]] - elif name == "test_experimental": - coll = db_dict["test_experimental/" + test_path[-2]] + if len(test_path) > 3 and test_path[3] == "test_experimental": + coll = db_dict[f"test_experimental/{test_path[4]}"] + else: + coll = db_dict[f"test_functional/{test_path[-2]}"] else: coll = db_dict[name] break @@ -69,16 +81,16 @@ def update_individual_test_results( frontend_version=None, device=None, ): - key = submod + "." + backend + key = f"{submod}.{backend}" if backend_version is not None: backend_version = backend_version.replace(".", "_") - key += "." + backend_version + key += f".{backend_version}" if frontend_version is not None: frontend_version = frontend_version.replace(".", "_") - key += "." + frontend_version - key += "." + test + key += f".{frontend_version}" + key += f".{test}" if device: - key += "." + device + key += f".{device}" collection.update_one( {"_id": id}, {"$set": {key: result}}, @@ -88,10 +100,7 @@ def update_individual_test_results( def remove_from_db(collection, id, submod, backend, test): - collection.update_one( - {"_id": id}, - {"$unset": {submod + "." + backend + ".": test}}, - ) + collection.update_one({"_id": id}, {"$unset": {f"{submod}.{backend}.": test}}) return @@ -158,7 +167,7 @@ def run_multiversion_testing(): if len(sys.argv) > 8 and sys.argv[8] != "null": run_id = sys.argv[8] else: - run_id = "https://github.com/unifyai/ivy/actions/runs/" + workflow_id + run_id = f"https://github.com/unifyai/ivy/actions/runs/{workflow_id}" failed = False # GPU Testing with_gpu = False @@ -210,13 +219,7 @@ def run_multiversion_testing(): else: res = make_clickable(run_id, result_config["success"]) frontend_version = None - if ( - coll[0] == "numpy" - or coll[0] == "jax" - or coll[0] == "tensorflow" - or coll[0] == "torch" - or coll[0] == "paddle" - ): + if coll[0] in ["numpy", "jax", "tensorflow", "torch", "paddle"]: frontend_version = "latest-stable" if priority_flag: print("Updating Priority DB") diff --git a/run_tests_CLI/array_api_det_coverage.py b/run_tests_CLI/array_api_det_coverage.py index 463f4f5944281..712ce6c8dbee1 100644 --- a/run_tests_CLI/array_api_det_coverage.py +++ b/run_tests_CLI/array_api_det_coverage.py @@ -33,7 +33,7 @@ def main(): continue if ("#" not in s) or ( "#" in s - and not (framework in s.lower()) + and (framework not in s.lower()) and any(f in s.lower() for f in framework_tests_to_run) ): submod = f"ivy_tests/array_api_testing/test_array_api/array_api_tests/test_{fname.replace('.txt', '.py')}" # noqa diff --git a/run_tests_CLI/array_api_determine_tests.py b/run_tests_CLI/array_api_determine_tests.py index d15665b13330d..fade279139968 100644 --- a/run_tests_CLI/array_api_determine_tests.py +++ b/run_tests_CLI/array_api_determine_tests.py @@ -41,15 +41,15 @@ def determine_tests_line(_tests_file, _line, _tests_to_run): modified_files = commit._parse_diff(diff_index) for file in modified_files: try: - file_name = file.new_path + ",cover" + file_name = f"{file.new_path},cover" except: # noqa continue if file_name not in tests.keys(): continue tests_file = tests[file_name] change = file.diff_parsed - added = set([x - 1 for (x, _) in change["added"]]) - deleted = set([x - 1 for (x, _) in change["deleted"]]) + added = {x - 1 for (x, _) in change["added"]} + deleted = {x - 1 for (x, _) in change["deleted"]} updated = added.intersection(deleted) added = added.difference(updated) deleted = deleted.difference(updated) diff --git a/run_tests_CLI/array_api_run_tests.py b/run_tests_CLI/array_api_run_tests.py index d17ae3c4528aa..bf960761b1a64 100644 --- a/run_tests_CLI/array_api_run_tests.py +++ b/run_tests_CLI/array_api_run_tests.py @@ -11,9 +11,10 @@ def make_clickable(url, name): - return ''.format(name) + return ( + f'' + ) def get_submodule(test_path): @@ -34,14 +35,14 @@ def update_individual_test_results( backend_version=None, frontend_version=None, ): - key = submod + "." + backend + key = f"{submod}.{backend}" if backend_version is not None: backend_version = backend_version.replace(".", "_") - key += "." + backend_version + key += f".{backend_version}" if frontend_version is not None: frontend_version = frontend_version.replace(".", "_") - key += "." + frontend_version - key += "." + test + key += f".{frontend_version}" + key += f".{test}" collection.update_one( {"_id": id}, {"$set": {key: result}}, @@ -60,7 +61,7 @@ def main(): if len(sys.argv) > 5: run_id = sys.argv[5] else: - run_id = "https://github.com/unifyai/ivy/actions/runs/" + workflow_id + run_id = f"https://github.com/unifyai/ivy/actions/runs/{workflow_id}" failed = False cluster = MongoClient( f"mongodb+srv://deep-ivy:{mongo_key}@cluster0.qdvf8q3.mongodb.net/?retryWrites=true&w=majority" # noqa diff --git a/run_tests_CLI/cron_tests_multi_version.py b/run_tests_CLI/cron_tests_multi_version.py index 8efbadb0ef9e4..a3ad33eaa3f74 100644 --- a/run_tests_CLI/cron_tests_multi_version.py +++ b/run_tests_CLI/cron_tests_multi_version.py @@ -52,7 +52,7 @@ "numpy/1.24.2", ] jax_req = [ - jax_ver + "/" + jaxlib_ver for jax_ver in jax_only_req for jaxlib_ver in jaxlib_req + f"{jax_ver}/{jaxlib_ver}" for jax_ver in jax_only_req for jaxlib_ver in jaxlib_req ] framework_versions = { @@ -81,21 +81,19 @@ test_names_without_backend.append(test_name) for test_name in test_names_without_backend: - for backend, backend_versions in framework_versions.items(): + for backend_versions in framework_versions.values(): for backend_version in backend_versions: - test_backend = test_name + "," + backend_version + test_backend = f"{test_name},{backend_version}" if "test_frontends" in test_name: frontend = test_name[39:] frontend = frontend[: frontend.find("/")] frontend_versions = framework_versions.get(frontend, []) for frontend_version in frontend_versions: - test_names.append(test_backend + ";" + frontend_version) + test_names.append(f"{test_backend};{frontend_version}") else: test_names.append(test_backend) -test_names = list(set(test_names)) -test_names.sort() - +test_names = sorted(set(test_names)) # Run 150 tests in each iteration of the cron job num_tests = len(test_names) print(num_tests) diff --git a/run_tests_CLI/get_all_tests.py b/run_tests_CLI/get_all_tests.py index 213eb0cb0d8ba..3568424441b94 100644 --- a/run_tests_CLI/get_all_tests.py +++ b/run_tests_CLI/get_all_tests.py @@ -37,14 +37,13 @@ def extract_tests_from_dir(directory): def get_all_tests(): test_names_without_backend = extract_tests_from_dir("ivy_tests/test_ivy") - test_names_without_backend = list(set(test_names_without_backend)) - test_names_without_backend.sort() + test_names_without_backend = sorted(set(test_names_without_backend)) random.Random(4).shuffle(test_names_without_backend) test_names = [] for test_name in test_names_without_backend: for backend in BACKENDS: - test_backend = test_name + "," + backend + test_backend = f"{test_name},{backend}" test_names.append(test_backend) return test_names diff --git a/run_tests_CLI/setup_priority_tests.py b/run_tests_CLI/setup_priority_tests.py index 3fbff574d2a48..c0bb3f1442304 100644 --- a/run_tests_CLI/setup_priority_tests.py +++ b/run_tests_CLI/setup_priority_tests.py @@ -3,15 +3,14 @@ def main(): - write_file = open("tests_to_run", "w") - with open(sys.argv[1], "r") as f: - for test in f: - test = test.strip() - if test.startswith("ivy/"): - test = test[4:] - for backend in BACKENDS: - write_file.write(f"{test},{backend}\n") - write_file.close() + with open("tests_to_run", "w") as write_file: + with open(sys.argv[1], "r") as f: + for test in f: + test = test.strip() + if test.startswith("ivy/"): + test = test[4:] + for backend in BACKENDS: + write_file.write(f"{test},{backend}\n") if __name__ == "__main__": diff --git a/run_tests_CLI/synchronize_db.py b/run_tests_CLI/synchronize_db.py index fad171323d6ba..a21ee7a88337f 100644 --- a/run_tests_CLI/synchronize_db.py +++ b/run_tests_CLI/synchronize_db.py @@ -1,3 +1,4 @@ +import json import sys from pymongo import MongoClient from get_all_tests import get_all_tests @@ -5,8 +6,8 @@ module_map = { "core": "test_functional/test_core", - "exp_core": "test_experimental/test_core", - "nn": "test_functional/test_nn", + "exp_core": "test_functional/test_experimental/test_core", + "nn": "test_functional/test_experimental/test_nn", "exp_nn": "test_experimental/test_nn", "stateful": "test_stateful", "torch": "test_frontends/test_torch", @@ -20,7 +21,9 @@ def keys_to_delete_from_db(all_tests, module, data, current_key=""): - """Recursively navigate and identify keys not in the list.""" + """ + Recursively navigate and identify keys not in the list. + """ keys_for_deletion = [] for key, value in data.items(): @@ -28,7 +31,9 @@ def keys_to_delete_from_db(all_tests, module, data, current_key=""): # If this is a dictionary, recurse deeper if isinstance(value, dict): - keys_for_deletion.extend(keys_to_delete_from_db(all_tests, value, new_key)) + keys_for_deletion.extend( + keys_to_delete_from_db(all_tests, module, value, new_key) + ) # If the new_key is not in keys_to_keep, mark it for deletion elif key != "_id": components = new_key.split(".") @@ -102,6 +107,29 @@ def process_test(test): return coll[0] + "/" + submod + "::" + test_fn +def remove_empty_objects(document, key_prefix=""): + # Base case: if the document is not a dictionary, return an empty list + if not isinstance(document, dict): + return [] + + # List to store keys associated with empty objects + empty_keys = [] + + for key, value in document.items(): + # Generate the full key path + full_key = key_prefix + "." + key if key_prefix else key + + # If the value is a dictionary, recursively check for empty objects + if isinstance(value, dict): + # If the dictionary is empty, store its key + if not value: + empty_keys.append(full_key) + else: + empty_keys.extend(remove_empty_objects(value, full_key)) + + return empty_keys + + def main(): all_tests = get_all_tests() all_tests = set([process_test(test.split(",")[0].strip()) for test in all_tests]) @@ -114,11 +142,26 @@ def main(): collection = db[collection_name] for document in collection.find({}): undesired_keys = keys_to_delete_from_db( - all_tests, module_map[collection_name], document + all_tests, collection_name, document ) for key in undesired_keys: - print(key) - # collection.update_one({"_id": document["_id"]}, {"$unset": {key: 1}}) + collection.update_one({"_id": document["_id"]}, {"$unset": {key: 1}}) + + for collection_name in db.list_collection_names(): + collection = db[collection_name] + break_flag = False + while True: + for document in collection.find({}): + keys_to_remove = remove_empty_objects(document) + if keys_to_remove: + update_operation = {"$unset": {key: 1 for key in keys_to_remove}} + collection.update_one({"_id": document["_id"]}, update_operation) + else: + break_flag = True + break + if break_flag: + break_flag = False + break if __name__ == "__main__": diff --git a/run_tests_CLI/test_dependencies.py b/run_tests_CLI/test_dependencies.py index 52ba1c061ab4d..348a63be15e6c 100644 --- a/run_tests_CLI/test_dependencies.py +++ b/run_tests_CLI/test_dependencies.py @@ -34,8 +34,8 @@ def test_imports(fname, assert_version, update_versions): global WARN global WARN_MSG global PRINT_MSG - versions_to_update = dict() - msg = "\nasserting imports work for: {}\n\n".format(fname) + versions_to_update = {} + msg = f"\nasserting imports work for: {fname}\n\n" PRINT_MSG += msg ERROR_MSG += msg WARN_MSG += msg @@ -48,7 +48,7 @@ def test_imports(fname, assert_version, update_versions): mod = importlib.import_module(mod_name) except Exception as e: ERROR = True - msg = "{} could not be imported: {}\n".format(mod_name, e) + msg = f"{mod_name} could not be imported: {e}\n" ERROR_MSG += msg PRINT_MSG += msg continue @@ -65,13 +65,11 @@ def test_imports(fname, assert_version, update_versions): detected_version = None if detected_version and expected_version: if detected_version == expected_version: - msg = "{} detected correct version: {}\n".format( - mod_name, detected_version - ) + msg = f"{mod_name} detected correct version: {detected_version}\n" else: msg = ( - "expected version {} for module {}, but detected version " - "{}\n".format(expected_version, mod_name, detected_version) + f"expected version {expected_version} for module {mod_name}, but" + f" detected version {detected_version}\n" ) versions_to_update[line_num] = { "expected": expected_version, @@ -87,17 +85,18 @@ def test_imports(fname, assert_version, update_versions): else: if detected_version: msg = ( - "{} detected version: {}, but no expected version " - "provided\n".format(mod_name, detected_version) + f"{mod_name} detected version: {detected_version}, but no expected" + " version provided\n" ) elif expected_version: - msg = "{} expected version: {}, but unable to detect version\n".format( - mod_name, expected_version + msg = ( + f"{mod_name} expected version: {expected_version}, but unable to" + " detect version\n" ) else: msg = ( - "no expected version provided, and unable to detect " - "version for {}\n".format(mod_name) + "no expected version provided, and unable to detect version for" + f" {mod_name}\n" ) WARN = True PRINT_MSG += msg diff --git a/run_tests_pr.py b/run_tests_pr.py index 167e33a26d8c5..3c5ee3391454a 100644 --- a/run_tests_pr.py +++ b/run_tests_pr.py @@ -40,9 +40,9 @@ def get_mod_submod_test(test_path): for name in modules: if name in test_path: if name == "test_functional": - module = module_map["test_functional/" + test_path[-2]] + module = module_map[f"test_functional/{test_path[-2]}"] elif name == "test_experimental": - module = module_map["test_experimental/" + test_path[-2]] + module = module_map[f"test_experimental/{test_path[-2]}"] else: module = module_map[name] break @@ -54,32 +54,29 @@ def get_mod_submod_test(test_path): if __name__ == "__main__": failed = False - f_write = open(sys.argv[1], "w") - with open("tests_to_run", "r") as f: - for line in f: - test, backend = line.split(",") - print(f"\n{'*' * 100}") - print(f"{line[:-1]}") - print(f"{'*' * 100}\n") - sys.stdout.flush() - ret = os.system( - f'docker run --rm -v "$(pwd)":/ivy -v "$(pwd)"/.hypothesis:/.hypothesis unifyai/ivy:latest python3 -m pytest --tb=short {test} --skip-compile-testing --backend {backend}' # noqa - ) - if ret != 0: - failed = True - module, submodule, test = get_mod_submod_test(test) - params = { - "module": module, - "submodule": submodule, - "backend": backend[:-1], - "test": test, - } - response = requests.get(url, params=params) - if response.status_code == 200: - if response.json(): - # The test passes on main but fails in this fork/branch + with open(sys.argv[1], "w") as f_write: + with open("tests_to_run", "r") as f: + for line in f: + test, backend = line.split(",") + print(f"\n{'*' * 100}") + print(f"{line[:-1]}") + print(f"{'*' * 100}\n") + sys.stdout.flush() + ret = os.system( + f'docker run --rm -v "$(pwd)":/ivy -v "$(pwd)"/.hypothesis:/.hypothesis unifyai/ivy:latest python3 -m pytest --tb=short {test} --skip-compile-testing --backend {backend}' # noqa + ) + if ret != 0: + failed = True + module, submodule, test = get_mod_submod_test(test) + params = { + "module": module, + "submodule": submodule, + "backend": backend[:-1], + "test": test, + } + response = requests.get(url, params=params) + if response.status_code == 200 and response.json(): f_write.write(line) - f_write.close() if failed: exit(1) diff --git a/scripts/backend_generation/generate.py b/scripts/backend_generation/generate.py index 812bae002b91d..c68757a53b8d4 100644 --- a/scripts/backend_generation/generate.py +++ b/scripts/backend_generation/generate.py @@ -113,23 +113,23 @@ def _update_native_config_value(key): try: obj = __builtins__.__dict__[parsed[-1]] except KeyError: - print(Fore.RED + f"{parsed[-1]} is not a primitive object.") + print(f"{Fore.RED}{parsed[-1]} is not a primitive object.") return False else: try: mod = import_module(parsed[0]) except ModuleNotFoundError: - print(Fore.RED + f"failed to import {parsed[0]}") + print(f"{Fore.RED}failed to import {parsed[0]}") return False try: obj = getattr(mod, parsed[-1]) except AttributeError: - print(Fore.RED + f"{parsed[-1]} is not found in module.") + print(f"{Fore.RED}{parsed[-1]} is not found in module.") return False if not inspect.isclass(obj): - print(Fore.RED + f"{obj} is not a class.") + print(f"{Fore.RED}{obj} is not a class.") return False - print(Fore.GREEN + f"Found class: {obj}") + print(f"{Fore.GREEN}Found class: {obj}") # Use alias if exists if backend["alias"] is not None: modified_namespace = parsed[0].replace( @@ -140,7 +140,7 @@ def _update_native_config_value(key): ) return True except KeyError: - print(Fore.RED + f"Couldn't find {ret}") + print(f"{Fore.RED}Couldn't find {ret}") return False return True @@ -163,7 +163,7 @@ def _should_install_backend(package_name): reqr_file.write("\n" + package_name + "\n") except subprocess.CalledProcessError as e: raise RuntimeError( - Fore.RED + f"Installing {package_name} failed. {e}" + f"{Fore.RED}Installing {package_name} failed. {e}" ) from e elif ret.lower() == "n": print( @@ -172,7 +172,7 @@ def _should_install_backend(package_name): "type checking won't be available.\n" ) else: - print(Fore.RED + f"{ret} not understood.") + print(f"{Fore.RED}{ret} not understood.") return False return True @@ -221,12 +221,12 @@ def _import_name(): _get_user_input(_import_name) global _imported_backend - print(Style.BRIGHT + f"Importing {backend['name']} for type checking...") + print(f"{Style.BRIGHT}Importing {backend['name']} for type checking...") try: _imported_backend = import_module(backend["name"]) return True except Exception as e: - print(Fore.RED + f"Failed to import {backend['name']}:{e}") + print(f"{Fore.RED}Failed to import {backend['name']}:{e}") return False return True @@ -253,9 +253,9 @@ def _update_flag_config_value(key): if ret == "y": config_flags[key] = not config_flags[key] return True - elif ret == "n" or ret == "": + elif ret in ["n", ""]: return True - print(Fore.RED + f"{ret} not understood.") + print(f"{Fore.RED}{ret} not understood.") return False @@ -327,15 +327,15 @@ def _call_generate_tree(config_name: str): for key, value in config_valids.copy().items(): all_items = fullset_mapping[key] invalid_items = list(set(all_items).difference(value)) - config_valids["in" + key] = invalid_items + config_valids[f"in{key}"] = invalid_items for key in config_valids["valid_dtypes"]: - new_key = "native_" + key + new_key = f"native_{key}" config_natives[new_key] = asdict(BackendNativeObject(name="None", namespace="")) _get_user_input(_update_native_config_value, new_key) for key in config_valids["invalid_dtypes"]: - new_key = "native_" + key + new_key = f"native_{key}" config_natives[new_key] = asdict(BackendNativeObject(name="None", namespace="")) print("\n:: Backend\n") @@ -348,10 +348,10 @@ def _call_generate_tree(config_name: str): if key.startswith("in"): continue valid_items = config_valids[key] - invalid_items = config_valids["in" + key] + invalid_items = config_valids[f"in{key}"] print("\n:: " + key.partition("_")[-1]) - print(Fore.GREEN + "valid > " + valid_items.__str__()) - print(Fore.RED + "invalid > " + invalid_items.__str__()) + print(f"{Fore.GREEN}valid > {valid_items.__str__()}") + print(f"{Fore.RED}invalid > {invalid_items.__str__()}") # Print flags for key, value in config_flags.items(): diff --git a/scripts/eager_mode_benchmark/benchmark.py b/scripts/eager_mode_benchmark/benchmark.py index f55f7fb2f484e..c2c8fe22cfaad 100644 --- a/scripts/eager_mode_benchmark/benchmark.py +++ b/scripts/eager_mode_benchmark/benchmark.py @@ -84,7 +84,7 @@ def _read_or_create_csv(output_path="./report.csv"): def _write_to_csv(df, row_list, output_path="./report.csv"): - row = {k: v for k, v in zip(COLUMNS, row_list)} + row = dict(zip(COLUMNS, row_list)) df = df.append(row, ignore_index=True) df.to_csv(output_path, index=False) @@ -206,16 +206,16 @@ def eager_benchmark( devices = ivy.default(devices, []) output_path = ivy.default(output_path, "./report.csv") print("\nBenchmarking backends : " + " ".join(backends)) - print("Number of experiments : {}".format(num_experiments) + "\n") + print(f"Number of experiments : {num_experiments}" + "\n") for i in range(num_experiments): if num_experiments > 1: print("====================") - print("Experiment {}".format(i + 1)) + print(f"Experiment {i + 1}") print("====================\n") for backend in backends: with _AvoidGPUPreallocation(backend) as _: print("------------------------------------------------\n") - print("backend : {}".format(backend)) + print(f"backend : {backend}") ivy.set_backend(backend, dynamic=True) valid_devices = [ device @@ -223,7 +223,7 @@ def eager_benchmark( if device.split(":")[0] not in ivy.invalid_devices ] for device in valid_devices: - print("device : {}".format(device)) + print(f"device : {device}") obj_call = obj if functional_api: obj_call = ivy.__dict__[obj] @@ -264,9 +264,9 @@ def eager_benchmark( ) ivy.clear_cached_mem_on_dev(device) print(LINE_UP * (len(valid_devices) - i), end=LINE_CLEAR) - print("device : {}\t --> done\n".format(device)) + print(f"device : {device}\t --> done\n") ivy.unset_backend() - print("Results written to {} ...".format(output_path)) + print(f"Results written to {output_path} ...") def visualize_speed_up( @@ -325,7 +325,7 @@ def visualize_speed_up( fig.set_figwidth(30) fig.set_figheight(12) fig.tight_layout(pad=10.0) - axes = np.asarray([axes]) if not isinstance(axes, np.ndarray) else axes + axes = axes if isinstance(axes, np.ndarray) else np.asarray([axes]) while len(axes.shape) < 2: if len(devices) > len(backends): axes = np.expand_dims(axes, len(axes.shape)) @@ -333,7 +333,7 @@ def visualize_speed_up( axes = np.expand_dims(axes, 0) for device, axis in zip(devices, axes): for backend, ax in zip(backends, axis): - ax.set_title("{} : {}".format(backend, device), {"fontsize": 18}) + ax.set_title(f"{backend} : {device}", {"fontsize": 18}) ax.set_ylabel("Percent Speed up on compiling", {"fontsize": 18}) ax.tick_params(axis="both", labelsize=15) query = df.query("backend == @backend and device == @device") @@ -341,8 +341,8 @@ def visualize_speed_up( ax.violinplot(query["percent_speed_up"]) else: warnings.warn( - "No records matching the filters passed" - "backend={} and device={}".format(backend, device) + f"No records matching the filters passedbackend={backend} and" + f" device={device}" ) plt.savefig(output_path) - print("plot saved to {} ...".format(output_path)) + print(f"plot saved to {output_path} ...") diff --git a/setup_tests.py b/setup_tests.py index 10c9f5cabb102..9f2a357d9a829 100644 --- a/setup_tests.py +++ b/setup_tests.py @@ -6,11 +6,10 @@ def main(): if len(sys.argv) < 2: return test = sys.argv[1] - if "," in test: - with open("tests_to_run", "w") as f: + with open("tests_to_run", "w") as f: + if "," in test: f.write(test + "\n") - else: - with open("tests_to_run", "w") as f: + else: for backend in BACKENDS: f.write(f"{test},{backend}\n")