Skip to content

Training

For field-by-field trainer parameter meanings, see Configuration Reference.

Unified YAML entrypoint

All example runs go through:

python examples/trainer.py --config <path-to-config.yaml>

Each config is split into:

  • dataset
  • model
  • encoder
  • decoder
  • trainer

dataset, model, encoder, and decoder use name + config structure. trainer is a flat config block.

Runtime placeholders

Example configs support runtime overrides such as:

max_vectors: ${max_vectors:5000}$
epochs: ${epoch:5}$
learning_rate: ${lr:0.001}$

So you can run:

python examples/trainer.py --config examples/configs/glove/ae.yaml --epoch 10 --lr 0.0005

Trainer optimization options

TrainingConfig now supports:

  • optimizer_name: adam, adamw, sgd, rmsprop, adagrad
  • weight_decay
  • lr_scheduler_type: none, constant, linear, cosine
  • warmup_epochs
  • grad_clip_norm
  • log_memory

Example:

trainer:
  output_dir: artifacts/glove/ae
  epochs: ${epoch:5}$
  batch_size: 256
  learning_rate: ${lr:0.001}$
  optimizer_name: adamw
  weight_decay: ${wd:0.01}$
  lr_scheduler_type: cosine
  warmup_epochs: ${warmup:1}$
  grad_clip_norm: ${clip:1.0}$
  log_memory: false
  device: auto
  seed: 42

Set log_memory: true when you want the trainer to print MEM snapshots after train, validation, and final test passes. This is useful when debugging long-run CUDA or MPS memory growth.

Evaluate metric list

The trainer logs and save_best_by use short metric names:

  • loss
  • recon
  • binary
  • bal
  • decor
  • sparse
  • topk
  • kl-sparse
  • contract
  • mmd
  • adv
  • disc
  • commit
  • book
  • kl
  • free-kl
  • codes
  • usage
  • ppl
  • dead
  • coll

coll is the collision rate for quantized models and is now shown by default.

Display metric switches

Use trainer.display_metrics to hide specific metrics from terminal logs. Any omitted metric defaults to true.

trainer:
  output_dir: artifacts/glove/rqvae
  epochs: ${epoch:5}$
  batch_size: 256
  learning_rate: ${lr:0.001}$
  display_metrics:
    commit: false
    book: false
    coll: true
  save_best_by: [loss, coll]

display_metrics and save_best_by both use the same short names listed above.

Family-specific trainers

  • AETrainer
  • VAETrainer
  • VQTrainer
  • FactorVAETrainer
  • AdversarialAutoencoderTrainer

The unified entrypoint selects the correct trainer from model.name.