diff options
| -rw-r--r-- | src/lib/src/train.c | 13 | 
1 files changed, 6 insertions, 7 deletions
diff --git a/src/lib/src/train.c b/src/lib/src/train.c index fe9f598..7559ece 100644 --- a/src/lib/src/train.c +++ b/src/lib/src/train.c  | |||
| @@ -239,17 +239,16 @@ void nnTrain( | |||
| 239 | 239 | ||
| 240 | // Compute this layer's gradient. | 240 | // Compute this layer's gradient. | 
| 241 | // | 241 | // | 
| 242 | // By "gradient" we mean the expression common to the weights and bias | 242 | // By 'gradient' we mean the subexpression common to all the gradients | 
| 243 | // gradients. This is the part of the expression that does not contain | 243 | // for this layer. | 
| 244 | // this layer's input. | 244 | // For linear layers, this is the subexpression common to both the | 
| 245 | // weights and bias gradients. | ||
| 245 | // | 246 | // | 
| 246 | // Linear: G = id | 247 | // Linear: G = id | 
| 247 | // Relu: G = (output_k > 0 ? 1 : 0) | 248 | // Relu: G = (output_k > 0 ? 1 : 0) | 
| 248 | // Sigmoid: G = output_k * (1 - output_k) | 249 | // Sigmoid: G = output_k * (1 - output_k) | 
| 249 | switch (layer->type) { | 250 | switch (layer->type) { | 
| 250 | case nnLinear: { | 251 | case nnLinear: { | 
| 251 | // TODO: Just copy the pointer? | ||
| 252 | *gradient = nnMatrixBorrow(&errors[l]); | ||
| 253 | break; | 252 | break; | 
| 254 | } | 253 | } | 
| 255 | case nnRelu: | 254 | case nnRelu: | 
| @@ -294,7 +293,7 @@ void nnTrain( | |||
| 294 | nnMatrix* layer_biases = &linear->biases; | 293 | nnMatrix* layer_biases = &linear->biases; | 
| 295 | 294 | ||
| 296 | // Outer product to compute the weight deltas. | 295 | // Outer product to compute the weight deltas. | 
| 297 | nnMatrixMulOuter(layer_input, gradient, &weight_deltas[l]); | 296 | nnMatrixMulOuter(layer_input, &errors[l], &weight_deltas[l]); | 
| 298 | 297 | ||
| 299 | // Update weights. | 298 | // Update weights. | 
| 300 | nnMatrixScale(&weight_deltas[l], params->learning_rate); | 299 | nnMatrixScale(&weight_deltas[l], params->learning_rate); | 
| @@ -304,7 +303,7 @@ void nnTrain( | |||
| 304 | // This is the same formula as for weights, except that the o_j term | 303 | // This is the same formula as for weights, except that the o_j term | 
| 305 | // is just 1. | 304 | // is just 1. | 
| 306 | nnMatrixMulSub( | 305 | nnMatrixMulSub( | 
| 307 | layer_biases, gradient, params->learning_rate, layer_biases); | 306 | layer_biases, &errors[l], params->learning_rate, layer_biases); | 
| 308 | } | 307 | } | 
| 309 | } | 308 | } | 
| 310 | 309 | ||
