Copy # Copyright (c) 2022-2024, The Cyber Nachos.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import numpy as np
import torch
from collections . abc import Sequence
import omni . isaac . lab . sim as sim_utils
from omni . isaac . lab . assets import Articulation , RigidObject
from omni . isaac . lab . envs import DirectRLEnv
from omni . isaac . lab . markers import VisualizationMarkers
from omni . isaac . lab . sim . spawners . from_files import GroundPlaneCfg , spawn_ground_plane
from omni . isaac . lab . utils . math import quat_conjugate , quat_from_angle_axis , quat_mul , sample_uniform , saturate
from isaacLab . manipulation . tasks . Direct_hand . franka_allegro import FrankaAllegroEnvCfg
import omni . kit . commands # type: ignore
from pxr import Usd , UsdGeom , Sdf , Gf , Vt , PhysxSchema # type: ignore
from omni . physx . scripts import physicsUtils , particleUtils # type: ignore
class PourWaterEnv ( DirectRLEnv ):
cfg : FrankaAllegroEnvCfg
def __init__ ( self , cfg : FrankaAllegroEnvCfg , render_mode : str | None = None , ** kwargs ):
self . stage = omni . usd . get_context (). get_stage ()
super (). __init__ (cfg, render_mode, ** kwargs)
self . num_hand_dofs = self . hand . num_joints
# buffers for position targets
self . hand_dof_targets = torch . zeros ((self.num_envs, self.num_hand_dofs), dtype = torch.float, device = self.device)
self . prev_targets = torch . zeros ((self.num_envs, self.num_hand_dofs), dtype = torch.float, device = self.device)
self . cur_targets = torch . zeros ((self.num_envs, self.num_hand_dofs), dtype = torch.float, device = self.device)
# list of actuated joints
self . actuated_dof_indices = list ()
for joint_name in cfg . actuated_joint_names :
self . actuated_dof_indices . append (self.hand.joint_names. index (joint_name))
self . actuated_dof_indices . sort ()
# finger bodies
self . finger_bodies = list ()
for body_name in self . cfg . fingertip_body_names :
self . finger_bodies . append (self.hand.body_names. index (body_name))
self . finger_bodies . sort ()
self . num_fingertips = len (self.finger_bodies)
# joint limits
joint_pos_limits = self . hand . root_physx_view . get_dof_limits (). to (self.device)
self . hand_dof_lower_limits = joint_pos_limits [ ... , 0 ]
self . hand_dof_upper_limits = joint_pos_limits [ ... , 1 ]
# track goal resets
self . reset_goal_buf = torch . zeros (self.num_envs, dtype = torch.bool, device = self.device)
# default goal positions
self . goal_rot = torch . zeros ((self.num_envs, 4 ), dtype = torch.float, device = self.device)
self . goal_rot [:, 0 ] = 1.0
self . goal_pos = torch . zeros ((self.num_envs, 3 ), dtype = torch.float, device = self.device)
self . goal_pos [:, :] = torch . tensor ([ - 0.2 , - 0.45 , 0.68 ], device = self.device)
# initialize goal marker
self . goal_markers = VisualizationMarkers (self.cfg.goal_object_cfg)
# track successes
self . successes = torch . zeros (self.num_envs, dtype = torch.float, device = self.device)
self . consecutive_successes = torch . zeros ( 1 , dtype = torch.float, device = self.device)
# unit tensors
self . x_unit_tensor = torch . tensor ([ 1 , 0 , 0 ], dtype = torch.float, device = self.device). repeat ((self.num_envs, 1 ))
self . y_unit_tensor = torch . tensor ([ 0 , 1 , 0 ], dtype = torch.float, device = self.device). repeat ((self.num_envs, 1 ))
self . z_unit_tensor = torch . tensor ([ 0 , 0 , 1 ], dtype = torch.float, device = self.device). repeat ((self.num_envs, 1 ))
particle_init_states = []
def _setup_scene ( self ):
# add hand, in-hand object, and goal object
self . hand = Articulation (self.cfg.robot_cfg)
self . object = RigidObject (self.cfg.object_cfg)
# add ground plane
spawn_ground_plane (prim_path = "/World/ground" , cfg = GroundPlaneCfg ())
# clone and replicate (no need to filter for this environment)
self . scene . clone_environments (copy_from_source = False )
# add articultion to scene - we must register to scene to randomize with EventManager
self . scene . articulations [ "robot" ] = self . hand
self . scene . rigid_objects [ "object" ] = self . object
# add lights
light_cfg = sim_utils . DomeLightCfg (intensity = 2000.0 , color = ( 0.75 , 0.75 , 0.75 ))
light_cfg . func ( "/World/Light" , light_cfg)
default_prim : Usd . Prim = UsdGeom . Xform . Define (self.stage, Sdf. Path ( "/World" )). GetPrim ()
self . stage . SetDefaultPrim (default_prim)
for i , origin in enumerate (self.scene.env_origins):
self . _create_fluid (i, self.stage, Sdf. Path ( f "/World" ), (origin + torch. tensor ([ - 0.02 , - 0.03 , 0.15 ]). to (self.device)). tolist (), [ 4 , 4 , 8 ])
self . particle_init_states . append (self. _get_particles_state (i))
print ( "Create complete" )
def _create_fluid ( self , num_system , stage , default_prim_path , lower_pos , nums ):
numParticlesX , numParticlesY , numParticlesZ = nums
particleSystemPath = default_prim_path . AppendChild ( "particleSystem" )
particle_system = stage . GetPrimAtPath (particleSystemPath)
particleSpacing = 0.01
restOffset = particleSpacing * 0.9
fluidRestOffset = restOffset * 0.6
particleContactOffset = restOffset + 0.001
if not particle_system . IsValid ():
# particle params
particle_system = particleUtils . add_physx_particle_system (
stage = stage,
particle_system_path = particleSystemPath,
simulation_owner = "physicsScene" ,
contact_offset = restOffset * 1.5 + 0.005 ,
rest_offset = restOffset * 1.5 ,
particle_contact_offset = particleContactOffset,
solid_rest_offset = 0.0 ,
fluid_rest_offset = fluidRestOffset,
solver_position_iterations = 16 ,
)
omni . kit . commands . execute (
"CreateMdlMaterialPrim" ,
mtl_url = "OmniSurfacePresets.mdl" ,
mtl_name = "OmniSurface_DeepWater" ,
mtl_path = '/World/Looks/OmniSurface_DeepWater' ,
select_new_prim = False ,
)
pbd_particle_material_path = '/World/Looks/OmniSurface_DeepWater'
omni . kit . commands . execute (
"BindMaterial" , prim_path = particleSystemPath, material_path = pbd_particle_material_path
)
# Create a pbd particle material and set it on the particle system
particleUtils . add_pbd_particle_material (
stage,
pbd_particle_material_path,
cohesion = 10 ,
viscosity = 0.91 ,
surface_tension = 0.74 ,
friction = 0.1 ,
)
physicsUtils . add_physics_material_to_prim (stage, particle_system. GetPrim (), pbd_particle_material_path)
particle_system . CreateMaxVelocityAttr (). Set ( 200 )
# add particle anisotropy
anisotropyAPI = PhysxSchema . PhysxParticleAnisotropyAPI . Apply (particle_system. GetPrim ())
anisotropyAPI . CreateParticleAnisotropyEnabledAttr (). Set ( True )
aniso_scale = 5.0
anisotropyAPI . CreateScaleAttr (). Set (aniso_scale)
anisotropyAPI . CreateMinAttr (). Set ( 1.0 )
anisotropyAPI . CreateMaxAttr (). Set ( 2.0 )
# add particle smoothing
smoothingAPI = PhysxSchema . PhysxParticleSmoothingAPI . Apply (particle_system. GetPrim ())
smoothingAPI . CreateParticleSmoothingEnabledAttr (). Set ( True )
smoothingAPI . CreateStrengthAttr (). Set ( 0.5 )
# apply isosurface params
isosurfaceAPI = PhysxSchema . PhysxParticleIsosurfaceAPI . Apply (particle_system. GetPrim ())
isosurfaceAPI . CreateIsosurfaceEnabledAttr (). Set ( True )
isosurfaceAPI . CreateMaxVerticesAttr (). Set ( 1024 * 1024 )
isosurfaceAPI . CreateMaxTrianglesAttr (). Set ( 2 * 1024 * 1024 )
isosurfaceAPI . CreateMaxSubgridsAttr (). Set ( 1024 * 4 )
isosurfaceAPI . CreateGridSpacingAttr (). Set (fluidRestOffset * 1.5 )
isosurfaceAPI . CreateSurfaceDistanceAttr (). Set (fluidRestOffset * 1.6 )
isosurfaceAPI . CreateGridFilteringPassesAttr (). Set ( "" )
isosurfaceAPI . CreateGridSmoothingRadiusAttr (). Set (fluidRestOffset * 2 )
isosurfaceAPI . CreateNumMeshSmoothingPassesAttr (). Set ( 1 )
primVarsApi = UsdGeom . PrimvarsAPI (particle_system)
primVarsApi . CreatePrimvar ( "doNotCastShadows" , Sdf.ValueTypeNames.Bool). Set ( True )
stage . SetInterpolationType (Usd.InterpolationTypeHeld)
particlesPath = default_prim_path . AppendChild ( "particles" + str (num_system))
lower = lower_pos
positions , velocities = particleUtils . create_particles_grid (
lower, particleSpacing + 0.01 , numParticlesX, numParticlesY, numParticlesZ
)
uniform_range = particleSpacing * 0.2
for i in range ( len (positions)):
positions [ i ] [ 0 ] += np . random . default_rng ( 45 ). uniform ( - uniform_range, uniform_range)
positions [ i ] [ 1 ] += np . random . default_rng ( 45 ). uniform ( - uniform_range, uniform_range)
positions [ i ] [ 2 ] += np . random . default_rng ( 45 ). uniform ( - uniform_range, uniform_range)
widths = [particleSpacing] * len (positions)
particleUtils . add_physx_particleset_points (
stage = stage,
path = particlesPath,
positions_list = Vt. Vec3fArray (positions),
velocities_list = Vt. Vec3fArray (velocities),
widths_list = widths,
particle_system_path = particleSystemPath,
self_collision = True ,
fluid = True ,
particle_group = 0 ,
particle_mass = 1 ,
density = 1000 ,
)
def _pre_physics_step ( self , actions : torch . Tensor) -> None :
self . actions = actions . clone ()
def _apply_action ( self ) -> None :
self . cur_targets [:, self . actuated_dof_indices ] = scale (
self.actions,
self.hand_dof_lower_limits[:, self.actuated_dof_indices],
self.hand_dof_upper_limits[:, self.actuated_dof_indices],
)
self . cur_targets [:, self . actuated_dof_indices ] = (
self . cfg . act_moving_average * self . cur_targets [:, self . actuated_dof_indices ]
+ ( 1.0 - self . cfg . act_moving_average) * self . prev_targets [:, self . actuated_dof_indices ]
)
self . cur_targets [:, self . actuated_dof_indices ] = saturate (
self.cur_targets[:, self.actuated_dof_indices],
self.hand_dof_lower_limits[:, self.actuated_dof_indices],
self.hand_dof_upper_limits[:, self.actuated_dof_indices],
)
self . prev_targets [:, self . actuated_dof_indices ] = self . cur_targets [:, self . actuated_dof_indices ]
self . hand . set_joint_position_target (
self.cur_targets[:, self.actuated_dof_indices], joint_ids = self.actuated_dof_indices
)
def _get_particles_state ( self , env_id ) -> tuple [ Gf . Vec3f , Gf . Vec3f ] :
# Gets particles' positions and velocities
particles = UsdGeom . Points (self.stage. GetPrimAtPath (Sdf. Path ( f "/World/particles { env_id } " )))
particles_pos = particles . GetPointsAttr (). Get ()
particles_vel = particles . GetVelocitiesAttr (). Get ()
return particles_pos , particles_vel
prim_cache = {}
def _get_particles_position ( self , env_id ) -> tuple [ Gf . Vec3f , Gf . Vec3f ] :
particles = self . prim_cache . get ( f "/World/particles { env_id } " , None )
if particles is None :
# Gets particles' positions and velocities
particles = UsdGeom . Points (self.stage. GetPrimAtPath (Sdf. Path ( f "/World/particles { env_id } " )))
self . prim_cache [ f "/World/particles { env_id } " ] = particles
particles_pos = particles . GetPointsAttr (). Get ()
return particles_pos
def _set_particles_state ( self , particles_pos : Gf . Vec3f , particles_vel : Gf . Vec3f , env_id : int ):
# Sets the particles' position and velocities to the given arrays
particles = UsdGeom . Points (self.stage. GetPrimAtPath (Sdf. Path ( f "/World/particles { env_id } " )))
particles_pos = particles . GetPointsAttr (). Set (particles_pos)
particles_vel = particles . GetVelocitiesAttr (). Set (particles_vel)
def _get_observations ( self ) -> dict :
if self . cfg . asymmetric_obs :
self . fingertip_force_sensors = self . hand . root_physx_view . get_link_incoming_joint_force () [
:, self . finger_bodies
]
if self . cfg . obs_type == "openai" :
obs = self . compute_reduced_observations ()
elif self . cfg . obs_type == "full" :
obs = self . compute_full_observations ()
else :
print ( "Unknown observations type!" )
if self . cfg . asymmetric_obs :
states = self . compute_full_state ()
observations = { "policy" : obs }
if self . cfg . asymmetric_obs :
observations = { "policy" : obs , "critic" : states }
return observations
def _get_rewards ( self ) -> torch . Tensor:
(
total_reward ,
self . reset_goal_buf ,
self . successes [:],
self . consecutive_successes [:],
) = compute_rewards (
self.reset_buf,
self.reset_goal_buf,
self.successes,
self.consecutive_successes,
self.max_episode_length,
self.hand.data.body_pos_w[:, 10 , :] - self.scene.env_origins,
self.object_pos,
self.object_rot,
self.goal_pos,
self.goal_rot,
self.particle_min,
self.cfg.dist_reward_temperature,
self.cfg.rot_reward_temperature,
self.cfg.hand_reward_temperature,
self.cfg.rot_eps,
self.actions,
self.cfg.action_penalty_scale,
self.cfg.success_tolerance,
self.cfg.reach_goal_bonus,
self.cfg.fall_dist,
self.cfg.fall_penalty,
self.cfg.av_factor,
)
if "log" not in self . extras :
self . extras [ "log" ] = dict ()
self . extras [ "log" ] [ "consecutive_successes" ] = self . consecutive_successes . mean ()
# reset goals if the goal has been reached
goal_env_ids = self . reset_goal_buf . nonzero (as_tuple = False ). squeeze ( - 1 )
if len (goal_env_ids) > 0 :
self . _reset_target_pose (goal_env_ids)
return total_reward
def _get_dones ( self ) -> tuple [ torch . Tensor , torch . Tensor ] :
self . _compute_intermediate_values ()
# reset when liquid out of cup
out_of_cup = self . particle_min < 0.015
if self . cfg . max_consecutive_success > 0 :
# Reset progress (episode length buf) on goal envs if max_consecutive_success > 0
pos_dist = torch . norm (self.object_pos - self.goal_pos)
self . episode_length_buf = torch . where (
torch. abs (pos_dist) <= self.cfg.success_tolerance,
torch. zeros_like (self.episode_length_buf),
self.episode_length_buf,
)
max_success_reached = self . successes >= self . cfg . max_consecutive_success
time_out = self . episode_length_buf >= self . max_episode_length - 1
if self . cfg . max_consecutive_success > 0 :
time_out = time_out | max_success_reached
return out_of_cup . bool (), time_out
def _reset_idx ( self , env_ids : Sequence [ int ] | None ):
if env_ids is None :
env_ids = self . hand . _ALL_INDICES
# resets articulation and rigid body attributes
super (). _reset_idx (env_ids)
# reset goals
self . _reset_target_pose (env_ids)
# reset object
object_default_state = self . object . data . default_root_state . clone () [env_ids]
rand_floats = torch . rand (( len (env_ids), 3 ), device = self.device) * 0.2 - 0.1
ops = rand_floats / torch . abs (rand_floats)
ops [:, 2 ] = 0
offsets = ops * 0.6
rand_floats = torch . abs (rand_floats) * ops + offsets
# global object positions
object_default_state [:, 0 : 3 ] = (
rand_floats + self . scene . env_origins [ env_ids ]
)
object_default_state [:, 7 :] = torch . zeros_like (self.object.data.default_root_state[env_ids, 7 :])
self . object . write_root_state_to_sim (object_default_state, env_ids)
# reset liquid
for i , idx in enumerate (env_ids):
self . _set_particles_state (self.particle_init_states[idx][ 0 ] + Gf. Vec3f (rand_floats[i]. tolist ()), self.particle_init_states[idx][ 1 ], idx)
# reset hand
delta_max = self . hand_dof_upper_limits [ env_ids ] - self . hand . data . default_joint_pos [ env_ids ]
delta_min = self . hand_dof_lower_limits [ env_ids ] - self . hand . data . default_joint_pos [ env_ids ]
dof_pos_noise = sample_uniform ( - 1.0 , 1.0 , ( len (env_ids), self.num_hand_dofs), device = self.device)
rand_delta = delta_min + (delta_max - delta_min) * 0.5 * dof_pos_noise
dof_pos = self . hand . data . default_joint_pos [ env_ids ] + self . cfg . reset_dof_pos_noise * rand_delta
dof_vel_noise = sample_uniform ( - 1.0 , 1.0 , ( len (env_ids), self.num_hand_dofs), device = self.device)
dof_vel = self . hand . data . default_joint_vel [ env_ids ] + self . cfg . reset_dof_vel_noise * dof_vel_noise
self . prev_targets [ env_ids ] = dof_pos
self . cur_targets [ env_ids ] = dof_pos
self . hand_dof_targets [ env_ids ] = dof_pos
self . hand . set_joint_position_target (dof_pos, env_ids = env_ids)
self . hand . write_joint_state_to_sim (dof_pos, dof_vel, env_ids = env_ids)
self . successes [ env_ids ] = 0
self . _compute_intermediate_values ()
def _reset_target_pose ( self , env_ids ):
# reset goal rotation
rand_floats = torch . rand (( len (env_ids), 3 ), device = self.device) * 0.2 - 0.1
ops = rand_floats / torch . abs (rand_floats)
ops [:, 2 ] = 0.7
offsets = ops * 0.6
rand_floats = torch . abs (rand_floats) * ops + offsets
# update goal pose and markers
self . goal_pos [ env_ids ] = rand_floats
self . goal_markers . visualize (self.scene.env_origins[env_ids] + self.goal_pos[env_ids], self.goal_rot)
self . reset_goal_buf [ env_ids ] = 0
def _compute_intermediate_values ( self ):
# data for hand
self . fingertip_pos = self . hand . data . body_pos_w [:, self . finger_bodies ]
self . fingertip_rot = self . hand . data . body_quat_w [:, self . finger_bodies ]
self . fingertip_pos -= self . scene . env_origins . repeat (( 1 , self.num_fingertips)). reshape (
self.num_envs, self.num_fingertips, 3
)
self . fingertip_velocities = self . hand . data . body_vel_w [:, self . finger_bodies ]
self . hand_dof_pos = self . hand . data . joint_pos
self . hand_dof_vel = self . hand . data . joint_vel
# data for object
self . object_pos = self . object . data . root_pos_w - self . scene . env_origins
self . object_rot = self . object . data . root_quat_w
self . object_velocities = self . object . data . root_vel_w
self . object_linvel = self . object . data . root_lin_vel_w
self . object_angvel = self . object . data . root_ang_vel_w
# data for liquid
self . particle_min = []
for idx in range (self.num_envs):
poses = np . array (self. _get_particles_position (idx))
self . particle_min . append (np. min (poses[:, 2 ]))
self . particle_min = torch . tensor (self.particle_min, device = self.device)
def compute_reduced_observations ( self ):
# Per https://arxiv.org/pdf/1808.00177.pdf Table 2
# Fingertip positions
# Object Position, but not orientation
# Relative target orientation
obs = torch . cat (
(
self.fingertip_pos. view (self.num_envs, self.num_fingertips * 3 ),
self.object_pos,
quat_mul (self.object_rot, quat_conjugate (self.goal_rot)),
self.actions,
),
dim =- 1 ,
)
return obs
def compute_full_observations ( self ):
obs = torch . cat (
(
# hand
unscale (self.hand_dof_pos, self.hand_dof_lower_limits, self.hand_dof_upper_limits),
self.cfg.vel_obs_scale * self.hand_dof_vel,
# object
self.object_pos,
self.object_rot,
self.object_linvel,
self.cfg.vel_obs_scale * self.object_angvel,
# goal
self.goal_pos,
self.goal_rot,
quat_mul (self.object_rot, quat_conjugate (self.goal_rot)),
# fingertips
self.fingertip_pos. view (self.num_envs, self.num_fingertips * 3 ),
self.fingertip_rot. view (self.num_envs, self.num_fingertips * 4 ),
self.fingertip_velocities. view (self.num_envs, self.num_fingertips * 6 ),
# actions
self.actions,
),
dim =- 1 ,
)
return obs
def compute_full_state ( self ):
states = torch . cat (
(
# hand
unscale (self.hand_dof_pos, self.hand_dof_lower_limits, self.hand_dof_upper_limits),
self.cfg.vel_obs_scale * self.hand_dof_vel,
# object
self.object_pos,
self.object_rot,
self.object_linvel,
self.cfg.vel_obs_scale * self.object_angvel,
# goal
self.goal_pos,
self.goal_rot,
quat_mul (self.object_rot, quat_conjugate (self.goal_rot)),
# fingertips
self.fingertip_pos. view (self.num_envs, self.num_fingertips * 3 ),
self.fingertip_rot. view (self.num_envs, self.num_fingertips * 4 ),
self.fingertip_velocities. view (self.num_envs, self.num_fingertips * 6 ),
self.cfg.force_torque_obs_scale
* self.fingertip_force_sensors. view (self.num_envs, self.num_fingertips * 6 ),
# actions
self.actions,
),
dim =- 1 ,
)
return states
@torch . jit . script
def scale ( x , lower , upper ):
return 0.5 * (x + 1.0 ) * (upper - lower) + lower
@torch . jit . script
def unscale ( x , lower , upper ):
return ( 2.0 * x - upper - lower) / (upper - lower)
@torch . jit . script
def randomize_rotation ( rand0 , rand1 , x_unit_tensor , y_unit_tensor ):
return quat_mul (
quat_from_angle_axis (rand0 * np.pi, x_unit_tensor), quat_from_angle_axis (rand1 * np.pi, y_unit_tensor)
)
@torch . jit . script
def rotation_distance ( object_rot , target_rot ):
# Orientation alignment for the cube in hand and goal cube
quat_diff = quat_mul (object_rot, quat_conjugate (target_rot))
return 2.0 * torch . asin (torch. clamp (torch. norm (quat_diff[:, 1 : 4 ], p = 2 , dim =- 1 ), max = 1.0 )) # changed quat convention
@torch . jit . script
def compute_rewards (
reset_buf : torch . Tensor ,
reset_goal_buf : torch . Tensor ,
successes : torch . Tensor ,
consecutive_successes : torch . Tensor ,
max_episode_length : float ,
hand_pos : torch . Tensor ,
object_pos : torch . Tensor ,
object_rot : torch . Tensor ,
target_pos : torch . Tensor ,
target_rot : torch . Tensor ,
particle_min : torch . Tensor ,
dist_reward_temperature : float ,
rot_reward_temperature : float ,
hand_reward_temperature : float ,
rot_eps : float ,
actions : torch . Tensor ,
action_penalty_scale : float ,
success_tolerance : float ,
reach_goal_bonus : float ,
fall_dist : float ,
fall_penalty : float ,
av_factor : float ,
):
goal_dist = torch . norm (object_pos - target_pos, p = 2 , dim =- 1 )
rot_dist = rotation_distance (object_rot, target_rot)
dist_rew = torch . exp ( - dist_reward_temperature * goal_dist)
rot_rew = torch . exp ( - rot_reward_temperature * rot_dist)
hand_dis = torch . norm (object_pos - hand_pos, p = 2 , dim =- 1 )
hand_rew = torch . exp ( - hand_reward_temperature * hand_dis)
action_penalty = torch . sum (actions ** 2 , dim =- 1 )
# Total reward is: position distance + orientation alignment + action regularization + success bonus + fall penalty
reward = dist_rew + rot_rew + hand_rew + action_penalty * action_penalty_scale
# Find out which envs hit the goal and update successes count
goal_resets = torch . where (torch. abs (goal_dist) <= success_tolerance, torch. ones_like (reset_goal_buf), reset_goal_buf)
successes = successes + goal_resets
# Success bonus: orientation is within `success_tolerance` of goal orientation
reward = torch . where (goal_resets == 1 , reward + reach_goal_bonus, reward)
# Check env termination conditions, including maximum success number
resets = particle_min < fall_dist
# Fall penalty: distance to the goal is larger than a threshold
reward = torch . where (resets == 1 , reward + fall_penalty, reward)
num_resets = torch . sum (resets)
finished_cons_successes = torch . sum (successes * resets. float ())
cons_successes = torch . where (
num_resets > 0 ,
av_factor * finished_cons_successes / num_resets + ( 1.0 - av_factor) * consecutive_successes,
consecutive_successes,
)
return reward , goal_resets , successes , cons_successes