Skip to content

Commit

Permalink
[Refactor] Fix ruff rule E721: type-comparison (ray-project#49919)
Browse files Browse the repository at this point in the history
Unlike a direct type comparison, isinstance will also check if an object is an instance of a class or a subclass thereof.
If you want to check for an exact type match, use is or is not.

Ref: https://docs.astral.sh/ruff/rules/type-comparison/

Signed-off-by: win5923 <[email protected]>
Signed-off-by: Anson Qian <[email protected]>
  • Loading branch information
win5923 authored and anson627 committed Jan 31, 2025
1 parent 2edc3d9 commit 96d64fa
Show file tree
Hide file tree
Showing 25 changed files with 55 additions and 56 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ ignore = [
# TODO(MortalHappiness): Remove the following rules from the ignore list
# The above are rules ignored originally in flake8
# The following are rules ignored in ruff
"E721",
"F841",
"B018",
"B023",
Expand Down
2 changes: 1 addition & 1 deletion python/ray/_private/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -1627,7 +1627,7 @@ def start_raylet(
Returns:
ProcessInfo for the process that was started.
"""
assert node_manager_port is not None and type(node_manager_port) == int
assert node_manager_port is not None and type(node_manager_port) is int

if use_valgrind and use_profiler:
raise ValueError("Cannot use valgrind and profiler at the same time.")
Expand Down
2 changes: 1 addition & 1 deletion python/ray/_private/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ def wait_until_succeeded_without_exception(
Return:
Whether exception occurs within a timeout.
"""
if type(exceptions) != tuple:
if isinstance(type(exceptions), tuple):
raise Exception("exceptions arguments should be given as a tuple")

time_elapsed = 0
Expand Down
12 changes: 6 additions & 6 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,7 @@ def _preprocess(self) -> None:
"the driver cannot participate in the NCCL group"
)

if type(dag_node.type_hint) == ChannelOutputType:
if type(dag_node.type_hint) is ChannelOutputType:
# No type hint specified by the user. Replace
# with the default type hint for this DAG.
dag_node.with_type_hint(self._default_type_hint)
Expand Down Expand Up @@ -2593,16 +2593,16 @@ def get_channel_details(
if channel in self._channel_dict and self._channel_dict[channel] != channel:
channel = self._channel_dict[channel]
channel_details += f"\n{type(channel).__name__}"
if type(channel) == CachedChannel:
if type(channel) is CachedChannel:
channel_details += f", {channel._channel_id[:6]}..."
# get inner channel
if (
type(channel) == CompositeChannel
type(channel) is CompositeChannel
and downstream_actor_id in channel._channel_dict
):
inner_channel = channel._channel_dict[downstream_actor_id]
channel_details += f"\n{type(inner_channel).__name__}"
if type(inner_channel) == IntraProcessChannel:
if type(inner_channel) is IntraProcessChannel:
channel_details += f", {inner_channel._channel_id[:6]}..."
return channel_details

Expand Down Expand Up @@ -2766,7 +2766,7 @@ def visualize(
task.output_channels[0],
(
downstream_node._get_actor_handle()._actor_id.hex()
if type(downstream_node) == ClassMethodNode
if type(downstream_node) is ClassMethodNode
else self._proxy_actor._actor_id.hex()
),
)
Expand All @@ -2784,7 +2784,7 @@ def visualize(
task.dag_node._get_actor_handle()._actor_id.hex(),
)
dot.edge(str(idx), str(downstream_idx), label=edge_label)
if type(task.dag_node) == InputAttributeNode:
if type(task.dag_node) is InputAttributeNode:
# Add an edge from the InputAttributeNode to the InputNode
dot.edge(str(self.input_task_idx), str(idx))
dot.render(filename, view=view)
Expand Down
18 changes: 9 additions & 9 deletions python/ray/dashboard/tests/test_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ def test_immutable_types():
d["list"][0] = {str(i): i for i in range(1000)}
d["dict"] = {str(i): i for i in range(1000)}
immutable_dict = dashboard_utils.make_immutable(d)
assert type(immutable_dict) == dashboard_utils.ImmutableDict
assert type(immutable_dict) is dashboard_utils.ImmutableDict
assert immutable_dict == dashboard_utils.ImmutableDict(d)
assert immutable_dict == d
assert dashboard_utils.ImmutableDict(immutable_dict) == immutable_dict
Expand All @@ -799,14 +799,14 @@ def test_immutable_types():
assert "512" in d["dict"]

# Test type conversion
assert type(dict(immutable_dict)["list"]) == dashboard_utils.ImmutableList
assert type(list(immutable_dict["list"])[0]) == dashboard_utils.ImmutableDict
assert type(dict(immutable_dict)["list"]) is dashboard_utils.ImmutableList
assert type(list(immutable_dict["list"])[0]) is dashboard_utils.ImmutableDict

# Test json dumps / loads
json_str = json.dumps(immutable_dict, cls=dashboard_optional_utils.CustomEncoder)
deserialized_immutable_dict = json.loads(json_str)
assert type(deserialized_immutable_dict) == dict
assert type(deserialized_immutable_dict["list"]) == list
assert type(deserialized_immutable_dict) is dict
assert type(deserialized_immutable_dict["list"]) is list
assert immutable_dict.mutable() == deserialized_immutable_dict
dashboard_optional_utils.rest_response(True, "OK", data=immutable_dict)
dashboard_optional_utils.rest_response(True, "OK", **immutable_dict)
Expand All @@ -819,12 +819,12 @@ def test_immutable_types():

# Test get default immutable
immutable_default_value = immutable_dict.get("not exist list", [1, 2])
assert type(immutable_default_value) == dashboard_utils.ImmutableList
assert type(immutable_default_value) is dashboard_utils.ImmutableList

# Test recursive immutable
assert type(immutable_dict["list"]) == dashboard_utils.ImmutableList
assert type(immutable_dict["dict"]) == dashboard_utils.ImmutableDict
assert type(immutable_dict["list"][0]) == dashboard_utils.ImmutableDict
assert type(immutable_dict["list"]) is dashboard_utils.ImmutableList
assert type(immutable_dict["dict"]) is dashboard_utils.ImmutableDict
assert type(immutable_dict["list"][0]) is dashboard_utils.ImmutableDict

# Test exception
with pytest.raises(TypeError):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def _cast_large_list_to_list(batch: pyarrow.Table):

for column_name in old_schema.names:
field_type = old_schema.field(column_name).type
if type(field_type) == pyarrow.lib.LargeListType:
if type(field_type) is pyarrow.lib.LargeListType:
value_type = field_type.value_type

if value_type == pyarrow.large_binary():
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -1442,7 +1442,7 @@ def empty_pandas(batch):
block_refs = _ref_bundles_iterator_to_block_refs_list(bundles)

assert len(block_refs) == 1
assert type(ray.get(block_refs[0])) == pd.DataFrame
assert type(ray.get(block_refs[0])) is pd.DataFrame


def test_map_with_objects_and_tensors(ray_start_regular_shared):
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/tests/test_numpy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def do_map_batches(data):


def assert_structure_equals(a, b):
assert type(a) == type(b), (type(a), type(b))
assert type(a) is type(b), (type(a), type(b))
assert type(a[0]) == type(b[0]), (type(a[0]), type(b[0])) # noqa: E721
assert a.dtype == b.dtype
assert a.shape == b.shape
Expand Down
2 changes: 1 addition & 1 deletion python/ray/serve/multiplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ async def load_model(self, model_id: str) -> Any:
The user-constructed model object.
"""

if type(model_id) != str:
if type(model_id) is not str:
raise TypeError("The model ID must be a string.")

if not model_id:
Expand Down
2 changes: 1 addition & 1 deletion python/ray/tests/gcp/test_gcp_tpu_command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def test_max_active_connections_env_var():
cmd_runner = TPUCommandRunner(**args)
os.environ[ray_constants.RAY_TPU_MAX_CONCURRENT_CONNECTIONS_ENV_VAR] = "1"
num_connections = cmd_runner.num_connections
assert type(num_connections) == int
assert type(num_connections) is int
assert num_connections == 1


Expand Down
2 changes: 1 addition & 1 deletion python/ray/tests/modin/modin_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def df_equals(df1, df2):
if isinstance(df1, pandas.DataFrame) and isinstance(df2, pandas.DataFrame):
if (df1.empty and not df2.empty) or (df2.empty and not df1.empty):
assert False, "One of the passed frames is empty, when other isn't"
elif df1.empty and df2.empty and type(df1) != type(df2):
elif df1.empty and df2.empty and type(df1) is not type(df2):
assert (
False
), f"Empty frames have different types: {type(df1)} != {type(df2)}"
Expand Down
6 changes: 3 additions & 3 deletions python/ray/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,9 +866,9 @@ def temp():
assert ray.get(f.remote(s)) == s

# Test types.
assert ray.get(f.remote(int)) == int
assert ray.get(f.remote(float)) == float
assert ray.get(f.remote(str)) == str
assert ray.get(f.remote(int)) is int
assert ray.get(f.remote(float)) is float
assert ray.get(f.remote(str)) is str

class Foo:
def __init__(self):
Expand Down
2 changes: 1 addition & 1 deletion python/ray/tests/test_client_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_client(address):
if address in ("local", None):
assert isinstance(builder, client_builder._LocalClientBuilder)
else:
assert type(builder) == client_builder.ClientBuilder
assert type(builder) is client_builder.ClientBuilder
assert builder.address == address.replace("ray://", "")


Expand Down
2 changes: 1 addition & 1 deletion python/ray/tests/test_joblib.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_ray_backend(shutdown_only):
from ray.util.joblib.ray_backend import RayBackend

with joblib.parallel_backend("ray"):
assert type(joblib.parallel.get_active_backend()[0]) == RayBackend
assert type(joblib.parallel.get_active_backend()[0]) is RayBackend


def test_svm_single_node(shutdown_only):
Expand Down
8 changes: 4 additions & 4 deletions python/ray/tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
def is_named_tuple(cls):
"""Return True if cls is a namedtuple and False otherwise."""
b = cls.__bases__
if len(b) != 1 or b[0] != tuple:
if len(b) != 1 or b[0] is not tuple:
return False
f = getattr(cls, "_fields", None)
if not isinstance(f, tuple):
return False
return all(type(n) == str for n in f)
return all(type(n) is str for n in f)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -95,8 +95,8 @@ def f(x):
# TODO(rkn): The numpy dtypes currently come back as regular integers
# or floats.
if type(obj).__module__ != "numpy":
assert type(obj) == type(new_obj_1)
assert type(obj) == type(new_obj_2)
assert type(obj) is type(new_obj_1)
assert type(obj) is type(new_obj_2)


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion python/ray/tune/tests/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def assertDictAlmostEqual(a, b):
assert k in b, f"Key {k} not found in {b}"
w = b[k]

assert type(v) == type(w), f"Type {type(v)} is not {type(w)}"
assert type(v) is type(w), f"Type {type(v)} is not {type(w)}"

if isinstance(v, dict):
assert assertDictAlmostEqual(v, w), f"Subdict {v} != {w}"
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/action/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def to_state(self):
@staticmethod
def from_state(ctx: ConnectorContext, params: Any):
assert (
type(params) == list
type(params) is list
), "ActionConnectorPipeline takes a list of connector params."
connectors = []
for state in params:
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/agent/clip_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, ctx: ConnectorContext, sign=False, limit=None):
def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
d = ac_data.data
assert (
type(d) == dict
type(d) is dict
), "Single agent data must be of type Dict[str, TensorStructType]"

if SampleBatch.REWARDS not in d:
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/agent/mean_std_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
d = ac_data.data
assert (
type(d) == dict
type(d) is dict
), "Single agent data must be of type Dict[str, TensorStructType]"
if SampleBatch.OBS in d:
d[SampleBatch.OBS] = self.filter(
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/agent/obs_preproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def is_identity(self):

def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
d = ac_data.data
assert type(d) == dict, (
assert type(d) is dict, (
"Single agent data must be of type Dict[str, TensorStructType] but is of "
"type {}".format(type(d))
)
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/agent/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def to_state(self):
@staticmethod
def from_state(ctx: ConnectorContext, params: List[Any]):
assert (
type(params) == list
type(params) is list
), "AgentConnectorPipeline takes a list of connector params."
connectors = []
for state in params:
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_agent_connectors_from_config(
clip_rewards = __clip_rewards(config)
if clip_rewards is True:
connectors.append(ClipRewardAgentConnector(ctx, sign=True))
elif type(clip_rewards) == float:
elif type(clip_rewards) is float:
connectors.append(ClipRewardAgentConnector(ctx, limit=abs(clip_rewards)))

if __preprocessing_enabled(config):
Expand Down
4 changes: 2 additions & 2 deletions rllib/env/wrappers/dm_control_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def _spec_to_box(spec):
def extract_min_max(s):
assert s.dtype == np.float64 or s.dtype == np.float32
dim = np.int_(np.prod(s.shape))
if type(s) == specs.Array:
if type(s) is specs.Array:
bound = np.inf * np.ones(dim, dtype=np.float32)
return -bound, bound
elif type(s) == specs.BoundedArray:
elif type(s) is specs.BoundedArray:
zeros = np.zeros(dim, dtype=np.float32)
return s.minimum + zeros, s.maximum + zeros

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_mixin_sampling_episodes(self):
for _ in range(20):
buffer.add(batch)
sample = buffer.sample(2)
assert type(sample) == MultiAgentBatch
assert type(sample) is MultiAgentBatch
results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
# One sample in the episode does not belong the the episode on thus
# gets dropped. Full episodes are of length two.
Expand All @@ -88,7 +88,7 @@ def test_mixin_sampling_sequences(self):
for _ in range(400):
buffer.add(batch)
sample = buffer.sample(10)
assert type(sample) == MultiAgentBatch
assert type(sample) is MultiAgentBatch
results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
self.assertAlmostEqual(np.mean(results), 2 * len(batch), delta=0.1)

Expand All @@ -113,7 +113,7 @@ def test_mixin_sampling_timesteps(self):
buffer.add(batch)
buffer.add(batch)
sample = buffer.sample(3)
assert type(sample) == MultiAgentBatch
assert type(sample) is MultiAgentBatch
results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
self.assertAlmostEqual(np.mean(results), 3.0, delta=0.2)

Expand All @@ -125,7 +125,7 @@ def test_mixin_sampling_timesteps(self):
for _ in range(100):
buffer.add(batch)
sample = buffer.sample(5)
assert type(sample) == MultiAgentBatch
assert type(sample) is MultiAgentBatch
results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
self.assertAlmostEqual(np.mean(results), 1.5, delta=0.2)

Expand All @@ -142,7 +142,7 @@ def test_mixin_sampling_timesteps(self):
for _ in range(100):
buffer.add(batch)
sample = buffer.sample(10)
assert type(sample) == MultiAgentBatch
assert type(sample) is MultiAgentBatch
results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
self.assertAlmostEqual(np.mean(results), 10.0, delta=0.2)

Expand All @@ -156,12 +156,12 @@ def test_mixin_sampling_timesteps(self):
buffer.add(batch)
# Expect exactly 1 batch to be returned.
sample = buffer.sample(1)
assert type(sample) == MultiAgentBatch
assert type(sample) is MultiAgentBatch
self.assertTrue(len(sample) == 1)
# Expect exactly 0 sample to be returned (nothing new to be returned;
# no replay allowed (replay_ratio=0.0)).
sample = buffer.sample(1)
assert type(sample) == MultiAgentBatch
assert type(sample) is MultiAgentBatch
assert len(sample.policy_batches) == 0
# If we insert and replay n times, expect roughly return batches of
# len 1 (replay_ratio=0.0 -> 0% replayed samples -> 1 new and 0 old samples
Expand All @@ -170,7 +170,7 @@ def test_mixin_sampling_timesteps(self):
for _ in range(100):
buffer.add(batch)
sample = buffer.sample(1)
assert type(sample) == MultiAgentBatch
assert type(sample) is MultiAgentBatch
results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
self.assertAlmostEqual(np.mean(results), 1.0, delta=0.2)

Expand All @@ -187,19 +187,19 @@ def test_mixin_sampling_timesteps(self):
buffer.add(batch)
# Expect exactly 1 sample to be returned (the new batch).
sample = buffer.sample(1)
assert type(sample) == MultiAgentBatch
assert type(sample) is MultiAgentBatch
self.assertTrue(len(sample) == 1)
# Another replay -> Expect exactly 1 sample to be returned.
sample = buffer.sample(1)
assert type(sample) == MultiAgentBatch
assert type(sample) is MultiAgentBatch
self.assertTrue(len(sample) == 1)
# If we replay n times, expect roughly return batches of
# len 1 (replay_ratio=1.0 -> 100% replayed samples -> 0 new and 1 old samples
# on average in each returned value).
results = []
for _ in range(100):
sample = buffer.sample(1)
assert type(sample) == MultiAgentBatch
assert type(sample) is MultiAgentBatch
results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
self.assertAlmostEqual(np.mean(results), 1.0)

Expand Down
Loading

0 comments on commit 96d64fa

Please sign in to comment.