aboutsummaryrefslogtreecommitdiff
path: root/src/lib/include
diff options
context:
space:
mode:
author3gg <3gg@shellblade.net>2023-12-16 10:21:16 -0800
committer3gg <3gg@shellblade.net>2023-12-16 10:21:16 -0800
commit653e98e029a0d0f110b0ac599e50406060bb0f87 (patch)
tree6f909215218f6720266bde1b3f49aeddad8b1da3 /src/lib/include
parent3df7b6fb0c65295eed4590e6f166d60e89b3c68e (diff)
Decouple activations from linear layer.
Diffstat (limited to 'src/lib/include')
-rw-r--r--src/lib/include/neuralnet/matrix.h3
-rw-r--r--src/lib/include/neuralnet/neuralnet.h51
2 files changed, 35 insertions, 19 deletions
diff --git a/src/lib/include/neuralnet/matrix.h b/src/lib/include/neuralnet/matrix.h
index b7281bf..f80b985 100644
--- a/src/lib/include/neuralnet/matrix.h
+++ b/src/lib/include/neuralnet/matrix.h
@@ -17,6 +17,9 @@ nnMatrix nnMatrixMake(int rows, int cols);
17/// Delete a matrix and free its internal memory. 17/// Delete a matrix and free its internal memory.
18void nnMatrixDel(nnMatrix*); 18void nnMatrixDel(nnMatrix*);
19 19
20/// Construct a matrix from an array of values.
21nnMatrix nnMatrixFromArray(int rows, int cols, const R values[]);
22
20/// Move a matrix. 23/// Move a matrix.
21/// 24///
22/// |in| is an empty matrix after the move. 25/// |in| is an empty matrix after the move.
diff --git a/src/lib/include/neuralnet/neuralnet.h b/src/lib/include/neuralnet/neuralnet.h
index 05c9406..f122c2a 100644
--- a/src/lib/include/neuralnet/neuralnet.h
+++ b/src/lib/include/neuralnet/neuralnet.h
@@ -1,32 +1,45 @@
1#pragma once 1#pragma once
2 2
3#include <neuralnet/matrix.h>
3#include <neuralnet/types.h> 4#include <neuralnet/types.h>
4 5
5typedef struct nnMatrix nnMatrix;
6
7typedef struct nnNeuralNetwork nnNeuralNetwork; 6typedef struct nnNeuralNetwork nnNeuralNetwork;
8typedef struct nnQueryObject nnQueryObject; 7typedef struct nnQueryObject nnQueryObject;
9 8
10/// Neuron activation. 9/// Linear layer parameters.
11typedef enum nnActivation { 10///
12 nnIdentity, 11/// Either one of the following must be set:
12/// a) Training: input and output sizes.
13/// b) Inference: weights + biases.
14typedef struct nnLinearParams {
15 int input_size;
16 int output_size;
17 nnMatrix weights;
18 nnMatrix biases;
19} nnLinearParams;
20
21/// Layer type.
22typedef enum nnLayerType {
23 nnLinear,
13 nnSigmoid, 24 nnSigmoid,
14 nnRelu, 25 nnRelu,
15} nnActivation; 26} nnLayerType;
27
28/// Neural network layer.
29typedef struct nnLayer {
30 nnLayerType type;
31 union {
32 nnLinearParams linear;
33 };
34} nnLayer;
16 35
17/// Create a network. 36/// Create a network.
18nnNeuralNetwork* nnMakeNet( 37nnNeuralNetwork* nnMakeNet(
19 int num_layers, const int* layer_sizes, const nnActivation* activations); 38 const nnLayer* layers, int num_layers, int input_size);
20 39
21/// Delete the network and free its internal memory. 40/// Delete the network and free its internal memory.
22void nnDeleteNet(nnNeuralNetwork**); 41void nnDeleteNet(nnNeuralNetwork**);
23 42
24/// Set the network's weights.
25void nnSetWeights(nnNeuralNetwork*, const R* weights);
26
27/// Set the network's biases.
28void nnSetBiases(nnNeuralNetwork*, const R* biases);
29
30/// Query the network. 43/// Query the network.
31/// 44///
32/// |input| is a matrix of inputs, one row per input and as many columns as the 45/// |input| is a matrix of inputs, one row per input and as many columns as the
@@ -42,10 +55,10 @@ void nnQueryArray(
42 55
43/// Create a query object. 56/// Create a query object.
44/// 57///
45/// The query object holds all the internal memory required to query a network. 58/// The query object holds all the internal memory required to query a network
46/// Query objects allocate all memory up front so that network queries can run 59/// with batches of the given size. Memory is allocated up front so that network
47/// without additional memory allocation. 60/// queries can run without additional memory allocation.
48nnQueryObject* nnMakeQueryObject(const nnNeuralNetwork*, int num_inputs); 61nnQueryObject* nnMakeQueryObject(const nnNeuralNetwork*, int batch_size);
49 62
50/// Delete the query object and free its internal memory. 63/// Delete the query object and free its internal memory.
51void nnDeleteQueryObject(nnQueryObject**); 64void nnDeleteQueryObject(nnQueryObject**);
@@ -60,7 +73,7 @@ int nnNetInputSize(const nnNeuralNetwork*);
60int nnNetOutputSize(const nnNeuralNetwork*); 73int nnNetOutputSize(const nnNeuralNetwork*);
61 74
62/// Return the layer's input size. 75/// Return the layer's input size.
63int nnLayerInputSize(const nnMatrix* weights); 76int nnLayerInputSize(const nnNeuralNetwork*, int layer);
64 77
65/// Return the layer's output size. 78/// Return the layer's output size.
66int nnLayerOutputSize(const nnMatrix* weights); 79int nnLayerOutputSize(const nnNeuralNetwork*, int layer);