OCR Project
Loading...
Searching...
No Matches
train_main.c File Reference

Entry point for the CNN training binary. More...

#include "src/cnn/cnn.h"
#include "src/cnn/dataset.h"
#include "src/cnn/model.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <sys/stat.h>
Include dependency graph for train_main.c:

Classes

struct  TrainArgs
 Parsed command-line options for the train binary. More...

Macros

#define DEFAULT_DATA_DIR   "training_data/"
#define DEFAULT_MODEL_DIR   "models/"
#define MAX_EPOCHS   50
#define VAL_SPLIT   0.15f /* fraction of data kept for validation */
#define ES_PATIENCE   5 /* epochs without improvement before stopping */
#define ES_MIN_DELTA   1e-4f /* minimum val-loss improvement to count */

Functions

static void usage (const char *prog)
 Print usage information to stderr.
static int parse_args (int argc, char **argv, TrainArgs *args)
 Parse argv into a TrainArgs structure.
static double eval_loss (CNN *net, const Sample *samples, size_t n, double *out_acc)
 Compute loss and accuracy on a slice of samples (no gradient update).
static void train_loop (CNN *net, Dataset *ds, int n_epochs)
 Run the training loop with early stopping.
int main (int argc, char **argv)

Detailed Description

Entry point for the CNN training binary.

Usage:

./train [--data <dir>] [--output <file>] [-j<N>]

Options: –data <dir> Training data root directory (default: training_data/). –output <file> Output model filename (default: models/model_<ts>.bin). -j<N> Number of loader threads (default: 1).

Example:

./train --data training_data/ --output my_model.bin -j4

Exit codes: 0 Success — model saved to the output path. 1 Argument error. 2 Dataset loading error. 3 Model save error.

Macro Definition Documentation

◆ DEFAULT_DATA_DIR

#define DEFAULT_DATA_DIR   "training_data/"

◆ DEFAULT_MODEL_DIR

#define DEFAULT_MODEL_DIR   "models/"

◆ ES_MIN_DELTA

#define ES_MIN_DELTA   1e-4f /* minimum val-loss improvement to count */

◆ ES_PATIENCE

#define ES_PATIENCE   5 /* epochs without improvement before stopping */

◆ MAX_EPOCHS

#define MAX_EPOCHS   50

◆ VAL_SPLIT

#define VAL_SPLIT   0.15f /* fraction of data kept for validation */

Function Documentation

◆ eval_loss()

double eval_loss ( CNN * net,
const Sample * samples,
size_t n,
double * out_acc )
static

Compute loss and accuracy on a slice of samples (no gradient update).

Parameters
netTrained CNN.
samplesPointer to first sample.
nNumber of samples.
out_accFilled with accuracy in [0, 1].
Returns
Average cross-entropy loss.
Here is the call graph for this function:
Here is the caller graph for this function:

◆ main()

int main ( int argc,
char ** argv )
Here is the call graph for this function:

◆ parse_args()

int parse_args ( int argc,
char ** argv,
TrainArgs * args )
static

Parse argv into a TrainArgs structure.

Parameters
argcArgument count.
argvArgument vector.
argsOutput structure (caller-allocated).
Returns
0 on success, -1 on parse error.
Here is the caller graph for this function:

◆ train_loop()

void train_loop ( CNN * net,
Dataset * ds,
int n_epochs )
static

Run the training loop with early stopping.

The dataset is split into a training portion (1 - VAL_SPLIT) and a validation portion (VAL_SPLIT) after an initial shuffle.

Early stopping: training halts when validation loss has not improved by more than ES_MIN_DELTA for ES_PATIENCE consecutive epochs. The best weights observed so far are restored before returning.

Parameters
netInitialised CNN.
dsFull dataset (will be shuffled in-place).
n_epochsMaximum number of epochs.
Here is the call graph for this function:
Here is the caller graph for this function:

◆ usage()

void usage ( const char * prog)
static

Print usage information to stderr.

Parameters
progProgram name (argv[0]).
Here is the caller graph for this function: