Skip to content

Commit

Permalink
Big changes to detection
Browse files Browse the repository at this point in the history
  • Loading branch information
pjreddie committed Mar 4, 2015
1 parent 5f4a5f5 commit fb9e0fe
Show file tree
Hide file tree
Showing 17 changed files with 298 additions and 151 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@
*.dSYM
*.csv
*.out
*.png
*.sh
mnist/
data/
caffe/
grasp/
images/
opencv/
convnet/
decaf/
submission/
cfg/
darknet

# OS Generated #
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ OBJDIR=./obj/
CC=gcc
NVCC=nvcc
OPTS=-O3
LDFLAGS=`pkg-config --libs opencv` -lm -pthread
LDFLAGS=`pkg-config --libs opencv` -lm -pthread -lstdc++
COMMON=`pkg-config --cflags opencv` -I/usr/local/cuda/include/
CFLAGS=-Wall -Wfatal-errors

Expand All @@ -25,7 +25,7 @@ CFLAGS+=-DGPU
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas
endif

OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o normalization_layer.o parser.o option_list.o darknet.o
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o normalization_layer.o parser.o option_list.o darknet.o detection_layer.o
ifeq ($(GPU), 1)
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o
endif
Expand Down
33 changes: 20 additions & 13 deletions src/cost_layer.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
COST_TYPE get_cost_type(char *s)
{
if (strcmp(s, "sse")==0) return SSE;
if (strcmp(s, "detection")==0) return DETECTION;
fprintf(stderr, "Couldn't find activation function %s, going with SSE\n", s);
return SSE;
}
Expand All @@ -20,8 +19,6 @@ char *get_cost_string(COST_TYPE a)
switch(a){
case SSE:
return "sse";
case DETECTION:
return "detection";
}
return "sse";
}
Expand All @@ -41,17 +38,20 @@ cost_layer *make_cost_layer(int batch, int inputs, COST_TYPE type)
return layer;
}

void pull_cost_layer(cost_layer layer)
{
cuda_pull_array(layer.delta_gpu, layer.delta, layer.batch*layer.inputs);
}
void push_cost_layer(cost_layer layer)
{
cuda_push_array(layer.delta_gpu, layer.delta, layer.batch*layer.inputs);
}

void forward_cost_layer(cost_layer layer, float *input, float *truth)
{
if (!truth) return;
copy_cpu(layer.batch*layer.inputs, truth, 1, layer.delta, 1);
axpy_cpu(layer.batch*layer.inputs, -1, input, 1, layer.delta, 1);
if(layer.type == DETECTION){
int i;
for(i = 0; i < layer.batch*layer.inputs; ++i){
if((i%25) && !truth[(i/25)*25]) layer.delta[i] = 0;
}
}
*(layer.output) = dot_cpu(layer.batch*layer.inputs, layer.delta, 1, layer.delta, 1);
//printf("cost: %f\n", *layer.output);
}
Expand All @@ -66,14 +66,21 @@ void backward_cost_layer(const cost_layer layer, float *input, float *delta)
void forward_cost_layer_gpu(cost_layer layer, float * input, float * truth)
{
if (!truth) return;

/*
float *in = calloc(layer.inputs*layer.batch, sizeof(float));
float *t = calloc(layer.inputs*layer.batch, sizeof(float));
cuda_pull_array(input, in, layer.batch*layer.inputs);
cuda_pull_array(truth, t, layer.batch*layer.inputs);
forward_cost_layer(layer, in, t);
cuda_push_array(layer.delta_gpu, layer.delta, layer.batch*layer.inputs);
free(in);
free(t);
*/

copy_ongpu(layer.batch*layer.inputs, truth, 1, layer.delta_gpu, 1);
axpy_ongpu(layer.batch*layer.inputs, -1, input, 1, layer.delta_gpu, 1);

if(layer.type==DETECTION){
mask_ongpu(layer.inputs*layer.batch, layer.delta_gpu, truth, 25);
}

cuda_pull_array(layer.delta_gpu, layer.delta, layer.batch*layer.inputs);
*(layer.output) = dot_cpu(layer.batch*layer.inputs, layer.delta, 1, layer.delta, 1);
//printf("cost: %f\n", *layer.output);
Expand Down
4 changes: 3 additions & 1 deletion src/cost_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
#define COST_LAYER_H

typedef enum{
SSE, DETECTION
SSE
} COST_TYPE;

typedef struct {
int inputs;
int batch;
int coords;
int classes;
float *delta;
float *output;
COST_TYPE type;
Expand Down
2 changes: 2 additions & 0 deletions src/cuda.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ int gpu_index = 0;
#include "cuda.h"
#include "utils.h"
#include "blas.h"
#include "assert.h"
#include <stdlib.h>


Expand All @@ -15,6 +16,7 @@ void check_error(cudaError_t status)
const char *s = cudaGetErrorString(status);
char buffer[256];
printf("CUDA Error: %s\n", s);
assert(0);
snprintf(buffer, 256, "CUDA Error: %s", s);
error(buffer);
}
Expand Down
82 changes: 35 additions & 47 deletions src/darknet.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,42 +36,30 @@ char *class_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus",
void draw_detection(image im, float *box, int side)
{
int classes = 20;
int elems = 4+classes+1;
int elems = 4+classes;
int j;
int r, c;
float amount[AMNT] = {0};
for(r = 0; r < side*side; ++r){
float val = box[r*elems];
for(j = 0; j < AMNT; ++j){
if(val > amount[j]) {
float swap = val;
val = amount[j];
amount[j] = swap;
}
}
}
float smallest = amount[AMNT-1];

for(r = 0; r < side; ++r){
for(c = 0; c < side; ++c){
j = (r*side + c) * elems;
//printf("%d\n", j);
//printf("Prob: %f\n", box[j]);
if(box[j] >= smallest){
int class = max_index(box+j+1, classes);
int z;
for(z = 0; z < classes; ++z) printf("%f %s\n", box[j+1+z], class_names[z]);
printf("%f %s\n", box[j+1+class], class_names[class]);
int class = max_index(box+j, classes);
if(box[j+class] > .02 || 1){
//int z;
//for(z = 0; z < classes; ++z) printf("%f %s\n", box[j+z], class_names[z]);
printf("%f %s\n", box[j+class], class_names[class]);
float red = get_color(0,class,classes);
float green = get_color(1,class,classes);
float blue = get_color(2,class,classes);

j += classes;
int d = im.w/side;
int y = r*d+box[j+1]*d;
int x = c*d+box[j+2]*d;
int h = box[j+3]*im.h;
int w = box[j+4]*im.w;
int y = r*d+box[j]*d;
int x = c*d+box[j+1]*d;
int h = box[j+2]*im.h;
int w = box[j+3]*im.w;
draw_box(im, x-w/2, y-h/2, x+w/2, y+h/2,red,green,blue);
}
}
Expand Down Expand Up @@ -117,29 +105,30 @@ void train_detection_net(char *cfgfile, char *weightfile)
data train, buffer;
int im_dim = 512;
int jitter = 64;
pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, 20, im_dim, im_dim, 7, 7, jitter, &buffer);
int classes = 21;
pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, im_dim, im_dim, 7, 7, jitter, &buffer);
clock_t time;
while(1){
i += 1;
time=clock();
pthread_join(load_thread, 0);
train = buffer;
load_thread = load_data_detection_thread(imgs, paths, plist->size, 20, im_dim, im_dim, 7, 7, jitter, &buffer);
load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, im_dim, im_dim, 7, 7, jitter, &buffer);

/*
image im = float_to_image(im_dim - jitter, im_dim-jitter, 3, train.X.vals[0]);
draw_detection(im, train.y.vals[0], 7);
show_image(im, "truth");
cvWaitKey(0);
*/
/*
image im = float_to_image(im_dim - jitter, im_dim-jitter, 3, train.X.vals[0]);
draw_detection(im, train.y.vals[0], 7);
show_image(im, "truth");
cvWaitKey(0);
*/

printf("Loaded: %lf seconds\n", sec(clock()-time));
time=clock();
float loss = train_network(net, train);
net.seen += imgs;
avg_loss = avg_loss*.9 + loss*.1;
printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), i*imgs);
if(i%800==0){
if(i%100==0){
char buff[256];
sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
save_weights(net, buff);
Expand All @@ -161,7 +150,7 @@ void validate_detection_net(char *cfgfile, char *weightfile)
char **paths = (char **)list_to_array(plist);
int num_output = 1225;
int im_size = 448;
int classes = 20;
int classes = 21;

int m = plist->size;
int i = 0;
Expand All @@ -180,30 +169,29 @@ void validate_detection_net(char *cfgfile, char *weightfile)
num = (i+1)*m/splits - i*m/splits;
char **part = paths+(i*m/splits);
if(i != splits) load_thread = load_data_thread(part, num, 0, 0, num_output, im_size, im_size, &buffer);

fprintf(stderr, "%d: Loaded: %lf seconds\n", i, sec(clock()-time));
matrix pred = network_predict_data(net, val);
int j, k, class;
for(j = 0; j < pred.rows; ++j){
for(k = 0; k < pred.cols; k += classes+4+1){
for(k = 0; k < pred.cols; k += classes+4){

/*
int z;
for(z = 0; z < 25; ++z) printf("%f, ", pred.vals[j][k+z]);
printf("\n");
*/
int z;
for(z = 0; z < 25; ++z) printf("%f, ", pred.vals[j][k+z]);
printf("\n");
*/

float p = pred.vals[j][k];
//if (pred.vals[j][k] > .001){
for(class = 0; class < classes; ++class){
int index = (k)/(classes+4+1);
for(class = 0; class < classes-1; ++class){
int index = (k)/(classes+4);
int r = index/7;
int c = index%7;
float y = (r + pred.vals[j][k+1+classes])/7.;
float x = (c + pred.vals[j][k+2+classes])/7.;
float h = pred.vals[j][k+3+classes];
float w = pred.vals[j][k+4+classes];
printf("%d %d %f %f %f %f %f\n", (i-1)*m/splits + j, class, p*pred.vals[j][k+class+1], y, x, h, w);
float y = (r + pred.vals[j][k+0+classes])/7.;
float x = (c + pred.vals[j][k+1+classes])/7.;
float h = pred.vals[j][k+2+classes];
float w = pred.vals[j][k+3+classes];
printf("%d %d %f %f %f %f %f\n", (i-1)*m/splits + j, class, pred.vals[j][k+class], y, x, h, w);
}
//}
}
Expand Down Expand Up @@ -462,7 +450,7 @@ void test_detection(char *cfgfile, char *weightfile)
if(weightfile){
load_weights(&net, weightfile);
}
int im_size = 224;
int im_size = 448;
set_batch_network(&net, 1);
srand(2222222);
clock_t time;
Expand Down
11 changes: 8 additions & 3 deletions src/data.c
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,20 @@ void fill_truth_detection(char *path, float *truth, int classes, int height, int
float dw = (x - i*box_width)/box_width;
float dh = (y - j*box_height)/box_height;
//printf("%d %d %d %f %f\n", id, i, j, dh, dw);
int index = (i+j*num_width)*(4+classes+1);
truth[index++] = 1;
int index = (i+j*num_width)*(4+classes);
truth[index+id] = 1;
index += classes;
truth[index++] = dh;
truth[index++] = dw;
truth[index++] = h*(height+jitter)/height;
truth[index++] = w*(width+jitter)/width;
}
int i, j;
for(i = 0; i < num_height*num_width*(4+classes); i += 4+classes){
int background = 1;
for(j = i; j < i+classes; ++j) if (truth[j]) background = 0;
truth[i+classes-1] = background;
}
fclose(file);
}

Expand Down Expand Up @@ -209,7 +214,7 @@ data load_data_detection_jitter_random(int n, char **paths, int m, int classes,
data d;
d.shallow = 0;
d.X = load_image_paths(random_paths, n, h, w);
int k = nh*nw*(4+classes+1);
int k = nh*nw*(4+classes);
d.y = make_matrix(n, k);
for(i = 0; i < n; ++i){
int dx = rand()%jitter;
Expand Down
Loading

0 comments on commit fb9e0fe

Please sign in to comment.