diff --git a/spatialmath/base/argcheck.py b/spatialmath/base/argcheck.py index 40f94336..85a25a0b 100644 --- a/spatialmath/base/argcheck.py +++ b/spatialmath/base/argcheck.py @@ -283,8 +283,7 @@ def getvector( dim: Optional[Union[int, None]] = None, out: str = "array", dtype: DTypeLike = np.float64, -) -> NDArray: - ... +) -> NDArray: ... @overload @@ -293,8 +292,7 @@ def getvector( dim: Optional[Union[int, None]] = None, out: str = "list", dtype: DTypeLike = np.float64, -) -> List[float]: - ... +) -> List[float]: ... @overload @@ -303,8 +301,7 @@ def getvector( dim: Optional[Union[int, None]] = None, out: str = "sequence", dtype: DTypeLike = np.float64, -) -> Tuple[float, ...]: - ... +) -> Tuple[float, ...]: ... @overload @@ -313,8 +310,7 @@ def getvector( dim: Optional[Union[int, None]] = None, out: str = "sequence", dtype: DTypeLike = np.float64, -) -> List[float]: - ... +) -> List[float]: ... def getvector( @@ -522,7 +518,9 @@ def isvector(v: Any, dim: Optional[int] = None) -> bool: return False -def getunit(v: ArrayLike, unit: str = "rad", dim=None) -> Union[float, NDArray]: +def getunit( + v: ArrayLike, unit: str = "rad", dim: Optional[int] = None, vector: bool = True +) -> Union[float, NDArray]: """ Convert values according to angular units @@ -530,8 +528,10 @@ def getunit(v: ArrayLike, unit: str = "rad", dim=None) -> Union[float, NDArray]: :type v: array_like(m) :param unit: the angular unit, "rad" or "deg" :type unit: str - :param dim: expected dimension of input, defaults to None + :param dim: expected dimension of input, defaults to don't check (None) :type dim: int, optional + :param vector: return a scalar as a 1d vector, defaults to True + :type vector: bool, optional :return: the converted value in radians :rtype: ndarray(m) or float :raises ValueError: argument is not a valid angular unit @@ -543,30 +543,44 @@ def getunit(v: ArrayLike, unit: str = "rad", dim=None) -> Union[float, NDArray]: >>> from spatialmath.base import getunit >>> import numpy as np >>> getunit(1.5, 'rad') - >>> getunit(1.5, 'rad', dim=0) - >>> # getunit([1.5], 'rad', dim=0) --> ValueError >>> getunit(90, 'deg') + >>> getunit(90, 'deg', vector=False) # force a scalar output + >>> getunit(1.5, 'rad', dim=0) # check argument is scalar + >>> getunit(1.5, 'rad', dim=3) # check argument is a 3-vector + >>> getunit([1.5], 'rad', dim=1) # check argument is a 1-vector + >>> getunit([1.5], 'rad', dim=3) # check argument is a 3-vector >>> getunit([90, 180], 'deg') - >>> getunit(np.r_[0.5, 1], 'rad') >>> getunit(np.r_[90, 180], 'deg') - >>> getunit(np.r_[90, 180], 'deg', dim=2) - >>> # getunit([90, 180], 'deg', dim=3) --> ValueError + >>> getunit(np.r_[90, 180], 'deg', dim=2) # check argument is a 2-vector + >>> getunit([90, 180], 'deg', dim=3) # check argument is a 3-vector :note: - the input value is processed by :func:`getvector` and the argument ``dim`` can - be used to check that ``v`` is the desired length. - - the output is always an ndarray except if the input is a scalar and ``dim=0``. + be used to check that ``v`` is the desired length. Note that 0 means a scalar, + whereas 1 means a 1-element array. + - the output is always an ndarray except if the input is a scalar and ``vector=False``. :seealso: :func:`getvector` """ - if not isinstance(v, Iterable) and dim == 0: - # scalar in, scalar out - if unit == "rad": - return v - elif unit == "deg": - return np.deg2rad(v) + if not isinstance(v, Iterable): + # scalar input + if dim is not None and dim != 0: + raise ValueError("for dim==0 input must be a scalar") + if vector: + # scalar in, vector out + if unit == "deg": + v = np.deg2rad(v) + elif unit != "rad": + raise ValueError("invalid angular units") + return np.array([v]) else: - raise ValueError("invalid angular units") + # scalar in, scalar out + if unit == "rad": + return v + elif unit == "deg": + return np.deg2rad(v) + else: + raise ValueError("invalid angular units") else: # scalar or iterable in, ndarray out diff --git a/spatialmath/base/transforms2d.py b/spatialmath/base/transforms2d.py index 682ea0ca..a6955914 100644 --- a/spatialmath/base/transforms2d.py +++ b/spatialmath/base/transforms2d.py @@ -63,7 +63,7 @@ def rot2(theta: float, unit: str = "rad") -> SO2Array: >>> rot2(0.3) >>> rot2(45, 'deg') """ - theta = smb.getunit(theta, unit, dim=0) + theta = smb.getunit(theta, unit, vector=False) ct = smb.sym.cos(theta) st = smb.sym.sin(theta) # fmt: off @@ -172,18 +172,15 @@ def tr2xyt(T: SE2Array, unit: str = "rad") -> R3: # ---------------------------------------------------------------------------------------# @overload # pragma: no cover -def transl2(x: float, y: float) -> SE2Array: - ... +def transl2(x: float, y: float) -> SE2Array: ... @overload # pragma: no cover -def transl2(x: ArrayLike2) -> SE2Array: - ... +def transl2(x: ArrayLike2) -> SE2Array: ... @overload # pragma: no cover -def transl2(x: SE2Array) -> R2: - ... +def transl2(x: SE2Array) -> R2: ... def transl2(x, y=None): @@ -446,8 +443,7 @@ def trlog2( twist: bool = False, check: bool = True, tol: float = 20, -) -> so2Array: - ... +) -> so2Array: ... @overload # pragma: no cover @@ -456,8 +452,7 @@ def trlog2( twist: bool = False, check: bool = True, tol: float = 20, -) -> se2Array: - ... +) -> se2Array: ... @overload # pragma: no cover @@ -466,8 +461,7 @@ def trlog2( twist: bool = True, check: bool = True, tol: float = 20, -) -> float: - ... +) -> float: ... @overload # pragma: no cover @@ -476,8 +470,7 @@ def trlog2( twist: bool = True, check: bool = True, tol: float = 20, -) -> R3: - ... +) -> R3: ... def trlog2( @@ -563,13 +556,15 @@ def trlog2( # ---------------------------------------------------------------------------------------# @overload # pragma: no cover -def trexp2(S: so2Array, theta: Optional[float] = None, check: bool = True) -> SO2Array: - ... +def trexp2( + S: so2Array, theta: Optional[float] = None, check: bool = True +) -> SO2Array: ... @overload # pragma: no cover -def trexp2(S: se2Array, theta: Optional[float] = None, check: bool = True) -> SE2Array: - ... +def trexp2( + S: se2Array, theta: Optional[float] = None, check: bool = True +) -> SE2Array: ... def trexp2( @@ -692,8 +687,7 @@ def trexp2( @overload # pragma: no cover -def trnorm2(R: SO2Array) -> SO2Array: - ... +def trnorm2(R: SO2Array) -> SO2Array: ... def trnorm2(T: SE2Array) -> SE2Array: @@ -758,13 +752,11 @@ def trnorm2(T: SE2Array) -> SE2Array: @overload # pragma: no cover -def tradjoint2(T: SO2Array) -> R1x1: - ... +def tradjoint2(T: SO2Array) -> R1x1: ... @overload # pragma: no cover -def tradjoint2(T: SE2Array) -> R3x3: - ... +def tradjoint2(T: SE2Array) -> R3x3: ... def tradjoint2(T): @@ -853,13 +845,15 @@ def tr2jac2(T: SE2Array) -> R3x3: @overload -def trinterp2(start: Optional[SO2Array], end: SO2Array, s: float, shortest: bool = True) -> SO2Array: - ... +def trinterp2( + start: Optional[SO2Array], end: SO2Array, s: float, shortest: bool = True +) -> SO2Array: ... @overload -def trinterp2(start: Optional[SE2Array], end: SE2Array, s: float, shortest: bool = True) -> SE2Array: - ... +def trinterp2( + start: Optional[SE2Array], end: SE2Array, s: float, shortest: bool = True +) -> SE2Array: ... def trinterp2(start, end, s, shortest: bool = True): diff --git a/spatialmath/base/transforms3d.py b/spatialmath/base/transforms3d.py index 3617f965..57ba6ce9 100644 --- a/spatialmath/base/transforms3d.py +++ b/spatialmath/base/transforms3d.py @@ -79,7 +79,7 @@ def rotx(theta: float, unit: str = "rad") -> SO3Array: :SymPy: supported """ - theta = getunit(theta, unit, dim=0) + theta = getunit(theta, unit, vector=False) ct = sym.cos(theta) st = sym.sin(theta) # fmt: off @@ -118,7 +118,7 @@ def roty(theta: float, unit: str = "rad") -> SO3Array: :SymPy: supported """ - theta = getunit(theta, unit, dim=0) + theta = getunit(theta, unit, vector=False) ct = sym.cos(theta) st = sym.sin(theta) # fmt: off @@ -152,7 +152,7 @@ def rotz(theta: float, unit: str = "rad") -> SO3Array: :seealso: :func:`~trotz` :SymPy: supported """ - theta = getunit(theta, unit, dim=0) + theta = getunit(theta, unit, vector=False) ct = sym.cos(theta) st = sym.sin(theta) # fmt: off @@ -263,18 +263,15 @@ def trotz(theta: float, unit: str = "rad", t: Optional[ArrayLike3] = None) -> SE @overload # pragma: no cover -def transl(x: float, y: float, z: float) -> SE3Array: - ... +def transl(x: float, y: float, z: float) -> SE3Array: ... @overload # pragma: no cover -def transl(x: ArrayLike3) -> SE3Array: - ... +def transl(x: ArrayLike3) -> SE3Array: ... @overload # pragma: no cover -def transl(x: SE3Array) -> R3: - ... +def transl(x: SE3Array) -> R3: ... def transl(x, y=None, z=None): @@ -327,6 +324,7 @@ def transl(x, y=None, z=None): .. note:: This function is compatible with the MATLAB version of the Toolbox. It is unusual/weird in doing two completely different things inside the one function. + :seealso: :func:`~spatialmath.base.transforms2d.transl2` :SymPy: supported """ @@ -426,8 +424,7 @@ def isrot(R: Any, check: bool = False, tol: float = 20) -> bool: @overload # pragma: no cover def rpy2r( roll: float, pitch: float, yaw: float, *, unit: str = "rad", order: str = "zyx" -) -> SO3Array: - ... +) -> SO3Array: ... @overload # pragma: no cover @@ -438,8 +435,7 @@ def rpy2r( *, unit: str = "rad", order: str = "zyx", -) -> SO3Array: - ... +) -> SO3Array: ... def rpy2r( @@ -517,8 +513,7 @@ def rpy2r( @overload # pragma: no cover def rpy2tr( roll: float, pitch: float, yaw: float, unit: str = "rad", order: str = "zyx" -) -> SE3Array: - ... +) -> SE3Array: ... @overload # pragma: no cover @@ -528,8 +523,7 @@ def rpy2tr( yaw: None = None, unit: str = "rad", order: str = "zyx", -) -> SE3Array: - ... +) -> SE3Array: ... def rpy2tr( @@ -593,15 +587,13 @@ def rpy2tr( @overload # pragma: no cover -def eul2r(phi: float, theta: float, psi: float, unit: str = "rad") -> SO3Array: - ... +def eul2r(phi: float, theta: float, psi: float, unit: str = "rad") -> SO3Array: ... @overload # pragma: no cover def eul2r( phi: ArrayLike3, theta: None = None, psi: None = None, unit: str = "rad" -) -> SO3Array: - ... +) -> SO3Array: ... def eul2r( @@ -654,13 +646,11 @@ def eul2r( # ---------------------------------------------------------------------------------------# @overload # pragma: no cover -def eul2tr(phi: float, theta: float, psi: float, unit: str = "rad") -> SE3Array: - ... +def eul2tr(phi: float, theta: float, psi: float, unit: str = "rad") -> SE3Array: ... @overload # pragma: no cover -def eul2tr(phi: ArrayLike3, theta=None, psi=None, unit: str = "rad") -> SE3Array: - ... +def eul2tr(phi: ArrayLike3, theta=None, psi=None, unit: str = "rad") -> SE3Array: ... def eul2tr( @@ -753,7 +743,7 @@ def angvec2r(theta: float, v: ArrayLike3, unit="rad", tol: float = 20) -> SO3Arr if np.linalg.norm(v) < tol * _eps: return np.eye(3) - θ = getunit(theta, unit) + θ = getunit(theta, unit, vector=False) # Rodrigue's equation @@ -1272,25 +1262,25 @@ def tr2rpy( @overload # pragma: no cover def trlog( T: SO3Array, check: bool = True, twist: bool = False, tol: float = 20 -) -> so3Array: - ... +) -> so3Array: ... @overload # pragma: no cover def trlog( T: SE3Array, check: bool = True, twist: bool = False, tol: float = 20 -) -> se3Array: - ... +) -> se3Array: ... @overload # pragma: no cover -def trlog(T: SO3Array, check: bool = True, twist: bool = True, tol: float = 20) -> R3: - ... +def trlog( + T: SO3Array, check: bool = True, twist: bool = True, tol: float = 20 +) -> R3: ... @overload # pragma: no cover -def trlog(T: SE3Array, check: bool = True, twist: bool = True, tol: float = 20) -> R6: - ... +def trlog( + T: SE3Array, check: bool = True, twist: bool = True, tol: float = 20 +) -> R6: ... def trlog( @@ -1405,23 +1395,23 @@ def trlog( # ---------------------------------------------------------------------------------------# @overload # pragma: no cover -def trexp(S: so3Array, theta: Optional[float] = None, check: bool = True) -> SO3Array: - ... +def trexp( + S: so3Array, theta: Optional[float] = None, check: bool = True +) -> SO3Array: ... @overload # pragma: no cover -def trexp(S: se3Array, theta: Optional[float] = None, check: bool = True) -> SE3Array: - ... +def trexp( + S: se3Array, theta: Optional[float] = None, check: bool = True +) -> SE3Array: ... @overload # pragma: no cover -def trexp(S: ArrayLike3, theta: Optional[float] = None, check=True) -> SO3Array: - ... +def trexp(S: ArrayLike3, theta: Optional[float] = None, check=True) -> SO3Array: ... @overload # pragma: no cover -def trexp(S: ArrayLike6, theta: Optional[float] = None, check=True) -> SE3Array: - ... +def trexp(S: ArrayLike6, theta: Optional[float] = None, check=True) -> SE3Array: ... def trexp(S, theta=None, check=True): @@ -1542,8 +1532,7 @@ def trexp(S, theta=None, check=True): @overload # pragma: no cover -def trnorm(R: SO3Array) -> SO3Array: - ... +def trnorm(R: SO3Array) -> SO3Array: ... def trnorm(T: SE3Array) -> SE3Array: @@ -1605,13 +1594,15 @@ def trnorm(T: SE3Array) -> SE3Array: @overload -def trinterp(start: Optional[SO3Array], end: SO3Array, s: float, shortest: bool = True) -> SO3Array: - ... +def trinterp( + start: Optional[SO3Array], end: SO3Array, s: float, shortest: bool = True +) -> SO3Array: ... @overload -def trinterp(start: Optional[SE3Array], end: SE3Array, s: float, shortest: bool = True) -> SE3Array: - ... +def trinterp( + start: Optional[SE3Array], end: SE3Array, s: float, shortest: bool = True +) -> SE3Array: ... def trinterp(start, end, s, shortest=True): @@ -2224,8 +2215,7 @@ def rotvelxform( inverse: bool = False, full: bool = False, representation="rpy/xyz", -) -> R3x3: - ... +) -> R3x3: ... @overload # pragma: no cover @@ -2233,8 +2223,7 @@ def rotvelxform( 𝚪: SO3Array, inverse: bool = False, full: bool = False, -) -> R3x3: - ... +) -> R3x3: ... @overload # pragma: no cover @@ -2243,8 +2232,7 @@ def rotvelxform( inverse: bool = False, full: bool = True, representation="rpy/xyz", -) -> R6x6: - ... +) -> R6x6: ... @overload # pragma: no cover @@ -2252,8 +2240,7 @@ def rotvelxform( 𝚪: SO3Array, inverse: bool = False, full: bool = True, -) -> R6x6: - ... +) -> R6x6: ... def rotvelxform( @@ -2465,15 +2452,13 @@ def rotvelxform( @overload # pragma: no cover def rotvelxform_inv_dot( 𝚪: ArrayLike3, 𝚪d: ArrayLike3, full: bool = False, representation: str = "rpy/xyz" -) -> R3x3: - ... +) -> R3x3: ... @overload # pragma: no cover def rotvelxform_inv_dot( 𝚪: ArrayLike3, 𝚪d: ArrayLike3, full: bool = True, representation: str = "rpy/xyz" -) -> R6x6: - ... +) -> R6x6: ... def rotvelxform_inv_dot( @@ -2670,13 +2655,11 @@ def rotvelxform_inv_dot( @overload # pragma: no cover -def tr2adjoint(T: SO3Array) -> R3x3: - ... +def tr2adjoint(T: SO3Array) -> R3x3: ... @overload # pragma: no cover -def tr2adjoint(T: SE3Array) -> R6x6: - ... +def tr2adjoint(T: SE3Array) -> R6x6: ... def tr2adjoint(T): @@ -2709,7 +2692,7 @@ def tr2adjoint(T): :Reference: - Robotics, Vision & Control for Python, Section 3, P. Corke, Springer 2023. - - `Lie groups for 2D and 3D Transformations _ + - `Lie groups for 2D and 3D Transformations `_ :SymPy: supported """ @@ -3002,29 +2985,36 @@ def trplot( - ``width`` of line - ``length`` of line - ``style`` which is one of: + - ``'arrow'`` [default], draw line with arrow head in ``color`` - ``'line'``, draw line with no arrow head in ``color`` - ``'rgb'``, frame axes are lines with no arrow head and red for X, green - for Y, blue for Z; no origin dot + for Y, blue for Z; no origin dot - ``'rviz'``, frame axes are thick lines with no arrow head and red for X, - green for Y, blue for Z; no origin dot + green for Y, blue for Z; no origin dot + - coordinate axis labels depend on: + - ``axislabel`` if True [default] label the axis, default labels are X, Y, Z - ``labels`` 3-list of alternative axis labels - ``textcolor`` which defaults to ``color`` - ``axissubscript`` if True [default] add the frame label ``frame`` as a subscript - for each axis label + for each axis label + - coordinate frame label depends on: + - `frame` the label placed inside {} near the origin of the frame + - a dot at the origin + - ``originsize`` size of the dot, if zero no dot - ``origincolor`` color of the dot, defaults to ``color`` Examples:: - trplot(T, frame='A') - trplot(T, frame='A', color='green') - trplot(T1, 'labels', 'UVW'); + trplot(T, frame='A') + trplot(T, frame='A', color='green') + trplot(T1, 'labels', 'UVW'); .. plot:: @@ -3383,12 +3373,12 @@ def tranimate(T: Union[SO3Array, SE3Array], **kwargs) -> str: :param **kwargs: arguments passed to ``trplot`` - ``tranimate(T)`` where ``T`` is an SO(3) or SE(3) matrix, animates a 3D - coordinate frame moving from the world frame to the frame ``T`` in - ``nsteps``. + coordinate frame moving from the world frame to the frame ``T`` in + ``nsteps``. - ``tranimate(I)`` where ``I`` is an iterable or generator, animates a 3D - coordinate frame representing the pose of each element in the sequence of - SO(3) or SE(3) matrices. + coordinate frame representing the pose of each element in the sequence of + SO(3) or SE(3) matrices. Examples: diff --git a/spatialmath/base/vectors.py b/spatialmath/base/vectors.py index f29740a3..2cc74325 100644 --- a/spatialmath/base/vectors.py +++ b/spatialmath/base/vectors.py @@ -530,6 +530,7 @@ def wrap_0_pi(theta: ArrayLike) -> Union[float, NDArray]: :param theta: input angle :type theta: scalar or ndarray :return: angle wrapped into range :math:`[0, \pi)` + :rtype: scalar or ndarray This is used to fold angles of colatitude. If zero is the angle of the north pole, colatitude increases to :math:`\pi` at the south pole then @@ -537,7 +538,7 @@ def wrap_0_pi(theta: ArrayLike) -> Union[float, NDArray]: :seealso: :func:`wrap_mpi2_pi2` :func:`wrap_0_2pi` :func:`wrap_mpi_pi` :func:`angle_wrap` """ - theta = np.abs(theta) + theta = np.abs(getvector(theta)) n = theta / np.pi if isinstance(n, np.ndarray): n = n.astype(int) @@ -546,7 +547,7 @@ def wrap_0_pi(theta: ArrayLike) -> Union[float, NDArray]: y = np.where(np.bitwise_and(n, 1) == 0, theta - n * np.pi, (n + 1) * np.pi - theta) if isinstance(y, np.ndarray) and y.size == 1: - return float(y) + return float(y[0]) else: return y @@ -558,6 +559,7 @@ def wrap_mpi2_pi2(theta: ArrayLike) -> Union[float, NDArray]: :param theta: input angle :type theta: scalar or ndarray :return: angle wrapped into range :math:`[-\pi/2, \pi/2]` + :rtype: scalar or ndarray This is used to fold angles of latitude. @@ -573,7 +575,7 @@ def wrap_mpi2_pi2(theta: ArrayLike) -> Union[float, NDArray]: y = np.where(np.bitwise_and(n, 1) == 0, theta - n * np.pi, n * np.pi - theta) if isinstance(y, np.ndarray) and len(y) == 1: - return float(y) + return float(y[0]) else: return y @@ -585,13 +587,14 @@ def wrap_0_2pi(theta: ArrayLike) -> Union[float, NDArray]: :param theta: input angle :type theta: scalar or ndarray :return: angle wrapped into range :math:`[0, 2\pi)` + :rtype: scalar or ndarray :seealso: :func:`wrap_mpi_pi` :func:`wrap_0_pi` :func:`wrap_mpi2_pi2` :func:`angle_wrap` """ theta = getvector(theta) y = theta - 2.0 * math.pi * np.floor(theta / 2.0 / np.pi) if isinstance(y, np.ndarray) and len(y) == 1: - return float(y) + return float(y[0]) else: return y @@ -603,13 +606,14 @@ def wrap_mpi_pi(theta: ArrayLike) -> Union[float, NDArray]: :param theta: input angle :type theta: scalar or ndarray :return: angle wrapped into range :math:`[-\pi, \pi)` + :rtype: scalar or ndarray :seealso: :func:`wrap_0_2pi` :func:`wrap_0_pi` :func:`wrap_mpi2_pi2` :func:`angle_wrap` """ theta = getvector(theta) y = np.mod(theta + math.pi, 2 * math.pi) - np.pi if isinstance(y, np.ndarray) and len(y) == 1: - return float(y) + return float(y[0]) else: return y @@ -620,13 +624,11 @@ def wrap_mpi_pi(theta: ArrayLike) -> Union[float, NDArray]: @overload -def angdiff(a: ArrayLike, b: ArrayLike) -> NDArray: - ... +def angdiff(a: ArrayLike, b: ArrayLike) -> NDArray: ... @overload -def angdiff(a: ArrayLike) -> NDArray: - ... +def angdiff(a: ArrayLike) -> NDArray: ... def angdiff(a, b=None): @@ -643,6 +645,7 @@ def angdiff(a, b=None): - ``angdiff(a, b)`` is the difference ``a - b`` wrapped to the range :math:`[-\pi, \pi)`. This is the operator :math:`a \circleddash b` used in the RVC book + - If ``a`` and ``b`` are both scalars, the result is scalar - If ``a`` is array_like, the result is a NumPy array ``a[i]-b`` - If ``a`` is array_like, the result is a NumPy array ``a-b[i]`` @@ -651,6 +654,7 @@ def angdiff(a, b=None): - ``angdiff(a)`` is the angle or vector of angles ``a`` wrapped to the range :math:`[-\pi, \pi)`. + - If ``a`` is a scalar, the result is scalar - If ``a`` is array_like, the result is a NumPy array @@ -671,7 +675,7 @@ def angdiff(a, b=None): y = np.mod(a + math.pi, 2 * math.pi) - math.pi if isinstance(y, np.ndarray) and len(y) == 1: - return float(y) + return float(y[0]) else: return y diff --git a/spatialmath/quaternion.py b/spatialmath/quaternion.py index 51561036..87e25beb 100644 --- a/spatialmath/quaternion.py +++ b/spatialmath/quaternion.py @@ -13,6 +13,7 @@ :top-classes: collections.UserList :parts: 1 """ + # pylint: disable=invalid-name from __future__ import annotations import math @@ -78,7 +79,7 @@ def __init__(self, s: Any = None, v=None, check: Optional[bool] = True): super().__init__() if s is None and smb.isvector(v, 4): - v,s = (s,v) + v, s = (s, v) if v is None: # single argument @@ -982,10 +983,10 @@ def __init__( """ super().__init__() - # handle: UnitQuaternion(v)`` constructs a unit quaternion with specified elements + # handle: UnitQuaternion(v)`` constructs a unit quaternion with specified elements # from ``v`` which is a 4-vector given as a list, tuple, or ndarray(4) if s is None and smb.isvector(v, 4): - v,s = (s,v) + v, s = (s, v) if v is None: # single argument @@ -1225,7 +1226,9 @@ def Rz(cls, angles: ArrayLike, unit: Optional[str] = "rad") -> UnitQuaternion: ) @classmethod - def Rand(cls, N: int = 1, *, theta_range:Optional[ArrayLike2] = None, unit: str = "rad") -> UnitQuaternion: + def Rand( + cls, N: int = 1, *, theta_range: Optional[ArrayLike2] = None, unit: str = "rad" + ) -> UnitQuaternion: """ Construct a new random unit quaternion @@ -1252,7 +1255,10 @@ def Rand(cls, N: int = 1, *, theta_range:Optional[ArrayLike2] = None, unit: str :seealso: :meth:`UnitQuaternion.Rand` """ - return cls([smb.qrand(theta_range=theta_range, unit=unit) for i in range(0, N)], check=False) + return cls( + [smb.qrand(theta_range=theta_range, unit=unit) for i in range(0, N)], + check=False, + ) @classmethod def Eul(cls, *angles: List[float], unit: Optional[str] = "rad") -> UnitQuaternion: @@ -1411,7 +1417,7 @@ def AngVec( :seealso: :meth:`UnitQuaternion.angvec` :meth:`UnitQuaternion.exp` :func:`~spatialmath.base.transforms3d.angvec2r` """ v = smb.getvector(v, 3) - theta = smb.getunit(theta, unit, dim=0) + theta = smb.getunit(theta, unit, vector=False) return cls( s=math.cos(theta / 2), v=math.sin(theta / 2) * v, norm=False, check=False ) diff --git a/tests/base/test_argcheck.py b/tests/base/test_argcheck.py index 685393b5..39c943d1 100755 --- a/tests/base/test_argcheck.py +++ b/tests/base/test_argcheck.py @@ -122,11 +122,49 @@ def test_verifymatrix(self): verifymatrix(a, (3, 4)) def test_unit(self): - self.assertIsInstance(getunit(1), np.ndarray) + # scalar -> vector + self.assertEqual(getunit(1), np.array([1])) + self.assertEqual(getunit(1, dim=0), np.array([1])) + with self.assertRaises(ValueError): + self.assertEqual(getunit(1, dim=1), np.array([1])) + + self.assertEqual(getunit(1, unit="deg"), np.array([1 * math.pi / 180.0])) + self.assertEqual(getunit(1, dim=0, unit="deg"), np.array([1 * math.pi / 180.0])) + with self.assertRaises(ValueError): + self.assertEqual( + getunit(1, dim=1, unit="deg"), np.array([1 * math.pi / 180.0]) + ) + + # scalar -> scalar + self.assertEqual(getunit(1, vector=False), 1) + self.assertEqual(getunit(1, dim=0, vector=False), 1) + with self.assertRaises(ValueError): + self.assertEqual(getunit(1, dim=1, vector=False), 1) + + self.assertIsInstance(getunit(1.0, vector=False), float) + self.assertIsInstance(getunit(1, vector=False), int) + + self.assertEqual(getunit(1, vector=False, unit="deg"), 1 * math.pi / 180.0) + self.assertEqual( + getunit(1, dim=0, vector=False, unit="deg"), 1 * math.pi / 180.0 + ) + with self.assertRaises(ValueError): + self.assertEqual( + getunit(1, dim=1, vector=False, unit="deg"), 1 * math.pi / 180.0 + ) + + self.assertIsInstance(getunit(1.0, vector=False, unit="deg"), float) + self.assertIsInstance(getunit(1, vector=False, unit="deg"), float) + + # vector -> vector + self.assertEqual(getunit([1]), np.array([1])) + self.assertEqual(getunit([1], dim=1), np.array([1])) + with self.assertRaises(ValueError): + getunit([1], dim=0) + self.assertIsInstance(getunit([1, 2]), np.ndarray) self.assertIsInstance(getunit((1, 2)), np.ndarray) self.assertIsInstance(getunit(np.r_[1, 2]), np.ndarray) - self.assertIsInstance(getunit(1.0, dim=0), float) nt.assert_equal(getunit(5, "rad"), 5) nt.assert_equal(getunit(5, "deg"), 5 * math.pi / 180.0) diff --git a/tests/base/test_quaternions.py b/tests/base/test_quaternions.py index c512c6d2..0aed943d 100644 --- a/tests/base/test_quaternions.py +++ b/tests/base/test_quaternions.py @@ -227,14 +227,16 @@ def test_r2q(self): def test_qangle(self): # Test function that calculates angle between quaternions - q1 = [1., 0, 0, 0] - q2 = [1 / np.sqrt(2), 0, 1 / np.sqrt(2), 0] # 90deg rotation about y-axis + q1 = [1.0, 0, 0, 0] + q2 = [1 / np.sqrt(2), 0, 1 / np.sqrt(2), 0] # 90deg rotation about y-axis nt.assert_almost_equal(qangle(q1, q2), np.pi / 2) - q1 = [1., 0, 0, 0] - q2 = [1 / np.sqrt(2), 1 / np.sqrt(2), 0, 0] # 90deg rotation about x-axis + q1 = [1.0, 0, 0, 0] + q2 = [1 / np.sqrt(2), 1 / np.sqrt(2), 0, 0] # 90deg rotation about x-axis nt.assert_almost_equal(qangle(q1, q2), np.pi / 2) if __name__ == "__main__": + # run tests with warnings enabled + unittest.main()