@@ -20,6 +20,16 @@ def __init__(self):
20
20
super ().__init__ ()
21
21
self ._forward_method = self .dispatch_forward ()
22
22
23
+ @classmethod
24
+ def set_foward_method (cls , method ):
25
+ """Provide a way to register a custom forward method for a specific
26
+ backend."""
27
+ if getattr (cls , f"forward_{ current_platform .device_name } " , None ):
28
+ raise ValueError (
29
+ f"Custom op { cls .__class__ .__name__ } already has a "
30
+ f"forward_{ current_platform .device_name } method" )
31
+ setattr (cls , f"forward_{ current_platform .device_name } " , method )
32
+
23
33
def forward (self , * args , ** kwargs ):
24
34
return self ._forward_method (* args , ** kwargs )
25
35
@@ -72,18 +82,15 @@ def dispatch_forward(self):
72
82
if not enabled :
73
83
return self .forward_native
74
84
75
- if current_platform .is_rocm ():
76
- return self .forward_hip
77
- elif current_platform .is_cpu ():
78
- return self .forward_cpu
79
- elif current_platform .is_hpu ():
80
- return self .forward_hpu
81
- elif current_platform .is_tpu ():
82
- return self .forward_tpu
83
- elif current_platform .is_xpu ():
84
- return self .forward_xpu
85
- else :
86
- return self .forward_cuda
85
+ custom_forward_func = \
86
+ getattr (self , f"forward_{ current_platform .device_name } " , None )
87
+ if not custom_forward_func :
88
+ logger .warning (
89
+ "Custom op %s is not supported on %s, falling back "
90
+ "to native." , self .__class__ .__name__ ,
91
+ current_platform .device_name )
92
+ return self .forward_native
93
+ return custom_forward_func
87
94
88
95
@classmethod
89
96
def enabled (cls ) -> bool :
0 commit comments