diff options
Diffstat (limited to 'src/lib/src/neuralnet_impl.h')
-rw-r--r-- | src/lib/src/neuralnet_impl.h | 35 |
1 files changed, 21 insertions, 14 deletions
diff --git a/src/lib/src/neuralnet_impl.h b/src/lib/src/neuralnet_impl.h index f5a9c63..935c5ea 100644 --- a/src/lib/src/neuralnet_impl.h +++ b/src/lib/src/neuralnet_impl.h | |||
@@ -2,22 +2,29 @@ | |||
2 | 2 | ||
3 | #include <neuralnet/matrix.h> | 3 | #include <neuralnet/matrix.h> |
4 | 4 | ||
5 | #include <stdbool.h> | ||
6 | |||
7 | /// Linear layer parameters. | ||
8 | typedef struct nnLinearImpl { | ||
9 | nnMatrix weights; | ||
10 | nnMatrix biases; | ||
11 | bool owned; /// Whether the library owns the weights and biases. | ||
12 | } nnLinearImpl; | ||
13 | |||
14 | /// Neural network layer. | ||
15 | typedef struct nnLayerImpl { | ||
16 | nnLayerType type; | ||
17 | int input_size; | ||
18 | int output_size; | ||
19 | union { | ||
20 | nnLinearImpl linear; | ||
21 | }; | ||
22 | } nnLayerImpl; | ||
23 | |||
5 | /// Neural network object. | 24 | /// Neural network object. |
6 | /// | ||
7 | /// We store the transposes of the weight matrices so that we can do forward | ||
8 | /// passes with a minimal amount of work. That is, if in paper we write: | ||
9 | /// | ||
10 | /// [w11 w21] | ||
11 | /// [w12 w22] | ||
12 | /// | ||
13 | /// then the weight matrix in memory is stored as the following array: | ||
14 | /// | ||
15 | /// w11 w12 w21 w22 | ||
16 | typedef struct nnNeuralNetwork { | 25 | typedef struct nnNeuralNetwork { |
17 | int num_layers; // Number of non-input layers (hidden + output). | 26 | int num_layers; // Number of non-input layers (hidden + output). |
18 | nnMatrix* weights; // One matrix per non-input layer. | 27 | nnLayerImpl* layers; // One per non-input layer. |
19 | nnMatrix* biases; // One vector per non-input layer. | ||
20 | nnActivation* activations; // One per non-input layer. | ||
21 | } nnNeuralNetwork; | 28 | } nnNeuralNetwork; |
22 | 29 | ||
23 | /// A query object that holds all the memory necessary to query a network. | 30 | /// A query object that holds all the memory necessary to query a network. |