Skip to main content

Add a New Task/Dataset

This tutorial walks through how to add a new downstream task in the WavesFM codebase. The core idea is:

  1. Preprocess raw data into an .h5 cache file (the training code does not parse raw datasets).
  2. Load that cache via a small Dataset class.
  3. Register the task in data.py so main_finetune.py can build datasets, pick the right model adapter, and choose losses/metrics.

If you follow the existing conventions, you usually only touch preprocessing/ and data.py. Changes to models_vit.py are optional and mostly relevant when you add a new modality or need different tokenization limits.

Relevant WavesFM source files

These are the main “touch points” you’ll reference while implementing a new task:


Step 0 — Decide the task “shape”

WavesFM routes tasks using two fields in data.py:TaskInfo:

  • modality: vision or iq
    • vision covers image-like tensors (including spectrograms).
    • iq covers time-domain IQ tensors.
  • target_type: classification, position, or regression
    • classification → cross-entropy + accuracy/per-class accuracy metrics.
    • position → MSE on normalized coordinates (expects targets scaled to [-1, 1]) + distance error metrics.
    • regression → MSE + MAE/RMSE metrics.

That decision determines what your .h5 should contain and what you’ll return from Dataset.__getitem__().


Step 1 — Preprocess into an HDF5 cache (.h5)

All training runs read from preprocessed .h5 files. At minimum, your cache needs:

  • A dataset for samples (e.g. image, csi, sample, …).
  • A dataset for targets (usually label).

You can look at existing preprocessing scripts under preprocessing/ for conventions and dataset-specific quirks.

Common conventions already used in WavesFM:

  • Vision-style samples: stored per example as (C, H, W) (float32).
  • IQ-style samples: stored per example as (2, C, T) (float32), where 2 is real/imag.
  • Classification labels: stored as an integer scalar per example (int64).
  • Optional HDF5 attributes:
    • labels: JSON list of class names (used for display/debugging).
    • class_weights: numpy array of class weights (used when training with --class-weights).
    • For position tasks: store coord_* metadata (see Step 3).

Minimal write pattern (pseudo-code):

import h5py
import numpy as np

with h5py.File("data/mytask.h5", "w") as h5:
h5.create_dataset("image", data=samples.astype(np.float32)) # (N,C,H,W)
h5.create_dataset("label", data=labels.astype(np.int64)) # (N,)
h5.attrs["labels"] = '["class0","class1", ...]' # optional
h5.attrs["class_weights"] = np.asarray([...], dtype=np.float32) # optional

Position tasks: coordinate normalization metadata

The built-in pos task reads coordinate bounds from HDF5 attributes named coord_nominal_min and coord_nominal_max (strings that parse as JSON arrays). You can follow the same convention:

import json

h5.attrs["coord_nominal_min"] = json.dumps(coord_min.tolist())
h5.attrs["coord_nominal_max"] = json.dumps(coord_max.tolist())

Those bounds are used at evaluation time to denormalize model outputs/targets and report distance errors.


Step 2 — Load the cache (dataset_classes/)

WavesFM already includes minimal HDF5-backed dataset helpers:

  • dataset_classes/base.py:ImageDataset (for vision-style caches)
  • dataset_classes/base.py:IQDataset (for IQ-style caches)

If your .h5 uses standard keys, you can reuse them directly from data.py:

  • Vision example: ImageDataset(path, sample_key="image", label_key="label")
  • IQ example: IQDataset(path, sample_key="sample", label_key="label")

The “contract” is simple (from dataset_classes/base.py):

sample = torch.as_tensor(h5[self.sample_key][idx])
label = torch.as_tensor(h5[self.label_key][idx], dtype=self.label_dtype)
return sample, label

If you need dataset-specific behavior (e.g. a different label mapping, extra metadata, or special shapes), add a new file under dataset_classes/ and export it in dataset_classes/__init__.py.

Example: a thin wrapper around ImageDataset that renames keys and exposes coordinate bounds:

from dataset_classes.base import ImageDataset
import h5py
import json
import torch

class MyTask(ImageDataset):
def __init__(self, h5_path: str):
super().__init__(h5_path, sample_key="image", label_key="label", label_dtype=torch.float32)
with h5py.File(self.h5_path, "r") as h5:
self.coord_min = torch.tensor(json.loads(h5.attrs["coord_nominal_min"]), dtype=torch.float32)
self.coord_max = torch.tensor(json.loads(h5.attrs["coord_nominal_max"]), dtype=torch.float32)

Return type requirement:

  • __getitem__ must return at least (sample, label).
  • If you want to use --stratified-split, label must be a scalar class id (not a vector).

Step 3 — Register the task in data.py

data.py is the registry that connects task ids (CLI strings) to datasets and metadata.

If you haven’t opened it yet, start here: data.py.

3.1 Add the task id

Add your task name to SUPPORTED_TASKS in data.py (this also becomes a valid --task choice in main_finetune.py).

SUPPORTED_TASKS = (
# ...
"mytask",
)

3.2 Map the task to a dataset loader

Add a branch in _dataset_factory() that returns a callable producing a torch.utils.data.Dataset.

Example for a new vision classification task:

if task == "mytask":
return lambda p: ImageDataset(p, sample_key="image", label_key="label")

3.3 Describe the task for the training loop

Add a branch in _infer_task_info() that returns TaskInfo.

For classification you must set:

  • modality (vision or iq)
  • target_type="classification"
  • num_outputs (number of classes)
  • in_chans for vision tasks (channels in your stored tensor)

Example:

if task == "mytask":
num_classes = 12
return TaskInfo(
name=task,
modality="vision",
target_type="classification",
num_outputs=num_classes,
in_chans=1, # your stored (C,H,W)
)

For position tasks, WavesFM expects targets normalized to [-1, 1] and uses coord_min/coord_max to denormalize when computing metrics. You can supply them either:

  • from HDF5 attributes (see how pos reads coord_nominal_min/max), or
  • from fields on your dataset object (see UWBIndoor.loc_min/loc_max).

Position example (reading bounds from your dataset wrapper):

if task == "mytask-pos":
sample, label = dataset[0]
return TaskInfo(
name=task,
modality="vision",
target_type="position",
num_outputs=int(label.numel()),
in_chans=int(sample.shape[0]),
coord_min=dataset.coord_min,
coord_max=dataset.coord_max,
)

3.4 (Optional) Class weights

build_datasets() will attach train_ds.class_weights if your .h5 includes a class_weights attribute. In training, pass --class-weights to enable weighted cross-entropy.


Step 4 — Understand when to modify models_vit.py

Most new tasks do not require changes here.

models_vit.py defines ModalityAdapterViT, which chooses an input adapter at init time:

  • modality="vision" → patch embedding over (C,H,W)
  • modality="iq" → segment-based tokenization over (2,C,T)

That routing happens in the model’s forward() (from models_vit.py):

def forward(self, x: torch.Tensor, time_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.modality == "vision":
tok, token_mask = self._tokens_from_vision(x) # expects (N,C,H,W)
else:
tok, token_mask = self._tokens_from_iq(x, time_mask) # expects (N,2,C,T)
z, token_mask = self.forward_features(tok, token_mask)
return self.forward_head(z, token_mask)

main_finetune.py constructs the model by passing:

  • modality and num_outputs from TaskInfo
  • vis_in_chans_actual from TaskInfo.in_chans (vision tasks)
  • IQ tokenization knobs from CLI (--iq-segment-len, --iq-target-len, --iq-downsample)

Vision channel mismatches are handled by an optional 1×1 “channel adapter” inside the model. In practice, that means:

  • you set TaskInfo.in_chans to match what you stored in HDF5 (e.g. 2 for real/imag CSI grids), and
  • the model will map that to the pretrained patch embed’s expected channels when needed.

You’d typically modify models_vit.py only if:

  • You are adding a new modality (beyond vision/iq).
  • Your IQ data has more antennas than the model config supports (see iq_max_antennas in the vit_multi_* builders).
  • Your IQ token count would exceed iq_max_tokens (the tokenizer truncates tokens past that limit).

Step 5 — Extend the CLI/training loop (main_finetune.py) when needed

For standard tasks, you don’t need to change main_finetune.py: it already calls build_datasets() and uses TaskInfo to pick:

  • model adapter (vision vs iq)
  • loss function (classification vs position/regression)
  • evaluation path (generic classification, positioning metrics, etc.)

The key call chain is:

  • main_finetune.py parses args and calls build_datasets() from data.py
  • build_datasets() returns (train_ds, val_ds, task_info)
  • build_model() uses task_info to build the correct ModalityAdapterViT

You can see it directly in main_finetune.py:

train_ds, val_ds, task_info = build_datasets(
args.task,
args.train_path,
val_path=args.val_path,
val_split=args.val_split,
stratified_split=args.stratified_split,
seed=args.seed,
deepmimo_n_beams=args.deepmimo_n_beams,
)
model = build_model(args, task_info)

You do edit main_finetune.py when your dataset needs a task-specific argument that affects dataset construction. Example pattern already in the codebase: --deepmimo-n-beams gets passed into build_datasets(..., deepmimo_n_beams=...).

If your new task needs custom metrics, add a task-specific branch in engine.py:evaluate() (optional, but common for multi-metric datasets).

Example pattern (from engine.py):

if task_type == "classification" and task_name == "radcom":
return evaluate_radcom(...)

Smoke-test checklist

Before running full sweeps, verify the plumbing in this order:

  1. h5 opens and has expected keys (image/label etc.).
  2. Your dataset returns the right shapes/dtypes for a single item.
  3. python main_finetune.py --task <your_task> --train-data <cache.h5> --val-split 0.2 --epochs 1 --batch-size 2 runs end-to-end.
  4. For classification, confirm labels are contiguous class ids in [0, num_outputs-1].
  5. For position, confirm targets are normalized to [-1, 1] and coord_min/max are correct.

Tip: for a fast “does it run” check, use the small model config and CPU:

python main_finetune.py \
--task <your_task> \
--train-data <cache.h5> \
--val-split 0.2 \
--model vit_multi_micro \
--device cpu \
--epochs 1 \
--batch-size 2