Skip to content

Commit

Permalink
Update extract plugin VRAM amounts for pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed Mar 3, 2024
1 parent 53b89e3 commit 9e68532
Show file tree
Hide file tree
Showing 9 changed files with 16 additions and 28 deletions.
4 changes: 0 additions & 4 deletions plugins/extract/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,6 @@ class Extractor():
vram: int
Approximate VRAM used by the model at :attr:`input_size`. Used to calculate the
:attr:`batchsize`. Be conservative to avoid OOM.
vram_warnings: int
Approximate VRAM used by the model at :attr:`input_size` that will still run, but generates
warnings. Used to calculate the :attr:`batchsize`. Be conservative to avoid OOM.
vram_per_batch: int
Approximate additional VRAM used by the model for each additional batch. Used to calculate
the :attr:`batchsize`. Be conservative to avoid OOM.
Expand Down Expand Up @@ -173,7 +170,6 @@ def __init__(self,
self.input_size = 0
self.color_format: T.Literal["BGR", "RGB", "GRAY"] = "BGR"
self.vram = 0
self.vram_warnings = 0 # Will run at this with warnings
self.vram_per_batch = 0

# << THE FOLLOWING ARE SET IN self.initialize METHOD >> #
Expand Down
5 changes: 2 additions & 3 deletions plugins/extract/align/fan.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ def __init__(self, **kwargs) -> None:
self.name = "FAN"
self.input_size = 256
self.color_format = "RGB"
self.vram = 2240
self.vram_warnings = 512 # Will run at this with warnings
self.vram_per_batch = 64
self.vram = 896 # 810 in testing
self.vram_per_batch = 768 # ~720 in testing
self.realign_centering = "head"
self.batchsize: int = self.config["batch-size"]
self.reference_scale = 200. / 195.
Expand Down
5 changes: 2 additions & 3 deletions plugins/extract/detect/mtcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@ def __init__(self, **kwargs) -> None:
super().__init__(git_model_id=git_model_id, model_filename=model_filename, **kwargs)
self.name = "MTCNN"
self.input_size = 640
self.vram = 320 if not self.config["cpu"] else 0
self.vram_warnings = 64 if not self.config["cpu"] else 0 # Will run at this with warnings
self.vram_per_batch = 32 if not self.config["cpu"] else 0
self.vram = 128 if not self.config["cpu"] else 0 # 66 in testing
self.vram_per_batch = 64 if not self.config["cpu"] else 0 # ~50 in testing
self.batchsize = self.config["batch-size"]
self.kwargs = self._validate_kwargs()
self.color_format = "RGB"
Expand Down
5 changes: 2 additions & 3 deletions plugins/extract/detect/s3fd.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@ def __init__(self, **kwargs) -> None:
super().__init__(git_model_id=git_model_id, model_filename=model_filename, **kwargs)
self.name = "S3FD"
self.input_size = 640
self.vram = 4112
self.vram_warnings = 1024 # Will run at this with warnings
self.vram_per_batch = 208
self.vram = 1088 # 1034 in testing
self.vram_per_batch = 960 # 922 in testing
self.batchsize = self.config["batch-size"]

def init_model(self) -> None:
Expand Down
5 changes: 2 additions & 3 deletions plugins/extract/mask/bisenet_fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,8 @@ def __init__(self, **kwargs) -> None:
self.name = "BiSeNet - Face Parsing"
self.input_size = 512
self.color_format = "RGB"
self.vram = 2304 if not self.config["cpu"] else 0
self.vram_warnings = 256 if not self.config["cpu"] else 0
self.vram_per_batch = 64 if not self.config["cpu"] else 0
self.vram = 384 if not self.config["cpu"] else 0 # 378 in testing
self.vram_per_batch = 384 if not self.config["cpu"] else 0 # ~328 in testing
self.batchsize = self.config["batch-size"]

self._segment_indices = self._get_segment_indices()
Expand Down
5 changes: 2 additions & 3 deletions plugins/extract/mask/unet_dfl.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@ def __init__(self, **kwargs) -> None:
self.model: KSession
self.name = "U-Net"
self.input_size = 256
self.vram = 3424
self.vram_warnings = 256
self.vram_per_batch = 80
self.vram = 320 # 276 in testing
self.vram_per_batch = 256 # ~215 in testing
self.batchsize = self.config["batch-size"]
self._storage_centering = "legacy"

Expand Down
5 changes: 2 additions & 3 deletions plugins/extract/mask/vgg_clear.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ def __init__(self, **kwargs) -> None:
self.model: KSession
self.name = "VGG Clear"
self.input_size = 300
self.vram = 2944
self.vram_warnings = 1088 # at BS 1. OOMs at higher batch sizes
self.vram_per_batch = 400
self.vram = 1344 # 1308 in testing
self.vram_per_batch = 448 # ~402 in testing
self.batchsize = self.config["batch-size"]

def init_model(self) -> None:
Expand Down
5 changes: 2 additions & 3 deletions plugins/extract/mask/vgg_obstructed.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ def __init__(self, **kwargs) -> None:
self.model: KSession
self.name = "VGG Obstructed"
self.input_size = 500
self.vram = 3936
self.vram_warnings = 1088 # at BS 1. OOMs at higher batch sizes
self.vram_per_batch = 304
self.vram = 1728 # 1710 in testing
self.vram_per_batch = 896 # ~886 in testing
self.batchsize = self.config["batch-size"]

def init_model(self) -> None:
Expand Down
5 changes: 2 additions & 3 deletions plugins/extract/recognition/vgg_face2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ def __init__(self, *args, **kwargs) -> None: # pylint:disable=unused-argument
self.input_size = 224
self.color_format = "BGR"

self.vram = 2468 if not self.config["cpu"] else 0
self.vram_warnings = 192 if not self.config["cpu"] else 0
self.vram_per_batch = 32 if not self.config["cpu"] else 0
self.vram = 384 if not self.config["cpu"] else 0 # 334 in testing
self.vram_per_batch = 192 if not self.config["cpu"] else 0 # ~155 in testing
self.batchsize = self.config["batch-size"]

# Average image provided in https://github.com/ox-vgg/vgg_face2
Expand Down

0 comments on commit 9e68532

Please sign in to comment.