From bc73531a7215234552dd4d74e16c55673bbfa7b2 Mon Sep 17 00:00:00 2001 From: Joel Croteau Date: Thu, 28 Jul 2022 23:13:45 -0700 Subject: [PATCH 1/2] Add option to PhysicsSensors to include angular velocities I find this useful for controlling handles and other objects which could potentially be rapidly spinning. It allows the model to make better predictions of object location. --- .../Sensors/ArticulationBodyPoseExtractor.cs | 6 ++ .../Runtime/Sensors/PhysicsSensorSettings.cs | 42 ++++++++++-- .../Runtime/Sensors/PoseExtractor.cs | 66 ++++++++++++++++++- .../Runtime/Sensors/RigidBodyPoseExtractor.cs | 21 +++++- 4 files changed, 124 insertions(+), 11 deletions(-) diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs index ceec6e0013..65310a2757 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs @@ -63,6 +63,12 @@ protected internal override Vector3 GetLinearVelocityAt(int index) return m_Bodies[index].velocity; } + /// + protected internal override Vector3 GetAngularVelocityAt(int index) + { + return m_Bodies[index].angularVelocity; + } + /// protected internal override Pose GetPoseAt(int index) { diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs index d9f9c0d441..8984685f83 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs @@ -34,11 +34,21 @@ public struct PhysicsSensorSettings /// public bool UseModelSpaceLinearVelocity; + /// + /// Whether to use model space (relative to the root body) angular velocities as observations. + /// + public bool UseModelSpaceAngularVelocity; + /// /// Whether to use local space (relative to the parent body) linear velocities as observations. /// public bool UseLocalSpaceLinearVelocity; + /// + /// Whether to use local space (relative to the parent body) angular velocities as observations. + /// + public bool UseLocalSpaceAngularVelocity; + /// /// Whether to use joint-specific positions and angles as observations. /// @@ -67,7 +77,8 @@ public static PhysicsSensorSettings Default() /// public bool UseModelSpace { - get { return UseModelSpaceTranslations || UseModelSpaceRotations || UseModelSpaceLinearVelocity; } + get { return UseModelSpaceTranslations || UseModelSpaceRotations || UseModelSpaceLinearVelocity || + UseModelSpaceAngularVelocity; } } /// @@ -75,7 +86,8 @@ public bool UseModelSpace /// public bool UseLocalSpace { - get { return UseLocalSpaceTranslations || UseLocalSpaceRotations || UseLocalSpaceLinearVelocity; } + get { return UseLocalSpaceTranslations || UseLocalSpaceRotations || UseLocalSpaceLinearVelocity || + UseLocalSpaceAngularVelocity; } } } @@ -109,9 +121,18 @@ public static int WritePoses(this ObservationWriter writer, PhysicsSensorSetting } } - foreach (var vel in poseExtractor.GetEnabledModelSpaceVelocities()) + if (settings.UseModelSpaceLinearVelocity) { - if (settings.UseModelSpaceLinearVelocity) + foreach (var vel in poseExtractor.GetEnabledModelSpaceVelocities()) + { + writer.Add(vel, offset); + offset += 3; + } + } + + if (settings.UseModelSpaceAngularVelocity) + { + foreach (var vel in poseExtractor.GetEnabledModelSpaceAngularVelocities()) { writer.Add(vel, offset); offset += 3; @@ -136,9 +157,18 @@ public static int WritePoses(this ObservationWriter writer, PhysicsSensorSetting } } - foreach (var vel in poseExtractor.GetEnabledLocalSpaceVelocities()) + if (settings.UseLocalSpaceLinearVelocity) + { + foreach (var vel in poseExtractor.GetEnabledLocalSpaceVelocities()) + { + writer.Add(vel, offset); + offset += 3; + } + } + + if (settings.UseLocalSpaceAngularVelocity) { - if (settings.UseLocalSpaceLinearVelocity) + foreach (var vel in poseExtractor.GetEnabledLocalSpaceAngularVelocities()) { writer.Add(vel, offset); offset += 3; diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs index 059804377b..b673589ec5 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs @@ -21,7 +21,9 @@ public abstract class PoseExtractor Pose[] m_LocalSpacePoses; Vector3[] m_ModelSpaceLinearVelocities; + Vector3[] m_ModelSpaceAngularVelocities; Vector3[] m_LocalSpaceLinearVelocities; + Vector3[] m_LocalSpaceAngularVelocities; bool[] m_PoseEnabled; @@ -83,6 +85,25 @@ public IEnumerable GetEnabledModelSpaceVelocities() } } + /// + /// Read iterator for the enabled model space angular velocities. + /// + public IEnumerable GetEnabledModelSpaceAngularVelocities() + { + if (m_ModelSpaceAngularVelocities == null) + { + yield break; + } + + for (var i = 0; i < m_ModelSpaceAngularVelocities.Length; i++) + { + if (m_PoseEnabled[i]) + { + yield return m_ModelSpaceAngularVelocities[i]; + } + } + } + /// /// Read iterator for the enabled local space linear velocities. /// @@ -102,6 +123,25 @@ public IEnumerable GetEnabledLocalSpaceVelocities() } } + /// + /// Read iterator for the enabled local space angular velocities. + /// + public IEnumerable GetEnabledLocalSpaceAngularVelocities() + { + if (m_LocalSpaceAngularVelocities == null) + { + yield break; + } + + for (var i = 0; i < m_LocalSpaceAngularVelocities.Length; i++) + { + if (m_PoseEnabled[i]) + { + yield return m_LocalSpaceAngularVelocities[i]; + } + } + } + /// /// Number of enabled poses in the hierarchy (read-only). /// @@ -181,7 +221,9 @@ protected void Setup(int[] parentIndices) m_LocalSpacePoses = new Pose[numPoses]; m_ModelSpaceLinearVelocities = new Vector3[numPoses]; + m_ModelSpaceAngularVelocities = new Vector3[numPoses]; m_LocalSpaceLinearVelocities = new Vector3[numPoses]; + m_LocalSpaceAngularVelocities = new Vector3[numPoses]; m_PoseEnabled = new bool[numPoses]; // All poses are enabled by default. Generally we'll want to disable the root though. @@ -205,6 +247,13 @@ protected void Setup(int[] parentIndices) /// protected internal abstract Vector3 GetLinearVelocityAt(int index); + /// + /// Return the world space angular velocity of the i'th object. + /// + /// + /// + protected internal abstract Vector3 GetAngularVelocityAt(int index); + /// /// Return the underlying object at the given index. This is only /// used for display in the inspector. @@ -232,6 +281,7 @@ public void UpdateModelSpacePoses() var rootWorldTransform = GetPoseAt(0); var worldToModel = rootWorldTransform.Inverse(); var rootLinearVel = GetLinearVelocityAt(0); + var rootAngularVel = GetAngularVelocityAt(0); for (var i = 0; i < m_ModelSpacePoses.Length; i++) { @@ -240,8 +290,11 @@ public void UpdateModelSpacePoses() m_ModelSpacePoses[i] = currentModelSpacePose; var currentBodyLinearVel = GetLinearVelocityAt(i); - var relativeVelocity = currentBodyLinearVel - rootLinearVel; - m_ModelSpaceLinearVelocities[i] = worldToModel.rotation * relativeVelocity; + var relativeLinearVel = currentBodyLinearVel - rootLinearVel; + m_ModelSpaceLinearVelocities[i] = worldToModel.rotation * relativeLinearVel; + var currentBodyAngularVel = GetAngularVelocityAt(i); + var relativeAngularVel = currentBodyAngularVel - rootAngularVel; + m_ModelSpaceAngularVelocities[i] = worldToModel.rotation * relativeAngularVel; } } } @@ -272,11 +325,15 @@ public void UpdateLocalSpacePoses() var parentLinearVel = GetLinearVelocityAt(m_ParentIndices[i]); var currentLinearVel = GetLinearVelocityAt(i); m_LocalSpaceLinearVelocities[i] = invParent.rotation * (currentLinearVel - parentLinearVel); + var parentAngularVel = GetAngularVelocityAt(m_ParentIndices[i]); + var currentAngularVel = GetAngularVelocityAt(i); + m_LocalSpaceAngularVelocities[i] = invParent.rotation * (currentAngularVel - parentAngularVel); } else { m_LocalSpacePoses[i] = Pose.identity; m_LocalSpaceLinearVelocities[i] = Vector3.zero; + m_LocalSpaceAngularVelocities[i] = Vector3.zero; } } } @@ -296,7 +353,9 @@ public int GetNumPoseObservations(PhysicsSensorSettings settings) obsPerPose += settings.UseLocalSpaceRotations ? 4 : 0; obsPerPose += settings.UseModelSpaceLinearVelocity ? 3 : 0; + obsPerPose += settings.UseModelSpaceAngularVelocity ? 3 : 0; obsPerPose += settings.UseLocalSpaceLinearVelocity ? 3 : 0; + obsPerPose += settings.UseLocalSpaceAngularVelocity ? 3 : 0; return NumEnabledPoses * obsPerPose; } @@ -363,6 +422,7 @@ internal IList GetDisplayNodes() { return Array.Empty(); } + var nodesOut = new List(NumPoses); // List of children for each node @@ -379,6 +439,7 @@ internal IList GetDisplayNodes() { tree[parent] = new List(); } + tree[parent].Add(i); } @@ -422,7 +483,6 @@ internal IList GetDisplayNodes() return nodesOut; } - } /// diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs index b54b0b5713..d47301a5b2 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs @@ -3,7 +3,6 @@ namespace Unity.MLAgents.Extensions.Sensors { - /// /// Utility class to track a hierarchy of RigidBodies. These are assumed to have a root node, /// and child nodes are connect to their parents via Joints. @@ -129,9 +128,22 @@ protected internal override Vector3 GetLinearVelocityAt(int index) // No velocity on the virtual root return Vector3.zero; } + return m_Bodies[index].velocity; } + /// + protected internal override Vector3 GetAngularVelocityAt(int index) + { + if (index == 0 && m_VirtualRoot != null) + { + // No velocity on the virtual root + return Vector3.zero; + } + + return m_Bodies[index].angularVelocity; + } + /// protected internal override Pose GetPoseAt(int index) { @@ -156,6 +168,7 @@ protected internal override Object GetObjectAt(int index) { return m_VirtualRoot; } + return m_Bodies[index]; } @@ -167,6 +180,11 @@ protected internal override Object GetObjectAt(int index) /// internal Dictionary GetBodyPosesEnabled() { + if (m_Bodies == null) + { + return new Dictionary(); + } + var bodyPosesEnabled = new Dictionary(m_Bodies.Length); for (var i = 0; i < m_Bodies.Length; i++) { @@ -205,5 +223,4 @@ internal IEnumerable GetEnabledRigidbodies() } } } - } From ff2f7a2a77ec0cf75faf06cc86c6bbb6a0b80f73 Mon Sep 17 00:00:00 2001 From: Joel Croteau Date: Wed, 17 Aug 2022 14:35:54 -0700 Subject: [PATCH 2/2] Add stubs to PoseExtractorTests for angular velocity methods --- .../Tests/Runtime/Sensors/PoseExtractorTests.cs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/PoseExtractorTests.cs b/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/PoseExtractorTests.cs index 782b7da3a9..f7a7c32852 100644 --- a/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/PoseExtractorTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/PoseExtractorTests.cs @@ -19,6 +19,11 @@ protected internal override Vector3 GetLinearVelocityAt(int index) { return Vector3.zero; } + + protected internal override Vector3 GetAngularVelocityAt(int index) + { + return Vector3.zero; + } } class UselessPoseExtractor : BasicPoseExtractor @@ -114,6 +119,10 @@ protected internal override Vector3 GetLinearVelocityAt(int index) return Vector3.zero; } + protected internal override Vector3 GetAngularVelocityAt(int index) + { + return Vector3.zero; + } } [Test]