Add a New Task/Dataset
This tutorial walks through how to add a new downstream task in the WavesFM codebase. The core idea is:
- Preprocess raw data into an
.h5cache file (the training code does not parse raw datasets). - Load that cache via a small
Datasetclass. - Register the task in
data.pysomain_finetune.pycan build datasets, pick the right model adapter, and choose losses/metrics.
If you follow the existing conventions, you usually only touch preprocessing/ and data.py. Changes to models_vit.py are optional and mostly relevant when you add a new modality or need different tokenization limits.
Relevant WavesFM source files
These are the main “touch points” you’ll reference while implementing a new task:
- Task registry + dataset wiring:
data.py - Training CLI:
main_finetune.py - Model (vision + IQ adapters):
models_vit.py - Dataset loaders:
dataset_classes/(especiallydataset_classes/base.py) - Metrics/eval routing:
engine.py - Preprocessing scripts:
preprocessing/
Step 0 — Decide the task “shape”
WavesFM routes tasks using two fields in data.py:TaskInfo:
modality:visionoriqvisioncovers image-like tensors (including spectrograms).iqcovers time-domain IQ tensors.
target_type:classification,position, orregressionclassification→ cross-entropy + accuracy/per-class accuracy metrics.position→ MSE on normalized coordinates (expects targets scaled to[-1, 1]) + distance error metrics.regression→ MSE + MAE/RMSE metrics.
That decision determines what your .h5 should contain and what you’ll return from Dataset.__getitem__().
Step 1 — Preprocess into an HDF5 cache (.h5)
All training runs read from preprocessed .h5 files. At minimum, your cache needs:
- A dataset for samples (e.g.
image,csi,sample, …). - A dataset for targets (usually
label).
You can look at existing preprocessing scripts under
preprocessing/ for conventions and dataset-specific quirks.
Common conventions already used in WavesFM:
- Vision-style samples: stored per example as
(C, H, W)(float32). - IQ-style samples: stored per example as
(2, C, T)(float32), where2is real/imag. - Classification labels: stored as an integer scalar per example (int64).
- Optional HDF5 attributes:
labels: JSON list of class names (used for display/debugging).class_weights: numpy array of class weights (used when training with--class-weights).- For
positiontasks: storecoord_*metadata (see Step 3).
Minimal write pattern (pseudo-code):
import h5py
import numpy as np
with h5py.File("data/mytask.h5", "w") as h5:
h5.create_dataset("image", data=samples.astype(np.float32)) # (N,C,H,W)
h5.create_dataset("label", data=labels.astype(np.int64)) # (N,)
h5.attrs["labels"] = '["class0","class1", ...]' # optional
h5.attrs["class_weights"] = np.asarray([...], dtype=np.float32) # optional
Position tasks: coordinate normalization metadata
The built-in pos task reads coordinate bounds from HDF5 attributes named coord_nominal_min and coord_nominal_max
(strings that parse as JSON arrays). You can follow the same convention:
import json
h5.attrs["coord_nominal_min"] = json.dumps(coord_min.tolist())
h5.attrs["coord_nominal_max"] = json.dumps(coord_max.tolist())
Those bounds are used at evaluation time to denormalize model outputs/targets and report distance errors.
Step 2 — Load the cache (dataset_classes/)
WavesFM already includes minimal HDF5-backed dataset helpers:
dataset_classes/base.py:ImageDataset(for vision-style caches)dataset_classes/base.py:IQDataset(for IQ-style caches)
If your .h5 uses standard keys, you can reuse them directly from data.py:
- Vision example:
ImageDataset(path, sample_key="image", label_key="label") - IQ example:
IQDataset(path, sample_key="sample", label_key="label")
The “contract” is simple (from dataset_classes/base.py):
sample = torch.as_tensor(h5[self.sample_key][idx])
label = torch.as_tensor(h5[self.label_key][idx], dtype=self.label_dtype)
return sample, label
If you need dataset-specific behavior (e.g. a different label mapping, extra metadata, or special shapes), add a new file under dataset_classes/ and export it in dataset_classes/__init__.py.
Example: a thin wrapper around ImageDataset that renames keys and exposes coordinate bounds:
from dataset_classes.base import ImageDataset
import h5py
import json
import torch
class MyTask(ImageDataset):
def __init__(self, h5_path: str):
super().__init__(h5_path, sample_key="image", label_key="label", label_dtype=torch.float32)
with h5py.File(self.h5_path, "r") as h5:
self.coord_min = torch.tensor(json.loads(h5.attrs["coord_nominal_min"]), dtype=torch.float32)
self.coord_max = torch.tensor(json.loads(h5.attrs["coord_nominal_max"]), dtype=torch.float32)
Return type requirement:
__getitem__must return at least(sample, label).- If you want to use
--stratified-split,labelmust be a scalar class id (not a vector).
Step 3 — Register the task in data.py
data.py is the registry that connects task ids (CLI strings) to datasets and metadata.
If you haven’t opened it yet, start here: data.py.
3.1 Add the task id
Add your task name to SUPPORTED_TASKS in data.py (this also becomes a valid --task choice in main_finetune.py).
SUPPORTED_TASKS = (
# ...
"mytask",
)
3.2 Map the task to a dataset loader
Add a branch in _dataset_factory() that returns a callable producing a torch.utils.data.Dataset.
Example for a new vision classification task:
if task == "mytask":
return lambda p: ImageDataset(p, sample_key="image", label_key="label")
3.3 Describe the task for the training loop
Add a branch in _infer_task_info() that returns TaskInfo.
For classification you must set:
modality(visionoriq)target_type="classification"num_outputs(number of classes)in_chansfor vision tasks (channels in your stored tensor)
Example:
if task == "mytask":
num_classes = 12
return TaskInfo(
name=task,
modality="vision",
target_type="classification",
num_outputs=num_classes,
in_chans=1, # your stored (C,H,W)
)
For position tasks, WavesFM expects targets normalized to [-1, 1] and uses coord_min/coord_max to denormalize when computing metrics. You can supply them either:
- from HDF5 attributes (see how
posreadscoord_nominal_min/max), or - from fields on your dataset object (see
UWBIndoor.loc_min/loc_max).
Position example (reading bounds from your dataset wrapper):
if task == "mytask-pos":
sample, label = dataset[0]
return TaskInfo(
name=task,
modality="vision",
target_type="position",
num_outputs=int(label.numel()),
in_chans=int(sample.shape[0]),
coord_min=dataset.coord_min,
coord_max=dataset.coord_max,
)
3.4 (Optional) Class weights
build_datasets() will attach train_ds.class_weights if your .h5 includes a class_weights attribute. In training, pass --class-weights to enable weighted cross-entropy.
Step 4 — Understand when to modify models_vit.py
Most new tasks do not require changes here.
models_vit.py defines ModalityAdapterViT, which chooses an input adapter at init time:
modality="vision"→ patch embedding over(C,H,W)modality="iq"→ segment-based tokenization over(2,C,T)
That routing happens in the model’s forward() (from models_vit.py):
def forward(self, x: torch.Tensor, time_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.modality == "vision":
tok, token_mask = self._tokens_from_vision(x) # expects (N,C,H,W)
else:
tok, token_mask = self._tokens_from_iq(x, time_mask) # expects (N,2,C,T)
z, token_mask = self.forward_features(tok, token_mask)
return self.forward_head(z, token_mask)
main_finetune.py constructs the model by passing:
modalityandnum_outputsfromTaskInfovis_in_chans_actualfromTaskInfo.in_chans(vision tasks)- IQ tokenization knobs from CLI (
--iq-segment-len,--iq-target-len,--iq-downsample)
Vision channel mismatches are handled by an optional 1×1 “channel adapter” inside the model. In practice, that means:
- you set
TaskInfo.in_chansto match what you stored in HDF5 (e.g.2for real/imag CSI grids), and - the model will map that to the pretrained patch embed’s expected channels when needed.
You’d typically modify models_vit.py only if:
- You are adding a new modality (beyond
vision/iq). - Your IQ data has more antennas than the model config supports (see
iq_max_antennasin thevit_multi_*builders). - Your IQ token count would exceed
iq_max_tokens(the tokenizer truncates tokens past that limit).
Step 5 — Extend the CLI/training loop (main_finetune.py) when needed
For standard tasks, you don’t need to change main_finetune.py: it already calls build_datasets() and uses TaskInfo to pick:
- model adapter (
visionvsiq) - loss function (
classificationvsposition/regression) - evaluation path (generic classification, positioning metrics, etc.)
The key call chain is:
main_finetune.pyparses args and callsbuild_datasets()fromdata.pybuild_datasets()returns(train_ds, val_ds, task_info)build_model()usestask_infoto build the correctModalityAdapterViT
You can see it directly in main_finetune.py:
train_ds, val_ds, task_info = build_datasets(
args.task,
args.train_path,
val_path=args.val_path,
val_split=args.val_split,
stratified_split=args.stratified_split,
seed=args.seed,
deepmimo_n_beams=args.deepmimo_n_beams,
)
model = build_model(args, task_info)
You do edit main_finetune.py when your dataset needs a task-specific argument that affects dataset construction. Example pattern already in the codebase: --deepmimo-n-beams gets passed into build_datasets(..., deepmimo_n_beams=...).
If your new task needs custom metrics, add a task-specific branch in engine.py:evaluate() (optional, but common for multi-metric datasets).
Example pattern (from engine.py):
if task_type == "classification" and task_name == "radcom":
return evaluate_radcom(...)
Smoke-test checklist
Before running full sweeps, verify the plumbing in this order:
h5opens and has expected keys (image/labeletc.).- Your dataset returns the right shapes/dtypes for a single item.
python main_finetune.py --task <your_task> --train-data <cache.h5> --val-split 0.2 --epochs 1 --batch-size 2runs end-to-end.- For classification, confirm labels are contiguous class ids in
[0, num_outputs-1]. - For
position, confirm targets are normalized to[-1, 1]andcoord_min/maxare correct.
Tip: for a fast “does it run” check, use the small model config and CPU:
python main_finetune.py \
--task <your_task> \
--train-data <cache.h5> \
--val-split 0.2 \
--model vit_multi_micro \
--device cpu \
--epochs 1 \
--batch-size 2