+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+from__future__importannotations
+
+importtorch
+fromabcimportABC,abstractmethod
+fromcollections.abcimportSequence
+fromtypingimportTYPE_CHECKING
+
+fromomni.isaac.core.utils.typesimportArticulationActions
+
+importomni.isaac.lab.utils.stringasstring_utils
+
+ifTYPE_CHECKING:
+ from.actuator_cfgimportActuatorBaseCfg
+
+
+
[文档]classActuatorBase(ABC):
+"""Base class for actuator models over a collection of actuated joints in an articulation.
+
+ Actuator models augment the simulated articulation joints with an external drive dynamics model.
+ The model is used to convert the user-provided joint commands (positions, velocities and efforts)
+ into the desired joint positions, velocities and efforts that are applied to the simulated articulation.
+
+ The base class provides the interface for the actuator models. It is responsible for parsing the
+ actuator parameters from the configuration and storing them as buffers. It also provides the
+ interface for resetting the actuator state and computing the desired joint commands for the simulation.
+
+ For each actuator model, a corresponding configuration class is provided. The configuration class
+ is used to parse the actuator parameters from the configuration. It also specifies the joint names
+ for which the actuator model is applied. These names can be specified as regular expressions, which
+ are matched against the joint names in the articulation.
+
+ To see how the class is used, check the :class:`omni.isaac.lab.assets.Articulation` class.
+ """
+
+ computed_effort:torch.Tensor
+"""The computed effort for the actuator group. Shape is (num_envs, num_joints)."""
+ applied_effort:torch.Tensor
+"""The applied effort for the actuator group. Shape is (num_envs, num_joints)."""
+ effort_limit:torch.Tensor
+"""The effort limit for the actuator group. Shape is (num_envs, num_joints)."""
+ velocity_limit:torch.Tensor
+"""The velocity limit for the actuator group. Shape is (num_envs, num_joints)."""
+ stiffness:torch.Tensor
+"""The stiffness (P gain) of the PD controller. Shape is (num_envs, num_joints)."""
+ damping:torch.Tensor
+"""The damping (D gain) of the PD controller. Shape is (num_envs, num_joints)."""
+ armature:torch.Tensor
+"""The armature of the actuator joints. Shape is (num_envs, num_joints)."""
+ friction:torch.Tensor
+"""The joint friction of the actuator joints. Shape is (num_envs, num_joints)."""
+
+
[文档]def__init__(
+ self,
+ cfg:ActuatorBaseCfg,
+ joint_names:list[str],
+ joint_ids:slice|Sequence[int],
+ num_envs:int,
+ device:str,
+ stiffness:torch.Tensor|float=0.0,
+ damping:torch.Tensor|float=0.0,
+ armature:torch.Tensor|float=0.0,
+ friction:torch.Tensor|float=0.0,
+ effort_limit:torch.Tensor|float=torch.inf,
+ velocity_limit:torch.Tensor|float=torch.inf,
+ ):
+"""Initialize the actuator.
+
+ Note:
+ The actuator parameters are parsed from the configuration and stored as buffers. If the parameters
+ are not specified in the configuration, then the default values provided in the arguments are used.
+
+ Args:
+ cfg: The configuration of the actuator model.
+ joint_names: The joint names in the articulation.
+ joint_ids: The joint indices in the articulation. If :obj:`slice(None)`, then all
+ the joints in the articulation are part of the group.
+ num_envs: Number of articulations in the view.
+ device: Device used for processing.
+ stiffness: The default joint stiffness (P gain). Defaults to 0.0.
+ If a tensor, then the shape is (num_envs, num_joints).
+ damping: The default joint damping (D gain). Defaults to 0.0.
+ If a tensor, then the shape is (num_envs, num_joints).
+ armature: The default joint armature. Defaults to 0.0.
+ If a tensor, then the shape is (num_envs, num_joints).
+ friction: The default joint friction. Defaults to 0.0.
+ If a tensor, then the shape is (num_envs, num_joints).
+ effort_limit: The default effort limit. Defaults to infinity.
+ If a tensor, then the shape is (num_envs, num_joints).
+ velocity_limit: The default velocity limit. Defaults to infinity.
+ If a tensor, then the shape is (num_envs, num_joints).
+ """
+ # save parameters
+ self.cfg=cfg
+ self._num_envs=num_envs
+ self._device=device
+ self._joint_names=joint_names
+ self._joint_indices=joint_ids
+
+ # parse joint stiffness and damping
+ self.stiffness=self._parse_joint_parameter(self.cfg.stiffness,stiffness)
+ self.damping=self._parse_joint_parameter(self.cfg.damping,damping)
+ # parse joint armature and friction
+ self.armature=self._parse_joint_parameter(self.cfg.armature,armature)
+ self.friction=self._parse_joint_parameter(self.cfg.friction,friction)
+ # parse joint limits
+ # note: for velocity limits, we don't have USD parameter, so default is infinity
+ self.effort_limit=self._parse_joint_parameter(self.cfg.effort_limit,effort_limit)
+ self.velocity_limit=self._parse_joint_parameter(self.cfg.velocity_limit,velocity_limit)
+
+ # create commands buffers for allocation
+ self.computed_effort=torch.zeros(self._num_envs,self.num_joints,device=self._device)
+ self.applied_effort=torch.zeros_like(self.computed_effort)
+
+ def__str__(self)->str:
+"""Returns: A string representation of the actuator group."""
+ # resolve joint indices for printing
+ joint_indices=self.joint_indices
+ ifjoint_indices==slice(None):
+ joint_indices=list(range(self.num_joints))
+ return(
+ f"<class {self.__class__.__name__}> object:\n"
+ f"\tNumber of joints : {self.num_joints}\n"
+ f"\tJoint names expression: {self.cfg.joint_names_expr}\n"
+ f"\tJoint names : {self.joint_names}\n"
+ f"\tJoint indices : {joint_indices}\n"
+ )
+
+"""
+ Properties.
+ """
+
+ @property
+ defnum_joints(self)->int:
+"""Number of actuators in the group."""
+ returnlen(self._joint_names)
+
+ @property
+ defjoint_names(self)->list[str]:
+"""Articulation's joint names that are part of the group."""
+ returnself._joint_names
+
+ @property
+ defjoint_indices(self)->slice|Sequence[int]:
+"""Articulation's joint indices that are part of the group.
+
+ Note:
+ If :obj:`slice(None)` is returned, then the group contains all the joints in the articulation.
+ We do this to avoid unnecessary indexing of the joints for performance reasons.
+ """
+ returnself._joint_indices
+
+"""
+ Operations.
+ """
+
+
[文档]@abstractmethod
+ defreset(self,env_ids:Sequence[int]):
+"""Reset the internals within the group.
+
+ Args:
+ env_ids: List of environment IDs to reset.
+ """
+ raiseNotImplementedError
+
+
[文档]@abstractmethod
+ defcompute(
+ self,control_action:ArticulationActions,joint_pos:torch.Tensor,joint_vel:torch.Tensor
+ )->ArticulationActions:
+"""Process the actuator group actions and compute the articulation actions.
+
+ It computes the articulation actions based on the actuator model type
+
+ Args:
+ control_action: The joint action instance comprising of the desired joint positions, joint velocities
+ and (feed-forward) joint efforts.
+ joint_pos: The current joint positions of the joints in the group. Shape is (num_envs, num_joints).
+ joint_vel: The current joint velocities of the joints in the group. Shape is (num_envs, num_joints).
+
+ Returns:
+ The computed desired joint positions, joint velocities and joint efforts.
+ """
+ raiseNotImplementedError
+
+"""
+ Helper functions.
+ """
+
+ def_parse_joint_parameter(
+ self,cfg_value:float|dict[str,float]|None,default_value:float|torch.Tensor|None
+ )->torch.Tensor:
+"""Parse the joint parameter from the configuration.
+
+ Args:
+ cfg_value: The parameter value from the configuration. If None, then use the default value.
+ default_value: The default value to use if the parameter is None. If it is also None,
+ then an error is raised.
+
+ Returns:
+ The parsed parameter value.
+
+ Raises:
+ TypeError: If the parameter value is not of the expected type.
+ TypeError: If the default value is not of the expected type.
+ ValueError: If the parameter value is None and no default value is provided.
+ """
+ # create parameter buffer
+ param=torch.zeros(self._num_envs,self.num_joints,device=self._device)
+ # parse the parameter
+ ifcfg_valueisnotNone:
+ ifisinstance(cfg_value,(float,int)):
+ # if float, then use the same value for all joints
+ param[:]=float(cfg_value)
+ elifisinstance(cfg_value,dict):
+ # if dict, then parse the regular expression
+ indices,_,values=string_utils.resolve_matching_names_values(cfg_value,self.joint_names)
+ # note: need to specify type to be safe (e.g. values are ints, but we want floats)
+ param[:,indices]=torch.tensor(values,dtype=torch.float,device=self._device)
+ else:
+ raiseTypeError(f"Invalid type for parameter value: {type(cfg_value)}. Expected float or dict.")
+ elifdefault_valueisnotNone:
+ ifisinstance(default_value,(float,int)):
+ # if float, then use the same value for all joints
+ param[:]=float(default_value)
+ elifisinstance(default_value,torch.Tensor):
+ # if tensor, then use the same tensor for all joints
+ param[:]=default_value.float()
+ else:
+ raiseTypeError(f"Invalid type for default value: {type(default_value)}. Expected float or Tensor.")
+ else:
+ raiseValueError("The parameter value is None and no default value is provided.")
+
+ returnparam
+
+ def_clip_effort(self,effort:torch.Tensor)->torch.Tensor:
+"""Clip the desired torques based on the motor limits.
+
+ Args:
+ desired_torques: The desired torques to clip.
+
+ Returns:
+ The clipped torques.
+ """
+ returntorch.clip(effort,min=-self.effort_limit,max=self.effort_limit)
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+importtorch
+fromcollections.abcimportIterable
+fromdataclassesimportMISSING
+fromtypingimportLiteral
+
+fromomni.isaac.lab.utilsimportconfigclass
+
+from.importactuator_net,actuator_pd
+from.actuator_baseimportActuatorBase
+
+
+
[文档]@configclass
+classActuatorBaseCfg:
+"""Configuration for default actuators in an articulation."""
+
+ class_type:type[ActuatorBase]=MISSING
+"""The associated actuator class.
+
+ The class should inherit from :class:`omni.isaac.lab.actuators.ActuatorBase`.
+ """
+
+ joint_names_expr:list[str]=MISSING
+"""Articulation's joint names that are part of the group.
+
+ Note:
+ This can be a list of joint names or a list of regex expressions (e.g. ".*").
+ """
+
+ effort_limit:dict[str,float]|float|None=None
+"""Force/Torque limit of the joints in the group. Defaults to None.
+
+ If None, the limit is set to the value specified in the USD joint prim.
+ """
+
+ velocity_limit:dict[str,float]|float|None=None
+"""Velocity limit of the joints in the group. Defaults to None.
+
+ If None, the limit is set to the value specified in the USD joint prim.
+ """
+
+ stiffness:dict[str,float]|float|None=MISSING
+"""Stiffness gains (also known as p-gain) of the joints in the group.
+
+ If None, the stiffness is set to the value from the USD joint prim.
+ """
+
+ damping:dict[str,float]|float|None=MISSING
+"""Damping gains (also known as d-gain) of the joints in the group.
+
+ If None, the damping is set to the value from the USD joint prim.
+ """
+
+ armature:dict[str,float]|float|None=None
+"""Armature of the joints in the group. Defaults to None.
+
+ If None, the armature is set to the value from the USD joint prim.
+ """
+
+ friction:dict[str,float]|float|None=None
+"""Joint friction of the joints in the group. Defaults to None.
+
+ If None, the joint friction is set to the value from the USD joint prim.
+ """
+
+
+"""
+Implicit Actuator Models.
+"""
+
+
+
[文档]@configclass
+classImplicitActuatorCfg(ActuatorBaseCfg):
+"""Configuration for an implicit actuator.
+
+ Note:
+ The PD control is handled implicitly by the simulation.
+ """
+
+ class_type:type=actuator_pd.ImplicitActuator
+
+
+"""
+Explicit Actuator Models.
+"""
+
+
+
[文档]@configclass
+classIdealPDActuatorCfg(ActuatorBaseCfg):
+"""Configuration for an ideal PD actuator."""
+
+ class_type:type=actuator_pd.IdealPDActuator
+
+
+
[文档]@configclass
+classDCMotorCfg(IdealPDActuatorCfg):
+"""Configuration for direct control (DC) motor actuator model."""
+
+ class_type:type=actuator_pd.DCMotor
+
+ saturation_effort:float=MISSING
+"""Peak motor force/torque of the electric DC motor (in N-m)."""
+
+
+
[文档]@configclass
+classActuatorNetLSTMCfg(DCMotorCfg):
+"""Configuration for LSTM-based actuator model."""
+
+ class_type:type=actuator_net.ActuatorNetLSTM
+ # we don't use stiffness and damping for actuator net
+ stiffness=None
+ damping=None
+
+ network_file:str=MISSING
+"""Path to the file containing network weights."""
+
+
+
[文档]@configclass
+classActuatorNetMLPCfg(DCMotorCfg):
+"""Configuration for MLP-based actuator model."""
+
+ class_type:type=actuator_net.ActuatorNetMLP
+ # we don't use stiffness and damping for actuator net
+ stiffness=None
+ damping=None
+
+ network_file:str=MISSING
+"""Path to the file containing network weights."""
+
+ pos_scale:float=MISSING
+"""Scaling of the joint position errors input to the network."""
+ vel_scale:float=MISSING
+"""Scaling of the joint velocities input to the network."""
+ torque_scale:float=MISSING
+"""Scaling of the joint efforts output from the network."""
+
+ input_order:Literal["pos_vel","vel_pos"]=MISSING
+"""Order of the inputs to the network.
+
+ The order can be one of the following:
+
+ * ``"pos_vel"``: joint position errors followed by joint velocities
+ * ``"vel_pos"``: joint velocities followed by joint position errors
+ """
+
+ input_idx:Iterable[int]=MISSING
+"""
+ Indices of the actuator history buffer passed as inputs to the network.
+
+ The index *0* corresponds to current time-step, while *n* corresponds to n-th
+ time-step in the past. The allocated history length is `max(input_idx) + 1`.
+ """
+
+
+
[文档]@configclass
+classDelayedPDActuatorCfg(IdealPDActuatorCfg):
+"""Configuration for a delayed PD actuator."""
+
+ class_type:type=actuator_pd.DelayedPDActuator
+
+ min_delay:int=0
+"""Minimum number of physics time-steps with which the actuator command may be delayed. Defaults to 0."""
+
+ max_delay:int=0
+"""Maximum number of physics time-steps with which the actuator command may be delayed. Defaults to 0."""
+
+
+
[文档]@configclass
+classRemotizedPDActuatorCfg(DelayedPDActuatorCfg):
+"""Configuration for a remotized PD actuator.
+
+ Note:
+ The torque output limits for this actuator is derived from a linear interpolation of a lookup table
+ in :attr:`joint_parameter_lookup`. This table describes the relationship between joint angles and
+ the output torques.
+ """
+
+ class_type:type=actuator_pd.RemotizedPDActuator
+
+ joint_parameter_lookup:torch.Tensor=MISSING
+"""Joint parameter lookup table. Shape is (num_lookup_points, 3).
+
+ This tensor describes the relationship between the joint angle (rad), the transmission ratio (in/out),
+ and the output torque (N*m). The table is used to interpolate the output torque based on the joint angle.
+ """
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Neural network models for actuators.
+
+Currently, the following models are supported:
+
+* Multi-Layer Perceptron (MLP)
+* Long Short-Term Memory (LSTM)
+
+"""
+
+from__future__importannotations
+
+importtorch
+fromcollections.abcimportSequence
+fromtypingimportTYPE_CHECKING
+
+fromomni.isaac.core.utils.typesimportArticulationActions
+
+fromomni.isaac.lab.utils.assetsimportread_file
+
+from.actuator_pdimportDCMotor
+
+ifTYPE_CHECKING:
+ from.actuator_cfgimportActuatorNetLSTMCfg,ActuatorNetMLPCfg
+
+
+
[文档]classActuatorNetLSTM(DCMotor):
+"""Actuator model based on recurrent neural network (LSTM).
+
+ Unlike the MLP implementation :cite:t:`hwangbo2019learning`, this class implements
+ the learned model as a temporal neural network (LSTM) based on the work from
+ :cite:t:`rudin2022learning`. This removes the need of storing a history as the
+ hidden states of the recurrent network captures the history.
+
+ Note:
+ Only the desired joint positions are used as inputs to the network.
+ """
+
+ cfg:ActuatorNetLSTMCfg
+"""The configuration of the actuator model."""
+
+
[文档]def__init__(self,cfg:ActuatorNetLSTMCfg,*args,**kwargs):
+ super().__init__(cfg,*args,**kwargs)
+
+ # load the model from JIT file
+ file_bytes=read_file(self.cfg.network_file)
+ self.network=torch.jit.load(file_bytes,map_location=self._device)
+
+ # extract number of lstm layers and hidden dim from the shape of weights
+ num_layers=len(self.network.lstm.state_dict())//4
+ hidden_dim=self.network.lstm.state_dict()["weight_hh_l0"].shape[1]
+ # create buffers for storing LSTM inputs
+ self.sea_input=torch.zeros(self._num_envs*self.num_joints,1,2,device=self._device)
+ self.sea_hidden_state=torch.zeros(
+ num_layers,self._num_envs*self.num_joints,hidden_dim,device=self._device
+ )
+ self.sea_cell_state=torch.zeros(num_layers,self._num_envs*self.num_joints,hidden_dim,device=self._device)
+ # reshape via views (doesn't change the actual memory layout)
+ layer_shape_per_env=(num_layers,self._num_envs,self.num_joints,hidden_dim)
+ self.sea_hidden_state_per_env=self.sea_hidden_state.view(layer_shape_per_env)
+ self.sea_cell_state_per_env=self.sea_cell_state.view(layer_shape_per_env)
+
+"""
+ Operations.
+ """
+
+
[文档]defreset(self,env_ids:Sequence[int]):
+ # reset the hidden and cell states for the specified environments
+ withtorch.no_grad():
+ self.sea_hidden_state_per_env[:,env_ids]=0.0
+ self.sea_cell_state_per_env[:,env_ids]=0.0
+
+
[文档]defcompute(
+ self,control_action:ArticulationActions,joint_pos:torch.Tensor,joint_vel:torch.Tensor
+ )->ArticulationActions:
+ # compute network inputs
+ self.sea_input[:,0,0]=(control_action.joint_positions-joint_pos).flatten()
+ self.sea_input[:,0,1]=joint_vel.flatten()
+ # save current joint vel for dc-motor clipping
+ self._joint_vel[:]=joint_vel
+
+ # run network inference
+ withtorch.inference_mode():
+ torques,(self.sea_hidden_state[:],self.sea_cell_state[:])=self.network(
+ self.sea_input,(self.sea_hidden_state,self.sea_cell_state)
+ )
+ self.computed_effort=torques.reshape(self._num_envs,self.num_joints)
+
+ # clip the computed effort based on the motor limits
+ self.applied_effort=self._clip_effort(self.computed_effort)
+
+ # return torques
+ control_action.joint_efforts=self.applied_effort
+ control_action.joint_positions=None
+ control_action.joint_velocities=None
+ returncontrol_action
+
+
+
[文档]classActuatorNetMLP(DCMotor):
+"""Actuator model based on multi-layer perceptron and joint history.
+
+ Many times the analytical model is not sufficient to capture the actuator dynamics, the
+ delay in the actuator response, or the non-linearities in the actuator. In these cases,
+ a neural network model can be used to approximate the actuator dynamics. This model is
+ trained using data collected from the physical actuator and maps the joint state and the
+ desired joint command to the produced torque by the actuator.
+
+ This class implements the learned model as a neural network based on the work from
+ :cite:t:`hwangbo2019learning`. The class stores the history of the joint positions errors
+ and velocities which are used to provide input to the neural network. The model is loaded
+ as a TorchScript.
+
+ Note:
+ Only the desired joint positions are used as inputs to the network.
+
+ """
+
+ cfg:ActuatorNetMLPCfg
+"""The configuration of the actuator model."""
+
+
[文档]def__init__(self,cfg:ActuatorNetMLPCfg,*args,**kwargs):
+ super().__init__(cfg,*args,**kwargs)
+
+ # load the model from JIT file
+ file_bytes=read_file(self.cfg.network_file)
+ self.network=torch.jit.load(file_bytes,map_location=self._device)
+
+ # create buffers for MLP history
+ history_length=max(self.cfg.input_idx)+1
+ self._joint_pos_error_history=torch.zeros(
+ self._num_envs,history_length,self.num_joints,device=self._device
+ )
+ self._joint_vel_history=torch.zeros(self._num_envs,history_length,self.num_joints,device=self._device)
+
+"""
+ Operations.
+ """
+
+
[文档]defreset(self,env_ids:Sequence[int]):
+ # reset the history for the specified environments
+ self._joint_pos_error_history[env_ids]=0.0
+ self._joint_vel_history[env_ids]=0.0
+
+
[文档]defcompute(
+ self,control_action:ArticulationActions,joint_pos:torch.Tensor,joint_vel:torch.Tensor
+ )->ArticulationActions:
+ # move history queue by 1 and update top of history
+ # -- positions
+ self._joint_pos_error_history=self._joint_pos_error_history.roll(1,1)
+ self._joint_pos_error_history[:,0]=control_action.joint_positions-joint_pos
+ # -- velocity
+ self._joint_vel_history=self._joint_vel_history.roll(1,1)
+ self._joint_vel_history[:,0]=joint_vel
+ # save current joint vel for dc-motor clipping
+ self._joint_vel[:]=joint_vel
+
+ # compute network inputs
+ # -- positions
+ pos_input=torch.cat([self._joint_pos_error_history[:,i].unsqueeze(2)foriinself.cfg.input_idx],dim=2)
+ pos_input=pos_input.view(self._num_envs*self.num_joints,-1)
+ # -- velocity
+ vel_input=torch.cat([self._joint_vel_history[:,i].unsqueeze(2)foriinself.cfg.input_idx],dim=2)
+ vel_input=vel_input.view(self._num_envs*self.num_joints,-1)
+ # -- scale and concatenate inputs
+ ifself.cfg.input_order=="pos_vel":
+ network_input=torch.cat([pos_input*self.cfg.pos_scale,vel_input*self.cfg.vel_scale],dim=1)
+ elifself.cfg.input_order=="vel_pos":
+ network_input=torch.cat([vel_input*self.cfg.vel_scale,pos_input*self.cfg.pos_scale],dim=1)
+ else:
+ raiseValueError(
+ f"Invalid input order for MLP actuator net: {self.cfg.input_order}. Must be 'pos_vel' or 'vel_pos'."
+ )
+
+ # run network inference
+ torques=self.network(network_input).view(self._num_envs,self.num_joints)
+ self.computed_effort=torques.view(self._num_envs,self.num_joints)*self.cfg.torque_scale
+
+ # clip the computed effort based on the motor limits
+ self.applied_effort=self._clip_effort(self.computed_effort)
+
+ # return torques
+ control_action.joint_efforts=self.applied_effort
+ control_action.joint_positions=None
+ control_action.joint_velocities=None
+ returncontrol_action
[文档]classImplicitActuator(ActuatorBase):
+"""Implicit actuator model that is handled by the simulation.
+
+ This performs a similar function as the :class:`IdealPDActuator` class. However, the PD control is handled
+ implicitly by the simulation which performs continuous-time integration of the PD control law. This is
+ generally more accurate than the explicit PD control law used in :class:`IdealPDActuator` when the simulation
+ time-step is large.
+
+ .. note::
+
+ The articulation class sets the stiffness and damping parameters from the configuration into the simulation.
+ Thus, the parameters are not used in this class.
+
+ .. caution::
+
+ The class is only provided for consistency with the other actuator models. It does not implement any
+ functionality and should not be used. All values should be set to the simulation directly.
+ """
+
+ cfg:ImplicitActuatorCfg
+"""The configuration for the actuator model."""
+
+"""
+ Operations.
+ """
+
+
[文档]defreset(self,*args,**kwargs):
+ # This is a no-op. There is no state to reset for implicit actuators.
+ pass
+
+
[文档]defcompute(
+ self,control_action:ArticulationActions,joint_pos:torch.Tensor,joint_vel:torch.Tensor
+ )->ArticulationActions:
+"""Compute the aproximmate torques for the actuated joint (physX does not compute this explicitly)."""
+ # store approximate torques for reward computation
+ error_pos=control_action.joint_positions-joint_pos
+ error_vel=control_action.joint_velocities-joint_vel
+ self.computed_effort=self.stiffness*error_pos+self.damping*error_vel+control_action.joint_efforts
+ # clip the torques based on the motor limits
+ self.applied_effort=self._clip_effort(self.computed_effort)
+ returncontrol_action
+
+
+"""
+Explicit Actuator Models.
+"""
+
+
+
[文档]classIdealPDActuator(ActuatorBase):
+r"""Ideal torque-controlled actuator model with a simple saturation model.
+
+ It employs the following model for computing torques for the actuated joint :math:`j`:
+
+ .. math::
+
+ \tau_{j, computed} = k_p * (q - q_{des}) + k_d * (\dot{q} - \dot{q}_{des}) + \tau_{ff}
+
+ where, :math:`k_p` and :math:`k_d` are joint stiffness and damping gains, :math:`q` and :math:`\dot{q}`
+ are the current joint positions and velocities, :math:`q_{des}`, :math:`\dot{q}_{des}` and :math:`\tau_{ff}`
+ are the desired joint positions, velocities and torques commands.
+
+ The clipping model is based on the maximum torque applied by the motor. It is implemented as:
+
+ .. math::
+
+ \tau_{j, max} & = \gamma \times \tau_{motor, max} \\
+ \tau_{j, applied} & = clip(\tau_{computed}, -\tau_{j, max}, \tau_{j, max})
+
+ where the clipping function is defined as :math:`clip(x, x_{min}, x_{max}) = min(max(x, x_{min}), x_{max})`.
+ The parameters :math:`\gamma` is the gear ratio of the gear box connecting the motor and the actuated joint ends,
+ and :math:`\tau_{motor, max}` is the maximum motor effort possible. These parameters are read from
+ the configuration instance passed to the class.
+ """
+
+ cfg:IdealPDActuatorCfg
+"""The configuration for the actuator model."""
+
+"""
+ Operations.
+ """
+
+
[文档]defcompute(
+ self,control_action:ArticulationActions,joint_pos:torch.Tensor,joint_vel:torch.Tensor
+ )->ArticulationActions:
+ # compute errors
+ error_pos=control_action.joint_positions-joint_pos
+ error_vel=control_action.joint_velocities-joint_vel
+ # calculate the desired joint torques
+ self.computed_effort=self.stiffness*error_pos+self.damping*error_vel+control_action.joint_efforts
+ # clip the torques based on the motor limits
+ self.applied_effort=self._clip_effort(self.computed_effort)
+ # set the computed actions back into the control action
+ control_action.joint_efforts=self.applied_effort
+ control_action.joint_positions=None
+ control_action.joint_velocities=None
+ returncontrol_action
+
+
+
[文档]classDCMotor(IdealPDActuator):
+r"""Direct control (DC) motor actuator model with velocity-based saturation model.
+
+ It uses the same model as the :class:`IdealActuator` for computing the torques from input commands.
+ However, it implements a saturation model defined by DC motor characteristics.
+
+ A DC motor is a type of electric motor that is powered by direct current electricity. In most cases,
+ the motor is connected to a constant source of voltage supply, and the current is controlled by a rheostat.
+ Depending on various design factors such as windings and materials, the motor can draw a limited maximum power
+ from the electronic source, which limits the produced motor torque and speed.
+
+ A DC motor characteristics are defined by the following parameters:
+
+ * Continuous-rated speed (:math:`\dot{q}_{motor, max}`) : The maximum-rated speed of the motor.
+ * Continuous-stall torque (:math:`\tau_{motor, max}`): The maximum-rated torque produced at 0 speed.
+ * Saturation torque (:math:`\tau_{motor, sat}`): The maximum torque that can be outputted for a short period.
+
+ Based on these parameters, the instantaneous minimum and maximum torques are defined as follows:
+
+ .. math::
+
+ \tau_{j, max}(\dot{q}) & = clip \left (\tau_{j, sat} \times \left(1 -
+ \frac{\dot{q}}{\dot{q}_{j, max}}\right), 0.0, \tau_{j, max} \right) \\
+ \tau_{j, min}(\dot{q}) & = clip \left (\tau_{j, sat} \times \left( -1 -
+ \frac{\dot{q}}{\dot{q}_{j, max}}\right), - \tau_{j, max}, 0.0 \right)
+
+ where :math:`\gamma` is the gear ratio of the gear box connecting the motor and the actuated joint ends,
+ :math:`\dot{q}_{j, max} = \gamma^{-1} \times \dot{q}_{motor, max}`, :math:`\tau_{j, max} =
+ \gamma \times \tau_{motor, max}` and :math:`\tau_{j, peak} = \gamma \times \tau_{motor, peak}`
+ are the maximum joint velocity, maximum joint torque and peak torque, respectively. These parameters
+ are read from the configuration instance passed to the class.
+
+ Using these values, the computed torques are clipped to the minimum and maximum values based on the
+ instantaneous joint velocity:
+
+ .. math::
+
+ \tau_{j, applied} = clip(\tau_{computed}, \tau_{j, min}(\dot{q}), \tau_{j, max}(\dot{q}))
+
+ """
+
+ cfg:DCMotorCfg
+"""The configuration for the actuator model."""
+
+
[文档]def__init__(self,cfg:DCMotorCfg,*args,**kwargs):
+ super().__init__(cfg,*args,**kwargs)
+ # parse configuration
+ ifself.cfg.saturation_effortisnotNone:
+ self._saturation_effort=self.cfg.saturation_effort
+ else:
+ self._saturation_effort=torch.inf
+ # prepare joint vel buffer for max effort computation
+ self._joint_vel=torch.zeros_like(self.computed_effort)
+ # create buffer for zeros effort
+ self._zeros_effort=torch.zeros_like(self.computed_effort)
+ # check that quantities are provided
+ ifself.cfg.velocity_limitisNone:
+ raiseValueError("The velocity limit must be provided for the DC motor actuator model.")
+
+"""
+ Operations.
+ """
+
+
[文档]defcompute(
+ self,control_action:ArticulationActions,joint_pos:torch.Tensor,joint_vel:torch.Tensor
+ )->ArticulationActions:
+ # save current joint vel
+ self._joint_vel[:]=joint_vel
+ # calculate the desired joint torques
+ returnsuper().compute(control_action,joint_pos,joint_vel)
+
+"""
+ Helper functions.
+ """
+
+ def_clip_effort(self,effort:torch.Tensor)->torch.Tensor:
+ # compute torque limits
+ # -- max limit
+ max_effort=self._saturation_effort*(1.0-self._joint_vel/self.velocity_limit)
+ max_effort=torch.clip(max_effort,min=self._zeros_effort,max=self.effort_limit)
+ # -- min limit
+ min_effort=self._saturation_effort*(-1.0-self._joint_vel/self.velocity_limit)
+ min_effort=torch.clip(min_effort,min=-self.effort_limit,max=self._zeros_effort)
+
+ # clip the torques based on the motor limits
+ returntorch.clip(effort,min=min_effort,max=max_effort)
+
+
+
[文档]classDelayedPDActuator(IdealPDActuator):
+"""Ideal PD actuator with delayed command application.
+
+ This class extends the :class:`IdealPDActuator` class by adding a delay to the actuator commands. The delay
+ is implemented using a circular buffer that stores the actuator commands for a certain number of physics steps.
+ The most recent actuation value is pushed to the buffer at every physics step, but the final actuation value
+ applied to the simulation is lagged by a certain number of physics steps.
+
+ The amount of time lag is configurable and can be set to a random value between the minimum and maximum time
+ lag bounds at every reset. The minimum and maximum time lag values are set in the configuration instance passed
+ to the class.
+ """
+
+ cfg:DelayedPDActuatorCfg
+"""The configuration for the actuator model."""
+
+
[文档]def__init__(self,cfg:DelayedPDActuatorCfg,*args,**kwargs):
+ super().__init__(cfg,*args,**kwargs)
+ # instantiate the delay buffers
+ self.positions_delay_buffer=DelayBuffer(cfg.max_delay,self._num_envs,device=self._device)
+ self.velocities_delay_buffer=DelayBuffer(cfg.max_delay,self._num_envs,device=self._device)
+ self.efforts_delay_buffer=DelayBuffer(cfg.max_delay,self._num_envs,device=self._device)
+ # all of the envs
+ self._ALL_INDICES=torch.arange(self._num_envs,dtype=torch.long,device=self._device)
+
+
[文档]defreset(self,env_ids:Sequence[int]):
+ super().reset(env_ids)
+ # number of environments (since env_ids can be a slice)
+ ifenv_idsisNoneorenv_ids==slice(None):
+ num_envs=self._num_envs
+ else:
+ num_envs=len(env_ids)
+ # set a new random delay for environments in env_ids
+ time_lags=torch.randint(
+ low=self.cfg.min_delay,
+ high=self.cfg.max_delay+1,
+ size=(num_envs,),
+ dtype=torch.int,
+ device=self._device,
+ )
+ # set delays
+ self.positions_delay_buffer.set_time_lag(time_lags,env_ids)
+ self.velocities_delay_buffer.set_time_lag(time_lags,env_ids)
+ self.efforts_delay_buffer.set_time_lag(time_lags,env_ids)
+ # reset buffers
+ self.positions_delay_buffer.reset(env_ids)
+ self.velocities_delay_buffer.reset(env_ids)
+ self.efforts_delay_buffer.reset(env_ids)
+
+
[文档]defcompute(
+ self,control_action:ArticulationActions,joint_pos:torch.Tensor,joint_vel:torch.Tensor
+ )->ArticulationActions:
+ # apply delay based on the delay the model for all the setpoints
+ control_action.joint_positions=self.positions_delay_buffer.compute(control_action.joint_positions)
+ control_action.joint_velocities=self.velocities_delay_buffer.compute(control_action.joint_velocities)
+ control_action.joint_efforts=self.efforts_delay_buffer.compute(control_action.joint_efforts)
+ # compte actuator model
+ returnsuper().compute(control_action,joint_pos,joint_vel)
+
+
+
[文档]classRemotizedPDActuator(DelayedPDActuator):
+"""Ideal PD actuator with angle-dependent torque limits.
+
+ This class extends the :class:`DelayedPDActuator` class by adding angle-dependent torque limits to the actuator.
+ The torque limits are applied by querying a lookup table describing the relationship between the joint angle
+ and the maximum output torque. The lookup table is provided in the configuration instance passed to the class.
+
+ The torque limits are interpolated based on the current joint positions and applied to the actuator commands.
+ """
+
+
[文档]def__init__(
+ self,
+ cfg:RemotizedPDActuatorCfg,
+ joint_names:list[str],
+ joint_ids:Sequence[int],
+ num_envs:int,
+ device:str,
+ stiffness:torch.Tensor|float=0.0,
+ damping:torch.Tensor|float=0.0,
+ armature:torch.Tensor|float=0.0,
+ friction:torch.Tensor|float=0.0,
+ effort_limit:torch.Tensor|float=torch.inf,
+ velocity_limit:torch.Tensor|float=torch.inf,
+ ):
+ # remove effort and velocity box constraints from the base class
+ cfg.effort_limit=torch.inf
+ cfg.velocity_limit=torch.inf
+ # call the base method and set default effort_limit and velocity_limit to inf
+ super().__init__(
+ cfg,joint_names,joint_ids,num_envs,device,stiffness,damping,armature,friction,torch.inf,torch.inf
+ )
+ self._joint_parameter_lookup=cfg.joint_parameter_lookup.to(device=device)
+ # define remotized joint torque limit
+ self._torque_limit=LinearInterpolation(self.angle_samples,self.max_torque_samples,device=device)
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Sub-package with the utility class to configure the :class:`omni.isaac.kit.SimulationApp`.
+
+The :class:`AppLauncher` parses environment variables and input CLI arguments to launch the simulator in
+various different modes. This includes with or without GUI and switching between different Omniverse remote
+clients. Some of these require the extensions to be loaded in a specific order, otherwise a segmentation
+fault occurs. The launched :class:`omni.isaac.kit.SimulationApp` instance is accessible via the
+:attr:`AppLauncher.app` property.
+"""
+
+importargparse
+importcontextlib
+importos
+importre
+importsignal
+importsys
+fromtypingimportAny,Literal
+
+withcontextlib.suppress(ModuleNotFoundError):
+ importisaacsim# noqa: F401
+
+fromomni.isaac.kitimportSimulationApp
+
+
+
[文档]classAppLauncher:
+"""A utility class to launch Isaac Sim application based on command-line arguments and environment variables.
+
+ The class resolves the simulation app settings that appear through environments variables,
+ command-line arguments (CLI) or as input keyword arguments. Based on these settings, it launches the
+ simulation app and configures the extensions to load (as a part of post-launch setup).
+
+ The input arguments provided to the class are given higher priority than the values set
+ from the corresponding environment variables. This provides flexibility to deal with different
+ users' preferences.
+
+ .. note::
+ Explicitly defined arguments are only given priority when their value is set to something outside
+ their default configuration. For example, the ``livestream`` argument is -1 by default. It only
+ overrides the ``LIVESTREAM`` environment variable when ``livestream`` argument is set to a
+ value >-1. In other words, if ``livestream=-1``, then the value from the environment variable
+ ``LIVESTREAM`` is used.
+
+ """
+
+
[文档]def__init__(self,launcher_args:argparse.Namespace|dict|None=None,**kwargs):
+"""Create a `SimulationApp`_ instance based on the input settings.
+
+ Args:
+ launcher_args: Input arguments to parse using the AppLauncher and set into the SimulationApp.
+ Defaults to None, which is equivalent to passing an empty dictionary. A detailed description of
+ the possible arguments is available in the `SimulationApp`_ documentation.
+ **kwargs : Additional keyword arguments that will be merged into :attr:`launcher_args`.
+ They serve as a convenience for those who want to pass some arguments using the argparse
+ interface and others directly into the AppLauncher. Duplicated arguments with
+ the :attr:`launcher_args` will raise a ValueError.
+
+ Raises:
+ ValueError: If there are common/duplicated arguments between ``launcher_args`` and ``kwargs``.
+ ValueError: If combination of ``launcher_args`` and ``kwargs`` are missing the necessary arguments
+ that are needed by the AppLauncher to resolve the desired app configuration.
+ ValueError: If incompatible or undefined values are assigned to relevant environment values,
+ such as ``LIVESTREAM``.
+
+ .. _argparse.Namespace: https://docs.python.org/3/library/argparse.html?highlight=namespace#argparse.Namespace
+ .. _SimulationApp: https://docs.omniverse.nvidia.com/py/isaacsim/source/extensions/omni.isaac.kit/docs/index.html
+ """
+ # We allow users to pass either a dict or an argparse.Namespace into
+ # __init__, anticipating that these will be all of the argparse arguments
+ # used by the calling script. Those which we appended via add_app_launcher_args
+ # will be used to control extension loading logic. Additional arguments are allowed,
+ # and will be passed directly to the SimulationApp initialization.
+ #
+ # We could potentially require users to enter each argument they want passed here
+ # as a kwarg, but this would require them to pass livestream, headless, and
+ # any other options we choose to add here explicitly, and with the correct keywords.
+ #
+ # @hunter: I feel that this is cumbersome and could introduce error, and would prefer to do
+ # some sanity checking in the add_app_launcher_args function
+ iflauncher_argsisNone:
+ launcher_args={}
+ elifisinstance(launcher_args,argparse.Namespace):
+ launcher_args=launcher_args.__dict__
+
+ # Check that arguments are unique
+ iflen(kwargs)>0:
+ ifnotset(kwargs.keys()).isdisjoint(launcher_args.keys()):
+ overlapping_args=set(kwargs.keys()).intersection(launcher_args.keys())
+ raiseValueError(
+ f"Input `launcher_args` and `kwargs` both provided common attributes: {overlapping_args}."
+ " Please ensure that each argument is supplied to only one of them, as the AppLauncher cannot"
+ " discern priority between them."
+ )
+ launcher_args.update(kwargs)
+
+ # Define config members that are read from env-vars or keyword args
+ self._headless:bool# 0: GUI, 1: Headless
+ self._livestream:Literal[0,1,2]# 0: Disabled, 1: Native, 2: WebRTC
+ self._offscreen_render:bool# 0: Disabled, 1: Enabled
+ self._sim_experience_file:str# Experience file to load
+
+ # Exposed to train scripts
+ self.device_id:int# device ID for GPU simulation (defaults to 0)
+ self.local_rank:int# local rank of GPUs in the current node
+ self.global_rank:int# global rank for multi-node training
+
+ # Integrate env-vars and input keyword args into simulation app config
+ self._config_resolution(launcher_args)
+ # Create SimulationApp, passing the resolved self._config to it for initialization
+ self._create_app()
+ # Load IsaacSim extensions
+ self._load_extensions()
+ # Hide the stop button in the toolbar
+ self._hide_stop_button()
+
+ # Set up signal handlers for graceful shutdown
+ # -- during interrupts
+ signal.signal(signal.SIGINT,self._interrupt_signal_handle_callback)
+ # -- during explicit `kill` commands
+ signal.signal(signal.SIGTERM,self._abort_signal_handle_callback)
+ # -- during segfaults
+ signal.signal(signal.SIGABRT,self._abort_signal_handle_callback)
+ signal.signal(signal.SIGSEGV,self._abort_signal_handle_callback)
+
+"""
+ Properties.
+ """
+
+ @property
+ defapp(self)->SimulationApp:
+"""The launched SimulationApp."""
+ ifself._appisnotNone:
+ returnself._app
+ else:
+ raiseRuntimeError("The `AppLauncher.app` member cannot be retrieved until the class is initialized.")
+
+"""
+ Operations.
+ """
+
+
[文档]@staticmethod
+ defadd_app_launcher_args(parser:argparse.ArgumentParser)->None:
+"""Utility function to configure AppLauncher arguments with an existing argument parser object.
+
+ This function takes an ``argparse.ArgumentParser`` object and does some sanity checking on the existing
+ arguments for ingestion by the SimulationApp. It then appends custom command-line arguments relevant
+ to the SimulationApp to the input :class:`argparse.ArgumentParser` instance. This allows overriding the
+ environment variables using command-line arguments.
+
+ Currently, it adds the following parameters to the argparser object:
+
+ * ``headless`` (bool): If True, the app will be launched in headless (no-gui) mode. The values map the same
+ as that for the ``HEADLESS`` environment variable. If False, then headless mode is determined by the
+ ``HEADLESS`` environment variable.
+ * ``livestream`` (int): If one of {1, 2}, then livestreaming and headless mode is enabled. The values
+ map the same as that for the ``LIVESTREAM`` environment variable. If :obj:`-1`, then livestreaming is
+ determined by the ``LIVESTREAM`` environment variable.
+ Valid options are:
+
+ - ``0``: Disabled
+ - ``1``: `Native <https://docs.omniverse.nvidia.com/extensions/latest/ext_livestream/native.html>`_
+ - ``2``: `WebRTC <https://docs.omniverse.nvidia.com/extensions/latest/ext_livestream/webrtc.html>`_
+
+ * ``enable_cameras`` (bool): If True, the app will enable camera sensors and render them, even when in
+ headless mode. This flag must be set to True if the environments contains any camera sensors.
+ The values map the same as that for the ``ENABLE_CAMERAS`` environment variable.
+ If False, then enable_cameras mode is determined by the ``ENABLE_CAMERAS`` environment variable.
+ * ``device`` (str): The device to run the simulation on.
+ Valid options are:
+
+ - ``cpu``: Use CPU.
+ - ``cuda``: Use GPU with device ID ``0``.
+ - ``cuda:N``: Use GPU, where N is the device ID. For example, "cuda:0".
+
+ * ``experience`` (str): The experience file to load when launching the SimulationApp. If a relative path
+ is provided, it is resolved relative to the ``apps`` folder in Isaac Sim and Isaac Lab (in that order).
+
+ If provided as an empty string, the experience file is determined based on the headless flag:
+
+ * If headless and enable_cameras are True, the experience file is set to ``isaaclab.python.headless.rendering.kit``.
+ * If headless is False and enable_cameras is True, the experience file is set to ``isaaclab.python.rendering.kit``.
+ * If headless is False and enable_cameras is False, the experience file is set to ``isaaclab.python.kit``.
+ * If headless is True and enable_cameras is False, the experience file is set to ``isaaclab.python.headless.kit``.
+
+ Args:
+ parser: An argument parser instance to be extended with the AppLauncher specific options.
+ """
+ # If the passed parser has an existing _HelpAction when passed,
+ # we here remove the options which would invoke it,
+ # to be added back after the additional AppLauncher args
+ # have been added. This is equivalent to
+ # initially constructing the ArgParser with add_help=False,
+ # but this means we don't have to require that behavior
+ # in users and can handle it on our end.
+ # We do this because calling parse_known_args() will handle
+ # any -h/--help options being passed and then exit immediately,
+ # before the additional arguments can be added to the help readout.
+ parser_help=None
+ iflen(parser._actions)>0andisinstance(parser._actions[0],argparse._HelpAction):# type: ignore
+ parser_help=parser._actions[0]
+ parser._option_string_actions.pop("-h")
+ parser._option_string_actions.pop("--help")
+
+ # Parse known args for potential name collisions/type mismatches
+ # between the config fields SimulationApp expects and the ArgParse
+ # arguments that the user passed.
+ known,_=parser.parse_known_args()
+ config=vars(known)
+ iflen(config)==0:
+ print(
+ "[WARN][AppLauncher]: There are no arguments attached to the ArgumentParser object."
+ " If you have your own arguments, please load your own arguments before calling the"
+ " `AppLauncher.add_app_launcher_args` method. This allows the method to check the validity"
+ " of the arguments and perform checks for argument names."
+ )
+ else:
+ AppLauncher._check_argparser_config_params(config)
+
+ # Add custom arguments to the parser
+ arg_group=parser.add_argument_group(
+ "app_launcher arguments",
+ description="Arguments for the AppLauncher. For more details, please check the documentation.",
+ )
+ arg_group.add_argument(
+ "--headless",
+ action="store_true",
+ default=AppLauncher._APPLAUNCHER_CFG_INFO["headless"][1],
+ help="Force display off at all times.",
+ )
+ arg_group.add_argument(
+ "--livestream",
+ type=int,
+ default=AppLauncher._APPLAUNCHER_CFG_INFO["livestream"][1],
+ choices={0,1,2},
+ help="Force enable livestreaming. Mapping corresponds to that for the `LIVESTREAM` environment variable.",
+ )
+ arg_group.add_argument(
+ "--enable_cameras",
+ action="store_true",
+ default=AppLauncher._APPLAUNCHER_CFG_INFO["enable_cameras"][1],
+ help="Enable camera sensors and relevant extension dependencies.",
+ )
+ arg_group.add_argument(
+ "--device",
+ type=str,
+ default=AppLauncher._APPLAUNCHER_CFG_INFO["device"][1],
+ help='The device to run the simulation on. Can be "cpu", "cuda", "cuda:N", where N is the device ID',
+ )
+ # Add the deprecated cpu flag to raise an error if it is used
+ arg_group.add_argument("--cpu",action="store_true",help=argparse.SUPPRESS)
+ arg_group.add_argument(
+ "--verbose",# Note: This is read by SimulationApp through sys.argv
+ action="store_true",
+ help="Enable verbose terminal output from the SimulationApp.",
+ )
+ arg_group.add_argument(
+ "--experience",
+ type=str,
+ default="",
+ help=(
+ "The experience file to load when launching the SimulationApp. If an empty string is provided,"
+ " the experience file is determined based on the headless flag. If a relative path is provided,"
+ " it is resolved relative to the `apps` folder in Isaac Sim and Isaac Lab (in that order)."
+ ),
+ )
+
+ # Corresponding to the beginning of the function,
+ # if we have removed -h/--help handling, we add it back.
+ ifparser_helpisnotNone:
+ parser._option_string_actions["-h"]=parser_help
+ parser._option_string_actions["--help"]=parser_help
+
+"""
+ Internal functions.
+ """
+
+ _APPLAUNCHER_CFG_INFO:dict[str,tuple[list[type],Any]]={
+ "headless":([bool],False),
+ "livestream":([int],-1),
+ "enable_cameras":([bool],False),
+ "device":([str],"cuda:0"),
+ "experience":([str],""),
+ }
+"""A dictionary of arguments added manually by the :meth:`AppLauncher.add_app_launcher_args` method.
+
+ The values are a tuple of the expected type and default value. This is used to check against name collisions
+ for arguments passed to the :class:`AppLauncher` class as well as for type checking.
+
+ They have corresponding environment variables as detailed in the documentation.
+ """
+
+ # TODO: Find some internally managed NVIDIA list of these types.
+ # SimulationApp.DEFAULT_LAUNCHER_CONFIG almost works, except that
+ # it is ambiguous where the default types are None
+ _SIM_APP_CFG_TYPES:dict[str,list[type]]={
+ "headless":[bool],
+ "hide_ui":[bool,type(None)],
+ "active_gpu":[int,type(None)],
+ "physics_gpu":[int],
+ "multi_gpu":[bool],
+ "sync_loads":[bool],
+ "width":[int],
+ "height":[int],
+ "window_width":[int],
+ "window_height":[int],
+ "display_options":[int],
+ "subdiv_refinement_level":[int],
+ "renderer":[str],
+ "anti_aliasing":[int],
+ "samples_per_pixel_per_frame":[int],
+ "denoiser":[bool],
+ "max_bounces":[int],
+ "max_specular_transmission_bounces":[int],
+ "max_volume_bounces":[int],
+ "open_usd":[str,type(None)],
+ "livesync_usd":[str,type(None)],
+ "fast_shutdown":[bool],
+ "experience":[str],
+ }
+"""A dictionary containing the type of arguments passed to SimulationApp.
+
+ This is used to check against name collisions for arguments passed to the :class:`AppLauncher` class
+ as well as for type checking. It corresponds closely to the :attr:`SimulationApp.DEFAULT_LAUNCHER_CONFIG`,
+ but specifically denotes where None types are allowed.
+ """
+
+ @staticmethod
+ def_check_argparser_config_params(config:dict)->None:
+"""Checks that input argparser object has parameters with valid settings with no name conflicts.
+
+ First, we inspect the dictionary to ensure that the passed ArgParser object is not attempting to add arguments
+ which should be assigned by calling :meth:`AppLauncher.add_app_launcher_args`.
+
+ Then, we check that if the key corresponds to a config setting expected by SimulationApp, then the type of
+ that key's value corresponds to the type expected by the SimulationApp. If it passes the check, the function
+ prints out that the setting with be passed to the SimulationApp. Otherwise, we raise a ValueError exception.
+
+ Args:
+ config: A configuration parameters which will be passed to the SimulationApp constructor.
+
+ Raises:
+ ValueError: If a key is an already existing field in the configuration parameters but
+ should be added by calling the :meth:`AppLauncher.add_app_launcher_args.
+ ValueError: If keys corresponding to those used to initialize SimulationApp
+ (as found in :attr:`_SIM_APP_CFG_TYPES`) are of the wrong value type.
+ """
+ # check that no config key conflicts with AppLauncher config names
+ applauncher_keys=set(AppLauncher._APPLAUNCHER_CFG_INFO.keys())
+ forkey,valueinconfig.items():
+ ifkeyinapplauncher_keys:
+ raiseValueError(
+ f"The passed ArgParser object already has the field '{key}'. This field will be added by"
+ " `AppLauncher.add_app_launcher_args()`, and should not be added directly. Please remove the"
+ " argument or rename it to a non-conflicting name."
+ )
+ # check that type of the passed keys are valid
+ simulationapp_keys=set(AppLauncher._SIM_APP_CFG_TYPES.keys())
+ forkey,valueinconfig.items():
+ ifkeyinsimulationapp_keys:
+ given_type=type(value)
+ expected_types=AppLauncher._SIM_APP_CFG_TYPES[key]
+ iftype(value)notinset(expected_types):
+ raiseValueError(
+ f"Invalid value type for the argument '{key}': {given_type}. Expected one of {expected_types},"
+ " if intended to be ingested by the SimulationApp object. Please change the type if this"
+ " intended for the SimulationApp or change the name of the argument to avoid name conflicts."
+ )
+ # Print out values which will be used
+ print(f"[INFO][AppLauncher]: The argument '{key}' will be used to configure the SimulationApp.")
+
+ def_config_resolution(self,launcher_args:dict):
+"""Resolve the input arguments and environment variables.
+
+ Args:
+ launcher_args: A dictionary of all input arguments passed to the class object.
+ """
+ # Handle all control logic resolution
+
+ # --LIVESTREAM logic--
+ #
+ livestream_env=int(os.environ.get("LIVESTREAM",0))
+ livestream_arg=launcher_args.pop("livestream",AppLauncher._APPLAUNCHER_CFG_INFO["livestream"][1])
+ livestream_valid_vals={0,1,2}
+ # Value checking on LIVESTREAM
+ iflivestream_envnotinlivestream_valid_vals:
+ raiseValueError(
+ f"Invalid value for environment variable `LIVESTREAM`: {livestream_env} ."
+ f" Expected: {livestream_valid_vals}."
+ )
+ # We allow livestream kwarg to supersede LIVESTREAM envvar
+ iflivestream_arg>=0:
+ iflivestream_arginlivestream_valid_vals:
+ self._livestream=livestream_arg
+ # print info that we overrode the env-var
+ print(
+ f"[INFO][AppLauncher]: Input keyword argument `livestream={livestream_arg}` has overridden"
+ f" the environment variable `LIVESTREAM={livestream_env}`."
+ )
+ else:
+ raiseValueError(
+ f"Invalid value for input keyword argument `livestream`: {livestream_arg} ."
+ f" Expected: {livestream_valid_vals}."
+ )
+ else:
+ self._livestream=livestream_env
+
+ # --HEADLESS logic--
+ #
+ # Resolve headless execution of simulation app
+ # HEADLESS is initially passed as an int instead of
+ # the bool of headless_arg to avoid messy string processing,
+ headless_env=int(os.environ.get("HEADLESS",0))
+ headless_arg=launcher_args.pop("headless",AppLauncher._APPLAUNCHER_CFG_INFO["headless"][1])
+ headless_valid_vals={0,1}
+ # Value checking on HEADLESS
+ ifheadless_envnotinheadless_valid_vals:
+ raiseValueError(
+ f"Invalid value for environment variable `HEADLESS`: {headless_env} . Expected: {headless_valid_vals}."
+ )
+ # We allow headless kwarg to supersede HEADLESS envvar if headless_arg does not have the default value
+ # Note: Headless is always true when livestreaming
+ ifheadless_argisTrue:
+ self._headless=headless_arg
+ elifself._livestreamin{1,2}:
+ # we are always headless on the host machine
+ self._headless=True
+ # inform who has toggled the headless flag
+ ifself._livestream==livestream_arg:
+ print(
+ f"[INFO][AppLauncher]: Input keyword argument `livestream={self._livestream}` has implicitly"
+ f" overridden the environment variable `HEADLESS={headless_env}` to True."
+ )
+ elifself._livestream==livestream_env:
+ print(
+ f"[INFO][AppLauncher]: Environment variable `LIVESTREAM={self._livestream}` has implicitly"
+ f" overridden the environment variable `HEADLESS={headless_env}` to True."
+ )
+ else:
+ # Headless needs to be a bool to be ingested by SimulationApp
+ self._headless=bool(headless_env)
+ # Headless needs to be passed to the SimulationApp so we keep it here
+ launcher_args["headless"]=self._headless
+
+ # --enable_cameras logic--
+ #
+ enable_cameras_env=int(os.environ.get("ENABLE_CAMERAS",0))
+ enable_cameras_arg=launcher_args.pop("enable_cameras",AppLauncher._APPLAUNCHER_CFG_INFO["enable_cameras"][1])
+ enable_cameras_valid_vals={0,1}
+ ifenable_cameras_envnotinenable_cameras_valid_vals:
+ raiseValueError(
+ f"Invalid value for environment variable `ENABLE_CAMERAS`: {enable_cameras_env} ."
+ f"Expected: {enable_cameras_valid_vals} ."
+ )
+ # We allow enable_cameras kwarg to supersede ENABLE_CAMERAS envvar
+ ifenable_cameras_argisTrue:
+ self._enable_cameras=enable_cameras_arg
+ else:
+ self._enable_cameras=bool(enable_cameras_env)
+ self._offscreen_render=False
+ ifself._enable_camerasandself._headless:
+ self._offscreen_render=True
+
+ # Check if we can disable the viewport to improve performance
+ # This should only happen if we are running headless and do not require livestreaming or video recording
+ # This is different from offscreen_render because this only affects the default viewport and not other renderproducts in the scene
+ self._render_viewport=True
+ ifself._headlessandnotself._livestreamandnotlauncher_args.get("video",False):
+ self._render_viewport=False
+
+ # hide_ui flag
+ launcher_args["hide_ui"]=False
+ ifself._headlessandnotself._livestream:
+ launcher_args["hide_ui"]=True
+
+ # --simulation GPU device logic --
+ self.device_id=0
+ device=launcher_args.get("device",AppLauncher._APPLAUNCHER_CFG_INFO["device"][1])
+ if"cuda"notindeviceand"cpu"notindevice:
+ raiseValueError(
+ f"Invalid value for input keyword argument `device`: {device}."
+ " Expected: a string with the format 'cuda', 'cuda:<device_id>', or 'cpu'."
+ )
+ if"cuda:"indevice:
+ self.device_id=int(device.split(":")[-1])
+
+ # Raise an error for the deprecated cpu flag
+ iflauncher_args.get("cpu",False):
+ raiseValueError("The `--cpu` flag is deprecated. Please use `--device cpu` instead.")
+
+ if"distributed"inlauncher_argsandlauncher_args["distributed"]:
+ # local rank (GPU id) in a current multi-gpu mode
+ self.local_rank=int(os.getenv("LOCAL_RANK","0"))+int(os.getenv("JAX_LOCAL_RANK","0"))
+ # global rank (GPU id) in multi-gpu multi-node mode
+ self.global_rank=int(os.getenv("RANK","0"))+int(os.getenv("JAX_RANK","0"))
+
+ self.device_id=self.local_rank
+ launcher_args["multi_gpu"]=False
+ # limit CPU threads to minimize thread context switching
+ # this ensures processes do not take up all available threads and fight for resources
+ num_cpu_cores=os.cpu_count()
+ num_threads_per_process=num_cpu_cores//int(os.getenv("WORLD_SIZE",1))
+ # set environment variables to limit CPU threads
+ os.environ["PXR_WORK_THREAD_LIMIT"]=str(num_threads_per_process)
+ os.environ["OPENBLAS_NUM_THREADS"]=str(num_threads_per_process)
+ # pass command line variable to kit
+ sys.argv.append(f"--/plugins/carb.tasking.plugin/threadCount={num_threads_per_process}")
+
+ # set physics and rendering device
+ launcher_args["physics_gpu"]=self.device_id
+ launcher_args["active_gpu"]=self.device_id
+
+ # Check if input keywords contain an 'experience' file setting
+ # Note: since experience is taken as a separate argument by Simulation App, we store it separately
+ self._sim_experience_file=launcher_args.pop("experience","")
+
+ # If nothing is provided resolve the experience file based on the headless flag
+ kit_app_exp_path=os.environ["EXP_PATH"]
+ isaaclab_app_exp_path=os.path.join(os.path.dirname(os.path.abspath(__file__)),*[".."]*6,"apps")
+ ifself._sim_experience_file=="":
+ # check if the headless flag is setS
+ ifself._enable_cameras:
+ ifself._headlessandnotself._livestream:
+ self._sim_experience_file=os.path.join(
+ isaaclab_app_exp_path,"isaaclab.python.headless.rendering.kit"
+ )
+ else:
+ self._sim_experience_file=os.path.join(isaaclab_app_exp_path,"isaaclab.python.rendering.kit")
+ elifself._headlessandnotself._livestream:
+ self._sim_experience_file=os.path.join(isaaclab_app_exp_path,"isaaclab.python.headless.kit")
+ else:
+ self._sim_experience_file=os.path.join(isaaclab_app_exp_path,"isaaclab.python.kit")
+ elifnotos.path.isabs(self._sim_experience_file):
+ option_1_app_exp_path=os.path.join(kit_app_exp_path,self._sim_experience_file)
+ option_2_app_exp_path=os.path.join(isaaclab_app_exp_path,self._sim_experience_file)
+ ifos.path.exists(option_1_app_exp_path):
+ self._sim_experience_file=option_1_app_exp_path
+ elifos.path.exists(option_2_app_exp_path):
+ self._sim_experience_file=option_2_app_exp_path
+ else:
+ raiseFileNotFoundError(
+ f"Invalid value for input keyword argument `experience`: {self._sim_experience_file}."
+ "\n No such file exists in either the Kit or Isaac Lab experience paths. Checked paths:"
+ f"\n\t [1]: {option_1_app_exp_path}"
+ f"\n\t [2]: {option_2_app_exp_path}"
+ )
+ elifnotos.path.exists(self._sim_experience_file):
+ raiseFileNotFoundError(
+ f"Invalid value for input keyword argument `experience`: {self._sim_experience_file}."
+ " The file does not exist."
+ )
+
+ print(f"[INFO][AppLauncher]: Loading experience file: {self._sim_experience_file}")
+ # Remove all values from input keyword args which are not meant for SimulationApp
+ # Assign all the passed settings to a dictionary for the simulation app
+ self._sim_app_config={
+ key:launcher_args[key]forkeyinset(AppLauncher._SIM_APP_CFG_TYPES.keys())&set(launcher_args.keys())
+ }
+
+ def_create_app(self):
+"""Launch and create the SimulationApp based on the parsed simulation config."""
+ # Initialize SimulationApp
+ # hack sys module to make sure that the SimulationApp is initialized correctly
+ # this is to avoid the warnings from the simulation app about not ok modules
+ r=re.compile(".*lab.*")
+ found_modules=list(filter(r.match,list(sys.modules.keys())))
+ found_modules+=["omni.isaac.kit.app_framework"]
+ # remove Isaac Lab modules from sys.modules
+ hacked_modules=dict()
+ forkeyinfound_modules:
+ hacked_modules[key]=sys.modules[key]
+ delsys.modules[key]
+ # launch simulation app
+ self._app=SimulationApp(self._sim_app_config,experience=self._sim_experience_file)
+ # add Isaac Lab modules back to sys.modules
+ forkey,valueinhacked_modules.items():
+ sys.modules[key]=value
+ # remove the threadCount argument from sys.argv if it was added for distributed training
+ pattern=r"--/plugins/carb\.tasking\.plugin/threadCount=\d+"
+ sys.argv=[argforarginsys.argvifnotre.match(pattern,arg)]
+
+ def_rendering_enabled(self)->bool:
+"""Check if rendering is required by the app."""
+ # Indicates whether rendering is required by the app.
+ # Extensions required for rendering bring startup and simulation costs, so we do not enable them if not required.
+ returnnotself._headlessorself._livestream>=1orself._enable_cameras
+
+ def_load_extensions(self):
+"""Load correct extensions based on AppLauncher's resolved config member variables."""
+ # These have to be loaded after SimulationApp is initialized
+ importcarb
+ importomni.physx.bindings._physxasphysx_impl
+ fromomni.isaac.core.utils.extensionsimportenable_extension
+
+ # Retrieve carb settings for modification
+ carb_settings_iface=carb.settings.get_settings()
+
+ ifself._livestream>=1:
+ # Ensure that a viewport exists in case an experience has been
+ # loaded which does not load it by default
+ enable_extension("omni.kit.viewport.window")
+ # Set carb settings to allow for livestreaming
+ carb_settings_iface.set_bool("/app/livestream/enabled",True)
+ carb_settings_iface.set_bool("/app/window/drawMouse",True)
+ carb_settings_iface.set_bool("/ngx/enabled",False)
+ carb_settings_iface.set_string("/app/livestream/proto","ws")
+ carb_settings_iface.set_int("/app/livestream/websocket/framerate_limit",120)
+ # Note: Only one livestream extension can be enabled at a time
+ ifself._livestream==1:
+ # Enable Native Livestream extension
+ # Default App: Streaming Client from the Omniverse Launcher
+ enable_extension("omni.kit.streamsdk.plugins-3.2.1")
+ enable_extension("omni.kit.livestream.core-3.2.0")
+ enable_extension("omni.kit.livestream.native-4.1.0")
+ elifself._livestream==2:
+ # Enable WebRTC Livestream extension
+ # Default URL: http://localhost:8211/streaming/webrtc-client/
+ enable_extension("omni.services.streamclient.webrtc")
+ else:
+ raiseValueError(f"Invalid value for livestream: {self._livestream}. Expected: 1, 2 .")
+ else:
+ carb_settings_iface.set_bool("/app/livestream/enabled",False)
+
+ # set carb setting to indicate Isaac Lab's offscreen_render pipeline should be enabled
+ # this flag is used by the SimulationContext class to enable the offscreen_render pipeline
+ # when the render() method is called.
+ carb_settings_iface.set_bool("/isaaclab/render/offscreen",self._offscreen_render)
+
+ # set carb setting to indicate Isaac Lab's render_viewport pipeline should be enabled
+ # this flag is used by the SimulationContext class to enable the render_viewport pipeline
+ # when the render() method is called.
+ carb_settings_iface.set_bool("/isaaclab/render/active_viewport",self._render_viewport)
+
+ # set carb setting to indicate no RTX sensors are used
+ # this flag is set to True when an RTX-rendering related sensor is created
+ # for example: the `Camera` sensor class
+ carb_settings_iface.set_bool("/isaaclab/render/rtx_sensors",False)
+
+ # set fabric update flag to disable updating transforms when rendering is disabled
+ carb_settings_iface.set_bool("/physics/fabricUpdateTransformations",self._rendering_enabled())
+
+ # set the nucleus directory manually to the latest published Nucleus
+ # note: this is done to ensure prior versions of Isaac Sim still use the latest assets
+ assets_path="http://omniverse-content-production.s3-us-west-2.amazonaws.com/Assets/Isaac/4.2"
+ carb_settings_iface.set_string("/persistent/isaac/asset_root/default",assets_path)
+ carb_settings_iface.set_string("/persistent/isaac/asset_root/cloud",assets_path)
+ carb_settings_iface.set_string("/persistent/isaac/asset_root/nvidia",assets_path)
+
+ # disable physics backwards compatibility check
+ carb_settings_iface.set_int(physx_impl.SETTING_BACKWARD_COMPATIBILITY,0)
+
+ def_hide_stop_button(self):
+"""Hide the stop button in the toolbar.
+
+ For standalone executions, having a stop button is confusing since it invalidates the whole simulation.
+ Thus, we hide the button so that users don't accidentally click it.
+ """
+ # when we are truly headless, then we can't import the widget toolbar
+ # thus, we only hide the stop button when we are not headless (i.e. GUI is enabled)
+ ifself._livestream>=1ornotself._headless:
+ importomni.kit.widget.toolbar
+
+ # grey out the stop button because we don't want to stop the simulation manually in standalone mode
+ toolbar=omni.kit.widget.toolbar.get_instance()
+ play_button_group=toolbar._builtin_tools._play_button_group# type: ignore
+ ifplay_button_groupisnotNone:
+ play_button_group._stop_button.visible=False# type: ignore
+ play_button_group._stop_button.enabled=False# type: ignore
+ play_button_group._stop_button=None# type: ignore
+
+ def_interrupt_signal_handle_callback(self,signal,frame):
+"""Handle the interrupt signal from the keyboard."""
+ # close the app
+ self._app.close()
+ # raise the error for keyboard interrupt
+ raiseKeyboardInterrupt
+
+ def_abort_signal_handle_callback(self,signal,frame):
+"""Handle the abort/segmentation/kill signals."""
+ # close the app
+ self._app.close()
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+# Flag for pyright to ignore type errors in this file.
+# pyright: reportPrivateUsage=false
+
+from__future__importannotations
+
+importtorch
+fromcollections.abcimportSequence
+fromprettytableimportPrettyTable
+fromtypingimportTYPE_CHECKING
+
+importcarb
+importomni.isaac.core.utils.stageasstage_utils
+importomni.physics.tensors.impl.apiasphysx
+fromomni.isaac.core.utils.typesimportArticulationActions
+frompxrimportPhysxSchema,UsdPhysics
+
+importomni.isaac.lab.simassim_utils
+importomni.isaac.lab.utils.mathasmath_utils
+importomni.isaac.lab.utils.stringasstring_utils
+fromomni.isaac.lab.actuatorsimportActuatorBase,ActuatorBaseCfg,ImplicitActuator
+
+from..asset_baseimportAssetBase
+from.articulation_dataimportArticulationData
+
+ifTYPE_CHECKING:
+ from.articulation_cfgimportArticulationCfg
+
+
+
[文档]classArticulation(AssetBase):
+"""An articulation asset class.
+
+ An articulation is a collection of rigid bodies connected by joints. The joints can be either
+ fixed or actuated. The joints can be of different types, such as revolute, prismatic, D-6, etc.
+ However, the articulation class has currently been tested with revolute and prismatic joints.
+ The class supports both floating-base and fixed-base articulations. The type of articulation
+ is determined based on the root joint of the articulation. If the root joint is fixed, then
+ the articulation is considered a fixed-base system. Otherwise, it is considered a floating-base
+ system. This can be checked using the :attr:`Articulation.is_fixed_base` attribute.
+
+ For an asset to be considered an articulation, the root prim of the asset must have the
+ `USD ArticulationRootAPI`_. This API is used to define the sub-tree of the articulation using
+ the reduced coordinate formulation. On playing the simulation, the physics engine parses the
+ articulation root prim and creates the corresponding articulation in the physics engine. The
+ articulation root prim can be specified using the :attr:`AssetBaseCfg.prim_path` attribute.
+
+ The articulation class also provides the functionality to augment the simulation of an articulated
+ system with custom actuator models. These models can either be explicit or implicit, as detailed in
+ the :mod:`omni.isaac.lab.actuators` module. The actuator models are specified using the
+ :attr:`ArticulationCfg.actuators` attribute. These are then parsed and used to initialize the
+ corresponding actuator models, when the simulation is played.
+
+ During the simulation step, the articulation class first applies the actuator models to compute
+ the joint commands based on the user-specified targets. These joint commands are then applied
+ into the simulation. The joint commands can be either position, velocity, or effort commands.
+ As an example, the following snippet shows how this can be used for position commands:
+
+ .. code-block:: python
+
+ # an example instance of the articulation class
+ my_articulation = Articulation(cfg)
+
+ # set joint position targets
+ my_articulation.set_joint_position_target(position)
+ # propagate the actuator models and apply the computed commands into the simulation
+ my_articulation.write_data_to_sim()
+
+ # step the simulation using the simulation context
+ sim_context.step()
+
+ # update the articulation state, where dt is the simulation time step
+ my_articulation.update(dt)
+
+ .. _`USD ArticulationRootAPI`: https://openusd.org/dev/api/class_usd_physics_articulation_root_a_p_i.html
+
+ """
+
+ cfg:ArticulationCfg
+"""Configuration instance for the articulations."""
+
+ actuators:dict[str,ActuatorBase]
+"""Dictionary of actuator instances for the articulation.
+
+ The keys are the actuator names and the values are the actuator instances. The actuator instances
+ are initialized based on the actuator configurations specified in the :attr:`ArticulationCfg.actuators`
+ attribute. They are used to compute the joint commands during the :meth:`write_data_to_sim` function.
+ """
+
+
[文档]def__init__(self,cfg:ArticulationCfg):
+"""Initialize the articulation.
+
+ Args:
+ cfg: A configuration instance.
+ """
+ super().__init__(cfg)
+
+"""
+ Properties
+ """
+
+ @property
+ defdata(self)->ArticulationData:
+ returnself._data
+
+ @property
+ defnum_instances(self)->int:
+ returnself.root_physx_view.count
+
+ @property
+ defis_fixed_base(self)->bool:
+"""Whether the articulation is a fixed-base or floating-base system."""
+ returnself.root_physx_view.shared_metatype.fixed_base
+
+ @property
+ defnum_joints(self)->int:
+"""Number of joints in articulation."""
+ returnself.root_physx_view.shared_metatype.dof_count
+
+ @property
+ defnum_fixed_tendons(self)->int:
+"""Number of fixed tendons in articulation."""
+ returnself.root_physx_view.max_fixed_tendons
+
+ @property
+ defnum_bodies(self)->int:
+"""Number of bodies in articulation."""
+ returnself.root_physx_view.shared_metatype.link_count
+
+ @property
+ defjoint_names(self)->list[str]:
+"""Ordered names of joints in articulation."""
+ returnself.root_physx_view.shared_metatype.dof_names
+
+ @property
+ deffixed_tendon_names(self)->list[str]:
+"""Ordered names of fixed tendons in articulation."""
+ returnself._fixed_tendon_names
+
+ @property
+ defbody_names(self)->list[str]:
+"""Ordered names of bodies in articulation."""
+ returnself.root_physx_view.shared_metatype.link_names
+
+ @property
+ defroot_physx_view(self)->physx.ArticulationView:
+"""Articulation view for the asset (PhysX).
+
+ Note:
+ Use this view with caution. It requires handling of tensors in a specific way.
+ """
+ returnself._root_physx_view
+
+"""
+ Operations.
+ """
+
+
[文档]defwrite_data_to_sim(self):
+"""Write external wrenches and joint commands to the simulation.
+
+ If any explicit actuators are present, then the actuator models are used to compute the
+ joint commands. Otherwise, the joint commands are directly set into the simulation.
+
+ Note:
+ We write external wrench to the simulation here since this function is called before the simulation step.
+ This ensures that the external wrench is applied at every simulation step.
+ """
+ # write external wrench
+ ifself.has_external_wrench:
+ self.root_physx_view.apply_forces_and_torques_at_position(
+ force_data=self._external_force_b.view(-1,3),
+ torque_data=self._external_torque_b.view(-1,3),
+ position_data=None,
+ indices=self._ALL_INDICES,
+ is_global=False,
+ )
+
+ # apply actuator models
+ self._apply_actuator_model()
+ # write actions into simulation
+ self.root_physx_view.set_dof_actuation_forces(self._joint_effort_target_sim,self._ALL_INDICES)
+ # position and velocity targets only for implicit actuators
+ ifself._has_implicit_actuators:
+ self.root_physx_view.set_dof_position_targets(self._joint_pos_target_sim,self._ALL_INDICES)
+ self.root_physx_view.set_dof_velocity_targets(self._joint_vel_target_sim,self._ALL_INDICES)
[文档]deffind_bodies(self,name_keys:str|Sequence[str],preserve_order:bool=False)->tuple[list[int],list[str]]:
+"""Find bodies in the articulation based on the name keys.
+
+ Please check the :meth:`omni.isaac.lab.utils.string_utils.resolve_matching_names` function for more
+ information on the name matching.
+
+ Args:
+ name_keys: A regular expression or a list of regular expressions to match the body names.
+ preserve_order: Whether to preserve the order of the name keys in the output. Defaults to False.
+
+ Returns:
+ A tuple of lists containing the body indices and names.
+ """
+ returnstring_utils.resolve_matching_names(name_keys,self.body_names,preserve_order)
+
+
[文档]deffind_joints(
+ self,name_keys:str|Sequence[str],joint_subset:list[str]|None=None,preserve_order:bool=False
+ )->tuple[list[int],list[str]]:
+"""Find joints in the articulation based on the name keys.
+
+ Please see the :func:`omni.isaac.lab.utils.string.resolve_matching_names` function for more information
+ on the name matching.
+
+ Args:
+ name_keys: A regular expression or a list of regular expressions to match the joint names.
+ joint_subset: A subset of joints to search for. Defaults to None, which means all joints
+ in the articulation are searched.
+ preserve_order: Whether to preserve the order of the name keys in the output. Defaults to False.
+
+ Returns:
+ A tuple of lists containing the joint indices and names.
+ """
+ ifjoint_subsetisNone:
+ joint_subset=self.joint_names
+ # find joints
+ returnstring_utils.resolve_matching_names(name_keys,joint_subset,preserve_order)
+
+
[文档]deffind_fixed_tendons(
+ self,name_keys:str|Sequence[str],tendon_subsets:list[str]|None=None,preserve_order:bool=False
+ )->tuple[list[int],list[str]]:
+"""Find fixed tendons in the articulation based on the name keys.
+
+ Please see the :func:`omni.isaac.lab.utils.string.resolve_matching_names` function for more information
+ on the name matching.
+
+ Args:
+ name_keys: A regular expression or a list of regular expressions to match the joint
+ names with fixed tendons.
+ tendon_subsets: A subset of joints with fixed tendons to search for. Defaults to None, which means
+ all joints in the articulation are searched.
+ preserve_order: Whether to preserve the order of the name keys in the output. Defaults to False.
+
+ Returns:
+ A tuple of lists containing the tendon indices and names.
+ """
+ iftendon_subsetsisNone:
+ # tendons follow the joint names they are attached to
+ tendon_subsets=self.fixed_tendon_names
+ # find tendons
+ returnstring_utils.resolve_matching_names(name_keys,tendon_subsets,preserve_order)
+
+"""
+ Operations - Writers.
+ """
+
+
[文档]defwrite_root_state_to_sim(self,root_state:torch.Tensor,env_ids:Sequence[int]|None=None):
+"""Set the root state over selected environment indices into the simulation.
+
+ The root state comprises of the cartesian position, quaternion orientation in (w, x, y, z), and linear
+ and angular velocity. All the quantities are in the simulation frame.
+
+ Args:
+ root_state: Root state in simulation frame. Shape is (len(env_ids), 13).
+ env_ids: Environment indices. If None, then all indices are used.
+ """
+ # set into simulation
+ self.write_root_pose_to_sim(root_state[:,:7],env_ids=env_ids)
+ self.write_root_velocity_to_sim(root_state[:,7:],env_ids=env_ids)
+
+
[文档]defwrite_root_pose_to_sim(self,root_pose:torch.Tensor,env_ids:Sequence[int]|None=None):
+"""Set the root pose over selected environment indices into the simulation.
+
+ The root pose comprises of the cartesian position and quaternion orientation in (w, x, y, z).
+
+ Args:
+ root_pose: Root poses in simulation frame. Shape is (len(env_ids), 7).
+ env_ids: Environment indices. If None, then all indices are used.
+ """
+ # resolve all indices
+ physx_env_ids=env_ids
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ physx_env_ids=self._ALL_INDICES
+ # note: we need to do this here since tensors are not set into simulation until step.
+ # set into internal buffers
+ self._data.root_state_w[env_ids,:7]=root_pose.clone()
+ # convert root quaternion from wxyz to xyzw
+ root_poses_xyzw=self._data.root_state_w[:,:7].clone()
+ root_poses_xyzw[:,3:]=math_utils.convert_quat(root_poses_xyzw[:,3:],to="xyzw")
+ # Need to invalidate the buffer to trigger the update with the new root pose.
+ self._data._body_state_w.timestamp=-1.0
+ # set into simulation
+ self.root_physx_view.set_root_transforms(root_poses_xyzw,indices=physx_env_ids)
+
+
[文档]defwrite_root_velocity_to_sim(self,root_velocity:torch.Tensor,env_ids:Sequence[int]|None=None):
+"""Set the root velocity over selected environment indices into the simulation.
+
+ Args:
+ root_velocity: Root velocities in simulation frame. Shape is (len(env_ids), 6).
+ env_ids: Environment indices. If None, then all indices are used.
+ """
+ # resolve all indices
+ physx_env_ids=env_ids
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ physx_env_ids=self._ALL_INDICES
+ # note: we need to do this here since tensors are not set into simulation until step.
+ # set into internal buffers
+ self._data.root_state_w[env_ids,7:]=root_velocity.clone()
+ self._data.body_acc_w[env_ids]=0.0
+ # set into simulation
+ self.root_physx_view.set_root_velocities(self._data.root_state_w[:,7:],indices=physx_env_ids)
+
+
[文档]defwrite_joint_state_to_sim(
+ self,
+ position:torch.Tensor,
+ velocity:torch.Tensor,
+ joint_ids:Sequence[int]|slice|None=None,
+ env_ids:Sequence[int]|slice|None=None,
+ ):
+"""Write joint positions and velocities to the simulation.
+
+ Args:
+ position: Joint positions. Shape is (len(env_ids), len(joint_ids)).
+ velocity: Joint velocities. Shape is (len(env_ids), len(joint_ids)).
+ joint_ids: The joint indices to set the targets for. Defaults to None (all joints).
+ env_ids: The environment indices to set the targets for. Defaults to None (all environments).
+ """
+ # resolve indices
+ physx_env_ids=env_ids
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ physx_env_ids=self._ALL_INDICES
+ ifjoint_idsisNone:
+ joint_ids=slice(None)
+ # broadcast env_ids if needed to allow double indexing
+ ifenv_ids!=slice(None)andjoint_ids!=slice(None):
+ env_ids=env_ids[:,None]
+ # set into internal buffers
+ self._data.joint_pos[env_ids,joint_ids]=position
+ self._data.joint_vel[env_ids,joint_ids]=velocity
+ self._data._previous_joint_vel[env_ids,joint_ids]=velocity
+ self._data.joint_acc[env_ids,joint_ids]=0.0
+ # Need to invalidate the buffer to trigger the update with the new root pose.
+ self._data._body_state_w.timestamp=-1.0
+ # set into simulation
+ self.root_physx_view.set_dof_positions(self._data.joint_pos,indices=physx_env_ids)
+ self.root_physx_view.set_dof_velocities(self._data.joint_vel,indices=physx_env_ids)
+
+
[文档]defwrite_joint_stiffness_to_sim(
+ self,
+ stiffness:torch.Tensor|float,
+ joint_ids:Sequence[int]|slice|None=None,
+ env_ids:Sequence[int]|None=None,
+ ):
+"""Write joint stiffness into the simulation.
+
+ Args:
+ stiffness: Joint stiffness. Shape is (len(env_ids), len(joint_ids)).
+ joint_ids: The joint indices to set the stiffness for. Defaults to None (all joints).
+ env_ids: The environment indices to set the stiffness for. Defaults to None (all environments).
+ """
+ # note: This function isn't setting the values for actuator models. (#128)
+ # resolve indices
+ physx_env_ids=env_ids
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ physx_env_ids=self._ALL_INDICES
+ ifjoint_idsisNone:
+ joint_ids=slice(None)
+ # broadcast env_ids if needed to allow double indexing
+ ifenv_ids!=slice(None)andjoint_ids!=slice(None):
+ env_ids=env_ids[:,None]
+ # set into internal buffers
+ self._data.joint_stiffness[env_ids,joint_ids]=stiffness
+ # set into simulation
+ self.root_physx_view.set_dof_stiffnesses(self._data.joint_stiffness.cpu(),indices=physx_env_ids.cpu())
+
+
[文档]defwrite_joint_damping_to_sim(
+ self,
+ damping:torch.Tensor|float,
+ joint_ids:Sequence[int]|slice|None=None,
+ env_ids:Sequence[int]|None=None,
+ ):
+"""Write joint damping into the simulation.
+
+ Args:
+ damping: Joint damping. Shape is (len(env_ids), len(joint_ids)).
+ joint_ids: The joint indices to set the damping for.
+ Defaults to None (all joints).
+ env_ids: The environment indices to set the damping for.
+ Defaults to None (all environments).
+ """
+ # note: This function isn't setting the values for actuator models. (#128)
+ # resolve indices
+ physx_env_ids=env_ids
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ physx_env_ids=self._ALL_INDICES
+ ifjoint_idsisNone:
+ joint_ids=slice(None)
+ # broadcast env_ids if needed to allow double indexing
+ ifenv_ids!=slice(None)andjoint_ids!=slice(None):
+ env_ids=env_ids[:,None]
+ # set into internal buffers
+ self._data.joint_damping[env_ids,joint_ids]=damping
+ # set into simulation
+ self.root_physx_view.set_dof_dampings(self._data.joint_damping.cpu(),indices=physx_env_ids.cpu())
+
+
[文档]defwrite_joint_effort_limit_to_sim(
+ self,
+ limits:torch.Tensor|float,
+ joint_ids:Sequence[int]|slice|None=None,
+ env_ids:Sequence[int]|None=None,
+ ):
+"""Write joint effort limits into the simulation.
+
+ Args:
+ limits: Joint torque limits. Shape is (len(env_ids), len(joint_ids)).
+ joint_ids: The joint indices to set the joint torque limits for. Defaults to None (all joints).
+ env_ids: The environment indices to set the joint torque limits for. Defaults to None (all environments).
+ """
+ # note: This function isn't setting the values for actuator models. (#128)
+ # resolve indices
+ physx_env_ids=env_ids
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ physx_env_ids=self._ALL_INDICES
+ ifjoint_idsisNone:
+ joint_ids=slice(None)
+ # broadcast env_ids if needed to allow double indexing
+ ifenv_ids!=slice(None)andjoint_ids!=slice(None):
+ env_ids=env_ids[:,None]
+ # move tensor to cpu if needed
+ ifisinstance(limits,torch.Tensor):
+ limits=limits.cpu()
+ # set into internal buffers
+ torque_limit_all=self.root_physx_view.get_dof_max_forces()
+ torque_limit_all[env_ids,joint_ids]=limits
+ # set into simulation
+ self.root_physx_view.set_dof_max_forces(torque_limit_all.cpu(),indices=physx_env_ids.cpu())
+
+
[文档]defwrite_joint_armature_to_sim(
+ self,
+ armature:torch.Tensor|float,
+ joint_ids:Sequence[int]|slice|None=None,
+ env_ids:Sequence[int]|None=None,
+ ):
+"""Write joint armature into the simulation.
+
+ Args:
+ armature: Joint armature. Shape is (len(env_ids), len(joint_ids)).
+ joint_ids: The joint indices to set the joint torque limits for. Defaults to None (all joints).
+ env_ids: The environment indices to set the joint torque limits for. Defaults to None (all environments).
+ """
+ # resolve indices
+ physx_env_ids=env_ids
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ physx_env_ids=self._ALL_INDICES
+ ifjoint_idsisNone:
+ joint_ids=slice(None)
+ # broadcast env_ids if needed to allow double indexing
+ ifenv_ids!=slice(None)andjoint_ids!=slice(None):
+ env_ids=env_ids[:,None]
+ # set into internal buffers
+ self._data.joint_armature[env_ids,joint_ids]=armature
+ # set into simulation
+ self.root_physx_view.set_dof_armatures(self._data.joint_armature.cpu(),indices=physx_env_ids.cpu())
+
+
[文档]defwrite_joint_friction_to_sim(
+ self,
+ joint_friction:torch.Tensor|float,
+ joint_ids:Sequence[int]|slice|None=None,
+ env_ids:Sequence[int]|None=None,
+ ):
+"""Write joint friction into the simulation.
+
+ Args:
+ joint_friction: Joint friction. Shape is (len(env_ids), len(joint_ids)).
+ joint_ids: The joint indices to set the joint torque limits for. Defaults to None (all joints).
+ env_ids: The environment indices to set the joint torque limits for. Defaults to None (all environments).
+ """
+ # resolve indices
+ physx_env_ids=env_ids
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ physx_env_ids=self._ALL_INDICES
+ ifjoint_idsisNone:
+ joint_ids=slice(None)
+ # broadcast env_ids if needed to allow double indexing
+ ifenv_ids!=slice(None)andjoint_ids!=slice(None):
+ env_ids=env_ids[:,None]
+ # set into internal buffers
+ self._data.joint_friction[env_ids,joint_ids]=joint_friction
+ # set into simulation
+ self.root_physx_view.set_dof_friction_coefficients(self._data.joint_friction.cpu(),indices=physx_env_ids.cpu())
+
+
[文档]defwrite_joint_limits_to_sim(
+ self,
+ limits:torch.Tensor|float,
+ joint_ids:Sequence[int]|slice|None=None,
+ env_ids:Sequence[int]|None=None,
+ ):
+"""Write joint limits into the simulation.
+
+ Args:
+ limits: Joint limits. Shape is (len(env_ids), len(joint_ids), 2).
+ joint_ids: The joint indices to set the limits for. Defaults to None (all joints).
+ env_ids: The environment indices to set the limits for. Defaults to None (all environments).
+ """
+ # note: This function isn't setting the values for actuator models. (#128)
+ # resolve indices
+ physx_env_ids=env_ids
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ physx_env_ids=self._ALL_INDICES
+ ifjoint_idsisNone:
+ joint_ids=slice(None)
+ # broadcast env_ids if needed to allow double indexing
+ ifenv_ids!=slice(None)andjoint_ids!=slice(None):
+ env_ids=env_ids[:,None]
+ # set into internal buffers
+ self._data.joint_limits[env_ids,joint_ids]=limits
+ # set into simulation
+ self.root_physx_view.set_dof_limits(self._data.joint_limits.cpu(),indices=physx_env_ids.cpu())
+
+"""
+ Operations - Setters.
+ """
+
+
[文档]defset_external_force_and_torque(
+ self,
+ forces:torch.Tensor,
+ torques:torch.Tensor,
+ body_ids:Sequence[int]|slice|None=None,
+ env_ids:Sequence[int]|None=None,
+ ):
+"""Set external force and torque to apply on the asset's bodies in their local frame.
+
+ For many applications, we want to keep the applied external force on rigid bodies constant over a period of
+ time (for instance, during the policy control). This function allows us to store the external force and torque
+ into buffers which are then applied to the simulation at every step.
+
+ .. caution::
+ If the function is called with empty forces and torques, then this function disables the application
+ of external wrench to the simulation.
+
+ .. code-block:: python
+
+ # example of disabling external wrench
+ asset.set_external_force_and_torque(forces=torch.zeros(0, 3), torques=torch.zeros(0, 3))
+
+ .. note::
+ This function does not apply the external wrench to the simulation. It only fills the buffers with
+ the desired values. To apply the external wrench, call the :meth:`write_data_to_sim` function
+ right before the simulation step.
+
+ Args:
+ forces: External forces in bodies' local frame. Shape is (len(env_ids), len(body_ids), 3).
+ torques: External torques in bodies' local frame. Shape is (len(env_ids), len(body_ids), 3).
+ body_ids: Body indices to apply external wrench to. Defaults to None (all bodies).
+ env_ids: Environment indices to apply external wrench to. Defaults to None (all instances).
+ """
+ ifforces.any()ortorques.any():
+ self.has_external_wrench=True
+ # resolve all indices
+ # -- env_ids
+ ifenv_idsisNone:
+ env_ids=self._ALL_INDICES
+ elifnotisinstance(env_ids,torch.Tensor):
+ env_ids=torch.tensor(env_ids,dtype=torch.long,device=self.device)
+ # -- body_ids
+ ifbody_idsisNone:
+ body_ids=torch.arange(self.num_bodies,dtype=torch.long,device=self.device)
+ elifisinstance(body_ids,slice):
+ body_ids=torch.arange(self.num_bodies,dtype=torch.long,device=self.device)[body_ids]
+ elifnotisinstance(body_ids,torch.Tensor):
+ body_ids=torch.tensor(body_ids,dtype=torch.long,device=self.device)
+
+ # note: we need to do this complicated indexing since torch doesn't support multi-indexing
+ # create global body indices from env_ids and env_body_ids
+ # (env_id * total_bodies_per_env) + body_id
+ indices=body_ids.repeat(len(env_ids),1)+env_ids.unsqueeze(1)*self.num_bodies
+ indices=indices.view(-1)
+ # set into internal buffers
+ # note: these are applied in the write_to_sim function
+ self._external_force_b.flatten(0,1)[indices]=forces.flatten(0,1)
+ self._external_torque_b.flatten(0,1)[indices]=torques.flatten(0,1)
+ else:
+ self.has_external_wrench=False
+
+
[文档]defset_joint_position_target(
+ self,target:torch.Tensor,joint_ids:Sequence[int]|slice|None=None,env_ids:Sequence[int]|None=None
+ ):
+"""Set joint position targets into internal buffers.
+
+ .. note::
+ This function does not apply the joint targets to the simulation. It only fills the buffers with
+ the desired values. To apply the joint targets, call the :meth:`write_data_to_sim` function.
+
+ Args:
+ target: Joint position targets. Shape is (len(env_ids), len(joint_ids)).
+ joint_ids: The joint indices to set the targets for. Defaults to None (all joints).
+ env_ids: The environment indices to set the targets for. Defaults to None (all environments).
+ """
+ # resolve indices
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ ifjoint_idsisNone:
+ joint_ids=slice(None)
+ # broadcast env_ids if needed to allow double indexing
+ ifenv_ids!=slice(None)andjoint_ids!=slice(None):
+ env_ids=env_ids[:,None]
+ # set targets
+ self._data.joint_pos_target[env_ids,joint_ids]=target
+
+
[文档]defset_joint_velocity_target(
+ self,target:torch.Tensor,joint_ids:Sequence[int]|slice|None=None,env_ids:Sequence[int]|None=None
+ ):
+"""Set joint velocity targets into internal buffers.
+
+ .. note::
+ This function does not apply the joint targets to the simulation. It only fills the buffers with
+ the desired values. To apply the joint targets, call the :meth:`write_data_to_sim` function.
+
+ Args:
+ target: Joint velocity targets. Shape is (len(env_ids), len(joint_ids)).
+ joint_ids: The joint indices to set the targets for. Defaults to None (all joints).
+ env_ids: The environment indices to set the targets for. Defaults to None (all environments).
+ """
+ # resolve indices
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ ifjoint_idsisNone:
+ joint_ids=slice(None)
+ # broadcast env_ids if needed to allow double indexing
+ ifenv_ids!=slice(None)andjoint_ids!=slice(None):
+ env_ids=env_ids[:,None]
+ # set targets
+ self._data.joint_vel_target[env_ids,joint_ids]=target
+
+
[文档]defset_joint_effort_target(
+ self,target:torch.Tensor,joint_ids:Sequence[int]|slice|None=None,env_ids:Sequence[int]|None=None
+ ):
+"""Set joint efforts into internal buffers.
+
+ .. note::
+ This function does not apply the joint targets to the simulation. It only fills the buffers with
+ the desired values. To apply the joint targets, call the :meth:`write_data_to_sim` function.
+
+ Args:
+ target: Joint effort targets. Shape is (len(env_ids), len(joint_ids)).
+ joint_ids: The joint indices to set the targets for. Defaults to None (all joints).
+ env_ids: The environment indices to set the targets for. Defaults to None (all environments).
+ """
+ # resolve indices
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ ifjoint_idsisNone:
+ joint_ids=slice(None)
+ # broadcast env_ids if needed to allow double indexing
+ ifenv_ids!=slice(None)andjoint_ids!=slice(None):
+ env_ids=env_ids[:,None]
+ # set targets
+ self._data.joint_effort_target[env_ids,joint_ids]=target
+
+"""
+ Operations - Tendons.
+ """
+
+
[文档]defset_fixed_tendon_stiffness(
+ self,
+ stiffness:torch.Tensor,
+ fixed_tendon_ids:Sequence[int]|slice|None=None,
+ env_ids:Sequence[int]|None=None,
+ ):
+"""Set fixed tendon stiffness into internal buffers.
+
+ .. note::
+ This function does not apply the tendon stiffness to the simulation. It only fills the buffers with
+ the desired values. To apply the tendon stiffness, call the :meth:`write_fixed_tendon_properties_to_sim` function.
+
+ Args:
+ stiffness: Fixed tendon stiffness. Shape is (len(env_ids), len(fixed_tendon_ids)).
+ fixed_tendon_ids: The tendon indices to set the stiffness for. Defaults to None (all fixed tendons).
+ env_ids: The environment indices to set the stiffness for. Defaults to None (all environments).
+ """
+ # resolve indices
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ iffixed_tendon_idsisNone:
+ fixed_tendon_ids=slice(None)
+ ifenv_ids!=slice(None)andfixed_tendon_ids!=slice(None):
+ env_ids=env_ids[:,None]
+ # set stiffness
+ self._data.fixed_tendon_stiffness[env_ids,fixed_tendon_ids]=stiffness
+
+
[文档]defset_fixed_tendon_damping(
+ self,
+ damping:torch.Tensor,
+ fixed_tendon_ids:Sequence[int]|slice|None=None,
+ env_ids:Sequence[int]|None=None,
+ ):
+"""Set fixed tendon damping into internal buffers.
+
+ .. note::
+ This function does not apply the tendon damping to the simulation. It only fills the buffers with
+ the desired values. To apply the tendon damping, call the :meth:`write_fixed_tendon_properties_to_sim` function.
+
+ Args:
+ damping: Fixed tendon damping. Shape is (len(env_ids), len(fixed_tendon_ids)).
+ fixed_tendon_ids: The tendon indices to set the damping for. Defaults to None (all fixed tendons).
+ env_ids: The environment indices to set the damping for. Defaults to None (all environments).
+ """
+ # resolve indices
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ iffixed_tendon_idsisNone:
+ fixed_tendon_ids=slice(None)
+ ifenv_ids!=slice(None)andfixed_tendon_ids!=slice(None):
+ env_ids=env_ids[:,None]
+ # set damping
+ self._data.fixed_tendon_damping[env_ids,fixed_tendon_ids]=damping
+
+
[文档]defset_fixed_tendon_limit_stiffness(
+ self,
+ limit_stiffness:torch.Tensor,
+ fixed_tendon_ids:Sequence[int]|slice|None=None,
+ env_ids:Sequence[int]|None=None,
+ ):
+"""Set fixed tendon limit stiffness efforts into internal buffers.
+
+ .. note::
+ This function does not apply the tendon limit stiffness to the simulation. It only fills the buffers with
+ the desired values. To apply the tendon limit stiffness, call the :meth:`write_fixed_tendon_properties_to_sim` function.
+
+ Args:
+ limit_stiffness: Fixed tendon limit stiffness. Shape is (len(env_ids), len(fixed_tendon_ids)).
+ fixed_tendon_ids: The tendon indices to set the limit stiffness for. Defaults to None (all fixed tendons).
+ env_ids: The environment indices to set the limit stiffness for. Defaults to None (all environments).
+ """
+ # resolve indices
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ iffixed_tendon_idsisNone:
+ fixed_tendon_ids=slice(None)
+ ifenv_ids!=slice(None)andfixed_tendon_ids!=slice(None):
+ env_ids=env_ids[:,None]
+ # set limit_stiffness
+ self._data.fixed_tendon_limit_stiffness[env_ids,fixed_tendon_ids]=limit_stiffness
+
+
[文档]defset_fixed_tendon_limit(
+ self,
+ limit:torch.Tensor,
+ fixed_tendon_ids:Sequence[int]|slice|None=None,
+ env_ids:Sequence[int]|None=None,
+ ):
+"""Set fixed tendon limit efforts into internal buffers.
+
+ .. note::
+ This function does not apply the tendon limit to the simulation. It only fills the buffers with
+ the desired values. To apply the tendon limit, call the :meth:`write_fixed_tendon_properties_to_sim` function.
+
+ Args:
+ limit: Fixed tendon limit. Shape is (len(env_ids), len(fixed_tendon_ids)).
+ fixed_tendon_ids: The tendon indices to set the limit for. Defaults to None (all fixed tendons).
+ env_ids: The environment indices to set the limit for. Defaults to None (all environments).
+ """
+ # resolve indices
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ iffixed_tendon_idsisNone:
+ fixed_tendon_ids=slice(None)
+ ifenv_ids!=slice(None)andfixed_tendon_ids!=slice(None):
+ env_ids=env_ids[:,None]
+ # set limit
+ self._data.fixed_tendon_limit[env_ids,fixed_tendon_ids]=limit
+
+
[文档]defset_fixed_tendon_rest_length(
+ self,
+ rest_length:torch.Tensor,
+ fixed_tendon_ids:Sequence[int]|slice|None=None,
+ env_ids:Sequence[int]|None=None,
+ ):
+"""Set fixed tendon rest length efforts into internal buffers.
+
+ .. note::
+ This function does not apply the tendon rest length to the simulation. It only fills the buffers with
+ the desired values. To apply the tendon rest length, call the :meth:`write_fixed_tendon_properties_to_sim` function.
+
+ Args:
+ rest_length: Fixed tendon rest length. Shape is (len(env_ids), len(fixed_tendon_ids)).
+ fixed_tendon_ids: The tendon indices to set the rest length for. Defaults to None (all fixed tendons).
+ env_ids: The environment indices to set the rest length for. Defaults to None (all environments).
+ """
+ # resolve indices
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ iffixed_tendon_idsisNone:
+ fixed_tendon_ids=slice(None)
+ ifenv_ids!=slice(None)andfixed_tendon_ids!=slice(None):
+ env_ids=env_ids[:,None]
+ # set rest_length
+ self._data.fixed_tendon_rest_length[env_ids,fixed_tendon_ids]=rest_length
+
+
[文档]defset_fixed_tendon_offset(
+ self,
+ offset:torch.Tensor,
+ fixed_tendon_ids:Sequence[int]|slice|None=None,
+ env_ids:Sequence[int]|None=None,
+ ):
+"""Set fixed tendon offset efforts into internal buffers.
+
+ .. note::
+ This function does not apply the tendon offset to the simulation. It only fills the buffers with
+ the desired values. To apply the tendon offset, call the :meth:`write_fixed_tendon_properties_to_sim` function.
+
+ Args:
+ offset: Fixed tendon offset. Shape is (len(env_ids), len(fixed_tendon_ids)).
+ fixed_tendon_ids: The tendon indices to set the offset for. Defaults to None (all fixed tendons).
+ env_ids: The environment indices to set the offset for. Defaults to None (all environments).
+ """
+ # resolve indices
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ iffixed_tendon_idsisNone:
+ fixed_tendon_ids=slice(None)
+ ifenv_ids!=slice(None)andfixed_tendon_ids!=slice(None):
+ env_ids=env_ids[:,None]
+ # set offset
+ self._data.fixed_tendon_offset[env_ids,fixed_tendon_ids]=offset
+
+
[文档]defwrite_fixed_tendon_properties_to_sim(
+ self,
+ fixed_tendon_ids:Sequence[int]|slice|None=None,
+ env_ids:Sequence[int]|None=None,
+ ):
+"""Write fixed tendon properties into the simulation.
+
+ Args:
+ fixed_tendon_ids: The fixed tendon indices to set the limits for. Defaults to None (all fixed tendons).
+ env_ids: The environment indices to set the limits for. Defaults to None (all environments).
+ """
+ # resolve indices
+ physx_env_ids=env_ids
+ ifenv_idsisNone:
+ physx_env_ids=self._ALL_INDICES
+ iffixed_tendon_idsisNone:
+ fixed_tendon_ids=slice(None)
+
+ # set into simulation
+ self.root_physx_view.set_fixed_tendon_properties(
+ self._data.fixed_tendon_stiffness,
+ self._data.fixed_tendon_damping,
+ self._data.fixed_tendon_limit_stiffness,
+ self._data.fixed_tendon_limit,
+ self._data.fixed_tendon_rest_length,
+ self._data.fixed_tendon_offset,
+ indices=physx_env_ids,
+ )
+
+"""
+ Internal helper.
+ """
+
+ def_initialize_impl(self):
+ # create simulation view
+ self._physics_sim_view=physx.create_simulation_view(self._backend)
+ self._physics_sim_view.set_subspace_roots("/")
+ # obtain the first prim in the regex expression (all others are assumed to be a copy of this)
+ template_prim=sim_utils.find_first_matching_prim(self.cfg.prim_path)
+ iftemplate_primisNone:
+ raiseRuntimeError(f"Failed to find prim for expression: '{self.cfg.prim_path}'.")
+ template_prim_path=template_prim.GetPath().pathString
+
+ # find articulation root prims
+ root_prims=sim_utils.get_all_matching_child_prims(
+ template_prim_path,predicate=lambdaprim:prim.HasAPI(UsdPhysics.ArticulationRootAPI)
+ )
+ iflen(root_prims)==0:
+ raiseRuntimeError(
+ f"Failed to find an articulation when resolving '{self.cfg.prim_path}'."
+ " Please ensure that the prim has 'USD ArticulationRootAPI' applied."
+ )
+ iflen(root_prims)>1:
+ raiseRuntimeError(
+ f"Failed to find a single articulation when resolving '{self.cfg.prim_path}'."
+ f" Found multiple '{root_prims}' under '{template_prim_path}'."
+ " Please ensure that there is only one articulation in the prim path tree."
+ )
+
+ # resolve articulation root prim back into regex expression
+ root_prim_path=root_prims[0].GetPath().pathString
+ root_prim_path_expr=self.cfg.prim_path+root_prim_path[len(template_prim_path):]
+ # -- articulation
+ self._root_physx_view=self._physics_sim_view.create_articulation_view(root_prim_path_expr.replace(".*","*"))
+
+ # check if the articulation was created
+ ifself._root_physx_view._backendisNone:
+ raiseRuntimeError(f"Failed to create articulation at: {self.cfg.prim_path}. Please check PhysX logs.")
+
+ # log information about the articulation
+ carb.log_info(f"Articulation initialized at: {self.cfg.prim_path} with root '{root_prim_path_expr}'.")
+ carb.log_info(f"Is fixed root: {self.is_fixed_base}")
+ carb.log_info(f"Number of bodies: {self.num_bodies}")
+ carb.log_info(f"Body names: {self.body_names}")
+ carb.log_info(f"Number of joints: {self.num_joints}")
+ carb.log_info(f"Joint names: {self.joint_names}")
+ carb.log_info(f"Number of fixed tendons: {self.num_fixed_tendons}")
+
+ # container for data access
+ self._data=ArticulationData(self.root_physx_view,self.device)
+
+ # create buffers
+ self._create_buffers()
+ # process configuration
+ self._process_cfg()
+ self._process_actuators_cfg()
+ self._process_fixed_tendons()
+ # validate configuration
+ self._validate_cfg()
+ # update the robot data
+ self.update(0.0)
+ # log joint information
+ self._log_articulation_joint_info()
+
+ def_create_buffers(self):
+ # constants
+ self._ALL_INDICES=torch.arange(self.num_instances,dtype=torch.long,device=self.device)
+
+ # external forces and torques
+ self.has_external_wrench=False
+ self._external_force_b=torch.zeros((self.num_instances,self.num_bodies,3),device=self.device)
+ self._external_torque_b=torch.zeros_like(self._external_force_b)
+
+ # asset data
+ # -- properties
+ self._data.joint_names=self.joint_names
+ self._data.body_names=self.body_names
+
+ # -- bodies
+ self._data.default_mass=self.root_physx_view.get_masses().clone()
+ self._data.default_inertia=self.root_physx_view.get_inertias().clone()
+
+ # -- default joint state
+ self._data.default_joint_pos=torch.zeros(self.num_instances,self.num_joints,device=self.device)
+ self._data.default_joint_vel=torch.zeros_like(self._data.default_joint_pos)
+
+ # -- joint commands
+ self._data.joint_pos_target=torch.zeros_like(self._data.default_joint_pos)
+ self._data.joint_vel_target=torch.zeros_like(self._data.default_joint_pos)
+ self._data.joint_effort_target=torch.zeros_like(self._data.default_joint_pos)
+ self._data.joint_stiffness=torch.zeros_like(self._data.default_joint_pos)
+ self._data.joint_damping=torch.zeros_like(self._data.default_joint_pos)
+ self._data.joint_armature=torch.zeros_like(self._data.default_joint_pos)
+ self._data.joint_friction=torch.zeros_like(self._data.default_joint_pos)
+ self._data.joint_limits=torch.zeros(self.num_instances,self.num_joints,2,device=self.device)
+
+ # -- joint commands (explicit)
+ self._data.computed_torque=torch.zeros_like(self._data.default_joint_pos)
+ self._data.applied_torque=torch.zeros_like(self._data.default_joint_pos)
+
+ # -- tendons
+ ifself.num_fixed_tendons>0:
+ self._data.fixed_tendon_stiffness=torch.zeros(
+ self.num_instances,self.num_fixed_tendons,device=self.device
+ )
+ self._data.fixed_tendon_damping=torch.zeros(
+ self.num_instances,self.num_fixed_tendons,device=self.device
+ )
+ self._data.fixed_tendon_limit_stiffness=torch.zeros(
+ self.num_instances,self.num_fixed_tendons,device=self.device
+ )
+ self._data.fixed_tendon_limit=torch.zeros(
+ self.num_instances,self.num_fixed_tendons,2,device=self.device
+ )
+ self._data.fixed_tendon_rest_length=torch.zeros(
+ self.num_instances,self.num_fixed_tendons,device=self.device
+ )
+ self._data.fixed_tendon_offset=torch.zeros(self.num_instances,self.num_fixed_tendons,device=self.device)
+
+ # -- other data
+ self._data.soft_joint_pos_limits=torch.zeros(self.num_instances,self.num_joints,2,device=self.device)
+ self._data.soft_joint_vel_limits=torch.zeros(self.num_instances,self.num_joints,device=self.device)
+ self._data.gear_ratio=torch.ones(self.num_instances,self.num_joints,device=self.device)
+
+ # -- initialize default buffers related to joint properties
+ self._data.default_joint_stiffness=torch.zeros(self.num_instances,self.num_joints,device=self.device)
+ self._data.default_joint_damping=torch.zeros(self.num_instances,self.num_joints,device=self.device)
+ self._data.default_joint_armature=torch.zeros(self.num_instances,self.num_joints,device=self.device)
+ self._data.default_joint_friction=torch.zeros(self.num_instances,self.num_joints,device=self.device)
+ self._data.default_joint_limits=torch.zeros(self.num_instances,self.num_joints,2,device=self.device)
+
+ # -- initialize default buffers related to fixed tendon properties
+ ifself.num_fixed_tendons>0:
+ self._data.default_fixed_tendon_stiffness=torch.zeros(
+ self.num_instances,self.num_fixed_tendons,device=self.device
+ )
+ self._data.default_fixed_tendon_damping=torch.zeros(
+ self.num_instances,self.num_fixed_tendons,device=self.device
+ )
+ self._data.default_fixed_tendon_limit_stiffness=torch.zeros(
+ self.num_instances,self.num_fixed_tendons,device=self.device
+ )
+ self._data.default_fixed_tendon_limit=torch.zeros(
+ self.num_instances,self.num_fixed_tendons,2,device=self.device
+ )
+ self._data.default_fixed_tendon_rest_length=torch.zeros(
+ self.num_instances,self.num_fixed_tendons,device=self.device
+ )
+ self._data.default_fixed_tendon_offset=torch.zeros(
+ self.num_instances,self.num_fixed_tendons,device=self.device
+ )
+
+ # soft joint position limits (recommended not to be too close to limits).
+ joint_pos_limits=self.root_physx_view.get_dof_limits()
+ joint_pos_mean=(joint_pos_limits[...,0]+joint_pos_limits[...,1])/2
+ joint_pos_range=joint_pos_limits[...,1]-joint_pos_limits[...,0]
+ soft_limit_factor=self.cfg.soft_joint_pos_limit_factor
+ # add to data
+ self._data.soft_joint_pos_limits[...,0]=joint_pos_mean-0.5*joint_pos_range*soft_limit_factor
+ self._data.soft_joint_pos_limits[...,1]=joint_pos_mean+0.5*joint_pos_range*soft_limit_factor
+
+ # create buffers to store processed actions from actuator models
+ self._joint_pos_target_sim=torch.zeros_like(self._data.joint_pos_target)
+ self._joint_vel_target_sim=torch.zeros_like(self._data.joint_pos_target)
+ self._joint_effort_target_sim=torch.zeros_like(self._data.joint_pos_target)
+
+ def_process_cfg(self):
+"""Post processing of configuration parameters."""
+ # default state
+ # -- root state
+ # note: we cast to tuple to avoid torch/numpy type mismatch.
+ default_root_state=(
+ tuple(self.cfg.init_state.pos)
+ +tuple(self.cfg.init_state.rot)
+ +tuple(self.cfg.init_state.lin_vel)
+ +tuple(self.cfg.init_state.ang_vel)
+ )
+ default_root_state=torch.tensor(default_root_state,dtype=torch.float,device=self.device)
+ self._data.default_root_state=default_root_state.repeat(self.num_instances,1)
+ # -- joint state
+ # joint pos
+ indices_list,_,values_list=string_utils.resolve_matching_names_values(
+ self.cfg.init_state.joint_pos,self.joint_names
+ )
+ self._data.default_joint_pos[:,indices_list]=torch.tensor(values_list,device=self.device)
+ # joint vel
+ indices_list,_,values_list=string_utils.resolve_matching_names_values(
+ self.cfg.init_state.joint_vel,self.joint_names
+ )
+ self._data.default_joint_vel[:,indices_list]=torch.tensor(values_list,device=self.device)
+
+ # -- joint limits
+ self._data.default_joint_limits=self.root_physx_view.get_dof_limits().to(device=self.device).clone()
+ self._data.joint_limits=self._data.default_joint_limits.clone()
+
+"""
+ Internal simulation callbacks.
+ """
+
+ def_invalidate_initialize_callback(self,event):
+"""Invalidates the scene elements."""
+ # call parent
+ super()._invalidate_initialize_callback(event)
+ # set all existing views to None to invalidate them
+ self._physics_sim_view=None
+ self._root_physx_view=None
+
+"""
+ Internal helpers -- Actuators.
+ """
+
+ def_process_actuators_cfg(self):
+"""Process and apply articulation joint properties."""
+ # create actuators
+ self.actuators=dict()
+ # flag for implicit actuators
+ # if this is false, we by-pass certain checks when doing actuator-related operations
+ self._has_implicit_actuators=False
+
+ # cache the values coming from the usd
+ usd_stiffness=self.root_physx_view.get_dof_stiffnesses().clone()
+ usd_damping=self.root_physx_view.get_dof_dampings().clone()
+ usd_armature=self.root_physx_view.get_dof_armatures().clone()
+ usd_friction=self.root_physx_view.get_dof_friction_coefficients().clone()
+ usd_effort_limit=self.root_physx_view.get_dof_max_forces().clone()
+ usd_velocity_limit=self.root_physx_view.get_dof_max_velocities().clone()
+
+ # iterate over all actuator configurations
+ foractuator_name,actuator_cfginself.cfg.actuators.items():
+ # type annotation for type checkers
+ actuator_cfg:ActuatorBaseCfg
+ # create actuator group
+ joint_ids,joint_names=self.find_joints(actuator_cfg.joint_names_expr)
+ # check if any joints are found
+ iflen(joint_names)==0:
+ raiseValueError(
+ f"No joints found for actuator group: {actuator_name} with joint name expression:"
+ f" {actuator_cfg.joint_names_expr}."
+ )
+ # create actuator collection
+ # note: for efficiency avoid indexing when over all indices
+ actuator:ActuatorBase=actuator_cfg.class_type(
+ cfg=actuator_cfg,
+ joint_names=joint_names,
+ joint_ids=(
+ slice(None)iflen(joint_names)==self.num_jointselsetorch.tensor(joint_ids,device=self.device)
+ ),
+ num_envs=self.num_instances,
+ device=self.device,
+ stiffness=usd_stiffness[:,joint_ids],
+ damping=usd_damping[:,joint_ids],
+ armature=usd_armature[:,joint_ids],
+ friction=usd_friction[:,joint_ids],
+ effort_limit=usd_effort_limit[:,joint_ids],
+ velocity_limit=usd_velocity_limit[:,joint_ids],
+ )
+ # log information on actuator groups
+ carb.log_info(
+ f"Actuator collection: {actuator_name} with model '{actuator_cfg.class_type.__name__}' and"
+ f" joint names: {joint_names} [{joint_ids}]."
+ )
+ # store actuator group
+ self.actuators[actuator_name]=actuator
+ # set the passed gains and limits into the simulation
+ ifisinstance(actuator,ImplicitActuator):
+ self._has_implicit_actuators=True
+ # the gains and limits are set into the simulation since actuator model is implicit
+ self.write_joint_stiffness_to_sim(actuator.stiffness,joint_ids=actuator.joint_indices)
+ self.write_joint_damping_to_sim(actuator.damping,joint_ids=actuator.joint_indices)
+ self.write_joint_effort_limit_to_sim(actuator.effort_limit,joint_ids=actuator.joint_indices)
+ self.write_joint_armature_to_sim(actuator.armature,joint_ids=actuator.joint_indices)
+ self.write_joint_friction_to_sim(actuator.friction,joint_ids=actuator.joint_indices)
+ else:
+ # the gains and limits are processed by the actuator model
+ # we set gains to zero, and torque limit to a high value in simulation to avoid any interference
+ self.write_joint_stiffness_to_sim(0.0,joint_ids=actuator.joint_indices)
+ self.write_joint_damping_to_sim(0.0,joint_ids=actuator.joint_indices)
+ self.write_joint_effort_limit_to_sim(1.0e9,joint_ids=actuator.joint_indices)
+ self.write_joint_armature_to_sim(actuator.armature,joint_ids=actuator.joint_indices)
+ self.write_joint_friction_to_sim(actuator.friction,joint_ids=actuator.joint_indices)
+
+ # set the default joint parameters based on the changes from the actuators
+ self._data.default_joint_stiffness=self.root_physx_view.get_dof_stiffnesses().to(device=self.device).clone()
+ self._data.default_joint_damping=self.root_physx_view.get_dof_dampings().to(device=self.device).clone()
+ self._data.default_joint_armature=self.root_physx_view.get_dof_armatures().to(device=self.device).clone()
+ self._data.default_joint_friction=(
+ self.root_physx_view.get_dof_friction_coefficients().to(device=self.device).clone()
+ )
+
+ # perform some sanity checks to ensure actuators are prepared correctly
+ total_act_joints=sum(actuator.num_jointsforactuatorinself.actuators.values())
+ iftotal_act_joints!=(self.num_joints-self.num_fixed_tendons):
+ carb.log_warn(
+ "Not all actuators are configured! Total number of actuated joints not equal to number of"
+ f" joints available: {total_act_joints} != {self.num_joints-self.num_fixed_tendons}."
+ )
+
+ def_process_fixed_tendons(self):
+"""Process fixed tendons."""
+ # create a list to store the fixed tendon names
+ self._fixed_tendon_names=list()
+
+ # parse fixed tendons properties if they exist
+ ifself.num_fixed_tendons>0:
+ stage=stage_utils.get_current_stage()
+
+ # iterate over all joints to find tendons attached to them
+ forjinrange(self.num_joints):
+ usd_joint_path=self.root_physx_view.dof_paths[0][j]
+ # check whether joint has tendons - tendon name follows the joint name it is attached to
+ joint=UsdPhysics.Joint.Get(stage,usd_joint_path)
+ ifjoint.GetPrim().HasAPI(PhysxSchema.PhysxTendonAxisRootAPI):
+ joint_name=usd_joint_path.split("/")[-1]
+ self._fixed_tendon_names.append(joint_name)
+
+ self._data.fixed_tendon_names=self._fixed_tendon_names
+ self._data.default_fixed_tendon_stiffness=self.root_physx_view.get_fixed_tendon_stiffnesses().clone()
+ self._data.default_fixed_tendon_damping=self.root_physx_view.get_fixed_tendon_dampings().clone()
+ self._data.default_fixed_tendon_limit_stiffness=(
+ self.root_physx_view.get_fixed_tendon_limit_stiffnesses().clone()
+ )
+ self._data.default_fixed_tendon_limit=self.root_physx_view.get_fixed_tendon_limits().clone()
+ self._data.default_fixed_tendon_rest_length=self.root_physx_view.get_fixed_tendon_rest_lengths().clone()
+ self._data.default_fixed_tendon_offset=self.root_physx_view.get_fixed_tendon_offsets().clone()
+
+ def_apply_actuator_model(self):
+"""Processes joint commands for the articulation by forwarding them to the actuators.
+
+ The actions are first processed using actuator models. Depending on the robot configuration,
+ the actuator models compute the joint level simulation commands and sets them into the PhysX buffers.
+ """
+ # process actions per group
+ foractuatorinself.actuators.values():
+ # prepare input for actuator model based on cached data
+ # TODO : A tensor dict would be nice to do the indexing of all tensors together
+ control_action=ArticulationActions(
+ joint_positions=self._data.joint_pos_target[:,actuator.joint_indices],
+ joint_velocities=self._data.joint_vel_target[:,actuator.joint_indices],
+ joint_efforts=self._data.joint_effort_target[:,actuator.joint_indices],
+ joint_indices=actuator.joint_indices,
+ )
+ # compute joint command from the actuator model
+ control_action=actuator.compute(
+ control_action,
+ joint_pos=self._data.joint_pos[:,actuator.joint_indices],
+ joint_vel=self._data.joint_vel[:,actuator.joint_indices],
+ )
+ # update targets (these are set into the simulation)
+ ifcontrol_action.joint_positionsisnotNone:
+ self._joint_pos_target_sim[:,actuator.joint_indices]=control_action.joint_positions
+ ifcontrol_action.joint_velocitiesisnotNone:
+ self._joint_vel_target_sim[:,actuator.joint_indices]=control_action.joint_velocities
+ ifcontrol_action.joint_effortsisnotNone:
+ self._joint_effort_target_sim[:,actuator.joint_indices]=control_action.joint_efforts
+ # update state of the actuator model
+ # -- torques
+ self._data.computed_torque[:,actuator.joint_indices]=actuator.computed_effort
+ self._data.applied_torque[:,actuator.joint_indices]=actuator.applied_effort
+ # -- actuator data
+ self._data.soft_joint_vel_limits[:,actuator.joint_indices]=actuator.velocity_limit
+ # TODO: find a cleaner way to handle gear ratio. Only needed for variable gear ratio actuators.
+ ifhasattr(actuator,"gear_ratio"):
+ self._data.gear_ratio[:,actuator.joint_indices]=actuator.gear_ratio
+
+"""
+ Internal helpers -- Debugging.
+ """
+
+ def_validate_cfg(self):
+"""Validate the configuration after processing.
+
+ Note:
+ This function should be called only after the configuration has been processed and the buffers have been
+ created. Otherwise, some settings that are altered during processing may not be validated.
+ For instance, the actuator models may change the joint max velocity limits.
+ """
+ # check that the default values are within the limits
+ joint_pos_limits=self.root_physx_view.get_dof_limits()[0].to(self.device)
+ out_of_range=self._data.default_joint_pos[0]<joint_pos_limits[:,0]
+ out_of_range|=self._data.default_joint_pos[0]>joint_pos_limits[:,1]
+ violated_indices=torch.nonzero(out_of_range,as_tuple=False).squeeze(-1)
+ # throw error if any of the default joint positions are out of the limits
+ iflen(violated_indices)>0:
+ # prepare message for violated joints
+ msg="The following joints have default positions out of the limits: \n"
+ foridxinviolated_indices:
+ joint_name=self.data.joint_names[idx]
+ joint_limits=joint_pos_limits[idx]
+ joint_pos=self.data.default_joint_pos[0,idx]
+ # add to message
+ msg+=f"\t- '{joint_name}': {joint_pos:.3f} not in [{joint_limits[0]:.3f}, {joint_limits[1]:.3f}]\n"
+ raiseValueError(msg)
+
+ # check that the default joint velocities are within the limits
+ joint_max_vel=self.root_physx_view.get_dof_max_velocities()[0].to(self.device)
+ out_of_range=torch.abs(self._data.default_joint_vel[0])>joint_max_vel
+ violated_indices=torch.nonzero(out_of_range,as_tuple=False).squeeze(-1)
+ iflen(violated_indices)>0:
+ # prepare message for violated joints
+ msg="The following joints have default velocities out of the limits: \n"
+ foridxinviolated_indices:
+ joint_name=self.data.joint_names[idx]
+ joint_limits=[-joint_max_vel[idx],joint_max_vel[idx]]
+ joint_vel=self.data.default_joint_vel[0,idx]
+ # add to message
+ msg+=f"\t- '{joint_name}': {joint_vel:.3f} not in [{joint_limits[0]:.3f}, {joint_limits[1]:.3f}]\n"
+ raiseValueError(msg)
+
+ def_log_articulation_joint_info(self):
+"""Log information about the articulation's simulated joints."""
+ # read out all joint parameters from simulation
+ # -- gains
+ stiffnesses=self.root_physx_view.get_dof_stiffnesses()[0].tolist()
+ dampings=self.root_physx_view.get_dof_dampings()[0].tolist()
+ # -- properties
+ armatures=self.root_physx_view.get_dof_armatures()[0].tolist()
+ frictions=self.root_physx_view.get_dof_friction_coefficients()[0].tolist()
+ # -- limits
+ position_limits=self.root_physx_view.get_dof_limits()[0].tolist()
+ velocity_limits=self.root_physx_view.get_dof_max_velocities()[0].tolist()
+ effort_limits=self.root_physx_view.get_dof_max_forces()[0].tolist()
+ # create table for term information
+ table=PrettyTable(float_format=".3f")
+ table.title=f"Simulation Joint Information (Prim path: {self.cfg.prim_path})"
+ table.field_names=[
+ "Index",
+ "Name",
+ "Stiffness",
+ "Damping",
+ "Armature",
+ "Friction",
+ "Position Limits",
+ "Velocity Limits",
+ "Effort Limits",
+ ]
+ # set alignment of table columns
+ table.align["Name"]="l"
+ # add info on each term
+ forindex,nameinenumerate(self.joint_names):
+ table.add_row([
+ index,
+ name,
+ stiffnesses[index],
+ dampings[index],
+ armatures[index],
+ frictions[index],
+ position_limits[index],
+ velocity_limits[index],
+ effort_limits[index],
+ ])
+ # convert table to string
+ carb.log_info(f"Simulation parameters for joints in {self.cfg.prim_path}:\n"+table.get_string())
+
+ # read out all tendon parameters from simulation
+ ifself.num_fixed_tendons>0:
+ # -- gains
+ ft_stiffnesses=self.root_physx_view.get_fixed_tendon_stiffnesses()[0].tolist()
+ ft_dampings=self.root_physx_view.get_fixed_tendon_dampings()[0].tolist()
+ # -- limits
+ ft_limit_stiffnesses=self.root_physx_view.get_fixed_tendon_limit_stiffnesses()[0].tolist()
+ ft_limits=self.root_physx_view.get_fixed_tendon_limits()[0].tolist()
+ ft_rest_lengths=self.root_physx_view.get_fixed_tendon_rest_lengths()[0].tolist()
+ ft_offsets=self.root_physx_view.get_fixed_tendon_offsets()[0].tolist()
+ # create table for term information
+ tendon_table=PrettyTable(float_format=".3f")
+ tendon_table.title=f"Simulation Tendon Information (Prim path: {self.cfg.prim_path})"
+ tendon_table.field_names=[
+ "Index",
+ "Stiffness",
+ "Damping",
+ "Limit Stiffness",
+ "Limit",
+ "Rest Length",
+ "Offset",
+ ]
+ # add info on each term
+ forindexinrange(self.num_fixed_tendons):
+ tendon_table.add_row([
+ index,
+ ft_stiffnesses[index],
+ ft_dampings[index],
+ ft_limit_stiffnesses[index],
+ ft_limits[index],
+ ft_rest_lengths[index],
+ ft_offsets[index],
+ ])
+ # convert table to string
+ carb.log_info(f"Simulation parameters for tendons in {self.cfg.prim_path}:\n"+tendon_table.get_string())
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+fromdataclassesimportMISSING
+
+fromomni.isaac.lab.actuatorsimportActuatorBaseCfg
+fromomni.isaac.lab.utilsimportconfigclass
+
+from..asset_base_cfgimportAssetBaseCfg
+from.articulationimportArticulation
+
+
+
[文档]@configclass
+classArticulationCfg(AssetBaseCfg):
+"""Configuration parameters for an articulation."""
+
+
[文档]@configclass
+ classInitialStateCfg(AssetBaseCfg.InitialStateCfg):
+"""Initial state of the articulation."""
+
+ # root velocity
+ lin_vel:tuple[float,float,float]=(0.0,0.0,0.0)
+"""Linear velocity of the root in simulation world frame. Defaults to (0.0, 0.0, 0.0)."""
+ ang_vel:tuple[float,float,float]=(0.0,0.0,0.0)
+"""Angular velocity of the root in simulation world frame. Defaults to (0.0, 0.0, 0.0)."""
+
+ # joint state
+ joint_pos:dict[str,float]={".*":0.0}
+"""Joint positions of the joints. Defaults to 0.0 for all joints."""
+ joint_vel:dict[str,float]={".*":0.0}
+"""Joint velocities of the joints. Defaults to 0.0 for all joints."""
+
+ ##
+ # Initialize configurations.
+ ##
+
+ class_type:type=Articulation
+
+ init_state:InitialStateCfg=InitialStateCfg()
+"""Initial state of the articulated object. Defaults to identity pose with zero velocity and zero joint state."""
+
+ soft_joint_pos_limit_factor:float=1.0
+"""Fraction specifying the range of DOF position limits (parsed from the asset) to use. Defaults to 1.0.
+
+ The joint position limits are scaled by this factor to allow for a limited range of motion.
+ This is accessible in the articulation data through :attr:`ArticulationData.soft_joint_pos_limits` attribute.
+ """
+
+ actuators:dict[str,ActuatorBaseCfg]=MISSING
+"""Actuators for the robot with corresponding joint names."""
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+importtorch
+importweakref
+
+importomni.physics.tensors.impl.apiasphysx
+
+importomni.isaac.lab.utils.mathasmath_utils
+fromomni.isaac.lab.utils.buffersimportTimestampedBuffer
+
+
+
[文档]classArticulationData:
+"""Data container for an articulation.
+
+ This class contains the data for an articulation in the simulation. The data includes the state of
+ the root rigid body, the state of all the bodies in the articulation, and the joint state. The data is
+ stored in the simulation world frame unless otherwise specified.
+
+ An articulation is comprised of multiple rigid bodies or links. For a rigid body, there are two frames
+ of reference that are used:
+
+ - Actor frame: The frame of reference of the rigid body prim. This typically corresponds to the Xform prim
+ with the rigid body schema.
+ - Center of mass frame: The frame of reference of the center of mass of the rigid body.
+
+ Depending on the settings, the two frames may not coincide with each other. In the robotics sense, the actor frame
+ can be interpreted as the link frame.
+ """
+
+ def__init__(self,root_physx_view:physx.ArticulationView,device:str):
+"""Initializes the articulation data.
+
+ Args:
+ root_physx_view: The root articulation view.
+ device: The device used for processing.
+ """
+ # Set the parameters
+ self.device=device
+ # Set the root articulation view
+ # note: this is stored as a weak reference to avoid circular references between the asset class
+ # and the data container. This is important to avoid memory leaks.
+ self._root_physx_view:physx.ArticulationView=weakref.proxy(root_physx_view)
+
+ # Set initial time stamp
+ self._sim_timestamp=0.0
+
+ # Obtain global physics sim view
+ self._physics_sim_view=physx.create_simulation_view("torch")
+ self._physics_sim_view.set_subspace_roots("/")
+ gravity=self._physics_sim_view.get_gravity()
+ # Convert to direction vector
+ gravity_dir=torch.tensor((gravity[0],gravity[1],gravity[2]),device=self.device)
+ gravity_dir=math_utils.normalize(gravity_dir.unsqueeze(0)).squeeze(0)
+
+ # Initialize constants
+ self.GRAVITY_VEC_W=gravity_dir.repeat(self._root_physx_view.count,1)
+ self.FORWARD_VEC_B=torch.tensor((1.0,0.0,0.0),device=self.device).repeat(self._root_physx_view.count,1)
+
+ # Initialize history for finite differencing
+ self._previous_joint_vel=self._root_physx_view.get_dof_velocities().clone()
+
+ # Initialize the lazy buffers.
+ self._root_state_w=TimestampedBuffer()
+ self._body_state_w=TimestampedBuffer()
+ self._body_acc_w=TimestampedBuffer()
+ self._joint_pos=TimestampedBuffer()
+ self._joint_acc=TimestampedBuffer()
+ self._joint_vel=TimestampedBuffer()
+
+ defupdate(self,dt:float):
+ # update the simulation timestamp
+ self._sim_timestamp+=dt
+ # Trigger an update of the joint acceleration buffer at a higher frequency
+ # since we do finite differencing.
+ self.joint_acc
+
+ ##
+ # Names.
+ ##
+
+ body_names:list[str]=None
+"""Body names in the order parsed by the simulation view."""
+
+ joint_names:list[str]=None
+"""Joint names in the order parsed by the simulation view."""
+
+ fixed_tendon_names:list[str]=None
+"""Fixed tendon names in the order parsed by the simulation view."""
+
+ ##
+ # Defaults.
+ ##
+
+ default_root_state:torch.Tensor=None
+"""Default root state ``[pos, quat, lin_vel, ang_vel]`` in local environment frame. Shape is (num_instances, 13).
+
+ The position and quaternion are of the articulation root's actor frame. Meanwhile, the linear and angular
+ velocities are of its center of mass frame.
+ """
+
+ default_mass:torch.Tensor=None
+"""Default mass read from the simulation. Shape is (num_instances, num_bodies)."""
+
+ default_inertia:torch.Tensor=None
+"""Default inertia read from the simulation. Shape is (num_instances, num_bodies, 9).
+
+ The inertia is the inertia tensor relative to the center of mass frame. The values are stored in
+ the order :math:`[I_{xx}, I_{xy}, I_{xz}, I_{yx}, I_{yy}, I_{yz}, I_{zx}, I_{zy}, I_{zz}]`.
+ """
+
+ default_joint_pos:torch.Tensor=None
+"""Default joint positions of all joints. Shape is (num_instances, num_joints)."""
+
+ default_joint_vel:torch.Tensor=None
+"""Default joint velocities of all joints. Shape is (num_instances, num_joints)."""
+
+ default_joint_stiffness:torch.Tensor=None
+"""Default joint stiffness of all joints. Shape is (num_instances, num_joints)."""
+
+ default_joint_damping:torch.Tensor=None
+"""Default joint damping of all joints. Shape is (num_instances, num_joints)."""
+
+ default_joint_armature:torch.Tensor=None
+"""Default joint armature of all joints. Shape is (num_instances, num_joints)."""
+
+ default_joint_friction:torch.Tensor=None
+"""Default joint friction of all joints. Shape is (num_instances, num_joints)."""
+
+ default_joint_limits:torch.Tensor=None
+"""Default joint limits of all joints. Shape is (num_instances, num_joints, 2)."""
+
+ default_fixed_tendon_stiffness:torch.Tensor=None
+"""Default tendon stiffness of all tendons. Shape is (num_instances, num_fixed_tendons)."""
+
+ default_fixed_tendon_damping:torch.Tensor=None
+"""Default tendon damping of all tendons. Shape is (num_instances, num_fixed_tendons)."""
+
+ default_fixed_tendon_limit_stiffness:torch.Tensor=None
+"""Default tendon limit stiffness of all tendons. Shape is (num_instances, num_fixed_tendons)."""
+
+ default_fixed_tendon_rest_length:torch.Tensor=None
+"""Default tendon rest length of all tendons. Shape is (num_instances, num_fixed_tendons)."""
+
+ default_fixed_tendon_offset:torch.Tensor=None
+"""Default tendon offset of all tendons. Shape is (num_instances, num_fixed_tendons)."""
+
+ default_fixed_tendon_limit:torch.Tensor=None
+"""Default tendon limits of all tendons. Shape is (num_instances, num_fixed_tendons, 2)."""
+
+ ##
+ # Joint commands -- Set into simulation.
+ ##
+
+ joint_pos_target:torch.Tensor=None
+"""Joint position targets commanded by the user. Shape is (num_instances, num_joints).
+
+ For an implicit actuator model, the targets are directly set into the simulation.
+ For an explicit actuator model, the targets are used to compute the joint torques (see :attr:`applied_torque`),
+ which are then set into the simulation.
+ """
+
+ joint_vel_target:torch.Tensor=None
+"""Joint velocity targets commanded by the user. Shape is (num_instances, num_joints).
+
+ For an implicit actuator model, the targets are directly set into the simulation.
+ For an explicit actuator model, the targets are used to compute the joint torques (see :attr:`applied_torque`),
+ which are then set into the simulation.
+ """
+
+ joint_effort_target:torch.Tensor=None
+"""Joint effort targets commanded by the user. Shape is (num_instances, num_joints).
+
+ For an implicit actuator model, the targets are directly set into the simulation.
+ For an explicit actuator model, the targets are used to compute the joint torques (see :attr:`applied_torque`),
+ which are then set into the simulation.
+ """
+
+ ##
+ # Joint commands -- Explicit actuators.
+ ##
+
+ computed_torque:torch.Tensor=None
+"""Joint torques computed from the actuator model (before clipping). Shape is (num_instances, num_joints).
+
+ This quantity is the raw torque output from the actuator mode, before any clipping is applied.
+ It is exposed for users who want to inspect the computations inside the actuator model.
+ For instance, to penalize the learning agent for a difference between the computed and applied torques.
+
+ Note: The torques are zero for implicit actuator models.
+ """
+
+ applied_torque:torch.Tensor=None
+"""Joint torques applied from the actuator model (after clipping). Shape is (num_instances, num_joints).
+
+ These torques are set into the simulation, after clipping the :attr:`computed_torque` based on the
+ actuator model.
+
+ Note: The torques are zero for implicit actuator models.
+ """
+
+ ##
+ # Joint properties.
+ ##
+
+ joint_stiffness:torch.Tensor=None
+"""Joint stiffness provided to simulation. Shape is (num_instances, num_joints)."""
+
+ joint_damping:torch.Tensor=None
+"""Joint damping provided to simulation. Shape is (num_instances, num_joints)."""
+
+ joint_armature:torch.Tensor=None
+"""Joint armature provided to simulation. Shape is (num_instances, num_joints)."""
+
+ joint_friction:torch.Tensor=None
+"""Joint friction provided to simulation. Shape is (num_instances, num_joints)."""
+
+ joint_limits:torch.Tensor=None
+"""Joint limits provided to simulation. Shape is (num_instances, num_joints, 2)."""
+
+ ##
+ # Fixed tendon properties.
+ ##
+
+ fixed_tendon_stiffness:torch.Tensor=None
+"""Fixed tendon stiffness provided to simulation. Shape is (num_instances, num_fixed_tendons)."""
+
+ fixed_tendon_damping:torch.Tensor=None
+"""Fixed tendon damping provided to simulation. Shape is (num_instances, num_fixed_tendons)."""
+
+ fixed_tendon_limit_stiffness:torch.Tensor=None
+"""Fixed tendon limit stiffness provided to simulation. Shape is (num_instances, num_fixed_tendons)."""
+
+ fixed_tendon_rest_length:torch.Tensor=None
+"""Fixed tendon rest length provided to simulation. Shape is (num_instances, num_fixed_tendons)."""
+
+ fixed_tendon_offset:torch.Tensor=None
+"""Fixed tendon offset provided to simulation. Shape is (num_instances, num_fixed_tendons)."""
+
+ fixed_tendon_limit:torch.Tensor=None
+"""Fixed tendon limits provided to simulation. Shape is (num_instances, num_fixed_tendons, 2)."""
+
+ ##
+ # Other Data.
+ ##
+
+ soft_joint_pos_limits:torch.Tensor=None
+"""Joint positions limits for all joints. Shape is (num_instances, num_joints, 2)."""
+
+ soft_joint_vel_limits:torch.Tensor=None
+"""Joint velocity limits for all joints. Shape is (num_instances, num_joints)."""
+
+ gear_ratio:torch.Tensor=None
+"""Gear ratio for relating motor torques to applied Joint torques. Shape is (num_instances, num_joints)."""
+
+ ##
+ # Properties.
+ ##
+
+ @property
+ defroot_state_w(self):
+"""Root state ``[pos, quat, lin_vel, ang_vel]`` in simulation world frame. Shape is (num_instances, 13).
+
+ The position and quaternion are of the articulation root's actor frame. Meanwhile, the linear and angular
+ velocities are of the articulation root's center of mass frame.
+ """
+ ifself._root_state_w.timestamp<self._sim_timestamp:
+ # read data from simulation
+ pose=self._root_physx_view.get_root_transforms().clone()
+ pose[:,3:7]=math_utils.convert_quat(pose[:,3:7],to="wxyz")
+ velocity=self._root_physx_view.get_root_velocities()
+ # set the buffer data and timestamp
+ self._root_state_w.data=torch.cat((pose,velocity),dim=-1)
+ self._root_state_w.timestamp=self._sim_timestamp
+ returnself._root_state_w.data
+
+ @property
+ defbody_state_w(self):
+"""State of all bodies `[pos, quat, lin_vel, ang_vel]` in simulation world frame.
+ Shape is (num_instances, num_bodies, 13).
+
+ The position and quaternion are of all the articulation links's actor frame. Meanwhile, the linear and angular
+ velocities are of the articulation links's center of mass frame.
+ """
+ ifself._body_state_w.timestamp<self._sim_timestamp:
+ self._physics_sim_view.update_articulations_kinematic()
+ # read data from simulation
+ poses=self._root_physx_view.get_link_transforms().clone()
+ poses[...,3:7]=math_utils.convert_quat(poses[...,3:7],to="wxyz")
+ velocities=self._root_physx_view.get_link_velocities()
+ # set the buffer data and timestamp
+ self._body_state_w.data=torch.cat((poses,velocities),dim=-1)
+ self._body_state_w.timestamp=self._sim_timestamp
+ returnself._body_state_w.data
+
+ @property
+ defbody_acc_w(self):
+"""Acceleration of all bodies. Shape is (num_instances, num_bodies, 6).
+
+ This quantity is the acceleration of the articulation links' center of mass frame.
+ """
+ ifself._body_acc_w.timestamp<self._sim_timestamp:
+ # read data from simulation and set the buffer data and timestamp
+ self._body_acc_w.data=self._root_physx_view.get_link_accelerations()
+ self._body_acc_w.timestamp=self._sim_timestamp
+ returnself._body_acc_w.data
+
+ @property
+ defprojected_gravity_b(self):
+"""Projection of the gravity direction on base frame. Shape is (num_instances, 3)."""
+ returnmath_utils.quat_rotate_inverse(self.root_quat_w,self.GRAVITY_VEC_W)
+
+ @property
+ defheading_w(self):
+"""Yaw heading of the base frame (in radians). Shape is (num_instances,).
+
+ Note:
+ This quantity is computed by assuming that the forward-direction of the base
+ frame is along x-direction, i.e. :math:`(1, 0, 0)`.
+ """
+ forward_w=math_utils.quat_apply(self.root_quat_w,self.FORWARD_VEC_B)
+ returntorch.atan2(forward_w[:,1],forward_w[:,0])
+
+ @property
+ defjoint_pos(self):
+"""Joint positions of all joints. Shape is (num_instances, num_joints)."""
+ ifself._joint_pos.timestamp<self._sim_timestamp:
+ # read data from simulation and set the buffer data and timestamp
+ self._joint_pos.data=self._root_physx_view.get_dof_positions()
+ self._joint_pos.timestamp=self._sim_timestamp
+ returnself._joint_pos.data
+
+ @property
+ defjoint_vel(self):
+"""Joint velocities of all joints. Shape is (num_instances, num_joints)."""
+ ifself._joint_vel.timestamp<self._sim_timestamp:
+ # read data from simulation and set the buffer data and timestamp
+ self._joint_vel.data=self._root_physx_view.get_dof_velocities()
+ self._joint_vel.timestamp=self._sim_timestamp
+ returnself._joint_vel.data
+
+ @property
+ defjoint_acc(self):
+"""Joint acceleration of all joints. Shape is (num_instances, num_joints)."""
+ ifself._joint_acc.timestamp<self._sim_timestamp:
+ # note: we use finite differencing to compute acceleration
+ time_elapsed=self._sim_timestamp-self._joint_acc.timestamp
+ self._joint_acc.data=(self.joint_vel-self._previous_joint_vel)/time_elapsed
+ self._joint_acc.timestamp=self._sim_timestamp
+ # update the previous joint velocity
+ self._previous_joint_vel[:]=self.joint_vel
+ returnself._joint_acc.data
+
+ ##
+ # Derived properties.
+ ##
+
+ @property
+ defroot_pos_w(self)->torch.Tensor:
+"""Root position in simulation world frame. Shape is (num_instances, 3).
+
+ This quantity is the position of the actor frame of the articulation root.
+ """
+ returnself.root_state_w[:,:3]
+
+ @property
+ defroot_quat_w(self)->torch.Tensor:
+"""Root orientation (w, x, y, z) in simulation world frame. Shape is (num_instances, 4).
+
+ This quantity is the orientation of the actor frame of the articulation root.
+ """
+ returnself.root_state_w[:,3:7]
+
+ @property
+ defroot_vel_w(self)->torch.Tensor:
+"""Root velocity in simulation world frame. Shape is (num_instances, 6).
+
+ This quantity contains the linear and angular velocities of the articulation root's center of
+ mass frame.
+ """
+ returnself.root_state_w[:,7:13]
+
+ @property
+ defroot_lin_vel_w(self)->torch.Tensor:
+"""Root linear velocity in simulation world frame. Shape is (num_instances, 3).
+
+ This quantity is the linear velocity of the articulation root's center of mass frame.
+ """
+ returnself.root_state_w[:,7:10]
+
+ @property
+ defroot_ang_vel_w(self)->torch.Tensor:
+"""Root angular velocity in simulation world frame. Shape is (num_instances, 3).
+
+ This quantity is the angular velocity of the articulation root's center of mass frame.
+ """
+ returnself.root_state_w[:,10:13]
+
+ @property
+ defroot_lin_vel_b(self)->torch.Tensor:
+"""Root linear velocity in base frame. Shape is (num_instances, 3).
+
+ This quantity is the linear velocity of the articulation root's center of mass frame with
+ respect to the articulation root's actor frame.
+ """
+ returnmath_utils.quat_rotate_inverse(self.root_quat_w,self.root_lin_vel_w)
+
+ @property
+ defroot_ang_vel_b(self)->torch.Tensor:
+"""Root angular velocity in base world frame. Shape is (num_instances, 3).
+
+ This quantity is the angular velocity of the articulation root's center of mass frame with respect to the
+ articulation root's actor frame.
+ """
+ returnmath_utils.quat_rotate_inverse(self.root_quat_w,self.root_ang_vel_w)
+
+ @property
+ defbody_pos_w(self)->torch.Tensor:
+"""Positions of all bodies in simulation world frame. Shape is (num_instances, num_bodies, 3).
+
+ This quantity is the position of the rigid bodies' actor frame.
+ """
+ returnself.body_state_w[...,:3]
+
+ @property
+ defbody_quat_w(self)->torch.Tensor:
+"""Orientation (w, x, y, z) of all bodies in simulation world frame. Shape is (num_instances, num_bodies, 4).
+
+ This quantity is the orientation of the rigid bodies' actor frame.
+ """
+ returnself.body_state_w[...,3:7]
+
+ @property
+ defbody_vel_w(self)->torch.Tensor:
+"""Velocity of all bodies in simulation world frame. Shape is (num_instances, num_bodies, 6).
+
+ This quantity contains the linear and angular velocities of the rigid bodies' center of mass frame.
+ """
+ returnself.body_state_w[...,7:13]
+
+ @property
+ defbody_lin_vel_w(self)->torch.Tensor:
+"""Linear velocity of all bodies in simulation world frame. Shape is (num_instances, num_bodies, 3).
+
+ This quantity is the linear velocity of the rigid bodies' center of mass frame.
+ """
+ returnself.body_state_w[...,7:10]
+
+ @property
+ defbody_ang_vel_w(self)->torch.Tensor:
+"""Angular velocity of all bodies in simulation world frame. Shape is (num_instances, num_bodies, 3).
+
+ This quantity is the angular velocity of the rigid bodies' center of mass frame.
+ """
+ returnself.body_state_w[...,10:13]
+
+ @property
+ defbody_lin_acc_w(self)->torch.Tensor:
+"""Linear acceleration of all bodies in simulation world frame. Shape is (num_instances, num_bodies, 3).
+
+ This quantity is the linear acceleration of the rigid bodies' center of mass frame.
+ """
+ returnself.body_acc_w[...,0:3]
+
+ @property
+ defbody_ang_acc_w(self)->torch.Tensor:
+"""Angular acceleration of all bodies in simulation world frame. Shape is (num_instances, num_bodies, 3).
+
+ This quantity is the angular acceleration of the rigid bodies' center of mass frame.
+ """
+ returnself.body_acc_w[...,3:6]
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+from__future__importannotations
+
+importinspect
+importre
+importweakref
+fromabcimportABC,abstractmethod
+fromcollections.abcimportSequence
+fromtypingimportTYPE_CHECKING,Any
+
+importomni.kit.app
+importomni.timeline
+
+importomni.isaac.lab.simassim_utils
+
+ifTYPE_CHECKING:
+ from.asset_base_cfgimportAssetBaseCfg
+
+
+
[文档]classAssetBase(ABC):
+"""The base interface class for assets.
+
+ An asset corresponds to any physics-enabled object that can be spawned in the simulation. These include
+ rigid objects, articulated objects, deformable objects etc. The core functionality of an asset is to
+ provide a set of buffers that can be used to interact with the simulator. The buffers are updated
+ by the asset class and can be written into the simulator using the their respective ``write`` methods.
+ This allows a convenient way to perform post-processing operations on the buffers before writing them
+ into the simulator and obtaining the corresponding simulation results.
+
+ The class handles both the spawning of the asset into the USD stage as well as initialization of necessary
+ physics handles to interact with the asset. Upon construction of the asset instance, the prim corresponding
+ to the asset is spawned into the USD stage if the spawn configuration is not None. The spawn configuration
+ is defined in the :attr:`AssetBaseCfg.spawn` attribute. In case the configured :attr:`AssetBaseCfg.prim_path`
+ is an expression, then the prim is spawned at all the matching paths. Otherwise, a single prim is spawned
+ at the configured path. For more information on the spawn configuration, see the
+ :mod:`omni.isaac.lab.sim.spawners` module.
+
+ Unlike Isaac Sim interface, where one usually needs to call the
+ :meth:`omni.isaac.core.prims.XFormPrimView.initialize` method to initialize the PhysX handles, the asset
+ class automatically initializes and invalidates the PhysX handles when the stage is played/stopped. This
+ is done by registering callbacks for the stage play/stop events.
+
+ Additionally, the class registers a callback for debug visualization of the asset if a debug visualization
+ is implemented in the asset class. This can be enabled by setting the :attr:`AssetBaseCfg.debug_vis` attribute
+ to True. The debug visualization is implemented through the :meth:`_set_debug_vis_impl` and
+ :meth:`_debug_vis_callback` methods.
+ """
+
+
[文档]def__init__(self,cfg:AssetBaseCfg):
+"""Initialize the asset base.
+
+ Args:
+ cfg: The configuration class for the asset.
+
+ Raises:
+ RuntimeError: If no prims found at input prim path or prim path expression.
+ """
+ # store inputs
+ self.cfg=cfg
+ # flag for whether the asset is initialized
+ self._is_initialized=False
+
+ # check if base asset path is valid
+ # note: currently the spawner does not work if there is a regex pattern in the leaf
+ # For example, if the prim path is "/World/Robot_[1,2]" since the spawner will not
+ # know which prim to spawn. This is a limitation of the spawner and not the asset.
+ asset_path=self.cfg.prim_path.split("/")[-1]
+ asset_path_is_regex=re.match(r"^[a-zA-Z0-9/_]+$",asset_path)isNone
+ # spawn the asset
+ ifself.cfg.spawnisnotNoneandnotasset_path_is_regex:
+ self.cfg.spawn.func(
+ self.cfg.prim_path,
+ self.cfg.spawn,
+ translation=self.cfg.init_state.pos,
+ orientation=self.cfg.init_state.rot,
+ )
+ # check that spawn was successful
+ matching_prims=sim_utils.find_matching_prims(self.cfg.prim_path)
+ iflen(matching_prims)==0:
+ raiseRuntimeError(f"Could not find prim with path {self.cfg.prim_path}.")
+
+ # note: Use weakref on all callbacks to ensure that this object can be deleted when its destructor is called.
+ # add callbacks for stage play/stop
+ # The order is set to 10 which is arbitrary but should be lower priority than the default order of 0
+ timeline_event_stream=omni.timeline.get_timeline_interface().get_timeline_event_stream()
+ self._initialize_handle=timeline_event_stream.create_subscription_to_pop_by_type(
+ int(omni.timeline.TimelineEventType.PLAY),
+ lambdaevent,obj=weakref.proxy(self):obj._initialize_callback(event),
+ order=10,
+ )
+ self._invalidate_initialize_handle=timeline_event_stream.create_subscription_to_pop_by_type(
+ int(omni.timeline.TimelineEventType.STOP),
+ lambdaevent,obj=weakref.proxy(self):obj._invalidate_initialize_callback(event),
+ order=10,
+ )
+ # add handle for debug visualization (this is set to a valid handle inside set_debug_vis)
+ self._debug_vis_handle=None
+ # set initial state of debug visualization
+ self.set_debug_vis(self.cfg.debug_vis)
+
+ def__del__(self):
+"""Unsubscribe from the callbacks."""
+ # clear physics events handles
+ ifself._initialize_handle:
+ self._initialize_handle.unsubscribe()
+ self._initialize_handle=None
+ ifself._invalidate_initialize_handle:
+ self._invalidate_initialize_handle.unsubscribe()
+ self._invalidate_initialize_handle=None
+ # clear debug visualization
+ ifself._debug_vis_handle:
+ self._debug_vis_handle.unsubscribe()
+ self._debug_vis_handle=None
+
+"""
+ Properties
+ """
+
+ @property
+ defis_initialized(self)->bool:
+"""Whether the asset is initialized.
+
+ Returns True if the asset is initialized, False otherwise.
+ """
+ returnself._is_initialized
+
+ @property
+ @abstractmethod
+ defnum_instances(self)->int:
+"""Number of instances of the asset.
+
+ This is equal to the number of asset instances per environment multiplied by the number of environments.
+ """
+ returnNotImplementedError
+
+ @property
+ defdevice(self)->str:
+"""Memory device for computation."""
+ returnself._device
+
+ @property
+ @abstractmethod
+ defdata(self)->Any:
+"""Data related to the asset."""
+ returnNotImplementedError
+
+ @property
+ defhas_debug_vis_implementation(self)->bool:
+"""Whether the asset has a debug visualization implemented."""
+ # check if function raises NotImplementedError
+ source_code=inspect.getsource(self._set_debug_vis_impl)
+ return"NotImplementedError"notinsource_code
+
+"""
+ Operations.
+ """
+
+
[文档]defset_debug_vis(self,debug_vis:bool)->bool:
+"""Sets whether to visualize the asset data.
+
+ Args:
+ debug_vis: Whether to visualize the asset data.
+
+ Returns:
+ Whether the debug visualization was successfully set. False if the asset
+ does not support debug visualization.
+ """
+ # check if debug visualization is supported
+ ifnotself.has_debug_vis_implementation:
+ returnFalse
+ # toggle debug visualization objects
+ self._set_debug_vis_impl(debug_vis)
+ # toggle debug visualization handles
+ ifdebug_vis:
+ # create a subscriber for the post update event if it doesn't exist
+ ifself._debug_vis_handleisNone:
+ app_interface=omni.kit.app.get_app_interface()
+ self._debug_vis_handle=app_interface.get_post_update_event_stream().create_subscription_to_pop(
+ lambdaevent,obj=weakref.proxy(self):obj._debug_vis_callback(event)
+ )
+ else:
+ # remove the subscriber if it exists
+ ifself._debug_vis_handleisnotNone:
+ self._debug_vis_handle.unsubscribe()
+ self._debug_vis_handle=None
+ # return success
+ returnTrue
+
+
[文档]@abstractmethod
+ defreset(self,env_ids:Sequence[int]|None=None):
+"""Resets all internal buffers of selected environments.
+
+ Args:
+ env_ids: The indices of the object to reset. Defaults to None (all instances).
+ """
+ raiseNotImplementedError
+
+
[文档]@abstractmethod
+ defwrite_data_to_sim(self):
+"""Writes data to the simulator."""
+ raiseNotImplementedError
+
+
[文档]@abstractmethod
+ defupdate(self,dt:float):
+"""Update the internal buffers.
+
+ The time step ``dt`` is used to compute numerical derivatives of quantities such as joint
+ accelerations which are not provided by the simulator.
+
+ Args:
+ dt: The amount of time passed from last ``update`` call.
+ """
+ raiseNotImplementedError
+
+"""
+ Implementation specific.
+ """
+
+ @abstractmethod
+ def_initialize_impl(self):
+"""Initializes the PhysX handles and internal buffers."""
+ raiseNotImplementedError
+
+ def_set_debug_vis_impl(self,debug_vis:bool):
+"""Set debug visualization into visualization objects.
+
+ This function is responsible for creating the visualization objects if they don't exist
+ and input ``debug_vis`` is True. If the visualization objects exist, the function should
+ set their visibility into the stage.
+ """
+ raiseNotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.")
+
+ def_debug_vis_callback(self,event):
+"""Callback for debug visualization.
+
+ This function calls the visualization objects and sets the data to visualize into them.
+ """
+ raiseNotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.")
+
+"""
+ Internal simulation callbacks.
+ """
+
+ def_initialize_callback(self,event):
+"""Initializes the scene elements.
+
+ Note:
+ PhysX handles are only enabled once the simulator starts playing. Hence, this function needs to be
+ called whenever the simulator "plays" from a "stop" state.
+ """
+ ifnotself._is_initialized:
+ # obtain simulation related information
+ sim=sim_utils.SimulationContext.instance()
+ ifsimisNone:
+ raiseRuntimeError("SimulationContext is not initialized! Please initialize SimulationContext first.")
+ self._backend=sim.backend
+ self._device=sim.device
+ # initialize the asset
+ self._initialize_impl()
+ # set flag
+ self._is_initialized=True
+
+ def_invalidate_initialize_callback(self,event):
+"""Invalidates the scene elements."""
+ self._is_initialized=False
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+fromdataclassesimportMISSING
+fromtypingimportLiteral
+
+fromomni.isaac.lab.simimportSpawnerCfg
+fromomni.isaac.lab.utilsimportconfigclass
+
+from.asset_baseimportAssetBase
+
+
+
[文档]@configclass
+classAssetBaseCfg:
+"""The base configuration class for an asset's parameters.
+
+ Please see the :class:`AssetBase` class for more information on the asset class.
+ """
+
+
[文档]@configclass
+ classInitialStateCfg:
+"""Initial state of the asset.
+
+ This defines the default initial state of the asset when it is spawned into the simulation, as
+ well as the default state when the simulation is reset.
+
+ After parsing the initial state, the asset class stores this information in the :attr:`data`
+ attribute of the asset class. This can then be accessed by the user to modify the state of the asset
+ during the simulation, for example, at resets.
+ """
+
+ # root position
+ pos:tuple[float,float,float]=(0.0,0.0,0.0)
+"""Position of the root in simulation world frame. Defaults to (0.0, 0.0, 0.0)."""
+ rot:tuple[float,float,float,float]=(1.0,0.0,0.0,0.0)
+"""Quaternion rotation (w, x, y, z) of the root in simulation world frame.
+ Defaults to (1.0, 0.0, 0.0, 0.0).
+ """
+
+ class_type:type[AssetBase]=MISSING
+"""The associated asset class.
+
+ The class should inherit from :class:`omni.isaac.lab.assets.asset_base.AssetBase`.
+ """
+
+ prim_path:str=MISSING
+"""Prim path (or expression) to the asset.
+
+ .. note::
+ The expression can contain the environment namespace regex ``{ENV_REGEX_NS}`` which
+ will be replaced with the environment namespace.
+
+ Example: ``{ENV_REGEX_NS}/Robot`` will be replaced with ``/World/envs/env_.*/Robot``.
+ """
+
+ spawn:SpawnerCfg|None=None
+"""Spawn configuration for the asset. Defaults to None.
+
+ If None, then no prims are spawned by the asset class. Instead, it is assumed that the
+ asset is already present in the scene.
+ """
+
+ init_state:InitialStateCfg=InitialStateCfg()
+"""Initial state of the rigid object. Defaults to identity pose."""
+
+ collision_group:Literal[0,-1]=0
+"""Collision group of the asset. Defaults to ``0``.
+
+ * ``-1``: global collision group (collides with all assets in the scene).
+ * ``0``: local collision group (collides with other assets in the same environment).
+ """
+
+ debug_vis:bool=False
+"""Whether to enable debug visualization for the asset. Defaults to ``False``."""
[文档]classDeformableObject(AssetBase):
+"""A deformable object asset class.
+
+ Deformable objects are assets that can be deformed in the simulation. They are typically used for
+ soft bodies, such as stuffed animals and food items.
+
+ Unlike rigid object assets, deformable objects have a more complex structure and require additional
+ handling for simulation. The simulation of deformable objects follows a finite element approach, where
+ the object is discretized into a mesh of nodes and elements. The nodes are connected by elements, which
+ define the material properties of the object. The nodes can be moved and deformed, and the elements
+ respond to these changes.
+
+ The state of a deformable object comprises of its nodal positions and velocities, and not the object's root
+ position and orientation. The nodal positions and velocities are in the simulation frame.
+
+ Soft bodies can be `partially kinematic`_, where some nodes are driven by kinematic targets, and the rest are
+ simulated. The kinematic targets are the desired positions of the nodes, and the simulation drives the nodes
+ towards these targets. This is useful for partial control of the object, such as moving a stuffed animal's
+ head while the rest of the body is simulated.
+
+ .. attention::
+ This class is experimental and subject to change due to changes on the underlying PhysX API on which
+ it depends. We will try to maintain backward compatibility as much as possible but some changes may be
+ necessary.
+
+ .. _partially kinematic: https://nvidia-omniverse.github.io/PhysX/physx/5.4.1/docs/SoftBodies.html#kinematic-soft-bodies
+ """
+
+ cfg:DeformableObjectCfg
+"""Configuration instance for the deformable object."""
+
+
[文档]def__init__(self,cfg:DeformableObjectCfg):
+"""Initialize the deformable object.
+
+ Args:
+ cfg: A configuration instance.
+ """
+ super().__init__(cfg)
+
+"""
+ Properties
+ """
+
+ @property
+ defdata(self)->DeformableObjectData:
+ returnself._data
+
+ @property
+ defnum_instances(self)->int:
+ returnself.root_physx_view.count
+
+ @property
+ defnum_bodies(self)->int:
+"""Number of bodies in the asset.
+
+ This is always 1 since each object is a single deformable body.
+ """
+ return1
+
+ @property
+ defroot_physx_view(self)->physx.SoftBodyView:
+"""Deformable body view for the asset (PhysX).
+
+ Note:
+ Use this view with caution. It requires handling of tensors in a specific way.
+ """
+ returnself._root_physx_view
+
+ @property
+ defmaterial_physx_view(self)->physx.SoftBodyMaterialView|None:
+"""Deformable material view for the asset (PhysX).
+
+ This view is optional and may not be available if the material is not bound to the deformable body.
+ If the material is not available, then the material properties will be set to default values.
+
+ Note:
+ Use this view with caution. It requires handling of tensors in a specific way.
+ """
+ returnself._material_physx_view
+
+ @property
+ defmax_sim_elements_per_body(self)->int:
+"""The maximum number of simulation mesh elements per deformable body."""
+ returnself.root_physx_view.max_sim_elements_per_body
+
+ @property
+ defmax_collision_elements_per_body(self)->int:
+"""The maximum number of collision mesh elements per deformable body."""
+ returnself.root_physx_view.max_elements_per_body
+
+ @property
+ defmax_sim_vertices_per_body(self)->int:
+"""The maximum number of simulation mesh vertices per deformable body."""
+ returnself.root_physx_view.max_sim_vertices_per_body
+
+ @property
+ defmax_collision_vertices_per_body(self)->int:
+"""The maximum number of collision mesh vertices per deformable body."""
+ returnself.root_physx_view.max_vertices_per_body
+
+"""
+ Operations.
+ """
+
+
[文档]defreset(self,env_ids:Sequence[int]|None=None):
+ # Think: Should we reset the kinematic targets when resetting the object?
+ # This is not done in the current implementation. We assume users will reset the kinematic targets.
+ pass
[文档]defwrite_nodal_state_to_sim(self,nodal_state:torch.Tensor,env_ids:Sequence[int]|None=None):
+"""Set the nodal state over selected environment indices into the simulation.
+
+ The nodal state comprises of the nodal positions and velocities. Since these are nodes, the velocity only has
+ a translational component. All the quantities are in the simulation frame.
+
+ Args:
+ nodal_state: Nodal state in simulation frame.
+ Shape is (len(env_ids), max_sim_vertices_per_body, 6).
+ env_ids: Environment indices. If None, then all indices are used.
+ """
+ # set into simulation
+ self.write_nodal_pos_to_sim(nodal_state[...,:3],env_ids=env_ids)
+ self.write_nodal_velocity_to_sim(nodal_state[...,3:],env_ids=env_ids)
+
+
[文档]defwrite_nodal_pos_to_sim(self,nodal_pos:torch.Tensor,env_ids:Sequence[int]|None=None):
+"""Set the nodal positions over selected environment indices into the simulation.
+
+ The nodal position comprises of individual nodal positions of the simulation mesh for the deformable body.
+ The positions are in the simulation frame.
+
+ Args:
+ nodal_pos: Nodal positions in simulation frame.
+ Shape is (len(env_ids), max_sim_vertices_per_body, 3).
+ env_ids: Environment indices. If None, then all indices are used.
+ """
+ # resolve all indices
+ physx_env_ids=env_ids
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ physx_env_ids=self._ALL_INDICES
+ # note: we need to do this here since tensors are not set into simulation until step.
+ # set into internal buffers
+ self._data.nodal_pos_w[env_ids]=nodal_pos.clone()
+ # set into simulation
+ self.root_physx_view.set_sim_nodal_positions(self._data.nodal_pos_w,indices=physx_env_ids)
+
+
[文档]defwrite_nodal_velocity_to_sim(self,nodal_vel:torch.Tensor,env_ids:Sequence[int]|None=None):
+"""Set the nodal velocity over selected environment indices into the simulation.
+
+ The nodal velocity comprises of individual nodal velocities of the simulation mesh for the deformable
+ body. Since these are nodes, the velocity only has a translational component. The velocities are in the
+ simulation frame.
+
+ Args:
+ nodal_vel: Nodal velocities in simulation frame.
+ Shape is (len(env_ids), max_sim_vertices_per_body, 3).
+ env_ids: Environment indices. If None, then all indices are used.
+ """
+ # resolve all indices
+ physx_env_ids=env_ids
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ physx_env_ids=self._ALL_INDICES
+ # note: we need to do this here since tensors are not set into simulation until step.
+ # set into internal buffers
+ self._data.nodal_vel_w[env_ids]=nodal_vel.clone()
+ # set into simulation
+ self.root_physx_view.set_sim_nodal_velocities(self._data.nodal_vel_w,indices=physx_env_ids)
+
+
[文档]defwrite_nodal_kinematic_target_to_sim(self,targets:torch.Tensor,env_ids:Sequence[int]|None=None):
+"""Set the kinematic targets of the simulation mesh for the deformable bodies indicated by the indices.
+
+ The kinematic targets comprise of individual nodal positions of the simulation mesh for the deformable body
+ and a flag indicating whether the node is kinematically driven or not. The positions are in the simulation frame.
+
+ Note:
+ The flag is set to 0.0 for kinematically driven nodes and 1.0 for free nodes.
+
+ Args:
+ targets: The kinematic targets comprising of nodal positions and flags.
+ Shape is (len(env_ids), max_sim_vertices_per_body, 4).
+ env_ids: Environment indices. If None, then all indices are used.
+ """
+ # resolve all indices
+ physx_env_ids=env_ids
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ physx_env_ids=self._ALL_INDICES
+ # store into internal buffers
+ self._data.nodal_kinematic_target[env_ids]=targets.clone()
+ # set into simulation
+ self.root_physx_view.set_sim_kinematic_targets(self._data.nodal_kinematic_target,indices=physx_env_ids)
+
+"""
+ Operations - Helper.
+ """
+
+
[文档]deftransform_nodal_pos(
+ self,nodal_pos:torch.tensor,pos:torch.Tensor|None=None,quat:torch.Tensor|None=None
+ )->torch.Tensor:
+"""Transform the nodal positions based on the pose transformation.
+
+ This function computes the transformation of the nodal positions based on the pose transformation.
+ It multiplies the nodal positions with the rotation matrix of the pose and adds the translation.
+ Internally, it calls the :meth:`omni.isaac.lab.utils.math.transform_points` function.
+
+ Args:
+ nodal_pos: The nodal positions in the simulation frame. Shape is (N, max_sim_vertices_per_body, 3).
+ pos: The position transformation. Shape is (N, 3).
+ Defaults to None, in which case the position is assumed to be zero.
+ quat: The orientation transformation as quaternion (w, x, y, z). Shape is (N, 4).
+ Defaults to None, in which case the orientation is assumed to be identity.
+
+ Returns:
+ The transformed nodal positions. Shape is (N, max_sim_vertices_per_body, 3).
+ """
+ # offset the nodal positions to center them around the origin
+ mean_nodal_pos=nodal_pos.mean(dim=1,keepdim=True)
+ nodal_pos=nodal_pos-mean_nodal_pos
+ # transform the nodal positions based on the pose around the origin
+ returnmath_utils.transform_points(nodal_pos,pos,quat)+mean_nodal_pos
+
+"""
+ Internal helper.
+ """
+
+ def_initialize_impl(self):
+ # create simulation view
+ self._physics_sim_view=physx.create_simulation_view(self._backend)
+ self._physics_sim_view.set_subspace_roots("/")
+ # obtain the first prim in the regex expression (all others are assumed to be a copy of this)
+ template_prim=sim_utils.find_first_matching_prim(self.cfg.prim_path)
+ iftemplate_primisNone:
+ raiseRuntimeError(f"Failed to find prim for expression: '{self.cfg.prim_path}'.")
+ template_prim_path=template_prim.GetPath().pathString
+
+ # find deformable root prims
+ root_prims=sim_utils.get_all_matching_child_prims(
+ template_prim_path,predicate=lambdaprim:prim.HasAPI(PhysxSchema.PhysxDeformableBodyAPI)
+ )
+ iflen(root_prims)==0:
+ raiseRuntimeError(
+ f"Failed to find a deformable body when resolving '{self.cfg.prim_path}'."
+ " Please ensure that the prim has 'PhysxSchema.PhysxDeformableBodyAPI' applied."
+ )
+ iflen(root_prims)>1:
+ raiseRuntimeError(
+ f"Failed to find a single deformable body when resolving '{self.cfg.prim_path}'."
+ f" Found multiple '{root_prims}' under '{template_prim_path}'."
+ " Please ensure that there is only one deformable body in the prim path tree."
+ )
+ # we only need the first one from the list
+ root_prim=root_prims[0]
+
+ # find deformable material prims
+ material_prim=None
+ # obtain material prim from the root prim
+ # note: here we assume that all the root prims have their material prims at similar paths
+ # and we only need to find the first one. This may not be the case for all scenarios.
+ # However, the checks in that case get cumbersome and are not included here.
+ ifroot_prim.HasAPI(UsdShade.MaterialBindingAPI):
+ # check the materials that are bound with the purpose 'physics'
+ material_paths=UsdShade.MaterialBindingAPI(root_prim).GetDirectBindingRel("physics").GetTargets()
+ # iterate through targets and find the deformable body material
+ iflen(material_paths)>0:
+ format_pathinmaterial_paths:
+ mat_prim=root_prim.GetStage().GetPrimAtPath(mat_path)
+ ifmat_prim.HasAPI(PhysxSchema.PhysxDeformableBodyMaterialAPI):
+ material_prim=mat_prim
+ break
+ ifmaterial_primisNone:
+ carb.log_info(
+ f"Failed to find a deformable material binding for '{root_prim.GetPath().pathString}'."
+ " The material properties will be set to default values and are not modifiable at runtime."
+ " If you want to modify the material properties, please ensure that the material is bound"
+ " to the deformable body."
+ )
+
+ # resolve root path back into regex expression
+ # -- root prim expression
+ root_prim_path=root_prim.GetPath().pathString
+ root_prim_path_expr=self.cfg.prim_path+root_prim_path[len(template_prim_path):]
+ # -- object view
+ self._root_physx_view=self._physics_sim_view.create_soft_body_view(root_prim_path_expr.replace(".*","*"))
+
+ # Return if the asset is not found
+ ifself._root_physx_view._backendisNone:
+ raiseRuntimeError(f"Failed to create deformable body at: {self.cfg.prim_path}. Please check PhysX logs.")
+
+ # resolve material path back into regex expression
+ ifmaterial_primisnotNone:
+ # -- material prim expression
+ material_prim_path=material_prim.GetPath().pathString
+ # check if the material prim is under the template prim
+ # if not then we are assuming that the single material prim is used for all the deformable bodies
+ iftemplate_prim_pathinmaterial_prim_path:
+ material_prim_path_expr=self.cfg.prim_path+material_prim_path[len(template_prim_path):]
+ else:
+ material_prim_path_expr=material_prim_path
+ # -- material view
+ self._material_physx_view=self._physics_sim_view.create_soft_body_material_view(
+ material_prim_path_expr.replace(".*","*")
+ )
+ else:
+ self._material_physx_view=None
+
+ # log information about the deformable body
+ carb.log_info(f"Deformable body initialized at: {root_prim_path_expr}")
+ carb.log_info(f"Number of instances: {self.num_instances}")
+ carb.log_info(f"Number of bodies: {self.num_bodies}")
+ ifself._material_physx_viewisnotNone:
+ carb.log_info(f"Deformable material initialized at: {material_prim_path_expr}")
+ carb.log_info(f"Number of instances: {self._material_physx_view.count}")
+ else:
+ carb.log_info("No deformable material found. Material properties will be set to default values.")
+
+ # container for data access
+ self._data=DeformableObjectData(self.root_physx_view,self.device)
+
+ # create buffers
+ self._create_buffers()
+ # update the deformable body data
+ self.update(0.0)
+
+ def_create_buffers(self):
+"""Create buffers for storing data."""
+ # constants
+ self._ALL_INDICES=torch.arange(self.num_instances,dtype=torch.long,device=self.device)
+
+ # default state
+ # we use the initial nodal positions at spawn time as the default state
+ # note: these are all in the simulation frame
+ nodal_positions=self.root_physx_view.get_sim_nodal_positions()
+ nodal_velocities=torch.zeros_like(nodal_positions)
+ self._data.default_nodal_state_w=torch.cat((nodal_positions,nodal_velocities),dim=-1)
+
+ # kinematic targets
+ self._data.nodal_kinematic_target=self.root_physx_view.get_sim_kinematic_targets()
+ # set all nodes as non-kinematic targets by default
+ self._data.nodal_kinematic_target[...,-1]=1.0
+
+"""
+ Internal simulation callbacks.
+ """
+
+ def_set_debug_vis_impl(self,debug_vis:bool):
+ # set visibility of markers
+ # note: parent only deals with callbacks. not their visibility
+ ifdebug_vis:
+ ifnothasattr(self,"target_visualizer"):
+ self.target_visualizer=VisualizationMarkers(self.cfg.visualizer_cfg)
+ # set their visibility to true
+ self.target_visualizer.set_visibility(True)
+ else:
+ ifhasattr(self,"target_visualizer"):
+ self.target_visualizer.set_visibility(False)
+
+ def_debug_vis_callback(self,event):
+ # check where to visualize
+ targets_enabled=self.data.nodal_kinematic_target[:,:,3]==0.0
+ num_enabled=int(torch.sum(targets_enabled).item())
+ # get positions if any targets are enabled
+ ifnum_enabled==0:
+ # create a marker below the ground
+ positions=torch.tensor([[0.0,0.0,-10.0]],device=self.device)
+ else:
+ positions=self.data.nodal_kinematic_target[targets_enabled][...,:3]
+ # show target visualizer
+ self.target_visualizer.visualize(positions)
+
+ def_invalidate_initialize_callback(self,event):
+"""Invalidates the scene elements."""
+ # call parent
+ super()._invalidate_initialize_callback(event)
+ # set all existing views to None to invalidate them
+ self._physics_sim_view=None
+ self._root_physx_view=None
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+from__future__importannotations
+
+fromomni.isaac.lab.markersimportVisualizationMarkersCfg
+fromomni.isaac.lab.markers.configimportDEFORMABLE_TARGET_MARKER_CFG
+fromomni.isaac.lab.utilsimportconfigclass
+
+from..asset_base_cfgimportAssetBaseCfg
+from.deformable_objectimportDeformableObject
+
+
+
[文档]@configclass
+classDeformableObjectCfg(AssetBaseCfg):
+"""Configuration parameters for a deformable object."""
+
+ class_type:type=DeformableObject
+
+ visualizer_cfg:VisualizationMarkersCfg=DEFORMABLE_TARGET_MARKER_CFG.replace(
+ prim_path="/Visuals/DeformableTarget"
+ )
+"""The configuration object for the visualization markers. Defaults to DEFORMABLE_TARGET_MARKER_CFG.
+
+ Note:
+ This attribute is only used when debug visualization is enabled.
+ """
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+importtorch
+importweakref
+
+importomni.physics.tensors.impl.apiasphysx
+
+importomni.isaac.lab.utils.mathasmath_utils
+fromomni.isaac.lab.utils.buffersimportTimestampedBuffer
+
+
+
[文档]classDeformableObjectData:
+"""Data container for a deformable object.
+
+ This class contains the data for a deformable object in the simulation. The data includes the nodal states of
+ the root deformable body in the object. The data is stored in the simulation world frame unless otherwise specified.
+
+ A deformable object in PhysX uses two tetrahedral meshes to represent the object:
+
+ 1. **Simulation mesh**: This mesh is used for the simulation and is the one that is deformed by the solver.
+ 2. **Collision mesh**: This mesh only needs to match the surface of the simulation mesh and is used for
+ collision detection.
+
+ The APIs exposed provides the data for both the simulation and collision meshes. These are specified
+ by the `sim` and `collision` prefixes in the property names.
+
+ The data is lazily updated, meaning that the data is only updated when it is accessed. This is useful
+ when the data is expensive to compute or retrieve. The data is updated when the timestamp of the buffer
+ is older than the current simulation timestamp. The timestamp is updated whenever the data is updated.
+ """
+
+ def__init__(self,root_physx_view:physx.SoftBodyView,device:str):
+"""Initializes the deformable object data.
+
+ Args:
+ root_physx_view: The root deformable body view of the object.
+ device: The device used for processing.
+ """
+ # Set the parameters
+ self.device=device
+ # Set the root deformable body view
+ # note: this is stored as a weak reference to avoid circular references between the asset class
+ # and the data container. This is important to avoid memory leaks.
+ self._root_physx_view:physx.SoftBodyView=weakref.proxy(root_physx_view)
+
+ # Set initial time stamp
+ self._sim_timestamp=0.0
+
+ # Initialize the lazy buffers.
+ # -- node state in simulation world frame
+ self._nodal_pos_w=TimestampedBuffer()
+ self._nodal_vel_w=TimestampedBuffer()
+ self._nodal_state_w=TimestampedBuffer()
+ # -- mesh element-wise rotations
+ self._sim_element_quat_w=TimestampedBuffer()
+ self._collision_element_quat_w=TimestampedBuffer()
+ # -- mesh element-wise deformation gradients
+ self._sim_element_deform_gradient_w=TimestampedBuffer()
+ self._collision_element_deform_gradient_w=TimestampedBuffer()
+ # -- mesh element-wise stresses
+ self._sim_element_stress_w=TimestampedBuffer()
+ self._collision_element_stress_w=TimestampedBuffer()
+
+
[文档]defupdate(self,dt:float):
+"""Updates the data for the deformable object.
+
+ Args:
+ dt: The time step for the update. This must be a positive value.
+ """
+ # update the simulation timestamp
+ self._sim_timestamp+=dt
+
+ ##
+ # Defaults.
+ ##
+
+ default_nodal_state_w:torch.Tensor=None
+"""Default nodal state ``[nodal_pos, nodal_vel]`` in simulation world frame.
+ Shape is (num_instances, max_sim_vertices_per_body, 6).
+ """
+
+ ##
+ # Kinematic commands
+ ##
+
+ nodal_kinematic_target:torch.Tensor=None
+"""Simulation mesh kinematic targets for the deformable bodies.
+ Shape is (num_instances, max_sim_vertices_per_body, 4).
+
+ The kinematic targets are used to drive the simulation mesh vertices to the target positions.
+ The targets are stored as (x, y, z, is_not_kinematic) where "is_not_kinematic" is a binary
+ flag indicating whether the vertex is kinematic or not. The flag is set to 0 for kinematic vertices
+ and 1 for non-kinematic vertices.
+ """
+
+ ##
+ # Properties.
+ ##
+
+ @property
+ defnodal_pos_w(self):
+"""Nodal positions in simulation world frame. Shape is (num_instances, max_sim_vertices_per_body, 3)."""
+ ifself._nodal_pos_w.timestamp<self._sim_timestamp:
+ self._nodal_pos_w.data=self._root_physx_view.get_sim_nodal_positions()
+ self._nodal_pos_w.timestamp=self._sim_timestamp
+ returnself._nodal_pos_w.data
+
+ @property
+ defnodal_vel_w(self):
+"""Nodal velocities in simulation world frame. Shape is (num_instances, max_sim_vertices_per_body, 3)."""
+ ifself._nodal_vel_w.timestamp<self._sim_timestamp:
+ self._nodal_vel_w.data=self._root_physx_view.get_sim_nodal_velocities()
+ self._nodal_vel_w.timestamp=self._sim_timestamp
+ returnself._nodal_vel_w.data
+
+ @property
+ defnodal_state_w(self):
+"""Nodal state ``[nodal_pos, nodal_vel]`` in simulation world frame.
+ Shape is (num_instances, max_sim_vertices_per_body, 6).
+ """
+ ifself._nodal_state_w.timestamp<self._sim_timestamp:
+ nodal_positions=self.nodal_pos_w
+ nodal_velocities=self.nodal_vel_w
+ # set the buffer data and timestamp
+ self._nodal_state_w.data=torch.cat((nodal_positions,nodal_velocities),dim=-1)
+ self._nodal_state_w.timestamp=self._sim_timestamp
+ returnself._nodal_state_w.data
+
+ @property
+ defsim_element_quat_w(self):
+"""Simulation mesh element-wise rotations as quaternions for the deformable bodies in simulation world frame.
+ Shape is (num_instances, max_sim_elements_per_body, 4).
+
+ The rotations are stored as quaternions in the order (w, x, y, z).
+ """
+ ifself._sim_element_quat_w.timestamp<self._sim_timestamp:
+ # convert from xyzw to wxyz
+ quats=self._root_physx_view.get_sim_element_rotations().view(self._root_physx_view.count,-1,4)
+ quats=math_utils.convert_quat(quats,to="wxyz")
+ # set the buffer data and timestamp
+ self._sim_element_quat_w.data=quats
+ self._sim_element_quat_w.timestamp=self._sim_timestamp
+ returnself._sim_element_quat_w.data
+
+ @property
+ defcollision_element_quat_w(self):
+"""Collision mesh element-wise rotations as quaternions for the deformable bodies in simulation world frame.
+ Shape is (num_instances, max_collision_elements_per_body, 4).
+
+ The rotations are stored as quaternions in the order (w, x, y, z).
+ """
+ ifself._collision_element_quat_w.timestamp<self._sim_timestamp:
+ # convert from xyzw to wxyz
+ quats=self._root_physx_view.get_element_rotations().view(self._root_physx_view.count,-1,4)
+ quats=math_utils.convert_quat(quats,to="wxyz")
+ # set the buffer data and timestamp
+ self._collision_element_quat_w.data=quats
+ self._collision_element_quat_w.timestamp=self._sim_timestamp
+ returnself._collision_element_quat_w.data
+
+ @property
+ defsim_element_deform_gradient_w(self):
+"""Simulation mesh element-wise second-order deformation gradient tensors for the deformable bodies
+ in simulation world frame. Shape is (num_instances, max_sim_elements_per_body, 3, 3).
+ """
+ ifself._sim_element_deform_gradient_w.timestamp<self._sim_timestamp:
+ # set the buffer data and timestamp
+ self._sim_element_deform_gradient_w.data=(
+ self._root_physx_view.get_sim_element_deformation_gradients().view(
+ self._root_physx_view.count,-1,3,3
+ )
+ )
+ self._sim_element_deform_gradient_w.timestamp=self._sim_timestamp
+ returnself._sim_element_deform_gradient_w.data
+
+ @property
+ defcollision_element_deform_gradient_w(self):
+"""Collision mesh element-wise second-order deformation gradient tensors for the deformable bodies
+ in simulation world frame. Shape is (num_instances, max_collision_elements_per_body, 3, 3).
+ """
+ ifself._collision_element_deform_gradient_w.timestamp<self._sim_timestamp:
+ # set the buffer data and timestamp
+ self._collision_element_deform_gradient_w.data=(
+ self._root_physx_view.get_element_deformation_gradients().view(self._root_physx_view.count,-1,3,3)
+ )
+ self._collision_element_deform_gradient_w.timestamp=self._sim_timestamp
+ returnself._collision_element_deform_gradient_w.data
+
+ @property
+ defsim_element_stress_w(self):
+"""Simulation mesh element-wise second-order Cauchy stress tensors for the deformable bodies
+ in simulation world frame. Shape is (num_instances, max_sim_elements_per_body, 3, 3).
+ """
+ ifself._sim_element_stress_w.timestamp<self._sim_timestamp:
+ # set the buffer data and timestamp
+ self._sim_element_stress_w.data=self._root_physx_view.get_sim_element_stresses().view(
+ self._root_physx_view.count,-1,3,3
+ )
+ self._sim_element_stress_w.timestamp=self._sim_timestamp
+ returnself._sim_element_stress_w.data
+
+ @property
+ defcollision_element_stress_w(self):
+"""Collision mesh element-wise second-order Cauchy stress tensors for the deformable bodies
+ in simulation world frame. Shape is (num_instances, max_collision_elements_per_body, 3, 3).
+ """
+ ifself._collision_element_stress_w.timestamp<self._sim_timestamp:
+ # set the buffer data and timestamp
+ self._collision_element_stress_w.data=self._root_physx_view.get_element_stresses().view(
+ self._root_physx_view.count,-1,3,3
+ )
+ self._collision_element_stress_w.timestamp=self._sim_timestamp
+ returnself._collision_element_stress_w.data
+
+ ##
+ # Derived properties.
+ ##
+
+ @property
+ defroot_pos_w(self)->torch.Tensor:
+"""Root position from nodal positions of the simulation mesh for the deformable bodies in simulation world frame.
+ Shape is (num_instances, 3).
+
+ This quantity is computed as the mean of the nodal positions.
+ """
+ returnself.nodal_pos_w.mean(dim=1)
+
+ @property
+ defroot_vel_w(self)->torch.Tensor:
+"""Root velocity from vertex velocities for the deformable bodies in simulation world frame.
+ Shape is (num_instances, 3).
+
+ This quantity is computed as the mean of the nodal velocities.
+ """
+ returnself.nodal_vel_w.mean(dim=1)
[文档]classRigidObject(AssetBase):
+"""A rigid object asset class.
+
+ Rigid objects are assets comprising of rigid bodies. They can be used to represent dynamic objects
+ such as boxes, spheres, etc. A rigid body is described by its pose, velocity and mass distribution.
+
+ For an asset to be considered a rigid object, the root prim of the asset must have the `USD RigidBodyAPI`_
+ applied to it. This API is used to define the simulation properties of the rigid body. On playing the
+ simulation, the physics engine will automatically register the rigid body and create a corresponding
+ rigid body handle. This handle can be accessed using the :attr:`root_physx_view` attribute.
+
+ .. note::
+
+ For users familiar with Isaac Sim, the PhysX view class API is not the exactly same as Isaac Sim view
+ class API. Similar to Isaac Lab, Isaac Sim wraps around the PhysX view API. However, as of now (2023.1 release),
+ we see a large difference in initializing the view classes in Isaac Sim. This is because the view classes
+ in Isaac Sim perform additional USD-related operations which are slow and also not required.
+
+ .. _`USD RigidBodyAPI`: https://openusd.org/dev/api/class_usd_physics_rigid_body_a_p_i.html
+ """
+
+ cfg:RigidObjectCfg
+"""Configuration instance for the rigid object."""
+
+
[文档]def__init__(self,cfg:RigidObjectCfg):
+"""Initialize the rigid object.
+
+ Args:
+ cfg: A configuration instance.
+ """
+ super().__init__(cfg)
+
+"""
+ Properties
+ """
+
+ @property
+ defdata(self)->RigidObjectData:
+ returnself._data
+
+ @property
+ defnum_instances(self)->int:
+ returnself.root_physx_view.count
+
+ @property
+ defnum_bodies(self)->int:
+"""Number of bodies in the asset.
+
+ This is always 1 since each object is a single rigid body.
+ """
+ return1
+
+ @property
+ defbody_names(self)->list[str]:
+"""Ordered names of bodies in the rigid object."""
+ prim_paths=self.root_physx_view.prim_paths[:self.num_bodies]
+ return[path.split("/")[-1]forpathinprim_paths]
+
+ @property
+ defroot_physx_view(self)->physx.RigidBodyView:
+"""Rigid body view for the asset (PhysX).
+
+ Note:
+ Use this view with caution. It requires handling of tensors in a specific way.
+ """
+ returnself._root_physx_view
+
+"""
+ Operations.
+ """
+
+
[文档]defreset(self,env_ids:Sequence[int]|None=None):
+ # resolve all indices
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ # reset external wrench
+ self._external_force_b[env_ids]=0.0
+ self._external_torque_b[env_ids]=0.0
+
+
[文档]defwrite_data_to_sim(self):
+"""Write external wrench to the simulation.
+
+ Note:
+ We write external wrench to the simulation here since this function is called before the simulation step.
+ This ensures that the external wrench is applied at every simulation step.
+ """
+ # write external wrench
+ ifself.has_external_wrench:
+ self.root_physx_view.apply_forces_and_torques_at_position(
+ force_data=self._external_force_b.view(-1,3),
+ torque_data=self._external_torque_b.view(-1,3),
+ position_data=None,
+ indices=self._ALL_INDICES,
+ is_global=False,
+ )
[文档]deffind_bodies(self,name_keys:str|Sequence[str],preserve_order:bool=False)->tuple[list[int],list[str]]:
+"""Find bodies in the rigid body based on the name keys.
+
+ Please check the :meth:`omni.isaac.lab.utils.string_utils.resolve_matching_names` function for more
+ information on the name matching.
+
+ Args:
+ name_keys: A regular expression or a list of regular expressions to match the body names.
+ preserve_order: Whether to preserve the order of the name keys in the output. Defaults to False.
+
+ Returns:
+ A tuple of lists containing the body indices and names.
+ """
+ returnstring_utils.resolve_matching_names(name_keys,self.body_names,preserve_order)
[文档]defwrite_root_state_to_sim(self,root_state:torch.Tensor,env_ids:Sequence[int]|None=None):
+"""Set the root state over selected environment indices into the simulation.
+
+ The root state comprises of the cartesian position, quaternion orientation in (w, x, y, z), and linear
+ and angular velocity. All the quantities are in the simulation frame.
+
+ Args:
+ root_state: Root state in simulation frame. Shape is (len(env_ids), 13).
+ env_ids: Environment indices. If None, then all indices are used.
+ """
+ # set into simulation
+ self.write_root_pose_to_sim(root_state[:,:7],env_ids=env_ids)
+ self.write_root_velocity_to_sim(root_state[:,7:],env_ids=env_ids)
+
+
[文档]defwrite_root_pose_to_sim(self,root_pose:torch.Tensor,env_ids:Sequence[int]|None=None):
+"""Set the root pose over selected environment indices into the simulation.
+
+ The root pose comprises of the cartesian position and quaternion orientation in (w, x, y, z).
+
+ Args:
+ root_pose: Root poses in simulation frame. Shape is (len(env_ids), 7).
+ env_ids: Environment indices. If None, then all indices are used.
+ """
+ # resolve all indices
+ physx_env_ids=env_ids
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ physx_env_ids=self._ALL_INDICES
+ # note: we need to do this here since tensors are not set into simulation until step.
+ # set into internal buffers
+ self._data.root_state_w[env_ids,:7]=root_pose.clone()
+ # convert root quaternion from wxyz to xyzw
+ root_poses_xyzw=self._data.root_state_w[:,:7].clone()
+ root_poses_xyzw[:,3:]=math_utils.convert_quat(root_poses_xyzw[:,3:],to="xyzw")
+ # set into simulation
+ self.root_physx_view.set_transforms(root_poses_xyzw,indices=physx_env_ids)
+
+
[文档]defwrite_root_velocity_to_sim(self,root_velocity:torch.Tensor,env_ids:Sequence[int]|None=None):
+"""Set the root velocity over selected environment indices into the simulation.
+
+ Args:
+ root_velocity: Root velocities in simulation frame. Shape is (len(env_ids), 6).
+ env_ids: Environment indices. If None, then all indices are used.
+ """
+ # resolve all indices
+ physx_env_ids=env_ids
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ physx_env_ids=self._ALL_INDICES
+ # note: we need to do this here since tensors are not set into simulation until step.
+ # set into internal buffers
+ self._data.root_state_w[env_ids,7:]=root_velocity.clone()
+ self._data.body_acc_w[env_ids]=0.0
+ # set into simulation
+ self.root_physx_view.set_velocities(self._data.root_state_w[:,7:],indices=physx_env_ids)
+
+"""
+ Operations - Setters.
+ """
+
+
[文档]defset_external_force_and_torque(
+ self,
+ forces:torch.Tensor,
+ torques:torch.Tensor,
+ body_ids:Sequence[int]|slice|None=None,
+ env_ids:Sequence[int]|None=None,
+ ):
+"""Set external force and torque to apply on the asset's bodies in their local frame.
+
+ For many applications, we want to keep the applied external force on rigid bodies constant over a period of
+ time (for instance, during the policy control). This function allows us to store the external force and torque
+ into buffers which are then applied to the simulation at every step.
+
+ .. caution::
+ If the function is called with empty forces and torques, then this function disables the application
+ of external wrench to the simulation.
+
+ .. code-block:: python
+
+ # example of disabling external wrench
+ asset.set_external_force_and_torque(forces=torch.zeros(0, 3), torques=torch.zeros(0, 3))
+
+ .. note::
+ This function does not apply the external wrench to the simulation. It only fills the buffers with
+ the desired values. To apply the external wrench, call the :meth:`write_data_to_sim` function
+ right before the simulation step.
+
+ Args:
+ forces: External forces in bodies' local frame. Shape is (len(env_ids), len(body_ids), 3).
+ torques: External torques in bodies' local frame. Shape is (len(env_ids), len(body_ids), 3).
+ body_ids: Body indices to apply external wrench to. Defaults to None (all bodies).
+ env_ids: Environment indices to apply external wrench to. Defaults to None (all instances).
+ """
+ ifforces.any()ortorques.any():
+ self.has_external_wrench=True
+ # resolve all indices
+ # -- env_ids
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ # -- body_ids
+ ifbody_idsisNone:
+ body_ids=slice(None)
+ # broadcast env_ids if needed to allow double indexing
+ ifenv_ids!=slice(None)andbody_ids!=slice(None):
+ env_ids=env_ids[:,None]
+
+ # set into internal buffers
+ self._external_force_b[env_ids,body_ids]=forces
+ self._external_torque_b[env_ids,body_ids]=torques
+ else:
+ self.has_external_wrench=False
+
+"""
+ Internal helper.
+ """
+
+ def_initialize_impl(self):
+ # create simulation view
+ self._physics_sim_view=physx.create_simulation_view(self._backend)
+ self._physics_sim_view.set_subspace_roots("/")
+ # obtain the first prim in the regex expression (all others are assumed to be a copy of this)
+ template_prim=sim_utils.find_first_matching_prim(self.cfg.prim_path)
+ iftemplate_primisNone:
+ raiseRuntimeError(f"Failed to find prim for expression: '{self.cfg.prim_path}'.")
+ template_prim_path=template_prim.GetPath().pathString
+
+ # find rigid root prims
+ root_prims=sim_utils.get_all_matching_child_prims(
+ template_prim_path,predicate=lambdaprim:prim.HasAPI(UsdPhysics.RigidBodyAPI)
+ )
+ iflen(root_prims)==0:
+ raiseRuntimeError(
+ f"Failed to find a rigid body when resolving '{self.cfg.prim_path}'."
+ " Please ensure that the prim has 'USD RigidBodyAPI' applied."
+ )
+ iflen(root_prims)>1:
+ raiseRuntimeError(
+ f"Failed to find a single rigid body when resolving '{self.cfg.prim_path}'."
+ f" Found multiple '{root_prims}' under '{template_prim_path}'."
+ " Please ensure that there is only one rigid body in the prim path tree."
+ )
+
+ # resolve root prim back into regex expression
+ root_prim_path=root_prims[0].GetPath().pathString
+ root_prim_path_expr=self.cfg.prim_path+root_prim_path[len(template_prim_path):]
+ # -- object view
+ self._root_physx_view=self._physics_sim_view.create_rigid_body_view(root_prim_path_expr.replace(".*","*"))
+
+ # check if the rigid body was created
+ ifself._root_physx_view._backendisNone:
+ raiseRuntimeError(f"Failed to create rigid body at: {self.cfg.prim_path}. Please check PhysX logs.")
+
+ # log information about the rigid body
+ carb.log_info(f"Rigid body initialized at: {self.cfg.prim_path} with root '{root_prim_path_expr}'.")
+ carb.log_info(f"Number of instances: {self.num_instances}")
+ carb.log_info(f"Number of bodies: {self.num_bodies}")
+ carb.log_info(f"Body names: {self.body_names}")
+
+ # container for data access
+ self._data=RigidObjectData(self.root_physx_view,self.device)
+
+ # create buffers
+ self._create_buffers()
+ # process configuration
+ self._process_cfg()
+ # update the rigid body data
+ self.update(0.0)
+
+ def_create_buffers(self):
+"""Create buffers for storing data."""
+ # constants
+ self._ALL_INDICES=torch.arange(self.num_instances,dtype=torch.long,device=self.device)
+
+ # external forces and torques
+ self.has_external_wrench=False
+ self._external_force_b=torch.zeros((self.num_instances,self.num_bodies,3),device=self.device)
+ self._external_torque_b=torch.zeros_like(self._external_force_b)
+
+ # set information about rigid body into data
+ self._data.body_names=self.body_names
+ self._data.default_mass=self.root_physx_view.get_masses().clone()
+ self._data.default_inertia=self.root_physx_view.get_inertias().clone()
+
+ def_process_cfg(self):
+"""Post processing of configuration parameters."""
+ # default state
+ # -- root state
+ # note: we cast to tuple to avoid torch/numpy type mismatch.
+ default_root_state=(
+ tuple(self.cfg.init_state.pos)
+ +tuple(self.cfg.init_state.rot)
+ +tuple(self.cfg.init_state.lin_vel)
+ +tuple(self.cfg.init_state.ang_vel)
+ )
+ default_root_state=torch.tensor(default_root_state,dtype=torch.float,device=self.device)
+ self._data.default_root_state=default_root_state.repeat(self.num_instances,1)
+
+"""
+ Internal simulation callbacks.
+ """
+
+ def_invalidate_initialize_callback(self,event):
+"""Invalidates the scene elements."""
+ # call parent
+ super()._invalidate_initialize_callback(event)
+ # set all existing views to None to invalidate them
+ self._physics_sim_view=None
+ self._root_physx_view=None
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+fromomni.isaac.lab.utilsimportconfigclass
+
+from..asset_base_cfgimportAssetBaseCfg
+from.rigid_objectimportRigidObject
+
+
+
[文档]@configclass
+classRigidObjectCfg(AssetBaseCfg):
+"""Configuration parameters for a rigid object."""
+
+
[文档]@configclass
+ classInitialStateCfg(AssetBaseCfg.InitialStateCfg):
+"""Initial state of the rigid body."""
+
+ lin_vel:tuple[float,float,float]=(0.0,0.0,0.0)
+"""Linear velocity of the root in simulation world frame. Defaults to (0.0, 0.0, 0.0)."""
+ ang_vel:tuple[float,float,float]=(0.0,0.0,0.0)
+"""Angular velocity of the root in simulation world frame. Defaults to (0.0, 0.0, 0.0)."""
+
+ ##
+ # Initialize configurations.
+ ##
+
+ class_type:type=RigidObject
+
+ init_state:InitialStateCfg=InitialStateCfg()
+"""Initial state of the rigid object. Defaults to identity pose with zero velocity."""
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+importtorch
+importweakref
+
+importomni.physics.tensors.impl.apiasphysx
+
+importomni.isaac.lab.utils.mathasmath_utils
+fromomni.isaac.lab.utils.buffersimportTimestampedBuffer
+
+
+
[文档]classRigidObjectData:
+"""Data container for a rigid object.
+
+ This class contains the data for a rigid object in the simulation. The data includes the state of
+ the root rigid body and the state of all the bodies in the object. The data is stored in the simulation
+ world frame unless otherwise specified.
+
+ For a rigid body, there are two frames of reference that are used:
+
+ - Actor frame: The frame of reference of the rigid body prim. This typically corresponds to the Xform prim
+ with the rigid body schema.
+ - Center of mass frame: The frame of reference of the center of mass of the rigid body.
+
+ Depending on the settings of the simulation, the actor frame and the center of mass frame may be the same.
+ This needs to be taken into account when interpreting the data.
+
+ The data is lazily updated, meaning that the data is only updated when it is accessed. This is useful
+ when the data is expensive to compute or retrieve. The data is updated when the timestamp of the buffer
+ is older than the current simulation timestamp. The timestamp is updated whenever the data is updated.
+ """
+
+ def__init__(self,root_physx_view:physx.RigidBodyView,device:str):
+"""Initializes the rigid object data.
+
+ Args:
+ root_physx_view: The root rigid body view.
+ device: The device used for processing.
+ """
+ # Set the parameters
+ self.device=device
+ # Set the root rigid body view
+ # note: this is stored as a weak reference to avoid circular references between the asset class
+ # and the data container. This is important to avoid memory leaks.
+ self._root_physx_view:physx.RigidBodyView=weakref.proxy(root_physx_view)
+
+ # Set initial time stamp
+ self._sim_timestamp=0.0
+
+ # Obtain global physics sim view
+ physics_sim_view=physx.create_simulation_view("torch")
+ physics_sim_view.set_subspace_roots("/")
+ gravity=physics_sim_view.get_gravity()
+ # Convert to direction vector
+ gravity_dir=torch.tensor((gravity[0],gravity[1],gravity[2]),device=self.device)
+ gravity_dir=math_utils.normalize(gravity_dir.unsqueeze(0)).squeeze(0)
+
+ # Initialize constants
+ self.GRAVITY_VEC_W=gravity_dir.repeat(self._root_physx_view.count,1)
+ self.FORWARD_VEC_B=torch.tensor((1.0,0.0,0.0),device=self.device).repeat(self._root_physx_view.count,1)
+
+ # Initialize the lazy buffers.
+ self._root_state_w=TimestampedBuffer()
+ self._body_acc_w=TimestampedBuffer()
+
+
[文档]defupdate(self,dt:float):
+"""Updates the data for the rigid object.
+
+ Args:
+ dt: The time step for the update. This must be a positive value.
+ """
+ # update the simulation timestamp
+ self._sim_timestamp+=dt
+
+ ##
+ # Names.
+ ##
+
+ body_names:list[str]=None
+"""Body names in the order parsed by the simulation view."""
+
+ ##
+ # Defaults.
+ ##
+
+ default_root_state:torch.Tensor=None
+"""Default root state ``[pos, quat, lin_vel, ang_vel]`` in local environment frame. Shape is (num_instances, 13).
+
+ The position and quaternion are of the rigid body's actor frame. Meanwhile, the linear and angular velocities are
+ of the center of mass frame.
+ """
+
+ default_mass:torch.Tensor=None
+"""Default mass read from the simulation. Shape is (num_instances, 1)."""
+
+ default_inertia:torch.Tensor=None
+"""Default inertia tensor read from the simulation. Shape is (num_instances, 9).
+
+ The inertia is the inertia tensor relative to the center of mass frame. The values are stored in
+ the order :math:`[I_{xx}, I_{xy}, I_{xz}, I_{yx}, I_{yy}, I_{yz}, I_{zx}, I_{zy}, I_{zz}]`.
+ """
+
+ ##
+ # Properties.
+ ##
+
+ @property
+ defroot_state_w(self):
+"""Root state ``[pos, quat, lin_vel, ang_vel]`` in simulation world frame. Shape is (num_instances, 13).
+
+ The position and orientation are of the rigid body's actor frame. Meanwhile, the linear and angular
+ velocities are of the rigid body's center of mass frame.
+ """
+ ifself._root_state_w.timestamp<self._sim_timestamp:
+ # read data from simulation
+ pose=self._root_physx_view.get_transforms().clone()
+ pose[:,3:7]=math_utils.convert_quat(pose[:,3:7],to="wxyz")
+ velocity=self._root_physx_view.get_velocities()
+ # set the buffer data and timestamp
+ self._root_state_w.data=torch.cat((pose,velocity),dim=-1)
+ self._root_state_w.timestamp=self._sim_timestamp
+ returnself._root_state_w.data
+
+ @property
+ defbody_state_w(self):
+"""State of all bodies `[pos, quat, lin_vel, ang_vel]` in simulation world frame. Shape is (num_instances, 1, 13).
+
+ The position and orientation are of the rigid bodies' actor frame. Meanwhile, the linear and angular
+ velocities are of the rigid bodies' center of mass frame.
+ """
+ returnself.root_state_w.view(-1,1,13)
+
+ @property
+ defbody_acc_w(self):
+"""Acceleration of all bodies. Shape is (num_instances, 1, 6).
+
+ This quantity is the acceleration of the rigid bodies' center of mass frame.
+ """
+ ifself._body_acc_w.timestamp<self._sim_timestamp:
+ # note: we use finite differencing to compute acceleration
+ self._body_acc_w.data=self._root_physx_view.get_accelerations().unsqueeze(1)
+ self._body_acc_w.timestamp=self._sim_timestamp
+ returnself._body_acc_w.data
+
+ @property
+ defprojected_gravity_b(self):
+"""Projection of the gravity direction on base frame. Shape is (num_instances, 3)."""
+ returnmath_utils.quat_rotate_inverse(self.root_quat_w,self.GRAVITY_VEC_W)
+
+ @property
+ defheading_w(self):
+"""Yaw heading of the base frame (in radians). Shape is (num_instances,).
+
+ Note:
+ This quantity is computed by assuming that the forward-direction of the base
+ frame is along x-direction, i.e. :math:`(1, 0, 0)`.
+ """
+ forward_w=math_utils.quat_apply(self.root_quat_w,self.FORWARD_VEC_B)
+ returntorch.atan2(forward_w[:,1],forward_w[:,0])
+
+ ##
+ # Derived properties.
+ ##
+
+ @property
+ defroot_pos_w(self)->torch.Tensor:
+"""Root position in simulation world frame. Shape is (num_instances, 3).
+
+ This quantity is the position of the actor frame of the root rigid body.
+ """
+ returnself.root_state_w[:,:3]
+
+ @property
+ defroot_quat_w(self)->torch.Tensor:
+"""Root orientation (w, x, y, z) in simulation world frame. Shape is (num_instances, 4).
+
+ This quantity is the orientation of the actor frame of the root rigid body.
+ """
+ returnself.root_state_w[:,3:7]
+
+ @property
+ defroot_vel_w(self)->torch.Tensor:
+"""Root velocity in simulation world frame. Shape is (num_instances, 6).
+
+ This quantity contains the linear and angular velocities of the root rigid body's center of mass frame.
+ """
+ returnself.root_state_w[:,7:13]
+
+ @property
+ defroot_lin_vel_w(self)->torch.Tensor:
+"""Root linear velocity in simulation world frame. Shape is (num_instances, 3).
+
+ This quantity is the linear velocity of the root rigid body's center of mass frame.
+ """
+ returnself.root_state_w[:,7:10]
+
+ @property
+ defroot_ang_vel_w(self)->torch.Tensor:
+"""Root angular velocity in simulation world frame. Shape is (num_instances, 3).
+
+ This quantity is the angular velocity of the root rigid body's center of mass frame.
+ """
+ returnself.root_state_w[:,10:13]
+
+ @property
+ defroot_lin_vel_b(self)->torch.Tensor:
+"""Root linear velocity in base frame. Shape is (num_instances, 3).
+
+ This quantity is the linear velocity of the root rigid body's center of mass frame with respect to the
+ rigid body's actor frame.
+ """
+ returnmath_utils.quat_rotate_inverse(self.root_quat_w,self.root_lin_vel_w)
+
+ @property
+ defroot_ang_vel_b(self)->torch.Tensor:
+"""Root angular velocity in base world frame. Shape is (num_instances, 3).
+
+ This quantity is the angular velocity of the root rigid body's center of mass frame with respect to the
+ rigid body's actor frame.
+ """
+ returnmath_utils.quat_rotate_inverse(self.root_quat_w,self.root_ang_vel_w)
+
+ @property
+ defbody_pos_w(self)->torch.Tensor:
+"""Positions of all bodies in simulation world frame. Shape is (num_instances, 1, 3).
+
+ This quantity is the position of the rigid bodies' actor frame.
+ """
+ returnself.body_state_w[...,:3]
+
+ @property
+ defbody_quat_w(self)->torch.Tensor:
+"""Orientation (w, x, y, z) of all bodies in simulation world frame. Shape is (num_instances, 1, 4).
+
+ This quantity is the orientation of the rigid bodies' actor frame.
+ """
+ returnself.body_state_w[...,3:7]
+
+ @property
+ defbody_vel_w(self)->torch.Tensor:
+"""Velocity of all bodies in simulation world frame. Shape is (num_instances, 1, 6).
+
+ This quantity contains the linear and angular velocities of the rigid bodies' center of mass frame.
+ """
+ returnself.body_state_w[...,7:13]
+
+ @property
+ defbody_lin_vel_w(self)->torch.Tensor:
+"""Linear velocity of all bodies in simulation world frame. Shape is (num_instances, 1, 3).
+
+ This quantity is the linear velocity of the rigid bodies' center of mass frame.
+ """
+ returnself.body_state_w[...,7:10]
+
+ @property
+ defbody_ang_vel_w(self)->torch.Tensor:
+"""Angular velocity of all bodies in simulation world frame. Shape is (num_instances, 1, 3).
+
+ This quantity is the angular velocity of the rigid bodies' center of mass frame.
+ """
+ returnself.body_state_w[...,10:13]
+
+ @property
+ defbody_lin_acc_w(self)->torch.Tensor:
+"""Linear acceleration of all bodies in simulation world frame. Shape is (num_instances, 1, 3).
+
+ This quantity is the linear acceleration of the rigid bodies' center of mass frame.
+ """
+ returnself.body_acc_w[...,0:3]
+
+ @property
+ defbody_ang_acc_w(self)->torch.Tensor:
+"""Angular acceleration of all bodies in simulation world frame. Shape is (num_instances, 1, 3).
+
+ This quantity is the angular acceleration of the rigid bodies' center of mass frame.
+ """
+ returnself.body_acc_w[...,3:6]
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+from__future__importannotations
+
+importtorch
+fromtypingimportTYPE_CHECKING
+
+fromomni.isaac.lab.utils.mathimportapply_delta_pose,compute_pose_error
+
+ifTYPE_CHECKING:
+ from.differential_ik_cfgimportDifferentialIKControllerCfg
+
+
+
[文档]classDifferentialIKController:
+r"""Differential inverse kinematics (IK) controller.
+
+ This controller is based on the concept of differential inverse kinematics [1, 2] which is a method for computing
+ the change in joint positions that yields the desired change in pose.
+
+ .. math::
+
+ \Delta \mathbf{q} &= \mathbf{J}^{\dagger} \Delta \mathbf{x} \\
+ \mathbf{q}_{\text{desired}} &= \mathbf{q}_{\text{current}} + \Delta \mathbf{q}
+
+ where :math:`\mathbf{J}^{\dagger}` is the pseudo-inverse of the Jacobian matrix :math:`\mathbf{J}`,
+ :math:`\Delta \mathbf{x}` is the desired change in pose, and :math:`\mathbf{q}_{\text{current}}`
+ is the current joint positions.
+
+ To deal with singularity in Jacobian, the following methods are supported for computing inverse of the Jacobian:
+
+ - "pinv": Moore-Penrose pseudo-inverse
+ - "svd": Adaptive singular-value decomposition (SVD)
+ - "trans": Transpose of matrix
+ - "dls": Damped version of Moore-Penrose pseudo-inverse (also called Levenberg-Marquardt)
+
+
+ .. caution::
+ The controller does not assume anything about the frames of the current and desired end-effector pose,
+ or the joint-space velocities. It is up to the user to ensure that these quantities are given
+ in the correct format.
+
+ Reference:
+
+ 1. `Robot Dynamics Lecture Notes <https://ethz.ch/content/dam/ethz/special-interest/mavt/robotics-n-intelligent-systems/rsl-dam/documents/RobotDynamics2017/RD_HS2017script.pdf>`_
+ by Marco Hutter (ETH Zurich)
+ 2. `Introduction to Inverse Kinematics <https://www.cs.cmu.edu/~15464-s13/lectures/lecture6/iksurvey.pdf>`_
+ by Samuel R. Buss (University of California, San Diego)
+
+ """
+
+
[文档]def__init__(self,cfg:DifferentialIKControllerCfg,num_envs:int,device:str):
+"""Initialize the controller.
+
+ Args:
+ cfg: The configuration for the controller.
+ num_envs: The number of environments.
+ device: The device to use for computations.
+ """
+ # store inputs
+ self.cfg=cfg
+ self.num_envs=num_envs
+ self._device=device
+ # create buffers
+ self.ee_pos_des=torch.zeros(self.num_envs,3,device=self._device)
+ self.ee_quat_des=torch.zeros(self.num_envs,4,device=self._device)
+ # -- input command
+ self._command=torch.zeros(self.num_envs,self.action_dim,device=self._device)
[文档]defreset(self,env_ids:torch.Tensor=None):
+"""Reset the internals.
+
+ Args:
+ env_ids: The environment indices to reset. If None, then all environments are reset.
+ """
+ pass
+
+
[文档]defset_command(
+ self,command:torch.Tensor,ee_pos:torch.Tensor|None=None,ee_quat:torch.Tensor|None=None
+ ):
+"""Set target end-effector pose command.
+
+ Based on the configured command type and relative mode, the method computes the desired end-effector pose.
+ It is up to the user to ensure that the command is given in the correct frame. The method only
+ applies the relative mode if the command type is ``position_rel`` or ``pose_rel``.
+
+ Args:
+ command: The input command in shape (N, 3) or (N, 6) or (N, 7).
+ ee_pos: The current end-effector position in shape (N, 3).
+ This is only needed if the command type is ``position_rel`` or ``pose_rel``.
+ ee_quat: The current end-effector orientation (w, x, y, z) in shape (N, 4).
+ This is only needed if the command type is ``position_*`` or ``pose_rel``.
+
+ Raises:
+ ValueError: If the command type is ``position_*`` and :attr:`ee_quat` is None.
+ ValueError: If the command type is ``position_rel`` and :attr:`ee_pos` is None.
+ ValueError: If the command type is ``pose_rel`` and either :attr:`ee_pos` or :attr:`ee_quat` is None.
+ """
+ # store command
+ self._command[:]=command
+ # compute the desired end-effector pose
+ ifself.cfg.command_type=="position":
+ # we need end-effector orientation even though we are in position mode
+ # this is only needed for display purposes
+ ifee_quatisNone:
+ raiseValueError("End-effector orientation can not be None for `position_*` command type!")
+ # compute targets
+ ifself.cfg.use_relative_mode:
+ ifee_posisNone:
+ raiseValueError("End-effector position can not be None for `position_rel` command type!")
+ self.ee_pos_des[:]=ee_pos+self._command
+ self.ee_quat_des[:]=ee_quat
+ else:
+ self.ee_pos_des[:]=self._command
+ self.ee_quat_des[:]=ee_quat
+ else:
+ # compute targets
+ ifself.cfg.use_relative_mode:
+ ifee_posisNoneoree_quatisNone:
+ raiseValueError(
+ "Neither end-effector position nor orientation can be None for `pose_rel` command type!"
+ )
+ self.ee_pos_des,self.ee_quat_des=apply_delta_pose(ee_pos,ee_quat,self._command)
+ else:
+ self.ee_pos_des=self._command[:,0:3]
+ self.ee_quat_des=self._command[:,3:7]
+
+
[文档]defcompute(
+ self,ee_pos:torch.Tensor,ee_quat:torch.Tensor,jacobian:torch.Tensor,joint_pos:torch.Tensor
+ )->torch.Tensor:
+"""Computes the target joint positions that will yield the desired end effector pose.
+
+ Args:
+ ee_pos: The current end-effector position in shape (N, 3).
+ ee_quat: The current end-effector orientation in shape (N, 4).
+ jacobian: The geometric jacobian matrix in shape (N, 6, num_joints).
+ joint_pos: The current joint positions in shape (N, num_joints).
+
+ Returns:
+ The target joint positions commands in shape (N, num_joints).
+ """
+ # compute the delta in joint-space
+ if"position"inself.cfg.command_type:
+ position_error=self.ee_pos_des-ee_pos
+ jacobian_pos=jacobian[:,0:3]
+ delta_joint_pos=self._compute_delta_joint_pos(delta_pose=position_error,jacobian=jacobian_pos)
+ else:
+ position_error,axis_angle_error=compute_pose_error(
+ ee_pos,ee_quat,self.ee_pos_des,self.ee_quat_des,rot_error_type="axis_angle"
+ )
+ pose_error=torch.cat((position_error,axis_angle_error),dim=1)
+ delta_joint_pos=self._compute_delta_joint_pos(delta_pose=pose_error,jacobian=jacobian)
+ # return the desired joint positions
+ returnjoint_pos+delta_joint_pos
+
+"""
+ Helper functions.
+ """
+
+ def_compute_delta_joint_pos(self,delta_pose:torch.Tensor,jacobian:torch.Tensor)->torch.Tensor:
+"""Computes the change in joint position that yields the desired change in pose.
+
+ The method uses the Jacobian mapping from joint-space velocities to end-effector velocities
+ to compute the delta-change in the joint-space that moves the robot closer to a desired
+ end-effector position.
+
+ Args:
+ delta_pose: The desired delta pose in shape (N, 3) or (N, 6).
+ jacobian: The geometric jacobian matrix in shape (N, 3, num_joints) or (N, 6, num_joints).
+
+ Returns:
+ The desired delta in joint space. Shape is (N, num-jointsß).
+ """
+ ifself.cfg.ik_paramsisNone:
+ raiseRuntimeError(f"Inverse-kinematics parameters for method '{self.cfg.ik_method}' is not defined!")
+ # compute the delta in joint-space
+ ifself.cfg.ik_method=="pinv":# Jacobian pseudo-inverse
+ # parameters
+ k_val=self.cfg.ik_params["k_val"]
+ # computation
+ jacobian_pinv=torch.linalg.pinv(jacobian)
+ delta_joint_pos=k_val*jacobian_pinv@delta_pose.unsqueeze(-1)
+ delta_joint_pos=delta_joint_pos.squeeze(-1)
+ elifself.cfg.ik_method=="svd":# adaptive SVD
+ # parameters
+ k_val=self.cfg.ik_params["k_val"]
+ min_singular_value=self.cfg.ik_params["min_singular_value"]
+ # computation
+ # U: 6xd, S: dxd, V: d x num-joint
+ U,S,Vh=torch.linalg.svd(jacobian)
+ S_inv=1.0/S
+ S_inv=torch.where(S>min_singular_value,S_inv,torch.zeros_like(S_inv))
+ jacobian_pinv=(
+ torch.transpose(Vh,dim0=1,dim1=2)[:,:,:6]
+ @torch.diag_embed(S_inv)
+ @torch.transpose(U,dim0=1,dim1=2)
+ )
+ delta_joint_pos=k_val*jacobian_pinv@delta_pose.unsqueeze(-1)
+ delta_joint_pos=delta_joint_pos.squeeze(-1)
+ elifself.cfg.ik_method=="trans":# Jacobian transpose
+ # parameters
+ k_val=self.cfg.ik_params["k_val"]
+ # computation
+ jacobian_T=torch.transpose(jacobian,dim0=1,dim1=2)
+ delta_joint_pos=k_val*jacobian_T@delta_pose.unsqueeze(-1)
+ delta_joint_pos=delta_joint_pos.squeeze(-1)
+ elifself.cfg.ik_method=="dls":# damped least squares
+ # parameters
+ lambda_val=self.cfg.ik_params["lambda_val"]
+ # computation
+ jacobian_T=torch.transpose(jacobian,dim0=1,dim1=2)
+ lambda_matrix=(lambda_val**2)*torch.eye(n=jacobian.shape[1],device=self._device)
+ delta_joint_pos=(
+ jacobian_T@torch.inverse(jacobian@jacobian_T+lambda_matrix)@delta_pose.unsqueeze(-1)
+ )
+ delta_joint_pos=delta_joint_pos.squeeze(-1)
+ else:
+ raiseValueError(f"Unsupported inverse-kinematics method: {self.cfg.ik_method}")
+
+ returndelta_joint_pos
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+fromdataclassesimportMISSING
+fromtypingimportLiteral
+
+fromomni.isaac.lab.utilsimportconfigclass
+
+from.differential_ikimportDifferentialIKController
+
+
+
[文档]@configclass
+classDifferentialIKControllerCfg:
+"""Configuration for differential inverse kinematics controller."""
+
+ class_type:type=DifferentialIKController
+"""The associated controller class."""
+
+ command_type:Literal["position","pose"]=MISSING
+"""Type of task-space command to control the articulation's body.
+
+ If "position", then the controller only controls the position of the articulation's body.
+ Otherwise, the controller controls the pose of the articulation's body.
+ """
+
+ use_relative_mode:bool=False
+"""Whether to use relative mode for the controller. Defaults to False.
+
+ If True, then the controller treats the input command as a delta change in the position/pose.
+ Otherwise, the controller treats the input command as the absolute position/pose.
+ """
+
+ ik_method:Literal["pinv","svd","trans","dls"]=MISSING
+"""Method for computing inverse of Jacobian."""
+
+ ik_params:dict[str,float]|None=None
+"""Parameters for the inverse-kinematics method. Defaults to None, in which case the default
+ parameters for the method are used.
+
+ - Moore-Penrose pseudo-inverse ("pinv"):
+ - "k_val": Scaling of computed delta-joint positions (default: 1.0).
+ - Adaptive Singular Value Decomposition ("svd"):
+ - "k_val": Scaling of computed delta-joint positions (default: 1.0).
+ - "min_singular_value": Single values less than this are suppressed to zero (default: 1e-5).
+ - Jacobian transpose ("trans"):
+ - "k_val": Scaling of computed delta-joint positions (default: 1.0).
+ - Damped Moore-Penrose pseudo-inverse ("dls"):
+ - "lambda_val": Damping coefficient (default: 0.01).
+ """
+
+ def__post_init__(self):
+ # check valid input
+ ifself.command_typenotin["position","pose"]:
+ raiseValueError(f"Unsupported inverse-kinematics command: {self.command_type}.")
+ ifself.ik_methodnotin["pinv","svd","trans","dls"]:
+ raiseValueError(f"Unsupported inverse-kinematics method: {self.ik_method}.")
+ # default parameters for different inverse kinematics approaches.
+ default_ik_params={
+ "pinv":{"k_val":1.0},
+ "svd":{"k_val":1.0,"min_singular_value":1e-5},
+ "trans":{"k_val":1.0},
+ "dls":{"lambda_val":0.01},
+ }
+ # update parameters for IK-method if not provided
+ ik_params=default_ik_params[self.ik_method].copy()
+ ifself.ik_paramsisnotNone:
+ ik_params.update(self.ik_params)
+ self.ik_params=ik_params
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Base class for teleoperation interface."""
+
+fromabcimportABC,abstractmethod
+fromcollections.abcimportCallable
+fromtypingimportAny
+
+
+
[文档]classDeviceBase(ABC):
+"""An interface class for teleoperation devices."""
+
+
[文档]def__init__(self):
+"""Initialize the teleoperation interface."""
+ pass
+
+ def__str__(self)->str:
+"""Returns: A string containing the information of joystick."""
+ returnf"{self.__class__.__name__}"
+
+"""
+ Operations
+ """
+
+
[文档]@abstractmethod
+ defreset(self):
+"""Reset the internals."""
+ raiseNotImplementedError
+
+
[文档]@abstractmethod
+ defadd_callback(self,key:Any,func:Callable):
+"""Add additional functions to bind keyboard.
+
+ Args:
+ key: The button to check against.
+ func: The function to call when key is pressed. The callback function should not
+ take any arguments.
+ """
+ raiseNotImplementedError
+
+
[文档]@abstractmethod
+ defadvance(self)->Any:
+"""Provides the joystick event state.
+
+ Returns:
+ The processed output form the joystick.
+ """
+ raiseNotImplementedError
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Gamepad controller for SE(2) control."""
+
+importnumpyasnp
+importweakref
+fromcollections.abcimportCallable
+
+importcarb
+importomni
+
+from..device_baseimportDeviceBase
+
+
+
[文档]classSe2Gamepad(DeviceBase):
+r"""A gamepad controller for sending SE(2) commands as velocity commands.
+
+ This class is designed to provide a gamepad controller for mobile base (such as quadrupeds).
+ It uses the Omniverse gamepad interface to listen to gamepad events and map them to robot's
+ task-space commands.
+
+ The command comprises of the base linear and angular velocity: :math:`(v_x, v_y, \omega_z)`.
+
+ Key bindings:
+ ====================== ========================= ========================
+ Command Key (+ve axis) Key (-ve axis)
+ ====================== ========================= ========================
+ Move along x-axis left stick up left stick down
+ Move along y-axis left stick right left stick left
+ Rotate along z-axis right stick right right stick left
+ ====================== ========================= ========================
+
+ .. seealso::
+
+ The official documentation for the gamepad interface: `Carb Gamepad Interface <https://docs.omniverse.nvidia.com/dev-guide/latest/programmer_ref/input-devices/gamepad.html>`__.
+
+ """
+
+
[文档]def__init__(
+ self,
+ v_x_sensitivity:float=1.0,
+ v_y_sensitivity:float=1.0,
+ omega_z_sensitivity:float=1.0,
+ dead_zone:float=0.01,
+ ):
+"""Initialize the gamepad layer.
+
+ Args:
+ v_x_sensitivity: Magnitude of linear velocity along x-direction scaling. Defaults to 1.0.
+ v_y_sensitivity: Magnitude of linear velocity along y-direction scaling. Defaults to 1.0.
+ omega_z_sensitivity: Magnitude of angular velocity along z-direction scaling. Defaults to 1.0.
+ dead_zone: Magnitude of dead zone for gamepad. An event value from the gamepad less than
+ this value will be ignored. Defaults to 0.01.
+ """
+ # turn off simulator gamepad control
+ carb_settings_iface=carb.settings.get_settings()
+ carb_settings_iface.set_bool("/persistent/app/omniverse/gamepadCameraControl",False)
+ # store inputs
+ self.v_x_sensitivity=v_x_sensitivity
+ self.v_y_sensitivity=v_y_sensitivity
+ self.omega_z_sensitivity=omega_z_sensitivity
+ self.dead_zone=dead_zone
+ # acquire omniverse interfaces
+ self._appwindow=omni.appwindow.get_default_app_window()
+ self._input=carb.input.acquire_input_interface()
+ self._gamepad=self._appwindow.get_gamepad(0)
+ # note: Use weakref on callbacks to ensure that this object can be deleted when its destructor is called
+ self._gamepad_sub=self._input.subscribe_to_gamepad_events(
+ self._gamepad,
+ lambdaevent,*args,obj=weakref.proxy(self):obj._on_gamepad_event(event,*args),
+ )
+ # bindings for gamepad to command
+ self._create_key_bindings()
+ # command buffers
+ # When using the gamepad, two values are provided for each axis.
+ # For example: when the left stick is moved down, there are two evens: `left_stick_down = 0.8`
+ # and `left_stick_up = 0.0`. If only the value of left_stick_up is used, the value will be 0.0,
+ # which is not the desired behavior. Therefore, we save both the values into the buffer and use
+ # the maximum value.
+ # (positive, negative), (x, y, yaw)
+ self._base_command_raw=np.zeros([2,3])
+ # dictionary for additional callbacks
+ self._additional_callbacks=dict()
+
+ def__del__(self):
+"""Unsubscribe from gamepad events."""
+ self._input.unsubscribe_from_gamepad_events(self._gamepad,self._gamepad_sub)
+ self._gamepad_sub=None
+
+ def__str__(self)->str:
+"""Returns: A string containing the information of joystick."""
+ msg=f"Gamepad Controller for SE(2): {self.__class__.__name__}\n"
+ msg+=f"\tDevice name: {self._input.get_gamepad_name(self._gamepad)}\n"
+ msg+="\t----------------------------------------------\n"
+ msg+="\tMove in X-Y plane: left stick\n"
+ msg+="\tRotate in Z-axis: right stick\n"
+ returnmsg
+
+"""
+ Operations
+ """
+
+
[文档]defadd_callback(self,key:carb.input.GamepadInput,func:Callable):
+"""Add additional functions to bind gamepad.
+
+ A list of available gamepad keys are present in the
+ `carb documentation <https://docs.omniverse.nvidia.com/dev-guide/latest/programmer_ref/input-devices/gamepad.html>`__.
+
+ Args:
+ key: The gamepad button to check against.
+ func: The function to call when key is pressed. The callback function should not
+ take any arguments.
+ """
+ self._additional_callbacks[key]=func
+
+
[文档]defadvance(self)->np.ndarray:
+"""Provides the result from gamepad event state.
+
+ Returns:
+ A 3D array containing the linear (x,y) and angular velocity (z).
+ """
+ returnself._resolve_command_buffer(self._base_command_raw)
+
+"""
+ Internal helpers.
+ """
+
+ def_on_gamepad_event(self,event:carb.input.GamepadEvent,*args,**kwargs):
+"""Subscriber callback to when kit is updated.
+
+ Reference:
+ https://docs.omniverse.nvidia.com/dev-guide/latest/programmer_ref/input-devices/gamepad.html
+ """
+
+ # check if the event is a button press
+ cur_val=event.value
+ ifabs(cur_val)<self.dead_zone:
+ cur_val=0
+ # -- left and right stick
+ ifevent.inputinself._INPUT_STICK_VALUE_MAPPING:
+ direction,axis,value=self._INPUT_STICK_VALUE_MAPPING[event.input]
+ # change the value only if the stick is moved (soft press)
+ self._base_command_raw[direction,axis]=value*cur_val
+
+ # additional callbacks
+ ifevent.inputinself._additional_callbacks:
+ self._additional_callbacks[event.input]()
+
+ # since no error, we are fine :)
+ returnTrue
+
+ def_create_key_bindings(self):
+"""Creates default key binding."""
+ self._INPUT_STICK_VALUE_MAPPING={
+ # forward command
+ carb.input.GamepadInput.LEFT_STICK_UP:(0,0,self.v_x_sensitivity),
+ # backward command
+ carb.input.GamepadInput.LEFT_STICK_DOWN:(1,0,self.v_x_sensitivity),
+ # right command
+ carb.input.GamepadInput.LEFT_STICK_RIGHT:(0,1,self.v_y_sensitivity),
+ # left command
+ carb.input.GamepadInput.LEFT_STICK_LEFT:(1,1,self.v_y_sensitivity),
+ # yaw command (positive)
+ carb.input.GamepadInput.RIGHT_STICK_RIGHT:(0,2,self.omega_z_sensitivity),
+ # yaw command (negative)
+ carb.input.GamepadInput.RIGHT_STICK_LEFT:(1,2,self.omega_z_sensitivity),
+ }
+
+ def_resolve_command_buffer(self,raw_command:np.ndarray)->np.ndarray:
+"""Resolves the command buffer.
+
+ Args:
+ raw_command: The raw command from the gamepad. Shape is (2, 3)
+ This is a 2D array since gamepad dpad/stick returns two values corresponding to
+ the positive and negative direction. The first index is the direction (0: positive, 1: negative)
+ and the second index is value (absolute) of the command.
+
+ Returns:
+ Resolved command. Shape is (3,)
+ """
+ # compare the positive and negative value decide the sign of the value
+ # if the positive value is larger, the sign is positive (i.e. False, 0)
+ # if the negative value is larger, the sign is positive (i.e. True, 1)
+ command_sign=raw_command[1,:]>raw_command[0,:]
+ # extract the command value
+ command=raw_command.max(axis=0)
+ # apply the sign
+ # if the sign is positive, the value is already positive.
+ # if the sign is negative, the value is negative after applying the sign.
+ command[command_sign]*=-1
+
+ returncommand
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Gamepad controller for SE(3) control."""
+
+importnumpyasnp
+importweakref
+fromcollections.abcimportCallable
+fromscipy.spatial.transformimportRotation
+
+importcarb
+importomni
+
+from..device_baseimportDeviceBase
+
+
+
[文档]classSe3Gamepad(DeviceBase):
+"""A gamepad controller for sending SE(3) commands as delta poses and binary command (open/close).
+
+ This class is designed to provide a gamepad controller for a robotic arm with a gripper.
+ It uses the gamepad interface to listen to gamepad events and map them to the robot's
+ task-space commands.
+
+ The command comprises of two parts:
+
+ * delta pose: a 6D vector of (x, y, z, roll, pitch, yaw) in meters and radians.
+ * gripper: a binary command to open or close the gripper.
+
+ Stick and Button bindings:
+ ============================ ========================= =========================
+ Description Stick/Button (+ve axis) Stick/Button (-ve axis)
+ ============================ ========================= =========================
+ Toggle gripper(open/close) X Button X Button
+ Move along x-axis Left Stick Up Left Stick Down
+ Move along y-axis Left Stick Left Left Stick Right
+ Move along z-axis Right Stick Up Right Stick Down
+ Rotate along x-axis D-Pad Left D-Pad Right
+ Rotate along y-axis D-Pad Down D-Pad Up
+ Rotate along z-axis Right Stick Left Right Stick Right
+ ============================ ========================= =========================
+
+ .. seealso::
+
+ The official documentation for the gamepad interface: `Carb Gamepad Interface <https://docs.omniverse.nvidia.com/dev-guide/latest/programmer_ref/input-devices/gamepad.html>`__.
+
+ """
+
+
[文档]def__init__(self,pos_sensitivity:float=1.0,rot_sensitivity:float=1.6,dead_zone:float=0.01):
+"""Initialize the gamepad layer.
+
+ Args:
+ pos_sensitivity: Magnitude of input position command scaling. Defaults to 1.0.
+ rot_sensitivity: Magnitude of scale input rotation commands scaling. Defaults to 1.6.
+ dead_zone: Magnitude of dead zone for gamepad. An event value from the gamepad less than
+ this value will be ignored. Defaults to 0.01.
+ """
+ # turn off simulator gamepad control
+ carb_settings_iface=carb.settings.get_settings()
+ carb_settings_iface.set_bool("/persistent/app/omniverse/gamepadCameraControl",False)
+ # store inputs
+ self.pos_sensitivity=pos_sensitivity
+ self.rot_sensitivity=rot_sensitivity
+ self.dead_zone=dead_zone
+ # acquire omniverse interfaces
+ self._appwindow=omni.appwindow.get_default_app_window()
+ self._input=carb.input.acquire_input_interface()
+ self._gamepad=self._appwindow.get_gamepad(0)
+ # note: Use weakref on callbacks to ensure that this object can be deleted when its destructor is called
+ self._gamepad_sub=self._input.subscribe_to_gamepad_events(
+ self._gamepad,
+ lambdaevent,*args,obj=weakref.proxy(self):obj._on_gamepad_event(event,*args),
+ )
+ # bindings for gamepad to command
+ self._create_key_bindings()
+ # command buffers
+ self._close_gripper=False
+ # When using the gamepad, two values are provided for each axis.
+ # For example: when the left stick is moved down, there are two evens: `left_stick_down = 0.8`
+ # and `left_stick_up = 0.0`. If only the value of left_stick_up is used, the value will be 0.0,
+ # which is not the desired behavior. Therefore, we save both the values into the buffer and use
+ # the maximum value.
+ # (positive, negative), (x, y, z, roll, pitch, yaw)
+ self._delta_pose_raw=np.zeros([2,6])
+ # dictionary for additional callbacks
+ self._additional_callbacks=dict()
+
+ def__del__(self):
+"""Unsubscribe from gamepad events."""
+ self._input.unsubscribe_from_gamepad_events(self._gamepad,self._gamepad_sub)
+ self._gamepad_sub=None
+
+ def__str__(self)->str:
+"""Returns: A string containing the information of joystick."""
+ msg=f"Gamepad Controller for SE(3): {self.__class__.__name__}\n"
+ msg+=f"\tDevice name: {self._input.get_gamepad_name(self._gamepad)}\n"
+ msg+="\t----------------------------------------------\n"
+ msg+="\tToggle gripper (open/close): X\n"
+ msg+="\tMove arm along x-axis: Left Stick Up/Down\n"
+ msg+="\tMove arm along y-axis: Left Stick Left/Right\n"
+ msg+="\tMove arm along z-axis: Right Stick Up/Down\n"
+ msg+="\tRotate arm along x-axis: D-Pad Right/Left\n"
+ msg+="\tRotate arm along y-axis: D-Pad Down/Up\n"
+ msg+="\tRotate arm along z-axis: Right Stick Left/Right\n"
+ returnmsg
+
+"""
+ Operations
+ """
+
+
[文档]defadd_callback(self,key:carb.input.GamepadInput,func:Callable):
+"""Add additional functions to bind gamepad.
+
+ A list of available gamepad keys are present in the
+ `carb documentation <https://docs.omniverse.nvidia.com/dev-guide/latest/programmer_ref/input-devices/gamepad.html>`__.
+
+ Args:
+ key: The gamepad button to check against.
+ func: The function to call when key is pressed. The callback function should not
+ take any arguments.
+ """
+ self._additional_callbacks[key]=func
+
+
[文档]defadvance(self)->tuple[np.ndarray,bool]:
+"""Provides the result from gamepad event state.
+
+ Returns:
+ A tuple containing the delta pose command and gripper commands.
+ """
+ # -- resolve position command
+ delta_pos=self._resolve_command_buffer(self._delta_pose_raw[:,:3])
+ # -- resolve rotation command
+ delta_rot=self._resolve_command_buffer(self._delta_pose_raw[:,3:])
+ # -- convert to rotation vector
+ rot_vec=Rotation.from_euler("XYZ",delta_rot).as_rotvec()
+ # return the command and gripper state
+ returnnp.concatenate([delta_pos,rot_vec]),self._close_gripper
+
+"""
+ Internal helpers.
+ """
+
+ def_on_gamepad_event(self,event,*args,**kwargs):
+"""Subscriber callback to when kit is updated.
+
+ Reference:
+ https://docs.omniverse.nvidia.com/dev-guide/latest/programmer_ref/input-devices/gamepad.html
+ """
+ # check if the event is a button press
+ cur_val=event.value
+ ifabs(cur_val)<self.dead_zone:
+ cur_val=0
+ # -- button
+ ifevent.input==carb.input.GamepadInput.X:
+ # toggle gripper based on the button pressed
+ ifcur_val>0.5:
+ self._close_gripper=notself._close_gripper
+ # -- left and right stick
+ ifevent.inputinself._INPUT_STICK_VALUE_MAPPING:
+ direction,axis,value=self._INPUT_STICK_VALUE_MAPPING[event.input]
+ # change the value only if the stick is moved (soft press)
+ self._delta_pose_raw[direction,axis]=value*cur_val
+ # -- dpad (4 arrow buttons on the console)
+ ifevent.inputinself._INPUT_DPAD_VALUE_MAPPING:
+ direction,axis,value=self._INPUT_DPAD_VALUE_MAPPING[event.input]
+ # change the value only if button is pressed on the DPAD
+ ifcur_val>0.5:
+ self._delta_pose_raw[direction,axis]=value
+ self._delta_pose_raw[1-direction,axis]=0
+ else:
+ self._delta_pose_raw[:,axis]=0
+ # additional callbacks
+ ifevent.inputinself._additional_callbacks:
+ self._additional_callbacks[event.input]()
+
+ # since no error, we are fine :)
+ returnTrue
+
+ def_create_key_bindings(self):
+"""Creates default key binding."""
+ # map gamepad input to the element in self._delta_pose_raw
+ # the first index is the direction (0: positive, 1: negative)
+ # the second index is the axis (0: x, 1: y, 2: z, 3: roll, 4: pitch, 5: yaw)
+ # the third index is the sensitivity of the command
+ self._INPUT_STICK_VALUE_MAPPING={
+ # forward command
+ carb.input.GamepadInput.LEFT_STICK_UP:(0,0,self.pos_sensitivity),
+ # backward command
+ carb.input.GamepadInput.LEFT_STICK_DOWN:(1,0,self.pos_sensitivity),
+ # right command
+ carb.input.GamepadInput.LEFT_STICK_RIGHT:(0,1,self.pos_sensitivity),
+ # left command
+ carb.input.GamepadInput.LEFT_STICK_LEFT:(1,1,self.pos_sensitivity),
+ # upward command
+ carb.input.GamepadInput.RIGHT_STICK_UP:(0,2,self.pos_sensitivity),
+ # downward command
+ carb.input.GamepadInput.RIGHT_STICK_DOWN:(1,2,self.pos_sensitivity),
+ # yaw command (positive)
+ carb.input.GamepadInput.RIGHT_STICK_RIGHT:(0,5,self.rot_sensitivity),
+ # yaw command (negative)
+ carb.input.GamepadInput.RIGHT_STICK_LEFT:(1,5,self.rot_sensitivity),
+ }
+
+ self._INPUT_DPAD_VALUE_MAPPING={
+ # pitch command (positive)
+ carb.input.GamepadInput.DPAD_UP:(1,4,self.rot_sensitivity*0.8),
+ # pitch command (negative)
+ carb.input.GamepadInput.DPAD_DOWN:(0,4,self.rot_sensitivity*0.8),
+ # roll command (positive)
+ carb.input.GamepadInput.DPAD_RIGHT:(1,3,self.rot_sensitivity*0.8),
+ # roll command (negative)
+ carb.input.GamepadInput.DPAD_LEFT:(0,3,self.rot_sensitivity*0.8),
+ }
+
+ def_resolve_command_buffer(self,raw_command:np.ndarray)->np.ndarray:
+"""Resolves the command buffer.
+
+ Args:
+ raw_command: The raw command from the gamepad. Shape is (2, 3)
+ This is a 2D array since gamepad dpad/stick returns two values corresponding to
+ the positive and negative direction. The first index is the direction (0: positive, 1: negative)
+ and the second index is value (absolute) of the command.
+
+ Returns:
+ Resolved command. Shape is (3,)
+ """
+ # compare the positive and negative value decide the sign of the value
+ # if the positive value is larger, the sign is positive (i.e. False, 0)
+ # if the negative value is larger, the sign is positive (i.e. True, 1)
+ delta_command_sign=raw_command[1,:]>raw_command[0,:]
+ # extract the command value
+ delta_command=raw_command.max(axis=0)
+ # apply the sign
+ # if the sign is positive, the value is already positive.
+ # if the sign is negative, the value is negative after applying the sign.
+ delta_command[delta_command_sign]*=-1
+
+ returndelta_command
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Keyboard controller for SE(2) control."""
+
+importnumpyasnp
+importweakref
+fromcollections.abcimportCallable
+
+importcarb
+importomni
+
+from..device_baseimportDeviceBase
+
+
+
[文档]classSe2Keyboard(DeviceBase):
+r"""A keyboard controller for sending SE(2) commands as velocity commands.
+
+ This class is designed to provide a keyboard controller for mobile base (such as quadrupeds).
+ It uses the Omniverse keyboard interface to listen to keyboard events and map them to robot's
+ task-space commands.
+
+ The command comprises of the base linear and angular velocity: :math:`(v_x, v_y, \omega_z)`.
+
+ Key bindings:
+ ====================== ========================= ========================
+ Command Key (+ve axis) Key (-ve axis)
+ ====================== ========================= ========================
+ Move along x-axis Numpad 8 / Arrow Up Numpad 2 / Arrow Down
+ Move along y-axis Numpad 4 / Arrow Right Numpad 6 / Arrow Left
+ Rotate along z-axis Numpad 7 / Z Numpad 9 / X
+ ====================== ========================= ========================
+
+ .. seealso::
+
+ The official documentation for the keyboard interface: `Carb Keyboard Interface <https://docs.omniverse.nvidia.com/dev-guide/latest/programmer_ref/input-devices/keyboard.html>`__.
+
+ """
+
+
[文档]def__init__(self,v_x_sensitivity:float=0.8,v_y_sensitivity:float=0.4,omega_z_sensitivity:float=1.0):
+"""Initialize the keyboard layer.
+
+ Args:
+ v_x_sensitivity: Magnitude of linear velocity along x-direction scaling. Defaults to 0.8.
+ v_y_sensitivity: Magnitude of linear velocity along y-direction scaling. Defaults to 0.4.
+ omega_z_sensitivity: Magnitude of angular velocity along z-direction scaling. Defaults to 1.0.
+ """
+ # store inputs
+ self.v_x_sensitivity=v_x_sensitivity
+ self.v_y_sensitivity=v_y_sensitivity
+ self.omega_z_sensitivity=omega_z_sensitivity
+ # acquire omniverse interfaces
+ self._appwindow=omni.appwindow.get_default_app_window()
+ self._input=carb.input.acquire_input_interface()
+ self._keyboard=self._appwindow.get_keyboard()
+ # note: Use weakref on callbacks to ensure that this object can be deleted when its destructor is called
+ self._keyboard_sub=self._input.subscribe_to_keyboard_events(
+ self._keyboard,
+ lambdaevent,*args,obj=weakref.proxy(self):obj._on_keyboard_event(event,*args),
+ )
+ # bindings for keyboard to command
+ self._create_key_bindings()
+ # command buffers
+ self._base_command=np.zeros(3)
+ # dictionary for additional callbacks
+ self._additional_callbacks=dict()
[文档]defadd_callback(self,key:str,func:Callable):
+"""Add additional functions to bind keyboard.
+
+ A list of available keys are present in the
+ `carb documentation <https://docs.omniverse.nvidia.com/dev-guide/latest/programmer_ref/input-devices/keyboard.html>`__.
+
+ Args:
+ key: The keyboard button to check against.
+ func: The function to call when key is pressed. The callback function should not
+ take any arguments.
+ """
+ self._additional_callbacks[key]=func
+
+
[文档]defadvance(self)->np.ndarray:
+"""Provides the result from keyboard event state.
+
+ Returns:
+ 3D array containing the linear (x,y) and angular velocity (z).
+ """
+ returnself._base_command
+
+"""
+ Internal helpers.
+ """
+
+ def_on_keyboard_event(self,event,*args,**kwargs):
+"""Subscriber callback to when kit is updated.
+
+ Reference:
+ https://docs.omniverse.nvidia.com/dev-guide/latest/programmer_ref/input-devices/keyboard.html
+ """
+ # apply the command when pressed
+ ifevent.type==carb.input.KeyboardEventType.KEY_PRESS:
+ ifevent.input.name=="L":
+ self.reset()
+ elifevent.input.nameinself._INPUT_KEY_MAPPING:
+ self._base_command+=self._INPUT_KEY_MAPPING[event.input.name]
+ # remove the command when un-pressed
+ ifevent.type==carb.input.KeyboardEventType.KEY_RELEASE:
+ ifevent.input.nameinself._INPUT_KEY_MAPPING:
+ self._base_command-=self._INPUT_KEY_MAPPING[event.input.name]
+ # additional callbacks
+ ifevent.type==carb.input.KeyboardEventType.KEY_PRESS:
+ ifevent.input.nameinself._additional_callbacks:
+ self._additional_callbacks[event.input.name]()
+
+ # since no error, we are fine :)
+ returnTrue
+
+ def_create_key_bindings(self):
+"""Creates default key binding."""
+ self._INPUT_KEY_MAPPING={
+ # forward command
+ "NUMPAD_8":np.asarray([1.0,0.0,0.0])*self.v_x_sensitivity,
+ "UP":np.asarray([1.0,0.0,0.0])*self.v_x_sensitivity,
+ # back command
+ "NUMPAD_2":np.asarray([-1.0,0.0,0.0])*self.v_x_sensitivity,
+ "DOWN":np.asarray([-1.0,0.0,0.0])*self.v_x_sensitivity,
+ # right command
+ "NUMPAD_4":np.asarray([0.0,1.0,0.0])*self.v_y_sensitivity,
+ "LEFT":np.asarray([0.0,1.0,0.0])*self.v_y_sensitivity,
+ # left command
+ "NUMPAD_6":np.asarray([0.0,-1.0,0.0])*self.v_y_sensitivity,
+ "RIGHT":np.asarray([0.0,-1.0,0.0])*self.v_y_sensitivity,
+ # yaw command (positive)
+ "NUMPAD_7":np.asarray([0.0,0.0,1.0])*self.omega_z_sensitivity,
+ "Z":np.asarray([0.0,0.0,1.0])*self.omega_z_sensitivity,
+ # yaw command (negative)
+ "NUMPAD_9":np.asarray([0.0,0.0,-1.0])*self.omega_z_sensitivity,
+ "X":np.asarray([0.0,0.0,-1.0])*self.omega_z_sensitivity,
+ }
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Keyboard controller for SE(3) control."""
+
+importnumpyasnp
+importweakref
+fromcollections.abcimportCallable
+fromscipy.spatial.transformimportRotation
+
+importcarb
+importomni
+
+from..device_baseimportDeviceBase
+
+
+
[文档]classSe3Keyboard(DeviceBase):
+"""A keyboard controller for sending SE(3) commands as delta poses and binary command (open/close).
+
+ This class is designed to provide a keyboard controller for a robotic arm with a gripper.
+ It uses the Omniverse keyboard interface to listen to keyboard events and map them to robot's
+ task-space commands.
+
+ The command comprises of two parts:
+
+ * delta pose: a 6D vector of (x, y, z, roll, pitch, yaw) in meters and radians.
+ * gripper: a binary command to open or close the gripper.
+
+ Key bindings:
+ ============================== ================= =================
+ Description Key (+ve axis) Key (-ve axis)
+ ============================== ================= =================
+ Toggle gripper (open/close) K
+ Move along x-axis W S
+ Move along y-axis A D
+ Move along z-axis Q E
+ Rotate along x-axis Z X
+ Rotate along y-axis T G
+ Rotate along z-axis C V
+ ============================== ================= =================
+
+ .. seealso::
+
+ The official documentation for the keyboard interface: `Carb Keyboard Interface <https://docs.omniverse.nvidia.com/dev-guide/latest/programmer_ref/input-devices/keyboard.html>`__.
+
+ """
+
+
[文档]def__init__(self,pos_sensitivity:float=0.4,rot_sensitivity:float=0.8):
+"""Initialize the keyboard layer.
+
+ Args:
+ pos_sensitivity: Magnitude of input position command scaling. Defaults to 0.05.
+ rot_sensitivity: Magnitude of scale input rotation commands scaling. Defaults to 0.5.
+ """
+ # store inputs
+ self.pos_sensitivity=pos_sensitivity
+ self.rot_sensitivity=rot_sensitivity
+ # acquire omniverse interfaces
+ self._appwindow=omni.appwindow.get_default_app_window()
+ self._input=carb.input.acquire_input_interface()
+ self._keyboard=self._appwindow.get_keyboard()
+ # note: Use weakref on callbacks to ensure that this object can be deleted when its destructor is called.
+ self._keyboard_sub=self._input.subscribe_to_keyboard_events(
+ self._keyboard,
+ lambdaevent,*args,obj=weakref.proxy(self):obj._on_keyboard_event(event,*args),
+ )
+ # bindings for keyboard to command
+ self._create_key_bindings()
+ # command buffers
+ self._close_gripper=False
+ self._delta_pos=np.zeros(3)# (x, y, z)
+ self._delta_rot=np.zeros(3)# (roll, pitch, yaw)
+ # dictionary for additional callbacks
+ self._additional_callbacks=dict()
+
+ def__del__(self):
+"""Release the keyboard interface."""
+ self._input.unsubscribe_from_keyboard_events(self._keyboard,self._keyboard_sub)
+ self._keyboard_sub=None
+
+ def__str__(self)->str:
+"""Returns: A string containing the information of joystick."""
+ msg=f"Keyboard Controller for SE(3): {self.__class__.__name__}\n"
+ msg+=f"\tKeyboard name: {self._input.get_keyboard_name(self._keyboard)}\n"
+ msg+="\t----------------------------------------------\n"
+ msg+="\tToggle gripper (open/close): K\n"
+ msg+="\tMove arm along x-axis: W/S\n"
+ msg+="\tMove arm along y-axis: A/D\n"
+ msg+="\tMove arm along z-axis: Q/E\n"
+ msg+="\tRotate arm along x-axis: Z/X\n"
+ msg+="\tRotate arm along y-axis: T/G\n"
+ msg+="\tRotate arm along z-axis: C/V"
+ returnmsg
+
+"""
+ Operations
+ """
+
+
[文档]defadd_callback(self,key:str,func:Callable):
+"""Add additional functions to bind keyboard.
+
+ A list of available keys are present in the
+ `carb documentation <https://docs.omniverse.nvidia.com/dev-guide/latest/programmer_ref/input-devices/keyboard.html>`__.
+
+ Args:
+ key: The keyboard button to check against.
+ func: The function to call when key is pressed. The callback function should not
+ take any arguments.
+ """
+ self._additional_callbacks[key]=func
+
+
[文档]defadvance(self)->tuple[np.ndarray,bool]:
+"""Provides the result from keyboard event state.
+
+ Returns:
+ A tuple containing the delta pose command and gripper commands.
+ """
+ # convert to rotation vector
+ rot_vec=Rotation.from_euler("XYZ",self._delta_rot).as_rotvec()
+ # return the command and gripper state
+ returnnp.concatenate([self._delta_pos,rot_vec]),self._close_gripper
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Spacemouse controller for SE(2) control."""
+
+importhid
+importnumpyasnp
+importthreading
+importtime
+fromcollections.abcimportCallable
+
+from..device_baseimportDeviceBase
+from.utilsimportconvert_buffer
+
+
+
[文档]classSe2SpaceMouse(DeviceBase):
+r"""A space-mouse controller for sending SE(2) commands as delta poses.
+
+ This class implements a space-mouse controller to provide commands to mobile base.
+ It uses the `HID-API`_ which interfaces with USD and Bluetooth HID-class devices across multiple platforms.
+
+ The command comprises of the base linear and angular velocity: :math:`(v_x, v_y, \omega_z)`.
+
+ Note:
+ The interface finds and uses the first supported device connected to the computer.
+
+ Currently tested for following devices:
+
+ - SpaceMouse Compact: https://3dconnexion.com/de/product/spacemouse-compact/
+
+ .. _HID-API: https://github.com/libusb/hidapi
+
+ """
+
+
[文档]def__init__(self,v_x_sensitivity:float=0.8,v_y_sensitivity:float=0.4,omega_z_sensitivity:float=1.0):
+"""Initialize the spacemouse layer.
+
+ Args:
+ v_x_sensitivity: Magnitude of linear velocity along x-direction scaling. Defaults to 0.8.
+ v_y_sensitivity: Magnitude of linear velocity along y-direction scaling. Defaults to 0.4.
+ omega_z_sensitivity: Magnitude of angular velocity along z-direction scaling. Defaults to 1.0.
+ """
+ # store inputs
+ self.v_x_sensitivity=v_x_sensitivity
+ self.v_y_sensitivity=v_y_sensitivity
+ self.omega_z_sensitivity=omega_z_sensitivity
+ # acquire device interface
+ self._device=hid.device()
+ self._find_device()
+ # command buffers
+ self._base_command=np.zeros(3)
+ # dictionary for additional callbacks
+ self._additional_callbacks=dict()
+ # run a thread for listening to device updates
+ self._thread=threading.Thread(target=self._run_device)
+ self._thread.daemon=True
+ self._thread.start()
+
+ def__del__(self):
+"""Destructor for the class."""
+ self._thread.join()
+
+ def__str__(self)->str:
+"""Returns: A string containing the information of joystick."""
+ msg=f"Spacemouse Controller for SE(2): {self.__class__.__name__}\n"
+ msg+=f"\tManufacturer: {self._device.get_manufacturer_string()}\n"
+ msg+=f"\tProduct: {self._device.get_product_string()}\n"
+ msg+="\t----------------------------------------------\n"
+ msg+="\tRight button: reset command\n"
+ msg+="\tMove mouse laterally: move base horizontally in x-y plane\n"
+ msg+="\tTwist mouse about z-axis: yaw base about a corresponding axis"
+ returnmsg
+
+"""
+ Operations
+ """
+
+
[文档]defadd_callback(self,key:str,func:Callable):
+ # check keys supported by callback
+ ifkeynotin["L","R"]:
+ raiseValueError(f"Only left (L) and right (R) buttons supported. Provided: {key}.")
+ # TODO: Improve this to allow multiple buttons on same key.
+ self._additional_callbacks[key]=func
+
+
[文档]defadvance(self)->np.ndarray:
+"""Provides the result from spacemouse event state.
+
+ Returns:
+ A 3D array containing the linear (x,y) and angular velocity (z).
+ """
+ returnself._base_command
+
+"""
+ Internal helpers.
+ """
+
+ def_find_device(self):
+"""Find the device connected to computer."""
+ found=False
+ # implement a timeout for device search
+ for_inrange(5):
+ fordeviceinhid.enumerate():
+ ifdevice["product_string"]=="SpaceMouse Compact":
+ # set found flag
+ found=True
+ vendor_id=device["vendor_id"]
+ product_id=device["product_id"]
+ # connect to the device
+ self._device.open(vendor_id,product_id)
+ # check if device found
+ ifnotfound:
+ time.sleep(1.0)
+ else:
+ break
+ # no device found: return false
+ ifnotfound:
+ raiseOSError("No device found by SpaceMouse. Is the device connected?")
+
+ def_run_device(self):
+"""Listener thread that keeps pulling new messages."""
+ # keep running
+ whileTrue:
+ # read the device data
+ data=self._device.read(13)
+ ifdataisnotNone:
+ # readings from 6-DoF sensor
+ ifdata[0]==1:
+ # along y-axis
+ self._base_command[1]=self.v_y_sensitivity*convert_buffer(data[1],data[2])
+ # along x-axis
+ self._base_command[0]=self.v_x_sensitivity*convert_buffer(data[3],data[4])
+ elifdata[0]==2:
+ # along z-axis
+ self._base_command[2]=self.omega_z_sensitivity*convert_buffer(data[3],data[4])
+ # readings from the side buttons
+ elifdata[0]==3:
+ # press left button
+ ifdata[1]==1:
+ # additional callbacks
+ if"L"inself._additional_callbacks:
+ self._additional_callbacks["L"]
+ # right button is for reset
+ ifdata[1]==2:
+ # reset layer
+ self.reset()
+ # additional callbacks
+ if"R"inself._additional_callbacks:
+ self._additional_callbacks["R"]
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Spacemouse controller for SE(3) control."""
+
+importhid
+importnumpyasnp
+importthreading
+importtime
+fromcollections.abcimportCallable
+fromscipy.spatial.transformimportRotation
+
+from..device_baseimportDeviceBase
+from.utilsimportconvert_buffer
+
+
+
[文档]classSe3SpaceMouse(DeviceBase):
+"""A space-mouse controller for sending SE(3) commands as delta poses.
+
+ This class implements a space-mouse controller to provide commands to a robotic arm with a gripper.
+ It uses the `HID-API`_ which interfaces with USD and Bluetooth HID-class devices across multiple platforms [1].
+
+ The command comprises of two parts:
+
+ * delta pose: a 6D vector of (x, y, z, roll, pitch, yaw) in meters and radians.
+ * gripper: a binary command to open or close the gripper.
+
+ Note:
+ The interface finds and uses the first supported device connected to the computer.
+
+ Currently tested for following devices:
+
+ - SpaceMouse Compact: https://3dconnexion.com/de/product/spacemouse-compact/
+
+ .. _HID-API: https://github.com/libusb/hidapi
+
+ """
+
+
[文档]def__init__(self,pos_sensitivity:float=0.4,rot_sensitivity:float=0.8):
+"""Initialize the space-mouse layer.
+
+ Args:
+ pos_sensitivity: Magnitude of input position command scaling. Defaults to 0.4.
+ rot_sensitivity: Magnitude of scale input rotation commands scaling. Defaults to 0.8.
+ """
+ # store inputs
+ self.pos_sensitivity=pos_sensitivity
+ self.rot_sensitivity=rot_sensitivity
+ # acquire device interface
+ self._device=hid.device()
+ self._find_device()
+ # read rotations
+ self._read_rotation=False
+
+ # command buffers
+ self._close_gripper=False
+ self._delta_pos=np.zeros(3)# (x, y, z)
+ self._delta_rot=np.zeros(3)# (roll, pitch, yaw)
+ # dictionary for additional callbacks
+ self._additional_callbacks=dict()
+ # run a thread for listening to device updates
+ self._thread=threading.Thread(target=self._run_device)
+ self._thread.daemon=True
+ self._thread.start()
+
+ def__del__(self):
+"""Destructor for the class."""
+ self._thread.join()
+
+ def__str__(self)->str:
+"""Returns: A string containing the information of joystick."""
+ msg=f"Spacemouse Controller for SE(3): {self.__class__.__name__}\n"
+ msg+=f"\tManufacturer: {self._device.get_manufacturer_string()}\n"
+ msg+=f"\tProduct: {self._device.get_product_string()}\n"
+ msg+="\t----------------------------------------------\n"
+ msg+="\tRight button: reset command\n"
+ msg+="\tLeft button: toggle gripper command (open/close)\n"
+ msg+="\tMove mouse laterally: move arm horizontally in x-y plane\n"
+ msg+="\tMove mouse vertically: move arm vertically\n"
+ msg+="\tTwist mouse about an axis: rotate arm about a corresponding axis"
+ returnmsg
+
+"""
+ Operations
+ """
+
+
[文档]defadd_callback(self,key:str,func:Callable):
+ # check keys supported by callback
+ ifkeynotin["L","R"]:
+ raiseValueError(f"Only left (L) and right (R) buttons supported. Provided: {key}.")
+ # TODO: Improve this to allow multiple buttons on same key.
+ self._additional_callbacks[key]=func
+
+
[文档]defadvance(self)->tuple[np.ndarray,bool]:
+"""Provides the result from spacemouse event state.
+
+ Returns:
+ A tuple containing the delta pose command and gripper commands.
+ """
+ rot_vec=Rotation.from_euler("XYZ",self._delta_rot).as_rotvec()
+ # if new command received, reset event flag to False until keyboard updated.
+ returnnp.concatenate([self._delta_pos,rot_vec]),self._close_gripper
+
+"""
+ Internal helpers.
+ """
+
+ def_find_device(self):
+"""Find the device connected to computer."""
+ found=False
+ # implement a timeout for device search
+ for_inrange(5):
+ fordeviceinhid.enumerate():
+ ifdevice["product_string"]=="SpaceMouse Compact":
+ # set found flag
+ found=True
+ vendor_id=device["vendor_id"]
+ product_id=device["product_id"]
+ # connect to the device
+ self._device.open(vendor_id,product_id)
+ # check if device found
+ ifnotfound:
+ time.sleep(1.0)
+ else:
+ break
+ # no device found: return false
+ ifnotfound:
+ raiseOSError("No device found by SpaceMouse. Is the device connected?")
+
+ def_run_device(self):
+"""Listener thread that keeps pulling new messages."""
+ # keep running
+ whileTrue:
+ # read the device data
+ data=self._device.read(7)
+ ifdataisnotNone:
+ # readings from 6-DoF sensor
+ ifdata[0]==1:
+ self._delta_pos[1]=self.pos_sensitivity*convert_buffer(data[1],data[2])
+ self._delta_pos[0]=self.pos_sensitivity*convert_buffer(data[3],data[4])
+ self._delta_pos[2]=self.pos_sensitivity*convert_buffer(data[5],data[6])*-1.0
+ elifdata[0]==2andnotself._read_rotation:
+ self._delta_rot[1]=self.rot_sensitivity*convert_buffer(data[1],data[2])
+ self._delta_rot[0]=self.rot_sensitivity*convert_buffer(data[3],data[4])
+ self._delta_rot[2]=self.rot_sensitivity*convert_buffer(data[5],data[6])
+ # readings from the side buttons
+ elifdata[0]==3:
+ # press left button
+ ifdata[1]==1:
+ # close gripper
+ self._close_gripper=notself._close_gripper
+ # additional callbacks
+ if"L"inself._additional_callbacks:
+ self._additional_callbacks["L"]
+ # right button is for reset
+ ifdata[1]==2:
+ # reset layer
+ self.reset()
+ # additional callbacks
+ if"R"inself._additional_callbacks:
+ self._additional_callbacks["R"]
+ ifdata[1]==3:
+ self._read_rotation=notself._read_rotation
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+from__future__importannotations
+
+importtorch
+fromtypingimportDict,Literal,TypeVar
+
+fromomni.isaac.lab.utilsimportconfigclass
+
+##
+# Configuration.
+##
+
+
+
[文档]@configclass
+classViewerCfg:
+"""Configuration of the scene viewport camera."""
+
+ eye:tuple[float,float,float]=(7.5,7.5,7.5)
+"""Initial camera position (in m). Default is (7.5, 7.5, 7.5)."""
+
+ lookat:tuple[float,float,float]=(0.0,0.0,0.0)
+"""Initial camera target position (in m). Default is (0.0, 0.0, 0.0)."""
+
+ cam_prim_path:str="/OmniverseKit_Persp"
+"""The camera prim path to record images from. Default is "/OmniverseKit_Persp",
+ which is the default camera in the viewport.
+ """
+
+ resolution:tuple[int,int]=(1280,720)
+"""The resolution (width, height) of the camera specified using :attr:`cam_prim_path`.
+ Default is (1280, 720).
+ """
+
+ origin_type:Literal["world","env","asset_root"]="world"
+"""The frame in which the camera position (eye) and target (lookat) are defined in. Default is "world".
+
+ Available options are:
+
+ * ``"world"``: The origin of the world.
+ * ``"env"``: The origin of the environment defined by :attr:`env_index`.
+ * ``"asset_root"``: The center of the asset defined by :attr:`asset_name` in environment :attr:`env_index`.
+ """
+
+ env_index:int=0
+"""The environment index for frame origin. Default is 0.
+
+ This quantity is only effective if :attr:`origin` is set to "env" or "asset_root".
+ """
+
+ asset_name:str|None=None
+"""The asset name in the interactive scene for the frame origin. Default is None.
+
+ This quantity is only effective if :attr:`origin` is set to "asset_root".
+ """
+
+
+##
+# Types.
+##
+
+VecEnvObs=Dict[str,torch.Tensor|Dict[str,torch.Tensor]]
+"""Observation returned by the environment.
+
+The observations are stored in a dictionary. The keys are the group to which the observations belong.
+This is useful for various setups such as reinforcement learning with asymmetric actor-critic or
+multi-agent learning. For non-learning paradigms, this may include observations for different components
+of a system.
+
+Within each group, the observations can be stored either as a dictionary with keys as the names of each
+observation term in the group, or a single tensor obtained from concatenating all the observation terms.
+For example, for asymmetric actor-critic, the observation for the actor and the critic can be accessed
+using the keys ``"policy"`` and ``"critic"`` respectively.
+
+Note:
+ By default, most learning frameworks deal with default and privileged observations in different ways.
+ This handling must be taken care of by the wrapper around the :class:`ManagerBasedRLEnv` instance.
+
+ For included frameworks (RSL-RL, RL-Games, skrl), the observations must have the key "policy". In case,
+ the key "critic" is also present, then the critic observations are taken from the "critic" group.
+ Otherwise, they are the same as the "policy" group.
+
+"""
+
+VecEnvStepReturn=tuple[VecEnvObs,torch.Tensor,torch.Tensor,torch.Tensor,dict]
+"""The environment signals processed at the end of each step.
+
+The tuple contains batched information for each sub-environment. The information is stored in the following order:
+
+1. **Observations**: The observations from the environment.
+2. **Rewards**: The rewards from the environment.
+3. **Terminated Dones**: Whether the environment reached a terminal state, such as task success or robot falling etc.
+4. **Timeout Dones**: Whether the environment reached a timeout state, such as end of max episode length.
+5. **Extras**: A dictionary containing additional information from the environment.
+"""
+
+AgentID=TypeVar("AgentID")
+"""Unique identifier for an agent within a multi-agent environment.
+
+The identifier has to be an immutable object, typically a string (e.g.: ``"agent_0"``).
+"""
+
+ObsType=TypeVar("ObsType",torch.Tensor,Dict[str,torch.Tensor])
+"""A sentinel object to indicate the data type of the observation.
+"""
+
+ActionType=TypeVar("ActionType",torch.Tensor,Dict[str,torch.Tensor])
+"""A sentinel object to indicate the data type of the action.
+"""
+
+StateType=TypeVar("StateType",torch.Tensor,dict)
+"""A sentinel object to indicate the data type of the state.
+"""
+
+EnvStepReturn=tuple[
+ Dict[AgentID,ObsType],
+ Dict[AgentID,torch.Tensor],
+ Dict[AgentID,torch.Tensor],
+ Dict[AgentID,torch.Tensor],
+ Dict[AgentID,dict],
+]
+"""The environment signals processed at the end of each step.
+
+The tuple contains batched information for each sub-environment (keyed by the agent ID).
+The information is stored in the following order:
+
+1. **Observations**: The observations from the environment.
+2. **Rewards**: The rewards from the environment.
+3. **Terminated Dones**: Whether the environment reached a terminal state, such as task success or robot falling etc.
+4. **Timeout Dones**: Whether the environment reached a timeout state, such as end of max episode length.
+5. **Extras**: A dictionary containing additional information from the environment.
+"""
+
[文档]classDirectMARLEnv:
+"""The superclass for the direct workflow to design multi-agent environments.
+
+ This class implements the core functionality for multi-agent reinforcement learning (MARL)
+ environments. It is designed to be used with any RL library. The class is designed
+ to be used with vectorized environments, i.e., the environment is expected to be run
+ in parallel with multiple sub-environments.
+
+ The design of this class is based on the PettingZoo Parallel API.
+ While the environment itself is implemented as a vectorized environment, we do not
+ inherit from :class:`pettingzoo.ParallelEnv` or :class:`gym.vector.VectorEnv`. This is mainly
+ because the class adds various attributes and methods that are inconsistent with them.
+
+ Note:
+ For vectorized environments, it is recommended to **only** call the :meth:`reset`
+ method once before the first call to :meth:`step`, i.e. after the environment is created.
+ After that, the :meth:`step` function handles the reset of terminated sub-environments.
+ This is because the simulator does not support resetting individual sub-environments
+ in a vectorized environment.
+
+ """
+
+ metadata:ClassVar[dict[str,Any]]={
+ "render_modes":[None,"human","rgb_array"],
+ "isaac_sim_version":get_version(),
+ }
+"""Metadata for the environment."""
+
+
[文档]def__init__(self,cfg:DirectMARLEnvCfg,render_mode:str|None=None,**kwargs):
+"""Initialize the environment.
+
+ Args:
+ cfg: The configuration object for the environment.
+ render_mode: The render mode for the environment. Defaults to None, which
+ is similar to ``"human"``.
+
+ Raises:
+ RuntimeError: If a simulation context already exists. The environment must always create one
+ since it configures the simulation context and controls the simulation.
+ """
+ # store inputs to class
+ self.cfg=cfg
+ # store the render mode
+ self.render_mode=render_mode
+ # initialize internal variables
+ self._is_closed=False
+
+ # set the seed for the environment
+ ifself.cfg.seedisnotNone:
+ self.seed(self.cfg.seed)
+ else:
+ carb.log_warn("Seed not set for the environment. The environment creation may not be deterministic.")
+
+ # create a simulation context to control the simulator
+ ifSimulationContext.instance()isNone:
+ self.sim:SimulationContext=SimulationContext(self.cfg.sim)
+ else:
+ raiseRuntimeError("Simulation context already exists. Cannot create a new one.")
+
+ # print useful information
+ print("[INFO]: Base environment:")
+ print(f"\tEnvironment device : {self.device}")
+ print(f"\tEnvironment seed : {self.cfg.seed}")
+ print(f"\tPhysics step-size : {self.physics_dt}")
+ print(f"\tRendering step-size : {self.physics_dt*self.cfg.sim.render_interval}")
+ print(f"\tEnvironment step-size : {self.step_dt}")
+
+ ifself.cfg.sim.render_interval<self.cfg.decimation:
+ msg=(
+ f"The render interval ({self.cfg.sim.render_interval}) is smaller than the decimation "
+ f"({self.cfg.decimation}). Multiple multiple render calls will happen for each environment step."
+ "If this is not intended, set the render interval to be equal to the decimation."
+ )
+ carb.log_warn(msg)
+
+ # generate scene
+ withTimer("[INFO]: Time taken for scene creation","scene_creation"):
+ self.scene=InteractiveScene(self.cfg.scene)
+ self._setup_scene()
+ print("[INFO]: Scene manager: ",self.scene)
+
+ # set up camera viewport controller
+ # viewport is not available in other rendering modes so the function will throw a warning
+ # FIXME: This needs to be fixed in the future when we unify the UI functionalities even for
+ # non-rendering modes.
+ ifself.sim.render_mode>=self.sim.RenderMode.PARTIAL_RENDERING:
+ self.viewport_camera_controller=ViewportCameraController(self,self.cfg.viewer)
+ else:
+ self.viewport_camera_controller=None
+
+ # play the simulator to activate physics handles
+ # note: this activates the physics simulation view that exposes TensorAPIs
+ # note: when started in extension mode, first call sim.reset_async() and then initialize the managers
+ ifbuiltins.ISAAC_LAUNCHED_FROM_TERMINALisFalse:
+ print("[INFO]: Starting the simulation. This may take a few seconds. Please wait...")
+ withTimer("[INFO]: Time taken for simulation start","simulation_start"):
+ self.sim.reset()
+
+ # -- event manager used for randomization
+ ifself.cfg.events:
+ self.event_manager=EventManager(self.cfg.events,self)
+ print("[INFO] Event Manager: ",self.event_manager)
+
+ # make sure torch is running on the correct device
+ if"cuda"inself.device:
+ torch.cuda.set_device(self.device)
+
+ # check if debug visualization is has been implemented by the environment
+ source_code=inspect.getsource(self._set_debug_vis_impl)
+ self.has_debug_vis_implementation="NotImplementedError"notinsource_code
+ self._debug_vis_handle=None
+
+ # extend UI elements
+ # we need to do this here after all the managers are initialized
+ # this is because they dictate the sensors and commands right now
+ ifself.sim.has_gui()andself.cfg.ui_window_class_typeisnotNone:
+ self._window=self.cfg.ui_window_class_type(self,window_name="IsaacLab")
+ else:
+ # if no window, then we don't need to store the window
+ self._window=None
+
+ # allocate dictionary to store metrics
+ self.extras={agent:{}foragentinself.cfg.possible_agents}
+
+ # initialize data and constants
+ # -- counter for simulation steps
+ self._sim_step_counter=0
+ # -- counter for curriculum
+ self.common_step_counter=0
+ # -- init buffers
+ self.episode_length_buf=torch.zeros(self.num_envs,device=self.device,dtype=torch.long)
+ self.reset_buf=torch.zeros(self.num_envs,dtype=torch.bool,device=self.sim.device)
+ self.actions={
+ agent:torch.zeros(self.num_envs,self.cfg.num_actions[agent],device=self.sim.device)
+ foragentinself.cfg.possible_agents
+ }
+
+ # setup the observation, state and action spaces
+ self._configure_env_spaces()
+
+ # setup noise cfg for adding action and observation noise
+ ifself.cfg.action_noise_model:
+ self._action_noise_model:dict[AgentID,NoiseModel]={
+ agent:noise_model.class_type(self.num_envs,noise_model,self.device)
+ foragent,noise_modelinself.cfg.action_noise_model.items()
+ ifnoise_modelisnotNone
+ }
+ ifself.cfg.observation_noise_model:
+ self._observation_noise_model:dict[AgentID,NoiseModel]={
+ agent:noise_model.class_type(self.num_envs,noise_model,self.device)
+ foragent,noise_modelinself.cfg.observation_noise_model.items()
+ ifnoise_modelisnotNone
+ }
+
+ # perform events at the start of the simulation
+ ifself.cfg.events:
+ if"startup"inself.event_manager.available_modes:
+ self.event_manager.apply(mode="startup")
+
+ # print the environment information
+ print("[INFO]: Completed setting up the environment...")
+
+ def__del__(self):
+"""Cleanup for the environment."""
+ self.close()
+
+"""
+ Properties.
+ """
+
+ @property
+ defnum_envs(self)->int:
+"""The number of instances of the environment that are running."""
+ returnself.scene.num_envs
+
+ @property
+ defnum_agents(self)->int:
+"""Number of current agents.
+
+ The number of current agents may change as the environment progresses (e.g.: agents can be added or removed).
+ """
+ returnlen(self.agents)
+
+ @property
+ defmax_num_agents(self)->int:
+"""Number of all possible agents the environment can generate.
+
+ This value remains constant as the environment progresses.
+ """
+ returnlen(self.possible_agents)
+
+ @property
+ defunwrapped(self)->DirectMARLEnv:
+"""Get the unwrapped environment underneath all the layers of wrappers."""
+ returnself
+
+ @property
+ defphysics_dt(self)->float:
+"""The physics time-step (in s).
+
+ This is the lowest time-decimation at which the simulation is happening.
+ """
+ returnself.cfg.sim.dt
+
+ @property
+ defstep_dt(self)->float:
+"""The environment stepping time-step (in s).
+
+ This is the time-step at which the environment steps forward.
+ """
+ returnself.cfg.sim.dt*self.cfg.decimation
+
+ @property
+ defdevice(self):
+"""The device on which the environment is running."""
+ returnself.sim.device
+
+ @property
+ defmax_episode_length_s(self)->float:
+"""Maximum episode length in seconds."""
+ returnself.cfg.episode_length_s
+
+ @property
+ defmax_episode_length(self):
+"""The maximum episode length in steps adjusted from s."""
+ returnmath.ceil(self.max_episode_length_s/(self.cfg.sim.dt*self.cfg.decimation))
+
+"""
+ Space methods
+ """
+
+
[文档]defobservation_space(self,agent:AgentID)->gym.Space:
+"""Get the observation space for the specified agent.
+
+ Returns:
+ The agent's observation space.
+ """
+ returnself.observation_spaces[agent]
+
+
[文档]defaction_space(self,agent:AgentID)->gym.Space:
+"""Get the action space for the specified agent.
+
+ Returns:
+ The agent's action space.
+ """
+ returnself.action_spaces[agent]
+
+"""
+ Operations.
+ """
+
+
[文档]defreset(
+ self,seed:int|None=None,options:dict[str,Any]|None=None
+ )->tuple[dict[AgentID,ObsType],dict[AgentID,dict]]:
+"""Resets all the environments and returns observations.
+
+ Args:
+ seed: The seed to use for randomization. Defaults to None, in which case the seed is not set.
+ options: Additional information to specify how the environment is reset. Defaults to None.
+
+ Note:
+ This argument is used for compatibility with Gymnasium environment definition.
+
+ Returns:
+ A tuple containing the observations and extras (keyed by the agent ID).
+ """
+ # set the seed
+ ifseedisnotNone:
+ self.seed(seed)
+
+ # reset state of scene
+ indices=torch.arange(self.num_envs,dtype=torch.int64,device=self.device)
+ self._reset_idx(indices)
+
+ # update observations and the list of current agents (sorted as in possible_agents)
+ self.obs_dict=self._get_observations()
+ self.agents=[agentforagentinself.possible_agentsifagentinself.obs_dict]
+
+ # return observations
+ returnself.obs_dict,self.extras
+
+
[文档]defstep(self,actions:dict[AgentID,ActionType])->EnvStepReturn:
+"""Execute one time-step of the environment's dynamics.
+
+ The environment steps forward at a fixed time-step, while the physics simulation is decimated at a
+ lower time-step. This is to ensure that the simulation is stable. These two time-steps can be configured
+ independently using the :attr:`DirectMARLEnvCfg.decimation` (number of simulation steps per environment step)
+ and the :attr:`DirectMARLEnvCfg.sim.physics_dt` (physics time-step). Based on these parameters, the environment
+ time-step is computed as the product of the two.
+
+ This function performs the following steps:
+
+ 1. Pre-process the actions before stepping through the physics.
+ 2. Apply the actions to the simulator and step through the physics in a decimated manner.
+ 3. Compute the reward and done signals.
+ 4. Reset environments that have terminated or reached the maximum episode length.
+ 5. Apply interval events if they are enabled.
+ 6. Compute observations.
+
+ Args:
+ actions: The actions to apply on the environment (keyed by the agent ID).
+ Shape of individual tensors is (num_envs, action_dim).
+
+ Returns:
+ A tuple containing the observations, rewards, resets (terminated and truncated) and extras (keyed by the agent ID).
+ """
+ actions={agent:action.to(self.device)foragent,actioninactions.items()}
+
+ # add action noise
+ ifself.cfg.action_noise_model:
+ foragent,actioninactions.items():
+ ifagentinself._action_noise_model:
+ actions[agent]=self._action_noise_model[agent].apply(action)
+ # process actions
+ self._pre_physics_step(actions)
+
+ # check if we need to do rendering within the physics loop
+ # note: checked here once to avoid multiple checks within the loop
+ is_rendering=self.sim.has_gui()orself.sim.has_rtx_sensors()
+
+ # perform physics stepping
+ for_inrange(self.cfg.decimation):
+ self._sim_step_counter+=1
+ # set actions into buffers
+ self._apply_action()
+ # set actions into simulator
+ self.scene.write_data_to_sim()
+ # simulate
+ self.sim.step(render=False)
+ # render between steps only if the GUI or an RTX sensor needs it
+ # note: we assume the render interval to be the shortest accepted rendering interval.
+ # If a camera needs rendering at a faster frequency, this will lead to unexpected behavior.
+ ifself._sim_step_counter%self.cfg.sim.render_interval==0andis_rendering:
+ self.sim.render()
+ # update buffers at sim dt
+ self.scene.update(dt=self.physics_dt)
+
+ # post-step:
+ # -- update env counters (used for curriculum generation)
+ self.episode_length_buf+=1# step in current episode (per env)
+ self.common_step_counter+=1# total step (common for all envs)
+
+ self.terminated_dict,self.time_out_dict=self._get_dones()
+ self.reset_buf[:]=math.prod(self.terminated_dict.values())|math.prod(self.time_out_dict.values())
+ self.reward_dict=self._get_rewards()
+
+ # -- reset envs that terminated/timed-out and log the episode information
+ reset_env_ids=self.reset_buf.nonzero(as_tuple=False).squeeze(-1)
+ iflen(reset_env_ids)>0:
+ self._reset_idx(reset_env_ids)
+
+ # post-step: step interval event
+ ifself.cfg.events:
+ if"interval"inself.event_manager.available_modes:
+ self.event_manager.apply(mode="interval",dt=self.step_dt)
+
+ # update observations and the list of current agents (sorted as in possible_agents)
+ self.obs_dict=self._get_observations()
+ self.agents=[agentforagentinself.possible_agentsifagentinself.obs_dict]
+
+ # add observation noise
+ # note: we apply no noise to the state space (since it is used for centralized training or critic networks)
+ ifself.cfg.observation_noise_model:
+ foragent,obsinself.obs_dict.items():
+ ifagentinself._observation_noise_model:
+ self.obs_dict[agent]=self._observation_noise_model[agent].apply(obs)
+
+ # return observations, rewards, resets and extras
+ returnself.obs_dict,self.reward_dict,self.terminated_dict,self.time_out_dict,self.extras
+
+
[文档]defstate(self)->StateType|None:
+"""Returns the state for the environment.
+
+ The state-space is used for centralized training or asymmetric actor-critic architectures. It is configured
+ using the :attr:`DirectMARLEnvCfg.num_states` parameter.
+
+ Returns:
+ The states for the environment, or None if :attr:`DirectMARLEnvCfg.num_states` parameter is zero.
+ """
+ ifnotself.cfg.num_states:
+ returnNone
+ # concatenate and return the observations as state
+ ifself.cfg.num_states<0:
+ self.state_buf=torch.cat([self.obs_dict[agent]foragentinself.cfg.possible_agents],dim=-1)
+ # compute and return custom environment state
+ else:
+ self.state_buf=self._get_states()
+ returnself.state_buf
+
+
[文档]@staticmethod
+ defseed(seed:int=-1)->int:
+"""Set the seed for the environment.
+
+ Args:
+ seed: The seed for random generator. Defaults to -1.
+
+ Returns:
+ The seed used for random generator.
+ """
+ # set seed for replicator
+ try:
+ importomni.replicator.coreasrep
+
+ rep.set_global_seed(seed)
+ exceptModuleNotFoundError:
+ pass
+ # set seed for torch and other libraries
+ returntorch_utils.set_seed(seed)
+
+
[文档]defrender(self,recompute:bool=False)->np.ndarray|None:
+"""Run rendering without stepping through the physics.
+
+ By convention, if mode is:
+
+ - **human**: Render to the current display and return nothing. Usually for human consumption.
+ - **rgb_array**: Return an numpy.ndarray with shape (x, y, 3), representing RGB values for an
+ x-by-y pixel image, suitable for turning into a video.
+
+ Args:
+ recompute: Whether to force a render even if the simulator has already rendered the scene.
+ Defaults to False.
+
+ Returns:
+ The rendered image as a numpy array if mode is "rgb_array". Otherwise, returns None.
+
+ Raises:
+ RuntimeError: If mode is set to "rgb_data" and simulation render mode does not support it.
+ In this case, the simulation render mode must be set to ``RenderMode.PARTIAL_RENDERING``
+ or ``RenderMode.FULL_RENDERING``.
+ NotImplementedError: If an unsupported rendering mode is specified.
+ """
+ # run a rendering step of the simulator
+ # if we have rtx sensors, we do not need to render again sin
+ ifnotself.sim.has_rtx_sensors()andnotrecompute:
+ self.sim.render()
+ # decide the rendering mode
+ ifself.render_mode=="human"orself.render_modeisNone:
+ returnNone
+ elifself.render_mode=="rgb_array":
+ # check that if any render could have happened
+ ifself.sim.render_mode.value<self.sim.RenderMode.PARTIAL_RENDERING.value:
+ raiseRuntimeError(
+ f"Cannot render '{self.render_mode}' when the simulation render mode is"
+ f" '{self.sim.render_mode.name}'. Please set the simulation render mode to:"
+ f"'{self.sim.RenderMode.PARTIAL_RENDERING.name}' or '{self.sim.RenderMode.FULL_RENDERING.name}'."
+ " If running headless, make sure --enable_cameras is set."
+ )
+ # create the annotator if it does not exist
+ ifnothasattr(self,"_rgb_annotator"):
+ importomni.replicator.coreasrep
+
+ # create render product
+ self._render_product=rep.create.render_product(
+ self.cfg.viewer.cam_prim_path,self.cfg.viewer.resolution
+ )
+ # create rgb annotator -- used to read data from the render product
+ self._rgb_annotator=rep.AnnotatorRegistry.get_annotator("rgb",device="cpu")
+ self._rgb_annotator.attach([self._render_product])
+ # obtain the rgb data
+ rgb_data=self._rgb_annotator.get_data()
+ # convert to numpy array
+ rgb_data=np.frombuffer(rgb_data,dtype=np.uint8).reshape(*rgb_data.shape)
+ # return the rgb data
+ # note: initially the renderer is warming up and returns empty data
+ ifrgb_data.size==0:
+ returnnp.zeros((self.cfg.viewer.resolution[1],self.cfg.viewer.resolution[0],3),dtype=np.uint8)
+ else:
+ returnrgb_data[:,:,:3]
+ else:
+ raiseNotImplementedError(
+ f"Render mode '{self.render_mode}' is not supported. Please use: {self.metadata['render_modes']}."
+ )
+
+
[文档]defclose(self):
+"""Cleanup for the environment."""
+ ifnotself._is_closed:
+ # close entities related to the environment
+ # note: this is order-sensitive to avoid any dangling references
+ ifself.cfg.events:
+ delself.event_manager
+ delself.scene
+ ifself.viewport_camera_controllerisnotNone:
+ delself.viewport_camera_controller
+ # clear callbacks and instance
+ self.sim.clear_all_callbacks()
+ self.sim.clear_instance()
+ # destroy the window
+ ifself._windowisnotNone:
+ self._window=None
+ # update closing status
+ self._is_closed=True
[文档]defset_debug_vis(self,debug_vis:bool)->bool:
+"""Toggles the environment debug visualization.
+
+ Args:
+ debug_vis: Whether to visualize the environment debug visualization.
+
+ Returns:
+ Whether the debug visualization was successfully set. False if the environment
+ does not support debug visualization.
+ """
+ # check if debug visualization is supported
+ ifnotself.has_debug_vis_implementation:
+ returnFalse
+ # toggle debug visualization objects
+ self._set_debug_vis_impl(debug_vis)
+ # toggle debug visualization handles
+ ifdebug_vis:
+ # create a subscriber for the post update event if it doesn't exist
+ ifself._debug_vis_handleisNone:
+ app_interface=omni.kit.app.get_app_interface()
+ self._debug_vis_handle=app_interface.get_post_update_event_stream().create_subscription_to_pop(
+ lambdaevent,obj=weakref.proxy(self):obj._debug_vis_callback(event)
+ )
+ else:
+ # remove the subscriber if it exists
+ ifself._debug_vis_handleisnotNone:
+ self._debug_vis_handle.unsubscribe()
+ self._debug_vis_handle=None
+ # return success
+ returnTrue
+
+"""
+ Helper functions.
+ """
+
+ def_configure_env_spaces(self):
+"""Configure the spaces for the environment."""
+ self.agents=self.cfg.possible_agents
+ self.possible_agents=self.cfg.possible_agents
+
+ # set up observation and action spaces
+ self.observation_spaces={
+ agent:gym.spaces.Box(low=-np.inf,high=np.inf,shape=(self.cfg.num_observations[agent],))
+ foragentinself.cfg.possible_agents
+ }
+ self.action_spaces={
+ agent:gym.spaces.Box(low=-np.inf,high=np.inf,shape=(self.cfg.num_actions[agent],))
+ foragentinself.cfg.possible_agents
+ }
+
+ # set up state space
+ ifnotself.cfg.num_states:
+ self.state_space=None
+ ifself.cfg.num_states<0:
+ self.state_space=gym.spaces.Box(
+ low=-np.inf,high=np.inf,shape=(sum(self.cfg.num_observations.values()),)
+ )
+ else:
+ self.state_space=gym.spaces.Box(low=-np.inf,high=np.inf,shape=(self.cfg.num_states,))
+
+ def_reset_idx(self,env_ids:Sequence[int]):
+"""Reset environments based on specified indices.
+
+ Args:
+ env_ids: List of environment ids which must be reset
+ """
+ self.scene.reset(env_ids)
+
+ # apply events such as randomization for environments that need a reset
+ ifself.cfg.events:
+ if"reset"inself.event_manager.available_modes:
+ env_step_count=self._sim_step_counter//self.cfg.decimation
+ self.event_manager.apply(mode="reset",env_ids=env_ids,global_env_step_count=env_step_count)
+
+ # reset noise models
+ ifself.cfg.action_noise_model:
+ fornoise_modelinself._action_noise_model.values():
+ noise_model.reset(env_ids)
+ ifself.cfg.observation_noise_model:
+ fornoise_modelinself._observation_noise_model.values():
+ noise_model.reset(env_ids)
+
+ # reset the episode length buffer
+ self.episode_length_buf[env_ids]=0
+
+"""
+ Implementation-specific functions.
+ """
+
+ def_setup_scene(self):
+"""Setup the scene for the environment.
+
+ This function is responsible for creating the scene objects and setting up the scene for the environment.
+ The scene creation can happen through :class:`omni.isaac.lab.scene.InteractiveSceneCfg` or through
+ directly creating the scene objects and registering them with the scene manager.
+
+ We leave the implementation of this function to the derived classes. If the environment does not require
+ any explicit scene setup, the function can be left empty.
+ """
+ pass
+
+ @abstractmethod
+ def_pre_physics_step(self,actions:dict[AgentID,ActionType]):
+"""Pre-process actions before stepping through the physics.
+
+ This function is responsible for pre-processing the actions before stepping through the physics.
+ It is called before the physics stepping (which is decimated).
+
+ Args:
+ actions: The actions to apply on the environment (keyed by the agent ID).
+ Shape of individual tensors is (num_envs, action_dim).
+ """
+ raiseNotImplementedError(f"Please implement the '_pre_physics_step' method for {self.__class__.__name__}.")
+
+ @abstractmethod
+ def_apply_action(self):
+"""Apply actions to the simulator.
+
+ This function is responsible for applying the actions to the simulator. It is called at each
+ physics time-step.
+ """
+ raiseNotImplementedError(f"Please implement the '_apply_action' method for {self.__class__.__name__}.")
+
+ @abstractmethod
+ def_get_observations(self)->dict[AgentID,ObsType]:
+"""Compute and return the observations for the environment.
+
+ Returns:
+ The observations for the environment (keyed by the agent ID).
+ """
+ raiseNotImplementedError(f"Please implement the '_get_observations' method for {self.__class__.__name__}.")
+
+ @abstractmethod
+ def_get_states(self)->StateType:
+"""Compute and return the states for the environment.
+
+ This method is only called (and therefore has to be implemented) when the :attr:`DirectMARLEnvCfg.num_states`
+ parameter is greater than zero.
+
+ Returns:
+ The states for the environment.
+ """
+ raiseNotImplementedError(f"Please implement the '_get_states' method for {self.__class__.__name__}.")
+
+ @abstractmethod
+ def_get_rewards(self)->dict[AgentID,torch.Tensor]:
+"""Compute and return the rewards for the environment.
+
+ Returns:
+ The rewards for the environment (keyed by the agent ID).
+ Shape of individual tensors is (num_envs,).
+ """
+ raiseNotImplementedError(f"Please implement the '_get_rewards' method for {self.__class__.__name__}.")
+
+ @abstractmethod
+ def_get_dones(self)->tuple[dict[AgentID,torch.Tensor],dict[AgentID,torch.Tensor]]:
+"""Compute and return the done flags for the environment.
+
+ Returns:
+ A tuple containing the done flags for termination and time-out (keyed by the agent ID).
+ Shape of individual tensors is (num_envs,).
+ """
+ raiseNotImplementedError(f"Please implement the '_get_dones' method for {self.__class__.__name__}.")
+
+ def_set_debug_vis_impl(self,debug_vis:bool):
+"""Set debug visualization into visualization objects.
+
+ This function is responsible for creating the visualization objects if they don't exist
+ and input ``debug_vis`` is True. If the visualization objects exist, the function should
+ set their visibility into the stage.
+ """
+ raiseNotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.")
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+fromdataclassesimportMISSING
+
+fromomni.isaac.lab.sceneimportInteractiveSceneCfg
+fromomni.isaac.lab.simimportSimulationCfg
+fromomni.isaac.lab.utilsimportconfigclass
+fromomni.isaac.lab.utils.noiseimportNoiseModelCfg
+
+from.commonimportAgentID,ViewerCfg
+from.uiimportBaseEnvWindow
+
+
+
[文档]@configclass
+classDirectMARLEnvCfg:
+"""Configuration for a MARL environment defined with the direct workflow.
+
+ Please refer to the :class:`omni.isaac.lab.envs.direct_marl_env.DirectMARLEnv` class for more details.
+ """
+
+ # simulation settings
+ viewer:ViewerCfg=ViewerCfg()
+"""Viewer configuration. Default is ViewerCfg()."""
+
+ sim:SimulationCfg=SimulationCfg()
+"""Physics simulation configuration. Default is SimulationCfg()."""
+
+ # ui settings
+ ui_window_class_type:type|None=BaseEnvWindow
+"""The class type of the UI window. Default is None.
+
+ If None, then no UI window is created.
+
+ Note:
+ If you want to make your own UI window, you can create a class that inherits from
+ from :class:`omni.isaac.lab.envs.ui.base_env_window.BaseEnvWindow`. Then, you can set
+ this attribute to your class type.
+ """
+
+ # general settings
+ seed:int|None=None
+"""The seed for the random number generator. Defaults to None, in which case the seed is not set.
+
+ Note:
+ The seed is set at the beginning of the environment initialization. This ensures that the environment
+ creation is deterministic and behaves similarly across different runs.
+ """
+
+ decimation:int=MISSING
+"""Number of control action updates @ sim dt per policy dt.
+
+ For instance, if the simulation dt is 0.01s and the policy dt is 0.1s, then the decimation is 10.
+ This means that the control action is updated every 10 simulation steps.
+ """
+
+ is_finite_horizon:bool=False
+"""Whether the learning task is treated as a finite or infinite horizon problem for the agent.
+ Defaults to False, which means the task is treated as an infinite horizon problem.
+
+ This flag handles the subtleties of finite and infinite horizon tasks:
+
+ * **Finite horizon**: no penalty or bootstrapping value is required by the the agent for
+ running out of time. However, the environment still needs to terminate the episode after the
+ time limit is reached.
+ * **Infinite horizon**: the agent needs to bootstrap the value of the state at the end of the episode.
+ This is done by sending a time-limit (or truncated) done signal to the agent, which triggers this
+ bootstrapping calculation.
+
+ If True, then the environment is treated as a finite horizon problem and no time-out (or truncated) done signal
+ is sent to the agent. If False, then the environment is treated as an infinite horizon problem and a time-out
+ (or truncated) done signal is sent to the agent.
+
+ Note:
+ The base :class:`ManagerBasedRLEnv` class does not use this flag directly. It is used by the environment
+ wrappers to determine what type of done signal to send to the corresponding learning agent.
+ """
+
+ episode_length_s:float=MISSING
+"""Duration of an episode (in seconds).
+
+ Based on the decimation rate and physics time step, the episode length is calculated as:
+
+ .. code-block:: python
+
+ episode_length_steps = ceil(episode_length_s / (decimation_rate * physics_time_step))
+
+ For example, if the decimation rate is 10, the physics time step is 0.01, and the episode length is 10 seconds,
+ then the episode length in steps is 100.
+ """
+
+ # environment settings
+ scene:InteractiveSceneCfg=MISSING
+"""Scene settings.
+
+ Please refer to the :class:`omni.isaac.lab.scene.InteractiveSceneCfg` class for more details.
+ """
+
+ events:object=None
+"""Event settings. Defaults to None, in which case no events are applied through the event manager.
+
+ Please refer to the :class:`omni.isaac.lab.managers.EventManager` class for more details.
+ """
+
+ num_observations:dict[AgentID,int]=MISSING
+"""The dimension of the observation space from each agent."""
+
+ num_states:int=MISSING
+"""The dimension of the state space from each environment instance.
+
+ The following values are supported:
+
+ * -1: All the observations from the different agents are automatically concatenated.
+ * 0: No state-space will be constructed (`state_space` is None).
+ This is useful to save computational resources when the algorithm to be trained does not need it.
+ * greater than 0: Custom state-space dimension to be provided by the task implementation.
+ """
+
+ observation_noise_model:dict[AgentID,NoiseModelCfg|None]|None=None
+"""The noise model to apply to the computed observations from the environment. Default is None, which means no noise is added.
+
+ Please refer to the :class:`omni.isaac.lab.utils.noise.NoiseModel` class for more details.
+ """
+
+ num_actions:dict[AgentID,int]=MISSING
+"""The dimension of the action space for each agent."""
+
+ action_noise_model:dict[AgentID,NoiseModelCfg|None]|None=None
+"""The noise model applied to the actions provided to the environment. Default is None, which means no noise is added.
+
+ Please refer to the :class:`omni.isaac.lab.utils.noise.NoiseModel` class for more details.
+ """
+
+ possible_agents:list[AgentID]=MISSING
+"""A list of all possible agents the environment could generate.
+
+ The contents of the list cannot be modified during the entire training process.
+ """
[文档]classDirectRLEnv(gym.Env):
+"""The superclass for the direct workflow to design environments.
+
+ This class implements the core functionality for reinforcement learning (RL)
+ environments. It is designed to be used with any RL library. The class is designed
+ to be used with vectorized environments, i.e., the environment is expected to be run
+ in parallel with multiple sub-environments.
+
+ While the environment itself is implemented as a vectorized environment, we do not
+ inherit from :class:`gym.vector.VectorEnv`. This is mainly because the class adds
+ various methods (for wait and asynchronous updates) which are not required.
+ Additionally, each RL library typically has its own definition for a vectorized
+ environment. Thus, to reduce complexity, we directly use the :class:`gym.Env` over
+ here and leave it up to library-defined wrappers to take care of wrapping this
+ environment for their agents.
+
+ Note:
+ For vectorized environments, it is recommended to **only** call the :meth:`reset`
+ method once before the first call to :meth:`step`, i.e. after the environment is created.
+ After that, the :meth:`step` function handles the reset of terminated sub-environments.
+ This is because the simulator does not support resetting individual sub-environments
+ in a vectorized environment.
+
+ """
+
+ is_vector_env:ClassVar[bool]=True
+"""Whether the environment is a vectorized environment."""
+ metadata:ClassVar[dict[str,Any]]={
+ "render_modes":[None,"human","rgb_array"],
+ "isaac_sim_version":get_version(),
+ }
+"""Metadata for the environment."""
+
+
[文档]def__init__(self,cfg:DirectRLEnvCfg,render_mode:str|None=None,**kwargs):
+"""Initialize the environment.
+
+ Args:
+ cfg: The configuration object for the environment.
+ render_mode: The render mode for the environment. Defaults to None, which
+ is similar to ``"human"``.
+
+ Raises:
+ RuntimeError: If a simulation context already exists. The environment must always create one
+ since it configures the simulation context and controls the simulation.
+ """
+ # store inputs to class
+ self.cfg=cfg
+ # store the render mode
+ self.render_mode=render_mode
+ # initialize internal variables
+ self._is_closed=False
+
+ # set the seed for the environment
+ ifself.cfg.seedisnotNone:
+ self.seed(self.cfg.seed)
+ else:
+ carb.log_warn("Seed not set for the environment. The environment creation may not be deterministic.")
+
+ # create a simulation context to control the simulator
+ ifSimulationContext.instance()isNone:
+ self.sim:SimulationContext=SimulationContext(self.cfg.sim)
+ else:
+ raiseRuntimeError("Simulation context already exists. Cannot create a new one.")
+
+ # print useful information
+ print("[INFO]: Base environment:")
+ print(f"\tEnvironment device : {self.device}")
+ print(f"\tEnvironment seed : {self.cfg.seed}")
+ print(f"\tPhysics step-size : {self.physics_dt}")
+ print(f"\tRendering step-size : {self.physics_dt*self.cfg.sim.render_interval}")
+ print(f"\tEnvironment step-size : {self.step_dt}")
+
+ ifself.cfg.sim.render_interval<self.cfg.decimation:
+ msg=(
+ f"The render interval ({self.cfg.sim.render_interval}) is smaller than the decimation "
+ f"({self.cfg.decimation}). Multiple multiple render calls will happen for each environment step."
+ "If this is not intended, set the render interval to be equal to the decimation."
+ )
+ carb.log_warn(msg)
+
+ # generate scene
+ withTimer("[INFO]: Time taken for scene creation","scene_creation"):
+ self.scene=InteractiveScene(self.cfg.scene)
+ self._setup_scene()
+ print("[INFO]: Scene manager: ",self.scene)
+
+ # set up camera viewport controller
+ # viewport is not available in other rendering modes so the function will throw a warning
+ # FIXME: This needs to be fixed in the future when we unify the UI functionalities even for
+ # non-rendering modes.
+ ifself.sim.render_mode>=self.sim.RenderMode.PARTIAL_RENDERING:
+ self.viewport_camera_controller=ViewportCameraController(self,self.cfg.viewer)
+ else:
+ self.viewport_camera_controller=None
+
+ # play the simulator to activate physics handles
+ # note: this activates the physics simulation view that exposes TensorAPIs
+ # note: when started in extension mode, first call sim.reset_async() and then initialize the managers
+ ifbuiltins.ISAAC_LAUNCHED_FROM_TERMINALisFalse:
+ print("[INFO]: Starting the simulation. This may take a few seconds. Please wait...")
+ withTimer("[INFO]: Time taken for simulation start","simulation_start"):
+ self.sim.reset()
+
+ # -- event manager used for randomization
+ ifself.cfg.events:
+ self.event_manager=EventManager(self.cfg.events,self)
+ print("[INFO] Event Manager: ",self.event_manager)
+
+ # make sure torch is running on the correct device
+ if"cuda"inself.device:
+ torch.cuda.set_device(self.device)
+
+ # check if debug visualization is has been implemented by the environment
+ source_code=inspect.getsource(self._set_debug_vis_impl)
+ self.has_debug_vis_implementation="NotImplementedError"notinsource_code
+ self._debug_vis_handle=None
+
+ # extend UI elements
+ # we need to do this here after all the managers are initialized
+ # this is because they dictate the sensors and commands right now
+ ifself.sim.has_gui()andself.cfg.ui_window_class_typeisnotNone:
+ self._window=self.cfg.ui_window_class_type(self,window_name="IsaacLab")
+ else:
+ # if no window, then we don't need to store the window
+ self._window=None
+
+ # allocate dictionary to store metrics
+ self.extras={}
+
+ # initialize data and constants
+ # -- counter for simulation steps
+ self._sim_step_counter=0
+ # -- counter for curriculum
+ self.common_step_counter=0
+ # -- init buffers
+ self.episode_length_buf=torch.zeros(self.num_envs,device=self.device,dtype=torch.long)
+ self.reset_terminated=torch.zeros(self.num_envs,device=self.device,dtype=torch.bool)
+ self.reset_time_outs=torch.zeros_like(self.reset_terminated)
+ self.reset_buf=torch.zeros(self.num_envs,dtype=torch.bool,device=self.sim.device)
+ self.actions=torch.zeros(self.num_envs,self.cfg.num_actions,device=self.sim.device)
+
+ # setup the action and observation spaces for Gym
+ self._configure_gym_env_spaces()
+
+ # setup noise cfg for adding action and observation noise
+ ifself.cfg.action_noise_model:
+ self._action_noise_model:NoiseModel=self.cfg.action_noise_model.class_type(
+ self.cfg.action_noise_model,num_envs=self.num_envs,device=self.device
+ )
+ ifself.cfg.observation_noise_model:
+ self._observation_noise_model:NoiseModel=self.cfg.observation_noise_model.class_type(
+ self.cfg.observation_noise_model,num_envs=self.num_envs,device=self.device
+ )
+
+ # perform events at the start of the simulation
+ ifself.cfg.events:
+ if"startup"inself.event_manager.available_modes:
+ self.event_manager.apply(mode="startup")
+
+ # -- set the framerate of the gym video recorder wrapper so that the playback speed of the produced video matches the simulation
+ self.metadata["render_fps"]=1/self.step_dt
+
+ # print the environment information
+ print("[INFO]: Completed setting up the environment...")
+
+ def__del__(self):
+"""Cleanup for the environment."""
+ self.close()
+
+"""
+ Properties.
+ """
+
+ @property
+ defnum_envs(self)->int:
+"""The number of instances of the environment that are running."""
+ returnself.scene.num_envs
+
+ @property
+ defphysics_dt(self)->float:
+"""The physics time-step (in s).
+
+ This is the lowest time-decimation at which the simulation is happening.
+ """
+ returnself.cfg.sim.dt
+
+ @property
+ defstep_dt(self)->float:
+"""The environment stepping time-step (in s).
+
+ This is the time-step at which the environment steps forward.
+ """
+ returnself.cfg.sim.dt*self.cfg.decimation
+
+ @property
+ defdevice(self):
+"""The device on which the environment is running."""
+ returnself.sim.device
+
+ @property
+ defmax_episode_length_s(self)->float:
+"""Maximum episode length in seconds."""
+ returnself.cfg.episode_length_s
+
+ @property
+ defmax_episode_length(self):
+"""The maximum episode length in steps adjusted from s."""
+ returnmath.ceil(self.max_episode_length_s/(self.cfg.sim.dt*self.cfg.decimation))
+
+"""
+ Operations.
+ """
+
+
[文档]defreset(self,seed:int|None=None,options:dict[str,Any]|None=None)->tuple[VecEnvObs,dict]:
+"""Resets all the environments and returns observations.
+
+ This function calls the :meth:`_reset_idx` function to reset all the environments.
+ However, certain operations, such as procedural terrain generation, that happened during initialization
+ are not repeated.
+
+ Args:
+ seed: The seed to use for randomization. Defaults to None, in which case the seed is not set.
+ options: Additional information to specify how the environment is reset. Defaults to None.
+
+ Note:
+ This argument is used for compatibility with Gymnasium environment definition.
+
+ Returns:
+ A tuple containing the observations and extras.
+ """
+ # set the seed
+ ifseedisnotNone:
+ self.seed(seed)
+
+ # reset state of scene
+ indices=torch.arange(self.num_envs,dtype=torch.int64,device=self.device)
+ self._reset_idx(indices)
+
+ # if sensors are added to the scene, make sure we render to reflect changes in reset
+ ifself.sim.has_rtx_sensors()andself.cfg.rerender_on_reset:
+ self.sim.render()
+
+ # return observations
+ returnself._get_observations(),self.extras
+
+
[文档]defstep(self,action:torch.Tensor)->VecEnvStepReturn:
+"""Execute one time-step of the environment's dynamics.
+
+ The environment steps forward at a fixed time-step, while the physics simulation is decimated at a
+ lower time-step. This is to ensure that the simulation is stable. These two time-steps can be configured
+ independently using the :attr:`DirectRLEnvCfg.decimation` (number of simulation steps per environment step)
+ and the :attr:`DirectRLEnvCfg.sim.physics_dt` (physics time-step). Based on these parameters, the environment
+ time-step is computed as the product of the two.
+
+ This function performs the following steps:
+
+ 1. Pre-process the actions before stepping through the physics.
+ 2. Apply the actions to the simulator and step through the physics in a decimated manner.
+ 3. Compute the reward and done signals.
+ 4. Reset environments that have terminated or reached the maximum episode length.
+ 5. Apply interval events if they are enabled.
+ 6. Compute observations.
+
+ Args:
+ action: The actions to apply on the environment. Shape is (num_envs, action_dim).
+
+ Returns:
+ A tuple containing the observations, rewards, resets (terminated and truncated) and extras.
+ """
+ action=action.to(self.device)
+ # add action noise
+ ifself.cfg.action_noise_model:
+ action=self._action_noise_model.apply(action)
+
+ # process actions
+ self._pre_physics_step(action)
+
+ # check if we need to do rendering within the physics loop
+ # note: checked here once to avoid multiple checks within the loop
+ is_rendering=self.sim.has_gui()orself.sim.has_rtx_sensors()
+
+ # perform physics stepping
+ for_inrange(self.cfg.decimation):
+ self._sim_step_counter+=1
+ # set actions into buffers
+ self._apply_action()
+ # set actions into simulator
+ self.scene.write_data_to_sim()
+ # simulate
+ self.sim.step(render=False)
+ # render between steps only if the GUI or an RTX sensor needs it
+ # note: we assume the render interval to be the shortest accepted rendering interval.
+ # If a camera needs rendering at a faster frequency, this will lead to unexpected behavior.
+ ifself._sim_step_counter%self.cfg.sim.render_interval==0andis_rendering:
+ self.sim.render()
+ # update buffers at sim dt
+ self.scene.update(dt=self.physics_dt)
+
+ # post-step:
+ # -- update env counters (used for curriculum generation)
+ self.episode_length_buf+=1# step in current episode (per env)
+ self.common_step_counter+=1# total step (common for all envs)
+
+ self.reset_terminated[:],self.reset_time_outs[:]=self._get_dones()
+ self.reset_buf=self.reset_terminated|self.reset_time_outs
+ self.reward_buf=self._get_rewards()
+
+ # -- reset envs that terminated/timed-out and log the episode information
+ reset_env_ids=self.reset_buf.nonzero(as_tuple=False).squeeze(-1)
+ iflen(reset_env_ids)>0:
+ self._reset_idx(reset_env_ids)
+ # if sensors are added to the scene, make sure we render to reflect changes in reset
+ ifself.sim.has_rtx_sensors()andself.cfg.rerender_on_reset:
+ self.sim.render()
+
+ # post-step: step interval event
+ ifself.cfg.events:
+ if"interval"inself.event_manager.available_modes:
+ self.event_manager.apply(mode="interval",dt=self.step_dt)
+
+ # update observations
+ self.obs_buf=self._get_observations()
+
+ # add observation noise
+ # note: we apply no noise to the state space (since it is used for critic networks)
+ ifself.cfg.observation_noise_model:
+ self.obs_buf["policy"]=self._observation_noise_model.apply(self.obs_buf["policy"])
+
+ # return observations, rewards, resets and extras
+ returnself.obs_buf,self.reward_buf,self.reset_terminated,self.reset_time_outs,self.extras
+
+
[文档]@staticmethod
+ defseed(seed:int=-1)->int:
+"""Set the seed for the environment.
+
+ Args:
+ seed: The seed for random generator. Defaults to -1.
+
+ Returns:
+ The seed used for random generator.
+ """
+ # set seed for replicator
+ try:
+ importomni.replicator.coreasrep
+
+ rep.set_global_seed(seed)
+ exceptModuleNotFoundError:
+ pass
+ # set seed for torch and other libraries
+ returntorch_utils.set_seed(seed)
+
+
[文档]defrender(self,recompute:bool=False)->np.ndarray|None:
+"""Run rendering without stepping through the physics.
+
+ By convention, if mode is:
+
+ - **human**: Render to the current display and return nothing. Usually for human consumption.
+ - **rgb_array**: Return an numpy.ndarray with shape (x, y, 3), representing RGB values for an
+ x-by-y pixel image, suitable for turning into a video.
+
+ Args:
+ recompute: Whether to force a render even if the simulator has already rendered the scene.
+ Defaults to False.
+
+ Returns:
+ The rendered image as a numpy array if mode is "rgb_array". Otherwise, returns None.
+
+ Raises:
+ RuntimeError: If mode is set to "rgb_data" and simulation render mode does not support it.
+ In this case, the simulation render mode must be set to ``RenderMode.PARTIAL_RENDERING``
+ or ``RenderMode.FULL_RENDERING``.
+ NotImplementedError: If an unsupported rendering mode is specified.
+ """
+ # run a rendering step of the simulator
+ # if we have rtx sensors, we do not need to render again sin
+ ifnotself.sim.has_rtx_sensors()andnotrecompute:
+ self.sim.render()
+ # decide the rendering mode
+ ifself.render_mode=="human"orself.render_modeisNone:
+ returnNone
+ elifself.render_mode=="rgb_array":
+ # check that if any render could have happened
+ ifself.sim.render_mode.value<self.sim.RenderMode.PARTIAL_RENDERING.value:
+ raiseRuntimeError(
+ f"Cannot render '{self.render_mode}' when the simulation render mode is"
+ f" '{self.sim.render_mode.name}'. Please set the simulation render mode to:"
+ f"'{self.sim.RenderMode.PARTIAL_RENDERING.name}' or '{self.sim.RenderMode.FULL_RENDERING.name}'."
+ " If running headless, make sure --enable_cameras is set."
+ )
+ # create the annotator if it does not exist
+ ifnothasattr(self,"_rgb_annotator"):
+ importomni.replicator.coreasrep
+
+ # create render product
+ self._render_product=rep.create.render_product(
+ self.cfg.viewer.cam_prim_path,self.cfg.viewer.resolution
+ )
+ # create rgb annotator -- used to read data from the render product
+ self._rgb_annotator=rep.AnnotatorRegistry.get_annotator("rgb",device="cpu")
+ self._rgb_annotator.attach([self._render_product])
+ # obtain the rgb data
+ rgb_data=self._rgb_annotator.get_data()
+ # convert to numpy array
+ rgb_data=np.frombuffer(rgb_data,dtype=np.uint8).reshape(*rgb_data.shape)
+ # return the rgb data
+ # note: initially the renerer is warming up and returns empty data
+ ifrgb_data.size==0:
+ returnnp.zeros((self.cfg.viewer.resolution[1],self.cfg.viewer.resolution[0],3),dtype=np.uint8)
+ else:
+ returnrgb_data[:,:,:3]
+ else:
+ raiseNotImplementedError(
+ f"Render mode '{self.render_mode}' is not supported. Please use: {self.metadata['render_modes']}."
+ )
+
+
[文档]defclose(self):
+"""Cleanup for the environment."""
+ ifnotself._is_closed:
+ # close entities related to the environment
+ # note: this is order-sensitive to avoid any dangling references
+ ifself.cfg.events:
+ delself.event_manager
+ delself.scene
+ ifself.viewport_camera_controllerisnotNone:
+ delself.viewport_camera_controller
+ # clear callbacks and instance
+ self.sim.clear_all_callbacks()
+ self.sim.clear_instance()
+ # destroy the window
+ ifself._windowisnotNone:
+ self._window=None
+ # update closing status
+ self._is_closed=True
[文档]defset_debug_vis(self,debug_vis:bool)->bool:
+"""Toggles the environment debug visualization.
+
+ Args:
+ debug_vis: Whether to visualize the environment debug visualization.
+
+ Returns:
+ Whether the debug visualization was successfully set. False if the environment
+ does not support debug visualization.
+ """
+ # check if debug visualization is supported
+ ifnotself.has_debug_vis_implementation:
+ returnFalse
+ # toggle debug visualization objects
+ self._set_debug_vis_impl(debug_vis)
+ # toggle debug visualization handles
+ ifdebug_vis:
+ # create a subscriber for the post update event if it doesn't exist
+ ifself._debug_vis_handleisNone:
+ app_interface=omni.kit.app.get_app_interface()
+ self._debug_vis_handle=app_interface.get_post_update_event_stream().create_subscription_to_pop(
+ lambdaevent,obj=weakref.proxy(self):obj._debug_vis_callback(event)
+ )
+ else:
+ # remove the subscriber if it exists
+ ifself._debug_vis_handleisnotNone:
+ self._debug_vis_handle.unsubscribe()
+ self._debug_vis_handle=None
+ # return success
+ returnTrue
+
+"""
+ Helper functions.
+ """
+
+ def_configure_gym_env_spaces(self):
+"""Configure the action and observation spaces for the Gym environment."""
+ # observation space (unbounded since we don't impose any limits)
+ self.num_actions=self.cfg.num_actions
+ self.num_observations=self.cfg.num_observations
+ self.num_states=self.cfg.num_states
+
+ # set up spaces
+ self.single_observation_space=gym.spaces.Dict()
+ self.single_observation_space["policy"]=gym.spaces.Box(
+ low=-np.inf,high=np.inf,shape=(self.num_observations,)
+ )
+ self.single_action_space=gym.spaces.Box(low=-np.inf,high=np.inf,shape=(self.num_actions,))
+
+ # batch the spaces for vectorized environments
+ self.observation_space=gym.vector.utils.batch_space(self.single_observation_space["policy"],self.num_envs)
+ self.action_space=gym.vector.utils.batch_space(self.single_action_space,self.num_envs)
+
+ # optional state space for asymmetric actor-critic architectures
+ ifself.num_states>0:
+ self.single_observation_space["critic"]=gym.spaces.Box(low=-np.inf,high=np.inf,shape=(self.num_states,))
+ self.state_space=gym.vector.utils.batch_space(self.single_observation_space["critic"],self.num_envs)
+
+ def_reset_idx(self,env_ids:Sequence[int]):
+"""Reset environments based on specified indices.
+
+ Args:
+ env_ids: List of environment ids which must be reset
+ """
+ self.scene.reset(env_ids)
+
+ # apply events such as randomization for environments that need a reset
+ ifself.cfg.events:
+ if"reset"inself.event_manager.available_modes:
+ env_step_count=self._sim_step_counter//self.cfg.decimation
+ self.event_manager.apply(mode="reset",env_ids=env_ids,global_env_step_count=env_step_count)
+
+ # reset noise models
+ ifself.cfg.action_noise_model:
+ self._action_noise_model.reset(env_ids)
+ ifself.cfg.observation_noise_model:
+ self._observation_noise_model.reset(env_ids)
+
+ # reset the episode length buffer
+ self.episode_length_buf[env_ids]=0
+
+"""
+ Implementation-specific functions.
+ """
+
+ def_setup_scene(self):
+"""Setup the scene for the environment.
+
+ This function is responsible for creating the scene objects and setting up the scene for the environment.
+ The scene creation can happen through :class:`omni.isaac.lab.scene.InteractiveSceneCfg` or through
+ directly creating the scene objects and registering them with the scene manager.
+
+ We leave the implementation of this function to the derived classes. If the environment does not require
+ any explicit scene setup, the function can be left empty.
+ """
+ pass
+
+ @abstractmethod
+ def_pre_physics_step(self,actions:torch.Tensor):
+"""Pre-process actions before stepping through the physics.
+
+ This function is responsible for pre-processing the actions before stepping through the physics.
+ It is called before the physics stepping (which is decimated).
+
+ Args:
+ actions: The actions to apply on the environment. Shape is (num_envs, action_dim).
+ """
+ raiseNotImplementedError(f"Please implement the '_pre_physics_step' method for {self.__class__.__name__}.")
+
+ @abstractmethod
+ def_apply_action(self):
+"""Apply actions to the simulator.
+
+ This function is responsible for applying the actions to the simulator. It is called at each
+ physics time-step.
+ """
+ raiseNotImplementedError(f"Please implement the '_apply_action' method for {self.__class__.__name__}.")
+
+ @abstractmethod
+ def_get_observations(self)->VecEnvObs:
+"""Compute and return the observations for the environment.
+
+ Returns:
+ The observations for the environment.
+ """
+ raiseNotImplementedError(f"Please implement the '_get_observations' method for {self.__class__.__name__}.")
+
+ def_get_states(self)->VecEnvObs|None:
+"""Compute and return the states for the environment.
+
+ The state-space is used for asymmetric actor-critic architectures. It is configured
+ using the :attr:`DirectRLEnvCfg.num_states` parameter.
+
+ Returns:
+ The states for the environment. If the environment does not have a state-space, the function
+ returns a None.
+ """
+ returnNone# noqa: R501
+
+ @abstractmethod
+ def_get_rewards(self)->torch.Tensor:
+"""Compute and return the rewards for the environment.
+
+ Returns:
+ The rewards for the environment. Shape is (num_envs,).
+ """
+ raiseNotImplementedError(f"Please implement the '_get_rewards' method for {self.__class__.__name__}.")
+
+ @abstractmethod
+ def_get_dones(self)->tuple[torch.Tensor,torch.Tensor]:
+"""Compute and return the done flags for the environment.
+
+ Returns:
+ A tuple containing the done flags for termination and time-out.
+ Shape of individual tensors is (num_envs,).
+ """
+ raiseNotImplementedError(f"Please implement the '_get_dones' method for {self.__class__.__name__}.")
+
+ def_set_debug_vis_impl(self,debug_vis:bool):
+"""Set debug visualization into visualization objects.
+
+ This function is responsible for creating the visualization objects if they don't exist
+ and input ``debug_vis`` is True. If the visualization objects exist, the function should
+ set their visibility into the stage.
+ """
+ raiseNotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.")
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+fromdataclassesimportMISSING
+
+fromomni.isaac.lab.sceneimportInteractiveSceneCfg
+fromomni.isaac.lab.simimportSimulationCfg
+fromomni.isaac.lab.utilsimportconfigclass
+fromomni.isaac.lab.utils.noiseimportNoiseModelCfg
+
+from.commonimportViewerCfg
+from.uiimportBaseEnvWindow
+
+
+
[文档]@configclass
+classDirectRLEnvCfg:
+"""Configuration for an RL environment defined with the direct workflow.
+
+ Please refer to the :class:`omni.isaac.lab.envs.direct_rl_env.DirectRLEnv` class for more details.
+ """
+
+ # simulation settings
+ viewer:ViewerCfg=ViewerCfg()
+"""Viewer configuration. Default is ViewerCfg()."""
+
+ sim:SimulationCfg=SimulationCfg()
+"""Physics simulation configuration. Default is SimulationCfg()."""
+
+ # ui settings
+ ui_window_class_type:type|None=BaseEnvWindow
+"""The class type of the UI window. Default is None.
+
+ If None, then no UI window is created.
+
+ Note:
+ If you want to make your own UI window, you can create a class that inherits from
+ from :class:`omni.isaac.lab.envs.ui.base_env_window.BaseEnvWindow`. Then, you can set
+ this attribute to your class type.
+ """
+
+ # general settings
+ seed:int|None=None
+"""The seed for the random number generator. Defaults to None, in which case the seed is not set.
+
+ Note:
+ The seed is set at the beginning of the environment initialization. This ensures that the environment
+ creation is deterministic and behaves similarly across different runs.
+ """
+
+ decimation:int=MISSING
+"""Number of control action updates @ sim dt per policy dt.
+
+ For instance, if the simulation dt is 0.01s and the policy dt is 0.1s, then the decimation is 10.
+ This means that the control action is updated every 10 simulation steps.
+ """
+
+ is_finite_horizon:bool=False
+"""Whether the learning task is treated as a finite or infinite horizon problem for the agent.
+ Defaults to False, which means the task is treated as an infinite horizon problem.
+
+ This flag handles the subtleties of finite and infinite horizon tasks:
+
+ * **Finite horizon**: no penalty or bootstrapping value is required by the the agent for
+ running out of time. However, the environment still needs to terminate the episode after the
+ time limit is reached.
+ * **Infinite horizon**: the agent needs to bootstrap the value of the state at the end of the episode.
+ This is done by sending a time-limit (or truncated) done signal to the agent, which triggers this
+ bootstrapping calculation.
+
+ If True, then the environment is treated as a finite horizon problem and no time-out (or truncated) done signal
+ is sent to the agent. If False, then the environment is treated as an infinite horizon problem and a time-out
+ (or truncated) done signal is sent to the agent.
+
+ Note:
+ The base :class:`ManagerBasedRLEnv` class does not use this flag directly. It is used by the environment
+ wrappers to determine what type of done signal to send to the corresponding learning agent.
+ """
+
+ episode_length_s:float=MISSING
+"""Duration of an episode (in seconds).
+
+ Based on the decimation rate and physics time step, the episode length is calculated as:
+
+ .. code-block:: python
+
+ episode_length_steps = ceil(episode_length_s / (decimation_rate * physics_time_step))
+
+ For example, if the decimation rate is 10, the physics time step is 0.01, and the episode length is 10 seconds,
+ then the episode length in steps is 100.
+ """
+
+ # environment settings
+ scene:InteractiveSceneCfg=MISSING
+"""Scene settings.
+
+ Please refer to the :class:`omni.isaac.lab.scene.InteractiveSceneCfg` class for more details.
+ """
+
+ events:object=None
+"""Event settings. Defaults to None, in which case no events are applied through the event manager.
+
+ Please refer to the :class:`omni.isaac.lab.managers.EventManager` class for more details.
+ """
+
+ num_observations:int=MISSING
+"""The dimension of the observation space from each environment instance."""
+
+ num_states:int=0
+"""The dimension of the state-space from each environment instance. Default is 0, which means no state-space is defined.
+
+ This is useful for asymmetric actor-critic and defines the observation space for the critic.
+ """
+
+ observation_noise_model:NoiseModelCfg|None=None
+"""The noise model to apply to the computed observations from the environment. Default is None, which means no noise is added.
+
+ Please refer to the :class:`omni.isaac.lab.utils.noise.NoiseModel` class for more details.
+ """
+
+ num_actions:int=MISSING
+"""The dimension of the action space for each environment."""
+
+ action_noise_model:NoiseModelCfg|None=None
+"""The noise model applied to the actions provided to the environment. Default is None, which means no noise is added.
+
+ Please refer to the :class:`omni.isaac.lab.utils.noise.NoiseModel` class for more details.
+ """
+
+ rerender_on_reset:bool=False
+"""Whether a render step is performed again after at least one environment has been reset.
+ Defaults to False, which means no render step will be performed after reset.
+
+ * When this is False, data collected from sensors after performing reset will be stale and will not reflect the
+ latest states in simulation caused by the reset.
+ * When this is True, an extra render step will be performed to update the sensor data
+ to reflect the latest states from the reset. This comes at a cost of performance as an additional render
+ step will be performed after each time an environment is reset.
+
+ """
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+importbuiltins
+importtorch
+fromcollections.abcimportSequence
+fromtypingimportAny
+
+importcarb
+importomni.isaac.core.utils.torchastorch_utils
+
+fromomni.isaac.lab.managersimportActionManager,EventManager,ObservationManager
+fromomni.isaac.lab.sceneimportInteractiveScene
+fromomni.isaac.lab.simimportSimulationContext
+fromomni.isaac.lab.utils.timerimportTimer
+
+from.commonimportVecEnvObs
+from.manager_based_env_cfgimportManagerBasedEnvCfg
+from.uiimportViewportCameraController
+
+
+
[文档]classManagerBasedEnv:
+"""The base environment encapsulates the simulation scene and the environment managers for the manager-based workflow.
+
+ While a simulation scene or world comprises of different components such as the robots, objects,
+ and sensors (cameras, lidars, etc.), the environment is a higher level abstraction
+ that provides an interface for interacting with the simulation. The environment is comprised of
+ the following components:
+
+ * **Scene**: The scene manager that creates and manages the virtual world in which the robot operates.
+ This includes defining the robot, static and dynamic objects, sensors, etc.
+ * **Observation Manager**: The observation manager that generates observations from the current simulation
+ state and the data gathered from the sensors. These observations may include privileged information
+ that is not available to the robot in the real world. Additionally, user-defined terms can be added
+ to process the observations and generate custom observations. For example, using a network to embed
+ high-dimensional observations into a lower-dimensional space.
+ * **Action Manager**: The action manager that processes the raw actions sent to the environment and
+ converts them to low-level commands that are sent to the simulation. It can be configured to accept
+ raw actions at different levels of abstraction. For example, in case of a robotic arm, the raw actions
+ can be joint torques, joint positions, or end-effector poses. Similarly for a mobile base, it can be
+ the joint torques, or the desired velocity of the floating base.
+ * **Event Manager**: The event manager orchestrates operations triggered based on simulation events.
+ This includes resetting the scene to a default state, applying random pushes to the robot at different intervals
+ of time, or randomizing properties such as mass and friction coefficients. This is useful for training
+ and evaluating the robot in a variety of scenarios.
+
+ The environment provides a unified interface for interacting with the simulation. However, it does not
+ include task-specific quantities such as the reward function, or the termination conditions. These
+ quantities are often specific to defining Markov Decision Processes (MDPs) while the base environment
+ is agnostic to the MDP definition.
+
+ The environment steps forward in time at a fixed time-step. The physics simulation is decimated at a
+ lower time-step. This is to ensure that the simulation is stable. These two time-steps can be configured
+ independently using the :attr:`ManagerBasedEnvCfg.decimation` (number of simulation steps per environment step)
+ and the :attr:`ManagerBasedEnvCfg.sim.dt` (physics time-step) parameters. Based on these parameters, the
+ environment time-step is computed as the product of the two. The two time-steps can be obtained by
+ querying the :attr:`physics_dt` and the :attr:`step_dt` properties respectively.
+ """
+
+
[文档]def__init__(self,cfg:ManagerBasedEnvCfg):
+"""Initialize the environment.
+
+ Args:
+ cfg: The configuration object for the environment.
+
+ Raises:
+ RuntimeError: If a simulation context already exists. The environment must always create one
+ since it configures the simulation context and controls the simulation.
+ """
+ # store inputs to class
+ self.cfg=cfg
+ # initialize internal variables
+ self._is_closed=False
+
+ # set the seed for the environment
+ ifself.cfg.seedisnotNone:
+ self.seed(self.cfg.seed)
+ else:
+ carb.log_warn("Seed not set for the environment. The environment creation may not be deterministic.")
+
+ # create a simulation context to control the simulator
+ ifSimulationContext.instance()isNone:
+ # the type-annotation is required to avoid a type-checking error
+ # since it gets confused with Isaac Sim's SimulationContext class
+ self.sim:SimulationContext=SimulationContext(self.cfg.sim)
+ else:
+ # simulation context should only be created before the environment
+ # when in extension mode
+ ifnotbuiltins.ISAAC_LAUNCHED_FROM_TERMINAL:
+ raiseRuntimeError("Simulation context already exists. Cannot create a new one.")
+ self.sim:SimulationContext=SimulationContext.instance()
+
+ # print useful information
+ print("[INFO]: Base environment:")
+ print(f"\tEnvironment device : {self.device}")
+ print(f"\tEnvironment seed : {self.cfg.seed}")
+ print(f"\tPhysics step-size : {self.physics_dt}")
+ print(f"\tRendering step-size : {self.physics_dt*self.cfg.sim.render_interval}")
+ print(f"\tEnvironment step-size : {self.step_dt}")
+
+ ifself.cfg.sim.render_interval<self.cfg.decimation:
+ msg=(
+ f"The render interval ({self.cfg.sim.render_interval}) is smaller than the decimation "
+ f"({self.cfg.decimation}). Multiple multiple render calls will happen for each environment step. "
+ "If this is not intended, set the render interval to be equal to the decimation."
+ )
+ carb.log_warn(msg)
+
+ # counter for simulation steps
+ self._sim_step_counter=0
+
+ # generate scene
+ withTimer("[INFO]: Time taken for scene creation","scene_creation"):
+ self.scene=InteractiveScene(self.cfg.scene)
+ print("[INFO]: Scene manager: ",self.scene)
+
+ # set up camera viewport controller
+ # viewport is not available in other rendering modes so the function will throw a warning
+ # FIXME: This needs to be fixed in the future when we unify the UI functionalities even for
+ # non-rendering modes.
+ ifself.sim.render_mode>=self.sim.RenderMode.PARTIAL_RENDERING:
+ self.viewport_camera_controller=ViewportCameraController(self,self.cfg.viewer)
+ else:
+ self.viewport_camera_controller=None
+
+ # play the simulator to activate physics handles
+ # note: this activates the physics simulation view that exposes TensorAPIs
+ # note: when started in extension mode, first call sim.reset_async() and then initialize the managers
+ ifbuiltins.ISAAC_LAUNCHED_FROM_TERMINALisFalse:
+ print("[INFO]: Starting the simulation. This may take a few seconds. Please wait...")
+ withTimer("[INFO]: Time taken for simulation start","simulation_start"):
+ self.sim.reset()
+ # add timeline event to load managers
+ self.load_managers()
+
+ # make sure torch is running on the correct device
+ if"cuda"inself.device:
+ torch.cuda.set_device(self.device)
+
+ # extend UI elements
+ # we need to do this here after all the managers are initialized
+ # this is because they dictate the sensors and commands right now
+ ifself.sim.has_gui()andself.cfg.ui_window_class_typeisnotNone:
+ self._window=self.cfg.ui_window_class_type(self,window_name="IsaacLab")
+ else:
+ # if no window, then we don't need to store the window
+ self._window=None
+
+ # allocate dictionary to store metrics
+ self.extras={}
+
+ def__del__(self):
+"""Cleanup for the environment."""
+ self.close()
+
+"""
+ Properties.
+ """
+
+ @property
+ defnum_envs(self)->int:
+"""The number of instances of the environment that are running."""
+ returnself.scene.num_envs
+
+ @property
+ defphysics_dt(self)->float:
+"""The physics time-step (in s).
+
+ This is the lowest time-decimation at which the simulation is happening.
+ """
+ returnself.cfg.sim.dt
+
+ @property
+ defstep_dt(self)->float:
+"""The environment stepping time-step (in s).
+
+ This is the time-step at which the environment steps forward.
+ """
+ returnself.cfg.sim.dt*self.cfg.decimation
+
+ @property
+ defdevice(self):
+"""The device on which the environment is running."""
+ returnself.sim.device
+
+"""
+ Operations - Setup.
+ """
+
+
[文档]defload_managers(self):
+"""Load the managers for the environment.
+
+ This function is responsible for creating the various managers (action, observation,
+ events, etc.) for the environment. Since the managers require access to physics handles,
+ they can only be created after the simulator is reset (i.e. played for the first time).
+
+ .. note::
+ In case of standalone application (when running simulator from Python), the function is called
+ automatically when the class is initialized.
+
+ However, in case of extension mode, the user must call this function manually after the simulator
+ is reset. This is because the simulator is only reset when the user calls
+ :meth:`SimulationContext.reset_async` and it isn't possible to call async functions in the constructor.
+
+ """
+ # prepare the managers
+ # -- action manager
+ self.action_manager=ActionManager(self.cfg.actions,self)
+ print("[INFO] Action Manager: ",self.action_manager)
+ # -- observation manager
+ self.observation_manager=ObservationManager(self.cfg.observations,self)
+ print("[INFO] Observation Manager:",self.observation_manager)
+ # -- event manager
+ self.event_manager=EventManager(self.cfg.events,self)
+ print("[INFO] Event Manager: ",self.event_manager)
+
+ # perform events at the start of the simulation
+ # in-case a child implementation creates other managers, the randomization should happen
+ # when all the other managers are created
+ ifself.__class__==ManagerBasedEnvand"startup"inself.event_manager.available_modes:
+ self.event_manager.apply(mode="startup")
+
+"""
+ Operations - MDP.
+ """
+
+
[文档]defreset(self,seed:int|None=None,options:dict[str,Any]|None=None)->tuple[VecEnvObs,dict]:
+"""Resets all the environments and returns observations.
+
+ This function calls the :meth:`_reset_idx` function to reset all the environments.
+ However, certain operations, such as procedural terrain generation, that happened during initialization
+ are not repeated.
+
+ Args:
+ seed: The seed to use for randomization. Defaults to None, in which case the seed is not set.
+ options: Additional information to specify how the environment is reset. Defaults to None.
+
+ Note:
+ This argument is used for compatibility with Gymnasium environment definition.
+
+ Returns:
+ A tuple containing the observations and extras.
+ """
+ # set the seed
+ ifseedisnotNone:
+ self.seed(seed)
+
+ # reset state of scene
+ indices=torch.arange(self.num_envs,dtype=torch.int64,device=self.device)
+ self._reset_idx(indices)
+
+ # if sensors are added to the scene, make sure we render to reflect changes in reset
+ ifself.sim.has_rtx_sensors()andself.cfg.rerender_on_reset:
+ self.sim.render()
+
+ # return observations
+ returnself.observation_manager.compute(),self.extras
+
+
[文档]defstep(self,action:torch.Tensor)->tuple[VecEnvObs,dict]:
+"""Execute one time-step of the environment's dynamics.
+
+ The environment steps forward at a fixed time-step, while the physics simulation is
+ decimated at a lower time-step. This is to ensure that the simulation is stable. These two
+ time-steps can be configured independently using the :attr:`ManagerBasedEnvCfg.decimation` (number of
+ simulation steps per environment step) and the :attr:`ManagerBasedEnvCfg.sim.dt` (physics time-step).
+ Based on these parameters, the environment time-step is computed as the product of the two.
+
+ Args:
+ action: The actions to apply on the environment. Shape is (num_envs, action_dim).
+
+ Returns:
+ A tuple containing the observations and extras.
+ """
+ # process actions
+ self.action_manager.process_action(action.to(self.device))
+
+ # check if we need to do rendering within the physics loop
+ # note: checked here once to avoid multiple checks within the loop
+ is_rendering=self.sim.has_gui()orself.sim.has_rtx_sensors()
+
+ # perform physics stepping
+ for_inrange(self.cfg.decimation):
+ self._sim_step_counter+=1
+ # set actions into buffers
+ self.action_manager.apply_action()
+ # set actions into simulator
+ self.scene.write_data_to_sim()
+ # simulate
+ self.sim.step(render=False)
+ # render between steps only if the GUI or an RTX sensor needs it
+ # note: we assume the render interval to be the shortest accepted rendering interval.
+ # If a camera needs rendering at a faster frequency, this will lead to unexpected behavior.
+ ifself._sim_step_counter%self.cfg.sim.render_interval==0andis_rendering:
+ self.sim.render()
+ # update buffers at sim dt
+ self.scene.update(dt=self.physics_dt)
+
+ # post-step: step interval event
+ if"interval"inself.event_manager.available_modes:
+ self.event_manager.apply(mode="interval",dt=self.step_dt)
+
+ # return observations and extras
+ returnself.observation_manager.compute(),self.extras
+
+
[文档]@staticmethod
+ defseed(seed:int=-1)->int:
+"""Set the seed for the environment.
+
+ Args:
+ seed: The seed for random generator. Defaults to -1.
+
+ Returns:
+ The seed used for random generator.
+ """
+ # set seed for replicator
+ try:
+ importomni.replicator.coreasrep
+
+ rep.set_global_seed(seed)
+ exceptModuleNotFoundError:
+ pass
+ # set seed for torch and other libraries
+ returntorch_utils.set_seed(seed)
+
+
[文档]defclose(self):
+"""Cleanup for the environment."""
+ ifnotself._is_closed:
+ # destructor is order-sensitive
+ delself.viewport_camera_controller
+ delself.action_manager
+ delself.observation_manager
+ delself.event_manager
+ delself.scene
+ # clear callbacks and instance
+ self.sim.clear_all_callbacks()
+ self.sim.clear_instance()
+ # destroy the window
+ ifself._windowisnotNone:
+ self._window=None
+ # update closing status
+ self._is_closed=True
+
+"""
+ Helper functions.
+ """
+
+ def_reset_idx(self,env_ids:Sequence[int]):
+"""Reset environments based on specified indices.
+
+ Args:
+ env_ids: List of environment ids which must be reset
+ """
+ # reset the internal buffers of the scene elements
+ self.scene.reset(env_ids)
+
+ # apply events such as randomization for environments that need a reset
+ if"reset"inself.event_manager.available_modes:
+ env_step_count=self._sim_step_counter//self.cfg.decimation
+ self.event_manager.apply(mode="reset",env_ids=env_ids,global_env_step_count=env_step_count)
+
+ # iterate over all managers and reset them
+ # this returns a dictionary of information which is stored in the extras
+ # note: This is order-sensitive! Certain things need be reset before others.
+ self.extras["log"]=dict()
+ # -- observation manager
+ info=self.observation_manager.reset(env_ids)
+ self.extras["log"].update(info)
+ # -- action manager
+ info=self.action_manager.reset(env_ids)
+ self.extras["log"].update(info)
+ # -- event manager
+ info=self.event_manager.reset(env_ids)
+ self.extras["log"].update(info)
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Base configuration of the environment.
+
+This module defines the general configuration of the environment. It includes parameters for
+configuring the environment instances, viewer settings, and simulation parameters.
+"""
+
+fromdataclassesimportMISSING
+
+importomni.isaac.lab.envs.mdpasmdp
+fromomni.isaac.lab.managersimportEventTermCfgasEventTerm
+fromomni.isaac.lab.sceneimportInteractiveSceneCfg
+fromomni.isaac.lab.simimportSimulationCfg
+fromomni.isaac.lab.utilsimportconfigclass
+
+from.commonimportViewerCfg
+from.uiimportBaseEnvWindow
+
+
+@configclass
+classDefaultEventManagerCfg:
+"""Configuration of the default event manager.
+
+ This manager is used to reset the scene to a default state. The default state is specified
+ by the scene configuration.
+ """
+
+ reset_scene_to_default=EventTerm(func=mdp.reset_scene_to_default,mode="reset")
+
+
+
[文档]@configclass
+classManagerBasedEnvCfg:
+"""Base configuration of the environment."""
+
+ # simulation settings
+ viewer:ViewerCfg=ViewerCfg()
+"""Viewer configuration. Default is ViewerCfg()."""
+
+ sim:SimulationCfg=SimulationCfg()
+"""Physics simulation configuration. Default is SimulationCfg()."""
+
+ # ui settings
+ ui_window_class_type:type|None=BaseEnvWindow
+"""The class type of the UI window. Default is None.
+
+ If None, then no UI window is created.
+
+ Note:
+ If you want to make your own UI window, you can create a class that inherits from
+ from :class:`omni.isaac.lab.envs.ui.base_env_window.BaseEnvWindow`. Then, you can set
+ this attribute to your class type.
+ """
+
+ # general settings
+ seed:int|None=None
+"""The seed for the random number generator. Defaults to None, in which case the seed is not set.
+
+ Note:
+ The seed is set at the beginning of the environment initialization. This ensures that the environment
+ creation is deterministic and behaves similarly across different runs.
+ """
+
+ decimation:int=MISSING
+"""Number of control action updates @ sim dt per policy dt.
+
+ For instance, if the simulation dt is 0.01s and the policy dt is 0.1s, then the decimation is 10.
+ This means that the control action is updated every 10 simulation steps.
+ """
+
+ # environment settings
+ scene:InteractiveSceneCfg=MISSING
+"""Scene settings.
+
+ Please refer to the :class:`omni.isaac.lab.scene.InteractiveSceneCfg` class for more details.
+ """
+
+ observations:object=MISSING
+"""Observation space settings.
+
+ Please refer to the :class:`omni.isaac.lab.managers.ObservationManager` class for more details.
+ """
+
+ actions:object=MISSING
+"""Action space settings.
+
+ Please refer to the :class:`omni.isaac.lab.managers.ActionManager` class for more details.
+ """
+
+ events:object=DefaultEventManagerCfg()
+"""Event settings. Defaults to the basic configuration that resets the scene to its default state.
+
+ Please refer to the :class:`omni.isaac.lab.managers.EventManager` class for more details.
+ """
+
+ rerender_on_reset:bool=False
+"""Whether a render step is performed again after at least one environment has been reset.
+ Defaults to False, which means no render step will be performed after reset.
+
+ * When this is False, data collected from sensors after performing reset will be stale and will not reflect the
+ latest states in simulation caused by the reset.
+ * When this is True, an extra render step will be performed to update the sensor data
+ to reflect the latest states from the reset. This comes at a cost of performance as an additional render
+ step will be performed after each time an environment is reset.
+
+ """
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+# needed to import for allowing type-hinting: np.ndarray | None
+from__future__importannotations
+
+importgymnasiumasgym
+importmath
+importnumpyasnp
+importtorch
+fromcollections.abcimportSequence
+fromtypingimportAny,ClassVar
+
+fromomni.isaac.versionimportget_version
+
+fromomni.isaac.lab.managersimportCommandManager,CurriculumManager,RewardManager,TerminationManager
+
+from.commonimportVecEnvStepReturn
+from.manager_based_envimportManagerBasedEnv
+from.manager_based_rl_env_cfgimportManagerBasedRLEnvCfg
+
+
+
[文档]classManagerBasedRLEnv(ManagerBasedEnv,gym.Env):
+"""The superclass for the manager-based workflow reinforcement learning-based environments.
+
+ This class inherits from :class:`ManagerBasedEnv` and implements the core functionality for
+ reinforcement learning-based environments. It is designed to be used with any RL
+ library. The class is designed to be used with vectorized environments, i.e., the
+ environment is expected to be run in parallel with multiple sub-environments. The
+ number of sub-environments is specified using the ``num_envs``.
+
+ Each observation from the environment is a batch of observations for each sub-
+ environments. The method :meth:`step` is also expected to receive a batch of actions
+ for each sub-environment.
+
+ While the environment itself is implemented as a vectorized environment, we do not
+ inherit from :class:`gym.vector.VectorEnv`. This is mainly because the class adds
+ various methods (for wait and asynchronous updates) which are not required.
+ Additionally, each RL library typically has its own definition for a vectorized
+ environment. Thus, to reduce complexity, we directly use the :class:`gym.Env` over
+ here and leave it up to library-defined wrappers to take care of wrapping this
+ environment for their agents.
+
+ Note:
+ For vectorized environments, it is recommended to **only** call the :meth:`reset`
+ method once before the first call to :meth:`step`, i.e. after the environment is created.
+ After that, the :meth:`step` function handles the reset of terminated sub-environments.
+ This is because the simulator does not support resetting individual sub-environments
+ in a vectorized environment.
+
+ """
+
+ is_vector_env:ClassVar[bool]=True
+"""Whether the environment is a vectorized environment."""
+ metadata:ClassVar[dict[str,Any]]={
+ "render_modes":[None,"human","rgb_array"],
+ "isaac_sim_version":get_version(),
+ }
+"""Metadata for the environment."""
+
+ cfg:ManagerBasedRLEnvCfg
+"""Configuration for the environment."""
+
+
[文档]def__init__(self,cfg:ManagerBasedRLEnvCfg,render_mode:str|None=None,**kwargs):
+"""Initialize the environment.
+
+ Args:
+ cfg: The configuration for the environment.
+ render_mode: The render mode for the environment. Defaults to None, which
+ is similar to ``"human"``.
+ """
+ # initialize the base class to setup the scene.
+ super().__init__(cfg=cfg)
+ # store the render mode
+ self.render_mode=render_mode
+
+ # initialize data and constants
+ # -- counter for curriculum
+ self.common_step_counter=0
+ # -- init buffers
+ self.episode_length_buf=torch.zeros(self.num_envs,device=self.device,dtype=torch.long)
+ # -- set the framerate of the gym video recorder wrapper so that the playback speed of the produced video matches the simulation
+ self.metadata["render_fps"]=1/self.step_dt
+
+ print("[INFO]: Completed setting up the environment...")
[文档]defload_managers(self):
+ # note: this order is important since observation manager needs to know the command and action managers
+ # and the reward manager needs to know the termination manager
+ # -- command manager
+ self.command_manager:CommandManager=CommandManager(self.cfg.commands,self)
+ print("[INFO] Command Manager: ",self.command_manager)
+
+ # call the parent class to load the managers for observations and actions.
+ super().load_managers()
+
+ # prepare the managers
+ # -- termination manager
+ self.termination_manager=TerminationManager(self.cfg.terminations,self)
+ print("[INFO] Termination Manager: ",self.termination_manager)
+ # -- reward manager
+ self.reward_manager=RewardManager(self.cfg.rewards,self)
+ print("[INFO] Reward Manager: ",self.reward_manager)
+ # -- curriculum manager
+ self.curriculum_manager=CurriculumManager(self.cfg.curriculum,self)
+ print("[INFO] Curriculum Manager: ",self.curriculum_manager)
+
+ # setup the action and observation spaces for Gym
+ self._configure_gym_env_spaces()
+
+ # perform events at the start of the simulation
+ if"startup"inself.event_manager.available_modes:
+ self.event_manager.apply(mode="startup")
+
+"""
+ Operations - MDP
+ """
+
+
[文档]defstep(self,action:torch.Tensor)->VecEnvStepReturn:
+"""Execute one time-step of the environment's dynamics and reset terminated environments.
+
+ Unlike the :class:`ManagerBasedEnv.step` class, the function performs the following operations:
+
+ 1. Process the actions.
+ 2. Perform physics stepping.
+ 3. Perform rendering if gui is enabled.
+ 4. Update the environment counters and compute the rewards and terminations.
+ 5. Reset the environments that terminated.
+ 6. Compute the observations.
+ 7. Return the observations, rewards, resets and extras.
+
+ Args:
+ action: The actions to apply on the environment. Shape is (num_envs, action_dim).
+
+ Returns:
+ A tuple containing the observations, rewards, resets (terminated and truncated) and extras.
+ """
+ # process actions
+ self.action_manager.process_action(action.to(self.device))
+
+ # check if we need to do rendering within the physics loop
+ # note: checked here once to avoid multiple checks within the loop
+ is_rendering=self.sim.has_gui()orself.sim.has_rtx_sensors()
+
+ # perform physics stepping
+ for_inrange(self.cfg.decimation):
+ self._sim_step_counter+=1
+ # set actions into buffers
+ self.action_manager.apply_action()
+ # set actions into simulator
+ self.scene.write_data_to_sim()
+ # simulate
+ self.sim.step(render=False)
+ # render between steps only if the GUI or an RTX sensor needs it
+ # note: we assume the render interval to be the shortest accepted rendering interval.
+ # If a camera needs rendering at a faster frequency, this will lead to unexpected behavior.
+ ifself._sim_step_counter%self.cfg.sim.render_interval==0andis_rendering:
+ self.sim.render()
+ # update buffers at sim dt
+ self.scene.update(dt=self.physics_dt)
+
+ # post-step:
+ # -- update env counters (used for curriculum generation)
+ self.episode_length_buf+=1# step in current episode (per env)
+ self.common_step_counter+=1# total step (common for all envs)
+ # -- check terminations
+ self.reset_buf=self.termination_manager.compute()
+ self.reset_terminated=self.termination_manager.terminated
+ self.reset_time_outs=self.termination_manager.time_outs
+ # -- reward computation
+ self.reward_buf=self.reward_manager.compute(dt=self.step_dt)
+
+ # -- reset envs that terminated/timed-out and log the episode information
+ reset_env_ids=self.reset_buf.nonzero(as_tuple=False).squeeze(-1)
+ iflen(reset_env_ids)>0:
+ self._reset_idx(reset_env_ids)
+ # if sensors are added to the scene, make sure we render to reflect changes in reset
+ ifself.sim.has_rtx_sensors()andself.cfg.rerender_on_reset:
+ self.sim.render()
+
+ # -- update command
+ self.command_manager.compute(dt=self.step_dt)
+ # -- step interval events
+ if"interval"inself.event_manager.available_modes:
+ self.event_manager.apply(mode="interval",dt=self.step_dt)
+ # -- compute observations
+ # note: done after reset to get the correct observations for reset envs
+ self.obs_buf=self.observation_manager.compute()
+
+ # return observations, rewards, resets and extras
+ returnself.obs_buf,self.reward_buf,self.reset_terminated,self.reset_time_outs,self.extras
+
+
[文档]defrender(self,recompute:bool=False)->np.ndarray|None:
+"""Run rendering without stepping through the physics.
+
+ By convention, if mode is:
+
+ - **human**: Render to the current display and return nothing. Usually for human consumption.
+ - **rgb_array**: Return an numpy.ndarray with shape (x, y, 3), representing RGB values for an
+ x-by-y pixel image, suitable for turning into a video.
+
+ Args:
+ recompute: Whether to force a render even if the simulator has already rendered the scene.
+ Defaults to False.
+
+ Returns:
+ The rendered image as a numpy array if mode is "rgb_array". Otherwise, returns None.
+
+ Raises:
+ RuntimeError: If mode is set to "rgb_data" and simulation render mode does not support it.
+ In this case, the simulation render mode must be set to ``RenderMode.PARTIAL_RENDERING``
+ or ``RenderMode.FULL_RENDERING``.
+ NotImplementedError: If an unsupported rendering mode is specified.
+ """
+ # run a rendering step of the simulator
+ # if we have rtx sensors, we do not need to render again sin
+ ifnotself.sim.has_rtx_sensors()andnotrecompute:
+ self.sim.render()
+ # decide the rendering mode
+ ifself.render_mode=="human"orself.render_modeisNone:
+ returnNone
+ elifself.render_mode=="rgb_array":
+ # check that if any render could have happened
+ ifself.sim.render_mode.value<self.sim.RenderMode.PARTIAL_RENDERING.value:
+ raiseRuntimeError(
+ f"Cannot render '{self.render_mode}' when the simulation render mode is"
+ f" '{self.sim.render_mode.name}'. Please set the simulation render mode to:"
+ f"'{self.sim.RenderMode.PARTIAL_RENDERING.name}' or '{self.sim.RenderMode.FULL_RENDERING.name}'."
+ " If running headless, make sure --enable_cameras is set."
+ )
+ # create the annotator if it does not exist
+ ifnothasattr(self,"_rgb_annotator"):
+ importomni.replicator.coreasrep
+
+ # create render product
+ self._render_product=rep.create.render_product(
+ self.cfg.viewer.cam_prim_path,self.cfg.viewer.resolution
+ )
+ # create rgb annotator -- used to read data from the render product
+ self._rgb_annotator=rep.AnnotatorRegistry.get_annotator("rgb",device="cpu")
+ self._rgb_annotator.attach([self._render_product])
+ # obtain the rgb data
+ rgb_data=self._rgb_annotator.get_data()
+ # convert to numpy array
+ rgb_data=np.frombuffer(rgb_data,dtype=np.uint8).reshape(*rgb_data.shape)
+ # return the rgb data
+ # note: initially the renerer is warming up and returns empty data
+ ifrgb_data.size==0:
+ returnnp.zeros((self.cfg.viewer.resolution[1],self.cfg.viewer.resolution[0],3),dtype=np.uint8)
+ else:
+ returnrgb_data[:,:,:3]
+ else:
+ raiseNotImplementedError(
+ f"Render mode '{self.render_mode}' is not supported. Please use: {self.metadata['render_modes']}."
+ )
+
+
[文档]defclose(self):
+ ifnotself._is_closed:
+ # destructor is order-sensitive
+ delself.command_manager
+ delself.reward_manager
+ delself.termination_manager
+ delself.curriculum_manager
+ # call the parent class to close the environment
+ super().close()
+
+"""
+ Helper functions.
+ """
+
+ def_configure_gym_env_spaces(self):
+"""Configure the action and observation spaces for the Gym environment."""
+ # observation space (unbounded since we don't impose any limits)
+ self.single_observation_space=gym.spaces.Dict()
+ forgroup_name,group_term_namesinself.observation_manager.active_terms.items():
+ # extract quantities about the group
+ has_concatenated_obs=self.observation_manager.group_obs_concatenate[group_name]
+ group_dim=self.observation_manager.group_obs_dim[group_name]
+ # check if group is concatenated or not
+ # if not concatenated, then we need to add each term separately as a dictionary
+ ifhas_concatenated_obs:
+ self.single_observation_space[group_name]=gym.spaces.Box(low=-np.inf,high=np.inf,shape=group_dim)
+ else:
+ self.single_observation_space[group_name]=gym.spaces.Dict({
+ term_name:gym.spaces.Box(low=-np.inf,high=np.inf,shape=term_dim)
+ forterm_name,term_diminzip(group_term_names,group_dim)
+ })
+ # action space (unbounded since we don't impose any limits)
+ action_dim=sum(self.action_manager.action_term_dim)
+ self.single_action_space=gym.spaces.Box(low=-np.inf,high=np.inf,shape=(action_dim,))
+
+ # batch the spaces for vectorized environments
+ self.observation_space=gym.vector.utils.batch_space(self.single_observation_space,self.num_envs)
+ self.action_space=gym.vector.utils.batch_space(self.single_action_space,self.num_envs)
+
+ def_reset_idx(self,env_ids:Sequence[int]):
+"""Reset environments based on specified indices.
+
+ Args:
+ env_ids: List of environment ids which must be reset
+ """
+ # update the curriculum for environments that need a reset
+ self.curriculum_manager.compute(env_ids=env_ids)
+ # reset the internal buffers of the scene elements
+ self.scene.reset(env_ids)
+ # apply events such as randomizations for environments that need a reset
+ if"reset"inself.event_manager.available_modes:
+ env_step_count=self._sim_step_counter//self.cfg.decimation
+ self.event_manager.apply(mode="reset",env_ids=env_ids,global_env_step_count=env_step_count)
+
+ # iterate over all managers and reset them
+ # this returns a dictionary of information which is stored in the extras
+ # note: This is order-sensitive! Certain things need be reset before others.
+ self.extras["log"]=dict()
+ # -- observation manager
+ info=self.observation_manager.reset(env_ids)
+ self.extras["log"].update(info)
+ # -- action manager
+ info=self.action_manager.reset(env_ids)
+ self.extras["log"].update(info)
+ # -- rewards manager
+ info=self.reward_manager.reset(env_ids)
+ self.extras["log"].update(info)
+ # -- curriculum manager
+ info=self.curriculum_manager.reset(env_ids)
+ self.extras["log"].update(info)
+ # -- command manager
+ info=self.command_manager.reset(env_ids)
+ self.extras["log"].update(info)
+ # -- event manager
+ info=self.event_manager.reset(env_ids)
+ self.extras["log"].update(info)
+ # -- termination manager
+ info=self.termination_manager.reset(env_ids)
+ self.extras["log"].update(info)
+
+ # reset the episode length buffer
+ self.episode_length_buf[env_ids]=0
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+fromdataclassesimportMISSING
+
+fromomni.isaac.lab.utilsimportconfigclass
+
+from.manager_based_env_cfgimportManagerBasedEnvCfg
+from.uiimportManagerBasedRLEnvWindow
+
+
+
[文档]@configclass
+classManagerBasedRLEnvCfg(ManagerBasedEnvCfg):
+"""Configuration for a reinforcement learning environment with the manager-based workflow."""
+
+ # ui settings
+ ui_window_class_type:type|None=ManagerBasedRLEnvWindow
+
+ # general settings
+ is_finite_horizon:bool=False
+"""Whether the learning task is treated as a finite or infinite horizon problem for the agent.
+ Defaults to False, which means the task is treated as an infinite horizon problem.
+
+ This flag handles the subtleties of finite and infinite horizon tasks:
+
+ * **Finite horizon**: no penalty or bootstrapping value is required by the the agent for
+ running out of time. However, the environment still needs to terminate the episode after the
+ time limit is reached.
+ * **Infinite horizon**: the agent needs to bootstrap the value of the state at the end of the episode.
+ This is done by sending a time-limit (or truncated) done signal to the agent, which triggers this
+ bootstrapping calculation.
+
+ If True, then the environment is treated as a finite horizon problem and no time-out (or truncated) done signal
+ is sent to the agent. If False, then the environment is treated as an infinite horizon problem and a time-out
+ (or truncated) done signal is sent to the agent.
+
+ Note:
+ The base :class:`ManagerBasedRLEnv` class does not use this flag directly. It is used by the environment
+ wrappers to determine what type of done signal to send to the corresponding learning agent.
+ """
+
+ episode_length_s:float=MISSING
+"""Duration of an episode (in seconds).
+
+ Based on the decimation rate and physics time step, the episode length is calculated as:
+
+ .. code-block:: python
+
+ episode_length_steps = ceil(episode_length_s / (decimation_rate * physics_time_step))
+
+ For example, if the decimation rate is 10, the physics time step is 0.01, and the episode length is 10 seconds,
+ then the episode length in steps is 100.
+ """
+
+ # environment settings
+ rewards:object=MISSING
+"""Reward settings.
+
+ Please refer to the :class:`omni.isaac.lab.managers.RewardManager` class for more details.
+ """
+
+ terminations:object=MISSING
+"""Termination settings.
+
+ Please refer to the :class:`omni.isaac.lab.managers.TerminationManager` class for more details.
+ """
+
+ curriculum:object=MISSING
+"""Curriculum settings.
+
+ Please refer to the :class:`omni.isaac.lab.managers.CurriculumManager` class for more details.
+ """
+
+ commands:object=MISSING
+"""Command settings.
+
+ Please refer to the :class:`omni.isaac.lab.managers.CommandManager` class for more details.
+ """
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+fromdataclassesimportMISSING
+
+fromomni.isaac.lab.controllersimportDifferentialIKControllerCfg
+fromomni.isaac.lab.managers.action_managerimportActionTerm,ActionTermCfg
+fromomni.isaac.lab.utilsimportconfigclass
+
+from.importbinary_joint_actions,joint_actions,joint_actions_to_limits,non_holonomic_actions,task_space_actions
+
+##
+# Joint actions.
+##
+
+
+
[文档]@configclass
+classJointActionCfg(ActionTermCfg):
+"""Configuration for the base joint action term.
+
+ See :class:`JointAction` for more details.
+ """
+
+ joint_names:list[str]=MISSING
+"""List of joint names or regex expressions that the action will be mapped to."""
+ scale:float|dict[str,float]=1.0
+"""Scale factor for the action (float or dict of regex expressions). Defaults to 1.0."""
+ offset:float|dict[str,float]=0.0
+"""Offset factor for the action (float or dict of regex expressions). Defaults to 0.0."""
+ preserve_order:bool=False
+"""Whether to preserve the order of the joint names in the action output. Defaults to False."""
+
+
+
[文档]@configclass
+classJointPositionActionCfg(JointActionCfg):
+"""Configuration for the joint position action term.
+
+ See :class:`JointPositionAction` for more details.
+ """
+
+ class_type:type[ActionTerm]=joint_actions.JointPositionAction
+
+ use_default_offset:bool=True
+"""Whether to use default joint positions configured in the articulation asset as offset.
+ Defaults to True.
+
+ If True, this flag results in overwriting the values of :attr:`offset` to the default joint positions
+ from the articulation asset.
+ """
+
+
+
[文档]@configclass
+classRelativeJointPositionActionCfg(JointActionCfg):
+"""Configuration for the relative joint position action term.
+
+ See :class:`RelativeJointPositionAction` for more details.
+ """
+
+ class_type:type[ActionTerm]=joint_actions.RelativeJointPositionAction
+
+ use_zero_offset:bool=True
+"""Whether to ignore the offset defined in articulation asset. Defaults to True.
+
+ If True, this flag results in overwriting the values of :attr:`offset` to zero.
+ """
+
+
+
[文档]@configclass
+classJointVelocityActionCfg(JointActionCfg):
+"""Configuration for the joint velocity action term.
+
+ See :class:`JointVelocityAction` for more details.
+ """
+
+ class_type:type[ActionTerm]=joint_actions.JointVelocityAction
+
+ use_default_offset:bool=True
+"""Whether to use default joint velocities configured in the articulation asset as offset.
+ Defaults to True.
+
+ This overrides the settings from :attr:`offset` if set to True.
+ """
+
+
+
[文档]@configclass
+classJointEffortActionCfg(JointActionCfg):
+"""Configuration for the joint effort action term.
+
+ See :class:`JointEffortAction` for more details.
+ """
+
+ class_type:type[ActionTerm]=joint_actions.JointEffortAction
[文档]@configclass
+classJointPositionToLimitsActionCfg(ActionTermCfg):
+"""Configuration for the bounded joint position action term.
+
+ See :class:`JointPositionWithinLimitsAction` for more details.
+ """
+
+ class_type:type[ActionTerm]=joint_actions_to_limits.JointPositionToLimitsAction
+
+ joint_names:list[str]=MISSING
+"""List of joint names or regex expressions that the action will be mapped to."""
+
+ scale:float|dict[str,float]=1.0
+"""Scale factor for the action (float or dict of regex expressions). Defaults to 1.0."""
+
+ rescale_to_limits:bool=True
+"""Whether to rescale the action to the joint limits. Defaults to True.
+
+ If True, the input actions are rescaled to the joint limits, i.e., the action value in
+ the range [-1, 1] corresponds to the joint lower and upper limits respectively.
+
+ Note:
+ This operation is performed after applying the scale factor.
+ """
+
+
+
[文档]@configclass
+classEMAJointPositionToLimitsActionCfg(JointPositionToLimitsActionCfg):
+"""Configuration for the exponential moving average (EMA) joint position action term.
+
+ See :class:`EMAJointPositionToLimitsAction` for more details.
+ """
+
+ class_type:type[ActionTerm]=joint_actions_to_limits.EMAJointPositionToLimitsAction
+
+ alpha:float|dict[str,float]=1.0
+"""The weight for the moving average (float or dict of regex expressions). Defaults to 1.0.
+
+ If set to 1.0, the processed action is applied directly without any moving average window.
+ """
+
+
+##
+# Gripper actions.
+##
+
+
+
[文档]@configclass
+classBinaryJointActionCfg(ActionTermCfg):
+"""Configuration for the base binary joint action term.
+
+ See :class:`BinaryJointAction` for more details.
+ """
+
+ joint_names:list[str]=MISSING
+"""List of joint names or regex expressions that the action will be mapped to."""
+ open_command_expr:dict[str,float]=MISSING
+"""The joint command to move to *open* configuration."""
+ close_command_expr:dict[str,float]=MISSING
+"""The joint command to move to *close* configuration."""
+
+
+
[文档]@configclass
+classBinaryJointPositionActionCfg(BinaryJointActionCfg):
+"""Configuration for the binary joint position action term.
+
+ See :class:`BinaryJointPositionAction` for more details.
+ """
+
+ class_type:type[ActionTerm]=binary_joint_actions.BinaryJointPositionAction
+
+
+
[文档]@configclass
+classBinaryJointVelocityActionCfg(BinaryJointActionCfg):
+"""Configuration for the binary joint velocity action term.
+
+ See :class:`BinaryJointVelocityAction` for more details.
+ """
+
+ class_type:type[ActionTerm]=binary_joint_actions.BinaryJointVelocityAction
+
+
+##
+# Non-holonomic actions.
+##
+
+
+
[文档]@configclass
+classNonHolonomicActionCfg(ActionTermCfg):
+"""Configuration for the non-holonomic action term with dummy joints at the base.
+
+ See :class:`NonHolonomicAction` for more details.
+ """
+
+ class_type:type[ActionTerm]=non_holonomic_actions.NonHolonomicAction
+
+ body_name:str=MISSING
+"""Name of the body which has the dummy mechanism connected to."""
+ x_joint_name:str=MISSING
+"""The dummy joint name in the x direction."""
+ y_joint_name:str=MISSING
+"""The dummy joint name in the y direction."""
+ yaw_joint_name:str=MISSING
+"""The dummy joint name in the yaw direction."""
+ scale:tuple[float,float]=(1.0,1.0)
+"""Scale factor for the action. Defaults to (1.0, 1.0)."""
+ offset:tuple[float,float]=(0.0,0.0)
+"""Offset factor for the action. Defaults to (0.0, 0.0)."""
+
+
+##
+# Task-space Actions.
+##
+
+
+
[文档]@configclass
+classDifferentialInverseKinematicsActionCfg(ActionTermCfg):
+"""Configuration for inverse differential kinematics action term.
+
+ See :class:`DifferentialInverseKinematicsAction` for more details.
+ """
+
+
[文档]@configclass
+ classOffsetCfg:
+"""The offset pose from parent frame to child frame.
+
+ On many robots, end-effector frames are fictitious frames that do not have a corresponding
+ rigid body. In such cases, it is easier to define this transform w.r.t. their parent rigid body.
+ For instance, for the Franka Emika arm, the end-effector is defined at an offset to the the
+ "panda_hand" frame.
+ """
+
+ pos:tuple[float,float,float]=(0.0,0.0,0.0)
+"""Translation w.r.t. the parent frame. Defaults to (0.0, 0.0, 0.0)."""
+ rot:tuple[float,float,float,float]=(1.0,0.0,0.0,0.0)
+"""Quaternion rotation ``(w, x, y, z)`` w.r.t. the parent frame. Defaults to (1.0, 0.0, 0.0, 0.0)."""
+
+ class_type:type[ActionTerm]=task_space_actions.DifferentialInverseKinematicsAction
+
+ joint_names:list[str]=MISSING
+"""List of joint names or regex expressions that the action will be mapped to."""
+ body_name:str=MISSING
+"""Name of the body or frame for which IK is performed."""
+ body_offset:OffsetCfg|None=None
+"""Offset of target frame w.r.t. to the body frame. Defaults to None, in which case no offset is applied."""
+ scale:float|tuple[float,...]=1.0
+"""Scale factor for the action. Defaults to 1.0."""
+ controller:DifferentialIKControllerCfg=MISSING
+"""The configuration for the differential IK controller."""
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+importmath
+fromdataclassesimportMISSING
+
+fromomni.isaac.lab.managersimportCommandTermCfg
+fromomni.isaac.lab.markersimportVisualizationMarkersCfg
+fromomni.isaac.lab.markers.configimportBLUE_ARROW_X_MARKER_CFG,FRAME_MARKER_CFG,GREEN_ARROW_X_MARKER_CFG
+fromomni.isaac.lab.utilsimportconfigclass
+
+from.null_commandimportNullCommand
+from.pose_2d_commandimportTerrainBasedPose2dCommand,UniformPose2dCommand
+from.pose_commandimportUniformPoseCommand
+from.velocity_commandimportNormalVelocityCommand,UniformVelocityCommand
+
+
+
[文档]@configclass
+classNullCommandCfg(CommandTermCfg):
+"""Configuration for the null command generator."""
+
+ class_type:type=NullCommand
+
+ def__post_init__(self):
+"""Post initialization."""
+ # set the resampling time range to infinity to avoid resampling
+ self.resampling_time_range=(math.inf,math.inf)
+
+
+
[文档]@configclass
+classUniformVelocityCommandCfg(CommandTermCfg):
+"""Configuration for the uniform velocity command generator."""
+
+ class_type:type=UniformVelocityCommand
+
+ asset_name:str=MISSING
+"""Name of the asset in the environment for which the commands are generated."""
+ heading_command:bool=MISSING
+"""Whether to use heading command or angular velocity command.
+
+ If True, the angular velocity command is computed from the heading error, where the
+ target heading is sampled uniformly from provided range. Otherwise, the angular velocity
+ command is sampled uniformly from provided range.
+ """
+ heading_control_stiffness:float=MISSING
+"""Scale factor to convert the heading error to angular velocity command."""
+ rel_standing_envs:float=MISSING
+"""Probability threshold for environments where the robots that are standing still."""
+ rel_heading_envs:float=MISSING
+"""Probability threshold for environments where the robots follow the heading-based angular velocity command
+ (the others follow the sampled angular velocity command)."""
+
+
[文档]@configclass
+ classRanges:
+"""Uniform distribution ranges for the velocity commands."""
+
+ lin_vel_x:tuple[float,float]=MISSING# min max [m/s]
+ lin_vel_y:tuple[float,float]=MISSING# min max [m/s]
+ ang_vel_z:tuple[float,float]=MISSING# min max [rad/s]
+ heading:tuple[float,float]=MISSING# min max [rad]
+
+ ranges:Ranges=MISSING
+"""Distribution ranges for the velocity commands."""
+
+ goal_vel_visualizer_cfg:VisualizationMarkersCfg=GREEN_ARROW_X_MARKER_CFG.replace(
+ prim_path="/Visuals/Command/velocity_goal"
+ )
+"""The configuration for the goal velocity visualization marker. Defaults to GREEN_ARROW_X_MARKER_CFG."""
+
+ current_vel_visualizer_cfg:VisualizationMarkersCfg=BLUE_ARROW_X_MARKER_CFG.replace(
+ prim_path="/Visuals/Command/velocity_current"
+ )
+"""The configuration for the current velocity visualization marker. Defaults to BLUE_ARROW_X_MARKER_CFG."""
+
+ # Set the scale of the visualization markers to (0.5, 0.5, 0.5)
+ goal_vel_visualizer_cfg.markers["arrow"].scale=(0.5,0.5,0.5)
+ current_vel_visualizer_cfg.markers["arrow"].scale=(0.5,0.5,0.5)
+
+
+
[文档]@configclass
+classNormalVelocityCommandCfg(UniformVelocityCommandCfg):
+"""Configuration for the normal velocity command generator."""
+
+ class_type:type=NormalVelocityCommand
+ heading_command:bool=False# --> we don't use heading command for normal velocity command.
+
+
[文档]@configclass
+ classRanges:
+"""Normal distribution ranges for the velocity commands."""
+
+ mean_vel:tuple[float,float,float]=MISSING
+"""Mean velocity for the normal distribution.
+
+ The tuple contains the mean linear-x, linear-y, and angular-z velocity.
+ """
+ std_vel:tuple[float,float,float]=MISSING
+"""Standard deviation for the normal distribution.
+
+ The tuple contains the standard deviation linear-x, linear-y, and angular-z velocity.
+ """
+ zero_prob:tuple[float,float,float]=MISSING
+"""Probability of zero velocity for the normal distribution.
+
+ The tuple contains the probability of zero linear-x, linear-y, and angular-z velocity.
+ """
+
+ ranges:Ranges=MISSING
+"""Distribution ranges for the velocity commands."""
+
+
+
[文档]@configclass
+classUniformPoseCommandCfg(CommandTermCfg):
+"""Configuration for uniform pose command generator."""
+
+ class_type:type=UniformPoseCommand
+
+ asset_name:str=MISSING
+"""Name of the asset in the environment for which the commands are generated."""
+ body_name:str=MISSING
+"""Name of the body in the asset for which the commands are generated."""
+
+ make_quat_unique:bool=False
+"""Whether to make the quaternion unique or not. Defaults to False.
+
+ If True, the quaternion is made unique by ensuring the real part is positive.
+ """
+
+
[文档]@configclass
+ classRanges:
+"""Uniform distribution ranges for the pose commands."""
+
+ pos_x:tuple[float,float]=MISSING# min max [m]
+ pos_y:tuple[float,float]=MISSING# min max [m]
+ pos_z:tuple[float,float]=MISSING# min max [m]
+ roll:tuple[float,float]=MISSING# min max [rad]
+ pitch:tuple[float,float]=MISSING# min max [rad]
+ yaw:tuple[float,float]=MISSING# min max [rad]
+
+ ranges:Ranges=MISSING
+"""Ranges for the commands."""
+
+ goal_pose_visualizer_cfg:VisualizationMarkersCfg=FRAME_MARKER_CFG.replace(prim_path="/Visuals/Command/goal_pose")
+"""The configuration for the goal pose visualization marker. Defaults to FRAME_MARKER_CFG."""
+
+ current_pose_visualizer_cfg:VisualizationMarkersCfg=FRAME_MARKER_CFG.replace(
+ prim_path="/Visuals/Command/body_pose"
+ )
+"""The configuration for the current pose visualization marker. Defaults to FRAME_MARKER_CFG."""
+
+ # Set the scale of the visualization markers to (0.1, 0.1, 0.1)
+ goal_pose_visualizer_cfg.markers["frame"].scale=(0.1,0.1,0.1)
+ current_pose_visualizer_cfg.markers["frame"].scale=(0.1,0.1,0.1)
+
+
+
[文档]@configclass
+classUniformPose2dCommandCfg(CommandTermCfg):
+"""Configuration for the uniform 2D-pose command generator."""
+
+ class_type:type=UniformPose2dCommand
+
+ asset_name:str=MISSING
+"""Name of the asset in the environment for which the commands are generated."""
+
+ simple_heading:bool=MISSING
+"""Whether to use simple heading or not.
+
+ If True, the heading is in the direction of the target position.
+ """
+
+
[文档]@configclass
+ classRanges:
+"""Uniform distribution ranges for the position commands."""
+
+ pos_x:tuple[float,float]=MISSING
+"""Range for the x position (in m)."""
+ pos_y:tuple[float,float]=MISSING
+"""Range for the y position (in m)."""
+ heading:tuple[float,float]=MISSING
+"""Heading range for the position commands (in rad).
+
+ Used only if :attr:`simple_heading` is False.
+ """
+
+ ranges:Ranges=MISSING
+"""Distribution ranges for the position commands."""
+
+ goal_pose_visualizer_cfg:VisualizationMarkersCfg=GREEN_ARROW_X_MARKER_CFG.replace(
+ prim_path="/Visuals/Command/pose_goal"
+ )
+"""The configuration for the goal pose visualization marker. Defaults to GREEN_ARROW_X_MARKER_CFG."""
+
+ # Set the scale of the visualization markers to (0.2, 0.2, 0.8)
+ goal_pose_visualizer_cfg.markers["arrow"].scale=(0.2,0.2,0.8)
+
+
+
[文档]@configclass
+classTerrainBasedPose2dCommandCfg(UniformPose2dCommandCfg):
+"""Configuration for the terrain-based position command generator."""
+
+ class_type=TerrainBasedPose2dCommand
+
+
[文档]@configclass
+ classRanges:
+"""Uniform distribution ranges for the position commands."""
+
+ heading:tuple[float,float]=MISSING
+"""Heading range for the position commands (in rad).
+
+ Used only if :attr:`simple_heading` is False.
+ """
+
+ ranges:Ranges=MISSING
+"""Distribution ranges for the sampled commands."""
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Common functions that can be used to create curriculum for the learning environment.
+
+The functions can be passed to the :class:`omni.isaac.lab.managers.CurriculumTermCfg` object to enable
+the curriculum introduced by the function.
+"""
+
+from__future__importannotations
+
+fromcollections.abcimportSequence
+fromtypingimportTYPE_CHECKING
+
+ifTYPE_CHECKING:
+ fromomni.isaac.lab.envsimportManagerBasedRLEnv
+
+
+
[文档]defmodify_reward_weight(env:ManagerBasedRLEnv,env_ids:Sequence[int],term_name:str,weight:float,num_steps:int):
+"""Curriculum that modifies a reward weight a given number of steps.
+
+ Args:
+ env: The learning environment.
+ env_ids: Not used since all environments are affected.
+ term_name: The name of the reward term.
+ weight: The weight of the reward term.
+ num_steps: The number of steps after which the change should be applied.
+ """
+ ifenv.common_step_counter>num_steps:
+ # obtain term settings
+ term_cfg=env.reward_manager.get_term_cfg(term_name)
+ # update term settings
+ term_cfg.weight=weight
+ env.reward_manager.set_term_cfg(term_name,term_cfg)
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Common functions that can be used to enable different events.
+
+Events include anything related to altering the simulation state. This includes changing the physics
+materials, applying external forces, and resetting the state of the asset.
+
+The functions can be passed to the :class:`omni.isaac.lab.managers.EventTermCfg` object to enable
+the event introduced by the function.
+"""
+
+from__future__importannotations
+
+importnumpyasnp
+importtorch
+fromtypingimportTYPE_CHECKING,Literal
+
+importcarb
+importomni.physics.tensors.impl.apiasphysx
+
+importomni.isaac.lab.simassim_utils
+importomni.isaac.lab.utils.mathasmath_utils
+fromomni.isaac.lab.actuatorsimportImplicitActuator
+fromomni.isaac.lab.assetsimportArticulation,DeformableObject,RigidObject
+fromomni.isaac.lab.managersimportSceneEntityCfg
+fromomni.isaac.lab.terrainsimportTerrainImporter
+
+ifTYPE_CHECKING:
+ fromomni.isaac.lab.envsimportManagerBasedEnv
+
+
+
[文档]defrandomize_rigid_body_material(
+ env:ManagerBasedEnv,
+ env_ids:torch.Tensor|None,
+ static_friction_range:tuple[float,float],
+ dynamic_friction_range:tuple[float,float],
+ restitution_range:tuple[float,float],
+ num_buckets:int,
+ asset_cfg:SceneEntityCfg,
+):
+"""Randomize the physics materials on all geometries of the asset.
+
+ This function creates a set of physics materials with random static friction, dynamic friction, and restitution
+ values. The number of materials is specified by ``num_buckets``. The materials are generated by sampling
+ uniform random values from the given ranges.
+
+ The material properties are then assigned to the geometries of the asset. The assignment is done by
+ creating a random integer tensor of shape (num_instances, max_num_shapes) where ``num_instances``
+ is the number of assets spawned and ``max_num_shapes`` is the maximum number of shapes in the asset (over
+ all bodies). The integer values are used as indices to select the material properties from the
+ material buckets.
+
+ .. attention::
+ This function uses CPU tensors to assign the material properties. It is recommended to use this function
+ only during the initialization of the environment. Otherwise, it may lead to a significant performance
+ overhead.
+
+ .. note::
+ PhysX only allows 64000 unique physics materials in the scene. If the number of materials exceeds this
+ limit, the simulation will crash.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:RigidObject|Articulation=env.scene[asset_cfg.name]
+
+ ifnotisinstance(asset,(RigidObject,Articulation)):
+ raiseValueError(
+ f"Randomization term 'randomize_rigid_body_material' not supported for asset: '{asset_cfg.name}'"
+ f" with type: '{type(asset)}'."
+ )
+
+ # resolve environment ids
+ ifenv_idsisNone:
+ env_ids=torch.arange(env.scene.num_envs,device="cpu")
+ else:
+ env_ids=env_ids.cpu()
+
+ # retrieve material buffer
+ materials=asset.root_physx_view.get_material_properties()
+
+ # sample material properties from the given ranges
+ material_samples=np.zeros(materials[env_ids].shape)
+ material_samples[...,0]=np.random.uniform(*static_friction_range)
+ material_samples[...,1]=np.random.uniform(*dynamic_friction_range)
+ material_samples[...,2]=np.random.uniform(*restitution_range)
+
+ # create uniform range tensor for bucketing
+ lo=np.array([static_friction_range[0],dynamic_friction_range[0],restitution_range[0]])
+ hi=np.array([static_friction_range[1],dynamic_friction_range[1],restitution_range[1]])
+
+ # to avoid 64k material limit in physx, we bucket materials by binning randomized material properties
+ # into buckets based on the number of buckets specified
+ fordinrange(3):
+ buckets=np.array([(hi[d]-lo[d])*i/num_buckets+lo[d]foriinrange(num_buckets)])
+ material_samples[...,d]=buckets[np.searchsorted(buckets,material_samples[...,d])-1]
+
+ # update material buffer with new samples
+ ifisinstance(asset,Articulation)andasset_cfg.body_ids!=slice(None):
+ # obtain number of shapes per body (needed for indexing the material properties correctly)
+ # note: this is a workaround since the Articulation does not provide a direct way to obtain the number of shapes
+ # per body. We use the physics simulation view to obtain the number of shapes per body.
+ num_shapes_per_body=[]
+ forlink_pathinasset.root_physx_view.link_paths[0]:
+ link_physx_view=asset._physics_sim_view.create_rigid_body_view(link_path)# type: ignore
+ num_shapes_per_body.append(link_physx_view.max_shapes)
+
+ # sample material properties from the given ranges
+ forbody_idinasset_cfg.body_ids:
+ # start index of shape
+ start_idx=sum(num_shapes_per_body[:body_id])
+ # end index of shape
+ end_idx=start_idx+num_shapes_per_body[body_id]
+ # assign the new materials
+ # material ids are of shape: num_env_ids x num_shapes
+ # material_buckets are of shape: num_buckets x 3
+ materials[env_ids,start_idx:end_idx]=torch.from_numpy(material_samples[:,start_idx:end_idx]).to(
+ dtype=torch.float
+ )
+ else:
+ materials[env_ids]=torch.from_numpy(material_samples).to(dtype=torch.float)
+
+ # apply to simulation
+ asset.root_physx_view.set_material_properties(materials,env_ids)
+
+
+
[文档]defrandomize_rigid_body_mass(
+ env:ManagerBasedEnv,
+ env_ids:torch.Tensor|None,
+ asset_cfg:SceneEntityCfg,
+ mass_distribution_params:tuple[float,float],
+ operation:Literal["add","scale","abs"],
+ distribution:Literal["uniform","log_uniform","gaussian"]="uniform",
+ recompute_inertia:bool=True,
+):
+"""Randomize the mass of the bodies by adding, scaling, or setting random values.
+
+ This function allows randomizing the mass of the bodies of the asset. The function samples random values from the
+ given distribution parameters and adds, scales, or sets the values into the physics simulation based on the operation.
+
+ If the ``recompute_inertia`` flag is set to ``True``, the function recomputes the inertia tensor of the bodies
+ after setting the mass. This is useful when the mass is changed significantly, as the inertia tensor depends
+ on the mass. It assumes the body is a uniform density object. If the body is not a uniform density object,
+ the inertia tensor may not be accurate.
+
+ .. tip::
+ This function uses CPU tensors to assign the body masses. It is recommended to use this function
+ only during the initialization of the environment.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:RigidObject|Articulation=env.scene[asset_cfg.name]
+
+ # resolve environment ids
+ ifenv_idsisNone:
+ env_ids=torch.arange(env.scene.num_envs,device="cpu")
+ else:
+ env_ids=env_ids.cpu()
+
+ # resolve body indices
+ ifasset_cfg.body_ids==slice(None):
+ body_ids=torch.arange(asset.num_bodies,dtype=torch.int,device="cpu")
+ else:
+ body_ids=torch.tensor(asset_cfg.body_ids,dtype=torch.int,device="cpu")
+
+ # get the current masses of the bodies (num_assets, num_bodies)
+ masses=asset.root_physx_view.get_masses()
+
+ # apply randomization on default values
+ # this is to make sure when calling the function multiple times, the randomization is applied on the
+ # default values and not the previously randomized values
+ masses[env_ids[:,None],body_ids]=asset.data.default_mass[env_ids[:,None],body_ids].clone()
+
+ # sample from the given range
+ # note: we modify the masses in-place for all environments
+ # however, the setter takes care that only the masses of the specified environments are modified
+ masses=_randomize_prop_by_op(
+ masses,mass_distribution_params,env_ids,body_ids,operation=operation,distribution=distribution
+ )
+
+ # set the mass into the physics simulation
+ asset.root_physx_view.set_masses(masses,env_ids)
+
+ # recompute inertia tensors if needed
+ ifrecompute_inertia:
+ # compute the ratios of the new masses to the initial masses
+ ratios=masses[env_ids[:,None],body_ids]/asset.data.default_mass[env_ids[:,None],body_ids]
+ # scale the inertia tensors by the the ratios
+ # since mass randomization is done on default values, we can use the default inertia tensors
+ inertias=asset.root_physx_view.get_inertias()
+ ifisinstance(asset,Articulation):
+ # inertia has shape: (num_envs, num_bodies, 9) for articulation
+ inertias[env_ids[:,None],body_ids]=(
+ asset.data.default_inertia[env_ids[:,None],body_ids]*ratios[...,None]
+ )
+ else:
+ # inertia has shape: (num_envs, 9) for rigid object
+ inertias[env_ids]=asset.data.default_inertia[env_ids]*ratios
+ # set the inertia tensors into the physics simulation
+ asset.root_physx_view.set_inertias(inertias,env_ids)
+
+
+
[文档]defrandomize_physics_scene_gravity(
+ env:ManagerBasedEnv,
+ env_ids:torch.Tensor|None,
+ gravity_distribution_params:tuple[list[float],list[float]],
+ operation:Literal["add","scale","abs"],
+ distribution:Literal["uniform","log_uniform","gaussian"]="uniform",
+):
+"""Randomize gravity by adding, scaling, or setting random values.
+
+ This function allows randomizing gravity of the physics scene. The function samples random values from the
+ given distribution parameters and adds, scales, or sets the values into the physics simulation based on the
+ operation.
+
+ The distribution parameters are lists of two elements each, representing the lower and upper bounds of the
+ distribution for the x, y, and z components of the gravity vector. The function samples random values for each
+ component independently.
+
+ .. attention::
+ This function applied the same gravity for all the environments.
+
+ .. tip::
+ This function uses CPU tensors to assign gravity.
+ """
+ # get the current gravity
+ gravity=torch.tensor(env.sim.cfg.gravity,device="cpu").unsqueeze(0)
+ dist_param_0=torch.tensor(gravity_distribution_params[0],device="cpu")
+ dist_param_1=torch.tensor(gravity_distribution_params[1],device="cpu")
+ gravity=_randomize_prop_by_op(
+ gravity,
+ (dist_param_0,dist_param_1),
+ None,
+ slice(None),
+ operation=operation,
+ distribution=distribution,
+ )
+ # unbatch the gravity tensor into a list
+ gravity=gravity[0].tolist()
+
+ # set the gravity into the physics simulation
+ physics_sim_view:physx.SimulationView=sim_utils.SimulationContext.instance().physics_sim_view
+ physics_sim_view.set_gravity(carb.Float3(*gravity))
+
+
+
[文档]defrandomize_actuator_gains(
+ env:ManagerBasedEnv,
+ env_ids:torch.Tensor|None,
+ asset_cfg:SceneEntityCfg,
+ stiffness_distribution_params:tuple[float,float]|None=None,
+ damping_distribution_params:tuple[float,float]|None=None,
+ operation:Literal["add","scale","abs"]="abs",
+ distribution:Literal["uniform","log_uniform","gaussian"]="uniform",
+):
+"""Randomize the actuator gains in an articulation by adding, scaling, or setting random values.
+
+ This function allows randomizing the actuator stiffness and damping gains.
+
+ The function samples random values from the given distribution parameters and applies the operation to the joint properties.
+ It then sets the values into the actuator models. If the distribution parameters are not provided for a particular property,
+ the function does not modify the property.
+
+ .. tip::
+ For implicit actuators, this function uses CPU tensors to assign the actuator gains into the simulation.
+ In such cases, it is recommended to use this function only during the initialization of the environment.
+
+ Raises:
+ NotImplementedError: If the joint indices are in explicit motor mode. This operation is currently
+ not supported for explicit actuator models.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:Articulation=env.scene[asset_cfg.name]
+
+ # resolve environment ids
+ ifenv_idsisNone:
+ env_ids=torch.arange(env.scene.num_envs,device=asset.device)
+
+ # resolve joint indices
+ ifasset_cfg.joint_ids==slice(None):
+ joint_ids_list=range(asset.num_joints)
+ joint_ids=slice(None)# for optimization purposes
+ else:
+ joint_ids_list=asset_cfg.joint_ids
+ joint_ids=torch.tensor(asset_cfg.joint_ids,dtype=torch.int,device=asset.device)
+
+ # check if none of the joint indices are in explicit motor mode
+ forjoint_indexinjoint_ids_list:
+ foract_name,actuatorinasset.actuators.items():
+ # if joint indices are a slice (i.e., all joints are captured) or the joint index is in the actuator
+ ifactuator.joint_indices==slice(None)orjoint_indexinactuator.joint_indices:
+ ifnotisinstance(actuator,ImplicitActuator):
+ raiseNotImplementedError(
+ "Event term 'randomize_actuator_stiffness_and_damping' is performed on asset"
+ f" '{asset_cfg.name}' on the joint '{asset.joint_names[joint_index]}' ('{joint_index}') which"
+ f" uses an explicit actuator model '{act_name}<{actuator.__class__.__name__}>'. This operation"
+ " is currently not supported for explicit actuator models."
+ )
+
+ # sample joint properties from the given ranges and set into the physics simulation
+ # -- stiffness
+ ifstiffness_distribution_paramsisnotNone:
+ stiffness=asset.data.default_joint_stiffness.to(asset.device).clone()
+ stiffness=_randomize_prop_by_op(
+ stiffness,stiffness_distribution_params,env_ids,joint_ids,operation=operation,distribution=distribution
+ )[env_ids][:,joint_ids]
+ asset.write_joint_stiffness_to_sim(stiffness,joint_ids=joint_ids,env_ids=env_ids)
+ # -- damping
+ ifdamping_distribution_paramsisnotNone:
+ damping=asset.data.default_joint_damping.to(asset.device).clone()
+ damping=_randomize_prop_by_op(
+ damping,damping_distribution_params,env_ids,joint_ids,operation=operation,distribution=distribution
+ )[env_ids][:,joint_ids]
+ asset.write_joint_damping_to_sim(damping,joint_ids=joint_ids,env_ids=env_ids)
+
+
+
[文档]defrandomize_joint_parameters(
+ env:ManagerBasedEnv,
+ env_ids:torch.Tensor|None,
+ asset_cfg:SceneEntityCfg,
+ friction_distribution_params:tuple[float,float]|None=None,
+ armature_distribution_params:tuple[float,float]|None=None,
+ lower_limit_distribution_params:tuple[float,float]|None=None,
+ upper_limit_distribution_params:tuple[float,float]|None=None,
+ operation:Literal["add","scale","abs"]="abs",
+ distribution:Literal["uniform","log_uniform","gaussian"]="uniform",
+):
+"""Randomize the joint parameters of an articulation by adding, scaling, or setting random values.
+
+ This function allows randomizing the joint parameters of the asset.
+ These correspond to the physics engine joint properties that affect the joint behavior.
+
+ The function samples random values from the given distribution parameters and applies the operation to the joint properties.
+ It then sets the values into the physics simulation. If the distribution parameters are not provided for a
+ particular property, the function does not modify the property.
+
+ .. tip::
+ This function uses CPU tensors to assign the joint properties. It is recommended to use this function
+ only during the initialization of the environment.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:Articulation=env.scene[asset_cfg.name]
+
+ # resolve environment ids
+ ifenv_idsisNone:
+ env_ids=torch.arange(env.scene.num_envs,device=asset.device)
+
+ # resolve joint indices
+ ifasset_cfg.joint_ids==slice(None):
+ joint_ids=slice(None)# for optimization purposes
+ else:
+ joint_ids=torch.tensor(asset_cfg.joint_ids,dtype=torch.int,device=asset.device)
+
+ # sample joint properties from the given ranges and set into the physics simulation
+ # -- friction
+ iffriction_distribution_paramsisnotNone:
+ friction=asset.data.default_joint_friction.to(asset.device).clone()
+ friction=_randomize_prop_by_op(
+ friction,friction_distribution_params,env_ids,joint_ids,operation=operation,distribution=distribution
+ )[env_ids][:,joint_ids]
+ asset.write_joint_friction_to_sim(friction,joint_ids=joint_ids,env_ids=env_ids)
+ # -- armature
+ ifarmature_distribution_paramsisnotNone:
+ armature=asset.data.default_joint_armature.to(asset.device).clone()
+ armature=_randomize_prop_by_op(
+ armature,armature_distribution_params,env_ids,joint_ids,operation=operation,distribution=distribution
+ )[env_ids][:,joint_ids]
+ asset.write_joint_armature_to_sim(armature,joint_ids=joint_ids,env_ids=env_ids)
+ # -- dof limits
+ iflower_limit_distribution_paramsisnotNoneorupper_limit_distribution_paramsisnotNone:
+ dof_limits=asset.data.default_joint_limits.to(asset.device).clone()
+ iflower_limit_distribution_paramsisnotNone:
+ lower_limits=dof_limits[...,0]
+ lower_limits=_randomize_prop_by_op(
+ lower_limits,
+ lower_limit_distribution_params,
+ env_ids,
+ joint_ids,
+ operation=operation,
+ distribution=distribution,
+ )[env_ids][:,joint_ids]
+ dof_limits[env_ids[:,None],joint_ids,0]=lower_limits
+ ifupper_limit_distribution_paramsisnotNone:
+ upper_limits=dof_limits[...,1]
+ upper_limits=_randomize_prop_by_op(
+ upper_limits,
+ upper_limit_distribution_params,
+ env_ids,
+ joint_ids,
+ operation=operation,
+ distribution=distribution,
+ )[env_ids][:,joint_ids]
+ dof_limits[env_ids[:,None],joint_ids,1]=upper_limits
+ if(dof_limits[env_ids[:,None],joint_ids,0]>dof_limits[env_ids[:,None],joint_ids,1]).any():
+ raiseValueError(
+ "Randomization term 'randomize_joint_parameters' is setting lower joint limits that are greater than"
+ " upper joint limits."
+ )
+
+ asset.write_joint_limits_to_sim(dof_limits[env_ids][:,joint_ids],joint_ids=joint_ids,env_ids=env_ids)
+
+
+
[文档]defrandomize_fixed_tendon_parameters(
+ env:ManagerBasedEnv,
+ env_ids:torch.Tensor|None,
+ asset_cfg:SceneEntityCfg,
+ stiffness_distribution_params:tuple[float,float]|None=None,
+ damping_distribution_params:tuple[float,float]|None=None,
+ limit_stiffness_distribution_params:tuple[float,float]|None=None,
+ lower_limit_distribution_params:tuple[float,float]|None=None,
+ upper_limit_distribution_params:tuple[float,float]|None=None,
+ rest_length_distribution_params:tuple[float,float]|None=None,
+ offset_distribution_params:tuple[float,float]|None=None,
+ operation:Literal["add","scale","abs"]="abs",
+ distribution:Literal["uniform","log_uniform","gaussian"]="uniform",
+):
+"""Randomize the fixed tendon parameters of an articulation by adding, scaling, or setting random values.
+
+ This function allows randomizing the fixed tendon parameters of the asset.
+ These correspond to the physics engine tendon properties that affect the joint behavior.
+
+ The function samples random values from the given distribution parameters and applies the operation to the tendon properties.
+ It then sets the values into the physics simulation. If the distribution parameters are not provided for a
+ particular property, the function does not modify the property.
+
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:Articulation=env.scene[asset_cfg.name]
+
+ # resolve environment ids
+ ifenv_idsisNone:
+ env_ids=torch.arange(env.scene.num_envs,device=asset.device)
+
+ # resolve joint indices
+ ifasset_cfg.fixed_tendon_ids==slice(None):
+ fixed_tendon_ids=slice(None)# for optimization purposes
+ else:
+ fixed_tendon_ids=torch.tensor(asset_cfg.fixed_tendon_ids,dtype=torch.int,device=asset.device)
+
+ # sample tendon properties from the given ranges and set into the physics simulation
+ # -- stiffness
+ ifstiffness_distribution_paramsisnotNone:
+ stiffness=asset.data.default_fixed_tendon_stiffness.clone()
+ stiffness=_randomize_prop_by_op(
+ stiffness,
+ stiffness_distribution_params,
+ env_ids,
+ fixed_tendon_ids,
+ operation=operation,
+ distribution=distribution,
+ )[env_ids][:,fixed_tendon_ids]
+ asset.set_fixed_tendon_stiffness(stiffness,fixed_tendon_ids,env_ids)
+ # -- damping
+ ifdamping_distribution_paramsisnotNone:
+ damping=asset.data.default_fixed_tendon_damping.clone()
+ damping=_randomize_prop_by_op(
+ damping,
+ damping_distribution_params,
+ env_ids,
+ fixed_tendon_ids,
+ operation=operation,
+ distribution=distribution,
+ )[env_ids][:,fixed_tendon_ids]
+ asset.set_fixed_tendon_damping(damping,fixed_tendon_ids,env_ids)
+ # -- limit stiffness
+ iflimit_stiffness_distribution_paramsisnotNone:
+ limit_stiffness=asset.data.default_fixed_tendon_limit_stiffness.clone()
+ limit_stiffness=_randomize_prop_by_op(
+ limit_stiffness,
+ limit_stiffness_distribution_params,
+ env_ids,
+ fixed_tendon_ids,
+ operation=operation,
+ distribution=distribution,
+ )[env_ids][:,fixed_tendon_ids]
+ asset.set_fixed_tendon_limit_stiffness(limit_stiffness,fixed_tendon_ids,env_ids)
+ # -- limits
+ iflower_limit_distribution_paramsisnotNoneorupper_limit_distribution_paramsisnotNone:
+ limit=asset.data.default_fixed_tendon_limit.clone()
+ # -- lower limit
+ iflower_limit_distribution_paramsisnotNone:
+ lower_limit=limit[...,0]
+ lower_limit=_randomize_prop_by_op(
+ lower_limit,
+ lower_limit_distribution_params,
+ env_ids,
+ fixed_tendon_ids,
+ operation=operation,
+ distribution=distribution,
+ )[env_ids][:,fixed_tendon_ids]
+ limit[env_ids[:,None],fixed_tendon_ids,0]=lower_limit
+ # -- upper limit
+ ifupper_limit_distribution_paramsisnotNone:
+ upper_limit=limit[...,1]
+ upper_limit=_randomize_prop_by_op(
+ upper_limit,
+ upper_limit_distribution_params,
+ env_ids,
+ fixed_tendon_ids,
+ operation=operation,
+ distribution=distribution,
+ )[env_ids][:,fixed_tendon_ids]
+ limit[env_ids[:,None],fixed_tendon_ids,1]=upper_limit
+ if(limit[env_ids[:,None],fixed_tendon_ids,0]>limit[env_ids[:,None],fixed_tendon_ids,1]).any():
+ raiseValueError(
+ "Randomization term 'randomize_fixed_tendon_parameters' is setting lower tendon limits that are greater"
+ " than upper tendon limits."
+ )
+ asset.set_fixed_tendon_limit(limit,fixed_tendon_ids,env_ids)
+ # -- rest length
+ ifrest_length_distribution_paramsisnotNone:
+ rest_length=asset.data.default_fixed_tendon_rest_length.clone()
+ rest_length=_randomize_prop_by_op(
+ rest_length,
+ rest_length_distribution_params,
+ env_ids,
+ fixed_tendon_ids,
+ operation=operation,
+ distribution=distribution,
+ )[env_ids][:,fixed_tendon_ids]
+ asset.set_fixed_tendon_rest_length(rest_length,fixed_tendon_ids,env_ids)
+ # -- offset
+ ifoffset_distribution_paramsisnotNone:
+ offset=asset.data.default_fixed_tendon_offset.clone()
+ offset=_randomize_prop_by_op(
+ offset,
+ offset_distribution_params,
+ env_ids,
+ fixed_tendon_ids,
+ operation=operation,
+ distribution=distribution,
+ )[env_ids][:,fixed_tendon_ids]
+ asset.set_fixed_tendon_offset(offset,fixed_tendon_ids,env_ids)
+
+ asset.write_fixed_tendon_properties_to_sim(fixed_tendon_ids,env_ids)
+
+
+
[文档]defapply_external_force_torque(
+ env:ManagerBasedEnv,
+ env_ids:torch.Tensor,
+ force_range:tuple[float,float],
+ torque_range:tuple[float,float],
+ asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"),
+):
+"""Randomize the external forces and torques applied to the bodies.
+
+ This function creates a set of random forces and torques sampled from the given ranges. The number of forces
+ and torques is equal to the number of bodies times the number of environments. The forces and torques are
+ applied to the bodies by calling ``asset.set_external_force_and_torque``. The forces and torques are only
+ applied when ``asset.write_data_to_sim()`` is called in the environment.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:RigidObject|Articulation=env.scene[asset_cfg.name]
+ # resolve environment ids
+ ifenv_idsisNone:
+ env_ids=torch.arange(env.scene.num_envs,device=asset.device)
+ # resolve number of bodies
+ num_bodies=len(asset_cfg.body_ids)ifisinstance(asset_cfg.body_ids,list)elseasset.num_bodies
+
+ # sample random forces and torques
+ size=(len(env_ids),num_bodies,3)
+ forces=math_utils.sample_uniform(*force_range,size,asset.device)
+ torques=math_utils.sample_uniform(*torque_range,size,asset.device)
+ # set the forces and torques into the buffers
+ # note: these are only applied when you call: `asset.write_data_to_sim()`
+ asset.set_external_force_and_torque(forces,torques,env_ids=env_ids,body_ids=asset_cfg.body_ids)
+
+
+
[文档]defpush_by_setting_velocity(
+ env:ManagerBasedEnv,
+ env_ids:torch.Tensor,
+ velocity_range:dict[str,tuple[float,float]],
+ asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"),
+):
+"""Push the asset by setting the root velocity to a random value within the given ranges.
+
+ This creates an effect similar to pushing the asset with a random impulse that changes the asset's velocity.
+ It samples the root velocity from the given ranges and sets the velocity into the physics simulation.
+
+ The function takes a dictionary of velocity ranges for each axis and rotation. The keys of the dictionary
+ are ``x``, ``y``, ``z``, ``roll``, ``pitch``, and ``yaw``. The values are tuples of the form ``(min, max)``.
+ If the dictionary does not contain a key, the velocity is set to zero for that axis.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:RigidObject|Articulation=env.scene[asset_cfg.name]
+
+ # velocities
+ vel_w=asset.data.root_vel_w[env_ids]
+ # sample random velocities
+ range_list=[velocity_range.get(key,(0.0,0.0))forkeyin["x","y","z","roll","pitch","yaw"]]
+ ranges=torch.tensor(range_list,device=asset.device)
+ vel_w[:]=math_utils.sample_uniform(ranges[:,0],ranges[:,1],vel_w.shape,device=asset.device)
+ # set the velocities into the physics simulation
+ asset.write_root_velocity_to_sim(vel_w,env_ids=env_ids)
+
+
+
[文档]defreset_root_state_uniform(
+ env:ManagerBasedEnv,
+ env_ids:torch.Tensor,
+ pose_range:dict[str,tuple[float,float]],
+ velocity_range:dict[str,tuple[float,float]],
+ asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"),
+):
+"""Reset the asset root state to a random position and velocity uniformly within the given ranges.
+
+ This function randomizes the root position and velocity of the asset.
+
+ * It samples the root position from the given ranges and adds them to the default root position, before setting
+ them into the physics simulation.
+ * It samples the root orientation from the given ranges and sets them into the physics simulation.
+ * It samples the root velocity from the given ranges and sets them into the physics simulation.
+
+ The function takes a dictionary of pose and velocity ranges for each axis and rotation. The keys of the
+ dictionary are ``x``, ``y``, ``z``, ``roll``, ``pitch``, and ``yaw``. The values are tuples of the form
+ ``(min, max)``. If the dictionary does not contain a key, the position or velocity is set to zero for that axis.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:RigidObject|Articulation=env.scene[asset_cfg.name]
+ # get default root state
+ root_states=asset.data.default_root_state[env_ids].clone()
+
+ # poses
+ range_list=[pose_range.get(key,(0.0,0.0))forkeyin["x","y","z","roll","pitch","yaw"]]
+ ranges=torch.tensor(range_list,device=asset.device)
+ rand_samples=math_utils.sample_uniform(ranges[:,0],ranges[:,1],(len(env_ids),6),device=asset.device)
+
+ positions=root_states[:,0:3]+env.scene.env_origins[env_ids]+rand_samples[:,0:3]
+ orientations_delta=math_utils.quat_from_euler_xyz(rand_samples[:,3],rand_samples[:,4],rand_samples[:,5])
+ orientations=math_utils.quat_mul(root_states[:,3:7],orientations_delta)
+ # velocities
+ range_list=[velocity_range.get(key,(0.0,0.0))forkeyin["x","y","z","roll","pitch","yaw"]]
+ ranges=torch.tensor(range_list,device=asset.device)
+ rand_samples=math_utils.sample_uniform(ranges[:,0],ranges[:,1],(len(env_ids),6),device=asset.device)
+
+ velocities=root_states[:,7:13]+rand_samples
+
+ # set into the physics simulation
+ asset.write_root_pose_to_sim(torch.cat([positions,orientations],dim=-1),env_ids=env_ids)
+ asset.write_root_velocity_to_sim(velocities,env_ids=env_ids)
+
+
+
[文档]defreset_root_state_with_random_orientation(
+ env:ManagerBasedEnv,
+ env_ids:torch.Tensor,
+ pose_range:dict[str,tuple[float,float]],
+ velocity_range:dict[str,tuple[float,float]],
+ asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"),
+):
+"""Reset the asset root position and velocities sampled randomly within the given ranges
+ and the asset root orientation sampled randomly from the SO(3).
+
+ This function randomizes the root position and velocity of the asset.
+
+ * It samples the root position from the given ranges and adds them to the default root position, before setting
+ them into the physics simulation.
+ * It samples the root orientation uniformly from the SO(3) and sets them into the physics simulation.
+ * It samples the root velocity from the given ranges and sets them into the physics simulation.
+
+ The function takes a dictionary of position and velocity ranges for each axis and rotation:
+
+ * :attr:`pose_range` - a dictionary of position ranges for each axis. The keys of the dictionary are ``x``,
+ ``y``, and ``z``. The orientation is sampled uniformly from the SO(3).
+ * :attr:`velocity_range` - a dictionary of velocity ranges for each axis and rotation. The keys of the dictionary
+ are ``x``, ``y``, ``z``, ``roll``, ``pitch``, and ``yaw``.
+
+ The values are tuples of the form ``(min, max)``. If the dictionary does not contain a particular key,
+ the position is set to zero for that axis.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:RigidObject|Articulation=env.scene[asset_cfg.name]
+ # get default root state
+ root_states=asset.data.default_root_state[env_ids].clone()
+
+ # poses
+ range_list=[pose_range.get(key,(0.0,0.0))forkeyin["x","y","z"]]
+ ranges=torch.tensor(range_list,device=asset.device)
+ rand_samples=math_utils.sample_uniform(ranges[:,0],ranges[:,1],(len(env_ids),3),device=asset.device)
+
+ positions=root_states[:,0:3]+env.scene.env_origins[env_ids]+rand_samples
+ orientations=math_utils.random_orientation(len(env_ids),device=asset.device)
+
+ # velocities
+ range_list=[velocity_range.get(key,(0.0,0.0))forkeyin["x","y","z","roll","pitch","yaw"]]
+ ranges=torch.tensor(range_list,device=asset.device)
+ rand_samples=math_utils.sample_uniform(ranges[:,0],ranges[:,1],(len(env_ids),6),device=asset.device)
+
+ velocities=root_states[:,7:13]+rand_samples
+
+ # set into the physics simulation
+ asset.write_root_pose_to_sim(torch.cat([positions,orientations],dim=-1),env_ids=env_ids)
+ asset.write_root_velocity_to_sim(velocities,env_ids=env_ids)
+
+
+
[文档]defreset_root_state_from_terrain(
+ env:ManagerBasedEnv,
+ env_ids:torch.Tensor,
+ pose_range:dict[str,tuple[float,float]],
+ velocity_range:dict[str,tuple[float,float]],
+ asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"),
+):
+"""Reset the asset root state by sampling a random valid pose from the terrain.
+
+ This function samples a random valid pose(based on flat patches) from the terrain and sets the root state
+ of the asset to this position. The function also samples random velocities from the given ranges and sets them
+ into the physics simulation.
+
+ The function takes a dictionary of position and velocity ranges for each axis and rotation:
+
+ * :attr:`pose_range` - a dictionary of pose ranges for each axis. The keys of the dictionary are ``roll``,
+ ``pitch``, and ``yaw``. The position is sampled from the flat patches of the terrain.
+ * :attr:`velocity_range` - a dictionary of velocity ranges for each axis and rotation. The keys of the dictionary
+ are ``x``, ``y``, ``z``, ``roll``, ``pitch``, and ``yaw``.
+
+ The values are tuples of the form ``(min, max)``. If the dictionary does not contain a particular key,
+ the position is set to zero for that axis.
+
+ Note:
+ The function expects the terrain to have valid flat patches under the key "init_pos". The flat patches
+ are used to sample the random pose for the robot.
+
+ Raises:
+ ValueError: If the terrain does not have valid flat patches under the key "init_pos".
+ """
+ # access the used quantities (to enable type-hinting)
+ asset:RigidObject|Articulation=env.scene[asset_cfg.name]
+ terrain:TerrainImporter=env.scene.terrain
+
+ # obtain all flat patches corresponding to the valid poses
+ valid_positions:torch.Tensor=terrain.flat_patches.get("init_pos")
+ ifvalid_positionsisNone:
+ raiseValueError(
+ "The event term 'reset_root_state_from_terrain' requires valid flat patches under 'init_pos'."
+ f" Found: {list(terrain.flat_patches.keys())}"
+ )
+
+ # sample random valid poses
+ ids=torch.randint(0,valid_positions.shape[2],size=(len(env_ids),),device=env.device)
+ positions=valid_positions[terrain.terrain_levels[env_ids],terrain.terrain_types[env_ids],ids]
+ positions+=asset.data.default_root_state[env_ids,:3]
+
+ # sample random orientations
+ range_list=[pose_range.get(key,(0.0,0.0))forkeyin["roll","pitch","yaw"]]
+ ranges=torch.tensor(range_list,device=asset.device)
+ rand_samples=math_utils.sample_uniform(ranges[:,0],ranges[:,1],(len(env_ids),3),device=asset.device)
+
+ # convert to quaternions
+ orientations=math_utils.quat_from_euler_xyz(rand_samples[:,0],rand_samples[:,1],rand_samples[:,2])
+
+ # sample random velocities
+ range_list=[velocity_range.get(key,(0.0,0.0))forkeyin["x","y","z","roll","pitch","yaw"]]
+ ranges=torch.tensor(range_list,device=asset.device)
+ rand_samples=math_utils.sample_uniform(ranges[:,0],ranges[:,1],(len(env_ids),6),device=asset.device)
+
+ velocities=asset.data.default_root_state[:,7:13]+rand_samples
+
+ # set into the physics simulation
+ asset.write_root_pose_to_sim(torch.cat([positions,orientations],dim=-1),env_ids=env_ids)
+ asset.write_root_velocity_to_sim(velocities,env_ids=env_ids)
+
+
+
[文档]defreset_joints_by_scale(
+ env:ManagerBasedEnv,
+ env_ids:torch.Tensor,
+ position_range:tuple[float,float],
+ velocity_range:tuple[float,float],
+ asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"),
+):
+"""Reset the robot joints by scaling the default position and velocity by the given ranges.
+
+ This function samples random values from the given ranges and scales the default joint positions and velocities
+ by these values. The scaled values are then set into the physics simulation.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:Articulation=env.scene[asset_cfg.name]
+ # get default joint state
+ joint_pos=asset.data.default_joint_pos[env_ids].clone()
+ joint_vel=asset.data.default_joint_vel[env_ids].clone()
+
+ # scale these values randomly
+ joint_pos*=math_utils.sample_uniform(*position_range,joint_pos.shape,joint_pos.device)
+ joint_vel*=math_utils.sample_uniform(*velocity_range,joint_vel.shape,joint_vel.device)
+
+ # clamp joint pos to limits
+ joint_pos_limits=asset.data.soft_joint_pos_limits[env_ids]
+ joint_pos=joint_pos.clamp_(joint_pos_limits[...,0],joint_pos_limits[...,1])
+ # clamp joint vel to limits
+ joint_vel_limits=asset.data.soft_joint_vel_limits[env_ids]
+ joint_vel=joint_vel.clamp_(-joint_vel_limits,joint_vel_limits)
+
+ # set into the physics simulation
+ asset.write_joint_state_to_sim(joint_pos,joint_vel,env_ids=env_ids)
+
+
+
[文档]defreset_joints_by_offset(
+ env:ManagerBasedEnv,
+ env_ids:torch.Tensor,
+ position_range:tuple[float,float],
+ velocity_range:tuple[float,float],
+ asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"),
+):
+"""Reset the robot joints with offsets around the default position and velocity by the given ranges.
+
+ This function samples random values from the given ranges and biases the default joint positions and velocities
+ by these values. The biased values are then set into the physics simulation.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:Articulation=env.scene[asset_cfg.name]
+
+ # get default joint state
+ joint_pos=asset.data.default_joint_pos[env_ids].clone()
+ joint_vel=asset.data.default_joint_vel[env_ids].clone()
+
+ # bias these values randomly
+ joint_pos+=math_utils.sample_uniform(*position_range,joint_pos.shape,joint_pos.device)
+ joint_vel+=math_utils.sample_uniform(*velocity_range,joint_vel.shape,joint_vel.device)
+
+ # clamp joint pos to limits
+ joint_pos_limits=asset.data.soft_joint_pos_limits[env_ids]
+ joint_pos=joint_pos.clamp_(joint_pos_limits[...,0],joint_pos_limits[...,1])
+ # clamp joint vel to limits
+ joint_vel_limits=asset.data.soft_joint_vel_limits[env_ids]
+ joint_vel=joint_vel.clamp_(-joint_vel_limits,joint_vel_limits)
+
+ # set into the physics simulation
+ asset.write_joint_state_to_sim(joint_pos,joint_vel,env_ids=env_ids)
+
+
+
[文档]defreset_nodal_state_uniform(
+ env:ManagerBasedEnv,
+ env_ids:torch.Tensor,
+ position_range:dict[str,tuple[float,float]],
+ velocity_range:dict[str,tuple[float,float]],
+ asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"),
+):
+"""Reset the asset nodal state to a random position and velocity uniformly within the given ranges.
+
+ This function randomizes the nodal position and velocity of the asset.
+
+ * It samples the root position from the given ranges and adds them to the default nodal position, before setting
+ them into the physics simulation.
+ * It samples the root velocity from the given ranges and sets them into the physics simulation.
+
+ The function takes a dictionary of position and velocity ranges for each axis. The keys of the
+ dictionary are ``x``, ``y``, ``z``. The values are tuples of the form ``(min, max)``.
+ If the dictionary does not contain a key, the position or velocity is set to zero for that axis.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:DeformableObject=env.scene[asset_cfg.name]
+ # get default root state
+ nodal_state=asset.data.default_nodal_state_w[env_ids].clone()
+
+ # position
+ range_list=[position_range.get(key,(0.0,0.0))forkeyin["x","y","z"]]
+ ranges=torch.tensor(range_list,device=asset.device)
+ rand_samples=math_utils.sample_uniform(ranges[:,0],ranges[:,1],(len(env_ids),1,3),device=asset.device)
+
+ nodal_state[...,:3]+=rand_samples
+
+ # velocities
+ range_list=[velocity_range.get(key,(0.0,0.0))forkeyin["x","y","z"]]
+ ranges=torch.tensor(range_list,device=asset.device)
+ rand_samples=math_utils.sample_uniform(ranges[:,0],ranges[:,1],(len(env_ids),1,3),device=asset.device)
+
+ nodal_state[...,3:]+=rand_samples
+
+ # set into the physics simulation
+ asset.write_nodal_state_to_sim(nodal_state,env_ids=env_ids)
+
+
+
[文档]defreset_scene_to_default(env:ManagerBasedEnv,env_ids:torch.Tensor):
+"""Reset the scene to the default state specified in the scene configuration."""
+ # rigid bodies
+ forrigid_objectinenv.scene.rigid_objects.values():
+ # obtain default and deal with the offset for env origins
+ default_root_state=rigid_object.data.default_root_state[env_ids].clone()
+ default_root_state[:,0:3]+=env.scene.env_origins[env_ids]
+ # set into the physics simulation
+ rigid_object.write_root_state_to_sim(default_root_state,env_ids=env_ids)
+ # articulations
+ forarticulation_assetinenv.scene.articulations.values():
+ # obtain default and deal with the offset for env origins
+ default_root_state=articulation_asset.data.default_root_state[env_ids].clone()
+ default_root_state[:,0:3]+=env.scene.env_origins[env_ids]
+ # set into the physics simulation
+ articulation_asset.write_root_state_to_sim(default_root_state,env_ids=env_ids)
+ # obtain default joint positions
+ default_joint_pos=articulation_asset.data.default_joint_pos[env_ids].clone()
+ default_joint_vel=articulation_asset.data.default_joint_vel[env_ids].clone()
+ # set into the physics simulation
+ articulation_asset.write_joint_state_to_sim(default_joint_pos,default_joint_vel,env_ids=env_ids)
+ # deformable objects
+ fordeformable_objectinenv.scene.deformable_objects.values():
+ # obtain default and set into the physics simulation
+ nodal_state=deformable_object.data.default_nodal_state_w[env_ids].clone()
+ deformable_object.write_nodal_state_to_sim(nodal_state,env_ids=env_ids)
+
+
+"""
+Internal helper functions.
+"""
+
+
+def_randomize_prop_by_op(
+ data:torch.Tensor,
+ distribution_parameters:tuple[float|torch.Tensor,float|torch.Tensor],
+ dim_0_ids:torch.Tensor|None,
+ dim_1_ids:torch.Tensor|slice,
+ operation:Literal["add","scale","abs"],
+ distribution:Literal["uniform","log_uniform","gaussian"],
+)->torch.Tensor:
+"""Perform data randomization based on the given operation and distribution.
+
+ Args:
+ data: The data tensor to be randomized. Shape is (dim_0, dim_1).
+ distribution_parameters: The parameters for the distribution to sample values from.
+ dim_0_ids: The indices of the first dimension to randomize.
+ dim_1_ids: The indices of the second dimension to randomize.
+ operation: The operation to perform on the data. Options: 'add', 'scale', 'abs'.
+ distribution: The distribution to sample the random values from. Options: 'uniform', 'log_uniform'.
+
+ Returns:
+ The data tensor after randomization. Shape is (dim_0, dim_1).
+
+ Raises:
+ NotImplementedError: If the operation or distribution is not supported.
+ """
+ # resolve shape
+ # -- dim 0
+ ifdim_0_idsisNone:
+ n_dim_0=data.shape[0]
+ dim_0_ids=slice(None)
+ else:
+ n_dim_0=len(dim_0_ids)
+ dim_0_ids=dim_0_ids[:,None]
+ # -- dim 1
+ ifisinstance(dim_1_ids,slice):
+ n_dim_1=data.shape[1]
+ else:
+ n_dim_1=len(dim_1_ids)
+
+ # resolve the distribution
+ ifdistribution=="uniform":
+ dist_fn=math_utils.sample_uniform
+ elifdistribution=="log_uniform":
+ dist_fn=math_utils.sample_log_uniform
+ elifdistribution=="gaussian":
+ dist_fn=math_utils.sample_gaussian
+ else:
+ raiseNotImplementedError(
+ f"Unknown distribution: '{distribution}' for joint properties randomization."
+ " Please use 'uniform', 'log_uniform', 'gaussian'."
+ )
+ # perform the operation
+ ifoperation=="add":
+ data[dim_0_ids,dim_1_ids]+=dist_fn(*distribution_parameters,(n_dim_0,n_dim_1),device=data.device)
+ elifoperation=="scale":
+ data[dim_0_ids,dim_1_ids]*=dist_fn(*distribution_parameters,(n_dim_0,n_dim_1),device=data.device)
+ elifoperation=="abs":
+ data[dim_0_ids,dim_1_ids]=dist_fn(*distribution_parameters,(n_dim_0,n_dim_1),device=data.device)
+ else:
+ raiseNotImplementedError(
+ f"Unknown operation: '{operation}' for property randomization. Please use 'add', 'scale', or 'abs'."
+ )
+ returndata
+
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Common functions that can be used to create observation terms.
+
+The functions can be passed to the :class:`omni.isaac.lab.managers.ObservationTermCfg` object to enable
+the observation introduced by the function.
+"""
+
+from__future__importannotations
+
+importtorch
+fromtypingimportTYPE_CHECKING
+
+importomni.isaac.lab.utils.mathasmath_utils
+fromomni.isaac.lab.assetsimportArticulation,RigidObject
+fromomni.isaac.lab.managersimportSceneEntityCfg
+fromomni.isaac.lab.sensorsimportCamera,RayCaster,RayCasterCamera,TiledCamera
+
+ifTYPE_CHECKING:
+ fromomni.isaac.lab.envsimportManagerBasedEnv,ManagerBasedRLEnv
+
+"""
+Root state.
+"""
+
+
+
[文档]defbase_pos_z(env:ManagerBasedEnv,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"))->torch.Tensor:
+"""Root height in the simulation world frame."""
+ # extract the used quantities (to enable type-hinting)
+ asset:Articulation=env.scene[asset_cfg.name]
+ returnasset.data.root_pos_w[:,2].unsqueeze(-1)
+
+
+
[文档]defbase_lin_vel(env:ManagerBasedEnv,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"))->torch.Tensor:
+"""Root linear velocity in the asset's root frame."""
+ # extract the used quantities (to enable type-hinting)
+ asset:RigidObject=env.scene[asset_cfg.name]
+ returnasset.data.root_lin_vel_b
+
+
+
[文档]defbase_ang_vel(env:ManagerBasedEnv,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"))->torch.Tensor:
+"""Root angular velocity in the asset's root frame."""
+ # extract the used quantities (to enable type-hinting)
+ asset:RigidObject=env.scene[asset_cfg.name]
+ returnasset.data.root_ang_vel_b
+
+
+
[文档]defprojected_gravity(env:ManagerBasedEnv,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"))->torch.Tensor:
+"""Gravity projection on the asset's root frame."""
+ # extract the used quantities (to enable type-hinting)
+ asset:RigidObject=env.scene[asset_cfg.name]
+ returnasset.data.projected_gravity_b
+
+
+
[文档]defroot_pos_w(env:ManagerBasedEnv,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"))->torch.Tensor:
+"""Asset root position in the environment frame."""
+ # extract the used quantities (to enable type-hinting)
+ asset:RigidObject=env.scene[asset_cfg.name]
+ returnasset.data.root_pos_w-env.scene.env_origins
+
+
+
[文档]defroot_quat_w(
+ env:ManagerBasedEnv,make_quat_unique:bool=False,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot")
+)->torch.Tensor:
+"""Asset root orientation (w, x, y, z) in the environment frame.
+
+ If :attr:`make_quat_unique` is True, then returned quaternion is made unique by ensuring
+ the quaternion has non-negative real component. This is because both ``q`` and ``-q`` represent
+ the same orientation.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:RigidObject=env.scene[asset_cfg.name]
+
+ quat=asset.data.root_quat_w
+ # make the quaternion real-part positive if configured
+ returnmath_utils.quat_unique(quat)ifmake_quat_uniqueelsequat
+
+
+
[文档]defroot_lin_vel_w(env:ManagerBasedEnv,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"))->torch.Tensor:
+"""Asset root linear velocity in the environment frame."""
+ # extract the used quantities (to enable type-hinting)
+ asset:RigidObject=env.scene[asset_cfg.name]
+ returnasset.data.root_lin_vel_w
+
+
+
[文档]defroot_ang_vel_w(env:ManagerBasedEnv,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"))->torch.Tensor:
+"""Asset root angular velocity in the environment frame."""
+ # extract the used quantities (to enable type-hinting)
+ asset:RigidObject=env.scene[asset_cfg.name]
+ returnasset.data.root_ang_vel_w
+
+
+"""
+Joint state.
+"""
+
+
+
[文档]defjoint_pos(env:ManagerBasedEnv,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"))->torch.Tensor:
+"""The joint positions of the asset.
+
+ Note: Only the joints configured in :attr:`asset_cfg.joint_ids` will have their positions returned.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:Articulation=env.scene[asset_cfg.name]
+ returnasset.data.joint_pos[:,asset_cfg.joint_ids]
+
+
+
[文档]defjoint_pos_rel(env:ManagerBasedEnv,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"))->torch.Tensor:
+"""The joint positions of the asset w.r.t. the default joint positions.
+
+ Note: Only the joints configured in :attr:`asset_cfg.joint_ids` will have their positions returned.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:Articulation=env.scene[asset_cfg.name]
+ returnasset.data.joint_pos[:,asset_cfg.joint_ids]-asset.data.default_joint_pos[:,asset_cfg.joint_ids]
+
+
+
[文档]defjoint_pos_limit_normalized(
+ env:ManagerBasedEnv,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot")
+)->torch.Tensor:
+"""The joint positions of the asset normalized with the asset's joint limits.
+
+ Note: Only the joints configured in :attr:`asset_cfg.joint_ids` will have their normalized positions returned.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:Articulation=env.scene[asset_cfg.name]
+ returnmath_utils.scale_transform(
+ asset.data.joint_pos[:,asset_cfg.joint_ids],
+ asset.data.soft_joint_pos_limits[:,asset_cfg.joint_ids,0],
+ asset.data.soft_joint_pos_limits[:,asset_cfg.joint_ids,1],
+ )
+
+
+
[文档]defjoint_vel(env:ManagerBasedEnv,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot")):
+"""The joint velocities of the asset.
+
+ Note: Only the joints configured in :attr:`asset_cfg.joint_ids` will have their velocities returned.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:Articulation=env.scene[asset_cfg.name]
+ returnasset.data.joint_vel[:,asset_cfg.joint_ids]
+
+
+
[文档]defjoint_vel_rel(env:ManagerBasedEnv,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot")):
+"""The joint velocities of the asset w.r.t. the default joint velocities.
+
+ Note: Only the joints configured in :attr:`asset_cfg.joint_ids` will have their velocities returned.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:Articulation=env.scene[asset_cfg.name]
+ returnasset.data.joint_vel[:,asset_cfg.joint_ids]-asset.data.default_joint_vel[:,asset_cfg.joint_ids]
+
+
+"""
+Sensors.
+"""
+
+
+
[文档]defheight_scan(env:ManagerBasedEnv,sensor_cfg:SceneEntityCfg,offset:float=0.5)->torch.Tensor:
+"""Height scan from the given sensor w.r.t. the sensor's frame.
+
+ The provided offset (Defaults to 0.5) is subtracted from the returned values.
+ """
+ # extract the used quantities (to enable type-hinting)
+ sensor:RayCaster=env.scene.sensors[sensor_cfg.name]
+ # height scan: height = sensor_height - hit_point_z - offset
+ returnsensor.data.pos_w[:,2].unsqueeze(1)-sensor.data.ray_hits_w[...,2]-offset
+
+
+
[文档]defbody_incoming_wrench(env:ManagerBasedEnv,asset_cfg:SceneEntityCfg)->torch.Tensor:
+"""Incoming spatial wrench on bodies of an articulation in the simulation world frame.
+
+ This is the 6-D wrench (force and torque) applied to the body link by the incoming joint force.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:Articulation=env.scene[asset_cfg.name]
+ # obtain the link incoming forces in world frame
+ link_incoming_forces=asset.root_physx_view.get_link_incoming_joint_force()[:,asset_cfg.body_ids]
+ returnlink_incoming_forces.view(env.num_envs,-1)
+
+
+
[文档]defgrab_images(
+ env:ManagerBasedEnv,
+ sensor_cfg:SceneEntityCfg=SceneEntityCfg("tiled_camera"),
+ data_type:str="rgb",
+ convert_perspective_to_orthogonal:bool=False,
+ normalize:bool=True,
+)->torch.Tensor:
+"""Grab all of the latest images of a specific datatype produced by a specific camera.
+
+ Args:
+ env: The environment the cameras are placed within.
+ sensor_cfg: The desired sensor to read from. Defaults to SceneEntityCfg("tiled_camera").
+ data_type: The data type to pull from the desired camera. Defaults to "rgb".
+ convert_perspective_to_orthogonal: Whether to convert perspective
+ depth images to orthogonal depth images. Defaults to False.
+ normalize: Set to True to normalize images. Defaults to True.
+
+ Returns:
+ The images produced at the last timestep
+ """
+ sensor:TiledCamera|Camera|RayCasterCamera=env.scene.sensors[sensor_cfg.name]
+ images=sensor.data.output[data_type]
+ if(data_type=="distance_to_camera")andconvert_perspective_to_orthogonal:
+ images=math_utils.convert_perspective_depth_to_orthogonal_depth(images,sensor.data.intrinsic_matrices)
+
+ ifnormalize:
+ ifdata_type=="rgb":
+ images=images/255
+ mean_tensor=torch.mean(images,dim=(1,2),keepdim=True)
+ images-=mean_tensor
+ elif"distance_to"indata_typeor"depth"indata_type:
+ images[images==float("inf")]=0
+ returnimages.clone()
+
+
+"""
+Actions.
+"""
+
+
+
[文档]deflast_action(env:ManagerBasedEnv,action_name:str|None=None)->torch.Tensor:
+"""The last input action to the environment.
+
+ The name of the action term for which the action is required. If None, the
+ entire action tensor is returned.
+ """
+ ifaction_nameisNone:
+ returnenv.action_manager.action
+ else:
+ returnenv.action_manager.get_term(action_name).raw_actions
+
+
+"""
+Commands.
+"""
+
+
+
[文档]defgenerated_commands(env:ManagerBasedRLEnv,command_name:str)->torch.Tensor:
+"""The generated command from command term in the command manager with the given name."""
+ returnenv.command_manager.get_command(command_name)
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Common functions that can be used to enable reward functions.
+
+The functions can be passed to the :class:`omni.isaac.lab.managers.RewardTermCfg` object to include
+the reward introduced by the function.
+"""
+
+from__future__importannotations
+
+importtorch
+fromtypingimportTYPE_CHECKING
+
+fromomni.isaac.lab.assetsimportArticulation,RigidObject
+fromomni.isaac.lab.managersimportSceneEntityCfg
+fromomni.isaac.lab.managers.manager_baseimportManagerTermBase
+fromomni.isaac.lab.managers.manager_term_cfgimportRewardTermCfg
+fromomni.isaac.lab.sensorsimportContactSensor
+
+ifTYPE_CHECKING:
+ fromomni.isaac.lab.envsimportManagerBasedRLEnv
+
+"""
+General.
+"""
+
+
+
[文档]defis_alive(env:ManagerBasedRLEnv)->torch.Tensor:
+"""Reward for being alive."""
+ return(~env.termination_manager.terminated).float()
+
+
+
[文档]defis_terminated(env:ManagerBasedRLEnv)->torch.Tensor:
+"""Penalize terminated episodes that don't correspond to episodic timeouts."""
+ returnenv.termination_manager.terminated.float()
+
+
+
[文档]classis_terminated_term(ManagerTermBase):
+"""Penalize termination for specific terms that don't correspond to episodic timeouts.
+
+ The parameters are as follows:
+
+ * attr:`term_keys`: The termination terms to penalize. This can be a string, a list of strings
+ or regular expressions. Default is ".*" which penalizes all terminations.
+
+ The reward is computed as the sum of the termination terms that are not episodic timeouts.
+ This means that the reward is 0 if the episode is terminated due to an episodic timeout. Otherwise,
+ if two termination terms are active, the reward is 2.
+ """
+
+
[文档]def__init__(self,cfg:RewardTermCfg,env:ManagerBasedRLEnv):
+ # initialize the base class
+ super().__init__(cfg,env)
+ # find and store the termination terms
+ term_keys=cfg.params.get("term_keys",".*")
+ self._term_names=env.termination_manager.find_terms(term_keys)
+
+ def__call__(self,env:ManagerBasedRLEnv,term_keys:str|list[str]=".*")->torch.Tensor:
+ # Return the unweighted reward for the termination terms
+ reset_buf=torch.zeros(env.num_envs,device=env.device)
+ forterminself._term_names:
+ # Sums over terminations term values to account for multiple terminations in the same step
+ reset_buf+=env.termination_manager.get_term(term)
+
+ return(reset_buf*(~env.termination_manager.time_outs)).float()
+
+
+"""
+Root penalties.
+"""
+
+
+
[文档]deflin_vel_z_l2(env:ManagerBasedRLEnv,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"))->torch.Tensor:
+"""Penalize z-axis base linear velocity using L2 squared kernel."""
+ # extract the used quantities (to enable type-hinting)
+ asset:RigidObject=env.scene[asset_cfg.name]
+ returntorch.square(asset.data.root_lin_vel_b[:,2])
+
+
+
[文档]defang_vel_xy_l2(env:ManagerBasedRLEnv,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"))->torch.Tensor:
+"""Penalize xy-axis base angular velocity using L2 squared kernel."""
+ # extract the used quantities (to enable type-hinting)
+ asset:RigidObject=env.scene[asset_cfg.name]
+ returntorch.sum(torch.square(asset.data.root_ang_vel_b[:,:2]),dim=1)
+
+
+
[文档]defflat_orientation_l2(env:ManagerBasedRLEnv,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"))->torch.Tensor:
+"""Penalize non-flat base orientation using L2 squared kernel.
+
+ This is computed by penalizing the xy-components of the projected gravity vector.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:RigidObject=env.scene[asset_cfg.name]
+ returntorch.sum(torch.square(asset.data.projected_gravity_b[:,:2]),dim=1)
+
+
+
[文档]defbase_height_l2(
+ env:ManagerBasedRLEnv,target_height:float,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot")
+)->torch.Tensor:
+"""Penalize asset height from its target using L2 squared kernel.
+
+ Note:
+ Currently, it assumes a flat terrain, i.e. the target height is in the world frame.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:RigidObject=env.scene[asset_cfg.name]
+ # TODO: Fix this for rough-terrain.
+ returntorch.square(asset.data.root_pos_w[:,2]-target_height)
+
+
+
[文档]defbody_lin_acc_l2(env:ManagerBasedRLEnv,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"))->torch.Tensor:
+"""Penalize the linear acceleration of bodies using L2-kernel."""
+ asset:Articulation=env.scene[asset_cfg.name]
+ returntorch.sum(torch.norm(asset.data.body_lin_acc_w[:,asset_cfg.body_ids,:],dim=-1),dim=1)
+
+
+"""
+Joint penalties.
+"""
+
+
+
[文档]defjoint_torques_l2(env:ManagerBasedRLEnv,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"))->torch.Tensor:
+"""Penalize joint torques applied on the articulation using L2 squared kernel.
+
+ NOTE: Only the joints configured in :attr:`asset_cfg.joint_ids` will have their joint torques contribute to the term.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:Articulation=env.scene[asset_cfg.name]
+ returntorch.sum(torch.square(asset.data.applied_torque[:,asset_cfg.joint_ids]),dim=1)
+
+
+
[文档]defjoint_vel_l1(env:ManagerBasedRLEnv,asset_cfg:SceneEntityCfg)->torch.Tensor:
+"""Penalize joint velocities on the articulation using an L1-kernel."""
+ # extract the used quantities (to enable type-hinting)
+ asset:Articulation=env.scene[asset_cfg.name]
+ returntorch.sum(torch.abs(asset.data.joint_vel[:,asset_cfg.joint_ids]),dim=1)
+
+
+
[文档]defjoint_vel_l2(env:ManagerBasedRLEnv,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"))->torch.Tensor:
+"""Penalize joint velocities on the articulation using L2 squared kernel.
+
+ NOTE: Only the joints configured in :attr:`asset_cfg.joint_ids` will have their joint velocities contribute to the term.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:Articulation=env.scene[asset_cfg.name]
+ returntorch.sum(torch.square(asset.data.joint_vel[:,asset_cfg.joint_ids]),dim=1)
+
+
+
[文档]defjoint_acc_l2(env:ManagerBasedRLEnv,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"))->torch.Tensor:
+"""Penalize joint accelerations on the articulation using L2 squared kernel.
+
+ NOTE: Only the joints configured in :attr:`asset_cfg.joint_ids` will have their joint accelerations contribute to the term.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:Articulation=env.scene[asset_cfg.name]
+ returntorch.sum(torch.square(asset.data.joint_acc[:,asset_cfg.joint_ids]),dim=1)
+
+
+
[文档]defjoint_deviation_l1(env,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"))->torch.Tensor:
+"""Penalize joint positions that deviate from the default one."""
+ # extract the used quantities (to enable type-hinting)
+ asset:Articulation=env.scene[asset_cfg.name]
+ # compute out of limits constraints
+ angle=asset.data.joint_pos[:,asset_cfg.joint_ids]-asset.data.default_joint_pos[:,asset_cfg.joint_ids]
+ returntorch.sum(torch.abs(angle),dim=1)
+
+
+
[文档]defjoint_pos_limits(env:ManagerBasedRLEnv,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"))->torch.Tensor:
+"""Penalize joint positions if they cross the soft limits.
+
+ This is computed as a sum of the absolute value of the difference between the joint position and the soft limits.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:Articulation=env.scene[asset_cfg.name]
+ # compute out of limits constraints
+ out_of_limits=-(
+ asset.data.joint_pos[:,asset_cfg.joint_ids]-asset.data.soft_joint_pos_limits[:,asset_cfg.joint_ids,0]
+ ).clip(max=0.0)
+ out_of_limits+=(
+ asset.data.joint_pos[:,asset_cfg.joint_ids]-asset.data.soft_joint_pos_limits[:,asset_cfg.joint_ids,1]
+ ).clip(min=0.0)
+ returntorch.sum(out_of_limits,dim=1)
+
+
+
[文档]defjoint_vel_limits(
+ env:ManagerBasedRLEnv,soft_ratio:float,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot")
+)->torch.Tensor:
+"""Penalize joint velocities if they cross the soft limits.
+
+ This is computed as a sum of the absolute value of the difference between the joint velocity and the soft limits.
+
+ Args:
+ soft_ratio: The ratio of the soft limits to be used.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:Articulation=env.scene[asset_cfg.name]
+ # compute out of limits constraints
+ out_of_limits=(
+ torch.abs(asset.data.joint_vel[:,asset_cfg.joint_ids])
+ -asset.data.soft_joint_vel_limits[:,asset_cfg.joint_ids]*soft_ratio
+ )
+ # clip to max error = 1 rad/s per joint to avoid huge penalties
+ out_of_limits=out_of_limits.clip_(min=0.0,max=1.0)
+ returntorch.sum(out_of_limits,dim=1)
+
+
+"""
+Action penalties.
+"""
+
+
+
[文档]defapplied_torque_limits(env:ManagerBasedRLEnv,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"))->torch.Tensor:
+"""Penalize applied torques if they cross the limits.
+
+ This is computed as a sum of the absolute value of the difference between the applied torques and the limits.
+
+ .. caution::
+ Currently, this only works for explicit actuators since we manually compute the applied torques.
+ For implicit actuators, we currently cannot retrieve the applied torques from the physics engine.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:Articulation=env.scene[asset_cfg.name]
+ # compute out of limits constraints
+ # TODO: We need to fix this to support implicit joints.
+ out_of_limits=torch.abs(
+ asset.data.applied_torque[:,asset_cfg.joint_ids]-asset.data.computed_torque[:,asset_cfg.joint_ids]
+ )
+ returntorch.sum(out_of_limits,dim=1)
+
+
+
[文档]defaction_rate_l2(env:ManagerBasedRLEnv)->torch.Tensor:
+"""Penalize the rate of change of the actions using L2 squared kernel."""
+ returntorch.sum(torch.square(env.action_manager.action-env.action_manager.prev_action),dim=1)
+
+
+
[文档]defaction_l2(env:ManagerBasedRLEnv)->torch.Tensor:
+"""Penalize the actions using L2 squared kernel."""
+ returntorch.sum(torch.square(env.action_manager.action),dim=1)
+
+
+"""
+Contact sensor.
+"""
+
+
+
[文档]defundesired_contacts(env:ManagerBasedRLEnv,threshold:float,sensor_cfg:SceneEntityCfg)->torch.Tensor:
+"""Penalize undesired contacts as the number of violations that are above a threshold."""
+ # extract the used quantities (to enable type-hinting)
+ contact_sensor:ContactSensor=env.scene.sensors[sensor_cfg.name]
+ # check if contact force is above threshold
+ net_contact_forces=contact_sensor.data.net_forces_w_history
+ is_contact=torch.max(torch.norm(net_contact_forces[:,:,sensor_cfg.body_ids],dim=-1),dim=1)[0]>threshold
+ # sum over contacts for each environment
+ returntorch.sum(is_contact,dim=1)
+
+
+
[文档]defcontact_forces(env:ManagerBasedRLEnv,threshold:float,sensor_cfg:SceneEntityCfg)->torch.Tensor:
+"""Penalize contact forces as the amount of violations of the net contact force."""
+ # extract the used quantities (to enable type-hinting)
+ contact_sensor:ContactSensor=env.scene.sensors[sensor_cfg.name]
+ net_contact_forces=contact_sensor.data.net_forces_w_history
+ # compute the violation
+ violation=torch.max(torch.norm(net_contact_forces[:,:,sensor_cfg.body_ids],dim=-1),dim=1)[0]-threshold
+ # compute the penalty
+ returntorch.sum(violation.clip(min=0.0),dim=1)
+
+
+"""
+Velocity-tracking rewards.
+"""
+
+
+
[文档]deftrack_lin_vel_xy_exp(
+ env:ManagerBasedRLEnv,std:float,command_name:str,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot")
+)->torch.Tensor:
+"""Reward tracking of linear velocity commands (xy axes) using exponential kernel."""
+ # extract the used quantities (to enable type-hinting)
+ asset:RigidObject=env.scene[asset_cfg.name]
+ # compute the error
+ lin_vel_error=torch.sum(
+ torch.square(env.command_manager.get_command(command_name)[:,:2]-asset.data.root_lin_vel_b[:,:2]),
+ dim=1,
+ )
+ returntorch.exp(-lin_vel_error/std**2)
+
+
+
[文档]deftrack_ang_vel_z_exp(
+ env:ManagerBasedRLEnv,std:float,command_name:str,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot")
+)->torch.Tensor:
+"""Reward tracking of angular velocity commands (yaw) using exponential kernel."""
+ # extract the used quantities (to enable type-hinting)
+ asset:RigidObject=env.scene[asset_cfg.name]
+ # compute the error
+ ang_vel_error=torch.square(env.command_manager.get_command(command_name)[:,2]-asset.data.root_ang_vel_b[:,2])
+ returntorch.exp(-ang_vel_error/std**2)
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Common functions that can be used to activate certain terminations.
+
+The functions can be passed to the :class:`omni.isaac.lab.managers.TerminationTermCfg` object to enable
+the termination introduced by the function.
+"""
+
+from__future__importannotations
+
+importtorch
+fromtypingimportTYPE_CHECKING
+
+fromomni.isaac.lab.assetsimportArticulation,RigidObject
+fromomni.isaac.lab.managersimportSceneEntityCfg
+fromomni.isaac.lab.sensorsimportContactSensor
+
+ifTYPE_CHECKING:
+ fromomni.isaac.lab.envsimportManagerBasedRLEnv
+ fromomni.isaac.lab.managers.command_managerimportCommandTerm
+
+"""
+MDP terminations.
+"""
+
+
+
[文档]deftime_out(env:ManagerBasedRLEnv)->torch.Tensor:
+"""Terminate the episode when the episode length exceeds the maximum episode length."""
+ returnenv.episode_length_buf>=env.max_episode_length
+
+
+
[文档]defcommand_resample(env:ManagerBasedRLEnv,command_name:str,num_resamples:int=1)->torch.Tensor:
+"""Terminate the episode based on the total number of times commands have been re-sampled.
+
+ This makes the maximum episode length fluid in nature as it depends on how the commands are
+ sampled. It is useful in situations where delayed rewards are used :cite:`rudin2022advanced`.
+ """
+ command:CommandTerm=env.command_manager.get_term(command_name)
+ returntorch.logical_and((command.time_left<=env.step_dt),(command.command_counter==num_resamples))
+
+
+"""
+Root terminations.
+"""
+
+
+
[文档]defbad_orientation(
+ env:ManagerBasedRLEnv,limit_angle:float,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot")
+)->torch.Tensor:
+"""Terminate when the asset's orientation is too far from the desired orientation limits.
+
+ This is computed by checking the angle between the projected gravity vector and the z-axis.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:RigidObject=env.scene[asset_cfg.name]
+ returntorch.acos(-asset.data.projected_gravity_b[:,2]).abs()>limit_angle
+
+
+
[文档]defroot_height_below_minimum(
+ env:ManagerBasedRLEnv,minimum_height:float,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot")
+)->torch.Tensor:
+"""Terminate when the asset's root height is below the minimum height.
+
+ Note:
+ This is currently only supported for flat terrains, i.e. the minimum height is in the world frame.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:RigidObject=env.scene[asset_cfg.name]
+ returnasset.data.root_pos_w[:,2]<minimum_height
+
+
+"""
+Joint terminations.
+"""
+
+
+
[文档]defjoint_pos_out_of_limit(env:ManagerBasedRLEnv,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"))->torch.Tensor:
+"""Terminate when the asset's joint positions are outside of the soft joint limits."""
+ # extract the used quantities (to enable type-hinting)
+ asset:Articulation=env.scene[asset_cfg.name]
+ # compute any violations
+ out_of_upper_limits=torch.any(asset.data.joint_pos>asset.data.soft_joint_pos_limits[...,1],dim=1)
+ out_of_lower_limits=torch.any(asset.data.joint_pos<asset.data.soft_joint_pos_limits[...,0],dim=1)
+ returntorch.logical_or(out_of_upper_limits[:,asset_cfg.joint_ids],out_of_lower_limits[:,asset_cfg.joint_ids])
+
+
+
[文档]defjoint_pos_out_of_manual_limit(
+ env:ManagerBasedRLEnv,bounds:tuple[float,float],asset_cfg:SceneEntityCfg=SceneEntityCfg("robot")
+)->torch.Tensor:
+"""Terminate when the asset's joint positions are outside of the configured bounds.
+
+ Note:
+ This function is similar to :func:`joint_pos_out_of_limit` but allows the user to specify the bounds manually.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:Articulation=env.scene[asset_cfg.name]
+ ifasset_cfg.joint_idsisNone:
+ asset_cfg.joint_ids=slice(None)
+ # compute any violations
+ out_of_upper_limits=torch.any(asset.data.joint_pos[:,asset_cfg.joint_ids]>bounds[1],dim=1)
+ out_of_lower_limits=torch.any(asset.data.joint_pos[:,asset_cfg.joint_ids]<bounds[0],dim=1)
+ returntorch.logical_or(out_of_upper_limits,out_of_lower_limits)
+
+
+
[文档]defjoint_vel_out_of_limit(env:ManagerBasedRLEnv,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot"))->torch.Tensor:
+"""Terminate when the asset's joint velocities are outside of the soft joint limits."""
+ # extract the used quantities (to enable type-hinting)
+ asset:Articulation=env.scene[asset_cfg.name]
+ # compute any violations
+ limits=asset.data.soft_joint_vel_limits
+ returntorch.any(torch.abs(asset.data.joint_vel[:,asset_cfg.joint_ids])>limits[:,asset_cfg.joint_ids],dim=1)
+
+
+
[文档]defjoint_vel_out_of_manual_limit(
+ env:ManagerBasedRLEnv,max_velocity:float,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot")
+)->torch.Tensor:
+"""Terminate when the asset's joint velocities are outside the provided limits."""
+ # extract the used quantities (to enable type-hinting)
+ asset:Articulation=env.scene[asset_cfg.name]
+ # compute any violations
+ returntorch.any(torch.abs(asset.data.joint_vel[:,asset_cfg.joint_ids])>max_velocity,dim=1)
+
+
+
[文档]defjoint_effort_out_of_limit(
+ env:ManagerBasedRLEnv,asset_cfg:SceneEntityCfg=SceneEntityCfg("robot")
+)->torch.Tensor:
+"""Terminate when effort applied on the asset's joints are outside of the soft joint limits.
+
+ In the actuators, the applied torque are the efforts applied on the joints. These are computed by clipping
+ the computed torques to the joint limits. Hence, we check if the computed torques are equal to the applied
+ torques.
+ """
+ # extract the used quantities (to enable type-hinting)
+ asset:Articulation=env.scene[asset_cfg.name]
+ # check if any joint effort is out of limit
+ out_of_limits=torch.isclose(
+ asset.data.computed_torque[:,asset_cfg.joint_ids],asset.data.applied_torque[:,asset_cfg.joint_ids]
+ )
+ returntorch.any(out_of_limits,dim=1)
+
+
+"""
+Contact sensor.
+"""
+
+
+
[文档]defillegal_contact(env:ManagerBasedRLEnv,threshold:float,sensor_cfg:SceneEntityCfg)->torch.Tensor:
+"""Terminate when the contact force on the sensor exceeds the force threshold."""
+ # extract the used quantities (to enable type-hinting)
+ contact_sensor:ContactSensor=env.scene.sensors[sensor_cfg.name]
+ net_contact_forces=contact_sensor.data.net_forces_w_history
+ # check if any contact force exceeds the threshold
+ returntorch.any(
+ torch.max(torch.norm(net_contact_forces[:,:,sensor_cfg.body_ids],dim=-1),dim=1)[0]>threshold,dim=1
+ )
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+from__future__importannotations
+
+importasyncio
+importos
+importweakref
+fromdatetimeimportdatetime
+fromtypingimportTYPE_CHECKING
+
+importomni.kit.app
+importomni.kit.commands
+importomni.usd
+frompxrimportPhysxSchema,Sdf,Usd,UsdGeom,UsdPhysics
+
+ifTYPE_CHECKING:
+ importomni.ui
+
+ from..manager_based_envimportManagerBasedEnv
+
+
+
[文档]classBaseEnvWindow:
+"""Window manager for the basic environment.
+
+ This class creates a window that is used to control the environment. The window
+ contains controls for rendering, debug visualization, and other environment-specific
+ UI elements.
+
+ Users can add their own UI elements to the window by using the `with` context manager.
+ This can be done either be inheriting the class or by using the `env.window` object
+ directly from the standalone execution script.
+
+ Example for adding a UI element from the standalone execution script:
+ >>> with env.window.ui_window_elements["main_vstack"]:
+ >>> ui.Label("My UI element")
+
+ """
+
+
[文档]def__init__(self,env:ManagerBasedEnv,window_name:str="IsaacLab"):
+"""Initialize the window.
+
+ Args:
+ env: The environment object.
+ window_name: The name of the window. Defaults to "IsaacLab".
+ """
+ # store inputs
+ self.env=env
+ # prepare the list of assets that can be followed by the viewport camera
+ # note that the first two options are "World" and "Env" which are special cases
+ self._viewer_assets_options=[
+ "World",
+ "Env",
+ *self.env.scene.rigid_objects.keys(),
+ *self.env.scene.articulations.keys(),
+ ]
+
+ print("Creating window for environment.")
+ # create window for UI
+ self.ui_window=omni.ui.Window(
+ window_name,width=400,height=500,visible=True,dock_preference=omni.ui.DockPreference.RIGHT_TOP
+ )
+ # dock next to properties window
+ asyncio.ensure_future(self._dock_window(window_title=self.ui_window.title))
+
+ # keep a dictionary of stacks so that child environments can add their own UI elements
+ # this can be done by using the `with` context manager
+ self.ui_window_elements=dict()
+ # create main frame
+ self.ui_window_elements["main_frame"]=self.ui_window.frame
+ withself.ui_window_elements["main_frame"]:
+ # create main stack
+ self.ui_window_elements["main_vstack"]=omni.ui.VStack(spacing=5,height=0)
+ withself.ui_window_elements["main_vstack"]:
+ # create collapsable frame for simulation
+ self._build_sim_frame()
+ # create collapsable frame for viewer
+ self._build_viewer_frame()
+ # create collapsable frame for debug visualization
+ self._build_debug_vis_frame()
+
+ def__del__(self):
+"""Destructor for the window."""
+ # destroy the window
+ ifself.ui_windowisnotNone:
+ self.ui_window.visible=False
+ self.ui_window.destroy()
+ self.ui_window=None
+
+"""
+ Build sub-sections of the UI.
+ """
+
+ def_build_sim_frame(self):
+"""Builds the sim-related controls frame for the UI."""
+ # create collapsable frame for controls
+ self.ui_window_elements["sim_frame"]=omni.ui.CollapsableFrame(
+ title="Simulation Settings",
+ width=omni.ui.Fraction(1),
+ height=0,
+ collapsed=False,
+ style=omni.isaac.ui.ui_utils.get_style(),
+ horizontal_scrollbar_policy=omni.ui.ScrollBarPolicy.SCROLLBAR_AS_NEEDED,
+ vertical_scrollbar_policy=omni.ui.ScrollBarPolicy.SCROLLBAR_ALWAYS_ON,
+ )
+ withself.ui_window_elements["sim_frame"]:
+ # create stack for controls
+ self.ui_window_elements["sim_vstack"]=omni.ui.VStack(spacing=5,height=0)
+ withself.ui_window_elements["sim_vstack"]:
+ # create rendering mode dropdown
+ render_mode_cfg={
+ "label":"Rendering Mode",
+ "type":"dropdown",
+ "default_val":self.env.sim.render_mode.value,
+ "items":[member.nameformemberinself.env.sim.RenderModeifmember.value>=0],
+ "tooltip":"Select a rendering mode\n"+self.env.sim.RenderMode.__doc__,
+ "on_clicked_fn":lambdavalue:self.env.sim.set_render_mode(self.env.sim.RenderMode[value]),
+ }
+ self.ui_window_elements["render_dropdown"]=omni.isaac.ui.ui_utils.dropdown_builder(**render_mode_cfg)
+
+ # create animation recording box
+ record_animate_cfg={
+ "label":"Record Animation",
+ "type":"state_button",
+ "a_text":"START",
+ "b_text":"STOP",
+ "tooltip":"Record the animation of the scene. Only effective if fabric is disabled.",
+ "on_clicked_fn":lambdavalue:self._toggle_recording_animation_fn(value),
+ }
+ self.ui_window_elements["record_animation"]=omni.isaac.ui.ui_utils.state_btn_builder(
+ **record_animate_cfg
+ )
+ # disable the button if fabric is not enabled
+ self.ui_window_elements["record_animation"].enabled=notself.env.sim.is_fabric_enabled()
+
+ def_build_viewer_frame(self):
+"""Build the viewer-related control frame for the UI."""
+ # create collapsable frame for viewer
+ self.ui_window_elements["viewer_frame"]=omni.ui.CollapsableFrame(
+ title="Viewer Settings",
+ width=omni.ui.Fraction(1),
+ height=0,
+ collapsed=False,
+ style=omni.isaac.ui.ui_utils.get_style(),
+ horizontal_scrollbar_policy=omni.ui.ScrollBarPolicy.SCROLLBAR_AS_NEEDED,
+ vertical_scrollbar_policy=omni.ui.ScrollBarPolicy.SCROLLBAR_ALWAYS_ON,
+ )
+ withself.ui_window_elements["viewer_frame"]:
+ # create stack for controls
+ self.ui_window_elements["viewer_vstack"]=omni.ui.VStack(spacing=5,height=0)
+ withself.ui_window_elements["viewer_vstack"]:
+ # create a number slider to move to environment origin
+ # NOTE: slider is 1-indexed, whereas the env index is 0-indexed
+ viewport_origin_cfg={
+ "label":"Environment Index",
+ "type":"button",
+ "default_val":self.env.cfg.viewer.env_index+1,
+ "min":1,
+ "max":self.env.num_envs,
+ "tooltip":"The environment index to follow. Only effective if follow mode is not 'World'.",
+ }
+ self.ui_window_elements["viewer_env_index"]=omni.isaac.ui.ui_utils.int_builder(**viewport_origin_cfg)
+ # create a number slider to move to environment origin
+ self.ui_window_elements["viewer_env_index"].add_value_changed_fn(self._set_viewer_env_index_fn)
+
+ # create a tracker for the camera location
+ viewer_follow_cfg={
+ "label":"Follow Mode",
+ "type":"dropdown",
+ "default_val":0,
+ "items":[name.replace("_"," ").title()fornameinself._viewer_assets_options],
+ "tooltip":"Select the viewport camera following mode.",
+ "on_clicked_fn":self._set_viewer_origin_type_fn,
+ }
+ self.ui_window_elements["viewer_follow"]=omni.isaac.ui.ui_utils.dropdown_builder(**viewer_follow_cfg)
+
+ # add viewer default eye and lookat locations
+ self.ui_window_elements["viewer_eye"]=omni.isaac.ui.ui_utils.xyz_builder(
+ label="Camera Eye",
+ tooltip="Modify the XYZ location of the viewer eye.",
+ default_val=self.env.cfg.viewer.eye,
+ step=0.1,
+ on_value_changed_fn=[self._set_viewer_location_fn]*3,
+ )
+ self.ui_window_elements["viewer_lookat"]=omni.isaac.ui.ui_utils.xyz_builder(
+ label="Camera Target",
+ tooltip="Modify the XYZ location of the viewer target.",
+ default_val=self.env.cfg.viewer.lookat,
+ step=0.1,
+ on_value_changed_fn=[self._set_viewer_location_fn]*3,
+ )
+
+ def_build_debug_vis_frame(self):
+"""Builds the debug visualization frame for various scene elements.
+
+ This function inquires the scene for all elements that have a debug visualization
+ implemented and creates a checkbox to toggle the debug visualization for each element
+ that has it implemented. If the element does not have a debug visualization implemented,
+ a label is created instead.
+ """
+ # import omni.isaac.ui.ui_utils as ui_utils
+ # import omni.ui
+
+ # create collapsable frame for debug visualization
+ self.ui_window_elements["debug_frame"]=omni.ui.CollapsableFrame(
+ title="Scene Debug Visualization",
+ width=omni.ui.Fraction(1),
+ height=0,
+ collapsed=False,
+ style=omni.isaac.ui.ui_utils.get_style(),
+ horizontal_scrollbar_policy=omni.ui.ScrollBarPolicy.SCROLLBAR_AS_NEEDED,
+ vertical_scrollbar_policy=omni.ui.ScrollBarPolicy.SCROLLBAR_ALWAYS_ON,
+ )
+ withself.ui_window_elements["debug_frame"]:
+ # create stack for debug visualization
+ self.ui_window_elements["debug_vstack"]=omni.ui.VStack(spacing=5,height=0)
+ withself.ui_window_elements["debug_vstack"]:
+ elements=[
+ self.env.scene.terrain,
+ *self.env.scene.rigid_objects.values(),
+ *self.env.scene.articulations.values(),
+ *self.env.scene.sensors.values(),
+ ]
+ names=[
+ "terrain",
+ *self.env.scene.rigid_objects.keys(),
+ *self.env.scene.articulations.keys(),
+ *self.env.scene.sensors.keys(),
+ ]
+ # create one for the terrain
+ forelem,nameinzip(elements,names):
+ ifelemisnotNone:
+ self._create_debug_vis_ui_element(name,elem)
+
+"""
+ Custom callbacks for UI elements.
+ """
+
+ def_toggle_recording_animation_fn(self,value:bool):
+"""Toggles the animation recording."""
+ ifvalue:
+ # log directory to save the recording
+ ifnothasattr(self,"animation_log_dir"):
+ # create a new log directory
+ log_dir=datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
+ self.animation_log_dir=os.path.join(os.getcwd(),"recordings",log_dir)
+ # start the recording
+ _=omni.kit.commands.execute(
+ "StartRecording",
+ target_paths=[("/World",True)],
+ live_mode=True,
+ use_frame_range=False,
+ start_frame=0,
+ end_frame=0,
+ use_preroll=False,
+ preroll_frame=0,
+ record_to="FILE",
+ fps=0,
+ apply_root_anim=False,
+ increment_name=True,
+ record_folder=self.animation_log_dir,
+ take_name="TimeSample",
+ )
+ else:
+ # stop the recording
+ _=omni.kit.commands.execute("StopRecording")
+ # save the current stage
+ stage=omni.usd.get_context().get_stage()
+ source_layer=stage.GetRootLayer()
+ # output the stage to a file
+ stage_usd_path=os.path.join(self.animation_log_dir,"Stage.usd")
+ source_prim_path="/"
+ # creates empty anon layer
+ temp_layer=Sdf.Find(stage_usd_path)
+ iftemp_layerisNone:
+ temp_layer=Sdf.Layer.CreateNew(stage_usd_path)
+ temp_stage=Usd.Stage.Open(temp_layer)
+ # update stage data
+ UsdGeom.SetStageUpAxis(temp_stage,UsdGeom.GetStageUpAxis(stage))
+ UsdGeom.SetStageMetersPerUnit(temp_stage,UsdGeom.GetStageMetersPerUnit(stage))
+ # copy the prim
+ Sdf.CreatePrimInLayer(temp_layer,source_prim_path)
+ Sdf.CopySpec(source_layer,source_prim_path,temp_layer,source_prim_path)
+ # set the default prim
+ temp_layer.defaultPrim=Sdf.Path(source_prim_path).name
+ # remove all physics from the stage
+ forprimintemp_stage.TraverseAll():
+ # skip if the prim is an instance
+ ifprim.IsInstanceable():
+ continue
+ # if prim has articulation then disable it
+ ifprim.HasAPI(UsdPhysics.ArticulationRootAPI):
+ prim.RemoveAPI(UsdPhysics.ArticulationRootAPI)
+ prim.RemoveAPI(PhysxSchema.PhysxArticulationAPI)
+ # if prim has rigid body then disable it
+ ifprim.HasAPI(UsdPhysics.RigidBodyAPI):
+ prim.RemoveAPI(UsdPhysics.RigidBodyAPI)
+ prim.RemoveAPI(PhysxSchema.PhysxRigidBodyAPI)
+ # if prim is a joint type then disable it
+ ifprim.IsA(UsdPhysics.Joint):
+ prim.GetAttribute("physics:jointEnabled").Set(False)
+ # resolve all paths relative to layer path
+ omni.usd.resolve_paths(source_layer.identifier,temp_layer.identifier)
+ # save the stage
+ temp_layer.Save()
+ # print the path to the saved stage
+ print("Recording completed.")
+ print(f"\tSaved recorded stage to : {stage_usd_path}")
+ print(f"\tSaved recorded animation to: {os.path.join(self.animation_log_dir,'TimeSample_tk001.usd')}")
+ print("\nTo play the animation, check the instructions in the following link:")
+ print(
+ "\thttps://docs.omniverse.nvidia.com/extensions/latest/ext_animation_stage-recorder.html#using-the-captured-timesamples"
+ )
+ print("\n")
+ # reset the log directory
+ self.animation_log_dir=None
+
+ def_set_viewer_origin_type_fn(self,value:str):
+"""Sets the origin of the viewport's camera. This is based on the drop-down menu in the UI."""
+ # Extract the viewport camera controller from environment
+ vcc=self.env.viewport_camera_controller
+ ifvccisNone:
+ raiseValueError("Viewport camera controller is not initialized! Please check the rendering mode.")
+
+ # Based on origin type, update the camera view
+ ifvalue=="World":
+ vcc.update_view_to_world()
+ elifvalue=="Env":
+ vcc.update_view_to_env()
+ else:
+ # find which index the asset is
+ fancy_names=[name.replace("_"," ").title()fornameinself._viewer_assets_options]
+ # store the desired env index
+ viewer_asset_name=self._viewer_assets_options[fancy_names.index(value)]
+ # update the camera view
+ vcc.update_view_to_asset_root(viewer_asset_name)
+
+ def_set_viewer_location_fn(self,model:omni.ui.SimpleFloatModel):
+"""Sets the viewport camera location based on the UI."""
+ # access the viewport camera controller (for brevity)
+ vcc=self.env.viewport_camera_controller
+ ifvccisNone:
+ raiseValueError("Viewport camera controller is not initialized! Please check the rendering mode.")
+ # obtain the camera locations and set them in the viewpoint camera controller
+ eye=[self.ui_window_elements["viewer_eye"][i].get_value_as_float()foriinrange(3)]
+ lookat=[self.ui_window_elements["viewer_lookat"][i].get_value_as_float()foriinrange(3)]
+ # update the camera view
+ vcc.update_view_location(eye,lookat)
+
+ def_set_viewer_env_index_fn(self,model:omni.ui.SimpleIntModel):
+"""Sets the environment index and updates the camera if in 'env' origin mode."""
+ # access the viewport camera controller (for brevity)
+ vcc=self.env.viewport_camera_controller
+ ifvccisNone:
+ raiseValueError("Viewport camera controller is not initialized! Please check the rendering mode.")
+ # store the desired env index, UI is 1-indexed
+ vcc.set_view_env_index(model.as_int-1)
+
+"""
+ Helper functions - UI building.
+ """
+
+ def_create_debug_vis_ui_element(self,name:str,elem:object):
+"""Create a checkbox for toggling debug visualization for the given element."""
+ fromomni.kit.window.extensionsimportSimpleCheckBox
+
+ withomni.ui.HStack():
+ # create the UI element
+ text=(
+ "Toggle debug visualization."
+ ifelem.has_debug_vis_implementation
+ else"Debug visualization not implemented."
+ )
+ omni.ui.Label(
+ name.replace("_"," ").title(),
+ width=omni.isaac.ui.ui_utils.LABEL_WIDTH-12,
+ alignment=omni.ui.Alignment.LEFT_CENTER,
+ tooltip=text,
+ )
+ self.ui_window_elements[f"{name}_cb"]=SimpleCheckBox(
+ model=omni.ui.SimpleBoolModel(),
+ enabled=elem.has_debug_vis_implementation,
+ checked=elem.cfg.debug_vis,
+ on_checked_fn=lambdavalue,e=weakref.proxy(elem):e.set_debug_vis(value),
+ )
+ omni.isaac.ui.ui_utils.add_line_rect_flourish()
+
+ asyncdef_dock_window(self,window_title:str):
+"""Docks the custom UI window to the property window."""
+ # wait for the window to be created
+ for_inrange(5):
+ ifomni.ui.Workspace.get_window(window_title):
+ break
+ awaitself.env.sim.app.next_update_async()
+
+ # dock next to properties window
+ custom_window=omni.ui.Workspace.get_window(window_title)
+ property_window=omni.ui.Workspace.get_window("Property")
+ ifcustom_windowandproperty_window:
+ custom_window.dock_in(property_window,omni.ui.DockPosition.SAME,1.0)
+ custom_window.focus()
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+from__future__importannotations
+
+fromtypingimportTYPE_CHECKING
+
+from.base_env_windowimportBaseEnvWindow
+
+ifTYPE_CHECKING:
+ from..manager_based_rl_envimportManagerBasedRLEnv
+
+
+
[文档]classManagerBasedRLEnvWindow(BaseEnvWindow):
+"""Window manager for the RL environment.
+
+ On top of the basic environment window, this class adds controls for the RL environment.
+ This includes visualization of the command manager.
+ """
+
+
[文档]def__init__(self,env:ManagerBasedRLEnv,window_name:str="IsaacLab"):
+"""Initialize the window.
+
+ Args:
+ env: The environment object.
+ window_name: The name of the window. Defaults to "IsaacLab".
+ """
+ # initialize base window
+ super().__init__(env,window_name)
+
+ # add custom UI elements
+ withself.ui_window_elements["main_vstack"]:
+ withself.ui_window_elements["debug_frame"]:
+ withself.ui_window_elements["debug_vstack"]:
+ self._create_debug_vis_ui_element("commands",self.env.command_manager)
+ self._create_debug_vis_ui_element("actions",self.env.action_manager)
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+from__future__importannotations
+
+importcopy
+importnumpyasnp
+importtorch
+importweakref
+fromcollections.abcimportSequence
+fromtypingimportTYPE_CHECKING
+
+importomni.kit.app
+importomni.timeline
+
+ifTYPE_CHECKING:
+ fromomni.isaac.lab.envsimportDirectRLEnv,ManagerBasedEnv,ViewerCfg
+
+
+
[文档]classViewportCameraController:
+"""This class handles controlling the camera associated with a viewport in the simulator.
+
+ It can be used to set the viewpoint camera to track different origin types:
+
+ - **world**: the center of the world (static)
+ - **env**: the center of an environment (static)
+ - **asset_root**: the root of an asset in the scene (e.g. tracking a robot moving in the scene)
+
+ On creation, the camera is set to track the origin type specified in the configuration.
+
+ For the :attr:`asset_root` origin type, the camera is updated at each rendering step to track the asset's
+ root position. For this, it registers a callback to the post update event stream from the simulation app.
+ """
+
+
[文档]def__init__(self,env:ManagerBasedEnv|DirectRLEnv,cfg:ViewerCfg):
+"""Initialize the ViewportCameraController.
+
+ Args:
+ env: The environment.
+ cfg: The configuration for the viewport camera controller.
+
+ Raises:
+ ValueError: If origin type is configured to be "env" but :attr:`cfg.env_index` is out of bounds.
+ ValueError: If origin type is configured to be "asset_root" but :attr:`cfg.asset_name` is unset.
+
+ """
+ # store inputs
+ self._env=env
+ self._cfg=copy.deepcopy(cfg)
+ # cast viewer eye and look-at to numpy arrays
+ self.default_cam_eye=np.array(self._cfg.eye)
+ self.default_cam_lookat=np.array(self._cfg.lookat)
+
+ # set the camera origins
+ ifself.cfg.origin_type=="env":
+ # check that the env_index is within bounds
+ self.set_view_env_index(self.cfg.env_index)
+ # set the camera origin to the center of the environment
+ self.update_view_to_env()
+ elifself.cfg.origin_type=="asset_root":
+ # note: we do not yet update camera for tracking an asset origin, as the asset may not yet be
+ # in the scene when this is called. Instead, we subscribe to the post update event to update the camera
+ # at each rendering step.
+ ifself.cfg.asset_nameisNone:
+ raiseValueError(f"No asset name provided for viewer with origin type: '{self.cfg.origin_type}'.")
+ else:
+ # set the camera origin to the center of the world
+ self.update_view_to_world()
+
+ # subscribe to post update event so that camera view can be updated at each rendering step
+ app_interface=omni.kit.app.get_app_interface()
+ app_event_stream=app_interface.get_post_update_event_stream()
+ self._viewport_camera_update_handle=app_event_stream.create_subscription_to_pop(
+ lambdaevent,obj=weakref.proxy(self):obj._update_tracking_callback(event)
+ )
+
+ def__del__(self):
+"""Unsubscribe from the callback."""
+ # use hasattr to handle case where __init__ has not completed before __del__ is called
+ ifhasattr(self,"_viewport_camera_update_handle")andself._viewport_camera_update_handleisnotNone:
+ self._viewport_camera_update_handle.unsubscribe()
+ self._viewport_camera_update_handle=None
+
+"""
+ Properties
+ """
+
+ @property
+ defcfg(self)->ViewerCfg:
+"""The configuration for the viewer."""
+ returnself._cfg
+
+"""
+ Public Functions
+ """
+
+
[文档]defset_view_env_index(self,env_index:int):
+"""Sets the environment index for the camera view.
+
+ Args:
+ env_index: The index of the environment to set the camera view to.
+
+ Raises:
+ ValueError: If the environment index is out of bounds. It should be between 0 and num_envs - 1.
+ """
+ # check that the env_index is within bounds
+ ifenv_index<0orenv_index>=self._env.num_envs:
+ raiseValueError(
+ f"Out of range value for attribute 'env_index': {env_index}."
+ f" Expected a value between 0 and {self._env.num_envs-1} for the current environment."
+ )
+ # update the environment index
+ self.cfg.env_index=env_index
+ # update the camera view if the origin is set to env type (since, the camera view is static)
+ # note: for assets, the camera view is updated at each rendering step
+ ifself.cfg.origin_type=="env":
+ self.update_view_to_env()
+
+
[文档]defupdate_view_to_world(self):
+"""Updates the viewer's origin to the origin of the world which is (0, 0, 0)."""
+ # set origin type to world
+ self.cfg.origin_type="world"
+ # update the camera origins
+ self.viewer_origin=torch.zeros(3)
+ # update the camera view
+ self.update_view_location()
+
+
[文档]defupdate_view_to_env(self):
+"""Updates the viewer's origin to the origin of the selected environment."""
+ # set origin type to world
+ self.cfg.origin_type="env"
+ # update the camera origins
+ self.viewer_origin=self._env.scene.env_origins[self.cfg.env_index]
+ # update the camera view
+ self.update_view_location()
+
+
[文档]defupdate_view_to_asset_root(self,asset_name:str):
+"""Updates the viewer's origin based upon the root of an asset in the scene.
+
+ Args:
+ asset_name: The name of the asset in the scene. The name should match the name of the
+ asset in the scene.
+
+ Raises:
+ ValueError: If the asset is not in the scene.
+ """
+ # check if the asset is in the scene
+ ifself.cfg.asset_name!=asset_name:
+ asset_entities=[*self._env.scene.rigid_objects.keys(),*self._env.scene.articulations.keys()]
+ ifasset_namenotinasset_entities:
+ raiseValueError(f"Asset '{asset_name}' is not in the scene. Available entities: {asset_entities}.")
+ # update the asset name
+ self.cfg.asset_name=asset_name
+ # set origin type to asset_root
+ self.cfg.origin_type="asset_root"
+ # update the camera origins
+ self.viewer_origin=self._env.scene[self.cfg.asset_name].data.root_pos_w[self.cfg.env_index]
+ # update the camera view
+ self.update_view_location()
+
+
[文档]defupdate_view_location(self,eye:Sequence[float]|None=None,lookat:Sequence[float]|None=None):
+"""Updates the camera view pose based on the current viewer origin and the eye and lookat positions.
+
+ Args:
+ eye: The eye position of the camera. If None, the current eye position is used.
+ lookat: The lookat position of the camera. If None, the current lookat position is used.
+ """
+ # store the camera view pose for later use
+ ifeyeisnotNone:
+ self.default_cam_eye=np.asarray(eye)
+ iflookatisnotNone:
+ self.default_cam_lookat=np.asarray(lookat)
+ # set the camera locations
+ viewer_origin=self.viewer_origin.detach().cpu().numpy()
+ cam_eye=viewer_origin+self.default_cam_eye
+ cam_target=viewer_origin+self.default_cam_lookat
+
+ # set the camera view
+ self._env.sim.set_camera_view(eye=cam_eye,target=cam_target)
+
+"""
+ Private Functions
+ """
+
+ def_update_tracking_callback(self,event):
+"""Updates the camera view at each rendering step."""
+ # update the camera view if the origin is set to asset_root
+ # in other cases, the camera view is static and does not need to be updated continuously
+ ifself.cfg.origin_type=="asset_root"andself.cfg.asset_nameisnotNone:
+ self.update_view_to_asset_root(self.cfg.asset_name)
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Action manager for processing actions sent to the environment."""
+
+from__future__importannotations
+
+importinspect
+importtorch
+importweakref
+fromabcimportabstractmethod
+fromcollections.abcimportSequence
+fromprettytableimportPrettyTable
+fromtypingimportTYPE_CHECKING
+
+importomni.kit.app
+
+fromomni.isaac.lab.assetsimportAssetBase
+
+from.manager_baseimportManagerBase,ManagerTermBase
+from.manager_term_cfgimportActionTermCfg
+
+ifTYPE_CHECKING:
+ fromomni.isaac.lab.envsimportManagerBasedEnv
+
+
+
[文档]classActionTerm(ManagerTermBase):
+"""Base class for action terms.
+
+ The action term is responsible for processing the raw actions sent to the environment
+ and applying them to the asset managed by the term. The action term is comprised of two
+ operations:
+
+ * Processing of actions: This operation is performed once per **environment step** and
+ is responsible for pre-processing the raw actions sent to the environment.
+ * Applying actions: This operation is performed once per **simulation step** and is
+ responsible for applying the processed actions to the asset managed by the term.
+ """
+
+
[文档]def__init__(self,cfg:ActionTermCfg,env:ManagerBasedEnv):
+"""Initialize the action term.
+
+ Args:
+ cfg: The configuration object.
+ env: The environment instance.
+ """
+ # call the base class constructor
+ super().__init__(cfg,env)
+ # parse config to obtain asset to which the term is applied
+ self._asset:AssetBase=self._env.scene[self.cfg.asset_name]
+
+ # add handle for debug visualization (this is set to a valid handle inside set_debug_vis)
+ self._debug_vis_handle=None
+ # set initial state of debug visualization
+ self.set_debug_vis(self.cfg.debug_vis)
+
+ def__del__(self):
+"""Unsubscribe from the callbacks."""
+ ifself._debug_vis_handle:
+ self._debug_vis_handle.unsubscribe()
+ self._debug_vis_handle=None
+
+"""
+ Properties.
+ """
+
+ @property
+ @abstractmethod
+ defaction_dim(self)->int:
+"""Dimension of the action term."""
+ raiseNotImplementedError
+
+ @property
+ @abstractmethod
+ defraw_actions(self)->torch.Tensor:
+"""The input/raw actions sent to the term."""
+ raiseNotImplementedError
+
+ @property
+ @abstractmethod
+ defprocessed_actions(self)->torch.Tensor:
+"""The actions computed by the term after applying any processing."""
+ raiseNotImplementedError
+
+ @property
+ defhas_debug_vis_implementation(self)->bool:
+"""Whether the action term has a debug visualization implemented."""
+ # check if function raises NotImplementedError
+ source_code=inspect.getsource(self._set_debug_vis_impl)
+ return"NotImplementedError"notinsource_code
+
+"""
+ Operations.
+ """
+
+
[文档]defset_debug_vis(self,debug_vis:bool)->bool:
+"""Sets whether to visualize the action term data.
+ Args:
+ debug_vis: Whether to visualize the action term data.
+ Returns:
+ Whether the debug visualization was successfully set. False if the action term does
+ not support debug visualization.
+ """
+ # check if debug visualization is supported
+ ifnotself.has_debug_vis_implementation:
+ returnFalse
+ # toggle debug visualization objects
+ self._set_debug_vis_impl(debug_vis)
+ # toggle debug visualization handles
+ ifdebug_vis:
+ # create a subscriber for the post update event if it doesn't exist
+ ifself._debug_vis_handleisNone:
+ app_interface=omni.kit.app.get_app_interface()
+ self._debug_vis_handle=app_interface.get_post_update_event_stream().create_subscription_to_pop(
+ lambdaevent,obj=weakref.proxy(self):obj._debug_vis_callback(event)
+ )
+ else:
+ # remove the subscriber if it exists
+ ifself._debug_vis_handleisnotNone:
+ self._debug_vis_handle.unsubscribe()
+ self._debug_vis_handle=None
+ # return success
+ returnTrue
+
+
[文档]@abstractmethod
+ defprocess_actions(self,actions:torch.Tensor):
+"""Processes the actions sent to the environment.
+
+ Note:
+ This function is called once per environment step by the manager.
+
+ Args:
+ actions: The actions to process.
+ """
+ raiseNotImplementedError
+
+
[文档]@abstractmethod
+ defapply_actions(self):
+"""Applies the actions to the asset managed by the term.
+
+ Note:
+ This is called at every simulation step by the manager.
+ """
+ raiseNotImplementedError
+
+ def_set_debug_vis_impl(self,debug_vis:bool):
+"""Set debug visualization into visualization objects.
+ This function is responsible for creating the visualization objects if they don't exist
+ and input ``debug_vis`` is True. If the visualization objects exist, the function should
+ set their visibility into the stage.
+ """
+ raiseNotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.")
+
+ def_debug_vis_callback(self,event):
+"""Callback for debug visualization.
+ This function calls the visualization objects and sets the data to visualize into them.
+ """
+ raiseNotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.")
+
+
+
[文档]classActionManager(ManagerBase):
+"""Manager for processing and applying actions for a given world.
+
+ The action manager handles the interpretation and application of user-defined
+ actions on a given world. It is comprised of different action terms that decide
+ the dimension of the expected actions.
+
+ The action manager performs operations at two stages:
+
+ * processing of actions: It splits the input actions to each term and performs any
+ pre-processing needed. This should be called once at every environment step.
+ * apply actions: This operation typically sets the processed actions into the assets in the
+ scene (such as robots). It should be called before every simulation step.
+ """
+
+
[文档]def__init__(self,cfg:object,env:ManagerBasedEnv):
+"""Initialize the action manager.
+
+ Args:
+ cfg: The configuration object or dictionary (``dict[str, ActionTermCfg]``).
+ env: The environment instance.
+ """
+ super().__init__(cfg,env)
+ # create buffers to store actions
+ self._action=torch.zeros((self.num_envs,self.total_action_dim),device=self.device)
+ self._prev_action=torch.zeros_like(self._action)
+
+ self.cfg.debug_vis=False
+ forterminself._terms.values():
+ self.cfg.debug_vis|=term.cfg.debug_vis
+
+ def__str__(self)->str:
+"""Returns: A string representation for action manager."""
+ msg=f"<ActionManager> contains {len(self._term_names)} active terms.\n"
+
+ # create table for term information
+ table=PrettyTable()
+ table.title=f"Active Action Terms (shape: {self.total_action_dim})"
+ table.field_names=["Index","Name","Dimension"]
+ # set alignment of table columns
+ table.align["Name"]="l"
+ table.align["Dimension"]="r"
+ # add info on each term
+ forindex,(name,term)inenumerate(self._terms.items()):
+ table.add_row([index,name,term.action_dim])
+ # convert table to string
+ msg+=table.get_string()
+ msg+="\n"
+
+ returnmsg
+
+"""
+ Properties.
+ """
+
+ @property
+ deftotal_action_dim(self)->int:
+"""Total dimension of actions."""
+ returnsum(self.action_term_dim)
+
+ @property
+ defactive_terms(self)->list[str]:
+"""Name of active action terms."""
+ returnself._term_names
+
+ @property
+ defaction_term_dim(self)->list[int]:
+"""Shape of each action term."""
+ return[term.action_dimforterminself._terms.values()]
+
+ @property
+ defaction(self)->torch.Tensor:
+"""The actions sent to the environment. Shape is (num_envs, total_action_dim)."""
+ returnself._action
+
+ @property
+ defprev_action(self)->torch.Tensor:
+"""The previous actions sent to the environment. Shape is (num_envs, total_action_dim)."""
+ returnself._prev_action
+
+ @property
+ defhas_debug_vis_implementation(self)->bool:
+"""Whether the command terms have debug visualization implemented."""
+ # check if function raises NotImplementedError
+ has_debug_vis=False
+ forterminself._terms.values():
+ has_debug_vis|=term.has_debug_vis_implementation
+ returnhas_debug_vis
+
+"""
+ Operations.
+ """
+
+
[文档]defset_debug_vis(self,debug_vis:bool)->bool:
+"""Sets whether to visualize the action data.
+ Args:
+ debug_vis: Whether to visualize the action data.
+ Returns:
+ Whether the debug visualization was successfully set. False if the action
+ does not support debug visualization.
+ """
+ forterminself._terms.values():
+ term.set_debug_vis(debug_vis)
+
+
[文档]defreset(self,env_ids:Sequence[int]|None=None)->dict[str,torch.Tensor]:
+"""Resets the action history.
+
+ Args:
+ env_ids: The environment ids. Defaults to None, in which case
+ all environments are considered.
+
+ Returns:
+ An empty dictionary.
+ """
+ # resolve environment ids
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ # reset the action history
+ self._prev_action[env_ids]=0.0
+ self._action[env_ids]=0.0
+ # reset all action terms
+ forterminself._terms.values():
+ term.reset(env_ids=env_ids)
+ # nothing to log here
+ return{}
+
+
[文档]defprocess_action(self,action:torch.Tensor):
+"""Processes the actions sent to the environment.
+
+ Note:
+ This function should be called once per environment step.
+
+ Args:
+ action: The actions to process.
+ """
+ # check if action dimension is valid
+ ifself.total_action_dim!=action.shape[1]:
+ raiseValueError(f"Invalid action shape, expected: {self.total_action_dim}, received: {action.shape[1]}.")
+ # store the input actions
+ self._prev_action[:]=self._action
+ self._action[:]=action.to(self.device)
+
+ # split the actions and apply to each tensor
+ idx=0
+ forterminself._terms.values():
+ term_actions=action[:,idx:idx+term.action_dim]
+ term.process_actions(term_actions)
+ idx+=term.action_dim
+
+
[文档]defapply_action(self)->None:
+"""Applies the actions to the environment/simulation.
+
+ Note:
+ This should be called at every simulation step.
+ """
+ forterminself._terms.values():
+ term.apply_actions()
+
+
[文档]defget_term(self,name:str)->ActionTerm:
+"""Returns the action term with the specified name.
+
+ Args:
+ name: The name of the action term.
+
+ Returns:
+ The action term with the specified name.
+ """
+ returnself._terms[name]
+
+"""
+ Helper functions.
+ """
+
+ def_prepare_terms(self):
+"""Prepares a list of action terms."""
+ # parse action terms from the config
+ self._term_names:list[str]=list()
+ self._terms:dict[str,ActionTerm]=dict()
+
+ # check if config is dict already
+ ifisinstance(self.cfg,dict):
+ cfg_items=self.cfg.items()
+ else:
+ cfg_items=self.cfg.__dict__.items()
+ forterm_name,term_cfgincfg_items:
+ # check if term config is None
+ ifterm_cfgisNone:
+ continue
+ # check valid type
+ ifnotisinstance(term_cfg,ActionTermCfg):
+ raiseTypeError(
+ f"Configuration for the term '{term_name}' is not of type ActionTermCfg."
+ f" Received: '{type(term_cfg)}'."
+ )
+ # create the action term
+ term=term_cfg.class_type(term_cfg,self._env)
+ # sanity check if term is valid type
+ ifnotisinstance(term,ActionTerm):
+ raiseTypeError(f"Returned object for the term '{term_name}' is not of type ActionType.")
+ # add term name and parameters
+ self._term_names.append(term_name)
+ self._terms[term_name]=term
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Command manager for generating and updating commands."""
+
+from__future__importannotations
+
+importinspect
+importtorch
+importweakref
+fromabcimportabstractmethod
+fromcollections.abcimportSequence
+fromprettytableimportPrettyTable
+fromtypingimportTYPE_CHECKING
+
+importomni.kit.app
+
+from.manager_baseimportManagerBase,ManagerTermBase
+from.manager_term_cfgimportCommandTermCfg
+
+ifTYPE_CHECKING:
+ fromomni.isaac.lab.envsimportManagerBasedRLEnv
+
+
+
[文档]classCommandTerm(ManagerTermBase):
+"""The base class for implementing a command term.
+
+ A command term is used to generate commands for goal-conditioned tasks. For example,
+ in the case of a goal-conditioned navigation task, the command term can be used to
+ generate a target position for the robot to navigate to.
+
+ It implements a resampling mechanism that allows the command to be resampled at a fixed
+ frequency. The resampling frequency can be specified in the configuration object.
+ Additionally, it is possible to assign a visualization function to the command term
+ that can be used to visualize the command in the simulator.
+ """
+
+ def__init__(self,cfg:CommandTermCfg,env:ManagerBasedRLEnv):
+"""Initialize the command generator class.
+
+ Args:
+ cfg: The configuration parameters for the command generator.
+ env: The environment object.
+ """
+ super().__init__(cfg,env)
+
+ # create buffers to store the command
+ # -- metrics that can be used for logging
+ self.metrics=dict()
+ # -- time left before resampling
+ self.time_left=torch.zeros(self.num_envs,device=self.device)
+ # -- counter for the number of times the command has been resampled within the current episode
+ self.command_counter=torch.zeros(self.num_envs,device=self.device,dtype=torch.long)
+
+ # add handle for debug visualization (this is set to a valid handle inside set_debug_vis)
+ self._debug_vis_handle=None
+ # set initial state of debug visualization
+ self.set_debug_vis(self.cfg.debug_vis)
+
+ def__del__(self):
+"""Unsubscribe from the callbacks."""
+ ifself._debug_vis_handle:
+ self._debug_vis_handle.unsubscribe()
+ self._debug_vis_handle=None
+
+"""
+ Properties
+ """
+
+ @property
+ @abstractmethod
+ defcommand(self)->torch.Tensor:
+"""The command tensor. Shape is (num_envs, command_dim)."""
+ raiseNotImplementedError
+
+ @property
+ defhas_debug_vis_implementation(self)->bool:
+"""Whether the command generator has a debug visualization implemented."""
+ # check if function raises NotImplementedError
+ source_code=inspect.getsource(self._set_debug_vis_impl)
+ return"NotImplementedError"notinsource_code
+
+"""
+ Operations.
+ """
+
+
[文档]defset_debug_vis(self,debug_vis:bool)->bool:
+"""Sets whether to visualize the command data.
+
+ Args:
+ debug_vis: Whether to visualize the command data.
+
+ Returns:
+ Whether the debug visualization was successfully set. False if the command
+ generator does not support debug visualization.
+ """
+ # check if debug visualization is supported
+ ifnotself.has_debug_vis_implementation:
+ returnFalse
+ # toggle debug visualization objects
+ self._set_debug_vis_impl(debug_vis)
+ # toggle debug visualization handles
+ ifdebug_vis:
+ # create a subscriber for the post update event if it doesn't exist
+ ifself._debug_vis_handleisNone:
+ app_interface=omni.kit.app.get_app_interface()
+ self._debug_vis_handle=app_interface.get_post_update_event_stream().create_subscription_to_pop(
+ lambdaevent,obj=weakref.proxy(self):obj._debug_vis_callback(event)
+ )
+ else:
+ # remove the subscriber if it exists
+ ifself._debug_vis_handleisnotNone:
+ self._debug_vis_handle.unsubscribe()
+ self._debug_vis_handle=None
+ # return success
+ returnTrue
+
+
[文档]defreset(self,env_ids:Sequence[int]|None=None)->dict[str,float]:
+"""Reset the command generator and log metrics.
+
+ This function resets the command counter and resamples the command. It should be called
+ at the beginning of each episode.
+
+ Args:
+ env_ids: The list of environment IDs to reset. Defaults to None.
+
+ Returns:
+ A dictionary containing the information to log under the "{name}" key.
+ """
+ # resolve the environment IDs
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ # set the command counter to zero
+ self.command_counter[env_ids]=0
+ # resample the command
+ self._resample(env_ids)
+ # add logging metrics
+ extras={}
+ formetric_name,metric_valueinself.metrics.items():
+ # compute the mean metric value
+ extras[metric_name]=torch.mean(metric_value[env_ids]).item()
+ # reset the metric value
+ metric_value[env_ids]=0.0
+ returnextras
+
+
[文档]defcompute(self,dt:float):
+"""Compute the command.
+
+ Args:
+ dt: The time step passed since the last call to compute.
+ """
+ # update the metrics based on current state
+ self._update_metrics()
+ # reduce the time left before resampling
+ self.time_left-=dt
+ # resample the command if necessary
+ resample_env_ids=(self.time_left<=0.0).nonzero().flatten()
+ iflen(resample_env_ids)>0:
+ self._resample(resample_env_ids)
+ # update the command
+ self._update_command()
+
+"""
+ Helper functions.
+ """
+
+ def_resample(self,env_ids:Sequence[int]):
+"""Resample the command.
+
+ This function resamples the command and time for which the command is applied for the
+ specified environment indices.
+
+ Args:
+ env_ids: The list of environment IDs to resample.
+ """
+ # resample the time left before resampling
+ iflen(env_ids)!=0:
+ self.time_left[env_ids]=self.time_left[env_ids].uniform_(*self.cfg.resampling_time_range)
+ # increment the command counter
+ self.command_counter[env_ids]+=1
+ # resample the command
+ self._resample_command(env_ids)
+
+"""
+ Implementation specific functions.
+ """
+
+ @abstractmethod
+ def_update_metrics(self):
+"""Update the metrics based on the current state."""
+ raiseNotImplementedError
+
+ @abstractmethod
+ def_resample_command(self,env_ids:Sequence[int]):
+"""Resample the command for the specified environments."""
+ raiseNotImplementedError
+
+ @abstractmethod
+ def_update_command(self):
+"""Update the command based on the current state."""
+ raiseNotImplementedError
+
+ def_set_debug_vis_impl(self,debug_vis:bool):
+"""Set debug visualization into visualization objects.
+
+ This function is responsible for creating the visualization objects if they don't exist
+ and input ``debug_vis`` is True. If the visualization objects exist, the function should
+ set their visibility into the stage.
+ """
+ raiseNotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.")
+
+ def_debug_vis_callback(self,event):
+"""Callback for debug visualization.
+
+ This function calls the visualization objects and sets the data to visualize into them.
+ """
+ raiseNotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.")
+
+
+
[文档]classCommandManager(ManagerBase):
+"""Manager for generating commands.
+
+ The command manager is used to generate commands for an agent to execute. It makes it convenient to switch
+ between different command generation strategies within the same environment. For instance, in an environment
+ consisting of a quadrupedal robot, the command to it could be a velocity command or position command.
+ By keeping the command generation logic separate from the environment, it is easy to switch between different
+ command generation strategies.
+
+ The command terms are implemented as classes that inherit from the :class:`CommandTerm` class.
+ Each command generator term should also have a corresponding configuration class that inherits from the
+ :class:`CommandTermCfg` class.
+ """
+
+ _env:ManagerBasedRLEnv
+"""The environment instance."""
+
+
[文档]def__init__(self,cfg:object,env:ManagerBasedRLEnv):
+"""Initialize the command manager.
+
+ Args:
+ cfg: The configuration object or dictionary (``dict[str, CommandTermCfg]``).
+ env: The environment instance.
+ """
+ super().__init__(cfg,env)
+ # store the commands
+ self._commands=dict()
+ self.cfg.debug_vis=False
+ forterminself._terms.values():
+ self.cfg.debug_vis|=term.cfg.debug_vis
+
+ def__str__(self)->str:
+"""Returns: A string representation for the command manager."""
+ msg=f"<CommandManager> contains {len(self._terms.values())} active terms.\n"
+
+ # create table for term information
+ table=PrettyTable()
+ table.title="Active Command Terms"
+ table.field_names=["Index","Name","Type"]
+ # set alignment of table columns
+ table.align["Name"]="l"
+ # add info on each term
+ forindex,(name,term)inenumerate(self._terms.items()):
+ table.add_row([index,name,term.__class__.__name__])
+ # convert table to string
+ msg+=table.get_string()
+ msg+="\n"
+
+ returnmsg
+
+"""
+ Properties.
+ """
+
+ @property
+ defactive_terms(self)->list[str]:
+"""Name of active command terms."""
+ returnlist(self._terms.keys())
+
+ @property
+ defhas_debug_vis_implementation(self)->bool:
+"""Whether the command terms have debug visualization implemented."""
+ # check if function raises NotImplementedError
+ has_debug_vis=False
+ forterminself._terms.values():
+ has_debug_vis|=term.has_debug_vis_implementation
+ returnhas_debug_vis
+
+"""
+ Operations.
+ """
+
+
[文档]defset_debug_vis(self,debug_vis:bool)->bool:
+"""Sets whether to visualize the command data.
+
+ Args:
+ debug_vis: Whether to visualize the command data.
+
+ Returns:
+ Whether the debug visualization was successfully set. False if the command
+ generator does not support debug visualization.
+ """
+ forterminself._terms.values():
+ term.set_debug_vis(debug_vis)
+
+
[文档]defreset(self,env_ids:Sequence[int]|None=None)->dict[str,torch.Tensor]:
+"""Reset the command terms and log their metrics.
+
+ This function resets the command counter and resamples the command for each term. It should be called
+ at the beginning of each episode.
+
+ Args:
+ env_ids: The list of environment IDs to reset. Defaults to None.
+
+ Returns:
+ A dictionary containing the information to log under the "Metrics/{term_name}/{metric_name}" key.
+ """
+ # resolve environment ids
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ # store information
+ extras={}
+ forname,terminself._terms.items():
+ # reset the command term
+ metrics=term.reset(env_ids=env_ids)
+ # compute the mean metric value
+ formetric_name,metric_valueinmetrics.items():
+ extras[f"Metrics/{name}/{metric_name}"]=metric_value
+ # return logged information
+ returnextras
+
+
[文档]defcompute(self,dt:float):
+"""Updates the commands.
+
+ This function calls each command term managed by the class.
+
+ Args:
+ dt: The time-step interval of the environment.
+
+ """
+ # iterate over all the command terms
+ forterminself._terms.values():
+ # compute term's value
+ term.compute(dt)
+
+
[文档]defget_command(self,name:str)->torch.Tensor:
+"""Returns the command for the specified command term.
+
+ Args:
+ name: The name of the command term.
+
+ Returns:
+ The command tensor of the specified command term.
+ """
+ returnself._terms[name].command
+
+
[文档]defget_term(self,name:str)->CommandTerm:
+"""Returns the command term with the specified name.
+
+ Args:
+ name: The name of the command term.
+
+ Returns:
+ The command term with the specified name.
+ """
+ returnself._terms[name]
+
+"""
+ Helper functions.
+ """
+
+ def_prepare_terms(self):
+"""Prepares a list of command terms."""
+ # parse command terms from the config
+ self._terms:dict[str,CommandTerm]=dict()
+
+ # check if config is dict already
+ ifisinstance(self.cfg,dict):
+ cfg_items=self.cfg.items()
+ else:
+ cfg_items=self.cfg.__dict__.items()
+ # iterate over all the terms
+ forterm_name,term_cfgincfg_items:
+ # check for non config
+ ifterm_cfgisNone:
+ continue
+ # check for valid config type
+ ifnotisinstance(term_cfg,CommandTermCfg):
+ raiseTypeError(
+ f"Configuration for the term '{term_name}' is not of type CommandTermCfg."
+ f" Received: '{type(term_cfg)}'."
+ )
+ # create the action term
+ term=term_cfg.class_type(term_cfg,self._env)
+ # add class to dict
+ self._terms[term_name]=term
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Curriculum manager for updating environment quantities subject to a training curriculum."""
+
+from__future__importannotations
+
+importtorch
+fromcollections.abcimportSequence
+fromprettytableimportPrettyTable
+fromtypingimportTYPE_CHECKING
+
+from.manager_baseimportManagerBase,ManagerTermBase
+from.manager_term_cfgimportCurriculumTermCfg
+
+ifTYPE_CHECKING:
+ fromomni.isaac.lab.envsimportManagerBasedRLEnv
+
+
+
[文档]classCurriculumManager(ManagerBase):
+"""Manager to implement and execute specific curricula.
+
+ The curriculum manager updates various quantities of the environment subject to a training curriculum by
+ calling a list of terms. These help stabilize learning by progressively making the learning tasks harder
+ as the agent improves.
+
+ The curriculum terms are parsed from a config class containing the manager's settings and each term's
+ parameters. Each curriculum term should instantiate the :class:`CurriculumTermCfg` class.
+ """
+
+ _env:ManagerBasedRLEnv
+"""The environment instance."""
+
+
[文档]def__init__(self,cfg:object,env:ManagerBasedRLEnv):
+"""Initialize the manager.
+
+ Args:
+ cfg: The configuration object or dictionary (``dict[str, CurriculumTermCfg]``)
+ env: An environment object.
+
+ Raises:
+ TypeError: If curriculum term is not of type :class:`CurriculumTermCfg`.
+ ValueError: If curriculum term configuration does not satisfy its function signature.
+ """
+ super().__init__(cfg,env)
+ # prepare logging
+ self._curriculum_state=dict()
+ forterm_nameinself._term_names:
+ self._curriculum_state[term_name]=None
+
+ def__str__(self)->str:
+"""Returns: A string representation for curriculum manager."""
+ msg=f"<CurriculumManager> contains {len(self._term_names)} active terms.\n"
+
+ # create table for term information
+ table=PrettyTable()
+ table.title="Active Curriculum Terms"
+ table.field_names=["Index","Name"]
+ # set alignment of table columns
+ table.align["Name"]="l"
+ # add info on each term
+ forindex,nameinenumerate(self._term_names):
+ table.add_row([index,name])
+ # convert table to string
+ msg+=table.get_string()
+ msg+="\n"
+
+ returnmsg
+
+"""
+ Properties.
+ """
+
+ @property
+ defactive_terms(self)->list[str]:
+"""Name of active curriculum terms."""
+ returnself._term_names
+
+"""
+ Operations.
+ """
+
+
[文档]defreset(self,env_ids:Sequence[int]|None=None)->dict[str,float]:
+"""Returns the current state of individual curriculum terms.
+
+ Note:
+ This function does not use the environment indices :attr:`env_ids`
+ and logs the state of all the terms. The argument is only present
+ to maintain consistency with other classes.
+
+ Returns:
+ Dictionary of curriculum terms and their states.
+ """
+ extras={}
+ forterm_name,term_stateinself._curriculum_state.items():
+ ifterm_stateisnotNone:
+ # deal with dict
+ ifisinstance(term_state,dict):
+ # each key is a separate state to log
+ forkey,valueinterm_state.items():
+ ifisinstance(value,torch.Tensor):
+ value=value.item()
+ extras[f"Curriculum/{term_name}/{key}"]=value
+ else:
+ # log directly if not a dict
+ ifisinstance(term_state,torch.Tensor):
+ term_state=term_state.item()
+ extras[f"Curriculum/{term_name}"]=term_state
+ # reset all the curriculum terms
+ forterm_cfginself._class_term_cfgs:
+ term_cfg.func.reset(env_ids=env_ids)
+ # return logged information
+ returnextras
+
+
[文档]defcompute(self,env_ids:Sequence[int]|None=None):
+"""Update the curriculum terms.
+
+ This function calls each curriculum term managed by the class.
+
+ Args:
+ env_ids: The list of environment IDs to update.
+ If None, all the environments are updated. Defaults to None.
+ """
+ # resolve environment indices
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ # iterate over all the curriculum terms
+ forname,term_cfginzip(self._term_names,self._term_cfgs):
+ state=term_cfg.func(self._env,env_ids,**term_cfg.params)
+ self._curriculum_state[name]=state
+
+"""
+ Helper functions.
+ """
+
+ def_prepare_terms(self):
+ # parse remaining curriculum terms and decimate their information
+ self._term_names:list[str]=list()
+ self._term_cfgs:list[CurriculumTermCfg]=list()
+ self._class_term_cfgs:list[CurriculumTermCfg]=list()
+
+ # check if config is dict already
+ ifisinstance(self.cfg,dict):
+ cfg_items=self.cfg.items()
+ else:
+ cfg_items=self.cfg.__dict__.items()
+ # iterate over all the terms
+ forterm_name,term_cfgincfg_items:
+ # check for non config
+ ifterm_cfgisNone:
+ continue
+ # check if the term is a valid term config
+ ifnotisinstance(term_cfg,CurriculumTermCfg):
+ raiseTypeError(
+ f"Configuration for the term '{term_name}' is not of type CurriculumTermCfg."
+ f" Received: '{type(term_cfg)}'."
+ )
+ # resolve common parameters
+ self._resolve_common_term_cfg(term_name,term_cfg,min_argc=2)
+ # add name and config to list
+ self._term_names.append(term_name)
+ self._term_cfgs.append(term_cfg)
+ # check if the term is a class
+ ifisinstance(term_cfg.func,ManagerTermBase):
+ self._class_term_cfgs.append(term_cfg)
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Event manager for orchestrating operations based on different simulation events."""
+
+from__future__importannotations
+
+importtorch
+fromcollections.abcimportSequence
+fromprettytableimportPrettyTable
+fromtypingimportTYPE_CHECKING
+
+importcarb
+
+from.manager_baseimportManagerBase,ManagerTermBase
+from.manager_term_cfgimportEventTermCfg
+
+ifTYPE_CHECKING:
+ fromomni.isaac.lab.envsimportManagerBasedEnv
+
+
+
[文档]classEventManager(ManagerBase):
+"""Manager for orchestrating operations based on different simulation events.
+
+ The event manager applies operations to the environment based on different simulation events. For example,
+ changing the masses of objects or their friction coefficients during initialization/ reset, or applying random
+ pushes to the robot at a fixed interval of steps. The user can specify several modes of events to fine-tune the
+ behavior based on when to apply the event.
+
+ The event terms are parsed from a config class containing the manager's settings and each term's
+ parameters. Each event term should instantiate the :class:`EventTermCfg` class.
+
+ Event terms can be grouped by their mode. The mode is a user-defined string that specifies when
+ the event term should be applied. This provides the user complete control over when event
+ terms should be applied.
+
+ For a typical training process, you may want to apply events in the following modes:
+
+ - "startup": Event is applied once at the beginning of the training.
+ - "reset": Event is applied at every reset.
+ - "interval": Event is applied at pre-specified intervals of time.
+
+ However, you can also define your own modes and use them in the training process as you see fit.
+ For this you will need to add the triggering of that mode in the environment implementation as well.
+
+ .. note::
+
+ The triggering of operations corresponding to the mode ``"interval"`` are the only mode that are
+ directly handled by the manager itself. The other modes are handled by the environment implementation.
+
+ """
+
+ _env:ManagerBasedEnv
+"""The environment instance."""
+
+
[文档]def__init__(self,cfg:object,env:ManagerBasedEnv):
+"""Initialize the event manager.
+
+ Args:
+ cfg: A configuration object or dictionary (``dict[str, EventTermCfg]``).
+ env: An environment object.
+ """
+ super().__init__(cfg,env)
+
+ def__str__(self)->str:
+"""Returns: A string representation for event manager."""
+ msg=f"<EventManager> contains {len(self._mode_term_names)} active terms.\n"
+
+ # add info on each mode
+ formodeinself._mode_term_names:
+ # create table for term information
+ table=PrettyTable()
+ table.title=f"Active Event Terms in Mode: '{mode}'"
+ # add table headers based on mode
+ ifmode=="interval":
+ table.field_names=["Index","Name","Interval time range (s)"]
+ table.align["Name"]="l"
+ forindex,(name,cfg)inenumerate(zip(self._mode_term_names[mode],self._mode_term_cfgs[mode])):
+ table.add_row([index,name,cfg.interval_range_s])
+ else:
+ table.field_names=["Index","Name"]
+ table.align["Name"]="l"
+ forindex,nameinenumerate(self._mode_term_names[mode]):
+ table.add_row([index,name])
+ # convert table to string
+ msg+=table.get_string()
+ msg+="\n"
+
+ returnmsg
+
+"""
+ Properties.
+ """
+
+ @property
+ defactive_terms(self)->dict[str,list[str]]:
+"""Name of active event terms.
+
+ The keys are the modes of event and the values are the names of the event terms.
+ """
+ returnself._mode_term_names
+
+ @property
+ defavailable_modes(self)->list[str]:
+"""Modes of events."""
+ returnlist(self._mode_term_names.keys())
+
+"""
+ Operations.
+ """
+
+
[文档]defreset(self,env_ids:Sequence[int]|None=None)->dict[str,float]:
+ # call all terms that are classes
+ formode_cfginself._mode_class_term_cfgs.values():
+ forterm_cfginmode_cfg:
+ term_cfg.func.reset(env_ids=env_ids)
+ # nothing to log here
+ return{}
+
+
[文档]defapply(
+ self,
+ mode:str,
+ env_ids:Sequence[int]|None=None,
+ dt:float|None=None,
+ global_env_step_count:int|None=None,
+ ):
+"""Calls each event term in the specified mode.
+
+ This function iterates over all the event terms in the specified mode and calls the function
+ corresponding to the term. The function is called with the environment instance and the environment
+ indices to apply the event to.
+
+ For the "interval" mode, the function is called when the time interval has passed. This requires
+ specifying the time step of the environment.
+
+ For the "reset" mode, the function is called when the mode is "reset" and the total number of environment
+ steps that have happened since the last trigger of the function is equal to its configured parameter for
+ the number of environment steps between resets.
+
+ Args:
+ mode: The mode of event.
+ env_ids: The indices of the environments to apply the event to.
+ Defaults to None, in which case the event is applied to all environments when applicable.
+ dt: The time step of the environment. This is only used for the "interval" mode.
+ Defaults to None to simplify the call for other modes.
+ global_env_step_count: The total number of environment steps that have happened. This is only used
+ for the "reset" mode. Defaults to None to simplify the call for other modes.
+
+ Raises:
+ ValueError: If the mode is ``"interval"`` and the time step is not provided.
+ ValueError: If the mode is ``"interval"`` and the environment indices are provided. This is an undefined
+ behavior as the environment indices are computed based on the time left for each environment.
+ ValueError: If the mode is ``"reset"`` and the total number of environment steps that have happened
+ is not provided.
+ """
+ # check if mode is valid
+ ifmodenotinself._mode_term_names:
+ carb.log_warn(f"Event mode '{mode}' is not defined. Skipping event.")
+ return
+ # check if mode is interval and dt is not provided
+ ifmode=="interval"anddtisNone:
+ raiseValueError(f"Event mode '{mode}' requires the time-step of the environment.")
+ ifmode=="interval"andenv_idsisnotNone:
+ raiseValueError(
+ f"Event mode '{mode}' does not require environment indices. This is an undefined behavior"
+ " as the environment indices are computed based on the time left for each environment."
+ )
+ # check if mode is reset and env step count is not provided
+ ifmode=="reset"andglobal_env_step_countisNone:
+ raiseValueError(f"Event mode '{mode}' requires the total number of environment steps to be provided.")
+
+ # iterate over all the event terms
+ forindex,term_cfginenumerate(self._mode_term_cfgs[mode]):
+ ifmode=="interval":
+ # extract time left for this term
+ time_left=self._interval_term_time_left[index]
+ # update the time left for each environment
+ time_left-=dt
+
+ # check if the interval has passed and sample a new interval
+ # note: we compare with a small value to handle floating point errors
+ ifterm_cfg.is_global_time:
+ iftime_left<1e-6:
+ lower,upper=term_cfg.interval_range_s
+ sampled_interval=torch.rand(1)*(upper-lower)+lower
+ self._interval_term_time_left[index][:]=sampled_interval
+
+ # call the event term (with None for env_ids)
+ term_cfg.func(self._env,None,**term_cfg.params)
+ else:
+ valid_env_ids=(time_left<1e-6).nonzero().flatten()
+ iflen(valid_env_ids)>0:
+ lower,upper=term_cfg.interval_range_s
+ sampled_time=torch.rand(len(valid_env_ids),device=self.device)*(upper-lower)+lower
+ self._interval_term_time_left[index][valid_env_ids]=sampled_time
+
+ # call the event term
+ term_cfg.func(self._env,valid_env_ids,**term_cfg.params)
+ elifmode=="reset":
+ # obtain the minimum step count between resets
+ min_step_count=term_cfg.min_step_count_between_reset
+ # resolve the environment indices
+ ifenv_idsisNone:
+ env_ids=slice(None)
+
+ # We bypass the trigger mechanism if min_step_count is zero, i.e. apply term on every reset call.
+ # This should avoid the overhead of checking the trigger condition.
+ ifmin_step_count==0:
+ self._reset_term_last_triggered_step_id[index][env_ids]=global_env_step_count
+ self._reset_term_last_triggered_once[index][env_ids]=True
+
+ # call the event term with the environment indices
+ term_cfg.func(self._env,env_ids,**term_cfg.params)
+ else:
+ # extract last reset step for this term
+ last_triggered_step=self._reset_term_last_triggered_step_id[index][env_ids]
+ triggered_at_least_once=self._reset_term_last_triggered_once[index][env_ids]
+ # compute the steps since last reset
+ steps_since_triggered=global_env_step_count-last_triggered_step
+
+ # check if the term can be applied after the minimum step count between triggers has passed
+ valid_trigger=steps_since_triggered>=min_step_count
+ # check if the term has not been triggered yet (in that case, we trigger it at least once)
+ # this is usually only needed at the start of the environment
+ valid_trigger|=(last_triggered_step==0)&~triggered_at_least_once
+
+ # select the valid environment indices based on the trigger
+ ifenv_ids==slice(None):
+ valid_env_ids=valid_trigger.nonzero().flatten()
+ else:
+ valid_env_ids=env_ids[valid_trigger]
+
+ # reset the last reset step for each environment to the current env step count
+ iflen(valid_env_ids)>0:
+ self._reset_term_last_triggered_once[index][valid_env_ids]=True
+ self._reset_term_last_triggered_step_id[index][valid_env_ids]=global_env_step_count
+
+ # call the event term
+ term_cfg.func(self._env,valid_env_ids,**term_cfg.params)
+ else:
+ # call the event term
+ term_cfg.func(self._env,env_ids,**term_cfg.params)
+
+"""
+ Operations - Term settings.
+ """
+
+
[文档]defset_term_cfg(self,term_name:str,cfg:EventTermCfg):
+"""Sets the configuration of the specified term into the manager.
+
+ The method finds the term by name by searching through all the modes.
+ It then updates the configuration of the term with the first matching name.
+
+ Args:
+ term_name: The name of the event term.
+ cfg: The configuration for the event term.
+
+ Raises:
+ ValueError: If the term name is not found.
+ """
+ term_found=False
+ formode,termsinself._mode_term_names.items():
+ ifterm_nameinterms:
+ self._mode_term_cfgs[mode][terms.index(term_name)]=cfg
+ term_found=True
+ break
+ ifnotterm_found:
+ raiseValueError(f"Event term '{term_name}' not found.")
+
+
[文档]defget_term_cfg(self,term_name:str)->EventTermCfg:
+"""Gets the configuration for the specified term.
+
+ The method finds the term by name by searching through all the modes.
+ It then returns the configuration of the term with the first matching name.
+
+ Args:
+ term_name: The name of the event term.
+
+ Returns:
+ The configuration of the event term.
+
+ Raises:
+ ValueError: If the term name is not found.
+ """
+ formode,termsinself._mode_term_names.items():
+ ifterm_nameinterms:
+ returnself._mode_term_cfgs[mode][terms.index(term_name)]
+ raiseValueError(f"Event term '{term_name}' not found.")
+
+"""
+ Helper functions.
+ """
+
+ def_prepare_terms(self):
+"""Prepares a list of event functions."""
+ # parse remaining event terms and decimate their information
+ self._mode_term_names:dict[str,list[str]]=dict()
+ self._mode_term_cfgs:dict[str,list[EventTermCfg]]=dict()
+ self._mode_class_term_cfgs:dict[str,list[EventTermCfg]]=dict()
+ # buffer to store the time left for "interval" mode
+ # if interval is global, then it is a single value, otherwise it is per environment
+ self._interval_term_time_left:list[torch.Tensor]=list()
+ # buffer to store the step count when the term was last triggered for each environment for "reset" mode
+ self._reset_term_last_triggered_step_id:list[torch.Tensor]=list()
+ self._reset_term_last_triggered_once:list[torch.Tensor]=list()
+
+ # check if config is dict already
+ ifisinstance(self.cfg,dict):
+ cfg_items=self.cfg.items()
+ else:
+ cfg_items=self.cfg.__dict__.items()
+ # iterate over all the terms
+ forterm_name,term_cfgincfg_items:
+ # check for non config
+ ifterm_cfgisNone:
+ continue
+ # check for valid config type
+ ifnotisinstance(term_cfg,EventTermCfg):
+ raiseTypeError(
+ f"Configuration for the term '{term_name}' is not of type EventTermCfg."
+ f" Received: '{type(term_cfg)}'."
+ )
+
+ ifterm_cfg.mode!="reset"andterm_cfg.min_step_count_between_reset!=0:
+ carb.log_warn(
+ f"Event term '{term_name}' has 'min_step_count_between_reset' set to a non-zero value"
+ " but the mode is not 'reset'. Ignoring the 'min_step_count_between_reset' value."
+ )
+
+ # resolve common parameters
+ self._resolve_common_term_cfg(term_name,term_cfg,min_argc=2)
+ # check if mode is a new mode
+ ifterm_cfg.modenotinself._mode_term_names:
+ # add new mode
+ self._mode_term_names[term_cfg.mode]=list()
+ self._mode_term_cfgs[term_cfg.mode]=list()
+ self._mode_class_term_cfgs[term_cfg.mode]=list()
+ # add term name and parameters
+ self._mode_term_names[term_cfg.mode].append(term_name)
+ self._mode_term_cfgs[term_cfg.mode].append(term_cfg)
+
+ # check if the term is a class
+ ifisinstance(term_cfg.func,ManagerTermBase):
+ self._mode_class_term_cfgs[term_cfg.mode].append(term_cfg)
+
+ # resolve the mode of the events
+ # -- interval mode
+ ifterm_cfg.mode=="interval":
+ ifterm_cfg.interval_range_sisNone:
+ raiseValueError(
+ f"Event term '{term_name}' has mode 'interval' but 'interval_range_s' is not specified."
+ )
+
+ # sample the time left for global
+ ifterm_cfg.is_global_time:
+ lower,upper=term_cfg.interval_range_s
+ time_left=torch.rand(1)*(upper-lower)+lower
+ self._interval_term_time_left.append(time_left)
+ else:
+ # sample the time left for each environment
+ lower,upper=term_cfg.interval_range_s
+ time_left=torch.rand(self.num_envs,device=self.device)*(upper-lower)+lower
+ self._interval_term_time_left.append(time_left)
+ # -- reset mode
+ elifterm_cfg.mode=="reset":
+ ifterm_cfg.min_step_count_between_reset<0:
+ raiseValueError(
+ f"Event term '{term_name}' has mode 'reset' but 'min_step_count_between_reset' is"
+ f" negative: {term_cfg.min_step_count_between_reset}. Please provide a non-negative value."
+ )
+
+ # initialize the current step count for each environment to zero
+ step_count=torch.zeros(self.num_envs,device=self.device,dtype=torch.int32)
+ self._reset_term_last_triggered_step_id.append(step_count)
+ # initialize the trigger flag for each environment to zero
+ no_trigger=torch.zeros(self.num_envs,device=self.device,dtype=torch.bool)
+ self._reset_term_last_triggered_once.append(no_trigger)
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+from__future__importannotations
+
+importcopy
+importinspect
+fromabcimportABC,abstractmethod
+fromcollections.abcimportSequence
+fromtypingimportTYPE_CHECKING,Any
+
+importcarb
+
+importomni.isaac.lab.utils.stringasstring_utils
+fromomni.isaac.lab.utilsimportstring_to_callable
+
+from.manager_term_cfgimportManagerTermBaseCfg
+from.scene_entity_cfgimportSceneEntityCfg
+
+ifTYPE_CHECKING:
+ fromomni.isaac.lab.envsimportManagerBasedEnv
+
+
+
[文档]classManagerTermBase(ABC):
+"""Base class for manager terms.
+
+ Manager term implementations can be functions or classes. If the term is a class, it should
+ inherit from this base class and implement the required methods.
+
+ Each manager is implemented as a class that inherits from the :class:`ManagerBase` class. Each manager
+ class should also have a corresponding configuration class that defines the configuration terms for the
+ manager. Each term should the :class:`ManagerTermBaseCfg` class or its subclass.
+
+ Example pseudo-code for creating a manager:
+
+ .. code-block:: python
+
+ from omni.isaac.lab.utils import configclass
+ from omni.isaac.lab.utils.mdp import ManagerBase, ManagerTermBaseCfg
+
+ @configclass
+ class MyManagerCfg:
+
+ my_term_1: ManagerTermBaseCfg = ManagerTermBaseCfg(...)
+ my_term_2: ManagerTermBaseCfg = ManagerTermBaseCfg(...)
+ my_term_3: ManagerTermBaseCfg = ManagerTermBaseCfg(...)
+
+ # define manager instance
+ my_manager = ManagerBase(cfg=ManagerCfg(), env=env)
+
+ """
+
+
[文档]def__init__(self,cfg:ManagerTermBaseCfg,env:ManagerBasedEnv):
+"""Initialize the manager term.
+
+ Args:
+ cfg: The configuration object.
+ env: The environment instance.
+ """
+ # store the inputs
+ self.cfg=cfg
+ self._env=env
+
+"""
+ Properties.
+ """
+
+ @property
+ defnum_envs(self)->int:
+"""Number of environments."""
+ returnself._env.num_envs
+
+ @property
+ defdevice(self)->str:
+"""Device on which to perform computations."""
+ returnself._env.device
+
+"""
+ Operations.
+ """
+
+
[文档]defreset(self,env_ids:Sequence[int]|None=None)->None:
+"""Resets the manager term.
+
+ Args:
+ env_ids: The environment ids. Defaults to None, in which case
+ all environments are considered.
+ """
+ pass
+
+ def__call__(self,*args)->Any:
+"""Returns the value of the term required by the manager.
+
+ In case of a class implementation, this function is called by the manager
+ to get the value of the term. The arguments passed to this function are
+ the ones specified in the term configuration (see :attr:`ManagerTermBaseCfg.params`).
+
+ .. attention::
+ To be consistent with memory-less implementation of terms with functions, it is
+ recommended to ensure that the returned mutable quantities are cloned before
+ returning them. For instance, if the term returns a tensor, it is recommended
+ to ensure that the returned tensor is a clone of the original tensor. This prevents
+ the manager from storing references to the tensors and altering the original tensors.
+
+ Args:
+ *args: Variable length argument list.
+
+ Returns:
+ The value of the term.
+ """
+ raiseNotImplementedError
+
+
+
[文档]classManagerBase(ABC):
+"""Base class for all managers."""
+
+
[文档]def__init__(self,cfg:object,env:ManagerBasedEnv):
+"""Initialize the manager.
+
+ Args:
+ cfg: The configuration object.
+ env: The environment instance.
+ """
+ # store the inputs
+ self.cfg=copy.deepcopy(cfg)
+ self._env=env
+ # parse config to create terms information
+ self._prepare_terms()
+
+"""
+ Properties.
+ """
+
+ @property
+ defnum_envs(self)->int:
+"""Number of environments."""
+ returnself._env.num_envs
+
+ @property
+ defdevice(self)->str:
+"""Device on which to perform computations."""
+ returnself._env.device
+
+ @property
+ @abstractmethod
+ defactive_terms(self)->list[str]|dict[str,list[str]]:
+"""Name of active terms."""
+ raiseNotImplementedError
+
+"""
+ Operations.
+ """
+
+
[文档]defreset(self,env_ids:Sequence[int]|None=None)->dict[str,float]:
+"""Resets the manager and returns logging information for the current time-step.
+
+ Args:
+ env_ids: The environment ids for which to log data.
+ Defaults None, which logs data for all environments.
+
+ Returns:
+ Dictionary containing the logging information.
+ """
+ return{}
+
+
[文档]deffind_terms(self,name_keys:str|Sequence[str])->list[str]:
+"""Find terms in the manager based on the names.
+
+ This function searches the manager for terms based on the names. The names can be
+ specified as regular expressions or a list of regular expressions. The search is
+ performed on the active terms in the manager.
+
+ Please check the :meth:`omni.isaac.lab.utils.string_utils.resolve_matching_names` function for more
+ information on the name matching.
+
+ Args:
+ name_keys: A regular expression or a list of regular expressions to match the term names.
+
+ Returns:
+ A list of term names that match the input keys.
+ """
+ # resolve search keys
+ ifisinstance(self.active_terms,dict):
+ list_of_strings=[]
+ fornamesinself.active_terms.values():
+ list_of_strings.extend(names)
+ else:
+ list_of_strings=self.active_terms
+
+ # return the matching names
+ returnstring_utils.resolve_matching_names(name_keys,list_of_strings)[1]
+
+"""
+ Implementation specific.
+ """
+
+ @abstractmethod
+ def_prepare_terms(self):
+"""Prepare terms information from the configuration object."""
+ raiseNotImplementedError
+
+"""
+ Helper functions.
+ """
+
+ def_resolve_common_term_cfg(self,term_name:str,term_cfg:ManagerTermBaseCfg,min_argc:int=1):
+"""Resolve common term configuration.
+
+ Usually, called by the :meth:`_prepare_terms` method to resolve common term configuration.
+
+ Note:
+ By default, all term functions are expected to have at least one argument, which is the
+ environment object. Some other managers may expect functions to take more arguments, for
+ instance, the environment indices as the second argument. In such cases, the
+ ``min_argc`` argument can be used to specify the minimum number of arguments
+ required by the term function to be called correctly by the manager.
+
+ Args:
+ term_name: The name of the term.
+ term_cfg: The term configuration.
+ min_argc: The minimum number of arguments required by the term function to be called correctly
+ by the manager.
+
+ Raises:
+ TypeError: If the term configuration is not of type :class:`ManagerTermBaseCfg`.
+ ValueError: If the scene entity defined in the term configuration does not exist.
+ AttributeError: If the term function is not callable.
+ ValueError: If the term function's arguments are not matched by the parameters.
+ """
+ # check if the term is a valid term config
+ ifnotisinstance(term_cfg,ManagerTermBaseCfg):
+ raiseTypeError(
+ f"Configuration for the term '{term_name}' is not of type ManagerTermBaseCfg."
+ f" Received: '{type(term_cfg)}'."
+ )
+ # iterate over all the entities and parse the joint and body names
+ forkey,valueinterm_cfg.params.items():
+ # deal with string
+ ifisinstance(value,SceneEntityCfg):
+ # load the entity
+ try:
+ value.resolve(self._env.scene)
+ exceptValueErrorase:
+ raiseValueError(f"Error while parsing '{term_name}:{key}'. {e}")
+ # log the entity for checking later
+ msg=f"[{term_cfg.__class__.__name__}:{term_name}] Found entity '{value.name}'."
+ ifvalue.joint_idsisnotNone:
+ msg+=f"\n\tJoint names: {value.joint_names} [{value.joint_ids}]"
+ ifvalue.body_idsisnotNone:
+ msg+=f"\n\tBody names: {value.body_names} [{value.body_ids}]"
+ # print the information
+ carb.log_info(msg)
+ # store the entity
+ term_cfg.params[key]=value
+
+ # get the corresponding function or functional class
+ ifisinstance(term_cfg.func,str):
+ term_cfg.func=string_to_callable(term_cfg.func)
+
+ # initialize the term if it is a class
+ ifinspect.isclass(term_cfg.func):
+ ifnotissubclass(term_cfg.func,ManagerTermBase):
+ raiseTypeError(
+ f"Configuration for the term '{term_name}' is not of type ManagerTermBase."
+ f" Received: '{type(term_cfg.func)}'."
+ )
+ term_cfg.func=term_cfg.func(cfg=term_cfg,env=self._env)
+ # check if function is callable
+ ifnotcallable(term_cfg.func):
+ raiseAttributeError(f"The term '{term_name}' is not callable. Received: {term_cfg.func}")
+
+ # check if term's arguments are matched by params
+ term_params=list(term_cfg.params.keys())
+ args=inspect.signature(term_cfg.func).parameters
+ args_with_defaults=[argforarginargsifargs[arg].defaultisnotinspect.Parameter.empty]
+ args_without_defaults=[argforarginargsifargs[arg].defaultisinspect.Parameter.empty]
+ args=args_without_defaults+args_with_defaults
+ # ignore first two arguments for env and env_ids
+ # Think: Check for cases when kwargs are set inside the function?
+ iflen(args)>min_argc:
+ ifset(args[min_argc:])!=set(term_params+args_with_defaults):
+ raiseValueError(
+ f"The term '{term_name}' expects mandatory parameters: {args_without_defaults[min_argc:]}"
+ f" and optional parameters: {args_with_defaults}, but received: {term_params}."
+ )
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Configuration terms for different managers."""
+
+from__future__importannotations
+
+importtorch
+fromcollections.abcimportCallable
+fromdataclassesimportMISSING
+fromtypingimportTYPE_CHECKING,Any
+
+fromomni.isaac.lab.utilsimportconfigclass
+fromomni.isaac.lab.utils.modifiersimportModifierCfg
+fromomni.isaac.lab.utils.noiseimportNoiseCfg
+
+from.scene_entity_cfgimportSceneEntityCfg
+
+ifTYPE_CHECKING:
+ from.action_managerimportActionTerm
+ from.command_managerimportCommandTerm
+ from.manager_baseimportManagerTermBase
+
+
+
[文档]@configclass
+classManagerTermBaseCfg:
+"""Configuration for a manager term."""
+
+ func:Callable|ManagerTermBase=MISSING
+"""The function or class to be called for the term.
+
+ The function must take the environment object as the first argument.
+ The remaining arguments are specified in the :attr:`params` attribute.
+
+ It also supports `callable classes`_, i.e. classes that implement the :meth:`__call__`
+ method. In this case, the class should inherit from the :class:`ManagerTermBase` class
+ and implement the required methods.
+
+ .. _`callable classes`: https://docs.python.org/3/reference/datamodel.html#object.__call__
+ """
+
+ params:dict[str,Any|SceneEntityCfg]=dict()
+"""The parameters to be passed to the function as keyword arguments. Defaults to an empty dict.
+
+ .. note::
+ If the value is a :class:`SceneEntityCfg` object, the manager will query the scene entity
+ from the :class:`InteractiveScene` and process the entity's joints and bodies as specified
+ in the :class:`SceneEntityCfg` object.
+ """
+
+
+##
+# Action manager.
+##
+
+
+
[文档]@configclass
+classActionTermCfg:
+"""Configuration for an action term."""
+
+ class_type:type[ActionTerm]=MISSING
+"""The associated action term class.
+
+ The class should inherit from :class:`omni.isaac.lab.managers.action_manager.ActionTerm`.
+ """
+
+ asset_name:str=MISSING
+"""The name of the scene entity.
+
+ This is the name defined in the scene configuration file. See the :class:`InteractiveSceneCfg`
+ class for more details.
+ """
+
+ debug_vis:bool=False
+"""Whether to visualize debug information. Defaults to False."""
+
+
+##
+# Command manager.
+##
+
+
+
[文档]@configclass
+classCommandTermCfg:
+"""Configuration for a command generator term."""
+
+ class_type:type[CommandTerm]=MISSING
+"""The associated command term class to use.
+
+ The class should inherit from :class:`omni.isaac.lab.managers.command_manager.CommandTerm`.
+ """
+
+ resampling_time_range:tuple[float,float]=MISSING
+"""Time before commands are changed [s]."""
+ debug_vis:bool=False
+"""Whether to visualize debug information. Defaults to False."""
+
+
+##
+# Curriculum manager.
+##
+
+
+
[文档]@configclass
+classCurriculumTermCfg(ManagerTermBaseCfg):
+"""Configuration for a curriculum term."""
+
+ func:Callable[...,float|dict[str,float]|None]=MISSING
+"""The name of the function to be called.
+
+ This function should take the environment object, environment indices
+ and any other parameters as input and return the curriculum state for
+ logging purposes. If the function returns None, the curriculum state
+ is not logged.
+ """
+
+
+##
+# Observation manager.
+##
+
+
+
[文档]@configclass
+classObservationTermCfg(ManagerTermBaseCfg):
+"""Configuration for an observation term."""
+
+ func:Callable[...,torch.Tensor]=MISSING
+"""The name of the function to be called.
+
+ This function should take the environment object and any other parameters
+ as input and return the observation signal as torch float tensors of
+ shape (num_envs, obs_term_dim).
+ """
+
+ modifiers:list[ModifierCfg]|None=None
+"""The list of data modifiers to apply to the observation in order. Defaults to None,
+ in which case no modifications will be applied.
+
+ Modifiers are applied in the order they are specified in the list. They can be stateless
+ or stateful, and can be used to apply transformations to the observation data. For example,
+ a modifier can be used to normalize the observation data or to apply a rolling average.
+
+ For more information on modifiers, see the :class:`~omni.isaac.lab.utils.modifiers.ModifierCfg` class.
+ """
+
+ noise:NoiseCfg|None=None
+"""The noise to add to the observation. Defaults to None, in which case no noise is added."""
+
+ clip:tuple[float,float]|None=None
+"""The clipping range for the observation after adding noise. Defaults to None,
+ in which case no clipping is applied."""
+
+ scale:float|None=None
+"""The scale to apply to the observation after clipping. Defaults to None,
+ in which case no scaling is applied (same as setting scale to :obj:`1`)."""
+
+
+
[文档]@configclass
+classObservationGroupCfg:
+"""Configuration for an observation group."""
+
+ concatenate_terms:bool=True
+"""Whether to concatenate the observation terms in the group. Defaults to True.
+
+ If true, the observation terms in the group are concatenated along the last dimension.
+ Otherwise, they are kept separate and returned as a dictionary.
+
+ If the observation group contains terms of different dimensions, it must be set to False.
+ """
+
+ enable_corruption:bool=False
+"""Whether to enable corruption for the observation group. Defaults to False.
+
+ If true, the observation terms in the group are corrupted by adding noise (if specified).
+ Otherwise, no corruption is applied.
+ """
+
+
+##
+# Event manager
+##
+
+
+
[文档]@configclass
+classEventTermCfg(ManagerTermBaseCfg):
+"""Configuration for a event term."""
+
+ func:Callable[...,None]=MISSING
+"""The name of the function to be called.
+
+ This function should take the environment object, environment indices
+ and any other parameters as input.
+ """
+
+ mode:str=MISSING
+"""The mode in which the event term is applied.
+
+ Note:
+ The mode name ``"interval"`` is a special mode that is handled by the
+ manager Hence, its name is reserved and cannot be used for other modes.
+ """
+
+ interval_range_s:tuple[float,float]|None=None
+"""The range of time in seconds at which the term is applied. Defaults to None.
+
+ Based on this, the interval is sampled uniformly between the specified
+ range for each environment instance. The term is applied on the environment
+ instances where the current time hits the interval time.
+
+ Note:
+ This is only used if the mode is ``"interval"``.
+ """
+
+ is_global_time:bool=False
+"""Whether randomization should be tracked on a per-environment basis. Defaults to False.
+
+ If True, the same interval time is used for all the environment instances.
+ If False, the interval time is sampled independently for each environment instance
+ and the term is applied when the current time hits the interval time for that instance.
+
+ Note:
+ This is only used if the mode is ``"interval"``.
+ """
+
+ min_step_count_between_reset:int=0
+"""The number of environment steps after which the term is applied since its last application. Defaults to 0.
+
+ When the mode is "reset", the term is only applied if the number of environment steps since
+ its last application exceeds this quantity. This helps to avoid calling the term too often,
+ thereby improving performance.
+
+ If the value is zero, the term is applied on every call to the manager with the mode "reset".
+
+ Note:
+ This is only used if the mode is ``"reset"``.
+ """
+
+
+##
+# Reward manager.
+##
+
+
+
[文档]@configclass
+classRewardTermCfg(ManagerTermBaseCfg):
+"""Configuration for a reward term."""
+
+ func:Callable[...,torch.Tensor]=MISSING
+"""The name of the function to be called.
+
+ This function should take the environment object and any other parameters
+ as input and return the reward signals as torch float tensors of
+ shape (num_envs,).
+ """
+
+ weight:float=MISSING
+"""The weight of the reward term.
+
+ This is multiplied with the reward term's value to compute the final
+ reward.
+
+ Note:
+ If the weight is zero, the reward term is ignored.
+ """
+
+
+##
+# Termination manager.
+##
+
+
+
[文档]@configclass
+classTerminationTermCfg(ManagerTermBaseCfg):
+"""Configuration for a termination term."""
+
+ func:Callable[...,torch.Tensor]=MISSING
+"""The name of the function to be called.
+
+ This function should take the environment object and any other parameters
+ as input and return the termination signals as torch boolean tensors of
+ shape (num_envs,).
+ """
+
+ time_out:bool=False
+"""Whether the termination term contributes towards episodic timeouts. Defaults to False.
+
+ Note:
+ These usually correspond to tasks that have a fixed time limit.
+ """
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Observation manager for computing observation signals for a given world."""
+
+from__future__importannotations
+
+importinspect
+importtorch
+fromcollections.abcimportSequence
+fromprettytableimportPrettyTable
+fromtypingimportTYPE_CHECKING
+
+fromomni.isaac.lab.utilsimportmodifiers
+
+from.manager_baseimportManagerBase,ManagerTermBase
+from.manager_term_cfgimportObservationGroupCfg,ObservationTermCfg
+
+ifTYPE_CHECKING:
+ fromomni.isaac.lab.envsimportManagerBasedEnv
+
+
+
[文档]classObservationManager(ManagerBase):
+"""Manager for computing observation signals for a given world.
+
+ Observations are organized into groups based on their intended usage. This allows having different observation
+ groups for different types of learning such as asymmetric actor-critic and student-teacher training. Each
+ group contains observation terms which contain information about the observation function to call, the noise
+ corruption model to use, and the sensor to retrieve data from.
+
+ Each observation group should inherit from the :class:`ObservationGroupCfg` class. Within each group, each
+ observation term should instantiate the :class:`ObservationTermCfg` class. Based on the configuration, the
+ observations in a group can be concatenated into a single tensor or returned as a dictionary with keys
+ corresponding to the term's name.
+
+ If the observations in a group are concatenated, the shape of the concatenated tensor is computed based on the
+ shapes of the individual observation terms. This information is stored in the :attr:`group_obs_dim` dictionary
+ with keys as the group names and values as the shape of the observation tensor. When the terms in a group are not
+ concatenated, the attribute stores a list of shapes for each term in the group.
+
+ .. note::
+ When the observation terms in a group do not have the same shape, the observation terms cannot be
+ concatenated. In this case, please set the :attr:`ObservationGroupCfg.concatenate_terms` attribute in the
+ group configuration to False.
+
+ The observation manager can be used to compute observations for all the groups or for a specific group. The
+ observations are computed by calling the registered functions for each term in the group. The functions are
+ called in the order of the terms in the group. The functions are expected to return a tensor with shape
+ (num_envs, ...).
+
+ If a noise model or custom modifier is registered for a term, the function is called to corrupt
+ the observation. The corruption function is expected to return a tensor with the same shape as the observation.
+ The observations are clipped and scaled as per the configuration settings.
+ """
+
+
[文档]def__init__(self,cfg:object,env:ManagerBasedEnv):
+"""Initialize observation manager.
+
+ Args:
+ cfg: The configuration object or dictionary (``dict[str, ObservationGroupCfg]``).
+ env: The environment instance.
+
+ Raises:
+ RuntimeError: If the shapes of the observation terms in a group are not compatible for concatenation
+ and the :attr:`~ObservationGroupCfg.concatenate_terms` attribute is set to True.
+ """
+ super().__init__(cfg,env)
+
+ # compute combined vector for obs group
+ self._group_obs_dim:dict[str,tuple[int,...]|list[tuple[int,...]]]=dict()
+ forgroup_name,group_term_dimsinself._group_obs_term_dim.items():
+ # if terms are concatenated, compute the combined shape into a single tuple
+ # otherwise, keep the list of shapes as is
+ ifself._group_obs_concatenate[group_name]:
+ try:
+ term_dims=[torch.tensor(dims,device="cpu")fordimsingroup_term_dims]
+ self._group_obs_dim[group_name]=tuple(torch.sum(torch.stack(term_dims,dim=0),dim=0).tolist())
+ exceptRuntimeError:
+ raiseRuntimeError(
+ f"Unable to concatenate observation terms in group '{group_name}'."
+ f" The shapes of the terms are: {group_term_dims}."
+ " Please ensure that the shapes are compatible for concatenation."
+ " Otherwise, set 'concatenate_terms' to False in the group configuration."
+ )
+ else:
+ self._group_obs_dim[group_name]=group_term_dims
+
+ def__str__(self)->str:
+"""Returns: A string representation for the observation manager."""
+ msg=f"<ObservationManager> contains {len(self._group_obs_term_names)} groups.\n"
+
+ # add info for each group
+ forgroup_name,group_diminself._group_obs_dim.items():
+ # create table for term information
+ table=PrettyTable()
+ table.title=f"Active Observation Terms in Group: '{group_name}'"
+ ifself._group_obs_concatenate[group_name]:
+ table.title+=f" (shape: {group_dim})"
+ table.field_names=["Index","Name","Shape"]
+ # set alignment of table columns
+ table.align["Name"]="l"
+ # add info for each term
+ obs_terms=zip(
+ self._group_obs_term_names[group_name],
+ self._group_obs_term_dim[group_name],
+ )
+ forindex,(name,dims)inenumerate(obs_terms):
+ # resolve inputs to simplify prints
+ tab_dims=tuple(dims)
+ # add row
+ table.add_row([index,name,tab_dims])
+ # convert table to string
+ msg+=table.get_string()
+ msg+="\n"
+
+ returnmsg
+
+"""
+ Properties.
+ """
+
+ @property
+ defactive_terms(self)->dict[str,list[str]]:
+"""Name of active observation terms in each group.
+
+ The keys are the group names and the values are the list of observation term names in the group.
+ """
+ returnself._group_obs_term_names
+
+ @property
+ defgroup_obs_dim(self)->dict[str,tuple[int,...]|list[tuple[int,...]]]:
+"""Shape of computed observations in each group.
+
+ The key is the group name and the value is the shape of the observation tensor.
+ If the terms in the group are concatenated, the value is a single tuple representing the
+ shape of the concatenated observation tensor. Otherwise, the value is a list of tuples,
+ where each tuple represents the shape of the observation tensor for a term in the group.
+ """
+ returnself._group_obs_dim
+
+ @property
+ defgroup_obs_term_dim(self)->dict[str,list[tuple[int,...]]]:
+"""Shape of individual observation terms in each group.
+
+ The key is the group name and the value is a list of tuples representing the shape of the observation terms
+ in the group. The order of the tuples corresponds to the order of the terms in the group.
+ This matches the order of the terms in the :attr:`active_terms`.
+ """
+ returnself._group_obs_term_dim
+
+ @property
+ defgroup_obs_concatenate(self)->dict[str,bool]:
+"""Whether the observation terms are concatenated in each group or not.
+
+ The key is the group name and the value is a boolean specifying whether the observation terms in the group
+ are concatenated into a single tensor. If True, the observations are concatenated along the last dimension.
+
+ The values are set based on the :attr:`~ObservationGroupCfg.concatenate_terms` attribute in the group
+ configuration.
+ """
+ returnself._group_obs_concatenate
+
+"""
+ Operations.
+ """
+
+
[文档]defreset(self,env_ids:Sequence[int]|None=None)->dict[str,float]:
+ # call all terms that are classes
+ forgroup_cfginself._group_obs_class_term_cfgs.values():
+ forterm_cfgingroup_cfg:
+ term_cfg.func.reset(env_ids=env_ids)
+ # call all modifiers that are classes
+ formodinself._group_obs_class_modifiers:
+ mod.reset(env_ids=env_ids)
+ # nothing to log here
+ return{}
+
+
[文档]defcompute(self)->dict[str,torch.Tensor|dict[str,torch.Tensor]]:
+"""Compute the observations per group for all groups.
+
+ The method computes the observations for all the groups handled by the observation manager.
+ Please check the :meth:`compute_group` on the processing of observations per group.
+
+ Returns:
+ A dictionary with keys as the group names and values as the computed observations.
+ The observations are either concatenated into a single tensor or returned as a dictionary
+ with keys corresponding to the term's name.
+ """
+ # create a buffer for storing obs from all the groups
+ obs_buffer=dict()
+ # iterate over all the terms in each group
+ forgroup_nameinself._group_obs_term_names:
+ obs_buffer[group_name]=self.compute_group(group_name)
+ # otherwise return a dict with observations of all groups
+ returnobs_buffer
+
+
[文档]defcompute_group(self,group_name:str)->torch.Tensor|dict[str,torch.Tensor]:
+"""Computes the observations for a given group.
+
+ The observations for a given group are computed by calling the registered functions for each
+ term in the group. The functions are called in the order of the terms in the group. The functions
+ are expected to return a tensor with shape (num_envs, ...).
+
+ The following steps are performed for each observation term:
+
+ 1. Compute observation term by calling the function
+ 2. Apply custom modifiers in the order specified in :attr:`ObservationTermCfg.modifiers`
+ 3. Apply corruption/noise model based on :attr:`ObservationTermCfg.noise`
+ 4. Apply clipping based on :attr:`ObservationTermCfg.clip`
+ 5. Apply scaling based on :attr:`ObservationTermCfg.scale`
+
+ We apply noise to the computed term first to maintain the integrity of how noise affects the data
+ as it truly exists in the real world. If the noise is applied after clipping or scaling, the noise
+ could be artificially constrained or amplified, which might misrepresent how noise naturally occurs
+ in the data.
+
+ Args:
+ group_name: The name of the group for which to compute the observations. Defaults to None,
+ in which case observations for all the groups are computed and returned.
+
+ Returns:
+ Depending on the group's configuration, the tensors for individual observation terms are
+ concatenated along the last dimension into a single tensor. Otherwise, they are returned as
+ a dictionary with keys corresponding to the term's name.
+
+ Raises:
+ ValueError: If input ``group_name`` is not a valid group handled by the manager.
+ """
+ # check ig group name is valid
+ ifgroup_namenotinself._group_obs_term_names:
+ raiseValueError(
+ f"Unable to find the group '{group_name}' in the observation manager."
+ f" Available groups are: {list(self._group_obs_term_names.keys())}"
+ )
+ # iterate over all the terms in each group
+ group_term_names=self._group_obs_term_names[group_name]
+ # buffer to store obs per group
+ group_obs=dict.fromkeys(group_term_names,None)
+ # read attributes for each term
+ obs_terms=zip(group_term_names,self._group_obs_term_cfgs[group_name])
+
+ # evaluate terms: compute, add noise, clip, scale, custom modifiers
+ forname,term_cfginobs_terms:
+ # compute term's value
+ obs:torch.Tensor=term_cfg.func(self._env,**term_cfg.params).clone()
+ # apply post-processing
+ ifterm_cfg.modifiersisnotNone:
+ formodifierinterm_cfg.modifiers:
+ obs=modifier.func(obs,**modifier.params)
+ ifterm_cfg.noise:
+ obs=term_cfg.noise.func(obs,term_cfg.noise)
+ ifterm_cfg.clip:
+ obs=obs.clip_(min=term_cfg.clip[0],max=term_cfg.clip[1])
+ ifterm_cfg.scale:
+ obs=obs.mul_(term_cfg.scale)
+ # add value to list
+ group_obs[name]=obs
+
+ # concatenate all observations in the group together
+ ifself._group_obs_concatenate[group_name]:
+ returntorch.cat(list(group_obs.values()),dim=-1)
+ else:
+ returngroup_obs
+
+"""
+ Helper functions.
+ """
+
+ def_prepare_terms(self):
+"""Prepares a list of observation terms functions."""
+ # create buffers to store information for each observation group
+ # TODO: Make this more convenient by using data structures.
+ self._group_obs_term_names:dict[str,list[str]]=dict()
+ self._group_obs_term_dim:dict[str,list[tuple[int,...]]]=dict()
+ self._group_obs_term_cfgs:dict[str,list[ObservationTermCfg]]=dict()
+ self._group_obs_class_term_cfgs:dict[str,list[ObservationTermCfg]]=dict()
+ self._group_obs_concatenate:dict[str,bool]=dict()
+
+ # create a list to store modifiers that are classes
+ # we store it as a separate list to only call reset on them and prevent unnecessary calls
+ self._group_obs_class_modifiers:list[modifiers.ModifierBase]=list()
+
+ # check if config is dict already
+ ifisinstance(self.cfg,dict):
+ group_cfg_items=self.cfg.items()
+ else:
+ group_cfg_items=self.cfg.__dict__.items()
+ # iterate over all the groups
+ forgroup_name,group_cfgingroup_cfg_items:
+ # check for non config
+ ifgroup_cfgisNone:
+ continue
+ # check if the term is a curriculum term
+ ifnotisinstance(group_cfg,ObservationGroupCfg):
+ raiseTypeError(
+ f"Observation group '{group_name}' is not of type 'ObservationGroupCfg'."
+ f" Received: '{type(group_cfg)}'."
+ )
+ # initialize list for the group settings
+ self._group_obs_term_names[group_name]=list()
+ self._group_obs_term_dim[group_name]=list()
+ self._group_obs_term_cfgs[group_name]=list()
+ self._group_obs_class_term_cfgs[group_name]=list()
+ # read common config for the group
+ self._group_obs_concatenate[group_name]=group_cfg.concatenate_terms
+ # check if config is dict already
+ ifisinstance(group_cfg,dict):
+ group_cfg_items=group_cfg.items()
+ else:
+ group_cfg_items=group_cfg.__dict__.items()
+ # iterate over all the terms in each group
+ forterm_name,term_cfgingroup_cfg.__dict__.items():
+ # skip non-obs settings
+ ifterm_namein["enable_corruption","concatenate_terms"]:
+ continue
+ # check for non config
+ ifterm_cfgisNone:
+ continue
+ ifnotisinstance(term_cfg,ObservationTermCfg):
+ raiseTypeError(
+ f"Configuration for the term '{term_name}' is not of type ObservationTermCfg."
+ f" Received: '{type(term_cfg)}'."
+ )
+ # resolve common terms in the config
+ self._resolve_common_term_cfg(f"{group_name}/{term_name}",term_cfg,min_argc=1)
+
+ # check noise settings
+ ifnotgroup_cfg.enable_corruption:
+ term_cfg.noise=None
+ # add term config to list to list
+ self._group_obs_term_names[group_name].append(term_name)
+ self._group_obs_term_cfgs[group_name].append(term_cfg)
+
+ # call function the first time to fill up dimensions
+ obs_dims=tuple(term_cfg.func(self._env,**term_cfg.params).shape)
+ self._group_obs_term_dim[group_name].append(obs_dims[1:])
+
+ # prepare modifiers for each observation
+ ifterm_cfg.modifiersisnotNone:
+ # initialize list of modifiers for term
+ formod_cfginterm_cfg.modifiers:
+ # check if class modifier and initialize with observation size when adding
+ ifisinstance(mod_cfg,modifiers.ModifierCfg):
+ # to list of modifiers
+ ifinspect.isclass(mod_cfg.func):
+ ifnotissubclass(mod_cfg.func,modifiers.ModifierBase):
+ raiseTypeError(
+ f"Modifier function '{mod_cfg.func}' for observation term '{term_name}'"
+ f" is not a subclass of 'ModifierBase'. Received: '{type(mod_cfg.func)}'."
+ )
+ mod_cfg.func=mod_cfg.func(cfg=mod_cfg,data_dim=obs_dims,device=self._env.device)
+
+ # add to list of class modifiers
+ self._group_obs_class_modifiers.append(mod_cfg.func)
+ else:
+ raiseTypeError(
+ f"Modifier configuration '{mod_cfg}' of observation term '{term_name}' is not of"
+ f" required type ModifierCfg, Received: '{type(mod_cfg)}'"
+ )
+
+ # check if function is callable
+ ifnotcallable(mod_cfg.func):
+ raiseAttributeError(
+ f"Modifier '{mod_cfg}' of observation term '{term_name}' is not callable."
+ f" Received: {mod_cfg.func}"
+ )
+
+ # check if term's arguments are matched by params
+ term_params=list(mod_cfg.params.keys())
+ args=inspect.signature(mod_cfg.func).parameters
+ args_with_defaults=[argforarginargsifargs[arg].defaultisnotinspect.Parameter.empty]
+ args_without_defaults=[argforarginargsifargs[arg].defaultisinspect.Parameter.empty]
+ args=args_without_defaults+args_with_defaults
+ # ignore first two arguments for env and env_ids
+ # Think: Check for cases when kwargs are set inside the function?
+ iflen(args)>1:
+ ifset(args[1:])!=set(term_params+args_with_defaults):
+ raiseValueError(
+ f"Modifier '{mod_cfg}' of observation term '{term_name}' expects"
+ f" mandatory parameters: {args_without_defaults[1:]}"
+ f" and optional parameters: {args_with_defaults}, but received: {term_params}."
+ )
+
+ # add term in a separate list if term is a class
+ ifisinstance(term_cfg.func,ManagerTermBase):
+ self._group_obs_class_term_cfgs[group_name].append(term_cfg)
+ # call reset (in-case above call to get obs dims changed the state)
+ term_cfg.func.reset()
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Reward manager for computing reward signals for a given world."""
+
+from__future__importannotations
+
+importtorch
+fromcollections.abcimportSequence
+fromprettytableimportPrettyTable
+fromtypingimportTYPE_CHECKING
+
+from.manager_baseimportManagerBase,ManagerTermBase
+from.manager_term_cfgimportRewardTermCfg
+
+ifTYPE_CHECKING:
+ fromomni.isaac.lab.envsimportManagerBasedRLEnv
+
+
+
[文档]classRewardManager(ManagerBase):
+"""Manager for computing reward signals for a given world.
+
+ The reward manager computes the total reward as a sum of the weighted reward terms. The reward
+ terms are parsed from a nested config class containing the reward manger's settings and reward
+ terms configuration.
+
+ The reward terms are parsed from a config class containing the manager's settings and each term's
+ parameters. Each reward term should instantiate the :class:`RewardTermCfg` class.
+
+ .. note::
+
+ The reward manager multiplies the reward term's ``weight`` with the time-step interval ``dt``
+ of the environment. This is done to ensure that the computed reward terms are balanced with
+ respect to the chosen time-step interval in the environment.
+
+ """
+
+ _env:ManagerBasedRLEnv
+"""The environment instance."""
+
+
[文档]def__init__(self,cfg:object,env:ManagerBasedRLEnv):
+"""Initialize the reward manager.
+
+ Args:
+ cfg: The configuration object or dictionary (``dict[str, RewardTermCfg]``).
+ env: The environment instance.
+ """
+ super().__init__(cfg,env)
+ # prepare extra info to store individual reward term information
+ self._episode_sums=dict()
+ forterm_nameinself._term_names:
+ self._episode_sums[term_name]=torch.zeros(self.num_envs,dtype=torch.float,device=self.device)
+ # create buffer for managing reward per environment
+ self._reward_buf=torch.zeros(self.num_envs,dtype=torch.float,device=self.device)
+
+ def__str__(self)->str:
+"""Returns: A string representation for reward manager."""
+ msg=f"<RewardManager> contains {len(self._term_names)} active terms.\n"
+
+ # create table for term information
+ table=PrettyTable()
+ table.title="Active Reward Terms"
+ table.field_names=["Index","Name","Weight"]
+ # set alignment of table columns
+ table.align["Name"]="l"
+ table.align["Weight"]="r"
+ # add info on each term
+ forindex,(name,term_cfg)inenumerate(zip(self._term_names,self._term_cfgs)):
+ table.add_row([index,name,term_cfg.weight])
+ # convert table to string
+ msg+=table.get_string()
+ msg+="\n"
+
+ returnmsg
+
+"""
+ Properties.
+ """
+
+ @property
+ defactive_terms(self)->list[str]:
+"""Name of active reward terms."""
+ returnself._term_names
+
+"""
+ Operations.
+ """
+
+
[文档]defreset(self,env_ids:Sequence[int]|None=None)->dict[str,torch.Tensor]:
+"""Returns the episodic sum of individual reward terms.
+
+ Args:
+ env_ids: The environment ids for which the episodic sum of
+ individual reward terms is to be returned. Defaults to all the environment ids.
+
+ Returns:
+ Dictionary of episodic sum of individual reward terms.
+ """
+ # resolve environment ids
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ # store information
+ extras={}
+ forkeyinself._episode_sums.keys():
+ # store information
+ # r_1 + r_2 + ... + r_n
+ episodic_sum_avg=torch.mean(self._episode_sums[key][env_ids])
+ extras["Episode_Reward/"+key]=episodic_sum_avg/self._env.max_episode_length_s
+ # reset episodic sum
+ self._episode_sums[key][env_ids]=0.0
+ # reset all the reward terms
+ forterm_cfginself._class_term_cfgs:
+ term_cfg.func.reset(env_ids=env_ids)
+ # return logged information
+ returnextras
+
+
[文档]defcompute(self,dt:float)->torch.Tensor:
+"""Computes the reward signal as a weighted sum of individual terms.
+
+ This function calls each reward term managed by the class and adds them to compute the net
+ reward signal. It also updates the episodic sums corresponding to individual reward terms.
+
+ Args:
+ dt: The time-step interval of the environment.
+
+ Returns:
+ The net reward signal of shape (num_envs,).
+ """
+ # reset computation
+ self._reward_buf[:]=0.0
+ # iterate over all the reward terms
+ forname,term_cfginzip(self._term_names,self._term_cfgs):
+ # skip if weight is zero (kind of a micro-optimization)
+ ifterm_cfg.weight==0.0:
+ continue
+ # compute term's value
+ value=term_cfg.func(self._env,**term_cfg.params)*term_cfg.weight*dt
+ # update total reward
+ self._reward_buf+=value
+ # update episodic sum
+ self._episode_sums[name]+=value
+
+ returnself._reward_buf
+
+"""
+ Operations - Term settings.
+ """
+
+
[文档]defset_term_cfg(self,term_name:str,cfg:RewardTermCfg):
+"""Sets the configuration of the specified term into the manager.
+
+ Args:
+ term_name: The name of the reward term.
+ cfg: The configuration for the reward term.
+
+ Raises:
+ ValueError: If the term name is not found.
+ """
+ ifterm_namenotinself._term_names:
+ raiseValueError(f"Reward term '{term_name}' not found.")
+ # set the configuration
+ self._term_cfgs[self._term_names.index(term_name)]=cfg
+
+
[文档]defget_term_cfg(self,term_name:str)->RewardTermCfg:
+"""Gets the configuration for the specified term.
+
+ Args:
+ term_name: The name of the reward term.
+
+ Returns:
+ The configuration of the reward term.
+
+ Raises:
+ ValueError: If the term name is not found.
+ """
+ ifterm_namenotinself._term_names:
+ raiseValueError(f"Reward term '{term_name}' not found.")
+ # return the configuration
+ returnself._term_cfgs[self._term_names.index(term_name)]
+
+"""
+ Helper functions.
+ """
+
+ def_prepare_terms(self):
+"""Prepares a list of reward functions."""
+ # parse remaining reward terms and decimate their information
+ self._term_names:list[str]=list()
+ self._term_cfgs:list[RewardTermCfg]=list()
+ self._class_term_cfgs:list[RewardTermCfg]=list()
+
+ # check if config is dict already
+ ifisinstance(self.cfg,dict):
+ cfg_items=self.cfg.items()
+ else:
+ cfg_items=self.cfg.__dict__.items()
+ # iterate over all the terms
+ forterm_name,term_cfgincfg_items:
+ # check for non config
+ ifterm_cfgisNone:
+ continue
+ # check for valid config type
+ ifnotisinstance(term_cfg,RewardTermCfg):
+ raiseTypeError(
+ f"Configuration for the term '{term_name}' is not of type RewardTermCfg."
+ f" Received: '{type(term_cfg)}'."
+ )
+ # check for valid weight type
+ ifnotisinstance(term_cfg.weight,(float,int)):
+ raiseTypeError(
+ f"Weight for the term '{term_name}' is not of type float or int."
+ f" Received: '{type(term_cfg.weight)}'."
+ )
+ # resolve common parameters
+ self._resolve_common_term_cfg(term_name,term_cfg,min_argc=1)
+ # add function to list
+ self._term_names.append(term_name)
+ self._term_cfgs.append(term_cfg)
+ # check if the term is a class
+ ifisinstance(term_cfg.func,ManagerTermBase):
+ self._class_term_cfgs.append(term_cfg)
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Configuration terms for different managers."""
+
+fromdataclassesimportMISSING
+
+fromomni.isaac.lab.assetsimportArticulation,RigidObject
+fromomni.isaac.lab.sceneimportInteractiveScene
+fromomni.isaac.lab.utilsimportconfigclass
+
+
+
[文档]@configclass
+classSceneEntityCfg:
+"""Configuration for a scene entity that is used by the manager's term.
+
+ This class is used to specify the name of the scene entity that is queried from the
+ :class:`InteractiveScene` and passed to the manager's term function.
+ """
+
+ name:str=MISSING
+"""The name of the scene entity.
+
+ This is the name defined in the scene configuration file. See the :class:`InteractiveSceneCfg`
+ class for more details.
+ """
+
+ joint_names:str|list[str]|None=None
+"""The names of the joints from the scene entity. Defaults to None.
+
+ The names can be either joint names or a regular expression matching the joint names.
+
+ These are converted to joint indices on initialization of the manager and passed to the term
+ function as a list of joint indices under :attr:`joint_ids`.
+ """
+
+ joint_ids:list[int]|slice=slice(None)
+"""The indices of the joints from the asset required by the term. Defaults to slice(None), which means
+ all the joints in the asset (if present).
+
+ If :attr:`joint_names` is specified, this is filled in automatically on initialization of the
+ manager.
+ """
+
+ fixed_tendon_names:str|list[str]|None=None
+"""The names of the fixed tendons from the scene entity. Defaults to None.
+
+ The names can be either joint names or a regular expression matching the joint names.
+
+ These are converted to fixed tendon indices on initialization of the manager and passed to the term
+ function as a list of fixed tendon indices under :attr:`fixed_tendon_ids`.
+ """
+
+ fixed_tendon_ids:list[int]|slice=slice(None)
+"""The indices of the fixed tendons from the asset required by the term. Defaults to slice(None), which means
+ all the fixed tendons in the asset (if present).
+
+ If :attr:`fixed_tendon_names` is specified, this is filled in automatically on initialization of the
+ manager.
+ """
+
+ body_names:str|list[str]|None=None
+"""The names of the bodies from the asset required by the term. Defaults to None.
+
+ The names can be either body names or a regular expression matching the body names.
+
+ These are converted to body indices on initialization of the manager and passed to the term
+ function as a list of body indices under :attr:`body_ids`.
+ """
+
+ body_ids:list[int]|slice=slice(None)
+"""The indices of the bodies from the asset required by the term. Defaults to slice(None), which means
+ all the bodies in the asset.
+
+ If :attr:`body_names` is specified, this is filled in automatically on initialization of the
+ manager.
+ """
+
+ preserve_order:bool=False
+"""Whether to preserve indices ordering to match with that in the specified joint or body names. Defaults to False.
+
+ If False, the ordering of the indices are sorted in ascending order (i.e. the ordering in the entity's joints
+ or bodies). Otherwise, the indices are preserved in the order of the specified joint and body names.
+
+ For more details, see the :meth:`omni.isaac.lab.utils.string.resolve_matching_names` function.
+
+ .. note::
+ This attribute is only used when :attr:`joint_names` or :attr:`body_names` are specified.
+
+ """
+
+
[文档]defresolve(self,scene:InteractiveScene):
+"""Resolves the scene entity and converts the joint and body names to indices.
+
+ This function examines the scene entity from the :class:`InteractiveScene` and resolves the indices
+ and names of the joints and bodies. It is an expensive operation as it resolves regular expressions
+ and should be called only once.
+
+ Args:
+ scene: The interactive scene instance.
+
+ Raises:
+ ValueError: If the scene entity is not found.
+ ValueError: If both ``joint_names`` and ``joint_ids`` are specified and are not consistent.
+ ValueError: If both ``fixed_tendon_names`` and ``fixed_tendon_ids`` are specified and are not consistent.
+ ValueError: If both ``body_names`` and ``body_ids`` are specified and are not consistent.
+ """
+ # check if the entity is valid
+ ifself.namenotinscene.keys():
+ raiseValueError(f"The scene entity '{self.name}' does not exist. Available entities: {scene.keys()}.")
+
+ # convert joint names to indices based on regex
+ self._resolve_joint_names(scene)
+
+ # convert fixed tendon names to indices based on regex
+ self._resolve_fixed_tendon_names(scene)
+
+ # convert body names to indices based on regex
+ self._resolve_body_names(scene)
+
+ def_resolve_joint_names(self,scene:InteractiveScene):
+ # convert joint names to indices based on regex
+ ifself.joint_namesisnotNoneorself.joint_ids!=slice(None):
+ entity:Articulation=scene[self.name]
+ # -- if both are not their default values, check if they are valid
+ ifself.joint_namesisnotNoneandself.joint_ids!=slice(None):
+ ifisinstance(self.joint_names,str):
+ self.joint_names=[self.joint_names]
+ ifisinstance(self.joint_ids,int):
+ self.joint_ids=[self.joint_ids]
+ joint_ids,_=entity.find_joints(self.joint_names,preserve_order=self.preserve_order)
+ joint_names=[entity.joint_names[i]foriinself.joint_ids]
+ ifjoint_ids!=self.joint_idsorjoint_names!=self.joint_names:
+ raiseValueError(
+ "Both 'joint_names' and 'joint_ids' are specified, and are not consistent."
+ f"\n\tfrom joint names: {self.joint_names} [{joint_ids}]"
+ f"\n\tfrom joint ids: {joint_names} [{self.joint_ids}]"
+ "\nHint: Use either 'joint_names' or 'joint_ids' to avoid confusion."
+ )
+ # -- from joint names to joint indices
+ elifself.joint_namesisnotNone:
+ ifisinstance(self.joint_names,str):
+ self.joint_names=[self.joint_names]
+ self.joint_ids,_=entity.find_joints(self.joint_names,preserve_order=self.preserve_order)
+ # performance optimization (slice offers faster indexing than list of indices)
+ # only all joint in the entity order are selected
+ iflen(self.joint_ids)==entity.num_jointsandself.joint_names==entity.joint_names:
+ self.joint_ids=slice(None)
+ # -- from joint indices to joint names
+ elifself.joint_ids!=slice(None):
+ ifisinstance(self.joint_ids,int):
+ self.joint_ids=[self.joint_ids]
+ self.joint_names=[entity.joint_names[i]foriinself.joint_ids]
+
+ def_resolve_fixed_tendon_names(self,scene:InteractiveScene):
+ # convert tendon names to indices based on regex
+ ifself.fixed_tendon_namesisnotNoneorself.fixed_tendon_ids!=slice(None):
+ entity:Articulation=scene[self.name]
+ # -- if both are not their default values, check if they are valid
+ ifself.fixed_tendon_namesisnotNoneandself.fixed_tendon_ids!=slice(None):
+ ifisinstance(self.fixed_tendon_names,str):
+ self.fixed_tendon_names=[self.fixed_tendon_names]
+ ifisinstance(self.fixed_tendon_ids,int):
+ self.fixed_tendon_ids=[self.fixed_tendon_ids]
+ fixed_tendon_ids,_=entity.find_fixed_tendons(
+ self.fixed_tendon_names,preserve_order=self.preserve_order
+ )
+ fixed_tendon_names=[entity.fixed_tendon_names[i]foriinself.fixed_tendon_ids]
+ iffixed_tendon_ids!=self.fixed_tendon_idsorfixed_tendon_names!=self.fixed_tendon_names:
+ raiseValueError(
+ "Both 'fixed_tendon_names' and 'fixed_tendon_ids' are specified, and are not consistent."
+ f"\n\tfrom joint names: {self.fixed_tendon_names} [{fixed_tendon_ids}]"
+ f"\n\tfrom joint ids: {fixed_tendon_names} [{self.fixed_tendon_ids}]"
+ "\nHint: Use either 'fixed_tendon_names' or 'fixed_tendon_ids' to avoid confusion."
+ )
+ # -- from fixed tendon names to fixed tendon indices
+ elifself.fixed_tendon_namesisnotNone:
+ ifisinstance(self.fixed_tendon_names,str):
+ self.fixed_tendon_names=[self.fixed_tendon_names]
+ self.fixed_tendon_ids,_=entity.find_fixed_tendons(
+ self.fixed_tendon_names,preserve_order=self.preserve_order
+ )
+ # performance optimization (slice offers faster indexing than list of indices)
+ # only all fixed tendon in the entity order are selected
+ if(
+ len(self.fixed_tendon_ids)==entity.num_fixed_tendons
+ andself.fixed_tendon_names==entity.fixed_tendon_names
+ ):
+ self.fixed_tendon_ids=slice(None)
+ # -- from fixed tendon indices to fixed tendon names
+ elifself.fixed_tendon_ids!=slice(None):
+ ifisinstance(self.fixed_tendon_ids,int):
+ self.fixed_tendon_ids=[self.fixed_tendon_ids]
+ self.fixed_tendon_names=[entity.fixed_tendon_names[i]foriinself.fixed_tendon_ids]
+
+ def_resolve_body_names(self,scene:InteractiveScene):
+ # convert body names to indices based on regex
+ ifself.body_namesisnotNoneorself.body_ids!=slice(None):
+ entity:RigidObject=scene[self.name]
+ # -- if both are not their default values, check if they are valid
+ ifself.body_namesisnotNoneandself.body_ids!=slice(None):
+ ifisinstance(self.body_names,str):
+ self.body_names=[self.body_names]
+ ifisinstance(self.body_ids,int):
+ self.body_ids=[self.body_ids]
+ body_ids,_=entity.find_bodies(self.body_names,preserve_order=self.preserve_order)
+ body_names=[entity.body_names[i]foriinself.body_ids]
+ ifbody_ids!=self.body_idsorbody_names!=self.body_names:
+ raiseValueError(
+ "Both 'body_names' and 'body_ids' are specified, and are not consistent."
+ f"\n\tfrom body names: {self.body_names} [{body_ids}]"
+ f"\n\tfrom body ids: {body_names} [{self.body_ids}]"
+ "\nHint: Use either 'body_names' or 'body_ids' to avoid confusion."
+ )
+ # -- from body names to body indices
+ elifself.body_namesisnotNone:
+ ifisinstance(self.body_names,str):
+ self.body_names=[self.body_names]
+ self.body_ids,_=entity.find_bodies(self.body_names,preserve_order=self.preserve_order)
+ # performance optimization (slice offers faster indexing than list of indices)
+ # only all bodies in the entity order are selected
+ iflen(self.body_ids)==entity.num_bodiesandself.body_names==entity.body_names:
+ self.body_ids=slice(None)
+ # -- from body indices to body names
+ elifself.body_ids!=slice(None):
+ ifisinstance(self.body_ids,int):
+ self.body_ids=[self.body_ids]
+ self.body_names=[entity.body_names[i]foriinself.body_ids]
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Termination manager for computing done signals for a given world."""
+
+from__future__importannotations
+
+importtorch
+fromcollections.abcimportSequence
+fromprettytableimportPrettyTable
+fromtypingimportTYPE_CHECKING
+
+from.manager_baseimportManagerBase,ManagerTermBase
+from.manager_term_cfgimportTerminationTermCfg
+
+ifTYPE_CHECKING:
+ fromomni.isaac.lab.envsimportManagerBasedRLEnv
+
+
+
[文档]classTerminationManager(ManagerBase):
+"""Manager for computing done signals for a given world.
+
+ The termination manager computes the termination signal (also called dones) as a combination
+ of termination terms. Each termination term is a function which takes the environment as an
+ argument and returns a boolean tensor of shape (num_envs,). The termination manager
+ computes the termination signal as the union (logical or) of all the termination terms.
+
+ Following the `Gymnasium API <https://gymnasium.farama.org/tutorials/gymnasium_basics/handling_time_limits/>`_,
+ the termination signal is computed as the logical OR of the following signals:
+
+ * **Time-out**: This signal is set to true if the environment has ended after an externally defined condition
+ (that is outside the scope of a MDP). For example, the environment may be terminated if the episode has
+ timed out (i.e. reached max episode length).
+ * **Terminated**: This signal is set to true if the environment has reached a terminal state defined by the
+ environment. This state may correspond to task success, task failure, robot falling, etc.
+
+ These signals can be individually accessed using the :attr:`time_outs` and :attr:`terminated` properties.
+
+ The termination terms are parsed from a config class containing the manager's settings and each term's
+ parameters. Each termination term should instantiate the :class:`TerminationTermCfg` class. The term's
+ configuration :attr:`TerminationTermCfg.time_out` decides whether the term is a timeout or a termination term.
+ """
+
+ _env:ManagerBasedRLEnv
+"""The environment instance."""
+
+
[文档]def__init__(self,cfg:object,env:ManagerBasedRLEnv):
+"""Initializes the termination manager.
+
+ Args:
+ cfg: The configuration object or dictionary (``dict[str, TerminationTermCfg]``).
+ env: An environment object.
+ """
+ super().__init__(cfg,env)
+ # prepare extra info to store individual termination term information
+ self._term_dones=dict()
+ forterm_nameinself._term_names:
+ self._term_dones[term_name]=torch.zeros(self.num_envs,device=self.device,dtype=torch.bool)
+ # create buffer for managing termination per environment
+ self._truncated_buf=torch.zeros(self.num_envs,device=self.device,dtype=torch.bool)
+ self._terminated_buf=torch.zeros_like(self._truncated_buf)
+
+ def__str__(self)->str:
+"""Returns: A string representation for termination manager."""
+ msg=f"<TerminationManager> contains {len(self._term_names)} active terms.\n"
+
+ # create table for term information
+ table=PrettyTable()
+ table.title="Active Termination Terms"
+ table.field_names=["Index","Name","Time Out"]
+ # set alignment of table columns
+ table.align["Name"]="l"
+ # add info on each term
+ forindex,(name,term_cfg)inenumerate(zip(self._term_names,self._term_cfgs)):
+ table.add_row([index,name,term_cfg.time_out])
+ # convert table to string
+ msg+=table.get_string()
+ msg+="\n"
+
+ returnmsg
+
+"""
+ Properties.
+ """
+
+ @property
+ defactive_terms(self)->list[str]:
+"""Name of active termination terms."""
+ returnself._term_names
+
+ @property
+ defdones(self)->torch.Tensor:
+"""The net termination signal. Shape is (num_envs,)."""
+ returnself._truncated_buf|self._terminated_buf
+
+ @property
+ deftime_outs(self)->torch.Tensor:
+"""The timeout signal (reaching max episode length). Shape is (num_envs,).
+
+ This signal is set to true if the environment has ended after an externally defined condition
+ (that is outside the scope of a MDP). For example, the environment may be terminated if the episode has
+ timed out (i.e. reached max episode length).
+ """
+ returnself._truncated_buf
+
+ @property
+ defterminated(self)->torch.Tensor:
+"""The terminated signal (reaching a terminal state). Shape is (num_envs,).
+
+ This signal is set to true if the environment has reached a terminal state defined by the environment.
+ This state may correspond to task success, task failure, robot falling, etc.
+ """
+ returnself._terminated_buf
+
+"""
+ Operations.
+ """
+
+
[文档]defreset(self,env_ids:Sequence[int]|None=None)->dict[str,torch.Tensor]:
+"""Returns the episodic counts of individual termination terms.
+
+ Args:
+ env_ids: The environment ids. Defaults to None, in which case
+ all environments are considered.
+
+ Returns:
+ Dictionary of episodic sum of individual reward terms.
+ """
+ # resolve environment ids
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ # add to episode dict
+ extras={}
+ forkeyinself._term_dones.keys():
+ # store information
+ extras["Episode_Termination/"+key]=torch.count_nonzero(self._term_dones[key][env_ids]).item()
+ # reset all the reward terms
+ forterm_cfginself._class_term_cfgs:
+ term_cfg.func.reset(env_ids=env_ids)
+ # return logged information
+ returnextras
+
+
[文档]defcompute(self)->torch.Tensor:
+"""Computes the termination signal as union of individual terms.
+
+ This function calls each termination term managed by the class and performs a logical OR operation
+ to compute the net termination signal.
+
+ Returns:
+ The combined termination signal of shape (num_envs,).
+ """
+ # reset computation
+ self._truncated_buf[:]=False
+ self._terminated_buf[:]=False
+ # iterate over all the termination terms
+ forname,term_cfginzip(self._term_names,self._term_cfgs):
+ value=term_cfg.func(self._env,**term_cfg.params)
+ # store timeout signal separately
+ ifterm_cfg.time_out:
+ self._truncated_buf|=value
+ else:
+ self._terminated_buf|=value
+ # add to episode dones
+ self._term_dones[name][:]=value
+ # return combined termination signal
+ returnself._truncated_buf|self._terminated_buf
+
+
[文档]defget_term(self,name:str)->torch.Tensor:
+"""Returns the termination term with the specified name.
+
+ Args:
+ name: The name of the termination term.
+
+ Returns:
+ The corresponding termination term value. Shape is (num_envs,).
+ """
+ returnself._term_dones[name]
+
+"""
+ Operations - Term settings.
+ """
+
+
[文档]defset_term_cfg(self,term_name:str,cfg:TerminationTermCfg):
+"""Sets the configuration of the specified term into the manager.
+
+ Args:
+ term_name: The name of the termination term.
+ cfg: The configuration for the termination term.
+
+ Raises:
+ ValueError: If the term name is not found.
+ """
+ ifterm_namenotinself._term_names:
+ raiseValueError(f"Termination term '{term_name}' not found.")
+ # set the configuration
+ self._term_cfgs[self._term_names.index(term_name)]=cfg
+
+
[文档]defget_term_cfg(self,term_name:str)->TerminationTermCfg:
+"""Gets the configuration for the specified term.
+
+ Args:
+ term_name: The name of the termination term.
+
+ Returns:
+ The configuration of the termination term.
+
+ Raises:
+ ValueError: If the term name is not found.
+ """
+ ifterm_namenotinself._term_names:
+ raiseValueError(f"Termination term '{term_name}' not found.")
+ # return the configuration
+ returnself._term_cfgs[self._term_names.index(term_name)]
+
+"""
+ Helper functions.
+ """
+
+ def_prepare_terms(self):
+"""Prepares a list of termination functions."""
+ # parse remaining termination terms and decimate their information
+ self._term_names:list[str]=list()
+ self._term_cfgs:list[TerminationTermCfg]=list()
+ self._class_term_cfgs:list[TerminationTermCfg]=list()
+
+ # check if config is dict already
+ ifisinstance(self.cfg,dict):
+ cfg_items=self.cfg.items()
+ else:
+ cfg_items=self.cfg.__dict__.items()
+ # iterate over all the terms
+ forterm_name,term_cfgincfg_items:
+ # check for non config
+ ifterm_cfgisNone:
+ continue
+ # check for valid config type
+ ifnotisinstance(term_cfg,TerminationTermCfg):
+ raiseTypeError(
+ f"Configuration for the term '{term_name}' is not of type TerminationTermCfg."
+ f" Received: '{type(term_cfg)}'."
+ )
+ # resolve common parameters
+ self._resolve_common_term_cfg(term_name,term_cfg,min_argc=1)
+ # add function to list
+ self._term_names.append(term_name)
+ self._term_cfgs.append(term_cfg)
+ # check if the term is a class
+ ifisinstance(term_cfg.func,ManagerTermBase):
+ self._class_term_cfgs.append(term_cfg)
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""A class to coordinate groups of visual markers (such as spheres, frames or arrows)
+using `UsdGeom.PointInstancer`_ class.
+
+The class :class:`VisualizationMarkers` is used to create a group of visual markers and
+visualize them in the viewport. The markers are represented as :class:`UsdGeom.PointInstancer` prims
+in the USD stage. The markers are created as prototypes in the :class:`UsdGeom.PointInstancer` prim
+and are instanced in the :class:`UsdGeom.PointInstancer` prim. The markers can be visualized by
+passing the indices of the marker prototypes and their translations, orientations and scales.
+The marker prototypes can be configured with the :class:`VisualizationMarkersCfg` class.
+
+.. _UsdGeom.PointInstancer: https://graphics.pixar.com/usd/dev/api/class_usd_geom_point_instancer.html
+"""
+
+# needed to import for allowing type-hinting: np.ndarray | torch.Tensor | None
+from__future__importannotations
+
+importnumpyasnp
+importtorch
+fromdataclassesimportMISSING
+
+importomni.isaac.core.utils.stageasstage_utils
+importomni.kit.commands
+importomni.physx.scripts.utilsasphysx_utils
+frompxrimportGf,PhysxSchema,Sdf,Usd,UsdGeom,UsdPhysics,Vt
+
+importomni.isaac.lab.simassim_utils
+fromomni.isaac.lab.sim.spawnersimportSpawnerCfg
+fromomni.isaac.lab.utils.configclassimportconfigclass
+fromomni.isaac.lab.utils.mathimportconvert_quat
+
+
+
[文档]@configclass
+classVisualizationMarkersCfg:
+"""A class to configure a :class:`VisualizationMarkers`."""
+
+ prim_path:str=MISSING
+"""The prim path where the :class:`UsdGeom.PointInstancer` will be created."""
+
+ markers:dict[str,SpawnerCfg]=MISSING
+"""The dictionary of marker configurations.
+
+ The key is the name of the marker, and the value is the configuration of the marker.
+ The key is used to identify the marker in the class.
+ """
+
+
+
[文档]classVisualizationMarkers:
+"""A class to coordinate groups of visual markers (loaded from USD).
+
+ This class allows visualization of different UI markers in the scene, such as points and frames.
+ The class wraps around the `UsdGeom.PointInstancer`_ for efficient handling of objects
+ in the stage via instancing the created marker prototype prims.
+
+ A marker prototype prim is a reusable template prim used for defining variations of objects
+ in the scene. For example, a sphere prim can be used as a marker prototype prim to create
+ multiple sphere prims in the scene at different locations. Thus, prototype prims are useful
+ for creating multiple instances of the same prim in the scene.
+
+ The class parses the configuration to create different the marker prototypes into the stage. Each marker
+ prototype prim is created as a child of the :class:`UsdGeom.PointInstancer` prim. The prim path for the
+ the marker prim is resolved using the key of the marker in the :attr:`VisualizationMarkersCfg.markers`
+ dictionary. The marker prototypes are created using the :meth:`omni.isaac.core.utils.create_prim`
+ function, and then then instanced using :class:`UsdGeom.PointInstancer` prim to allow creating multiple
+ instances of the marker prims.
+
+ Switching between different marker prototypes is possible by calling the :meth:`visualize` method with
+ the prototype indices corresponding to the marker prototype. The prototype indices are based on the order
+ in the :attr:`VisualizationMarkersCfg.markers` dictionary. For example, if the dictionary has two markers,
+ "marker1" and "marker2", then their prototype indices are 0 and 1 respectively. The prototype indices
+ can be passed as a list or array of integers.
+
+ Usage:
+ The following snippet shows how to create 24 sphere markers with a radius of 1.0 at random translations
+ within the range [-1.0, 1.0]. The first 12 markers will be colored red and the rest will be colored green.
+
+ .. code-block:: python
+
+ import omni.isaac.lab.sim as sim_utils
+ from omni.isaac.lab.markers import VisualizationMarkersCfg, VisualizationMarkers
+
+ # Create the markers configuration
+ # This creates two marker prototypes, "marker1" and "marker2" which are spheres with a radius of 1.0.
+ # The color of "marker1" is red and the color of "marker2" is green.
+ cfg = VisualizationMarkersCfg(
+ prim_path="/World/Visuals/testMarkers",
+ markers={
+ "marker1": sim_utils.SphereCfg(
+ radius=1.0,
+ visual_material=sim_utils.PreviewSurfaceCfg(diffuse_color=(1.0, 0.0, 0.0)),
+ ),
+ "marker2": VisualizationMarkersCfg.SphereCfg(
+ radius=1.0,
+ visual_material=sim_utils.PreviewSurfaceCfg(diffuse_color=(0.0, 1.0, 0.0)),
+ ),
+ }
+ )
+ # Create the markers instance
+ # This will create a UsdGeom.PointInstancer prim at the given path along with the marker prototypes.
+ marker = VisualizationMarkers(cfg)
+
+ # Set position of the marker
+ # -- randomly sample translations between -1.0 and 1.0
+ marker_translations = np.random.uniform(-1.0, 1.0, (24, 3))
+ # -- this will create 24 markers at the given translations
+ # note: the markers will all be `marker1` since the marker indices are not given
+ marker.visualize(translations=marker_translations)
+
+ # alter the markers based on their prototypes indices
+ # first 12 markers will be marker1 and the rest will be marker2
+ # 0 -> marker1, 1 -> marker2
+ marker_indices = [0] * 12 + [1] * 12
+ # this will change the marker prototypes at the given indices
+ # note: the translations of the markers will not be changed from the previous call
+ # since the translations are not given.
+ marker.visualize(marker_indices=marker_indices)
+
+ # alter the markers based on their prototypes indices and translations
+ marker.visualize(marker_indices=marker_indices, translations=marker_translations)
+
+ .. _UsdGeom.PointInstancer: https://graphics.pixar.com/usd/dev/api/class_usd_geom_point_instancer.html
+
+ """
+
+
[文档]def__init__(self,cfg:VisualizationMarkersCfg):
+"""Initialize the class.
+
+ When the class is initialized, the :class:`UsdGeom.PointInstancer` is created into the stage
+ and the marker prims are registered into it.
+
+ .. note::
+ If a prim already exists at the given path, the function will find the next free path
+ and create the :class:`UsdGeom.PointInstancer` prim there.
+
+ Args:
+ cfg: The configuration for the markers.
+
+ Raises:
+ ValueError: When no markers are provided in the :obj:`cfg`.
+ """
+ # get next free path for the prim
+ prim_path=stage_utils.get_next_free_path(cfg.prim_path)
+ # create a new prim
+ stage=stage_utils.get_current_stage()
+ self._instancer_manager=UsdGeom.PointInstancer.Define(stage,prim_path)
+ # store inputs
+ self.prim_path=prim_path
+ self.cfg=cfg
+ # check if any markers is provided
+ iflen(self.cfg.markers)==0:
+ raiseValueError(f"The `cfg.markers` cannot be empty. Received: {self.cfg.markers}")
+
+ # create a child prim for the marker
+ self._add_markers_prototypes(self.cfg.markers)
+ # Note: We need to do this the first time to initialize the instancer.
+ # Otherwise, the instancer will not be "created" and the function `GetInstanceIndices()` will fail.
+ self._instancer_manager.GetProtoIndicesAttr().Set(list(range(self.num_prototypes)))
+ self._instancer_manager.GetPositionsAttr().Set([Gf.Vec3f(0.0)]*self.num_prototypes)
+ self._count=self.num_prototypes
+
+ def__str__(self)->str:
+"""Return: A string representation of the class."""
+ msg=f"VisualizationMarkers(prim_path={self.prim_path})"
+ msg+=f"\n\tCount: {self.count}"
+ msg+=f"\n\tNumber of prototypes: {self.num_prototypes}"
+ msg+="\n\tMarkers Prototypes:"
+ forindex,(name,marker)inenumerate(self.cfg.markers.items()):
+ msg+=f"\n\t\t[Index: {index}]: {name}: {marker.to_dict()}"
+ returnmsg
+
+"""
+ Properties.
+ """
+
+ @property
+ defnum_prototypes(self)->int:
+"""The number of marker prototypes available."""
+ returnlen(self.cfg.markers)
+
+ @property
+ defcount(self)->int:
+"""The total number of marker instances."""
+ # TODO: Update this when the USD API is available (Isaac Sim 2023.1)
+ # return self._instancer_manager.GetInstanceCount()
+ returnself._count
+
+"""
+ Operations.
+ """
+
+
[文档]defset_visibility(self,visible:bool):
+"""Sets the visibility of the markers.
+
+ The method does this through the USD API.
+
+ Args:
+ visible: flag to set the visibility.
+ """
+ imageable=UsdGeom.Imageable(self._instancer_manager)
+ ifvisible:
+ imageable.MakeVisible()
+ else:
+ imageable.MakeInvisible()
+
+
[文档]defis_visible(self)->bool:
+"""Checks the visibility of the markers.
+
+ Returns:
+ True if the markers are visible, False otherwise.
+ """
+ returnself._instancer_manager.GetVisibilityAttr().Get()!=UsdGeom.Tokens.invisible
+
+
[文档]defvisualize(
+ self,
+ translations:np.ndarray|torch.Tensor|None=None,
+ orientations:np.ndarray|torch.Tensor|None=None,
+ scales:np.ndarray|torch.Tensor|None=None,
+ marker_indices:list[int]|np.ndarray|torch.Tensor|None=None,
+ ):
+"""Update markers in the viewport.
+
+ .. note::
+ If the prim `PointInstancer` is hidden in the stage, the function will simply return
+ without updating the markers. This helps in unnecessary computation when the markers
+ are not visible.
+
+ Whenever updating the markers, the input arrays must have the same number of elements
+ in the first dimension. If the number of elements is different, the `UsdGeom.PointInstancer`
+ will raise an error complaining about the mismatch.
+
+ Additionally, the function supports dynamic update of the markers. This means that the
+ number of markers can change between calls. For example, if you have 24 points that you
+ want to visualize, you can pass 24 translations, orientations, and scales. If you want to
+ visualize only 12 points, you can pass 12 translations, orientations, and scales. The
+ function will automatically update the number of markers in the scene.
+
+ The function will also update the marker prototypes based on their prototype indices. For instance,
+ if you have two marker prototypes, and you pass the following marker indices: [0, 1, 0, 1], the function
+ will update the first and third markers with the first prototype, and the second and fourth markers
+ with the second prototype. This is useful when you want to visualize different markers in the same
+ scene. The list of marker indices must have the same number of elements as the translations, orientations,
+ or scales. If the number of elements is different, the function will raise an error.
+
+ .. caution::
+ This function will update all the markers instanced from the prototypes. That means
+ if you have 24 markers, you will need to pass 24 translations, orientations, and scales.
+
+ If you want to update only a subset of the markers, you will need to handle the indices
+ yourself and pass the complete arrays to this function.
+
+ Args:
+ translations: Translations w.r.t. parent prim frame. Shape is (M, 3).
+ Defaults to None, which means left unchanged.
+ orientations: Quaternion orientations (w, x, y, z) w.r.t. parent prim frame. Shape is (M, 4).
+ Defaults to None, which means left unchanged.
+ scales: Scale applied before any rotation is applied. Shape is (M, 3).
+ Defaults to None, which means left unchanged.
+ marker_indices: Decides which marker prototype to visualize. Shape is (M).
+ Defaults to None, which means left unchanged provided that the total number of markers
+ is the same as the previous call. If the number of markers is different, the function
+ will update the number of markers in the scene.
+
+ Raises:
+ ValueError: When input arrays do not follow the expected shapes.
+ ValueError: When the function is called with all None arguments.
+ """
+ # check if it is visible (if not then let's not waste time)
+ ifnotself.is_visible():
+ return
+ # check if we have any markers to visualize
+ num_markers=0
+ # resolve inputs
+ # -- position
+ iftranslationsisnotNone:
+ ifisinstance(translations,torch.Tensor):
+ translations=translations.detach().cpu().numpy()
+ # check that shape is correct
+ iftranslations.shape[1]!=3orlen(translations.shape)!=2:
+ raiseValueError(f"Expected `translations` to have shape (M, 3). Received: {translations.shape}.")
+ # apply translations
+ self._instancer_manager.GetPositionsAttr().Set(Vt.Vec3fArray.FromNumpy(translations))
+ # update number of markers
+ num_markers=translations.shape[0]
+ # -- orientation
+ iforientationsisnotNone:
+ ifisinstance(orientations,torch.Tensor):
+ orientations=orientations.detach().cpu().numpy()
+ # check that shape is correct
+ iforientations.shape[1]!=4orlen(orientations.shape)!=2:
+ raiseValueError(f"Expected `orientations` to have shape (M, 4). Received: {orientations.shape}.")
+ # roll orientations from (w, x, y, z) to (x, y, z, w)
+ # internally USD expects (x, y, z, w)
+ orientations=convert_quat(orientations,to="xyzw")
+ # apply orientations
+ self._instancer_manager.GetOrientationsAttr().Set(Vt.QuathArray.FromNumpy(orientations))
+ # update number of markers
+ num_markers=orientations.shape[0]
+ # -- scales
+ ifscalesisnotNone:
+ ifisinstance(scales,torch.Tensor):
+ scales=scales.detach().cpu().numpy()
+ # check that shape is correct
+ ifscales.shape[1]!=3orlen(scales.shape)!=2:
+ raiseValueError(f"Expected `scales` to have shape (M, 3). Received: {scales.shape}.")
+ # apply scales
+ self._instancer_manager.GetScalesAttr().Set(Vt.Vec3fArray.FromNumpy(scales))
+ # update number of markers
+ num_markers=scales.shape[0]
+ # -- status
+ ifmarker_indicesisnotNoneornum_markers!=self._count:
+ # apply marker indices
+ ifmarker_indicesisnotNone:
+ ifisinstance(marker_indices,torch.Tensor):
+ marker_indices=marker_indices.detach().cpu().numpy()
+ elifisinstance(marker_indices,list):
+ marker_indices=np.array(marker_indices)
+ # check that shape is correct
+ iflen(marker_indices.shape)!=1:
+ raiseValueError(f"Expected `marker_indices` to have shape (M,). Received: {marker_indices.shape}.")
+ # apply proto indices
+ self._instancer_manager.GetProtoIndicesAttr().Set(Vt.IntArray.FromNumpy(marker_indices))
+ # update number of markers
+ num_markers=marker_indices.shape[0]
+ else:
+ # check that number of markers is not zero
+ ifnum_markers==0:
+ raiseValueError("Number of markers cannot be zero! Hint: The function was called with no inputs?")
+ # set all markers to be the first prototype
+ self._instancer_manager.GetProtoIndicesAttr().Set([0]*num_markers)
+ # set number of markers
+ self._count=num_markers
+
+"""
+ Helper functions.
+ """
+
+ def_add_markers_prototypes(self,markers_cfg:dict[str,sim_utils.SpawnerCfg]):
+"""Adds markers prototypes to the scene and sets the markers instancer to use them."""
+ # add markers based on config
+ forname,cfginmarkers_cfg.items():
+ # resolve prim path
+ marker_prim_path=f"{self.prim_path}/{name}"
+ # create a child prim for the marker
+ marker_prim=cfg.func(prim_path=marker_prim_path,cfg=cfg)
+ # make the asset uninstanceable (in case it is)
+ # point instancer defines its own prototypes so if an asset is already instanced, this doesn't work.
+ self._process_prototype_prim(marker_prim)
+ # add child reference to point instancer
+ self._instancer_manager.GetPrototypesRel().AddTarget(marker_prim_path)
+ # check that we loaded all the prototypes
+ prototypes=self._instancer_manager.GetPrototypesRel().GetTargets()
+ iflen(prototypes)!=len(markers_cfg):
+ raiseRuntimeError(
+ f"Failed to load all the prototypes. Expected: {len(markers_cfg)}. Received: {len(prototypes)}."
+ )
+
+ def_process_prototype_prim(self,prim:Usd.Prim):
+"""Process a prim and its descendants to make them suitable for defining prototypes.
+
+ Point instancer defines its own prototypes so if an asset is already instanced, this doesn't work.
+ This function checks if the prim at the specified prim path and its descendants are instanced.
+ If so, it makes the respective prim uninstanceable by disabling instancing on the prim.
+
+ Additionally, it makes the prim invisible to secondary rays. This is useful when we do not want
+ to see the marker prims on camera images.
+
+ Args:
+ prim_path: The prim path to check.
+ stage: The stage where the prim exists.
+ Defaults to None, in which case the current stage is used.
+ """
+ # check if prim is valid
+ ifnotprim.IsValid():
+ raiseValueError(f"Prim at path '{prim.GetPrimAtPath()}' is not valid.")
+ # iterate over all prims under prim-path
+ all_prims=[prim]
+ whilelen(all_prims)>0:
+ # get current prim
+ child_prim=all_prims.pop(0)
+ # check if it is physics body -> if so, remove it
+ ifchild_prim.HasAPI(UsdPhysics.ArticulationRootAPI):
+ child_prim.RemoveAPI(UsdPhysics.ArticulationRootAPI)
+ child_prim.RemoveAPI(PhysxSchema.PhysxArticulationAPI)
+ ifchild_prim.HasAPI(UsdPhysics.RigidBodyAPI):
+ child_prim.RemoveAPI(UsdPhysics.RigidBodyAPI)
+ child_prim.RemoveAPI(PhysxSchema.PhysxRigidBodyAPI)
+ ifchild_prim.IsA(UsdPhysics.Joint):
+ child_prim.GetAttribute("physics:jointEnabled").Set(False)
+ # check if prim is instanced -> if so, make it uninstanceable
+ ifchild_prim.IsInstance():
+ child_prim.SetInstanceable(False)
+ # check if prim is a mesh -> if so, make it invisible to secondary rays
+ ifchild_prim.IsA(UsdGeom.Gprim):
+ # invisible to secondary rays such as depth images
+ omni.kit.commands.execute(
+ "ChangePropertyCommand",
+ prop_path=Sdf.Path(f"{child_prim.GetPrimPath().pathString}.primvars:invisibleToSecondaryRays"),
+ value=True,
+ prev=None,
+ type_to_create_if_not_exist=Sdf.ValueTypeNames.Bool,
+ )
+ # add children to list
+ all_prims+=child_prim.GetChildren()
+
+ # remove any physics on the markers because they are only for visualization!
+ physx_utils.removeRigidBodySubtree(prim)
[文档]classInteractiveScene:
+"""A scene that contains entities added to the simulation.
+
+ The interactive scene parses the :class:`InteractiveSceneCfg` class to create the scene.
+ Based on the specified number of environments, it clones the entities and groups them into different
+ categories (e.g., articulations, sensors, etc.).
+
+ Cloning can be performed in two ways:
+
+ * For tasks where all environments contain the same assets, a more performant cloning paradigm
+ can be used to allow for faster environment creation. This is specified by the ``replicate_physics`` flag.
+
+ .. code-block:: python
+
+ scene = InteractiveScene(cfg=InteractiveSceneCfg(replicate_physics=True))
+
+ * For tasks that require having separate assets in the environments, ``replicate_physics`` would have to
+ be set to False, which will add some costs to the overall startup time.
+
+ .. code-block:: python
+
+ scene = InteractiveScene(cfg=InteractiveSceneCfg(replicate_physics=False))
+
+ Each entity is registered to scene based on its name in the configuration class. For example, if the user
+ specifies a robot in the configuration class as follows:
+
+ .. code-block:: python
+
+ from omni.isaac.lab.scene import InteractiveSceneCfg
+ from omni.isaac.lab.utils import configclass
+
+ from omni.isaac.lab_assets.anymal import ANYMAL_C_CFG
+
+ @configclass
+ class MySceneCfg(InteractiveSceneCfg):
+
+ robot = ANYMAL_C_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot")
+
+ Then the robot can be accessed from the scene as follows:
+
+ .. code-block:: python
+
+ from omni.isaac.lab.scene import InteractiveScene
+
+ # create 128 environments
+ scene = InteractiveScene(cfg=MySceneCfg(num_envs=128))
+
+ # access the robot from the scene
+ robot = scene["robot"]
+ # access the robot based on its type
+ robot = scene.articulations["robot"]
+
+ If the :class:`InteractiveSceneCfg` class does not include asset entities, the cloning process
+ can still be triggered if assets were added to the stage outside of the :class:`InteractiveScene` class:
+
+ .. code-block:: python
+
+ scene = InteractiveScene(cfg=InteractiveSceneCfg(num_envs=128, replicate_physics=True))
+ scene.clone_environments()
+
+ .. note::
+ It is important to note that the scene only performs common operations on the entities. For example,
+ resetting the internal buffers, writing the buffers to the simulation and updating the buffers from the
+ simulation. The scene does not perform any task specific to the entity. For example, it does not apply
+ actions to the robot or compute observations from the robot. These tasks are handled by different
+ modules called "managers" in the framework. Please refer to the :mod:`omni.isaac.lab.managers` sub-package
+ for more details.
+ """
+
+
[文档]def__init__(self,cfg:InteractiveSceneCfg):
+"""Initializes the scene.
+
+ Args:
+ cfg: The configuration class for the scene.
+ """
+ # store inputs
+ self.cfg=cfg
+ # initialize scene elements
+ self._terrain=None
+ self._articulations=dict()
+ self._deformable_objects=dict()
+ self._rigid_objects=dict()
+ self._sensors=dict()
+ self._extras=dict()
+ # obtain the current stage
+ self.stage=omni.usd.get_context().get_stage()
+ # physics scene path
+ self._physics_scene_path=None
+ # prepare cloner for environment replication
+ self.cloner=GridCloner(spacing=self.cfg.env_spacing)
+ self.cloner.define_base_env(self.env_ns)
+ self.env_prim_paths=self.cloner.generate_paths(f"{self.env_ns}/env",self.cfg.num_envs)
+ # create source prim
+ self.stage.DefinePrim(self.env_prim_paths[0],"Xform")
+
+ # when replicate_physics=False, we assume heterogeneous environments and clone the xforms first.
+ # this triggers per-object level cloning in the spawner.
+ ifnotself.cfg.replicate_physics:
+ # clone the env xform
+ env_origins=self.cloner.clone(
+ source_prim_path=self.env_prim_paths[0],
+ prim_paths=self.env_prim_paths,
+ replicate_physics=False,
+ copy_from_source=True,
+ )
+ self._default_env_origins=torch.tensor(env_origins,device=self.device,dtype=torch.float32)
+ else:
+ # otherwise, environment origins will be initialized during cloning at the end of environment creation
+ self._default_env_origins=None
+
+ self._global_prim_paths=list()
+ ifself._is_scene_setup_from_cfg():
+ # add entities from config
+ self._add_entities_from_cfg()
+ # clone environments on a global scope if environment is homogeneous
+ ifself.cfg.replicate_physics:
+ self.clone_environments(copy_from_source=False)
+ # replicate physics if we have more than one environment
+ # this is done to make scene initialization faster at play time
+ ifself.cfg.replicate_physicsandself.cfg.num_envs>1:
+ self.cloner.replicate_physics(
+ source_prim_path=self.env_prim_paths[0],
+ prim_paths=self.env_prim_paths,
+ base_env_path=self.env_ns,
+ root_path=self.env_regex_ns.replace(".*",""),
+ )
+
+ self.filter_collisions(self._global_prim_paths)
+
+
[文档]defclone_environments(self,copy_from_source:bool=False):
+"""Creates clones of the environment ``/World/envs/env_0``.
+
+ Args:
+ copy_from_source: (bool): If set to False, clones inherit from /World/envs/env_0 and mirror its changes.
+ If True, clones are independent copies of the source prim and won't reflect its changes (start-up time
+ may increase). Defaults to False.
+ """
+ env_origins=self.cloner.clone(
+ source_prim_path=self.env_prim_paths[0],
+ prim_paths=self.env_prim_paths,
+ replicate_physics=self.cfg.replicate_physics,
+ copy_from_source=copy_from_source,
+ )
+
+ # in case of heterogeneous cloning, the env origins is specified at init
+ ifself._default_env_originsisNone:
+ self._default_env_origins=torch.tensor(env_origins,device=self.device,dtype=torch.float32)
+
+
[文档]deffilter_collisions(self,global_prim_paths:list[str]|None=None):
+"""Filter environments collisions.
+
+ Disables collisions between the environments in ``/World/envs/env_.*`` and enables collisions with the prims
+ in global prim paths (e.g. ground plane).
+
+ Args:
+ global_prim_paths: A list of global prim paths to enable collisions with.
+ Defaults to None, in which case no global prim paths are considered.
+ """
+ # obtain the current physics scene
+ physics_scene_prim_path=self.physics_scene_path
+
+ # validate paths in global prim paths
+ ifglobal_prim_pathsisNone:
+ global_prim_paths=[]
+ else:
+ # remove duplicates in paths
+ global_prim_paths=list(set(global_prim_paths))
+
+ # set global prim paths list if not previously defined
+ iflen(self._global_prim_paths)<1:
+ self._global_prim_paths+=global_prim_paths
+
+ # filter collisions within each environment instance
+ self.cloner.filter_collisions(
+ physics_scene_prim_path,
+ "/World/collisions",
+ self.env_prim_paths,
+ global_paths=self._global_prim_paths,
+ )
+
+ def__str__(self)->str:
+"""Returns a string representation of the scene."""
+ msg=f"<class {self.__class__.__name__}>\n"
+ msg+=f"\tNumber of environments: {self.cfg.num_envs}\n"
+ msg+=f"\tEnvironment spacing : {self.cfg.env_spacing}\n"
+ msg+=f"\tSource prim name : {self.env_prim_paths[0]}\n"
+ msg+=f"\tGlobal prim paths : {self._global_prim_paths}\n"
+ msg+=f"\tReplicate physics : {self.cfg.replicate_physics}"
+ returnmsg
+
+"""
+ Properties.
+ """
+
+ @property
+ defphysics_scene_path(self):
+"""Search the stage for the physics scene"""
+ ifself._physics_scene_pathisNone:
+ forpriminself.stage.Traverse():
+ ifprim.HasAPI(PhysxSchema.PhysxSceneAPI):
+ self._physics_scene_path=prim.GetPrimPath()
+ carb.log_info(f"Physics scene prim path: {self._physics_scene_path}")
+ break
+ returnself._physics_scene_path
+
+ @property
+ defphysics_dt(self)->float:
+"""The physics timestep of the scene."""
+ returnsim_utils.SimulationContext.instance().get_physics_dt()# pyright: ignore [reportOptionalMemberAccess]
+
+ @property
+ defdevice(self)->str:
+"""The device on which the scene is created."""
+ returnsim_utils.SimulationContext.instance().device# pyright: ignore [reportOptionalMemberAccess]
+
+ @property
+ defenv_ns(self)->str:
+"""The namespace ``/World/envs`` in which all environments created.
+
+ The environments are present w.r.t. this namespace under "env_{N}" prim,
+ where N is a natural number.
+ """
+ return"/World/envs"
+
+ @property
+ defenv_regex_ns(self)->str:
+"""The namespace ``/World/envs/env_.*`` in which all environments created."""
+ returnf"{self.env_ns}/env_.*"
+
+ @property
+ defnum_envs(self)->int:
+"""The number of environments handled by the scene."""
+ returnself.cfg.num_envs
+
+ @property
+ defenv_origins(self)->torch.Tensor:
+"""The origins of the environments in the scene. Shape is (num_envs, 3)."""
+ ifself._terrainisnotNone:
+ returnself._terrain.env_origins
+ else:
+ returnself._default_env_origins
+
+ @property
+ defterrain(self)->TerrainImporter|None:
+"""The terrain in the scene. If None, then the scene has no terrain.
+
+ Note:
+ We treat terrain separate from :attr:`extras` since terrains define environment origins and are
+ handled differently from other miscellaneous entities.
+ """
+ returnself._terrain
+
+ @property
+ defarticulations(self)->dict[str,Articulation]:
+"""A dictionary of articulations in the scene."""
+ returnself._articulations
+
+ @property
+ defdeformable_objects(self)->dict[str,DeformableObject]:
+"""A dictionary of deformable objects in the scene."""
+ returnself._deformable_objects
+
+ @property
+ defrigid_objects(self)->dict[str,RigidObject]:
+"""A dictionary of rigid objects in the scene."""
+ returnself._rigid_objects
+
+ @property
+ defsensors(self)->dict[str,SensorBase]:
+"""A dictionary of the sensors in the scene, such as cameras and contact reporters."""
+ returnself._sensors
+
+ @property
+ defextras(self)->dict[str,XFormPrimView]:
+"""A dictionary of miscellaneous simulation objects that neither inherit from assets nor sensors.
+
+ The keys are the names of the miscellaneous objects, and the values are the `XFormPrimView`_
+ of the corresponding prims.
+
+ As an example, lights or other props in the scene that do not have any attributes or properties that you
+ want to alter at runtime can be added to this dictionary.
+
+ Note:
+ These are not reset or updated by the scene. They are mainly other prims that are not necessarily
+ handled by the interactive scene, but are useful to be accessed by the user.
+
+ .. _XFormPrimView: https://docs.omniverse.nvidia.com/py/isaacsim/source/extensions/omni.isaac.core/docs/index.html#omni.isaac.core.prims.XFormPrimView
+
+ """
+ returnself._extras
+
+"""
+ Operations.
+ """
+
+
[文档]defreset(self,env_ids:Sequence[int]|None=None):
+"""Resets the scene entities.
+
+ Args:
+ env_ids: The indices of the environments to reset.
+ Defaults to None (all instances).
+ """
+ # -- assets
+ forarticulationinself._articulations.values():
+ articulation.reset(env_ids)
+ fordeformable_objectinself._deformable_objects.values():
+ deformable_object.reset(env_ids)
+ forrigid_objectinself._rigid_objects.values():
+ rigid_object.reset(env_ids)
+ # -- sensors
+ forsensorinself._sensors.values():
+ sensor.reset(env_ids)
+
+
[文档]defwrite_data_to_sim(self):
+"""Writes the data of the scene entities to the simulation."""
+ # -- assets
+ forarticulationinself._articulations.values():
+ articulation.write_data_to_sim()
+ fordeformable_objectinself._deformable_objects.values():
+ deformable_object.write_data_to_sim()
+ forrigid_objectinself._rigid_objects.values():
+ rigid_object.write_data_to_sim()
+
+
[文档]defupdate(self,dt:float)->None:
+"""Update the scene entities.
+
+ Args:
+ dt: The amount of time passed from last :meth:`update` call.
+ """
+ # -- assets
+ forarticulationinself._articulations.values():
+ articulation.update(dt)
+ fordeformable_objectinself._deformable_objects.values():
+ deformable_object.update(dt)
+ forrigid_objectinself._rigid_objects.values():
+ rigid_object.update(dt)
+ # -- sensors
+ forsensorinself._sensors.values():
+ sensor.update(dt,force_recompute=notself.cfg.lazy_sensor_update)
+
+"""
+ Operations: Iteration.
+ """
+
+
[文档]defkeys(self)->list[str]:
+"""Returns the keys of the scene entities.
+
+ Returns:
+ The keys of the scene entities.
+ """
+ all_keys=["terrain"]
+ forasset_familyin[
+ self._articulations,
+ self._deformable_objects,
+ self._rigid_objects,
+ self._sensors,
+ self._extras,
+ ]:
+ all_keys+=list(asset_family.keys())
+ returnall_keys
+
+ def__getitem__(self,key:str)->Any:
+"""Returns the scene entity with the given key.
+
+ Args:
+ key: The key of the scene entity.
+
+ Returns:
+ The scene entity.
+ """
+ # check if it is a terrain
+ ifkey=="terrain":
+ returnself._terrain
+
+ all_keys=["terrain"]
+ # check if it is in other dictionaries
+ forasset_familyin[
+ self._articulations,
+ self._deformable_objects,
+ self._rigid_objects,
+ self._sensors,
+ self._extras,
+ ]:
+ out=asset_family.get(key)
+ # if found, return
+ ifoutisnotNone:
+ returnout
+ all_keys+=list(asset_family.keys())
+ # if not found, raise error
+ raiseKeyError(f"Scene entity with key '{key}' not found. Available Entities: '{all_keys}'")
+
+"""
+ Internal methods.
+ """
+
+ def_is_scene_setup_from_cfg(self):
+ returnany(
+ not(asset_nameinInteractiveSceneCfg.__dataclass_fields__orasset_cfgisNone)
+ forasset_name,asset_cfginself.cfg.__dict__.items()
+ )
+
+ def_add_entities_from_cfg(self):
+"""Add scene entities from the config."""
+ # store paths that are in global collision filter
+ self._global_prim_paths=list()
+ # parse the entire scene config and resolve regex
+ forasset_name,asset_cfginself.cfg.__dict__.items():
+ # skip keywords
+ # note: easier than writing a list of keywords: [num_envs, env_spacing, lazy_sensor_update]
+ ifasset_nameinInteractiveSceneCfg.__dataclass_fields__orasset_cfgisNone:
+ continue
+ # resolve regex
+ asset_cfg.prim_path=asset_cfg.prim_path.format(ENV_REGEX_NS=self.env_regex_ns)
+ # create asset
+ ifisinstance(asset_cfg,TerrainImporterCfg):
+ # terrains are special entities since they define environment origins
+ asset_cfg.num_envs=self.cfg.num_envs
+ asset_cfg.env_spacing=self.cfg.env_spacing
+ self._terrain=asset_cfg.class_type(asset_cfg)
+ elifisinstance(asset_cfg,ArticulationCfg):
+ self._articulations[asset_name]=asset_cfg.class_type(asset_cfg)
+ elifisinstance(asset_cfg,DeformableObjectCfg):
+ self._deformable_objects[asset_name]=asset_cfg.class_type(asset_cfg)
+ elifisinstance(asset_cfg,RigidObjectCfg):
+ self._rigid_objects[asset_name]=asset_cfg.class_type(asset_cfg)
+ elifisinstance(asset_cfg,SensorBaseCfg):
+ # Update target frame path(s)' regex name space for FrameTransformer
+ ifisinstance(asset_cfg,FrameTransformerCfg):
+ updated_target_frames=[]
+ fortarget_frameinasset_cfg.target_frames:
+ target_frame.prim_path=target_frame.prim_path.format(ENV_REGEX_NS=self.env_regex_ns)
+ updated_target_frames.append(target_frame)
+ asset_cfg.target_frames=updated_target_frames
+ elifisinstance(asset_cfg,ContactSensorCfg):
+ updated_filter_prim_paths_expr=[]
+ forfilter_prim_pathinasset_cfg.filter_prim_paths_expr:
+ updated_filter_prim_paths_expr.append(filter_prim_path.format(ENV_REGEX_NS=self.env_regex_ns))
+ asset_cfg.filter_prim_paths_expr=updated_filter_prim_paths_expr
+
+ self._sensors[asset_name]=asset_cfg.class_type(asset_cfg)
+ elifisinstance(asset_cfg,AssetBaseCfg):
+ # manually spawn asset
+ ifasset_cfg.spawnisnotNone:
+ asset_cfg.spawn.func(
+ asset_cfg.prim_path,
+ asset_cfg.spawn,
+ translation=asset_cfg.init_state.pos,
+ orientation=asset_cfg.init_state.rot,
+ )
+ # store xform prim view corresponding to this asset
+ # all prims in the scene are Xform prims (i.e. have a transform component)
+ self._extras[asset_name]=XFormPrimView(asset_cfg.prim_path,reset_xform_properties=False)
+ else:
+ raiseValueError(f"Unknown asset config type for {asset_name}: {asset_cfg}")
+ # store global collision paths
+ ifhasattr(asset_cfg,"collision_group")andasset_cfg.collision_group==-1:
+ asset_paths=sim_utils.find_matching_prim_paths(asset_cfg.prim_path)
+ self._global_prim_paths+=asset_paths
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+fromdataclassesimportMISSING
+
+fromomni.isaac.lab.utils.configclassimportconfigclass
+
+
+
[文档]@configclass
+classInteractiveSceneCfg:
+"""Configuration for the interactive scene.
+
+ The users can inherit from this class to add entities to their scene. This is then parsed by the
+ :class:`InteractiveScene` class to create the scene.
+
+ .. note::
+ The adding of entities to the scene is sensitive to the order of the attributes in the configuration.
+ Please make sure to add the entities in the order you want them to be added to the scene.
+ The recommended order of specification is terrain, physics-related assets (articulations and rigid bodies),
+ sensors and non-physics-related assets (lights).
+
+ For example, to add a robot to the scene, the user can create a configuration class as follows:
+
+ .. code-block:: python
+
+ import omni.isaac.lab.sim as sim_utils
+ from omni.isaac.lab.assets import AssetBaseCfg
+ from omni.isaac.lab.scene import InteractiveSceneCfg
+ from omni.isaac.lab.sensors.ray_caster import GridPatternCfg, RayCasterCfg
+ from omni.isaac.lab.utils import configclass
+
+ from omni.isaac.lab_assets.anymal import ANYMAL_C_CFG
+
+ @configclass
+ class MySceneCfg(InteractiveSceneCfg):
+
+ # terrain - flat terrain plane
+ terrain = TerrainImporterCfg(
+ prim_path="/World/ground",
+ terrain_type="plane",
+ )
+
+ # articulation - robot 1
+ robot_1 = ANYMAL_C_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot_1")
+ # articulation - robot 2
+ robot_2 = ANYMAL_C_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot_2")
+ robot_2.init_state.pos = (0.0, 1.0, 0.6)
+
+ # sensor - ray caster attached to the base of robot 1 that scans the ground
+ height_scanner = RayCasterCfg(
+ prim_path="{ENV_REGEX_NS}/Robot_1/base",
+ offset=RayCasterCfg.OffsetCfg(pos=(0.0, 0.0, 20.0)),
+ attach_yaw_only=True,
+ pattern_cfg=GridPatternCfg(resolution=0.1, size=[1.6, 1.0]),
+ debug_vis=True,
+ mesh_prim_paths=["/World/ground"],
+ )
+
+ # extras - light
+ light = AssetBaseCfg(
+ prim_path="/World/light",
+ spawn=sim_utils.DistantLightCfg(intensity=3000.0, color=(0.75, 0.75, 0.75)),
+ init_state=AssetBaseCfg.InitialStateCfg(pos=(0.0, 0.0, 500.0)),
+ )
+
+ """
+
+ num_envs:int=MISSING
+"""Number of environment instances handled by the scene."""
+
+ env_spacing:float=MISSING
+"""Spacing between environments.
+
+ This is the default distance between environment origins in the scene. Used only when the
+ number of environments is greater than one.
+ """
+
+ lazy_sensor_update:bool=True
+"""Whether to update sensors only when they are accessed. Default is True.
+
+ If true, the sensor data is only updated when their attribute ``data`` is accessed. Otherwise, the sensor
+ data is updated every time sensors are updated.
+ """
+
+ replicate_physics:bool=True
+"""Enable/disable replication of physics schemas when using the Cloner APIs. Default is True.
+
+ If True, the simulation will have the same asset instances (USD prims) in all the cloned environments.
+ Internally, this ensures optimization in setting up the scene and parsing it via the physics stage parser.
+
+ If False, the simulation allows having separate asset instances (USD prims) in each environment.
+ This flexibility comes at a cost of slowdowns in setting up and parsing the scene.
+
+ .. note::
+ Optimized parsing of certain prim types (such as deformable objects) is not currently supported
+ by the physics engine. In these cases, this flag needs to be set to False.
+ """
[文档]classCamera(SensorBase):
+r"""The camera sensor for acquiring visual data.
+
+ This class wraps over the `UsdGeom Camera`_ for providing a consistent API for acquiring visual data.
+ It ensures that the camera follows the ROS convention for the coordinate system.
+
+ Summarizing from the `replicator extension`_, the following sensor types are supported:
+
+ - ``"rgb"``: A 3-channel rendered color image.
+ - ``"rgba"``: A 4-channel rendered color image with alpha channel.
+ - ``"distance_to_camera"``: An image containing the distance to camera optical center.
+ - ``"distance_to_image_plane"``: An image containing distances of 3D points from camera plane along camera's z-axis.
+ - ``"depth"``: The same as ``"distance_to_image_plane"``.
+ - ``"normals"``: An image containing the local surface normal vectors at each pixel.
+ - ``"motion_vectors"``: An image containing the motion vector data at each pixel.
+ - ``"semantic_segmentation"``: The semantic segmentation data.
+ - ``"instance_segmentation_fast"``: The instance segmentation data.
+ - ``"instance_id_segmentation_fast"``: The instance id segmentation data.
+
+ .. note::
+ Currently the following sensor types are not supported in a "view" format:
+
+ - ``"instance_segmentation"``: The instance segmentation data. Please use the fast counterparts instead.
+ - ``"instance_id_segmentation"``: The instance id segmentation data. Please use the fast counterparts instead.
+ - ``"bounding_box_2d_tight"``: The tight 2D bounding box data (only contains non-occluded regions).
+ - ``"bounding_box_2d_tight_fast"``: The tight 2D bounding box data (only contains non-occluded regions).
+ - ``"bounding_box_2d_loose"``: The loose 2D bounding box data (contains occluded regions).
+ - ``"bounding_box_2d_loose_fast"``: The loose 2D bounding box data (contains occluded regions).
+ - ``"bounding_box_3d"``: The 3D view space bounding box data.
+ - ``"bounding_box_3d_fast"``: The 3D view space bounding box data.
+
+ .. _replicator extension: https://docs.omniverse.nvidia.com/extensions/latest/ext_replicator/annotators_details.html#annotator-output
+ .. _USDGeom Camera: https://graphics.pixar.com/usd/docs/api/class_usd_geom_camera.html
+
+ """
+
+ cfg:CameraCfg
+"""The configuration parameters."""
+
+ UNSUPPORTED_TYPES:set[str]={
+ "instance_id_segmentation",
+ "instance_segmentation",
+ "bounding_box_2d_tight",
+ "bounding_box_2d_loose",
+ "bounding_box_3d",
+ "bounding_box_2d_tight_fast",
+ "bounding_box_2d_loose_fast",
+ "bounding_box_3d_fast",
+ }
+"""The set of sensor types that are not supported by the camera class."""
+
+
[文档]def__init__(self,cfg:CameraCfg):
+"""Initializes the camera sensor.
+
+ Args:
+ cfg: The configuration parameters.
+
+ Raises:
+ RuntimeError: If no camera prim is found at the given path.
+ ValueError: If the provided data types are not supported by the camera.
+ """
+ # check if sensor path is valid
+ # note: currently we do not handle environment indices if there is a regex pattern in the leaf
+ # For example, if the prim path is "/World/Sensor_[1,2]".
+ sensor_path=cfg.prim_path.split("/")[-1]
+ sensor_path_is_regex=re.match(r"^[a-zA-Z0-9/_]+$",sensor_path)isNone
+ ifsensor_path_is_regex:
+ raiseRuntimeError(
+ f"Invalid prim path for the camera sensor: {self.cfg.prim_path}."
+ "\n\tHint: Please ensure that the prim path does not contain any regex patterns in the leaf."
+ )
+ # perform check on supported data types
+ self._check_supported_data_types(cfg)
+ # initialize base class
+ super().__init__(cfg)
+
+ # toggle rendering of rtx sensors as True
+ # this flag is read by SimulationContext to determine if rtx sensors should be rendered
+ carb_settings_iface=carb.settings.get_settings()
+ carb_settings_iface.set_bool("/isaaclab/render/rtx_sensors",True)
+
+ # spawn the asset
+ ifself.cfg.spawnisnotNone:
+ # compute the rotation offset
+ rot=torch.tensor(self.cfg.offset.rot,dtype=torch.float32).unsqueeze(0)
+ rot_offset=convert_orientation_convention(rot,origin=self.cfg.offset.convention,target="opengl")
+ rot_offset=rot_offset.squeeze(0).numpy()
+ # ensure vertical aperture is set, otherwise replace with default for squared pixels
+ ifself.cfg.spawn.vertical_apertureisNone:
+ self.cfg.spawn.vertical_aperture=self.cfg.spawn.horizontal_aperture*self.cfg.height/self.cfg.width
+ # spawn the asset
+ self.cfg.spawn.func(
+ self.cfg.prim_path,self.cfg.spawn,translation=self.cfg.offset.pos,orientation=rot_offset
+ )
+ # check that spawn was successful
+ matching_prims=sim_utils.find_matching_prims(self.cfg.prim_path)
+ iflen(matching_prims)==0:
+ raiseRuntimeError(f"Could not find prim with path {self.cfg.prim_path}.")
+
+ # UsdGeom Camera prim for the sensor
+ self._sensor_prims:list[UsdGeom.Camera]=list()
+ # Create empty variables for storing output data
+ self._data=CameraData()
+
+ def__del__(self):
+"""Unsubscribes from callbacks and detach from the replicator registry."""
+ # unsubscribe callbacks
+ super().__del__()
+ # delete from replicator registry
+ for_,annotatorsinself._rep_registry.items():
+ forannotator,render_product_pathinzip(annotators,self._render_product_paths):
+ annotator.detach([render_product_path])
+ annotator=None
+
+ def__str__(self)->str:
+"""Returns: A string containing information about the instance."""
+ # message for class
+ return(
+ f"Camera @ '{self.cfg.prim_path}': \n"
+ f"\tdata types : {self.data.output.sorted_keys}\n"
+ f"\tsemantic filter : {self.cfg.semantic_filter}\n"
+ f"\tcolorize semantic segm. : {self.cfg.colorize_semantic_segmentation}\n"
+ f"\tcolorize instance segm. : {self.cfg.colorize_instance_segmentation}\n"
+ f"\tcolorize instance id segm.: {self.cfg.colorize_instance_id_segmentation}\n"
+ f"\tupdate period (s): {self.cfg.update_period}\n"
+ f"\tshape : {self.image_shape}\n"
+ f"\tnumber of sensors : {self._view.count}"
+ )
+
+"""
+ Properties
+ """
+
+ @property
+ defnum_instances(self)->int:
+ returnself._view.count
+
+ @property
+ defdata(self)->CameraData:
+ # update sensors if needed
+ self._update_outdated_buffers()
+ # return the data
+ returnself._data
+
+ @property
+ defframe(self)->torch.tensor:
+"""Frame number when the measurement took place."""
+ returnself._frame
+
+ @property
+ defrender_product_paths(self)->list[str]:
+"""The path of the render products for the cameras.
+
+ This can be used via replicator interfaces to attach to writes or external annotator registry.
+ """
+ returnself._render_product_paths
+
+ @property
+ defimage_shape(self)->tuple[int,int]:
+"""A tuple containing (height, width) of the camera sensor."""
+ return(self.cfg.height,self.cfg.width)
+
+"""
+ Configuration
+ """
+
+
[文档]defset_intrinsic_matrices(
+ self,matrices:torch.Tensor,focal_length:float=1.0,env_ids:Sequence[int]|None=None
+ ):
+"""Set parameters of the USD camera from its intrinsic matrix.
+
+ The intrinsic matrix and focal length are used to set the following parameters to the USD camera:
+
+ - ``focal_length``: The focal length of the camera.
+ - ``horizontal_aperture``: The horizontal aperture of the camera.
+ - ``vertical_aperture``: The vertical aperture of the camera.
+ - ``horizontal_aperture_offset``: The horizontal offset of the camera.
+ - ``vertical_aperture_offset``: The vertical offset of the camera.
+
+ .. warning::
+
+ Due to limitations of Omniverse camera, we need to assume that the camera is a spherical lens,
+ i.e. has square pixels, and the optical center is centered at the camera eye. If this assumption
+ is not true in the input intrinsic matrix, then the camera will not set up correctly.
+
+ Args:
+ matrices: The intrinsic matrices for the camera. Shape is (N, 3, 3).
+ focal_length: Focal length to use when computing aperture values (in cm). Defaults to 1.0.
+ env_ids: A sensor ids to manipulate. Defaults to None, which means all sensor indices.
+ """
+ # resolve env_ids
+ ifenv_idsisNone:
+ env_ids=self._ALL_INDICES
+ # convert matrices to numpy tensors
+ ifisinstance(matrices,torch.Tensor):
+ matrices=matrices.cpu().numpy()
+ else:
+ matrices=np.asarray(matrices,dtype=float)
+ # iterate over env_ids
+ fori,intrinsic_matrixinzip(env_ids,matrices):
+ # extract parameters from matrix
+ f_x=intrinsic_matrix[0,0]
+ c_x=intrinsic_matrix[0,2]
+ f_y=intrinsic_matrix[1,1]
+ c_y=intrinsic_matrix[1,2]
+ # get viewport parameters
+ height,width=self.image_shape
+ height,width=float(height),float(width)
+ # resolve parameters for usd camera
+ params={
+ "focal_length":focal_length,
+ "horizontal_aperture":width*focal_length/f_x,
+ "vertical_aperture":height*focal_length/f_y,
+ "horizontal_aperture_offset":(c_x-width/2)/f_x,
+ "vertical_aperture_offset":(c_y-height/2)/f_y,
+ }
+
+ # TODO: Adjust to handle aperture offsets once supported by omniverse
+ # Internal ticket from rendering team: OM-42611
+ ifparams["horizontal_aperture_offset"]>1e-4orparams["vertical_aperture_offset"]>1e-4:
+ carb.log_warn("Camera aperture offsets are not supported by Omniverse. These parameters are ignored.")
+
+ # change data for corresponding camera index
+ sensor_prim=self._sensor_prims[i]
+ # set parameters for camera
+ forparam_name,param_valueinparams.items():
+ # convert to camel case (CC)
+ param_name=to_camel_case(param_name,to="CC")
+ # get attribute from the class
+ param_attr=getattr(sensor_prim,f"Get{param_name}Attr")
+ # set value
+ # note: We have to do it this way because the camera might be on a different
+ # layer (default cameras are on session layer), and this is the simplest
+ # way to set the property on the right layer.
+ omni.usd.set_prop_val(param_attr(),param_value)
+ # update the internal buffers
+ self._update_intrinsic_matrices(env_ids)
+
+"""
+ Operations - Set pose.
+ """
+
+
[文档]defset_world_poses(
+ self,
+ positions:torch.Tensor|None=None,
+ orientations:torch.Tensor|None=None,
+ env_ids:Sequence[int]|None=None,
+ convention:Literal["opengl","ros","world"]="ros",
+ ):
+r"""Set the pose of the camera w.r.t. the world frame using specified convention.
+
+ Since different fields use different conventions for camera orientations, the method allows users to
+ set the camera poses in the specified convention. Possible conventions are:
+
+ - :obj:`"opengl"` - forward axis: -Z - up axis +Y - Offset is applied in the OpenGL (Usd.Camera) convention
+ - :obj:`"ros"` - forward axis: +Z - up axis -Y - Offset is applied in the ROS convention
+ - :obj:`"world"` - forward axis: +X - up axis +Z - Offset is applied in the World Frame convention
+
+ See :meth:`omni.isaac.lab.sensors.camera.utils.convert_orientation_convention` for more details
+ on the conventions.
+
+ Args:
+ positions: The cartesian coordinates (in meters). Shape is (N, 3).
+ Defaults to None, in which case the camera position in not changed.
+ orientations: The quaternion orientation in (w, x, y, z). Shape is (N, 4).
+ Defaults to None, in which case the camera orientation in not changed.
+ env_ids: A sensor ids to manipulate. Defaults to None, which means all sensor indices.
+ convention: The convention in which the poses are fed. Defaults to "ros".
+
+ Raises:
+ RuntimeError: If the camera prim is not set. Need to call :meth:`initialize` method first.
+ """
+ # resolve env_ids
+ ifenv_idsisNone:
+ env_ids=self._ALL_INDICES
+ # convert to backend tensor
+ ifpositionsisnotNone:
+ ifisinstance(positions,np.ndarray):
+ positions=torch.from_numpy(positions).to(device=self._device)
+ elifnotisinstance(positions,torch.Tensor):
+ positions=torch.tensor(positions,device=self._device)
+ # convert rotation matrix from input convention to OpenGL
+ iforientationsisnotNone:
+ ifisinstance(orientations,np.ndarray):
+ orientations=torch.from_numpy(orientations).to(device=self._device)
+ elifnotisinstance(orientations,torch.Tensor):
+ orientations=torch.tensor(orientations,device=self._device)
+ orientations=convert_orientation_convention(orientations,origin=convention,target="opengl")
+ # set the pose
+ self._view.set_world_poses(positions,orientations,env_ids)
+
+
[文档]defset_world_poses_from_view(
+ self,eyes:torch.Tensor,targets:torch.Tensor,env_ids:Sequence[int]|None=None
+ ):
+"""Set the poses of the camera from the eye position and look-at target position.
+
+ Args:
+ eyes: The positions of the camera's eye. Shape is (N, 3).
+ targets: The target locations to look at. Shape is (N, 3).
+ env_ids: A sensor ids to manipulate. Defaults to None, which means all sensor indices.
+
+ Raises:
+ RuntimeError: If the camera prim is not set. Need to call :meth:`initialize` method first.
+ NotImplementedError: If the stage up-axis is not "Y" or "Z".
+ """
+ # resolve env_ids
+ ifenv_idsisNone:
+ env_ids=self._ALL_INDICES
+ # set camera poses using the view
+ orientations=quat_from_matrix(create_rotation_matrix_from_view(eyes,targets,device=self._device))
+ self._view.set_world_poses(eyes,orientations,env_ids)
+
+"""
+ Operations
+ """
+
+
[文档]defreset(self,env_ids:Sequence[int]|None=None):
+ ifnotself._is_initialized:
+ raiseRuntimeError(
+ "Camera could not be initialized. Please ensure --enable_cameras is used to enable rendering."
+ )
+ # reset the timestamps
+ super().reset(env_ids)
+ # resolve None
+ # note: cannot do smart indexing here since we do a for loop over data.
+ ifenv_idsisNone:
+ env_ids=self._ALL_INDICES
+ # reset the data
+ # note: this recomputation is useful if one performs events such as randomizations on the camera poses.
+ self._update_poses(env_ids)
+ # Reset the frame count
+ self._frame[env_ids]=0
+
+"""
+ Implementation.
+ """
+
+ def_initialize_impl(self):
+"""Initializes the sensor handles and internal buffers.
+
+ This function creates handles and registers the provided data types with the replicator registry to
+ be able to access the data from the sensor. It also initializes the internal buffers to store the data.
+
+ Raises:
+ RuntimeError: If the number of camera prims in the view does not match the number of environments.
+ RuntimeError: If replicator was not found.
+ """
+ carb_settings_iface=carb.settings.get_settings()
+ ifnotcarb_settings_iface.get("/isaaclab/cameras_enabled"):
+ raiseRuntimeError(
+ "A camera was spawned without the --enable_cameras flag. Please use --enable_cameras to enable"
+ " rendering."
+ )
+
+ importomni.replicator.coreasrep
+ fromomni.syntheticdata.scripts.SyntheticDataimportSyntheticData
+
+ # Initialize parent class
+ super()._initialize_impl()
+ # Create a view for the sensor
+ self._view=XFormPrimView(self.cfg.prim_path,reset_xform_properties=False)
+ self._view.initialize()
+ # Check that sizes are correct
+ ifself._view.count!=self._num_envs:
+ raiseRuntimeError(
+ f"Number of camera prims in the view ({self._view.count}) does not match"
+ f" the number of environments ({self._num_envs})."
+ )
+
+ # WAR: use DLAA antialiasing to avoid frame offset issue at small resolutions
+ ifself.cfg.width<265orself.cfg.height<265:
+ rep.settings.set_render_rtx_realtime(antialiasing="DLAA")
+
+ # Create all env_ids buffer
+ self._ALL_INDICES=torch.arange(self._view.count,device=self._device,dtype=torch.long)
+ # Create frame count buffer
+ self._frame=torch.zeros(self._view.count,device=self._device,dtype=torch.long)
+
+ # Attach the sensor data types to render node
+ self._render_product_paths:list[str]=list()
+ self._rep_registry:dict[str,list[rep.annotators.Annotator]]={name:list()fornameinself.cfg.data_types}
+
+ # Obtain current stage
+ stage=omni.usd.get_context().get_stage()
+ # Convert all encapsulated prims to Camera
+ forcam_prim_pathinself._view.prim_paths:
+ # Get camera prim
+ cam_prim=stage.GetPrimAtPath(cam_prim_path)
+ # Check if prim is a camera
+ ifnotcam_prim.IsA(UsdGeom.Camera):
+ raiseRuntimeError(f"Prim at path '{cam_prim_path}' is not a Camera.")
+ # Add to list
+ sensor_prim=UsdGeom.Camera(cam_prim)
+ self._sensor_prims.append(sensor_prim)
+
+ # Get render product
+ # From Isaac Sim 2023.1 onwards, render product is a HydraTexture so we need to extract the path
+ render_prod_path=rep.create.render_product(cam_prim_path,resolution=(self.cfg.width,self.cfg.height))
+ ifnotisinstance(render_prod_path,str):
+ render_prod_path=render_prod_path.path
+ self._render_product_paths.append(render_prod_path)
+
+ # Check if semantic types or semantic filter predicate is provided
+ ifisinstance(self.cfg.semantic_filter,list):
+ semantic_filter_predicate=":*; ".join(self.cfg.semantic_filter)+":*"
+ elifisinstance(self.cfg.semantic_filter,str):
+ semantic_filter_predicate=self.cfg.semantic_filter
+ else:
+ raiseValueError(f"Semantic types must be a list or a string. Received: {self.cfg.semantic_filter}.")
+ # set the semantic filter predicate
+ # copied from rep.scripts.writes_default.basic_writer.py
+ SyntheticData.Get().set_instance_mapping_semantic_filter(semantic_filter_predicate)
+
+ # Iterate over each data type and create annotator
+ # TODO: This will move out of the loop once Replicator supports multiple render products within a single
+ # annotator, i.e.: rep_annotator.attach(self._render_product_paths)
+ fornameinself.cfg.data_types:
+ # note: we are verbose here to make it easier to understand the code.
+ # if colorize is true, the data is mapped to colors and a uint8 4 channel image is returned.
+ # if colorize is false, the data is returned as a uint32 image with ids as values.
+ ifname=="semantic_segmentation":
+ init_params={"colorize":self.cfg.colorize_semantic_segmentation}
+ elifname=="instance_segmentation_fast":
+ init_params={"colorize":self.cfg.colorize_instance_segmentation}
+ elifname=="instance_id_segmentation_fast":
+ init_params={"colorize":self.cfg.colorize_instance_id_segmentation}
+ else:
+ init_params=None
+
+ # Resolve device name
+ if"cuda"inself._device:
+ device_name=self._device.split(":")[0]
+ else:
+ device_name="cpu"
+
+ # Map special cases to their corresponding annotator names
+ special_cases={"rgba":"rgb","depth":"distance_to_image_plane"}
+ # Get the annotator name, falling back to the original name if not a special case
+ annotator_name=special_cases.get(name,name)
+ # Create the annotator node
+ rep_annotator=rep.AnnotatorRegistry.get_annotator(annotator_name,init_params,device=device_name)
+
+ # attach annotator to render product
+ rep_annotator.attach(render_prod_path)
+ # add to registry
+ self._rep_registry[name].append(rep_annotator)
+
+ # Create internal buffers
+ self._create_buffers()
+ self._update_intrinsic_matrices(self._ALL_INDICES)
+
+ def_update_buffers_impl(self,env_ids:Sequence[int]):
+ # Increment frame count
+ self._frame[env_ids]+=1
+ # -- pose
+ self._update_poses(env_ids)
+ # -- read the data from annotator registry
+ # check if buffer is called for the first time. If so then, allocate the memory
+ iflen(self._data.output.sorted_keys)==0:
+ # this is the first time buffer is called
+ # it allocates memory for all the sensors
+ self._create_annotator_data()
+ else:
+ # iterate over all the data types
+ forname,annotatorsinself._rep_registry.items():
+ # iterate over all the annotators
+ forindexinenv_ids:
+ # get the output
+ output=annotators[index].get_data()
+ # process the output
+ data,info=self._process_annotator_output(name,output)
+ # add data to output
+ self._data.output[name][index]=data
+ # add info to output
+ self._data.info[index][name]=info
+
+"""
+ Private Helpers
+ """
+
+ def_check_supported_data_types(self,cfg:CameraCfg):
+"""Checks if the data types are supported by the ray-caster camera."""
+ # check if there is any intersection in unsupported types
+ # reason: these use np structured data types which we can't yet convert to torch tensor
+ common_elements=set(cfg.data_types)&Camera.UNSUPPORTED_TYPES
+ ifcommon_elements:
+ # provide alternative fast counterparts
+ fast_common_elements=[]
+ foritemincommon_elements:
+ if"instance_segmentation"initemor"instance_id_segmentation"initem:
+ fast_common_elements.append(item+"_fast")
+ # raise error
+ raiseValueError(
+ f"Camera class does not support the following sensor types: {common_elements}."
+ "\n\tThis is because these sensor types output numpy structured data types which"
+ "can't be converted to torch tensors easily."
+ "\n\tHint: If you need to work with these sensor types, we recommend using their fast counterparts."
+ f"\n\t\tFast counterparts: {fast_common_elements}"
+ )
+
+ def_create_buffers(self):
+"""Create buffers for storing data."""
+ # create the data object
+ # -- pose of the cameras
+ self._data.pos_w=torch.zeros((self._view.count,3),device=self._device)
+ self._data.quat_w_world=torch.zeros((self._view.count,4),device=self._device)
+ # -- intrinsic matrix
+ self._data.intrinsic_matrices=torch.zeros((self._view.count,3,3),device=self._device)
+ self._data.image_shape=self.image_shape
+ # -- output data
+ # lazy allocation of data dictionary
+ # since the size of the output data is not known in advance, we leave it as None
+ # the memory will be allocated when the buffer() function is called for the first time.
+ self._data.output=TensorDict({},batch_size=self._view.count,device=self.device)
+ self._data.info=[{name:Nonefornameinself.cfg.data_types}for_inrange(self._view.count)]
+
+ def_update_intrinsic_matrices(self,env_ids:Sequence[int]):
+"""Compute camera's matrix of intrinsic parameters.
+
+ Also called calibration matrix. This matrix works for linear depth images. We assume square pixels.
+
+ Note:
+ The calibration matrix projects points in the 3D scene onto an imaginary screen of the camera.
+ The coordinates of points on the image plane are in the homogeneous representation.
+ """
+ # iterate over all cameras
+ foriinenv_ids:
+ # Get corresponding sensor prim
+ sensor_prim=self._sensor_prims[i]
+ # get camera parameters
+ focal_length=sensor_prim.GetFocalLengthAttr().Get()
+ horiz_aperture=sensor_prim.GetHorizontalApertureAttr().Get()
+ vert_aperture=sensor_prim.GetVerticalApertureAttr().Get()
+ horiz_aperture_offset=sensor_prim.GetHorizontalApertureOffsetAttr().Get()
+ vert_aperture_offset=sensor_prim.GetVerticalApertureOffsetAttr().Get()
+ # get viewport parameters
+ height,width=self.image_shape
+ # extract intrinsic parameters
+ f_x=(width*focal_length)/horiz_aperture
+ f_y=(height*focal_length)/vert_aperture
+ c_x=width*0.5+horiz_aperture_offset*f_x
+ c_y=height*0.5+vert_aperture_offset*f_y
+ # create intrinsic matrix for depth linear
+ self._data.intrinsic_matrices[i,0,0]=f_x
+ self._data.intrinsic_matrices[i,0,2]=c_x
+ self._data.intrinsic_matrices[i,1,1]=f_y
+ self._data.intrinsic_matrices[i,1,2]=c_y
+ self._data.intrinsic_matrices[i,2,2]=1
+
+ def_update_poses(self,env_ids:Sequence[int]):
+"""Computes the pose of the camera in the world frame with ROS convention.
+
+ This methods uses the ROS convention to resolve the input pose. In this convention,
+ we assume that the camera front-axis is +Z-axis and up-axis is -Y-axis.
+
+ Returns:
+ A tuple of the position (in meters) and quaternion (w, x, y, z).
+ """
+ # check camera prim exists
+ iflen(self._sensor_prims)==0:
+ raiseRuntimeError("Camera prim is None. Please call 'sim.play()' first.")
+
+ # get the poses from the view
+ poses,quat=self._view.get_world_poses(env_ids)
+ self._data.pos_w[env_ids]=poses
+ self._data.quat_w_world[env_ids]=convert_orientation_convention(quat,origin="opengl",target="world")
+
+ def_create_annotator_data(self):
+"""Create the buffers to store the annotator data.
+
+ We create a buffer for each annotator and store the data in a dictionary. Since the data
+ shape is not known beforehand, we create a list of buffers and concatenate them later.
+
+ This is an expensive operation and should be called only once.
+ """
+ # add data from the annotators
+ forname,annotatorsinself._rep_registry.items():
+ # create a list to store the data for each annotator
+ data_all_cameras=list()
+ # iterate over all the annotators
+ forindexinself._ALL_INDICES:
+ # get the output
+ output=annotators[index].get_data()
+ # process the output
+ data,info=self._process_annotator_output(name,output)
+ # append the data
+ data_all_cameras.append(data)
+ # store the info
+ self._data.info[index][name]=info
+ # concatenate the data along the batch dimension
+ self._data.output[name]=torch.stack(data_all_cameras,dim=0)
+
+ def_process_annotator_output(self,name:str,output:Any)->tuple[torch.tensor,dict|None]:
+"""Process the annotator output.
+
+ This function is called after the data has been collected from all the cameras.
+ """
+ # extract info and data from the output
+ ifisinstance(output,dict):
+ data=output["data"]
+ info=output["info"]
+ else:
+ data=output
+ info=None
+ # convert data into torch tensor
+ data=convert_to_torch(data,device=self.device)
+
+ # process data for different segmentation types
+ # Note: Replicator returns raw buffers of dtype int32 for segmentation types
+ # so we need to convert them to uint8 4 channel images for colorized types
+ height,width=self.image_shape
+ ifname=="semantic_segmentation":
+ ifself.cfg.colorize_semantic_segmentation:
+ data=data.view(torch.uint8).reshape(height,width,-1)
+ else:
+ data=data.view(height,width,1)
+ elifname=="instance_segmentation_fast":
+ ifself.cfg.colorize_instance_segmentation:
+ data=data.view(torch.uint8).reshape(height,width,-1)
+ else:
+ data=data.view(height,width,1)
+ elifname=="instance_id_segmentation_fast":
+ ifself.cfg.colorize_instance_id_segmentation:
+ data=data.view(torch.uint8).reshape(height,width,-1)
+ else:
+ data=data.view(height,width,1)
+ # make sure buffer dimensions are consistent as (H, W, C)
+ elifname=="distance_to_camera"orname=="distance_to_image_plane"orname=="depth":
+ data=data.view(height,width,1)
+ # we only return the RGB channels from the RGBA output if rgb is required
+ # normals return (x, y, z) in first 3 channels, 4th channel is unused
+ elifname=="rgb"orname=="normals":
+ data=data[...,:3]
+ # motion vectors return (x, y) in first 2 channels, 3rd and 4th channels are unused
+ elifname=="motion_vectors":
+ data=data[...,:2]
+
+ # return the data and info
+ returndata,info
+
+"""
+ Internal simulation callbacks.
+ """
+
+ def_invalidate_initialize_callback(self,event):
+"""Invalidates the scene elements."""
+ # call parent
+ super()._invalidate_initialize_callback(event)
+ # set all existing views to None to invalidate them
+ self._view=None
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+fromdataclassesimportMISSING
+fromtypingimportLiteral
+
+fromomni.isaac.lab.simimportFisheyeCameraCfg,PinholeCameraCfg
+fromomni.isaac.lab.utilsimportconfigclass
+
+from..sensor_base_cfgimportSensorBaseCfg
+from.cameraimportCamera
+
+
+
[文档]@configclass
+classCameraCfg(SensorBaseCfg):
+"""Configuration for a camera sensor."""
+
+
[文档]@configclass
+ classOffsetCfg:
+"""The offset pose of the sensor's frame from the sensor's parent frame."""
+
+ pos:tuple[float,float,float]=(0.0,0.0,0.0)
+"""Translation w.r.t. the parent frame. Defaults to (0.0, 0.0, 0.0)."""
+
+ rot:tuple[float,float,float,float]=(1.0,0.0,0.0,0.0)
+"""Quaternion rotation (w, x, y, z) w.r.t. the parent frame. Defaults to (1.0, 0.0, 0.0, 0.0)."""
+
+ convention:Literal["opengl","ros","world"]="ros"
+"""The convention in which the frame offset is applied. Defaults to "ros".
+
+ - ``"opengl"`` - forward axis: ``-Z`` - up axis: ``+Y`` - Offset is applied in the OpenGL (Usd.Camera) convention.
+ - ``"ros"`` - forward axis: ``+Z`` - up axis: ``-Y`` - Offset is applied in the ROS convention.
+ - ``"world"`` - forward axis: ``+X`` - up axis: ``+Z`` - Offset is applied in the World Frame convention.
+
+ """
+
+ class_type:type=Camera
+
+ offset:OffsetCfg=OffsetCfg()
+"""The offset pose of the sensor's frame from the sensor's parent frame. Defaults to identity.
+
+ Note:
+ The parent frame is the frame the sensor attaches to. For example, the parent frame of a
+ camera at path ``/World/envs/env_0/Robot/Camera`` is ``/World/envs/env_0/Robot``.
+ """
+
+ spawn:PinholeCameraCfg|FisheyeCameraCfg|None=MISSING
+"""Spawn configuration for the asset.
+
+ If None, then the prim is not spawned by the asset. Instead, it is assumed that the
+ asset is already present in the scene.
+ """
+
+ data_types:list[str]=["rgb"]
+"""List of sensor names/types to enable for the camera. Defaults to ["rgb"].
+
+ Please refer to the :class:`Camera` class for a list of available data types.
+ """
+
+ width:int=MISSING
+"""Width of the image in pixels."""
+
+ height:int=MISSING
+"""Height of the image in pixels."""
+
+ semantic_filter:str|list[str]="*:*"
+"""A string or a list specifying a semantic filter predicate. Defaults to ``"*:*"``.
+
+ If a string, it should be a disjunctive normal form of (semantic type, labels). For examples:
+
+ * ``"typeA : labelA & !labelB | labelC , typeB: labelA ; typeC: labelE"``:
+ All prims with semantic type "typeA" and label "labelA" but not "labelB" or with label "labelC".
+ Also, all prims with semantic type "typeB" and label "labelA", or with semantic type "typeC" and label "labelE".
+ * ``"typeA : * ; * : labelA"``: All prims with semantic type "typeA" or with label "labelA"
+
+ If a list of strings, each string should be a semantic type. The segmentation for prims with
+ semantics of the specified types will be retrieved. For example, if the list is ["class"], only
+ the segmentation for prims with semantics of type "class" will be retrieved.
+
+ .. seealso::
+
+ For more information on the semantics filter, see the documentation on `Replicator Semantics Schema Editor`_.
+
+ .. _Replicator Semantics Schema Editor: https://docs.omniverse.nvidia.com/extensions/latest/ext_replicator/semantics_schema_editor.html#semantics-filtering
+ """
+
+ colorize_semantic_segmentation:bool=True
+"""Whether to colorize the semantic segmentation images. Defaults to True.
+
+ If True, semantic segmentation is converted to an image where semantic IDs are mapped to colors
+ and returned as a ``uint8`` 4-channel array. If False, the output is returned as a ``int32`` array.
+ """
+
+ colorize_instance_id_segmentation:bool=True
+"""Whether to colorize the instance ID segmentation images. Defaults to True.
+
+ If True, instance id segmentation is converted to an image where instance IDs are mapped to colors.
+ and returned as a ``uint8`` 4-channel array. If False, the output is returned as a ``int32`` array.
+ """
+
+ colorize_instance_segmentation:bool=True
+"""Whether to colorize the instance ID segmentation images. Defaults to True.
+
+ If True, instance segmentation is converted to an image where instance IDs are mapped to colors.
+ and returned as a ``uint8`` 4-channel array. If False, the output is returned as a ``int32`` array.
+ """
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+importtorch
+fromdataclassesimportdataclass
+fromtensordictimportTensorDict
+fromtypingimportAny
+
+from.utilsimportconvert_orientation_convention
+
+
+
[文档]@dataclass
+classCameraData:
+"""Data container for the camera sensor."""
+
+ ##
+ # Frame state.
+ ##
+
+ pos_w:torch.Tensor=None
+"""Position of the sensor origin in world frame, following ROS convention.
+
+ Shape is (N, 3) where N is the number of sensors.
+ """
+
+ quat_w_world:torch.Tensor=None
+"""Quaternion orientation `(w, x, y, z)` of the sensor origin in world frame, following the world coordinate frame
+
+ .. note::
+ World frame convention follows the camera aligned with forward axis +X and up axis +Z.
+
+ Shape is (N, 4) where N is the number of sensors.
+ """
+
+ ##
+ # Camera data
+ ##
+
+ image_shape:tuple[int,int]=None
+"""A tuple containing (height, width) of the camera sensor."""
+
+ intrinsic_matrices:torch.Tensor=None
+"""The intrinsic matrices for the camera.
+
+ Shape is (N, 3, 3) where N is the number of sensors.
+ """
+
+ output:TensorDict=None
+"""The retrieved sensor data with sensor types as key.
+
+ The format of the data is available in the `Replicator Documentation`_. For semantic-based data,
+ this corresponds to the ``"data"`` key in the output of the sensor.
+
+ .. _Replicator Documentation: https://docs.omniverse.nvidia.com/prod_extensions/prod_extensions/ext_replicator/annotators_details.html#annotator-output
+ """
+
+ info:list[dict[str,Any]]=None
+"""The retrieved sensor info with sensor types as key.
+
+ This contains extra information provided by the sensor such as semantic segmentation label mapping, prim paths.
+ For semantic-based data, this corresponds to the ``"info"`` key in the output of the sensor. For other sensor
+ types, the info is empty.
+ """
+
+ ##
+ # Additional Frame orientation conventions
+ ##
+
+ @property
+ defquat_w_ros(self)->torch.Tensor:
+"""Quaternion orientation `(w, x, y, z)` of the sensor origin in the world frame, following ROS convention.
+
+ .. note::
+ ROS convention follows the camera aligned with forward axis +Z and up axis -Y.
+
+ Shape is (N, 4) where N is the number of sensors.
+ """
+ returnconvert_orientation_convention(self.quat_w_world,origin="world",target="ros")
+
+ @property
+ defquat_w_opengl(self)->torch.Tensor:
+"""Quaternion orientation `(w, x, y, z)` of the sensor origin in the world frame, following
+ Opengl / USD Camera convention.
+
+ .. note::
+ OpenGL convention follows the camera aligned with forward axis -Z and up axis +Y.
+
+ Shape is (N, 4) where N is the number of sensors.
+ """
+ returnconvert_orientation_convention(self.quat_w_world,origin="world",target="opengl")
[文档]classTiledCamera(Camera):
+r"""The tiled rendering based camera sensor for acquiring the same data as the Camera class.
+
+ This class inherits from the :class:`Camera` class but uses the tiled-rendering API to acquire
+ the visual data. Tiled-rendering concatenates the rendered images from multiple cameras into a single image.
+ This allows for rendering multiple cameras in parallel and is useful for rendering large scenes with multiple
+ cameras efficiently.
+
+ The following sensor types are supported:
+
+ - ``"rgb"``: A 3-channel rendered color image.
+ - ``"rgba"``: A 4-channel rendered color image with alpha channel.
+ - ``"distance_to_camera"``: An image containing the distance to camera optical center.
+ - ``"distance_to_image_plane"``: An image containing distances of 3D points from camera plane along camera's z-axis.
+ - ``"depth"``: Alias for ``"distance_to_image_plane"``.
+ - ``"normals"``: An image containing the local surface normal vectors at each pixel.
+ - ``"motion_vectors"``: An image containing the motion vector data at each pixel.
+ - ``"semantic_segmentation"``: The semantic segmentation data.
+ - ``"instance_segmentation_fast"``: The instance segmentation data.
+ - ``"instance_id_segmentation_fast"``: The instance id segmentation data.
+
+ .. note::
+ Currently the following sensor types are not supported in a "view" format:
+
+ - ``"instance_segmentation"``: The instance segmentation data. Please use the fast counterparts instead.
+ - ``"instance_id_segmentation"``: The instance id segmentation data. Please use the fast counterparts instead.
+ - ``"bounding_box_2d_tight"``: The tight 2D bounding box data (only contains non-occluded regions).
+ - ``"bounding_box_2d_tight_fast"``: The tight 2D bounding box data (only contains non-occluded regions).
+ - ``"bounding_box_2d_loose"``: The loose 2D bounding box data (contains occluded regions).
+ - ``"bounding_box_2d_loose_fast"``: The loose 2D bounding box data (contains occluded regions).
+ - ``"bounding_box_3d"``: The 3D view space bounding box data.
+ - ``"bounding_box_3d_fast"``: The 3D view space bounding box data.
+
+ .. _replicator extension: https://docs.omniverse.nvidia.com/extensions/latest/ext_replicator/annotators_details.html#annotator-output
+ .. _USDGeom Camera: https://graphics.pixar.com/usd/docs/api/class_usd_geom_camera.html
+
+ .. versionadded:: v1.0.0
+
+ This feature is available starting from Isaac Sim 4.2. Before this version, the tiled rendering APIs
+ were not available.
+
+ """
+
+ cfg:TiledCameraCfg
+"""The configuration parameters."""
+
+
[文档]def__init__(self,cfg:TiledCameraCfg):
+"""Initializes the tiled camera sensor.
+
+ Args:
+ cfg: The configuration parameters.
+
+ Raises:
+ RuntimeError: If no camera prim is found at the given path.
+ RuntimeError: If Isaac Sim version < 4.2
+ ValueError: If the provided data types are not supported by the camera.
+ """
+ isaac_sim_version=float(".".join(get_version()[2:4]))
+ ifisaac_sim_version<4.2:
+ raiseRuntimeError(
+ f"TiledCamera is only available from Isaac Sim 4.2.0. Current version is {isaac_sim_version}. Please"
+ " update to Isaac Sim 4.2.0"
+ )
+ super().__init__(cfg)
+
+ def__del__(self):
+"""Unsubscribes from callbacks and detach from the replicator registry."""
+ # unsubscribe from callbacks
+ SensorBase.__del__(self)
+ # detach from the replicator registry
+ forannotatorinself._annotators.values():
+ annotator.detach(self.render_product_paths)
+
+ def__str__(self)->str:
+"""Returns: A string containing information about the instance."""
+ # message for class
+ return(
+ f"Tiled Camera @ '{self.cfg.prim_path}': \n"
+ f"\tdata types : {self.data.output.sorted_keys}\n"
+ f"\tsemantic filter : {self.cfg.semantic_filter}\n"
+ f"\tcolorize semantic segm. : {self.cfg.colorize_semantic_segmentation}\n"
+ f"\tcolorize instance segm. : {self.cfg.colorize_instance_segmentation}\n"
+ f"\tcolorize instance id segm.: {self.cfg.colorize_instance_id_segmentation}\n"
+ f"\tupdate period (s): {self.cfg.update_period}\n"
+ f"\tshape : {self.image_shape}\n"
+ f"\tnumber of sensors : {self._view.count}"
+ )
+
+"""
+ Operations
+ """
+
+
[文档]defreset(self,env_ids:Sequence[int]|None=None):
+ ifnotself._is_initialized:
+ raiseRuntimeError(
+ "TiledCamera could not be initialized. Please ensure --enable_cameras is used to enable rendering."
+ )
+ # reset the timestamps
+ SensorBase.reset(self,env_ids)
+ # resolve None
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ # reset the frame count
+ self._frame[env_ids]=0
+
+"""
+ Implementation.
+ """
+
+ def_initialize_impl(self):
+"""Initializes the sensor handles and internal buffers.
+
+ This function creates handles and registers the provided data types with the replicator registry to
+ be able to access the data from the sensor. It also initializes the internal buffers to store the data.
+
+ Raises:
+ RuntimeError: If the number of camera prims in the view does not match the number of environments.
+ RuntimeError: If replicator was not found.
+ """
+ carb_settings_iface=carb.settings.get_settings()
+ ifnotcarb_settings_iface.get("/isaaclab/cameras_enabled"):
+ raiseRuntimeError(
+ "A camera was spawned without the --enable_cameras flag. Please use --enable_cameras to enable"
+ " rendering."
+ )
+
+ importomni.replicator.coreasrep
+
+ # Initialize parent class
+ SensorBase._initialize_impl(self)
+ # Create a view for the sensor
+ self._view=XFormPrimView(self.cfg.prim_path,reset_xform_properties=False)
+ self._view.initialize()
+ # Check that sizes are correct
+ ifself._view.count!=self._num_envs:
+ raiseRuntimeError(
+ f"Number of camera prims in the view ({self._view.count}) does not match"
+ f" the number of environments ({self._num_envs})."
+ )
+
+ # Create all env_ids buffer
+ self._ALL_INDICES=torch.arange(self._view.count,device=self._device,dtype=torch.long)
+ # Create frame count buffer
+ self._frame=torch.zeros(self._view.count,device=self._device,dtype=torch.long)
+
+ # Obtain current stage
+ stage=omni.usd.get_context().get_stage()
+ # Convert all encapsulated prims to Camera
+ forcam_prim_pathinself._view.prim_paths:
+ # Get camera prim
+ cam_prim=stage.GetPrimAtPath(cam_prim_path)
+ # Check if prim is a camera
+ ifnotcam_prim.IsA(UsdGeom.Camera):
+ raiseRuntimeError(f"Prim at path '{cam_prim_path}' is not a Camera.")
+ # Add to list
+ sensor_prim=UsdGeom.Camera(cam_prim)
+ self._sensor_prims.append(sensor_prim)
+
+ # Create replicator tiled render product
+ rp=rep.create.render_product_tiled(
+ cameras=self._view.prim_paths,tile_resolution=(self.cfg.width,self.cfg.height)
+ )
+ self._render_product_paths=[rp.path]
+
+ # WAR: use DLAA antialiasing to avoid frame offset issue at small resolutions
+ ifself._tiling_grid_shape()[0]*self.cfg.width<265orself._tiling_grid_shape()[1]*self.cfg.height<265:
+ rep.settings.set_render_rtx_realtime(antialiasing="DLAA")
+
+ # Define the annotators based on requested data types
+ self._annotators=dict()
+ forannotator_typeinself.cfg.data_types:
+ ifannotator_type=="rgba"orannotator_type=="rgb":
+ annotator=rep.AnnotatorRegistry.get_annotator("rgb",device=self.device,do_array_copy=False)
+ self._annotators["rgba"]=annotator
+ elifannotator_type=="depth"orannotator_type=="distance_to_image_plane":
+ # keep depth for backwards compatibility
+ annotator=rep.AnnotatorRegistry.get_annotator(
+ "distance_to_image_plane",device=self.device,do_array_copy=False
+ )
+ self._annotators[annotator_type]=annotator
+ # note: we are verbose here to make it easier to understand the code.
+ # if colorize is true, the data is mapped to colors and a uint8 4 channel image is returned.
+ # if colorize is false, the data is returned as a uint32 image with ids as values.
+ else:
+ init_params=None
+ ifannotator_type=="semantic_segmentation":
+ init_params={"colorize":self.cfg.colorize_semantic_segmentation}
+ elifannotator_type=="instance_segmentation_fast":
+ init_params={"colorize":self.cfg.colorize_instance_segmentation}
+ elifannotator_type=="instance_id_segmentation_fast":
+ init_params={"colorize":self.cfg.colorize_instance_id_segmentation}
+
+ annotator=rep.AnnotatorRegistry.get_annotator(
+ annotator_type,init_params,device=self.device,do_array_copy=False
+ )
+ self._annotators[annotator_type]=annotator
+
+ # Attach the annotator to the render product
+ forannotatorinself._annotators.values():
+ annotator.attach(self._render_product_paths)
+
+ # Create internal buffers
+ self._create_buffers()
+
+ def_update_buffers_impl(self,env_ids:Sequence[int]):
+ # Increment frame count
+ self._frame[env_ids]+=1
+
+ # Extract the flattened image buffer
+ fordata_type,annotatorinself._annotators.items():
+ # check whether returned data is a dict (used for segmentation)
+ output=annotator.get_data()
+ ifisinstance(output,dict):
+ tiled_data_buffer=output["data"]
+ self._data.info[data_type]=output["info"]
+ else:
+ tiled_data_buffer=output
+
+ # convert data buffer to warp array
+ ifisinstance(tiled_data_buffer,np.ndarray):
+ tiled_data_buffer=wp.array(tiled_data_buffer,device=self.device,dtype=wp.uint8)
+ else:
+ tiled_data_buffer=tiled_data_buffer.to(device=self.device)
+
+ # process data for different segmentation types
+ # Note: Replicator returns raw buffers of dtype uint32 for segmentation types
+ # so we need to convert them to uint8 4 channel images for colorized types
+ if(
+ (data_type=="semantic_segmentation"andself.cfg.colorize_semantic_segmentation)
+ or(data_type=="instance_segmentation_fast"andself.cfg.colorize_instance_segmentation)
+ or(data_type=="instance_id_segmentation_fast"andself.cfg.colorize_instance_id_segmentation)
+ ):
+ tiled_data_buffer=wp.array(
+ ptr=tiled_data_buffer.ptr,shape=(*tiled_data_buffer.shape,4),dtype=wp.uint8,device=self.device
+ )
+
+ wp.launch(
+ kernel=reshape_tiled_image,
+ dim=(self._view.count,self.cfg.height,self.cfg.width),
+ inputs=[
+ tiled_data_buffer.flatten(),
+ wp.from_torch(self._data.output[data_type]),# zero-copy alias
+ *list(self._data.output[data_type].shape[1:]),# height, width, num_channels
+ self._tiling_grid_shape()[0],# num_tiles_x
+ ],
+ device=self.device,
+ )
+
+ # alias rgb as first 3 channels of rgba
+ ifdata_type=="rgba"and"rgb"inself.cfg.data_types:
+ self._data.output["rgb"]=self._data.output["rgba"][...,:3]
+
+"""
+ Private Helpers
+ """
+
+ def_check_supported_data_types(self,cfg:TiledCameraCfg):
+"""Checks if the data types are supported by the ray-caster camera."""
+ # check if there is any intersection in unsupported types
+ # reason: these use np structured data types which we can't yet convert to torch tensor
+ common_elements=set(cfg.data_types)&Camera.UNSUPPORTED_TYPES
+ ifcommon_elements:
+ # provide alternative fast counterparts
+ fast_common_elements=[]
+ foritemincommon_elements:
+ if"instance_segmentation"initemor"instance_id_segmentation"initem:
+ fast_common_elements.append(item+"_fast")
+ # raise error
+ raiseValueError(
+ f"TiledCamera class does not support the following sensor types: {common_elements}."
+ "\n\tThis is because these sensor types output numpy structured data types which"
+ "can't be converted to torch tensors easily."
+ "\n\tHint: If you need to work with these sensor types, we recommend using their fast counterparts."
+ f"\n\t\tFast counterparts: {fast_common_elements}"
+ )
+
+ def_create_buffers(self):
+"""Create buffers for storing data."""
+ # create the data object
+ # -- pose of the cameras
+ self._data.pos_w=torch.zeros((self._view.count,3),device=self._device)
+ self._data.quat_w_world=torch.zeros((self._view.count,4),device=self._device)
+ self._update_poses(self._ALL_INDICES)
+ # -- intrinsic matrix
+ self._data.intrinsic_matrices=torch.zeros((self._view.count,3,3),device=self._device)
+ self._update_intrinsic_matrices(self._ALL_INDICES)
+ self._data.image_shape=self.image_shape
+ # -- output data
+ data_dict=dict()
+ if"rgba"inself.cfg.data_typesor"rgb"inself.cfg.data_types:
+ data_dict["rgba"]=torch.zeros(
+ (self._view.count,self.cfg.height,self.cfg.width,4),device=self.device,dtype=torch.uint8
+ ).contiguous()
+ if"rgb"inself.cfg.data_types:
+ # RGB is the first 3 channels of RGBA
+ data_dict["rgb"]=data_dict["rgba"][...,:3]
+ if"distance_to_image_plane"inself.cfg.data_types:
+ data_dict["distance_to_image_plane"]=torch.zeros(
+ (self._view.count,self.cfg.height,self.cfg.width,1),device=self.device,dtype=torch.float32
+ ).contiguous()
+ if"depth"inself.cfg.data_types:
+ data_dict["depth"]=torch.zeros(
+ (self._view.count,self.cfg.height,self.cfg.width,1),device=self.device,dtype=torch.float32
+ ).contiguous()
+ if"distance_to_camera"inself.cfg.data_types:
+ data_dict["distance_to_camera"]=torch.zeros(
+ (self._view.count,self.cfg.height,self.cfg.width,1),device=self.device,dtype=torch.float32
+ ).contiguous()
+ if"normals"inself.cfg.data_types:
+ data_dict["normals"]=torch.zeros(
+ (self._view.count,self.cfg.height,self.cfg.width,3),device=self.device,dtype=torch.float32
+ ).contiguous()
+ if"motion_vectors"inself.cfg.data_types:
+ data_dict["motion_vectors"]=torch.zeros(
+ (self._view.count,self.cfg.height,self.cfg.width,2),device=self.device,dtype=torch.float32
+ ).contiguous()
+ if"semantic_segmentation"inself.cfg.data_types:
+ ifself.cfg.colorize_semantic_segmentation:
+ data_dict["semantic_segmentation"]=torch.zeros(
+ (self._view.count,self.cfg.height,self.cfg.width,4),device=self.device,dtype=torch.uint8
+ ).contiguous()
+ else:
+ data_dict["semantic_segmentation"]=torch.zeros(
+ (self._view.count,self.cfg.height,self.cfg.width,1),device=self.device,dtype=torch.int32
+ ).contiguous()
+ if"instance_segmentation_fast"inself.cfg.data_types:
+ ifself.cfg.colorize_instance_segmentation:
+ data_dict["instance_segmentation_fast"]=torch.zeros(
+ (self._view.count,self.cfg.height,self.cfg.width,4),device=self.device,dtype=torch.uint8
+ ).contiguous()
+ else:
+ data_dict["instance_segmentation_fast"]=torch.zeros(
+ (self._view.count,self.cfg.height,self.cfg.width,1),device=self.device,dtype=torch.int32
+ ).contiguous()
+ if"instance_id_segmentation_fast"inself.cfg.data_types:
+ ifself.cfg.colorize_instance_id_segmentation:
+ data_dict["instance_id_segmentation_fast"]=torch.zeros(
+ (self._view.count,self.cfg.height,self.cfg.width,4),device=self.device,dtype=torch.uint8
+ ).contiguous()
+ else:
+ data_dict["instance_id_segmentation_fast"]=torch.zeros(
+ (self._view.count,self.cfg.height,self.cfg.width,1),device=self.device,dtype=torch.int32
+ ).contiguous()
+
+ self._data.output=TensorDict(data_dict,batch_size=self._view.count,device=self.device)
+ self._data.info=dict()
+
+ def_tiled_image_shape(self)->tuple[int,int]:
+"""Returns a tuple containing the dimension of the tiled image."""
+ cols,rows=self._tiling_grid_shape()
+ return(self.cfg.width*cols,self.cfg.height*rows)
+
+ def_tiling_grid_shape(self)->tuple[int,int]:
+"""Returns a tuple containing the tiling grid dimension."""
+ cols=math.ceil(math.sqrt(self._view.count))
+ rows=math.ceil(self._view.count/cols)
+ return(cols,rows)
+
+ def_create_annotator_data(self):
+ # we do not need to create annotator data for the tiled camera sensor
+ raiseRuntimeError("This function should not be called for the tiled camera sensor.")
+
+ def_process_annotator_output(self,name:str,output:Any)->tuple[torch.tensor,dict|None]:
+ # we do not need to process annotator output for the tiled camera sensor
+ raiseRuntimeError("This function should not be called for the tiled camera sensor.")
+
+"""
+ Internal simulation callbacks.
+ """
+
+ def_invalidate_initialize_callback(self,event):
+"""Invalidates the scene elements."""
+ # call parent
+ super()._invalidate_initialize_callback(event)
+ # set all existing views to None to invalidate them
+ self._view=None
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+fromomni.isaac.lab.utilsimportconfigclass
+
+from.camera_cfgimportCameraCfg
+from.tiled_cameraimportTiledCamera
+
+
+
[文档]@configclass
+classTiledCameraCfg(CameraCfg):
+"""Configuration for a tiled rendering-based camera sensor."""
+
+ class_type:type=TiledCamera
+
+ return_latest_camera_pose:bool=False
+"""Whether to return the latest camera pose when fetching the camera's data. Defaults to False.
+
+ If True, the latest camera pose is returned in the camera's data which will slow down performance
+ due to the use of :class:`XformPrimView`.
+ If False, the pose of the camera during initialization is returned.
+ """
[文档]classContactSensor(SensorBase):
+"""A contact reporting sensor.
+
+ The contact sensor reports the normal contact forces on a rigid body in the world frame.
+ It relies on the `PhysX ContactReporter`_ API to be activated on the rigid bodies.
+
+ To enable the contact reporter on a rigid body, please make sure to enable the
+ :attr:`omni.isaac.lab.sim.spawner.RigidObjectSpawnerCfg.activate_contact_sensors` on your
+ asset spawner configuration. This will enable the contact reporter on all the rigid bodies
+ in the asset.
+
+ The sensor can be configured to report the contact forces on a set of bodies with a given
+ filter pattern using the :attr:`ContactSensorCfg.filter_prim_paths_expr`. This is useful
+ when you want to report the contact forces between the sensor bodies and a specific set of
+ bodies in the scene. The data can be accessed using the :attr:`ContactSensorData.force_matrix_w`.
+ Please check the documentation on `RigidContactView`_ for more details.
+
+ The reporting of the filtered contact forces is only possible as one-to-many. This means that only one
+ sensor body in an environment can be filtered against multiple bodies in that environment. If you need to
+ filter multiple sensor bodies against multiple bodies, you need to create separate sensors for each sensor
+ body.
+
+ As an example, suppose you want to report the contact forces for all the feet of a robot against an object
+ exclusively. In that case, setting the :attr:`ContactSensorCfg.prim_path` and
+ :attr:`ContactSensorCfg.filter_prim_paths_expr` with ``{ENV_REGEX_NS}/Robot/.*_FOOT`` and ``{ENV_REGEX_NS}/Object``
+ respectively will not work. Instead, you need to create a separate sensor for each foot and filter
+ it against the object.
+
+ .. _PhysX ContactReporter: https://docs.omniverse.nvidia.com/kit/docs/omni_usd_schema_physics/104.2/class_physx_schema_physx_contact_report_a_p_i.html
+ .. _RigidContactView: https://docs.omniverse.nvidia.com/py/isaacsim/source/extensions/omni.isaac.core/docs/index.html#omni.isaac.core.prims.RigidContactView
+ """
+
+ cfg:ContactSensorCfg
+"""The configuration parameters."""
+
+
[文档]def__init__(self,cfg:ContactSensorCfg):
+"""Initializes the contact sensor object.
+
+ Args:
+ cfg: The configuration parameters.
+ """
+ # initialize base class
+ super().__init__(cfg)
+ # Create empty variables for storing output data
+ self._data:ContactSensorData=ContactSensorData()
+ # initialize self._body_physx_view for running in extension mode
+ self._body_physx_view=None
+
+ def__str__(self)->str:
+"""Returns: A string containing information about the instance."""
+ return(
+ f"Contact sensor @ '{self.cfg.prim_path}': \n"
+ f"\tview type : {self.body_physx_view.__class__}\n"
+ f"\tupdate period (s) : {self.cfg.update_period}\n"
+ f"\tnumber of bodies : {self.num_bodies}\n"
+ f"\tbody names : {self.body_names}\n"
+ )
+
+"""
+ Properties
+ """
+
+ @property
+ defnum_instances(self)->int:
+ returnself.body_physx_view.count
+
+ @property
+ defdata(self)->ContactSensorData:
+ # update sensors if needed
+ self._update_outdated_buffers()
+ # return the data
+ returnself._data
+
+ @property
+ defnum_bodies(self)->int:
+"""Number of bodies with contact sensors attached."""
+ returnself._num_bodies
+
+ @property
+ defbody_names(self)->list[str]:
+"""Ordered names of bodies with contact sensors attached."""
+ prim_paths=self.body_physx_view.prim_paths[:self.num_bodies]
+ return[path.split("/")[-1]forpathinprim_paths]
+
+ @property
+ defbody_physx_view(self)->physx.RigidBodyView:
+"""View for the rigid bodies captured (PhysX).
+
+ Note:
+ Use this view with caution. It requires handling of tensors in a specific way.
+ """
+ returnself._body_physx_view
+
+ @property
+ defcontact_physx_view(self)->physx.RigidContactView:
+"""Contact reporter view for the bodies (PhysX).
+
+ Note:
+ Use this view with caution. It requires handling of tensors in a specific way.
+ """
+ returnself._contact_physx_view
+
+"""
+ Operations
+ """
+
+
[文档]defreset(self,env_ids:Sequence[int]|None=None):
+ # reset the timers and counters
+ super().reset(env_ids)
+ # resolve None
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ # reset accumulative data buffers
+ self._data.net_forces_w[env_ids]=0.0
+ self._data.net_forces_w_history[env_ids]=0.0
+ ifself.cfg.history_length>0:
+ self._data.net_forces_w_history[env_ids]=0.0
+ # reset force matrix
+ iflen(self.cfg.filter_prim_paths_expr)!=0:
+ self._data.force_matrix_w[env_ids]=0.0
+ # reset the current air time
+ ifself.cfg.track_air_time:
+ self._data.current_air_time[env_ids]=0.0
+ self._data.last_air_time[env_ids]=0.0
+ self._data.current_contact_time[env_ids]=0.0
+ self._data.last_contact_time[env_ids]=0.0
+
+
[文档]deffind_bodies(self,name_keys:str|Sequence[str],preserve_order:bool=False)->tuple[list[int],list[str]]:
+"""Find bodies in the articulation based on the name keys.
+
+ Args:
+ name_keys: A regular expression or a list of regular expressions to match the body names.
+ preserve_order: Whether to preserve the order of the name keys in the output. Defaults to False.
+
+ Returns:
+ A tuple of lists containing the body indices and names.
+ """
+ returnstring_utils.resolve_matching_names(name_keys,self.body_names,preserve_order)
+
+
[文档]defcompute_first_contact(self,dt:float,abs_tol:float=1.0e-8)->torch.Tensor:
+"""Checks if bodies that have established contact within the last :attr:`dt` seconds.
+
+ This function checks if the bodies have established contact within the last :attr:`dt` seconds
+ by comparing the current contact time with the given time period. If the contact time is less
+ than the given time period, then the bodies are considered to be in contact.
+
+ Note:
+ The function assumes that :attr:`dt` is a factor of the sensor update time-step. In other
+ words :math:`dt / dt_sensor = n`, where :math:`n` is a natural number. This is always true
+ if the sensor is updated by the physics or the environment stepping time-step and the sensor
+ is read by the environment stepping time-step.
+
+ Args:
+ dt: The time period since the contact was established.
+ abs_tol: The absolute tolerance for the comparison.
+
+ Returns:
+ A boolean tensor indicating the bodies that have established contact within the last
+ :attr:`dt` seconds. Shape is (N, B), where N is the number of sensors and B is the
+ number of bodies in each sensor.
+
+ Raises:
+ RuntimeError: If the sensor is not configured to track contact time.
+ """
+ # check if the sensor is configured to track contact time
+ ifnotself.cfg.track_air_time:
+ raiseRuntimeError(
+ "The contact sensor is not configured to track contact time."
+ "Please enable the 'track_air_time' in the sensor configuration."
+ )
+ # check if the bodies are in contact
+ currently_in_contact=self.data.current_contact_time>0.0
+ less_than_dt_in_contact=self.data.current_contact_time<(dt+abs_tol)
+ returncurrently_in_contact*less_than_dt_in_contact
+
+
[文档]defcompute_first_air(self,dt:float,abs_tol:float=1.0e-8)->torch.Tensor:
+"""Checks if bodies that have broken contact within the last :attr:`dt` seconds.
+
+ This function checks if the bodies have broken contact within the last :attr:`dt` seconds
+ by comparing the current air time with the given time period. If the air time is less
+ than the given time period, then the bodies are considered to not be in contact.
+
+ Note:
+ It assumes that :attr:`dt` is a factor of the sensor update time-step. In other words,
+ :math:`dt / dt_sensor = n`, where :math:`n` is a natural number. This is always true if
+ the sensor is updated by the physics or the environment stepping time-step and the sensor
+ is read by the environment stepping time-step.
+
+ Args:
+ dt: The time period since the contract is broken.
+ abs_tol: The absolute tolerance for the comparison.
+
+ Returns:
+ A boolean tensor indicating the bodies that have broken contact within the last :attr:`dt` seconds.
+ Shape is (N, B), where N is the number of sensors and B is the number of bodies in each sensor.
+
+ Raises:
+ RuntimeError: If the sensor is not configured to track contact time.
+ """
+ # check if the sensor is configured to track contact time
+ ifnotself.cfg.track_air_time:
+ raiseRuntimeError(
+ "The contact sensor is not configured to track contact time."
+ "Please enable the 'track_air_time' in the sensor configuration."
+ )
+ # check if the sensor is configured to track contact time
+ currently_detached=self.data.current_air_time>0.0
+ less_than_dt_detached=self.data.current_air_time<(dt+abs_tol)
+ returncurrently_detached*less_than_dt_detached
+
+"""
+ Implementation.
+ """
+
+ def_initialize_impl(self):
+ super()._initialize_impl()
+ # create simulation view
+ self._physics_sim_view=physx.create_simulation_view(self._backend)
+ self._physics_sim_view.set_subspace_roots("/")
+ # check that only rigid bodies are selected
+ leaf_pattern=self.cfg.prim_path.rsplit("/",1)[-1]
+ template_prim_path=self._parent_prims[0].GetPath().pathString
+ body_names=list()
+ forpriminsim_utils.find_matching_prims(template_prim_path+"/"+leaf_pattern):
+ # check if prim has contact reporter API
+ ifprim.HasAPI(PhysxSchema.PhysxContactReportAPI):
+ prim_path=prim.GetPath().pathString
+ body_names.append(prim_path.rsplit("/",1)[-1])
+ # check that there is at least one body with contact reporter API
+ ifnotbody_names:
+ raiseRuntimeError(
+ f"Sensor at path '{self.cfg.prim_path}' could not find any bodies with contact reporter API."
+ "\nHINT: Make sure to enable 'activate_contact_sensors' in the corresponding asset spawn configuration."
+ )
+
+ # construct regex expression for the body names
+ body_names_regex=r"("+"|".join(body_names)+r")"
+ body_names_regex=f"{self.cfg.prim_path.rsplit('/',1)[0]}/{body_names_regex}"
+ # convert regex expressions to glob expressions for PhysX
+ body_names_glob=body_names_regex.replace(".*","*")
+ filter_prim_paths_glob=[expr.replace(".*","*")forexprinself.cfg.filter_prim_paths_expr]
+
+ # create a rigid prim view for the sensor
+ self._body_physx_view=self._physics_sim_view.create_rigid_body_view(body_names_glob)
+ self._contact_physx_view=self._physics_sim_view.create_rigid_contact_view(
+ body_names_glob,filter_patterns=filter_prim_paths_glob
+ )
+ # resolve the true count of bodies
+ self._num_bodies=self.body_physx_view.count//self._num_envs
+ # check that contact reporter succeeded
+ ifself._num_bodies!=len(body_names):
+ raiseRuntimeError(
+ "Failed to initialize contact reporter for specified bodies."
+ f"\n\tInput prim path : {self.cfg.prim_path}"
+ f"\n\tResolved prim paths: {body_names_regex}"
+ )
+
+ # prepare data buffers
+ self._data.net_forces_w=torch.zeros(self._num_envs,self._num_bodies,3,device=self._device)
+ # optional buffers
+ # -- history of net forces
+ ifself.cfg.history_length>0:
+ self._data.net_forces_w_history=torch.zeros(
+ self._num_envs,self.cfg.history_length,self._num_bodies,3,device=self._device
+ )
+ else:
+ self._data.net_forces_w_history=self._data.net_forces_w.unsqueeze(1)
+ # -- pose of sensor origins
+ ifself.cfg.track_pose:
+ self._data.pos_w=torch.zeros(self._num_envs,self._num_bodies,3,device=self._device)
+ self._data.quat_w=torch.zeros(self._num_envs,self._num_bodies,4,device=self._device)
+ # -- air/contact time between contacts
+ ifself.cfg.track_air_time:
+ self._data.last_air_time=torch.zeros(self._num_envs,self._num_bodies,device=self._device)
+ self._data.current_air_time=torch.zeros(self._num_envs,self._num_bodies,device=self._device)
+ self._data.last_contact_time=torch.zeros(self._num_envs,self._num_bodies,device=self._device)
+ self._data.current_contact_time=torch.zeros(self._num_envs,self._num_bodies,device=self._device)
+ # force matrix: (num_envs, num_bodies, num_filter_shapes, 3)
+ iflen(self.cfg.filter_prim_paths_expr)!=0:
+ num_filters=self.contact_physx_view.filter_count
+ self._data.force_matrix_w=torch.zeros(
+ self._num_envs,self._num_bodies,num_filters,3,device=self._device
+ )
+
+ def_update_buffers_impl(self,env_ids:Sequence[int]):
+"""Fills the buffers of the sensor data."""
+ # default to all sensors
+ iflen(env_ids)==self._num_envs:
+ env_ids=slice(None)
+
+ # obtain the contact forces
+ # TODO: We are handling the indexing ourself because of the shape; (N, B) vs expected (N * B).
+ # This isn't the most efficient way to do this, but it's the easiest to implement.
+ net_forces_w=self.contact_physx_view.get_net_contact_forces(dt=self._sim_physics_dt)
+ self._data.net_forces_w[env_ids,:,:]=net_forces_w.view(-1,self._num_bodies,3)[env_ids]
+ # update contact force history
+ ifself.cfg.history_length>0:
+ self._data.net_forces_w_history[env_ids,1:]=self._data.net_forces_w_history[env_ids,:-1].clone()
+ self._data.net_forces_w_history[env_ids,0]=self._data.net_forces_w[env_ids]
+
+ # obtain the contact force matrix
+ iflen(self.cfg.filter_prim_paths_expr)!=0:
+ # shape of the filtering matrix: (num_envs, num_bodies, num_filter_shapes, 3)
+ num_filters=self.contact_physx_view.filter_count
+ # acquire and shape the force matrix
+ force_matrix_w=self.contact_physx_view.get_contact_force_matrix(dt=self._sim_physics_dt)
+ force_matrix_w=force_matrix_w.view(-1,self._num_bodies,num_filters,3)
+ self._data.force_matrix_w[env_ids]=force_matrix_w[env_ids]
+
+ # obtain the pose of the sensor origin
+ ifself.cfg.track_pose:
+ pose=self.body_physx_view.get_transforms().view(-1,self._num_bodies,7)[env_ids]
+ pose[...,3:]=convert_quat(pose[...,3:],to="wxyz")
+ self._data.pos_w[env_ids],self._data.quat_w[env_ids]=pose.split([3,4],dim=-1)
+
+ # obtain the air time
+ ifself.cfg.track_air_time:
+ # -- time elapsed since last update
+ # since this function is called every frame, we can use the difference to get the elapsed time
+ elapsed_time=self._timestamp[env_ids]-self._timestamp_last_update[env_ids]
+ # -- check contact state of bodies
+ is_contact=torch.norm(self._data.net_forces_w[env_ids,:,:],dim=-1)>self.cfg.force_threshold
+ is_first_contact=(self._data.current_air_time[env_ids]>0)*is_contact
+ is_first_detached=(self._data.current_contact_time[env_ids]>0)*~is_contact
+ # -- update the last contact time if body has just become in contact
+ self._data.last_air_time[env_ids]=torch.where(
+ is_first_contact,
+ self._data.current_air_time[env_ids]+elapsed_time.unsqueeze(-1),
+ self._data.last_air_time[env_ids],
+ )
+ # -- increment time for bodies that are not in contact
+ self._data.current_air_time[env_ids]=torch.where(
+ ~is_contact,self._data.current_air_time[env_ids]+elapsed_time.unsqueeze(-1),0.0
+ )
+ # -- update the last contact time if body has just detached
+ self._data.last_contact_time[env_ids]=torch.where(
+ is_first_detached,
+ self._data.current_contact_time[env_ids]+elapsed_time.unsqueeze(-1),
+ self._data.last_contact_time[env_ids],
+ )
+ # -- increment time for bodies that are in contact
+ self._data.current_contact_time[env_ids]=torch.where(
+ is_contact,self._data.current_contact_time[env_ids]+elapsed_time.unsqueeze(-1),0.0
+ )
+
+ def_set_debug_vis_impl(self,debug_vis:bool):
+ # set visibility of markers
+ # note: parent only deals with callbacks. not their visibility
+ ifdebug_vis:
+ # create markers if necessary for the first tome
+ ifnothasattr(self,"contact_visualizer"):
+ self.contact_visualizer=VisualizationMarkers(self.cfg.visualizer_cfg)
+ # set their visibility to true
+ self.contact_visualizer.set_visibility(True)
+ else:
+ ifhasattr(self,"contact_visualizer"):
+ self.contact_visualizer.set_visibility(False)
+
+ def_debug_vis_callback(self,event):
+ # safely return if view becomes invalid
+ # note: this invalidity happens because of isaac sim view callbacks
+ ifself.body_physx_viewisNone:
+ return
+ # marker indices
+ # 0: contact, 1: no contact
+ net_contact_force_w=torch.norm(self._data.net_forces_w,dim=-1)
+ marker_indices=torch.where(net_contact_force_w>self.cfg.force_threshold,0,1)
+ # check if prim is visualized
+ ifself.cfg.track_pose:
+ frame_origins:torch.Tensor=self._data.pos_w
+ else:
+ pose=self.body_physx_view.get_transforms()
+ frame_origins=pose.view(-1,self._num_bodies,7)[:,:,:3]
+ # visualize
+ self.contact_visualizer.visualize(frame_origins.view(-1,3),marker_indices=marker_indices.view(-1))
+
+"""
+ Internal simulation callbacks.
+ """
+
+ def_invalidate_initialize_callback(self,event):
+"""Invalidates the scene elements."""
+ # call parent
+ super()._invalidate_initialize_callback(event)
+ # set all existing views to None to invalidate them
+ self._physics_sim_view=None
+ self._body_physx_view=None
+ self._contact_physx_view=None
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+fromomni.isaac.lab.markersimportVisualizationMarkersCfg
+fromomni.isaac.lab.markers.configimportCONTACT_SENSOR_MARKER_CFG
+fromomni.isaac.lab.utilsimportconfigclass
+
+from..sensor_base_cfgimportSensorBaseCfg
+from.contact_sensorimportContactSensor
+
+
+
[文档]@configclass
+classContactSensorCfg(SensorBaseCfg):
+"""Configuration for the contact sensor."""
+
+ class_type:type=ContactSensor
+
+ track_pose:bool=False
+"""Whether to track the pose of the sensor's origin. Defaults to False."""
+
+ track_air_time:bool=False
+"""Whether to track the air/contact time of the bodies (time between contacts). Defaults to False."""
+
+ force_threshold:float=1.0
+"""The threshold on the norm of the contact force that determines whether two bodies are in collision or not.
+
+ This value is only used for tracking the mode duration (the time in contact or in air),
+ if :attr:`track_air_time` is True.
+ """
+
+ filter_prim_paths_expr:list[str]=list()
+"""The list of primitive paths (or expressions) to filter contacts with. Defaults to an empty list, in which case
+ no filtering is applied.
+
+ The contact sensor allows reporting contacts between the primitive specified with :attr:`prim_path` and
+ other primitives in the scene. For instance, in a scene containing a robot, a ground plane and an object,
+ you can obtain individual contact reports of the base of the robot with the ground plane and the object.
+
+ .. note::
+ The expression in the list can contain the environment namespace regex ``{ENV_REGEX_NS}`` which
+ will be replaced with the environment namespace.
+
+ Example: ``{ENV_REGEX_NS}/Object`` will be replaced with ``/World/envs/env_.*/Object``.
+
+ .. attention::
+ The reporting of filtered contacts only works when the sensor primitive :attr:`prim_path` corresponds to a
+ single primitive in that environment. If the sensor primitive corresponds to multiple primitives, the
+ filtering will not work as expected. Please check :class:`~omni.isaac.lab.sensors.contact_sensor.ContactSensor`
+ for more details.
+ """
+
+ visualizer_cfg:VisualizationMarkersCfg=CONTACT_SENSOR_MARKER_CFG.replace(prim_path="/Visuals/ContactSensor")
+"""The configuration object for the visualization markers. Defaults to CONTACT_SENSOR_MARKER_CFG.
+
+ .. note::
+ This attribute is only used when debug visualization is enabled.
+ """
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+# needed to import for allowing type-hinting: torch.Tensor | None
+from__future__importannotations
+
+importtorch
+fromdataclassesimportdataclass
+
+
+
[文档]@dataclass
+classContactSensorData:
+"""Data container for the contact reporting sensor."""
+
+ pos_w:torch.Tensor|None=None
+"""Position of the sensor origin in world frame.
+
+ Shape is (N, 3), where N is the number of sensors.
+
+ Note:
+ If the :attr:`ContactSensorCfg.track_pose` is False, then this quantity is None.
+ """
+
+ quat_w:torch.Tensor|None=None
+"""Orientation of the sensor origin in quaternion (w, x, y, z) in world frame.
+
+ Shape is (N, 4), where N is the number of sensors.
+
+ Note:
+ If the :attr:`ContactSensorCfg.track_pose` is False, then this quantity is None.
+ """
+
+ net_forces_w:torch.Tensor|None=None
+"""The net normal contact forces in world frame.
+
+ Shape is (N, B, 3), where N is the number of sensors and B is the number of bodies in each sensor.
+
+ Note:
+ This quantity is the sum of the normal contact forces acting on the sensor bodies. It must not be confused
+ with the total contact forces acting on the sensor bodies (which also includes the tangential forces).
+ """
+
+ net_forces_w_history:torch.Tensor|None=None
+"""The net normal contact forces in world frame.
+
+ Shape is (N, T, B, 3), where N is the number of sensors, T is the configured history length
+ and B is the number of bodies in each sensor.
+
+ In the history dimension, the first index is the most recent and the last index is the oldest.
+
+ Note:
+ This quantity is the sum of the normal contact forces acting on the sensor bodies. It must not be confused
+ with the total contact forces acting on the sensor bodies (which also includes the tangential forces).
+ """
+
+ force_matrix_w:torch.Tensor|None=None
+"""The normal contact forces filtered between the sensor bodies and filtered bodies in world frame.
+
+ Shape is (N, B, M, 3), where N is the number of sensors, B is number of bodies in each sensor
+ and ``M`` is the number of filtered bodies.
+
+ Note:
+ If the :attr:`ContactSensorCfg.filter_prim_paths_expr` is empty, then this quantity is None.
+ """
+
+ last_air_time:torch.Tensor|None=None
+"""Time spent (in s) in the air before the last contact.
+
+ Shape is (N, B), where N is the number of sensors and B is the number of bodies in each sensor.
+
+ Note:
+ If the :attr:`ContactSensorCfg.track_air_time` is False, then this quantity is None.
+ """
+
+ current_air_time:torch.Tensor|None=None
+"""Time spent (in s) in the air since the last detach.
+
+ Shape is (N, B), where N is the number of sensors and B is the number of bodies in each sensor.
+
+ Note:
+ If the :attr:`ContactSensorCfg.track_air_time` is False, then this quantity is None.
+ """
+
+ last_contact_time:torch.Tensor|None=None
+"""Time spent (in s) in contact before the last detach.
+
+ Shape is (N, B), where N is the number of sensors and B is the number of bodies in each sensor.
+
+ Note:
+ If the :attr:`ContactSensorCfg.track_air_time` is False, then this quantity is None.
+ """
+
+ current_contact_time:torch.Tensor|None=None
+"""Time spent (in s) in contact since the last contact.
+
+ Shape is (N, B), where N is the number of sensors and B is the number of bodies in each sensor.
+
+ Note:
+ If the :attr:`ContactSensorCfg.track_air_time` is False, then this quantity is None.
+ """
[文档]classFrameTransformer(SensorBase):
+"""A sensor for reporting frame transforms.
+
+ This class provides an interface for reporting the transform of one or more frames (target frames)
+ with respect to another frame (source frame). The source frame is specified by the user as a prim path
+ (:attr:`FrameTransformerCfg.prim_path`) and the target frames are specified by the user as a list of
+ prim paths (:attr:`FrameTransformerCfg.target_frames`).
+
+ The source frame and target frames are assumed to be rigid bodies. The transform of the target frames
+ with respect to the source frame is computed by first extracting the transform of the source frame
+ and target frames from the physics engine and then computing the relative transform between the two.
+
+ Additionally, the user can specify an offset for the source frame and each target frame. This is useful
+ for specifying the transform of the desired frame with respect to the body's center of mass, for instance.
+
+ A common example of using this sensor is to track the position and orientation of the end effector of a
+ robotic manipulator. In this case, the source frame would be the body corresponding to the base frame of the
+ manipulator, and the target frame would be the body corresponding to the end effector. Since the end-effector is
+ typically a fictitious body, the user may need to specify an offset from the end-effector to the body of the
+ manipulator.
+
+ .. note::
+
+ Currently, this implementation only handles frames within an articulation. This is because the frame
+ regex expressions are resolved based on their parent prim path. This can be extended to handle
+ frames outside of articulation by using the frame prim path instead. However, this would require
+ additional checks to ensure that the user-specified frames are valid which is not currently implemented.
+
+ .. warning::
+
+ The implementation assumes that the parent body of a target frame is not the same as that
+ of the source frame (i.e. :attr:`FrameTransformerCfg.prim_path`). While a corner case, this can occur
+ if the user specifies the same prim path for both the source frame and target frame. In this case,
+ the target frame will be ignored and not reported. This is a limitation of the current implementation
+ and will be fixed in a future release.
+
+ """
+
+ cfg:FrameTransformerCfg
+"""The configuration parameters."""
+
+
[文档]def__init__(self,cfg:FrameTransformerCfg):
+"""Initializes the frame transformer object.
+
+ Args:
+ cfg: The configuration parameters.
+ """
+ # initialize base class
+ super().__init__(cfg)
+ # Create empty variables for storing output data
+ self._data:FrameTransformerData=FrameTransformerData()
+
+ def__str__(self)->str:
+"""Returns: A string containing information about the instance."""
+ return(
+ f"FrameTransformer @ '{self.cfg.prim_path}': \n"
+ f"\ttracked body frames: {[self._source_frame_body_name]+self._target_frame_body_names}\n"
+ f"\tnumber of envs: {self._num_envs}\n"
+ f"\tsource body frame: {self._source_frame_body_name}\n"
+ f"\ttarget frames (count: {self._target_frame_names}): {len(self._target_frame_names)}\n"
+ )
+
+"""
+ Properties
+ """
+
+ @property
+ defdata(self)->FrameTransformerData:
+ # update sensors if needed
+ self._update_outdated_buffers()
+ # return the data
+ returnself._data
+
+"""
+ Operations
+ """
+
+
[文档]defreset(self,env_ids:Sequence[int]|None=None):
+ # reset the timers and counters
+ super().reset(env_ids)
+ # resolve None
+ ifenv_idsisNone:
+ env_ids=...
+
+"""
+ Implementation.
+ """
+
+ def_initialize_impl(self):
+ super()._initialize_impl()
+
+ # resolve source frame offset
+ source_frame_offset_pos=torch.tensor(self.cfg.source_frame_offset.pos,device=self.device)
+ source_frame_offset_quat=torch.tensor(self.cfg.source_frame_offset.rot,device=self.device)
+ # Only need to perform offsetting of source frame if the position offsets is non-zero and rotation offset is
+ # not the identity quaternion for efficiency in _update_buffer_impl
+ self._apply_source_frame_offset=True
+ # Handle source frame offsets
+ ifis_identity_pose(source_frame_offset_pos,source_frame_offset_quat):
+ carb.log_verbose(f"No offset application needed for source frame as it is identity: {self.cfg.prim_path}")
+ self._apply_source_frame_offset=False
+ else:
+ carb.log_verbose(f"Applying offset to source frame as it is not identity: {self.cfg.prim_path}")
+ # Store offsets as tensors (duplicating each env's offsets for ease of multiplication later)
+ self._source_frame_offset_pos=source_frame_offset_pos.unsqueeze(0).repeat(self._num_envs,1)
+ self._source_frame_offset_quat=source_frame_offset_quat.unsqueeze(0).repeat(self._num_envs,1)
+
+ # Keep track of mapping from the rigid body name to the desired frame, as there may be multiple frames
+ # based upon the same body name and we don't want to create unnecessary views
+ body_names_to_frames:dict[str,set[str]]={}
+ # The offsets associated with each target frame
+ target_offsets:dict[str,dict[str,torch.Tensor]]={}
+ # The frames whose offsets are not identity
+ non_identity_offset_frames:list[str]=[]
+
+ # Only need to perform offsetting of target frame if any of the position offsets are non-zero or any of the
+ # rotation offsets are not the identity quaternion for efficiency in _update_buffer_impl
+ self._apply_target_frame_offset=False
+
+ # Collect all target frames, their associated body prim paths and their offsets so that we can extract
+ # the prim, check that it has the appropriate rigid body API in a single loop.
+ # First element is None because user can't specify source frame name
+ frames=[None]+[target_frame.namefortarget_frameinself.cfg.target_frames]
+ frame_prim_paths=[self.cfg.prim_path]+[target_frame.prim_pathfortarget_frameinself.cfg.target_frames]
+ # First element is None because source frame offset is handled separately
+ frame_offsets=[None]+[target_frame.offsetfortarget_frameinself.cfg.target_frames]
+ forframe,prim_path,offsetinzip(frames,frame_prim_paths,frame_offsets):
+ # Find correct prim
+ matching_prims=sim_utils.find_matching_prims(prim_path)
+ iflen(matching_prims)==0:
+ raiseValueError(
+ f"Failed to create frame transformer for frame '{frame}' with path '{prim_path}'."
+ " No matching prims were found."
+ )
+ forpriminmatching_prims:
+ # Get the prim path of the matching prim
+ matching_prim_path=prim.GetPath().pathString
+ # Check if it is a rigid prim
+ ifnotprim.HasAPI(UsdPhysics.RigidBodyAPI):
+ raiseValueError(
+ f"While resolving expression '{prim_path}' found a prim '{matching_prim_path}' which is not a"
+ " rigid body. The class only supports transformations between rigid bodies."
+ )
+
+ # Get the name of the body
+ body_name=matching_prim_path.rsplit("/",1)[-1]
+ # Use body name if frame isn't specified by user
+ frame_name=frameifframeisnotNoneelsebody_name
+
+ # Keep track of which frames are associated with which bodies
+ ifbody_nameinbody_names_to_frames:
+ body_names_to_frames[body_name].add(frame_name)
+ else:
+ body_names_to_frames[body_name]={frame_name}
+
+ ifoffsetisnotNone:
+ offset_pos=torch.tensor(offset.pos,device=self.device)
+ offset_quat=torch.tensor(offset.rot,device=self.device)
+ # Check if we need to apply offsets (optimized code path in _update_buffer_impl)
+ ifnotis_identity_pose(offset_pos,offset_quat):
+ non_identity_offset_frames.append(frame_name)
+ self._apply_target_frame_offset=True
+
+ target_offsets[frame_name]={"pos":offset_pos,"quat":offset_quat}
+
+ ifnotself._apply_target_frame_offset:
+ carb.log_info(
+ f"No offsets application needed from '{self.cfg.prim_path}' to target frames as all"
+ f" are identity: {frames[1:]}"
+ )
+ else:
+ carb.log_info(
+ f"Offsets application needed from '{self.cfg.prim_path}' to the following target frames:"
+ f" {non_identity_offset_frames}"
+ )
+
+ # The names of bodies that RigidPrimView will be tracking to later extract transforms from
+ tracked_body_names=list(body_names_to_frames.keys())
+ # Construct regex expression for the body names
+ body_names_regex=r"("+"|".join(tracked_body_names)+r")"
+ body_names_regex=f"{self.cfg.prim_path.rsplit('/',1)[0]}/{body_names_regex}"
+ # Create simulation view
+ self._physics_sim_view=physx.create_simulation_view(self._backend)
+ self._physics_sim_view.set_subspace_roots("/")
+ # Create a prim view for all frames and initialize it
+ # order of transforms coming out of view will be source frame followed by target frame(s)
+ self._frame_physx_view=self._physics_sim_view.create_rigid_body_view(body_names_regex.replace(".*","*"))
+
+ # Determine the order in which regex evaluated body names so we can later index into frame transforms
+ # by frame name correctly
+ all_prim_paths=self._frame_physx_view.prim_paths
+
+ # Only need first env as the names and their ordering are the same across environments
+ first_env_prim_paths=all_prim_paths[0:len(tracked_body_names)]
+ first_env_body_names=[first_env_prim_path.split("/")[-1]forfirst_env_prim_pathinfirst_env_prim_paths]
+
+ # Re-parse the list as it may have moved when resolving regex above
+ # -- source frame
+ self._source_frame_body_name=self.cfg.prim_path.split("/")[-1]
+ source_frame_index=first_env_body_names.index(self._source_frame_body_name)
+ # -- target frames
+ self._target_frame_body_names=first_env_body_names[:]
+ self._target_frame_body_names.remove(self._source_frame_body_name)
+
+ # Determine indices into all tracked body frames for both source and target frames
+ all_ids=torch.arange(self._num_envs*len(tracked_body_names))
+ self._source_frame_body_ids=torch.arange(self._num_envs)*len(tracked_body_names)+source_frame_index
+ self._target_frame_body_ids=all_ids[~torch.isin(all_ids,self._source_frame_body_ids)]
+
+ # The name of each of the target frame(s) - either user specified or defaulted to the body name
+ self._target_frame_names:list[str]=[]
+ # The position and rotation components of target frame offsets
+ target_frame_offset_pos=[]
+ target_frame_offset_quat=[]
+ # Stores the indices of bodies that need to be duplicated. For instance, if body "LF_SHANK" is needed
+ # for 2 frames, this list enables us to duplicate the body to both frames when doing the calculations
+ # when updating sensor in _update_buffers_impl
+ duplicate_frame_indices=[]
+
+ # Go through each body name and determine the number of duplicates we need for that frame
+ # and extract the offsets. This is all done to handles the case where multiple frames
+ # reference the same body, but have different names and/or offsets
+ fori,body_nameinenumerate(self._target_frame_body_names):
+ forframeinbody_names_to_frames[body_name]:
+ target_frame_offset_pos.append(target_offsets[frame]["pos"])
+ target_frame_offset_quat.append(target_offsets[frame]["quat"])
+ self._target_frame_names.append(frame)
+ duplicate_frame_indices.append(i)
+
+ # To handle multiple environments, need to expand so [0, 1, 1, 2] with 2 environments becomes
+ # [0, 1, 1, 2, 3, 4, 4, 5]. Again, this is a optimization to make _update_buffer_impl more efficient
+ duplicate_frame_indices=torch.tensor(duplicate_frame_indices,device=self.device)
+ num_target_body_frames=len(tracked_body_names)-1
+ self._duplicate_frame_indices=torch.cat(
+ [duplicate_frame_indices+num_target_body_frames*env_numforenv_numinrange(self._num_envs)]
+ )
+
+ # Stack up all the frame offsets for shape (num_envs, num_frames, 3) and (num_envs, num_frames, 4)
+ self._target_frame_offset_pos=torch.stack(target_frame_offset_pos).repeat(self._num_envs,1)
+ self._target_frame_offset_quat=torch.stack(target_frame_offset_quat).repeat(self._num_envs,1)
+
+ # fill the data buffer
+ self._data.target_frame_names=self._target_frame_names
+ self._data.source_pos_w=torch.zeros(self._num_envs,3,device=self._device)
+ self._data.source_quat_w=torch.zeros(self._num_envs,4,device=self._device)
+ self._data.target_pos_w=torch.zeros(self._num_envs,len(duplicate_frame_indices),3,device=self._device)
+ self._data.target_quat_w=torch.zeros(self._num_envs,len(duplicate_frame_indices),4,device=self._device)
+ self._data.target_pos_source=torch.zeros_like(self._data.target_pos_w)
+ self._data.target_quat_source=torch.zeros_like(self._data.target_quat_w)
+
+ def_update_buffers_impl(self,env_ids:Sequence[int]):
+"""Fills the buffers of the sensor data."""
+ # default to all sensors
+ iflen(env_ids)==self._num_envs:
+ env_ids=...
+
+ # Extract transforms from view - shape is:
+ # (the total number of source and target body frames being tracked * self._num_envs, 7)
+ transforms=self._frame_physx_view.get_transforms()
+ # Convert quaternions as PhysX uses xyzw form
+ transforms[:,3:]=convert_quat(transforms[:,3:],to="wxyz")
+
+ # Process source frame transform
+ source_frames=transforms[self._source_frame_body_ids]
+ # Only apply offset if the offsets will result in a coordinate frame transform
+ ifself._apply_source_frame_offset:
+ source_pos_w,source_quat_w=combine_frame_transforms(
+ source_frames[:,:3],
+ source_frames[:,3:],
+ self._source_frame_offset_pos,
+ self._source_frame_offset_quat,
+ )
+ else:
+ source_pos_w=source_frames[:,:3]
+ source_quat_w=source_frames[:,3:]
+
+ # Process target frame transforms
+ target_frames=transforms[self._target_frame_body_ids]
+ duplicated_target_frame_pos_w=target_frames[self._duplicate_frame_indices,:3]
+ duplicated_target_frame_quat_w=target_frames[self._duplicate_frame_indices,3:]
+ # Only apply offset if the offsets will result in a coordinate frame transform
+ ifself._apply_target_frame_offset:
+ target_pos_w,target_quat_w=combine_frame_transforms(
+ duplicated_target_frame_pos_w,
+ duplicated_target_frame_quat_w,
+ self._target_frame_offset_pos,
+ self._target_frame_offset_quat,
+ )
+ else:
+ target_pos_w=duplicated_target_frame_pos_w
+ target_quat_w=duplicated_target_frame_quat_w
+
+ # Compute the transform of the target frame with respect to the source frame
+ total_num_frames=len(self._target_frame_names)
+ target_pos_source,target_quat_source=subtract_frame_transforms(
+ source_pos_w.unsqueeze(1).expand(-1,total_num_frames,-1).reshape(-1,3),
+ source_quat_w.unsqueeze(1).expand(-1,total_num_frames,-1).reshape(-1,4),
+ target_pos_w,
+ target_quat_w,
+ )
+
+ # Update buffers
+ # note: The frame names / ordering don't change so no need to update them after initialization
+ self._data.source_pos_w[:]=source_pos_w.view(-1,3)
+ self._data.source_quat_w[:]=source_quat_w.view(-1,4)
+ self._data.target_pos_w[:]=target_pos_w.view(-1,total_num_frames,3)
+ self._data.target_quat_w[:]=target_quat_w.view(-1,total_num_frames,4)
+ self._data.target_pos_source[:]=target_pos_source.view(-1,total_num_frames,3)
+ self._data.target_quat_source[:]=target_quat_source.view(-1,total_num_frames,4)
+
+ def_set_debug_vis_impl(self,debug_vis:bool):
+ # set visibility of markers
+ # note: parent only deals with callbacks. not their visibility
+ ifdebug_vis:
+ ifnothasattr(self,"frame_visualizer"):
+ self.frame_visualizer=VisualizationMarkers(self.cfg.visualizer_cfg)
+ # set their visibility to true
+ self.frame_visualizer.set_visibility(True)
+ else:
+ ifhasattr(self,"frame_visualizer"):
+ self.frame_visualizer.set_visibility(False)
+
+ def_debug_vis_callback(self,event):
+ # Update the visualized markers
+ ifself.frame_visualizerisnotNone:
+ self.frame_visualizer.visualize(self._data.target_pos_w.view(-1,3),self._data.target_quat_w.view(-1,4))
+
+"""
+ Internal simulation callbacks.
+ """
+
+ def_invalidate_initialize_callback(self,event):
+"""Invalidates the scene elements."""
+ # call parent
+ super()._invalidate_initialize_callback(event)
+ # set all existing views to None to invalidate them
+ self._physics_sim_view=None
+ self._frame_physx_view=None
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+fromdataclassesimportMISSING
+
+fromomni.isaac.lab.markers.configimportFRAME_MARKER_CFG,VisualizationMarkersCfg
+fromomni.isaac.lab.utilsimportconfigclass
+
+from..sensor_base_cfgimportSensorBaseCfg
+from.frame_transformerimportFrameTransformer
+
+
+
[文档]@configclass
+classOffsetCfg:
+"""The offset pose of one frame relative to another frame."""
+
+ pos:tuple[float,float,float]=(0.0,0.0,0.0)
+"""Translation w.r.t. the parent frame. Defaults to (0.0, 0.0, 0.0)."""
+ rot:tuple[float,float,float,float]=(1.0,0.0,0.0,0.0)
+"""Quaternion rotation (w, x, y, z) w.r.t. the parent frame. Defaults to (1.0, 0.0, 0.0, 0.0)."""
+
+
+
[文档]@configclass
+classFrameTransformerCfg(SensorBaseCfg):
+"""Configuration for the frame transformer sensor."""
+
+
[文档]@configclass
+ classFrameCfg:
+"""Information specific to a coordinate frame."""
+
+ prim_path:str=MISSING
+"""The prim path corresponding to the parent rigid body.
+
+ This prim should be part of the same articulation as :attr:`FrameTransformerCfg.prim_path`.
+ """
+ name:str|None=None
+"""User-defined name for the new coordinate frame. Defaults to None.
+
+ If None, then the name is extracted from the leaf of the prim path.
+ """
+
+ offset:OffsetCfg=OffsetCfg()
+"""The pose offset from the parent prim frame."""
+
+ class_type:type=FrameTransformer
+
+ prim_path:str=MISSING
+"""The prim path of the body to transform from (source frame)."""
+
+ source_frame_offset:OffsetCfg=OffsetCfg()
+"""The pose offset from the source prim frame."""
+
+ target_frames:list[FrameCfg]=MISSING
+"""A list of the target frames.
+
+ This allows a single FrameTransformer to handle multiple target prims. For example, in a quadruped,
+ we can use a single FrameTransformer to track each foot's position and orientation in the body
+ frame using four frame offsets.
+ """
+
+ visualizer_cfg:VisualizationMarkersCfg=FRAME_MARKER_CFG.replace(prim_path="/Visuals/FrameTransformer")
+"""The configuration object for the visualization markers. Defaults to FRAME_MARKER_CFG.
+
+ Note:
+ This attribute is only used when debug visualization is enabled.
+ """
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+importtorch
+fromdataclassesimportdataclass
+
+
+
[文档]@dataclass
+classFrameTransformerData:
+"""Data container for the frame transformer sensor."""
+
+ target_frame_names:list[str]=None
+"""Target frame names (this denotes the order in which that frame data is ordered).
+
+ The frame names are resolved from the :attr:`FrameTransformerCfg.FrameCfg.name` field.
+ This usually follows the order in which the frames are defined in the config. However, in
+ the case of regex matching, the order may be different.
+ """
+
+ target_pos_source:torch.Tensor=None
+"""Position of the target frame(s) relative to source frame.
+
+ Shape is (N, M, 3), where N is the number of environments, and M is the number of target frames.
+ """
+
+ target_quat_source:torch.Tensor=None
+"""Orientation of the target frame(s) relative to source frame quaternion (w, x, y, z).
+
+ Shape is (N, M, 4), where N is the number of environments, and M is the number of target frames.
+ """
+
+ target_pos_w:torch.Tensor=None
+"""Position of the target frame(s) after offset (in world frame).
+
+ Shape is (N, M, 3), where N is the number of environments, and M is the number of target frames.
+ """
+
+ target_quat_w:torch.Tensor=None
+"""Orientation of the target frame(s) after offset (in world frame) quaternion (w, x, y, z).
+
+ Shape is (N, M, 4), where N is the number of environments, and M is the number of target frames.
+ """
+
+ source_pos_w:torch.Tensor=None
+"""Position of the source frame after offset (in world frame).
+
+ Shape is (N, 3), where N is the number of environments.
+ """
+
+ source_quat_w:torch.Tensor=None
+"""Orientation of the source frame after offset (in world frame) quaternion (w, x, y, z).
+
+ Shape is (N, 4), where N is the number of environments.
+ """
[文档]classRayCaster(SensorBase):
+"""A ray-casting sensor.
+
+ The ray-caster uses a set of rays to detect collisions with meshes in the scene. The rays are
+ defined in the sensor's local coordinate frame. The sensor can be configured to ray-cast against
+ a set of meshes with a given ray pattern.
+
+ The meshes are parsed from the list of primitive paths provided in the configuration. These are then
+ converted to warp meshes and stored in the `warp_meshes` list. The ray-caster then ray-casts against
+ these warp meshes using the ray pattern provided in the configuration.
+
+ .. note::
+ Currently, only static meshes are supported. Extending the warp mesh to support dynamic meshes
+ is a work in progress.
+ """
+
+ cfg:RayCasterCfg
+"""The configuration parameters."""
+ meshes:ClassVar[dict[str,wp.Mesh]]={}
+"""The warp meshes available for raycasting.
+
+ The keys correspond to the prim path for the meshes, and values are the corresponding warp Mesh objects.
+
+ Note:
+ We store a global dictionary of all warp meshes to prevent re-loading the mesh for different ray-cast sensor instances.
+ """
+
+
[文档]def__init__(self,cfg:RayCasterCfg):
+"""Initializes the ray-caster object.
+
+ Args:
+ cfg: The configuration parameters.
+ """
+ # check if sensor path is valid
+ # note: currently we do not handle environment indices if there is a regex pattern in the leaf
+ # For example, if the prim path is "/World/Sensor_[1,2]".
+ sensor_path=cfg.prim_path.split("/")[-1]
+ sensor_path_is_regex=re.match(r"^[a-zA-Z0-9/_]+$",sensor_path)isNone
+ ifsensor_path_is_regex:
+ raiseRuntimeError(
+ f"Invalid prim path for the ray-caster sensor: {self.cfg.prim_path}."
+ "\n\tHint: Please ensure that the prim path does not contain any regex patterns in the leaf."
+ )
+ # Initialize base class
+ super().__init__(cfg)
+ # Create empty variables for storing output data
+ self._data=RayCasterData()
+
+ def__str__(self)->str:
+"""Returns: A string containing information about the instance."""
+ return(
+ f"Ray-caster @ '{self.cfg.prim_path}': \n"
+ f"\tview type : {self._view.__class__}\n"
+ f"\tupdate period (s) : {self.cfg.update_period}\n"
+ f"\tnumber of meshes : {len(RayCaster.meshes)}\n"
+ f"\tnumber of sensors : {self._view.count}\n"
+ f"\tnumber of rays/sensor: {self.num_rays}\n"
+ f"\ttotal number of rays : {self.num_rays*self._view.count}"
+ )
+
+"""
+ Properties
+ """
+
+ @property
+ defnum_instances(self)->int:
+ returnself._view.count
+
+ @property
+ defdata(self)->RayCasterData:
+ # update sensors if needed
+ self._update_outdated_buffers()
+ # return the data
+ returnself._data
+
+"""
+ Operations.
+ """
+
+
[文档]defreset(self,env_ids:Sequence[int]|None=None):
+ # reset the timers and counters
+ super().reset(env_ids)
+ # resolve None
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ # resample the drift
+ self.drift[env_ids].uniform_(*self.cfg.drift_range)
+
+"""
+ Implementation.
+ """
+
+ def_initialize_impl(self):
+ super()._initialize_impl()
+ # create simulation view
+ self._physics_sim_view=physx.create_simulation_view(self._backend)
+ self._physics_sim_view.set_subspace_roots("/")
+ # check if the prim at path is an articulated or rigid prim
+ # we do this since for physics-based view classes we can access their data directly
+ # otherwise we need to use the xform view class which is slower
+ found_supported_prim_class=False
+ prim=sim_utils.find_first_matching_prim(self.cfg.prim_path)
+ ifprimisNone:
+ raiseRuntimeError(f"Failed to find a prim at path expression: {self.cfg.prim_path}")
+ # create view based on the type of prim
+ ifprim.HasAPI(UsdPhysics.ArticulationRootAPI):
+ self._view=self._physics_sim_view.create_articulation_view(self.cfg.prim_path.replace(".*","*"))
+ found_supported_prim_class=True
+ elifprim.HasAPI(UsdPhysics.RigidBodyAPI):
+ self._view=self._physics_sim_view.create_rigid_body_view(self.cfg.prim_path.replace(".*","*"))
+ found_supported_prim_class=True
+ else:
+ self._view=XFormPrimView(self.cfg.prim_path,reset_xform_properties=False)
+ found_supported_prim_class=True
+ carb.log_warn(f"The prim at path {prim.GetPath().pathString} is not a physics prim! Using XFormPrimView.")
+ # check if prim view class is found
+ ifnotfound_supported_prim_class:
+ raiseRuntimeError(f"Failed to find a valid prim view class for the prim paths: {self.cfg.prim_path}")
+
+ # load the meshes by parsing the stage
+ self._initialize_warp_meshes()
+ # initialize the ray start and directions
+ self._initialize_rays_impl()
+
+ def_initialize_warp_meshes(self):
+ # check number of mesh prims provided
+ iflen(self.cfg.mesh_prim_paths)!=1:
+ raiseNotImplementedError(
+ f"RayCaster currently only supports one mesh prim. Received: {len(self.cfg.mesh_prim_paths)}"
+ )
+
+ # read prims to ray-cast
+ formesh_prim_pathinself.cfg.mesh_prim_paths:
+ # check if mesh already casted into warp mesh
+ ifmesh_prim_pathinRayCaster.meshes:
+ continue
+
+ # check if the prim is a plane - handle PhysX plane as a special case
+ # if a plane exists then we need to create an infinite mesh that is a plane
+ mesh_prim=sim_utils.get_first_matching_child_prim(
+ mesh_prim_path,lambdaprim:prim.GetTypeName()=="Plane"
+ )
+ # if we did not find a plane then we need to read the mesh
+ ifmesh_primisNone:
+ # obtain the mesh prim
+ mesh_prim=sim_utils.get_first_matching_child_prim(
+ mesh_prim_path,lambdaprim:prim.GetTypeName()=="Mesh"
+ )
+ # check if valid
+ ifmesh_primisNoneornotmesh_prim.IsValid():
+ raiseRuntimeError(f"Invalid mesh prim path: {mesh_prim_path}")
+ # cast into UsdGeomMesh
+ mesh_prim=UsdGeom.Mesh(mesh_prim)
+ # read the vertices and faces
+ points=np.asarray(mesh_prim.GetPointsAttr().Get())
+ indices=np.asarray(mesh_prim.GetFaceVertexIndicesAttr().Get())
+ wp_mesh=convert_to_warp_mesh(points,indices,device=self.device)
+ # print info
+ carb.log_info(
+ f"Read mesh prim: {mesh_prim.GetPath()} with {len(points)} vertices and {len(indices)} faces."
+ )
+ else:
+ mesh=make_plane(size=(2e6,2e6),height=0.0,center_zero=True)
+ wp_mesh=convert_to_warp_mesh(mesh.vertices,mesh.faces,device=self.device)
+ # print info
+ carb.log_info(f"Created infinite plane mesh prim: {mesh_prim.GetPath()}.")
+ # add the warp mesh to the list
+ RayCaster.meshes[mesh_prim_path]=wp_mesh
+
+ # throw an error if no meshes are found
+ ifall([mesh_prim_pathnotinRayCaster.meshesformesh_prim_pathinself.cfg.mesh_prim_paths]):
+ raiseRuntimeError(
+ f"No meshes found for ray-casting! Please check the mesh prim paths: {self.cfg.mesh_prim_paths}"
+ )
+
+ def_initialize_rays_impl(self):
+ # compute ray stars and directions
+ self.ray_starts,self.ray_directions=self.cfg.pattern_cfg.func(self.cfg.pattern_cfg,self._device)
+ self.num_rays=len(self.ray_directions)
+ # apply offset transformation to the rays
+ offset_pos=torch.tensor(list(self.cfg.offset.pos),device=self._device)
+ offset_quat=torch.tensor(list(self.cfg.offset.rot),device=self._device)
+ self.ray_directions=quat_apply(offset_quat.repeat(len(self.ray_directions),1),self.ray_directions)
+ self.ray_starts+=offset_pos
+ # repeat the rays for each sensor
+ self.ray_starts=self.ray_starts.repeat(self._view.count,1,1)
+ self.ray_directions=self.ray_directions.repeat(self._view.count,1,1)
+ # prepare drift
+ self.drift=torch.zeros(self._view.count,3,device=self.device)
+ # fill the data buffer
+ self._data.pos_w=torch.zeros(self._view.count,3,device=self._device)
+ self._data.quat_w=torch.zeros(self._view.count,4,device=self._device)
+ self._data.ray_hits_w=torch.zeros(self._view.count,self.num_rays,3,device=self._device)
+
+ def_update_buffers_impl(self,env_ids:Sequence[int]):
+"""Fills the buffers of the sensor data."""
+ # obtain the poses of the sensors
+ ifisinstance(self._view,XFormPrimView):
+ pos_w,quat_w=self._view.get_world_poses(env_ids)
+ elifisinstance(self._view,physx.ArticulationView):
+ pos_w,quat_w=self._view.get_root_transforms()[env_ids].split([3,4],dim=-1)
+ quat_w=convert_quat(quat_w,to="wxyz")
+ elifisinstance(self._view,physx.RigidBodyView):
+ pos_w,quat_w=self._view.get_transforms()[env_ids].split([3,4],dim=-1)
+ quat_w=convert_quat(quat_w,to="wxyz")
+ else:
+ raiseRuntimeError(f"Unsupported view type: {type(self._view)}")
+ # note: we clone here because we are read-only operations
+ pos_w=pos_w.clone()
+ quat_w=quat_w.clone()
+ # apply drift
+ pos_w+=self.drift[env_ids]
+ # store the poses
+ self._data.pos_w[env_ids]=pos_w
+ self._data.quat_w[env_ids]=quat_w
+
+ # ray cast based on the sensor poses
+ ifself.cfg.attach_yaw_only:
+ # only yaw orientation is considered and directions are not rotated
+ ray_starts_w=quat_apply_yaw(quat_w.repeat(1,self.num_rays),self.ray_starts[env_ids])
+ ray_starts_w+=pos_w.unsqueeze(1)
+ ray_directions_w=self.ray_directions[env_ids]
+ else:
+ # full orientation is considered
+ ray_starts_w=quat_apply(quat_w.repeat(1,self.num_rays),self.ray_starts[env_ids])
+ ray_starts_w+=pos_w.unsqueeze(1)
+ ray_directions_w=quat_apply(quat_w.repeat(1,self.num_rays),self.ray_directions[env_ids])
+ # ray cast and store the hits
+ # TODO: Make this work for multiple meshes?
+ self._data.ray_hits_w[env_ids]=raycast_mesh(
+ ray_starts_w,
+ ray_directions_w,
+ max_dist=self.cfg.max_distance,
+ mesh=RayCaster.meshes[self.cfg.mesh_prim_paths[0]],
+ )[0]
+
+ def_set_debug_vis_impl(self,debug_vis:bool):
+ # set visibility of markers
+ # note: parent only deals with callbacks. not their visibility
+ ifdebug_vis:
+ ifnothasattr(self,"ray_visualizer"):
+ self.ray_visualizer=VisualizationMarkers(self.cfg.visualizer_cfg)
+ # set their visibility to true
+ self.ray_visualizer.set_visibility(True)
+ else:
+ ifhasattr(self,"ray_visualizer"):
+ self.ray_visualizer.set_visibility(False)
+
+ def_debug_vis_callback(self,event):
+ # show ray hit positions
+ self.ray_visualizer.visualize(self._data.ray_hits_w.view(-1,3))
+
+"""
+ Internal simulation callbacks.
+ """
+
+ def_invalidate_initialize_callback(self,event):
+"""Invalidates the scene elements."""
+ # call parent
+ super()._invalidate_initialize_callback(event)
+ # set all existing views to None to invalidate them
+ self._physics_sim_view=None
+ self._view=None
[文档]classRayCasterCamera(RayCaster):
+"""A ray-casting camera sensor.
+
+ The ray-caster camera uses a set of rays to get the distances to meshes in the scene. The rays are
+ defined in the sensor's local coordinate frame. The sensor has the same interface as the
+ :class:`omni.isaac.lab.sensors.Camera` that implements the camera class through USD camera prims.
+ However, this class provides a faster image generation. The sensor converts meshes from the list of
+ primitive paths provided in the configuration to Warp meshes. The camera then ray-casts against these
+ Warp meshes only.
+
+ Currently, only the following annotators are supported:
+
+ - ``"distance_to_camera"``: An image containing the distance to camera optical center.
+ - ``"distance_to_image_plane"``: An image containing distances of 3D points from camera plane along camera's z-axis.
+ - ``"normals"``: An image containing the local surface normal vectors at each pixel.
+
+ .. note::
+ Currently, only static meshes are supported. Extending the warp mesh to support dynamic meshes
+ is a work in progress.
+ """
+
+ cfg:RayCasterCameraCfg
+"""The configuration parameters."""
+ UNSUPPORTED_TYPES:ClassVar[set[str]]={
+ "rgb",
+ "instance_id_segmentation",
+ "instance_id_segmentation_fast",
+ "instance_segmentation",
+ "instance_segmentation_fast",
+ "semantic_segmentation",
+ "skeleton_data",
+ "motion_vectors",
+ "bounding_box_2d_tight",
+ "bounding_box_2d_tight_fast",
+ "bounding_box_2d_loose",
+ "bounding_box_2d_loose_fast",
+ "bounding_box_3d",
+ "bounding_box_3d_fast",
+ }
+"""A set of sensor types that are not supported by the ray-caster camera."""
+
+
[文档]def__init__(self,cfg:RayCasterCameraCfg):
+"""Initializes the camera object.
+
+ Args:
+ cfg: The configuration parameters.
+
+ Raises:
+ ValueError: If the provided data types are not supported by the ray-caster camera.
+ """
+ # perform check on supported data types
+ self._check_supported_data_types(cfg)
+ # initialize base class
+ super().__init__(cfg)
+ # create empty variables for storing output data
+ self._data=CameraData()
+
+ def__str__(self)->str:
+"""Returns: A string containing information about the instance."""
+ return(
+ f"Ray-Caster-Camera @ '{self.cfg.prim_path}': \n"
+ f"\tview type : {self._view.__class__}\n"
+ f"\tupdate period (s) : {self.cfg.update_period}\n"
+ f"\tnumber of meshes : {len(RayCaster.meshes)}\n"
+ f"\tnumber of sensors : {self._view.count}\n"
+ f"\tnumber of rays/sensor: {self.num_rays}\n"
+ f"\ttotal number of rays : {self.num_rays*self._view.count}\n"
+ f"\timage shape : {self.image_shape}"
+ )
+
+"""
+ Properties
+ """
+
+ @property
+ defdata(self)->CameraData:
+ # update sensors if needed
+ self._update_outdated_buffers()
+ # return the data
+ returnself._data
+
+ @property
+ defimage_shape(self)->tuple[int,int]:
+"""A tuple containing (height, width) of the camera sensor."""
+ return(self.cfg.pattern_cfg.height,self.cfg.pattern_cfg.width)
+
+ @property
+ defframe(self)->torch.tensor:
+"""Frame number when the measurement took place."""
+ returnself._frame
+
+"""
+ Operations.
+ """
+
+
[文档]defset_intrinsic_matrices(
+ self,matrices:torch.Tensor,focal_length:float=1.0,env_ids:Sequence[int]|None=None
+ ):
+"""Set the intrinsic matrix of the camera.
+
+ Args:
+ matrices: The intrinsic matrices for the camera. Shape is (N, 3, 3).
+ focal_length: Focal length to use when computing aperture values (in cm). Defaults to 1.0.
+ env_ids: A sensor ids to manipulate. Defaults to None, which means all sensor indices.
+ """
+ # resolve env_ids
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ # save new intrinsic matrices and focal length
+ self._data.intrinsic_matrices[env_ids]=matrices.to(self._device)
+ self._focal_length=focal_length
+ # recompute ray directions
+ self.ray_starts[env_ids],self.ray_directions[env_ids]=self.cfg.pattern_cfg.func(
+ self.cfg.pattern_cfg,self._data.intrinsic_matrices[env_ids],self._device
+ )
+
+
[文档]defreset(self,env_ids:Sequence[int]|None=None):
+ # reset the timestamps
+ super().reset(env_ids)
+ # resolve None
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ # reset the data
+ # note: this recomputation is useful if one performs events such as randomizations on the camera poses.
+ pos_w,quat_w=self._compute_camera_world_poses(env_ids)
+ self._data.pos_w[env_ids]=pos_w
+ self._data.quat_w_world[env_ids]=quat_w
+ # Reset the frame count
+ self._frame[env_ids]=0
+
+
[文档]defset_world_poses(
+ self,
+ positions:torch.Tensor|None=None,
+ orientations:torch.Tensor|None=None,
+ env_ids:Sequence[int]|None=None,
+ convention:Literal["opengl","ros","world"]="ros",
+ ):
+"""Set the pose of the camera w.r.t. the world frame using specified convention.
+
+ Since different fields use different conventions for camera orientations, the method allows users to
+ set the camera poses in the specified convention. Possible conventions are:
+
+ - :obj:`"opengl"` - forward axis: -Z - up axis +Y - Offset is applied in the OpenGL (Usd.Camera) convention
+ - :obj:`"ros"` - forward axis: +Z - up axis -Y - Offset is applied in the ROS convention
+ - :obj:`"world"` - forward axis: +X - up axis +Z - Offset is applied in the World Frame convention
+
+ See :meth:`omni.isaac.lab.sensors.camera.utils.convert_orientation_convention` for more details
+ on the conventions.
+
+ Args:
+ positions: The cartesian coordinates (in meters). Shape is (N, 3).
+ Defaults to None, in which case the camera position in not changed.
+ orientations: The quaternion orientation in (w, x, y, z). Shape is (N, 4).
+ Defaults to None, in which case the camera orientation in not changed.
+ env_ids: A sensor ids to manipulate. Defaults to None, which means all sensor indices.
+ convention: The convention in which the poses are fed. Defaults to "ros".
+
+ Raises:
+ RuntimeError: If the camera prim is not set. Need to call :meth:`initialize` method first.
+ """
+ # resolve env_ids
+ ifenv_idsisNone:
+ env_ids=self._ALL_INDICES
+
+ # get current positions
+ pos_w,quat_w=self._compute_view_world_poses(env_ids)
+ ifpositionsisnotNone:
+ # transform to camera frame
+ pos_offset_world_frame=positions-pos_w
+ self._offset_pos[env_ids]=math_utils.quat_apply(math_utils.quat_inv(quat_w),pos_offset_world_frame)
+ iforientationsisnotNone:
+ # convert rotation matrix from input convention to world
+ quat_w_set=convert_orientation_convention(orientations,origin=convention,target="world")
+ self._offset_quat[env_ids]=math_utils.quat_mul(math_utils.quat_inv(quat_w),quat_w_set)
+
+ # update the data
+ pos_w,quat_w=self._compute_camera_world_poses(env_ids)
+ self._data.pos_w[env_ids]=pos_w
+ self._data.quat_w_world[env_ids]=quat_w
+
+
[文档]defset_world_poses_from_view(
+ self,eyes:torch.Tensor,targets:torch.Tensor,env_ids:Sequence[int]|None=None
+ ):
+"""Set the poses of the camera from the eye position and look-at target position.
+
+ Args:
+ eyes: The positions of the camera's eye. Shape is N, 3).
+ targets: The target locations to look at. Shape is (N, 3).
+ env_ids: A sensor ids to manipulate. Defaults to None, which means all sensor indices.
+
+ Raises:
+ RuntimeError: If the camera prim is not set. Need to call :meth:`initialize` method first.
+ NotImplementedError: If the stage up-axis is not "Y" or "Z".
+ """
+ # camera position and rotation in opengl convention
+ orientations=math_utils.quat_from_matrix(create_rotation_matrix_from_view(eyes,targets,device=self._device))
+ self.set_world_poses(eyes,orientations,env_ids,convention="opengl")
+
+"""
+ Implementation.
+ """
+
+ def_initialize_rays_impl(self):
+ # Create all indices buffer
+ self._ALL_INDICES=torch.arange(self._view.count,device=self._device,dtype=torch.long)
+ # Create frame count buffer
+ self._frame=torch.zeros(self._view.count,device=self._device,dtype=torch.long)
+ # create buffers
+ self._create_buffers()
+ # compute intrinsic matrices
+ self._compute_intrinsic_matrices()
+ # compute ray stars and directions
+ self.ray_starts,self.ray_directions=self.cfg.pattern_cfg.func(
+ self.cfg.pattern_cfg,self._data.intrinsic_matrices,self._device
+ )
+ self.num_rays=self.ray_directions.shape[1]
+ # create buffer to store ray hits
+ self.ray_hits_w=torch.zeros(self._view.count,self.num_rays,3,device=self._device)
+ # set offsets
+ quat_w=convert_orientation_convention(
+ torch.tensor([self.cfg.offset.rot],device=self._device),origin=self.cfg.offset.convention,target="world"
+ )
+ self._offset_quat=quat_w.repeat(self._view.count,1)
+ self._offset_pos=torch.tensor(list(self.cfg.offset.pos),device=self._device).repeat(self._view.count,1)
+
+ def_update_buffers_impl(self,env_ids:Sequence[int]):
+"""Fills the buffers of the sensor data."""
+ # increment frame count
+ self._frame[env_ids]+=1
+
+ # compute poses from current view
+ pos_w,quat_w=self._compute_camera_world_poses(env_ids)
+ # update the data
+ self._data.pos_w[env_ids]=pos_w
+ self._data.quat_w_world[env_ids]=quat_w
+
+ # note: full orientation is considered
+ ray_starts_w=math_utils.quat_apply(quat_w.repeat(1,self.num_rays),self.ray_starts[env_ids])
+ ray_starts_w+=pos_w.unsqueeze(1)
+ ray_directions_w=math_utils.quat_apply(quat_w.repeat(1,self.num_rays),self.ray_directions[env_ids])
+
+ # ray cast and store the hits
+ # note: we set max distance to 1e6 during the ray-casting. THis is because we clip the distance
+ # to the image plane and distance to the camera to the maximum distance afterwards in-order to
+ # match the USD camera behavior.
+
+ # TODO: Make ray-casting work for multiple meshes?
+ # necessary for regular dictionaries.
+ self.ray_hits_w,ray_depth,ray_normal,_=raycast_mesh(
+ ray_starts_w,
+ ray_directions_w,
+ mesh=RayCasterCamera.meshes[self.cfg.mesh_prim_paths[0]],
+ max_dist=1e6,
+ return_distance=any(
+ [nameinself.cfg.data_typesfornamein["distance_to_image_plane","distance_to_camera"]]
+ ),
+ return_normal="normals"inself.cfg.data_types,
+ )
+ # update output buffers
+ if"distance_to_image_plane"inself.cfg.data_types:
+ # note: data is in camera frame so we only take the first component (z-axis of camera frame)
+ distance_to_image_plane=(
+ math_utils.quat_apply(
+ math_utils.quat_inv(quat_w).repeat(1,self.num_rays),
+ (ray_depth[:,:,None]*ray_directions_w),
+ )
+ )[:,:,0]
+ # apply the maximum distance after the transformation
+ distance_to_image_plane=torch.clip(distance_to_image_plane,max=self.cfg.max_distance)
+ self._data.output["distance_to_image_plane"][env_ids]=distance_to_image_plane.view(
+ -1,*self.image_shape,1
+ )
+ if"distance_to_camera"inself.cfg.data_types:
+ self._data.output["distance_to_camera"][env_ids]=torch.clip(
+ ray_depth.view(-1,*self.image_shape,1),max=self.cfg.max_distance
+ )
+ if"normals"inself.cfg.data_types:
+ self._data.output["normals"][env_ids]=ray_normal.view(-1,*self.image_shape,3)
+
+ def_debug_vis_callback(self,event):
+ # in case it crashes be safe
+ ifnothasattr(self,"ray_hits_w"):
+ return
+ # show ray hit positions
+ self.ray_visualizer.visualize(self.ray_hits_w.view(-1,3))
+
+"""
+ Private Helpers
+ """
+
+ def_check_supported_data_types(self,cfg:RayCasterCameraCfg):
+"""Checks if the data types are supported by the ray-caster camera."""
+ # check if there is any intersection in unsupported types
+ # reason: we cannot obtain this data from simplified warp-based ray caster
+ common_elements=set(cfg.data_types)&RayCasterCamera.UNSUPPORTED_TYPES
+ ifcommon_elements:
+ raiseValueError(
+ f"RayCasterCamera class does not support the following sensor types: {common_elements}."
+ "\n\tThis is because these sensor types cannot be obtained in a fast way using ''warp''."
+ "\n\tHint: If you need to work with these sensor types, we recommend using the USD camera"
+ " interface from the omni.isaac.lab.sensors.camera module."
+ )
+
+ def_create_buffers(self):
+"""Create buffers for storing data."""
+ # prepare drift
+ self.drift=torch.zeros(self._view.count,3,device=self.device)
+ # create the data object
+ # -- pose of the cameras
+ self._data.pos_w=torch.zeros((self._view.count,3),device=self._device)
+ self._data.quat_w_world=torch.zeros((self._view.count,4),device=self._device)
+ # -- intrinsic matrix
+ self._data.intrinsic_matrices=torch.zeros((self._view.count,3,3),device=self._device)
+ self._data.intrinsic_matrices[:,2,2]=1.0
+ self._data.image_shape=self.image_shape
+ # -- output data
+ # create the buffers to store the annotator data.
+ self._data.output=TensorDict({},batch_size=self._view.count,device=self.device)
+ self._data.info=[{name:Nonefornameinself.cfg.data_types}]*self._view.count
+ fornameinself.cfg.data_types:
+ ifnamein["distance_to_image_plane","distance_to_camera"]:
+ shape=(self.cfg.pattern_cfg.height,self.cfg.pattern_cfg.width,1)
+ elifnamein["normals"]:
+ shape=(self.cfg.pattern_cfg.height,self.cfg.pattern_cfg.width,3)
+ else:
+ raiseValueError(f"Received unknown data type: {name}. Please check the configuration.")
+ # allocate tensor to store the data
+ self._data.output[name]=torch.zeros((self._view.count,*shape),device=self._device)
+
+ def_compute_intrinsic_matrices(self):
+"""Computes the intrinsic matrices for the camera based on the config provided."""
+ # get the sensor properties
+ pattern_cfg=self.cfg.pattern_cfg
+
+ # check if vertical aperture is provided
+ # if not then it is auto-computed based on the aspect ratio to preserve squared pixels
+ ifpattern_cfg.vertical_apertureisNone:
+ pattern_cfg.vertical_aperture=pattern_cfg.horizontal_aperture*pattern_cfg.height/pattern_cfg.width
+
+ # compute the intrinsic matrix
+ f_x=pattern_cfg.width*pattern_cfg.focal_length/pattern_cfg.horizontal_aperture
+ f_y=pattern_cfg.height*pattern_cfg.focal_length/pattern_cfg.vertical_aperture
+ c_x=pattern_cfg.horizontal_aperture_offset*f_x+pattern_cfg.width/2
+ c_y=pattern_cfg.vertical_aperture_offset*f_y+pattern_cfg.height/2
+ # allocate the intrinsic matrices
+ self._data.intrinsic_matrices[:,0,0]=f_x
+ self._data.intrinsic_matrices[:,0,2]=c_x
+ self._data.intrinsic_matrices[:,1,1]=f_y
+ self._data.intrinsic_matrices[:,1,2]=c_y
+
+ # save focal length
+ self._focal_length=pattern_cfg.focal_length
+
+ def_compute_view_world_poses(self,env_ids:Sequence[int])->tuple[torch.Tensor,torch.Tensor]:
+"""Obtains the pose of the view the camera is attached to in the world frame.
+
+ Returns:
+ A tuple of the position (in meters) and quaternion (w, x, y, z).
+ """
+ # obtain the poses of the sensors
+ # note: clone arg doesn't exist for xform prim view so we need to do this manually
+ ifisinstance(self._view,XFormPrimView):
+ pos_w,quat_w=self._view.get_world_poses(env_ids)
+ elifisinstance(self._view,physx.ArticulationView):
+ pos_w,quat_w=self._view.get_root_transforms()[env_ids].split([3,4],dim=-1)
+ quat_w=math_utils.convert_quat(quat_w,to="wxyz")
+ elifisinstance(self._view,physx.RigidBodyView):
+ pos_w,quat_w=self._view.get_transforms()[env_ids].split([3,4],dim=-1)
+ quat_w=math_utils.convert_quat(quat_w,to="wxyz")
+ else:
+ raiseRuntimeError(f"Unsupported view type: {type(self._view)}")
+ # return the pose
+ returnpos_w.clone(),quat_w.clone()
+
+ def_compute_camera_world_poses(self,env_ids:Sequence[int])->tuple[torch.Tensor,torch.Tensor]:
+"""Computes the pose of the camera in the world frame.
+
+ This function applies the offset pose to the pose of the view the camera is attached to.
+
+ Returns:
+ A tuple of the position (in meters) and quaternion (w, x, y, z) in "world" convention.
+ """
+ # get the pose of the view the camera is attached to
+ pos_w,quat_w=self._compute_view_world_poses(env_ids)
+ # apply offsets
+ # need to apply quat because offset relative to parent frame
+ pos_w+=math_utils.quat_apply(quat_w,self._offset_pos[env_ids])
+ quat_w=math_utils.quat_mul(quat_w,self._offset_quat[env_ids])
+
+ returnpos_w,quat_w
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Configuration for the ray-cast camera sensor."""
+
+fromdataclassesimportMISSING
+fromtypingimportLiteral
+
+fromomni.isaac.lab.utilsimportconfigclass
+
+from.patternsimportPinholeCameraPatternCfg
+from.ray_caster_cameraimportRayCasterCamera
+from.ray_caster_cfgimportRayCasterCfg
+
+
+
[文档]@configclass
+classRayCasterCameraCfg(RayCasterCfg):
+"""Configuration for the ray-cast sensor."""
+
+
[文档]@configclass
+ classOffsetCfg:
+"""The offset pose of the sensor's frame from the sensor's parent frame."""
+
+ pos:tuple[float,float,float]=(0.0,0.0,0.0)
+"""Translation w.r.t. the parent frame. Defaults to (0.0, 0.0, 0.0)."""
+
+ rot:tuple[float,float,float,float]=(1.0,0.0,0.0,0.0)
+"""Quaternion rotation (w, x, y, z) w.r.t. the parent frame. Defaults to (1.0, 0.0, 0.0, 0.0)."""
+
+ convention:Literal["opengl","ros","world"]="ros"
+"""The convention in which the frame offset is applied. Defaults to "ros".
+
+ - ``"opengl"`` - forward axis: ``-Z`` - up axis: ``+Y`` - Offset is applied in the OpenGL (Usd.Camera) convention.
+ - ``"ros"`` - forward axis: ``+Z`` - up axis: ``-Y`` - Offset is applied in the ROS convention.
+ - ``"world"`` - forward axis: ``+X`` - up axis: ``+Z`` - Offset is applied in the World Frame convention.
+
+ """
+
+ class_type:type=RayCasterCamera
+
+ offset:OffsetCfg=OffsetCfg()
+"""The offset pose of the sensor's frame from the sensor's parent frame. Defaults to identity."""
+
+ data_types:list[str]=["distance_to_image_plane"]
+"""List of sensor names/types to enable for the camera. Defaults to ["distance_to_image_plane"]."""
+
+ pattern_cfg:PinholeCameraPatternCfg=MISSING
+"""The pattern that defines the local ray starting positions and directions in a pinhole camera pattern."""
+
+ def__post_init__(self):
+ # for cameras, this quantity should be False always.
+ self.attach_yaw_only=False
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Configuration for the ray-cast sensor."""
+
+
+fromdataclassesimportMISSING
+
+fromomni.isaac.lab.markersimportVisualizationMarkersCfg
+fromomni.isaac.lab.markers.configimportRAY_CASTER_MARKER_CFG
+fromomni.isaac.lab.utilsimportconfigclass
+
+from..sensor_base_cfgimportSensorBaseCfg
+from.patterns.patterns_cfgimportPatternBaseCfg
+from.ray_casterimportRayCaster
+
+
+
[文档]@configclass
+classRayCasterCfg(SensorBaseCfg):
+"""Configuration for the ray-cast sensor."""
+
+
[文档]@configclass
+ classOffsetCfg:
+"""The offset pose of the sensor's frame from the sensor's parent frame."""
+
+ pos:tuple[float,float,float]=(0.0,0.0,0.0)
+"""Translation w.r.t. the parent frame. Defaults to (0.0, 0.0, 0.0)."""
+ rot:tuple[float,float,float,float]=(1.0,0.0,0.0,0.0)
+"""Quaternion rotation (w, x, y, z) w.r.t. the parent frame. Defaults to (1.0, 0.0, 0.0, 0.0)."""
+
+ class_type:type=RayCaster
+
+ mesh_prim_paths:list[str]=MISSING
+"""The list of mesh primitive paths to ray cast against.
+
+ Note:
+ Currently, only a single static mesh is supported. We are working on supporting multiple
+ static meshes and dynamic meshes.
+ """
+
+ offset:OffsetCfg=OffsetCfg()
+"""The offset pose of the sensor's frame from the sensor's parent frame. Defaults to identity."""
+
+ attach_yaw_only:bool=MISSING
+"""Whether the rays' starting positions and directions only track the yaw orientation.
+
+ This is useful for ray-casting height maps, where only yaw rotation is needed.
+ """
+
+ pattern_cfg:PatternBaseCfg=MISSING
+"""The pattern that defines the local ray starting positions and directions."""
+
+ max_distance:float=1e6
+"""Maximum distance (in meters) from the sensor to ray cast to. Defaults to 1e6."""
+
+ drift_range:tuple[float,float]=(0.0,0.0)
+"""The range of drift (in meters) to add to the ray starting positions (xyz). Defaults to (0.0, 0.0).
+
+ For floating base robots, this is useful for simulating drift in the robot's pose estimation.
+ """
+
+ visualizer_cfg:VisualizationMarkersCfg=RAY_CASTER_MARKER_CFG.replace(prim_path="/Visuals/RayCaster")
+"""The configuration object for the visualization markers. Defaults to RAY_CASTER_MARKER_CFG.
+
+ Note:
+ This attribute is only used when debug visualization is enabled.
+ """
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+importtorch
+fromdataclassesimportdataclass
+
+
+
[文档]@dataclass
+classRayCasterData:
+"""Data container for the ray-cast sensor."""
+
+ pos_w:torch.Tensor=None
+"""Position of the sensor origin in world frame.
+
+ Shape is (N, 3), where N is the number of sensors.
+ """
+ quat_w:torch.Tensor=None
+"""Orientation of the sensor origin in quaternion (w, x, y, z) in world frame.
+
+ Shape is (N, 4), where N is the number of sensors.
+ """
+ ray_hits_w:torch.Tensor=None
+"""The ray hit positions in the world frame.
+
+ Shape is (N, B, 3), where N is the number of sensors, B is the number of rays
+ in the scan pattern per sensor.
+ """
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Base class for sensors.
+
+This class defines an interface for sensors similar to how the :class:`omni.isaac.lab.assets.AssetBase` class works.
+Each sensor class should inherit from this class and implement the abstract methods.
+"""
+
+from__future__importannotations
+
+importinspect
+importtorch
+importweakref
+fromabcimportABC,abstractmethod
+fromcollections.abcimportSequence
+fromtypingimportTYPE_CHECKING,Any
+
+importomni.kit.app
+importomni.timeline
+
+importomni.isaac.lab.simassim_utils
+
+ifTYPE_CHECKING:
+ from.sensor_base_cfgimportSensorBaseCfg
+
+
+
[文档]classSensorBase(ABC):
+"""The base class for implementing a sensor.
+
+ The implementation is based on lazy evaluation. The sensor data is only updated when the user
+ tries accessing the data through the :attr:`data` property or sets ``force_compute=True`` in
+ the :meth:`update` method. This is done to avoid unnecessary computation when the sensor data
+ is not used.
+
+ The sensor is updated at the specified update period. If the update period is zero, then the
+ sensor is updated at every simulation step.
+ """
+
+
[文档]def__init__(self,cfg:SensorBaseCfg):
+"""Initialize the sensor class.
+
+ Args:
+ cfg: The configuration parameters for the sensor.
+ """
+ # check that config is valid
+ ifcfg.history_length<0:
+ raiseValueError(f"History length must be greater than 0! Received: {cfg.history_length}")
+ # store inputs
+ self.cfg=cfg
+ # flag for whether the sensor is initialized
+ self._is_initialized=False
+ # flag for whether the sensor is in visualization mode
+ self._is_visualizing=False
+
+ # note: Use weakref on callbacks to ensure that this object can be deleted when its destructor is called.
+ # add callbacks for stage play/stop
+ # The order is set to 10 which is arbitrary but should be lower priority than the default order of 0
+ timeline_event_stream=omni.timeline.get_timeline_interface().get_timeline_event_stream()
+ self._initialize_handle=timeline_event_stream.create_subscription_to_pop_by_type(
+ int(omni.timeline.TimelineEventType.PLAY),
+ lambdaevent,obj=weakref.proxy(self):obj._initialize_callback(event),
+ order=10,
+ )
+ self._invalidate_initialize_handle=timeline_event_stream.create_subscription_to_pop_by_type(
+ int(omni.timeline.TimelineEventType.STOP),
+ lambdaevent,obj=weakref.proxy(self):obj._invalidate_initialize_callback(event),
+ order=10,
+ )
+ # add handle for debug visualization (this is set to a valid handle inside set_debug_vis)
+ self._debug_vis_handle=None
+ # set initial state of debug visualization
+ self.set_debug_vis(self.cfg.debug_vis)
+
+ def__del__(self):
+"""Unsubscribe from the callbacks."""
+ # clear physics events handles
+ ifself._initialize_handle:
+ self._initialize_handle.unsubscribe()
+ self._initialize_handle=None
+ ifself._invalidate_initialize_handle:
+ self._invalidate_initialize_handle.unsubscribe()
+ self._invalidate_initialize_handle=None
+ # clear debug visualization
+ ifself._debug_vis_handle:
+ self._debug_vis_handle.unsubscribe()
+ self._debug_vis_handle=None
+
+"""
+ Properties
+ """
+
+ @property
+ defis_initialized(self)->bool:
+"""Whether the sensor is initialized.
+
+ Returns True if the sensor is initialized, False otherwise.
+ """
+ returnself._is_initialized
+
+ @property
+ defnum_instances(self)->int:
+"""Number of instances of the sensor.
+
+ This is equal to the number of sensors per environment multiplied by the number of environments.
+ """
+ returnself._num_envs
+
+ @property
+ defdevice(self)->str:
+"""Memory device for computation."""
+ returnself._device
+
+ @property
+ @abstractmethod
+ defdata(self)->Any:
+"""Data from the sensor.
+
+ This property is only updated when the user tries to access the data. This is done to avoid
+ unnecessary computation when the sensor data is not used.
+
+ For updating the sensor when this property is accessed, you can use the following
+ code snippet in your sensor implementation:
+
+ .. code-block:: python
+
+ # update sensors if needed
+ self._update_outdated_buffers()
+ # return the data (where `_data` is the data for the sensor)
+ return self._data
+ """
+ raiseNotImplementedError
+
+ @property
+ defhas_debug_vis_implementation(self)->bool:
+"""Whether the sensor has a debug visualization implemented."""
+ # check if function raises NotImplementedError
+ source_code=inspect.getsource(self._set_debug_vis_impl)
+ return"NotImplementedError"notinsource_code
+
+"""
+ Operations
+ """
+
+
[文档]defset_debug_vis(self,debug_vis:bool)->bool:
+"""Sets whether to visualize the sensor data.
+
+ Args:
+ debug_vis: Whether to visualize the sensor data.
+
+ Returns:
+ Whether the debug visualization was successfully set. False if the sensor
+ does not support debug visualization.
+ """
+ # check if debug visualization is supported
+ ifnotself.has_debug_vis_implementation:
+ returnFalse
+ # toggle debug visualization objects
+ self._set_debug_vis_impl(debug_vis)
+ # toggle debug visualization flag
+ self._is_visualizing=debug_vis
+ # toggle debug visualization handles
+ ifdebug_vis:
+ # create a subscriber for the post update event if it doesn't exist
+ ifself._debug_vis_handleisNone:
+ app_interface=omni.kit.app.get_app_interface()
+ self._debug_vis_handle=app_interface.get_post_update_event_stream().create_subscription_to_pop(
+ lambdaevent,obj=weakref.proxy(self):obj._debug_vis_callback(event)
+ )
+ else:
+ # remove the subscriber if it exists
+ ifself._debug_vis_handleisnotNone:
+ self._debug_vis_handle.unsubscribe()
+ self._debug_vis_handle=None
+ # return success
+ returnTrue
+
+
[文档]defreset(self,env_ids:Sequence[int]|None=None):
+"""Resets the sensor internals.
+
+ Args:
+ env_ids: The sensor ids to reset. Defaults to None.
+ """
+ # Resolve sensor ids
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ # Reset the timestamp for the sensors
+ self._timestamp[env_ids]=0.0
+ self._timestamp_last_update[env_ids]=0.0
+ # Set all reset sensors to outdated so that they are updated when data is called the next time.
+ self._is_outdated[env_ids]=True
+
+ defupdate(self,dt:float,force_recompute:bool=False):
+ # Update the timestamp for the sensors
+ self._timestamp+=dt
+ self._is_outdated|=self._timestamp-self._timestamp_last_update+1e-6>=self.cfg.update_period
+ # Update the buffers
+ # TODO (from @mayank): Why is there a history length here when it doesn't mean anything in the sensor base?!?
+ # It is only for the contact sensor but there we should redefine the update function IMO.
+ ifforce_recomputeorself._is_visualizingor(self.cfg.history_length>0):
+ self._update_outdated_buffers()
+
+"""
+ Implementation specific.
+ """
+
+ @abstractmethod
+ def_initialize_impl(self):
+"""Initializes the sensor-related handles and internal buffers."""
+ # Obtain Simulation Context
+ sim=sim_utils.SimulationContext.instance()
+ ifsimisNone:
+ raiseRuntimeError("Simulation Context is not initialized!")
+ # Obtain device and backend
+ self._device=sim.device
+ self._backend=sim.backend
+ self._sim_physics_dt=sim.get_physics_dt()
+ # Count number of environments
+ env_prim_path_expr=self.cfg.prim_path.rsplit("/",1)[0]
+ self._parent_prims=sim_utils.find_matching_prims(env_prim_path_expr)
+ self._num_envs=len(self._parent_prims)
+ # Boolean tensor indicating whether the sensor data has to be refreshed
+ self._is_outdated=torch.ones(self._num_envs,dtype=torch.bool,device=self._device)
+ # Current timestamp (in seconds)
+ self._timestamp=torch.zeros(self._num_envs,device=self._device)
+ # Timestamp from last update
+ self._timestamp_last_update=torch.zeros_like(self._timestamp)
+
+ @abstractmethod
+ def_update_buffers_impl(self,env_ids:Sequence[int]):
+"""Fills the sensor data for provided environment ids.
+
+ This function does not perform any time-based checks and directly fills the data into the
+ data container.
+
+ Args:
+ env_ids: The indices of the sensors that are ready to capture.
+ """
+ raiseNotImplementedError
+
+ def_set_debug_vis_impl(self,debug_vis:bool):
+"""Set debug visualization into visualization objects.
+
+ This function is responsible for creating the visualization objects if they don't exist
+ and input ``debug_vis`` is True. If the visualization objects exist, the function should
+ set their visibility into the stage.
+ """
+ raiseNotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.")
+
+ def_debug_vis_callback(self,event):
+"""Callback for debug visualization.
+
+ This function calls the visualization objects and sets the data to visualize into them.
+ """
+ raiseNotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.")
+
+"""
+ Internal simulation callbacks.
+ """
+
+ def_initialize_callback(self,event):
+"""Initializes the scene elements.
+
+ Note:
+ PhysX handles are only enabled once the simulator starts playing. Hence, this function needs to be
+ called whenever the simulator "plays" from a "stop" state.
+ """
+ ifnotself._is_initialized:
+ self._initialize_impl()
+ self._is_initialized=True
+
+ def_invalidate_initialize_callback(self,event):
+"""Invalidates the scene elements."""
+ self._is_initialized=False
+
+"""
+ Helper functions.
+ """
+
+ def_update_outdated_buffers(self):
+"""Fills the sensor data for the outdated sensors."""
+ outdated_env_ids=self._is_outdated.nonzero().squeeze(-1)
+ iflen(outdated_env_ids)>0:
+ # obtain new data
+ self._update_buffers_impl(outdated_env_ids)
+ # update the timestamp from last update
+ self._timestamp_last_update[outdated_env_ids]=self._timestamp[outdated_env_ids]
+ # set outdated flag to false for the updated sensors
+ self._is_outdated[outdated_env_ids]=False
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+fromdataclassesimportMISSING
+
+fromomni.isaac.lab.utilsimportconfigclass
+
+from.sensor_baseimportSensorBase
+
+
+
[文档]@configclass
+classSensorBaseCfg:
+"""Configuration parameters for a sensor."""
+
+ class_type:type[SensorBase]=MISSING
+"""The associated sensor class.
+
+ The class should inherit from :class:`omni.isaac.lab.sensors.sensor_base.SensorBase`.
+ """
+
+ prim_path:str=MISSING
+"""Prim path (or expression) to the sensor.
+
+ .. note::
+ The expression can contain the environment namespace regex ``{ENV_REGEX_NS}`` which
+ will be replaced with the environment namespace.
+
+ Example: ``{ENV_REGEX_NS}/Robot/sensor`` will be replaced with ``/World/envs/env_.*/Robot/sensor``.
+
+ """
+
+ update_period:float=0.0
+"""Update period of the sensor buffers (in seconds). Defaults to 0.0 (update every step)."""
+
+ history_length:int=0
+"""Number of past frames to store in the sensor buffers. Defaults to 0, which means that only
+ the current data is stored (no history)."""
+
+ debug_vis:bool=False
+"""Whether to visualize the sensor. Defaults to False."""
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+importabc
+importhashlib
+importjson
+importos
+importpathlib
+importrandom
+fromdatetimeimportdatetime
+
+fromomni.isaac.lab.sim.converters.asset_converter_base_cfgimportAssetConverterBaseCfg
+fromomni.isaac.lab.utils.assetsimportcheck_file_path
+fromomni.isaac.lab.utils.ioimportdump_yaml
+
+
+
[文档]classAssetConverterBase(abc.ABC):
+"""Base class for converting an asset file from different formats into USD format.
+
+ This class provides a common interface for converting an asset file into USD. It does not
+ provide any implementation for the conversion. The derived classes must implement the
+ :meth:`_convert_asset` method to provide the actual conversion.
+
+ The file conversion is lazy if the output directory (:obj:`AssetConverterBaseCfg.usd_dir`) is provided.
+ In the lazy conversion, the USD file is re-generated only if:
+
+ * The asset file is modified.
+ * The configuration parameters are modified.
+ * The USD file does not exist.
+
+ To override this behavior to force conversion, the flag :obj:`AssetConverterBaseCfg.force_usd_conversion`
+ can be set to True.
+
+ When no output directory is defined, lazy conversion is deactivated and the generated USD file is
+ stored in folder ``/tmp/IsaacLab/usd_{date}_{time}_{random}``, where the parameters in braces are generated
+ at runtime. The random identifiers help avoid a race condition where two simultaneously triggered conversions
+ try to use the same directory for reading/writing the generated files.
+
+ .. note::
+ Changes to the parameters :obj:`AssetConverterBaseCfg.asset_path`, :obj:`AssetConverterBaseCfg.usd_dir`, and
+ :obj:`AssetConverterBaseCfg.usd_file_name` are not considered as modifications in the configuration instance that
+ trigger USD file re-generation.
+
+ """
+
+
[文档]def__init__(self,cfg:AssetConverterBaseCfg):
+"""Initializes the class.
+
+ Args:
+ cfg: The configuration instance for converting an asset file to USD format.
+
+ Raises:
+ ValueError: When provided asset file does not exist.
+ """
+ # check if the asset file exists
+ ifnotcheck_file_path(cfg.asset_path):
+ raiseValueError(f"The asset path does not exist: {cfg.asset_path}")
+ # save the inputs
+ self.cfg=cfg
+
+ # resolve USD directory name
+ ifcfg.usd_dirisNone:
+ # a folder in "/tmp/IsaacLab" by the name: usd_{date}_{time}_{random}
+ time_tag=datetime.now().strftime("%Y%m%d_%H%M%S")
+ self._usd_dir=f"/tmp/IsaacLab/usd_{time_tag}_{random.randrange(10000)}"
+ else:
+ self._usd_dir=cfg.usd_dir
+
+ # resolve the file name from asset file name if not provided
+ ifcfg.usd_file_nameisNone:
+ usd_file_name=pathlib.PurePath(cfg.asset_path).stem
+ else:
+ usd_file_name=cfg.usd_file_name
+ # add USD extension if not provided
+ ifnot(usd_file_name.endswith(".usd")orusd_file_name.endswith(".usda")):
+ self._usd_file_name=usd_file_name+".usd"
+ else:
+ self._usd_file_name=usd_file_name
+
+ # create the USD directory
+ os.makedirs(self.usd_dir,exist_ok=True)
+ # check if usd files exist
+ self._usd_file_exists=os.path.isfile(self.usd_path)
+ # path to read/write asset hash file
+ self._dest_hash_path=os.path.join(self.usd_dir,".asset_hash")
+ # create asset hash to check if the asset has changed
+ self._asset_hash=self._config_to_hash(cfg)
+ # read the saved hash
+ try:
+ withopen(self._dest_hash_path)asf:
+ existing_asset_hash=f.readline()
+ self._is_same_asset=existing_asset_hash==self._asset_hash
+ exceptFileNotFoundError:
+ self._is_same_asset=False
+
+ # convert the asset to USD if the hash is different or USD file does not exist
+ ifcfg.force_usd_conversionornotself._usd_file_existsornotself._is_same_asset:
+ # write the updated hash
+ withopen(self._dest_hash_path,"w")asf:
+ f.write(self._asset_hash)
+ # convert the asset to USD
+ self._convert_asset(cfg)
+ # dump the configuration to a file
+ dump_yaml(os.path.join(self.usd_dir,"config.yaml"),cfg.to_dict())
+ # add comment to top of the saved config file with information about the converter
+ current_date=datetime.now().strftime("%Y-%m-%d")
+ current_time=datetime.now().strftime("%H:%M:%S")
+ generation_comment=(
+ f"##\n# Generated by {self.__class__.__name__} on {current_date} at {current_time}.\n##\n"
+ )
+ withopen(os.path.join(self.usd_dir,"config.yaml"),"a")asf:
+ f.write(generation_comment)
+
+"""
+ Properties.
+ """
+
+ @property
+ defusd_dir(self)->str:
+"""The absolute path to the directory where the generated USD files are stored."""
+ returnself._usd_dir
+
+ @property
+ defusd_file_name(self)->str:
+"""The file name of the generated USD file."""
+ returnself._usd_file_name
+
+ @property
+ defusd_path(self)->str:
+"""The absolute path to the generated USD file."""
+ returnos.path.join(self.usd_dir,self.usd_file_name)
+
+ @property
+ defusd_instanceable_meshes_path(self)->str:
+"""The relative path to the USD file with meshes.
+
+ The path is with respect to the USD directory :attr:`usd_dir`. This is to ensure that the
+ mesh references in the generated USD file are resolved relatively. Otherwise, it becomes
+ difficult to move the USD asset to a different location.
+ """
+ returnos.path.join(".","Props","instanceable_meshes.usd")
+
+"""
+ Implementation specifics.
+ """
+
+ @abc.abstractmethod
+ def_convert_asset(self,cfg:AssetConverterBaseCfg):
+"""Converts the asset file to USD.
+
+ Args:
+ cfg: The configuration instance for the input asset to USD conversion.
+ """
+ raiseNotImplementedError()
+
+"""
+ Private helpers.
+ """
+
+ @staticmethod
+ def_config_to_hash(cfg:AssetConverterBaseCfg)->str:
+"""Converts the configuration object and asset file to an MD5 hash of a string.
+
+ .. warning::
+ It only checks the main asset file (:attr:`cfg.asset_path`).
+
+ Args:
+ config : The asset converter configuration object.
+
+ Returns:
+ An MD5 hash of a string.
+ """
+
+ # convert to dict and remove path related info
+ config_dic=cfg.to_dict()
+ _=config_dic.pop("asset_path")
+ _=config_dic.pop("usd_dir")
+ _=config_dic.pop("usd_file_name")
+ # convert config dic to bytes
+ config_bytes=json.dumps(config_dic).encode()
+ # hash config
+ md5=hashlib.md5()
+ md5.update(config_bytes)
+
+ # read the asset file to observe changes
+ withopen(cfg.asset_path,"rb")asf:
+ whileTrue:
+ # read 64kb chunks to avoid memory issues for the large files!
+ data=f.read(65536)
+ ifnotdata:
+ break
+ md5.update(data)
+ # return the hash
+ returnmd5.hexdigest()
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+fromdataclassesimportMISSING
+
+fromomni.isaac.lab.utilsimportconfigclass
+
+
+
[文档]@configclass
+classAssetConverterBaseCfg:
+"""The base configuration class for asset converters."""
+
+ asset_path:str=MISSING
+"""The absolute path to the asset file to convert into USD."""
+
+ usd_dir:str|None=None
+"""The output directory path to store the generated USD file. Defaults to None.
+
+ If None, it is resolved as ``/tmp/IsaacLab/usd_{date}_{time}_{random}``, where
+ the parameters in braces are runtime generated.
+ """
+
+ usd_file_name:str|None=None
+"""The name of the generated usd file. Defaults to None.
+
+ If None, it is resolved from the asset file name. For example, if the asset file
+ name is ``"my_asset.urdf"``, then the generated USD file name is ``"my_asset.usd"``.
+
+ If the providing file name does not end with ".usd" or ".usda", then the extension
+ ".usd" is appended to the file name.
+ """
+
+ force_usd_conversion:bool=False
+"""Force the conversion of the asset file to usd. Defaults to False.
+
+ If True, then the USD file is always generated. It will overwrite the existing USD file if it exists.
+ """
+
+ make_instanceable:bool=True
+"""Make the generated USD file instanceable. Defaults to True.
+
+ Note:
+ Instancing helps reduce the memory footprint of the asset when multiple copies of the asset are
+ used in the scene. For more information, please check the USD documentation on
+ `scene-graph instancing <https://openusd.org/dev/api/_usd__page__scenegraph_instancing.html>`_.
+ """
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+importasyncio
+importos
+
+importomni
+importomni.kit.commands
+importomni.usd
+fromomni.isaac.core.utils.extensionsimportenable_extension
+frompxrimportUsd,UsdGeom,UsdPhysics,UsdUtils
+
+fromomni.isaac.lab.sim.converters.asset_converter_baseimportAssetConverterBase
+fromomni.isaac.lab.sim.converters.mesh_converter_cfgimportMeshConverterCfg
+fromomni.isaac.lab.sim.schemasimportschemas
+fromomni.isaac.lab.sim.utilsimportexport_prim_to_file
+
+
+
[文档]classMeshConverter(AssetConverterBase):
+"""Converter for a mesh file in OBJ / STL / FBX format to a USD file.
+
+ This class wraps around the `omni.kit.asset_converter`_ extension to provide a lazy implementation
+ for mesh to USD conversion. It stores the output USD file in an instanceable format since that is
+ what is typically used in all learning related applications.
+
+ To make the asset instanceable, we must follow a certain structure dictated by how USD scene-graph
+ instancing and physics work. The rigid body component must be added to each instance and not the
+ referenced asset (i.e. the prototype prim itself). This is because the rigid body component defines
+ properties that are specific to each instance and cannot be shared under the referenced asset. For
+ more information, please check the `documentation <https://docs.omniverse.nvidia.com/extensions/latest/ext_physics/rigid-bodies.html#instancing-rigid-bodies>`_.
+
+ Due to the above, we follow the following structure:
+
+ * ``{prim_path}`` - The root prim that is an Xform with the rigid body and mass APIs if configured.
+ * ``{prim_path}/geometry`` - The prim that contains the mesh and optionally the materials if configured.
+ If instancing is enabled, this prim will be an instanceable reference to the prototype prim.
+
+ .. _omni.kit.asset_converter: https://docs.omniverse.nvidia.com/extensions/latest/ext_asset-converter.html
+
+ .. caution::
+ When converting STL files, Z-up convention is assumed, even though this is not the default for many CAD
+ export programs. Asset orientation convention can either be modified directly in the CAD program's export
+ process or an offset can be added within the config in Isaac Lab.
+
+ """
+
+ cfg:MeshConverterCfg
+"""The configuration instance for mesh to USD conversion."""
+
+
[文档]def__init__(self,cfg:MeshConverterCfg):
+"""Initializes the class.
+
+ Args:
+ cfg: The configuration instance for mesh to USD conversion.
+ """
+ super().__init__(cfg=cfg)
+
+"""
+ Implementation specific methods.
+ """
+
+ def_convert_asset(self,cfg:MeshConverterCfg):
+"""Generate USD from OBJ, STL or FBX.
+
+ It stores the asset in the following format:
+
+ /file_name (default prim)
+ |- /geometry <- Made instanceable if requested
+ |- /Looks
+ |- /mesh
+
+ Args:
+ cfg: The configuration for conversion of mesh to USD.
+
+ Raises:
+ RuntimeError: If the conversion using the Omniverse asset converter fails.
+ """
+ # resolve mesh name and format
+ mesh_file_basename,mesh_file_format=os.path.basename(cfg.asset_path).split(".")
+ mesh_file_format=mesh_file_format.lower()
+
+ # Convert USD
+ asyncio.get_event_loop().run_until_complete(
+ self._convert_mesh_to_usd(
+ in_file=cfg.asset_path,out_file=self.usd_path,prim_path=f"/{mesh_file_basename}"
+ )
+ )
+ # Open converted USD stage
+ # note: This opens a new stage and does not use the stage created earlier by the user
+ # create a new stage
+ stage=Usd.Stage.Open(self.usd_path)
+ # add USD to stage cache
+ stage_id=UsdUtils.StageCache.Get().Insert(stage)
+ # Get the default prim (which is the root prim) -- "/{mesh_file_basename}"
+ xform_prim=stage.GetDefaultPrim()
+ geom_prim=stage.GetPrimAtPath(f"/{mesh_file_basename}/geometry")
+ # Move all meshes to underneath new Xform
+ forchild_mesh_primingeom_prim.GetChildren():
+ ifchild_mesh_prim.GetTypeName()=="Mesh":
+ # Apply collider properties to mesh
+ ifcfg.collision_propsisnotNone:
+ # -- Collision approximation to mesh
+ # TODO: Move this to a new Schema: https://github.com/isaac-orbit/IsaacLab/issues/163
+ mesh_collision_api=UsdPhysics.MeshCollisionAPI.Apply(child_mesh_prim)
+ mesh_collision_api.GetApproximationAttr().Set(cfg.collision_approximation)
+ # -- Collider properties such as offset, scale, etc.
+ schemas.define_collision_properties(
+ prim_path=child_mesh_prim.GetPath(),cfg=cfg.collision_props,stage=stage
+ )
+ # Delete the old Xform and make the new Xform the default prim
+ stage.SetDefaultPrim(xform_prim)
+ # Handle instanceable
+ # Create a new Xform prim that will be the prototype prim
+ ifcfg.make_instanceable:
+ # Export Xform to a file so we can reference it from all instances
+ export_prim_to_file(
+ path=os.path.join(self.usd_dir,self.usd_instanceable_meshes_path),
+ source_prim_path=geom_prim.GetPath(),
+ stage=stage,
+ )
+ # Delete the original prim that will now be a reference
+ geom_prim_path=geom_prim.GetPath().pathString
+ omni.kit.commands.execute("DeletePrims",paths=[geom_prim_path],stage=stage)
+ # Update references to exported Xform and make it instanceable
+ geom_undef_prim=stage.DefinePrim(geom_prim_path)
+ geom_undef_prim.GetReferences().AddReference(self.usd_instanceable_meshes_path,primPath=geom_prim_path)
+ geom_undef_prim.SetInstanceable(True)
+
+ # Apply mass and rigid body properties after everything else
+ # Properties are applied to the top level prim to avoid the case where all instances of this
+ # asset unintentionally share the same rigid body properties
+ # apply mass properties
+ ifcfg.mass_propsisnotNone:
+ schemas.define_mass_properties(prim_path=xform_prim.GetPath(),cfg=cfg.mass_props,stage=stage)
+ # apply rigid body properties
+ ifcfg.rigid_propsisnotNone:
+ schemas.define_rigid_body_properties(prim_path=xform_prim.GetPath(),cfg=cfg.rigid_props,stage=stage)
+
+ # Save changes to USD stage
+ stage.Save()
+ ifstage_idisnotNone:
+ UsdUtils.StageCache.Get().Erase(stage_id)
+
+"""
+ Helper methods.
+ """
+
+ @staticmethod
+ asyncdef_convert_mesh_to_usd(
+ in_file:str,out_file:str,prim_path:str="/World",load_materials:bool=True
+ )->bool:
+"""Convert mesh from supported file types to USD.
+
+ This function uses the Omniverse Asset Converter extension to convert a mesh file to USD.
+ It is an asynchronous function and should be called using `asyncio.get_event_loop().run_until_complete()`.
+
+ The converted asset is stored in the USD format in the specified output file.
+ The USD file has Y-up axis and is scaled to meters.
+
+ The asset hierarchy is arranged as follows:
+
+ .. code-block:: none
+ prim_path (default prim)
+ |- /geometry/Looks
+ |- /geometry/mesh
+
+ Args:
+ in_file: The file to convert.
+ out_file: The path to store the output file.
+ prim_path: The prim path of the mesh.
+ load_materials: Set to True to enable attaching materials defined in the input file
+ to the generated USD mesh. Defaults to True.
+
+ Returns:
+ True if the conversion succeeds.
+ """
+ enable_extension("omni.kit.asset_converter")
+ enable_extension("omni.usd.metrics.assembler")
+
+ importomni.kit.asset_converter
+ importomni.usd
+ fromomni.metrics.assembler.coreimportget_metrics_assembler_interface
+
+ # Create converter context
+ converter_context=omni.kit.asset_converter.AssetConverterContext()
+ # Set up converter settings
+ # Don't import/export materials
+ converter_context.ignore_materials=notload_materials
+ converter_context.ignore_animations=True
+ converter_context.ignore_camera=True
+ converter_context.ignore_light=True
+ # Merge all meshes into one
+ converter_context.merge_all_meshes=True
+ # Sets world units to meters, this will also scale asset if it's centimeters model.
+ # This does not work right now :(, so we need to scale the mesh manually
+ converter_context.use_meter_as_world_unit=True
+ converter_context.baking_scales=True
+ # Uses double precision for all transform ops.
+ converter_context.use_double_precision_to_usd_transform_op=True
+
+ # Create converter task
+ instance=omni.kit.asset_converter.get_instance()
+ out_file_non_metric=out_file.replace(".usd","_non_metric.usd")
+ task=instance.create_converter_task(in_file,out_file_non_metric,None,converter_context)
+ # Start conversion task and wait for it to finish
+ success=True
+ whileTrue:
+ success=awaittask.wait_until_finished()
+ ifnotsuccess:
+ awaitasyncio.sleep(0.1)
+ else:
+ break
+
+ temp_stage=Usd.Stage.CreateInMemory()
+ UsdGeom.SetStageUpAxis(temp_stage,UsdGeom.Tokens.z)
+ UsdGeom.SetStageMetersPerUnit(temp_stage,1.0)
+ UsdPhysics.SetStageKilogramsPerUnit(temp_stage,1.0)
+
+ base_prim=temp_stage.DefinePrim(prim_path,"Xform")
+ prim=temp_stage.DefinePrim(f"{prim_path}/geometry","Xform")
+ prim.GetReferences().AddReference(out_file_non_metric)
+ cache=UsdUtils.StageCache.Get()
+ cache.Insert(temp_stage)
+ stage_id=cache.GetId(temp_stage).ToLongInt()
+ get_metrics_assembler_interface().resolve_stage(stage_id)
+ temp_stage.SetDefaultPrim(base_prim)
+ temp_stage.Export(out_file)
+ returnsuccess
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+fromomni.isaac.lab.sim.converters.asset_converter_base_cfgimportAssetConverterBaseCfg
+fromomni.isaac.lab.sim.schemasimportschemas_cfg
+fromomni.isaac.lab.utilsimportconfigclass
+
+
+
[文档]@configclass
+classMeshConverterCfg(AssetConverterBaseCfg):
+"""The configuration class for MeshConverter."""
+
+ mass_props:schemas_cfg.MassPropertiesCfg=None
+"""Mass properties to apply to the USD. Defaults to None.
+
+ Note:
+ If None, then no mass properties will be added.
+ """
+
+ rigid_props:schemas_cfg.RigidBodyPropertiesCfg=None
+"""Rigid body properties to apply to the USD. Defaults to None.
+
+ Note:
+ If None, then no rigid body properties will be added.
+ """
+
+ collision_props:schemas_cfg.CollisionPropertiesCfg=None
+"""Collision properties to apply to the USD. Defaults to None.
+
+ Note:
+ If None, then no collision properties will be added.
+ """
+
+ collision_approximation:str="convexDecomposition"
+"""Collision approximation method to use. Defaults to "convexDecomposition".
+
+ Valid options are:
+ "convexDecomposition", "convexHull", "boundingCube",
+ "boundingSphere", "meshSimplification", or "none"
+
+ "none" causes no collision mesh to be added.
+ """
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+from__future__importannotations
+
+importos
+
+importomni.kit.commands
+importomni.usd
+fromomni.isaac.core.utils.extensionsimportenable_extension
+fromomni.isaac.versionimportget_version
+frompxrimportUsd
+
+from.asset_converter_baseimportAssetConverterBase
+from.urdf_converter_cfgimportUrdfConverterCfg
+
+_DRIVE_TYPE={
+ "none":0,
+ "position":1,
+ "velocity":2,
+}
+"""Mapping from drive type name to URDF importer drive number."""
+
+_NORMALS_DIVISION={
+ "catmullClark":0,
+ "loop":1,
+ "bilinear":2,
+ "none":3,
+}
+"""Mapping from normals division name to urdf importer normals division number."""
+
+
+
[文档]classUrdfConverter(AssetConverterBase):
+"""Converter for a URDF description file to a USD file.
+
+ This class wraps around the `omni.isaac.urdf_importer`_ extension to provide a lazy implementation
+ for URDF to USD conversion. It stores the output USD file in an instanceable format since that is
+ what is typically used in all learning related applications.
+
+ .. caution::
+ The current lazy conversion implementation does not automatically trigger USD generation if
+ only the mesh files used by the URDF are modified. To force generation, either set
+ :obj:`AssetConverterBaseCfg.force_usd_conversion` to True or delete the output directory.
+
+ .. note::
+ From Isaac Sim 2023.1 onwards, the extension name changed from ``omni.isaac.urdf`` to
+ ``omni.importer.urdf``. This converter class automatically detects the version of Isaac Sim
+ and uses the appropriate extension.
+
+ The new extension supports a custom XML tag``"dont_collapse"`` for joints. Setting this parameter
+ to true in the URDF joint tag prevents the child link from collapsing when the associated joint type
+ is "fixed".
+
+ .. _omni.isaac.urdf_importer: https://docs.omniverse.nvidia.com/isaacsim/latest/ext_omni_isaac_urdf.html
+ """
+
+ cfg:UrdfConverterCfg
+"""The configuration instance for URDF to USD conversion."""
+
+
[文档]def__init__(self,cfg:UrdfConverterCfg):
+"""Initializes the class.
+
+ Args:
+ cfg: The configuration instance for URDF to USD conversion.
+ """
+ super().__init__(cfg=cfg)
+
+"""
+ Implementation specific methods.
+ """
+
+ def_convert_asset(self,cfg:UrdfConverterCfg):
+"""Calls underlying Omniverse command to convert URDF to USD.
+
+ Args:
+ cfg: The URDF conversion configuration.
+ """
+ import_config=self._get_urdf_import_config(cfg)
+ omni.kit.commands.execute(
+ "URDFParseAndImportFile",
+ urdf_path=cfg.asset_path,
+ import_config=import_config,
+ dest_path=self.usd_path,
+ )
+ # fix the issue that material paths are not relative
+ ifself.cfg.make_instanceable:
+ instanced_usd_path=os.path.join(self.usd_dir,self.usd_instanceable_meshes_path)
+ stage=Usd.Stage.Open(instanced_usd_path)
+ # resolve all paths relative to layer path
+ source_layer=stage.GetRootLayer()
+ omni.usd.resolve_paths(source_layer.identifier,source_layer.identifier)
+ stage.Save()
+
+ # fix the issue that material paths are not relative
+ # note: This issue seems to have popped up in Isaac Sim 2023.1.1
+ stage=Usd.Stage.Open(self.usd_path)
+ # resolve all paths relative to layer path
+ source_layer=stage.GetRootLayer()
+ omni.usd.resolve_paths(source_layer.identifier,source_layer.identifier)
+ stage.Save()
+
+"""
+ Helper methods.
+ """
+
+ def_get_urdf_import_config(self,cfg:UrdfConverterCfg)->omni.importer.urdf.ImportConfig:
+"""Create and fill URDF ImportConfig with desired settings
+
+ Args:
+ cfg: The URDF conversion configuration.
+
+ Returns:
+ The constructed ``ImportConfig`` object containing the desired settings.
+ """
+ # Enable urdf extension
+ enable_extension("omni.importer.urdf")
+
+ fromomni.importer.urdfimport_urdfasomni_urdf
+
+ import_config=omni_urdf.ImportConfig()
+
+ # set the unit scaling factor, 1.0 means meters, 100.0 means cm
+ import_config.set_distance_scale(1.0)
+ # set imported robot as default prim
+ import_config.set_make_default_prim(True)
+ # add a physics scene to the stage on import if none exists
+ import_config.set_create_physics_scene(False)
+
+ # -- instancing settings
+ # meshes will be placed in a separate usd file
+ import_config.set_make_instanceable(cfg.make_instanceable)
+ import_config.set_instanceable_usd_path(self.usd_instanceable_meshes_path)
+
+ # -- asset settings
+ # default density used for links, use 0 to auto-compute
+ import_config.set_density(cfg.link_density)
+ # import inertia tensor from urdf, if it is not specified in urdf it will import as identity
+ import_config.set_import_inertia_tensor(cfg.import_inertia_tensor)
+ # decompose a convex mesh into smaller pieces for a closer fit
+ import_config.set_convex_decomp(cfg.convex_decompose_mesh)
+ import_config.set_subdivision_scheme(_NORMALS_DIVISION["bilinear"])
+
+ # -- physics settings
+ # create fix joint for base link
+ import_config.set_fix_base(cfg.fix_base)
+ # consolidating links that are connected by fixed joints
+ import_config.set_merge_fixed_joints(cfg.merge_fixed_joints)
+ # self collisions between links in the articulation
+ import_config.set_self_collision(cfg.self_collision)
+
+ # default drive type used for joints
+ import_config.set_default_drive_type(_DRIVE_TYPE[cfg.default_drive_type])
+ # default proportional gains
+ import_config.set_default_drive_strength(cfg.default_drive_stiffness)
+ # default derivative gains
+ import_config.set_default_position_drive_damping(cfg.default_drive_damping)
+ ifget_version()[2]=="4":
+ # override joint dynamics parsed from urdf
+ import_config.set_override_joint_dynamics(cfg.override_joint_dynamics)
+
+ returnimport_config
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+fromdataclassesimportMISSING
+fromtypingimportLiteral
+
+fromomni.isaac.lab.sim.converters.asset_converter_base_cfgimportAssetConverterBaseCfg
+fromomni.isaac.lab.utilsimportconfigclass
+
+
+
[文档]@configclass
+classUrdfConverterCfg(AssetConverterBaseCfg):
+"""The configuration class for UrdfConverter."""
+
+ link_density=0.0
+"""Default density used for links. Defaults to 0.
+
+ This setting is only effective if ``"inertial"`` properties are missing in the URDF.
+ """
+
+ import_inertia_tensor:bool=True
+"""Import the inertia tensor from urdf. Defaults to True.
+
+ If the ``"inertial"`` tag is missing, then it is imported as an identity.
+ """
+
+ convex_decompose_mesh=False
+"""Decompose a convex mesh into smaller pieces for a closer fit. Defaults to False."""
+
+ fix_base:bool=MISSING
+"""Create a fix joint to the root/base link. Defaults to True."""
+
+ merge_fixed_joints:bool=False
+"""Consolidate links that are connected by fixed joints. Defaults to False."""
+
+ self_collision:bool=False
+"""Activate self-collisions between links of the articulation. Defaults to False."""
+
+ default_drive_type:Literal["none","position","velocity"]="none"
+"""The drive type used for joints. Defaults to ``"none"``.
+
+ The drive type dictates the loaded joint PD gains and USD attributes for joint control:
+
+ * ``"none"``: The joint stiffness and damping are set to 0.0.
+ * ``"position"``: The joint stiff and damping are set based on the URDF file or provided configuration.
+ * ``"velocity"``: The joint stiff is set to zero and damping is based on the URDF file or provided configuration.
+ """
+
+ override_joint_dynamics:bool=False
+"""Override the joint dynamics parsed from the URDF file. Defaults to False."""
+
+ default_drive_stiffness:float=0.0
+"""The default stiffness of the joint drive. Defaults to 0.0."""
+
+ default_drive_damping:float=0.0
+"""The default damping of the joint drive. Defaults to 0.0.
+
+ Note:
+ If ``override_joint_dynamics`` is True, the values parsed from the URDF joint tag ``"<dynamics><damping>"`` are used.
+ Otherwise, it is overridden by the configured value.
+ """
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+# needed to import for allowing type-hinting: Usd.Stage | None
+from__future__importannotations
+
+importcarb
+importomni.isaac.core.utils.stageasstage_utils
+importomni.physx.scripts.utilsasphysx_utils
+fromomni.physx.scriptsimportdeformableUtilsasdeformable_utils
+frompxrimportPhysxSchema,Usd,UsdPhysics
+
+from..utilsimport(
+ apply_nested,
+ find_global_fixed_joint_prim,
+ get_all_matching_child_prims,
+ safe_set_attribute_on_usd_schema,
+)
+from.importschemas_cfg
+
+"""
+Articulation root properties.
+"""
+
+
+
[文档]defdefine_articulation_root_properties(
+ prim_path:str,cfg:schemas_cfg.ArticulationRootPropertiesCfg,stage:Usd.Stage|None=None
+):
+"""Apply the articulation root schema on the input prim and set its properties.
+
+ See :func:`modify_articulation_root_properties` for more details on how the properties are set.
+
+ Args:
+ prim_path: The prim path where to apply the articulation root schema.
+ cfg: The configuration for the articulation root.
+ stage: The stage where to find the prim. Defaults to None, in which case the
+ current stage is used.
+
+ Raises:
+ ValueError: When the prim path is not valid.
+ TypeError: When the prim already has conflicting API schemas.
+ """
+ # obtain stage
+ ifstageisNone:
+ stage=stage_utils.get_current_stage()
+ # get articulation USD prim
+ prim=stage.GetPrimAtPath(prim_path)
+ # check if prim path is valid
+ ifnotprim.IsValid():
+ raiseValueError(f"Prim path '{prim_path}' is not valid.")
+ # check if prim has articulation applied on it
+ ifnotUsdPhysics.ArticulationRootAPI(prim):
+ UsdPhysics.ArticulationRootAPI.Apply(prim)
+ # set articulation root properties
+ modify_articulation_root_properties(prim_path,cfg,stage)
+
+
+
[文档]@apply_nested
+defmodify_articulation_root_properties(
+ prim_path:str,cfg:schemas_cfg.ArticulationRootPropertiesCfg,stage:Usd.Stage|None=None
+)->bool:
+"""Modify PhysX parameters for an articulation root prim.
+
+ The `articulation root`_ marks the root of an articulation tree. For floating articulations, this should be on
+ the root body. For fixed articulations, this API can be on a direct or indirect parent of the root joint
+ which is fixed to the world.
+
+ The schema comprises of attributes that belong to the `ArticulationRootAPI`_ and `PhysxArticulationAPI`_.
+ schemas. The latter contains the PhysX parameters for the articulation root.
+
+ The properties are applied to the articulation root prim. The common properties (such as solver position
+ and velocity iteration counts, sleep threshold, stabilization threshold) take precedence over those specified
+ in the rigid body schemas for all the rigid bodies in the articulation.
+
+ .. caution::
+ When the attribute :attr:`schemas_cfg.ArticulationRootPropertiesCfg.fix_root_link` is set to True,
+ a fixed joint is created between the root link and the world frame (if it does not already exist). However,
+ to deal with physics parser limitations, the articulation root schema needs to be applied to the parent of
+ the root link.
+
+ .. note::
+ This function is decorated with :func:`apply_nested` that set the properties to all the prims
+ (that have the schema applied on them) under the input prim path.
+
+ .. _articulation root: https://nvidia-omniverse.github.io/PhysX/physx/5.4.1/docs/Articulations.html
+ .. _ArticulationRootAPI: https://openusd.org/dev/api/class_usd_physics_articulation_root_a_p_i.html
+ .. _PhysxArticulationAPI: https://docs.omniverse.nvidia.com/kit/docs/omni_usd_schema_physics/104.2/class_physx_schema_physx_articulation_a_p_i.html
+
+ Args:
+ prim_path: The prim path to the articulation root.
+ cfg: The configuration for the articulation root.
+ stage: The stage where to find the prim. Defaults to None, in which case the
+ current stage is used.
+
+ Returns:
+ True if the properties were successfully set, False otherwise.
+
+ Raises:
+ NotImplementedError: When the root prim is not a rigid body and a fixed joint is to be created.
+ """
+ # obtain stage
+ ifstageisNone:
+ stage=stage_utils.get_current_stage()
+ # get articulation USD prim
+ articulation_prim=stage.GetPrimAtPath(prim_path)
+ # check if prim has articulation applied on it
+ ifnotUsdPhysics.ArticulationRootAPI(articulation_prim):
+ returnFalse
+ # retrieve the articulation api
+ physx_articulation_api=PhysxSchema.PhysxArticulationAPI(articulation_prim)
+ ifnotphysx_articulation_api:
+ physx_articulation_api=PhysxSchema.PhysxArticulationAPI.Apply(articulation_prim)
+
+ # convert to dict
+ cfg=cfg.to_dict()
+ # extract non-USD properties
+ fix_root_link=cfg.pop("fix_root_link",None)
+
+ # set into physx api
+ forattr_name,valueincfg.items():
+ safe_set_attribute_on_usd_schema(physx_articulation_api,attr_name,value,camel_case=True)
+
+ # fix root link based on input
+ # we do the fixed joint processing later to not interfere with setting other properties
+ iffix_root_linkisnotNone:
+ # check if a global fixed joint exists under the root prim
+ existing_fixed_joint_prim=find_global_fixed_joint_prim(prim_path)
+
+ # if we found a fixed joint, enable/disable it based on the input
+ # otherwise, create a fixed joint between the world and the root link
+ ifexisting_fixed_joint_primisnotNone:
+ carb.log_info(
+ f"Found an existing fixed joint for the articulation: '{prim_path}'. Setting it to: {fix_root_link}."
+ )
+ existing_fixed_joint_prim.GetJointEnabledAttr().Set(fix_root_link)
+ eliffix_root_link:
+ carb.log_info(f"Creating a fixed joint for the articulation: '{prim_path}'.")
+
+ # note: we have to assume that the root prim is a rigid body,
+ # i.e. we don't handle the case where the root prim is not a rigid body but has articulation api on it
+ # Currently, there is no obvious way to get first rigid body link identified by the PhysX parser
+ ifnotarticulation_prim.HasAPI(UsdPhysics.RigidBodyAPI):
+ raiseNotImplementedError(
+ f"The articulation prim '{prim_path}' does not have the RigidBodyAPI applied."
+ " To create a fixed joint, we need to determine the first rigid body link in"
+ " the articulation tree. However, this is not implemented yet."
+ )
+
+ # create a fixed joint between the root link and the world frame
+ physx_utils.createJoint(stage=stage,joint_type="Fixed",from_prim=None,to_prim=articulation_prim)
+
+ # Having a fixed joint on a rigid body is not treated as "fixed base articulation".
+ # instead, it is treated as a part of the maximal coordinate tree.
+ # Moving the articulation root to the parent solves this issue. This is a limitation of the PhysX parser.
+ # get parent prim
+ parent_prim=articulation_prim.GetParent()
+ # apply api to parent
+ UsdPhysics.ArticulationRootAPI.Apply(parent_prim)
+ PhysxSchema.PhysxArticulationAPI.Apply(parent_prim)
+
+ # copy the attributes
+ # -- usd attributes
+ usd_articulation_api=UsdPhysics.ArticulationRootAPI(articulation_prim)
+ forattr_nameinusd_articulation_api.GetSchemaAttributeNames():
+ attr=articulation_prim.GetAttribute(attr_name)
+ parent_prim.GetAttribute(attr_name).Set(attr.Get())
+ # -- physx attributes
+ physx_articulation_api=PhysxSchema.PhysxArticulationAPI(articulation_prim)
+ forattr_nameinphysx_articulation_api.GetSchemaAttributeNames():
+ attr=articulation_prim.GetAttribute(attr_name)
+ parent_prim.GetAttribute(attr_name).Set(attr.Get())
+
+ # remove api from root
+ articulation_prim.RemoveAPI(UsdPhysics.ArticulationRootAPI)
+ articulation_prim.RemoveAPI(PhysxSchema.PhysxArticulationAPI)
+
+ # success
+ returnTrue
+
+
+"""
+Rigid body properties.
+"""
+
+
+
[文档]defdefine_rigid_body_properties(
+ prim_path:str,cfg:schemas_cfg.RigidBodyPropertiesCfg,stage:Usd.Stage|None=None
+):
+"""Apply the rigid body schema on the input prim and set its properties.
+
+ See :func:`modify_rigid_body_properties` for more details on how the properties are set.
+
+ Args:
+ prim_path: The prim path where to apply the rigid body schema.
+ cfg: The configuration for the rigid body.
+ stage: The stage where to find the prim. Defaults to None, in which case the
+ current stage is used.
+
+ Raises:
+ ValueError: When the prim path is not valid.
+ TypeError: When the prim already has conflicting API schemas.
+ """
+ # obtain stage
+ ifstageisNone:
+ stage=stage_utils.get_current_stage()
+ # get USD prim
+ prim=stage.GetPrimAtPath(prim_path)
+ # check if prim path is valid
+ ifnotprim.IsValid():
+ raiseValueError(f"Prim path '{prim_path}' is not valid.")
+ # check if prim has rigid body applied on it
+ ifnotUsdPhysics.RigidBodyAPI(prim):
+ UsdPhysics.RigidBodyAPI.Apply(prim)
+ # set rigid body properties
+ modify_rigid_body_properties(prim_path,cfg,stage)
+
+
+
[文档]@apply_nested
+defmodify_rigid_body_properties(
+ prim_path:str,cfg:schemas_cfg.RigidBodyPropertiesCfg,stage:Usd.Stage|None=None
+)->bool:
+"""Modify PhysX parameters for a rigid body prim.
+
+ A `rigid body`_ is a single body that can be simulated by PhysX. It can be either dynamic or kinematic.
+ A dynamic body responds to forces and collisions. A `kinematic body`_ can be moved by the user, but does not
+ respond to forces. They are similar to having static bodies that can be moved around.
+
+ The schema comprises of attributes that belong to the `RigidBodyAPI`_ and `PhysxRigidBodyAPI`_.
+ schemas. The latter contains the PhysX parameters for the rigid body.
+
+ .. note::
+ This function is decorated with :func:`apply_nested` that sets the properties to all the prims
+ (that have the schema applied on them) under the input prim path.
+
+ .. _rigid body: https://nvidia-omniverse.github.io/PhysX/physx/5.4.1/docs/RigidBodyOverview.html
+ .. _kinematic body: https://openusd.org/release/wp_rigid_body_physics.html#kinematic-bodies
+ .. _RigidBodyAPI: https://openusd.org/dev/api/class_usd_physics_rigid_body_a_p_i.html
+ .. _PhysxRigidBodyAPI: https://docs.omniverse.nvidia.com/kit/docs/omni_usd_schema_physics/104.2/class_physx_schema_physx_rigid_body_a_p_i.html
+
+ Args:
+ prim_path: The prim path to the rigid body.
+ cfg: The configuration for the rigid body.
+ stage: The stage where to find the prim. Defaults to None, in which case the
+ current stage is used.
+
+ Returns:
+ True if the properties were successfully set, False otherwise.
+ """
+ # obtain stage
+ ifstageisNone:
+ stage=stage_utils.get_current_stage()
+ # get rigid-body USD prim
+ rigid_body_prim=stage.GetPrimAtPath(prim_path)
+ # check if prim has rigid-body applied on it
+ ifnotUsdPhysics.RigidBodyAPI(rigid_body_prim):
+ returnFalse
+ # retrieve the USD rigid-body api
+ usd_rigid_body_api=UsdPhysics.RigidBodyAPI(rigid_body_prim)
+ # retrieve the physx rigid-body api
+ physx_rigid_body_api=PhysxSchema.PhysxRigidBodyAPI(rigid_body_prim)
+ ifnotphysx_rigid_body_api:
+ physx_rigid_body_api=PhysxSchema.PhysxRigidBodyAPI.Apply(rigid_body_prim)
+
+ # convert to dict
+ cfg=cfg.to_dict()
+ # set into USD API
+ forattr_namein["rigid_body_enabled","kinematic_enabled"]:
+ value=cfg.pop(attr_name,None)
+ safe_set_attribute_on_usd_schema(usd_rigid_body_api,attr_name,value,camel_case=True)
+ # set into PhysX API
+ forattr_name,valueincfg.items():
+ safe_set_attribute_on_usd_schema(physx_rigid_body_api,attr_name,value,camel_case=True)
+ # success
+ returnTrue
+
+
+"""
+Collision properties.
+"""
+
+
+
[文档]defdefine_collision_properties(
+ prim_path:str,cfg:schemas_cfg.CollisionPropertiesCfg,stage:Usd.Stage|None=None
+):
+"""Apply the collision schema on the input prim and set its properties.
+
+ See :func:`modify_collision_properties` for more details on how the properties are set.
+
+ Args:
+ prim_path: The prim path where to apply the rigid body schema.
+ cfg: The configuration for the collider.
+ stage: The stage where to find the prim. Defaults to None, in which case the
+ current stage is used.
+
+ Raises:
+ ValueError: When the prim path is not valid.
+ """
+ # obtain stage
+ ifstageisNone:
+ stage=stage_utils.get_current_stage()
+ # get USD prim
+ prim=stage.GetPrimAtPath(prim_path)
+ # check if prim path is valid
+ ifnotprim.IsValid():
+ raiseValueError(f"Prim path '{prim_path}' is not valid.")
+ # check if prim has collision applied on it
+ ifnotUsdPhysics.CollisionAPI(prim):
+ UsdPhysics.CollisionAPI.Apply(prim)
+ # set collision properties
+ modify_collision_properties(prim_path,cfg,stage)
+
+
+
[文档]@apply_nested
+defmodify_collision_properties(
+ prim_path:str,cfg:schemas_cfg.CollisionPropertiesCfg,stage:Usd.Stage|None=None
+)->bool:
+"""Modify PhysX properties of collider prim.
+
+ These properties are based on the `UsdPhysics.CollisionAPI`_ and `PhysxSchema.PhysxCollisionAPI`_ schemas.
+ For more information on the properties, please refer to the official documentation.
+
+ Tuning these parameters influence the contact behavior of the rigid body. For more information on
+ tune them and their effect on the simulation, please refer to the
+ `PhysX documentation <https://nvidia-omniverse.github.io/PhysX/physx/5.4.1/docs/AdvancedCollisionDetection.html>`__.
+
+ .. note::
+ This function is decorated with :func:`apply_nested` that sets the properties to all the prims
+ (that have the schema applied on them) under the input prim path.
+
+ .. _UsdPhysics.CollisionAPI: https://openusd.org/dev/api/class_usd_physics_collision_a_p_i.html
+ .. _PhysxSchema.PhysxCollisionAPI: https://docs.omniverse.nvidia.com/kit/docs/omni_usd_schema_physics/104.2/class_physx_schema_physx_collision_a_p_i.html
+
+ Args:
+ prim_path: The prim path of parent.
+ cfg: The configuration for the collider.
+ stage: The stage where to find the prim. Defaults to None, in which case the
+ current stage is used.
+
+ Returns:
+ True if the properties were successfully set, False otherwise.
+ """
+ # obtain stage
+ ifstageisNone:
+ stage=stage_utils.get_current_stage()
+ # get USD prim
+ collider_prim=stage.GetPrimAtPath(prim_path)
+ # check if prim has collision applied on it
+ ifnotUsdPhysics.CollisionAPI(collider_prim):
+ returnFalse
+ # retrieve the USD collision api
+ usd_collision_api=UsdPhysics.CollisionAPI(collider_prim)
+ # retrieve the collision api
+ physx_collision_api=PhysxSchema.PhysxCollisionAPI(collider_prim)
+ ifnotphysx_collision_api:
+ physx_collision_api=PhysxSchema.PhysxCollisionAPI.Apply(collider_prim)
+
+ # convert to dict
+ cfg=cfg.to_dict()
+ # set into USD API
+ forattr_namein["collision_enabled"]:
+ value=cfg.pop(attr_name,None)
+ safe_set_attribute_on_usd_schema(usd_collision_api,attr_name,value,camel_case=True)
+ # set into PhysX API
+ forattr_name,valueincfg.items():
+ safe_set_attribute_on_usd_schema(physx_collision_api,attr_name,value,camel_case=True)
+ # success
+ returnTrue
+
+
+"""
+Mass properties.
+"""
+
+
+
[文档]defdefine_mass_properties(prim_path:str,cfg:schemas_cfg.MassPropertiesCfg,stage:Usd.Stage|None=None):
+"""Apply the mass schema on the input prim and set its properties.
+
+ See :func:`modify_mass_properties` for more details on how the properties are set.
+
+ Args:
+ prim_path: The prim path where to apply the rigid body schema.
+ cfg: The configuration for the mass properties.
+ stage: The stage where to find the prim. Defaults to None, in which case the
+ current stage is used.
+
+ Raises:
+ ValueError: When the prim path is not valid.
+ """
+ # obtain stage
+ ifstageisNone:
+ stage=stage_utils.get_current_stage()
+ # get USD prim
+ prim=stage.GetPrimAtPath(prim_path)
+ # check if prim path is valid
+ ifnotprim.IsValid():
+ raiseValueError(f"Prim path '{prim_path}' is not valid.")
+ # check if prim has mass applied on it
+ ifnotUsdPhysics.MassAPI(prim):
+ UsdPhysics.MassAPI.Apply(prim)
+ # set mass properties
+ modify_mass_properties(prim_path,cfg,stage)
+
+
+
[文档]@apply_nested
+defmodify_mass_properties(prim_path:str,cfg:schemas_cfg.MassPropertiesCfg,stage:Usd.Stage|None=None)->bool:
+"""Set properties for the mass of a rigid body prim.
+
+ These properties are based on the `UsdPhysics.MassAPI` schema. If the mass is not defined, the density is used
+ to compute the mass. However, in that case, a collision approximation of the rigid body is used to
+ compute the density. For more information on the properties, please refer to the
+ `documentation <https://openusd.org/release/wp_rigid_body_physics.html#body-mass-properties>`__.
+
+ .. caution::
+
+ The mass of an object can be specified in multiple ways and have several conflicting settings
+ that are resolved based on precedence. Please make sure to understand the precedence rules
+ before using this property.
+
+ .. note::
+ This function is decorated with :func:`apply_nested` that sets the properties to all the prims
+ (that have the schema applied on them) under the input prim path.
+
+ .. UsdPhysics.MassAPI: https://openusd.org/dev/api/class_usd_physics_mass_a_p_i.html
+
+ Args:
+ prim_path: The prim path of the rigid body.
+ cfg: The configuration for the mass properties.
+ stage: The stage where to find the prim. Defaults to None, in which case the
+ current stage is used.
+
+ Returns:
+ True if the properties were successfully set, False otherwise.
+ """
+ # obtain stage
+ ifstageisNone:
+ stage=stage_utils.get_current_stage()
+ # get USD prim
+ rigid_prim=stage.GetPrimAtPath(prim_path)
+ # check if prim has mass API applied on it
+ ifnotUsdPhysics.MassAPI(rigid_prim):
+ returnFalse
+ # retrieve the USD mass api
+ usd_physics_mass_api=UsdPhysics.MassAPI(rigid_prim)
+
+ # convert to dict
+ cfg=cfg.to_dict()
+ # set into USD API
+ forattr_namein["mass","density"]:
+ value=cfg.pop(attr_name,None)
+ safe_set_attribute_on_usd_schema(usd_physics_mass_api,attr_name,value,camel_case=True)
+ # success
+ returnTrue
+
+
+"""
+Contact sensor.
+"""
+
+
+
[文档]defactivate_contact_sensors(prim_path:str,threshold:float=0.0,stage:Usd.Stage=None):
+"""Activate the contact sensor on all rigid bodies under a specified prim path.
+
+ This function adds the PhysX contact report API to all rigid bodies under the specified prim path.
+ It also sets the force threshold beyond which the contact sensor reports the contact. The contact
+ reporting API can only be added to rigid bodies.
+
+ Args:
+ prim_path: The prim path under which to search and prepare contact sensors.
+ threshold: The threshold for the contact sensor. Defaults to 0.0.
+ stage: The stage where to find the prim. Defaults to None, in which case the
+ current stage is used.
+
+ Raises:
+ ValueError: If the input prim path is not valid.
+ ValueError: If there are no rigid bodies under the prim path.
+ """
+ # obtain stage
+ ifstageisNone:
+ stage=stage_utils.get_current_stage()
+ # get prim
+ prim:Usd.Prim=stage.GetPrimAtPath(prim_path)
+ # check if prim is valid
+ ifnotprim.IsValid():
+ raiseValueError(f"Prim path '{prim_path}' is not valid.")
+ # iterate over all children
+ num_contact_sensors=0
+ all_prims=[prim]
+ whilelen(all_prims)>0:
+ # get current prim
+ child_prim=all_prims.pop(0)
+ # check if prim is a rigid body
+ # nested rigid bodies are not allowed by SDK so we can safely assume that
+ # if a prim has a rigid body API, it is a rigid body and we don't need to
+ # check its children
+ ifchild_prim.HasAPI(UsdPhysics.RigidBodyAPI):
+ # set sleep threshold to zero
+ rb=PhysxSchema.PhysxRigidBodyAPI.Get(stage,prim.GetPrimPath())
+ rb.CreateSleepThresholdAttr().Set(0.0)
+ # add contact report API with threshold of zero
+ ifnotchild_prim.HasAPI(PhysxSchema.PhysxContactReportAPI):
+ carb.log_verbose(f"Adding contact report API to prim: '{child_prim.GetPrimPath()}'")
+ cr_api=PhysxSchema.PhysxContactReportAPI.Apply(child_prim)
+ else:
+ carb.log_verbose(f"Contact report API already exists on prim: '{child_prim.GetPrimPath()}'")
+ cr_api=PhysxSchema.PhysxContactReportAPI.Get(stage,child_prim.GetPrimPath())
+ # set threshold to zero
+ cr_api.CreateThresholdAttr().Set(threshold)
+ # increment number of contact sensors
+ num_contact_sensors+=1
+ else:
+ # add all children to tree
+ all_prims+=child_prim.GetChildren()
+ # check if no contact sensors were found
+ ifnum_contact_sensors==0:
+ raiseValueError(
+ f"No contact sensors added to the prim: '{prim_path}'. This means that no rigid bodies"
+ " are present under this prim. Please check the prim path."
+ )
+ # success
+ returnTrue
+
+
+"""
+Joint drive properties.
+"""
+
+
+
[文档]@apply_nested
+defmodify_joint_drive_properties(
+ prim_path:str,drive_props:schemas_cfg.JointDrivePropertiesCfg,stage:Usd.Stage|None=None
+)->bool:
+"""Modify PhysX parameters for a joint prim.
+
+ This function checks if the input prim is a prismatic or revolute joint and applies the joint drive schema
+ on it. If the joint is a tendon (i.e., it has the `PhysxTendonAxisAPI`_ schema applied on it), then the joint
+ drive schema is not applied.
+
+ Based on the configuration, this method modifies the properties of the joint drive. These properties are
+ based on the `UsdPhysics.DriveAPI`_ schema. For more information on the properties, please refer to the
+ official documentation.
+
+ .. caution::
+
+ We highly recommend modifying joint properties of articulations through the functionalities in the
+ :mod:`omni.isaac.lab.actuators` module. The methods here are for setting simulation low-level
+ properties only.
+
+ .. _UsdPhysics.DriveAPI: https://openusd.org/dev/api/class_usd_physics_drive_a_p_i.html
+ .. _PhysxTendonAxisAPI: https://docs.omniverse.nvidia.com/kit/docs/omni_usd_schema_physics/104.2/class_physx_schema_physx_tendon_axis_a_p_i.html
+
+ Args:
+ prim_path: The prim path where to apply the joint drive schema.
+ drive_props: The configuration for the joint drive.
+ stage: The stage where to find the prim. Defaults to None, in which case the
+ current stage is used.
+
+ Returns:
+ True if the properties were successfully set, False otherwise.
+
+ Raises:
+ ValueError: If the input prim path is not valid.
+ """
+ # obtain stage
+ ifstageisNone:
+ stage=stage_utils.get_current_stage()
+ # get USD prim
+ prim=stage.GetPrimAtPath(prim_path)
+ # check if prim path is valid
+ ifnotprim.IsValid():
+ raiseValueError(f"Prim path '{prim_path}' is not valid.")
+
+ # check if prim has joint drive applied on it
+ ifprim.IsA(UsdPhysics.RevoluteJoint):
+ drive_api_name="angular"
+ elifprim.IsA(UsdPhysics.PrismaticJoint):
+ drive_api_name="linear"
+ else:
+ returnFalse
+ # check that prim is not a tendon child prim
+ # note: root prim is what "controls" the tendon so we still want to apply the drive to it
+ ifprim.HasAPI(PhysxSchema.PhysxTendonAxisAPI)andnotprim.HasAPI(PhysxSchema.PhysxTendonAxisRootAPI):
+ returnFalse
+
+ # check if prim has joint drive applied on it
+ usd_drive_api=UsdPhysics.DriveAPI(prim,drive_api_name)
+ ifnotusd_drive_api:
+ usd_drive_api=UsdPhysics.DriveAPI.Apply(prim,drive_api_name)
+
+ # change the drive type to input
+ ifdrive_props.drive_typeisnotNone:
+ usd_drive_api.CreateTypeAttr().Set(drive_props.drive_type)
+
+ returnTrue
+
+
+"""
+Fixed tendon properties.
+"""
+
+
+
[文档]@apply_nested
+defmodify_fixed_tendon_properties(
+ prim_path:str,cfg:schemas_cfg.FixedTendonPropertiesCfg,stage:Usd.Stage|None=None
+)->bool:
+"""Modify PhysX parameters for a fixed tendon attachment prim.
+
+ A `fixed tendon`_ can be used to link multiple degrees of freedom of articulation joints
+ through length and limit constraints. For instance, it can be used to set up an equality constraint
+ between a driven and passive revolute joints.
+
+ The schema comprises of attributes that belong to the `PhysxTendonAxisRootAPI`_ schema.
+
+ .. note::
+ This function is decorated with :func:`apply_nested` that sets the properties to all the prims
+ (that have the schema applied on them) under the input prim path.
+
+ .. _fixed tendon: https://nvidia-omniverse.github.io/PhysX/physx/5.4.1/_api_build/classPxArticulationFixedTendon.html
+ .. _PhysxTendonAxisRootAPI: https://docs.omniverse.nvidia.com/kit/docs/omni_usd_schema_physics/104.2/class_physx_schema_physx_tendon_axis_root_a_p_i.html
+
+ Args:
+ prim_path: The prim path to the tendon attachment.
+ cfg: The configuration for the tendon attachment.
+ stage: The stage where to find the prim. Defaults to None, in which case the
+ current stage is used.
+
+ Returns:
+ True if the properties were successfully set, False otherwise.
+
+ Raises:
+ ValueError: If the input prim path is not valid.
+ """
+ # obtain stage
+ ifstageisNone:
+ stage=stage_utils.get_current_stage()
+ # get USD prim
+ tendon_prim=stage.GetPrimAtPath(prim_path)
+ # check if prim has fixed tendon applied on it
+ has_root_fixed_tendon=tendon_prim.HasAPI(PhysxSchema.PhysxTendonAxisRootAPI)
+ ifnothas_root_fixed_tendon:
+ returnFalse
+
+ # resolve all available instances of the schema since it is multi-instance
+ forschema_nameintendon_prim.GetAppliedSchemas():
+ # only consider the fixed tendon schema
+ if"PhysxTendonAxisRootAPI"notinschema_name:
+ continue
+ # retrieve the USD tendon api
+ instance_name=schema_name.split(":")[-1]
+ physx_tendon_axis_api=PhysxSchema.PhysxTendonAxisRootAPI(tendon_prim,instance_name)
+
+ # convert to dict
+ cfg=cfg.to_dict()
+ # set into PhysX API
+ forattr_name,valueincfg.items():
+ safe_set_attribute_on_usd_schema(physx_tendon_axis_api,attr_name,value,camel_case=True)
+ # success
+ returnTrue
+
+
+"""
+Deformable body properties.
+"""
+
+
+
[文档]defdefine_deformable_body_properties(
+ prim_path:str,cfg:schemas_cfg.DeformableBodyPropertiesCfg,stage:Usd.Stage|None=None
+):
+"""Apply the deformable body schema on the input prim and set its properties.
+
+ See :func:`modify_deformable_body_properties` for more details on how the properties are set.
+
+ .. note::
+ If the input prim is not a mesh, this function will traverse the prim and find the first mesh
+ under it. If no mesh or multiple meshes are found, an error is raised. This is because the deformable
+ body schema can only be applied to a single mesh.
+
+ Args:
+ prim_path: The prim path where to apply the deformable body schema.
+ cfg: The configuration for the deformable body.
+ stage: The stage where to find the prim. Defaults to None, in which case the
+ current stage is used.
+
+ Raises:
+ ValueError: When the prim path is not valid.
+ ValueError: When the prim has no mesh or multiple meshes.
+ """
+ # obtain stage
+ ifstageisNone:
+ stage=stage_utils.get_current_stage()
+ # get USD prim
+ prim=stage.GetPrimAtPath(prim_path)
+ # check if prim path is valid
+ ifnotprim.IsValid():
+ raiseValueError(f"Prim path '{prim_path}' is not valid.")
+
+ # traverse the prim and get the mesh
+ matching_prims=get_all_matching_child_prims(prim_path,lambdap:p.GetTypeName()=="Mesh")
+ # check if the mesh is valid
+ iflen(matching_prims)==0:
+ raiseValueError(f"Could not find any mesh in '{prim_path}'. Please check asset.")
+ iflen(matching_prims)>1:
+ # get list of all meshes found
+ mesh_paths=[p.GetPrimPath()forpinmatching_prims]
+ raiseValueError(
+ f"Found multiple meshes in '{prim_path}': {mesh_paths}."
+ " Deformable body schema can only be applied to one mesh."
+ )
+
+ # get deformable-body USD prim
+ mesh_prim=matching_prims[0]
+ # check if prim has deformable-body applied on it
+ ifnotPhysxSchema.PhysxDeformableBodyAPI(mesh_prim):
+ PhysxSchema.PhysxDeformableBodyAPI.Apply(mesh_prim)
+ # set deformable body properties
+ modify_deformable_body_properties(mesh_prim.GetPrimPath(),cfg,stage)
+
+
+
[文档]@apply_nested
+defmodify_deformable_body_properties(
+ prim_path:str,cfg:schemas_cfg.DeformableBodyPropertiesCfg,stage:Usd.Stage|None=None
+):
+"""Modify PhysX parameters for a deformable body prim.
+
+ A `deformable body`_ is a single body that can be simulated by PhysX. Unlike rigid bodies, deformable bodies
+ support relative motion of the nodes in the mesh. Consequently, they can be used to simulate deformations
+ under applied forces.
+
+ PhysX soft body simulation employs Finite Element Analysis (FEA) to simulate the deformations of the mesh.
+ It uses two tetrahedral meshes to represent the deformable body:
+
+ 1. **Simulation mesh**: This mesh is used for the simulation and is the one that is deformed by the solver.
+ 2. **Collision mesh**: This mesh only needs to match the surface of the simulation mesh and is used for
+ collision detection.
+
+ For most applications, we assume that the above two meshes are computed from the "render mesh" of the deformable
+ body. The render mesh is the mesh that is visible in the scene and is used for rendering purposes. It is composed
+ of triangles and is the one that is used to compute the above meshes based on PhysX cookings.
+
+ The schema comprises of attributes that belong to the `PhysxDeformableBodyAPI`_. schemas containing the PhysX
+ parameters for the deformable body.
+
+ .. caution::
+ The deformable body schema is still under development by the Omniverse team. The current implementation
+ works with the PhysX schemas shipped with Isaac Sim 4.0.0 onwards. It may change in future releases.
+
+ .. note::
+ This function is decorated with :func:`apply_nested` that sets the properties to all the prims
+ (that have the schema applied on them) under the input prim path.
+
+ .. _deformable body: https://nvidia-omniverse.github.io/PhysX/physx/5.4.1/docs/SoftBodies.html
+ .. _PhysxDeformableBodyAPI: https://docs.omniverse.nvidia.com/kit/docs/omni_usd_schema_physics/104.2/class_physx_schema_physx_deformable_a_p_i.html
+
+ Args:
+ prim_path: The prim path to the deformable body.
+ cfg: The configuration for the deformable body.
+ stage: The stage where to find the prim. Defaults to None, in which case the
+ current stage is used.
+
+ Returns:
+ True if the properties were successfully set, False otherwise.
+ """
+ # obtain stage
+ ifstageisNone:
+ stage=stage_utils.get_current_stage()
+
+ # get deformable-body USD prim
+ deformable_body_prim=stage.GetPrimAtPath(prim_path)
+
+ # check if the prim is valid and has the deformable-body API
+ ifnotdeformable_body_prim.IsValid()ornotPhysxSchema.PhysxDeformableBodyAPI(deformable_body_prim):
+ returnFalse
+
+ # retrieve the physx deformable-body api
+ physx_deformable_body_api=PhysxSchema.PhysxDeformableBodyAPI(deformable_body_prim)
+ # retrieve the physx deformable api
+ physx_deformable_api=PhysxSchema.PhysxDeformableAPI(physx_deformable_body_api)
+
+ # convert to dict
+ cfg=cfg.to_dict()
+ # set into deformable body API
+ attr_kwargs={
+ attr_name:cfg.pop(attr_name)
+ forattr_namein[
+ "kinematic_enabled",
+ "collision_simplification",
+ "collision_simplification_remeshing",
+ "collision_simplification_remeshing_resolution",
+ "collision_simplification_target_triangle_count",
+ "collision_simplification_force_conforming",
+ "simulation_hexahedral_resolution",
+ "solver_position_iteration_count",
+ "vertex_velocity_damping",
+ "sleep_damping",
+ "sleep_threshold",
+ "settling_threshold",
+ "self_collision",
+ "self_collision_filter_distance",
+ ]
+ }
+ status=deformable_utils.add_physx_deformable_body(stage,prim_path=prim_path,**attr_kwargs)
+ # check if the deformable body was successfully added
+ ifnotstatus:
+ returnFalse
+
+ # obtain the PhysX collision API (this is set when the deformable body is added)
+ physx_collision_api=PhysxSchema.PhysxCollisionAPI(deformable_body_prim)
+
+ # set into PhysX API
+ forattr_name,valueincfg.items():
+ ifattr_namein["rest_offset","contact_offset"]:
+ safe_set_attribute_on_usd_schema(physx_collision_api,attr_name,value,camel_case=True)
+ else:
+ safe_set_attribute_on_usd_schema(physx_deformable_api,attr_name,value,camel_case=True)
+
+ # success
+ returnTrue
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+fromtypingimportLiteral
+
+fromomni.isaac.lab.utilsimportconfigclass
+
+
+
[文档]@configclass
+classArticulationRootPropertiesCfg:
+"""Properties to apply to the root of an articulation.
+
+ See :meth:`modify_articulation_root_properties` for more information.
+
+ .. note::
+ If the values are None, they are not modified. This is useful when you want to set only a subset of
+ the properties and leave the rest as-is.
+ """
+
+ articulation_enabled:bool|None=None
+"""Whether to enable or disable articulation."""
+
+ enabled_self_collisions:bool|None=None
+"""Whether to enable or disable self-collisions."""
+
+ solver_position_iteration_count:int|None=None
+"""Solver position iteration counts for the body."""
+
+ solver_velocity_iteration_count:int|None=None
+"""Solver position iteration counts for the body."""
+
+ sleep_threshold:float|None=None
+"""Mass-normalized kinetic energy threshold below which an actor may go to sleep."""
+
+ stabilization_threshold:float|None=None
+"""The mass-normalized kinetic energy threshold below which an articulation may participate in stabilization."""
+
+ fix_root_link:bool|None=None
+"""Whether to fix the root link of the articulation.
+
+ * If set to None, the root link is not modified.
+ * If the articulation already has a fixed root link, this flag will enable or disable the fixed joint.
+ * If the articulation does not have a fixed root link, this flag will create a fixed joint between the world
+ frame and the root link. The joint is created with the name "FixedJoint" under the articulation prim.
+
+ .. note::
+ This is a non-USD schema property. It is handled by the :meth:`modify_articulation_root_properties` function.
+
+ """
+
+
+
[文档]@configclass
+classRigidBodyPropertiesCfg:
+"""Properties to apply to a rigid body.
+
+ See :meth:`modify_rigid_body_properties` for more information.
+
+ .. note::
+ If the values are None, they are not modified. This is useful when you want to set only a subset of
+ the properties and leave the rest as-is.
+ """
+
+ rigid_body_enabled:bool|None=None
+"""Whether to enable or disable the rigid body."""
+
+ kinematic_enabled:bool|None=None
+"""Determines whether the body is kinematic or not.
+
+ A kinematic body is a body that is moved through animated poses or through user defined poses. The simulation
+ still derives velocities for the kinematic body based on the external motion.
+
+ For more information on kinematic bodies, please refer to the `documentation <https://openusd.org/release/wp_rigid_body_physics.html#kinematic-bodies>`_.
+ """
+
+ disable_gravity:bool|None=None
+"""Disable gravity for the actor."""
+
+ linear_damping:float|None=None
+"""Linear damping for the body."""
+
+ angular_damping:float|None=None
+"""Angular damping for the body."""
+
+ max_linear_velocity:float|None=None
+"""Maximum linear velocity for rigid bodies (in m/s)."""
+
+ max_angular_velocity:float|None=None
+"""Maximum angular velocity for rigid bodies (in deg/s)."""
+
+ max_depenetration_velocity:float|None=None
+"""Maximum depenetration velocity permitted to be introduced by the solver (in m/s)."""
+
+ max_contact_impulse:float|None=None
+"""The limit on the impulse that may be applied at a contact."""
+
+ enable_gyroscopic_forces:bool|None=None
+"""Enables computation of gyroscopic forces on the rigid body."""
+
+ retain_accelerations:bool|None=None
+"""Carries over forces/accelerations over sub-steps."""
+
+ solver_position_iteration_count:int|None=None
+"""Solver position iteration counts for the body."""
+
+ solver_velocity_iteration_count:int|None=None
+"""Solver position iteration counts for the body."""
+
+ sleep_threshold:float|None=None
+"""Mass-normalized kinetic energy threshold below which an actor may go to sleep."""
+
+ stabilization_threshold:float|None=None
+"""The mass-normalized kinetic energy threshold below which an actor may participate in stabilization."""
+
+
+
[文档]@configclass
+classCollisionPropertiesCfg:
+"""Properties to apply to colliders in a rigid body.
+
+ See :meth:`modify_collision_properties` for more information.
+
+ .. note::
+ If the values are None, they are not modified. This is useful when you want to set only a subset of
+ the properties and leave the rest as-is.
+ """
+
+ collision_enabled:bool|None=None
+"""Whether to enable or disable collisions."""
+
+ contact_offset:float|None=None
+"""Contact offset for the collision shape (in m).
+
+ The collision detector generates contact points as soon as two shapes get closer than the sum of their
+ contact offsets. This quantity should be non-negative which means that contact generation can potentially start
+ before the shapes actually penetrate.
+ """
+
+ rest_offset:float|None=None
+"""Rest offset for the collision shape (in m).
+
+ The rest offset quantifies how close a shape gets to others at rest, At rest, the distance between two
+ vertically stacked objects is the sum of their rest offsets. If a pair of shapes have a positive rest
+ offset, the shapes will be separated at rest by an air gap.
+ """
+
+ torsional_patch_radius:float|None=None
+"""Radius of the contact patch for applying torsional friction (in m).
+
+ It is used to approximate rotational friction introduced by the compression of contacting surfaces.
+ If the radius is zero, no torsional friction is applied.
+ """
+
+ min_torsional_patch_radius:float|None=None
+"""Minimum radius of the contact patch for applying torsional friction (in m)."""
+
+
+
[文档]@configclass
+classMassPropertiesCfg:
+"""Properties to define explicit mass properties of a rigid body.
+
+ See :meth:`modify_mass_properties` for more information.
+
+ .. note::
+ If the values are None, they are not modified. This is useful when you want to set only a subset of
+ the properties and leave the rest as-is.
+ """
+
+ mass:float|None=None
+"""The mass of the rigid body (in kg).
+
+ Note:
+ If non-zero, the mass is ignored and the density is used to compute the mass.
+ """
+
+ density:float|None=None
+"""The density of the rigid body (in kg/m^3).
+
+ The density indirectly defines the mass of the rigid body. It is generally computed using the collision
+ approximation of the body.
+ """
+
+
+
[文档]@configclass
+classJointDrivePropertiesCfg:
+"""Properties to define the drive mechanism of a joint.
+
+ See :meth:`modify_joint_drive_properties` for more information.
+
+ .. note::
+ If the values are None, they are not modified. This is useful when you want to set only a subset of
+ the properties and leave the rest as-is.
+ """
+
+ drive_type:Literal["force","acceleration"]|None=None
+"""Joint drive type to apply.
+
+ If the drive type is "force", then the joint is driven by a force. If the drive type is "acceleration",
+ then the joint is driven by an acceleration (usually used for kinematic joints).
+ """
+
+
+
[文档]@configclass
+classFixedTendonPropertiesCfg:
+"""Properties to define fixed tendons of an articulation.
+
+ See :meth:`modify_fixed_tendon_properties` for more information.
+
+ .. note::
+ If the values are None, they are not modified. This is useful when you want to set only a subset of
+ the properties and leave the rest as-is.
+ """
+
+ tendon_enabled:bool|None=None
+"""Whether to enable or disable the tendon."""
+
+ stiffness:float|None=None
+"""Spring stiffness term acting on the tendon's length."""
+
+ damping:float|None=None
+"""The damping term acting on both the tendon length and the tendon-length limits."""
+
+ limit_stiffness:float|None=None
+"""Limit stiffness term acting on the tendon's length limits."""
+
+ offset:float|None=None
+"""Length offset term for the tendon.
+
+ It defines an amount to be added to the accumulated length computed for the tendon. This allows the application
+ to actuate the tendon by shortening or lengthening it.
+ """
+
+ rest_length:float|None=None
+"""Spring rest length of the tendon."""
+
+
+
[文档]@configclass
+classDeformableBodyPropertiesCfg:
+"""Properties to apply to a deformable body.
+
+ A deformable body is a body that can deform under forces. The configuration allows users to specify
+ the properties of the deformable body, such as the solver iteration counts, damping, and self-collision.
+
+ An FEM-based deformable body is created by providing a collision mesh and simulation mesh. The collision mesh
+ is used for collision detection and the simulation mesh is used for simulation. The collision mesh is usually
+ a simplified version of the simulation mesh.
+
+ Based on the above, the PhysX team provides APIs to either set the simulation and collision mesh directly
+ (by specifying the points) or to simplify the collision mesh based on the simulation mesh. The simplification
+ process involves remeshing the collision mesh and simplifying it based on the target triangle count.
+
+ Since specifying the collision mesh points directly is not a common use case, we only expose the parameters
+ to simplify the collision mesh based on the simulation mesh. If you want to provide the collision mesh points,
+ please open an issue on the repository and we can add support for it.
+
+ See :meth:`modify_deformable_body_properties` for more information.
+
+ .. note::
+ If the values are :obj:`None`, they are not modified. This is useful when you want to set only a subset of
+ the properties and leave the rest as-is.
+ """
+
+ deformable_enabled:bool|None=None
+"""Enables deformable body."""
+
+ kinematic_enabled:bool=False
+"""Enables kinematic body. Defaults to False, which means that the body is not kinematic.
+
+ Similar to rigid bodies, this allows setting user-driven motion for the deformable body. For more information,
+ please refer to the `documentation <https://nvidia-omniverse.github.io/PhysX/physx/5.4.1/docs/SoftBodies.html#kinematic-soft-bodies>`__.
+ """
+
+ self_collision:bool|None=None
+"""Whether to enable or disable self-collisions for the deformable body based on the rest position distances."""
+
+ self_collision_filter_distance:float|None=None
+"""Penetration value that needs to get exceeded before contacts for self collision are generated.
+
+ This parameter must be greater than of equal to twice the :attr:`rest_offset` value.
+
+ This value has an effect only if :attr:`self_collision` is enabled.
+ """
+
+ settling_threshold:float|None=None
+"""Threshold vertex velocity (in m/s) under which sleep damping is applied in addition to velocity damping."""
+
+ sleep_damping:float|None=None
+"""Coefficient for the additional damping term if fertex velocity drops below setting threshold."""
+
+ sleep_threshold:float|None=None
+"""The velocity threshold (in m/s) under which the vertex becomes a candidate for sleeping in the next step."""
+
+ solver_position_iteration_count:int|None=None
+"""Number of the solver positional iterations per step. Range is [1,255]"""
+
+ vertex_velocity_damping:float|None=None
+"""Coefficient for artificial damping on the vertex velocity.
+
+ This parameter can be used to approximate the effect of air drag on the deformable body.
+ """
+
+ simulation_hexahedral_resolution:int=10
+"""The target resolution for the hexahedral mesh used for simulation. Defaults to 10.
+
+ Note:
+ This value is ignored if the user provides the simulation mesh points directly. However, we assume that
+ most users will not provide the simulation mesh points directly. If you want to provide the simulation mesh
+ directly, please set this value to :obj:`None`.
+ """
+
+ collision_simplification:bool=True
+"""Whether or not to simplify the collision mesh before creating a soft body out of it. Defaults to True.
+
+ Note:
+ This flag is ignored if the user provides the simulation mesh points directly. However, we assume that
+ most users will not provide the simulation mesh points directly. Hence, this flag is enabled by default.
+
+ If you want to provide the simulation mesh points directly, please set this flag to False.
+ """
+
+ collision_simplification_remeshing:bool=True
+"""Whether or not the collision mesh should be remeshed before simplification. Defaults to True.
+
+ This parameter is ignored if :attr:`collision_simplification` is False.
+ """
+
+ collision_simplification_remeshing_resolution:int=0
+"""The resolution used for remeshing. Defaults to 0, which means that a heuristic is used to determine the
+ resolution.
+
+ This parameter is ignored if :attr:`collision_simplification_remeshing` is False.
+ """
+
+ collision_simplification_target_triangle_count:int=0
+"""The target triangle count used for the simplification. Defaults to 0, which means that a heuristic based on
+ the :attr:`simulation_hexahedral_resolution` is used to determine the target count.
+
+ This parameter is ignored if :attr:`collision_simplification` is False.
+ """
+
+ collision_simplification_force_conforming:bool=True
+"""Whether or not the simplification should force the output mesh to conform to the input mesh. Defaults to True.
+
+ The flag indicates that the tretrahedralizer used to generate the collision mesh should produce tetrahedra
+ that conform to the triangle mesh. If False, the simplifier uses the output from the tretrahedralizer used.
+
+ This parameter is ignored if :attr:`collision_simplification` is False.
+ """
+
+ contact_offset:float|None=None
+"""Contact offset for the collision shape (in m).
+
+ The collision detector generates contact points as soon as two shapes get closer than the sum of their
+ contact offsets. This quantity should be non-negative which means that contact generation can potentially start
+ before the shapes actually penetrate.
+ """
+
+ rest_offset:float|None=None
+"""Rest offset for the collision shape (in m).
+
+ The rest offset quantifies how close a shape gets to others at rest, At rest, the distance between two
+ vertically stacked objects is the sum of their rest offsets. If a pair of shapes have a positive rest
+ offset, the shapes will be separated at rest by an air gap.
+ """
+
+ max_depenetration_velocity:float|None=None
+"""Maximum depenetration velocity permitted to be introduced by the solver (in m/s)."""
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Base configuration of the environment.
+
+This module defines the general configuration of the environment. It includes parameters for
+configuring the environment instances, viewer settings, and simulation parameters.
+"""
+
+fromtypingimportLiteral
+
+fromomni.isaac.lab.utilsimportconfigclass
+
+from.spawners.materialsimportRigidBodyMaterialCfg
+
+
+
[文档]@configclass
+classPhysxCfg:
+"""Configuration for PhysX solver-related parameters.
+
+ These parameters are used to configure the PhysX solver. For more information, see the `PhysX 5 SDK
+ documentation`_.
+
+ PhysX 5 supports GPU-accelerated physics simulation. This is enabled by default, but can be disabled
+ by setting the :attr:`~SimulationCfg.device` to ``cpu`` in :class:`SimulationCfg`. Unlike CPU PhysX, the GPU
+ simulation feature is unable to dynamically grow all the buffers. Therefore, it is necessary to provide
+ a reasonable estimate of the buffer sizes for GPU features. If insufficient buffer sizes are provided, the
+ simulation will fail with errors and lead to adverse behaviors. The buffer sizes can be adjusted through the
+ ``gpu_*`` parameters.
+
+ .. _PhysX 5 SDK documentation: https://nvidia-omniverse.github.io/PhysX/physx/5.4.1/_api_build/classPxSceneDesc.html
+
+ """
+
+ solver_type:Literal[0,1]=1
+"""The type of solver to use.Default is 1 (TGS).
+
+ Available solvers:
+
+ * :obj:`0`: PGS (Projective Gauss-Seidel)
+ * :obj:`1`: TGS (Truncated Gauss-Seidel)
+ """
+
+ min_position_iteration_count:int=1
+"""Minimum number of solver position iterations (rigid bodies, cloth, particles etc.). Default is 1.
+
+ .. note::
+
+ Each physics actor in Omniverse specifies its own solver iteration count. The solver takes
+ the number of iterations specified by the actor with the highest iteration and clamps it to
+ the range ``[min_position_iteration_count, max_position_iteration_count]``.
+ """
+
+ max_position_iteration_count:int=255
+"""Maximum number of solver position iterations (rigid bodies, cloth, particles etc.). Default is 255.
+
+ .. note::
+
+ Each physics actor in Omniverse specifies its own solver iteration count. The solver takes
+ the number of iterations specified by the actor with the highest iteration and clamps it to
+ the range ``[min_position_iteration_count, max_position_iteration_count]``.
+ """
+
+ min_velocity_iteration_count:int=0
+"""Minimum number of solver velocity iterations (rigid bodies, cloth, particles etc.). Default is 0.
+
+ .. note::
+
+ Each physics actor in Omniverse specifies its own solver iteration count. The solver takes
+ the number of iterations specified by the actor with the highest iteration and clamps it to
+ the range ``[min_velocity_iteration_count, max_velocity_iteration_count]``.
+ """
+
+ max_velocity_iteration_count:int=255
+"""Maximum number of solver velocity iterations (rigid bodies, cloth, particles etc.). Default is 255.
+
+ .. note::
+
+ Each physics actor in Omniverse specifies its own solver iteration count. The solver takes
+ the number of iterations specified by the actor with the highest iteration and clamps it to
+ the range ``[min_velocity_iteration_count, max_velocity_iteration_count]``.
+ """
+
+ enable_ccd:bool=False
+"""Enable a second broad-phase pass that makes it possible to prevent objects from tunneling through each other.
+ Default is False."""
+
+ enable_stabilization:bool=True
+"""Enable/disable additional stabilization pass in solver. Default is True."""
+
+ enable_enhanced_determinism:bool=False
+"""Enable/disable improved determinism at the expense of performance. Defaults to False.
+
+ For more information on PhysX determinism, please check `here`_.
+
+ .. _here: https://nvidia-omniverse.github.io/PhysX/physx/5.4.1/docs/RigidBodyDynamics.html#enhanced-determinism
+ """
+
+ bounce_threshold_velocity:float=0.5
+"""Relative velocity threshold for contacts to bounce (in m/s). Default is 0.5 m/s."""
+
+ friction_offset_threshold:float=0.04
+"""Threshold for contact point to experience friction force (in m). Default is 0.04 m."""
+
+ friction_correlation_distance:float=0.025
+"""Distance threshold for merging contacts into a single friction anchor point (in m). Default is 0.025 m."""
+
+ gpu_max_rigid_contact_count:int=2**23
+"""Size of rigid contact stream buffer allocated in pinned host memory. Default is 2 ** 23."""
+
+ gpu_max_rigid_patch_count:int=5*2**15
+"""Size of the rigid contact patch stream buffer allocated in pinned host memory. Default is 5 * 2 ** 15."""
+
+ gpu_found_lost_pairs_capacity:int=2**21
+"""Capacity of found and lost buffers allocated in GPU global memory. Default is 2 ** 21.
+
+ This is used for the found/lost pair reports in the BP.
+ """
+
+ gpu_found_lost_aggregate_pairs_capacity:int=2**25
+"""Capacity of found and lost buffers in aggregate system allocated in GPU global memory.
+ Default is 2 ** 25.
+
+ This is used for the found/lost pair reports in AABB manager.
+ """
+
+ gpu_total_aggregate_pairs_capacity:int=2**21
+"""Capacity of total number of aggregate pairs allocated in GPU global memory. Default is 2 ** 21."""
+
+ gpu_collision_stack_size:int=2**26
+"""Size of the collision stack buffer allocated in pinned host memory. Default is 2 ** 26."""
+
+ gpu_heap_capacity:int=2**26
+"""Initial capacity of the GPU and pinned host memory heaps. Additional memory will be allocated
+ if more memory is required. Default is 2 ** 26."""
+
+ gpu_temp_buffer_capacity:int=2**24
+"""Capacity of temp buffer allocated in pinned host memory. Default is 2 ** 24."""
+
+ gpu_max_num_partitions:int=8
+"""Limitation for the partitions in the GPU dynamics pipeline. Default is 8.
+
+ This variable must be power of 2. A value greater than 32 is currently not supported. Range: (1, 32)
+ """
+
+ gpu_max_soft_body_contacts:int=2**20
+"""Size of soft body contacts stream buffer allocated in pinned host memory. Default is 2 ** 20."""
+
+ gpu_max_particle_contacts:int=2**20
+"""Size of particle contacts stream buffer allocated in pinned host memory. Default is 2 ** 20."""
+
+
+
[文档]@configclass
+classSimulationCfg:
+"""Configuration for simulation physics."""
+
+ physics_prim_path:str="/physicsScene"
+"""The prim path where the USD PhysicsScene is created. Default is "/physicsScene"."""
+
+ device:str="cuda:0"
+"""The device to run the simulation on. Default is ``"cuda:0"``.
+
+ Valid options are:
+
+ - ``"cpu"``: Use CPU.
+ - ``"cuda"``: Use GPU, where the device ID is inferred from :class:`~omni.isaac.lab.app.AppLauncher`'s config.
+ - ``"cuda:N"``: Use GPU, where N is the device ID. For example, "cuda:0".
+ """
+
+ dt:float=1.0/60.0
+"""The physics simulation time-step (in seconds). Default is 0.0167 seconds."""
+
+ render_interval:int=1
+"""The number of physics simulation steps per rendering step. Default is 1."""
+
+ gravity:tuple[float,float,float]=(0.0,0.0,-9.81)
+"""The gravity vector (in m/s^2). Default is (0.0, 0.0, -9.81).
+
+ If set to (0.0, 0.0, 0.0), gravity is disabled.
+ """
+
+ enable_scene_query_support:bool=False
+"""Enable/disable scene query support for collision shapes. Default is False.
+
+ This flag allows performing collision queries (raycasts, sweeps, and overlaps) on actors and
+ attached shapes in the scene. This is useful for implementing custom collision detection logic
+ outside of the physics engine.
+
+ If set to False, the physics engine does not create the scene query manager and the scene query
+ functionality will not be available. However, this provides some performance speed-up.
+
+ Note:
+ This flag is overridden to True inside the :class:`SimulationContext` class when running the simulation
+ with the GUI enabled. This is to allow certain GUI features to work properly.
+ """
+
+ use_fabric:bool=True
+"""Enable/disable reading of physics buffers directly. Default is True.
+
+ When running the simulation, updates in the states in the scene is normally synchronized with USD.
+ This leads to an overhead in reading the data and does not scale well with massive parallelization.
+ This flag allows disabling the synchronization and reading the data directly from the physics buffers.
+
+ It is recommended to set this flag to :obj:`True` when running the simulation with a large number
+ of primitives in the scene.
+
+ Note:
+ When enabled, the GUI will not update the physics parameters in real-time. To enable real-time
+ updates, please set this flag to :obj:`False`.
+ """
+
+ disable_contact_processing:bool=False
+"""Enable/disable contact processing. Default is False.
+
+ By default, the physics engine processes all the contacts in the scene. However, reporting this contact
+ information can be expensive due to its combinatorial complexity. This flag allows disabling the contact
+ processing and querying the contacts manually by the user over a limited set of primitives in the scene.
+
+ .. note::
+
+ It is required to set this flag to :obj:`True` when using the TensorAPIs for contact reporting.
+ """
+
+ physx:PhysxCfg=PhysxCfg()
+"""PhysX solver settings. Default is PhysxCfg()."""
+
+ physics_material:RigidBodyMaterialCfg=RigidBodyMaterialCfg()
+"""Default physics material settings for rigid bodies. Default is RigidBodyMaterialCfg().
+
+ The physics engine defaults to this physics material for all the rigid body prims that do not have any
+ physics material specified on them.
+
+ The material is created at the path: ``{physics_prim_path}/defaultMaterial``.
+ """
[文档]classSimulationContext(_SimulationContext):
+"""A class to control simulation-related events such as physics stepping and rendering.
+
+ The simulation context helps control various simulation aspects. This includes:
+
+ * configure the simulator with different settings such as the physics time-step, the number of physics substeps,
+ and the physics solver parameters (for more information, see :class:`omni.isaac.lab.sim.SimulationCfg`)
+ * playing, pausing, stepping and stopping the simulation
+ * adding and removing callbacks to different simulation events such as physics stepping, rendering, etc.
+
+ This class inherits from the :class:`omni.isaac.core.simulation_context.SimulationContext` class and
+ adds additional functionalities such as setting up the simulation context with a configuration object,
+ exposing other commonly used simulator-related functions, and performing version checks of Isaac Sim
+ to ensure compatibility between releases.
+
+ The simulation context is a singleton object. This means that there can only be one instance
+ of the simulation context at any given time. This is enforced by the parent class. Therefore, it is
+ not possible to create multiple instances of the simulation context. Instead, the simulation context
+ can be accessed using the ``instance()`` method.
+
+ .. attention::
+ Since we only support the `PyTorch <https://pytorch.org/>`_ backend for simulation, the
+ simulation context is configured to use the ``torch`` backend by default. This means that
+ all the data structures used in the simulation are ``torch.Tensor`` objects.
+
+ The simulation context can be used in two different modes of operations:
+
+ 1. **Standalone python script**: In this mode, the user has full control over the simulation and
+ can trigger stepping events synchronously (i.e. as a blocking call). In this case the user
+ has to manually call :meth:`step` step the physics simulation and :meth:`render` to
+ render the scene.
+ 2. **Omniverse extension**: In this mode, the user has limited control over the simulation stepping
+ and all the simulation events are triggered asynchronously (i.e. as a non-blocking call). In this
+ case, the user can only trigger the simulation to start, pause, and stop. The simulation takes
+ care of stepping the physics simulation and rendering the scene.
+
+ Based on above, for most functions in this class there is an equivalent function that is suffixed
+ with ``_async``. The ``_async`` functions are used in the Omniverse extension mode and
+ the non-``_async`` functions are used in the standalone python script mode.
+ """
+
+
[文档]classRenderMode(enum.IntEnum):
+"""Different rendering modes for the simulation.
+
+ Render modes correspond to how the viewport and other UI elements (such as listeners to keyboard or mouse
+ events) are updated. There are three main components that can be updated when the simulation is rendered:
+
+ 1. **UI elements and other extensions**: These are UI elements (such as buttons, sliders, etc.) and other
+ extensions that are running in the background that need to be updated when the simulation is running.
+ 2. **Cameras**: These are typically based on Hydra textures and are used to render the scene from different
+ viewpoints. They can be attached to a viewport or be used independently to render the scene.
+ 3. **Viewports**: These are windows where you can see the rendered scene.
+
+ Updating each of the above components has a different overhead. For example, updating the viewports is
+ computationally expensive compared to updating the UI elements. Therefore, it is useful to be able to
+ control what is updated when the simulation is rendered. This is where the render mode comes in. There are
+ four different render modes:
+
+ * :attr:`NO_GUI_OR_RENDERING`: The simulation is running without a GUI and off-screen rendering flag is disabled,
+ so none of the above are updated.
+ * :attr:`NO_RENDERING`: No rendering, where only 1 is updated at a lower rate.
+ * :attr:`PARTIAL_RENDERING`: Partial rendering, where only 1 and 2 are updated.
+ * :attr:`FULL_RENDERING`: Full rendering, where everything (1, 2, 3) is updated.
+
+ .. _Viewports: https://docs.omniverse.nvidia.com/extensions/latest/ext_viewport.html
+ """
+
+ NO_GUI_OR_RENDERING=-1
+"""The simulation is running without a GUI and off-screen rendering is disabled."""
+ NO_RENDERING=0
+"""No rendering, where only other UI elements are updated at a lower rate."""
+ PARTIAL_RENDERING=1
+"""Partial rendering, where the simulation cameras and UI elements are updated."""
+ FULL_RENDERING=2
+"""Full rendering, where all the simulation viewports, cameras and UI elements are updated."""
+
+
[文档]def__init__(self,cfg:SimulationCfg|None=None):
+"""Creates a simulation context to control the simulator.
+
+ Args:
+ cfg: The configuration of the simulation. Defaults to None,
+ in which case the default configuration is used.
+ """
+ # store input
+ ifcfgisNone:
+ cfg=SimulationCfg()
+ self.cfg=cfg
+ # check that simulation is running
+ ifstage_utils.get_current_stage()isNone:
+ raiseRuntimeError("The stage has not been created. Did you run the simulator?")
+
+ # set flags for simulator
+ # acquire settings interface
+ carb_settings_iface=carb.settings.get_settings()
+ # enable hydra scene-graph instancing
+ # note: this allows rendering of instanceable assets on the GUI
+ carb_settings_iface.set_bool("/persistent/omnihydra/useSceneGraphInstancing",True)
+ # change dispatcher to use the default dispatcher in PhysX SDK instead of carb tasking
+ # note: dispatcher handles how threads are launched for multi-threaded physics
+ carb_settings_iface.set_bool("/physics/physxDispatcher",True)
+ # disable contact processing in omni.physx if requested
+ # note: helpful when creating contact reporting over limited number of objects in the scene
+ ifself.cfg.disable_contact_processing:
+ carb_settings_iface.set_bool("/physics/disableContactProcessing",True)
+ # enable custom geometry for cylinder and cone collision shapes to allow contact reporting for them
+ # reason: cylinders and cones aren't natively supported by PhysX so we need to use custom geometry flags
+ # reference: https://nvidia-omniverse.github.io/PhysX/physx/5.4.1/docs/Geometry.html?highlight=capsule#geometry
+ carb_settings_iface.set_bool("/physics/collisionConeCustomGeometry",False)
+ carb_settings_iface.set_bool("/physics/collisionCylinderCustomGeometry",False)
+ # note: we read this once since it is not expected to change during runtime
+ # read flag for whether a local GUI is enabled
+ self._local_gui=carb_settings_iface.get("/app/window/enabled")
+ # read flag for whether livestreaming GUI is enabled
+ self._livestream_gui=carb_settings_iface.get("/app/livestream/enabled")
+
+ # read flag for whether the Isaac Lab viewport capture pipeline will be used,
+ # casting None to False if the flag doesn't exist
+ # this flag is set from the AppLauncher class
+ self._offscreen_render=bool(carb_settings_iface.get("/isaaclab/render/offscreen"))
+ # read flag for whether the default viewport should be enabled
+ self._render_viewport=bool(carb_settings_iface.get("/isaaclab/render/active_viewport"))
+ # flag for whether any GUI will be rendered (local, livestreamed or viewport)
+ self._has_gui=self._local_guiorself._livestream_gui
+
+ # store the default render mode
+ ifnotself._has_guiandnotself._offscreen_render:
+ # set default render mode
+ # note: this is the terminal state: cannot exit from this render mode
+ self.render_mode=self.RenderMode.NO_GUI_OR_RENDERING
+ # set viewport context to None
+ self._viewport_context=None
+ self._viewport_window=None
+ elifnotself._has_guiandself._offscreen_render:
+ # set default render mode
+ # note: this is the terminal state: cannot exit from this render mode
+ self.render_mode=self.RenderMode.PARTIAL_RENDERING
+ # set viewport context to None
+ self._viewport_context=None
+ self._viewport_window=None
+ else:
+ # note: need to import here in case the UI is not available (ex. headless mode)
+ importomni.uiasui
+ fromomni.kit.viewport.utilityimportget_active_viewport
+
+ # set default render mode
+ # note: this can be changed by calling the `set_render_mode` function
+ self.render_mode=self.RenderMode.FULL_RENDERING
+ # acquire viewport context
+ self._viewport_context=get_active_viewport()
+ self._viewport_context.updates_enabled=True# pyright: ignore [reportOptionalMemberAccess]
+ # acquire viewport window
+ # TODO @mayank: Why not just use get_active_viewport_and_window() directly?
+ self._viewport_window=ui.Workspace.get_window("Viewport")
+ # counter for periodic rendering
+ self._render_throttle_counter=0
+ # rendering frequency in terms of number of render calls
+ self._render_throttle_period=5
+
+ # check the case where we don't need to render the viewport
+ # since render_viewport can only be False in headless mode, we only need to check for offscreen_render
+ ifnotself._render_viewportandself._offscreen_render:
+ # disable the viewport if offscreen_render is enabled
+ fromomni.kit.viewport.utilityimportget_active_viewport
+
+ get_active_viewport().updates_enabled=False
+
+ # override enable scene querying if rendering is enabled
+ # this is needed for some GUI features
+ ifself._has_gui:
+ self.cfg.enable_scene_query_support=True
+ # set up flatcache/fabric interface (default is None)
+ # this is needed to flush the flatcache data into Hydra manually when calling `render()`
+ # ref: https://docs.omniverse.nvidia.com/prod_extensions/prod_extensions/ext_physics.html
+ # note: need to do this here because super().__init__ calls render and this variable is needed
+ self._fabric_iface=None
+ # read isaac sim version (this includes build tag, release tag etc.)
+ # note: we do it once here because it reads the VERSION file from disk and is not expected to change.
+ self._isaacsim_version=get_version()
+
+ # create a tensor for gravity
+ # note: this line is needed to create a "tensor" in the device to avoid issues with torch 2.1 onwards.
+ # the issue is with some heap memory corruption when torch tensor is created inside the asset class.
+ # you can reproduce the issue by commenting out this line and running the test `test_articulation.py`.
+ self._gravity_tensor=torch.tensor(self.cfg.gravity,dtype=torch.float32,device=self.cfg.device)
+
+ # add callback to deal the simulation app when simulation is stopped.
+ # this is needed because physics views go invalid once we stop the simulation
+ ifnotbuiltins.ISAAC_LAUNCHED_FROM_TERMINAL:
+ timeline_event_stream=omni.timeline.get_timeline_interface().get_timeline_event_stream()
+ self._app_control_on_stop_handle=timeline_event_stream.create_subscription_to_pop_by_type(
+ int(omni.timeline.TimelineEventType.STOP),
+ lambda*args,obj=weakref.proxy(self):obj._app_control_on_stop_callback(*args),
+ order=15,
+ )
+ else:
+ self._app_control_on_stop_handle=None
+
+ # flatten out the simulation dictionary
+ sim_params=self.cfg.to_dict()
+ ifsim_paramsisnotNone:
+ if"physx"insim_params:
+ physx_params=sim_params.pop("physx")
+ sim_params.update(physx_params)
+ # create a simulation context to control the simulator
+ super().__init__(
+ stage_units_in_meters=1.0,
+ physics_dt=self.cfg.dt,
+ rendering_dt=self.cfg.dt*self.cfg.render_interval,
+ backend="torch",
+ sim_params=sim_params,
+ physics_prim_path=self.cfg.physics_prim_path,
+ device=self.cfg.device,
+ )
+
+"""
+ Operations - New.
+ """
+
+
[文档]defhas_gui(self)->bool:
+"""Returns whether the simulation has a GUI enabled.
+
+ True if the simulation has a GUI enabled either locally or live-streamed.
+ """
+ returnself._has_gui
+
+
[文档]defhas_rtx_sensors(self)->bool:
+"""Returns whether the simulation has any RTX-rendering related sensors.
+
+ This function returns the value of the simulation parameter ``"/isaaclab/render/rtx_sensors"``.
+ The parameter is set to True when instances of RTX-related sensors (cameras or LiDARs) are
+ created using Isaac Lab's sensor classes.
+
+ True if the simulation has RTX sensors (such as USD Cameras or LiDARs).
+
+ For more information, please check `NVIDIA RTX documentation`_.
+
+ .. _NVIDIA RTX documentation: https://developer.nvidia.com/rendering-technologies
+ """
+ returnself._settings.get_as_bool("/isaaclab/render/rtx_sensors")
+
+
[文档]defis_fabric_enabled(self)->bool:
+"""Returns whether the fabric interface is enabled.
+
+ When fabric interface is enabled, USD read/write operations are disabled. Instead all applications
+ read and write the simulation state directly from the fabric interface. This reduces a lot of overhead
+ that occurs during USD read/write operations.
+
+ For more information, please check `Fabric documentation`_.
+
+ .. _Fabric documentation: https://docs.omniverse.nvidia.com/kit/docs/usdrt/latest/docs/usd_fabric_usdrt.html
+ """
+ returnself._fabric_ifaceisnotNone
+
+
[文档]defget_version(self)->tuple[int,int,int]:
+"""Returns the version of the simulator.
+
+ This is a wrapper around the ``omni.isaac.version.get_version()`` function.
+
+ The returned tuple contains the following information:
+
+ * Major version (int): This is the year of the release (e.g. 2022).
+ * Minor version (int): This is the half-year of the release (e.g. 1 or 2).
+ * Patch version (int): This is the patch number of the release (e.g. 0).
+ """
+ returnint(self._isaacsim_version[2]),int(self._isaacsim_version[3]),int(self._isaacsim_version[4])
+
+"""
+ Operations - New utilities.
+ """
+
+
[文档]@staticmethod
+ defset_camera_view(
+ eye:tuple[float,float,float],
+ target:tuple[float,float,float],
+ camera_prim_path:str="/OmniverseKit_Persp",
+ ):
+"""Set the location and target of the viewport camera in the stage.
+
+ Note:
+ This is a wrapper around the :math:`omni.isaac.core.utils.viewports.set_camera_view` function.
+ It is provided here for convenience to reduce the amount of imports needed.
+
+ Args:
+ eye: The location of the camera eye.
+ target: The location of the camera target.
+ camera_prim_path: The path to the camera primitive in the stage. Defaults to
+ "/OmniverseKit_Persp".
+ """
+ set_camera_view(eye,target,camera_prim_path)
+
+
[文档]defset_render_mode(self,mode:RenderMode):
+"""Change the current render mode of the simulation.
+
+ Please see :class:`RenderMode` for more information on the different render modes.
+
+ .. note::
+ When no GUI is available (locally or livestreamed), we do not need to choose whether the viewport
+ needs to render or not (since there is no GUI). Thus, in this case, calling the function will not
+ change the render mode.
+
+ Args:
+ mode (RenderMode): The rendering mode. If different than SimulationContext's rendering mode,
+ SimulationContext's mode is changed to the new mode.
+
+ Raises:
+ ValueError: If the input mode is not supported.
+ """
+ # check if mode change is possible -- not possible when no GUI is available
+ ifnotself._has_gui:
+ carb.log_warn(
+ f"Cannot change render mode when GUI is disabled. Using the default render mode: {self.render_mode}."
+ )
+ return
+ # check if there is a mode change
+ # note: this is mostly needed for GUI when we want to switch between full rendering and no rendering.
+ ifmode!=self.render_mode:
+ ifmode==self.RenderMode.FULL_RENDERING:
+ # display the viewport and enable updates
+ self._viewport_context.updates_enabled=True# pyright: ignore [reportOptionalMemberAccess]
+ self._viewport_window.visible=True# pyright: ignore [reportOptionalMemberAccess]
+ elifmode==self.RenderMode.PARTIAL_RENDERING:
+ # hide the viewport and disable updates
+ self._viewport_context.updates_enabled=False# pyright: ignore [reportOptionalMemberAccess]
+ self._viewport_window.visible=False# pyright: ignore [reportOptionalMemberAccess]
+ elifmode==self.RenderMode.NO_RENDERING:
+ # hide the viewport and disable updates
+ ifself._viewport_contextisnotNone:
+ self._viewport_context.updates_enabled=False# pyright: ignore [reportOptionalMemberAccess]
+ self._viewport_window.visible=False# pyright: ignore [reportOptionalMemberAccess]
+ # reset the throttle counter
+ self._render_throttle_counter=0
+ else:
+ raiseValueError(f"Unsupported render mode: {mode}! Please check `RenderMode` for details.")
+ # update render mode
+ self.render_mode=mode
+
+
[文档]defset_setting(self,name:str,value:Any):
+"""Set simulation settings using the Carbonite SDK.
+
+ .. note::
+ If the input setting name does not exist, it will be created. If it does exist, the value will be
+ overwritten. Please make sure to use the correct setting name.
+
+ To understand the settings interface, please refer to the
+ `Carbonite SDK <https://docs.omniverse.nvidia.com/dev-guide/latest/programmer_ref/settings.html>`_
+ documentation.
+
+ Args:
+ name: The name of the setting.
+ value: The value of the setting.
+ """
+ self._settings.set(name,value)
+
+
[文档]defget_setting(self,name:str)->Any:
+"""Read the simulation setting using the Carbonite SDK.
+
+ Args:
+ name: The name of the setting.
+
+ Returns:
+ The value of the setting.
+ """
+ returnself._settings.get(name)
+
+"""
+ Operations - Override (standalone)
+ """
+
+ defreset(self,soft:bool=False):
+ super().reset(soft=soft)
+ # perform additional rendering steps to warm up replicator buffers
+ # this is only needed for the first time we set the simulation
+ ifnotsoft:
+ for_inrange(2):
+ self.render()
+
+
[文档]defstep(self,render:bool=True):
+"""Steps the simulation.
+
+ .. note::
+ This function blocks if the timeline is paused. It only returns when the timeline is playing.
+
+ Args:
+ render: Whether to render the scene after stepping the physics simulation.
+ If set to False, the scene is not rendered and only the physics simulation is stepped.
+ """
+ # check if the simulation timeline is paused. in that case keep stepping until it is playing
+ ifnotself.is_playing():
+ # step the simulator (but not the physics) to have UI still active
+ whilenotself.is_playing():
+ self.render()
+ # meantime if someone stops, break out of the loop
+ ifself.is_stopped():
+ break
+ # need to do one step to refresh the app
+ # reason: physics has to parse the scene again and inform other extensions like hydra-delegate.
+ # without this the app becomes unresponsive.
+ # FIXME: This steps physics as well, which we is not good in general.
+ self.app.update()
+
+ # step the simulation
+ super().step(render=render)
+
+
[文档]defrender(self,mode:RenderMode|None=None):
+"""Refreshes the rendering components including UI elements and view-ports depending on the render mode.
+
+ This function is used to refresh the rendering components of the simulation. This includes updating the
+ view-ports, UI elements, and other extensions (besides physics simulation) that are running in the
+ background. The rendering components are refreshed based on the render mode.
+
+ Please see :class:`RenderMode` for more information on the different render modes.
+
+ Args:
+ mode: The rendering mode. Defaults to None, in which case the current rendering mode is used.
+ """
+ # check if we need to change the render mode
+ ifmodeisnotNone:
+ self.set_render_mode(mode)
+ # render based on the render mode
+ ifself.render_mode==self.RenderMode.NO_GUI_OR_RENDERING:
+ # we never want to render anything here (this is for complete headless mode)
+ pass
+ elifself.render_mode==self.RenderMode.NO_RENDERING:
+ # throttle the rendering frequency to keep the UI responsive
+ self._render_throttle_counter+=1
+ ifself._render_throttle_counter%self._render_throttle_period==0:
+ self._render_throttle_counter=0
+ # here we don't render viewport so don't need to flush fabric data
+ # note: we don't call super().render() anymore because they do flush the fabric data
+ self.set_setting("/app/player/playSimulations",False)
+ self._app.update()
+ self.set_setting("/app/player/playSimulations",True)
+ else:
+ # manually flush the fabric data to update Hydra textures
+ ifself._fabric_ifaceisnotNone:
+ ifself.physics_sim_viewisnotNoneandself.is_playing():
+ # Update the articulations' link's poses before rendering
+ self.physics_sim_view.update_articulations_kinematic()
+ self._update_fabric(0.0,0.0)
+ # render the simulation
+ # note: we don't call super().render() anymore because they do above operation inside
+ # and we don't want to do it twice. We may remove it once we drop support for Isaac Sim 2022.2.
+ self.set_setting("/app/player/playSimulations",False)
+ self._app.update()
+ self.set_setting("/app/player/playSimulations",True)
+
+"""
+ Operations - Override (extension)
+ """
+
+ asyncdefreset_async(self,soft:bool=False):
+ # need to load all "physics" information from the USD file
+ ifnotsoft:
+ omni.physx.acquire_physx_interface().force_load_physics_from_usd()
+ # play the simulation
+ awaitsuper().reset_async(soft=soft)
+
+"""
+ Initialization/Destruction - Override.
+ """
+
+ def_init_stage(self,*args,**kwargs)->Usd.Stage:
+ _=super()._init_stage(*args,**kwargs)
+ # a stage update here is needed for the case when physics_dt != rendering_dt, otherwise the app crashes
+ # when in headless mode
+ self.set_setting("/app/player/playSimulations",False)
+ self._app.update()
+ self.set_setting("/app/player/playSimulations",True)
+ # set additional physx parameters and bind material
+ self._set_additional_physx_params()
+ # load flatcache/fabric interface
+ self._load_fabric_interface()
+ # return the stage
+ returnself.stage
+
+ asyncdef_initialize_stage_async(self,*args,**kwargs)->Usd.Stage:
+ awaitsuper()._initialize_stage_async(*args,**kwargs)
+ # set additional physx parameters and bind material
+ self._set_additional_physx_params()
+ # load flatcache/fabric interface
+ self._load_fabric_interface()
+ # return the stage
+ returnself.stage
+
+ @classmethod
+ defclear_instance(cls):
+ # clear the callback
+ ifcls._instanceisnotNone:
+ ifcls._instance._app_control_on_stop_handleisnotNone:
+ cls._instance._app_control_on_stop_handle.unsubscribe()
+ cls._instance._app_control_on_stop_handle=None
+ # call parent to clear the instance
+ super().clear_instance()
+
+"""
+ Helper Functions
+ """
+
+ def_set_additional_physx_params(self):
+"""Sets additional PhysX parameters that are not directly supported by the parent class."""
+ # obtain the physics scene api
+ physics_scene:UsdPhysics.Scene=self._physics_context._physics_scene
+ physx_scene_api:PhysxSchema.PhysxSceneAPI=self._physics_context._physx_scene_api
+ # assert that scene api is not None
+ ifphysx_scene_apiisNone:
+ raiseRuntimeError("Physics scene API is None! Please create the scene first.")
+ # set parameters not directly supported by the constructor
+ # -- Continuous Collision Detection (CCD)
+ # ref: https://nvidia-omniverse.github.io/PhysX/physx/5.4.1/docs/AdvancedCollisionDetection.html?highlight=ccd#continuous-collision-detection
+ self._physics_context.enable_ccd(self.cfg.physx.enable_ccd)
+ # -- GPU collision stack size
+ physx_scene_api.CreateGpuCollisionStackSizeAttr(self.cfg.physx.gpu_collision_stack_size)
+ # -- Improved determinism by PhysX
+ physx_scene_api.CreateEnableEnhancedDeterminismAttr(self.cfg.physx.enable_enhanced_determinism)
+
+ # -- Gravity
+ # note: Isaac sim only takes the "up-axis" as the gravity direction. But physics allows any direction so we
+ # need to convert the gravity vector to a direction and magnitude pair explicitly.
+ gravity=np.asarray(self.cfg.gravity)
+ gravity_magnitude=np.linalg.norm(gravity)
+
+ # Avoid division by zero
+ ifgravity_magnitude!=0.0:
+ gravity_direction=gravity/gravity_magnitude
+ else:
+ gravity_direction=gravity
+
+ physics_scene.CreateGravityDirectionAttr(Gf.Vec3f(*gravity_direction))
+ physics_scene.CreateGravityMagnitudeAttr(gravity_magnitude)
+
+ # position iteration count
+ physx_scene_api.CreateMinPositionIterationCountAttr(self.cfg.physx.min_position_iteration_count)
+ physx_scene_api.CreateMaxPositionIterationCountAttr(self.cfg.physx.max_position_iteration_count)
+ # velocity iteration count
+ physx_scene_api.CreateMinVelocityIterationCountAttr(self.cfg.physx.min_velocity_iteration_count)
+ physx_scene_api.CreateMaxVelocityIterationCountAttr(self.cfg.physx.max_velocity_iteration_count)
+
+ # create the default physics material
+ # this material is used when no material is specified for a primitive
+ # check: https://docs.omniverse.nvidia.com/extensions/latest/ext_physics/simulation-control/physics-settings.html#physics-materials
+ material_path=f"{self.cfg.physics_prim_path}/defaultMaterial"
+ self.cfg.physics_material.func(material_path,self.cfg.physics_material)
+ # bind the physics material to the scene
+ bind_physics_material(self.cfg.physics_prim_path,material_path)
+
+ def_load_fabric_interface(self):
+"""Loads the fabric interface if enabled."""
+ ifself.cfg.use_fabric:
+ fromomni.physxfabricimportget_physx_fabric_interface
+
+ # acquire fabric interface
+ self._fabric_iface=get_physx_fabric_interface()
+ ifhasattr(self._fabric_iface,"force_update"):
+ # The update method in the fabric interface only performs an update if a physics step has occurred.
+ # However, for rendering, we need to force an update since any element of the scene might have been
+ # modified in a reset (which occurs after the physics step) and we want the renderer to be aware of
+ # these changes.
+ self._update_fabric=self._fabric_iface.force_update
+ else:
+ # Needed for backward compatibility with older Isaac Sim versions
+ self._update_fabric=self._fabric_iface.update
+
+"""
+ Callbacks.
+ """
+
+ def_app_control_on_stop_callback(self,event:carb.events.IEvent):
+"""Callback to deal with the app when the simulation is stopped.
+
+ Once the simulation is stopped, the physics handles go invalid. After that, it is not possible to
+ resume the simulation from the last state. This leaves the app in an inconsistent state, where
+ two possible actions can be taken:
+
+ 1. **Keep the app rendering**: In this case, the simulation is kept running and the app is not shutdown.
+ However, the physics is not updated and the script cannot be resumed from the last state. The
+ user has to manually close the app to stop the simulation.
+ 2. **Shutdown the app**: This is the default behavior. In this case, the app is shutdown and
+ the simulation is stopped.
+
+ Note:
+ This callback is used only when running the simulation in a standalone python script. In an extension,
+ it is expected that the user handles the extension shutdown.
+ """
+ # check if the simulation is stopped
+ ifevent.type==int(omni.timeline.TimelineEventType.STOP):
+ # keep running the simulator when configured to not shutdown the app
+ ifself._has_guiandsys.exc_info()[0]isNone:
+ self.app.print_and_log(
+ "Simulation is stopped. The app will keep running with physics disabled."
+ " Press Ctrl+C or close the window to exit the app."
+ )
+ whileself.app.is_running():
+ self.render()
+
+ # Note: For the following code:
+ # The method is an exact copy of the implementation in the `omni.isaac.kit.SimulationApp` class.
+ # We need to remove this method once the SimulationApp class becomes a singleton.
+
+ # make sure that any replicator workflows finish rendering/writing
+ try:
+ importomni.replicator.coreasrep
+
+ rep_status=rep.orchestrator.get_status()
+ ifrep_statusnotin[rep.orchestrator.Status.STOPPED,rep.orchestrator.Status.STOPPING]:
+ rep.orchestrator.stop()
+ ifrep_status!=rep.orchestrator.Status.STOPPED:
+ rep.orchestrator.wait_until_complete()
+
+ # Disable capture on play to avoid replicator engaging on any new timeline events
+ rep.orchestrator.set_capture_on_play(False)
+ exceptException:
+ pass
+
+ # clear the instance and all callbacks
+ # note: clearing callbacks is important to prevent memory leaks
+ self.clear_all_callbacks()
+
+ # workaround for exit issues, clean the stage first:
+ ifomni.usd.get_context().can_close_stage():
+ omni.usd.get_context().close_stage()
+
+ # print logging information
+ self.app.print_and_log("Simulation is stopped. Shutting down the app...")
+
+ # Cleanup any running tracy instances so data is not lost
+ try:
+ profiler_tracy=carb.profiler.acquire_profiler_interface(plugin_name="carb.profiler-tracy.plugin")
+ ifprofiler_tracy:
+ profiler_tracy.set_capture_mask(0)
+ profiler_tracy.end(0)
+ profiler_tracy.shutdown()
+ exceptRuntimeError:
+ # Tracy plugin was not loaded, so profiler never started - skip checks.
+ pass
+
+ # Disable logging before shutdown to keep the log clean
+ # Warnings at this point don't matter as the python process is about to be terminated
+ logging=carb.logging.acquire_logging()
+ logging.set_level_threshold(carb.logging.LEVEL_ERROR)
+
+ # App shutdown is disabled to prevent crashes on shutdown. Terminating carb is faster
+ # self._app.shutdown()
+ self._framework.unload_all_plugins()
+
+
+@contextmanager
+defbuild_simulation_context(
+ create_new_stage:bool=True,
+ gravity_enabled:bool=True,
+ device:str="cuda:0",
+ dt:float=0.01,
+ sim_cfg:SimulationCfg|None=None,
+ add_ground_plane:bool=False,
+ add_lighting:bool=False,
+ auto_add_lighting:bool=False,
+)->Iterator[SimulationContext]:
+"""Context manager to build a simulation context with the provided settings.
+
+ This function facilitates the creation of a simulation context and provides flexibility in configuring various
+ aspects of the simulation, such as time step, gravity, device, and scene elements like ground plane and
+ lighting.
+
+ If :attr:`sim_cfg` is None, then an instance of :class:`SimulationCfg` is created with default settings, with parameters
+ overwritten based on arguments to the function.
+
+ An example usage of the context manager function:
+
+ .. code-block:: python
+
+ with build_simulation_context() as sim:
+ # Design the scene
+
+ # Play the simulation
+ sim.reset()
+ while sim.is_playing():
+ sim.step()
+
+ Args:
+ create_new_stage: Whether to create a new stage. Defaults to True.
+ gravity_enabled: Whether to enable gravity in the simulation. Defaults to True.
+ device: Device to run the simulation on. Defaults to "cuda:0".
+ dt: Time step for the simulation: Defaults to 0.01.
+ sim_cfg: :class:`omni.isaac.lab.sim.SimulationCfg` to use for the simulation. Defaults to None.
+ add_ground_plane: Whether to add a ground plane to the simulation. Defaults to False.
+ add_lighting: Whether to add a dome light to the simulation. Defaults to False.
+ auto_add_lighting: Whether to automatically add a dome light to the simulation if the simulation has a GUI.
+ Defaults to False. This is useful for debugging tests in the GUI.
+
+ Yields:
+ The simulation context to use for the simulation.
+
+ """
+ try:
+ ifcreate_new_stage:
+ stage_utils.create_new_stage()
+
+ ifsim_cfgisNone:
+ # Construct one and overwrite the dt, gravity, and device
+ sim_cfg=SimulationCfg(dt=dt)
+
+ # Set up gravity
+ ifgravity_enabled:
+ sim_cfg.gravity=(0.0,0.0,-9.81)
+ else:
+ sim_cfg.gravity=(0.0,0.0,0.0)
+
+ # Set device
+ sim_cfg.device=device
+
+ # Construct simulation context
+ sim=SimulationContext(sim_cfg)
+
+ ifadd_ground_plane:
+ # Ground-plane
+ cfg=GroundPlaneCfg()
+ cfg.func("/World/defaultGroundPlane",cfg)
+
+ ifadd_lightingor(auto_add_lightingandsim.has_gui()):
+ # Lighting
+ cfg=DomeLightCfg(
+ color=(0.1,0.1,0.1),
+ enable_color_temperature=True,
+ color_temperature=5500,
+ intensity=10000,
+ )
+ # Dome light named specifically to avoid conflicts
+ cfg.func(prim_path="/World/defaultDomeLight",cfg=cfg,translation=(0.0,0.0,10.0))
+
+ yieldsim
+
+ exceptException:
+ carb.log_error(traceback.format_exc())
+ raise
+ finally:
+ ifnotsim.has_gui():
+ # Stop simulation only if we aren't rendering otherwise the app will hang indefinitely
+ sim.stop()
+
+ # Clear the stage
+ sim.clear_all_callbacks()
+ sim.clear_instance()
+
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+from__future__importannotations
+
+fromtypingimportTYPE_CHECKING
+
+importcarb
+importomni.isaac.core.utils.primsasprim_utils
+importomni.isaac.core.utils.stageasstage_utils
+importomni.kit.commands
+frompxrimportGf,Sdf,Usd
+
+fromomni.isaac.lab.simimportconverters,schemas
+fromomni.isaac.lab.sim.utilsimportbind_physics_material,bind_visual_material,clone,select_usd_variants
+
+ifTYPE_CHECKING:
+ from.importfrom_files_cfg
+
+
+
[文档]@clone
+defspawn_from_usd(
+ prim_path:str,
+ cfg:from_files_cfg.UsdFileCfg,
+ translation:tuple[float,float,float]|None=None,
+ orientation:tuple[float,float,float,float]|None=None,
+)->Usd.Prim:
+"""Spawn an asset from a USD file and override the settings with the given config.
+
+ In the case of a USD file, the asset is spawned at the default prim specified in the USD file.
+ If a default prim is not specified, then the asset is spawned at the root prim.
+
+ In case a prim already exists at the given prim path, then the function does not create a new prim
+ or throw an error that the prim already exists. Instead, it just takes the existing prim and overrides
+ the settings with the given config.
+
+ .. note::
+ This function is decorated with :func:`clone` that resolves prim path into list of paths
+ if the input prim path is a regex pattern. This is done to support spawning multiple assets
+ from a single and cloning the USD prim at the given path expression.
+
+ Args:
+ prim_path: The prim path or pattern to spawn the asset at. If the prim path is a regex pattern,
+ then the asset is spawned at all the matching prim paths.
+ cfg: The configuration instance.
+ translation: The translation to apply to the prim w.r.t. its parent prim. Defaults to None, in which
+ case the translation specified in the USD file is used.
+ orientation: The orientation in (w, x, y, z) to apply to the prim w.r.t. its parent prim. Defaults to None,
+ in which case the orientation specified in the USD file is used.
+
+ Returns:
+ The prim of the spawned asset.
+
+ Raises:
+ FileNotFoundError: If the USD file does not exist at the given path.
+ """
+ # spawn asset from the given usd file
+ return_spawn_from_usd_file(prim_path,cfg.usd_path,cfg,translation,orientation)
+
+
+
[文档]@clone
+defspawn_from_urdf(
+ prim_path:str,
+ cfg:from_files_cfg.UrdfFileCfg,
+ translation:tuple[float,float,float]|None=None,
+ orientation:tuple[float,float,float,float]|None=None,
+)->Usd.Prim:
+"""Spawn an asset from a URDF file and override the settings with the given config.
+
+ It uses the :class:`UrdfConverter` class to create a USD file from URDF. This file is then imported
+ at the specified prim path.
+
+ In case a prim already exists at the given prim path, then the function does not create a new prim
+ or throw an error that the prim already exists. Instead, it just takes the existing prim and overrides
+ the settings with the given config.
+
+ .. note::
+ This function is decorated with :func:`clone` that resolves prim path into list of paths
+ if the input prim path is a regex pattern. This is done to support spawning multiple assets
+ from a single and cloning the USD prim at the given path expression.
+
+ Args:
+ prim_path: The prim path or pattern to spawn the asset at. If the prim path is a regex pattern,
+ then the asset is spawned at all the matching prim paths.
+ cfg: The configuration instance.
+ translation: The translation to apply to the prim w.r.t. its parent prim. Defaults to None, in which
+ case the translation specified in the generated USD file is used.
+ orientation: The orientation in (w, x, y, z) to apply to the prim w.r.t. its parent prim. Defaults to None,
+ in which case the orientation specified in the generated USD file is used.
+
+ Returns:
+ The prim of the spawned asset.
+
+ Raises:
+ FileNotFoundError: If the URDF file does not exist at the given path.
+ """
+ # urdf loader to convert urdf to usd
+ urdf_loader=converters.UrdfConverter(cfg)
+ # spawn asset from the generated usd file
+ return_spawn_from_usd_file(prim_path,urdf_loader.usd_path,cfg,translation,orientation)
+
+
+
[文档]defspawn_ground_plane(
+ prim_path:str,
+ cfg:from_files_cfg.GroundPlaneCfg,
+ translation:tuple[float,float,float]|None=None,
+ orientation:tuple[float,float,float,float]|None=None,
+)->Usd.Prim:
+"""Spawns a ground plane into the scene.
+
+ This function loads the USD file containing the grid plane asset from Isaac Sim. It may
+ not work with other assets for ground planes. In those cases, please use the `spawn_from_usd`
+ function.
+
+ Note:
+ This function takes keyword arguments to be compatible with other spawners. However, it does not
+ use any of the kwargs.
+
+ Args:
+ prim_path: The path to spawn the asset at.
+ cfg: The configuration instance.
+ translation: The translation to apply to the prim w.r.t. its parent prim. Defaults to None, in which
+ case the translation specified in the USD file is used.
+ orientation: The orientation in (w, x, y, z) to apply to the prim w.r.t. its parent prim. Defaults to None,
+ in which case the orientation specified in the USD file is used.
+
+ Returns:
+ The prim of the spawned asset.
+
+ Raises:
+ ValueError: If the prim path already exists.
+ """
+ # Spawn Ground-plane
+ ifnotprim_utils.is_prim_path_valid(prim_path):
+ prim_utils.create_prim(prim_path,usd_path=cfg.usd_path,translation=translation,orientation=orientation)
+ else:
+ raiseValueError(f"A prim already exists at path: '{prim_path}'.")
+
+ # Create physics material
+ ifcfg.physics_materialisnotNone:
+ cfg.physics_material.func(f"{prim_path}/physicsMaterial",cfg.physics_material)
+ # Apply physics material to ground plane
+ collision_prim_path=prim_utils.get_prim_path(
+ prim_utils.get_first_matching_child_prim(
+ prim_path,predicate=lambdax:prim_utils.get_prim_type_name(x)=="Plane"
+ )
+ )
+ bind_physics_material(collision_prim_path,f"{prim_path}/physicsMaterial")
+
+ # Scale only the mesh
+ # Warning: This is specific to the default grid plane asset.
+ ifprim_utils.is_prim_path_valid(f"{prim_path}/Enviroment"):
+ # compute scale from size
+ scale=(cfg.size[0]/100.0,cfg.size[1]/100.0,1.0)
+ # apply scale to the mesh
+ omni.kit.commands.execute(
+ "ChangeProperty",
+ prop_path=Sdf.Path(f"{prim_path}/Enviroment.xformOp:scale"),
+ value=scale,
+ prev=None,
+ )
+
+ # Change the color of the plane
+ # Warning: This is specific to the default grid plane asset.
+ ifcfg.colorisnotNone:
+ prop_path=f"{prim_path}/Looks/theGrid/Shader.inputs:diffuse_tint"
+ # change the color
+ omni.kit.commands.execute(
+ "ChangePropertyCommand",
+ prop_path=Sdf.Path(prop_path),
+ value=Gf.Vec3f(*cfg.color),
+ prev=None,
+ type_to_create_if_not_exist=Sdf.ValueTypeNames.Color3f,
+ )
+ # Remove the light from the ground plane
+ # It isn't bright enough and messes up with the user's lighting settings
+ omni.kit.commands.execute("ToggleVisibilitySelectedPrims",selected_paths=[f"{prim_path}/SphereLight"])
+
+ # return the prim
+ returnprim_utils.get_prim_at_path(prim_path)
+
+
+"""
+Helper functions.
+"""
+
+
+def_spawn_from_usd_file(
+ prim_path:str,
+ usd_path:str,
+ cfg:from_files_cfg.FileCfg,
+ translation:tuple[float,float,float]|None=None,
+ orientation:tuple[float,float,float,float]|None=None,
+)->Usd.Prim:
+"""Spawn an asset from a USD file and override the settings with the given config.
+
+ In case a prim already exists at the given prim path, then the function does not create a new prim
+ or throw an error that the prim already exists. Instead, it just takes the existing prim and overrides
+ the settings with the given config.
+
+ Args:
+ prim_path: The prim path or pattern to spawn the asset at. If the prim path is a regex pattern,
+ then the asset is spawned at all the matching prim paths.
+ usd_path: The path to the USD file to spawn the asset from.
+ cfg: The configuration instance.
+ translation: The translation to apply to the prim w.r.t. its parent prim. Defaults to None, in which
+ case the translation specified in the generated USD file is used.
+ orientation: The orientation in (w, x, y, z) to apply to the prim w.r.t. its parent prim. Defaults to None,
+ in which case the orientation specified in the generated USD file is used.
+
+ Returns:
+ The prim of the spawned asset.
+
+ Raises:
+ FileNotFoundError: If the USD file does not exist at the given path.
+ """
+ # check file path exists
+ stage:Usd.Stage=stage_utils.get_current_stage()
+ ifnotstage.ResolveIdentifierToEditTarget(usd_path):
+ raiseFileNotFoundError(f"USD file not found at path: '{usd_path}'.")
+ # spawn asset if it doesn't exist.
+ ifnotprim_utils.is_prim_path_valid(prim_path):
+ # add prim as reference to stage
+ prim_utils.create_prim(
+ prim_path,
+ usd_path=usd_path,
+ translation=translation,
+ orientation=orientation,
+ scale=cfg.scale,
+ )
+ else:
+ carb.log_warn(f"A prim already exists at prim path: '{prim_path}'.")
+
+ # modify variants
+ ifhasattr(cfg,"variants")andcfg.variantsisnotNone:
+ select_usd_variants(prim_path,cfg.variants)
+
+ # modify rigid body properties
+ ifcfg.rigid_propsisnotNone:
+ schemas.modify_rigid_body_properties(prim_path,cfg.rigid_props)
+ # modify collision properties
+ ifcfg.collision_propsisnotNone:
+ schemas.modify_collision_properties(prim_path,cfg.collision_props)
+ # modify mass properties
+ ifcfg.mass_propsisnotNone:
+ schemas.modify_mass_properties(prim_path,cfg.mass_props)
+
+ # modify articulation root properties
+ ifcfg.articulation_propsisnotNone:
+ schemas.modify_articulation_root_properties(prim_path,cfg.articulation_props)
+ # modify tendon properties
+ ifcfg.fixed_tendons_propsisnotNone:
+ schemas.modify_fixed_tendon_properties(prim_path,cfg.fixed_tendons_props)
+ # define drive API on the joints
+ # note: these are only for setting low-level simulation properties. all others should be set or are
+ # and overridden by the articulation/actuator properties.
+ ifcfg.joint_drive_propsisnotNone:
+ schemas.modify_joint_drive_properties(prim_path,cfg.joint_drive_props)
+
+ # modify deformable body properties
+ ifcfg.deformable_propsisnotNone:
+ schemas.modify_deformable_body_properties(prim_path,cfg.deformable_props)
+
+ # apply visual material
+ ifcfg.visual_materialisnotNone:
+ ifnotcfg.visual_material_path.startswith("/"):
+ material_path=f"{prim_path}/{cfg.visual_material_path}"
+ else:
+ material_path=cfg.visual_material_path
+ # create material
+ cfg.visual_material.func(material_path,cfg.visual_material)
+ # apply material
+ bind_visual_material(prim_path,material_path)
+
+ # return the prim
+ returnprim_utils.get_prim_at_path(prim_path)
+
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+from__future__importannotations
+
+fromcollections.abcimportCallable
+fromdataclassesimportMISSING
+
+fromomni.isaac.lab.simimportconverters,schemas
+fromomni.isaac.lab.sim.spawnersimportmaterials
+fromomni.isaac.lab.sim.spawners.spawner_cfgimportDeformableObjectSpawnerCfg,RigidObjectSpawnerCfg,SpawnerCfg
+fromomni.isaac.lab.utilsimportconfigclass
+fromomni.isaac.lab.utils.assetsimportISAAC_NUCLEUS_DIR
+
+from.importfrom_files
+
+
+@configclass
+classFileCfg(RigidObjectSpawnerCfg,DeformableObjectSpawnerCfg):
+"""Configuration parameters for spawning an asset from a file.
+
+ This class is a base class for spawning assets from files. It includes the common parameters
+ for spawning assets from files, such as the path to the file and the function to use for spawning
+ the asset.
+
+ Note:
+ By default, all properties are set to None. This means that no properties will be added or modified
+ to the prim outside of the properties available by default when spawning the prim.
+
+ If they are set to a value, then the properties are modified on the spawned prim in a nested manner.
+ This is done by calling the respective function with the specified properties.
+ """
+
+ scale:tuple[float,float,float]|None=None
+"""Scale of the asset. Defaults to None, in which case the scale is not modified."""
+
+ articulation_props:schemas.ArticulationRootPropertiesCfg|None=None
+"""Properties to apply to the articulation root."""
+
+ fixed_tendons_props:schemas.FixedTendonsPropertiesCfg|None=None
+"""Properties to apply to the fixed tendons (if any)."""
+
+ joint_drive_props:schemas.JointDrivePropertiesCfg|None=None
+"""Properties to apply to a joint."""
+
+ visual_material_path:str="material"
+"""Path to the visual material to use for the prim. Defaults to "material".
+
+ If the path is relative, then it will be relative to the prim's path.
+ This parameter is ignored if `visual_material` is not None.
+ """
+
+ visual_material:materials.VisualMaterialCfg|None=None
+"""Visual material properties to override the visual material properties in the URDF file.
+
+ Note:
+ If None, then no visual material will be added.
+ """
+
+
+
[文档]@configclass
+classUsdFileCfg(FileCfg):
+"""USD file to spawn asset from.
+
+ USD files are imported directly into the scene. However, given their complexity, there are various different
+ operations that can be performed on them. For example, selecting variants, applying materials, or modifying
+ existing properties.
+
+ To prevent the explosion of configuration parameters, the available operations are limited to the most common
+ ones. These include:
+
+ - **Selecting variants**: This is done by specifying the :attr:`variants` parameter.
+ - **Creating and applying materials**: This is done by specifying the :attr:`visual_material` and
+ :attr:`physics_material` parameters.
+ - **Modifying existing properties**: This is done by specifying the respective properties in the configuration
+ class. For instance, to modify the scale of the imported prim, set the :attr:`scale` parameter.
+
+ See :meth:`spawn_from_usd` for more information.
+
+ .. note::
+ The configuration parameters include various properties. If not `None`, these properties
+ are modified on the spawned prim in a nested manner.
+
+ If they are set to a value, then the properties are modified on the spawned prim in a nested manner.
+ This is done by calling the respective function with the specified properties.
+ """
+
+ func:Callable=from_files.spawn_from_usd
+
+ usd_path:str=MISSING
+"""Path to the USD file to spawn asset from."""
+
+ variants:object|dict[str,str]|None=None
+"""Variants to select from in the input USD file. Defaults to None, in which case no variants are applied.
+
+ This can either be a configclass object, in which case each attribute is used as a variant set name and its specified value,
+ or a dictionary mapping between the two. Please check the :meth:`~omni.isaac.lab.sim.utils.select_usd_variants` function
+ for more information.
+ """
+
+
+
[文档]@configclass
+classUrdfFileCfg(FileCfg,converters.UrdfConverterCfg):
+"""URDF file to spawn asset from.
+
+ It uses the :class:`UrdfConverter` class to create a USD file from URDF and spawns the imported
+ USD file. Similar to the :class:`UsdFileCfg`, the generated USD file can be modified by specifying
+ the respective properties in the configuration class.
+
+ See :meth:`spawn_from_urdf` for more information.
+
+ .. note::
+ The configuration parameters include various properties. If not `None`, these properties
+ are modified on the spawned prim in a nested manner.
+
+ If they are set to a value, then the properties are modified on the spawned prim in a nested manner.
+ This is done by calling the respective function with the specified properties.
+
+ """
+
+ func:Callable=from_files.spawn_from_urdf
+
+
+"""
+Spawning ground plane.
+"""
+
+
+
[文档]@configclass
+classGroundPlaneCfg(SpawnerCfg):
+"""Create a ground plane prim.
+
+ This uses the USD for the standard grid-world ground plane from Isaac Sim by default.
+ """
+
+ func:Callable=from_files.spawn_ground_plane
+
+ usd_path:str=f"{ISAAC_NUCLEUS_DIR}/Environments/Grid/default_environment.usd"
+"""Path to the USD file to spawn asset from. Defaults to the grid-world ground plane."""
+
+ color:tuple[float,float,float]|None=(0.0,0.0,0.0)
+"""The color of the ground plane. Defaults to (0.0, 0.0, 0.0).
+
+ If None, then the color remains unchanged.
+ """
+
+ size:tuple[float,float]=(100.0,100.0)
+"""The size of the ground plane. Defaults to 100 m x 100 m."""
+
+ physics_material:materials.RigidBodyMaterialCfg=materials.RigidBodyMaterialCfg()
+"""Physics material properties. Defaults to the default rigid body material."""
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+from__future__importannotations
+
+fromtypingimportTYPE_CHECKING
+
+importomni.isaac.core.utils.primsasprim_utils
+frompxrimportUsd,UsdLux
+
+fromomni.isaac.lab.sim.utilsimportclone,safe_set_attribute_on_usd_prim
+
+ifTYPE_CHECKING:
+ from.importlights_cfg
+
+
+
[文档]@clone
+defspawn_light(
+ prim_path:str,
+ cfg:lights_cfg.LightCfg,
+ translation:tuple[float,float,float]|None=None,
+ orientation:tuple[float,float,float,float]|None=None,
+)->Usd.Prim:
+"""Create a light prim at the specified prim path with the specified configuration.
+
+ The created prim is based on the `USD.LuxLight <https://openusd.org/dev/api/class_usd_lux_light_a_p_i.html>`_ API.
+
+ .. note::
+ This function is decorated with :func:`clone` that resolves prim path into list of paths
+ if the input prim path is a regex pattern. This is done to support spawning multiple assets
+ from a single and cloning the USD prim at the given path expression.
+
+ Args:
+ prim_path: The prim path or pattern to spawn the asset at. If the prim path is a regex pattern,
+ then the asset is spawned at all the matching prim paths.
+ cfg: The configuration for the light source.
+ translation: The translation of the prim. Defaults to None, in which case this is set to the origin.
+ orientation: The orientation of the prim as (w, x, y, z). Defaults to None, in which case this
+ is set to identity.
+
+ Raises:
+ ValueError: When a prim already exists at the specified prim path.
+ """
+ # check if prim already exists
+ ifprim_utils.is_prim_path_valid(prim_path):
+ raiseValueError(f"A prim already exists at path: '{prim_path}'.")
+ # create the prim
+ prim=prim_utils.create_prim(prim_path,prim_type=cfg.prim_type,translation=translation,orientation=orientation)
+
+ # convert to dict
+ cfg=cfg.to_dict()
+ # delete spawner func specific parameters
+ delcfg["prim_type"]
+ # delete custom attributes in the config that are not USD parameters
+ non_usd_cfg_param_names=["func","copy_from_source","visible","semantic_tags"]
+ forparam_nameinnon_usd_cfg_param_names:
+ delcfg[param_name]
+ # set into USD API
+ forattr_name,valueincfg.items():
+ # special operation for texture properties
+ # note: this is only used for dome light
+ if"texture"inattr_name:
+ light_prim=UsdLux.DomeLight(prim)
+ ifattr_name=="texture_file":
+ light_prim.CreateTextureFileAttr(value)
+ elifattr_name=="texture_format":
+ light_prim.CreateTextureFormatAttr(value)
+ else:
+ raiseValueError(f"Unsupported texture attribute: '{attr_name}'.")
+ else:
+ ifattr_name=="visible_in_primary_ray":
+ prim_prop_name=attr_name
+ else:
+ prim_prop_name=f"inputs:{attr_name}"
+ # set the attribute
+ safe_set_attribute_on_usd_prim(prim,prim_prop_name,value,camel_case=True)
+ # return the prim
+ returnprim
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+fromcollections.abcimportCallable
+fromdataclassesimportMISSING
+fromtypingimportLiteral
+
+fromomni.isaac.lab.sim.spawners.spawner_cfgimportSpawnerCfg
+fromomni.isaac.lab.utilsimportconfigclass
+
+from.importlights
+
+
+
[文档]@configclass
+classLightCfg(SpawnerCfg):
+"""Configuration parameters for creating a light in the scene.
+
+ Please refer to the documentation on `USD LuxLight <https://openusd.org/dev/api/class_usd_lux_light_a_p_i.html>`_
+ for more information.
+
+ .. note::
+ The default values for the attributes are those specified in the their official documentation.
+ """
+
+ func:Callable=lights.spawn_light
+
+ prim_type:str=MISSING
+"""The prim type name for the light prim."""
+
+ color:tuple[float,float,float]=(1.0,1.0,1.0)
+"""The color of emitted light, in energy-linear terms. Defaults to white."""
+
+ enable_color_temperature:bool=False
+"""Enables color temperature. Defaults to false."""
+
+ color_temperature:float=6500.0
+"""Color temperature (in Kelvin) representing the white point. The valid range is [1000, 10000]. Defaults to 6500K.
+
+ The `color temperature <https://en.wikipedia.org/wiki/Color_temperature>`_ corresponds to the warmth
+ or coolness of light. Warmer light has a lower color temperature, while cooler light has a higher
+ color temperature.
+
+ Note:
+ It only takes effect when :attr:`enable_color_temperature` is true.
+ """
+
+ normalize:bool=False
+"""Normalizes power by the surface area of the light. Defaults to false.
+
+ This makes it easier to independently adjust the power and shape of the light, by causing the power
+ to not vary with the area or angular size of the light.
+ """
+
+ exposure:float=0.0
+"""Scales the power of the light exponentially as a power of 2. Defaults to 0.0.
+
+ The result is multiplied against the intensity.
+ """
+
+ intensity:float=1.0
+"""Scales the power of the light linearly. Defaults to 1.0."""
+
+
+
[文档]@configclass
+classDiskLightCfg(LightCfg):
+"""Configuration parameters for creating a disk light in the scene.
+
+ A disk light is a light source that emits light from a disk. It is useful for simulating
+ fluorescent lights. For more information, please refer to the documentation on
+ `USDLux DiskLight <https://openusd.org/dev/api/class_usd_lux_disk_light.html>`_.
+
+ .. note::
+ The default values for the attributes are those specified in the their official documentation.
+ """
+
+ prim_type="DiskLight"
+
+ radius:float=0.5
+"""Radius of the disk (in m). Defaults to 0.5m."""
+
+
+
[文档]@configclass
+classDistantLightCfg(LightCfg):
+"""Configuration parameters for creating a distant light in the scene.
+
+ A distant light is a light source that is infinitely far away, and emits parallel rays of light.
+ It is useful for simulating sun/moon light. For more information, please refer to the documentation on
+ `USDLux DistantLight <https://openusd.org/dev/api/class_usd_lux_distant_light.html>`_.
+
+ .. note::
+ The default values for the attributes are those specified in the their official documentation.
+ """
+
+ prim_type="DistantLight"
+
+ angle:float=0.53
+"""Angular size of the light (in degrees). Defaults to 0.53 degrees.
+
+ As an example, the Sun is approximately 0.53 degrees as seen from Earth.
+ Higher values broaden the light and therefore soften shadow edges.
+ """
+
+
+
[文档]@configclass
+classDomeLightCfg(LightCfg):
+"""Configuration parameters for creating a dome light in the scene.
+
+ A dome light is a light source that emits light inwards from all directions. It is also possible to
+ attach a texture to the dome light, which will be used to emit light. For more information, please refer
+ to the documentation on `USDLux DomeLight <https://openusd.org/dev/api/class_usd_lux_dome_light.html>`_.
+
+ .. note::
+ The default values for the attributes are those specified in the their official documentation.
+ """
+
+ prim_type="DomeLight"
+
+ texture_file:str|None=None
+"""A color texture to use on the dome, such as an HDR (high dynamic range) texture intended
+ for IBL (image based lighting). Defaults to None.
+
+ If None, the dome will emit a uniform color.
+ """
+
+ texture_format:Literal["automatic","latlong","mirroredBall","angular","cubeMapVerticalCross"]="automatic"
+"""The parametrization format of the color map file. Defaults to "automatic".
+
+ Valid values are:
+
+ * ``"automatic"``: Tries to determine the layout from the file itself. For example, Renderman texture files embed an explicit parameterization.
+ * ``"latlong"``: Latitude as X, longitude as Y.
+ * ``"mirroredBall"``: An image of the environment reflected in a sphere, using an implicitly orthogonal projection.
+ * ``"angular"``: Similar to mirroredBall but the radial dimension is mapped linearly to the angle, providing better sampling at the edges.
+ * ``"cubeMapVerticalCross"``: A cube map with faces laid out as a vertical cross.
+ """
+
+ visible_in_primary_ray:bool=True
+"""Whether the dome light is visible in the primary ray. Defaults to True.
+
+ If true, the texture in the sky is visible, otherwise the sky is black.
+ """
+
+
+
[文档]@configclass
+classCylinderLightCfg(LightCfg):
+"""Configuration parameters for creating a cylinder light in the scene.
+
+ A cylinder light is a light source that emits light from a cylinder. It is useful for simulating
+ fluorescent lights. For more information, please refer to the documentation on
+ `USDLux CylinderLight <https://openusd.org/dev/api/class_usd_lux_cylinder_light.html>`_.
+
+ .. note::
+ The default values for the attributes are those specified in the their official documentation.
+ """
+
+ prim_type="CylinderLight"
+
+ length:float=1.0
+"""Length of the cylinder (in m). Defaults to 1.0m."""
+
+ radius:float=0.5
+"""Radius of the cylinder (in m). Defaults to 0.5m."""
+
+ treat_as_line:bool=False
+"""Treats the cylinder as a line source, i.e. a zero-radius cylinder. Defaults to false."""
+
+
+
[文档]@configclass
+classSphereLightCfg(LightCfg):
+"""Configuration parameters for creating a sphere light in the scene.
+
+ A sphere light is a light source that emits light outward from a sphere. For more information,
+ please refer to the documentation on
+ `USDLux SphereLight <https://openusd.org/dev/api/class_usd_lux_sphere_light.html>`_.
+
+ .. note::
+ The default values for the attributes are those specified in the their official documentation.
+ """
+
+ prim_type="SphereLight"
+
+ radius:float=0.5
+"""Radius of the sphere. Defaults to 0.5m."""
+
+ treat_as_point:bool=False
+"""Treats the sphere as a point source, i.e. a zero-radius sphere. Defaults to false."""
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+from__future__importannotations
+
+fromtypingimportTYPE_CHECKING
+
+importomni.isaac.core.utils.primsasprim_utils
+importomni.isaac.core.utils.stageasstage_utils
+frompxrimportPhysxSchema,Usd,UsdPhysics,UsdShade
+
+fromomni.isaac.lab.sim.utilsimportclone,safe_set_attribute_on_usd_schema
+
+ifTYPE_CHECKING:
+ from.importphysics_materials_cfg
+
+
+
[文档]@clone
+defspawn_rigid_body_material(prim_path:str,cfg:physics_materials_cfg.RigidBodyMaterialCfg)->Usd.Prim:
+"""Create material with rigid-body physics properties.
+
+ Rigid body materials are used to define the physical properties to meshes of a rigid body. These
+ include the friction, restitution, and their respective combination modes. For more information on
+ rigid body material, please refer to the `documentation on PxMaterial <https://nvidia-omniverse.github.io/PhysX/physx/5.4.1/_api_build/classPxBaseMaterial.html>`_.
+
+ .. note::
+ This function is decorated with :func:`clone` that resolves prim path into list of paths
+ if the input prim path is a regex pattern. This is done to support spawning multiple assets
+ from a single and cloning the USD prim at the given path expression.
+
+ Args:
+ prim_path: The prim path or pattern to spawn the asset at. If the prim path is a regex pattern,
+ then the asset is spawned at all the matching prim paths.
+ cfg: The configuration for the physics material.
+
+ Returns:
+ The spawned rigid body material prim.
+
+ Raises:
+ ValueError: When a prim already exists at the specified prim path and is not a material.
+ """
+ # create material prim if no prim exists
+ ifnotprim_utils.is_prim_path_valid(prim_path):
+ _=UsdShade.Material.Define(stage_utils.get_current_stage(),prim_path)
+
+ # obtain prim
+ prim=prim_utils.get_prim_at_path(prim_path)
+ # check if prim is a material
+ ifnotprim.IsA(UsdShade.Material):
+ raiseValueError(f"A prim already exists at path: '{prim_path}' but is not a material.")
+ # retrieve the USD rigid-body api
+ usd_physics_material_api=UsdPhysics.MaterialAPI(prim)
+ ifnotusd_physics_material_api:
+ usd_physics_material_api=UsdPhysics.MaterialAPI.Apply(prim)
+ # retrieve the collision api
+ physx_material_api=PhysxSchema.PhysxMaterialAPI(prim)
+ ifnotphysx_material_api:
+ physx_material_api=PhysxSchema.PhysxMaterialAPI.Apply(prim)
+
+ # convert to dict
+ cfg=cfg.to_dict()
+ delcfg["func"]
+ # set into USD API
+ forattr_namein["static_friction","dynamic_friction","restitution"]:
+ value=cfg.pop(attr_name,None)
+ safe_set_attribute_on_usd_schema(usd_physics_material_api,attr_name,value,camel_case=True)
+ # set into PhysX API
+ forattr_name,valueincfg.items():
+ safe_set_attribute_on_usd_schema(physx_material_api,attr_name,value,camel_case=True)
+ # return the prim
+ returnprim
+
+
+
[文档]@clone
+defspawn_deformable_body_material(prim_path:str,cfg:physics_materials_cfg.DeformableBodyMaterialCfg)->Usd.Prim:
+"""Create material with deformable-body physics properties.
+
+ Deformable body materials are used to define the physical properties to meshes of a deformable body. These
+ include the friction and deformable body properties. For more information on deformable body material,
+ please refer to the documentation on `PxFEMSoftBodyMaterial`_.
+
+ .. note::
+ This function is decorated with :func:`clone` that resolves prim path into list of paths
+ if the input prim path is a regex pattern. This is done to support spawning multiple assets
+ from a single and cloning the USD prim at the given path expression.
+
+ Args:
+ prim_path: The prim path or pattern to spawn the asset at. If the prim path is a regex pattern,
+ then the asset is spawned at all the matching prim paths.
+ cfg: The configuration for the physics material.
+
+ Returns:
+ The spawned deformable body material prim.
+
+ Raises:
+ ValueError: When a prim already exists at the specified prim path and is not a material.
+
+ .. _PxFEMSoftBodyMaterial: https://nvidia-omniverse.github.io/PhysX/physx/5.4.1/_api_build/structPxFEMSoftBodyMaterialModel.html
+ """
+ # create material prim if no prim exists
+ ifnotprim_utils.is_prim_path_valid(prim_path):
+ _=UsdShade.Material.Define(stage_utils.get_current_stage(),prim_path)
+
+ # obtain prim
+ prim=prim_utils.get_prim_at_path(prim_path)
+ # check if prim is a material
+ ifnotprim.IsA(UsdShade.Material):
+ raiseValueError(f"A prim already exists at path: '{prim_path}' but is not a material.")
+ # retrieve the deformable-body api
+ physx_deformable_body_material_api=PhysxSchema.PhysxDeformableBodyMaterialAPI(prim)
+ ifnotphysx_deformable_body_material_api:
+ physx_deformable_body_material_api=PhysxSchema.PhysxDeformableBodyMaterialAPI.Apply(prim)
+
+ # convert to dict
+ cfg=cfg.to_dict()
+ delcfg["func"]
+ # set into PhysX API
+ forattr_name,valueincfg.items():
+ safe_set_attribute_on_usd_schema(physx_deformable_body_material_api,attr_name,value,camel_case=True)
+ # return the prim
+ returnprim
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+fromcollections.abcimportCallable
+fromdataclassesimportMISSING
+fromtypingimportLiteral
+
+fromomni.isaac.lab.utilsimportconfigclass
+
+from.importphysics_materials
+
+
+
[文档]@configclass
+classPhysicsMaterialCfg:
+"""Configuration parameters for creating a physics material.
+
+ Physics material are PhysX schemas that can be applied to a USD material prim to define the
+ physical properties related to the material. For example, the friction coefficient, restitution
+ coefficient, etc. For more information on physics material, please refer to the
+ `PhysX documentation <https://nvidia-omniverse.github.io/PhysX/physx/5.4.1/_api_build/classPxBaseMaterial.html>`__.
+ """
+
+ func:Callable=MISSING
+"""Function to use for creating the material."""
+
+
+
[文档]@configclass
+classRigidBodyMaterialCfg(PhysicsMaterialCfg):
+"""Physics material parameters for rigid bodies.
+
+ See :meth:`spawn_rigid_body_material` for more information.
+
+ Note:
+ The default values are the `default values used by PhysX 5
+ <https://docs.omniverse.nvidia.com/extensions/latest/ext_physics/rigid-bodies.html#rigid-body-materials>`__.
+ """
+
+ func:Callable=physics_materials.spawn_rigid_body_material
+
+ static_friction:float=0.5
+"""The static friction coefficient. Defaults to 0.5."""
+
+ dynamic_friction:float=0.5
+"""The dynamic friction coefficient. Defaults to 0.5."""
+
+ restitution:float=0.0
+"""The restitution coefficient. Defaults to 0.0."""
+
+ improve_patch_friction:bool=True
+"""Whether to enable patch friction. Defaults to True."""
+
+ friction_combine_mode:Literal["average","min","multiply","max"]="average"
+"""Determines the way friction will be combined during collisions. Defaults to `"average"`.
+
+ .. attention::
+
+ When two physics materials with different combine modes collide, the combine mode with the higher
+ priority will be used. The priority order is provided `here
+ <https://nvidia-omniverse.github.io/PhysX/physx/5.4.1/_api_build/structPxCombineMode.html>`__.
+ """
+
+ restitution_combine_mode:Literal["average","min","multiply","max"]="average"
+"""Determines the way restitution coefficient will be combined during collisions. Defaults to `"average"`.
+
+ .. attention::
+
+ When two physics materials with different combine modes collide, the combine mode with the higher
+ priority will be used. The priority order is provided `here
+ <https://nvidia-omniverse.github.io/PhysX/physx/5.4.1/_api_build/structPxCombineMode.html>`__.
+ """
+
+ compliant_contact_stiffness:float=0.0
+"""Spring stiffness for a compliant contact model using implicit springs. Defaults to 0.0.
+
+ A higher stiffness results in behavior closer to a rigid contact. The compliant contact model is only enabled
+ if the stiffness is larger than 0.
+ """
+
+ compliant_contact_damping:float=0.0
+"""Damping coefficient for a compliant contact model using implicit springs. Defaults to 0.0.
+
+ Irrelevant if compliant contacts are disabled when :obj:`compliant_contact_stiffness` is set to zero and
+ rigid contacts are active.
+ """
+
+
+
[文档]@configclass
+classDeformableBodyMaterialCfg(PhysicsMaterialCfg):
+"""Physics material parameters for deformable bodies.
+
+ See :meth:`spawn_deformable_body_material` for more information.
+
+ Note:
+ The default values are the `default values used by PhysX 5
+ <https://docs.omniverse.nvidia.com/extensions/latest/ext_physics/deformable-bodies.html#deformable-body-material>`__.
+ """
+
+ func:Callable=physics_materials.spawn_deformable_body_material
+
+ density:float|None=None
+"""The material density. Defaults to None, in which case the simulation decides the default density."""
+
+ dynamic_friction:float=0.25
+"""The dynamic friction. Defaults to 0.25."""
+
+ youngs_modulus:float=50000000.0
+"""The Young's modulus, which defines the body's stiffness. Defaults to 50000000.0.
+
+ The Young's modulus is a measure of the material's ability to deform under stress. It is measured in Pascals (Pa).
+ """
+
+ poissons_ratio:float=0.45
+"""The Poisson's ratio which defines the body's volume preservation. Defaults to 0.45.
+
+ The Poisson's ratio is a measure of the material's ability to expand in the lateral direction when compressed
+ in the axial direction. It is a dimensionless number between 0 and 0.5. Using a value of 0.5 will make the
+ material incompressible.
+ """
+
+ elasticity_damping:float=0.005
+"""The elasticity damping for the deformable material. Defaults to 0.005."""
+
+ damping_scale:float=1.0
+"""The damping scale for the deformable material. Defaults to 1.0.
+
+ A scale of 1 corresponds to default damping. A value of 0 will only apply damping to certain motions leading
+ to special effects that look similar to water filled soft bodies.
+ """
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+from__future__importannotations
+
+fromtypingimportTYPE_CHECKING
+
+importomni.isaac.core.utils.primsasprim_utils
+importomni.kit.commands
+frompxrimportUsd
+
+fromomni.isaac.lab.sim.utilsimportclone,safe_set_attribute_on_usd_prim
+fromomni.isaac.lab.utils.assetsimportNVIDIA_NUCLEUS_DIR
+
+ifTYPE_CHECKING:
+ from.importvisual_materials_cfg
+
+
+
[文档]@clone
+defspawn_preview_surface(prim_path:str,cfg:visual_materials_cfg.PreviewSurfaceCfg)->Usd.Prim:
+"""Create a preview surface prim and override the settings with the given config.
+
+ A preview surface is a physically-based surface that handles simple shaders while supporting
+ both *specular* and *metallic* workflows. All color inputs are in linear color space (RGB).
+ For more information, see the `documentation <https://openusd.org/release/spec_usdpreviewsurface.html>`__.
+
+ The function calls the USD command `CreatePreviewSurfaceMaterialPrim`_ to create the prim.
+
+ .. _CreatePreviewSurfaceMaterialPrim: https://docs.omniverse.nvidia.com/kit/docs/omni.usd/latest/omni.usd.commands/omni.usd.commands.CreatePreviewSurfaceMaterialPrimCommand.html
+
+ .. note::
+ This function is decorated with :func:`clone` that resolves prim path into list of paths
+ if the input prim path is a regex pattern. This is done to support spawning multiple assets
+ from a single and cloning the USD prim at the given path expression.
+
+ Args:
+ prim_path: The prim path or pattern to spawn the asset at. If the prim path is a regex pattern,
+ then the asset is spawned at all the matching prim paths.
+ cfg: The configuration instance.
+
+ Returns:
+ The created prim.
+
+ Raises:
+ ValueError: If a prim already exists at the given path.
+ """
+ # spawn material if it doesn't exist.
+ ifnotprim_utils.is_prim_path_valid(prim_path):
+ omni.kit.commands.execute("CreatePreviewSurfaceMaterialPrim",mtl_path=prim_path,select_new_prim=False)
+ else:
+ raiseValueError(f"A prim already exists at path: '{prim_path}'.")
+ # obtain prim
+ prim=prim_utils.get_prim_at_path(f"{prim_path}/Shader")
+ # apply properties
+ cfg=cfg.to_dict()
+ delcfg["func"]
+ forattr_name,attr_valueincfg.items():
+ safe_set_attribute_on_usd_prim(prim,f"inputs:{attr_name}",attr_value,camel_case=True)
+ # return prim
+ returnprim
+
+
+
[文档]@clone
+defspawn_from_mdl_file(prim_path:str,cfg:visual_materials_cfg.MdlMaterialCfg)->Usd.Prim:
+"""Load a material from its MDL file and override the settings with the given config.
+
+ NVIDIA's `Material Definition Language (MDL) <https://www.nvidia.com/en-us/design-visualization/technologies/material-definition-language/>`__
+ is a language for defining physically-based materials. The MDL file format is a binary format
+ that can be loaded by Omniverse and other applications such as Adobe Substance Designer.
+ To learn more about MDL, see the `documentation <https://docs.omniverse.nvidia.com/materials-and-rendering/latest/materials.html>`_.
+
+ The function calls the USD command `CreateMdlMaterialPrim`_ to create the prim.
+
+ .. _CreateMdlMaterialPrim: https://docs.omniverse.nvidia.com/kit/docs/omni.usd/latest/omni.usd.commands/omni.usd.commands.CreateMdlMaterialPrimCommand.html
+
+ .. note::
+ This function is decorated with :func:`clone` that resolves prim path into list of paths
+ if the input prim path is a regex pattern. This is done to support spawning multiple assets
+ from a single and cloning the USD prim at the given path expression.
+
+ Args:
+ prim_path: The prim path or pattern to spawn the asset at. If the prim path is a regex pattern,
+ then the asset is spawned at all the matching prim paths.
+ cfg: The configuration instance.
+
+ Returns:
+ The created prim.
+
+ Raises:
+ ValueError: If a prim already exists at the given path.
+ """
+ # spawn material if it doesn't exist.
+ ifnotprim_utils.is_prim_path_valid(prim_path):
+ # extract material name from path
+ material_name=cfg.mdl_path.split("/")[-1].split(".")[0]
+ omni.kit.commands.execute(
+ "CreateMdlMaterialPrim",
+ mtl_url=cfg.mdl_path.format(NVIDIA_NUCLEUS_DIR=NVIDIA_NUCLEUS_DIR),
+ mtl_name=material_name,
+ mtl_path=prim_path,
+ select_new_prim=False,
+ )
+ else:
+ raiseValueError(f"A prim already exists at path: '{prim_path}'.")
+ # obtain prim
+ prim=prim_utils.get_prim_at_path(f"{prim_path}/Shader")
+ # apply properties
+ cfg=cfg.to_dict()
+ delcfg["func"]
+ delcfg["mdl_path"]
+ forattr_name,attr_valueincfg.items():
+ safe_set_attribute_on_usd_prim(prim,f"inputs:{attr_name}",attr_value,camel_case=False)
+ # return prim
+ returnprim
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+fromcollections.abcimportCallable
+fromdataclassesimportMISSING
+
+fromomni.isaac.lab.utilsimportconfigclass
+
+from.importvisual_materials
+
+
+
[文档]@configclass
+classVisualMaterialCfg:
+"""Configuration parameters for creating a visual material."""
+
+ func:Callable=MISSING
+"""The function to use for creating the material."""
+
+
+
[文档]@configclass
+classPreviewSurfaceCfg(VisualMaterialCfg):
+"""Configuration parameters for creating a preview surface.
+
+ See :meth:`spawn_preview_surface` for more information.
+ """
+
+ func:Callable=visual_materials.spawn_preview_surface
+
+ diffuse_color:tuple[float,float,float]=(0.18,0.18,0.18)
+"""The RGB diffusion color. This is the base color of the surface. Defaults to a dark gray."""
+ emissive_color:tuple[float,float,float]=(0.0,0.0,0.0)
+"""The RGB emission component of the surface. Defaults to black."""
+ roughness:float=0.5
+"""The roughness for specular lobe. Ranges from 0 (smooth) to 1 (rough). Defaults to 0.5."""
+ metallic:float=0.0
+"""The metallic component. Ranges from 0 (dielectric) to 1 (metal). Defaults to 0."""
+ opacity:float=1.0
+"""The opacity of the surface. Ranges from 0 (transparent) to 1 (opaque). Defaults to 1.
+
+ Note:
+ Opacity only affects the surface's appearance during interactive rendering.
+ """
+
+
+
[文档]@configclass
+classMdlFileCfg(VisualMaterialCfg):
+"""Configuration parameters for loading an MDL material from a file.
+
+ See :meth:`spawn_from_mdl_file` for more information.
+ """
+
+ func:Callable=visual_materials.spawn_from_mdl_file
+
+ mdl_path:str=MISSING
+"""The path to the MDL material.
+
+ NVIDIA Omniverse provides various MDL materials in the NVIDIA Nucleus.
+ To use these materials, you can set the path of the material in the nucleus directory
+ using the ``{NVIDIA_NUCLEUS_DIR}`` variable. This is internally resolved to the path of the
+ NVIDIA Nucleus directory on the host machine through the attribute
+ :attr:`omni.isaac.lab.utils.assets.NVIDIA_NUCLEUS_DIR`.
+
+ For example, to use the "Aluminum_Anodized" material, you can set the path to:
+ ``{NVIDIA_NUCLEUS_DIR}/Materials/Base/Metals/Aluminum_Anodized.mdl``.
+ """
+ project_uvw:bool|None=None
+"""Whether to project the UVW coordinates of the material. Defaults to None.
+
+ If None, then the default setting in the MDL material will be used.
+ """
+ albedo_brightness:float|None=None
+"""Multiplier for the diffuse color of the material. Defaults to None.
+
+ If None, then the default setting in the MDL material will be used.
+ """
+ texture_scale:tuple[float,float]|None=None
+"""The scale of the texture. Defaults to None.
+
+ If None, then the default setting in the MDL material will be used.
+ """
+
+
+
[文档]@configclass
+classGlassMdlCfg(VisualMaterialCfg):
+"""Configuration parameters for loading a glass MDL material.
+
+ This is a convenience class for loading a glass MDL material. For more information on
+ glass materials, see the `documentation <https://docs.omniverse.nvidia.com/materials-and-rendering/latest/materials.html#omniglass>`__.
+
+ .. note::
+ The default values are taken from the glass material in the NVIDIA Nucleus.
+ """
+
+ func:Callable=visual_materials.spawn_from_mdl_file
+
+ mdl_path:str="OmniGlass.mdl"
+"""The path to the MDL material. Defaults to the glass material in the NVIDIA Nucleus."""
+ glass_color:tuple[float,float,float]=(1.0,1.0,1.0)
+"""The RGB color or tint of the glass. Defaults to white."""
+ frosting_roughness:float=0.0
+"""The amount of reflectivity of the surface. Ranges from 0 (perfectly clear) to 1 (frosted).
+ Defaults to 0."""
+ thin_walled:bool=False
+"""Whether to perform thin-walled refraction. Defaults to False."""
+ glass_ior:float=1.491
+"""The incidence of refraction to control how much light is bent when passing through the glass.
+ Defaults to 1.491, which is the IOR of glass.
+ """
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+from__future__importannotations
+
+importnumpyasnp
+importtrimesh
+importtrimesh.transformations
+fromtypingimportTYPE_CHECKING
+
+importomni.isaac.core.utils.primsasprim_utils
+frompxrimportUsd,UsdPhysics
+
+fromomni.isaac.lab.simimportschemas
+fromomni.isaac.lab.sim.utilsimportbind_physics_material,bind_visual_material,clone
+
+from..materialsimportDeformableBodyMaterialCfg,RigidBodyMaterialCfg
+
+ifTYPE_CHECKING:
+ from.importmeshes_cfg
+
+
+
[文档]@clone
+defspawn_mesh_sphere(
+ prim_path:str,
+ cfg:meshes_cfg.MeshSphereCfg,
+ translation:tuple[float,float,float]|None=None,
+ orientation:tuple[float,float,float,float]|None=None,
+)->Usd.Prim:
+"""Create a USD-Mesh sphere prim with the given attributes.
+
+ .. note::
+ This function is decorated with :func:`clone` that resolves prim path into list of paths
+ if the input prim path is a regex pattern. This is done to support spawning multiple assets
+ from a single and cloning the USD prim at the given path expression.
+
+ Args:
+ prim_path: The prim path or pattern to spawn the asset at. If the prim path is a regex pattern,
+ then the asset is spawned at all the matching prim paths.
+ cfg: The configuration instance.
+ translation: The translation to apply to the prim w.r.t. its parent prim. Defaults to None, in which case
+ this is set to the origin.
+ orientation: The orientation in (w, x, y, z) to apply to the prim w.r.t. its parent prim. Defaults to None,
+ in which case this is set to identity.
+
+ Returns:
+ The created prim.
+
+ Raises:
+ ValueError: If a prim already exists at the given path.
+ """
+ # create a trimesh sphere
+ sphere=trimesh.creation.uv_sphere(radius=cfg.radius)
+ # spawn the sphere as a mesh
+ _spawn_mesh_geom_from_mesh(prim_path,cfg,sphere,translation,orientation)
+ # return the prim
+ returnprim_utils.get_prim_at_path(prim_path)
+
+
+
[文档]@clone
+defspawn_mesh_cuboid(
+ prim_path:str,
+ cfg:meshes_cfg.MeshCuboidCfg,
+ translation:tuple[float,float,float]|None=None,
+ orientation:tuple[float,float,float,float]|None=None,
+)->Usd.Prim:
+"""Create a USD-Mesh cuboid prim with the given attributes.
+
+ .. note::
+ This function is decorated with :func:`clone` that resolves prim path into list of paths
+ if the input prim path is a regex pattern. This is done to support spawning multiple assets
+ from a single and cloning the USD prim at the given path expression.
+
+ Args:
+ prim_path: The prim path or pattern to spawn the asset at. If the prim path is a regex pattern,
+ then the asset is spawned at all the matching prim paths.
+ cfg: The configuration instance.
+ translation: The translation to apply to the prim w.r.t. its parent prim. Defaults to None, in which case
+ this is set to the origin.
+ orientation: The orientation in (w, x, y, z) to apply to the prim w.r.t. its parent prim. Defaults to None,
+ in which case this is set to identity.
+
+ Returns:
+ The created prim.
+
+ Raises:
+ ValueError: If a prim already exists at the given path.
+ """# create a trimesh box
+ box=trimesh.creation.box(cfg.size)
+ # spawn the cuboid as a mesh
+ _spawn_mesh_geom_from_mesh(prim_path,cfg,box,translation,orientation,None)
+ # return the prim
+ returnprim_utils.get_prim_at_path(prim_path)
+
+
+
[文档]@clone
+defspawn_mesh_cylinder(
+ prim_path:str,
+ cfg:meshes_cfg.MeshCylinderCfg,
+ translation:tuple[float,float,float]|None=None,
+ orientation:tuple[float,float,float,float]|None=None,
+)->Usd.Prim:
+"""Create a USD-Mesh cylinder prim with the given attributes.
+
+ .. note::
+ This function is decorated with :func:`clone` that resolves prim path into list of paths
+ if the input prim path is a regex pattern. This is done to support spawning multiple assets
+ from a single and cloning the USD prim at the given path expression.
+
+ Args:
+ prim_path: The prim path or pattern to spawn the asset at. If the prim path is a regex pattern,
+ then the asset is spawned at all the matching prim paths.
+ cfg: The configuration instance.
+ translation: The translation to apply to the prim w.r.t. its parent prim. Defaults to None, in which case
+ this is set to the origin.
+ orientation: The orientation in (w, x, y, z) to apply to the prim w.r.t. its parent prim. Defaults to None,
+ in which case this is set to identity.
+
+ Returns:
+ The created prim.
+
+ Raises:
+ ValueError: If a prim already exists at the given path.
+ """
+ # align axis from "Z" to input by rotating the cylinder
+ axis=cfg.axis.upper()
+ ifaxis=="X":
+ transform=trimesh.transformations.rotation_matrix(np.pi/2,[0,1,0])
+ elifaxis=="Y":
+ transform=trimesh.transformations.rotation_matrix(-np.pi/2,[1,0,0])
+ else:
+ transform=None
+ # create a trimesh cylinder
+ cylinder=trimesh.creation.cylinder(radius=cfg.radius,height=cfg.height,transform=transform)
+ # spawn the cylinder as a mesh
+ _spawn_mesh_geom_from_mesh(prim_path,cfg,cylinder,translation,orientation)
+ # return the prim
+ returnprim_utils.get_prim_at_path(prim_path)
+
+
+
[文档]@clone
+defspawn_mesh_capsule(
+ prim_path:str,
+ cfg:meshes_cfg.MeshCapsuleCfg,
+ translation:tuple[float,float,float]|None=None,
+ orientation:tuple[float,float,float,float]|None=None,
+)->Usd.Prim:
+"""Create a USD-Mesh capsule prim with the given attributes.
+
+ .. note::
+ This function is decorated with :func:`clone` that resolves prim path into list of paths
+ if the input prim path is a regex pattern. This is done to support spawning multiple assets
+ from a single and cloning the USD prim at the given path expression.
+
+ Args:
+ prim_path: The prim path or pattern to spawn the asset at. If the prim path is a regex pattern,
+ then the asset is spawned at all the matching prim paths.
+ cfg: The configuration instance.
+ translation: The translation to apply to the prim w.r.t. its parent prim. Defaults to None, in which case
+ this is set to the origin.
+ orientation: The orientation in (w, x, y, z) to apply to the prim w.r.t. its parent prim. Defaults to None,
+ in which case this is set to identity.
+
+ Returns:
+ The created prim.
+
+ Raises:
+ ValueError: If a prim already exists at the given path.
+ """
+ # align axis from "Z" to input by rotating the cylinder
+ axis=cfg.axis.upper()
+ ifaxis=="X":
+ transform=trimesh.transformations.rotation_matrix(np.pi/2,[0,1,0])
+ elifaxis=="Y":
+ transform=trimesh.transformations.rotation_matrix(-np.pi/2,[1,0,0])
+ else:
+ transform=None
+ # create a trimesh capsule
+ capsule=trimesh.creation.capsule(radius=cfg.radius,height=cfg.height,transform=transform)
+ # spawn capsule if it doesn't exist.
+ _spawn_mesh_geom_from_mesh(prim_path,cfg,capsule,translation,orientation)
+ # return the prim
+ returnprim_utils.get_prim_at_path(prim_path)
+
+
+
[文档]@clone
+defspawn_mesh_cone(
+ prim_path:str,
+ cfg:meshes_cfg.MeshConeCfg,
+ translation:tuple[float,float,float]|None=None,
+ orientation:tuple[float,float,float,float]|None=None,
+)->Usd.Prim:
+"""Create a USD-Mesh cone prim with the given attributes.
+
+ .. note::
+ This function is decorated with :func:`clone` that resolves prim path into list of paths
+ if the input prim path is a regex pattern. This is done to support spawning multiple assets
+ from a single and cloning the USD prim at the given path expression.
+
+ Args:
+ prim_path: The prim path or pattern to spawn the asset at. If the prim path is a regex pattern,
+ then the asset is spawned at all the matching prim paths.
+ cfg: The configuration instance.
+ translation: The translation to apply to the prim w.r.t. its parent prim. Defaults to None, in which case
+ this is set to the origin.
+ orientation: The orientation in (w, x, y, z) to apply to the prim w.r.t. its parent prim. Defaults to None,
+ in which case this is set to identity.
+
+ Returns:
+ The created prim.
+
+ Raises:
+ ValueError: If a prim already exists at the given path.
+ """
+ # align axis from "Z" to input by rotating the cylinder
+ axis=cfg.axis.upper()
+ ifaxis=="X":
+ transform=trimesh.transformations.rotation_matrix(np.pi/2,[0,1,0])
+ elifaxis=="Y":
+ transform=trimesh.transformations.rotation_matrix(-np.pi/2,[1,0,0])
+ else:
+ transform=None
+ # create a trimesh cone
+ cone=trimesh.creation.cone(radius=cfg.radius,height=cfg.height,transform=transform)
+ # spawn cone if it doesn't exist.
+ _spawn_mesh_geom_from_mesh(prim_path,cfg,cone,translation,orientation)
+ # return the prim
+ returnprim_utils.get_prim_at_path(prim_path)
+
+
+"""
+Helper functions.
+"""
+
+
+def_spawn_mesh_geom_from_mesh(
+ prim_path:str,
+ cfg:meshes_cfg.MeshCfg,
+ mesh:trimesh.Trimesh,
+ translation:tuple[float,float,float]|None=None,
+ orientation:tuple[float,float,float,float]|None=None,
+ scale:tuple[float,float,float]|None=None,
+):
+"""Create a `USDGeomMesh`_ prim from the given mesh.
+
+ This function is similar to :func:`shapes._spawn_geom_from_prim_type` but spawns the prim from a given mesh.
+ In case of the mesh, it is spawned as a USDGeomMesh prim with the given vertices and faces.
+
+ There is a difference in how the properties are applied to the prim based on the type of object:
+
+ - Deformable body properties: The properties are applied to the mesh prim: ``{prim_path}/geometry/mesh``.
+ - Collision properties: The properties are applied to the mesh prim: ``{prim_path}/geometry/mesh``.
+ - Rigid body properties: The properties are applied to the parent prim: ``{prim_path}``.
+
+ Args:
+ prim_path: The prim path to spawn the asset at.
+ cfg: The config containing the properties to apply.
+ mesh: The mesh to spawn the prim from.
+ translation: The translation to apply to the prim w.r.t. its parent prim. Defaults to None, in which case
+ this is set to the origin.
+ orientation: The orientation in (w, x, y, z) to apply to the prim w.r.t. its parent prim. Defaults to None,
+ in which case this is set to identity.
+ scale: The scale to apply to the prim. Defaults to None, in which case this is set to identity.
+
+ Raises:
+ ValueError: If a prim already exists at the given path.
+ ValueError: If both deformable and rigid properties are used.
+ ValueError: If both deformable and collision properties are used.
+ ValueError: If the physics material is not of the correct type. Deformable properties require a deformable
+ physics material, and rigid properties require a rigid physics material.
+
+ .. _USDGeomMesh: https://openusd.org/dev/api/class_usd_geom_mesh.html
+ """
+ # spawn geometry if it doesn't exist.
+ ifnotprim_utils.is_prim_path_valid(prim_path):
+ prim_utils.create_prim(prim_path,prim_type="Xform",translation=translation,orientation=orientation)
+ else:
+ raiseValueError(f"A prim already exists at path: '{prim_path}'.")
+
+ # check that invalid schema types are not used
+ ifcfg.deformable_propsisnotNoneandcfg.rigid_propsisnotNone:
+ raiseValueError("Cannot use both deformable and rigid properties at the same time.")
+ ifcfg.deformable_propsisnotNoneandcfg.collision_propsisnotNone:
+ raiseValueError("Cannot use both deformable and collision properties at the same time.")
+ # check material types are correct
+ ifcfg.deformable_propsisnotNoneandcfg.physics_materialisnotNone:
+ ifnotisinstance(cfg.physics_material,DeformableBodyMaterialCfg):
+ raiseValueError("Deformable properties require a deformable physics material.")
+ ifcfg.rigid_propsisnotNoneandcfg.physics_materialisnotNone:
+ ifnotisinstance(cfg.physics_material,RigidBodyMaterialCfg):
+ raiseValueError("Rigid properties require a rigid physics material.")
+
+ # create all the paths we need for clarity
+ geom_prim_path=prim_path+"/geometry"
+ mesh_prim_path=geom_prim_path+"/mesh"
+
+ # create the mesh prim
+ mesh_prim=prim_utils.create_prim(
+ mesh_prim_path,
+ prim_type="Mesh",
+ scale=scale,
+ attributes={
+ "points":mesh.vertices,
+ "faceVertexIndices":mesh.faces.flatten(),
+ "faceVertexCounts":np.asarray([3]*len(mesh.faces)),
+ "subdivisionScheme":"bilinear",
+ },
+ )
+
+ # note: in case of deformable objects, we need to apply the deformable properties to the mesh prim.
+ # this is different from rigid objects where we apply the properties to the parent prim.
+ ifcfg.deformable_propsisnotNone:
+ # apply mass properties
+ ifcfg.mass_propsisnotNone:
+ schemas.define_mass_properties(mesh_prim_path,cfg.mass_props)
+ # apply deformable body properties
+ schemas.define_deformable_body_properties(mesh_prim_path,cfg.deformable_props)
+ elifcfg.collision_propsisnotNone:
+ # decide on type of collision approximation based on the mesh
+ ifcfg.__class__.__name__=="MeshSphereCfg":
+ collision_approximation="boundingSphere"
+ elifcfg.__class__.__name__=="MeshCuboidCfg":
+ collision_approximation="boundingCube"
+ else:
+ # for: MeshCylinderCfg, MeshCapsuleCfg, MeshConeCfg
+ collision_approximation="convexHull"
+ # apply collision approximation to mesh
+ # note: for primitives, we use the convex hull approximation -- this should be sufficient for most cases.
+ mesh_collision_api=UsdPhysics.MeshCollisionAPI.Apply(mesh_prim)
+ mesh_collision_api.GetApproximationAttr().Set(collision_approximation)
+ # apply collision properties
+ schemas.define_collision_properties(mesh_prim_path,cfg.collision_props)
+
+ # apply visual material
+ ifcfg.visual_materialisnotNone:
+ ifnotcfg.visual_material_path.startswith("/"):
+ material_path=f"{geom_prim_path}/{cfg.visual_material_path}"
+ else:
+ material_path=cfg.visual_material_path
+ # create material
+ cfg.visual_material.func(material_path,cfg.visual_material)
+ # apply material
+ bind_visual_material(mesh_prim_path,material_path)
+
+ # apply physics material
+ ifcfg.physics_materialisnotNone:
+ ifnotcfg.physics_material_path.startswith("/"):
+ material_path=f"{geom_prim_path}/{cfg.physics_material_path}"
+ else:
+ material_path=cfg.physics_material_path
+ # create material
+ cfg.physics_material.func(material_path,cfg.physics_material)
+ # apply material
+ bind_physics_material(mesh_prim_path,material_path)
+
+ # note: we apply the rigid properties to the parent prim in case of rigid objects.
+ ifcfg.rigid_propsisnotNone:
+ # apply mass properties
+ ifcfg.mass_propsisnotNone:
+ schemas.define_mass_properties(prim_path,cfg.mass_props)
+ # apply rigid properties
+ schemas.define_rigid_body_properties(prim_path,cfg.rigid_props)
+
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+from__future__importannotations
+
+fromcollections.abcimportCallable
+fromdataclassesimportMISSING
+fromtypingimportLiteral
+
+fromomni.isaac.lab.sim.spawnersimportmaterials
+fromomni.isaac.lab.sim.spawners.spawner_cfgimportDeformableObjectSpawnerCfg,RigidObjectSpawnerCfg
+fromomni.isaac.lab.utilsimportconfigclass
+
+from.importmeshes
+
+
+
[文档]@configclass
+classMeshCfg(RigidObjectSpawnerCfg,DeformableObjectSpawnerCfg):
+"""Configuration parameters for a USD Geometry or Geom prim.
+
+ This class is similar to :class:`ShapeCfg` but is specifically for meshes.
+
+ Meshes support both rigid and deformable properties. However, their schemas are applied at
+ different levels in the USD hierarchy based on the type of the object. These are described below:
+
+ - Deformable body properties: Applied to the mesh prim: ``{prim_path}/geometry/mesh``.
+ - Collision properties: Applied to the mesh prim: ``{prim_path}/geometry/mesh``.
+ - Rigid body properties: Applied to the parent prim: ``{prim_path}``.
+
+ where ``{prim_path}`` is the path to the prim in the USD stage and ``{prim_path}/geometry/mesh``
+ is the path to the mesh prim.
+
+ .. note::
+ There are mututally exclusive parameters for rigid and deformable properties. If both are set,
+ then an error will be raised. This also holds if collision and deformable properties are set together.
+
+ """
+
+ visual_material_path:str="material"
+"""Path to the visual material to use for the prim. Defaults to "material".
+
+ If the path is relative, then it will be relative to the prim's path.
+ This parameter is ignored if `visual_material` is not None.
+ """
+
+ visual_material:materials.VisualMaterialCfg|None=None
+"""Visual material properties.
+
+ Note:
+ If None, then no visual material will be added.
+ """
+
+ physics_material_path:str="material"
+"""Path to the physics material to use for the prim. Defaults to "material".
+
+ If the path is relative, then it will be relative to the prim's path.
+ This parameter is ignored if `physics_material` is not None.
+ """
+
+ physics_material:materials.PhysicsMaterialCfg|None=None
+"""Physics material properties.
+
+ Note:
+ If None, then no physics material will be added.
+ """
+
+
+
[文档]@configclass
+classMeshSphereCfg(MeshCfg):
+"""Configuration parameters for a sphere mesh prim with deformable properties.
+
+ See :meth:`spawn_mesh_sphere` for more information.
+ """
+
+ func:Callable=meshes.spawn_mesh_sphere
+
+ radius:float=MISSING
+"""Radius of the sphere (in m)."""
+
+
+
[文档]@configclass
+classMeshCuboidCfg(MeshCfg):
+"""Configuration parameters for a cuboid mesh prim with deformable properties.
+
+ See :meth:`spawn_mesh_cuboid` for more information.
+ """
+
+ func:Callable=meshes.spawn_mesh_cuboid
+
+ size:tuple[float,float,float]=MISSING
+"""Size of the cuboid (in m)."""
+
+
+
[文档]@configclass
+classMeshCylinderCfg(MeshCfg):
+"""Configuration parameters for a cylinder mesh prim with deformable properties.
+
+ See :meth:`spawn_cylinder` for more information.
+ """
+
+ func:Callable=meshes.spawn_mesh_cylinder
+
+ radius:float=MISSING
+"""Radius of the cylinder (in m)."""
+ height:float=MISSING
+"""Height of the cylinder (in m)."""
+ axis:Literal["X","Y","Z"]="Z"
+"""Axis of the cylinder. Defaults to "Z"."""
+
+
+
[文档]@configclass
+classMeshCapsuleCfg(MeshCfg):
+"""Configuration parameters for a capsule mesh prim.
+
+ See :meth:`spawn_capsule` for more information.
+ """
+
+ func:Callable=meshes.spawn_mesh_capsule
+
+ radius:float=MISSING
+"""Radius of the capsule (in m)."""
+ height:float=MISSING
+"""Height of the capsule (in m)."""
+ axis:Literal["X","Y","Z"]="Z"
+"""Axis of the capsule. Defaults to "Z"."""
+
+
+
[文档]@configclass
+classMeshConeCfg(MeshCfg):
+"""Configuration parameters for a cone mesh prim.
+
+ See :meth:`spawn_cone` for more information.
+ """
+
+ func:Callable=meshes.spawn_mesh_cone
+
+ radius:float=MISSING
+"""Radius of the cone (in m)."""
+ height:float=MISSING
+"""Height of the v (in m)."""
+ axis:Literal["X","Y","Z"]="Z"
+"""Axis of the cone. Defaults to "Z"."""
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+from__future__importannotations
+
+fromtypingimportTYPE_CHECKING
+
+importcarb
+importomni.isaac.core.utils.primsasprim_utils
+importomni.kit.commands
+frompxrimportSdf,Usd
+
+fromomni.isaac.lab.sim.utilsimportclone
+fromomni.isaac.lab.utilsimportto_camel_case
+
+ifTYPE_CHECKING:
+ from.importsensors_cfg
+
+
+CUSTOM_PINHOLE_CAMERA_ATTRIBUTES={
+ "projection_type":("cameraProjectionType",Sdf.ValueTypeNames.Token),
+}
+"""Custom attributes for pinhole camera model.
+
+The dictionary maps the attribute name in the configuration to the attribute name in the USD prim.
+"""
+
+
+CUSTOM_FISHEYE_CAMERA_ATTRIBUTES={
+ "projection_type":("cameraProjectionType",Sdf.ValueTypeNames.Token),
+ "fisheye_nominal_width":("fthetaWidth",Sdf.ValueTypeNames.Float),
+ "fisheye_nominal_height":("fthetaHeight",Sdf.ValueTypeNames.Float),
+ "fisheye_optical_centre_x":("fthetaCx",Sdf.ValueTypeNames.Float),
+ "fisheye_optical_centre_y":("fthetaCy",Sdf.ValueTypeNames.Float),
+ "fisheye_max_fov":("fthetaMaxFov",Sdf.ValueTypeNames.Float),
+ "fisheye_polynomial_a":("fthetaPolyA",Sdf.ValueTypeNames.Float),
+ "fisheye_polynomial_b":("fthetaPolyB",Sdf.ValueTypeNames.Float),
+ "fisheye_polynomial_c":("fthetaPolyC",Sdf.ValueTypeNames.Float),
+ "fisheye_polynomial_d":("fthetaPolyD",Sdf.ValueTypeNames.Float),
+ "fisheye_polynomial_e":("fthetaPolyE",Sdf.ValueTypeNames.Float),
+ "fisheye_polynomial_f":("fthetaPolyF",Sdf.ValueTypeNames.Float),
+}
+"""Custom attributes for fisheye camera model.
+
+The dictionary maps the attribute name in the configuration to the attribute name in the USD prim.
+"""
+
+
+
[文档]@clone
+defspawn_camera(
+ prim_path:str,
+ cfg:sensors_cfg.PinholeCameraCfg|sensors_cfg.FisheyeCameraCfg,
+ translation:tuple[float,float,float]|None=None,
+ orientation:tuple[float,float,float,float]|None=None,
+)->Usd.Prim:
+"""Create a USD camera prim with given projection type.
+
+ The function creates various attributes on the camera prim that specify the camera's properties.
+ These are later used by ``omni.replicator.core`` to render the scene with the given camera.
+
+ .. note::
+ This function is decorated with :func:`clone` that resolves prim path into list of paths
+ if the input prim path is a regex pattern. This is done to support spawning multiple assets
+ from a single and cloning the USD prim at the given path expression.
+
+ Args:
+ prim_path: The prim path or pattern to spawn the asset at. If the prim path is a regex pattern,
+ then the asset is spawned at all the matching prim paths.
+ cfg: The configuration instance.
+ translation: The translation to apply to the prim w.r.t. its parent prim. Defaults to None, in which case
+ this is set to the origin.
+ orientation: The orientation in (w, x, y, z) to apply to the prim w.r.t. its parent prim. Defaults to None,
+ in which case this is set to identity.
+
+ Returns:
+ The created prim.
+
+ Raises:
+ ValueError: If a prim already exists at the given path.
+ """
+ # spawn camera if it doesn't exist.
+ ifnotprim_utils.is_prim_path_valid(prim_path):
+ prim_utils.create_prim(prim_path,"Camera",translation=translation,orientation=orientation)
+ else:
+ raiseValueError(f"A prim already exists at path: '{prim_path}'.")
+
+ # lock camera from viewport (this disables viewport movement for camera)
+ ifcfg.lock_camera:
+ omni.kit.commands.execute(
+ "ChangePropertyCommand",
+ prop_path=Sdf.Path(f"{prim_path}.omni:kit:cameraLock"),
+ value=True,
+ prev=None,
+ type_to_create_if_not_exist=Sdf.ValueTypeNames.Bool,
+ )
+ # decide the custom attributes to add
+ ifcfg.projection_type=="pinhole":
+ attribute_types=CUSTOM_PINHOLE_CAMERA_ATTRIBUTES
+ else:
+ attribute_types=CUSTOM_FISHEYE_CAMERA_ATTRIBUTES
+
+ # TODO: Adjust to handle aperture offsets once supported by omniverse
+ # Internal ticket from rendering team: OM-42611
+ ifcfg.horizontal_aperture_offset>1e-4orcfg.vertical_aperture_offset>1e-4:
+ carb.log_warn("Camera aperture offsets are not supported by Omniverse. These parameters will be ignored.")
+
+ # custom attributes in the config that are not USD Camera parameters
+ non_usd_cfg_param_names=[
+ "func",
+ "copy_from_source",
+ "lock_camera",
+ "visible",
+ "semantic_tags",
+ "from_intrinsic_matrix",
+ ]
+ # get camera prim
+ prim=prim_utils.get_prim_at_path(prim_path)
+ # create attributes for the fisheye camera model
+ # note: for pinhole those are already part of the USD camera prim
+ forattr_name,attr_typeinattribute_types.values():
+ # check if attribute does not exist
+ ifprim.GetAttribute(attr_name).Get()isNone:
+ # create attribute based on type
+ prim.CreateAttribute(attr_name,attr_type)
+ # set attribute values
+ forparam_name,param_valueincfg.__dict__.items():
+ # check if value is valid
+ ifparam_valueisNoneorparam_nameinnon_usd_cfg_param_names:
+ continue
+ # obtain prim property name
+ ifparam_nameinattribute_types:
+ # check custom attributes
+ prim_prop_name=attribute_types[param_name][0]
+ else:
+ # convert attribute name in prim to cfg name
+ prim_prop_name=to_camel_case(param_name,to="cC")
+ # get attribute from the class
+ prim.GetAttribute(prim_prop_name).Set(param_value)
+ # return the prim
+ returnprim_utils.get_prim_at_path(prim_path)
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+from__future__importannotations
+
+fromcollections.abcimportCallable
+fromtypingimportLiteral
+
+fromomni.isaac.lab.sim.spawners.spawner_cfgimportSpawnerCfg
+fromomni.isaac.lab.utilsimportconfigclass
+
+from.importsensors
+
+
+
[文档]@configclass
+classPinholeCameraCfg(SpawnerCfg):
+"""Configuration parameters for a USD camera prim with pinhole camera settings.
+
+ For more information on the parameters, please refer to the `camera documentation <https://docs.omniverse.nvidia.com/materials-and-rendering/latest/cameras.html>`__.
+
+ ..note ::
+ Focal length as well as the aperture sizes and offsets are set as a tenth of the world unit. In our case, the
+ world unit is Meter s.t. all of these values are set in cm.
+
+ .. note::
+ The default values are taken from the `Replicator camera <https://docs.omniverse.nvidia.com/py/replicator/1.9.8/source/extensions/omni.replicator.core/docs/API.html#omni.replicator.core.create.camera>`__
+ function.
+ """
+
+ func:Callable=sensors.spawn_camera
+
+ projection_type:str="pinhole"
+"""Type of projection to use for the camera. Defaults to "pinhole".
+
+ Note:
+ Currently only "pinhole" is supported.
+ """
+
+ clipping_range:tuple[float,float]=(0.01,1e6)
+"""Near and far clipping distances (in m). Defaults to (0.01, 1e6).
+
+ The minimum clipping range will shift the camera forward by the specified distance. Don't set it too high to
+ avoid issues for distance related data types (e.g., ``distance_to_image_plane``).
+ """
+
+ focal_length:float=24.0
+"""Perspective focal length (in cm). Defaults to 24.0cm.
+
+ Longer lens lengths narrower FOV, shorter lens lengths wider FOV.
+ """
+
+ focus_distance:float=400.0
+"""Distance from the camera to the focus plane (in m). Defaults to 400.0.
+
+ The distance at which perfect sharpness is achieved.
+ """
+
+ f_stop:float=0.0
+"""Lens aperture. Defaults to 0.0, which turns off focusing.
+
+ Controls Distance Blurring. Lower Numbers decrease focus range, larger numbers increase it.
+ """
+
+ horizontal_aperture:float=20.955
+"""Horizontal aperture (in cm). Defaults to 20.955 cm.
+
+ Emulates sensor/film width on a camera.
+
+ Note:
+ The default value is the horizontal aperture of a 35 mm spherical projector.
+ """
+
+ vertical_aperture:float|None=None
+r"""Vertical aperture (in mm). Defaults to None.
+
+ Emulates sensor/film height on a camera. If None, then the vertical aperture is calculated based on the
+ horizontal aperture and the aspect ratio of the image to maintain squared pixels. This is calculated as:
+
+ .. math::
+ \text{vertical aperture} = \text{horizontal aperture} \times \frac{\text{height}}{\text{width}}
+ """
+
+ horizontal_aperture_offset:float=0.0
+"""Offsets Resolution/Film gate horizontally. Defaults to 0.0."""
+
+ vertical_aperture_offset:float=0.0
+"""Offsets Resolution/Film gate vertically. Defaults to 0.0."""
+
+ lock_camera:bool=True
+"""Locks the camera in the Omniverse viewport. Defaults to True.
+
+ If True, then the camera remains fixed at its configured transform. This is useful when wanting to view
+ the camera output on the GUI and not accidentally moving the camera through the GUI interactions.
+ """
+
+
[文档]@classmethod
+ deffrom_intrinsic_matrix(
+ cls,
+ intrinsic_matrix:list[float],
+ width:int,
+ height:int,
+ clipping_range:tuple[float,float]=(0.01,1e6),
+ focal_length:float=24.0,
+ focus_distance:float=400.0,
+ f_stop:float=0.0,
+ projection_type:str="pinhole",
+ lock_camera:bool=True,
+ )->PinholeCameraCfg:
+r"""Create a :class:`PinholeCameraCfg` class instance from an intrinsic matrix.
+
+ The intrinsic matrix is a 3x3 matrix that defines the mapping between the 3D world coordinates and
+ the 2D image. The matrix is defined as:
+
+ .. math::
+ I_{cam} = \begin{bmatrix}
+ f_x & 0 & c_x \\
+ 0 & f_y & c_y \\
+ 0 & 0 & 1
+ \\end{bmatrix},
+
+ where :math:`f_x` and :math:`f_y` are the focal length along x and y direction, while :math:`c_x` and :math:`c_y` are the
+ principle point offsets along x and y direction respectively.
+
+ Args:
+ intrinsic_matrix: Intrinsic matrix of the camera in row-major format.
+ The matrix is defined as [f_x, 0, c_x, 0, f_y, c_y, 0, 0, 1]. Shape is (9,).
+ width: Width of the image (in pixels).
+ height: Height of the image (in pixels).
+ clipping_range: Near and far clipping distances (in m). Defaults to (0.01, 1e6).
+ focal_length: Perspective focal length (in cm). Defaults to 24.0 cm.
+ focus_distance: Distance from the camera to the focus plane (in m). Defaults to 400.0 m.
+ f_stop: Lens aperture. Defaults to 0.0, which turns off focusing.
+ projection_type: Type of projection to use for the camera. Defaults to "pinhole".
+ lock_camera: Locks the camera in the Omniverse viewport. Defaults to True.
+
+ Returns:
+ An instance of the :class:`PinholeCameraCfg` class.
+ """
+ # raise not implemented error is projection type is not pinhole
+ ifprojection_type!="pinhole":
+ raiseNotImplementedError("Only pinhole projection type is supported.")
+
+ # extract parameters from matrix
+ f_x=intrinsic_matrix[0]
+ c_x=intrinsic_matrix[2]
+ f_y=intrinsic_matrix[4]
+ c_y=intrinsic_matrix[5]
+ # resolve parameters for usd camera
+ horizontal_aperture=width*focal_length/f_x
+ vertical_aperture=height*focal_length/f_y
+ horizontal_aperture_offset=(c_x-width/2)/f_x
+ vertical_aperture_offset=(c_y-height/2)/f_y
+
+ returncls(
+ projection_type=projection_type,
+ clipping_range=clipping_range,
+ focal_length=focal_length,
+ focus_distance=focus_distance,
+ f_stop=f_stop,
+ horizontal_aperture=horizontal_aperture,
+ vertical_aperture=vertical_aperture,
+ horizontal_aperture_offset=horizontal_aperture_offset,
+ vertical_aperture_offset=vertical_aperture_offset,
+ lock_camera=lock_camera,
+ )
+
+
+
[文档]@configclass
+classFisheyeCameraCfg(PinholeCameraCfg):
+"""Configuration parameters for a USD camera prim with `fish-eye camera`_ settings.
+
+ For more information on the parameters, please refer to the
+ `camera documentation <https://docs.omniverse.nvidia.com/materials-and-rendering/latest/cameras.html#fisheye-properties>`__.
+
+ .. note::
+ The default values are taken from the `Replicator camera <https://docs.omniverse.nvidia.com/py/replicator/1.9.8/source/extensions/omni.replicator.core/docs/API.html#omni.replicator.core.create.camera>`__
+ function.
+
+ .. _fish-eye camera: https://en.wikipedia.org/wiki/Fisheye_lens
+ """
+
+ func:Callable=sensors.spawn_camera
+
+ projection_type:Literal[
+ "fisheye_orthographic","fisheye_equidistant","fisheye_equisolid","fisheye_polynomial","fisheye_spherical"
+ ]="fisheye_polynomial"
+r"""Type of projection to use for the camera. Defaults to "fisheye_polynomial".
+
+ Available options:
+
+ - ``"fisheye_orthographic"``: Fisheye camera model using orthographic correction.
+ - ``"fisheye_equidistant"``: Fisheye camera model using equidistant correction.
+ - ``"fisheye_equisolid"``: Fisheye camera model using equisolid correction.
+ - ``"fisheye_polynomial"``: Fisheye camera model with :math:`360^{\circ}` spherical projection.
+ - ``"fisheye_spherical"``: Fisheye camera model with :math:`360^{\circ}` full-frame projection.
+ """
+
+ fisheye_nominal_width:float=1936.0
+"""Nominal width of fisheye lens model (in pixels). Defaults to 1936.0."""
+
+ fisheye_nominal_height:float=1216.0
+"""Nominal height of fisheye lens model (in pixels). Defaults to 1216.0."""
+
+ fisheye_optical_centre_x:float=970.94244
+"""Horizontal optical centre position of fisheye lens model (in pixels). Defaults to 970.94244."""
+
+ fisheye_optical_centre_y:float=600.37482
+"""Vertical optical centre position of fisheye lens model (in pixels). Defaults to 600.37482."""
+
+ fisheye_max_fov:float=200.0
+"""Maximum field of view of fisheye lens model (in degrees). Defaults to 200.0 degrees."""
+
+ fisheye_polynomial_a:float=0.0
+"""First component of fisheye polynomial. Defaults to 0.0."""
+
+ fisheye_polynomial_b:float=0.00245
+"""Second component of fisheye polynomial. Defaults to 0.00245."""
+
+ fisheye_polynomial_c:float=0.0
+"""Third component of fisheye polynomial. Defaults to 0.0."""
+
+ fisheye_polynomial_d:float=0.0
+"""Fourth component of fisheye polynomial. Defaults to 0.0."""
+
+ fisheye_polynomial_e:float=0.0
+"""Fifth component of fisheye polynomial. Defaults to 0.0."""
+
+ fisheye_polynomial_f:float=0.0
+"""Sixth component of fisheye polynomial. Defaults to 0.0."""
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+from__future__importannotations
+
+fromtypingimportTYPE_CHECKING
+
+importomni.isaac.core.utils.primsasprim_utils
+frompxrimportUsd
+
+fromomni.isaac.lab.simimportschemas
+fromomni.isaac.lab.sim.utilsimportbind_physics_material,bind_visual_material,clone
+
+ifTYPE_CHECKING:
+ from.importshapes_cfg
+
+
+
[文档]@clone
+defspawn_sphere(
+ prim_path:str,
+ cfg:shapes_cfg.SphereCfg,
+ translation:tuple[float,float,float]|None=None,
+ orientation:tuple[float,float,float,float]|None=None,
+)->Usd.Prim:
+"""Create a USDGeom-based sphere prim with the given attributes.
+
+ For more information, see `USDGeomSphere <https://openusd.org/dev/api/class_usd_geom_sphere.html>`_.
+
+ .. note::
+ This function is decorated with :func:`clone` that resolves prim path into list of paths
+ if the input prim path is a regex pattern. This is done to support spawning multiple assets
+ from a single and cloning the USD prim at the given path expression.
+
+ Args:
+ prim_path: The prim path or pattern to spawn the asset at. If the prim path is a regex pattern,
+ then the asset is spawned at all the matching prim paths.
+ cfg: The configuration instance.
+ translation: The translation to apply to the prim w.r.t. its parent prim. Defaults to None, in which case
+ this is set to the origin.
+ orientation: The orientation in (w, x, y, z) to apply to the prim w.r.t. its parent prim. Defaults to None,
+ in which case this is set to identity.
+
+ Returns:
+ The created prim.
+
+ Raises:
+ ValueError: If a prim already exists at the given path.
+ """
+ # spawn sphere if it doesn't exist.
+ attributes={"radius":cfg.radius}
+ _spawn_geom_from_prim_type(prim_path,cfg,"Sphere",attributes,translation,orientation)
+ # return the prim
+ returnprim_utils.get_prim_at_path(prim_path)
+
+
+
[文档]@clone
+defspawn_cuboid(
+ prim_path:str,
+ cfg:shapes_cfg.CuboidCfg,
+ translation:tuple[float,float,float]|None=None,
+ orientation:tuple[float,float,float,float]|None=None,
+)->Usd.Prim:
+"""Create a USDGeom-based cuboid prim with the given attributes.
+
+ For more information, see `USDGeomCube <https://openusd.org/dev/api/class_usd_geom_cube.html>`_.
+
+ Note:
+ Since USD only supports cubes, we set the size of the cube to the minimum of the given size and
+ scale the cube accordingly.
+
+ .. note::
+ This function is decorated with :func:`clone` that resolves prim path into list of paths
+ if the input prim path is a regex pattern. This is done to support spawning multiple assets
+ from a single and cloning the USD prim at the given path expression.
+
+ Args:
+ prim_path: The prim path or pattern to spawn the asset at. If the prim path is a regex pattern,
+ then the asset is spawned at all the matching prim paths.
+ cfg: The configuration instance.
+ translation: The translation to apply to the prim w.r.t. its parent prim. Defaults to None, in which case
+ this is set to the origin.
+ orientation: The orientation in (w, x, y, z) to apply to the prim w.r.t. its parent prim. Defaults to None,
+ in which case this is set to identity.
+
+ Returns:
+ The created prim.
+
+ Raises:
+ If a prim already exists at the given path.
+ """
+ # resolve the scale
+ size=min(cfg.size)
+ scale=[dim/sizefordimincfg.size]
+ # spawn cuboid if it doesn't exist.
+ attributes={"size":size}
+ _spawn_geom_from_prim_type(prim_path,cfg,"Cube",attributes,translation,orientation,scale)
+ # return the prim
+ returnprim_utils.get_prim_at_path(prim_path)
+
+
+
[文档]@clone
+defspawn_cylinder(
+ prim_path:str,
+ cfg:shapes_cfg.CylinderCfg,
+ translation:tuple[float,float,float]|None=None,
+ orientation:tuple[float,float,float,float]|None=None,
+)->Usd.Prim:
+"""Create a USDGeom-based cylinder prim with the given attributes.
+
+ For more information, see `USDGeomCylinder <https://openusd.org/dev/api/class_usd_geom_cylinder.html>`_.
+
+ .. note::
+ This function is decorated with :func:`clone` that resolves prim path into list of paths
+ if the input prim path is a regex pattern. This is done to support spawning multiple assets
+ from a single and cloning the USD prim at the given path expression.
+
+ Args:
+ prim_path: The prim path or pattern to spawn the asset at. If the prim path is a regex pattern,
+ then the asset is spawned at all the matching prim paths.
+ cfg: The configuration instance.
+ translation: The translation to apply to the prim w.r.t. its parent prim. Defaults to None, in which case
+ this is set to the origin.
+ orientation: The orientation in (w, x, y, z) to apply to the prim w.r.t. its parent prim. Defaults to None,
+ in which case this is set to identity.
+
+ Returns:
+ The created prim.
+
+ Raises:
+ ValueError: If a prim already exists at the given path.
+ """
+ # spawn cylinder if it doesn't exist.
+ attributes={"radius":cfg.radius,"height":cfg.height,"axis":cfg.axis.upper()}
+ _spawn_geom_from_prim_type(prim_path,cfg,"Cylinder",attributes,translation,orientation)
+ # return the prim
+ returnprim_utils.get_prim_at_path(prim_path)
+
+
+
[文档]@clone
+defspawn_capsule(
+ prim_path:str,
+ cfg:shapes_cfg.CapsuleCfg,
+ translation:tuple[float,float,float]|None=None,
+ orientation:tuple[float,float,float,float]|None=None,
+)->Usd.Prim:
+"""Create a USDGeom-based capsule prim with the given attributes.
+
+ For more information, see `USDGeomCapsule <https://openusd.org/dev/api/class_usd_geom_capsule.html>`_.
+
+ .. note::
+ This function is decorated with :func:`clone` that resolves prim path into list of paths
+ if the input prim path is a regex pattern. This is done to support spawning multiple assets
+ from a single and cloning the USD prim at the given path expression.
+
+ Args:
+ prim_path: The prim path or pattern to spawn the asset at. If the prim path is a regex pattern,
+ then the asset is spawned at all the matching prim paths.
+ cfg: The configuration instance.
+ translation: The translation to apply to the prim w.r.t. its parent prim. Defaults to None, in which case
+ this is set to the origin.
+ orientation: The orientation in (w, x, y, z) to apply to the prim w.r.t. its parent prim. Defaults to None,
+ in which case this is set to identity.
+
+ Returns:
+ The created prim.
+
+ Raises:
+ ValueError: If a prim already exists at the given path.
+ """
+ # spawn capsule if it doesn't exist.
+ attributes={"radius":cfg.radius,"height":cfg.height,"axis":cfg.axis.upper()}
+ _spawn_geom_from_prim_type(prim_path,cfg,"Capsule",attributes,translation,orientation)
+ # return the prim
+ returnprim_utils.get_prim_at_path(prim_path)
+
+
+
[文档]@clone
+defspawn_cone(
+ prim_path:str,
+ cfg:shapes_cfg.ConeCfg,
+ translation:tuple[float,float,float]|None=None,
+ orientation:tuple[float,float,float,float]|None=None,
+)->Usd.Prim:
+"""Create a USDGeom-based cone prim with the given attributes.
+
+ For more information, see `USDGeomCone <https://openusd.org/dev/api/class_usd_geom_cone.html>`_.
+
+ .. note::
+ This function is decorated with :func:`clone` that resolves prim path into list of paths
+ if the input prim path is a regex pattern. This is done to support spawning multiple assets
+ from a single and cloning the USD prim at the given path expression.
+
+ Args:
+ prim_path: The prim path or pattern to spawn the asset at. If the prim path is a regex pattern,
+ then the asset is spawned at all the matching prim paths.
+ cfg: The configuration instance.
+ translation: The translation to apply to the prim w.r.t. its parent prim. Defaults to None, in which case
+ this is set to the origin.
+ orientation: The orientation in (w, x, y, z) to apply to the prim w.r.t. its parent prim. Defaults to None,
+ in which case this is set to identity.
+
+ Returns:
+ The created prim.
+
+ Raises:
+ ValueError: If a prim already exists at the given path.
+ """
+ # spawn cone if it doesn't exist.
+ attributes={"radius":cfg.radius,"height":cfg.height,"axis":cfg.axis.upper()}
+ _spawn_geom_from_prim_type(prim_path,cfg,"Cone",attributes,translation,orientation)
+ # return the prim
+ returnprim_utils.get_prim_at_path(prim_path)
+
+
+"""
+Helper functions.
+"""
+
+
+def_spawn_geom_from_prim_type(
+ prim_path:str,
+ cfg:shapes_cfg.ShapeCfg,
+ prim_type:str,
+ attributes:dict,
+ translation:tuple[float,float,float]|None=None,
+ orientation:tuple[float,float,float,float]|None=None,
+ scale:tuple[float,float,float]|None=None,
+):
+"""Create a USDGeom-based prim with the given attributes.
+
+ To make the asset instanceable, we must follow a certain structure dictated by how USD scene-graph
+ instancing and physics work. The rigid body component must be added to each instance and not the
+ referenced asset (i.e. the prototype prim itself). This is because the rigid body component defines
+ properties that are specific to each instance and cannot be shared under the referenced asset. For
+ more information, please check the `documentation <https://docs.omniverse.nvidia.com/extensions/latest/ext_physics/rigid-bodies.html#instancing-rigid-bodies>`_.
+
+ Due to the above, we follow the following structure:
+
+ * ``{prim_path}`` - The root prim that is an Xform with the rigid body and mass APIs if configured.
+ * ``{prim_path}/geometry`` - The prim that contains the mesh and optionally the materials if configured.
+ If instancing is enabled, this prim will be an instanceable reference to the prototype prim.
+
+ Args:
+ prim_path: The prim path to spawn the asset at.
+ cfg: The config containing the properties to apply.
+ prim_type: The type of prim to create.
+ attributes: The attributes to apply to the prim.
+ translation: The translation to apply to the prim w.r.t. its parent prim. Defaults to None, in which case
+ this is set to the origin.
+ orientation: The orientation in (w, x, y, z) to apply to the prim w.r.t. its parent prim. Defaults to None,
+ in which case this is set to identity.
+ scale: The scale to apply to the prim. Defaults to None, in which case this is set to identity.
+
+ Raises:
+ ValueError: If a prim already exists at the given path.
+ """
+ # spawn geometry if it doesn't exist.
+ ifnotprim_utils.is_prim_path_valid(prim_path):
+ prim_utils.create_prim(prim_path,prim_type="Xform",translation=translation,orientation=orientation)
+ else:
+ raiseValueError(f"A prim already exists at path: '{prim_path}'.")
+
+ # create all the paths we need for clarity
+ geom_prim_path=prim_path+"/geometry"
+ mesh_prim_path=geom_prim_path+"/mesh"
+
+ # create the geometry prim
+ prim_utils.create_prim(mesh_prim_path,prim_type,scale=scale,attributes=attributes)
+ # apply collision properties
+ ifcfg.collision_propsisnotNone:
+ schemas.define_collision_properties(mesh_prim_path,cfg.collision_props)
+ # apply visual material
+ ifcfg.visual_materialisnotNone:
+ ifnotcfg.visual_material_path.startswith("/"):
+ material_path=f"{geom_prim_path}/{cfg.visual_material_path}"
+ else:
+ material_path=cfg.visual_material_path
+ # create material
+ cfg.visual_material.func(material_path,cfg.visual_material)
+ # apply material
+ bind_visual_material(mesh_prim_path,material_path)
+ # apply physics material
+ ifcfg.physics_materialisnotNone:
+ ifnotcfg.physics_material_path.startswith("/"):
+ material_path=f"{geom_prim_path}/{cfg.physics_material_path}"
+ else:
+ material_path=cfg.physics_material_path
+ # create material
+ cfg.physics_material.func(material_path,cfg.physics_material)
+ # apply material
+ bind_physics_material(mesh_prim_path,material_path)
+
+ # note: we apply rigid properties in the end to later make the instanceable prim
+ # apply mass properties
+ ifcfg.mass_propsisnotNone:
+ schemas.define_mass_properties(prim_path,cfg.mass_props)
+ # apply rigid body properties
+ ifcfg.rigid_propsisnotNone:
+ schemas.define_rigid_body_properties(prim_path,cfg.rigid_props)
+
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+fromcollections.abcimportCallable
+fromdataclassesimportMISSING
+fromtypingimportLiteral
+
+fromomni.isaac.lab.sim.spawnersimportmaterials
+fromomni.isaac.lab.sim.spawners.spawner_cfgimportRigidObjectSpawnerCfg
+fromomni.isaac.lab.utilsimportconfigclass
+
+from.importshapes
+
+
+
[文档]@configclass
+classShapeCfg(RigidObjectSpawnerCfg):
+"""Configuration parameters for a USD Geometry or Geom prim."""
+
+ visual_material_path:str="material"
+"""Path to the visual material to use for the prim. Defaults to "material".
+
+ If the path is relative, then it will be relative to the prim's path.
+ This parameter is ignored if `visual_material` is not None.
+ """
+ visual_material:materials.VisualMaterialCfg|None=None
+"""Visual material properties.
+
+ Note:
+ If None, then no visual material will be added.
+ """
+
+ physics_material_path:str="material"
+"""Path to the physics material to use for the prim. Defaults to "material".
+
+ If the path is relative, then it will be relative to the prim's path.
+ This parameter is ignored if `physics_material` is not None.
+ """
+ physics_material:materials.PhysicsMaterialCfg|None=None
+"""Physics material properties.
+
+ Note:
+ If None, then no physics material will be added.
+ """
+
+
+
[文档]@configclass
+classSphereCfg(ShapeCfg):
+"""Configuration parameters for a sphere prim.
+
+ See :meth:`spawn_sphere` for more information.
+ """
+
+ func:Callable=shapes.spawn_sphere
+
+ radius:float=MISSING
+"""Radius of the sphere (in m)."""
+
+
+
[文档]@configclass
+classCuboidCfg(ShapeCfg):
+"""Configuration parameters for a cuboid prim.
+
+ See :meth:`spawn_cuboid` for more information.
+ """
+
+ func:Callable=shapes.spawn_cuboid
+
+ size:tuple[float,float,float]=MISSING
+"""Size of the cuboid."""
+
+
+
[文档]@configclass
+classCylinderCfg(ShapeCfg):
+"""Configuration parameters for a cylinder prim.
+
+ See :meth:`spawn_cylinder` for more information.
+ """
+
+ func:Callable=shapes.spawn_cylinder
+
+ radius:float=MISSING
+"""Radius of the cylinder (in m)."""
+ height:float=MISSING
+"""Height of the cylinder (in m)."""
+ axis:Literal["X","Y","Z"]="Z"
+"""Axis of the cylinder. Defaults to "Z"."""
+
+
+
[文档]@configclass
+classCapsuleCfg(ShapeCfg):
+"""Configuration parameters for a capsule prim.
+
+ See :meth:`spawn_capsule` for more information.
+ """
+
+ func:Callable=shapes.spawn_capsule
+
+ radius:float=MISSING
+"""Radius of the capsule (in m)."""
+ height:float=MISSING
+"""Height of the capsule (in m)."""
+ axis:Literal["X","Y","Z"]="Z"
+"""Axis of the capsule. Defaults to "Z"."""
+
+
+
[文档]@configclass
+classConeCfg(ShapeCfg):
+"""Configuration parameters for a cone prim.
+
+ See :meth:`spawn_cone` for more information.
+ """
+
+ func:Callable=shapes.spawn_cone
+
+ radius:float=MISSING
+"""Radius of the cone (in m)."""
+ height:float=MISSING
+"""Height of the v (in m)."""
+ axis:Literal["X","Y","Z"]="Z"
+"""Axis of the cone. Defaults to "Z"."""
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+from__future__importannotations
+
+fromcollections.abcimportCallable
+fromdataclassesimportMISSING
+
+frompxrimportUsd
+
+fromomni.isaac.lab.simimportschemas
+fromomni.isaac.lab.utilsimportconfigclass
+
+
+
[文档]@configclass
+classSpawnerCfg:
+"""Configuration parameters for spawning an asset.
+
+ Spawning an asset is done by calling the :attr:`func` function. The function takes in the
+ prim path to spawn the asset at, the configuration instance and transformation, and returns the
+ prim path of the spawned asset.
+
+ The function is typically decorated with :func:`omni.isaac.lab.sim.spawner.utils.clone` decorator
+ that checks if input prim path is a regex expression and spawns the asset at all matching prims.
+ For this, the decorator uses the Cloner API from Isaac Sim and handles the :attr:`copy_from_source`
+ parameter.
+ """
+
+ func:Callable[...,Usd.Prim]=MISSING
+"""Function to use for spawning the asset.
+
+ The function takes in the prim path (or expression) to spawn the asset at, the configuration instance
+ and transformation, and returns the source prim spawned.
+ """
+
+ visible:bool=True
+"""Whether the spawned asset should be visible. Defaults to True."""
+
+ semantic_tags:list[tuple[str,str]]|None=None
+"""List of semantic tags to add to the spawned asset. Defaults to None,
+ which means no semantic tags will be added.
+
+ The semantic tags follow the `Replicator Semantic` tagging system. Each tag is a tuple of the
+ form ``(type, data)``, where ``type`` is the type of the tag and ``data`` is the semantic label
+ associated with the tag. For example, to annotate a spawned asset in the class avocado, the semantic
+ tag would be ``[("class", "avocado")]``.
+
+ You can specify multiple semantic tags by passing in a list of tags. For example, to annotate a
+ spawned asset in the class avocado and the color green, the semantic tags would be
+ ``[("class", "avocado"), ("color", "green")]``.
+
+ .. seealso::
+
+ For more information on the semantics filter, see the documentation for the `semantics schema editor`_.
+
+ .. _semantics schema editor: https://docs.omniverse.nvidia.com/extensions/latest/ext_replicator/semantics_schema_editor.html#semantics-filtering
+
+ """
+
+ copy_from_source:bool=True
+"""Whether to copy the asset from the source prim or inherit it. Defaults to True.
+
+ This parameter is only used when cloning prims. If False, then the asset will be inherited from
+ the source prim, i.e. all USD changes to the source prim will be reflected in the cloned prims.
+
+ .. versionadded:: 2023.1
+
+ This parameter is only supported from Isaac Sim 2023.1 onwards. If you are using an older
+ version of Isaac Sim, this parameter will be ignored.
+ """
+
+
+
[文档]@configclass
+classRigidObjectSpawnerCfg(SpawnerCfg):
+"""Configuration parameters for spawning a rigid asset.
+
+ Note:
+ By default, all properties are set to None. This means that no properties will be added or modified
+ to the prim outside of the properties available by default when spawning the prim.
+ """
+
+ mass_props:schemas.MassPropertiesCfg|None=None
+"""Mass properties."""
+
+ rigid_props:schemas.RigidBodyPropertiesCfg|None=None
+"""Rigid body properties.
+
+ For making a rigid object static, set the :attr:`schemas.RigidBodyPropertiesCfg.kinematic_enabled`
+ as True. This will make the object static and will not be affected by gravity or other forces.
+ """
+
+ collision_props:schemas.CollisionPropertiesCfg|None=None
+"""Properties to apply to all collision meshes."""
+
+ activate_contact_sensors:bool=False
+"""Activate contact reporting on all rigid bodies. Defaults to False.
+
+ This adds the PhysxContactReporter API to all the rigid bodies in the given prim path and its children.
+ """
+
+
+
[文档]@configclass
+classDeformableObjectSpawnerCfg(SpawnerCfg):
+"""Configuration parameters for spawning a deformable asset.
+
+ Unlike rigid objects, deformable objects are affected by forces and can deform when subjected to
+ external forces. This class is used to configure the properties of the deformable object.
+
+ Deformable bodies don't have a separate collision mesh. The collision mesh is the same as the visual mesh.
+ The collision properties such as rest and collision offsets are specified in the :attr:`deformable_props`.
+
+ Note:
+ By default, all properties are set to None. This means that no properties will be added or modified
+ to the prim outside of the properties available by default when spawning the prim.
+ """
+
+ mass_props:schemas.MassPropertiesCfg|None=None
+"""Mass properties."""
+
+ deformable_props:schemas.DeformableBodyPropertiesCfg|None=None
+"""Deformable body properties."""
[文档]defsafe_set_attribute_on_usd_schema(schema_api:Usd.APISchemaBase,name:str,value:Any,camel_case:bool):
+"""Set the value of an attribute on its USD schema if it exists.
+
+ A USD API schema serves as an interface or API for authoring and extracting a set of attributes.
+ They typically derive from the :class:`pxr.Usd.SchemaBase` class. This function checks if the
+ attribute exists on the schema and sets the value of the attribute if it exists.
+
+ Args:
+ schema_api: The USD schema to set the attribute on.
+ name: The name of the attribute.
+ value: The value to set the attribute to.
+ camel_case: Whether to convert the attribute name to camel case.
+
+ Raises:
+ TypeError: When the input attribute name does not exist on the provided schema API.
+ """
+ # if value is None, do nothing
+ ifvalueisNone:
+ return
+ # convert attribute name to camel case
+ ifcamel_case:
+ attr_name=to_camel_case(name,to="CC")
+ else:
+ attr_name=name
+ # retrieve the attribute
+ # reference: https://openusd.org/dev/api/_usd__page__common_idioms.html#Usd_Create_Or_Get_Property
+ attr=getattr(schema_api,f"Create{attr_name}Attr",None)
+ # check if attribute exists
+ ifattrisnotNone:
+ attr().Set(value)
+ else:
+ # think: do we ever need to create the attribute if it doesn't exist?
+ # currently, we are not doing this since the schemas are already created with some defaults.
+ carb.log_error(f"Attribute '{attr_name}' does not exist on prim '{schema_api.GetPath()}'.")
+ raiseTypeError(f"Attribute '{attr_name}' does not exist on prim '{schema_api.GetPath()}'.")
+
+
+
[文档]defsafe_set_attribute_on_usd_prim(prim:Usd.Prim,attr_name:str,value:Any,camel_case:bool):
+"""Set the value of a attribute on its USD prim.
+
+ The function creates a new attribute if it does not exist on the prim. This is because in some cases (such
+ as with shaders), their attributes are not exposed as USD prim properties that can be altered. This function
+ allows us to set the value of the attributes in these cases.
+
+ Args:
+ prim: The USD prim to set the attribute on.
+ attr_name: The name of the attribute.
+ value: The value to set the attribute to.
+ camel_case: Whether to convert the attribute name to camel case.
+ """
+ # if value is None, do nothing
+ ifvalueisNone:
+ return
+ # convert attribute name to camel case
+ ifcamel_case:
+ attr_name=to_camel_case(attr_name,to="cC")
+ # resolve sdf type based on value
+ ifisinstance(value,bool):
+ sdf_type=Sdf.ValueTypeNames.Bool
+ elifisinstance(value,int):
+ sdf_type=Sdf.ValueTypeNames.Int
+ elifisinstance(value,float):
+ sdf_type=Sdf.ValueTypeNames.Float
+ elifisinstance(value,(tuple,list))andlen(value)==3andany(isinstance(v,float)forvinvalue):
+ sdf_type=Sdf.ValueTypeNames.Float3
+ elifisinstance(value,(tuple,list))andlen(value)==2andany(isinstance(v,float)forvinvalue):
+ sdf_type=Sdf.ValueTypeNames.Float2
+ else:
+ raiseNotImplementedError(
+ f"Cannot set attribute '{attr_name}' with value '{value}'. Please modify the code to support this type."
+ )
+ # change property
+ omni.kit.commands.execute(
+ "ChangePropertyCommand",
+ prop_path=Sdf.Path(f"{prim.GetPath()}.{attr_name}"),
+ value=value,
+ prev=None,
+ type_to_create_if_not_exist=sdf_type,
+ usd_context_name=prim.GetStage(),
+ )
+
+
+"""
+Decorators.
+"""
+
+
+
[文档]defapply_nested(func:Callable)->Callable:
+"""Decorator to apply a function to all prims under a specified prim-path.
+
+ The function iterates over the provided prim path and all its children to apply input function
+ to all prims under the specified prim path.
+
+ If the function succeeds to apply to a prim, it will not look at the children of that prim.
+ This is based on the physics behavior that nested schemas are not allowed. For example, a parent prim
+ and its child prim cannot both have a rigid-body schema applied on them, or it is not possible to
+ have nested articulations.
+
+ While traversing the prims under the specified prim path, the function will throw a warning if it
+ does not succeed to apply the function to any prim. This is because the user may have intended to
+ apply the function to a prim that does not have valid attributes, or the prim may be an instanced prim.
+
+ Args:
+ func: The function to apply to all prims under a specified prim-path. The function
+ must take the prim-path and other arguments. It should return a boolean indicating whether
+ the function succeeded or not.
+
+ Returns:
+ The wrapped function that applies the function to all prims under a specified prim-path.
+
+ Raises:
+ ValueError: If the prim-path does not exist on the stage.
+ """
+
+ @functools.wraps(func)
+ defwrapper(prim_path:str|Sdf.Path,*args,**kwargs):
+ # map args and kwargs to function signature so we can get the stage
+ # note: we do this to check if stage is given in arg or kwarg
+ sig=inspect.signature(func)
+ bound_args=sig.bind(prim_path,*args,**kwargs)
+ # get current stage
+ stage=bound_args.arguments.get("stage")
+ ifstageisNone:
+ stage=stage_utils.get_current_stage()
+ # get USD prim
+ prim:Usd.Prim=stage.GetPrimAtPath(prim_path)
+ # check if prim is valid
+ ifnotprim.IsValid():
+ raiseValueError(f"Prim at path '{prim_path}' is not valid.")
+ # add iterable to check if property was applied on any of the prims
+ count_success=0
+ instanced_prim_paths=[]
+ # iterate over all prims under prim-path
+ all_prims=[prim]
+ whilelen(all_prims)>0:
+ # get current prim
+ child_prim=all_prims.pop(0)
+ child_prim_path=child_prim.GetPath().pathString# type: ignore
+ # check if prim is a prototype
+ ifchild_prim.IsInstance():
+ instanced_prim_paths.append(child_prim_path)
+ continue
+ # set properties
+ success=func(child_prim_path,*args,**kwargs)
+ # if successful, do not look at children
+ # this is based on the physics behavior that nested schemas are not allowed
+ ifnotsuccess:
+ all_prims+=child_prim.GetChildren()
+ else:
+ count_success+=1
+ # check if we were successful in applying the function to any prim
+ ifcount_success==0:
+ carb.log_warn(
+ f"Could not perform '{func.__name__}' on any prims under: '{prim_path}'."
+ " This might be because of the following reasons:"
+ "\n\t(1) The desired attribute does not exist on any of the prims."
+ "\n\t(2) The desired attribute exists on an instanced prim."
+ f"\n\t\tDiscovered list of instanced prim paths: {instanced_prim_paths}"
+ )
+
+ returnwrapper
+
+
+
[文档]defclone(func:Callable)->Callable:
+"""Decorator for cloning a prim based on matching prim paths of the prim's parent.
+
+ The decorator checks if the parent prim path matches any prim paths in the stage. If so, it clones the
+ spawned prim at each matching prim path. For example, if the input prim path is: ``/World/Table_[0-9]/Bottle``,
+ the decorator will clone the prim at each matching prim path of the parent prim: ``/World/Table_0/Bottle``,
+ ``/World/Table_1/Bottle``, etc.
+
+ Note:
+ For matching prim paths, the decorator assumes that valid prims exist for all matching prim paths.
+ In case no matching prim paths are found, the decorator raises a ``RuntimeError``.
+
+ Args:
+ func: The function to decorate.
+
+ Returns:
+ The decorated function that spawns the prim and clones it at each matching prim path.
+ It returns the spawned source prim, i.e., the first prim in the list of matching prim paths.
+ """
+
+ @functools.wraps(func)
+ defwrapper(prim_path:str|Sdf.Path,cfg:SpawnerCfg,*args,**kwargs):
+ # cast prim_path to str type in case its an Sdf.Path
+ prim_path=str(prim_path)
+ # check prim path is global
+ ifnotprim_path.startswith("/"):
+ raiseValueError(f"Prim path '{prim_path}' is not global. It must start with '/'.")
+ # resolve: {SPAWN_NS}/AssetName
+ # note: this assumes that the spawn namespace already exists in the stage
+ root_path,asset_path=prim_path.rsplit("/",1)
+ # check if input is a regex expression
+ # note: a valid prim path can only contain alphanumeric characters, underscores, and forward slashes
+ is_regex_expression=re.match(r"^[a-zA-Z0-9/_]+$",root_path)isNone
+
+ # resolve matching prims for source prim path expression
+ ifis_regex_expressionandroot_path!="":
+ source_prim_paths=find_matching_prim_paths(root_path)
+ # if no matching prims are found, raise an error
+ iflen(source_prim_paths)==0:
+ raiseRuntimeError(
+ f"Unable to find source prim path: '{root_path}'. Please create the prim before spawning."
+ )
+ else:
+ source_prim_paths=[root_path]
+
+ # resolve prim paths for spawning and cloning
+ prim_paths=[f"{source_prim_path}/{asset_path}"forsource_prim_pathinsource_prim_paths]
+ # spawn single instance
+ prim=func(prim_paths[0],cfg,*args,**kwargs)
+ # set the prim visibility
+ ifhasattr(cfg,"visible"):
+ imageable=UsdGeom.Imageable(prim)
+ ifcfg.visible:
+ imageable.MakeVisible()
+ else:
+ imageable.MakeInvisible()
+ # set the semantic annotations
+ ifhasattr(cfg,"semantic_tags")andcfg.semantic_tagsisnotNone:
+ # note: taken from replicator scripts.utils.utils.py
+ forsemantic_type,semantic_valueincfg.semantic_tags:
+ # deal with spaces by replacing them with underscores
+ semantic_type_sanitized=semantic_type.replace(" ","_")
+ semantic_value_sanitized=semantic_value.replace(" ","_")
+ # set the semantic API for the instance
+ instance_name=f"{semantic_type_sanitized}_{semantic_value_sanitized}"
+ sem=Semantics.SemanticsAPI.Apply(prim,instance_name)
+ # create semantic type and data attributes
+ sem.CreateSemanticTypeAttr()
+ sem.CreateSemanticDataAttr()
+ sem.GetSemanticTypeAttr().Set(semantic_type)
+ sem.GetSemanticDataAttr().Set(semantic_value)
+ # activate rigid body contact sensors
+ ifhasattr(cfg,"activate_contact_sensors")andcfg.activate_contact_sensors:
+ schemas.activate_contact_sensors(prim_paths[0],cfg.activate_contact_sensors)
+ # clone asset using cloner API
+ iflen(prim_paths)>1:
+ cloner=Cloner()
+ # clone the prim
+ cloner.clone(prim_paths[0],prim_paths[1:],replicate_physics=False,copy_from_source=cfg.copy_from_source)
+ # return the source prim
+ returnprim
+
+ returnwrapper
+
+
+"""
+Material bindings.
+"""
+
+
+
[文档]@apply_nested
+defbind_visual_material(
+ prim_path:str|Sdf.Path,
+ material_path:str|Sdf.Path,
+ stage:Usd.Stage|None=None,
+ stronger_than_descendants:bool=True,
+):
+"""Bind a visual material to a prim.
+
+ This function is a wrapper around the USD command `BindMaterialCommand`_.
+
+ .. note::
+ The function is decorated with :meth:`apply_nested` to allow applying the function to a prim path
+ and all its descendants.
+
+ .. _BindMaterialCommand: https://docs.omniverse.nvidia.com/kit/docs/omni.usd/latest/omni.usd.commands/omni.usd.commands.BindMaterialCommand.html
+
+ Args:
+ prim_path: The prim path where to apply the material.
+ material_path: The prim path of the material to apply.
+ stage: The stage where the prim and material exist.
+ Defaults to None, in which case the current stage is used.
+ stronger_than_descendants: Whether the material should override the material of its descendants.
+ Defaults to True.
+
+ Raises:
+ ValueError: If the provided prim paths do not exist on stage.
+ """
+ # resolve stage
+ ifstageisNone:
+ stage=stage_utils.get_current_stage()
+ # check if prim and material exists
+ ifnotstage.GetPrimAtPath(prim_path).IsValid():
+ raiseValueError(f"Target prim '{material_path}' does not exist.")
+ ifnotstage.GetPrimAtPath(material_path).IsValid():
+ raiseValueError(f"Visual material '{material_path}' does not exist.")
+
+ # resolve token for weaker than descendants
+ ifstronger_than_descendants:
+ binding_strength="strongerThanDescendants"
+ else:
+ binding_strength="weakerThanDescendants"
+ # obtain material binding API
+ # note: we prefer using the command here as it is more robust than the USD API
+ success,_=omni.kit.commands.execute(
+ "BindMaterialCommand",
+ prim_path=prim_path,
+ material_path=material_path,
+ strength=binding_strength,
+ stage=stage,
+ )
+ # return success
+ returnsuccess
+
+
+
[文档]@apply_nested
+defbind_physics_material(
+ prim_path:str|Sdf.Path,
+ material_path:str|Sdf.Path,
+ stage:Usd.Stage|None=None,
+ stronger_than_descendants:bool=True,
+):
+"""Bind a physics material to a prim.
+
+ `Physics material`_ can be applied only to a prim with physics-enabled on them. This includes having
+ collision APIs, or deformable body APIs, or being a particle system. In case the prim does not have
+ any of these APIs, the function will not apply the material and return False.
+
+ .. note::
+ The function is decorated with :meth:`apply_nested` to allow applying the function to a prim path
+ and all its descendants.
+
+ .. _Physics material: https://docs.omniverse.nvidia.com/extensions/latest/ext_physics/simulation-control/physics-settings.html#physics-materials
+
+ Args:
+ prim_path: The prim path where to apply the material.
+ material_path: The prim path of the material to apply.
+ stage: The stage where the prim and material exist.
+ Defaults to None, in which case the current stage is used.
+ stronger_than_descendants: Whether the material should override the material of its descendants.
+ Defaults to True.
+
+ Raises:
+ ValueError: If the provided prim paths do not exist on stage.
+ """
+ # resolve stage
+ ifstageisNone:
+ stage=stage_utils.get_current_stage()
+ # check if prim and material exists
+ ifnotstage.GetPrimAtPath(prim_path).IsValid():
+ raiseValueError(f"Target prim '{material_path}' does not exist.")
+ ifnotstage.GetPrimAtPath(material_path).IsValid():
+ raiseValueError(f"Physics material '{material_path}' does not exist.")
+ # get USD prim
+ prim=stage.GetPrimAtPath(prim_path)
+ # check if prim has collision applied on it
+ has_physics_scene_api=prim.HasAPI(PhysxSchema.PhysxSceneAPI)
+ has_collider=prim.HasAPI(UsdPhysics.CollisionAPI)
+ has_deformable_body=prim.HasAPI(PhysxSchema.PhysxDeformableBodyAPI)
+ has_particle_system=prim.IsA(PhysxSchema.PhysxParticleSystem)
+ ifnot(has_physics_scene_apiorhas_colliderorhas_deformable_bodyorhas_particle_system):
+ carb.log_verbose(
+ f"Cannot apply physics material '{material_path}' on prim '{prim_path}'. It is neither a"
+ " PhysX scene, collider, a deformable body, nor a particle system."
+ )
+ returnFalse
+
+ # obtain material binding API
+ ifprim.HasAPI(UsdShade.MaterialBindingAPI):
+ material_binding_api=UsdShade.MaterialBindingAPI(prim)
+ else:
+ material_binding_api=UsdShade.MaterialBindingAPI.Apply(prim)
+ # obtain the material prim
+ material=UsdShade.Material(stage.GetPrimAtPath(material_path))
+ # resolve token for weaker than descendants
+ ifstronger_than_descendants:
+ binding_strength=UsdShade.Tokens.strongerThanDescendants
+ else:
+ binding_strength=UsdShade.Tokens.weakerThanDescendants
+ # apply the material
+ material_binding_api.Bind(material,bindingStrength=binding_strength,materialPurpose="physics")# type: ignore
+ # return success
+ returnTrue
+
+
+"""
+Exporting.
+"""
+
+
+
[文档]defexport_prim_to_file(
+ path:str|Sdf.Path,
+ source_prim_path:str|Sdf.Path,
+ target_prim_path:str|Sdf.Path|None=None,
+ stage:Usd.Stage|None=None,
+):
+"""Exports a prim from a given stage to a USD file.
+
+ The function creates a new layer at the provided path and copies the prim to the layer.
+ It sets the copied prim as the default prim in the target layer. Additionally, it updates
+ the stage up-axis and meters-per-unit to match the current stage.
+
+ Args:
+ path: The filepath path to export the prim to.
+ source_prim_path: The prim path to export.
+ target_prim_path: The prim path to set as the default prim in the target layer.
+ Defaults to None, in which case the source prim path is used.
+ stage: The stage where the prim exists. Defaults to None, in which case the
+ current stage is used.
+
+ Raises:
+ ValueError: If the prim paths are not global (i.e: do not start with '/').
+ """
+ # automatically casting to str in case args
+ # are path types
+ path=str(path)
+ source_prim_path=str(source_prim_path)
+ iftarget_prim_pathisnotNone:
+ target_prim_path=str(target_prim_path)
+
+ ifnotsource_prim_path.startswith("/"):
+ raiseValueError(f"Source prim path '{source_prim_path}' is not global. It must start with '/'.")
+ iftarget_prim_pathisnotNoneandnottarget_prim_path.startswith("/"):
+ raiseValueError(f"Target prim path '{target_prim_path}' is not global. It must start with '/'.")
+ # get current stage
+ ifstageisNone:
+ stage:Usd.Stage=omni.usd.get_context().get_stage()
+ # get root layer
+ source_layer=stage.GetRootLayer()
+
+ # only create a new layer if it doesn't exist already
+ target_layer=Sdf.Find(path)
+ iftarget_layerisNone:
+ target_layer=Sdf.Layer.CreateNew(path)
+ # open the target stage
+ target_stage=Usd.Stage.Open(target_layer)
+
+ # update stage data
+ UsdGeom.SetStageUpAxis(target_stage,UsdGeom.GetStageUpAxis(stage))
+ UsdGeom.SetStageMetersPerUnit(target_stage,UsdGeom.GetStageMetersPerUnit(stage))
+
+ # specify the prim to copy
+ source_prim_path=Sdf.Path(source_prim_path)
+ iftarget_prim_pathisNone:
+ target_prim_path=source_prim_path
+
+ # copy the prim
+ Sdf.CreatePrimInLayer(target_layer,target_prim_path)
+ Sdf.CopySpec(source_layer,source_prim_path,target_layer,target_prim_path)
+ # set the default prim
+ target_layer.defaultPrim=Sdf.Path(target_prim_path).name
+ # resolve all paths relative to layer path
+ omni.usd.resolve_paths(source_layer.identifier,target_layer.identifier)
+ # save the stage
+ target_layer.Save()
+
+
+"""
+USD Prim properties.
+"""
+
+
+
[文档]defmake_uninstanceable(prim_path:str|Sdf.Path,stage:Usd.Stage|None=None):
+"""Check if a prim and its descendants are instanced and make them uninstanceable.
+
+ This function checks if the prim at the specified prim path and its descendants are instanced.
+ If so, it makes the respective prim uninstanceable by disabling instancing on the prim.
+
+ This is useful when we want to modify the properties of a prim that is instanced. For example, if we
+ want to apply a different material on an instanced prim, we need to make the prim uninstanceable first.
+
+ Args:
+ prim_path: The prim path to check.
+ stage: The stage where the prim exists. Defaults to None, in which case the current stage is used.
+
+ Raises:
+ ValueError: If the prim path is not global (i.e: does not start with '/').
+ """
+ # make paths str type if they aren't already
+ prim_path=str(prim_path)
+ # check if prim path is global
+ ifnotprim_path.startswith("/"):
+ raiseValueError(f"Prim path '{prim_path}' is not global. It must start with '/'.")
+ # get current stage
+ ifstageisNone:
+ stage=stage_utils.get_current_stage()
+ # get prim
+ prim:Usd.Prim=stage.GetPrimAtPath(prim_path)
+ # check if prim is valid
+ ifnotprim.IsValid():
+ raiseValueError(f"Prim at path '{prim_path}' is not valid.")
+ # iterate over all prims under prim-path
+ all_prims=[prim]
+ whilelen(all_prims)>0:
+ # get current prim
+ child_prim=all_prims.pop(0)
+ # check if prim is instanced
+ ifchild_prim.IsInstance():
+ # make the prim uninstanceable
+ child_prim.SetInstanceable(False)
+ # add children to list
+ all_prims+=child_prim.GetChildren()
+
+
+"""
+USD Stage traversal.
+"""
+
+
+
[文档]defget_first_matching_child_prim(
+ prim_path:str|Sdf.Path,predicate:Callable[[Usd.Prim],bool],stage:Usd.Stage|None=None
+)->Usd.Prim|None:
+"""Recursively get the first USD Prim at the path string that passes the predicate function
+
+ Args:
+ prim_path: The path of the prim in the stage.
+ predicate: The function to test the prims against. It takes a prim as input and returns a boolean.
+ stage: The stage where the prim exists. Defaults to None, in which case the current stage is used.
+
+ Returns:
+ The first prim on the path that passes the predicate. If no prim passes the predicate, it returns None.
+
+ Raises:
+ ValueError: If the prim path is not global (i.e: does not start with '/').
+ """
+ # make paths str type if they aren't already
+ prim_path=str(prim_path)
+ # check if prim path is global
+ ifnotprim_path.startswith("/"):
+ raiseValueError(f"Prim path '{prim_path}' is not global. It must start with '/'.")
+ # get current stage
+ ifstageisNone:
+ stage=stage_utils.get_current_stage()
+ # get prim
+ prim=stage.GetPrimAtPath(prim_path)
+ # check if prim is valid
+ ifnotprim.IsValid():
+ raiseValueError(f"Prim at path '{prim_path}' is not valid.")
+ # iterate over all prims under prim-path
+ all_prims=[prim]
+ whilelen(all_prims)>0:
+ # get current prim
+ child_prim=all_prims.pop(0)
+ # check if prim passes predicate
+ ifpredicate(child_prim):
+ returnchild_prim
+ # add children to list
+ all_prims+=child_prim.GetChildren()
+ returnNone
+
+
+
[文档]defget_all_matching_child_prims(
+ prim_path:str|Sdf.Path,
+ predicate:Callable[[Usd.Prim],bool]=lambda_:True,
+ depth:int|None=None,
+ stage:Usd.Stage|None=None,
+)->list[Usd.Prim]:
+"""Performs a search starting from the root and returns all the prims matching the predicate.
+
+ Args:
+ prim_path: The root prim path to start the search from.
+ predicate: The predicate that checks if the prim matches the desired criteria. It takes a prim as input
+ and returns a boolean. Defaults to a function that always returns True.
+ depth: The maximum depth for traversal, should be bigger than zero if specified.
+ Defaults to None (i.e: traversal happens till the end of the tree).
+ stage: The stage where the prim exists. Defaults to None, in which case the current stage is used.
+
+ Returns:
+ A list containing all the prims matching the predicate.
+
+ Raises:
+ ValueError: If the prim path is not global (i.e: does not start with '/').
+ """
+ # make paths str type if they aren't already
+ prim_path=str(prim_path)
+ # check if prim path is global
+ ifnotprim_path.startswith("/"):
+ raiseValueError(f"Prim path '{prim_path}' is not global. It must start with '/'.")
+ # get current stage
+ ifstageisNone:
+ stage=stage_utils.get_current_stage()
+ # get prim
+ prim=stage.GetPrimAtPath(prim_path)
+ # check if prim is valid
+ ifnotprim.IsValid():
+ raiseValueError(f"Prim at path '{prim_path}' is not valid.")
+ # check if depth is valid
+ ifdepthisnotNoneanddepth<=0:
+ raiseValueError(f"Depth must be bigger than zero, got {depth}.")
+
+ # iterate over all prims under prim-path
+ # list of tuples (prim, current_depth)
+ all_prims_queue=[(prim,0)]
+ output_prims=[]
+ whilelen(all_prims_queue)>0:
+ # get current prim
+ child_prim,current_depth=all_prims_queue.pop(0)
+ # check if prim passes predicate
+ ifpredicate(child_prim):
+ output_prims.append(child_prim)
+ # add children to list
+ ifdepthisNoneorcurrent_depth<depth:
+ all_prims_queue+=[(child,current_depth+1)forchildinchild_prim.GetChildren()]
+
+ returnoutput_prims
+
+
+
[文档]deffind_first_matching_prim(prim_path_regex:str,stage:Usd.Stage|None=None)->Usd.Prim|None:
+"""Find the first matching prim in the stage based on input regex expression.
+
+ Args:
+ prim_path_regex: The regex expression for prim path.
+ stage: The stage where the prim exists. Defaults to None, in which case the current stage is used.
+
+ Returns:
+ The first prim that matches input expression. If no prim matches, returns None.
+
+ Raises:
+ ValueError: If the prim path is not global (i.e: does not start with '/').
+ """
+ # check prim path is global
+ ifnotprim_path_regex.startswith("/"):
+ raiseValueError(f"Prim path '{prim_path_regex}' is not global. It must start with '/'.")
+ # get current stage
+ ifstageisNone:
+ stage=stage_utils.get_current_stage()
+ # need to wrap the token patterns in '^' and '$' to prevent matching anywhere in the string
+ pattern=f"^{prim_path_regex}$"
+ compiled_pattern=re.compile(pattern)
+ # obtain matching prim (depth-first search)
+ forpriminstage.Traverse():
+ # check if prim passes predicate
+ ifcompiled_pattern.match(prim.GetPath().pathString)isnotNone:
+ returnprim
+ returnNone
+
+
+
[文档]deffind_matching_prims(prim_path_regex:str,stage:Usd.Stage|None=None)->list[Usd.Prim]:
+"""Find all the matching prims in the stage based on input regex expression.
+
+ Args:
+ prim_path_regex: The regex expression for prim path.
+ stage: The stage where the prim exists. Defaults to None, in which case the current stage is used.
+
+ Returns:
+ A list of prims that match input expression.
+
+ Raises:
+ ValueError: If the prim path is not global (i.e: does not start with '/').
+ """
+ # check prim path is global
+ ifnotprim_path_regex.startswith("/"):
+ raiseValueError(f"Prim path '{prim_path_regex}' is not global. It must start with '/'.")
+ # get current stage
+ ifstageisNone:
+ stage=stage_utils.get_current_stage()
+ # need to wrap the token patterns in '^' and '$' to prevent matching anywhere in the string
+ tokens=prim_path_regex.split("/")[1:]
+ tokens=[f"^{token}$"fortokenintokens]
+ # iterate over all prims in stage (breath-first search)
+ all_prims=[stage.GetPseudoRoot()]
+ output_prims=[]
+ forindex,tokeninenumerate(tokens):
+ token_compiled=re.compile(token)
+ forpriminall_prims:
+ forchildinprim.GetAllChildren():
+ iftoken_compiled.match(child.GetName())isnotNone:
+ output_prims.append(child)
+ ifindex<len(tokens)-1:
+ all_prims=output_prims
+ output_prims=[]
+ returnoutput_prims
+
+
+
[文档]deffind_matching_prim_paths(prim_path_regex:str,stage:Usd.Stage|None=None)->list[str]:
+"""Find all the matching prim paths in the stage based on input regex expression.
+
+ Args:
+ prim_path_regex: The regex expression for prim path.
+ stage: The stage where the prim exists. Defaults to None, in which case the current stage is used.
+
+ Returns:
+ A list of prim paths that match input expression.
+
+ Raises:
+ ValueError: If the prim path is not global (i.e: does not start with '/').
+ """
+ # obtain matching prims
+ output_prims=find_matching_prims(prim_path_regex,stage)
+ # convert prims to prim paths
+ output_prim_paths=[]
+ forpriminoutput_prims:
+ output_prim_paths.append(prim.GetPath().pathString)
+ returnoutput_prim_paths
+
+
+
[文档]deffind_global_fixed_joint_prim(
+ prim_path:str|Sdf.Path,check_enabled_only:bool=False,stage:Usd.Stage|None=None
+)->UsdPhysics.Joint|None:
+"""Find the fixed joint prim under the specified prim path that connects the target to the simulation world.
+
+ A joint is a connection between two bodies. A fixed joint is a joint that does not allow relative motion
+ between the two bodies. When a fixed joint has only one target body, it is considered to attach the body
+ to the simulation world.
+
+ This function finds the fixed joint prim that has only one target under the specified prim path. If no such
+ fixed joint prim exists, it returns None.
+
+ Args:
+ prim_path: The prim path to search for the fixed joint prim.
+ check_enabled_only: Whether to consider only enabled fixed joints. Defaults to False.
+ If False, then all joints (enabled or disabled) are considered.
+ stage: The stage where the prim exists. Defaults to None, in which case the current stage is used.
+
+ Returns:
+ The fixed joint prim that has only one target. If no such fixed joint prim exists, it returns None.
+
+ Raises:
+ ValueError: If the prim path is not global (i.e: does not start with '/').
+ ValueError: If the prim path does not exist on the stage.
+ """
+ # check prim path is global
+ ifnotprim_path.startswith("/"):
+ raiseValueError(f"Prim path '{prim_path}' is not global. It must start with '/'.")
+ # get current stage
+ ifstageisNone:
+ stage=stage_utils.get_current_stage()
+
+ # check if prim exists
+ prim=stage.GetPrimAtPath(prim_path)
+ ifnotprim.IsValid():
+ raiseValueError(f"Prim at path '{prim_path}' is not valid.")
+
+ fixed_joint_prim=None
+ # we check all joints under the root prim and classify the asset as fixed base if there exists
+ # a fixed joint that has only one target (i.e. the root link).
+ forpriminUsd.PrimRange(prim):
+ # note: ideally checking if it is FixedJoint would have been enough, but some assets use "Joint" as the
+ # schema name which makes it difficult to distinguish between the two.
+ joint_prim=UsdPhysics.Joint(prim)
+ ifjoint_prim:
+ # if check_enabled_only is True, we only consider enabled joints
+ ifcheck_enabled_onlyandnotjoint_prim.GetJointEnabledAttr().Get():
+ continue
+ # check body 0 and body 1 exist
+ body_0_exist=joint_prim.GetBody0Rel().GetTargets()!=[]
+ body_1_exist=joint_prim.GetBody1Rel().GetTargets()!=[]
+ # if either body 0 or body 1 does not exist, we have a fixed joint that connects to the world
+ ifnot(body_0_existandbody_1_exist):
+ fixed_joint_prim=joint_prim
+ break
+
+ returnfixed_joint_prim
+
+
+"""
+USD Variants.
+"""
+
+
+
[文档]defselect_usd_variants(prim_path:str,variants:object|dict[str,str],stage:Usd.Stage|None=None):
+"""Sets the variant selections from the specified variant sets on a USD prim.
+
+ `USD Variants`_ are a very powerful tool in USD composition that allows prims to have different options on
+ a single asset. This can be done by modifying variations of the same prim parameters per variant option in a set.
+ This function acts as a script-based utility to set the variant selections for the specified variant sets on a
+ USD prim.
+
+ The function takes a dictionary or a config class mapping variant set names to variant selections. For instance,
+ if we have a prim at ``"/World/Table"`` with two variant sets: "color" and "size", we can set the variant
+ selections as follows:
+
+ .. code-block:: python
+
+ select_usd_variants(
+ prim_path="/World/Table",
+ variants={
+ "color": "red",
+ "size": "large",
+ },
+ )
+
+ Alternatively, we can use a config class to define the variant selections:
+
+ .. code-block:: python
+
+ @configclass
+ class TableVariants:
+ color: Literal["blue", "red"] = "red"
+ size: Literal["small", "large"] = "large"
+
+ select_usd_variants(
+ prim_path="/World/Table",
+ variants=TableVariants(),
+ )
+
+ Args:
+ prim_path: The path of the USD prim.
+ variants: A dictionary or config class mapping variant set names to variant selections.
+ stage: The USD stage. Defaults to None, in which case, the current stage is used.
+
+ Raises:
+ ValueError: If the prim at the specified path is not valid.
+
+ .. _USD Variants: https://graphics.pixar.com/usd/docs/USD-Glossary.html#USDGlossary-Variant
+ """
+ # Resolve stage
+ ifstageisNone:
+ stage=stage_utils.get_current_stage()
+ # Obtain prim
+ prim=stage.GetPrimAtPath(prim_path)
+ ifnotprim.IsValid():
+ raiseValueError(f"Prim at path '{prim_path}' is not valid.")
+ # Convert to dict if we have a configclass object.
+ ifnotisinstance(variants,dict):
+ variants=variants.to_dict()
+
+ existing_variant_sets=prim.GetVariantSets()
+ forvariant_set_name,variant_selectioninvariants.items():
+ # Check if the variant set exists on the prim.
+ ifnotexisting_variant_sets.HasVariantSet(variant_set_name):
+ carb.log_warn(f"Variant set '{variant_set_name}' does not exist on prim '{prim_path}'.")
+ continue
+
+ variant_set=existing_variant_sets.GetVariantSet(variant_set_name)
+ # Only set the variant selection if it is different from the current selection.
+ ifvariant_set.GetVariantSelection()!=variant_selection:
+ variant_set.SetVariantSelection(variant_selection)
+ carb.log_info(
+ f"Setting variant selection '{variant_selection}' for variant set '{variant_set_name}' on"
+ f" prim '{prim_path}'."
+ )
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Functions to generate height fields for different terrains."""
+
+from__future__importannotations
+
+importnumpyasnp
+importscipy.interpolateasinterpolate
+fromtypingimportTYPE_CHECKING
+
+from.utilsimportheight_field_to_mesh
+
+ifTYPE_CHECKING:
+ from.importhf_terrains_cfg
+
+
+
[文档]@height_field_to_mesh
+defrandom_uniform_terrain(difficulty:float,cfg:hf_terrains_cfg.HfRandomUniformTerrainCfg)->np.ndarray:
+"""Generate a terrain with height sampled uniformly from a specified range.
+
+ .. image:: ../../_static/terrains/height_field/random_uniform_terrain.jpg
+ :width: 40%
+ :align: center
+
+ Note:
+ The :obj:`difficulty` parameter is ignored for this terrain.
+
+ Args:
+ difficulty: The difficulty of the terrain. This is a value between 0 and 1.
+ cfg: The configuration for the terrain.
+
+ Returns:
+ The height field of the terrain as a 2D numpy array with discretized heights.
+ The shape of the array is (width, length), where width and length are the number of points
+ along the x and y axis, respectively.
+
+ Raises:
+ ValueError: When the downsampled scale is smaller than the horizontal scale.
+ """
+ # check parameters
+ # -- horizontal scale
+ ifcfg.downsampled_scaleisNone:
+ cfg.downsampled_scale=cfg.horizontal_scale
+ elifcfg.downsampled_scale<cfg.horizontal_scale:
+ raiseValueError(
+ "Downsampled scale must be larger than or equal to the horizontal scale:"
+ f" {cfg.downsampled_scale} < {cfg.horizontal_scale}."
+ )
+
+ # switch parameters to discrete units
+ # -- horizontal scale
+ width_pixels=int(cfg.size[0]/cfg.horizontal_scale)
+ length_pixels=int(cfg.size[1]/cfg.horizontal_scale)
+ # -- downsampled scale
+ width_downsampled=int(cfg.size[0]/cfg.downsampled_scale)
+ length_downsampled=int(cfg.size[1]/cfg.downsampled_scale)
+ # -- height
+ height_min=int(cfg.noise_range[0]/cfg.vertical_scale)
+ height_max=int(cfg.noise_range[1]/cfg.vertical_scale)
+ height_step=int(cfg.noise_step/cfg.vertical_scale)
+
+ # create range of heights possible
+ height_range=np.arange(height_min,height_max+height_step,height_step)
+ # sample heights randomly from the range along a grid
+ height_field_downsampled=np.random.choice(height_range,size=(width_downsampled,length_downsampled))
+ # create interpolation function for the sampled heights
+ x=np.linspace(0,cfg.size[0]*cfg.horizontal_scale,width_downsampled)
+ y=np.linspace(0,cfg.size[1]*cfg.horizontal_scale,length_downsampled)
+ func=interpolate.RectBivariateSpline(x,y,height_field_downsampled)
+
+ # interpolate the sampled heights to obtain the height field
+ x_upsampled=np.linspace(0,cfg.size[0]*cfg.horizontal_scale,width_pixels)
+ y_upsampled=np.linspace(0,cfg.size[1]*cfg.horizontal_scale,length_pixels)
+ z_upsampled=func(x_upsampled,y_upsampled)
+ # round off the interpolated heights to the nearest vertical step
+ returnnp.rint(z_upsampled).astype(np.int16)
+
+
+
[文档]@height_field_to_mesh
+defpyramid_sloped_terrain(difficulty:float,cfg:hf_terrains_cfg.HfPyramidSlopedTerrainCfg)->np.ndarray:
+"""Generate a terrain with a truncated pyramid structure.
+
+ The terrain is a pyramid-shaped sloped surface with a slope of :obj:`slope` that trims into a flat platform
+ at the center. The slope is defined as the ratio of the height change along the x axis to the width along the
+ x axis. For example, a slope of 1.0 means that the height changes by 1 unit for every 1 unit of width.
+
+ If the :obj:`cfg.inverted` flag is set to :obj:`True`, the terrain is inverted such that
+ the platform is at the bottom.
+
+ .. image:: ../../_static/terrains/height_field/pyramid_sloped_terrain.jpg
+ :width: 40%
+
+ .. image:: ../../_static/terrains/height_field/inverted_pyramid_sloped_terrain.jpg
+ :width: 40%
+
+ Args:
+ difficulty: The difficulty of the terrain. This is a value between 0 and 1.
+ cfg: The configuration for the terrain.
+
+ Returns:
+ The height field of the terrain as a 2D numpy array with discretized heights.
+ The shape of the array is (width, length), where width and length are the number of points
+ along the x and y axis, respectively.
+ """
+ # resolve terrain configuration
+ ifcfg.inverted:
+ slope=-cfg.slope_range[0]-difficulty*(cfg.slope_range[1]-cfg.slope_range[0])
+ else:
+ slope=cfg.slope_range[0]+difficulty*(cfg.slope_range[1]-cfg.slope_range[0])
+
+ # switch parameters to discrete units
+ # -- horizontal scale
+ width_pixels=int(cfg.size[0]/cfg.horizontal_scale)
+ length_pixels=int(cfg.size[1]/cfg.horizontal_scale)
+ # -- height
+ # we want the height to be 1/2 of the width since the terrain is a pyramid
+ height_max=int(slope*cfg.size[0]/2/cfg.vertical_scale)
+ # -- center of the terrain
+ center_x=int(width_pixels/2)
+ center_y=int(length_pixels/2)
+
+ # create a meshgrid of the terrain
+ x=np.arange(0,width_pixels)
+ y=np.arange(0,length_pixels)
+ xx,yy=np.meshgrid(x,y,sparse=True)
+ # offset the meshgrid to the center of the terrain
+ xx=(center_x-np.abs(center_x-xx))/center_x
+ yy=(center_y-np.abs(center_y-yy))/center_y
+ # reshape the meshgrid to be 2D
+ xx=xx.reshape(width_pixels,1)
+ yy=yy.reshape(1,length_pixels)
+ # create a sloped surface
+ hf_raw=np.zeros((width_pixels,length_pixels))
+ hf_raw=height_max*xx*yy
+
+ # create a flat platform at the center of the terrain
+ platform_width=int(cfg.platform_width/cfg.horizontal_scale/2)
+ # get the height of the platform at the corner of the platform
+ x_pf=width_pixels//2-platform_width
+ y_pf=length_pixels//2-platform_width
+ z_pf=hf_raw[x_pf,y_pf]
+ hf_raw=np.clip(hf_raw,min(0,z_pf),max(0,z_pf))
+
+ # round off the heights to the nearest vertical step
+ returnnp.rint(hf_raw).astype(np.int16)
+
+
+
[文档]@height_field_to_mesh
+defpyramid_stairs_terrain(difficulty:float,cfg:hf_terrains_cfg.HfPyramidStairsTerrainCfg)->np.ndarray:
+"""Generate a terrain with a pyramid stair pattern.
+
+ The terrain is a pyramid stair pattern which trims to a flat platform at the center of the terrain.
+
+ If the :obj:`cfg.inverted` flag is set to :obj:`True`, the terrain is inverted such that
+ the platform is at the bottom.
+
+ .. image:: ../../_static/terrains/height_field/pyramid_stairs_terrain.jpg
+ :width: 40%
+
+ .. image:: ../../_static/terrains/height_field/inverted_pyramid_stairs_terrain.jpg
+ :width: 40%
+
+ Args:
+ difficulty: The difficulty of the terrain. This is a value between 0 and 1.
+ cfg: The configuration for the terrain.
+
+ Returns:
+ The height field of the terrain as a 2D numpy array with discretized heights.
+ The shape of the array is (width, length), where width and length are the number of points
+ along the x and y axis, respectively.
+ """
+ # resolve terrain configuration
+ step_height=cfg.step_height_range[0]+difficulty*(cfg.step_height_range[1]-cfg.step_height_range[0])
+ ifcfg.inverted:
+ step_height*=-1
+ # switch parameters to discrete units
+ # -- terrain
+ width_pixels=int(cfg.size[0]/cfg.horizontal_scale)
+ length_pixels=int(cfg.size[1]/cfg.horizontal_scale)
+ # -- stairs
+ step_width=int(cfg.step_width/cfg.horizontal_scale)
+ step_height=int(step_height/cfg.vertical_scale)
+ # -- platform
+ platform_width=int(cfg.platform_width/cfg.horizontal_scale)
+
+ # create a terrain with a flat platform at the center
+ hf_raw=np.zeros((width_pixels,length_pixels))
+ # add the steps
+ current_step_height=0
+ start_x,start_y=0,0
+ stop_x,stop_y=width_pixels,length_pixels
+ while(stop_x-start_x)>platform_widthand(stop_y-start_y)>platform_width:
+ # increment position
+ # -- x
+ start_x+=step_width
+ stop_x-=step_width
+ # -- y
+ start_y+=step_width
+ stop_y-=step_width
+ # increment height
+ current_step_height+=step_height
+ # add the step
+ hf_raw[start_x:stop_x,start_y:stop_y]=current_step_height
+
+ # round off the heights to the nearest vertical step
+ returnnp.rint(hf_raw).astype(np.int16)
+
+
+
[文档]@height_field_to_mesh
+defdiscrete_obstacles_terrain(difficulty:float,cfg:hf_terrains_cfg.HfDiscreteObstaclesTerrainCfg)->np.ndarray:
+"""Generate a terrain with randomly generated obstacles as pillars with positive and negative heights.
+
+ The terrain is a flat platform at the center of the terrain with randomly generated obstacles as pillars
+ with positive and negative height. The obstacles are randomly generated cuboids with a random width and
+ height. They are placed randomly on the terrain with a minimum distance of :obj:`cfg.platform_width`
+ from the center of the terrain.
+
+ .. image:: ../../_static/terrains/height_field/discrete_obstacles_terrain.jpg
+ :width: 40%
+ :align: center
+
+ Args:
+ difficulty: The difficulty of the terrain. This is a value between 0 and 1.
+ cfg: The configuration for the terrain.
+
+ Returns:
+ The height field of the terrain as a 2D numpy array with discretized heights.
+ The shape of the array is (width, length), where width and length are the number of points
+ along the x and y axis, respectively.
+ """
+ # resolve terrain configuration
+ obs_height=cfg.obstacle_height_range[0]+difficulty*(
+ cfg.obstacle_height_range[1]-cfg.obstacle_height_range[0]
+ )
+
+ # switch parameters to discrete units
+ # -- terrain
+ width_pixels=int(cfg.size[0]/cfg.horizontal_scale)
+ length_pixels=int(cfg.size[1]/cfg.horizontal_scale)
+ # -- obstacles
+ obs_height=int(obs_height/cfg.vertical_scale)
+ obs_width_min=int(cfg.obstacle_width_range[0]/cfg.horizontal_scale)
+ obs_width_max=int(cfg.obstacle_width_range[1]/cfg.horizontal_scale)
+ # -- center of the terrain
+ platform_width=int(cfg.platform_width/cfg.horizontal_scale)
+
+ # create discrete ranges for the obstacles
+ # -- shape
+ obs_width_range=np.arange(obs_width_min,obs_width_max,4)
+ obs_length_range=np.arange(obs_width_min,obs_width_max,4)
+ # -- position
+ obs_x_range=np.arange(0,width_pixels,4)
+ obs_y_range=np.arange(0,length_pixels,4)
+
+ # create a terrain with a flat platform at the center
+ hf_raw=np.zeros((width_pixels,length_pixels))
+ # generate the obstacles
+ for_inrange(cfg.num_obstacles):
+ # sample size
+ ifcfg.obstacle_height_mode=="choice":
+ height=np.random.choice([-obs_height,-obs_height//2,obs_height//2,obs_height])
+ elifcfg.obstacle_height_mode=="fixed":
+ height=obs_height
+ else:
+ raiseValueError(f"Unknown obstacle height mode '{cfg.obstacle_height_mode}'. Must be 'choice' or 'fixed'.")
+ width=int(np.random.choice(obs_width_range))
+ length=int(np.random.choice(obs_length_range))
+ # sample position
+ x_start=int(np.random.choice(obs_x_range))
+ y_start=int(np.random.choice(obs_y_range))
+ # clip start position to the terrain
+ ifx_start+width>width_pixels:
+ x_start=width_pixels-width
+ ify_start+length>length_pixels:
+ y_start=length_pixels-length
+ # add to terrain
+ hf_raw[x_start:x_start+width,y_start:y_start+length]=height
+ # clip the terrain to the platform
+ x1=(width_pixels-platform_width)//2
+ x2=(width_pixels+platform_width)//2
+ y1=(length_pixels-platform_width)//2
+ y2=(length_pixels+platform_width)//2
+ hf_raw[x1:x2,y1:y2]=0
+ # round off the heights to the nearest vertical step
+ returnnp.rint(hf_raw).astype(np.int16)
+
+
+
[文档]@height_field_to_mesh
+defwave_terrain(difficulty:float,cfg:hf_terrains_cfg.HfWaveTerrainCfg)->np.ndarray:
+r"""Generate a terrain with a wave pattern.
+
+ The terrain is a flat platform at the center of the terrain with a wave pattern. The wave pattern
+ is generated by adding sinusoidal waves based on the number of waves and the amplitude of the waves.
+
+ The height of the terrain at a point :math:`(x, y)` is given by:
+
+ .. math::
+
+ h(x, y) = A \left(\sin\left(\frac{2 \pi x}{\lambda}\right) + \cos\left(\frac{2 \pi y}{\lambda}\right) \right)
+
+ where :math:`A` is the amplitude of the waves, :math:`\lambda` is the wavelength of the waves.
+
+ .. image:: ../../_static/terrains/height_field/wave_terrain.jpg
+ :width: 40%
+ :align: center
+
+ Args:
+ difficulty: The difficulty of the terrain. This is a value between 0 and 1.
+ cfg: The configuration for the terrain.
+
+ Returns:
+ The height field of the terrain as a 2D numpy array with discretized heights.
+ The shape of the array is (width, length), where width and length are the number of points
+ along the x and y axis, respectively.
+
+ Raises:
+ ValueError: When the number of waves is non-positive.
+ """
+ # check number of waves
+ ifcfg.num_waves<0:
+ raiseValueError(f"Number of waves must be a positive integer. Got: {cfg.num_waves}.")
+
+ # resolve terrain configuration
+ amplitude=cfg.amplitude_range[0]+difficulty*(cfg.amplitude_range[1]-cfg.amplitude_range[0])
+ # switch parameters to discrete units
+ # -- terrain
+ width_pixels=int(cfg.size[0]/cfg.horizontal_scale)
+ length_pixels=int(cfg.size[1]/cfg.horizontal_scale)
+ amplitude_pixels=int(0.5*amplitude/cfg.vertical_scale)
+
+ # compute the wave number: nu = 2 * pi / lambda
+ wave_length=length_pixels/cfg.num_waves
+ wave_number=2*np.pi/wave_length
+ # create meshgrid for the terrain
+ x=np.arange(0,width_pixels)
+ y=np.arange(0,length_pixels)
+ xx,yy=np.meshgrid(x,y,sparse=True)
+ xx=xx.reshape(width_pixels,1)
+ yy=yy.reshape(1,length_pixels)
+
+ # create a terrain with a flat platform at the center
+ hf_raw=np.zeros((width_pixels,length_pixels))
+ # add the waves
+ hf_raw+=amplitude_pixels*(np.cos(yy*wave_number)+np.sin(xx*wave_number))
+ # round off the heights to the nearest vertical step
+ returnnp.rint(hf_raw).astype(np.int16)
+
+
+
[文档]@height_field_to_mesh
+defstepping_stones_terrain(difficulty:float,cfg:hf_terrains_cfg.HfSteppingStonesTerrainCfg)->np.ndarray:
+"""Generate a terrain with a stepping stones pattern.
+
+ The terrain is a stepping stones pattern which trims to a flat platform at the center of the terrain.
+
+ .. image:: ../../_static/terrains/height_field/stepping_stones_terrain.jpg
+ :width: 40%
+ :align: center
+
+ Args:
+ difficulty: The difficulty of the terrain. This is a value between 0 and 1.
+ cfg: The configuration for the terrain.
+
+ Returns:
+ The height field of the terrain as a 2D numpy array with discretized heights.
+ The shape of the array is (width, length), where width and length are the number of points
+ along the x and y axis, respectively.
+ """
+ # resolve terrain configuration
+ stone_width=cfg.stone_width_range[1]-difficulty*(cfg.stone_width_range[1]-cfg.stone_width_range[0])
+ stone_distance=cfg.stone_distance_range[0]+difficulty*(
+ cfg.stone_distance_range[1]-cfg.stone_distance_range[0]
+ )
+
+ # switch parameters to discrete units
+ # -- terrain
+ width_pixels=int(cfg.size[0]/cfg.horizontal_scale)
+ length_pixels=int(cfg.size[1]/cfg.horizontal_scale)
+ # -- stones
+ stone_distance=int(stone_distance/cfg.horizontal_scale)
+ stone_width=int(stone_width/cfg.horizontal_scale)
+ stone_height_max=int(cfg.stone_height_max/cfg.vertical_scale)
+ # -- holes
+ holes_depth=int(cfg.holes_depth/cfg.vertical_scale)
+ # -- platform
+ platform_width=int(cfg.platform_width/cfg.horizontal_scale)
+ # create range of heights
+ stone_height_range=np.arange(-stone_height_max-1,stone_height_max,step=1)
+
+ # create a terrain with a flat platform at the center
+ hf_raw=np.full((width_pixels,length_pixels),holes_depth)
+ # add the stones
+ start_x,start_y=0,0
+ # -- if the terrain is longer than it is wide then fill the terrain column by column
+ iflength_pixels>=width_pixels:
+ whilestart_y<length_pixels:
+ # ensure that stone stops along y-axis
+ stop_y=min(length_pixels,start_y+stone_width)
+ # randomly sample x-position
+ start_x=np.random.randint(0,stone_width)
+ stop_x=max(0,start_x-stone_distance)
+ # fill first stone
+ hf_raw[0:stop_x,start_y:stop_y]=np.random.choice(stone_height_range)
+ # fill row with stones
+ whilestart_x<width_pixels:
+ stop_x=min(width_pixels,start_x+stone_width)
+ hf_raw[start_x:stop_x,start_y:stop_y]=np.random.choice(stone_height_range)
+ start_x+=stone_width+stone_distance
+ # update y-position
+ start_y+=stone_width+stone_distance
+ elifwidth_pixels>length_pixels:
+ whilestart_x<width_pixels:
+ # ensure that stone stops along x-axis
+ stop_x=min(width_pixels,start_x+stone_width)
+ # randomly sample y-position
+ start_y=np.random.randint(0,stone_width)
+ stop_y=max(0,start_y-stone_distance)
+ # fill first stone
+ hf_raw[start_x:stop_x,0:stop_y]=np.random.choice(stone_height_range)
+ # fill column with stones
+ whilestart_y<length_pixels:
+ stop_y=min(length_pixels,start_y+stone_width)
+ hf_raw[start_x:stop_x,start_y:stop_y]=np.random.choice(stone_height_range)
+ start_y+=stone_width+stone_distance
+ # update x-position
+ start_x+=stone_width+stone_distance
+ # add the platform in the center
+ x1=(width_pixels-platform_width)//2
+ x2=(width_pixels+platform_width)//2
+ y1=(length_pixels-platform_width)//2
+ y2=(length_pixels+platform_width)//2
+ hf_raw[x1:x2,y1:y2]=0
+ # round off the heights to the nearest vertical step
+ returnnp.rint(hf_raw).astype(np.int16)
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+fromdataclassesimportMISSING
+
+fromomni.isaac.lab.utilsimportconfigclass
+
+from..terrain_generator_cfgimportSubTerrainBaseCfg
+from.importhf_terrains
+
+
+
[文档]@configclass
+classHfTerrainBaseCfg(SubTerrainBaseCfg):
+"""The base configuration for height field terrains."""
+
+ border_width:float=0.0
+"""The width of the border/padding around the terrain (in m). Defaults to 0.0.
+
+ The border width is subtracted from the :obj:`size` of the terrain. If non-zero, it must be
+ greater than or equal to the :obj:`horizontal scale`.
+ """
+ horizontal_scale:float=0.1
+"""The discretization of the terrain along the x and y axes (in m). Defaults to 0.1."""
+ vertical_scale:float=0.005
+"""The discretization of the terrain along the z axis (in m). Defaults to 0.005."""
+ slope_threshold:float|None=None
+"""The slope threshold above which surfaces are made vertical. Defaults to None,
+ in which case no correction is applied."""
[文档]@configclass
+classHfRandomUniformTerrainCfg(HfTerrainBaseCfg):
+"""Configuration for a random uniform height field terrain."""
+
+ function=hf_terrains.random_uniform_terrain
+
+ noise_range:tuple[float,float]=MISSING
+"""The minimum and maximum height noise (i.e. along z) of the terrain (in m)."""
+ noise_step:float=MISSING
+"""The minimum height (in m) change between two points."""
+ downsampled_scale:float|None=None
+"""The distance between two randomly sampled points on the terrain. Defaults to None,
+ in which case the :obj:`horizontal scale` is used.
+
+ The heights are sampled at this resolution and interpolation is performed for intermediate points.
+ This must be larger than or equal to the :obj:`horizontal scale`.
+ """
+
+
+
[文档]@configclass
+classHfPyramidSlopedTerrainCfg(HfTerrainBaseCfg):
+"""Configuration for a pyramid sloped height field terrain."""
+
+ function=hf_terrains.pyramid_sloped_terrain
+
+ slope_range:tuple[float,float]=MISSING
+"""The slope of the terrain (in radians)."""
+ platform_width:float=1.0
+"""The width of the square platform at the center of the terrain. Defaults to 1.0."""
+ inverted:bool=False
+"""Whether the pyramid is inverted. Defaults to False.
+
+ If True, the terrain is inverted such that the platform is at the bottom and the slopes are upwards.
+ """
+
+
+
[文档]@configclass
+classHfInvertedPyramidSlopedTerrainCfg(HfPyramidSlopedTerrainCfg):
+"""Configuration for an inverted pyramid sloped height field terrain.
+
+ Note:
+ This is a subclass of :class:`HfPyramidSlopedTerrainCfg` with :obj:`inverted` set to True.
+ We make it as a separate class to make it easier to distinguish between the two and match
+ the naming convention of the other terrains.
+ """
+
+ inverted:bool=True
+
+
+
[文档]@configclass
+classHfPyramidStairsTerrainCfg(HfTerrainBaseCfg):
+"""Configuration for a pyramid stairs height field terrain."""
+
+ function=hf_terrains.pyramid_stairs_terrain
+
+ step_height_range:tuple[float,float]=MISSING
+"""The minimum and maximum height of the steps (in m)."""
+ step_width:float=MISSING
+"""The width of the steps (in m)."""
+ platform_width:float=1.0
+"""The width of the square platform at the center of the terrain. Defaults to 1.0."""
+ inverted:bool=False
+"""Whether the pyramid stairs is inverted. Defaults to False.
+
+ If True, the terrain is inverted such that the platform is at the bottom and the stairs are upwards.
+ """
+
+
+
[文档]@configclass
+classHfInvertedPyramidStairsTerrainCfg(HfPyramidStairsTerrainCfg):
+"""Configuration for an inverted pyramid stairs height field terrain.
+
+ Note:
+ This is a subclass of :class:`HfPyramidStairsTerrainCfg` with :obj:`inverted` set to True.
+ We make it as a separate class to make it easier to distinguish between the two and match
+ the naming convention of the other terrains.
+ """
+
+ inverted:bool=True
+
+
+
[文档]@configclass
+classHfDiscreteObstaclesTerrainCfg(HfTerrainBaseCfg):
+"""Configuration for a discrete obstacles height field terrain."""
+
+ function=hf_terrains.discrete_obstacles_terrain
+
+ obstacle_height_mode:str="choice"
+"""The mode to use for the obstacle height. Defaults to "choice".
+
+ The following modes are supported: "choice", "fixed".
+ """
+ obstacle_width_range:tuple[float,float]=MISSING
+"""The minimum and maximum width of the obstacles (in m)."""
+ obstacle_height_range:tuple[float,float]=MISSING
+"""The minimum and maximum height of the obstacles (in m)."""
+ num_obstacles:int=MISSING
+"""The number of obstacles to generate."""
+ platform_width:float=1.0
+"""The width of the square platform at the center of the terrain. Defaults to 1.0."""
+
+
+
[文档]@configclass
+classHfWaveTerrainCfg(HfTerrainBaseCfg):
+"""Configuration for a wave height field terrain."""
+
+ function=hf_terrains.wave_terrain
+
+ amplitude_range:tuple[float,float]=MISSING
+"""The minimum and maximum amplitude of the wave (in m)."""
+ num_waves:int=1.0
+"""The number of waves to generate. Defaults to 1.0."""
+
+
+
[文档]@configclass
+classHfSteppingStonesTerrainCfg(HfTerrainBaseCfg):
+"""Configuration for a stepping stones height field terrain."""
+
+ function=hf_terrains.stepping_stones_terrain
+
+ stone_height_max:float=MISSING
+"""The maximum height of the stones (in m)."""
+ stone_width_range:tuple[float,float]=MISSING
+"""The minimum and maximum width of the stones (in m)."""
+ stone_distance_range:tuple[float,float]=MISSING
+"""The minimum and maximum distance between stones (in m)."""
+ holes_depth:float=-10.0
+"""The depth of the holes (negative obstacles). Defaults to -10.0."""
+ platform_width:float=1.0
+"""The width of the square platform at the center of the terrain. Defaults to 1.0."""
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+importnumpyasnp
+importos
+importtorch
+importtrimesh
+
+importcarb
+
+fromomni.isaac.lab.utils.dictimportdict_to_md5_hash
+fromomni.isaac.lab.utils.ioimportdump_yaml
+fromomni.isaac.lab.utils.timerimportTimer
+fromomni.isaac.lab.utils.warpimportconvert_to_warp_mesh
+
+from.height_fieldimportHfTerrainBaseCfg
+from.terrain_generator_cfgimportFlatPatchSamplingCfg,SubTerrainBaseCfg,TerrainGeneratorCfg
+from.trimesh.utilsimportmake_border
+from.utilsimportcolor_meshes_by_height,find_flat_patches
+
+
+
[文档]classTerrainGenerator:
+r"""Terrain generator to handle different terrain generation functions.
+
+ The terrains are represented as meshes. These are obtained either from height fields or by using the
+ `trimesh <https://trimsh.org/trimesh.html>`__ library. The height field representation is more
+ flexible, but it is less computationally and memory efficient than the trimesh representation.
+
+ All terrain generation functions take in the argument :obj:`difficulty` which determines the complexity
+ of the terrain. The difficulty is a number between 0 and 1, where 0 is the easiest and 1 is the hardest.
+ In most cases, the difficulty is used for linear interpolation between different terrain parameters.
+ For example, in a pyramid stairs terrain the step height is interpolated between the specified minimum
+ and maximum step height.
+
+ Each sub-terrain has a corresponding configuration class that can be used to specify the parameters
+ of the terrain. The configuration classes are inherited from the :class:`SubTerrainBaseCfg` class
+ which contains the common parameters for all terrains.
+
+ If a curriculum is used, the terrains are generated based on their difficulty parameter.
+ The difficulty is varied linearly over the number of rows (i.e. along x) with a small random value
+ added to the difficulty to ensure that the columns with the same sub-terrain type are not exactly
+ the same. The difficulty parameter for a sub-terrain at a given row is calculated as:
+
+ .. math::
+
+ \text{difficulty} = \frac{\text{row_id} + \eta}{\text{num_rows}} \times (\text{upper} - \text{lower}) + \text{lower}
+
+ where :math:`\eta\sim\mathcal{U}(0, 1)` is a random perturbation to the difficulty, and
+ :math:`(\text{lower}, \text{upper})` is the range of the difficulty parameter, specified using the
+ :attr:`~TerrainGeneratorCfg.difficulty_range` parameter.
+
+ If a curriculum is not used, the terrains are generated randomly. In this case, the difficulty parameter
+ is randomly sampled from the specified range, given by the :attr:`~TerrainGeneratorCfg.difficulty_range` parameter:
+
+ .. math::
+
+ \text{difficulty} \sim \mathcal{U}(\text{lower}, \text{upper})
+
+ If the :attr:`~TerrainGeneratorCfg.flat_patch_sampling` is specified for a sub-terrain, flat patches are sampled
+ on the terrain. These can be used for spawning robots, targets, etc. The sampled patches are stored
+ in the :obj:`flat_patches` dictionary. The key specifies the intention of the flat patches and the
+ value is a tensor containing the flat patches for each sub-terrain.
+
+ If the flag :attr:`~TerrainGeneratorCfg.use_cache` is set to True, the terrains are cached based on their
+ sub-terrain configurations. This means that if the same sub-terrain configuration is used
+ multiple times, the terrain is only generated once and then reused. This is useful when
+ generating complex sub-terrains that take a long time to generate.
+
+ .. attention::
+
+ The terrain generation has its own seed parameter. This is set using the :attr:`TerrainGeneratorCfg.seed`
+ parameter. If the seed is not set and the caching is disabled, the terrain generation may not be
+ completely reproducible.
+
+ """
+
+ terrain_mesh:trimesh.Trimesh
+"""A single trimesh.Trimesh object for all the generated sub-terrains."""
+ terrain_meshes:list[trimesh.Trimesh]
+"""List of trimesh.Trimesh objects for all the generated sub-terrains."""
+ terrain_origins:np.ndarray
+"""The origin of each sub-terrain. Shape is (num_rows, num_cols, 3)."""
+ flat_patches:dict[str,torch.Tensor]
+"""A dictionary of sampled valid (flat) patches for each sub-terrain.
+
+ The dictionary keys are the names of the flat patch sampling configurations. This maps to a
+ tensor containing the flat patches for each sub-terrain. The shape of the tensor is
+ (num_rows, num_cols, num_patches, 3).
+
+ For instance, the key "root_spawn" maps to a tensor containing the flat patches for spawning an asset.
+ Similarly, the key "target_spawn" maps to a tensor containing the flat patches for setting targets.
+ """
+
+
[文档]def__init__(self,cfg:TerrainGeneratorCfg,device:str="cpu"):
+"""Initialize the terrain generator.
+
+ Args:
+ cfg: Configuration for the terrain generator.
+ device: The device to use for the flat patches tensor.
+ """
+ # check inputs
+ iflen(cfg.sub_terrains)==0:
+ raiseValueError("No sub-terrains specified! Please add at least one sub-terrain.")
+ # store inputs
+ self.cfg=cfg
+ self.device=device
+
+ # set common values to all sub-terrains config
+ forsub_cfginself.cfg.sub_terrains.values():
+ # size of all terrains
+ sub_cfg.size=self.cfg.size
+ # params for height field terrains
+ ifisinstance(sub_cfg,HfTerrainBaseCfg):
+ sub_cfg.horizontal_scale=self.cfg.horizontal_scale
+ sub_cfg.vertical_scale=self.cfg.vertical_scale
+ sub_cfg.slope_threshold=self.cfg.slope_threshold
+
+ # throw a warning if the cache is enabled but the seed is not set
+ ifself.cfg.use_cacheandself.cfg.seedisNone:
+ carb.log_warn(
+ "Cache is enabled but the seed is not set. The terrain generation will not be reproducible."
+ " Please set the seed in the terrain generator configuration to make the generation reproducible."
+ )
+
+ # if the seed is not set, we assume there is a global seed set and use that.
+ # this ensures that the terrain is reproducible if the seed is set at the beginning of the program.
+ ifself.cfg.seedisnotNone:
+ seed=self.cfg.seed
+ else:
+ seed=np.random.get_state()[1][0]
+ # set the seed for reproducibility
+ # note: we create a new random number generator to avoid affecting the global state
+ # in the other places where random numbers are used.
+ self.np_rng=np.random.default_rng(seed)
+
+ # buffer for storing valid patches
+ self.flat_patches={}
+ # create a list of all sub-terrains
+ self.terrain_meshes=list()
+ self.terrain_origins=np.zeros((self.cfg.num_rows,self.cfg.num_cols,3))
+
+ # parse configuration and add sub-terrains
+ # create terrains based on curriculum or randomly
+ ifself.cfg.curriculum:
+ withTimer("[INFO] Generating terrains based on curriculum took"):
+ self._generate_curriculum_terrains()
+ else:
+ withTimer("[INFO] Generating terrains randomly took"):
+ self._generate_random_terrains()
+ # add a border around the terrains
+ self._add_terrain_border()
+ # combine all the sub-terrains into a single mesh
+ self.terrain_mesh=trimesh.util.concatenate(self.terrain_meshes)
+
+ # color the terrain mesh
+ ifself.cfg.color_scheme=="height":
+ self.terrain_mesh=color_meshes_by_height(self.terrain_mesh)
+ elifself.cfg.color_scheme=="random":
+ self.terrain_mesh.visual.vertex_colors=self.np_rng.choice(
+ range(256),size=(len(self.terrain_mesh.vertices),4)
+ )
+ elifself.cfg.color_scheme=="none":
+ pass
+ else:
+ raiseValueError(f"Invalid color scheme: {self.cfg.color_scheme}.")
+
+ # offset the entire terrain and origins so that it is centered
+ # -- terrain mesh
+ transform=np.eye(4)
+ transform[:2,-1]=-self.cfg.size[0]*self.cfg.num_rows*0.5,-self.cfg.size[1]*self.cfg.num_cols*0.5
+ self.terrain_mesh.apply_transform(transform)
+ # -- terrain origins
+ self.terrain_origins+=transform[:3,-1]
+ # -- valid patches
+ terrain_origins_torch=torch.tensor(self.terrain_origins,dtype=torch.float,device=self.device).unsqueeze(2)
+ forname,valueinself.flat_patches.items():
+ self.flat_patches[name]=value+terrain_origins_torch
+
+ def__str__(self):
+"""Return a string representation of the terrain generator."""
+ msg="Terrain Generator:"
+ msg+=f"\n\tSeed: {self.cfg.seed}"
+ msg+=f"\n\tNumber of rows: {self.cfg.num_rows}"
+ msg+=f"\n\tNumber of columns: {self.cfg.num_cols}"
+ msg+=f"\n\tSub-terrain size: {self.cfg.size}"
+ msg+=f"\n\tSub-terrain types: {list(self.cfg.sub_terrains.keys())}"
+ msg+=f"\n\tCurriculum: {self.cfg.curriculum}"
+ msg+=f"\n\tDifficulty range: {self.cfg.difficulty_range}"
+ msg+=f"\n\tColor scheme: {self.cfg.color_scheme}"
+ msg+=f"\n\tUse cache: {self.cfg.use_cache}"
+ ifself.cfg.use_cache:
+ msg+=f"\n\tCache directory: {self.cfg.cache_dir}"
+
+ returnmsg
+
+"""
+ Terrain generator functions.
+ """
+
+ def_generate_random_terrains(self):
+"""Add terrains based on randomly sampled difficulty parameter."""
+ # normalize the proportions of the sub-terrains
+ proportions=np.array([sub_cfg.proportionforsub_cfginself.cfg.sub_terrains.values()])
+ proportions/=np.sum(proportions)
+ # create a list of all terrain configs
+ sub_terrains_cfgs=list(self.cfg.sub_terrains.values())
+
+ # randomly sample sub-terrains
+ forindexinrange(self.cfg.num_rows*self.cfg.num_cols):
+ # coordinate index of the sub-terrain
+ (sub_row,sub_col)=np.unravel_index(index,(self.cfg.num_rows,self.cfg.num_cols))
+ # randomly sample terrain index
+ sub_index=self.np_rng.choice(len(proportions),p=proportions)
+ # randomly sample difficulty parameter
+ difficulty=self.np_rng.uniform(*self.cfg.difficulty_range)
+ # generate terrain
+ mesh,origin=self._get_terrain_mesh(difficulty,sub_terrains_cfgs[sub_index])
+ # add to sub-terrains
+ self._add_sub_terrain(mesh,origin,sub_row,sub_col,sub_terrains_cfgs[sub_index])
+
+ def_generate_curriculum_terrains(self):
+"""Add terrains based on the difficulty parameter."""
+ # normalize the proportions of the sub-terrains
+ proportions=np.array([sub_cfg.proportionforsub_cfginself.cfg.sub_terrains.values()])
+ proportions/=np.sum(proportions)
+
+ # find the sub-terrain index for each column
+ # we generate the terrains based on their proportion (not randomly sampled)
+ sub_indices=[]
+ forindexinrange(self.cfg.num_cols):
+ sub_index=np.min(np.where(index/self.cfg.num_cols+0.001<np.cumsum(proportions))[0])
+ sub_indices.append(sub_index)
+ sub_indices=np.array(sub_indices,dtype=np.int32)
+ # create a list of all terrain configs
+ sub_terrains_cfgs=list(self.cfg.sub_terrains.values())
+
+ # curriculum-based sub-terrains
+ forsub_colinrange(self.cfg.num_cols):
+ forsub_rowinrange(self.cfg.num_rows):
+ # vary the difficulty parameter linearly over the number of rows
+ # note: based on the proportion, multiple columns can have the same sub-terrain type.
+ # Thus to increase the diversity along the rows, we add a small random value to the difficulty.
+ # This ensures that the terrains are not exactly the same. For example, if the
+ # the row index is 2 and the number of rows is 10, the nominal difficulty is 0.2.
+ # We add a small random value to the difficulty to make it between 0.2 and 0.3.
+ lower,upper=self.cfg.difficulty_range
+ difficulty=(sub_row+self.np_rng.uniform())/self.cfg.num_rows
+ difficulty=lower+(upper-lower)*difficulty
+ # generate terrain
+ mesh,origin=self._get_terrain_mesh(difficulty,sub_terrains_cfgs[sub_indices[sub_col]])
+ # add to sub-terrains
+ self._add_sub_terrain(mesh,origin,sub_row,sub_col,sub_terrains_cfgs[sub_indices[sub_col]])
+
+"""
+ Internal helper functions.
+ """
+
+ def_add_terrain_border(self):
+"""Add a surrounding border over all the sub-terrains into the terrain meshes."""
+ # border parameters
+ border_size=(
+ self.cfg.num_rows*self.cfg.size[0]+2*self.cfg.border_width,
+ self.cfg.num_cols*self.cfg.size[1]+2*self.cfg.border_width,
+ )
+ inner_size=(self.cfg.num_rows*self.cfg.size[0],self.cfg.num_cols*self.cfg.size[1])
+ border_center=(
+ self.cfg.num_rows*self.cfg.size[0]/2,
+ self.cfg.num_cols*self.cfg.size[1]/2,
+ -self.cfg.border_height/2,
+ )
+ # border mesh
+ border_meshes=make_border(border_size,inner_size,height=self.cfg.border_height,position=border_center)
+ border=trimesh.util.concatenate(border_meshes)
+ # update the faces to have minimal triangles
+ selector=~(np.asarray(border.triangles)[:,:,2]<-0.1).any(1)
+ border.update_faces(selector)
+ # add the border to the list of meshes
+ self.terrain_meshes.append(border)
+
+ def_add_sub_terrain(
+ self,mesh:trimesh.Trimesh,origin:np.ndarray,row:int,col:int,sub_terrain_cfg:SubTerrainBaseCfg
+ ):
+"""Add input sub-terrain to the list of sub-terrains.
+
+ This function adds the input sub-terrain mesh to the list of sub-terrains and updates the origin
+ of the sub-terrain in the list of origins. It also samples flat patches if specified.
+
+ Args:
+ mesh: The mesh of the sub-terrain.
+ origin: The origin of the sub-terrain.
+ row: The row index of the sub-terrain.
+ col: The column index of the sub-terrain.
+ """
+ # sample flat patches if specified
+ ifsub_terrain_cfg.flat_patch_samplingisnotNone:
+ carb.log_info(f"Sampling flat patches for sub-terrain at (row, col): ({row}, {col})")
+ # convert the mesh to warp mesh
+ wp_mesh=convert_to_warp_mesh(mesh.vertices,mesh.faces,device=self.device)
+ # sample flat patches based on each patch configuration for that sub-terrain
+ forname,patch_cfginsub_terrain_cfg.flat_patch_sampling.items():
+ patch_cfg:FlatPatchSamplingCfg
+ # create the flat patches tensor (if not already created)
+ ifnamenotinself.flat_patches:
+ self.flat_patches[name]=torch.zeros(
+ (self.cfg.num_rows,self.cfg.num_cols,patch_cfg.num_patches,3),device=self.device
+ )
+ # add the flat patches to the tensor
+ self.flat_patches[name][row,col]=find_flat_patches(
+ wp_mesh=wp_mesh,
+ origin=origin,
+ num_patches=patch_cfg.num_patches,
+ patch_radius=patch_cfg.patch_radius,
+ x_range=patch_cfg.x_range,
+ y_range=patch_cfg.y_range,
+ z_range=patch_cfg.z_range,
+ max_height_diff=patch_cfg.max_height_diff,
+ )
+
+ # transform the mesh to the correct position
+ transform=np.eye(4)
+ transform[0:2,-1]=(row+0.5)*self.cfg.size[0],(col+0.5)*self.cfg.size[1]
+ mesh.apply_transform(transform)
+ # add mesh to the list
+ self.terrain_meshes.append(mesh)
+ # add origin to the list
+ self.terrain_origins[row,col]=origin+transform[:3,-1]
+
+ def_get_terrain_mesh(self,difficulty:float,cfg:SubTerrainBaseCfg)->tuple[trimesh.Trimesh,np.ndarray]:
+"""Generate a sub-terrain mesh based on the input difficulty parameter.
+
+ If caching is enabled, the sub-terrain is cached and loaded from the cache if it exists.
+ The cache is stored in the cache directory specified in the configuration.
+
+ .. Note:
+ This function centers the 2D center of the mesh and its specified origin such that the
+ 2D center becomes :math:`(0, 0)` instead of :math:`(size[0] / 2, size[1] / 2).
+
+ Args:
+ difficulty: The difficulty parameter.
+ cfg: The configuration of the sub-terrain.
+
+ Returns:
+ The sub-terrain mesh and origin.
+ """
+ # copy the configuration
+ cfg=cfg.copy()
+ # add other parameters to the sub-terrain configuration
+ cfg.difficulty=float(difficulty)
+ cfg.seed=self.cfg.seed
+ # generate hash for the sub-terrain
+ sub_terrain_hash=dict_to_md5_hash(cfg.to_dict())
+ # generate the file name
+ sub_terrain_cache_dir=os.path.join(self.cfg.cache_dir,sub_terrain_hash)
+ sub_terrain_obj_filename=os.path.join(sub_terrain_cache_dir,"mesh.obj")
+ sub_terrain_csv_filename=os.path.join(sub_terrain_cache_dir,"origin.csv")
+ sub_terrain_meta_filename=os.path.join(sub_terrain_cache_dir,"cfg.yaml")
+
+ # check if hash exists - if true, load the mesh and origin and return
+ ifself.cfg.use_cacheandos.path.exists(sub_terrain_obj_filename):
+ # load existing mesh
+ mesh=trimesh.load_mesh(sub_terrain_obj_filename,process=False)
+ origin=np.loadtxt(sub_terrain_csv_filename,delimiter=",")
+ # return the generated mesh
+ returnmesh,origin
+
+ # generate the terrain
+ meshes,origin=cfg.function(difficulty,cfg)
+ mesh=trimesh.util.concatenate(meshes)
+ # offset mesh such that they are in their center
+ transform=np.eye(4)
+ transform[0:2,-1]=-cfg.size[0]*0.5,-cfg.size[1]*0.5
+ mesh.apply_transform(transform)
+ # change origin to be in the center of the sub-terrain
+ origin+=transform[0:3,-1]
+
+ # if caching is enabled, save the mesh and origin
+ ifself.cfg.use_cache:
+ # create the cache directory
+ os.makedirs(sub_terrain_cache_dir,exist_ok=True)
+ # save the data
+ mesh.export(sub_terrain_obj_filename)
+ np.savetxt(sub_terrain_csv_filename,origin,delimiter=",",header="x,y,z")
+ dump_yaml(sub_terrain_meta_filename,cfg)
+ # return the generated mesh
+ returnmesh,origin
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""
+Configuration classes defining the different terrains available. Each configuration class must
+inherit from ``omni.isaac.lab.terrains.terrains_cfg.TerrainConfig`` and define the following attributes:
+
+- ``name``: Name of the terrain. This is used for the prim name in the USD stage.
+- ``function``: Function to generate the terrain. This function must take as input the terrain difficulty
+ and the configuration parameters and return a `tuple with the `trimesh`` mesh object and terrain origin.
+"""
+
+from__future__importannotations
+
+importnumpyasnp
+importtrimesh
+fromcollections.abcimportCallable
+fromdataclassesimportMISSING
+fromtypingimportLiteral
+
+fromomni.isaac.lab.utilsimportconfigclass
+
+
+@configclass
+classFlatPatchSamplingCfg:
+"""Configuration for sampling flat patches on the sub-terrain.
+
+ For a given sub-terrain, this configuration specifies how to sample flat patches on the terrain.
+ The sampled flat patches can be used for spawning robots, targets, etc.
+
+ Please check the function :meth:`~omni.isaac.lab.terrains.utils.find_flat_patches` for more details.
+ """
+
+ num_patches:int=MISSING
+"""Number of patches to sample."""
+
+ patch_radius:float|list[float]=MISSING
+"""Radius of the patches.
+
+ A list of radii can be provided to check for patches of different sizes. This is useful to deal with
+ cases where the terrain may have holes or obstacles in some areas.
+ """
+
+ x_range:tuple[float,float]=(-1e6,1e6)
+"""The range of x-coordinates to sample from. Defaults to (-1e6, 1e6).
+
+ This range is internally clamped to the size of the terrain mesh.
+ """
+
+ y_range:tuple[float,float]=(-1e6,1e6)
+"""The range of y-coordinates to sample from. Defaults to (-1e6, 1e6).
+
+ This range is internally clamped to the size of the terrain mesh.
+ """
+
+ z_range:tuple[float,float]=(-1e6,1e6)
+"""Allowed range of z-coordinates for the sampled patch. Defaults to (-1e6, 1e6)."""
+
+ max_height_diff:float=MISSING
+"""Maximum allowed height difference between the highest and lowest points on the patch."""
+
+
+
[文档]@configclass
+classSubTerrainBaseCfg:
+"""Base class for terrain configurations.
+
+ All the sub-terrain configurations must inherit from this class.
+
+ The :attr:`size` attribute is the size of the generated sub-terrain. Based on this, the terrain must
+ extend from :math:`(0, 0)` to :math:`(size[0], size[1])`.
+ """
+
+ function:Callable[[float,SubTerrainBaseCfg],tuple[list[trimesh.Trimesh],np.ndarray]]=MISSING
+"""Function to generate the terrain.
+
+ This function must take as input the terrain difficulty and the configuration parameters and
+ return a tuple with a list of ``trimesh`` mesh objects and the terrain origin.
+ """
+
+ proportion:float=1.0
+"""Proportion of the terrain to generate. Defaults to 1.0.
+
+ This is used to generate a mix of terrains. The proportion corresponds to the probability of sampling
+ the particular terrain. For example, if there are two terrains, A and B, with proportions 0.3 and 0.7,
+ respectively, then the probability of sampling terrain A is 0.3 and the probability of sampling terrain B
+ is 0.7.
+ """
+
+ size:tuple[float,float]=MISSING
+"""The width (along x) and length (along y) of the terrain (in m)."""
+
+ flat_patch_sampling:dict[str,FlatPatchSamplingCfg]|None=None
+"""Dictionary of configurations for sampling flat patches on the sub-terrain. Defaults to None,
+ in which case no flat patch sampling is performed.
+
+ The keys correspond to the name of the flat patch sampling configuration and the values are the
+ corresponding configurations.
+ """
+
+
+
[文档]@configclass
+classTerrainGeneratorCfg:
+"""Configuration for the terrain generator."""
+
+ seed:int|None=None
+"""The seed for the random number generator. Defaults to None, in which case the seed from the
+ current NumPy's random state is used.
+
+ When the seed is set, the random number generator is initialized with the given seed. This ensures
+ that the generated terrains are deterministic across different runs. If the seed is not set, the
+ seed from the current NumPy's random state is used. This assumes that the seed is set elsewhere in
+ the code.
+ """
+
+ curriculum:bool=False
+"""Whether to use the curriculum mode. Defaults to False.
+
+ If True, the terrains are generated based on their difficulty parameter. Otherwise,
+ they are randomly generated.
+ """
+
+ size:tuple[float,float]=MISSING
+"""The width (along x) and length (along y) of each sub-terrain (in m).
+
+ Note:
+ This value is passed on to all the sub-terrain configurations.
+ """
+
+ border_width:float=0.0
+"""The width of the border around the terrain (in m). Defaults to 0.0."""
+
+ border_height:float=1.0
+"""The height of the border around the terrain (in m). Defaults to 1.0."""
+
+ num_rows:int=1
+"""Number of rows of sub-terrains to generate. Defaults to 1."""
+
+ num_cols:int=1
+"""Number of columns of sub-terrains to generate. Defaults to 1."""
+
+ color_scheme:Literal["height","random","none"]="none"
+"""Color scheme to use for the terrain. Defaults to "none".
+
+ The available color schemes are:
+
+ - "height": Color based on the height of the terrain.
+ - "random": Random color scheme.
+ - "none": No color scheme.
+ """
+
+ horizontal_scale:float=0.1
+"""The discretization of the terrain along the x and y axes (in m). Defaults to 0.1.
+
+ This value is passed on to all the height field sub-terrain configurations.
+ """
+
+ vertical_scale:float=0.005
+"""The discretization of the terrain along the z axis (in m). Defaults to 0.005.
+
+ This value is passed on to all the height field sub-terrain configurations.
+ """
+
+ slope_threshold:float|None=0.75
+"""The slope threshold above which surfaces are made vertical. Defaults to 0.75.
+
+ If None no correction is applied.
+
+ This value is passed on to all the height field sub-terrain configurations.
+ """
+
+ sub_terrains:dict[str,SubTerrainBaseCfg]=MISSING
+"""Dictionary of sub-terrain configurations.
+
+ The keys correspond to the name of the sub-terrain configuration and the values are the corresponding
+ configurations.
+ """
+
+ difficulty_range:tuple[float,float]=(0.0,1.0)
+"""The range of difficulty values for the sub-terrains. Defaults to (0.0, 1.0).
+
+ If curriculum is enabled, the terrains will be generated based on this range in ascending order
+ of difficulty. Otherwise, the terrains will be generated based on this range in a random order.
+ """
+
+ use_cache:bool=False
+"""Whether to load the sub-terrain from cache if it exists. Defaults to True.
+
+ If enabled, the generated terrains are stored in the cache directory. When generating terrains, the cache
+ is checked to see if the terrain already exists. If it does, the terrain is loaded from the cache. Otherwise,
+ the terrain is generated and stored in the cache. Caching can be used to speed up terrain generation.
+ """
+
+ cache_dir:str="/tmp/isaaclab/terrains"
+"""The directory where the terrain cache is stored. Defaults to "/tmp/isaaclab/terrains"."""
[文档]classTerrainImporter:
+r"""A class to handle terrain meshes and import them into the simulator.
+
+ We assume that a terrain mesh comprises of sub-terrains that are arranged in a grid with
+ rows ``num_rows`` and columns ``num_cols``. The terrain origins are the positions of the sub-terrains
+ where the robot should be spawned.
+
+ Based on the configuration, the terrain importer handles computing the environment origins from the sub-terrain
+ origins. In a typical setup, the number of sub-terrains (:math:`num\_rows \times num\_cols`) is smaller than
+ the number of environments (:math:`num\_envs`). In this case, the environment origins are computed by
+ sampling the sub-terrain origins.
+
+ If a curriculum is used, it is possible to update the environment origins to terrain origins that correspond
+ to a harder difficulty. This is done by calling :func:`update_terrain_levels`. The idea comes from game-based
+ curriculum. For example, in a game, the player starts with easy levels and progresses to harder levels.
+ """
+
+ meshes:dict[str,trimesh.Trimesh]
+"""A dictionary containing the names of the meshes and their keys."""
+ warp_meshes:dict[str,warp.Mesh]
+"""A dictionary containing the names of the warp meshes and their keys."""
+ terrain_origins:torch.Tensor|None
+"""The origins of the sub-terrains in the added terrain mesh. Shape is (num_rows, num_cols, 3).
+
+ If None, then it is assumed no sub-terrains exist. The environment origins are computed in a grid.
+ """
+ env_origins:torch.Tensor
+"""The origins of the environments. Shape is (num_envs, 3)."""
+
+
[文档]def__init__(self,cfg:TerrainImporterCfg):
+"""Initialize the terrain importer.
+
+ Args:
+ cfg: The configuration for the terrain importer.
+
+ Raises:
+ ValueError: If input terrain type is not supported.
+ ValueError: If terrain type is 'generator' and no configuration provided for ``terrain_generator``.
+ ValueError: If terrain type is 'usd' and no configuration provided for ``usd_path``.
+ ValueError: If terrain type is 'usd' or 'plane' and no configuration provided for ``env_spacing``.
+ """
+ # store inputs
+ self.cfg=cfg
+ self.device=sim_utils.SimulationContext.instance().device# type: ignore
+
+ # create a dict of meshes
+ self.meshes=dict()
+ self.warp_meshes=dict()
+ self.env_origins=None
+ self.terrain_origins=None
+ # private variables
+ self._terrain_flat_patches=dict()
+
+ # auto-import the terrain based on the config
+ ifself.cfg.terrain_type=="generator":
+ # check config is provided
+ ifself.cfg.terrain_generatorisNone:
+ raiseValueError("Input terrain type is 'generator' but no value provided for 'terrain_generator'.")
+ # generate the terrain
+ terrain_generator=TerrainGenerator(cfg=self.cfg.terrain_generator,device=self.device)
+ self.import_mesh("terrain",terrain_generator.terrain_mesh)
+ # configure the terrain origins based on the terrain generator
+ self.configure_env_origins(terrain_generator.terrain_origins)
+ # refer to the flat patches
+ self._terrain_flat_patches=terrain_generator.flat_patches
+ elifself.cfg.terrain_type=="usd":
+ # check if config is provided
+ ifself.cfg.usd_pathisNone:
+ raiseValueError("Input terrain type is 'usd' but no value provided for 'usd_path'.")
+ # import the terrain
+ self.import_usd("terrain",self.cfg.usd_path)
+ # configure the origins in a grid
+ self.configure_env_origins()
+ elifself.cfg.terrain_type=="plane":
+ # load the plane
+ self.import_ground_plane("terrain")
+ # configure the origins in a grid
+ self.configure_env_origins()
+ else:
+ raiseValueError(f"Terrain type '{self.cfg.terrain_type}' not available.")
+
+ # set initial state of debug visualization
+ self.set_debug_vis(self.cfg.debug_vis)
+
+"""
+ Properties.
+ """
+
+ @property
+ defhas_debug_vis_implementation(self)->bool:
+"""Whether the terrain importer has a debug visualization implemented.
+
+ This always returns True.
+ """
+ returnTrue
+
+ @property
+ defflat_patches(self)->dict[str,torch.Tensor]:
+"""A dictionary containing the sampled valid (flat) patches for the terrain.
+
+ This is only available if the terrain type is 'generator'. For other terrain types, this feature
+ is not available and the function returns an empty dictionary.
+
+ Please refer to the :attr:`TerrainGenerator.flat_patches` for more information.
+ """
+ returnself._terrain_flat_patches
+
+"""
+ Operations - Visibility.
+ """
+
+
[文档]defset_debug_vis(self,debug_vis:bool)->bool:
+"""Set the debug visualization of the terrain importer.
+
+ Args:
+ debug_vis: Whether to visualize the terrain origins.
+
+ Returns:
+ Whether the debug visualization was successfully set. False if the terrain
+ importer does not support debug visualization.
+
+ Raises:
+ RuntimeError: If terrain origins are not configured.
+ """
+ # create a marker if necessary
+ ifdebug_vis:
+ ifnothasattr(self,"origin_visualizer"):
+ self.origin_visualizer=VisualizationMarkers(
+ cfg=FRAME_MARKER_CFG.replace(prim_path="/Visuals/TerrainOrigin")
+ )
+ ifself.terrain_originsisnotNone:
+ self.origin_visualizer.visualize(self.terrain_origins.reshape(-1,3))
+ elifself.env_originsisnotNone:
+ self.origin_visualizer.visualize(self.env_origins.reshape(-1,3))
+ else:
+ raiseRuntimeError("Terrain origins are not configured.")
+ # set visibility
+ self.origin_visualizer.set_visibility(True)
+ else:
+ ifhasattr(self,"origin_visualizer"):
+ self.origin_visualizer.set_visibility(False)
+ # report success
+ returnTrue
+
+"""
+ Operations - Import.
+ """
+
+
[文档]defimport_ground_plane(self,key:str,size:tuple[float,float]=(2.0e6,2.0e6)):
+"""Add a plane to the terrain importer.
+
+ Args:
+ key: The key to store the mesh.
+ size: The size of the plane. Defaults to (2.0e6, 2.0e6).
+
+ Raises:
+ ValueError: If a terrain with the same key already exists.
+ """
+ # check if key exists
+ ifkeyinself.meshes:
+ raiseValueError(f"Mesh with key {key} already exists. Existing keys: {self.meshes.keys()}.")
+ # create a plane
+ mesh=make_plane(size,height=0.0,center_zero=True)
+ # store the mesh
+ self.meshes[key]=mesh
+ # create a warp mesh
+ device="cuda"if"cuda"inself.deviceelse"cpu"
+ self.warp_meshes[key]=convert_to_warp_mesh(mesh.vertices,mesh.faces,device=device)
+
+ # get the mesh
+ ground_plane_cfg=sim_utils.GroundPlaneCfg(physics_material=self.cfg.physics_material,size=size)
+ ground_plane_cfg.func(self.cfg.prim_path,ground_plane_cfg)
+
+
[文档]defimport_mesh(self,key:str,mesh:trimesh.Trimesh):
+"""Import a mesh into the simulator.
+
+ The mesh is imported into the simulator under the prim path ``cfg.prim_path/{key}``. The created path
+ contains the mesh as a :class:`pxr.UsdGeom` instance along with visual or physics material prims.
+
+ Args:
+ key: The key to store the mesh.
+ mesh: The mesh to import.
+
+ Raises:
+ ValueError: If a terrain with the same key already exists.
+ """
+ # check if key exists
+ ifkeyinself.meshes:
+ raiseValueError(f"Mesh with key {key} already exists. Existing keys: {self.meshes.keys()}.")
+ # store the mesh
+ self.meshes[key]=mesh
+ # create a warp mesh
+ device="cuda"if"cuda"inself.deviceelse"cpu"
+ self.warp_meshes[key]=convert_to_warp_mesh(mesh.vertices,mesh.faces,device=device)
+
+ # get the mesh
+ mesh=self.meshes[key]
+ mesh_prim_path=self.cfg.prim_path+f"/{key}"
+ # import the mesh
+ create_prim_from_mesh(
+ mesh_prim_path,
+ mesh,
+ visual_material=self.cfg.visual_material,
+ physics_material=self.cfg.physics_material,
+ )
+
+
[文档]defimport_usd(self,key:str,usd_path:str):
+"""Import a mesh from a USD file.
+
+ We assume that the USD file contains a single mesh. If the USD file contains multiple meshes, then
+ the first mesh is used. The function mainly helps in registering the mesh into the warp meshes
+ and the meshes dictionary.
+
+ Note:
+ We do not apply any material properties to the mesh. The material properties should
+ be defined in the USD file.
+
+ Args:
+ key: The key to store the mesh.
+ usd_path: The path to the USD file.
+
+ Raises:
+ ValueError: If a terrain with the same key already exists.
+ """
+ # add mesh to the dict
+ ifkeyinself.meshes:
+ raiseValueError(f"Mesh with key {key} already exists. Existing keys: {self.meshes.keys()}.")
+ # add the prim path
+ cfg=sim_utils.UsdFileCfg(usd_path=usd_path)
+ cfg.func(self.cfg.prim_path+f"/{key}",cfg)
+
+ # traverse the prim and get the collision mesh
+ # THINK: Should the user specify the collision mesh?
+ mesh_prim=sim_utils.get_first_matching_child_prim(
+ self.cfg.prim_path+f"/{key}",lambdaprim:prim.GetTypeName()=="Mesh"
+ )
+ # check if the mesh is valid
+ ifmesh_primisNone:
+ raiseValueError(f"Could not find any collision mesh in {usd_path}. Please check asset.")
+ # cast into UsdGeomMesh
+ mesh_prim=UsdGeom.Mesh(mesh_prim)
+ # store the mesh
+ vertices=np.asarray(mesh_prim.GetPointsAttr().Get())
+ faces=np.asarray(mesh_prim.GetFaceVertexIndicesAttr().Get()).reshape(-1,3)
+ self.meshes[key]=trimesh.Trimesh(vertices=vertices,faces=faces)
+ # create a warp mesh
+ device="cuda"if"cuda"inself.deviceelse"cpu"
+ self.warp_meshes[key]=convert_to_warp_mesh(vertices,faces,device=device)
+
+"""
+ Operations - Origins.
+ """
+
+
[文档]defconfigure_env_origins(self,origins:np.ndarray|None=None):
+"""Configure the origins of the environments based on the added terrain.
+
+ Args:
+ origins: The origins of the sub-terrains. Shape is (num_rows, num_cols, 3).
+ """
+ # decide whether to compute origins in a grid or based on curriculum
+ iforiginsisnotNone:
+ # convert to numpy
+ ifisinstance(origins,np.ndarray):
+ origins=torch.from_numpy(origins)
+ # store the origins
+ self.terrain_origins=origins.to(self.device,dtype=torch.float)
+ # compute environment origins
+ self.env_origins=self._compute_env_origins_curriculum(self.cfg.num_envs,self.terrain_origins)
+ else:
+ self.terrain_origins=None
+ # check if env spacing is valid
+ ifself.cfg.env_spacingisNone:
+ raiseValueError("Environment spacing must be specified for configuring grid-like origins.")
+ # compute environment origins
+ self.env_origins=self._compute_env_origins_grid(self.cfg.num_envs,self.cfg.env_spacing)
+
+
[文档]defupdate_env_origins(self,env_ids:torch.Tensor,move_up:torch.Tensor,move_down:torch.Tensor):
+"""Update the environment origins based on the terrain levels."""
+ # check if grid-like spawning
+ ifself.terrain_originsisNone:
+ return
+ # update terrain level for the envs
+ self.terrain_levels[env_ids]+=1*move_up-1*move_down
+ # robots that solve the last level are sent to a random one
+ # the minimum level is zero
+ self.terrain_levels[env_ids]=torch.where(
+ self.terrain_levels[env_ids]>=self.max_terrain_level,
+ torch.randint_like(self.terrain_levels[env_ids],self.max_terrain_level),
+ torch.clip(self.terrain_levels[env_ids],0),
+ )
+ # update the env origins
+ self.env_origins[env_ids]=self.terrain_origins[self.terrain_levels[env_ids],self.terrain_types[env_ids]]
+
+"""
+ Internal helpers.
+ """
+
+ def_compute_env_origins_curriculum(self,num_envs:int,origins:torch.Tensor)->torch.Tensor:
+"""Compute the origins of the environments defined by the sub-terrains origins."""
+ # extract number of rows and cols
+ num_rows,num_cols=origins.shape[:2]
+ # maximum initial level possible for the terrains
+ ifself.cfg.max_init_terrain_levelisNone:
+ max_init_level=num_rows-1
+ else:
+ max_init_level=min(self.cfg.max_init_terrain_level,num_rows-1)
+ # store maximum terrain level possible
+ self.max_terrain_level=num_rows
+ # define all terrain levels and types available
+ self.terrain_levels=torch.randint(0,max_init_level+1,(num_envs,),device=self.device)
+ self.terrain_types=torch.div(
+ torch.arange(num_envs,device=self.device),
+ (num_envs/num_cols),
+ rounding_mode="floor",
+ ).to(torch.long)
+ # create tensor based on number of environments
+ env_origins=torch.zeros(num_envs,3,device=self.device)
+ env_origins[:]=origins[self.terrain_levels,self.terrain_types]
+ returnenv_origins
+
+ def_compute_env_origins_grid(self,num_envs:int,env_spacing:float)->torch.Tensor:
+"""Compute the origins of the environments in a grid based on configured spacing."""
+ # create tensor based on number of environments
+ env_origins=torch.zeros(num_envs,3,device=self.device)
+ # create a grid of origins
+ num_rows=np.ceil(num_envs/int(np.sqrt(num_envs)))
+ num_cols=np.ceil(num_envs/num_rows)
+ ii,jj=torch.meshgrid(
+ torch.arange(num_rows,device=self.device),torch.arange(num_cols,device=self.device),indexing="ij"
+ )
+ env_origins[:,0]=-(ii.flatten()[:num_envs]-(num_rows-1)/2)*env_spacing
+ env_origins[:,1]=(jj.flatten()[:num_envs]-(num_cols-1)/2)*env_spacing
+ env_origins[:,2]=0.0
+ returnenv_origins
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+from__future__importannotations
+
+fromdataclassesimportMISSING
+fromtypingimportTYPE_CHECKING,Literal
+
+importomni.isaac.lab.simassim_utils
+fromomni.isaac.lab.utilsimportconfigclass
+
+from.terrain_importerimportTerrainImporter
+
+ifTYPE_CHECKING:
+ from.terrain_generator_cfgimportTerrainGeneratorCfg
+
+
+
[文档]@configclass
+classTerrainImporterCfg:
+"""Configuration for the terrain manager."""
+
+ class_type:type=TerrainImporter
+"""The class to use for the terrain importer.
+
+ Defaults to :class:`omni.isaac.lab.terrains.terrain_importer.TerrainImporter`.
+ """
+
+ collision_group:int=-1
+"""The collision group of the terrain. Defaults to -1."""
+
+ prim_path:str=MISSING
+"""The absolute path of the USD terrain prim.
+
+ All sub-terrains are imported relative to this prim path.
+ """
+
+ num_envs:int=MISSING
+"""The number of environment origins to consider."""
+
+ terrain_type:Literal["generator","plane","usd"]="generator"
+"""The type of terrain to generate. Defaults to "generator".
+
+ Available options are "plane", "usd", and "generator".
+ """
+
+ terrain_generator:TerrainGeneratorCfg|None=None
+"""The terrain generator configuration.
+
+ Only used if ``terrain_type`` is set to "generator".
+ """
+
+ usd_path:str|None=None
+"""The path to the USD file containing the terrain.
+
+ Only used if ``terrain_type`` is set to "usd".
+ """
+
+ env_spacing:float|None=None
+"""The spacing between environment origins when defined in a grid. Defaults to None.
+
+ Note:
+ This parameter is used only when the ``terrain_type`` is ``"plane"`` or ``"usd"``.
+ """
+
+ visual_material:sim_utils.VisualMaterialCfg|None=sim_utils.PreviewSurfaceCfg(
+ diffuse_color=(0.065,0.0725,0.080)
+ )
+"""The visual material of the terrain. Defaults to a dark gray color material.
+
+ The material is created at the path: ``{prim_path}/visualMaterial``. If `None`, then no material is created.
+
+ .. note::
+ This parameter is used only when the ``terrain_type`` is ``"generator"``.
+ """
+
+ physics_material:sim_utils.RigidBodyMaterialCfg=sim_utils.RigidBodyMaterialCfg()
+"""The physics material of the terrain. Defaults to a default physics material.
+
+ The material is created at the path: ``{prim_path}/physicsMaterial``.
+
+ .. note::
+ This parameter is used only when the ``terrain_type`` is ``"generator"`` or ``"plane"``.
+ """
+
+ max_init_terrain_level:int|None=None
+"""The maximum initial terrain level for defining environment origins. Defaults to None.
+
+ The terrain levels are specified by the number of rows in the grid arrangement of
+ sub-terrains. If None, then the initial terrain level is set to the maximum
+ terrain level available (``num_rows - 1``).
+
+ Note:
+ This parameter is used only when sub-terrain origins are defined.
+ """
+
+ debug_vis:bool=False
+"""Whether to enable visualization of terrain origins for the terrain. Defaults to False."""
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Functions to generate different terrains using the ``trimesh`` library."""
+
+from__future__importannotations
+
+importnumpyasnp
+importscipy.spatial.transformastf
+importtorch
+importtrimesh
+fromtypingimportTYPE_CHECKING
+
+from.utilsimport*# noqa: F401, F403
+from.utilsimportmake_border,make_plane
+
+ifTYPE_CHECKING:
+ from.importmesh_terrains_cfg
+
+
+
[文档]defflat_terrain(
+ difficulty:float,cfg:mesh_terrains_cfg.MeshPlaneTerrainCfg
+)->tuple[list[trimesh.Trimesh],np.ndarray]:
+"""Generate a flat terrain as a plane.
+
+ .. image:: ../../_static/terrains/trimesh/flat_terrain.jpg
+ :width: 45%
+ :align: center
+
+ Note:
+ The :obj:`difficulty` parameter is ignored for this terrain.
+
+ Args:
+ difficulty: The difficulty of the terrain. This is a value between 0 and 1.
+ cfg: The configuration for the terrain.
+
+ Returns:
+ A tuple containing the tri-mesh of the terrain and the origin of the terrain (in m).
+ """
+ # compute the position of the terrain
+ origin=(cfg.size[0]/2.0,cfg.size[1]/2.0,0.0)
+ # compute the vertices of the terrain
+ plane_mesh=make_plane(cfg.size,0.0,center_zero=False)
+ # return the tri-mesh and the position
+ return[plane_mesh],np.array(origin)
+
+
+
[文档]defpyramid_stairs_terrain(
+ difficulty:float,cfg:mesh_terrains_cfg.MeshPyramidStairsTerrainCfg
+)->tuple[list[trimesh.Trimesh],np.ndarray]:
+"""Generate a terrain with a pyramid stair pattern.
+
+ The terrain is a pyramid stair pattern which trims to a flat platform at the center of the terrain.
+
+ If :obj:`cfg.holes` is True, the terrain will have pyramid stairs of length or width
+ :obj:`cfg.platform_width` (depending on the direction) with no steps in the remaining area. Additionally,
+ no border will be added.
+
+ .. image:: ../../_static/terrains/trimesh/pyramid_stairs_terrain.jpg
+ :width: 45%
+
+ .. image:: ../../_static/terrains/trimesh/pyramid_stairs_terrain_with_holes.jpg
+ :width: 45%
+
+ Args:
+ difficulty: The difficulty of the terrain. This is a value between 0 and 1.
+ cfg: The configuration for the terrain.
+
+ Returns:
+ A tuple containing the tri-mesh of the terrain and the origin of the terrain (in m).
+ """
+ # resolve the terrain configuration
+ step_height=cfg.step_height_range[0]+difficulty*(cfg.step_height_range[1]-cfg.step_height_range[0])
+
+ # compute number of steps in x and y direction
+ num_steps_x=(cfg.size[0]-2*cfg.border_width-cfg.platform_width)//(2*cfg.step_width)+1
+ num_steps_y=(cfg.size[1]-2*cfg.border_width-cfg.platform_width)//(2*cfg.step_width)+1
+ # we take the minimum number of steps in x and y direction
+ num_steps=int(min(num_steps_x,num_steps_y))
+
+ # initialize list of meshes
+ meshes_list=list()
+
+ # generate the border if needed
+ ifcfg.border_width>0.0andnotcfg.holes:
+ # obtain a list of meshes for the border
+ border_center=[0.5*cfg.size[0],0.5*cfg.size[1],-step_height/2]
+ border_inner_size=(cfg.size[0]-2*cfg.border_width,cfg.size[1]-2*cfg.border_width)
+ make_borders=make_border(cfg.size,border_inner_size,step_height,border_center)
+ # add the border meshes to the list of meshes
+ meshes_list+=make_borders
+
+ # generate the terrain
+ # -- compute the position of the center of the terrain
+ terrain_center=[0.5*cfg.size[0],0.5*cfg.size[1],0.0]
+ terrain_size=(cfg.size[0]-2*cfg.border_width,cfg.size[1]-2*cfg.border_width)
+ # -- generate the stair pattern
+ forkinrange(num_steps):
+ # check if we need to add holes around the steps
+ ifcfg.holes:
+ box_size=(cfg.platform_width,cfg.platform_width)
+ else:
+ box_size=(terrain_size[0]-2*k*cfg.step_width,terrain_size[1]-2*k*cfg.step_width)
+ # compute the quantities of the box
+ # -- location
+ box_z=terrain_center[2]+k*step_height/2.0
+ box_offset=(k+0.5)*cfg.step_width
+ # -- dimensions
+ box_height=(k+2)*step_height
+ # generate the boxes
+ # top/bottom
+ box_dims=(box_size[0],cfg.step_width,box_height)
+ # -- top
+ box_pos=(terrain_center[0],terrain_center[1]+terrain_size[1]/2.0-box_offset,box_z)
+ box_top=trimesh.creation.box(box_dims,trimesh.transformations.translation_matrix(box_pos))
+ # -- bottom
+ box_pos=(terrain_center[0],terrain_center[1]-terrain_size[1]/2.0+box_offset,box_z)
+ box_bottom=trimesh.creation.box(box_dims,trimesh.transformations.translation_matrix(box_pos))
+ # right/left
+ ifcfg.holes:
+ box_dims=(cfg.step_width,box_size[1],box_height)
+ else:
+ box_dims=(cfg.step_width,box_size[1]-2*cfg.step_width,box_height)
+ # -- right
+ box_pos=(terrain_center[0]+terrain_size[0]/2.0-box_offset,terrain_center[1],box_z)
+ box_right=trimesh.creation.box(box_dims,trimesh.transformations.translation_matrix(box_pos))
+ # -- left
+ box_pos=(terrain_center[0]-terrain_size[0]/2.0+box_offset,terrain_center[1],box_z)
+ box_left=trimesh.creation.box(box_dims,trimesh.transformations.translation_matrix(box_pos))
+ # add the boxes to the list of meshes
+ meshes_list+=[box_top,box_bottom,box_right,box_left]
+
+ # generate final box for the middle of the terrain
+ box_dims=(
+ terrain_size[0]-2*num_steps*cfg.step_width,
+ terrain_size[1]-2*num_steps*cfg.step_width,
+ (num_steps+2)*step_height,
+ )
+ box_pos=(terrain_center[0],terrain_center[1],terrain_center[2]+num_steps*step_height/2)
+ box_middle=trimesh.creation.box(box_dims,trimesh.transformations.translation_matrix(box_pos))
+ meshes_list.append(box_middle)
+ # origin of the terrain
+ origin=np.array([terrain_center[0],terrain_center[1],(num_steps+1)*step_height])
+
+ returnmeshes_list,origin
+
+
+
[文档]definverted_pyramid_stairs_terrain(
+ difficulty:float,cfg:mesh_terrains_cfg.MeshInvertedPyramidStairsTerrainCfg
+)->tuple[list[trimesh.Trimesh],np.ndarray]:
+"""Generate a terrain with a inverted pyramid stair pattern.
+
+ The terrain is an inverted pyramid stair pattern which trims to a flat platform at the center of the terrain.
+
+ If :obj:`cfg.holes` is True, the terrain will have pyramid stairs of length or width
+ :obj:`cfg.platform_width` (depending on the direction) with no steps in the remaining area. Additionally,
+ no border will be added.
+
+ .. image:: ../../_static/terrains/trimesh/inverted_pyramid_stairs_terrain.jpg
+ :width: 45%
+
+ .. image:: ../../_static/terrains/trimesh/inverted_pyramid_stairs_terrain_with_holes.jpg
+ :width: 45%
+
+ Args:
+ difficulty: The difficulty of the terrain. This is a value between 0 and 1.
+ cfg: The configuration for the terrain.
+
+ Returns:
+ A tuple containing the tri-mesh of the terrain and the origin of the terrain (in m).
+ """
+ # resolve the terrain configuration
+ step_height=cfg.step_height_range[0]+difficulty*(cfg.step_height_range[1]-cfg.step_height_range[0])
+
+ # compute number of steps in x and y direction
+ num_steps_x=(cfg.size[0]-2*cfg.border_width-cfg.platform_width)//(2*cfg.step_width)+1
+ num_steps_y=(cfg.size[1]-2*cfg.border_width-cfg.platform_width)//(2*cfg.step_width)+1
+ # we take the minimum number of steps in x and y direction
+ num_steps=int(min(num_steps_x,num_steps_y))
+ # total height of the terrain
+ total_height=(num_steps+1)*step_height
+
+ # initialize list of meshes
+ meshes_list=list()
+
+ # generate the border if needed
+ ifcfg.border_width>0.0andnotcfg.holes:
+ # obtain a list of meshes for the border
+ border_center=[0.5*cfg.size[0],0.5*cfg.size[1],-0.5*step_height]
+ border_inner_size=(cfg.size[0]-2*cfg.border_width,cfg.size[1]-2*cfg.border_width)
+ make_borders=make_border(cfg.size,border_inner_size,step_height,border_center)
+ # add the border meshes to the list of meshes
+ meshes_list+=make_borders
+ # generate the terrain
+ # -- compute the position of the center of the terrain
+ terrain_center=[0.5*cfg.size[0],0.5*cfg.size[1],0.0]
+ terrain_size=(cfg.size[0]-2*cfg.border_width,cfg.size[1]-2*cfg.border_width)
+ # -- generate the stair pattern
+ forkinrange(num_steps):
+ # check if we need to add holes around the steps
+ ifcfg.holes:
+ box_size=(cfg.platform_width,cfg.platform_width)
+ else:
+ box_size=(terrain_size[0]-2*k*cfg.step_width,terrain_size[1]-2*k*cfg.step_width)
+ # compute the quantities of the box
+ # -- location
+ box_z=terrain_center[2]-total_height/2-(k+1)*step_height/2.0
+ box_offset=(k+0.5)*cfg.step_width
+ # -- dimensions
+ box_height=total_height-(k+1)*step_height
+ # generate the boxes
+ # top/bottom
+ box_dims=(box_size[0],cfg.step_width,box_height)
+ # -- top
+ box_pos=(terrain_center[0],terrain_center[1]+terrain_size[1]/2.0-box_offset,box_z)
+ box_top=trimesh.creation.box(box_dims,trimesh.transformations.translation_matrix(box_pos))
+ # -- bottom
+ box_pos=(terrain_center[0],terrain_center[1]-terrain_size[1]/2.0+box_offset,box_z)
+ box_bottom=trimesh.creation.box(box_dims,trimesh.transformations.translation_matrix(box_pos))
+ # right/left
+ ifcfg.holes:
+ box_dims=(cfg.step_width,box_size[1],box_height)
+ else:
+ box_dims=(cfg.step_width,box_size[1]-2*cfg.step_width,box_height)
+ # -- right
+ box_pos=(terrain_center[0]+terrain_size[0]/2.0-box_offset,terrain_center[1],box_z)
+ box_right=trimesh.creation.box(box_dims,trimesh.transformations.translation_matrix(box_pos))
+ # -- left
+ box_pos=(terrain_center[0]-terrain_size[0]/2.0+box_offset,terrain_center[1],box_z)
+ box_left=trimesh.creation.box(box_dims,trimesh.transformations.translation_matrix(box_pos))
+ # add the boxes to the list of meshes
+ meshes_list+=[box_top,box_bottom,box_right,box_left]
+ # generate final box for the middle of the terrain
+ box_dims=(
+ terrain_size[0]-2*num_steps*cfg.step_width,
+ terrain_size[1]-2*num_steps*cfg.step_width,
+ step_height,
+ )
+ box_pos=(terrain_center[0],terrain_center[1],terrain_center[2]-total_height-step_height/2)
+ box_middle=trimesh.creation.box(box_dims,trimesh.transformations.translation_matrix(box_pos))
+ meshes_list.append(box_middle)
+ # origin of the terrain
+ origin=np.array([terrain_center[0],terrain_center[1],-(num_steps+1)*step_height])
+
+ returnmeshes_list,origin
+
+
+
[文档]defrandom_grid_terrain(
+ difficulty:float,cfg:mesh_terrains_cfg.MeshRandomGridTerrainCfg
+)->tuple[list[trimesh.Trimesh],np.ndarray]:
+"""Generate a terrain with cells of random heights and fixed width.
+
+ The terrain is generated in the x-y plane and has a height of 1.0. It is then divided into a grid of the
+ specified size :obj:`cfg.grid_width`. Each grid cell is then randomly shifted in the z-direction by a value uniformly
+ sampled between :obj:`cfg.grid_height_range`. At the center of the terrain, a platform of the specified width
+ :obj:`cfg.platform_width` is generated.
+
+ If :obj:`cfg.holes` is True, the terrain will have randomized grid cells only along the plane extending
+ from the platform (like a plus sign). The remaining area remains empty and no border will be added.
+
+ .. image:: ../../_static/terrains/trimesh/random_grid_terrain.jpg
+ :width: 45%
+
+ .. image:: ../../_static/terrains/trimesh/random_grid_terrain_with_holes.jpg
+ :width: 45%
+
+ Args:
+ difficulty: The difficulty of the terrain. This is a value between 0 and 1.
+ cfg: The configuration for the terrain.
+
+ Returns:
+ A tuple containing the tri-mesh of the terrain and the origin of the terrain (in m).
+
+ Raises:
+ ValueError: If the terrain is not square. This method only supports square terrains.
+ RuntimeError: If the grid width is large such that the border width is negative.
+ """
+ # check to ensure square terrain
+ ifcfg.size[0]!=cfg.size[1]:
+ raiseValueError(f"The terrain must be square. Received size: {cfg.size}.")
+ # resolve the terrain configuration
+ grid_height=cfg.grid_height_range[0]+difficulty*(cfg.grid_height_range[1]-cfg.grid_height_range[0])
+
+ # initialize list of meshes
+ meshes_list=list()
+ # compute the number of boxes in each direction
+ num_boxes_x=int(cfg.size[0]/cfg.grid_width)
+ num_boxes_y=int(cfg.size[1]/cfg.grid_width)
+ # constant parameters
+ terrain_height=1.0
+ device=torch.device("cuda")iftorch.cuda.is_available()elsetorch.device("cpu")
+
+ # generate the border
+ border_width=cfg.size[0]-min(num_boxes_x,num_boxes_y)*cfg.grid_width
+ ifborder_width>0:
+ # compute parameters for the border
+ border_center=(0.5*cfg.size[0],0.5*cfg.size[1],-terrain_height/2)
+ border_inner_size=(cfg.size[0]-border_width,cfg.size[1]-border_width)
+ # create border meshes
+ make_borders=make_border(cfg.size,border_inner_size,terrain_height,border_center)
+ meshes_list+=make_borders
+ else:
+ raiseRuntimeError("Border width must be greater than 0! Adjust the parameter 'cfg.grid_width'.")
+
+ # create a template grid of terrain height
+ grid_dim=[cfg.grid_width,cfg.grid_width,terrain_height]
+ grid_position=[0.5*cfg.grid_width,0.5*cfg.grid_width,-terrain_height/2]
+ template_box=trimesh.creation.box(grid_dim,trimesh.transformations.translation_matrix(grid_position))
+ # extract vertices and faces of the box to create a template
+ template_vertices=template_box.vertices# (8, 3)
+ template_faces=template_box.faces
+
+ # repeat the template box vertices to span the terrain (num_boxes_x * num_boxes_y, 8, 3)
+ vertices=torch.tensor(template_vertices,device=device).repeat(num_boxes_x*num_boxes_y,1,1)
+ # create a meshgrid to offset the vertices
+ x=torch.arange(0,num_boxes_x,device=device)
+ y=torch.arange(0,num_boxes_y,device=device)
+ xx,yy=torch.meshgrid(x,y,indexing="ij")
+ xx=xx.flatten().view(-1,1)
+ yy=yy.flatten().view(-1,1)
+ xx_yy=torch.cat((xx,yy),dim=1)
+ # offset the vertices
+ offsets=cfg.grid_width*xx_yy+border_width/2
+ vertices[:,:,:2]+=offsets.unsqueeze(1)
+ # mask the vertices to create holes, s.t. only grids along the x and y axis are present
+ ifcfg.holes:
+ # -- x-axis
+ mask_x=torch.logical_and(
+ (vertices[:,:,0]>(cfg.size[0]-border_width-cfg.platform_width)/2).all(dim=1),
+ (vertices[:,:,0]<(cfg.size[0]+border_width+cfg.platform_width)/2).all(dim=1),
+ )
+ vertices_x=vertices[mask_x]
+ # -- y-axis
+ mask_y=torch.logical_and(
+ (vertices[:,:,1]>(cfg.size[1]-border_width-cfg.platform_width)/2).all(dim=1),
+ (vertices[:,:,1]<(cfg.size[1]+border_width+cfg.platform_width)/2).all(dim=1),
+ )
+ vertices_y=vertices[mask_y]
+ # -- combine these vertices
+ vertices=torch.cat((vertices_x,vertices_y))
+ # add noise to the vertices to have a random height over each grid cell
+ num_boxes=len(vertices)
+ # create noise for the z-axis
+ h_noise=torch.zeros((num_boxes,3),device=device)
+ h_noise[:,2].uniform_(-grid_height,grid_height)
+ # reshape noise to match the vertices (num_boxes, 4, 3)
+ # only the top vertices of the box are affected
+ vertices_noise=torch.zeros((num_boxes,4,3),device=device)
+ vertices_noise+=h_noise.unsqueeze(1)
+ # add height only to the top vertices of the box
+ vertices[vertices[:,:,2]==0]+=vertices_noise.view(-1,3)
+ # move to numpy
+ vertices=vertices.reshape(-1,3).cpu().numpy()
+
+ # create faces for boxes (num_boxes, 12, 3). Each box has 6 faces, each face has 2 triangles.
+ faces=torch.tensor(template_faces,device=device).repeat(num_boxes,1,1)
+ face_offsets=torch.arange(0,num_boxes,device=device).unsqueeze(1).repeat(1,12)*8
+ faces+=face_offsets.unsqueeze(2)
+ # move to numpy
+ faces=faces.view(-1,3).cpu().numpy()
+ # convert to trimesh
+ grid_mesh=trimesh.Trimesh(vertices=vertices,faces=faces)
+ meshes_list.append(grid_mesh)
+
+ # add a platform in the center of the terrain that is accessible from all sides
+ dim=(cfg.platform_width,cfg.platform_width,terrain_height+grid_height)
+ pos=(0.5*cfg.size[0],0.5*cfg.size[1],-terrain_height/2+grid_height/2)
+ box_platform=trimesh.creation.box(dim,trimesh.transformations.translation_matrix(pos))
+ meshes_list.append(box_platform)
+
+ # specify the origin of the terrain
+ origin=np.array([0.5*cfg.size[0],0.5*cfg.size[1],grid_height])
+
+ returnmeshes_list,origin
+
+
+
[文档]defrails_terrain(
+ difficulty:float,cfg:mesh_terrains_cfg.MeshRailsTerrainCfg
+)->tuple[list[trimesh.Trimesh],np.ndarray]:
+"""Generate a terrain with box rails as extrusions.
+
+ The terrain contains two sets of box rails created as extrusions. The first set (inner rails) is extruded from
+ the platform at the center of the terrain, and the second set is extruded between the first set of rails
+ and the terrain border. Each set of rails is extruded to the same height.
+
+ .. image:: ../../_static/terrains/trimesh/rails_terrain.jpg
+ :width: 40%
+ :align: center
+
+ Args:
+ difficulty: The difficulty of the terrain. this is a value between 0 and 1.
+ cfg: The configuration for the terrain.
+
+ Returns:
+ A tuple containing the tri-mesh of the terrain and the origin of the terrain (in m).
+ """
+ # resolve the terrain configuration
+ rail_height=cfg.rail_height_range[1]-difficulty*(cfg.rail_height_range[1]-cfg.rail_height_range[0])
+
+ # initialize list of meshes
+ meshes_list=list()
+ # extract quantities
+ rail_1_thickness,rail_2_thickness=cfg.rail_thickness_range
+ rail_center=(0.5*cfg.size[0],0.5*cfg.size[1],rail_height*0.5)
+ # constants for terrain generation
+ terrain_height=1.0
+ rail_2_ratio=0.6
+
+ # generate first set of rails
+ rail_1_inner_size=(cfg.platform_width,cfg.platform_width)
+ rail_1_outer_size=(cfg.platform_width+2.0*rail_1_thickness,cfg.platform_width+2.0*rail_1_thickness)
+ meshes_list+=make_border(rail_1_outer_size,rail_1_inner_size,rail_height,rail_center)
+ # generate second set of rails
+ rail_2_inner_x=cfg.platform_width+(cfg.size[0]-cfg.platform_width)*rail_2_ratio
+ rail_2_inner_y=cfg.platform_width+(cfg.size[1]-cfg.platform_width)*rail_2_ratio
+ rail_2_inner_size=(rail_2_inner_x,rail_2_inner_y)
+ rail_2_outer_size=(rail_2_inner_x+2.0*rail_2_thickness,rail_2_inner_y+2.0*rail_2_thickness)
+ meshes_list+=make_border(rail_2_outer_size,rail_2_inner_size,rail_height,rail_center)
+ # generate the ground
+ dim=(cfg.size[0],cfg.size[1],terrain_height)
+ pos=(0.5*cfg.size[0],0.5*cfg.size[1],-terrain_height/2)
+ ground_meshes=trimesh.creation.box(dim,trimesh.transformations.translation_matrix(pos))
+ meshes_list.append(ground_meshes)
+
+ # specify the origin of the terrain
+ origin=np.array([pos[0],pos[1],0.0])
+
+ returnmeshes_list,origin
+
+
+
[文档]defpit_terrain(
+ difficulty:float,cfg:mesh_terrains_cfg.MeshPitTerrainCfg
+)->tuple[list[trimesh.Trimesh],np.ndarray]:
+"""Generate a terrain with a pit with levels (stairs) leading out of the pit.
+
+ The terrain contains a platform at the center and a staircase leading out of the pit.
+ The staircase is a series of steps that are aligned along the x- and y- axis. The steps are
+ created by extruding a ring along the x- and y- axis. If :obj:`is_double_pit` is True, the pit
+ contains two levels.
+
+ .. image:: ../../_static/terrains/trimesh/pit_terrain.jpg
+ :width: 40%
+
+ .. image:: ../../_static/terrains/trimesh/pit_terrain_with_two_levels.jpg
+ :width: 40%
+
+ Args:
+ difficulty: The difficulty of the terrain. This is a value between 0 and 1.
+ cfg: The configuration for the terrain.
+
+ Returns:
+ A tuple containing the tri-mesh of the terrain and the origin of the terrain (in m).
+ """
+ # resolve the terrain configuration
+ pit_depth=cfg.pit_depth_range[0]+difficulty*(cfg.pit_depth_range[1]-cfg.pit_depth_range[0])
+
+ # initialize list of meshes
+ meshes_list=list()
+ # extract quantities
+ inner_pit_size=(cfg.platform_width,cfg.platform_width)
+ total_depth=pit_depth
+ # constants for terrain generation
+ terrain_height=1.0
+ ring_2_ratio=0.6
+
+ # if the pit is double, the inner ring is smaller to fit the second level
+ ifcfg.double_pit:
+ # increase the total height of the pit
+ total_depth*=2.0
+ # reduce the size of the inner ring
+ inner_pit_x=cfg.platform_width+(cfg.size[0]-cfg.platform_width)*ring_2_ratio
+ inner_pit_y=cfg.platform_width+(cfg.size[1]-cfg.platform_width)*ring_2_ratio
+ inner_pit_size=(inner_pit_x,inner_pit_y)
+
+ # generate the pit (outer ring)
+ pit_center=[0.5*cfg.size[0],0.5*cfg.size[1],-total_depth*0.5]
+ meshes_list+=make_border(cfg.size,inner_pit_size,total_depth,pit_center)
+ # generate the second level of the pit (inner ring)
+ ifcfg.double_pit:
+ pit_center[2]=-total_depth
+ meshes_list+=make_border(inner_pit_size,(cfg.platform_width,cfg.platform_width),total_depth,pit_center)
+ # generate the ground
+ dim=(cfg.size[0],cfg.size[1],terrain_height)
+ pos=(0.5*cfg.size[0],0.5*cfg.size[1],-total_depth-terrain_height/2)
+ ground_meshes=trimesh.creation.box(dim,trimesh.transformations.translation_matrix(pos))
+ meshes_list.append(ground_meshes)
+
+ # specify the origin of the terrain
+ origin=np.array([pos[0],pos[1],-total_depth])
+
+ returnmeshes_list,origin
+
+
+
[文档]defbox_terrain(
+ difficulty:float,cfg:mesh_terrains_cfg.MeshBoxTerrainCfg
+)->tuple[list[trimesh.Trimesh],np.ndarray]:
+"""Generate a terrain with boxes (similar to a pyramid).
+
+ The terrain has a ground with boxes on top of it that are stacked on top of each other.
+ The boxes are created by extruding a rectangle along the z-axis. If :obj:`double_box` is True,
+ then two boxes of height :obj:`box_height` are stacked on top of each other.
+
+ .. image:: ../../_static/terrains/trimesh/box_terrain.jpg
+ :width: 40%
+
+ .. image:: ../../_static/terrains/trimesh/box_terrain_with_two_boxes.jpg
+ :width: 40%
+
+ Args:
+ difficulty: The difficulty of the terrain. This is a value between 0 and 1.
+ cfg: The configuration for the terrain.
+
+ Returns:
+ A tuple containing the tri-mesh of the terrain and the origin of the terrain (in m).
+ """
+ # resolve the terrain configuration
+ box_height=cfg.box_height_range[0]+difficulty*(cfg.box_height_range[1]-cfg.box_height_range[0])
+
+ # initialize list of meshes
+ meshes_list=list()
+ # extract quantities
+ total_height=box_height
+ ifcfg.double_box:
+ total_height*=2.0
+ # constants for terrain generation
+ terrain_height=1.0
+ box_2_ratio=0.6
+
+ # Generate the top box
+ dim=(cfg.platform_width,cfg.platform_width,terrain_height+total_height)
+ pos=(0.5*cfg.size[0],0.5*cfg.size[1],(total_height-terrain_height)/2)
+ box_mesh=trimesh.creation.box(dim,trimesh.transformations.translation_matrix(pos))
+ meshes_list.append(box_mesh)
+ # Generate the lower box
+ ifcfg.double_box:
+ # calculate the size of the lower box
+ outer_box_x=cfg.platform_width+(cfg.size[0]-cfg.platform_width)*box_2_ratio
+ outer_box_y=cfg.platform_width+(cfg.size[1]-cfg.platform_width)*box_2_ratio
+ # create the lower box
+ dim=(outer_box_x,outer_box_y,terrain_height+total_height/2)
+ pos=(0.5*cfg.size[0],0.5*cfg.size[1],(total_height-terrain_height)/2-total_height/4)
+ box_mesh=trimesh.creation.box(dim,trimesh.transformations.translation_matrix(pos))
+ meshes_list.append(box_mesh)
+ # Generate the ground
+ pos=(0.5*cfg.size[0],0.5*cfg.size[1],-terrain_height/2)
+ dim=(cfg.size[0],cfg.size[1],terrain_height)
+ ground_mesh=trimesh.creation.box(dim,trimesh.transformations.translation_matrix(pos))
+ meshes_list.append(ground_mesh)
+
+ # specify the origin of the terrain
+ origin=np.array([pos[0],pos[1],total_height])
+
+ returnmeshes_list,origin
+
+
+
[文档]defgap_terrain(
+ difficulty:float,cfg:mesh_terrains_cfg.MeshGapTerrainCfg
+)->tuple[list[trimesh.Trimesh],np.ndarray]:
+"""Generate a terrain with a gap around the platform.
+
+ The terrain has a ground with a platform in the middle. The platform is surrounded by a gap
+ of width :obj:`gap_width` on all sides.
+
+ .. image:: ../../_static/terrains/trimesh/gap_terrain.jpg
+ :width: 40%
+ :align: center
+
+ Args:
+ difficulty: The difficulty of the terrain. This is a value between 0 and 1.
+ cfg: The configuration for the terrain.
+
+ Returns:
+ A tuple containing the tri-mesh of the terrain and the origin of the terrain (in m).
+ """
+ # resolve the terrain configuration
+ gap_width=cfg.gap_width_range[0]+difficulty*(cfg.gap_width_range[1]-cfg.gap_width_range[0])
+
+ # initialize list of meshes
+ meshes_list=list()
+ # constants for terrain generation
+ terrain_height=1.0
+ terrain_center=(0.5*cfg.size[0],0.5*cfg.size[1],-terrain_height/2)
+
+ # Generate the outer ring
+ inner_size=(cfg.platform_width+2*gap_width,cfg.platform_width+2*gap_width)
+ meshes_list+=make_border(cfg.size,inner_size,terrain_height,terrain_center)
+ # Generate the inner box
+ box_dim=(cfg.platform_width,cfg.platform_width,terrain_height)
+ box=trimesh.creation.box(box_dim,trimesh.transformations.translation_matrix(terrain_center))
+ meshes_list.append(box)
+
+ # specify the origin of the terrain
+ origin=np.array([terrain_center[0],terrain_center[1],0.0])
+
+ returnmeshes_list,origin
+
+
+
[文档]deffloating_ring_terrain(
+ difficulty:float,cfg:mesh_terrains_cfg.MeshFloatingRingTerrainCfg
+)->tuple[list[trimesh.Trimesh],np.ndarray]:
+"""Generate a terrain with a floating square ring.
+
+ The terrain has a ground with a floating ring in the middle. The ring extends from the center from
+ :obj:`platform_width` to :obj:`platform_width` + :obj:`ring_width` in the x and y directions.
+ The thickness of the ring is :obj:`ring_thickness` and the height of the ring from the terrain
+ is :obj:`ring_height`.
+
+ .. image:: ../../_static/terrains/trimesh/floating_ring_terrain.jpg
+ :width: 40%
+ :align: center
+
+ Args:
+ difficulty: The difficulty of the terrain. This is a value between 0 and 1.
+ cfg: The configuration for the terrain.
+
+ Returns:
+ A tuple containing the tri-mesh of the terrain and the origin of the terrain (in m).
+ """
+ # resolve the terrain configuration
+ ring_height=cfg.ring_height_range[1]-difficulty*(cfg.ring_height_range[1]-cfg.ring_height_range[0])
+ ring_width=cfg.ring_width_range[0]+difficulty*(cfg.ring_width_range[1]-cfg.ring_width_range[0])
+
+ # initialize list of meshes
+ meshes_list=list()
+ # constants for terrain generation
+ terrain_height=1.0
+
+ # Generate the floating ring
+ ring_center=(0.5*cfg.size[0],0.5*cfg.size[1],ring_height+0.5*cfg.ring_thickness)
+ ring_outer_size=(cfg.platform_width+2*ring_width,cfg.platform_width+2*ring_width)
+ ring_inner_size=(cfg.platform_width,cfg.platform_width)
+ meshes_list+=make_border(ring_outer_size,ring_inner_size,cfg.ring_thickness,ring_center)
+ # Generate the ground
+ dim=(cfg.size[0],cfg.size[1],terrain_height)
+ pos=(0.5*cfg.size[0],0.5*cfg.size[1],-terrain_height/2)
+ ground=trimesh.creation.box(dim,trimesh.transformations.translation_matrix(pos))
+ meshes_list.append(ground)
+
+ # specify the origin of the terrain
+ origin=np.asarray([pos[0],pos[1],0.0])
+
+ returnmeshes_list,origin
+
+
+
[文档]defstar_terrain(
+ difficulty:float,cfg:mesh_terrains_cfg.MeshStarTerrainCfg
+)->tuple[list[trimesh.Trimesh],np.ndarray]:
+"""Generate a terrain with a star.
+
+ The terrain has a ground with a cylinder in the middle. The star is made of :obj:`num_bars` bars
+ with a width of :obj:`bar_width` and a height of :obj:`bar_height`. The bars are evenly
+ spaced around the cylinder and connect to the peripheral of the terrain.
+
+ .. image:: ../../_static/terrains/trimesh/star_terrain.jpg
+ :width: 40%
+ :align: center
+
+ Args:
+ difficulty: The difficulty of the terrain. This is a value between 0 and 1.
+ cfg: The configuration for the terrain.
+
+ Returns:
+ A tuple containing the tri-mesh of the terrain and the origin of the terrain (in m).
+
+ Raises:
+ ValueError: If :obj:`num_bars` is less than 2.
+ """
+ # check the number of bars
+ ifcfg.num_bars<2:
+ raiseValueError(f"The number of bars in the star must be greater than 2. Received: {cfg.num_bars}")
+
+ # resolve the terrain configuration
+ bar_height=cfg.bar_height_range[0]+difficulty*(cfg.bar_height_range[1]-cfg.bar_height_range[0])
+ bar_width=cfg.bar_width_range[1]-difficulty*(cfg.bar_width_range[1]-cfg.bar_width_range[0])
+
+ # initialize list of meshes
+ meshes_list=list()
+ # Generate a platform in the middle
+ platform_center=(0.5*cfg.size[0],0.5*cfg.size[1],-bar_height/2)
+ platform_transform=trimesh.transformations.translation_matrix(platform_center)
+ platform=trimesh.creation.cylinder(
+ cfg.platform_width*0.5,bar_height,sections=2*cfg.num_bars,transform=platform_transform
+ )
+ meshes_list.append(platform)
+ # Generate bars to connect the platform to the terrain
+ transform=np.eye(4)
+ transform[:3,-1]=np.asarray(platform_center)
+ yaw=0.0
+ for_inrange(cfg.num_bars):
+ # compute the length of the bar based on the yaw
+ # length changes since the bar is connected to a square border
+ bar_length=cfg.size[0]
+ ifyaw<0.25*np.pi:
+ bar_length/=np.math.cos(yaw)
+ elifyaw<0.75*np.pi:
+ bar_length/=np.math.sin(yaw)
+ else:
+ bar_length/=np.math.cos(np.pi-yaw)
+ # compute the transform of the bar
+ transform[0:3,0:3]=tf.Rotation.from_euler("z",yaw).as_matrix()
+ # add the bar to the mesh
+ dim=[bar_length-bar_width,bar_width,bar_height]
+ bar=trimesh.creation.box(dim,transform)
+ meshes_list.append(bar)
+ # increment the yaw
+ yaw+=np.pi/cfg.num_bars
+ # Generate the exterior border
+ inner_size=(cfg.size[0]-2*bar_width,cfg.size[1]-2*bar_width)
+ meshes_list+=make_border(cfg.size,inner_size,bar_height,platform_center)
+ # Generate the ground
+ ground=make_plane(cfg.size,-bar_height,center_zero=False)
+ meshes_list.append(ground)
+ # specify the origin of the terrain
+ origin=np.asarray([0.5*cfg.size[0],0.5*cfg.size[1],0.0])
+
+ returnmeshes_list,origin
+
+
+
[文档]defrepeated_objects_terrain(
+ difficulty:float,cfg:mesh_terrains_cfg.MeshRepeatedObjectsTerrainCfg
+)->tuple[list[trimesh.Trimesh],np.ndarray]:
+"""Generate a terrain with a set of repeated objects.
+
+ The terrain has a ground with a platform in the middle. The objects are randomly placed on the
+ terrain s.t. they do not overlap with the platform.
+
+ Depending on the object type, the objects are generated with different parameters. The objects
+ The types of objects that can be generated are: ``"cylinder"``, ``"box"``, ``"cone"``.
+
+ The object parameters are specified in the configuration as curriculum parameters. The difficulty
+ is used to linearly interpolate between the minimum and maximum values of the parameters.
+
+ .. image:: ../../_static/terrains/trimesh/repeated_objects_cylinder_terrain.jpg
+ :width: 30%
+
+ .. image:: ../../_static/terrains/trimesh/repeated_objects_box_terrain.jpg
+ :width: 30%
+
+ .. image:: ../../_static/terrains/trimesh/repeated_objects_pyramid_terrain.jpg
+ :width: 30%
+
+ Args:
+ difficulty: The difficulty of the terrain. This is a value between 0 and 1.
+ cfg: The configuration for the terrain.
+
+ Returns:
+ A tuple containing the tri-mesh of the terrain and the origin of the terrain (in m).
+
+ Raises:
+ ValueError: If the object type is not supported. It must be either a string or a callable.
+ """
+ # import the object functions -- this is done here to avoid circular imports
+ from.mesh_terrains_cfgimport(
+ MeshRepeatedBoxesTerrainCfg,
+ MeshRepeatedCylindersTerrainCfg,
+ MeshRepeatedPyramidsTerrainCfg,
+ )
+
+ # if object type is a string, get the function: make_{object_type}
+ ifisinstance(cfg.object_type,str):
+ object_func=globals().get(f"make_{cfg.object_type}")
+ else:
+ object_func=cfg.object_type
+ ifnotcallable(object_func):
+ raiseValueError(f"The attribute 'object_type' must be a string or a callable. Received: {object_func}")
+
+ # Resolve the terrain configuration
+ # -- pass parameters to make calling simpler
+ cp_0=cfg.object_params_start
+ cp_1=cfg.object_params_end
+ # -- common parameters
+ num_objects=cp_0.num_objects+int(difficulty*(cp_1.num_objects-cp_0.num_objects))
+ height=cp_0.height+difficulty*(cp_1.height-cp_0.height)
+ # -- object specific parameters
+ # note: SIM114 requires duplicated logical blocks under a single body.
+ ifisinstance(cfg,MeshRepeatedBoxesTerrainCfg):
+ cp_0:MeshRepeatedBoxesTerrainCfg.ObjectCfg
+ cp_1:MeshRepeatedBoxesTerrainCfg.ObjectCfg
+ object_kwargs={
+ "length":cp_0.size[0]+difficulty*(cp_1.size[0]-cp_0.size[0]),
+ "width":cp_0.size[1]+difficulty*(cp_1.size[1]-cp_0.size[1]),
+ "max_yx_angle":cp_0.max_yx_angle+difficulty*(cp_1.max_yx_angle-cp_0.max_yx_angle),
+ "degrees":cp_0.degrees,
+ }
+ elifisinstance(cfg,MeshRepeatedPyramidsTerrainCfg):# noqa: SIM114
+ cp_0:MeshRepeatedPyramidsTerrainCfg.ObjectCfg
+ cp_1:MeshRepeatedPyramidsTerrainCfg.ObjectCfg
+ object_kwargs={
+ "radius":cp_0.radius+difficulty*(cp_1.radius-cp_0.radius),
+ "max_yx_angle":cp_0.max_yx_angle+difficulty*(cp_1.max_yx_angle-cp_0.max_yx_angle),
+ "degrees":cp_0.degrees,
+ }
+ elifisinstance(cfg,MeshRepeatedCylindersTerrainCfg):# noqa: SIM114
+ cp_0:MeshRepeatedCylindersTerrainCfg.ObjectCfg
+ cp_1:MeshRepeatedCylindersTerrainCfg.ObjectCfg
+ object_kwargs={
+ "radius":cp_0.radius+difficulty*(cp_1.radius-cp_0.radius),
+ "max_yx_angle":cp_0.max_yx_angle+difficulty*(cp_1.max_yx_angle-cp_0.max_yx_angle),
+ "degrees":cp_0.degrees,
+ }
+ else:
+ raiseValueError(f"Unknown terrain configuration: {cfg}")
+ # constants for the terrain
+ platform_clearance=0.1
+
+ # initialize list of meshes
+ meshes_list=list()
+ # compute quantities
+ origin=np.asarray((0.5*cfg.size[0],0.5*cfg.size[1],0.5*height))
+ platform_corners=np.asarray([
+ [origin[0]-cfg.platform_width/2,origin[1]-cfg.platform_width/2],
+ [origin[0]+cfg.platform_width/2,origin[1]+cfg.platform_width/2],
+ ])
+ platform_corners[0,:]*=1-platform_clearance
+ platform_corners[1,:]*=1+platform_clearance
+ # sample center for objects
+ whileTrue:
+ object_centers=np.zeros((num_objects,3))
+ object_centers[:,0]=np.random.uniform(0,cfg.size[0],num_objects)
+ object_centers[:,1]=np.random.uniform(0,cfg.size[1],num_objects)
+ # filter out the centers that are on the platform
+ is_within_platform_x=np.logical_and(
+ object_centers[:,0]>=platform_corners[0,0],object_centers[:,0]<=platform_corners[1,0]
+ )
+ is_within_platform_y=np.logical_and(
+ object_centers[:,1]>=platform_corners[0,1],object_centers[:,1]<=platform_corners[1,1]
+ )
+ masks=np.logical_and(is_within_platform_x,is_within_platform_y)
+ # if there are no objects on the platform, break
+ ifnotnp.any(masks):
+ break
+
+ # generate obstacles (but keep platform clean)
+ forindexinrange(len(object_centers)):
+ # randomize the height of the object
+ ob_height=height+np.random.uniform(-cfg.max_height_noise,cfg.max_height_noise)
+ ifob_height>0.0:
+ object_mesh=object_func(center=object_centers[index],height=ob_height,**object_kwargs)
+ meshes_list.append(object_mesh)
+
+ # generate a ground plane for the terrain
+ ground_plane=make_plane(cfg.size,height=0.0,center_zero=False)
+ meshes_list.append(ground_plane)
+ # generate a platform in the middle
+ dim=(cfg.platform_width,cfg.platform_width,0.5*height)
+ pos=(0.5*cfg.size[0],0.5*cfg.size[1],0.25*height)
+ platform=trimesh.creation.box(dim,trimesh.transformations.translation_matrix(pos))
+ meshes_list.append(platform)
+
+ returnmeshes_list,origin
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+fromdataclassesimportMISSING
+fromtypingimportLiteral
+
+importomni.isaac.lab.terrains.trimesh.mesh_terrainsasmesh_terrains
+importomni.isaac.lab.terrains.trimesh.utilsasmesh_utils_terrains
+fromomni.isaac.lab.utilsimportconfigclass
+
+from..terrain_generator_cfgimportSubTerrainBaseCfg
+
+"""
+Different trimesh terrain configurations.
+"""
+
+
+
[文档]@configclass
+classMeshPlaneTerrainCfg(SubTerrainBaseCfg):
+"""Configuration for a plane mesh terrain."""
+
+ function=mesh_terrains.flat_terrain
+
+
+
[文档]@configclass
+classMeshPyramidStairsTerrainCfg(SubTerrainBaseCfg):
+"""Configuration for a pyramid stair mesh terrain."""
+
+ function=mesh_terrains.pyramid_stairs_terrain
+
+ border_width:float=0.0
+"""The width of the border around the terrain (in m). Defaults to 0.0.
+
+ The border is a flat terrain with the same height as the terrain.
+ """
+ step_height_range:tuple[float,float]=MISSING
+"""The minimum and maximum height of the steps (in m)."""
+ step_width:float=MISSING
+"""The width of the steps (in m)."""
+ platform_width:float=1.0
+"""The width of the square platform at the center of the terrain. Defaults to 1.0."""
+ holes:bool=False
+"""If True, the terrain will have holes in the steps. Defaults to False.
+
+ If :obj:`holes` is True, the terrain will have pyramid stairs of length or width
+ :obj:`platform_width` (depending on the direction) with no steps in the remaining area. Additionally,
+ no border will be added.
+ """
+
+
+
[文档]@configclass
+classMeshInvertedPyramidStairsTerrainCfg(MeshPyramidStairsTerrainCfg):
+"""Configuration for an inverted pyramid stair mesh terrain.
+
+ Note:
+ This is the same as :class:`MeshPyramidStairsTerrainCfg` except that the steps are inverted.
+ """
+
+ function=mesh_terrains.inverted_pyramid_stairs_terrain
+
+
+
[文档]@configclass
+classMeshRandomGridTerrainCfg(SubTerrainBaseCfg):
+"""Configuration for a random grid mesh terrain."""
+
+ function=mesh_terrains.random_grid_terrain
+
+ grid_width:float=MISSING
+"""The width of the grid cells (in m)."""
+ grid_height_range:tuple[float,float]=MISSING
+"""The minimum and maximum height of the grid cells (in m)."""
+ platform_width:float=1.0
+"""The width of the square platform at the center of the terrain. Defaults to 1.0."""
+ holes:bool=False
+"""If True, the terrain will have holes in the steps. Defaults to False.
+
+ If :obj:`holes` is True, the terrain will have randomized grid cells only along the plane extending
+ from the platform (like a plus sign). The remaining area remains empty and no border will be added.
+ """
+
+
+
[文档]@configclass
+classMeshRailsTerrainCfg(SubTerrainBaseCfg):
+"""Configuration for a terrain with box rails as extrusions."""
+
+ function=mesh_terrains.rails_terrain
+
+ rail_thickness_range:tuple[float,float]=MISSING
+"""The thickness of the inner and outer rails (in m)."""
+ rail_height_range:tuple[float,float]=MISSING
+"""The minimum and maximum height of the rails (in m)."""
+ platform_width:float=1.0
+"""The width of the square platform at the center of the terrain. Defaults to 1.0."""
+
+
+
[文档]@configclass
+classMeshPitTerrainCfg(SubTerrainBaseCfg):
+"""Configuration for a terrain with a pit that leads out of the pit."""
+
+ function=mesh_terrains.pit_terrain
+
+ pit_depth_range:tuple[float,float]=MISSING
+"""The minimum and maximum height of the pit (in m)."""
+ platform_width:float=1.0
+"""The width of the square platform at the center of the terrain. Defaults to 1.0."""
+ double_pit:bool=False
+"""If True, the pit contains two levels of stairs. Defaults to False."""
+
+
+
[文档]@configclass
+classMeshBoxTerrainCfg(SubTerrainBaseCfg):
+"""Configuration for a terrain with boxes (similar to a pyramid)."""
+
+ function=mesh_terrains.box_terrain
+
+ box_height_range:tuple[float,float]=MISSING
+"""The minimum and maximum height of the box (in m)."""
+ platform_width:float=1.0
+"""The width of the square platform at the center of the terrain. Defaults to 1.0."""
+ double_box:bool=False
+"""If True, the pit contains two levels of stairs/boxes. Defaults to False."""
+
+
+
[文档]@configclass
+classMeshGapTerrainCfg(SubTerrainBaseCfg):
+"""Configuration for a terrain with a gap around the platform."""
+
+ function=mesh_terrains.gap_terrain
+
+ gap_width_range:tuple[float,float]=MISSING
+"""The minimum and maximum width of the gap (in m)."""
+ platform_width:float=1.0
+"""The width of the square platform at the center of the terrain. Defaults to 1.0."""
+
+
+
[文档]@configclass
+classMeshFloatingRingTerrainCfg(SubTerrainBaseCfg):
+"""Configuration for a terrain with a floating ring around the center."""
+
+ function=mesh_terrains.floating_ring_terrain
+
+ ring_width_range:tuple[float,float]=MISSING
+"""The minimum and maximum width of the ring (in m)."""
+ ring_height_range:tuple[float,float]=MISSING
+"""The minimum and maximum height of the ring (in m)."""
+ ring_thickness:float=MISSING
+"""The thickness (along z) of the ring (in m)."""
+ platform_width:float=1.0
+"""The width of the square platform at the center of the terrain. Defaults to 1.0."""
+
+
+
[文档]@configclass
+classMeshStarTerrainCfg(SubTerrainBaseCfg):
+"""Configuration for a terrain with a star pattern."""
+
+ function=mesh_terrains.star_terrain
+
+ num_bars:int=MISSING
+"""The number of bars per-side the star. Must be greater than 2."""
+ bar_width_range:tuple[float,float]=MISSING
+"""The minimum and maximum width of the bars in the star (in m)."""
+ bar_height_range:tuple[float,float]=MISSING
+"""The minimum and maximum height of the bars in the star (in m)."""
+ platform_width:float=1.0
+"""The width of the cylindrical platform at the center of the terrain. Defaults to 1.0."""
+
+
+
[文档]@configclass
+classMeshRepeatedObjectsTerrainCfg(SubTerrainBaseCfg):
+"""Base configuration for a terrain with repeated objects."""
+
+
[文档]@configclass
+ classObjectCfg:
+"""Configuration of repeated objects."""
+
+ num_objects:int=MISSING
+"""The number of objects to add to the terrain."""
+ height:float=MISSING
+"""The height (along z) of the object (in m)."""
+
+ function=mesh_terrains.repeated_objects_terrain
+
+ object_type:Literal["cylinder","box","cone"]|callable=MISSING
+"""The type of object to generate.
+
+ The type can be a string or a callable. If it is a string, the function will look for a function called
+ ``make_{object_type}`` in the current module scope. If it is a callable, the function will
+ use the callable to generate the object.
+ """
+ object_params_start:ObjectCfg=MISSING
+"""The object curriculum parameters at the start of the curriculum."""
+ object_params_end:ObjectCfg=MISSING
+"""The object curriculum parameters at the end of the curriculum."""
+
+ max_height_noise:float=0.0
+"""The maximum amount of noise to add to the height of the objects (in m). Defaults to 0.0."""
+ platform_width:float=1.0
+"""The width of the cylindrical platform at the center of the terrain. Defaults to 1.0."""
+
+
+
[文档]@configclass
+classMeshRepeatedPyramidsTerrainCfg(MeshRepeatedObjectsTerrainCfg):
+"""Configuration for a terrain with repeated pyramids."""
+
+
[文档]@configclass
+ classObjectCfg(MeshRepeatedObjectsTerrainCfg.ObjectCfg):
+"""Configuration for a curriculum of repeated pyramids."""
+
+ radius:float=MISSING
+"""The radius of the pyramids (in m)."""
+ max_yx_angle:float=0.0
+"""The maximum angle along the y and x axis. Defaults to 0.0."""
+ degrees:bool=True
+"""Whether the angle is in degrees. Defaults to True."""
+
+ object_type=mesh_utils_terrains.make_cone
+
+ object_params_start:ObjectCfg=MISSING
+"""The object curriculum parameters at the start of the curriculum."""
+ object_params_end:ObjectCfg=MISSING
+"""The object curriculum parameters at the end of the curriculum."""
+
+
+
[文档]@configclass
+classMeshRepeatedBoxesTerrainCfg(MeshRepeatedObjectsTerrainCfg):
+"""Configuration for a terrain with repeated boxes."""
+
+
[文档]@configclass
+ classObjectCfg(MeshRepeatedObjectsTerrainCfg.ObjectCfg):
+"""Configuration for repeated boxes."""
+
+ size:tuple[float,float]=MISSING
+"""The width (along x) and length (along y) of the box (in m)."""
+ max_yx_angle:float=0.0
+"""The maximum angle along the y and x axis. Defaults to 0.0."""
+ degrees:bool=True
+"""Whether the angle is in degrees. Defaults to True."""
+
+ object_type=mesh_utils_terrains.make_box
+
+ object_params_start:ObjectCfg=MISSING
+"""The box curriculum parameters at the start of the curriculum."""
+ object_params_end:ObjectCfg=MISSING
+"""The box curriculum parameters at the end of the curriculum."""
+
+
+
[文档]@configclass
+classMeshRepeatedCylindersTerrainCfg(MeshRepeatedObjectsTerrainCfg):
+"""Configuration for a terrain with repeated cylinders."""
+
+
[文档]@configclass
+ classObjectCfg(MeshRepeatedObjectsTerrainCfg.ObjectCfg):
+"""Configuration for repeated cylinder."""
+
+ radius:float=MISSING
+"""The radius of the pyramids (in m)."""
+ max_yx_angle:float=0.0
+"""The maximum angle along the y and x axis. Defaults to 0.0."""
+ degrees:bool=True
+"""Whether the angle is in degrees. Defaults to True."""
+
+ object_type=mesh_utils_terrains.make_cylinder
+
+ object_params_start:ObjectCfg=MISSING
+"""The box curriculum parameters at the start of the curriculum."""
+ object_params_end:ObjectCfg=MISSING
+"""The box curriculum parameters at the end of the curriculum."""
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+# needed to import for allowing type-hinting: np.ndarray | torch.Tensor | None
+from__future__importannotations
+
+importnumpyasnp
+importtorch
+importtrimesh
+
+importwarpaswp
+
+fromomni.isaac.lab.utils.warpimportraycast_mesh
+
+
+
[文档]defcolor_meshes_by_height(meshes:list[trimesh.Trimesh],**kwargs)->trimesh.Trimesh:
+"""
+ Color the vertices of a trimesh object based on the z-coordinate (height) of each vertex,
+ using the Turbo colormap. If the z-coordinates are all the same, the vertices will be colored
+ with a single color.
+
+ Args:
+ meshes: A list of trimesh objects.
+
+ Keyword Args:
+ color: A list of 3 integers in the range [0,255] representing the RGB
+ color of the mesh. Used when the z-coordinates of all vertices are the same.
+ Defaults to [172, 216, 230].
+ color_map: The name of the color map to be used. Defaults to "turbo".
+
+ Returns:
+ A trimesh object with the vertices colored based on the z-coordinate (height) of each vertex.
+ """
+ # Combine all meshes into a single mesh
+ mesh=trimesh.util.concatenate(meshes)
+ # Get the z-coordinates of each vertex
+ heights=mesh.vertices[:,2]
+ # Check if the z-coordinates are all the same
+ ifnp.max(heights)==np.min(heights):
+ # Obtain a single color: light blue
+ color=kwargs.pop("color",(172,216,230))
+ color=np.asarray(color,dtype=np.uint8)
+ # Set the color for all vertices
+ mesh.visual.vertex_colors=color
+ else:
+ # Normalize the heights to [0,1]
+ heights_normalized=(heights-np.min(heights))/(np.max(heights)-np.min(heights))
+ # clip lower and upper bounds to have better color mapping
+ heights_normalized=np.clip(heights_normalized,0.1,0.9)
+ # Get the color for each vertex based on the height
+ color_map=kwargs.pop("color_map","turbo")
+ colors=trimesh.visual.color.interpolate(heights_normalized,color_map=color_map)
+ # Set the vertex colors
+ mesh.visual.vertex_colors=colors
+ # Return the mesh
+ returnmesh
+
+
+
[文档]defcreate_prim_from_mesh(prim_path:str,mesh:trimesh.Trimesh,**kwargs):
+"""Create a USD prim with mesh defined from vertices and triangles.
+
+ The function creates a USD prim with a mesh defined from vertices and triangles. It performs the
+ following steps:
+
+ - Create a USD Xform prim at the path :obj:`prim_path`.
+ - Create a USD prim with a mesh defined from the input vertices and triangles at the path :obj:`{prim_path}/mesh`.
+ - Assign a physics material to the mesh at the path :obj:`{prim_path}/physicsMaterial`.
+ - Assign a visual material to the mesh at the path :obj:`{prim_path}/visualMaterial`.
+
+ Args:
+ prim_path: The path to the primitive to be created.
+ mesh: The mesh to be used for the primitive.
+
+ Keyword Args:
+ translation: The translation of the terrain. Defaults to None.
+ orientation: The orientation of the terrain. Defaults to None.
+ visual_material: The visual material to apply. Defaults to None.
+ physics_material: The physics material to apply. Defaults to None.
+ """
+ # need to import these here to prevent isaacsim launching when importing this module
+ importomni.isaac.core.utils.primsasprim_utils
+ frompxrimportUsdGeom
+
+ importomni.isaac.lab.simassim_utils
+
+ # create parent prim
+ prim_utils.create_prim(prim_path,"Xform")
+ # create mesh prim
+ prim=prim_utils.create_prim(
+ f"{prim_path}/mesh",
+ "Mesh",
+ translation=kwargs.get("translation"),
+ orientation=kwargs.get("orientation"),
+ attributes={
+ "points":mesh.vertices,
+ "faceVertexIndices":mesh.faces.flatten(),
+ "faceVertexCounts":np.asarray([3]*len(mesh.faces)),
+ "subdivisionScheme":"bilinear",
+ },
+ )
+ # apply collider properties
+ collider_cfg=sim_utils.CollisionPropertiesCfg(collision_enabled=True)
+ sim_utils.define_collision_properties(prim.GetPrimPath(),collider_cfg)
+ # add rgba color to the mesh primvars
+ ifmesh.visual.vertex_colorsisnotNone:
+ # obtain color from the mesh
+ rgba_colors=np.asarray(mesh.visual.vertex_colors).astype(np.float32)/255.0
+ # displayColor is a primvar attribute that is used to color the mesh
+ color_prim_attr=prim.GetAttribute("primvars:displayColor")
+ color_prim_var=UsdGeom.Primvar(color_prim_attr)
+ color_prim_var.SetInterpolation(UsdGeom.Tokens.vertex)
+ color_prim_attr.Set(rgba_colors[:,:3])
+ # displayOpacity is a primvar attribute that is used to set the opacity of the mesh
+ display_prim_attr=prim.GetAttribute("primvars:displayOpacity")
+ display_prim_var=UsdGeom.Primvar(display_prim_attr)
+ display_prim_var.SetInterpolation(UsdGeom.Tokens.vertex)
+ display_prim_var.Set(rgba_colors[:,3])
+
+ # create visual material
+ ifkwargs.get("visual_material")isnotNone:
+ visual_material_cfg:sim_utils.VisualMaterialCfg=kwargs.get("visual_material")
+ # spawn the material
+ visual_material_cfg.func(f"{prim_path}/visualMaterial",visual_material_cfg)
+ sim_utils.bind_visual_material(prim.GetPrimPath(),f"{prim_path}/visualMaterial")
+ # create physics material
+ ifkwargs.get("physics_material")isnotNone:
+ physics_material_cfg:sim_utils.RigidBodyMaterialCfg=kwargs.get("physics_material")
+ # spawn the material
+ physics_material_cfg.func(f"{prim_path}/physicsMaterial",physics_material_cfg)
+ sim_utils.bind_physics_material(prim.GetPrimPath(),f"{prim_path}/physicsMaterial")
+
+
+
[文档]deffind_flat_patches(
+ wp_mesh:wp.Mesh,
+ num_patches:int,
+ patch_radius:float|list[float],
+ origin:np.ndarray|torch.Tensor|tuple[float,float,float],
+ x_range:tuple[float,float],
+ y_range:tuple[float,float],
+ z_range:tuple[float,float],
+ max_height_diff:float,
+)->torch.Tensor:
+"""Finds flat patches of given radius in the input mesh.
+
+ The function finds flat patches of given radius based on the search space defined by the input ranges.
+ The search space is characterized by origin in the mesh frame, and the x, y, and z ranges. The x and y
+ ranges are used to sample points in the 2D region around the origin, and the z range is used to filter
+ patches based on the height of the points.
+
+ The function performs rejection sampling to find the patches based on the following steps:
+
+ 1. Sample patch locations in the 2D region around the origin.
+ 2. Define a ring of points around each patch location to query the height of the points using ray-casting.
+ 3. Reject patches that are outside the z range or have a height difference that is too large.
+ 4. Keep sampling until all patches are valid.
+
+ Args:
+ wp_mesh: The warp mesh to find patches in.
+ num_patches: The desired number of patches to find.
+ patch_radius: The radii used to form patches. If a list is provided, multiple patch sizes are checked.
+ This is useful to deal with holes or other artifacts in the mesh.
+ origin: The origin defining the center of the search space. This is specified in the mesh frame.
+ x_range: The range of X coordinates to sample from.
+ y_range: The range of Y coordinates to sample from.
+ z_range: The range of valid Z coordinates used for filtering patches.
+ max_height_diff: The maximum allowable distance between the lowest and highest points
+ on a patch to consider it as valid. If the difference is greater than this value,
+ the patch is rejected.
+
+ Returns:
+ A tensor of shape (num_patches, 3) containing the flat patches. The patches are defined in the mesh frame.
+
+ Raises:
+ RuntimeError: If the function fails to find valid patches. This can happen if the input parameters
+ are not suitable for finding valid patches and maximum number of iterations is reached.
+ """
+ # set device to warp mesh device
+ device=wp.device_to_torch(wp_mesh.device)
+
+ # resolve inputs to consistent type
+ # -- patch radii
+ ifisinstance(patch_radius,float):
+ patch_radius=[patch_radius]
+ # -- origin
+ ifisinstance(origin,np.ndarray):
+ origin=torch.from_numpy(origin).to(torch.float).to(device)
+ elifisinstance(origin,torch.Tensor):
+ origin=origin.to(device)
+ else:
+ origin=torch.tensor(origin,dtype=torch.float,device=device)
+
+ # create ranges for the x and y coordinates around the origin.
+ # The provided ranges are bounded by the mesh's bounding box.
+ x_range=(
+ max(x_range[0]+origin[0].item(),wp_mesh.points.numpy()[:,0].min()),
+ min(x_range[1]+origin[0].item(),wp_mesh.points.numpy()[:,0].max()),
+ )
+ y_range=(
+ max(y_range[0]+origin[1].item(),wp_mesh.points.numpy()[:,1].min()),
+ min(y_range[1]+origin[1].item(),wp_mesh.points.numpy()[:,1].max()),
+ )
+ z_range=(
+ z_range[0]+origin[2].item(),
+ z_range[1]+origin[2].item(),
+ )
+
+ # create a circle of points around (0, 0) to query validity of the patches
+ # the ring of points is uniformly distributed around the circle
+ angle=torch.linspace(0,2*np.pi,10,device=device)
+ query_x=[]
+ query_y=[]
+ forradiusinpatch_radius:
+ query_x.append(radius*torch.cos(angle))
+ query_y.append(radius*torch.sin(angle))
+ query_x=torch.cat(query_x).unsqueeze(1)# dim: (num_radii * 10, 1)
+ query_y=torch.cat(query_y).unsqueeze(1)# dim: (num_radii * 10, 1)
+ # dim: (num_radii * 10, 3)
+ query_points=torch.cat([query_x,query_y,torch.zeros_like(query_x)],dim=-1)
+
+ # create buffers
+ # -- a buffer to store indices of points that are not valid
+ points_ids=torch.arange(num_patches,device=device)
+ # -- a buffer to store the flat patches locations
+ flat_patches=torch.zeros(num_patches,3,device=device)
+
+ # sample points and raycast to find the height.
+ # 1. Reject points that are outside the z_range or have a height difference that is too large.
+ # 2. Keep sampling until all points are valid.
+ iter_count=0
+ whilelen(points_ids)>0anditer_count<10000:
+ # sample points in the 2D region around the origin
+ pos_x=torch.empty(len(points_ids),device=device).uniform_(*x_range)
+ pos_y=torch.empty(len(points_ids),device=device).uniform_(*y_range)
+ flat_patches[points_ids,:2]=torch.stack([pos_x,pos_y],dim=-1)
+
+ # define the query points to check validity of the patch
+ # dim: (num_patches, num_radii * 10, 3)
+ points=flat_patches[points_ids].unsqueeze(1)+query_points
+ points[...,2]=100.0
+ # ray-cast direction is downwards
+ dirs=torch.zeros_like(points)
+ dirs[...,2]=-1.0
+
+ # ray-cast to find the height of the patches
+ ray_hits=raycast_mesh(points.view(-1,3),dirs.view(-1,3),wp_mesh)[0]
+ heights=ray_hits.view(points.shape)[...,2]
+ # set the height of the patches
+ # note: for invalid patches, they would be overwritten in the next iteration
+ # so it's safe to set the height to the last value
+ flat_patches[points_ids,2]=heights[...,-1]
+
+ # check validity
+ # -- height is within the z range
+ not_valid=torch.any(torch.logical_or(heights<z_range[0],heights>z_range[1]),dim=1)
+ # -- height difference is within the max height difference
+ not_valid=torch.logical_or(not_valid,(heights.max(dim=1)[0]-heights.min(dim=1)[0])>max_height_diff)
+
+ # remove invalid patches indices
+ points_ids=points_ids[not_valid]
+ # increment count
+ iter_count+=1
+
+ # check all patches are valid
+ iflen(points_ids)>0:
+ raiseRuntimeError(
+ "Failed to find valid patches! Please check the input parameters."
+ f"\n\tMaximum number of iterations reached: {iter_count}"
+ f"\n\tNumber of invalid patches: {len(points_ids)}"
+ f"\n\tMaximum height difference: {max_height_diff}"
+ )
+
+ # return the flat patches (in the mesh frame)
+ returnflat_patches-origin
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Sub-module containing utilities for working with different array backends."""
+
+# needed to import for allowing type-hinting: torch.device | str | None
+from__future__importannotations
+
+importnumpyasnp
+importtorch
+fromtypingimportUnion
+
+importwarpaswp
+
+TensorData=Union[np.ndarray,torch.Tensor,wp.array]
+"""Type definition for a tensor data.
+
+Union of numpy, torch, and warp arrays.
+"""
+
+TENSOR_TYPES={
+ "numpy":np.ndarray,
+ "torch":torch.Tensor,
+ "warp":wp.array,
+}
+"""A dictionary containing the types for each backend.
+
+The keys are the name of the backend ("numpy", "torch", "warp") and the values are the corresponding type
+(``np.ndarray``, ``torch.Tensor``, ``wp.array``).
+"""
+
+TENSOR_TYPE_CONVERSIONS={
+ "numpy":{wp.array:lambdax:x.numpy(),torch.Tensor:lambdax:x.detach().cpu().numpy()},
+ "torch":{wp.array:lambdax:wp.torch.to_torch(x),np.ndarray:lambdax:torch.from_numpy(x)},
+ "warp":{np.array:lambdax:wp.array(x),torch.Tensor:lambdax:wp.torch.from_torch(x)},
+}
+"""A nested dictionary containing the conversion functions for each backend.
+
+The keys of the outer dictionary are the name of target backend ("numpy", "torch", "warp"). The keys of the
+inner dictionary are the source backend (``np.ndarray``, ``torch.Tensor``, ``wp.array``).
+"""
+
+
+
[文档]defconvert_to_torch(
+ array:TensorData,
+ dtype:torch.dtype=None,
+ device:torch.device|str|None=None,
+)->torch.Tensor:
+"""Converts a given array into a torch tensor.
+
+ The function tries to convert the array to a torch tensor. If the array is a numpy/warp arrays, or python
+ list/tuples, it is converted to a torch tensor. If the array is already a torch tensor, it is returned
+ directly.
+
+ If ``device`` is None, then the function deduces the current device of the data. For numpy arrays,
+ this defaults to "cpu", for torch tensors it is "cpu" or "cuda", and for warp arrays it is "cuda".
+
+ Note:
+ Since PyTorch does not support unsigned integer types, unsigned integer arrays are converted to
+ signed integer arrays. This is done by casting the array to the corresponding signed integer type.
+
+ Args:
+ array: The input array. It can be a numpy array, warp array, python list/tuple, or torch tensor.
+ dtype: Target data-type for the tensor.
+ device: The target device for the tensor. Defaults to None.
+
+ Returns:
+ The converted array as torch tensor.
+ """
+ # Convert array to tensor
+ # if the datatype is not currently supported by torch we need to improvise
+ # supported types are: https://pytorch.org/docs/stable/tensors.html
+ ifisinstance(array,torch.Tensor):
+ tensor=array
+ elifisinstance(array,np.ndarray):
+ ifarray.dtype==np.uint32:
+ array=array.astype(np.int32)
+ # need to deal with object arrays (np.void) separately
+ tensor=torch.from_numpy(array)
+ elifisinstance(array,wp.array):
+ ifarray.dtype==wp.uint32:
+ array=array.view(wp.int32)
+ tensor=wp.to_torch(array)
+ else:
+ tensor=torch.Tensor(array)
+ # Convert tensor to the right device
+ ifdeviceisnotNoneandstr(tensor.device)!=str(device):
+ tensor=tensor.to(device)
+ # Convert dtype of tensor if requested
+ ifdtypeisnotNoneandtensor.dtype!=dtype:
+ tensor=tensor.type(dtype)
+
+ returntensor
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Sub-module that defines the host-server where assets and resources are stored.
+
+By default, we use the Isaac Sim Nucleus Server for hosting assets and resources. This makes
+distribution of the assets easier and makes the repository smaller in size code-wise.
+
+For more information, please check information on `Omniverse Nucleus`_.
+
+.. _Omniverse Nucleus: https://docs.omniverse.nvidia.com/nucleus/latest/overview/overview.html
+"""
+
+importio
+importos
+importtempfile
+fromtypingimportLiteral
+
+importcarb
+importomni.client
+
+NUCLEUS_ASSET_ROOT_DIR=carb.settings.get_settings().get("/persistent/isaac/asset_root/cloud")
+"""Path to the root directory on the Nucleus Server."""
+
+NVIDIA_NUCLEUS_DIR=f"{NUCLEUS_ASSET_ROOT_DIR}/NVIDIA"
+"""Path to the root directory on the NVIDIA Nucleus Server."""
+
+ISAAC_NUCLEUS_DIR=f"{NUCLEUS_ASSET_ROOT_DIR}/Isaac"
+"""Path to the ``Isaac`` directory on the NVIDIA Nucleus Server."""
+
+ISAACLAB_NUCLEUS_DIR=f"{ISAAC_NUCLEUS_DIR}/IsaacLab"
+"""Path to the ``Isaac/IsaacLab`` directory on the NVIDIA Nucleus Server."""
+
+
+
[文档]defcheck_file_path(path:str)->Literal[0,1,2]:
+"""Checks if a file exists on the Nucleus Server or locally.
+
+ Args:
+ path: The path to the file.
+
+ Returns:
+ The status of the file. Possible values are listed below.
+
+ * :obj:`0` if the file does not exist
+ * :obj:`1` if the file exists locally
+ * :obj:`2` if the file exists on the Nucleus Server
+ """
+ ifos.path.isfile(path):
+ return1
+ elifomni.client.stat(path)[0]==omni.client.Result.OK:
+ return2
+ else:
+ return0
+
+
+
[文档]defretrieve_file_path(path:str,download_dir:str|None=None,force_download:bool=True)->str:
+"""Retrieves the path to a file on the Nucleus Server or locally.
+
+ If the file exists locally, then the absolute path to the file is returned.
+ If the file exists on the Nucleus Server, then the file is downloaded to the local machine
+ and the absolute path to the file is returned.
+
+ Args:
+ path: The path to the file.
+ download_dir: The directory where the file should be downloaded. Defaults to None, in which
+ case the file is downloaded to the system's temporary directory.
+ force_download: Whether to force download the file from the Nucleus Server. This will overwrite
+ the local file if it exists. Defaults to True.
+
+ Returns:
+ The path to the file on the local machine.
+
+ Raises:
+ FileNotFoundError: When the file not found locally or on Nucleus Server.
+ RuntimeError: When the file cannot be copied from the Nucleus Server to the local machine. This
+ can happen when the file already exists locally and :attr:`force_download` is set to False.
+ """
+ # check file status
+ file_status=check_file_path(path)
+ iffile_status==1:
+ returnos.path.abspath(path)
+ eliffile_status==2:
+ # resolve download directory
+ ifdownload_dirisNone:
+ download_dir=tempfile.gettempdir()
+ else:
+ download_dir=os.path.abspath(download_dir)
+ # create download directory if it does not exist
+ ifnotos.path.exists(download_dir):
+ os.makedirs(download_dir)
+ # download file in temp directory using os
+ file_name=os.path.basename(omni.client.break_url(path).path)
+ target_path=os.path.join(download_dir,file_name)
+ # check if file already exists locally
+ ifnotos.path.isfile(target_path)orforce_download:
+ # copy file to local machine
+ result=omni.client.copy(path,target_path)
+ ifresult!=omni.client.Result.OKandforce_download:
+ raiseRuntimeError(f"Unable to copy file: '{path}'. Is the Nucleus Server running?")
+ returnos.path.abspath(target_path)
+ else:
+ raiseFileNotFoundError(f"Unable to find the file: {path}")
+
+
+
[文档]defread_file(path:str)->io.BytesIO:
+"""Reads a file from the Nucleus Server or locally.
+
+ Args:
+ path: The path to the file.
+
+ Raises:
+ FileNotFoundError: When the file not found locally or on Nucleus Server.
+
+ Returns:
+ The content of the file.
+ """
+ # check file status
+ file_status=check_file_path(path)
+ iffile_status==1:
+ withopen(path,"rb")asf:
+ returnio.BytesIO(f.read())
+ eliffile_status==2:
+ file_content=omni.client.read_file(path)[2]
+ returnio.BytesIO(memoryview(file_content).tobytes())
+ else:
+ raiseFileNotFoundError(f"Unable to find the file: {path}")
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+importtorch
+fromcollections.abcimportSequence
+
+
+
[文档]classCircularBuffer:
+"""Circular buffer for storing a history of batched tensor data.
+
+ This class implements a circular buffer for storing a history of batched tensor data. The buffer is
+ initialized with a maximum length and a batch size. The data is stored in a circular fashion, and the
+ data can be retrieved in a LIFO (Last-In-First-Out) fashion. The buffer is designed to be used in
+ multi-environment settings, where each environment has its own data.
+
+ The shape of the appended data is expected to be (batch_size, ...), where the first dimension is the
+ batch dimension. Correspondingly, the shape of the ring buffer is (max_len, batch_size, ...).
+ """
+
+
[文档]def__init__(self,max_len:int,batch_size:int,device:str):
+"""Initialize the circular buffer.
+
+ Args:
+ max_len: The maximum length of the circular buffer. The minimum allowed value is 1.
+ batch_size: The batch dimension of the data.
+ device: The device used for processing.
+
+ Raises:
+ ValueError: If the buffer size is less than one.
+ """
+ ifmax_len<1:
+ raiseValueError(f"The buffer size should be greater than zero. However, it is set to {max_len}!")
+ # set the parameters
+ self._batch_size=batch_size
+ self._device=device
+ self._ALL_INDICES=torch.arange(batch_size,device=device)
+
+ # max length tensor for comparisons
+ self._max_len=torch.full((batch_size,),max_len,dtype=torch.int,device=device)
+ # number of data pushes passed since the last call to :meth:`reset`
+ self._num_pushes=torch.zeros(batch_size,dtype=torch.long,device=device)
+ # the pointer to the current head of the circular buffer (-1 means not initialized)
+ self._pointer:int=-1
+ # the actual buffer for data storage
+ # note: this is initialized on the first call to :meth:`append`
+ self._buffer:torch.Tensor=None# type: ignore
+
+"""
+ Properties.
+ """
+
+ @property
+ defbatch_size(self)->int:
+"""The batch size of the ring buffer."""
+ returnself._batch_size
+
+ @property
+ defdevice(self)->str:
+"""The device used for processing."""
+ returnself._device
+
+ @property
+ defmax_length(self)->int:
+"""The maximum length of the ring buffer."""
+ returnint(self._max_len[0].item())
+
+ @property
+ defcurrent_length(self)->torch.Tensor:
+"""The current length of the buffer. Shape is (batch_size,).
+
+ Since the buffer is circular, the current length is the minimum of the number of pushes
+ and the maximum length.
+ """
+ returntorch.minimum(self._num_pushes,self._max_len)
+
+"""
+ Operations.
+ """
+
+
[文档]defreset(self,batch_ids:Sequence[int]|None=None):
+"""Reset the circular buffer at the specified batch indices.
+
+ Args:
+ batch_ids: Elements to reset in the batch dimension. Default is None, which resets all the batch indices.
+ """
+ # resolve all indices
+ ifbatch_idsisNone:
+ batch_ids=slice(None)
+ # reset the number of pushes for the specified batch indices
+ # note: we don't need to reset the buffer since it will be overwritten. The pointer handles this.
+ self._num_pushes[batch_ids]=0
+
+
[文档]defappend(self,data:torch.Tensor):
+"""Append the data to the circular buffer.
+
+ Args:
+ data: The data to append to the circular buffer. The first dimension should be the batch dimension.
+ Shape is (batch_size, ...).
+
+ Raises:
+ ValueError: If the input data has a different batch size than the buffer.
+ """
+ # check the batch size
+ ifdata.shape[0]!=self.batch_size:
+ raiseValueError(f"The input data has {data.shape[0]} environments while expecting {self.batch_size}")
+
+ # at the fist call, initialize the buffer
+ ifself._bufferisNone:
+ self._pointer=-1
+ self._buffer=torch.empty((self.max_length,*data.shape),dtype=data.dtype,device=self._device)
+ # move the head to the next slot
+ self._pointer=(self._pointer+1)%self.max_length
+ # add the new data to the last layer
+ self._buffer[self._pointer]=data.to(self._device)
+ # increment number of number of pushes
+ self._num_pushes+=1
+
+ def__getitem__(self,key:torch.Tensor)->torch.Tensor:
+"""Retrieve the data from the circular buffer in last-in-first-out (LIFO) fashion.
+
+ If the requested index is larger than the number of pushes since the last call to :meth:`reset`,
+ the oldest stored data is returned.
+
+ Args:
+ key: The index to retrieve from the circular buffer. The index should be less than the number of pushes
+ since the last call to :meth:`reset`. Shape is (batch_size,).
+
+ Returns:
+ The data from the circular buffer. Shape is (batch_size, ...).
+
+ Raises:
+ ValueError: If the input key has a different batch size than the buffer.
+ RuntimeError: If the buffer is empty.
+ """
+ # check the batch size
+ iflen(key)!=self.batch_size:
+ raiseValueError(f"The argument 'key' has length {key.shape[0]}, while expecting {self.batch_size}")
+ # check if the buffer is empty
+ iftorch.any(self._num_pushes==0)orself._bufferisNone:
+ raiseRuntimeError("Attempting to retrieve data on an empty circular buffer. Please append data first.")
+
+ # admissible lag
+ valid_keys=torch.minimum(key,self._num_pushes-1)
+ # the index in the circular buffer (pointer points to the last+1 index)
+ index_in_buffer=torch.remainder(self._pointer-valid_keys,self.max_length)
+ # return output
+ returnself._buffer[index_in_buffer,self._ALL_INDICES]
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+# needed because we concatenate int and torch.Tensor in the type hints
+from__future__importannotations
+
+importtorch
+fromcollections.abcimportSequence
+
+from.circular_bufferimportCircularBuffer
+
+
+
[文档]classDelayBuffer:
+"""Delay buffer that allows retrieving stored data with delays.
+
+ This class uses a batched circular buffer to store input data. Different to a standard circular buffer,
+ which uses the LIFO (last-in-first-out) principle to retrieve the data, the delay buffer class allows
+ retrieving data based on the lag set by the user. For instance, if the delay set inside the buffer
+ is 1, then the second last entry from the stream is retrieved. If it is 2, then the third last entry
+ and so on.
+
+ The class supports storing a batched tensor data. This means that the shape of the appended data
+ is expected to be (batch_size, ...), where the first dimension is the batch dimension. Correspondingly,
+ the delay can be set separately for each batch index. If the requested delay is larger than the current
+ length of the underlying buffer, the most recent entry is returned.
+
+ .. note::
+ By default, the delay buffer has no delay, meaning that the data is returned as is.
+ """
+
+
[文档]def__init__(self,history_length:int,batch_size:int,device:str):
+"""Initialize the delay buffer.
+
+ Args:
+ history_length: The history of the buffer, i.e., the number of time steps in the past that the data
+ will be buffered. It is recommended to set this value equal to the maximum time-step lag that
+ is expected. The minimum acceptable value is zero, which means only the latest data is stored.
+ batch_size: The batch dimension of the data.
+ device: The device used for processing.
+ """
+ # set the parameters
+ self._history_length=max(0,history_length)
+
+ # the buffer size: current data plus the history length
+ self._circular_buffer=CircularBuffer(self._history_length+1,batch_size,device)
+
+ # the minimum and maximum lags across all environments.
+ self._min_time_lag=0
+ self._max_time_lag=0
+ # the lags for each environment.
+ self._time_lags=torch.zeros(batch_size,dtype=torch.int,device=device)
+
+"""
+ Properties.
+ """
+
+ @property
+ defbatch_size(self)->int:
+"""The batch size of the ring buffer."""
+ returnself._circular_buffer.batch_size
+
+ @property
+ defdevice(self)->str:
+"""The device used for processing."""
+ returnself._circular_buffer.device
+
+ @property
+ defhistory_length(self)->int:
+"""The history length of the delay buffer.
+
+ If zero, only the latest data is stored. If one, the latest and the previous data are stored, and so on.
+ """
+ returnself._history_length
+
+ @property
+ defmin_time_lag(self)->int:
+"""Minimum amount of time steps that can be delayed.
+
+ This value cannot be negative or larger than :attr:`max_time_lag`.
+ """
+ returnself._min_time_lag
+
+ @property
+ defmax_time_lag(self)->int:
+"""Maximum amount of time steps that can be delayed.
+
+ This value cannot be greater than :attr:`history_length`.
+ """
+ returnself._max_time_lag
+
+ @property
+ deftime_lags(self)->torch.Tensor:
+"""The time lag across each batch index.
+
+ The shape of the tensor is (batch_size, ). The value at each index represents the delay for that index.
+ This value is used to retrieve the data from the buffer.
+ """
+ returnself._time_lags
+
+"""
+ Operations.
+ """
+
+
[文档]defset_time_lag(self,time_lag:int|torch.Tensor,batch_ids:Sequence[int]|None=None):
+"""Sets the time lag for the delay buffer across the provided batch indices.
+
+ Args:
+ time_lag: The desired delay for the buffer.
+
+ * If an integer is provided, the same delay is set for the provided batch indices.
+ * If a tensor is provided, the delay is set for each batch index separately. The shape of the tensor
+ should be (len(batch_ids),).
+
+ batch_ids: The batch indices for which the time lag is set. Default is None, which sets the time lag
+ for all batch indices.
+
+ Raises:
+ TypeError: If the type of the :attr:`time_lag` is not int or integer tensor.
+ ValueError: If the minimum time lag is negative or the maximum time lag is larger than the history length.
+ """
+ # resolve batch indices
+ ifbatch_idsisNone:
+ batch_ids=slice(None)
+
+ # parse requested time_lag
+ ifisinstance(time_lag,int):
+ # set the time lags across provided batch indices
+ self._time_lags[batch_ids]=time_lag
+ elifisinstance(time_lag,torch.Tensor):
+ # check valid dtype for time_lag: must be int or long
+ iftime_lag.dtypenotin[torch.int,torch.long]:
+ raiseTypeError(f"Invalid dtype for time_lag: {time_lag.dtype}. Expected torch.int or torch.long.")
+ # set the time lags
+ self._time_lags[batch_ids]=time_lag.to(device=self.device)
+ else:
+ raiseTypeError(f"Invalid type for time_lag: {type(time_lag)}. Expected int or integer tensor.")
+
+ # compute the min and max time lag
+ self._min_time_lag=int(torch.min(self._time_lags).item())
+ self._max_time_lag=int(torch.max(self._time_lags).item())
+ # check that time_lag is feasible
+ ifself._min_time_lag<0:
+ raiseValueError(f"The minimum time lag cannot be negative. Received: {self._min_time_lag}")
+ ifself._max_time_lag>self._history_length:
+ raiseValueError(
+ f"The maximum time lag cannot be larger than the history length. Received: {self._max_time_lag}"
+ )
+
+
[文档]defreset(self,batch_ids:Sequence[int]|None=None):
+"""Reset the data in the delay buffer at the specified batch indices.
+
+ Args:
+ batch_ids: Elements to reset in the batch dimension. Default is None, which resets all the batch indices.
+ """
+ self._circular_buffer.reset(batch_ids)
+
+
[文档]defcompute(self,data:torch.Tensor)->torch.Tensor:
+"""Append the input data to the buffer and returns a stale version of the data based on time lag delay.
+
+ If the requested delay is larger than the number of buffered data points since the last reset,
+ the function returns the latest data. For instance, if the delay is set to 2 and only one data point
+ is stored in the buffer, the function will return the latest data. If the delay is set to 2 and three
+ data points are stored, the function will return the first data point.
+
+ Args:
+ data: The input data. Shape is (batch_size, ...).
+
+ Returns:
+ The delayed version of the data from the stored buffer. Shape is (batch_size, ...).
+ """
+ # add the new data to the last layer
+ self._circular_buffer.append(data)
+ # return output
+ delayed_data=self._circular_buffer[self._time_lags]
+ returndelayed_data.clone()
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+importtorch
+fromdataclassesimportdataclass
+
+
+
[文档]@dataclass
+classTimestampedBuffer:
+"""A buffer class containing data and its timestamp.
+
+ This class is a simple data container that stores a tensor and its timestamp. The timestamp is used to
+ track the last update of the buffer. The timestamp is set to -1.0 by default, indicating that the buffer
+ has not been updated yet. The timestamp should be updated whenever the data in the buffer is updated. This
+ way the buffer can be used to check whether the data is outdated and needs to be refreshed.
+
+ The buffer is useful for creating lazy buffers that only update the data when it is outdated. This can be
+ useful when the data is expensive to compute or retrieve. For example usage, refer to the data classes in
+ the :mod:`omni.isaac.lab.assets` module.
+ """
+
+ data:torch.Tensor=None# type: ignore
+"""The data stored in the buffer. Default is None, indicating that the buffer is empty."""
+
+ timestamp:float=-1.0
+"""Timestamp at the last update of the buffer. Default is -1.0, indicating that the buffer has not been updated."""
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Sub-module that provides a wrapper around the Python 3.7 onwards ``dataclasses`` module."""
+
+importinspect
+importtypes
+fromcollections.abcimportCallable
+fromcopyimportdeepcopy
+fromdataclassesimportMISSING,Field,dataclass,field,replace
+fromtypingimportAny,ClassVar
+
+from.dictimportclass_to_dict,update_class_from_dict
+
+_CONFIGCLASS_METHODS=["to_dict","from_dict","replace","copy"]
+"""List of class methods added at runtime to dataclass."""
+
+"""
+Wrapper around dataclass.
+"""
+
+
+def__dataclass_transform__():
+"""Add annotations decorator for PyLance."""
+ returnlambdaa:a
+
+
+
[文档]@__dataclass_transform__()
+defconfigclass(cls,**kwargs):
+"""Wrapper around `dataclass` functionality to add extra checks and utilities.
+
+ As of Python 3.7, the standard dataclasses have two main issues which makes them non-generic for
+ configuration use-cases. These include:
+
+ 1. Requiring a type annotation for all its members.
+ 2. Requiring explicit usage of :meth:`field(default_factory=...)` to reinitialize mutable variables.
+
+ This function provides a decorator that wraps around Python's `dataclass`_ utility to deal with
+ the above two issues. It also provides additional helper functions for dictionary <-> class
+ conversion and easily copying class instances.
+
+ Usage:
+
+ .. code-block:: python
+
+ from dataclasses import MISSING
+
+ from omni.isaac.lab.utils.configclass import configclass
+
+
+ @configclass
+ class ViewerCfg:
+ eye: list = [7.5, 7.5, 7.5] # field missing on purpose
+ lookat: list = field(default_factory=[0.0, 0.0, 0.0])
+
+
+ @configclass
+ class EnvCfg:
+ num_envs: int = MISSING
+ episode_length: int = 2000
+ viewer: ViewerCfg = ViewerCfg()
+
+ # create configuration instance
+ env_cfg = EnvCfg(num_envs=24)
+
+ # print information as a dictionary
+ print(env_cfg.to_dict())
+
+ # create a copy of the configuration
+ env_cfg_copy = env_cfg.copy()
+
+ # replace arbitrary fields using keyword arguments
+ env_cfg_copy = env_cfg_copy.replace(num_envs=32)
+
+ Args:
+ cls: The class to wrap around.
+ **kwargs: Additional arguments to pass to :func:`dataclass`.
+
+ Returns:
+ The wrapped class.
+
+ .. _dataclass: https://docs.python.org/3/library/dataclasses.html
+ """
+ # add type annotations
+ _add_annotation_types(cls)
+ # add field factory
+ _process_mutable_types(cls)
+ # copy mutable members
+ # note: we check if user defined __post_init__ function exists and augment it with our own
+ ifhasattr(cls,"__post_init__"):
+ setattr(cls,"__post_init__",_combined_function(cls.__post_init__,_custom_post_init))
+ else:
+ setattr(cls,"__post_init__",_custom_post_init)
+ # add helper functions for dictionary conversion
+ setattr(cls,"to_dict",_class_to_dict)
+ setattr(cls,"from_dict",_update_class_from_dict)
+ setattr(cls,"replace",_replace_class_with_kwargs)
+ setattr(cls,"copy",_copy_class)
+ # wrap around dataclass
+ cls=dataclass(cls,**kwargs)
+ # return wrapped class
+ returncls
+
+
+"""
+Dictionary <-> Class operations.
+
+These are redefined here to add new docstrings.
+"""
+
+
+def_class_to_dict(obj:object)->dict[str,Any]:
+"""Convert an object into dictionary recursively.
+
+ Args:
+ obj: The object to convert.
+
+ Returns:
+ Converted dictionary mapping.
+ """
+ returnclass_to_dict(obj)
+
+
+def_update_class_from_dict(obj,data:dict[str,Any])->None:
+"""Reads a dictionary and sets object variables recursively.
+
+ This function performs in-place update of the class member attributes.
+
+ Args:
+ obj: The object to update.
+ data: Input (nested) dictionary to update from.
+
+ Raises:
+ TypeError: When input is not a dictionary.
+ ValueError: When dictionary has a value that does not match default config type.
+ KeyError: When dictionary has a key that does not exist in the default config type.
+ """
+ update_class_from_dict(obj,data,_ns="")
+
+
+def_replace_class_with_kwargs(obj:object,**kwargs)->object:
+"""Return a new object replacing specified fields with new values.
+
+ This is especially useful for frozen classes. Example usage:
+
+ .. code-block:: python
+
+ @configclass(frozen=True)
+ class C:
+ x: int
+ y: int
+
+ c = C(1, 2)
+ c1 = c.replace(x=3)
+ assert c1.x == 3 and c1.y == 2
+
+ Args:
+ obj: The object to replace.
+ **kwargs: The fields to replace and their new values.
+
+ Returns:
+ The new object.
+ """
+ returnreplace(obj,**kwargs)
+
+
+def_copy_class(obj:object)->object:
+"""Return a new object with the same fields as the original."""
+ returnreplace(obj)
+
+
+"""
+Private helper functions.
+"""
+
+
+def_add_annotation_types(cls):
+"""Add annotations to all elements in the dataclass.
+
+ By definition in Python, a field is defined as a class variable that has a type annotation.
+
+ In case type annotations are not provided, dataclass ignores those members when :func:`__dict__()` is called.
+ This function adds these annotations to the class variable to prevent any issues in case the user forgets to
+ specify the type annotation.
+
+ This makes the following a feasible operation:
+
+ @dataclass
+ class State:
+ pos = (0.0, 0.0, 0.0)
+ ^^
+ If the function is NOT used, the following type-error is returned:
+ TypeError: 'pos' is a field but has no type annotation
+ """
+ # get type hints
+ hints={}
+ # iterate over class inheritance
+ # we add annotations from base classes first
+ forbaseinreversed(cls.__mro__):
+ # check if base is object
+ ifbaseisobject:
+ continue
+ # get base class annotations
+ ann=base.__dict__.get("__annotations__",{})
+ # directly add all annotations from base class
+ hints.update(ann)
+ # iterate over base class members
+ # Note: Do not change this to dir(base) since it orders the members alphabetically.
+ # This is not desirable since the order of the members is important in some cases.
+ forkeyinbase.__dict__:
+ # get class member
+ value=getattr(base,key)
+ # skip members
+ if_skippable_class_member(key,value,hints):
+ continue
+ # add type annotations for members that don't have explicit type annotations
+ # for these, we deduce the type from the default value
+ ifnotisinstance(value,type):
+ ifkeynotinhints:
+ # check if var type is not MISSING
+ # we cannot deduce type from MISSING!
+ ifvalueisMISSING:
+ raiseTypeError(
+ f"Missing type annotation for '{key}' in class '{cls.__name__}'."
+ " Please add a type annotation or set a default value."
+ )
+ # add type annotation
+ hints[key]=type(value)
+ elifkey!=value.__name__:
+ # note: we don't want to add type annotations for nested configclass. Thus, we check if
+ # the name of the type matches the name of the variable.
+ # since Python 3.10, type hints are stored as strings
+ hints[key]=f"type[{value.__name__}]"
+
+ # Note: Do not change this line. `cls.__dict__.get("__annotations__", {})` is different from
+ # `cls.__annotations__` because of inheritance.
+ cls.__annotations__=cls.__dict__.get("__annotations__",{})
+ cls.__annotations__=hints
+
+
+def_process_mutable_types(cls):
+"""Initialize all mutable elements through :obj:`dataclasses.Field` to avoid unnecessary complaints.
+
+ By default, dataclass requires usage of :obj:`field(default_factory=...)` to reinitialize mutable objects every time a new
+ class instance is created. If a member has a mutable type and it is created without specifying the `field(default_factory=...)`,
+ then Python throws an error requiring the usage of `default_factory`.
+
+ Additionally, Python only explicitly checks for field specification when the type is a list, set or dict. This misses the
+ use-case where the type is class itself. Thus, the code silently carries a bug with it which can lead to undesirable effects.
+
+ This function deals with this issue
+
+ This makes the following a feasible operation:
+
+ @dataclass
+ class State:
+ pos: list = [0.0, 0.0, 0.0]
+ ^^
+ If the function is NOT used, the following value-error is returned:
+ ValueError: mutable default <class 'list'> for field pos is not allowed: use default_factory
+ """
+ # note: Need to set this up in the same order as annotations. Otherwise, it
+ # complains about missing positional arguments.
+ ann=cls.__dict__.get("__annotations__",{})
+
+ # iterate over all class members and store them in a dictionary
+ class_members={}
+ forbaseinreversed(cls.__mro__):
+ # check if base is object
+ ifbaseisobject:
+ continue
+ # iterate over base class members
+ forkeyinbase.__dict__:
+ # get class member
+ f=getattr(base,key)
+ # skip members
+ if_skippable_class_member(key,f):
+ continue
+ # store class member if it is not a type or if it is already present in annotations
+ ifnotisinstance(f,type)orkeyinann:
+ class_members[key]=f
+ # iterate over base class data fields
+ # in previous call, things that became a dataclass field were removed from class members
+ # so we need to add them back here as a dataclass field directly
+ forkey,finbase.__dict__.get("__dataclass_fields__",{}).items():
+ # store class member
+ ifnotisinstance(f,type):
+ class_members[key]=f
+
+ # check that all annotations are present in class members
+ # note: mainly for debugging purposes
+ iflen(class_members)!=len(ann):
+ raiseValueError(
+ f"In class '{cls.__name__}', number of annotations ({len(ann)}) does not match number of class members"
+ f" ({len(class_members)}). Please check that all class members have type annotations and/or a default"
+ " value. If you don't want to specify a default value, please use the literal `dataclasses.MISSING`."
+ )
+ # iterate over annotations and add field factory for mutable types
+ forkeyinann:
+ # find matching field in class
+ value=class_members.get(key,MISSING)
+ # check if key belongs to ClassVar
+ # in that case, we cannot use default_factory!
+ origin=getattr(ann[key],"__origin__",None)
+ iforiginisClassVar:
+ continue
+ # check if f is MISSING
+ # note: commented out for now since it causes issue with inheritance
+ # of dataclasses when parent have some positional and some keyword arguments.
+ # Ref: https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
+ # TODO: check if this is fixed in Python 3.10
+ # if f is MISSING:
+ # continue
+ ifisinstance(value,Field):
+ setattr(cls,key,value)
+ elifnotisinstance(value,type):
+ # create field factory for mutable types
+ value=field(default_factory=_return_f(value))
+ setattr(cls,key,value)
+
+
+def_custom_post_init(obj):
+"""Deepcopy all elements to avoid shared memory issues for mutable objects in dataclasses initialization.
+
+ This function is called explicitly instead of as a part of :func:`_process_mutable_types()` to prevent mapping
+ proxy type i.e. a read only proxy for mapping objects. The error is thrown when using hierarchical data-classes
+ for configuration.
+ """
+ forkeyindir(obj):
+ # skip dunder members
+ ifkey.startswith("__"):
+ continue
+ # get data member
+ value=getattr(obj,key)
+ # check annotation
+ ann=obj.__class__.__dict__.get(key)
+ # duplicate data members that are mutable
+ ifnotcallable(value)andnotisinstance(ann,property):
+ setattr(obj,key,deepcopy(value))
+
+
+def_combined_function(f1:Callable,f2:Callable)->Callable:
+"""Combine two functions into one.
+
+ Args:
+ f1: The first function.
+ f2: The second function.
+
+ Returns:
+ The combined function.
+ """
+
+ def_combined(*args,**kwargs):
+ # call both functions
+ f1(*args,**kwargs)
+ f2(*args,**kwargs)
+
+ return_combined
+
+
+"""
+Helper functions
+"""
+
+
+def_skippable_class_member(key:str,value:Any,hints:dict|None=None)->bool:
+"""Check if the class member should be skipped in configclass processing.
+
+ The following members are skipped:
+
+ * Dunder members: ``__name__``, ``__module__``, ``__qualname__``, ``__annotations__``, ``__dict__``.
+ * Manually-added special class functions: From :obj:`_CONFIGCLASS_METHODS`.
+ * Members that are already present in the type annotations.
+ * Functions bounded to class object or class.
+ * Properties bounded to class object.
+
+ Args:
+ key: The class member name.
+ value: The class member value.
+ hints: The type hints for the class. Defaults to None, in which case, the
+ members existence in type hints are not checked.
+
+ Returns:
+ True if the class member should be skipped, False otherwise.
+ """
+ # skip dunder members
+ ifkey.startswith("__"):
+ returnTrue
+ # skip manually-added special class functions
+ ifkeyin_CONFIGCLASS_METHODS:
+ returnTrue
+ # check if key is already present
+ ifhintsisnotNoneandkeyinhints:
+ returnTrue
+ # skip functions bounded to class
+ ifcallable(value):
+ # FIXME: This doesn't yet work for static methods because they are essentially seen as function types.
+ # check for class methods
+ ifisinstance(value,types.MethodType):
+ returnTrue
+ # check for instance methods
+ signature=inspect.signature(value)
+ if"self"insignature.parametersor"cls"insignature.parameters:
+ returnTrue
+ # skip property methods
+ ifisinstance(value,property):
+ returnTrue
+ # Otherwise, don't skip
+ returnFalse
+
+
+def_return_f(f:Any)->Callable[[],Any]:
+"""Returns default factory function for creating mutable/immutable variables.
+
+ This function should be used to create default factory functions for variables.
+
+ Example:
+
+ .. code-block:: python
+
+ value = field(default_factory=_return_f(value))
+ setattr(cls, key, value)
+ """
+
+ def_wrap():
+ ifisinstance(f,Field):
+ iff.default_factoryisMISSING:
+ returndeepcopy(f.default)
+ else:
+ returnf.default_factory
+ else:
+ returndeepcopy(f)
+
+ return_wrap
+
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Sub-module for utilities for working with dictionaries."""
+
+importcollections.abc
+importhashlib
+importjson
+fromcollections.abcimportIterable,Mapping
+fromtypingimportAny
+
+from.arrayimportTENSOR_TYPE_CONVERSIONS,TENSOR_TYPES
+from.stringimportcallable_to_string,string_to_callable,string_to_slice
+
+"""
+Dictionary <-> Class operations.
+"""
+
+
+
[文档]defclass_to_dict(obj:object)->dict[str,Any]:
+"""Convert an object into dictionary recursively.
+
+ Note:
+ Ignores all names starting with "__" (i.e. built-in methods).
+
+ Args:
+ obj: An instance of a class to convert.
+
+ Raises:
+ ValueError: When input argument is not an object.
+
+ Returns:
+ Converted dictionary mapping.
+ """
+ # check that input data is class instance
+ ifnothasattr(obj,"__class__"):
+ raiseValueError(f"Expected a class instance. Received: {type(obj)}.")
+ # convert object to dictionary
+ ifisinstance(obj,dict):
+ obj_dict=obj
+ else:
+ obj_dict=obj.__dict__
+
+ # convert to dictionary
+ data=dict()
+ forkey,valueinobj_dict.items():
+ # disregard builtin attributes
+ ifkey.startswith("__"):
+ continue
+ # check if attribute is callable -- function
+ ifcallable(value):
+ data[key]=callable_to_string(value)
+ # check if attribute is a dictionary
+ elifhasattr(value,"__dict__")orisinstance(value,dict):
+ data[key]=class_to_dict(value)
+ else:
+ data[key]=value
+ returndata
+
+
+
[文档]defupdate_class_from_dict(obj,data:dict[str,Any],_ns:str="")->None:
+"""Reads a dictionary and sets object variables recursively.
+
+ This function performs in-place update of the class member attributes.
+
+ Args:
+ obj: An instance of a class to update.
+ data: Input dictionary to update from.
+ _ns: Namespace of the current object. This is useful for nested configuration
+ classes or dictionaries. Defaults to "".
+
+ Raises:
+ TypeError: When input is not a dictionary.
+ ValueError: When dictionary has a value that does not match default config type.
+ KeyError: When dictionary has a key that does not exist in the default config type.
+ """
+ forkey,valueindata.items():
+ # key_ns is the full namespace of the key
+ key_ns=_ns+"/"+key
+ # check if key is present in the object
+ ifhasattr(obj,key)orisinstance(obj,dict):
+ obj_mem=obj[key]ifisinstance(obj,dict)elsegetattr(obj,key)
+ ifisinstance(value,Mapping):
+ # recursively call if it is a dictionary
+ update_class_from_dict(obj_mem,value,_ns=key_ns)
+ continue
+ ifisinstance(value,Iterable)andnotisinstance(value,str):
+ # check length of value to be safe
+ iflen(obj_mem)!=len(value)andobj_memisnotNone:
+ raiseValueError(
+ f"[Config]: Incorrect length under namespace: {key_ns}."
+ f" Expected: {len(obj_mem)}, Received: {len(value)}."
+ )
+ ifisinstance(obj_mem,tuple):
+ value=tuple(value)
+ else:
+ set_obj=True
+ # recursively call if iterable contains dictionaries
+ foriinrange(len(obj_mem)):
+ ifisinstance(value[i],dict):
+ update_class_from_dict(obj_mem[i],value[i],_ns=key_ns)
+ set_obj=False
+ # do not set value to obj, otherwise it overwrites the cfg class with the dict
+ ifnotset_obj:
+ continue
+ elifcallable(obj_mem):
+ # update function name
+ value=string_to_callable(value)
+ elifisinstance(value,type(obj_mem))orvalueisNone:
+ pass
+ else:
+ raiseValueError(
+ f"[Config]: Incorrect type under namespace: {key_ns}."
+ f" Expected: {type(obj_mem)}, Received: {type(value)}."
+ )
+ # set value
+ ifisinstance(obj,dict):
+ obj[key]=value
+ else:
+ setattr(obj,key,value)
+ else:
+ raiseKeyError(f"[Config]: Key not found under namespace: {key_ns}.")
[文档]defdict_to_md5_hash(data:object)->str:
+"""Convert a dictionary into a hashable key using MD5 hash.
+
+ Args:
+ data: Input dictionary or configuration object to convert.
+
+ Returns:
+ A string object of double length containing only hexadecimal digits.
+ """
+ # convert to dictionary
+ ifisinstance(data,dict):
+ encoded_buffer=json.dumps(data,sort_keys=True).encode()
+ else:
+ encoded_buffer=json.dumps(class_to_dict(data),sort_keys=True).encode()
+ # compute hash using MD5
+ data_hash=hashlib.md5()
+ data_hash.update(encoded_buffer)
+ # return the hash key
+ returndata_hash.hexdigest()
+
+
+"""
+Dictionary operations.
+"""
+
+
+
[文档]defconvert_dict_to_backend(
+ data:dict,backend:str="numpy",array_types:Iterable[str]=("numpy","torch","warp")
+)->dict:
+"""Convert all arrays or tensors in a dictionary to a given backend.
+
+ This function iterates over the dictionary, converts all arrays or tensors with the given types to
+ the desired backend, and stores them in a new dictionary. It also works with nested dictionaries.
+
+ Currently supported backends are "numpy", "torch", and "warp".
+
+ Note:
+ This function only converts arrays or tensors. Other types of data are left unchanged. Mutable types
+ (e.g. lists) are referenced by the new dictionary, so they are not copied.
+
+ Args:
+ data: An input dict containing array or tensor data as values.
+ backend: The backend ("numpy", "torch", "warp") to which arrays in this dict should be converted.
+ Defaults to "numpy".
+ array_types: A list containing the types of arrays that should be converted to
+ the desired backend. Defaults to ("numpy", "torch", "warp").
+
+ Raises:
+ ValueError: If the specified ``backend`` or ``array_types`` are unknown, i.e. not in the list of supported
+ backends ("numpy", "torch", "warp").
+
+ Returns:
+ The updated dict with the data converted to the desired backend.
+ """
+ # THINK: Should we also support converting to a specific device, e.g. "cuda:0"?
+ # Check the backend is valid.
+ ifbackendnotinTENSOR_TYPE_CONVERSIONS:
+ raiseValueError(f"Unknown backend '{backend}'. Supported backends are 'numpy', 'torch', and 'warp'.")
+ # Define the conversion functions for each backend.
+ tensor_type_conversions=TENSOR_TYPE_CONVERSIONS[backend]
+
+ # Parse the array types and convert them to the corresponding types: "numpy" -> np.ndarray, etc.
+ parsed_types=list()
+ fortinarray_types:
+ # Check type is valid.
+ iftnotinTENSOR_TYPES:
+ raiseValueError(f"Unknown array type: '{t}'. Supported array types are 'numpy', 'torch', and 'warp'.")
+ # Exclude types that match the backend, since we do not need to convert these.
+ ift==backend:
+ continue
+ # Convert the string types to the corresponding types.
+ parsed_types.append(TENSOR_TYPES[t])
+
+ # Convert the data to the desired backend.
+ output_dict=dict()
+ forkey,valueindata.items():
+ # Obtain the data type of the current value.
+ data_type=type(value)
+ # -- arrays
+ ifdata_typeinparsed_types:
+ # check if we have a known conversion.
+ ifdata_typenotintensor_type_conversions:
+ raiseValueError(f"No registered conversion for data type: {data_type} to {backend}!")
+ # convert the data to the desired backend.
+ output_dict[key]=tensor_type_conversions[data_type](value)
+ # -- nested dictionaries
+ elifisinstance(data[key],dict):
+ output_dict[key]=convert_dict_to_backend(value)
+ # -- everything else
+ else:
+ output_dict[key]=value
+
+ returnoutput_dict
+
+
+
[文档]defupdate_dict(orig_dict:dict,new_dict:collections.abc.Mapping)->dict:
+"""Updates existing dictionary with values from a new dictionary.
+
+ This function mimics the dict.update() function. However, it works for
+ nested dictionaries as well.
+
+ Args:
+ orig_dict: The original dictionary to insert items to.
+ new_dict: The new dictionary to insert items from.
+
+ Returns:
+ The updated dictionary.
+ """
+ forkeyname,valueinnew_dict.items():
+ ifisinstance(value,collections.abc.Mapping):
+ orig_dict[keyname]=update_dict(orig_dict.get(keyname,{}),value)
+ else:
+ orig_dict[keyname]=value
+ returnorig_dict
+
+
+
[文档]defreplace_slices_with_strings(data:dict)->dict:
+"""Replace slice objects with their string representations in a dictionary.
+
+ Args:
+ data: The dictionary to process.
+
+ Returns:
+ The dictionary with slice objects replaced by their string representations.
+ """
+ ifisinstance(data,dict):
+ return{k:replace_slices_with_strings(v)fork,vindata.items()}
+ elifisinstance(data,slice):
+ returnf"slice({data.start},{data.stop},{data.step})"
+ else:
+ returndata
+
+
+
[文档]defreplace_strings_with_slices(data:dict)->dict:
+"""Replace string representations of slices with slice objects in a dictionary.
+
+ Args:
+ data: The dictionary to process.
+
+ Returns:
+ The dictionary with string representations of slices replaced by slice objects.
+ """
+ ifisinstance(data,dict):
+ return{k:replace_strings_with_slices(v)fork,vindata.items()}
+ elifisinstance(data,str)anddata.startswith("slice("):
+ returnstring_to_slice(data)
+ else:
+ returndata
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+importtorch
+
+
+
[文档]classLinearInterpolation:
+"""Linearly interpolates a sampled scalar function for arbitrary query points.
+
+ This class implements a linear interpolation for a scalar function. The function maps from real values, x, to
+ real values, y. It expects a set of samples from the function's domain, x, and the corresponding values, y.
+ The class allows querying the function's values at any arbitrary point.
+
+ The interpolation is done by finding the two closest points in x to the query point and then linearly
+ interpolating between the corresponding y values. For the query points that are outside the input points,
+ the class does a zero-order-hold extrapolation based on the boundary values. This means that the class
+ returns the value of the closest point in x.
+ """
+
+
[文档]def__init__(self,x:torch.Tensor,y:torch.Tensor,device:str):
+"""Initializes the linear interpolation.
+
+ The scalar function maps from real values, x, to real values, y. The input to the class is a set of samples
+ from the function's domain, x, and the corresponding values, y.
+
+ Note:
+ The input tensor x should be sorted in ascending order.
+
+ Args:
+ x: An vector of samples from the function's domain. The values should be sorted in ascending order.
+ Shape is (num_samples,)
+ y: The function's values associated to the input x. Shape is (num_samples,)
+ device: The device used for processing.
+
+ Raises:
+ ValueError: If the input tensors are empty or have different sizes.
+ ValueError: If the input tensor x is not sorted in ascending order.
+ """
+ # make sure that input tensors are 1D of size (num_samples,)
+ self._x=x.view(-1).clone().to(device=device)
+ self._y=y.view(-1).clone().to(device=device)
+
+ # make sure sizes are correct
+ ifself._x.numel()==0:
+ raiseValueError("Input tensor x is empty!")
+ ifself._x.numel()!=self._y.numel():
+ raiseValueError(f"Input tensors x and y have different sizes: {self._x.numel()} != {self._y.numel()}")
+ # make sure that x is sorted
+ iftorch.any(self._x[1:]<self._x[:-1]):
+ raiseValueError("Input tensor x is not sorted in ascending order!")
+
+
[文档]defcompute(self,q:torch.Tensor)->torch.Tensor:
+"""Calculates a linearly interpolated values for the query points.
+
+ Args:
+ q: The query points. It can have any arbitrary shape.
+
+ Returns:
+ The interpolated values at query points. It has the same shape as the input tensor.
+ """
+ # serialized q
+ q_1d=q.view(-1)
+ # Number of elements in the x that are strictly smaller than query points (use int32 instead of int64)
+ num_smaller_elements=torch.sum(self._x.unsqueeze(1)<q_1d.unsqueeze(0),dim=0,dtype=torch.int)
+
+ # The index pointing to the first element in x such that x[lower_bound_i] < q_i
+ # If a point is smaller that all x elements, it will assign 0
+ lower_bound=torch.clamp(num_smaller_elements-1,min=0)
+ # The index pointing to the first element in x such that x[upper_bound_i] >= q_i
+ # If a point is greater than all x elements, it will assign the last elements' index
+ upper_bound=torch.clamp(num_smaller_elements,max=self._x.numel()-1)
+
+ # compute the weight as: (q_i - x_lb) / (x_ub - x_lb)
+ weight=(q_1d-self._x[lower_bound])/(self._x[upper_bound]-self._x[lower_bound])
+ # If a point is out of bounds assign weight 0.0
+ weight[upper_bound==lower_bound]=0.0
+
+ # Perform linear interpolation
+ fq=self._y[lower_bound]+weight*(self._y[upper_bound]-self._y[lower_bound])
+
+ # deserialized fq
+ fq=fq.view(q.shape)
+ returnfq
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Utilities for file I/O with pickle."""
+
+importos
+importpickle
+fromtypingimportAny
+
+
+
[文档]defload_pickle(filename:str)->Any:
+"""Loads an input PKL file safely.
+
+ Args:
+ filename: The path to pickled file.
+
+ Raises:
+ FileNotFoundError: When the specified file does not exist.
+
+ Returns:
+ The data read from the input file.
+ """
+ ifnotos.path.exists(filename):
+ raiseFileNotFoundError(f"File not found: {filename}")
+ withopen(filename,"rb")asf:
+ data=pickle.load(f)
+ returndata
+
+
+
[文档]defdump_pickle(filename:str,data:Any):
+"""Saves data into a pickle file safely.
+
+ Note:
+ The function creates any missing directory along the file's path.
+
+ Args:
+ filename: The path to save the file at.
+ data: The data to save.
+ """
+ # check ending
+ ifnotfilename.endswith("pkl"):
+ filename+=".pkl"
+ # create directory
+ ifnotos.path.exists(os.path.dirname(filename)):
+ os.makedirs(os.path.dirname(filename),exist_ok=True)
+ # save data
+ withopen(filename,"wb")asf:
+ pickle.dump(data,f)
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Utilities for file I/O with yaml."""
+
+importos
+importyaml
+
+fromomni.isaac.lab.utilsimportclass_to_dict
+
+
+
[文档]defload_yaml(filename:str)->dict:
+"""Loads an input PKL file safely.
+
+ Args:
+ filename: The path to pickled file.
+
+ Raises:
+ FileNotFoundError: When the specified file does not exist.
+
+ Returns:
+ The data read from the input file.
+ """
+ ifnotos.path.exists(filename):
+ raiseFileNotFoundError(f"File not found: {filename}")
+ withopen(filename)asf:
+ data=yaml.full_load(f)
+ returndata
+
+
+
[文档]defdump_yaml(filename:str,data:dict|object,sort_keys:bool=False):
+"""Saves data into a YAML file safely.
+
+ Note:
+ The function creates any missing directory along the file's path.
+
+ Args:
+ filename: The path to save the file at.
+ data: The data to save either a dictionary or class object.
+ sort_keys: Whether to sort the keys in the output file. Defaults to False.
+ """
+ # check ending
+ ifnotfilename.endswith("yaml"):
+ filename+=".yaml"
+ # create directory
+ ifnotos.path.exists(os.path.dirname(filename)):
+ os.makedirs(os.path.dirname(filename),exist_ok=True)
+ # convert data into dictionary
+ ifnotisinstance(data,dict):
+ data=class_to_dict(data)
+ # save data
+ withopen(filename,"w")asf:
+ yaml.dump(data,f,default_flow_style=False,sort_keys=sort_keys)
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Sub-module containing utilities for various math operations."""
+
+# needed to import for allowing type-hinting: torch.Tensor | np.ndarray
+from__future__importannotations
+
+importnumpyasnp
+importtorch
+importtorch.nn.functional
+fromtypingimportLiteral
+
+"""
+General
+"""
+
+
+@torch.jit.script
+defscale_transform(x:torch.Tensor,lower:torch.Tensor,upper:torch.Tensor)->torch.Tensor:
+"""Normalizes a given input tensor to a range of [-1, 1].
+
+ .. note::
+ It uses pytorch broadcasting functionality to deal with batched input.
+
+ Args:
+ x: Input tensor of shape (N, dims).
+ lower: The minimum value of the tensor. Shape is (N, dims) or (dims,).
+ upper: The maximum value of the tensor. Shape is (N, dims) or (dims,).
+
+ Returns:
+ Normalized transform of the tensor. Shape is (N, dims).
+ """
+ # default value of center
+ offset=(lower+upper)*0.5
+ # return normalized tensor
+ return2*(x-offset)/(upper-lower)
+
+
+@torch.jit.script
+defunscale_transform(x:torch.Tensor,lower:torch.Tensor,upper:torch.Tensor)->torch.Tensor:
+"""De-normalizes a given input tensor from range of [-1, 1] to (lower, upper).
+
+ .. note::
+ It uses pytorch broadcasting functionality to deal with batched input.
+
+ Args:
+ x: Input tensor of shape (N, dims).
+ lower: The minimum value of the tensor. Shape is (N, dims) or (dims,).
+ upper: The maximum value of the tensor. Shape is (N, dims) or (dims,).
+
+ Returns:
+ De-normalized transform of the tensor. Shape is (N, dims).
+ """
+ # default value of center
+ offset=(lower+upper)*0.5
+ # return normalized tensor
+ returnx*(upper-lower)*0.5+offset
+
+
+@torch.jit.script
+defsaturate(x:torch.Tensor,lower:torch.Tensor,upper:torch.Tensor)->torch.Tensor:
+"""Clamps a given input tensor to (lower, upper).
+
+ It uses pytorch broadcasting functionality to deal with batched input.
+
+ Args:
+ x: Input tensor of shape (N, dims).
+ lower: The minimum value of the tensor. Shape is (N, dims) or (dims,).
+ upper: The maximum value of the tensor. Shape is (N, dims) or (dims,).
+
+ Returns:
+ Clamped transform of the tensor. Shape is (N, dims).
+ """
+ returntorch.max(torch.min(x,upper),lower)
+
+
+@torch.jit.script
+defnormalize(x:torch.Tensor,eps:float=1e-9)->torch.Tensor:
+"""Normalizes a given input tensor to unit length.
+
+ Args:
+ x: Input tensor of shape (N, dims).
+ eps: A small value to avoid division by zero. Defaults to 1e-9.
+
+ Returns:
+ Normalized tensor of shape (N, dims).
+ """
+ returnx/x.norm(p=2,dim=-1).clamp(min=eps,max=None).unsqueeze(-1)
+
+
+@torch.jit.script
+defwrap_to_pi(angles:torch.Tensor)->torch.Tensor:
+r"""Wraps input angles (in radians) to the range :math:`[-\pi, \pi]`.
+
+ This function wraps angles in radians to the range :math:`[-\pi, \pi]`, such that
+ :math:`\pi` maps to :math:`\pi`, and :math:`-\pi` maps to :math:`-\pi`. In general,
+ odd positive multiples of :math:`\pi` are mapped to :math:`\pi`, and odd negative
+ multiples of :math:`\pi` are mapped to :math:`-\pi`.
+
+ The function behaves similar to MATLAB's `wrapToPi <https://www.mathworks.com/help/map/ref/wraptopi.html>`_
+ function.
+
+ Args:
+ angles: Input angles of any shape.
+
+ Returns:
+ Angles in the range :math:`[-\pi, \pi]`.
+ """
+ # wrap to [0, 2*pi)
+ wrapped_angle=(angles+torch.pi)%(2*torch.pi)
+ # map to [-pi, pi]
+ # we check for zero in wrapped angle to make it go to pi when input angle is odd multiple of pi
+ returntorch.where((wrapped_angle==0)&(angles>0),torch.pi,wrapped_angle-torch.pi)
+
+
+@torch.jit.script
+defcopysign(mag:float,other:torch.Tensor)->torch.Tensor:
+"""Create a new floating-point tensor with the magnitude of input and the sign of other, element-wise.
+
+ Note:
+ The implementation follows from `torch.copysign`. The function allows a scalar magnitude.
+
+ Args:
+ mag: The magnitude scalar.
+ other: The tensor containing values whose signbits are applied to magnitude.
+
+ Returns:
+ The output tensor.
+ """
+ mag_torch=torch.tensor(mag,device=other.device,dtype=torch.float).repeat(other.shape[0])
+ returntorch.abs(mag_torch)*torch.sign(other)
+
+
+"""
+Rotation
+"""
+
+
+@torch.jit.script
+defmatrix_from_quat(quaternions:torch.Tensor)->torch.Tensor:
+"""Convert rotations given as quaternions to rotation matrices.
+
+ Args:
+ quaternions: The quaternion orientation in (w, x, y, z). Shape is (..., 4).
+
+ Returns:
+ Rotation matrices. The shape is (..., 3, 3).
+
+ Reference:
+ https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L41-L70
+ """
+ r,i,j,k=torch.unbind(quaternions,-1)
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
+ two_s=2.0/(quaternions*quaternions).sum(-1)
+
+ o=torch.stack(
+ (
+ 1-two_s*(j*j+k*k),
+ two_s*(i*j-k*r),
+ two_s*(i*k+j*r),
+ two_s*(i*j+k*r),
+ 1-two_s*(i*i+k*k),
+ two_s*(j*k-i*r),
+ two_s*(i*k-j*r),
+ two_s*(j*k+i*r),
+ 1-two_s*(i*i+j*j),
+ ),
+ -1,
+ )
+ returno.reshape(quaternions.shape[:-1]+(3,3))
+
+
+
[文档]defconvert_quat(quat:torch.Tensor|np.ndarray,to:Literal["xyzw","wxyz"]="xyzw")->torch.Tensor|np.ndarray:
+"""Converts quaternion from one convention to another.
+
+ The convention to convert TO is specified as an optional argument. If to == 'xyzw',
+ then the input is in 'wxyz' format, and vice-versa.
+
+ Args:
+ quat: The quaternion of shape (..., 4).
+ to: Convention to convert quaternion to.. Defaults to "xyzw".
+
+ Returns:
+ The converted quaternion in specified convention.
+
+ Raises:
+ ValueError: Invalid input argument `to`, i.e. not "xyzw" or "wxyz".
+ ValueError: Invalid shape of input `quat`, i.e. not (..., 4,).
+ """
+ # check input is correct
+ ifquat.shape[-1]!=4:
+ msg=f"Expected input quaternion shape mismatch: {quat.shape} != (..., 4)."
+ raiseValueError(msg)
+ iftonotin["xyzw","wxyz"]:
+ msg=f"Expected input argument `to` to be 'xyzw' or 'wxyz'. Received: {to}."
+ raiseValueError(msg)
+ # check if input is numpy array (we support this backend since some classes use numpy)
+ ifisinstance(quat,np.ndarray):
+ # use numpy functions
+ ifto=="xyzw":
+ # wxyz -> xyzw
+ returnnp.roll(quat,-1,axis=-1)
+ else:
+ # xyzw -> wxyz
+ returnnp.roll(quat,1,axis=-1)
+ else:
+ # convert to torch (sanity check)
+ ifnotisinstance(quat,torch.Tensor):
+ quat=torch.tensor(quat,dtype=float)
+ # convert to specified quaternion type
+ ifto=="xyzw":
+ # wxyz -> xyzw
+ returnquat.roll(-1,dims=-1)
+ else:
+ # xyzw -> wxyz
+ returnquat.roll(1,dims=-1)
+
+
+@torch.jit.script
+defquat_conjugate(q:torch.Tensor)->torch.Tensor:
+"""Computes the conjugate of a quaternion.
+
+ Args:
+ q: The quaternion orientation in (w, x, y, z). Shape is (..., 4).
+
+ Returns:
+ The conjugate quaternion in (w, x, y, z). Shape is (..., 4).
+ """
+ shape=q.shape
+ q=q.reshape(-1,4)
+ returntorch.cat((q[:,0:1],-q[:,1:]),dim=-1).view(shape)
+
+
+@torch.jit.script
+defquat_inv(q:torch.Tensor)->torch.Tensor:
+"""Compute the inverse of a quaternion.
+
+ Args:
+ q: The quaternion orientation in (w, x, y, z). Shape is (N, 4).
+
+ Returns:
+ The inverse quaternion in (w, x, y, z). Shape is (N, 4).
+ """
+ returnnormalize(quat_conjugate(q))
+
+
+@torch.jit.script
+defquat_from_euler_xyz(roll:torch.Tensor,pitch:torch.Tensor,yaw:torch.Tensor)->torch.Tensor:
+"""Convert rotations given as Euler angles in radians to Quaternions.
+
+ Note:
+ The euler angles are assumed in XYZ convention.
+
+ Args:
+ roll: Rotation around x-axis (in radians). Shape is (N,).
+ pitch: Rotation around y-axis (in radians). Shape is (N,).
+ yaw: Rotation around z-axis (in radians). Shape is (N,).
+
+ Returns:
+ The quaternion in (w, x, y, z). Shape is (N, 4).
+ """
+ cy=torch.cos(yaw*0.5)
+ sy=torch.sin(yaw*0.5)
+ cr=torch.cos(roll*0.5)
+ sr=torch.sin(roll*0.5)
+ cp=torch.cos(pitch*0.5)
+ sp=torch.sin(pitch*0.5)
+ # compute quaternion
+ qw=cy*cr*cp+sy*sr*sp
+ qx=cy*sr*cp-sy*cr*sp
+ qy=cy*cr*sp+sy*sr*cp
+ qz=sy*cr*cp-cy*sr*sp
+
+ returntorch.stack([qw,qx,qy,qz],dim=-1)
+
+
+@torch.jit.script
+def_sqrt_positive_part(x:torch.Tensor)->torch.Tensor:
+"""Returns torch.sqrt(torch.max(0, x)) but with a zero sub-gradient where x is 0.
+
+ Reference:
+ https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L91-L99
+ """
+ ret=torch.zeros_like(x)
+ positive_mask=x>0
+ ret[positive_mask]=torch.sqrt(x[positive_mask])
+ returnret
+
+
+@torch.jit.script
+defquat_from_matrix(matrix:torch.Tensor)->torch.Tensor:
+"""Convert rotations given as rotation matrices to quaternions.
+
+ Args:
+ matrix: The rotation matrices. Shape is (..., 3, 3).
+
+ Returns:
+ The quaternion in (w, x, y, z). Shape is (..., 4).
+
+ Reference:
+ https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L102-L161
+ """
+ ifmatrix.size(-1)!=3ormatrix.size(-2)!=3:
+ raiseValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+
+ batch_dim=matrix.shape[:-2]
+ m00,m01,m02,m10,m11,m12,m20,m21,m22=torch.unbind(matrix.reshape(batch_dim+(9,)),dim=-1)
+
+ q_abs=_sqrt_positive_part(
+ torch.stack(
+ [
+ 1.0+m00+m11+m22,
+ 1.0+m00-m11-m22,
+ 1.0-m00+m11-m22,
+ 1.0-m00-m11+m22,
+ ],
+ dim=-1,
+ )
+ )
+
+ # we produce the desired quaternion multiplied by each of r, i, j, k
+ quat_by_rijk=torch.stack(
+ [
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
+ torch.stack([q_abs[...,0]**2,m21-m12,m02-m20,m10-m01],dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
+ torch.stack([m21-m12,q_abs[...,1]**2,m10+m01,m02+m20],dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
+ torch.stack([m02-m20,m10+m01,q_abs[...,2]**2,m12+m21],dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
+ torch.stack([m10-m01,m20+m02,m21+m12,q_abs[...,3]**2],dim=-1),
+ ],
+ dim=-2,
+ )
+
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
+ # the candidate won't be picked.
+ flr=torch.tensor(0.1).to(dtype=q_abs.dtype,device=q_abs.device)
+ quat_candidates=quat_by_rijk/(2.0*q_abs[...,None].max(flr))
+
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
+ # forall i; we pick the best-conditioned one (with the largest denominator)
+ returnquat_candidates[torch.nn.functional.one_hot(q_abs.argmax(dim=-1),num_classes=4)>0.5,:].reshape(
+ batch_dim+(4,)
+ )
+
+
+def_axis_angle_rotation(axis:Literal["X","Y","Z"],angle:torch.Tensor)->torch.Tensor:
+"""Return the rotation matrices for one of the rotations about an axis of which Euler angles describe,
+ for each value of the angle given.
+
+ Args:
+ axis: Axis label "X" or "Y or "Z".
+ angle: Euler angles in radians of any shape.
+
+ Returns:
+ Rotation matrices. Shape is (..., 3, 3).
+
+ Reference:
+ https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L164-L191
+ """
+ cos=torch.cos(angle)
+ sin=torch.sin(angle)
+ one=torch.ones_like(angle)
+ zero=torch.zeros_like(angle)
+
+ ifaxis=="X":
+ R_flat=(one,zero,zero,zero,cos,-sin,zero,sin,cos)
+ elifaxis=="Y":
+ R_flat=(cos,zero,sin,zero,one,zero,-sin,zero,cos)
+ elifaxis=="Z":
+ R_flat=(cos,-sin,zero,sin,cos,zero,zero,zero,one)
+ else:
+ raiseValueError("letter must be either X, Y or Z.")
+
+ returntorch.stack(R_flat,-1).reshape(angle.shape+(3,3))
+
+
+
[文档]defmatrix_from_euler(euler_angles:torch.Tensor,convention:str)->torch.Tensor:
+"""
+ Convert rotations given as Euler angles in radians to rotation matrices.
+
+ Args:
+ euler_angles: Euler angles in radians. Shape is (..., 3).
+ convention: Convention string of three uppercase letters from {"X", "Y", and "Z"}.
+ For example, "XYZ" means that the rotations should be applied first about x,
+ then y, then z.
+
+ Returns:
+ Rotation matrices. Shape is (..., 3, 3).
+
+ Reference:
+ https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L194-L220
+ """
+ ifeuler_angles.dim()==0oreuler_angles.shape[-1]!=3:
+ raiseValueError("Invalid input euler angles.")
+ iflen(convention)!=3:
+ raiseValueError("Convention must have 3 letters.")
+ ifconvention[1]in(convention[0],convention[2]):
+ raiseValueError(f"Invalid convention {convention}.")
+ forletterinconvention:
+ ifletternotin("X","Y","Z"):
+ raiseValueError(f"Invalid letter {letter} in convention string.")
+ matrices=[_axis_angle_rotation(c,e)forc,einzip(convention,torch.unbind(euler_angles,-1))]
+ # return functools.reduce(torch.matmul, matrices)
+ returntorch.matmul(torch.matmul(matrices[0],matrices[1]),matrices[2])
+
+
+@torch.jit.script
+defeuler_xyz_from_quat(quat:torch.Tensor)->tuple[torch.Tensor,torch.Tensor,torch.Tensor]:
+"""Convert rotations given as quaternions to Euler angles in radians.
+
+ Note:
+ The euler angles are assumed in XYZ convention.
+
+ Args:
+ quat: The quaternion orientation in (w, x, y, z). Shape is (N, 4).
+
+ Returns:
+ A tuple containing roll-pitch-yaw. Each element is a tensor of shape (N,).
+
+ Reference:
+ https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles
+ """
+ q_w,q_x,q_y,q_z=quat[:,0],quat[:,1],quat[:,2],quat[:,3]
+ # roll (x-axis rotation)
+ sin_roll=2.0*(q_w*q_x+q_y*q_z)
+ cos_roll=1-2*(q_x*q_x+q_y*q_y)
+ roll=torch.atan2(sin_roll,cos_roll)
+
+ # pitch (y-axis rotation)
+ sin_pitch=2.0*(q_w*q_y-q_z*q_x)
+ pitch=torch.where(torch.abs(sin_pitch)>=1,copysign(torch.pi/2.0,sin_pitch),torch.asin(sin_pitch))
+
+ # yaw (z-axis rotation)
+ sin_yaw=2.0*(q_w*q_z+q_x*q_y)
+ cos_yaw=1-2*(q_y*q_y+q_z*q_z)
+ yaw=torch.atan2(sin_yaw,cos_yaw)
+
+ returnroll%(2*torch.pi),pitch%(2*torch.pi),yaw%(2*torch.pi)# TODO: why not wrap_to_pi here ?
+
+
+@torch.jit.script
+defquat_unique(q:torch.Tensor)->torch.Tensor:
+"""Convert a unit quaternion to a standard form where the real part is non-negative.
+
+ Quaternion representations have a singularity since ``q`` and ``-q`` represent the same
+ rotation. This function ensures the real part of the quaternion is non-negative.
+
+ Args:
+ q: The quaternion orientation in (w, x, y, z). Shape is (..., 4).
+
+ Returns:
+ Standardized quaternions. Shape is (..., 4).
+ """
+ returntorch.where(q[...,0:1]<0,-q,q)
+
+
+@torch.jit.script
+defquat_mul(q1:torch.Tensor,q2:torch.Tensor)->torch.Tensor:
+"""Multiply two quaternions together.
+
+ Args:
+ q1: The first quaternion in (w, x, y, z). Shape is (..., 4).
+ q2: The second quaternion in (w, x, y, z). Shape is (..., 4).
+
+ Returns:
+ The product of the two quaternions in (w, x, y, z). Shape is (..., 4).
+
+ Raises:
+ ValueError: Input shapes of ``q1`` and ``q2`` are not matching.
+ """
+ # check input is correct
+ ifq1.shape!=q2.shape:
+ msg=f"Expected input quaternion shape mismatch: {q1.shape} != {q2.shape}."
+ raiseValueError(msg)
+ # reshape to (N, 4) for multiplication
+ shape=q1.shape
+ q1=q1.reshape(-1,4)
+ q2=q2.reshape(-1,4)
+ # extract components from quaternions
+ w1,x1,y1,z1=q1[:,0],q1[:,1],q1[:,2],q1[:,3]
+ w2,x2,y2,z2=q2[:,0],q2[:,1],q2[:,2],q2[:,3]
+ # perform multiplication
+ ww=(z1+x1)*(x2+y2)
+ yy=(w1-y1)*(w2+z2)
+ zz=(w1+y1)*(w2-z2)
+ xx=ww+yy+zz
+ qq=0.5*(xx+(z1-x1)*(x2-y2))
+ w=qq-ww+(z1-y1)*(y2-z2)
+ x=qq-xx+(x1+w1)*(x2+w2)
+ y=qq-yy+(w1-x1)*(y2+z2)
+ z=qq-zz+(z1+y1)*(w2-x2)
+
+ returntorch.stack([w,x,y,z],dim=-1).view(shape)
+
+
+@torch.jit.script
+defquat_box_minus(q1:torch.Tensor,q2:torch.Tensor)->torch.Tensor:
+"""The box-minus operator (quaternion difference) between two quaternions.
+
+ Args:
+ q1: The first quaternion in (w, x, y, z). Shape is (N, 4).
+ q2: The second quaternion in (w, x, y, z). Shape is (N, 4).
+
+ Returns:
+ The difference between the two quaternions. Shape is (N, 3).
+ """
+ quat_diff=quat_mul(q1,quat_conjugate(q2))# q1 * q2^-1
+ re=quat_diff[:,0]# real part, q = [w, x, y, z] = [re, im]
+ im=quat_diff[:,1:]# imaginary part
+ norm_im=torch.norm(im,dim=1)
+ scale=2.0*torch.where(norm_im>1.0e-7,torch.atan2(norm_im,re)/norm_im,torch.sign(re))
+ returnscale.unsqueeze(-1)*im
+
+
+@torch.jit.script
+defyaw_quat(quat:torch.Tensor)->torch.Tensor:
+"""Extract the yaw component of a quaternion.
+
+ Args:
+ quat: The orientation in (w, x, y, z). Shape is (..., 4)
+
+ Returns:
+ A quaternion with only yaw component.
+ """
+ shape=quat.shape
+ quat_yaw=quat.clone().view(-1,4)
+ qw=quat_yaw[:,0]
+ qx=quat_yaw[:,1]
+ qy=quat_yaw[:,2]
+ qz=quat_yaw[:,3]
+ yaw=torch.atan2(2*(qw*qz+qx*qy),1-2*(qy*qy+qz*qz))
+ quat_yaw[:]=0.0
+ quat_yaw[:,3]=torch.sin(yaw/2)
+ quat_yaw[:,0]=torch.cos(yaw/2)
+ quat_yaw=normalize(quat_yaw)
+ returnquat_yaw.view(shape)
+
+
+@torch.jit.script
+defquat_apply(quat:torch.Tensor,vec:torch.Tensor)->torch.Tensor:
+"""Apply a quaternion rotation to a vector.
+
+ Args:
+ quat: The quaternion in (w, x, y, z). Shape is (..., 4).
+ vec: The vector in (x, y, z). Shape is (..., 3).
+
+ Returns:
+ The rotated vector in (x, y, z). Shape is (..., 3).
+ """
+ # store shape
+ shape=vec.shape
+ # reshape to (N, 3) for multiplication
+ quat=quat.reshape(-1,4)
+ vec=vec.reshape(-1,3)
+ # extract components from quaternions
+ xyz=quat[:,1:]
+ t=xyz.cross(vec,dim=-1)*2
+ return(vec+quat[:,0:1]*t+xyz.cross(t,dim=-1)).view(shape)
+
+
+@torch.jit.script
+defquat_apply_yaw(quat:torch.Tensor,vec:torch.Tensor)->torch.Tensor:
+"""Rotate a vector only around the yaw-direction.
+
+ Args:
+ quat: The orientation in (w, x, y, z). Shape is (N, 4).
+ vec: The vector in (x, y, z). Shape is (N, 3).
+
+ Returns:
+ The rotated vector in (x, y, z). Shape is (N, 3).
+ """
+ quat_yaw=yaw_quat(quat)
+ returnquat_apply(quat_yaw,vec)
+
+
+@torch.jit.script
+defquat_rotate(q:torch.Tensor,v:torch.Tensor)->torch.Tensor:
+"""Rotate a vector by a quaternion along the last dimension of q and v.
+
+ Args:
+ q: The quaternion in (w, x, y, z). Shape is (..., 4).
+ v: The vector in (x, y, z). Shape is (..., 3).
+
+ Returns:
+ The rotated vector in (x, y, z). Shape is (..., 3).
+ """
+ q_w=q[...,0]
+ q_vec=q[...,1:]
+ a=v*(2.0*q_w**2-1.0).unsqueeze(-1)
+ b=torch.cross(q_vec,v,dim=-1)*q_w.unsqueeze(-1)*2.0
+ # for two-dimensional tensors, bmm is faster than einsum
+ ifq_vec.dim()==2:
+ c=q_vec*torch.bmm(q_vec.view(q.shape[0],1,3),v.view(q.shape[0],3,1)).squeeze(-1)*2.0
+ else:
+ c=q_vec*torch.einsum("...i,...i->...",q_vec,v).unsqueeze(-1)*2.0
+ returna+b+c
+
+
+@torch.jit.script
+defquat_rotate_inverse(q:torch.Tensor,v:torch.Tensor)->torch.Tensor:
+"""Rotate a vector by the inverse of a quaternion along the last dimension of q and v.
+
+ Args:
+ q: The quaternion in (w, x, y, z). Shape is (..., 4).
+ v: The vector in (x, y, z). Shape is (..., 3).
+
+ Returns:
+ The rotated vector in (x, y, z). Shape is (..., 3).
+ """
+ q_w=q[...,0]
+ q_vec=q[...,1:]
+ a=v*(2.0*q_w**2-1.0).unsqueeze(-1)
+ b=torch.cross(q_vec,v,dim=-1)*q_w.unsqueeze(-1)*2.0
+ # for two-dimensional tensors, bmm is faster than einsum
+ ifq_vec.dim()==2:
+ c=q_vec*torch.bmm(q_vec.view(q.shape[0],1,3),v.view(q.shape[0],3,1)).squeeze(-1)*2.0
+ else:
+ c=q_vec*torch.einsum("...i,...i->...",q_vec,v).unsqueeze(-1)*2.0
+ returna-b+c
+
+
+@torch.jit.script
+defquat_from_angle_axis(angle:torch.Tensor,axis:torch.Tensor)->torch.Tensor:
+"""Convert rotations given as angle-axis to quaternions.
+
+ Args:
+ angle: The angle turned anti-clockwise in radians around the vector's direction. Shape is (N,).
+ axis: The axis of rotation. Shape is (N, 3).
+
+ Returns:
+ The quaternion in (w, x, y, z). Shape is (N, 4).
+ """
+ theta=(angle/2).unsqueeze(-1)
+ xyz=normalize(axis)*theta.sin()
+ w=theta.cos()
+ returnnormalize(torch.cat([w,xyz],dim=-1))
+
+
+@torch.jit.script
+defaxis_angle_from_quat(quat:torch.Tensor,eps:float=1.0e-6)->torch.Tensor:
+"""Convert rotations given as quaternions to axis/angle.
+
+ Args:
+ quat: The quaternion orientation in (w, x, y, z). Shape is (..., 4).
+ eps: The tolerance for Taylor approximation. Defaults to 1.0e-6.
+
+ Returns:
+ Rotations given as a vector in axis angle form. Shape is (..., 3).
+ The vector's magnitude is the angle turned anti-clockwise in radians around the vector's direction.
+
+ Reference:
+ https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L526-L554
+ """
+ # Modified to take in quat as [q_w, q_x, q_y, q_z]
+ # Quaternion is [q_w, q_x, q_y, q_z] = [cos(theta/2), n_x * sin(theta/2), n_y * sin(theta/2), n_z * sin(theta/2)]
+ # Axis-angle is [a_x, a_y, a_z] = [theta * n_x, theta * n_y, theta * n_z]
+ # Thus, axis-angle is [q_x, q_y, q_z] / (sin(theta/2) / theta)
+ # When theta = 0, (sin(theta/2) / theta) is undefined
+ # However, as theta --> 0, we can use the Taylor approximation 1/2 - theta^2 / 48
+ quat=quat*(1.0-2.0*(quat[...,0:1]<0.0))
+ mag=torch.linalg.norm(quat[...,1:],dim=-1)
+ half_angle=torch.atan2(mag,quat[...,0])
+ angle=2.0*half_angle
+ # check whether to apply Taylor approximation
+ sin_half_angles_over_angles=torch.where(
+ angle.abs()>eps,torch.sin(half_angle)/angle,0.5-angle*angle/48
+ )
+ returnquat[...,1:4]/sin_half_angles_over_angles.unsqueeze(-1)
+
+
+@torch.jit.script
+defquat_error_magnitude(q1:torch.Tensor,q2:torch.Tensor)->torch.Tensor:
+"""Computes the rotation difference between two quaternions.
+
+ Args:
+ q1: The first quaternion in (w, x, y, z). Shape is (..., 4).
+ q2: The second quaternion in (w, x, y, z). Shape is (..., 4).
+
+ Returns:
+ Angular error between input quaternions in radians.
+ """
+ quat_diff=quat_mul(q1,quat_conjugate(q2))
+ returntorch.norm(axis_angle_from_quat(quat_diff),dim=-1)
+
+
+@torch.jit.script
+defskew_symmetric_matrix(vec:torch.Tensor)->torch.Tensor:
+"""Computes the skew-symmetric matrix of a vector.
+
+ Args:
+ vec: The input vector. Shape is (3,) or (N, 3).
+
+ Returns:
+ The skew-symmetric matrix. Shape is (1, 3, 3) or (N, 3, 3).
+
+ Raises:
+ ValueError: If input tensor is not of shape (..., 3).
+ """
+ # check input is correct
+ ifvec.shape[-1]!=3:
+ raiseValueError(f"Expected input vector shape mismatch: {vec.shape} != (..., 3).")
+ # unsqueeze the last dimension
+ ifvec.ndim==1:
+ vec=vec.unsqueeze(0)
+ # create a skew-symmetric matrix
+ skew_sym_mat=torch.zeros(vec.shape[0],3,3,device=vec.device,dtype=vec.dtype)
+ skew_sym_mat[:,0,1]=-vec[:,2]
+ skew_sym_mat[:,0,2]=vec[:,1]
+ skew_sym_mat[:,1,2]=-vec[:,0]
+ skew_sym_mat[:,1,0]=vec[:,2]
+ skew_sym_mat[:,2,0]=-vec[:,1]
+ skew_sym_mat[:,2,1]=vec[:,0]
+
+ returnskew_sym_mat
+
+
+"""
+Transformations
+"""
+
+
+
[文档]defis_identity_pose(pos:torch.tensor,rot:torch.tensor)->bool:
+"""Checks if input poses are identity transforms.
+
+ The function checks if the input position and orientation are close to zero and
+ identity respectively using L2-norm. It does NOT check the error in the orientation.
+
+ Args:
+ pos: The cartesian position. Shape is (N, 3).
+ rot: The quaternion in (w, x, y, z). Shape is (N, 4).
+
+ Returns:
+ True if all the input poses result in identity transform. Otherwise, False.
+ """
+ # create identity transformations
+ pos_identity=torch.zeros_like(pos)
+ rot_identity=torch.zeros_like(rot)
+ rot_identity[...,0]=1
+ # compare input to identity
+ returntorch.allclose(pos,pos_identity)andtorch.allclose(rot,rot_identity)
+
+
+# @torch.jit.script
+
[文档]defcombine_frame_transforms(
+ t01:torch.Tensor,q01:torch.Tensor,t12:torch.Tensor|None=None,q12:torch.Tensor|None=None
+)->tuple[torch.Tensor,torch.Tensor]:
+r"""Combine transformations between two reference frames into a stationary frame.
+
+ It performs the following transformation operation: :math:`T_{02} = T_{01} \times T_{12}`,
+ where :math:`T_{AB}` is the homogeneous transformation matrix from frame A to B.
+
+ Args:
+ t01: Position of frame 1 w.r.t. frame 0. Shape is (N, 3).
+ q01: Quaternion orientation of frame 1 w.r.t. frame 0 in (w, x, y, z). Shape is (N, 4).
+ t12: Position of frame 2 w.r.t. frame 1. Shape is (N, 3).
+ Defaults to None, in which case the position is assumed to be zero.
+ q12: Quaternion orientation of frame 2 w.r.t. frame 1 in (w, x, y, z). Shape is (N, 4).
+ Defaults to None, in which case the orientation is assumed to be identity.
+
+ Returns:
+ A tuple containing the position and orientation of frame 2 w.r.t. frame 0.
+ Shape of the tensors are (N, 3) and (N, 4) respectively.
+ """
+ # compute orientation
+ ifq12isnotNone:
+ q02=quat_mul(q01,q12)
+ else:
+ q02=q01
+ # compute translation
+ ift12isnotNone:
+ t02=t01+quat_apply(q01,t12)
+ else:
+ t02=t01
+
+ returnt02,q02
+
+
+# @torch.jit.script
+
[文档]defsubtract_frame_transforms(
+ t01:torch.Tensor,q01:torch.Tensor,t02:torch.Tensor|None=None,q02:torch.Tensor|None=None
+)->tuple[torch.Tensor,torch.Tensor]:
+r"""Subtract transformations between two reference frames into a stationary frame.
+
+ It performs the following transformation operation: :math:`T_{12} = T_{01}^{-1} \times T_{02}`,
+ where :math:`T_{AB}` is the homogeneous transformation matrix from frame A to B.
+
+ Args:
+ t01: Position of frame 1 w.r.t. frame 0. Shape is (N, 3).
+ q01: Quaternion orientation of frame 1 w.r.t. frame 0 in (w, x, y, z). Shape is (N, 4).
+ t02: Position of frame 2 w.r.t. frame 0. Shape is (N, 3).
+ Defaults to None, in which case the position is assumed to be zero.
+ q02: Quaternion orientation of frame 2 w.r.t. frame 0 in (w, x, y, z). Shape is (N, 4).
+ Defaults to None, in which case the orientation is assumed to be identity.
+
+ Returns:
+ A tuple containing the position and orientation of frame 2 w.r.t. frame 1.
+ Shape of the tensors are (N, 3) and (N, 4) respectively.
+ """
+ # compute orientation
+ q10=quat_inv(q01)
+ ifq02isnotNone:
+ q12=quat_mul(q10,q02)
+ else:
+ q12=q10
+ # compute translation
+ ift02isnotNone:
+ t12=quat_apply(q10,t02-t01)
+ else:
+ t12=quat_apply(q10,-t01)
+ returnt12,q12
+
+
+# @torch.jit.script
+
[文档]defcompute_pose_error(
+ t01:torch.Tensor,
+ q01:torch.Tensor,
+ t02:torch.Tensor,
+ q02:torch.Tensor,
+ rot_error_type:Literal["quat","axis_angle"]="axis_angle",
+)->tuple[torch.Tensor,torch.Tensor]:
+"""Compute the position and orientation error between source and target frames.
+
+ Args:
+ t01: Position of source frame. Shape is (N, 3).
+ q01: Quaternion orientation of source frame in (w, x, y, z). Shape is (N, 4).
+ t02: Position of target frame. Shape is (N, 3).
+ q02: Quaternion orientation of target frame in (w, x, y, z). Shape is (N, 4).
+ rot_error_type: The rotation error type to return: "quat", "axis_angle".
+ Defaults to "axis_angle".
+
+ Returns:
+ A tuple containing position and orientation error. Shape of position error is (N, 3).
+ Shape of orientation error depends on the value of :attr:`rot_error_type`:
+
+ - If :attr:`rot_error_type` is "quat", the orientation error is returned
+ as a quaternion. Shape is (N, 4).
+ - If :attr:`rot_error_type` is "axis_angle", the orientation error is
+ returned as an axis-angle vector. Shape is (N, 3).
+
+ Raises:
+ ValueError: Invalid rotation error type.
+ """
+ # Compute quaternion error (i.e., difference quaternion)
+ # Reference: https://personal.utdallas.edu/~sxb027100/dock/quaternion.html
+ # q_current_norm = q_current * q_current_conj
+ source_quat_norm=quat_mul(q01,quat_conjugate(q01))[:,0]
+ # q_current_inv = q_current_conj / q_current_norm
+ source_quat_inv=quat_conjugate(q01)/source_quat_norm.unsqueeze(-1)
+ # q_error = q_target * q_current_inv
+ quat_error=quat_mul(q02,source_quat_inv)
+
+ # Compute position error
+ pos_error=t02-t01
+
+ # return error based on specified type
+ ifrot_error_type=="quat":
+ returnpos_error,quat_error
+ elifrot_error_type=="axis_angle":
+ # Convert to axis-angle error
+ axis_angle_error=axis_angle_from_quat(quat_error)
+ returnpos_error,axis_angle_error
+ else:
+ raiseValueError(f"Unsupported orientation error type: {rot_error_type}. Valid: 'quat', 'axis_angle'.")
+
+
+@torch.jit.script
+defapply_delta_pose(
+ source_pos:torch.Tensor,source_rot:torch.Tensor,delta_pose:torch.Tensor,eps:float=1.0e-6
+)->tuple[torch.Tensor,torch.Tensor]:
+"""Applies delta pose transformation on source pose.
+
+ The first three elements of `delta_pose` are interpreted as cartesian position displacement.
+ The remaining three elements of `delta_pose` are interpreted as orientation displacement
+ in the angle-axis format.
+
+ Args:
+ source_pos: Position of source frame. Shape is (N, 3).
+ source_rot: Quaternion orientation of source frame in (w, x, y, z). Shape is (N, 4)..
+ delta_pose: Position and orientation displacements. Shape is (N, 6).
+ eps: The tolerance to consider orientation displacement as zero. Defaults to 1.0e-6.
+
+ Returns:
+ A tuple containing the displaced position and orientation frames.
+ Shape of the tensors are (N, 3) and (N, 4) respectively.
+ """
+ # number of poses given
+ num_poses=source_pos.shape[0]
+ device=source_pos.device
+
+ # interpret delta_pose[:, 0:3] as target position displacements
+ target_pos=source_pos+delta_pose[:,0:3]
+ # interpret delta_pose[:, 3:6] as target rotation displacements
+ rot_actions=delta_pose[:,3:6]
+ angle=torch.linalg.vector_norm(rot_actions,dim=1)
+ axis=rot_actions/angle.unsqueeze(-1)
+ # change from axis-angle to quat convention
+ identity_quat=torch.tensor([1.0,0.0,0.0,0.0],device=device).repeat(num_poses,1)
+ rot_delta_quat=torch.where(
+ angle.unsqueeze(-1).repeat(1,4)>eps,quat_from_angle_axis(angle,axis),identity_quat
+ )
+ # TODO: Check if this is the correct order for this multiplication.
+ target_rot=quat_mul(rot_delta_quat,source_rot)
+
+ returntarget_pos,target_rot
+
+
+# @torch.jit.script
+
[文档]deftransform_points(
+ points:torch.Tensor,pos:torch.Tensor|None=None,quat:torch.Tensor|None=None
+)->torch.Tensor:
+r"""Transform input points in a given frame to a target frame.
+
+ This function transform points from a source frame to a target frame. The transformation is defined by the
+ position :math:`t` and orientation :math:`R` of the target frame in the source frame.
+
+ .. math::
+ p_{target} = R_{target} \times p_{source} + t_{target}
+
+ If the input `points` is a batch of points, the inputs `pos` and `quat` must be either a batch of
+ positions and quaternions or a single position and quaternion. If the inputs `pos` and `quat` are
+ a single position and quaternion, the same transformation is applied to all points in the batch.
+
+ If either the inputs :attr:`pos` and :attr:`quat` are None, the corresponding transformation is not applied.
+
+ Args:
+ points: Points to transform. Shape is (N, P, 3) or (P, 3).
+ pos: Position of the target frame. Shape is (N, 3) or (3,).
+ Defaults to None, in which case the position is assumed to be zero.
+ quat: Quaternion orientation of the target frame in (w, x, y, z). Shape is (N, 4) or (4,).
+ Defaults to None, in which case the orientation is assumed to be identity.
+
+ Returns:
+ Transformed points in the target frame. Shape is (N, P, 3) or (P, 3).
+
+ Raises:
+ ValueError: If the inputs `points` is not of shape (N, P, 3) or (P, 3).
+ ValueError: If the inputs `pos` is not of shape (N, 3) or (3,).
+ ValueError: If the inputs `quat` is not of shape (N, 4) or (4,).
+ """
+ points_batch=points.clone()
+ # check if inputs are batched
+ is_batched=points_batch.dim()==3
+ # -- check inputs
+ ifpoints_batch.dim()==2:
+ points_batch=points_batch[None]# (P, 3) -> (1, P, 3)
+ ifpoints_batch.dim()!=3:
+ raiseValueError(f"Expected points to have dim = 2 or dim = 3: got shape {points.shape}")
+ ifnot(posisNoneorpos.dim()==1orpos.dim()==2):
+ raiseValueError(f"Expected pos to have dim = 1 or dim = 2: got shape {pos.shape}")
+ ifnot(quatisNoneorquat.dim()==1orquat.dim()==2):
+ raiseValueError(f"Expected quat to have dim = 1 or dim = 2: got shape {quat.shape}")
+ # -- rotation
+ ifquatisnotNone:
+ # convert to batched rotation matrix
+ rot_mat=matrix_from_quat(quat)
+ ifrot_mat.dim()==2:
+ rot_mat=rot_mat[None]# (3, 3) -> (1, 3, 3)
+ # convert points to matching batch size (N, P, 3) -> (N, 3, P)
+ # and apply rotation
+ points_batch=torch.matmul(rot_mat,points_batch.transpose_(1,2))
+ # (N, 3, P) -> (N, P, 3)
+ points_batch=points_batch.transpose_(1,2)
+ # -- translation
+ ifposisnotNone:
+ # convert to batched translation vector
+ ifpos.dim()==1:
+ pos=pos[None,None,:]# (3,) -> (1, 1, 3)
+ else:
+ pos=pos[:,None,:]# (N, 3) -> (N, 1, 3)
+ # apply translation
+ points_batch+=pos
+ # -- return points in same shape as input
+ ifnotis_batched:
+ points_batch=points_batch.squeeze(0)# (1, P, 3) -> (P, 3)
+
+ returnpoints_batch
+
+
+"""
+Projection operations.
+"""
+
+
+@torch.jit.script
+defunproject_depth(depth:torch.Tensor,intrinsics:torch.Tensor)->torch.Tensor:
+r"""Unproject depth image into a pointcloud. This method assumes that depth
+ is provided orthogonally relative to the image plane, as opposed to absolutely relative to the camera's
+ principal point (perspective depth). To unproject a perspective depth image, use
+ :meth:`convert_perspective_depth_to_orthogonal_depth` to convert
+ to an orthogonal depth image prior to calling this method. Otherwise, the
+ created point cloud will be distorted, especially around the edges.
+
+ This function converts depth images into points given the calibration matrix of the camera.
+
+ .. math::
+ p_{3D} = K^{-1} \times [u, v, 1]^T \times d
+
+ where :math:`p_{3D}` is the 3D point, :math:`d` is the depth value, :math:`u` and :math:`v` are
+ the pixel coordinates and :math:`K` is the intrinsic matrix.
+
+ If `depth` is a batch of depth images and `intrinsics` is a single intrinsic matrix, the same
+ calibration matrix is applied to all depth images in the batch.
+
+ The function assumes that the width and height are both greater than 1. This makes the function
+ deal with many possible shapes of depth images and intrinsics matrices.
+
+ Args:
+ depth: The depth measurement. Shape is (H, W) or or (H, W, 1) or (N, H, W) or (N, H, W, 1).
+ intrinsics: A tensor providing camera's calibration matrix. Shape is (3, 3) or (N, 3, 3).
+
+ Returns:
+ The 3D coordinates of points. Shape is (P, 3) or (N, P, 3).
+
+ Raises:
+ ValueError: When depth is not of shape (H, W) or (H, W, 1) or (N, H, W) or (N, H, W, 1).
+ ValueError: When intrinsics is not of shape (3, 3) or (N, 3, 3).
+ """
+ depth_batch=depth.clone()
+ intrinsics_batch=intrinsics.clone()
+ # check if inputs are batched
+ is_batched=depth_batch.dim()==4or(depth_batch.dim()==3anddepth_batch.shape[-1]!=1)
+ # make sure inputs are batched
+ ifdepth_batch.dim()==3anddepth_batch.shape[-1]==1:
+ depth_batch=depth_batch.squeeze(dim=2)# (H, W, 1) -> (H, W)
+ ifdepth_batch.dim()==2:
+ depth_batch=depth_batch[None]# (H, W) -> (1, H, W)
+ ifdepth_batch.dim()==4anddepth_batch.shape[-1]==1:
+ depth_batch=depth_batch.squeeze(dim=3)# (N, H, W, 1) -> (N, H, W)
+ ifintrinsics_batch.dim()==2:
+ intrinsics_batch=intrinsics_batch[None]# (3, 3) -> (1, 3, 3)
+ # check shape of inputs
+ ifdepth_batch.dim()!=3:
+ raiseValueError(f"Expected depth images to have dim = 2 or 3 or 4: got shape {depth.shape}")
+ ifintrinsics_batch.dim()!=3:
+ raiseValueError(f"Expected intrinsics to have shape (3, 3) or (N, 3, 3): got shape {intrinsics.shape}")
+
+ # get image height and width
+ im_height,im_width=depth_batch.shape[1:]
+ # create image points in homogeneous coordinates (3, H x W)
+ indices_u=torch.arange(im_width,device=depth.device,dtype=depth.dtype)
+ indices_v=torch.arange(im_height,device=depth.device,dtype=depth.dtype)
+ img_indices=torch.stack(torch.meshgrid([indices_u,indices_v],indexing="ij"),dim=0).reshape(2,-1)
+ pixels=torch.nn.functional.pad(img_indices,(0,0,0,1),mode="constant",value=1.0)
+ pixels=pixels.unsqueeze(0)# (3, H x W) -> (1, 3, H x W)
+
+ # unproject points into 3D space
+ points=torch.matmul(torch.inverse(intrinsics_batch),pixels)# (N, 3, H x W)
+ points=points/points[:,-1,:].unsqueeze(1)# normalize by last coordinate
+ # flatten depth image (N, H, W) -> (N, H x W)
+ depth_batch=depth_batch.transpose_(1,2).reshape(depth_batch.shape[0],-1).unsqueeze(2)
+ depth_batch=depth_batch.expand(-1,-1,3)
+ # scale points by depth
+ points_xyz=points.transpose_(1,2)*depth_batch# (N, H x W, 3)
+
+ # return points in same shape as input
+ ifnotis_batched:
+ points_xyz=points_xyz.squeeze(0)
+
+ returnpoints_xyz
+
+
+@torch.jit.script
+defconvert_perspective_depth_to_orthogonal_depth(
+ perspective_depth:torch.Tensor,intrinsics:torch.Tensor
+)->torch.Tensor:
+r"""Provided depth image(s) where depth is provided as the distance to the principal
+ point of the camera (perspective depth), this function converts it so that depth
+ is provided as the distance to the camera's image plane (orthogonal depth).
+
+ This is helpful because `unproject_depth` assumes that depth is expressed in
+ the orthogonal depth format.
+
+ If `perspective_depth` is a batch of depth images and `intrinsics` is a single intrinsic matrix,
+ the same calibration matrix is applied to all depth images in the batch.
+
+ The function assumes that the width and height are both greater than 1.
+
+ Args:
+ perspective_depth: The depth measurement obtained with the distance_to_camera replicator.
+ Shape is (H, W) or or (H, W, 1) or (N, H, W) or (N, H, W, 1).
+ intrinsics: A tensor providing camera's calibration matrix. Shape is (3, 3) or (N, 3, 3).
+
+ Returns:
+ The depth image as if obtained by the distance_to_image_plane replicator. Shape
+ matches the input shape of depth
+
+ Raises:
+ ValueError: When depth is not of shape (H, W) or (H, W, 1) or (N, H, W) or (N, H, W, 1).
+ ValueError: When intrinsics is not of shape (3, 3) or (N, 3, 3).
+ """
+
+ # Clone inputs to avoid in-place modifications
+ perspective_depth_batch=perspective_depth.clone()
+ intrinsics_batch=intrinsics.clone()
+
+ # Check if inputs are batched
+ is_batched=perspective_depth_batch.dim()==4or(
+ perspective_depth_batch.dim()==3andperspective_depth_batch.shape[-1]!=1
+ )
+
+ # Track whether the last dimension was singleton
+ add_last_dim=False
+ ifperspective_depth_batch.dim()==4andperspective_depth_batch.shape[-1]==1:
+ add_last_dim=True
+ perspective_depth_batch=perspective_depth_batch.squeeze(dim=3)# (N, H, W, 1) -> (N, H, W)
+ ifperspective_depth_batch.dim()==3andperspective_depth_batch.shape[-1]==1:
+ add_last_dim=True
+ perspective_depth_batch=perspective_depth_batch.squeeze(dim=2)# (H, W, 1) -> (H, W)
+
+ ifperspective_depth_batch.dim()==2:
+ perspective_depth_batch=perspective_depth_batch[None]# (H, W) -> (1, H, W)
+
+ ifintrinsics_batch.dim()==2:
+ intrinsics_batch=intrinsics_batch[None]# (3, 3) -> (1, 3, 3)
+
+ ifis_batchedandintrinsics_batch.shape[0]==1:
+ intrinsics_batch=intrinsics_batch.expand(perspective_depth_batch.shape[0],-1,-1)# (1, 3, 3) -> (N, 3, 3)
+
+ # Validate input shapes
+ ifperspective_depth_batch.dim()!=3:
+ raiseValueError(f"Expected perspective_depth to have 2, 3, or 4 dimensions; got {perspective_depth.shape}.")
+ ifintrinsics_batch.dim()!=3:
+ raiseValueError(f"Expected intrinsics to have shape (3, 3) or (N, 3, 3); got {intrinsics.shape}.")
+
+ # Image dimensions
+ im_height,im_width=perspective_depth_batch.shape[1:]
+
+ # Get the intrinsics parameters
+ fx=intrinsics_batch[:,0,0].view(-1,1,1)
+ fy=intrinsics_batch[:,1,1].view(-1,1,1)
+ cx=intrinsics_batch[:,0,2].view(-1,1,1)
+ cy=intrinsics_batch[:,1,2].view(-1,1,1)
+
+ # Create meshgrid of pixel coordinates
+ u_grid=torch.arange(im_width,device=perspective_depth.device,dtype=perspective_depth.dtype)
+ v_grid=torch.arange(im_height,device=perspective_depth.device,dtype=perspective_depth.dtype)
+ u_grid,v_grid=torch.meshgrid(u_grid,v_grid,indexing="xy")
+
+ # Expand the grids for batch processing
+ u_grid=u_grid.unsqueeze(0).expand(perspective_depth_batch.shape[0],-1,-1)
+ v_grid=v_grid.unsqueeze(0).expand(perspective_depth_batch.shape[0],-1,-1)
+
+ # Compute the squared terms for efficiency
+ x_term=((u_grid-cx)/fx)**2
+ y_term=((v_grid-cy)/fy)**2
+
+ # Calculate the orthogonal (normal) depth
+ normal_depth=perspective_depth_batch/torch.sqrt(1+x_term+y_term)
+
+ # Restore the last dimension if it was present in the input
+ ifadd_last_dim:
+ normal_depth=normal_depth.unsqueeze(-1)
+
+ # Return to original shape if input was not batched
+ ifnotis_batched:
+ normal_depth=normal_depth.squeeze(0)
+
+ returnnormal_depth
+
+
+@torch.jit.script
+defproject_points(points:torch.Tensor,intrinsics:torch.Tensor)->torch.Tensor:
+r"""Projects 3D points into 2D image plane.
+
+ This project 3D points into a 2D image plane. The transformation is defined by the intrinsic
+ matrix of the camera.
+
+ .. math::
+
+ \begin{align}
+ p &= K \times p_{3D} = \\
+ p_{2D} &= \begin{pmatrix} u \\ v \\ d \end{pmatrix}
+ = \begin{pmatrix} p[0] / p[2] \\ p[1] / p[2] \\ Z \end{pmatrix}
+ \end{align}
+
+ where :math:`p_{2D} = (u, v, d)` is the projected 3D point, :math:`p_{3D} = (X, Y, Z)` is the
+ 3D point and :math:`K \in \mathbb{R}^{3 \times 3}` is the intrinsic matrix.
+
+ If `points` is a batch of 3D points and `intrinsics` is a single intrinsic matrix, the same
+ calibration matrix is applied to all points in the batch.
+
+ Args:
+ points: The 3D coordinates of points. Shape is (P, 3) or (N, P, 3).
+ intrinsics: Camera's calibration matrix. Shape is (3, 3) or (N, 3, 3).
+
+ Returns:
+ Projected 3D coordinates of points. Shape is (P, 3) or (N, P, 3).
+ """
+ points_batch=points.clone()
+ intrinsics_batch=intrinsics.clone()
+ # check if inputs are batched
+ is_batched=points_batch.dim()==2
+ # make sure inputs are batched
+ ifpoints_batch.dim()==2:
+ points_batch=points_batch[None]# (P, 3) -> (1, P, 3)
+ ifintrinsics_batch.dim()==2:
+ intrinsics_batch=intrinsics_batch[None]# (3, 3) -> (1, 3, 3)
+ # check shape of inputs
+ ifpoints_batch.dim()!=3:
+ raiseValueError(f"Expected points to have dim = 3: got shape {points.shape}.")
+ ifintrinsics_batch.dim()!=3:
+ raiseValueError(f"Expected intrinsics to have shape (3, 3) or (N, 3, 3): got shape {intrinsics.shape}.")
+ # project points into 2D image plane
+ points_2d=torch.matmul(intrinsics_batch,points_batch.transpose(1,2))
+ points_2d=points_2d/points_2d[:,-1,:].unsqueeze(1)# normalize by last coordinate
+ points_2d=points_2d.transpose_(1,2)# (N, 3, P) -> (N, P, 3)
+ # replace last coordinate with depth
+ points_2d[:,:,-1]=points_batch[:,:,-1]
+ # return points in same shape as input
+ ifnotis_batched:
+ points_2d=points_2d.squeeze(0)# (1, 3, P) -> (3, P)
+
+ returnpoints_2d
+
+
+"""
+Sampling
+"""
+
+
+@torch.jit.script
+defdefault_orientation(num:int,device:str)->torch.Tensor:
+"""Returns identity rotation transform.
+
+ Args:
+ num: The number of rotations to sample.
+ device: Device to create tensor on.
+
+ Returns:
+ Identity quaternion in (w, x, y, z). Shape is (num, 4).
+ """
+ quat=torch.zeros((num,4),dtype=torch.float,device=device)
+ quat[...,0]=1.0
+
+ returnquat
+
+
+@torch.jit.script
+defrandom_orientation(num:int,device:str)->torch.Tensor:
+"""Returns sampled rotation in 3D as quaternion.
+
+ Args:
+ num: The number of rotations to sample.
+ device: Device to create tensor on.
+
+ Returns:
+ Sampled quaternion in (w, x, y, z). Shape is (num, 4).
+
+ Reference:
+ https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.transform.Rotation.random.html
+ """
+ # sample random orientation from normal distribution
+ quat=torch.randn((num,4),dtype=torch.float,device=device)
+ # normalize the quaternion
+ returntorch.nn.functional.normalize(quat,p=2.0,dim=-1,eps=1e-12)
+
+
+@torch.jit.script
+defrandom_yaw_orientation(num:int,device:str)->torch.Tensor:
+"""Returns sampled rotation around z-axis.
+
+ Args:
+ num: The number of rotations to sample.
+ device: Device to create tensor on.
+
+ Returns:
+ Sampled quaternion in (w, x, y, z). Shape is (num, 4).
+ """
+ roll=torch.zeros(num,dtype=torch.float,device=device)
+ pitch=torch.zeros(num,dtype=torch.float,device=device)
+ yaw=2*torch.pi*torch.rand(num,dtype=torch.float,device=device)
+
+ returnquat_from_euler_xyz(roll,pitch,yaw)
+
+
+
[文档]defsample_triangle(lower:float,upper:float,size:int|tuple[int,...],device:str)->torch.Tensor:
+"""Randomly samples tensor from a triangular distribution.
+
+ Args:
+ lower: The lower range of the sampled tensor.
+ upper: The upper range of the sampled tensor.
+ size: The shape of the tensor.
+ device: Device to create tensor on.
+
+ Returns:
+ Sampled tensor. Shape is based on :attr:`size`.
+ """
+ # convert to tuple
+ ifisinstance(size,int):
+ size=(size,)
+ # create random tensor in the range [-1, 1]
+ r=2*torch.rand(*size,device=device)-1
+ # convert to triangular distribution
+ r=torch.where(r<0.0,-torch.sqrt(-r),torch.sqrt(r))
+ # rescale back to [0, 1]
+ r=(r+1.0)/2.0
+ # rescale to range [lower, upper]
+ return(upper-lower)*r+lower
+
+
+
[文档]defsample_uniform(
+ lower:torch.Tensor|float,upper:torch.Tensor|float,size:int|tuple[int,...],device:str
+)->torch.Tensor:
+"""Sample uniformly within a range.
+
+ Args:
+ lower: Lower bound of uniform range.
+ upper: Upper bound of uniform range.
+ size: The shape of the tensor.
+ device: Device to create tensor on.
+
+ Returns:
+ Sampled tensor. Shape is based on :attr:`size`.
+ """
+ # convert to tuple
+ ifisinstance(size,int):
+ size=(size,)
+ # return tensor
+ returntorch.rand(*size,device=device)*(upper-lower)+lower
+
+
+
[文档]defsample_log_uniform(
+ lower:torch.Tensor|float,upper:torch.Tensor|float,size:int|tuple[int,...],device:str
+)->torch.Tensor:
+r"""Sample using log-uniform distribution within a range.
+
+ The log-uniform distribution is defined as a uniform distribution in the log-space. It
+ is useful for sampling values that span several orders of magnitude. The sampled values
+ are uniformly distributed in the log-space and then exponentiated to get the final values.
+
+ .. math::
+
+ x = \exp(\text{uniform}(\log(\text{lower}), \log(\text{upper})))
+
+ Args:
+ lower: Lower bound of uniform range.
+ upper: Upper bound of uniform range.
+ size: The shape of the tensor.
+ device: Device to create tensor on.
+
+ Returns:
+ Sampled tensor. Shape is based on :attr:`size`.
+ """
+ # cast to tensor if not already
+ ifnotisinstance(lower,torch.Tensor):
+ lower=torch.tensor(lower,dtype=torch.float,device=device)
+ ifnotisinstance(upper,torch.Tensor):
+ upper=torch.tensor(upper,dtype=torch.float,device=device)
+ # sample in log-space and exponentiate
+ returntorch.exp(sample_uniform(torch.log(lower),torch.log(upper),size,device))
+
+
+
[文档]defsample_gaussian(
+ mean:torch.Tensor|float,std:torch.Tensor|float,size:int|tuple[int,...],device:str
+)->torch.Tensor:
+"""Sample using gaussian distribution.
+
+ Args:
+ mean: Mean of the gaussian.
+ std: Std of the gaussian.
+ size: The shape of the tensor.
+ device: Device to create tensor on.
+
+ Returns:
+ Sampled tensor.
+ """
+ ifisinstance(mean,float):
+ ifisinstance(size,int):
+ size=(size,)
+ returntorch.normal(mean=mean,std=std,size=size).to(device=device)
+ else:
+ returntorch.normal(mean=mean,std=std).to(device=device)
+
+
+
[文档]defsample_cylinder(
+ radius:float,h_range:tuple[float,float],size:int|tuple[int,...],device:str
+)->torch.Tensor:
+"""Sample 3D points uniformly on a cylinder's surface.
+
+ The cylinder is centered at the origin and aligned with the z-axis. The height of the cylinder is
+ sampled uniformly from the range :obj:`h_range`, while the radius is fixed to :obj:`radius`.
+
+ The sampled points are returned as a tensor of shape :obj:`(*size, 3)`, i.e. the last dimension
+ contains the x, y, and z coordinates of the sampled points.
+
+ Args:
+ radius: The radius of the cylinder.
+ h_range: The minimum and maximum height of the cylinder.
+ size: The shape of the tensor.
+ device: Device to create tensor on.
+
+ Returns:
+ Sampled tensor. Shape is :obj:`(*size, 3)`.
+ """
+ # sample angles
+ angles=(torch.rand(size,device=device)*2-1)*torch.pi
+ h_min,h_max=h_range
+ # add shape
+ ifisinstance(size,int):
+ size=(size,3)
+ else:
+ size+=(3,)
+ # allocate a tensor
+ xyz=torch.zeros(size,device=device)
+ xyz[...,0]=radius*torch.cos(angles)
+ xyz[...,1]=radius*torch.sin(angles)
+ xyz[...,2].uniform_(h_min,h_max)
+ # return positions
+ returnxyz
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+from__future__importannotations
+
+importtorch
+fromcollections.abcimportSequence
+fromtypingimportTYPE_CHECKING
+
+from.modifier_baseimportModifierBase
+
+ifTYPE_CHECKING:
+ from.importmodifier_cfg
+
+##
+# Modifiers as functions
+##
+
+
+
[文档]defscale(data:torch.Tensor,multiplier:float)->torch.Tensor:
+"""Scales input data by a multiplier.
+
+ Args:
+ data: The data to apply the scale to.
+ multiplier: Value to scale input by.
+
+ Returns:
+ Scaled data. Shape is the same as data.
+ """
+ returndata*multiplier
+
+
+
[文档]defclip(data:torch.Tensor,bounds:tuple[float|None,float|None])->torch.Tensor:
+"""Clips the data to a minimum and maximum value.
+
+ Args:
+ data: The data to apply the clip to.
+ bounds: A tuple containing the minimum and maximum values to clip data to.
+ If the value is None, that bound is not applied.
+
+ Returns:
+ Clipped data. Shape is the same as data.
+ """
+ returndata.clip(min=bounds[0],max=bounds[1])
+
+
+
[文档]defbias(data:torch.Tensor,value:float)->torch.Tensor:
+"""Adds a uniform bias to the data.
+
+ Args:
+ data: The data to add bias to.
+ value: Value of bias to add to data.
+
+ Returns:
+ Biased data. Shape is the same as data.
+ """
+ returndata+value
+
+
+##
+# Sample of class based modifiers
+##
+
+
+
[文档]classDigitalFilter(ModifierBase):
+r"""Modifier used to apply digital filtering to the input data.
+
+ `Digital filters <https://en.wikipedia.org/wiki/Digital_filter>`_ are used to process discrete-time
+ signals to extract useful parts of the signal, such as smoothing, noise reduction, or frequency separation.
+
+ The filter can be implemented as a linear difference equation in the time domain. This equation
+ can be used to calculate the output at each time-step based on the current and previous inputs and outputs.
+
+ .. math::
+ y_{i} = X B - Y A = \sum_{j=0}^{N} b_j x_{i-j} - \sum_{j=1}^{M} a_j y_{i-j}
+
+ where :math:`y_{i}` is the current output of the filter. The array :math:`Y` contains previous
+ outputs from the filter :math:`\{y_{i-j}\}_{j=1}^M` for :math:`M` previous time-steps. The array
+ :math:`X` contains current :math:`x_{i}` and previous inputs to the filter
+ :math:`\{x_{i-j}\}_{j=1}^N` for :math:`N` previous time-steps respectively.
+ The filter coefficients :math:`A` and :math:`B` are used to design the filter. They are column vectors of
+ length :math:`M` and :math:`N + 1` respectively.
+
+ Different types of filters can be implemented by choosing different values for :math:`A` and :math:`B`.
+ We provide some examples below.
+
+ Examples
+ ^^^^^^^^
+
+ **Unit Delay Filter**
+
+ A filter that delays the input signal by a single time-step simply outputs the previous input value.
+
+ .. math:: y_{i} = x_{i-1}
+
+ This can be implemented as a digital filter with the coefficients :math:`A = [0.0]` and :math:`B = [0.0, 1.0]`.
+
+ **Moving Average Filter**
+
+ A moving average filter is used to smooth out noise in a signal. It is similar to a low-pass filter
+ but has a finite impulse response (FIR) and is non-recursive.
+
+ The filter calculates the average of the input signal over a window of time-steps. The linear difference
+ equation for a moving average filter is:
+
+ .. math:: y_{i} = \frac{1}{N} \sum_{j=0}^{N} x_{i-j}
+
+ This can be implemented as a digital filter with the coefficients :math:`A = [0.0]` and
+ :math:`B = [1/N, 1/N, \cdots, 1/N]`.
+
+ **First-order recursive low-pass filter**
+
+ A recursive low-pass filter is used to smooth out high-frequency noise in a signal. It is a first-order
+ infinite impulse response (IIR) filter which means it has a recursive component (previous output) in the
+ linear difference equation.
+
+ A first-order low-pass IIR filter has the difference equation:
+
+ .. math:: y_{i} = \alpha y_{i-1} + (1-\alpha)x_{i}
+
+ where :math:`\alpha` is a smoothing parameter between 0 and 1. Typically, the value of :math:`\alpha` is
+ chosen based on the desired cut-off frequency of the filter.
+
+ This filter can be implemented as a digital filter with the coefficients :math:`A = [\alpha]` and
+ :math:`B = [1 - \alpha]`.
+ """
+
+ def__init__(self,cfg:modifier_cfg.DigitalFilterCfg,data_dim:tuple[int,...],device:str)->None:
+"""Initializes digital filter.
+
+ Args:
+ cfg: Configuration parameters.
+ data_dim: The dimensions of the data to be modified. First element is the batch size
+ which usually corresponds to number of environments in the simulation.
+ device: The device to run the modifier on.
+
+ Raises:
+ ValueError: If filter coefficients are None.
+ """
+ # check that filter coefficients are not None
+ ifcfg.AisNoneorcfg.BisNone:
+ raiseValueError("Digital filter coefficients A and B must not be None. Please provide valid coefficients.")
+
+ # initialize parent class
+ super().__init__(cfg,data_dim,device)
+
+ # assign filter coefficients and make sure they are column vectors
+ self.A=torch.tensor(self._cfg.A,device=self._device).unsqueeze(1)
+ self.B=torch.tensor(self._cfg.B,device=self._device).unsqueeze(1)
+
+ # create buffer for input and output history
+ self.x_n=torch.zeros(self._data_dim+(self.B.shape[0],),device=self._device)
+ self.y_n=torch.zeros(self._data_dim+(self.A.shape[0],),device=self._device)
+
+
[文档]defreset(self,env_ids:Sequence[int]|None=None):
+"""Resets digital filter history.
+
+ Args:
+ env_ids: The environment ids. Defaults to None, in which case
+ all environments are considered.
+ """
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ # reset history buffers
+ self.x_n[env_ids]=0.0
+ self.y_n[env_ids]=0.0
+
+
[文档]def__call__(self,data:torch.Tensor)->torch.Tensor:
+"""Applies digital filter modification with a rolling history window inputs and outputs.
+
+ Args:
+ data: The data to apply filter to.
+
+ Returns:
+ Filtered data. Shape is the same as data.
+ """
+ # move history window for input
+ self.x_n=torch.roll(self.x_n,shifts=1,dims=-1)
+ self.x_n[...,0]=data
+
+ # calculate current filter value: y[i] = Y*A - X*B
+ y_i=torch.matmul(self.x_n,self.B)-torch.matmul(self.y_n,self.A)
+ y_i.squeeze_(-1)
+
+ # move history window for output and add current filter value to history
+ self.y_n=torch.roll(self.y_n,shifts=1,dims=-1)
+ self.y_n[...,0]=y_i
+
+ returny_i
+
+
+
[文档]classIntegrator(ModifierBase):
+r"""Modifier that applies a numerical forward integration based on a middle Reimann sum.
+
+ An integrator is used to calculate the integral of a signal over time. The integral of a signal
+ is the area under the curve of the signal. The integral can be approximated using numerical methods
+ such as the `Riemann sum <https://en.wikipedia.org/wiki/Riemann_sum>`_.
+
+ The middle Riemann sum is a method to approximate the integral of a function by dividing the area
+ under the curve into rectangles. The height of each rectangle is the value of the function at the
+ midpoint of the interval. The area of each rectangle is the width of the interval multiplied by the
+ height of the rectangle.
+
+ This integral method is useful for signals that are sampled at regular intervals. The integral
+ can be written as:
+
+ .. math::
+ \int_{t_0}^{t_n} f(t) dt & \approx \int_{t_0}^{t_{n-1}} f(t) dt + \frac{f(t_{n-1}) + f(t_n)}{2} \Delta t
+
+ where :math:`f(t)` is the signal to integrate, :math:`t_i` is the time at the i-th sample, and
+ :math:`\Delta t` is the time step between samples.
+ """
+
+ def__init__(self,cfg:modifier_cfg.IntegratorCfg,data_dim:tuple[int,...],device:str):
+"""Initializes the integrator configuration and state.
+
+ Args:
+ cfg: Integral parameters.
+ data_dim: The dimensions of the data to be modified. First element is the batch size
+ which usually corresponds to number of environments in the simulation.
+ device: The device to run the modifier on.
+ """
+ # initialize parent class
+ super().__init__(cfg,data_dim,device)
+
+ # assign buffer for integral and previous value
+ self.integral=torch.zeros(self._data_dim,device=self._device)
+ self.y_prev=torch.zeros(self._data_dim,device=self._device)
+
+
[文档]defreset(self,env_ids:Sequence[int]|None=None):
+"""Resets integrator state to zero.
+
+ Args:
+ env_ids: The environment ids. Defaults to None, in which case
+ all environments are considered.
+ """
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ # reset history buffers
+ self.integral[env_ids]=0.0
+ self.y_prev[env_ids]=0.0
+
+
[文档]def__call__(self,data:torch.Tensor)->torch.Tensor:
+"""Applies integral modification to input data.
+
+ Args:
+ data: The data to integrate.
+
+ Returns:
+ Integral of input signal. Shape is the same as data.
+ """
+ # integrate using middle Riemann sum
+ self.integral+=(data+self.y_prev)/2*self._cfg.dt
+ # update previous value
+ self.y_prev[:]=data
+
+ returnself.integral
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+from__future__importannotations
+
+importtorch
+fromabcimportABC,abstractmethod
+fromcollections.abcimportSequence
+fromtypingimportTYPE_CHECKING
+
+ifTYPE_CHECKING:
+ from.modifier_cfgimportModifierCfg
+
+
+
[文档]classModifierBase(ABC):
+"""Base class for modifiers implemented as classes.
+
+ Modifiers implementations can be functions or classes. If a modifier is a class, it should
+ inherit from this class and implement the required methods.
+
+ A class implementation of a modifier can be used to store state information between calls.
+ This is useful for modifiers that require stateful operations, such as rolling averages
+ or delays or decaying filters.
+
+ Example pseudo-code to create and use the class:
+
+ .. code-block:: python
+
+ from omni.isaac.lab.utils import modifiers
+
+ # define custom keyword arguments to pass to ModifierCfg
+ kwarg_dict = {"arg_1" : VAL_1, "arg_2" : VAL_2}
+
+ # create modifier configuration object
+ # func is the class name of the modifier and params is the dictionary of arguments
+ modifier_config = modifiers.ModifierCfg(func=modifiers.ModifierBase, params=kwarg_dict)
+
+ # define modifier instance
+ my_modifier = modifiers.ModifierBase(cfg=modifier_config)
+
+ """
+
+ def__init__(self,cfg:ModifierCfg,data_dim:tuple[int,...],device:str)->None:
+"""Initializes the modifier class.
+
+ Args:
+ cfg: Configuration parameters.
+ data_dim: The dimensions of the data to be modified. First element is the batch size
+ which usually corresponds to number of environments in the simulation.
+ device: The device to run the modifier on.
+ """
+ self._cfg=cfg
+ self._data_dim=data_dim
+ self._device=device
+
+
[文档]@abstractmethod
+ defreset(self,env_ids:Sequence[int]|None=None):
+"""Resets the Modifier.
+
+ Args:
+ env_ids: The environment ids. Defaults to None, in which case
+ all environments are considered.
+ """
+ raiseNotImplementedError
+
+
[文档]@abstractmethod
+ def__call__(self,data:torch.Tensor)->torch.Tensor:
+"""Abstract method for defining the modification function.
+
+ Args:
+ data: The data to be modified. Shape should match the data_dim passed during initialization.
+
+ Returns:
+ Modified data. Shape is the same as the input data.
+ """
+ raiseNotImplementedError
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+importtorch
+fromcollections.abcimportCallable
+fromdataclassesimportMISSING
+fromtypingimportAny
+
+fromomni.isaac.lab.utilsimportconfigclass
+
+from.importmodifier
+
+
+
[文档]@configclass
+classModifierCfg:
+"""Configuration parameters modifiers"""
+
+ func:Callable[...,torch.Tensor]=MISSING
+"""Function or callable class used by modifier.
+
+ The function must take a torch tensor as the first argument. The remaining arguments are specified
+ in the :attr:`params` attribute.
+
+ It also supports `callable classes <https://docs.python.org/3/reference/datamodel.html#object.__call__>`_,
+ i.e. classes that implement the ``__call__()`` method. In this case, the class should inherit from the
+ :class:`ModifierBase` class and implement the required methods.
+ """
+
+ params:dict[str,Any]=dict()
+"""The parameters to be passed to the function or callable class as keyword arguments. Defaults to
+ an empty dictionary."""
+
+
+
[文档]@configclass
+classDigitalFilterCfg(ModifierCfg):
+"""Configuration parameters for a digital filter modifier.
+
+ For more information, please check the :class:`DigitalFilter` class.
+ """
+
+ func:type[modifier.DigitalFilter]=modifier.DigitalFilter
+"""The digital filter function to be called for applying the filter."""
+
+ A:list[float]=MISSING
+"""The coefficients corresponding the the filter's response to past outputs.
+
+ These correspond to the weights of the past outputs of the filter. The first element is the coefficient
+ for the output at the previous time step, the second element is the coefficient for the output at two
+ time steps ago, and so on.
+
+ It is the denominator coefficients of the transfer function of the filter.
+ """
+
+ B:list[float]=MISSING
+"""The coefficients corresponding the the filter's response to current and past inputs.
+
+ These correspond to the weights of the current and past inputs of the filter. The first element is the
+ coefficient for the current input, the second element is the coefficient for the input at the previous
+ time step, and so on.
+
+ It is the numerator coefficients of the transfer function of the filter.
+ """
+
+
+
[文档]@configclass
+classIntegratorCfg(ModifierCfg):
+"""Configuration parameters for an integrator modifier.
+
+ For more information, please check the :class:`Integrator` class.
+ """
+
+ func:type[modifier.Integrator]=modifier.Integrator
+"""The integrator function to be called for applying the integrator."""
+
+ dt:float=MISSING
+"""The time step of the integrator."""
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+from__future__importannotations
+
+importtorch
+fromcollections.abcimportCallable
+fromdataclassesimportMISSING
+fromtypingimportLiteral
+
+fromomni.isaac.lab.utilsimportconfigclass
+
+from.importnoise_model
+
+
+
[文档]@configclass
+classNoiseCfg:
+"""Base configuration for a noise term."""
+
+ func:Callable[[torch.Tensor,NoiseCfg],torch.Tensor]=MISSING
+"""The function to be called for applying the noise.
+
+ Note:
+ The shape of the input and output tensors must be the same.
+ """
+ operation:Literal["add","scale","abs"]="add"
+"""The operation to apply the noise on the data. Defaults to "add"."""
+
+
+
[文档]@configclass
+classConstantNoiseCfg(NoiseCfg):
+"""Configuration for an additive constant noise term."""
+
+ func=noise_model.constant_noise
+
+ bias:torch.Tensor|float=0.0
+"""The bias to add. Defaults to 0.0."""
+
+
+
[文档]@configclass
+classUniformNoiseCfg(NoiseCfg):
+"""Configuration for a additive uniform noise term."""
+
+ func=noise_model.uniform_noise
+
+ n_min:torch.Tensor|float=-1.0
+"""The minimum value of the noise. Defaults to -1.0."""
+ n_max:torch.Tensor|float=1.0
+"""The maximum value of the noise. Defaults to 1.0."""
+
+
+
[文档]@configclass
+classGaussianNoiseCfg(NoiseCfg):
+"""Configuration for an additive gaussian noise term."""
+
+ func=noise_model.gaussian_noise
+
+ mean:torch.Tensor|float=0.0
+"""The mean of the noise. Defaults to 0.0."""
+ std:torch.Tensor|float=1.0
+"""The standard deviation of the noise. Defaults to 1.0."""
+
+
+##
+# Noise models
+##
+
+
+
[文档]@configclass
+classNoiseModelCfg:
+"""Configuration for a noise model."""
+
+ class_type:type=noise_model.NoiseModel
+"""The class type of the noise model."""
+
+ noise_cfg:NoiseCfg=MISSING
+"""The noise configuration to use."""
+
+
+
[文档]@configclass
+classNoiseModelWithAdditiveBiasCfg(NoiseModelCfg):
+"""Configuration for an additive gaussian noise with bias model."""
+
+ class_type:type=noise_model.NoiseModelWithAdditiveBias
+
+ bias_noise_cfg:NoiseCfg=MISSING
+"""The noise configuration for the bias.
+
+ Based on this configuration, the bias is sampled at every reset of the noise model.
+ """
[文档]classNoiseModel:
+"""Base class for noise models."""
+
+ def__init__(self,noise_model_cfg:noise_cfg.NoiseModelCfg,num_envs:int,device:str):
+"""Initialize the noise model.
+
+ Args:
+ noise_model_cfg: The noise configuration to use.
+ num_envs: The number of environments.
+ device: The device to use for the noise model.
+ """
+ self._noise_model_cfg=noise_model_cfg
+ self._num_envs=num_envs
+ self._device=device
+
+
[文档]defreset(self,env_ids:Sequence[int]|None=None):
+"""Reset the noise model.
+
+ This method can be implemented by derived classes to reset the noise model.
+ This is useful when implementing temporal noise models such as random walk.
+
+ Args:
+ env_ids: The environment ids to reset the noise model for. Defaults to None,
+ in which case all environments are considered.
+ """
+ pass
+
+
[文档]defapply(self,data:torch.Tensor)->torch.Tensor:
+"""Apply the noise to the data.
+
+ Args:
+ data: The data to apply the noise to. Shape is (num_envs, ...).
+
+ Returns:
+ The data with the noise applied. Shape is the same as the input data.
+ """
+ returnself._noise_model_cfg.noise_cfg.func(data,self._noise_model_cfg.noise_cfg)
+
+
+
[文档]classNoiseModelWithAdditiveBias(NoiseModel):
+"""Noise model with an additive bias.
+
+ The bias term is sampled from a the specified distribution on reset.
+ """
+
+ def__init__(self,noise_model_cfg:noise_cfg.NoiseModelWithAdditiveBiasCfg,num_envs:int,device:str):
+ # initialize parent class
+ super().__init__(noise_model_cfg,num_envs,device)
+ # store the bias noise configuration
+ self._bias_noise_cfg=noise_model_cfg.bias_noise_cfg
+ self._bias=torch.zeros((num_envs,1),device=self._device)
+
+
[文档]defreset(self,env_ids:Sequence[int]|None=None):
+"""Reset the noise model.
+
+ This method resets the bias term for the specified environments.
+
+ Args:
+ env_ids: The environment ids to reset the noise model for. Defaults to None,
+ in which case all environments are considered.
+ """
+ # resolve the environment ids
+ ifenv_idsisNone:
+ env_ids=slice(None)
+ # reset the bias term
+ self._bias[env_ids]=self._bias_noise_cfg.func(self._bias[env_ids],self._bias_noise_cfg)
+
+
[文档]defapply(self,data:torch.Tensor)->torch.Tensor:
+"""Apply bias noise to the data.
+
+ Args:
+ data: The data to apply the noise to. Shape is (num_envs, ...).
+
+ Returns:
+ The data with the noise applied. Shape is the same as the input data.
+ """
+ returnsuper().apply(data)+self._bias
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Sub-module containing utilities for transforming strings and regular expressions."""
+
+importast
+importimportlib
+importinspect
+importre
+fromcollections.abcimportCallable,Sequence
+fromtypingimportAny
+
+"""
+String formatting.
+"""
+
+
+
[文档]defto_camel_case(snake_str:str,to:str="cC")->str:
+"""Converts a string from snake case to camel case.
+
+ Args:
+ snake_str: A string in snake case (i.e. with '_')
+ to: Convention to convert string to. Defaults to "cC".
+
+ Raises:
+ ValueError: Invalid input argument `to`, i.e. not "cC" or "CC".
+
+ Returns:
+ A string in camel-case format.
+ """
+ # check input is correct
+ iftonotin["cC","CC"]:
+ msg="to_camel_case(): Choose a valid `to` argument (CC or cC)"
+ raiseValueError(msg)
+ # convert string to lower case and split
+ components=snake_str.lower().split("_")
+ ifto=="cC":
+ # We capitalize the first letter of each component except the first one
+ # with the 'title' method and join them together.
+ returncomponents[0]+"".join(x.title()forxincomponents[1:])
+ else:
+ # Capitalize first letter in all the components
+ return"".join(x.title()forxincomponents)
+
+
+
[文档]defto_snake_case(camel_str:str)->str:
+"""Converts a string from camel case to snake case.
+
+ Args:
+ camel_str: A string in camel case.
+
+ Returns:
+ A string in snake case (i.e. with '_')
+ """
+ camel_str=re.sub("(.)([A-Z][a-z]+)",r"\1_\2",camel_str)
+ returnre.sub("([a-z0-9])([A-Z])",r"\1_\2",camel_str).lower()
+
+
+
[文档]defstring_to_slice(s:str):
+"""Convert a string representation of a slice to a slice object.
+
+ Args:
+ s: The string representation of the slice.
+
+ Returns:
+ The slice object.
+ """
+ # extract the content inside the slice()
+ match=re.match(r"slice\((.*),(.*),(.*)\)",s)
+ ifnotmatch:
+ raiseValueError(f"Invalid slice string format: {s}")
+
+ # extract start, stop, and step values
+ start_str,stop_str,step_str=match.groups()
+
+ # convert 'None' to None and other strings to integers
+ start=Noneifstart_str=="None"elseint(start_str)
+ stop=Noneifstop_str=="None"elseint(stop_str)
+ step=Noneifstep_str=="None"elseint(step_str)
+
+ # create and return the slice object
+ returnslice(start,stop,step)
[文档]defis_lambda_expression(name:str)->bool:
+"""Checks if the input string is a lambda expression.
+
+ Args:
+ name: The input string.
+
+ Returns:
+ Whether the input string is a lambda expression.
+ """
+ try:
+ ast.parse(name)
+ returnisinstance(ast.parse(name).body[0],ast.Expr)andisinstance(ast.parse(name).body[0].value,ast.Lambda)
+ exceptSyntaxError:
+ returnFalse
+
+
+
[文档]defcallable_to_string(value:Callable)->str:
+"""Converts a callable object to a string.
+
+ Args:
+ value: A callable object.
+
+ Raises:
+ ValueError: When the input argument is not a callable object.
+
+ Returns:
+ A string representation of the callable object.
+ """
+ # check if callable
+ ifnotcallable(value):
+ raiseValueError(f"The input argument is not callable: {value}.")
+ # check if lambda function
+ ifvalue.__name__=="<lambda>":
+ # we resolve the lambda expression by checking the source code and extracting the line with lambda expression
+ # we also remove any comments from the line
+ lambda_line=inspect.getsourcelines(value)[0][0].strip().split("lambda")[1].strip().split(",")[0]
+ lambda_line=re.sub(r"#.*$","",lambda_line).rstrip()
+ returnf"lambda {lambda_line}"
+ else:
+ # get the module and function name
+ module_name=value.__module__
+ function_name=value.__name__
+ # return the string
+ returnf"{module_name}:{function_name}"
+
+
+
[文档]defstring_to_callable(name:str)->Callable:
+"""Resolves the module and function names to return the function.
+
+ Args:
+ name: The function name. The format should be 'module:attribute_name' or a
+ lambda expression of format: 'lambda x: x'.
+
+ Raises:
+ ValueError: When the resolved attribute is not a function.
+ ValueError: When the module cannot be found.
+
+ Returns:
+ Callable: The function loaded from the module.
+ """
+ try:
+ ifis_lambda_expression(name):
+ callable_object=eval(name)
+ else:
+ mod_name,attr_name=name.split(":")
+ mod=importlib.import_module(mod_name)
+ callable_object=getattr(mod,attr_name)
+ # check if attribute is callable
+ ifcallable(callable_object):
+ returncallable_object
+ else:
+ raiseAttributeError(f"The imported object is not callable: '{name}'")
+ except(ValueError,ModuleNotFoundError)ase:
+ msg=(
+ f"Could not resolve the input string '{name}' into callable object."
+ " The format of input should be 'module:attribute_name'.\n"
+ f"Received the error:\n{e}."
+ )
+ raiseValueError(msg)
+
+
+"""
+Regex operations.
+"""
+
+
+
[文档]defresolve_matching_names(
+ keys:str|Sequence[str],list_of_strings:Sequence[str],preserve_order:bool=False
+)->tuple[list[int],list[str]]:
+"""Match a list of query regular expressions against a list of strings and return the matched indices and names.
+
+ When a list of query regular expressions is provided, the function checks each target string against each
+ query regular expression and returns the indices of the matched strings and the matched strings.
+
+ If the :attr:`preserve_order` is True, the ordering of the matched indices and names is the same as the order
+ of the provided list of strings. This means that the ordering is dictated by the order of the target strings
+ and not the order of the query regular expressions.
+
+ If the :attr:`preserve_order` is False, the ordering of the matched indices and names is the same as the order
+ of the provided list of query regular expressions.
+
+ For example, consider the list of strings is ['a', 'b', 'c', 'd', 'e'] and the regular expressions are ['a|c', 'b'].
+ If :attr:`preserve_order` is False, then the function will return the indices of the matched strings and the
+ strings as: ([0, 1, 2], ['a', 'b', 'c']). When :attr:`preserve_order` is True, it will return them as:
+ ([0, 2, 1], ['a', 'c', 'b']).
+
+ Note:
+ The function does not sort the indices. It returns the indices in the order they are found.
+
+ Args:
+ keys: A regular expression or a list of regular expressions to match the strings in the list.
+ list_of_strings: A list of strings to match.
+ preserve_order: Whether to preserve the order of the query keys in the returned values. Defaults to False.
+
+ Returns:
+ A tuple of lists containing the matched indices and names.
+
+ Raises:
+ ValueError: When multiple matches are found for a string in the list.
+ ValueError: When not all regular expressions are matched.
+ """
+ # resolve name keys
+ ifisinstance(keys,str):
+ keys=[keys]
+ # find matching patterns
+ index_list=[]
+ names_list=[]
+ key_idx_list=[]
+ # book-keeping to check that we always have a one-to-one mapping
+ # i.e. each target string should match only one regular expression
+ target_strings_match_found=[Nonefor_inrange(len(list_of_strings))]
+ keys_match_found=[[]for_inrange(len(keys))]
+ # loop over all target strings
+ fortarget_index,potential_match_stringinenumerate(list_of_strings):
+ forkey_index,re_keyinenumerate(keys):
+ ifre.fullmatch(re_key,potential_match_string):
+ # check if match already found
+ iftarget_strings_match_found[target_index]:
+ raiseValueError(
+ f"Multiple matches for '{potential_match_string}':"
+ f" '{target_strings_match_found[target_index]}' and '{re_key}'!"
+ )
+ # add to list
+ target_strings_match_found[target_index]=re_key
+ index_list.append(target_index)
+ names_list.append(potential_match_string)
+ key_idx_list.append(key_index)
+ # add for regex key
+ keys_match_found[key_index].append(potential_match_string)
+ # reorder keys if they should be returned in order of the query keys
+ ifpreserve_order:
+ reordered_index_list=[None]*len(index_list)
+ global_index=0
+ forkey_indexinrange(len(keys)):
+ forkey_idx_position,key_idx_entryinenumerate(key_idx_list):
+ ifkey_idx_entry==key_index:
+ reordered_index_list[key_idx_position]=global_index
+ global_index+=1
+ # reorder index and names list
+ index_list_reorder=[None]*len(index_list)
+ names_list_reorder=[None]*len(index_list)
+ foridx,reorder_idxinenumerate(reordered_index_list):
+ index_list_reorder[reorder_idx]=index_list[idx]
+ names_list_reorder[reorder_idx]=names_list[idx]
+ # update
+ index_list=index_list_reorder
+ names_list=names_list_reorder
+ # check that all regular expressions are matched
+ ifnotall(keys_match_found):
+ # make this print nicely aligned for debugging
+ msg="\n"
+ forkey,valueinzip(keys,keys_match_found):
+ msg+=f"\t{key}: {value}\n"
+ msg+=f"Available strings: {list_of_strings}\n"
+ # raise error
+ raiseValueError(
+ f"Not all regular expressions are matched! Please check that the regular expressions are correct: {msg}"
+ )
+ # return
+ returnindex_list,names_list
+
+
+
[文档]defresolve_matching_names_values(
+ data:dict[str,Any],list_of_strings:Sequence[str],preserve_order:bool=False
+)->tuple[list[int],list[str],list[Any]]:
+"""Match a list of regular expressions in a dictionary against a list of strings and return
+ the matched indices, names, and values.
+
+ If the :attr:`preserve_order` is True, the ordering of the matched indices and names is the same as the order
+ of the provided list of strings. This means that the ordering is dictated by the order of the target strings
+ and not the order of the query regular expressions.
+
+ If the :attr:`preserve_order` is False, the ordering of the matched indices and names is the same as the order
+ of the provided list of query regular expressions.
+
+ For example, consider the dictionary is {"a|d|e": 1, "b|c": 2}, the list of strings is ['a', 'b', 'c', 'd', 'e'].
+ If :attr:`preserve_order` is False, then the function will return the indices of the matched strings, the
+ matched strings, and the values as: ([0, 1, 2, 3, 4], ['a', 'b', 'c', 'd', 'e'], [1, 2, 2, 1, 1]). When
+ :attr:`preserve_order` is True, it will return them as: ([0, 3, 4, 1, 2], ['a', 'd', 'e', 'b', 'c'], [1, 1, 1, 2, 2]).
+
+ Args:
+ data: A dictionary of regular expressions and values to match the strings in the list.
+ list_of_strings: A list of strings to match.
+ preserve_order: Whether to preserve the order of the query keys in the returned values. Defaults to False.
+
+ Returns:
+ A tuple of lists containing the matched indices, names, and values.
+
+ Raises:
+ TypeError: When the input argument :attr:`data` is not a dictionary.
+ ValueError: When multiple matches are found for a string in the dictionary.
+ ValueError: When not all regular expressions in the data keys are matched.
+ """
+ # check valid input
+ ifnotisinstance(data,dict):
+ raiseTypeError(f"Input argument `data` should be a dictionary. Received: {data}")
+ # find matching patterns
+ index_list=[]
+ names_list=[]
+ values_list=[]
+ key_idx_list=[]
+ # book-keeping to check that we always have a one-to-one mapping
+ # i.e. each target string should match only one regular expression
+ target_strings_match_found=[Nonefor_inrange(len(list_of_strings))]
+ keys_match_found=[[]for_inrange(len(data))]
+ # loop over all target strings
+ fortarget_index,potential_match_stringinenumerate(list_of_strings):
+ forkey_index,(re_key,value)inenumerate(data.items()):
+ ifre.fullmatch(re_key,potential_match_string):
+ # check if match already found
+ iftarget_strings_match_found[target_index]:
+ raiseValueError(
+ f"Multiple matches for '{potential_match_string}':"
+ f" '{target_strings_match_found[target_index]}' and '{re_key}'!"
+ )
+ # add to list
+ target_strings_match_found[target_index]=re_key
+ index_list.append(target_index)
+ names_list.append(potential_match_string)
+ values_list.append(value)
+ key_idx_list.append(key_index)
+ # add for regex key
+ keys_match_found[key_index].append(potential_match_string)
+ # reorder keys if they should be returned in order of the query keys
+ ifpreserve_order:
+ reordered_index_list=[None]*len(index_list)
+ global_index=0
+ forkey_indexinrange(len(data)):
+ forkey_idx_position,key_idx_entryinenumerate(key_idx_list):
+ ifkey_idx_entry==key_index:
+ reordered_index_list[key_idx_position]=global_index
+ global_index+=1
+ # reorder index and names list
+ index_list_reorder=[None]*len(index_list)
+ names_list_reorder=[None]*len(index_list)
+ values_list_reorder=[None]*len(index_list)
+ foridx,reorder_idxinenumerate(reordered_index_list):
+ index_list_reorder[reorder_idx]=index_list[idx]
+ names_list_reorder[reorder_idx]=names_list[idx]
+ values_list_reorder[reorder_idx]=values_list[idx]
+ # update
+ index_list=index_list_reorder
+ names_list=names_list_reorder
+ values_list=values_list_reorder
+ # check that all regular expressions are matched
+ ifnotall(keys_match_found):
+ # make this print nicely aligned for debugging
+ msg="\n"
+ forkey,valueinzip(data.keys(),keys_match_found):
+ msg+=f"\t{key}: {value}\n"
+ msg+=f"Available strings: {list_of_strings}\n"
+ # raise error
+ raiseValueError(
+ f"Not all regular expressions are matched! Please check that the regular expressions are correct: {msg}"
+ )
+ # return
+ returnindex_list,names_list,values_list
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Sub-module for a timer class that can be used for performance measurements."""
+
+from__future__importannotations
+
+importtime
+fromcontextlibimportContextDecorator
+fromtypingimportAny,ClassVar
+
+
+
[文档]classTimerError(Exception):
+"""A custom exception used to report errors in use of :class:`Timer` class."""
+
+ pass
+
+
+
[文档]classTimer(ContextDecorator):
+"""A timer for performance measurements.
+
+ A class to keep track of time for performance measurement.
+ It allows timing via context managers and decorators as well.
+
+ It uses the `time.perf_counter` function to measure time. This function
+ returns the number of seconds since the epoch as a float. It has the
+ highest resolution available on the system.
+
+ As a regular object:
+
+ .. code-block:: python
+
+ import time
+
+ from omni.isaac.lab.utils.timer import Timer
+
+ timer = Timer()
+ timer.start()
+ time.sleep(1)
+ print(1 <= timer.time_elapsed <= 2) # Output: True
+
+ time.sleep(1)
+ timer.stop()
+ print(2 <= stopwatch.total_run_time) # Output: True
+
+ As a context manager:
+
+ .. code-block:: python
+
+ import time
+
+ from omni.isaac.lab.utils.timer import Timer
+
+ with Timer() as timer:
+ time.sleep(1)
+ print(1 <= timer.time_elapsed <= 2) # Output: True
+
+ Reference: https://gist.github.com/sumeet/1123871
+ """
+
+ timing_info:ClassVar[dict[str,float]]=dict()
+"""Dictionary for storing the elapsed time per timer instances globally.
+
+ This dictionary logs the timer information. The keys are the names given to the timer class
+ at its initialization. If no :attr:`name` is passed to the constructor, no time
+ is recorded in the dictionary.
+ """
+
+
[文档]def__init__(self,msg:str|None=None,name:str|None=None):
+"""Initializes the timer.
+
+ Args:
+ msg: The message to display when using the timer
+ class in a context manager. Defaults to None.
+ name: The name to use for logging times in a global
+ dictionary. Defaults to None.
+ """
+ self._msg=msg
+ self._name=name
+ self._start_time=None
+ self._stop_time=None
+ self._elapsed_time=None
+
+ def__str__(self)->str:
+"""A string representation of the class object.
+
+ Returns:
+ A string containing the elapsed time.
+ """
+ returnf"{self.time_elapsed:0.6f} seconds"
+
+"""
+ Properties
+ """
+
+ @property
+ deftime_elapsed(self)->float:
+"""The number of seconds that have elapsed since this timer started timing.
+
+ Note:
+ This is used for checking how much time has elapsed while the timer is still running.
+ """
+ returntime.perf_counter()-self._start_time
+
+ @property
+ deftotal_run_time(self)->float:
+"""The number of seconds that elapsed from when the timer started to when it ended."""
+ returnself._elapsed_time
+
+"""
+ Operations
+ """
+
+
[文档]defstart(self):
+"""Start timing."""
+ ifself._start_timeisnotNone:
+ raiseTimerError("Timer is running. Use .stop() to stop it")
+
+ self._start_time=time.perf_counter()
+
+
[文档]defstop(self):
+"""Stop timing."""
+ ifself._start_timeisNone:
+ raiseTimerError("Timer is not running. Use .start() to start it")
+
+ self._stop_time=time.perf_counter()
+ self._elapsed_time=self._stop_time-self._start_time
+ self._start_time=None
+
+ ifself._name:
+ Timer.timing_info[self._name]=self._elapsed_time
[文档]@staticmethod
+ defget_timer_info(name:str)->float:
+"""Retrieves the time logged in the global dictionary
+ based on name.
+
+ Args:
+ name: Name of the the entry to be retrieved.
+
+ Raises:
+ TimerError: If name doesn't exist in the log.
+
+ Returns:
+ A float containing the time logged if the name exists.
+ """
+ ifnamenotinTimer.timing_info:
+ raiseTimerError(f"Timer {name} does not exist")
+ returnTimer.timing_info.get(name)
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Wrapping around warp kernels for compatibility with torch tensors."""
+
+# needed to import for allowing type-hinting: torch.Tensor | None
+from__future__importannotations
+
+importnumpyasnp
+importtorch
+
+importwarpaswp
+
+wp.init()
+
+from.importkernels
+
+
+
[文档]defraycast_mesh(
+ ray_starts:torch.Tensor,
+ ray_directions:torch.Tensor,
+ mesh:wp.Mesh,
+ max_dist:float=1e6,
+ return_distance:bool=False,
+ return_normal:bool=False,
+ return_face_id:bool=False,
+)->tuple[torch.Tensor,torch.Tensor|None,torch.Tensor|None,torch.Tensor|None]:
+"""Performs ray-casting against a mesh.
+
+ Note that the `ray_starts` and `ray_directions`, and `ray_hits` should have compatible shapes
+ and data types to ensure proper execution. Additionally, they all must be in the same frame.
+
+ Args:
+ ray_starts: The starting position of the rays. Shape (N, 3).
+ ray_directions: The ray directions for each ray. Shape (N, 3).
+ mesh: The warp mesh to ray-cast against.
+ max_dist: The maximum distance to ray-cast. Defaults to 1e6.
+ return_distance: Whether to return the distance of the ray until it hits the mesh. Defaults to False.
+ return_normal: Whether to return the normal of the mesh face the ray hits. Defaults to False.
+ return_face_id: Whether to return the face id of the mesh face the ray hits. Defaults to False.
+
+ Returns:
+ The ray hit position. Shape (N, 3).
+ The returned tensor contains :obj:`float('inf')` for missed hits.
+ The ray hit distance. Shape (N,).
+ Will only return if :attr:`return_distance` is True, else returns None.
+ The returned tensor contains :obj:`float('inf')` for missed hits.
+ The ray hit normal. Shape (N, 3).
+ Will only return if :attr:`return_normal` is True else returns None.
+ The returned tensor contains :obj:`float('inf')` for missed hits.
+ The ray hit face id. Shape (N,).
+ Will only return if :attr:`return_face_id` is True else returns None.
+ The returned tensor contains :obj:`int(-1)` for missed hits.
+ """
+ # extract device and shape information
+ shape=ray_starts.shape
+ device=ray_starts.device
+ # device of the mesh
+ torch_device=wp.device_to_torch(mesh.device)
+ # reshape the tensors
+ ray_starts=ray_starts.to(torch_device).view(-1,3).contiguous()
+ ray_directions=ray_directions.to(torch_device).view(-1,3).contiguous()
+ num_rays=ray_starts.shape[0]
+ # create output tensor for the ray hits
+ ray_hits=torch.full((num_rays,3),float("inf"),device=torch_device).contiguous()
+
+ # map the memory to warp arrays
+ ray_starts_wp=wp.from_torch(ray_starts,dtype=wp.vec3)
+ ray_directions_wp=wp.from_torch(ray_directions,dtype=wp.vec3)
+ ray_hits_wp=wp.from_torch(ray_hits,dtype=wp.vec3)
+
+ ifreturn_distance:
+ ray_distance=torch.full((num_rays,),float("inf"),device=torch_device).contiguous()
+ ray_distance_wp=wp.from_torch(ray_distance,dtype=wp.float32)
+ else:
+ ray_distance=None
+ ray_distance_wp=wp.empty((1,),dtype=wp.float32,device=torch_device)
+
+ ifreturn_normal:
+ ray_normal=torch.full((num_rays,3),float("inf"),device=torch_device).contiguous()
+ ray_normal_wp=wp.from_torch(ray_normal,dtype=wp.vec3)
+ else:
+ ray_normal=None
+ ray_normal_wp=wp.empty((1,),dtype=wp.vec3,device=torch_device)
+
+ ifreturn_face_id:
+ ray_face_id=torch.ones((num_rays,),dtype=torch.int32,device=torch_device).contiguous()*(-1)
+ ray_face_id_wp=wp.from_torch(ray_face_id,dtype=wp.int32)
+ else:
+ ray_face_id=None
+ ray_face_id_wp=wp.empty((1,),dtype=wp.int32,device=torch_device)
+
+ # launch the warp kernel
+ wp.launch(
+ kernel=kernels.raycast_mesh_kernel,
+ dim=num_rays,
+ inputs=[
+ mesh.id,
+ ray_starts_wp,
+ ray_directions_wp,
+ ray_hits_wp,
+ ray_distance_wp,
+ ray_normal_wp,
+ ray_face_id_wp,
+ float(max_dist),
+ int(return_distance),
+ int(return_normal),
+ int(return_face_id),
+ ],
+ device=mesh.device,
+ )
+ # NOTE: Synchronize is not needed anymore, but we keep it for now. Check with @dhoeller.
+ wp.synchronize()
+
+ ifreturn_distance:
+ ray_distance=ray_distance.to(device).view(shape[0],shape[1])
+ ifreturn_normal:
+ ray_normal=ray_normal.to(device).view(shape)
+ ifreturn_face_id:
+ ray_face_id=ray_face_id.to(device).view(shape[0],shape[1])
+
+ returnray_hits.to(device).view(shape),ray_distance,ray_normal,ray_face_id
+
+
+
[文档]defconvert_to_warp_mesh(points:np.ndarray,indices:np.ndarray,device:str)->wp.Mesh:
+"""Create a warp mesh object with a mesh defined from vertices and triangles.
+
+ Args:
+ points: The vertices of the mesh. Shape is (N, 3), where N is the number of vertices.
+ indices: The triangles of the mesh as references to vertices for each triangle.
+ Shape is (M, 3), where M is the number of triangles / faces.
+ device: The device to use for the mesh.
+
+ Returns:
+ The warp mesh object.
+ """
+ returnwp.Mesh(
+ points=wp.array(points.astype(np.float32),dtype=wp.vec3,device=device),
+ indices=wp.array(indices.astype(np.int32).flatten(),dtype=wp.int32,device=device),
+ )
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Interface to collect and store data from the environment using format from `robomimic`."""
+
+# needed to import for allowing type-hinting: np.ndarray | torch.Tensor
+from__future__importannotations
+
+importh5py
+importjson
+importnumpyasnp
+importos
+importtorch
+fromcollections.abcimportIterable
+
+importcarb
+
+
+
[文档]classRobomimicDataCollector:
+"""Data collection interface for robomimic.
+
+ This class implements a data collector interface for saving simulation states to disk.
+ The data is stored in `HDF5`_ binary data format. The class is useful for collecting
+ demonstrations. The collected data follows the `structure`_ from robomimic.
+
+ All datasets in `robomimic` require the observations and next observations obtained
+ from before and after the environment step. These are stored as a dictionary of
+ observations in the keys "obs" and "next_obs" respectively.
+
+ For certain agents in `robomimic`, the episode data should have the following
+ additional keys: "actions", "rewards", "dones". This behavior can be altered by changing
+ the dataset keys required in the training configuration for the respective learning agent.
+
+ For reference on datasets, please check the robomimic `documentation`.
+
+ .. _HDF5: https://www.h5py.org/
+ .. _structure: https://robomimic.github.io/docs/datasets/overview.html#dataset-structure
+ .. _documentation: https://github.com/ARISE-Initiative/robomimic/blob/master/robomimic/config/base_config.py#L167-L173
+ """
+
+
[文档]def__init__(
+ self,
+ env_name:str,
+ directory_path:str,
+ filename:str="test",
+ num_demos:int=1,
+ flush_freq:int=1,
+ env_config:dict|None=None,
+ ):
+"""Initializes the data collection wrapper.
+
+ Args:
+ env_name: The name of the environment.
+ directory_path: The path to store collected data.
+ filename: The basename of the saved file. Defaults to "test".
+ num_demos: Number of demonstrations to record until stopping. Defaults to 1.
+ flush_freq: Frequency to dump data to disk. Defaults to 1.
+ env_config: The configuration for the environment. Defaults to None.
+ """
+ # save input arguments
+ self._env_name=env_name
+ self._env_config=env_config
+ self._directory=os.path.abspath(directory_path)
+ self._filename=filename
+ self._num_demos=num_demos
+ self._flush_freq=flush_freq
+ # print info
+ print(self.__str__())
+
+ # create directory it doesn't exist
+ ifnotos.path.isdir(self._directory):
+ os.makedirs(self._directory)
+
+ # placeholder for current hdf5 file object
+ self._h5_file_stream=None
+ self._h5_data_group=None
+ self._h5_episode_group=None
+
+ # store count of demos within episode
+ self._demo_count=0
+ # flags for setting up
+ self._is_first_interaction=True
+ self._is_stop=False
+ # create buffers to store data
+ self._dataset=dict()
+
+ def__del__(self):
+"""Destructor for data collector."""
+ ifnotself._is_stop:
+ self.close()
+
+ def__str__(self)->str:
+"""Represents the data collector as a string."""
+ msg="Dataset collector <class RobomimicDataCollector> object"
+ msg+=f"\tStoring trajectories in directory: {self._directory}\n"
+ msg+=f"\tNumber of demos for collection : {self._num_demos}\n"
+ msg+=f"\tFrequency for saving data to disk: {self._flush_freq}\n"
+
+ returnmsg
+
+"""
+ Properties
+ """
+
+ @property
+ defdemo_count(self)->int:
+"""The number of demos collected so far."""
+ returnself._demo_count
+
+"""
+ Operations.
+ """
+
+
[文档]defis_stopped(self)->bool:
+"""Whether data collection is stopped or not.
+
+ Returns:
+ True if data collection has stopped.
+ """
+ returnself._is_stop
+
+
[文档]defreset(self):
+"""Reset the internals of data logger."""
+ # setup the file to store data in
+ ifself._is_first_interaction:
+ self._demo_count=0
+ self._create_new_file(self._filename)
+ self._is_first_interaction=False
+ # clear out existing buffers
+ self._dataset=dict()
+
+
[文档]defadd(self,key:str,value:np.ndarray|torch.Tensor):
+"""Add a key-value pair to the dataset.
+
+ The key can be nested by using the "/" character. For example:
+ "obs/joint_pos". Currently only two-level nesting is supported.
+
+ Args:
+ key: The key name.
+ value: The corresponding value
+ of shape (N, ...), where `N` is number of environments.
+
+ Raises:
+ ValueError: When provided key has sub-keys more than 2. Example: "obs/joints/pos", instead
+ of "obs/joint_pos".
+ """
+ # check if data should be recorded
+ ifself._is_first_interaction:
+ carb.log_warn("Please call reset before adding new data. Calling reset...")
+ self.reset()
+ ifself._is_stop:
+ carb.log_warn(f"Desired number of demonstrations collected: {self._demo_count} >= {self._num_demos}.")
+ return
+ # check datatype
+ ifisinstance(value,torch.Tensor):
+ value=value.cpu().numpy()
+ else:
+ value=np.asarray(value)
+ # check if there are sub-keys
+ sub_keys=key.split("/")
+ num_sub_keys=len(sub_keys)
+ iflen(sub_keys)>2:
+ raiseValueError(f"Input key '{key}' has elements {num_sub_keys} which is more than two.")
+ # add key to dictionary if it doesn't exist
+ foriinrange(value.shape[0]):
+ # demo index
+ iff"env_{i}"notinself._dataset:
+ self._dataset[f"env_{i}"]=dict()
+ # key index
+ ifnum_sub_keys==2:
+ # create keys
+ ifsub_keys[0]notinself._dataset[f"env_{i}"]:
+ self._dataset[f"env_{i}"][sub_keys[0]]=dict()
+ ifsub_keys[1]notinself._dataset[f"env_{i}"][sub_keys[0]]:
+ self._dataset[f"env_{i}"][sub_keys[0]][sub_keys[1]]=list()
+ # add data to key
+ self._dataset[f"env_{i}"][sub_keys[0]][sub_keys[1]].append(value[i])
+ else:
+ # create keys
+ ifsub_keys[0]notinself._dataset[f"env_{i}"]:
+ self._dataset[f"env_{i}"][sub_keys[0]]=list()
+ # add data to key
+ self._dataset[f"env_{i}"][sub_keys[0]].append(value[i])
+
+
[文档]defflush(self,env_ids:Iterable[int]=(0,)):
+"""Flush the episode data based on environment indices.
+
+ Args:
+ env_ids: Environment indices to write data for. Defaults to (0).
+ """
+ # check that data is being recorded
+ ifself._h5_file_streamisNoneorself._h5_data_groupisNone:
+ carb.log_error("No file stream has been opened. Please call reset before flushing data.")
+ return
+
+ # iterate over each environment and add their data
+ forindexinenv_ids:
+ # data corresponding to demo
+ env_dataset=self._dataset[f"env_{index}"]
+
+ # create episode group based on demo count
+ h5_episode_group=self._h5_data_group.create_group(f"demo_{self._demo_count}")
+ # store number of steps taken
+ h5_episode_group.attrs["num_samples"]=len(env_dataset["actions"])
+ # store other data from dictionary
+ forkey,valueinenv_dataset.items():
+ ifisinstance(value,dict):
+ # create group
+ key_group=h5_episode_group.create_group(key)
+ # add sub-keys values
+ forsub_key,sub_valueinvalue.items():
+ key_group.create_dataset(sub_key,data=np.array(sub_value))
+ else:
+ h5_episode_group.create_dataset(key,data=np.array(value))
+ # increment total step counts
+ self._h5_data_group.attrs["total"]+=h5_episode_group.attrs["num_samples"]
+
+ # increment total demo counts
+ self._demo_count+=1
+ # reset buffer for environment
+ self._dataset[f"env_{index}"]=dict()
+
+ # dump at desired frequency
+ ifself._demo_count%self._flush_freq==0:
+ self._h5_file_stream.flush()
+ print(f">>> Flushing data to disk. Collected demos: {self._demo_count} / {self._num_demos}")
+
+ # if demos collected then stop
+ ifself._demo_count>=self._num_demos:
+ print(f">>> Desired number of demonstrations collected: {self._demo_count} >= {self._num_demos}.")
+ self.close()
+ # break out of loop
+ break
+
+
[文档]defclose(self):
+"""Stop recording and save the file at its current state."""
+ ifnotself._is_stop:
+ print(f">>> Closing recording of data. Collected demos: {self._demo_count} / {self._num_demos}")
+ # close the file safely
+ ifself._h5_file_streamisnotNone:
+ self._h5_file_stream.close()
+ # mark that data collection is stopped
+ self._is_stop=True
+
+"""
+ Helper functions.
+ """
+
+ def_create_new_file(self,fname:str):
+"""Create a new HDF5 file for writing episode info into.
+
+ Reference:
+ https://robomimic.github.io/docs/datasets/overview.html
+
+ Args:
+ fname: The base name of the file.
+ """
+ ifnotfname.endswith(".hdf5"):
+ fname+=".hdf5"
+ # define path to file
+ hdf5_path=os.path.join(self._directory,fname)
+ # construct the stream object
+ self._h5_file_stream=h5py.File(hdf5_path,"w")
+ # create group to store data
+ self._h5_data_group=self._h5_file_stream.create_group("data")
+ # stores total number of samples accumulated across demonstrations
+ self._h5_data_group.attrs["total"]=0
+ # store the environment meta-info
+ # -- we use gym environment type
+ # Ref: https://github.com/ARISE-Initiative/robomimic/blob/master/robomimic/envs/env_base.py#L15
+ env_type=2
+ # -- check if env config provided
+ ifself._env_configisNone:
+ self._env_config=dict()
+ # -- add info
+ self._h5_data_group.attrs["env_args"]=json.dumps({
+ "env_name":self._env_name,
+ "type":env_type,
+ "env_kwargs":self._env_config,
+ })
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Sub-module with utility for importing all modules in a package recursively."""
+
+from__future__importannotations
+
+importimportlib
+importpkgutil
+importsys
+
+
+
[文档]defimport_packages(package_name:str,blacklist_pkgs:list[str]|None=None):
+"""Import all sub-packages in a package recursively.
+
+ It is easier to use this function to import all sub-packages in a package recursively
+ than to manually import each sub-package.
+
+ It replaces the need of the following code snippet on the top of each package's ``__init__.py`` file:
+
+ .. code-block:: python
+
+ import .locomotion.velocity
+ import .manipulation.reach
+ import .manipulation.lift
+
+ Args:
+ package_name: The package name.
+ blacklist_pkgs: The list of blacklisted packages to skip. Defaults to None,
+ which means no packages are blacklisted.
+ """
+ # Default blacklist
+ ifblacklist_pkgsisNone:
+ blacklist_pkgs=[]
+ # Import the package itself
+ package=importlib.import_module(package_name)
+ # Import all Python files
+ for_in_walk_packages(package.__path__,package.__name__+".",blacklist_pkgs=blacklist_pkgs):
+ pass
+
+
+def_walk_packages(
+ path:str|None=None,
+ prefix:str="",
+ onerror:callable|None=None,
+ blacklist_pkgs:list[str]|None=None,
+):
+"""Yields ModuleInfo for all modules recursively on path, or, if path is None, all accessible modules.
+
+ Note:
+ This function is a modified version of the original ``pkgutil.walk_packages`` function. It adds
+ the `blacklist_pkgs` argument to skip blacklisted packages. Please refer to the original
+ ``pkgutil.walk_packages`` function for more details.
+ """
+ ifblacklist_pkgsisNone:
+ blacklist_pkgs=[]
+
+ defseen(p,m={}):
+ ifpinm:
+ returnTrue
+ m[p]=True# noqa: R503
+
+ forinfoinpkgutil.iter_modules(path,prefix):
+ # check blacklisted
+ ifany([black_pkg_nameininfo.nameforblack_pkg_nameinblacklist_pkgs]):
+ continue
+
+ # yield the module info
+ yieldinfo
+
+ ifinfo.ispkg:
+ try:
+ __import__(info.name)
+ exceptException:
+ ifonerrorisnotNone:
+ onerror(info.name)
+ else:
+ raise
+ else:
+ path=getattr(sys.modules[info.name],"__path__",None)or[]
+
+ # don't traverse path items we've seen before
+ path=[pforpinpathifnotseen(p)]
+
+ yield from_walk_packages(path,info.name+".",onerror,blacklist_pkgs)
+
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Sub-module with utilities for parsing and loading configurations."""
+
+
+importgymnasiumasgym
+importimportlib
+importinspect
+importos
+importre
+importyaml
+
+fromomni.isaac.lab.envsimportDirectRLEnvCfg,ManagerBasedRLEnvCfg
+
+
+
[文档]defload_cfg_from_registry(task_name:str,entry_point_key:str)->dict|object:
+"""Load default configuration given its entry point from the gym registry.
+
+ This function loads the configuration object from the gym registry for the given task name.
+ It supports both YAML and Python configuration files.
+
+ It expects the configuration to be registered in the gym registry as:
+
+ .. code-block:: python
+
+ gym.register(
+ id="My-Awesome-Task-v0",
+ ...
+ kwargs={"env_entry_point_cfg": "path.to.config:ConfigClass"},
+ )
+
+ The parsed configuration object for above example can be obtained as:
+
+ .. code-block:: python
+
+ from omni.isaac.lab_tasks.utils.parse_cfg import load_cfg_from_registry
+
+ cfg = load_cfg_from_registry("My-Awesome-Task-v0", "env_entry_point_cfg")
+
+ Args:
+ task_name: The name of the environment.
+ entry_point_key: The entry point key to resolve the configuration file.
+
+ Returns:
+ The parsed configuration object. If the entry point is a YAML file, it is parsed into a dictionary.
+ If the entry point is a Python class, it is instantiated and returned.
+
+ Raises:
+ ValueError: If the entry point key is not available in the gym registry for the task.
+ """
+ # obtain the configuration entry point
+ cfg_entry_point=gym.spec(task_name).kwargs.get(entry_point_key)
+ # check if entry point exists
+ ifcfg_entry_pointisNone:
+ raiseValueError(
+ f"Could not find configuration for the environment: '{task_name}'."
+ f" Please check that the gym registry has the entry point: '{entry_point_key}'."
+ )
+ # parse the default config file
+ ifisinstance(cfg_entry_point,str)andcfg_entry_point.endswith(".yaml"):
+ ifos.path.exists(cfg_entry_point):
+ # absolute path for the config file
+ config_file=cfg_entry_point
+ else:
+ # resolve path to the module location
+ mod_name,file_name=cfg_entry_point.split(":")
+ mod_path=os.path.dirname(importlib.import_module(mod_name).__file__)
+ # obtain the configuration file path
+ config_file=os.path.join(mod_path,file_name)
+ # load the configuration
+ print(f"[INFO]: Parsing configuration from: {config_file}")
+ withopen(config_file,encoding="utf-8")asf:
+ cfg=yaml.full_load(f)
+ else:
+ ifcallable(cfg_entry_point):
+ # resolve path to the module location
+ mod_path=inspect.getfile(cfg_entry_point)
+ # load the configuration
+ cfg_cls=cfg_entry_point()
+ elifisinstance(cfg_entry_point,str):
+ # resolve path to the module location
+ mod_name,attr_name=cfg_entry_point.split(":")
+ mod=importlib.import_module(mod_name)
+ cfg_cls=getattr(mod,attr_name)
+ else:
+ cfg_cls=cfg_entry_point
+ # load the configuration
+ print(f"[INFO]: Parsing configuration from: {cfg_entry_point}")
+ ifcallable(cfg_cls):
+ cfg=cfg_cls()
+ else:
+ cfg=cfg_cls
+ returncfg
+
+
+
[文档]defparse_env_cfg(
+ task_name:str,device:str="cuda:0",num_envs:int|None=None,use_fabric:bool|None=None
+)->ManagerBasedRLEnvCfg|DirectRLEnvCfg:
+"""Parse configuration for an environment and override based on inputs.
+
+ Args:
+ task_name: The name of the environment.
+ device: The device to run the simulation on. Defaults to "cuda:0".
+ num_envs: Number of environments to create. Defaults to None, in which case it is left unchanged.
+ use_fabric: Whether to enable/disable fabric interface. If false, all read/write operations go through USD.
+ This slows down the simulation but allows seeing the changes in the USD through the USD stage.
+ Defaults to None, in which case it is left unchanged.
+
+ Returns:
+ The parsed configuration object.
+
+ Raises:
+ RuntimeError: If the configuration for the task is not a class. We assume users always use a class for the
+ environment configuration.
+ """
+ # load the default configuration
+ cfg=load_cfg_from_registry(task_name,"env_cfg_entry_point")
+
+ # check that it is not a dict
+ # we assume users always use a class for the configuration
+ ifisinstance(cfg,dict):
+ raiseRuntimeError(f"Configuration for the task: '{task_name}' is not a class. Please provide a class.")
+
+ # simulation device
+ cfg.sim.device=device
+ # disable fabric to read/write through USD
+ ifuse_fabricisnotNone:
+ cfg.sim.use_fabric=use_fabric
+ # number of environments
+ ifnum_envsisnotNone:
+ cfg.scene.num_envs=num_envs
+
+ returncfg
+
+
+
[文档]defget_checkpoint_path(
+ log_path:str,run_dir:str=".*",checkpoint:str=".*",other_dirs:list[str]=None,sort_alpha:bool=True
+)->str:
+"""Get path to the model checkpoint in input directory.
+
+ The checkpoint file is resolved as: ``<log_path>/<run_dir>/<*other_dirs>/<checkpoint>``, where the
+ :attr:`other_dirs` are intermediate folder names to concatenate. These cannot be regex expressions.
+
+ If :attr:`run_dir` and :attr:`checkpoint` are regex expressions then the most recent (highest alphabetical order)
+ run and checkpoint are selected. To disable this behavior, set the flag :attr:`sort_alpha` to False.
+
+ Args:
+ log_path: The log directory path to find models in.
+ run_dir: The regex expression for the name of the directory containing the run. Defaults to the most
+ recent directory created inside :attr:`log_path`.
+ other_dirs: The intermediate directories between the run directory and the checkpoint file. Defaults to
+ None, which implies that checkpoint file is directly under the run directory.
+ checkpoint: The regex expression for the model checkpoint file. Defaults to the most recent
+ torch-model saved in the :attr:`run_dir` directory.
+ sort_alpha: Whether to sort the runs by alphabetical order. Defaults to True.
+ If False, the folders in :attr:`run_dir` are sorted by the last modified time.
+
+ Returns:
+ The path to the model checkpoint.
+
+ Raises:
+ ValueError: When no runs are found in the input directory.
+ ValueError: When no checkpoints are found in the input directory.
+
+ """
+ # check if runs present in directory
+ try:
+ # find all runs in the directory that math the regex expression
+ runs=[
+ os.path.join(log_path,run)forruninos.scandir(log_path)ifrun.is_dir()andre.match(run_dir,run.name)
+ ]
+ # sort matched runs by alphabetical order (latest run should be last)
+ ifsort_alpha:
+ runs.sort()
+ else:
+ runs=sorted(runs,key=os.path.getmtime)
+ # create last run file path
+ ifother_dirsisnotNone:
+ run_path=os.path.join(runs[-1],*other_dirs)
+ else:
+ run_path=runs[-1]
+ exceptIndexError:
+ raiseValueError(f"No runs present in the directory: '{log_path}' match: '{run_dir}'.")
+
+ # list all model checkpoints in the directory
+ model_checkpoints=[fforfinos.listdir(run_path)ifre.match(checkpoint,f)]
+ # check if any checkpoints are present
+ iflen(model_checkpoints)==0:
+ raiseValueError(f"No checkpoints in the directory: '{run_path}' match '{checkpoint}'.")
+ # sort alphabetically while ensuring that *_10 comes after *_9
+ model_checkpoints.sort(key=lambdam:f"{m:0>15}")
+ # get latest matched checkpoint file
+ checkpoint_file=model_checkpoints[-1]
+
+ returnos.path.join(run_path,checkpoint_file)
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Wrapper to configure a :class:`ManagerBasedRLEnv` or :class:`DirectRlEnv` instance to RL-Games vectorized environment.
+
+The following example shows how to wrap an environment for RL-Games and register the environment construction
+for RL-Games :class:`Runner` class:
+
+.. code-block:: python
+
+ from rl_games.common import env_configurations, vecenv
+
+ from omni.isaac.lab_tasks.utils.wrappers.rl_games import RlGamesGpuEnv, RlGamesVecEnvWrapper
+
+ # configuration parameters
+ rl_device = "cuda:0"
+ clip_obs = 10.0
+ clip_actions = 1.0
+
+ # wrap around environment for rl-games
+ env = RlGamesVecEnvWrapper(env, rl_device, clip_obs, clip_actions)
+
+ # register the environment to rl-games registry
+ # note: in agents configuration: environment name must be "rlgpu"
+ vecenv.register(
+ "IsaacRlgWrapper", lambda config_name, num_actors, **kwargs: RlGamesGpuEnv(config_name, num_actors, **kwargs)
+ )
+ env_configurations.register("rlgpu", {"vecenv_type": "IsaacRlgWrapper", "env_creator": lambda **kwargs: env})
+
+"""
+
+# needed to import for allowing type-hinting:gym.spaces.Box | None
+from__future__importannotations
+
+importgym.spaces# needed for rl-games incompatibility: https://github.com/Denys88/rl_games/issues/261
+importgymnasium
+importtorch
+
+fromrl_games.commonimportenv_configurations
+fromrl_games.common.vecenvimportIVecEnv
+
+fromomni.isaac.lab.envsimportDirectRLEnv,ManagerBasedRLEnv,VecEnvObs
+
+"""
+Vectorized environment wrapper.
+"""
+
+
+
[文档]classRlGamesVecEnvWrapper(IVecEnv):
+"""Wraps around Isaac Lab environment for RL-Games.
+
+ This class wraps around the Isaac Lab environment. Since RL-Games works directly on
+ GPU buffers, the wrapper handles moving of buffers from the simulation environment
+ to the same device as the learning agent. Additionally, it performs clipping of
+ observations and actions.
+
+ For algorithms like asymmetric actor-critic, RL-Games expects a dictionary for
+ observations. This dictionary contains "obs" and "states" which typically correspond
+ to the actor and critic observations respectively.
+
+ To use asymmetric actor-critic, the environment observations from :class:`ManagerBasedRLEnv` or :class:`DirectRLEnv`
+ must have the key or group name "critic". The observation group is used to set the
+ :attr:`num_states` (int) and :attr:`state_space` (:obj:`gym.spaces.Box`). These are
+ used by the learning agent in RL-Games to allocate buffers in the trajectory memory.
+ Since this is optional for some environments, the wrapper checks if these attributes exist.
+ If they don't then the wrapper defaults to zero as number of privileged observations.
+
+ .. caution::
+
+ This class must be the last wrapper in the wrapper chain. This is because the wrapper does not follow
+ the :class:`gym.Wrapper` interface. Any subsequent wrappers will need to be modified to work with this
+ wrapper.
+
+
+ Reference:
+ https://github.com/Denys88/rl_games/blob/master/rl_games/common/ivecenv.py
+ https://github.com/NVIDIA-Omniverse/IsaacGymEnvs
+ """
+
+
[文档]def__init__(self,env:ManagerBasedRLEnv|DirectRLEnv,rl_device:str,clip_obs:float,clip_actions:float):
+"""Initializes the wrapper instance.
+
+ Args:
+ env: The environment to wrap around.
+ rl_device: The device on which agent computations are performed.
+ clip_obs: The clipping value for observations.
+ clip_actions: The clipping value for actions.
+
+ Raises:
+ ValueError: The environment is not inherited from :class:`ManagerBasedRLEnv` or :class:`DirectRLEnv`.
+ ValueError: If specified, the privileged observations (critic) are not of type :obj:`gym.spaces.Box`.
+ """
+ # check that input is valid
+ ifnotisinstance(env.unwrapped,ManagerBasedRLEnv)andnotisinstance(env.unwrapped,DirectRLEnv):
+ raiseValueError(
+ "The environment must be inherited from ManagerBasedRLEnv or DirectRLEnv. Environment type:"
+ f" {type(env)}"
+ )
+ # initialize the wrapper
+ self.env=env
+ # store provided arguments
+ self._rl_device=rl_device
+ self._clip_obs=clip_obs
+ self._clip_actions=clip_actions
+ self._sim_device=env.unwrapped.device
+ # information for privileged observations
+ ifself.state_spaceisNone:
+ self.rlg_num_states=0
+ else:
+ self.rlg_num_states=self.state_space.shape[0]
+
+ def__str__(self):
+"""Returns the wrapper name and the :attr:`env` representation string."""
+ return(
+ f"<{type(self).__name__}{self.env}>"
+ f"\n\tObservations clipping: {self._clip_obs}"
+ f"\n\tActions clipping : {self._clip_actions}"
+ f"\n\tAgent device : {self._rl_device}"
+ f"\n\tAsymmetric-learning : {self.rlg_num_states!=0}"
+ )
+
+ def__repr__(self):
+"""Returns the string representation of the wrapper."""
+ returnstr(self)
+
+"""
+ Properties -- Gym.Wrapper
+ """
+
+ @property
+ defrender_mode(self)->str|None:
+"""Returns the :attr:`Env` :attr:`render_mode`."""
+ returnself.env.render_mode
+
+ @property
+ defobservation_space(self)->gym.spaces.Box:
+"""Returns the :attr:`Env` :attr:`observation_space`."""
+ # note: rl-games only wants single observation space
+ policy_obs_space=self.unwrapped.single_observation_space["policy"]
+ ifnotisinstance(policy_obs_space,gymnasium.spaces.Box):
+ raiseNotImplementedError(
+ f"The RL-Games wrapper does not currently support observation space: '{type(policy_obs_space)}'."
+ f" If you need to support this, please modify the wrapper: {self.__class__.__name__},"
+ " and if you are nice, please send a merge-request."
+ )
+ # note: maybe should check if we are a sub-set of the actual space. don't do it right now since
+ # in ManagerBasedRLEnv we are setting action space as (-inf, inf).
+ returngym.spaces.Box(-self._clip_obs,self._clip_obs,policy_obs_space.shape)
+
+ @property
+ defaction_space(self)->gym.Space:
+"""Returns the :attr:`Env` :attr:`action_space`."""
+ # note: rl-games only wants single action space
+ action_space=self.unwrapped.single_action_space
+ ifnotisinstance(action_space,gymnasium.spaces.Box):
+ raiseNotImplementedError(
+ f"The RL-Games wrapper does not currently support action space: '{type(action_space)}'."
+ f" If you need to support this, please modify the wrapper: {self.__class__.__name__},"
+ " and if you are nice, please send a merge-request."
+ )
+ # return casted space in gym.spaces.Box (OpenAI Gym)
+ # note: maybe should check if we are a sub-set of the actual space. don't do it right now since
+ # in ManagerBasedRLEnv we are setting action space as (-inf, inf).
+ returngym.spaces.Box(-self._clip_actions,self._clip_actions,action_space.shape)
+
+
[文档]@classmethod
+ defclass_name(cls)->str:
+"""Returns the class name of the wrapper."""
+ returncls.__name__
+
+ @property
+ defunwrapped(self)->ManagerBasedRLEnv|DirectRLEnv:
+"""Returns the base environment of the wrapper.
+
+ This will be the bare :class:`gymnasium.Env` environment, underneath all layers of wrappers.
+ """
+ returnself.env.unwrapped
+
+"""
+ Properties
+ """
+
+ @property
+ defnum_envs(self)->int:
+"""Returns the number of sub-environment instances."""
+ returnself.unwrapped.num_envs
+
+ @property
+ defdevice(self)->str:
+"""Returns the base environment simulation device."""
+ returnself.unwrapped.device
+
+ @property
+ defstate_space(self)->gym.spaces.Box|None:
+"""Returns the :attr:`Env` :attr:`observation_space`."""
+ # note: rl-games only wants single observation space
+ critic_obs_space=self.unwrapped.single_observation_space.get("critic")
+ # check if we even have a critic obs
+ ifcritic_obs_spaceisNone:
+ returnNone
+ elifnotisinstance(critic_obs_space,gymnasium.spaces.Box):
+ raiseNotImplementedError(
+ f"The RL-Games wrapper does not currently support state space: '{type(critic_obs_space)}'."
+ f" If you need to support this, please modify the wrapper: {self.__class__.__name__},"
+ " and if you are nice, please send a merge-request."
+ )
+ # return casted space in gym.spaces.Box (OpenAI Gym)
+ # note: maybe should check if we are a sub-set of the actual space. don't do it right now since
+ # in ManagerBasedRLEnv we are setting action space as (-inf, inf).
+ returngym.spaces.Box(-self._clip_obs,self._clip_obs,critic_obs_space.shape)
+
+
[文档]defget_number_of_agents(self)->int:
+"""Returns number of actors in the environment."""
+ returngetattr(self,"num_agents",1)
+
+
[文档]defget_env_info(self)->dict:
+"""Returns the Gym spaces for the environment."""
+ return{
+ "observation_space":self.observation_space,
+ "action_space":self.action_space,
+ "state_space":self.state_space,
+ }
+
+"""
+ Operations - MDP
+ """
+
+ defseed(self,seed:int=-1)->int:# noqa: D102
+ returnself.unwrapped.seed(seed)
+
+ defreset(self):# noqa: D102
+ obs_dict,_=self.env.reset()
+ # process observations and states
+ returnself._process_obs(obs_dict)
+
+ defstep(self,actions):# noqa: D102
+ # move actions to sim-device
+ actions=actions.detach().clone().to(device=self._sim_device)
+ # clip the actions
+ actions=torch.clamp(actions,-self._clip_actions,self._clip_actions)
+ # perform environment step
+ obs_dict,rew,terminated,truncated,extras=self.env.step(actions)
+
+ # move time out information to the extras dict
+ # this is only needed for infinite horizon tasks
+ # note: only useful when `value_bootstrap` is True in the agent configuration
+ ifnotself.unwrapped.cfg.is_finite_horizon:
+ extras["time_outs"]=truncated.to(device=self._rl_device)
+ # process observations and states
+ obs_and_states=self._process_obs(obs_dict)
+ # move buffers to rl-device
+ # note: we perform clone to prevent issues when rl-device and sim-device are the same.
+ rew=rew.to(device=self._rl_device)
+ dones=(terminated|truncated).to(device=self._rl_device)
+ extras={
+ k:v.to(device=self._rl_device,non_blocking=True)ifhasattr(v,"to")elsevfork,vinextras.items()
+ }
+ # remap extras from "log" to "episode"
+ if"log"inextras:
+ extras["episode"]=extras.pop("log")
+
+ returnobs_and_states,rew,dones,extras
+
+ defclose(self):# noqa: D102
+ returnself.env.close()
+
+"""
+ Helper functions
+ """
+
+ def_process_obs(self,obs_dict:VecEnvObs)->torch.Tensor|dict[str,torch.Tensor]:
+"""Processing of the observations and states from the environment.
+
+ Note:
+ States typically refers to privileged observations for the critic function. It is typically used in
+ asymmetric actor-critic algorithms.
+
+ Args:
+ obs_dict: The current observations from environment.
+
+ Returns:
+ If environment provides states, then a dictionary containing the observations and states is returned.
+ Otherwise just the observations tensor is returned.
+ """
+ # process policy obs
+ obs=obs_dict["policy"]
+ # clip the observations
+ obs=torch.clamp(obs,-self._clip_obs,self._clip_obs)
+ # move the buffer to rl-device
+ obs=obs.to(device=self._rl_device).clone()
+
+ # check if asymmetric actor-critic or not
+ ifself.rlg_num_states>0:
+ # acquire states from the environment if it exists
+ try:
+ states=obs_dict["critic"]
+ exceptAttributeError:
+ raiseNotImplementedError("Environment does not define key 'critic' for privileged observations.")
+ # clip the states
+ states=torch.clamp(states,-self._clip_obs,self._clip_obs)
+ # move buffers to rl-device
+ states=states.to(self._rl_device).clone()
+ # convert to dictionary
+ return{"obs":obs,"states":states}
+ else:
+ returnobs
+
+
+"""
+Environment Handler.
+"""
+
+
+
[文档]classRlGamesGpuEnv(IVecEnv):
+"""Thin wrapper to create instance of the environment to fit RL-Games runner."""
+
+ # TODO: Adding this for now but do we really need this?
+
+
[文档]def__init__(self,config_name:str,num_actors:int,**kwargs):
+"""Initialize the environment.
+
+ Args:
+ config_name: The name of the environment configuration.
+ num_actors: The number of actors in the environment. This is not used in this wrapper.
+ """
+ self.env:RlGamesVecEnvWrapper=env_configurations.configurations[config_name]["env_creator"](**kwargs)
[文档]defget_number_of_agents(self)->int:
+"""Get number of agents in the environment.
+
+ Returns:
+ The number of agents in the environment.
+ """
+ returnself.env.get_number_of_agents()
+
+
[文档]defget_env_info(self)->dict:
+"""Get the Gym spaces for the environment.
+
+ Returns:
+ The Gym spaces for the environment.
+ """
+ returnself.env.get_env_info()
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+importcopy
+importos
+importtorch
+
+
+
[文档]defexport_policy_as_jit(actor_critic:object,normalizer:object|None,path:str,filename="policy.pt"):
+"""Export policy into a Torch JIT file.
+
+ Args:
+ actor_critic: The actor-critic torch module.
+ normalizer: The empirical normalizer module. If None, Identity is used.
+ path: The path to the saving directory.
+ filename: The name of exported JIT file. Defaults to "policy.pt".
+ """
+ policy_exporter=_TorchPolicyExporter(actor_critic,normalizer)
+ policy_exporter.export(path,filename)
+
+
+
[文档]defexport_policy_as_onnx(
+ actor_critic:object,path:str,normalizer:object|None=None,filename="policy.onnx",verbose=False
+):
+"""Export policy into a Torch ONNX file.
+
+ Args:
+ actor_critic: The actor-critic torch module.
+ normalizer: The empirical normalizer module. If None, Identity is used.
+ path: The path to the saving directory.
+ filename: The name of exported ONNX file. Defaults to "policy.onnx".
+ verbose: Whether to print the model summary. Defaults to False.
+ """
+ ifnotos.path.exists(path):
+ os.makedirs(path,exist_ok=True)
+ policy_exporter=_OnnxPolicyExporter(actor_critic,normalizer,verbose)
+ policy_exporter.export(path,filename)
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+fromdataclassesimportMISSING
+fromtypingimportLiteral
+
+fromomni.isaac.lab.utilsimportconfigclass
+
+
+
[文档]@configclass
+classRslRlPpoActorCriticCfg:
+"""Configuration for the PPO actor-critic networks."""
+
+ class_name:str="ActorCritic"
+"""The policy class name. Default is ActorCritic."""
+
+ init_noise_std:float=MISSING
+"""The initial noise standard deviation for the policy."""
+
+ actor_hidden_dims:list[int]=MISSING
+"""The hidden dimensions of the actor network."""
+
+ critic_hidden_dims:list[int]=MISSING
+"""The hidden dimensions of the critic network."""
+
+ activation:str=MISSING
+"""The activation function for the actor and critic networks."""
+
+
+
[文档]@configclass
+classRslRlPpoAlgorithmCfg:
+"""Configuration for the PPO algorithm."""
+
+ class_name:str="PPO"
+"""The algorithm class name. Default is PPO."""
+
+ value_loss_coef:float=MISSING
+"""The coefficient for the value loss."""
+
+ use_clipped_value_loss:bool=MISSING
+"""Whether to use clipped value loss."""
+
+ clip_param:float=MISSING
+"""The clipping parameter for the policy."""
+
+ entropy_coef:float=MISSING
+"""The coefficient for the entropy loss."""
+
+ num_learning_epochs:int=MISSING
+"""The number of learning epochs per update."""
+
+ num_mini_batches:int=MISSING
+"""The number of mini-batches per update."""
+
+ learning_rate:float=MISSING
+"""The learning rate for the policy."""
+
+ schedule:str=MISSING
+"""The learning rate schedule."""
+
+ gamma:float=MISSING
+"""The discount factor."""
+
+ lam:float=MISSING
+"""The lambda parameter for Generalized Advantage Estimation (GAE)."""
+
+ desired_kl:float=MISSING
+"""The desired KL divergence."""
+
+ max_grad_norm:float=MISSING
+"""The maximum gradient norm."""
+
+
+
[文档]@configclass
+classRslRlOnPolicyRunnerCfg:
+"""Configuration of the runner for on-policy algorithms."""
+
+ seed:int=42
+"""The seed for the experiment. Default is 42."""
+
+ device:str="cuda:0"
+"""The device for the rl-agent. Default is cuda:0."""
+
+ num_steps_per_env:int=MISSING
+"""The number of steps per environment per update."""
+
+ max_iterations:int=MISSING
+"""The maximum number of iterations."""
+
+ empirical_normalization:bool=MISSING
+"""Whether to use empirical normalization."""
+
+ policy:RslRlPpoActorCriticCfg=MISSING
+"""The policy configuration."""
+
+ algorithm:RslRlPpoAlgorithmCfg=MISSING
+"""The algorithm configuration."""
+
+ ##
+ # Checkpointing parameters
+ ##
+
+ save_interval:int=MISSING
+"""The number of iterations between saves."""
+
+ experiment_name:str=MISSING
+"""The experiment name."""
+
+ run_name:str=""
+"""The run name. Default is empty string.
+
+ The name of the run directory is typically the time-stamp at execution. If the run name is not empty,
+ then it is appended to the run directory's name, i.e. the logging directory's name will become
+ ``{time-stamp}_{run_name}``.
+ """
+
+ ##
+ # Logging parameters
+ ##
+
+ logger:Literal["tensorboard","neptune","wandb"]="tensorboard"
+"""The logger to use. Default is tensorboard."""
+
+ neptune_project:str="isaaclab"
+"""The neptune project name. Default is "isaaclab"."""
+
+ wandb_project:str="isaaclab"
+"""The wandb project name. Default is "isaaclab"."""
+
+ ##
+ # Loading parameters
+ ##
+
+ resume:bool=False
+"""Whether to resume. Default is False."""
+
+ load_run:str=".*"
+"""The run directory to load. Default is ".*" (all).
+
+ If regex expression, the latest (alphabetical order) matching run will be loaded.
+ """
+
+ load_checkpoint:str="model_.*.pt"
+"""The checkpoint file to load. Default is ``"model_.*.pt"`` (all).
+
+ If regex expression, the latest (alphabetical order) matching file will be loaded.
+ """
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Wrapper to configure a :class:`ManagerBasedRLEnv` or :class:`DirectRlEnv` instance to RSL-RL vectorized environment.
+
+The following example shows how to wrap an environment for RSL-RL:
+
+.. code-block:: python
+
+ from omni.isaac.lab_tasks.utils.wrappers.rsl_rl import RslRlVecEnvWrapper
+
+ env = RslRlVecEnvWrapper(env)
+
+"""
+
+
+importgymnasiumasgym
+importtorch
+
+fromrsl_rl.envimportVecEnv
+
+fromomni.isaac.lab.envsimportDirectRLEnv,ManagerBasedRLEnv
+
+
+
[文档]classRslRlVecEnvWrapper(VecEnv):
+"""Wraps around Isaac Lab environment for RSL-RL library
+
+ To use asymmetric actor-critic, the environment instance must have the attributes :attr:`num_privileged_obs` (int).
+ This is used by the learning agent to allocate buffers in the trajectory memory. Additionally, the returned
+ observations should have the key "critic" which corresponds to the privileged observations. Since this is
+ optional for some environments, the wrapper checks if these attributes exist. If they don't then the wrapper
+ defaults to zero as number of privileged observations.
+
+ .. caution::
+
+ This class must be the last wrapper in the wrapper chain. This is because the wrapper does not follow
+ the :class:`gym.Wrapper` interface. Any subsequent wrappers will need to be modified to work with this
+ wrapper.
+
+ Reference:
+ https://github.com/leggedrobotics/rsl_rl/blob/master/rsl_rl/env/vec_env.py
+ """
+
+
[文档]def__init__(self,env:ManagerBasedRLEnv|DirectRLEnv):
+"""Initializes the wrapper.
+
+ Note:
+ The wrapper calls :meth:`reset` at the start since the RSL-RL runner does not call reset.
+
+ Args:
+ env: The environment to wrap around.
+
+ Raises:
+ ValueError: When the environment is not an instance of :class:`ManagerBasedRLEnv` or :class:`DirectRLEnv`.
+ """
+ # check that input is valid
+ ifnotisinstance(env.unwrapped,ManagerBasedRLEnv)andnotisinstance(env.unwrapped,DirectRLEnv):
+ raiseValueError(
+ "The environment must be inherited from ManagerBasedRLEnv or DirectRLEnv. Environment type:"
+ f" {type(env)}"
+ )
+ # initialize the wrapper
+ self.env=env
+ # store information required by wrapper
+ self.num_envs=self.unwrapped.num_envs
+ self.device=self.unwrapped.device
+ self.max_episode_length=self.unwrapped.max_episode_length
+ ifhasattr(self.unwrapped,"action_manager"):
+ self.num_actions=self.unwrapped.action_manager.total_action_dim
+ else:
+ self.num_actions=self.unwrapped.num_actions
+ ifhasattr(self.unwrapped,"observation_manager"):
+ self.num_obs=self.unwrapped.observation_manager.group_obs_dim["policy"][0]
+ else:
+ self.num_obs=self.unwrapped.num_observations
+ # -- privileged observations
+ if(
+ hasattr(self.unwrapped,"observation_manager")
+ and"critic"inself.unwrapped.observation_manager.group_obs_dim
+ ):
+ self.num_privileged_obs=self.unwrapped.observation_manager.group_obs_dim["critic"][0]
+ elifhasattr(self.unwrapped,"num_states"):
+ self.num_privileged_obs=self.unwrapped.num_states
+ else:
+ self.num_privileged_obs=0
+ # reset at the start since the RSL-RL runner does not call reset
+ self.env.reset()
+
+ def__str__(self):
+"""Returns the wrapper name and the :attr:`env` representation string."""
+ returnf"<{type(self).__name__}{self.env}>"
+
+ def__repr__(self):
+"""Returns the string representation of the wrapper."""
+ returnstr(self)
+
+"""
+ Properties -- Gym.Wrapper
+ """
+
+ @property
+ defcfg(self)->object:
+"""Returns the configuration class instance of the environment."""
+ returnself.unwrapped.cfg
+
+ @property
+ defrender_mode(self)->str|None:
+"""Returns the :attr:`Env` :attr:`render_mode`."""
+ returnself.env.render_mode
+
+ @property
+ defobservation_space(self)->gym.Space:
+"""Returns the :attr:`Env` :attr:`observation_space`."""
+ returnself.env.observation_space
+
+ @property
+ defaction_space(self)->gym.Space:
+"""Returns the :attr:`Env` :attr:`action_space`."""
+ returnself.env.action_space
+
+
[文档]@classmethod
+ defclass_name(cls)->str:
+"""Returns the class name of the wrapper."""
+ returncls.__name__
+
+ @property
+ defunwrapped(self)->ManagerBasedRLEnv|DirectRLEnv:
+"""Returns the base environment of the wrapper.
+
+ This will be the bare :class:`gymnasium.Env` environment, underneath all layers of wrappers.
+ """
+ returnself.env.unwrapped
+
+"""
+ Properties
+ """
+
+
[文档]defget_observations(self)->tuple[torch.Tensor,dict]:
+"""Returns the current observations of the environment."""
+ ifhasattr(self.unwrapped,"observation_manager"):
+ obs_dict=self.unwrapped.observation_manager.compute()
+ else:
+ obs_dict=self.unwrapped._get_observations()
+ returnobs_dict["policy"],{"observations":obs_dict}
+
+ @property
+ defepisode_length_buf(self)->torch.Tensor:
+"""The episode length buffer."""
+ returnself.unwrapped.episode_length_buf
+
+ @episode_length_buf.setter
+ defepisode_length_buf(self,value:torch.Tensor):
+"""Set the episode length buffer.
+
+ Note:
+ This is needed to perform random initialization of episode lengths in RSL-RL.
+ """
+ self.unwrapped.episode_length_buf=value
+
+"""
+ Operations - MDP
+ """
+
+ defseed(self,seed:int=-1)->int:# noqa: D102
+ returnself.unwrapped.seed(seed)
+
+ defreset(self)->tuple[torch.Tensor,dict]:# noqa: D102
+ # reset the environment
+ obs_dict,_=self.env.reset()
+ # return observations
+ returnobs_dict["policy"],{"observations":obs_dict}
+
+ defstep(self,actions:torch.Tensor)->tuple[torch.Tensor,torch.Tensor,torch.Tensor,dict]:
+ # record step information
+ obs_dict,rew,terminated,truncated,extras=self.env.step(actions)
+ # compute dones for compatibility with RSL-RL
+ dones=(terminated|truncated).to(dtype=torch.long)
+ # move extra observations to the extras dict
+ obs=obs_dict["policy"]
+ extras["observations"]=obs_dict
+ # move time out information to the extras dict
+ # this is only needed for infinite horizon tasks
+ ifnotself.unwrapped.cfg.is_finite_horizon:
+ extras["time_outs"]=truncated
+
+ # return the step information
+ returnobs,rew,dones,extras
+
+ defclose(self):# noqa: D102
+ returnself.env.close()
[文档]classSb3VecEnvWrapper(VecEnv):
+"""Wraps around Isaac Lab environment for Stable Baselines3.
+
+ Isaac Sim internally implements a vectorized environment. However, since it is
+ still considered a single environment instance, Stable Baselines tries to wrap
+ around it using the :class:`DummyVecEnv`. This is only done if the environment
+ is not inheriting from their :class:`VecEnv`. Thus, this class thinly wraps
+ over the environment from :class:`ManagerBasedRLEnv` or :class:`DirectRLEnv`.
+
+ Note:
+ While Stable-Baselines3 supports Gym 0.26+ API, their vectorized environment
+ still uses the old API (i.e. it is closer to Gym 0.21). Thus, we implement
+ the old API for the vectorized environment.
+
+ We also add monitoring functionality that computes the un-discounted episode
+ return and length. This information is added to the info dicts under key `episode`.
+
+ In contrast to the Isaac Lab environment, stable-baselines expect the following:
+
+ 1. numpy datatype for MDP signals
+ 2. a list of info dicts for each sub-environment (instead of a dict)
+ 3. when environment has terminated, the observations from the environment should correspond
+ to the one after reset. The "real" final observation is passed using the info dicts
+ under the key ``terminal_observation``.
+
+ .. warning::
+
+ By the nature of physics stepping in Isaac Sim, it is not possible to forward the
+ simulation buffers without performing a physics step. Thus, reset is performed
+ inside the :meth:`step()` function after the actual physics step is taken.
+ Thus, the returned observations for terminated environments is the one after the reset.
+
+ .. caution::
+
+ This class must be the last wrapper in the wrapper chain. This is because the wrapper does not follow
+ the :class:`gym.Wrapper` interface. Any subsequent wrappers will need to be modified to work with this
+ wrapper.
+
+ Reference:
+
+ 1. https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html
+ 2. https://stable-baselines3.readthedocs.io/en/master/common/monitor.html
+
+ """
+
+
[文档]def__init__(self,env:ManagerBasedRLEnv|DirectRLEnv):
+"""Initialize the wrapper.
+
+ Args:
+ env: The environment to wrap around.
+
+ Raises:
+ ValueError: When the environment is not an instance of :class:`ManagerBasedRLEnv` or :class:`DirectRLEnv`.
+ """
+ # check that input is valid
+ ifnotisinstance(env.unwrapped,ManagerBasedRLEnv)andnotisinstance(env.unwrapped,DirectRLEnv):
+ raiseValueError(
+ "The environment must be inherited from ManagerBasedRLEnv or DirectRLEnv. Environment type:"
+ f" {type(env)}"
+ )
+ # initialize the wrapper
+ self.env=env
+ # collect common information
+ self.num_envs=self.unwrapped.num_envs
+ self.sim_device=self.unwrapped.device
+ self.render_mode=self.unwrapped.render_mode
+
+ # obtain gym spaces
+ # note: stable-baselines3 does not like when we have unbounded action space so
+ # we set it to some high value here. Maybe this is not general but something to think about.
+ observation_space=self.unwrapped.single_observation_space["policy"]
+ action_space=self.unwrapped.single_action_space
+ ifisinstance(action_space,gym.spaces.Box)andnotaction_space.is_bounded("both"):
+ action_space=gym.spaces.Box(low=-100,high=100,shape=action_space.shape)
+
+ # initialize vec-env
+ VecEnv.__init__(self,self.num_envs,observation_space,action_space)
+ # add buffer for logging episodic information
+ self._ep_rew_buf=torch.zeros(self.num_envs,device=self.sim_device)
+ self._ep_len_buf=torch.zeros(self.num_envs,device=self.sim_device)
+
+ def__str__(self):
+"""Returns the wrapper name and the :attr:`env` representation string."""
+ returnf"<{type(self).__name__}{self.env}>"
+
+ def__repr__(self):
+"""Returns the string representation of the wrapper."""
+ returnstr(self)
+
+"""
+ Properties -- Gym.Wrapper
+ """
+
+
[文档]@classmethod
+ defclass_name(cls)->str:
+"""Returns the class name of the wrapper."""
+ returncls.__name__
+
+ @property
+ defunwrapped(self)->ManagerBasedRLEnv|DirectRLEnv:
+"""Returns the base environment of the wrapper.
+
+ This will be the bare :class:`gymnasium.Env` environment, underneath all layers of wrappers.
+ """
+ returnself.env.unwrapped
+
+"""
+ Properties
+ """
+
+
[文档]defget_episode_rewards(self)->list[float]:
+"""Returns the rewards of all the episodes."""
+ returnself._ep_rew_buf.cpu().tolist()
+
+
[文档]defget_episode_lengths(self)->list[int]:
+"""Returns the number of time-steps of all the episodes."""
+ returnself._ep_len_buf.cpu().tolist()
+
+"""
+ Operations - MDP
+ """
+
+ defseed(self,seed:int|None=None)->list[int|None]:# noqa: D102
+ return[self.unwrapped.seed(seed)]*self.unwrapped.num_envs
+
+ defreset(self)->VecEnvObs:# noqa: D102
+ obs_dict,_=self.env.reset()
+ # reset episodic information buffers
+ self._ep_rew_buf.zero_()
+ self._ep_len_buf.zero_()
+ # convert data types to numpy depending on backend
+ returnself._process_obs(obs_dict)
+
+ defstep_async(self,actions):# noqa: D102
+ # convert input to numpy array
+ ifnotisinstance(actions,torch.Tensor):
+ actions=np.asarray(actions)
+ actions=torch.from_numpy(actions).to(device=self.sim_device,dtype=torch.float32)
+ else:
+ actions=actions.to(device=self.sim_device,dtype=torch.float32)
+ # convert to tensor
+ self._async_actions=actions
+
+ defstep_wait(self)->VecEnvStepReturn:# noqa: D102
+ # record step information
+ obs_dict,rew,terminated,truncated,extras=self.env.step(self._async_actions)
+ # update episode un-discounted return and length
+ self._ep_rew_buf+=rew
+ self._ep_len_buf+=1
+ # compute reset ids
+ dones=terminated|truncated
+ reset_ids=(dones>0).nonzero(as_tuple=False)
+
+ # convert data types to numpy depending on backend
+ # note: ManagerBasedRLEnv uses torch backend (by default).
+ obs=self._process_obs(obs_dict)
+ rew=rew.detach().cpu().numpy()
+ terminated=terminated.detach().cpu().numpy()
+ truncated=truncated.detach().cpu().numpy()
+ dones=dones.detach().cpu().numpy()
+ # convert extra information to list of dicts
+ infos=self._process_extras(obs,terminated,truncated,extras,reset_ids)
+
+ # reset info for terminated environments
+ self._ep_rew_buf[reset_ids]=0
+ self._ep_len_buf[reset_ids]=0
+
+ returnobs,rew,dones,infos
+
+ defclose(self):# noqa: D102
+ self.env.close()
+
+ defget_attr(self,attr_name,indices=None):# noqa: D102
+ # resolve indices
+ ifindicesisNone:
+ indices=slice(None)
+ num_indices=self.num_envs
+ else:
+ num_indices=len(indices)
+ # obtain attribute value
+ attr_val=getattr(self.env,attr_name)
+ # return the value
+ ifnotisinstance(attr_val,torch.Tensor):
+ return[attr_val]*num_indices
+ else:
+ returnattr_val[indices].detach().cpu().numpy()
+
+ defset_attr(self,attr_name,value,indices=None):# noqa: D102
+ raiseNotImplementedError("Setting attributes is not supported.")
+
+ defenv_method(self,method_name:str,*method_args,indices=None,**method_kwargs):# noqa: D102
+ ifmethod_name=="render":
+ # gymnasium does not support changing render mode at runtime
+ returnself.env.render()
+ else:
+ # this isn't properly implemented but it is not necessary.
+ # mostly done for completeness.
+ env_method=getattr(self.env,method_name)
+ returnenv_method(*method_args,indices=indices,**method_kwargs)
+
+ defenv_is_wrapped(self,wrapper_class,indices=None):# noqa: D102
+ raiseNotImplementedError("Checking if environment is wrapped is not supported.")
+
+ defget_images(self):# noqa: D102
+ raiseNotImplementedError("Getting images is not supported.")
+
+"""
+ Helper functions.
+ """
+
+ def_process_obs(self,obs_dict:torch.Tensor|dict[str,torch.Tensor])->np.ndarray|dict[str,np.ndarray]:
+"""Convert observations into NumPy data type."""
+ # Sb3 doesn't support asymmetric observation spaces, so we only use "policy"
+ obs=obs_dict["policy"]
+ # note: ManagerBasedRLEnv uses torch backend (by default).
+ ifisinstance(obs,dict):
+ forkey,valueinobs.items():
+ obs[key]=value.detach().cpu().numpy()
+ elifisinstance(obs,torch.Tensor):
+ obs=obs.detach().cpu().numpy()
+ else:
+ raiseNotImplementedError(f"Unsupported data type: {type(obs)}")
+ returnobs
+
+ def_process_extras(
+ self,obs:np.ndarray,terminated:np.ndarray,truncated:np.ndarray,extras:dict,reset_ids:np.ndarray
+ )->list[dict[str,Any]]:
+"""Convert miscellaneous information into dictionary for each sub-environment."""
+ # create empty list of dictionaries to fill
+ infos:list[dict[str,Any]]=[dict.fromkeys(extras.keys())for_inrange(self.num_envs)]
+ # fill-in information for each sub-environment
+ # note: This loop becomes slow when number of environments is large.
+ foridxinrange(self.num_envs):
+ # fill-in episode monitoring info
+ ifidxinreset_ids:
+ infos[idx]["episode"]=dict()
+ infos[idx]["episode"]["r"]=float(self._ep_rew_buf[idx])
+ infos[idx]["episode"]["l"]=float(self._ep_len_buf[idx])
+ else:
+ infos[idx]["episode"]=None
+ # fill-in bootstrap information
+ infos[idx]["TimeLimit.truncated"]=truncated[idx]andnotterminated[idx]
+ # fill-in information from extras
+ forkey,valueinextras.items():
+ # 1. remap extra episodes information safely
+ # 2. for others just store their values
+ ifkey=="log":
+ # only log this data for episodes that are terminated
+ ifinfos[idx]["episode"]isnotNone:
+ forsub_key,sub_valueinvalue.items():
+ infos[idx]["episode"][sub_key]=sub_value
+ else:
+ infos[idx][key]=value[idx]
+ # add information about terminal observation separately
+ ifidxinreset_ids:
+ # extract terminal observations
+ ifisinstance(obs,dict):
+ terminal_obs=dict.fromkeys(obs.keys())
+ forkey,valueinobs.items():
+ terminal_obs[key]=value[idx]
+ else:
+ terminal_obs=obs[idx]
+ # add info to dict
+ infos[idx]["terminal_observation"]=terminal_obs
+ else:
+ infos[idx]["terminal_observation"]=None
+ # return list of dictionaries
+ returninfos
+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Wrapper to configure an Isaac Lab environment instance to skrl environment.
+
+The following example shows how to wrap an environment for skrl:
+
+.. code-block:: python
+
+ from omni.isaac.lab_tasks.utils.wrappers.skrl import SkrlVecEnvWrapper
+
+ env = SkrlVecEnvWrapper(env, ml_framework="torch") # or ml_framework="jax"
+
+Or, equivalently, by directly calling the skrl library API as follows:
+
+.. code-block:: python
+
+ from skrl.envs.torch.wrappers import wrap_env # for PyTorch, or...
+ from skrl.envs.jax.wrappers import wrap_env # for JAX
+
+ env = wrap_env(env, wrapper="isaaclab")
+
+"""
+
+# needed to import for type hinting: Agent | list[Agent]
+from__future__importannotations
+
+fromtypingimportLiteral
+
+fromomni.isaac.lab.envsimportDirectMARLEnv,DirectRLEnv,ManagerBasedRLEnv
+
+"""
+Vectorized environment wrapper.
+"""
+
+
+
[文档]defSkrlVecEnvWrapper(
+ env:ManagerBasedRLEnv|DirectRLEnv|DirectMARLEnv,
+ ml_framework:Literal["torch","jax","jax-numpy"]="torch",
+ wrapper:Literal["auto","isaaclab","isaaclab-single-agent","isaaclab-multi-agent"]="isaaclab",
+):
+"""Wraps around Isaac Lab environment for skrl.
+
+ This function wraps around the Isaac Lab environment. Since the wrapping
+ functionality is defined within the skrl library itself, this implementation
+ is maintained for compatibility with the structure of the extension that contains it.
+ Internally it calls the :func:`wrap_env` from the skrl library API.
+
+ Args:
+ env: The environment to wrap around.
+ ml_framework: The ML framework to use for the wrapper. Defaults to "torch".
+ wrapper: The wrapper to use. Defaults to "isaaclab": leave it to skrl to determine if the environment
+ will be wrapped as single-agent or multi-agent.
+
+ Raises:
+ ValueError: When the environment is not an instance of any Isaac Lab environment interface.
+ ValueError: If the specified ML framework is not valid.
+
+ Reference:
+ https://skrl.readthedocs.io/en/latest/api/envs/wrapping.html
+ """
+ # check that input is valid
+ if(
+ notisinstance(env.unwrapped,ManagerBasedRLEnv)
+ andnotisinstance(env.unwrapped,DirectRLEnv)
+ andnotisinstance(env.unwrapped,DirectMARLEnv)
+ ):
+ raiseValueError(
+ "The environment must be inherited from ManagerBasedRLEnv, DirectRLEnv or DirectMARLEnv. Environment type:"
+ f" {type(env)}"
+ )
+
+ # import statements according to the ML framework
+ ifml_framework.startswith("torch"):
+ fromskrl.envs.wrappers.torchimportwrap_env
+ elifml_framework.startswith("jax"):
+ fromskrl.envs.wrappers.jaximportwrap_env
+ else:
+ ValueError(
+ f"Invalid ML framework for skrl: {ml_framework}. Available options are: 'torch', 'jax' or 'jax-numpy'"
+ )
+
+ # wrap and return the environment
+ returnwrap_env(env,wrapper)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/_sources/index.rst b/_sources/index.rst
new file mode 100644
index 0000000000..19db6e45dd
--- /dev/null
+++ b/_sources/index.rst
@@ -0,0 +1,158 @@
+Overview
+========
+
+.. figure:: source/_static/isaaclab.jpg
+ :width: 100%
+ :alt: H1 Humanoid example using Isaac Lab
+
+
+**注意:** 本翻译项目不属于 NVIDIA 或 IsaacLab 官方文档,由范子琦提供中文翻译,仅供学习交流使用,禁止转载或用于商业用途。详情请查看 `关于翻译