Skip to content

Commit

Permalink
Merge pull request BVLC#1083 from longjon/fix-solver-gpu-init
Browse files Browse the repository at this point in the history
Fix solver GPU initialization order (e.g., training with cuDNN on non-default device)
  • Loading branch information
shelhamer committed Sep 15, 2014
2 parents 2da6bc9 + bbd166e commit 1f4e039
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
5 changes: 0 additions & 5 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,6 @@ void Solver<Dtype>::Init(const SolverParameter& param) {
LOG(INFO) << "Initializing solver from parameters: " << std::endl
<< param.DebugString();
param_ = param;
if (param_.solver_mode() == SolverParameter_SolverMode_GPU &&
param_.has_device_id()) {
Caffe::SetDevice(param_.device_id());
}
Caffe::set_mode(Caffe::Brew(param_.solver_mode()));
if (param_.random_seed() >= 0) {
Caffe::set_random_seed(param_.random_seed());
}
Expand Down
15 changes: 11 additions & 4 deletions tools/caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,27 @@ int train() {
caffe::SolverParameter solver_param;
caffe::ReadProtoFromTextFileOrDie(FLAGS_solver, &solver_param);

LOG(INFO) << "Starting Optimization";
shared_ptr<caffe::Solver<float> >
solver(caffe::GetSolver<float>(solver_param));
// If the gpu flag is not provided, allow the mode and device to be set
// in the solver prototxt.
if (FLAGS_gpu < 0
&& solver_param.solver_mode() == caffe::SolverParameter_SolverMode_GPU) {
FLAGS_gpu = solver_param.device_id();
}

// Set device id and mode
if (FLAGS_gpu >= 0) {
LOG(INFO) << "Use GPU with device ID " << FLAGS_gpu;
Caffe::SetDevice(FLAGS_gpu);
Caffe::set_mode(Caffe::GPU);
} else if (!solver_param.has_solver_mode()) {
} else {
LOG(INFO) << "Use CPU.";
Caffe::set_mode(Caffe::CPU);
}

LOG(INFO) << "Starting Optimization";
shared_ptr<caffe::Solver<float> >
solver(caffe::GetSolver<float>(solver_param));

if (FLAGS_snapshot.size()) {
LOG(INFO) << "Resuming from " << FLAGS_snapshot;
solver->Solve(FLAGS_snapshot);
Expand Down

0 comments on commit 1f4e039

Please sign in to comment.