Skip to content

Commit

Permalink
don't use cpu state_dict for model unpatching when executing on cpu (#…
Browse files Browse the repository at this point in the history
…6631)

Co-authored-by: Lincoln Stein <[email protected]>
  • Loading branch information
lstein and Lincoln Stein committed Jul 18, 2024
1 parent 0583101 commit 97a7f51
Showing 1 changed file with 2 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ def put(
size = calc_model_size_by_data(self.logger, model)
self.make_room(size)

state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
running_on_cpu = self.execution_device == torch.device("cpu")
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
Expand Down

0 comments on commit 97a7f51

Please sign in to comment.