aboutsummaryrefslogtreecommitdiff
path: root/src/lib/src/neuralnet_impl.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib/src/neuralnet_impl.h')
-rw-r--r--src/lib/src/neuralnet_impl.h35
1 files changed, 21 insertions, 14 deletions
diff --git a/src/lib/src/neuralnet_impl.h b/src/lib/src/neuralnet_impl.h
index f5a9c63..935c5ea 100644
--- a/src/lib/src/neuralnet_impl.h
+++ b/src/lib/src/neuralnet_impl.h
@@ -2,22 +2,29 @@
2 2
3#include <neuralnet/matrix.h> 3#include <neuralnet/matrix.h>
4 4
5#include <stdbool.h>
6
7/// Linear layer parameters.
8typedef struct nnLinearImpl {
9 nnMatrix weights;
10 nnMatrix biases;
11 bool owned; /// Whether the library owns the weights and biases.
12} nnLinearImpl;
13
14/// Neural network layer.
15typedef struct nnLayerImpl {
16 nnLayerType type;
17 int input_size;
18 int output_size;
19 union {
20 nnLinearImpl linear;
21 };
22} nnLayerImpl;
23
5/// Neural network object. 24/// Neural network object.
6///
7/// We store the transposes of the weight matrices so that we can do forward
8/// passes with a minimal amount of work. That is, if in paper we write:
9///
10/// [w11 w21]
11/// [w12 w22]
12///
13/// then the weight matrix in memory is stored as the following array:
14///
15/// w11 w12 w21 w22
16typedef struct nnNeuralNetwork { 25typedef struct nnNeuralNetwork {
17 int num_layers; // Number of non-input layers (hidden + output). 26 int num_layers; // Number of non-input layers (hidden + output).
18 nnMatrix* weights; // One matrix per non-input layer. 27 nnLayerImpl* layers; // One per non-input layer.
19 nnMatrix* biases; // One vector per non-input layer.
20 nnActivation* activations; // One per non-input layer.
21} nnNeuralNetwork; 28} nnNeuralNetwork;
22 29
23/// A query object that holds all the memory necessary to query a network. 30/// A query object that holds all the memory necessary to query a network.