-
Notifications
You must be signed in to change notification settings - Fork 4
/
tensorflow.h
80 lines (68 loc) · 2.43 KB
/
tensorflow.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#ifndef TENSORFLOW_H
#define TENSORFLOW_H
#include <QStringList>
#include <QImage>
#include <QRectF>
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/framework/tensor.h"
class Tensorflow
{
public:
static const int knIMAGE_CLASSIFIER = 1;
static const int knOBJECT_DETECTION = 2;
Tensorflow();
bool init(int imgHeight, int imgWidth);
void initInput(int imgHeight, int imgWidth);
double getThreshold() const;
void setThreshold(double value);
QStringList getResults();
QList<double> getConfidence();
QList<QRectF> getBoxes();
int getKindNetwork();
bool run(QImage img);
QString getModelFilename() const;
void setModelFilename(const QString &value);
QString getLabelsFilename() const;
void setLabelsFilename(const QString &value);
int getImgHeight() const;
int getImgWidth() const;
double getInfTime() const;
private:
// Fixed image size for image classification
const int fixed_width = 224;
const int fixed_heigth = 224;
// Output names
const QString num_detections = "num_detections";
const QString detection_classes = "detection_classes";
const QString detection_scores = "detection_scores";
const QString detection_boxes = "detection_boxes";
// Output lists
const std::vector<std::string> listOutputsObjDet = {num_detections.toStdString(),detection_classes.toStdString(),detection_scores.toStdString(),detection_boxes.toStdString()};
const std::vector<std::string> listOutputsImgCla = {"MobilenetV2/Predictions/Reshape_1"};
bool initialized;
double threshold;
// Results
QStringList rCaption;
QList<double> rConfidence;
QList<QRectF> rBox;
double infTime;
int kind_network;
std::unique_ptr<tensorflow::Session> session;
std::vector<tensorflow::Tensor> outputs;
bool inference();
bool setInputs(QImage image);
bool getClassfierOutputs(int &index, double &score);
bool getObjectOutputs(QStringList &captions, QList<double> &confidences, QList<QRectF> &locations);
bool readLabels();
QString input_name;
tensorflow::DataType input_dtype;
std::unique_ptr<tensorflow::Tensor> input_tensor;
QString modelFilename;
QString labelsFilename;
QStringList labels;
QString getLabel(int index);
int img_height, img_width, img_channels;
const QImage::Format format = QImage::Format_RGB888;
const int numChannels = 3;
};
#endif // TENSORFLOW_H