Skip to content

Commit d65a89a

Browse files
committed
Rebase
1 parent 5dd987a commit d65a89a

File tree

3 files changed

+15
-21
lines changed

3 files changed

+15
-21
lines changed

python/tvm/relax/backend/adreno/clml.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,6 @@ def _check_conv2d(context: PatternCheckContext) -> bool:
134134
if weight_dtype not in ["float32", "float16"]:
135135
return False
136136

137-
if "pad_val" in context.annotated_expr:
138-
pad_expr = context.annotated_expr["pad_val"]
139-
if 0 != pad_expr.data.numpy():
140-
return False
141-
142137
return True
143138

144139
def populate_patterns(patterns, name, op, annotations, *args):
@@ -159,7 +154,6 @@ def conv_pattern():
159154
bn_bias = is_const()
160155
bn_mean = is_const()
161156
bn_var = is_const()
162-
pad_val = is_const()
163157

164158
annotations = {
165159
"data": data,
@@ -173,9 +167,8 @@ def conv_pattern():
173167
}
174168

175169
pad_annotations = annotations.copy()
176-
pad_annotations.update({"pad_val": pad_val})
177170
patterns["pad.nn.conv2d"] = {
178-
"pattern": is_op("relax.nn.conv2d")(is_op("relax.nn.pad")(data, pad_val), weight),
171+
"pattern": is_op("relax.nn.conv2d")(is_op("relax.nn.pad")(data), weight),
179172
"annotation": pad_annotations,
180173
}
181174

python/tvm/relax/backend/adreno/pipeline.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,22 @@
2323
def library_dispatch_passes(target: tvm.target.Target): # pylint: disable=unused-argument
2424
"""The default library dispatch passes for Adreno GPU backend."""
2525
if "clml" in target.keys:
26-
return [tvm.relax.backend.adreno.clml.OpenCLMLOffLoad()]
26+
return [relax.backend.adreno.clml.OpenCLMLOffLoad()]
2727
else:
2828
return []
2929

3030

3131
def legalize_passes(target: tvm.target.Target): # pylint: disable=unused-argument
3232
"""The default legalization passes for Adreno GPU backend."""
3333
return [
34-
tvm.relax.transform.DecomposeOpsForInference(),
35-
tvm.relax.transform.FoldConstant(),
36-
tvm.relax.transform.LegalizeOps(),
37-
tvm.relax.transform.AnnotateTIROpPattern(),
38-
tvm.relax.transform.FoldConstant(),
39-
tvm.relax.transform.FuseOps(),
40-
tvm.relax.transform.FuseTIR(),
41-
tvm.relax.transform.DeadCodeElimination(),
34+
relax.transform.DecomposeOpsForInference(),
35+
relax.transform.FoldConstant(),
36+
relax.transform.LegalizeOps(),
37+
relax.transform.AnnotateTIROpPattern(),
38+
relax.transform.FoldConstant(),
39+
relax.transform.FuseOps(),
40+
relax.transform.FuseTIR(),
41+
relax.transform.DeadCodeElimination(),
4242
dl.ApplyDefaultSchedule(
4343
dl.gpu.Reduction(),
4444
dl.gpu.GeneralReduction(),
@@ -49,12 +49,12 @@ def legalize_passes(target: tvm.target.Target): # pylint: disable=unused-argume
4949

5050
def dataflow_lower_passes(target: tvm.target.Target): # pylint: disable=unused-argument
5151
"""The default dataflow lowering passes for Adreno GPU backend."""
52-
return tvm.relax.backend.gpu_generic.library_dispatch_passes(target)
52+
return relax.backend.gpu_generic.library_dispatch_passes(target)
5353

5454

5555
def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argument
5656
"""The default finalization passes for Adreno GPU backend."""
57-
return tvm.relax.backend.gpu_generic.finalize_passes(target)
57+
return relax.backend.gpu_generic.finalize_passes(target)
5858

5959

6060
def get_default_pipeline(target: tvm.target.Target):

tests/python/relax/backend/clml/conftest.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import sys
2020
import tvm
2121
import pytest
22-
from tvm.autotvm.measure import request_remote
22+
from tvm import rpc as _rpc
2323

2424

2525
@pytest.fixture(scope="session")
@@ -33,6 +33,7 @@ def rpc():
3333
target_host = "llvm -mtriple=aarch64-linux-gnu"
3434
device_key = os.getenv("RPC_DEVICE_KEY", "android")
3535
cross_compile = os.getenv("TVM_NDK_CC", "aarch64-linux-android-g++")
36-
return request_remote(device_key, host, port, timeout=1000)
36+
tracker = _rpc.connect_tracker(host, port)
37+
return tracker.request(device_key, priority=1, session_timeout=1000)
3738
else:
3839
return None

0 commit comments

Comments
 (0)