Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python scalars in elementwise functions #807

Open
shoyer opened this issue May 15, 2024 · 19 comments
Open

Python scalars in elementwise functions #807

shoyer opened this issue May 15, 2024 · 19 comments
Labels
API change Changes to existing functions or objects in the API.
Milestone

Comments

@shoyer
Copy link
Contributor

shoyer commented May 15, 2024

The array API supports Python scalars in arithmetic only, i.e., operations like x + 1.

For the same readability reasons that supporting scalars in arithmetic is valuable, it would nice to also support Python scalars in other elementwise functions, at least those that take multiple arguments like maximum(x, 0) or where(y, x, 0).

@rgommers
Copy link
Member

Hmm,I am in two minds about reconsidering this choice.

On the con side: non-array input to functions is going against the design we have had from the start, it makes static typing a bit harder (we'd need both an Array protocol and an ArrayOrScalar union), and not all libraries support it yet - PyTorch in particular. E.g.:

>>> import torch
>>> t = torch.ones(3)
>>> torch.maximum(t, 1.5)
...
TypeError: maximum(): argument 'other' (position 2) must be Tensor, not float

In principle PyTorch is fine with adding this it looks like, but it's a nontrivial amount of work and no one is working on it as far as I know: pytorch/pytorch#110636. PyTorch does support it in functions matching operators (e.g., torch.add) and in torch.where.

TensorFlow also doesn't support it (except for in their experimental.numpy namespace IIRC), but that's less relevant now since it doesn't look like they're going to implement anything.

For the same readability reasons that supporting scalars in arithmetic is valuable

The readability argument is less prominent for functions that for operators though. Both because x + 1 is very short so the relative increase in characters is worse than for function calls (since modname.funcname is already long). Plus scalars are less commonly used in function calls.


On the pro side: I agree that it is pretty annoying to get right in a completely portable and generic way. In the cases where one does need it, the natural choice of asarray(scalar) often doesn't work, it should also use the dtype and device. So xp.maximum(x, 1) becomes:

xp.maximum(x, xp.asarray(1, dtype=x.dtype, device=x.device))

Hence if this is a pattern that a project happens to need a lot, it will probably create a utility function like:

def as_zerodim(value, x, /, xp=None):
    if xp is None:
        xp = array_namespace(x)
    return xp.asarray(value, dtype=x.dtype, device=x.device)


# Usage:
xp.maximum(x, as_zerodim(1, x))

PyTorch support comes through array-api-compat at this point, so wrapping the PyTorch functions isn't too hard. So it is doable. I think I'm +0.5 on balance. It's not the highest-prio item, but it's nice to have if it works for all implementing libraries.

@rgommers rgommers added the API change Changes to existing functions or objects in the API. label May 17, 2024
@asmeurer
Copy link
Member

We could support them in a bespoke way for specific useful functions' arguments like where. We already added scalar support specifically to the min and max arguments to clip https://data-apis.org/array-api/latest/API_specification/generated/array_api.clip.html

@shoyer
Copy link
Contributor Author

shoyer commented May 17, 2024

On the pro side: I agree that it is pretty annoying to get right in a completely portable and generic way. In the cases where one does need it, the natural choice of asarray(scalar) often doesn't work, it should also use the dtype and device. So xp.maximum(x, 1) becomes:

xp.maximum(x, xp.asarray(1, dtype=x.dtype, device=x.device))

It's even a little messier in the case Xarray is currently facing:

  1. We want this to work in a completely portable and generic way, with the minimum array-API requirements.
  2. We also still want to allow libraries like NumPy to figure out the result type itself. For example, consider maximum(x, 0.5) in the case where x is an integer dtype. In the array API, mixed dtype casting is undefined, but in most array libraries the result would be upcast to some form of float.

@asmeurer
Copy link
Member

asmeurer commented May 17, 2024

In the array API, mixed dtype casting is undefined, but in most array libraries the result would be upcast to some form of float.

That's deviating from even the operator behavior in the array API. The specified scalar OP array behavior is to only upcast the scalar to the type of the array, not the other way around https://data-apis.org/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-scalars. In other words, int OP float_array is OK, but float OP int_array is not. Implicitly casing an integer array to a floating point dtype is cross-kind casting, and is something we've tried to explicitly avoid. (to be clear, these are all recommended, not required. Libraries like NumPy are free to implement this if they choose to)

Similarly, clip, which as I mentioned is an example of a function that already allows Python scalars, leaves mixed kind scalars unspecified, although I personally think it should adopt the same logic as operators and allow int alongside floating-point arrays.

asmeurer added a commit to asmeurer/array-api that referenced this issue Jun 4, 2024
This is for consistency with operators, which allow combining an int with an
array that has a floating-point data type.

See the discussion at data-apis#807.
@asmeurer
Copy link
Member

asmeurer commented Jun 10, 2024

List of potential APIs to support Python scalars in:

  • All elementwise functions that correspond to binary operators (add, multiply, etc.). It's not strictly required in terms of usefulness to support scalars in these operators since you can just use the operators themselves which support scalars, but most libraries should just alias one to the other, so the support will already be there (this is the case in PyTorch, for instance).
  • where (arguments 2 and 3)
  • clip (arguments 2 and 3). This is already supported but mixed kind scalars are not, see feat!: allow clip to have int min or max when x is floating-point #811
  • copysign (argument 2)
  • maximum (this function is symmetrical, so it's not clear if support should be added to both arguments or just argument 2)
  • minimum (ditto)
  • nextafter (arguments 1 and 2)

I've omitted all 1-argument elementwise functions.

Also these multi-argument elementwise functions, which seem less useful at a glance, but let me know if adding scalar support to any of these would be useful:

  • atan2
  • hypot
  • logaddexp

By the way, repeat() is also prior art in allowing both an array or int for the repeats argument.

@rgommers
Copy link
Member

Thanks Aaron. I agree with the choices here: where, clip, copysign, nextafter, minimum and maximum are quite naturally written with one scalar input, and it's easy to find real-world examples of that.

maximum (this function is symmetrical, so it's not clear if support should be added to both arguments or just argument 2)

Since there must always be at least one input that is an array, it would be more annoying to statically type if both arguments may be scalar. It then requires overloads, just an array | scalar union for both arguments isn't enough. So it seems preferable to me to not make it symmetric.

@asmeurer
Copy link
Member

Do you agree that allowing scalars for both arguments of nextafter is useful? At a glance, it seems to me that one could end up using it for a scalar in either argument, but I also haven't made much use of the function myself.

If so, should we allow scalars in both arguments? It could be useful, for instance, to just compute a specific float x + eps. The result should be a 0-D array. Also, unlike other cases, one wouldn't necessarily want to use math.nextafter since if your default floating-point dtype is float32, you would want x + float32_eps. OTOH, this would necessarily just automatically use the default floating-point dtype and the default device for the result. One can always manually cast one of the arguments to a 0-D array, though I don't know if that's an argument to allow or to not allow it.

@mdhaber
Copy link
Contributor

mdhaber commented Jun 17, 2024

Just wanted to add a +1 to this effort; it would really simplify translation efforts. If you need additional opinions about fine points, LMK.

@shoyer
Copy link
Contributor Author

shoyer commented Jun 27, 2024

One consideration that came up in discussion: How can users write new elementwise functions that support scalars in some arguments themselves using the array API?

e.g., suppose we want polar_to_cartesian() to work with either r or theta being a scalar:

def polar_to_cartesian(r, theta):
    xp = get_array_namespace(r, theta)
    return (r * xp.sin(theta), r * xp.cos(theta))

This seems to require supporting scalars even in single-argument elementwise operations like sin and cos.

@kgryte
Copy link
Contributor

kgryte commented Jun 27, 2024

@shoyer Your example may run into issues due to device selection. What device should libraries supporting multiple devices (e.g., PyTorch) allocate the returned arrays fromxp.sin(theta) and xp.cos(theta) to? And what if r is on a different device? That seems like could be a recipe for issues.

@seberg
Copy link
Contributor

seberg commented Jun 27, 2024

I find it a bit weird to not allow the scalar in both (would be nice if you could just add an (scalar, scalar) -> error overload, but dunno if that is possible and it probably doesn't matter much in practice.

Only allowing one of the two is a bit strange for assymetric functions. nextafter is probably a bit niche copysign maybe also atan2, hypot may not matter in a first iteration. So I have a doubt it is a true long term solution, but OK.

@asmeurer
Copy link
Member

nextafter is used quite a bit in SciPy https://github.com/search?q=repo%3Ascipy%2Fscipy%20nextafter&type=code (although half the uses are in the tests), with things like nextafter(np.pi, np.inf), suggesting double scalar usage is common.

Ditto for copysign https://github.com/search?q=repo%3Ascipy%2Fscipy+copysign&type=code. In fact, I would say my suggestion above to only support argument 2 in copysign was wrong. If anything, it's more common to use copysign(0.0, x) to create a signed 0.

Based on the discussions in the meeting today, we should still require at least one argument to be an array for now (or rather, leave full scalar inputs unspecified), but I don't see any reason to prefer arguments 1 or 2 for nextafter or copysign. There was also some sentiment that breaking symmetry for minimum and maximum would be confusing, so we should perhaps allow scalars in either argument for those functions.

Not including scalar support for argument 1 in where, clip, or repeat should be fine I'd imagine, but if anyone is aware of use-cases otherwise please point them out.

@rgommers
Copy link
Member

rgommers commented Jun 27, 2024

Based on the discussions in the meeting today,

The brief summary of that was:

  • Let's go one step further than @asmeurer's proposal above and allow scalars for all functions as long as there is at least one array input
    • TBD if that should be symmetric or if the first input should be an array (there's pros and cons there, we can see on the PR that updates the standard, or discuss further here)
    • It seemed to attendees that this was easier to explain to users, with fairly limited extra implementation work, and some usability benefits
  • We also considered allowing scalars everywhere, but there's multiple potential issues with that which are hard to evaluate (determining dtype/device, promotion rules, more complex static typing, ...), so decided against doing that at least for now
    • An important argument is that we can still do this in the future - but if we'd allow it now, there is no way to go back if it turns out to be problematic
    • Hence, do the "needs at least one input array" flavor now. In case that turns out to not be enough down the line, we can re-evaluate.

@asmeurer
Copy link
Member

asmeurer commented Jun 27, 2024

There are also some functions I intentionally omitted from my list above because it doesn't really seem to make sense for them to include scalars, even though they support multiple arrays since they are array manipulation functions, like broadcast_arrays, stack, and concat. There's also functions that don't allow 0-D inputs at all, like searchsorted all the linalg functions, which should obviously be omitted from this list.

@shoyer
Copy link
Contributor Author

shoyer commented Jun 27, 2024

I can maybe see a case for supporting scalars in broadcast_arrays and stack (NumPy and JAX support it), though it's pretty marginal.

Scalars don't make sense in concat because it requires arrays with an least one dimension.

@betatim
Copy link
Member

betatim commented Oct 22, 2024

Do we want to reboot this topic? Via "multi device" support in array-api-strict and the scikit-learn tests we found scipy/scipy#21736 - which I think will be a common bug without support for scalars. If only because it is quite tedious to spell out the correct version each time.

I am motivated to get this done and implemented in array-api-compat. Maybe we can only do it for those functions that we can agree on/where it is a "no brainer"? Then do a second pass later for other functions?

@asmeurer
Copy link
Member

We basically agreed to do this, so it just needs updates for the standard. We can implement things in compat even if they are only in a draft release of the standard. If anything it's better to implement things first because then we can catch potential issues.

@betatim
Copy link
Member

betatim commented Oct 22, 2024

Do you want to make the change in compat or should I try?

@asmeurer
Copy link
Member

If you want to make the change feel free. If you don't, I will, but I also have other things that I'm working on presently so I might not get to it until later this week or next week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API change Changes to existing functions or objects in the API.
Projects
None yet
Development

No branches or pull requests

7 participants