-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathhard_tanh_transformer.h
28 lines (24 loc) · 1.03 KB
/
hard_tanh_transformer.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#ifndef SYRENN_SYRENN_SERVER_HARD_TANH_TRANSFORMER_H_
#define SYRENN_SYRENN_SERVER_HARD_TANH_TRANSFORMER_H_
#include <memory>
#include <string>
#include "syrenn_proto/syrenn.grpc.pb.h"
#include "syrenn_server/segmented_line.h"
#include "syrenn_server/pwl_transformer.h"
// Transformer for Hard Tanh layers.
class HardTanhTransformer : public PWLTransformer {
public:
static std::unique_ptr<LayerTransformer> Deserialize(
const syrenn_server::Layer &layer);
void Compute(RMMatrixXf *inout) const override;
std::string layer_type() const override { return "HardTanh"; };
size_t out_size(size_t in_size) const override { return in_size; }
protected:
size_t n_piece_faces(size_t dims) const override;
double CrossingRatio(Eigen::Ref<const RMVectorXf> from,
Eigen::Ref<const RMVectorXf> to,
const size_t face) const override;
int PointSign(Eigen::Ref<const RMVectorXf> point,
const size_t face) const override;
};
#endif // SYRENN_SYRENN_SERVER_HARD_TANH_TRANSFORMER_H_