OCR Project
Loading...
Searching...
No Matches
cnn.h
Go to the documentation of this file.
1
18
19#ifndef CNN_H
20#define CNN_H
21
22#include <stddef.h>
23
24/* -------------------------------------------------------------------------
25 * Architecture constants
26 * ---------------------------------------------------------------------- */
27
29#define CNN_IMG_H 56
30#define CNN_IMG_W 56
31
33#define CNN_N_FILTERS 16
34#define CNN_KERNEL_H 3
35#define CNN_KERNEL_W 3
36
38#define CNN_CONV_H (CNN_IMG_H - CNN_KERNEL_H + 1) /* 26 */
39#define CNN_CONV_W (CNN_IMG_W - CNN_KERNEL_W + 1) /* 26 */
40
42#define CNN_POOL_H 4
43#define CNN_POOL_W 4
44
46#define CNN_POOL_OUT_H (CNN_CONV_H / CNN_POOL_H) /* 13 */
47#define CNN_POOL_OUT_W (CNN_CONV_W / CNN_POOL_W) /* 13 */
48
50#define CNN_FLAT_SIZE (CNN_N_FILTERS * CNN_POOL_OUT_H * CNN_POOL_OUT_W) /* 2704 */
51
53#define CNN_HIDDEN_SIZE 128
54
56#define CNN_N_CLASSES 26
57
58/* -------------------------------------------------------------------------
59 * Hyper-parameters
60 * ---------------------------------------------------------------------- */
61
63#define CNN_LR 0.001f
64
66#define CNN_MOMENTUM 0.9f
67
69#define CNN_BATCH_SIZE 32
70
71/* -------------------------------------------------------------------------
72 * Data structures
73 * ---------------------------------------------------------------------- */
74
97
133
144
145/* -------------------------------------------------------------------------
146 * Public API
147 * ---------------------------------------------------------------------- */
148
158CNN *cnn_create(void);
159
165void cnn_free(CNN *net);
166
176void cnn_forward(CNN *net, const float *image);
177
188void cnn_backward(CNN *net, int label);
189
199void cnn_update(CNN *net);
200
206void cnn_zero_grads(CNN *net);
207
217int cnn_predict(CNN *net, const float *image);
218
228float cnn_loss(const CNN *net, int label);
229
230#endif /* CNN_H */
#define CNN_POOL_OUT_H
Definition cnn.h:46
#define CNN_FLAT_SIZE
Definition cnn.h:50
void cnn_backward(CNN *net, int label)
Compute gradients via backpropagation.
Definition cnn.c:294
#define CNN_POOL_OUT_W
Definition cnn.h:47
#define CNN_KERNEL_W
Definition cnn.h:35
int cnn_predict(CNN *net, const float *image)
Predict the most likely class for a single image.
Definition cnn.c:449
#define CNN_HIDDEN_SIZE
Definition cnn.h:53
#define CNN_IMG_W
Definition cnn.h:30
#define CNN_IMG_H
Definition cnn.h:29
void cnn_zero_grads(CNN *net)
Zero all gradient accumulators in net->grads.
Definition cnn.c:385
float cnn_loss(const CNN *net, int label)
Compute cross-entropy loss for the current forward-pass output.
Definition cnn.c:283
void cnn_forward(CNN *net, const float *image)
Run a full forward pass and populate net->act.
Definition cnn.c:266
#define CNN_N_CLASSES
Definition cnn.h:56
void cnn_update(CNN *net)
Apply one SGD-with-momentum update step to the weights.
Definition cnn.c:412
void cnn_free(CNN *net)
Free all memory associated with a CNN.
Definition cnn.c:123
CNN * cnn_create(void)
Allocate and initialise a CNN with He-initialised weights.
Definition cnn.c:90
#define CNN_KERNEL_H
Definition cnn.h:34
#define CNN_CONV_H
Definition cnn.h:38
#define CNN_N_FILTERS
Definition cnn.h:33
#define CNN_CONV_W
Definition cnn.h:39
Intermediate activations produced by the forward pass.
Definition cnn.h:103
int pool_max_c[CNN_N_FILTERS][CNN_POOL_OUT_H][CNN_POOL_OUT_W]
Definition cnn.h:118
float z1[CNN_HIDDEN_SIZE]
Definition cnn.h:124
float output[CNN_N_CLASSES]
Definition cnn.h:131
float h1[CNN_HIDDEN_SIZE]
Definition cnn.h:126
float pool_out[CNN_N_FILTERS][CNN_POOL_OUT_H][CNN_POOL_OUT_W]
Definition cnn.h:111
int pool_max_r[CNN_N_FILTERS][CNN_POOL_OUT_H][CNN_POOL_OUT_W]
Definition cnn.h:117
float flat[CNN_FLAT_SIZE]
Definition cnn.h:121
float z2[CNN_N_CLASSES]
Definition cnn.h:129
float conv_out[CNN_N_FILTERS][CNN_CONV_H][CNN_CONV_W]
Definition cnn.h:108
float input[CNN_IMG_H][CNN_IMG_W]
Definition cnn.h:105
Learnable parameters of the CNN.
Definition cnn.h:81
float kernels[CNN_N_FILTERS][CNN_KERNEL_H][CNN_KERNEL_W]
Definition cnn.h:83
float W1[CNN_HIDDEN_SIZE][CNN_FLAT_SIZE]
Definition cnn.h:88
float b1[CNN_HIDDEN_SIZE]
Definition cnn.h:90
float conv_bias[CNN_N_FILTERS]
Definition cnn.h:85
float b2[CNN_N_CLASSES]
Definition cnn.h:95
float W2[CNN_N_CLASSES][CNN_HIDDEN_SIZE]
Definition cnn.h:93
Full CNN state: weights, gradients, momentum buffers, activations.
Definition cnn.h:137
CNNActivations act
Definition cnn.h:141
CNNWeights weights
Definition cnn.h:138
CNNWeights velocity
Definition cnn.h:140
CNNWeights grads
Definition cnn.h:139
int batch_count
Definition cnn.h:142