|
OCR Project
|
Entry point for the CNN training binary. More...
#include "src/cnn/cnn.h"#include "src/cnn/dataset.h"#include "src/cnn/model.h"#include <stdio.h>#include <stdlib.h>#include <string.h>#include <time.h>#include <sys/stat.h>
Classes | |
| struct | TrainArgs |
| Parsed command-line options for the train binary. More... | |
Macros | |
| #define | DEFAULT_DATA_DIR "training_data/" |
| #define | DEFAULT_MODEL_DIR "models/" |
| #define | MAX_EPOCHS 50 |
| #define | VAL_SPLIT 0.15f /* fraction of data kept for validation */ |
| #define | ES_PATIENCE 5 /* epochs without improvement before stopping */ |
| #define | ES_MIN_DELTA 1e-4f /* minimum val-loss improvement to count */ |
Functions | |
| static void | usage (const char *prog) |
| Print usage information to stderr. | |
| static int | parse_args (int argc, char **argv, TrainArgs *args) |
| Parse argv into a TrainArgs structure. | |
| static double | eval_loss (CNN *net, const Sample *samples, size_t n, double *out_acc) |
| Compute loss and accuracy on a slice of samples (no gradient update). | |
| static void | train_loop (CNN *net, Dataset *ds, int n_epochs) |
| Run the training loop with early stopping. | |
| int | main (int argc, char **argv) |
Entry point for the CNN training binary.
Usage:
Options: –data <dir> Training data root directory (default: training_data/). –output <file> Output model filename (default: models/model_<ts>.bin). -j<N> Number of loader threads (default: 1).
Example:
Exit codes: 0 Success — model saved to the output path. 1 Argument error. 2 Dataset loading error. 3 Model save error.
| #define DEFAULT_DATA_DIR "training_data/" |
| #define DEFAULT_MODEL_DIR "models/" |
| #define ES_MIN_DELTA 1e-4f /* minimum val-loss improvement to count */ |
| #define ES_PATIENCE 5 /* epochs without improvement before stopping */ |
| #define MAX_EPOCHS 50 |
| #define VAL_SPLIT 0.15f /* fraction of data kept for validation */ |
Compute loss and accuracy on a slice of samples (no gradient update).
| net | Trained CNN. |
| samples | Pointer to first sample. |
| n | Number of samples. |
| out_acc | Filled with accuracy in [0, 1]. |


| int main | ( | int | argc, |
| char ** | argv ) |

|
static |
Parse argv into a TrainArgs structure.
| argc | Argument count. |
| argv | Argument vector. |
| args | Output structure (caller-allocated). |

Run the training loop with early stopping.
The dataset is split into a training portion (1 - VAL_SPLIT) and a validation portion (VAL_SPLIT) after an initial shuffle.
Early stopping: training halts when validation loss has not improved by more than ES_MIN_DELTA for ES_PATIENCE consecutive epochs. The best weights observed so far are restored before returning.
| net | Initialised CNN. |
| ds | Full dataset (will be shuffled in-place). |
| n_epochs | Maximum number of epochs. |


|
static |
Print usage information to stderr.
| prog | Program name (argv[0]). |
