From 295531593698885902b37b210668233cacdd15e8 Mon Sep 17 00:00:00 2001 From: chengmengli06 Date: Mon, 26 Feb 2024 15:01:31 +0800 Subject: [PATCH] add conditional placement of feature column on cpu --- easy_rec/python/compat/feature_column/feature_column.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/easy_rec/python/compat/feature_column/feature_column.py b/easy_rec/python/compat/feature_column/feature_column.py index 3af58f949..c7ce6199e 100644 --- a/easy_rec/python/compat/feature_column/feature_column.py +++ b/easy_rec/python/compat/feature_column/feature_column.py @@ -171,6 +171,7 @@ from tensorflow.python.util import nest from easy_rec.python.compat.feature_column import utils as fc_utils +from easy_rec.python.utils import conditional from easy_rec.python.utils import constant from easy_rec.python.utils import embedding_utils @@ -634,7 +635,7 @@ def _get_var_type(column): if embedding_utils.is_embedding_parallel(): return _get_logits_embedding_parallel() else: - with ops.device('/CPU:0'): + with conditional(embedding_utils.embedding_on_cpu(), '/cpu:0'): return _get_logits()