diff --git a/op_builder/gds.py b/op_builder/gds.py index 01c2d5a245d1..e024674e01d8 100644 --- a/op_builder/gds.py +++ b/op_builder/gds.py @@ -36,13 +36,7 @@ def extra_ldflags(self): return super().extra_ldflags() + ['-lcufile'] def is_compatible(self, verbose=False): - try: - import torch.utils.cpp_extension - except ImportError: - if verbose: - self.warning("Please install torch if trying to pre-compile GDS") - return False - + import torch.utils.cpp_extension CUDA_HOME = torch.utils.cpp_extension.CUDA_HOME CUDA_LIB64 = os.path.join(CUDA_HOME, "lib64") gds_compatible = self.has_function(funcname="cuFileDriverOpen",