loki2.cl.train_projection_wds ============================= .. py:module:: loki2.cl.train_projection_wds .. autoapi-nested-parse:: Train ProjectionCL using WebDataset shards. Module Contents --------------- .. py:data:: LOGGER_NAME :value: 'projection_cl.train' .. py:function:: setup_logging(log_path: Optional[pathlib.Path]) -> logging.Logger Set up logging configuration for training. :param log_path: Optional path to log file. If None, only console logging. :returns: Configured logger instance. :rtype: logging.Logger .. py:function:: parse_args(argv: Iterable[str] | None = None) -> argparse.Namespace .. py:function:: load_keys(meta_path: pathlib.Path, limit: int | None) -> List[str] Load sample keys from a metadata CSV file. :param meta_path: Path to the metadata CSV file. :param limit: Optional maximum number of keys to load. Defaults to None. :returns: List of sample keys. :rtype: List[str] :raises ValueError: If the CSV is missing the 'key' column or contains no samples. .. py:function:: select_keys(keys: Sequence[str]) -> Callable[[Dict[str, Any]], bool] Create a filter function for WebDataset samples. :param keys: Sequence of allowed sample keys. :returns: Filter function that returns True if sample key is in the allowed set. :rtype: Callable .. py:function:: pop_tensor(sample: Dict[str, Any], stem: str) -> torch.Tensor Extract and remove a tensor from a sample dictionary. :param sample: Sample dictionary containing tensor data. :param stem: Base name for the tensor (e.g., "morph" for "morph.npy" or "morph.pt"). :returns: Extracted tensor as float32. :rtype: torch.Tensor :raises KeyError: If neither .npy nor .pt version of the tensor is found. .. py:function:: decode_sample(sample: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor] Decode a WebDataset sample into morphological and transcription tensors. :param sample: Sample dictionary from WebDataset. :returns: - Morphological embedding tensor. - Transcription embedding tensor. :rtype: Tuple[torch.Tensor, torch.Tensor] .. py:function:: collate_batch(batch: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor] Collate a batch of (morph, trans) tuples into stacked tensors. :param batch: List of (morphological_tensor, transcription_tensor) tuples. :returns: - Stacked morphological embeddings. - Stacked transcription embeddings. :rtype: Tuple[torch.Tensor, torch.Tensor] .. py:function:: count_shards(shards: Sequence[str]) -> int Count the total number of shards after expanding URLs. :param shards: Sequence of shard URLs or patterns. :returns: Total number of shards. :rtype: int .. py:function:: resolve_num_workers(shards: Sequence[str], requested_workers: int) -> int Resolve the number of workers based on shard count. Adjusts the number of workers to not exceed the number of shards. :param shards: Sequence of shard URLs or patterns. :param requested_workers: Requested number of workers. :returns: Adjusted number of workers (at least 1, at most min(requested, shard_count)). :rtype: int .. py:function:: create_loader(shards: Sequence[str], keys: Sequence[str], batch_size: int, shuffle_buffer: int, num_workers: int, *, training: bool) -> webdataset.WebLoader Create a WebDataset DataLoader for training or validation. :param shards: Sequence of shard URLs or patterns. :param keys: Sequence of sample keys to include. :param batch_size: Batch size for the loader. :param shuffle_buffer: Size of shuffle buffer (only used if training=True). :param num_workers: Number of worker processes. :param training: Whether this is for training (enables shuffling). :returns: Configured WebDataset loader. :rtype: wds.WebLoader .. py:function:: train(args: argparse.Namespace) -> None Train a ProjectionCL model using WebDataset shards. :param args: Parsed command-line arguments containing training configuration. :raises ValueError: If no training samples are found. .. py:function:: main(argv: Iterable[str] | None = None) -> None Main entry point for training ProjectionCL model. :param argv: Optional command-line arguments. If None, uses sys.argv. :raises FileNotFoundError: If train metadata CSV is not found.