Skip to content

Commit

Permalink
Supports Intel XPU
Browse files Browse the repository at this point in the history
  • Loading branch information
DDXDB committed Jan 17, 2025
1 parent cb62fc9 commit a4d7bbe
Show file tree
Hide file tree
Showing 13 changed files with 69 additions and 71 deletions.
2 changes: 1 addition & 1 deletion backend/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ bgm_separation:
# Whether to offload the model after the inference. Should be true if your setup has a VRAM less than <16GB
enable_offload: true
# Device to load BGM separation model
device: cuda
device: xpu

# Settings that apply to the `cache' directory. The output files for `/bgm-separation` are stored in the `cache' directory,
# (You can check out the actual generated files by testing `/bgm-separation`.)
Expand Down
2 changes: 1 addition & 1 deletion backend/tests/test_backend_bgm_separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip the test because CUDA is not available")
@pytest.mark.skipif(not torch.xpu.is_available(), reason="Skip the test because CUDA is not available")
@pytest.mark.parametrize(
"bgm_separation_params",
[
Expand Down
47 changes: 23 additions & 24 deletions configs/default_parameters.yaml
Original file line number Diff line number Diff line change
@@ -1,64 +1,63 @@
whisper:
model_size: "large-v2"
file_format: "SRT"
lang: "Automatic Detection"
model_size: large-v2
lang: chinese
is_translate: false
beam_size: 5
log_prob_threshold: -1
log_prob_threshold: -1.0
no_speech_threshold: 0.6
compute_type: float32
best_of: 5
patience: 1
patience: 1.0
condition_on_previous_text: true
prompt_reset_on_temperature: 0.5
initial_prompt: null
temperature: 0
temperature: 0.0
compression_ratio_threshold: 2.4
chunk_length: 30
batch_size: 24
length_penalty: 1
repetition_penalty: 1
length_penalty: 1.0
repetition_penalty: 1.0
no_repeat_ngram_size: 0
prefix: null
suppress_blank: true
suppress_tokens: "[-1]"
max_initial_timestamp: 1
suppress_tokens: '[-1]'
max_initial_timestamp: 1.0
word_timestamps: false
prepend_punctuations: "\"'“¿([{-"
append_punctuations: "\"'.。,,!!??::”)]}、"
prepend_punctuations: '"''“¿([{-'
append_punctuations: '"''.。,,!!??::”)]}、'
max_new_tokens: null
chunk_length: 30
hallucination_silence_threshold: null
hotwords: null
language_detection_threshold: 0.5
language_detection_segments: 1
batch_size: 24
add_timestamp: true

file_format: SRT
vad:
vad_filter: false
threshold: 0.5
min_speech_duration_ms: 250
max_speech_duration_s: 9999
min_silence_duration_ms: 1000
speech_pad_ms: 2000

diarization:
is_diarize: false
hf_token: ""

diarization_device: xpu
hf_token: ''
bgm_separation:
is_separate_bgm: false
uvr_model_size: "UVR-MDX-NET-Inst_HQ_4"
uvr_model_size: UVR-MDX-NET-Inst_HQ_4
uvr_device: xpu
segment_size: 256
save_file: false
enable_offload: true

translation:
deepl:
api_key: ""
api_key: ''
is_pro: false
source_lang: "Automatic Detection"
target_lang: "English"
source_lang: Automatic Detection
target_lang: English
nllb:
model_size: "facebook/nllb-200-1.3B"
model_size: facebook/nllb-200-1.3B
source_lang: null
target_lang: null
max_length: 200
Expand Down
14 changes: 7 additions & 7 deletions modules/diarize/diarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,15 @@ def offload(self):
if self.pipe is not None:
del self.pipe
self.pipe = None
if self.device == "cuda":
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
if self.device == "xpu":
torch.xpu.empty_cache()
torch.xpu.reset_max_memory_allocated()
gc.collect()

@staticmethod
def get_device():
if torch.cuda.is_available():
return "cuda"
if torch.xpu.is_available():
return "xpu"
elif torch.backends.mps.is_available():
return "mps"
else:
Expand All @@ -146,8 +146,8 @@ def get_device():
@staticmethod
def get_available_device():
devices = ["cpu"]
if torch.cuda.is_available():
devices.append("cuda")
if torch.xpu.is_available():
devices.append("xpu")
elif torch.backends.mps.is_available():
devices.append("mps")
return devices
18 changes: 9 additions & 9 deletions modules/translation/translation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,31 +127,31 @@ def translate_file(self,
print(f"Error translating file: {e}")
raise
finally:
self.release_cuda_memory()
self.release_xpu_memory()

def offload(self):
"""Offload the model and free up the memory"""
if self.model is not None:
del self.model
self.model = None
if self.device == "cuda":
self.release_cuda_memory()
if self.device == "xpu":
self.release_xpu_memory()
gc.collect()

@staticmethod
def get_device():
if torch.cuda.is_available():
return "cuda"
if torch.xpu.is_available():
return "xpu"
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"

@staticmethod
def release_cuda_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
def release_xpu_memory():
if torch.xpu.is_available():
torch.xpu.empty_cache()
torch.xpu.reset_max_memory_allocated()

@staticmethod
def remove_input_files(file_paths: List[str]):
Expand Down
8 changes: 4 additions & 4 deletions modules/uvr/music_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self,
output_dir: Optional[str] = UVR_OUTPUT_DIR):
self.model = None
self.device = self.get_device()
self.available_devices = ["cpu", "cuda"]
self.available_devices = ["cpu", "xpu"]
self.model_dir = model_dir
self.output_dir = output_dir
instrumental_output_dir = os.path.join(self.output_dir, "instrumental")
Expand Down Expand Up @@ -159,15 +159,15 @@ def separate_files(self,
@staticmethod
def get_device():
"""Get device for the model"""
return "cuda" if torch.cuda.is_available() else "cpu"
return "xpu" if torch.xpu.is_available() else "cpu"

def offload(self):
"""Offload the model and free up the memory"""
if self.model is not None:
del self.model
self.model = None
if self.device == "cuda":
torch.cuda.empty_cache()
if self.device == "xpu":
torch.xpu.empty_cache()
gc.collect()
self.audio_info = None

Expand Down
22 changes: 11 additions & 11 deletions modules/whisper/base_transcription_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def transcribe_file(self,
except Exception as e:
raise RuntimeError(f"Error transcribing file: {e}") from e
finally:
self.release_cuda_memory()
self.release_xpu_memory()

def transcribe_mic(self,
mic_audio: str,
Expand Down Expand Up @@ -328,7 +328,7 @@ def transcribe_mic(self,
except Exception as e:
raise RuntimeError(f"Error transcribing mic: {e}") from e
finally:
self.release_cuda_memory()
self.release_xpu_memory()

def transcribe_youtube(self,
youtube_link: str,
Expand Down Expand Up @@ -400,7 +400,7 @@ def transcribe_youtube(self,
except Exception as e:
raise RuntimeError(f"Error transcribing youtube: {e}") from e
finally:
self.release_cuda_memory()
self.release_xpu_memory()

def get_compute_type(self):
if "float16" in self.available_compute_types:
Expand All @@ -421,8 +421,8 @@ def offload(self):
if self.model is not None:
del self.model
self.model = None
if self.device == "cuda":
self.release_cuda_memory()
if self.device == "xpu":
self.release_xpu_memory()
gc.collect()

@staticmethod
Expand Down Expand Up @@ -454,8 +454,8 @@ def format_time(elapsed_time: float) -> str:

@staticmethod
def get_device():
if torch.cuda.is_available():
return "cuda"
if torch.xpu.is_available():
return "xpu"
elif torch.backends.mps.is_available():
if not BaseTranscriptionPipeline.is_sparse_api_supported():
# Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
Expand All @@ -482,11 +482,11 @@ def is_sparse_api_supported():
return False

@staticmethod
def release_cuda_memory():
def release_xpu_memory():
"""Release memory"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
if torch.xpu.is_available():
torch.xpu.empty_cache()
torch.xpu.max_memory_allocated()

@staticmethod
def remove_input_files(file_paths: List[str]):
Expand Down
8 changes: 4 additions & 4 deletions modules/whisper/data_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components
class DiarizationParams(BaseParams):
"""Speaker diarization parameters"""
is_diarize: bool = Field(default=False, description="Enable speaker diarization")
diarization_device: str = Field(default="cuda", description="Device to run Diarization model.")
diarization_device: str = Field(default="xpu", description="Device to run Diarization model.")
hf_token: str = Field(
default="",
description="Hugging Face token for downloading diarization models"
Expand All @@ -174,7 +174,7 @@ def to_gradio_inputs(cls,
),
gr.Dropdown(
label=_("Device"),
choices=["cpu", "cuda"] if available_devices is None else available_devices,
choices=["cpu", "xpu"] if available_devices is None else available_devices,
value=defaults.get("device", device),
),
gr.Textbox(
Expand All @@ -192,7 +192,7 @@ class BGMSeparationParams(BaseParams):
default="UVR-MDX-NET-Inst_HQ_4",
description="UVR model size"
)
uvr_device: str = Field(default="cuda", description="Device to run UVR model.")
uvr_device: str = Field(default="xpu", description="Device to run UVR model.")
segment_size: int = Field(
default=256,
gt=0,
Expand Down Expand Up @@ -228,7 +228,7 @@ def to_gradio_input(cls,
),
gr.Dropdown(
label=_("Device"),
choices=["cpu", "cuda"] if available_devices is None else available_devices,
choices=["cpu", "xpu"] if available_devices is None else available_devices,
value=defaults.get("device", device),
),
gr.Number(
Expand Down
4 changes: 2 additions & 2 deletions modules/whisper/faster_whisper_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ def get_model_paths(self):

@staticmethod
def get_device():
if torch.cuda.is_available():
return "cuda"
if torch.xpu.is_available():
return "xpu"
else:
return "auto"

Expand Down
5 changes: 2 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
--extra-index-url https://download.pytorch.org/whl/cu124


torch
torchaudio

git+https://github.com/jhj0517/jhj0517-whisper.git
faster-whisper==1.1.1
#faster-whisper==1.1.1
transformers
gradio
gradio-i18n
Expand Down
4 changes: 2 additions & 2 deletions tests/test_bgm_separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


@pytest.mark.skipif(
not is_cuda_available(),
not is_xpu_available(),
reason="Skipping because the test only works on GPU"
)
@pytest.mark.parametrize(
Expand All @@ -32,7 +32,7 @@ def test_bgm_separation_pipeline(


@pytest.mark.skipif(
not is_cuda_available(),
not is_xpu_available(),
reason="Skipping because the test only works on GPU"
)
@pytest.mark.parametrize(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@


@functools.lru_cache
def is_cuda_available():
return torch.cuda.is_available()
def is_xpu_available():
return torch.xpu.is_available()


@functools.lru_cache
Expand Down
2 changes: 1 addition & 1 deletion tests/test_diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@pytest.mark.skipif(
not is_cuda_available(),
not is_xpu_available(),
reason="Skipping because the test only works on GPU"
)
@pytest.mark.parametrize(
Expand Down

0 comments on commit a4d7bbe

Please sign in to comment.