Skip to content

Commit

Permalink
Merge pull request #84 from sony/feature/20241211-padim-on-cpu
Browse files Browse the repository at this point in the history
CPU support for PaDiM and PatchCore
  • Loading branch information
YukioOobuchi authored Dec 13, 2024
2 parents 227aa6f + f665a34 commit f20512d
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 8 deletions.
4 changes: 2 additions & 2 deletions plugins/_Training/padim.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@


def get_model(model, batch_size, feature_ratio):
nn.set_default_context(get_extension_context('cudnn'))
try:
nn.set_default_context(get_extension_context('cudnn'))
x = nn.Variable()
F.relu(x)
except:
Expand Down Expand Up @@ -265,7 +265,7 @@ def func(args):
# Create anomary detection model
logger.log(99, 'Saving anomary detection model...')
contents = {
'global_config': {'default_context': get_extension_context('cudnn')},
'global_config': {'default_context': nn.get_current_context()},
'networks': [
{'name': 'network',
'batch_size': batch_size,
Expand Down
4 changes: 2 additions & 2 deletions plugins/_Training/padim_c1.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@


def get_model(model, batch_size, feature_ratio):
nn.set_default_context(get_extension_context('cudnn'))
try:
nn.set_default_context(get_extension_context('cudnn'))
x = nn.Variable()
F.relu(x)
except:
Expand Down Expand Up @@ -282,7 +282,7 @@ def func(args):
# Create anomary detection model
logger.log(99, 'Saving anomary detection model...')
contents = {
'global_config': {'default_context': get_extension_context('cudnn')},
'global_config': {'default_context': nn.get_current_context()},
'networks': [
{'name': 'network',
'batch_size': batch_size,
Expand Down
4 changes: 2 additions & 2 deletions plugins/_Training/patchcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@


def get_model(model, batch_size, feature_ratio):
nn.set_default_context(get_extension_context('cudnn'))
try:
nn.set_default_context(get_extension_context('cudnn'))
x = nn.Variable()
F.relu(x)
except:
Expand Down Expand Up @@ -177,7 +177,7 @@ def func(args):
# Create anomary detection model
logger.log(99, 'Saving anomary detection model...')
contents = {
'global_config': {'default_context': get_extension_context('cudnn')},
'global_config': {'default_context': nn.get_current_context()},
'networks': [
{'name': 'network',
'batch_size': 1,
Expand Down
6 changes: 4 additions & 2 deletions plugins/_Training/patchcore_c1.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,17 @@
from nnabla import logger
import nnabla_ext.cpu
from nnabla.ext_utils import get_extension_context
from nnabla.models.utils import get_model_home
from nnabla.utils import nnabla_pb2
from nnabla.utils.data_iterator import data_iterator_csv_dataset
from nnabla.utils.download import download
from nnabla.utils.progress import configure_progress, progress
from nnabla.utils.save import save


def get_model(model, batch_size, feature_ratio):
nn.set_default_context(get_extension_context('cudnn'))
try:
nn.set_default_context(get_extension_context('cudnn'))
x = nn.Variable()
F.relu(x)
except:
Expand Down Expand Up @@ -192,7 +194,7 @@ def func(args):
# Create anomary detection model
logger.log(99, 'Saving anomary detection model...')
contents = {
'global_config': {'default_context': get_extension_context('cudnn')},
'global_config': {'default_context': nn.get_current_context()},
'networks': [
{'name': 'network',
'batch_size': 1,
Expand Down

0 comments on commit f20512d

Please sign in to comment.