loki2.cl.train_projection_wds

Train ProjectionCL using WebDataset shards.

Module Contents

loki2.cl.train_projection_wds.LOGGER_NAME = 'projection_cl.train'
loki2.cl.train_projection_wds.setup_logging(log_path: pathlib.Path | None) logging.Logger

Set up logging configuration for training.

Parameters:

log_path – Optional path to log file. If None, only console logging.

Returns:

Configured logger instance.

Return type:

logging.Logger

loki2.cl.train_projection_wds.parse_args(argv: Iterable[str] | None = None) argparse.Namespace
loki2.cl.train_projection_wds.load_keys(meta_path: pathlib.Path, limit: int | None) List[str]

Load sample keys from a metadata CSV file.

Parameters:
  • meta_path – Path to the metadata CSV file.

  • limit – Optional maximum number of keys to load. Defaults to None.

Returns:

List of sample keys.

Return type:

List[str]

Raises:

ValueError – If the CSV is missing the ‘key’ column or contains no samples.

loki2.cl.train_projection_wds.select_keys(keys: Sequence[str]) Callable[[Dict[str, Any]], bool]

Create a filter function for WebDataset samples.

Parameters:

keys – Sequence of allowed sample keys.

Returns:

Filter function that returns True if sample key is in the allowed set.

Return type:

Callable

loki2.cl.train_projection_wds.pop_tensor(sample: Dict[str, Any], stem: str) torch.Tensor

Extract and remove a tensor from a sample dictionary.

Parameters:
  • sample – Sample dictionary containing tensor data.

  • stem – Base name for the tensor (e.g., “morph” for “morph.npy” or “morph.pt”).

Returns:

Extracted tensor as float32.

Return type:

torch.Tensor

Raises:

KeyError – If neither .npy nor .pt version of the tensor is found.

loki2.cl.train_projection_wds.decode_sample(sample: Dict[str, Any]) Tuple[torch.Tensor, torch.Tensor]

Decode a WebDataset sample into morphological and transcription tensors.

Parameters:

sample – Sample dictionary from WebDataset.

Returns:

  • Morphological embedding tensor.

  • Transcription embedding tensor.

Return type:

Tuple[torch.Tensor, torch.Tensor]

loki2.cl.train_projection_wds.collate_batch(batch: List[Tuple[torch.Tensor, torch.Tensor]]) Tuple[torch.Tensor, torch.Tensor]

Collate a batch of (morph, trans) tuples into stacked tensors.

Parameters:

batch – List of (morphological_tensor, transcription_tensor) tuples.

Returns:

  • Stacked morphological embeddings.

  • Stacked transcription embeddings.

Return type:

Tuple[torch.Tensor, torch.Tensor]

loki2.cl.train_projection_wds.count_shards(shards: Sequence[str]) int

Count the total number of shards after expanding URLs.

Parameters:

shards – Sequence of shard URLs or patterns.

Returns:

Total number of shards.

Return type:

int

loki2.cl.train_projection_wds.resolve_num_workers(shards: Sequence[str], requested_workers: int) int

Resolve the number of workers based on shard count.

Adjusts the number of workers to not exceed the number of shards.

Parameters:
  • shards – Sequence of shard URLs or patterns.

  • requested_workers – Requested number of workers.

Returns:

Adjusted number of workers (at least 1, at most min(requested, shard_count)).

Return type:

int

loki2.cl.train_projection_wds.create_loader(shards: Sequence[str], keys: Sequence[str], batch_size: int, shuffle_buffer: int, num_workers: int, *, training: bool) webdataset.WebLoader

Create a WebDataset DataLoader for training or validation.

Parameters:
  • shards – Sequence of shard URLs or patterns.

  • keys – Sequence of sample keys to include.

  • batch_size – Batch size for the loader.

  • shuffle_buffer – Size of shuffle buffer (only used if training=True).

  • num_workers – Number of worker processes.

  • training – Whether this is for training (enables shuffling).

Returns:

Configured WebDataset loader.

Return type:

wds.WebLoader

loki2.cl.train_projection_wds.train(args: argparse.Namespace) None

Train a ProjectionCL model using WebDataset shards.

Parameters:

args – Parsed command-line arguments containing training configuration.

Raises:

ValueError – If no training samples are found.

loki2.cl.train_projection_wds.main(argv: Iterable[str] | None = None) None

Main entry point for training ProjectionCL model.

Parameters:

argv – Optional command-line arguments. If None, uses sys.argv.

Raises:

FileNotFoundError – If train metadata CSV is not found.