Skip to content

Issue 144 #152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 38 additions & 24 deletions spatialmath/base/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,7 @@ def getvector(
dim: Optional[Union[int, None]] = None,
out: str = "array",
dtype: DTypeLike = np.float64,
) -> NDArray:
...
) -> NDArray: ...


@overload
Expand All @@ -293,8 +292,7 @@ def getvector(
dim: Optional[Union[int, None]] = None,
out: str = "list",
dtype: DTypeLike = np.float64,
) -> List[float]:
...
) -> List[float]: ...


@overload
Expand All @@ -303,8 +301,7 @@ def getvector(
dim: Optional[Union[int, None]] = None,
out: str = "sequence",
dtype: DTypeLike = np.float64,
) -> Tuple[float, ...]:
...
) -> Tuple[float, ...]: ...


@overload
Expand All @@ -313,8 +310,7 @@ def getvector(
dim: Optional[Union[int, None]] = None,
out: str = "sequence",
dtype: DTypeLike = np.float64,
) -> List[float]:
...
) -> List[float]: ...


def getvector(
Expand Down Expand Up @@ -522,16 +518,20 @@ def isvector(v: Any, dim: Optional[int] = None) -> bool:
return False


def getunit(v: ArrayLike, unit: str = "rad", dim=None) -> Union[float, NDArray]:
def getunit(
v: ArrayLike, unit: str = "rad", dim: Optional[int] = None, vector: bool = True
) -> Union[float, NDArray]:
"""
Convert values according to angular units

:param v: the value in radians or degrees
:type v: array_like(m)
:param unit: the angular unit, "rad" or "deg"
:type unit: str
:param dim: expected dimension of input, defaults to None
:param dim: expected dimension of input, defaults to don't check (None)
:type dim: int, optional
:param vector: return a scalar as a 1d vector, defaults to True
:type vector: bool, optional
:return: the converted value in radians
:rtype: ndarray(m) or float
:raises ValueError: argument is not a valid angular unit
Expand All @@ -543,30 +543,44 @@ def getunit(v: ArrayLike, unit: str = "rad", dim=None) -> Union[float, NDArray]:
>>> from spatialmath.base import getunit
>>> import numpy as np
>>> getunit(1.5, 'rad')
>>> getunit(1.5, 'rad', dim=0)
>>> # getunit([1.5], 'rad', dim=0) --> ValueError
>>> getunit(90, 'deg')
>>> getunit(90, 'deg', vector=False) # force a scalar output
>>> getunit(1.5, 'rad', dim=0) # check argument is scalar
>>> getunit(1.5, 'rad', dim=3) # check argument is a 3-vector
>>> getunit([1.5], 'rad', dim=1) # check argument is a 1-vector
>>> getunit([1.5], 'rad', dim=3) # check argument is a 3-vector
>>> getunit([90, 180], 'deg')
>>> getunit(np.r_[0.5, 1], 'rad')
>>> getunit(np.r_[90, 180], 'deg')
>>> getunit(np.r_[90, 180], 'deg', dim=2)
>>> # getunit([90, 180], 'deg', dim=3) --> ValueError
>>> getunit(np.r_[90, 180], 'deg', dim=2) # check argument is a 2-vector
>>> getunit([90, 180], 'deg', dim=3) # check argument is a 3-vector

:note:
- the input value is processed by :func:`getvector` and the argument ``dim`` can
be used to check that ``v`` is the desired length.
- the output is always an ndarray except if the input is a scalar and ``dim=0``.
be used to check that ``v`` is the desired length. Note that 0 means a scalar,
whereas 1 means a 1-element array.
- the output is always an ndarray except if the input is a scalar and ``vector=False``.

:seealso: :func:`getvector`
"""
if not isinstance(v, Iterable) and dim == 0:
# scalar in, scalar out
if unit == "rad":
return v
elif unit == "deg":
return np.deg2rad(v)
if not isinstance(v, Iterable):
# scalar input
if dim is not None and dim != 0:
raise ValueError("for dim==0 input must be a scalar")
if vector:
# scalar in, vector out
if unit == "deg":
v = np.deg2rad(v)
elif unit != "rad":
raise ValueError("invalid angular units")
return np.array([v])
else:
raise ValueError("invalid angular units")
# scalar in, scalar out
if unit == "rad":
return v
elif unit == "deg":
return np.deg2rad(v)
else:
raise ValueError("invalid angular units")

else:
# scalar or iterable in, ndarray out
Expand Down
52 changes: 23 additions & 29 deletions spatialmath/base/transforms2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def rot2(theta: float, unit: str = "rad") -> SO2Array:
>>> rot2(0.3)
>>> rot2(45, 'deg')
"""
theta = smb.getunit(theta, unit, dim=0)
theta = smb.getunit(theta, unit, vector=False)
ct = smb.sym.cos(theta)
st = smb.sym.sin(theta)
# fmt: off
Expand Down Expand Up @@ -172,18 +172,15 @@ def tr2xyt(T: SE2Array, unit: str = "rad") -> R3:

# ---------------------------------------------------------------------------------------#
@overload # pragma: no cover
def transl2(x: float, y: float) -> SE2Array:
...
def transl2(x: float, y: float) -> SE2Array: ...


@overload # pragma: no cover
def transl2(x: ArrayLike2) -> SE2Array:
...
def transl2(x: ArrayLike2) -> SE2Array: ...


@overload # pragma: no cover
def transl2(x: SE2Array) -> R2:
...
def transl2(x: SE2Array) -> R2: ...


def transl2(x, y=None):
Expand Down Expand Up @@ -446,8 +443,7 @@ def trlog2(
twist: bool = False,
check: bool = True,
tol: float = 20,
) -> so2Array:
...
) -> so2Array: ...


@overload # pragma: no cover
Expand All @@ -456,8 +452,7 @@ def trlog2(
twist: bool = False,
check: bool = True,
tol: float = 20,
) -> se2Array:
...
) -> se2Array: ...


@overload # pragma: no cover
Expand All @@ -466,8 +461,7 @@ def trlog2(
twist: bool = True,
check: bool = True,
tol: float = 20,
) -> float:
...
) -> float: ...


@overload # pragma: no cover
Expand All @@ -476,8 +470,7 @@ def trlog2(
twist: bool = True,
check: bool = True,
tol: float = 20,
) -> R3:
...
) -> R3: ...


def trlog2(
Expand Down Expand Up @@ -563,13 +556,15 @@ def trlog2(

# ---------------------------------------------------------------------------------------#
@overload # pragma: no cover
def trexp2(S: so2Array, theta: Optional[float] = None, check: bool = True) -> SO2Array:
...
def trexp2(
S: so2Array, theta: Optional[float] = None, check: bool = True
) -> SO2Array: ...


@overload # pragma: no cover
def trexp2(S: se2Array, theta: Optional[float] = None, check: bool = True) -> SE2Array:
...
def trexp2(
S: se2Array, theta: Optional[float] = None, check: bool = True
) -> SE2Array: ...


def trexp2(
Expand Down Expand Up @@ -692,8 +687,7 @@ def trexp2(


@overload # pragma: no cover
def trnorm2(R: SO2Array) -> SO2Array:
...
def trnorm2(R: SO2Array) -> SO2Array: ...


def trnorm2(T: SE2Array) -> SE2Array:
Expand Down Expand Up @@ -758,13 +752,11 @@ def trnorm2(T: SE2Array) -> SE2Array:


@overload # pragma: no cover
def tradjoint2(T: SO2Array) -> R1x1:
...
def tradjoint2(T: SO2Array) -> R1x1: ...


@overload # pragma: no cover
def tradjoint2(T: SE2Array) -> R3x3:
...
def tradjoint2(T: SE2Array) -> R3x3: ...


def tradjoint2(T):
Expand Down Expand Up @@ -853,13 +845,15 @@ def tr2jac2(T: SE2Array) -> R3x3:


@overload
def trinterp2(start: Optional[SO2Array], end: SO2Array, s: float, shortest: bool = True) -> SO2Array:
...
def trinterp2(
start: Optional[SO2Array], end: SO2Array, s: float, shortest: bool = True
) -> SO2Array: ...


@overload
def trinterp2(start: Optional[SE2Array], end: SE2Array, s: float, shortest: bool = True) -> SE2Array:
...
def trinterp2(
start: Optional[SE2Array], end: SE2Array, s: float, shortest: bool = True
) -> SE2Array: ...


def trinterp2(start, end, s, shortest: bool = True):
Expand Down
Loading
Loading