Skip to content

API documentation

align

The module aligns a reference sequence to a read sequence using Parasail. The module also provides functions to generate alignment strings and chunks for pretty printing.

Author: Adnan M. Niazi Date: 2024-02-28

PairwiseAlignment dataclass

Pairwise alignment with semi-global alignment allowing for gaps at the start and end of the query sequence.

Source code in src/capfinder/align.py
@dataclass
class PairwiseAlignment:
    """
    Pairwise alignment with semi-global alignment allowing for gaps at the
    start and end of the query sequence.
    """

    ref_start: int
    ref_end: int
    query_start: int
    query_end: int
    cigar_pysam: CigarTuplesPySam
    cigar_sam: CigarTuplesSam

    def __init__(
        self,
        ref_start: int,
        ref_end: int,
        query_start: int,
        query_end: int,
        cigar_pysam: CigarTuplesPySam,
        cigar_sam: CigarTuplesSam,
    ):
        """
        Initializes a PairwiseAlignment object.

        Args:
            ref_start (int): The starting position of the alignment in the reference sequence.
            ref_end (int): The ending position of the alignment in the reference sequence.
            query_start (int): The starting position of the alignment in the query sequence.
            query_end (int): The ending position of the alignment in the query sequence.
            cigar_pysam (CigarTuplesPySam): A list of tuples representing the CIGAR string in the Pysam format.
            cigar_sam (CigarTuplesSam): A list of tuples representing the CIGAR string in the SAM format.
        """
        self.ref_start = ref_start
        self.ref_end = ref_end
        self.query_start = query_start
        self.query_end = query_end
        self.cigar_pysam = cigar_pysam
        self.cigar_sam = cigar_sam

__init__(ref_start: int, ref_end: int, query_start: int, query_end: int, cigar_pysam: CigarTuplesPySam, cigar_sam: CigarTuplesSam)

Initializes a PairwiseAlignment object.

Parameters:

Name Type Description Default
ref_start int

The starting position of the alignment in the reference sequence.

required
ref_end int

The ending position of the alignment in the reference sequence.

required
query_start int

The starting position of the alignment in the query sequence.

required
query_end int

The ending position of the alignment in the query sequence.

required
cigar_pysam CigarTuplesPySam

A list of tuples representing the CIGAR string in the Pysam format.

required
cigar_sam CigarTuplesSam

A list of tuples representing the CIGAR string in the SAM format.

required
Source code in src/capfinder/align.py
def __init__(
    self,
    ref_start: int,
    ref_end: int,
    query_start: int,
    query_end: int,
    cigar_pysam: CigarTuplesPySam,
    cigar_sam: CigarTuplesSam,
):
    """
    Initializes a PairwiseAlignment object.

    Args:
        ref_start (int): The starting position of the alignment in the reference sequence.
        ref_end (int): The ending position of the alignment in the reference sequence.
        query_start (int): The starting position of the alignment in the query sequence.
        query_end (int): The ending position of the alignment in the query sequence.
        cigar_pysam (CigarTuplesPySam): A list of tuples representing the CIGAR string in the Pysam format.
        cigar_sam (CigarTuplesSam): A list of tuples representing the CIGAR string in the SAM format.
    """
    self.ref_start = ref_start
    self.ref_end = ref_end
    self.query_start = query_start
    self.query_end = query_end
    self.cigar_pysam = cigar_pysam
    self.cigar_sam = cigar_sam

align(query_seq: str, target_seq: str, pretty_print_alns: bool) -> Tuple[str, str, str, int]

Main function call to align two sequences and print the alignment.

Parameters:

Name Type Description Default
query_seq str

The query sequence.

required
target_seq str

The target/reference sequence.

required
pretty_print_alns bool

Whether to print the alignment in a pretty format.

required

Returns:

Type Description
Tuple[str, str, str, int]

Tuple[str, str, str]: A tuple containing three strings: 1. The aligned query sequence with gaps. 2. The visual representation of the alignment with '|' for matches, '/' for mismatches, and ' ' for gaps or insertions. 3. The aligned target sequence with gaps. 4. The alignment score.

Source code in src/capfinder/align.py
def align(
    query_seq: str, target_seq: str, pretty_print_alns: bool
) -> Tuple[str, str, str, int]:
    """
    Main function call to align two sequences and print the alignment.

    Args:
        query_seq (str): The query sequence.
        target_seq (str): The target/reference sequence.
        pretty_print_alns (bool): Whether to print the alignment in a pretty format.

    Returns:
        Tuple[str, str, str]: A tuple containing three strings:
            1. The aligned query sequence with gaps.
            2. The visual representation of the alignment with '|' for matches, '/' for mismatches,
                and ' ' for gaps or insertions.
            3. The aligned target sequence with gaps.
            4. The alignment score.

    """
    # Perform the alignment
    alignment = parasail_align(query=query_seq, ref=target_seq)
    alignment_score = alignment.score
    alignment = trim_parasail_alignment(alignment)
    # Generate the aligned strings
    aln_query, aln, aln_target = make_alignment_strings(
        query_seq, target_seq, alignment
    )

    # Print the alignment in a pretty format if required
    if pretty_print_alns:
        print("Alignment score:", alignment_score)
        chunked_aln_str = make_alignment_chunks(
            aln_target, aln_query, aln, chunk_size=40
        )
        print(chunked_aln_str)
        return (
            "",
            "",
            chunked_aln_str,
            alignment_score,
        )
    else:
        return aln_query, aln, aln_target, alignment_score

cigartuples_from_string(cigarstring: str) -> CigarTuplesPySam

Returns pysam-style list of (op, count) tuples from a cigarstring.

Source code in src/capfinder/align.py
def cigartuples_from_string(cigarstring: str) -> CigarTuplesPySam:
    """
    Returns pysam-style list of (op, count) tuples from a cigarstring.
    """
    return [
        (CODE_TO_OP[m.group(2)], int(m.group(1)))
        for m in re.finditer(CIGAR_STRING_PATTERN, cigarstring)
    ]

make_alignment_chunks(target: str, query: str, alignment: str, chunk_size: int) -> str

Divide three strings (target, query, and alignment) into chunks of the specified length and print them as triplets with the specified prefixes and a one-line gap between each triplet.

Parameters:

Name Type Description Default
target str

The target/reference string.

required
query str

The query string.

required
alignment str

The alignment string.

required
chunk_size int

The desired chunk size.

required

Returns:

Name Type Description
aln_string str

The aligned strings in chunks with the specified prefix.

Source code in src/capfinder/align.py
def make_alignment_chunks(
    target: str, query: str, alignment: str, chunk_size: int
) -> str:
    """
    Divide three strings (target, query, and alignment) into chunks of the specified length
    and print them as triplets with the specified prefixes and a one-line gap between each triplet.

    Args:
        target (str): The target/reference string.
        query (str): The query string.
        alignment (str): The alignment string.
        chunk_size (int): The desired chunk size.

    Returns:
        aln_string (str): The aligned strings in chunks with the specified prefix.
    """
    # Check if chunk size is valid
    if chunk_size <= 0:
        raise ValueError("Chunk size must be greater than zero")

    # Divide the strings into chunks
    target_chunks = [
        target[i : i + chunk_size] for i in range(0, len(target), chunk_size)
    ]
    query_chunks = [query[i : i + chunk_size] for i in range(0, len(query), chunk_size)]
    alignment_chunks = [
        alignment[i : i + chunk_size] for i in range(0, len(alignment), chunk_size)
    ]

    # Iterate over the triplets and print them
    aln_string = ""
    for t_chunk, q_chunk, a_chunk in zip(target_chunks, query_chunks, alignment_chunks):
        aln_string += f"QRY: {q_chunk}\n"
        aln_string += f"ALN: {a_chunk}\n"
        aln_string += f"REF: {t_chunk}\n\n"

    return aln_string

make_alignment_strings(query: str, target: str, alignment: PairwiseAlignment) -> Tuple[str, str, str]

Generate alignment strings for the given query and target sequences based on a PairwiseAlignment object.

Parameters:

Name Type Description Default
query str

The query sequence.

required
target str

The target/reference sequence.

required
alignment PairwiseAlignment

An object representing the alignment between query and target sequences.

required

Returns:

Type Description
Tuple[str, str, str]

Tuple[str, str, str]: A tuple containing three strings: 1. The aligned target sequence with gaps. 2. The aligned query sequence with gaps. 3. The visual representation of the alignment with '|' for matches, '/' for mismatches, and ' ' for gaps or insertions.

Source code in src/capfinder/align.py
def make_alignment_strings(
    query: str, target: str, alignment: PairwiseAlignment
) -> Tuple[str, str, str]:
    """
    Generate alignment strings for the given query and target sequences based on a PairwiseAlignment object.

    Args:
        query (str): The query sequence.
        target (str): The target/reference sequence.
        alignment (PairwiseAlignment): An object representing the alignment between query and target sequences.

    Returns:
        Tuple[str, str, str]: A tuple containing three strings:
            1. The aligned target sequence with gaps.
            2. The aligned query sequence with gaps.
            3. The visual representation of the alignment with '|' for matches, '/' for mismatches,
               and ' ' for gaps or insertions.
    """
    ref_start = alignment.ref_start
    ref_end = alignment.ref_end
    query_start = alignment.query_start
    cigar_sam = alignment.cigar_sam

    # Initialize the strings
    aln_query = ""
    aln_target = ""
    aln = ""
    target_count = 0
    query_count = 0

    # Handle the start
    if query_start != 0:
        aln_target += "-" * query_start
        aln_query += query[:query_start]
        aln += " " * query_start
        query_count += query_start

    if ref_start != 0:
        aln_target += target[:ref_start]
        target_count = ref_start
        aln_query += "-" * ref_start
        aln += " " * ref_start

    # Handle the middle
    for operation, length in cigar_sam:
        # Match: advance both target and query counts
        if operation in ("=", "M"):
            aln_target += target[target_count : target_count + length]
            aln_query += query[query_count : query_count + length]
            aln += "|" * length
            target_count += length
            query_count += length

        # Insertion: advance query count only
        elif operation == "I":
            aln_target += "-" * length
            aln_query += query[query_count : query_count + length]
            aln += " " * length
            query_count += length

        # Deletion or gaps: advance target count only
        # see: https://jef.works/blog/2017/03/28/CIGAR-strings-for-dummies/
        elif operation in ("D", "N"):
            aln_target += target[target_count : target_count + length]
            aln_query += "-" * length
            aln += " " * length
            target_count += length

        # Mismatch: advance both target and query counts
        elif operation == "X":
            aln_target += target[target_count : target_count + length]
            aln_query += query[query_count : query_count + length]
            aln += "/" * length
            target_count += length
            query_count += length

    # Handle the end
    ql = len(query)
    tl = len(target)
    target_remainder = tl - ref_end
    if target_remainder:
        aln_target += target[target_count:]
        aln_query += target_remainder * "-"
        aln += target_remainder * " "

    end_dash_len = ql - query_count
    if end_dash_len:
        aln_target += "-" * end_dash_len
        aln_query += query[query_count:]
        aln += " " * end_dash_len
        query_count += query_start

    return aln_query, aln, aln_target

parasail_align(*, query: str, ref: str) -> Any

Semi-global alignment allowing for gaps at the start and end of the query sequence.

:param query: str :param ref: str :return: PairwiseAlignment

Source code in src/capfinder/align.py
def parasail_align(*, query: str, ref: str) -> Any:
    """
    Semi-global alignment allowing for gaps at the start and end of the query
    sequence.

    :param query: str
    :param ref: str
    :return: PairwiseAlignment
    """
    alignment_result = parasail.sg_trace_scan_32(query, ref, 10, 2, parasail.dnafull)
    return alignment_result

trim_parasail_alignment(alignment_result: Any) -> PairwiseAlignment

Trim the alignment result to remove leading and trailing gaps.

Source code in src/capfinder/align.py
def trim_parasail_alignment(alignment_result: Any) -> PairwiseAlignment:
    """
    Trim the alignment result to remove leading and trailing gaps.
    """

    try:
        ref_start = 0
        ref_end = alignment_result.len_ref
        query_start = 0
        query_end = alignment_result.len_query
        fixed_start = False
        fixed_end = False

        cigar_string = alignment_result.cigar.decode.decode()
        cigar_tuples = deque(cigartuples_from_string(cigar_string))

        while not (fixed_start and fixed_end):
            first_op, first_length = cigar_tuples[0]
            if first_op in (1, 4):  # insert, soft-clip, increment query start
                query_start += first_length
                cigar_tuples.popleft()
            elif first_op == 2:  # delete, increment reference start
                ref_start += first_length
                cigar_tuples.popleft()
            else:
                fixed_start = True

            last_op, last_length = cigar_tuples[-1]
            if last_op in (1, 4):  # decrement the query end
                query_end -= last_length
                cigar_tuples.pop()
            elif last_op == 2:  # decrement the ref_end
                ref_end -= last_length
                cigar_tuples.pop()
            else:
                fixed_end = True

        cigar_pysam = list(cigar_tuples)
        cigar_sam = [(OP_TO_CODE[str(k)], v) for k, v in cigar_pysam]

        return PairwiseAlignment(
            ref_start=ref_start,
            ref_end=ref_end,
            query_start=query_start,
            query_end=query_end,
            cigar_pysam=cigar_pysam,
            cigar_sam=cigar_sam,
        )
    except IndexError as e:
        raise RuntimeError(
            "failed to find match operations in pairwise alignment"
        ) from e

attention_cnnlstm_model

CapfinderHyperModel

Bases: HyperModel

Hypermodel for the Capfinder CNN-LSTM with Attention architecture.

This model is designed for time series classification tasks, specifically for identifying RNA cap types. It combines Convolutional Neural Networks (CNNs) for local feature extraction, Long Short-Term Memory (LSTM) networks for sequence processing, and an attention mechanism to focus on the most relevant parts of the input sequence.

The architecture is flexible and allows for hyperparameter tuning of the number of layers, units, and other key parameters.

Attributes:

Name Type Description
input_shape Tuple[int, ...]

The shape of the input data.

n_classes int

The number of classes for classification.

encoder_model Optional[Model]

Placeholder for a potential encoder model.

Methods:

Name Description
build

Constructs and returns a Keras model based on the provided hyperparameters.

Source code in src/capfinder/attention_cnnlstm_model.py
class CapfinderHyperModel(HyperModel):
    """
    Hypermodel for the Capfinder CNN-LSTM with Attention architecture.

    This model is designed for time series classification tasks, specifically for
    identifying RNA cap types. It combines Convolutional Neural Networks (CNNs) for
    local feature extraction, Long Short-Term Memory (LSTM) networks for sequence
    processing, and an attention mechanism to focus on the most relevant parts of
    the input sequence.

    The architecture is flexible and allows for hyperparameter tuning of the number
    of layers, units, and other key parameters.

    Attributes:
        input_shape (Tuple[int, ...]): The shape of the input data.
        n_classes (int): The number of classes for classification.
        encoder_model (Optional[Model]): Placeholder for a potential encoder model.

    Methods:
        build(hp): Constructs and returns a Keras model based on the provided
                   hyperparameters.
    """

    def __init__(self, input_shape: Tuple[int, int], n_classes: int) -> None:
        self.input_shape = input_shape
        self.n_classes = n_classes
        self.encoder_model = None

    def build(self, hp: Any) -> Model:
        inputs = Input(shape=self.input_shape)
        x = inputs

        # Calculate the maximum number of conv layers based on input size
        max_conv_layers = min(
            int(math.log2(self.input_shape[0])) - 1, 5
        )  # Limit to 5 layers max
        conv_layers = hp.Int("conv_layers", 1, max_conv_layers)

        # Convolutional layers
        for i in range(conv_layers):
            # Dynamically adjust the range for filters based on the layer depth
            max_filters = min(256, 32 * (2 ** (i + 1)))
            filters = hp.Int(f"filters_{i}", 32, max_filters, step=32)

            # Dynamically adjust the kernel size based on the current feature map size
            current_size = x.shape[1]
            max_kernel_size = min(7, current_size)
            kernel_size = hp.Choice(
                f"kernel_size_{i}", list(range(3, max_kernel_size + 1, 2))
            )

            x = Conv1D(
                filters=filters,
                kernel_size=kernel_size,
                activation="relu",
                padding="same",
            )(x)

            # Only apply MaxPooling if the current size is greater than 2
            if current_size > 2:
                x = MaxPooling1D(pool_size=2)(x)

            x = Dropout(hp.Float(f"dropout_{i}", 0.1, 0.5, step=0.1))(x)
            x = BatchNormalization()(x)

        # Calculate the maximum number of LSTM layers based on remaining sequence length
        current_seq_length = x.shape[1]
        max_lstm_layers = min(
            int(math.log2(current_seq_length)) + 1, 3
        )  # Limit to 3 LSTM layers max
        lstm_layers = hp.Int("lstm_layers", 1, max_lstm_layers)

        # LSTM layers
        for i in range(lstm_layers):
            return_sequences = i < lstm_layers - 1
            max_lstm_units = min(256, 32 * (2 ** (i + 1)))
            lstm_units = hp.Int(f"lstm_units_{i}", 32, max_lstm_units, step=32)
            x = LSTM(
                units=lstm_units,
                return_sequences=return_sequences or i == lstm_layers - 1,
            )(x)
            x = Dropout(hp.Float(f"lstm_dropout_{i}", 0.1, 0.5, step=0.1))(x)
            x = BatchNormalization()(x)

        # Attention layer
        x = AttentionLayer()(x)

        # Fully connected layer
        max_dense_units = min(256, x.shape[-1] * 2)
        dense_units = hp.Int("dense_units", 16, max_dense_units, step=16)
        x = Dense(units=dense_units, activation="relu")(x)
        x = Dropout(hp.Float("dense_dropout", 0.1, 0.5, step=0.1))(x)
        outputs = Dense(self.n_classes, activation="softmax")(x)

        model = Model(inputs=inputs, outputs=outputs)

        model.compile(
            optimizer=Adam(
                learning_rate=hp.Choice("learning_rate", [1e-2, 1e-3, 1e-4])
            ),
            loss="sparse_categorical_crossentropy",
            metrics=["sparse_categorical_accuracy"],
        )

        return model

bam

We can only read BAM records one at a time from a BAM file. PySAM does not allow random access of BAM records. The module prepares and yields the BAM record information for each read.

Author: Adnan M. Niazi Date: 2024-02-28

generate_bam_records(bam_filepath: str) -> Generator[pysam.AlignedSegment, None, None]

Yield each record from a BAM file. Also creates an index (.bai) file if one does not exist already.

Parameters:

Name Type Description Default
bam_filepath str

str Path to the BAM file.

required

Yields:

Name Type Description
record AlignedSegment

pysam.AlignedSegment A BAM record.

Source code in src/capfinder/bam.py
def generate_bam_records(
    bam_filepath: str,
) -> Generator[pysam.AlignedSegment, None, None]:
    """Yield each record from a BAM file. Also creates an index (.bai)
    file if one does not exist already.

    Params:
        bam_filepath: str
            Path to the BAM file.

    Yields:
        record: pysam.AlignedSegment
            A BAM record.
    """
    index_filepath = f"{bam_filepath}.bai"

    if not os.path.exists(index_filepath):
        pysam.index(bam_filepath)  # type: ignore

    with pysam.AlignmentFile(bam_filepath, "rb") as bam_file:
        yield from bam_file

get_signal_info(record: pysam.AlignedSegment) -> Dict[str, Any]

Returns the signal info from a BAM record.

Parameters:

Name Type Description Default
record AlignedSegment

pysam.AlignedSegment A BAM record.

required

Returns:

Name Type Description
signal_info Dict[str, Any]

Dict[str, Any] Dictionary containing signal info for a read.

Source code in src/capfinder/bam.py
def get_signal_info(record: pysam.AlignedSegment) -> Dict[str, Any]:
    """Returns the signal info from a BAM record.

    Params:
        record: pysam.AlignedSegment
            A BAM record.

    Returns:
        signal_info: Dict[str, Any]
            Dictionary containing signal info for a read.
    """
    signal_info = {}
    tags_dict = dict(record.tags)  # type: ignore
    moves_table = tags_dict["mv"]
    moves_step = moves_table.pop(0)
    signal_info["moves_table"] = moves_table
    signal_info["moves_step"] = moves_step
    signal_info["read_id"] = record.query_name
    signal_info["start_sample"] = tags_dict["ts"]
    signal_info["num_samples"] = tags_dict["ns"]
    signal_info["quality_score"] = tags_dict["qs"]
    signal_info["channel"] = tags_dict["ch"]
    signal_info["signal_mean"] = tags_dict["sm"]
    signal_info["signal_sd"] = tags_dict["sd"]
    signal_info["is_qcfail"] = record.is_qcfail
    signal_info["is_reverse"] = record.is_reverse
    signal_info["is_forward"] = record.is_forward
    signal_info["is_mapped"] = record.is_mapped
    signal_info["is_supplementary"] = record.is_supplementary
    signal_info["is_secondary"] = record.is_secondary
    signal_info["read_quality"] = record.qual  # type: ignore
    signal_info["read_fasta"] = record.query_sequence
    signal_info["mapping_quality"] = record.mapping_quality
    signal_info["parent_read_id"] = tags_dict.get("pi", "")
    signal_info["split_point"] = tags_dict.get("sp", 0)
    signal_info["time_stamp"] = tags_dict.get("st")
    signal_info["pod5_filename"] = tags_dict.get("fn")
    (
        signal_info["num_left_clipped_bases"],
        signal_info["num_right_clipped_bases"],
    ) = find_hard_clipped_bases(str(record.cigarstring))
    return signal_info

get_total_records(bam_filepath: str) -> int

Returns the total number of records in a BAM file.

Parameters:

Name Type Description Default
bam_filepath str

str Path to the BAM file.

required

Returns:

Name Type Description
total_records int

int Total number of records in the BAM file.

Source code in src/capfinder/bam.py
def get_total_records(bam_filepath: str) -> int:
    """Returns the total number of records in a BAM file.

    Params:
        bam_filepath: str
            Path to the BAM file.

    Returns:
        total_records: int
            Total number of records in the BAM file.
    """
    bam_file = pysam.AlignmentFile(bam_filepath)
    total_records = sum(1 for _ in bam_file)
    bam_file.close()
    return total_records

process_bam_records(bam_filepath: str) -> Generator[Dict[str, Any], None, None]

Top level function to process a BAM file. Yields signal info for each read in the BAM file.

Parameters:

Name Type Description Default
bam_filepath str

str Path to the BAM file to process.

required

Yields:

Name Type Description
signal_info Dict[str, Any]

Generator[Dict[str, Any], None, None] Dictionary containing signal info for a read.

Source code in src/capfinder/bam.py
def process_bam_records(bam_filepath: str) -> Generator[Dict[str, Any], None, None]:
    """Top level function to process a BAM file.
    Yields signal info for each read in the BAM file.

    Params:
        bam_filepath: str
            Path to the BAM file to process.

    Yields:
        signal_info: Generator[Dict[str, Any], None, None]
            Dictionary containing signal info for a read.
    """
    for record in generate_bam_records(bam_filepath):
        yield get_signal_info(record)

cli

add_cap(cap_int: int, cap_name: str) -> None

Add a new cap mapping or update an existing one.

Source code in src/capfinder/cli.py
@caps_app.command("add")
def add_cap(cap_int: int, cap_name: str) -> None:
    """Add a new cap mapping or update an existing one."""
    global CAP_MAPPING

    next_available = get_next_available_cap_number()

    # Check if the cap name is unique
    existing_cap_int = is_cap_name_unique(cap_name)
    if existing_cap_int is not None and existing_cap_int != cap_int:
        typer.echo(
            f"Error: The cap name '{cap_name}' is already used for cap number {existing_cap_int}."
        )
        typer.echo("Please use a unique name for each cap.")
        return

    if cap_int in CAP_MAPPING:
        update_cap_mapping({cap_int: cap_name})
        typer.echo(f"Updated existing mapping: {cap_int} -> {cap_name}")
    elif cap_int == next_available:
        update_cap_mapping({cap_int: cap_name})
        typer.echo(f"Added new mapping: {cap_int} -> {cap_name}")
    else:
        typer.echo(f"Error: The next available cap number is {next_available}.")
        typer.echo(
            f"Please use {next_available} as the cap number to maintain continuity."
        )
        return
    typer.echo(f"Custom mappings saved to: {CUSTOM_MAPPING_PATH}")

cap_help() -> None

Display help information about cap mapping management.

Source code in src/capfinder/cli.py
@caps_app.command("help")
def cap_help() -> None:
    """Display help information about cap mapping management."""
    typer.echo("Cap Mapping Management Help")
    typer.echo("----------------------------")
    typer.echo(
        "Capfinder allows you to customize cap mappings. These mappings persist across runs."
    )
    typer.echo(f"\nYour custom mappings are stored in: {CUSTOM_MAPPING_PATH}")
    typer.echo("\nAvailable commands:")
    typer.echo("  capfinder capmap add <int> <name>  : Add or update a cap mapping")
    typer.echo("  capfinder capmap remove <int>      : Remove a cap mapping")
    typer.echo("  capfinder capmap list              : List all current cap mappings")
    typer.echo("  capfinder capmap reset             : Reset cap mappings to default")
    typer.echo(
        "  capfinder capmap config            : Show the location of the configuration file"
    )
    typer.echo("\nExamples:")
    typer.echo("  capfinder capmap add 7 new_cap_type")
    typer.echo("  capfinder capmap remove 7")
    typer.echo("  capfinder capmap list")
    typer.echo(
        "\nNote: Changes to cap mappings are immediately saved and will persist across runs."
    )
    typer.echo(
        "When adding a new cap, you must use the next available number in the sequence."
    )

create_train_config(file_path: Annotated[str, typer.Option(--file_path, -f, help='File path to save the JSON configuration file')] = '') -> None

Creates a dummy JSON configuration file at the specified path. Edit it to suit your needs.

Source code in src/capfinder/cli.py
@app.command()
def create_train_config(
    file_path: Annotated[
        str,
        typer.Option(
            "--file_path", "-f", help="File path to save the JSON configuration file"
        ),
    ] = "",
) -> None:
    """Creates a dummy JSON configuration file at the specified path. Edit it to suit your needs."""
    config = {
        "etl_params": {
            "use_remote_dataset_version": "latest",  # Version of the remote dataset to use, e.g., "latest", "1.0.0", etc. If set to "", then a local dataset will be used/made and/or uploaded to the remote dataset
            "caps_data_dir": "/dir/",  # Directory containing cap signal data files for all cap classes in the model
            "examples_per_class": 100000,  # Maximum number of examples to use per class
            "comet_project_name": "dataset",  # Name of the Comet ML project for dataset logging
        },
        "tune_params": {
            "comet_project_name": "capfinder_tune",  # Name of the Comet ML project for hyperparameter tuning
            "patience": 0,  # Number of epochs with no improvement after which training will be stopped
            "max_epochs_hpt": 3,  # Maximum number of epochs for each trial during hyperparameter tuning
            "max_trials": 5,  # Maximum number of trials for hyperparameter search
            "factor": 2,  # Reduction factor for Hyperband algorithm
            "seed": 42,  # Random seed for reproducibility
            "tuning_strategy": "hyperband",  # Options: "hyperband", "random_search", "bayesian_optimization"
            "overwrite": False,  # Whether to overwrite previous tuning results. All hyperparameter tuning results will be lost if set to True
        },  # Added comma here
        "train_params": {
            "comet_project_name": "capfinder_train",  # Name of the Comet ML project for model training
            "patience": 120,  # Number of epochs with no improvement after which training will be stopped
            "max_epochs_final_model": 300,  # Maximum number of epochs for training the final model
        },
        "shared_params": {
            "num_classes": 4,  # Number of classes in the dataset
            "model_type": "cnn_lstm",  # Options: "attention_cnn_lstm", "cnn_lstm", "encoder", "resnet"
            "batch_size": 32,  # Batch size for training
            "target_length": 500,  # Target length for input sequences
            "dtype": "float16",  # Data type for model parameters. Options: "float16", "float32", "float64"
            "train_test_fraction": 0.95,  # Fraction of total data to use for training (vs. testing)
            "train_val_fraction": 0.8,  # Fraction of training data to use for training (vs. validation)
            "use_augmentation": False,  # Whether to include time warped versions of original training examples in the dataset
            "output_dir": "/dir/",  # Directory to save output files
        },
        "lr_scheduler_params": {
            "type": "reduce_lr_on_plateau",  # Options: "reduce_lr_on_plateau", "cyclic_lr", "sgdr"
            "reduce_lr_on_plateau": {
                "factor": 0.5,  # Factor by which the learning rate will be reduced
                "patience": 5,  # Number of epochs with no improvement after which learning rate will be reduced
                "min_lr": 1e-6,  # Lower bound on the learning rate
            },
            "cyclic_lr": {
                "base_lr": 1e-3,  # Initial learning rate which is the lower boundary in the cycle
                "max_lr": 5e-2,  # Upper boundary in the cycle for learning rate
                "step_size_factor": 8,  # Number of training iterations in the increasing half of a cycle
                "mode": "triangular2",  # One of {triangular, triangular2, exp_range}
            },
            "sgdr": {
                "min_lr": 1e-3,  # Minimum learning rate
                "max_lr": 2e-3,  # Maximum learning rate
                "lr_decay": 0.9,  # Decay factor for learning rate
                "cycle_length": 5,  # Number of epochs in a cycle
                "mult_factor": 1.5,  # Multiplication factor for cycle length after each restart
            },
        },
        "debug_code": False,  # Whether to run in debug mode
    }
    import json

    from capfinder.logger_config import configure_logger, configure_prefect_logging
    from capfinder.utils import log_header, log_output

    log_filepath = configure_logger(
        os.path.join(os.path.dirname(file_path), "logs"), show_location=False
    )
    configure_prefect_logging(show_location=False)
    version_info = version("capfinder")
    log_header(f"Using Capfinder v{version_info}")

    with open(file_path, "w") as file:
        json.dump(config, file, indent=4)

    grey = "\033[90m"
    reset = "\033[0m"
    log_output(
        f"The training config JSON file has been saved to:\n {grey}{file_path}{reset}\nThe log file has been saved to:\n {grey}{log_filepath}{reset}"
    )
    log_header("Processing finished!")

extract_cap_signal(bam_filepath: Annotated[str, typer.Option(--bam_filepath, -b, help='Path to the BAM file')] = '', pod5_dir: Annotated[str, typer.Option(--pod5_dir, -p, help='Path to directory containing POD5 files')] = '', reference: Annotated[str, typer.Option(--reference, -r, help="Reference Sequence (5' -> 3')")] = 'GCTTTCGTTCGTCTCCGGACTTATCGCACCACCTATCCATCATCAGTACTGT', cap_class: Annotated[int, typer.Option(--cap_class, -c, help='\n\n Integer-based class label for the RNA cap type. \n\n - -99 represents an unknown cap(s). \n\n - 0 represents Cap_0 \n\n - 1 represents Cap 1 \n\n - 2 represents Cap 2 \n\n - 3 represents Cap2-1 \n\n You can use the capmap command to manage cap mappings and use additional interger label for additional caps. \n\n ')] = -99, cap_n1_pos0: Annotated[int, typer.Option(--cap_n1_pos0, -p, help='0-based index of 1st nucleotide (N1) of cap in the reference')] = 52, train_or_test: Annotated[str, typer.Option(--train_or_test, -t, help='set to train or test depending on whether it is training or testing data')] = 'test', output_dir: Annotated[str, typer.Option(--output_dir, -o, help=textwrap.dedent('\n Path to the output directory which will contain: \n\n ├── A CSV file (data__cap_x.csv) containing the extracted ROI signal data.\n\n ├── A CSV file (metadata__cap_x.csv) containing the complete metadata information.\n\n ├── A log file (capfinder_vXYZ_datatime.log) containing the logs of the program.\n\n └── (Optional) plots directory containing cap signal plots, if --plot-signal is used.\n\n \u200b ├── good_reads: Directory that contains the plots for the good reads.\n\n \u200b ├── bad_reads: Directory that contains the plots for the bad reads.\n\n \u200b └── plotpaths.csv: CSV file containing the paths to the plots based on the read ID.\n'))] = '', n_workers: Annotated[int, typer.Option(--n_workers, -n, help='Number of CPUs to use for parallel processing')] = 1, plot_signal: Annotated[Optional[bool], typer.Option(--plot - signal / --no - plot - signal, help='Whether to plot extracted cap signal or not')] = None, debug_code: Annotated[bool, typer.Option(--debug / --no - debug, help='Enable debug mode for more detailed logging')] = False) -> None

Extracts signal corresponding to the RNA cap type using BAM and POD5 files. Also, generates plots if required.

Example command (for training data): capfinder extract-cap-signal \ --bam_filepath /path/to/sorted.bam \ --pod5_dir /path/to/pod5_dir \ --reference GCTTTCGTTCGTCTCCGGACTTATCGCACCACCTATCCATCATCAGTACTGTNNNNNNCGATGTAACTGGGACATGGTGAGCAATCAGGGAAAAAAAAAAAAAAA \ --cap_class 0 \ --cap_n1_pos0 52 \ --train_or_test train \ --output_dir /path/to/output_dir \ --n_workers 10 \ --no-plot-signal \ --no-debug

Example command (for testing data): capfinder extract-cap-signal \ --bam_filepath /path/to/sorted.bam \ --pod5_dir /path/to/pod5_dir \ --reference GCTTTCGTTCGTCTCCGGACTTATCGCACCACCTATCCATCATCAGTACTGT \ --cap_class -99 \ --cap_n1_pos0 52 \ --train_or_test test \ --output_dir /path/to/output_dir \ --n_workers 10 \ --no-plot-signal \ --no-debug

Source code in src/capfinder/cli.py
@app.command()
def extract_cap_signal(
    bam_filepath: Annotated[
        str, typer.Option("--bam_filepath", "-b", help="Path to the BAM file")
    ] = "",
    pod5_dir: Annotated[
        str,
        typer.Option(
            "--pod5_dir", "-p", help="Path to directory containing POD5 files"
        ),
    ] = "",
    reference: Annotated[
        str,
        typer.Option("--reference", "-r", help="Reference Sequence (5' -> 3')"),
    ] = "GCTTTCGTTCGTCTCCGGACTTATCGCACCACCTATCCATCATCAGTACTGT",
    cap_class: Annotated[
        int,
        typer.Option(
            "--cap_class",
            "-c",
            help="""\n
    Integer-based class label for the RNA cap type. \n
    - -99 represents an unknown cap(s). \n
    - 0 represents Cap_0 \n
    - 1 represents Cap 1 \n
    - 2 represents Cap 2 \n
    - 3 represents Cap2-1 \n
    You can use the capmap command to manage cap mappings and use additional interger label for additional caps. \n
    """,
        ),
    ] = -99,
    cap_n1_pos0: Annotated[
        int,
        typer.Option(
            "--cap_n1_pos0",
            "-p",
            help="0-based index of 1st nucleotide (N1) of cap in the reference",
        ),
    ] = 52,
    train_or_test: Annotated[
        str,
        typer.Option(
            "--train_or_test",
            "-t",
            help="set to train or test depending on whether it is training or testing data",
        ),
    ] = "test",
    output_dir: Annotated[
        str,
        typer.Option(
            "--output_dir",
            "-o",
            help=textwrap.dedent(
                """
        Path to the output directory which will contain: \n
            ├── A CSV file (data__cap_x.csv) containing the extracted ROI signal data.\n
            ├── A CSV file (metadata__cap_x.csv) containing the complete metadata information.\n
            ├── A log file (capfinder_vXYZ_datatime.log) containing the logs of the program.\n
            └── (Optional) plots directory containing cap signal plots, if --plot-signal is used.\n
            \u200B    ├── good_reads: Directory that contains the plots for the good reads.\n
            \u200B    ├── bad_reads: Directory that contains the plots for the bad reads.\n
            \u200B    └── plotpaths.csv: CSV file containing the paths to the plots based on the read ID.\n"""
            ),
        ),
    ] = "",
    n_workers: Annotated[
        int,
        typer.Option(
            "--n_workers", "-n", help="Number of CPUs to use for parallel processing"
        ),
    ] = 1,
    plot_signal: Annotated[
        Optional[bool],
        typer.Option(
            "--plot-signal/--no-plot-signal",
            help="Whether to plot extracted cap signal or not",
        ),
    ] = None,
    debug_code: Annotated[
        bool,
        typer.Option(
            "--debug/--no-debug",
            help="Enable debug mode for more detailed logging",
        ),
    ] = False,
) -> None:
    """
    Extracts signal corresponding to the RNA cap type using BAM and POD5 files. Also, generates plots if required.

    Example command (for training data):
    capfinder extract-cap-signal \\
        --bam_filepath /path/to/sorted.bam \\
        --pod5_dir /path/to/pod5_dir \\
        --reference GCTTTCGTTCGTCTCCGGACTTATCGCACCACCTATCCATCATCAGTACTGTNNNNNNCGATGTAACTGGGACATGGTGAGCAATCAGGGAAAAAAAAAAAAAAA \\
        --cap_class 0 \\
        --cap_n1_pos0 52 \\
        --train_or_test train \\
        --output_dir /path/to/output_dir \\
        --n_workers 10 \\
        --no-plot-signal \\
        --no-debug

    Example command (for testing data):
    capfinder extract-cap-signal \\
        --bam_filepath /path/to/sorted.bam \\
        --pod5_dir /path/to/pod5_dir \\
        --reference GCTTTCGTTCGTCTCCGGACTTATCGCACCACCTATCCATCATCAGTACTGT \\
        --cap_class -99 \\
        --cap_n1_pos0 52 \\
        --train_or_test test \\
        --output_dir /path/to/output_dir \\
        --n_workers 10 \\
        --no-plot-signal \\
        --no-debug
    """
    from capfinder.collate import collate_bam_pod5_wrapper

    ps = False
    if plot_signal is None:
        ps = False
    elif plot_signal:
        ps = True
    else:
        ps = False

    global formatted_command_global

    collate_bam_pod5_wrapper(
        bam_filepath=bam_filepath,
        pod5_dir=pod5_dir,
        num_processes=n_workers,
        reference=reference,
        cap_class=cap_class,
        cap0_pos=cap_n1_pos0,
        train_or_test=train_or_test,
        plot_signal=ps,
        output_dir=output_dir,
        debug_code=debug_code,
        formatted_command=formatted_command_global,
    )

list_caps() -> None

List all current cap mappings.

Source code in src/capfinder/cli.py
@caps_app.command("list")
def list_caps() -> None:
    """List all current cap mappings."""
    load_custom_mapping()  # Reload the mappings from the file
    global CAP_MAPPING

    if not CAP_MAPPING:
        typer.echo("No cap mappings found. Using default mappings:")
        for cap_int, cap_name in sorted(DEFAULT_CAP_MAPPING.items()):
            typer.echo(f"{cap_int}: {cap_name}")
    else:
        typer.echo("Current cap mappings:")
        for cap_int, cap_name in sorted(CAP_MAPPING.items()):
            typer.echo(f"{cap_int}: {cap_name}")

    next_available = get_next_available_cap_number()
    typer.echo(f"\nNext available cap number: {next_available}")
    typer.echo(f"\nCustom mappings file location: {CUSTOM_MAPPING_PATH}")

make_train_dataset(caps_data_dir: Annotated[str, typer.Option(--caps_data_dir, -c, help='Directory containing all the cap signal data files (data__cap_x.csv)')] = '', output_dir: Annotated[str, typer.Option(--output_dir, -o, help='A dataset directory will be created inside this directory automatically and the dataset will be saved there as CSV files.')] = '', target_length: Annotated[int, typer.Option(--target_length, -t, help='Number of signal points in cap signal to consider. If the signal is shorter, it will be padded with zeros. If the signal is longer, it will be truncated.')] = 500, dtype: Annotated[str, typer.Option(--dtype, -d, help="Data type to transform the dataset to. Valid values are 'float16', 'float32', or 'float64'.")] = 'float16', examples_per_class: Annotated[int, typer.Option(--examples_per_class, -e, help='Number of examples to include per class in the dataset')] = 1000, train_test_fraction: Annotated[float, typer.Option(--train_test_fraction, -tt, help='Fraction of data out of all data to use for training (0.0 to 1.0)')] = 0.95, train_val_fraction: Annotated[float, typer.Option(--train_val_fraction, -tv, help='Fraction of data out all the training split to use for validation (0.0 to 1.0)')] = 0.8, num_classes: Annotated[int, typer.Option(--num_classes, help='Number of classes in the dataset')] = 4, batch_size: Annotated[int, typer.Option(--batch_size, -b, help='Batch size for processing data')] = 1024, comet_project_name: Annotated[str, typer.Option(--comet_project_name, help='Name of the Comet ML project for logging')] = 'dataset', use_remote_dataset_version: Annotated[str, typer.Option(--use_remote_dataset_version, help='Version of the remote dataset to use. If not provided at all, the local dataset will be used/made and/or uploaded')] = '', use_augmentation: Annotated[bool, typer.Option(--use - augmentation / --no - use - augmentation, help='Whether to augment original data with time warped data')] = False) -> None

Prepares dataset for training the ML model. This command can be run independently from here or is automatically invoked by the train-model command.

This command processes cap signal data files, applies necessary transformations, and prepares a dataset suitable for training machine learning models. It supports both local data processing and fetching from a remote dataset.

Example command: capfinder make-train-dataset \ --caps_data_dir /path/to/caps_data \ --output_dir /path/to/output \ --target_length 500 \ --dtype float16 \ --examples_per_class 1000 \ --train_test_fraction 0.95 \ --train_val_fraction 0.8 \ --num_classes 4 \ --batch_size 32 \ --comet_project_name my-capfinder-project \ --use_remote_dataset_version latest --use-augmentation

Source code in src/capfinder/cli.py
@app.command()
def make_train_dataset(
    caps_data_dir: Annotated[
        str,
        typer.Option(
            "--caps_data_dir",
            "-c",
            help="Directory containing all the cap signal data files (data__cap_x.csv)",
        ),
    ] = "",
    output_dir: Annotated[
        str,
        typer.Option(
            "--output_dir",
            "-o",
            help="A dataset directory will be created inside this directory automatically and the dataset will be saved there as CSV files.",
        ),
    ] = "",
    target_length: Annotated[
        int,
        typer.Option(
            "--target_length",
            "-t",
            help="Number of signal points in cap signal to consider. If the signal is shorter, it will be padded with zeros. If the signal is longer, it will be truncated.",
        ),
    ] = 500,
    dtype: Annotated[
        str,
        typer.Option(
            "--dtype",
            "-d",
            help="Data type to transform the dataset to. Valid values are 'float16', 'float32', or 'float64'.",
        ),
    ] = "float16",
    examples_per_class: Annotated[
        int,
        typer.Option(
            "--examples_per_class",
            "-e",
            help="Number of examples to include per class in the dataset",
        ),
    ] = 1000,
    train_test_fraction: Annotated[
        float,
        typer.Option(
            "--train_test_fraction",
            "-tt",
            help="Fraction of data out of all data to use for training (0.0 to 1.0)",
        ),
    ] = 0.95,
    train_val_fraction: Annotated[
        float,
        typer.Option(
            "--train_val_fraction",
            "-tv",
            help="Fraction of data out all the training split to use for validation (0.0 to 1.0)",
        ),
    ] = 0.8,
    num_classes: Annotated[
        int,
        typer.Option(
            "--num_classes",
            help="Number of classes in the dataset",
        ),
    ] = 4,
    batch_size: Annotated[
        int,
        typer.Option(
            "--batch_size",
            "-b",
            help="Batch size for processing data",
        ),
    ] = 1024,
    comet_project_name: Annotated[
        str,
        typer.Option(
            "--comet_project_name",
            help="Name of the Comet ML project for logging",
        ),
    ] = "dataset",
    use_remote_dataset_version: Annotated[
        str,
        typer.Option(
            "--use_remote_dataset_version",
            help="Version of the remote dataset to use. If not provided at all, the local dataset will be used/made and/or uploaded",
        ),
    ] = "",
    use_augmentation: Annotated[
        bool,
        typer.Option(
            "--use-augmentation/--no-use-augmentation",
            help="Whether to augment original data with time warped data",
        ),
    ] = False,
) -> None:
    """
    Prepares dataset for training the ML model. This command can be run independently
    from here or is automatically invoked by the `train-model` command.

    This command processes cap signal data files, applies necessary transformations,
    and prepares a dataset suitable for training machine learning models. It supports
    both local data processing and fetching from a remote dataset.

    Example command:
    capfinder make-train-dataset \\
        --caps_data_dir /path/to/caps_data \\
        --output_dir /path/to/output \\
        --target_length 500 \\
        --dtype float16 \\
        --examples_per_class 1000 \\
        --train_test_fraction 0.95 \\
        --train_val_fraction 0.8 \\
        --num_classes 4 \\
        --batch_size 32 \\
        --comet_project_name my-capfinder-project \\
        --use_remote_dataset_version latest
        --use-augmentation

    """
    from typing import cast

    from capfinder.logger_config import configure_logger, configure_prefect_logging
    from capfinder.train_etl import DtypeLiteral, train_etl
    from capfinder.utils import log_header, log_output

    global formatted_command_global

    dataset_dir = os.path.join(output_dir, "dataset")
    if not os.path.exists(dataset_dir):
        os.makedirs(dataset_dir)
    log_filepath = configure_logger(
        os.path.join(dataset_dir, "logs"), show_location=False
    )
    configure_prefect_logging(show_location=False)
    version_info = version("capfinder")
    log_header(f"Using Capfinder v{version_info}")
    logger.info(formatted_command_global)

    dt: DtypeLiteral = "float32"
    if dtype in {"float16", "float32", "float64"}:
        dt = cast(DtypeLiteral, dtype)
    else:
        logger.warning(
            f"Invalid dtype literal: {dtype}. Allowed values are 'float16', 'float32', 'float64'. Using 'float32' as default."
        )

    train_etl(
        caps_data_dir=caps_data_dir,
        dataset_dir=dataset_dir,
        target_length=target_length,
        dtype=dt,
        examples_per_class=examples_per_class,
        train_test_fraction=train_test_fraction,
        train_val_fraction=train_val_fraction,
        num_classes=num_classes,
        batch_size=batch_size,
        comet_project_name=comet_project_name,
        use_remote_dataset_version=use_remote_dataset_version,
        use_augmentation=use_augmentation,
    )

    grey = "\033[90m"
    reset = "\033[0m"
    log_output(f"The log file has been saved to:\n {grey}{log_filepath}{reset}")
    log_header("Processing finished!")

predict_cap_types(bam_filepath: Annotated[str, typer.Option(--bam_filepath, -b, help='Path to the BAM file')] = '', pod5_dir: Annotated[str, typer.Option(--pod5_dir, -p, help='Path to directory containing POD5 files')] = '', output_dir: Annotated[str, typer.Option(--output_dir, -o, help='Path to the output directory for prediction results and logs')] = '', n_cpus: Annotated[int, typer.Option(--n_cpus, -n, help=textwrap.dedent(" Number of CPUs to use for parallel processing.\n We use multiple CPUs during processing for POD5 file and BAM data (Step 1/5).\n For faster processing of this data (POD5 & BAM), increase the number of CPUs.\n For inference (Step 4/5), only a single CPU is used no matter how many CPUs you have specified.\n For faster inference, have a GPU available (it will be detected automatically) and set dtype to 'float16'."))] = 1, dtype: Annotated[str, typer.Option(--dtype, -d, help=textwrap.dedent(" Data type for model input. Valid values are 'float16', 'float32', or 'float64'.\n If you do not have a GPU, use 'float32' or 'float64' for better performance.\n If you have a GPU, use 'float16' for faster inference."))] = 'float16', batch_size: Annotated[int, typer.Option(--batch_size, -bs, help=textwrap.dedent(' Batch size for model inference.\n Larger batch sizes can speed up inference but require more memory.'))] = 128, custom_model_path: Annotated[Optional[str], typer.Option(--custom_model_path, -m, help='Path to a custom model (.keras) file. If not provided, the default pre-packaged model will be used.')] = None, plot_signal: Annotated[bool, typer.Option(--plot - signal / --no - plot - signal, help=textwrap.dedent(' "Whether to plot extracted cap signal or not.\n Saving plots can help you plot the read\'s signal, and plot the signal for cap and flanking bases(±5).'))] = False, debug_code: Annotated[bool, typer.Option(--debug / --no - debug, help='Enable debug mode for more detailed logging')] = False, refresh_cache: Annotated[bool, typer.Option(--refresh - cache / --no - refresh - cache, help='Refresh the cache for intermediate results')] = False) -> None

Predicts RNA cap types using BAM and POD5 files.

Example command

capfinder predict-cap-types \ --bam_filepath /path/to/sorted.bam \ --pod5_dir /path/to/pod5_dir \ --output_dir /path/to/output_dir \ --n_cpus 10 \ --dtype float16 \ --batch_size 256 \ --no-plot-signal \ --no-debug \ --no-refresh-cache

Source code in src/capfinder/cli.py
@app.command()
def predict_cap_types(
    bam_filepath: Annotated[
        str, typer.Option("--bam_filepath", "-b", help="Path to the BAM file")
    ] = "",
    pod5_dir: Annotated[
        str,
        typer.Option(
            "--pod5_dir", "-p", help="Path to directory containing POD5 files"
        ),
    ] = "",
    output_dir: Annotated[
        str,
        typer.Option(
            "--output_dir",
            "-o",
            help="Path to the output directory for prediction results and logs",
        ),
    ] = "",
    # reference: Annotated[
    #     str,
    #     typer.Option("--reference", "-r", help="Reference Sequence (5' -> 3')"),
    # ] = "GCTTTCGTTCGTCTCCGGACTTATCGCACCACCTATCCATCATCAGTACTGT",
    # cap_n1_pos0: Annotated[
    #     int,
    #     typer.Option(
    #         "--cap_n1_pos0",
    #         "-p",
    #         help="0-based index of 1st nucleotide (N1) of cap in the reference",
    #     ),
    # ] = 52,
    n_cpus: Annotated[
        int,
        typer.Option(
            "--n_cpus",
            "-n",
            help=textwrap.dedent(
                """\
            Number of CPUs to use for parallel processing.
            We use multiple CPUs during processing for POD5 file and BAM data (Step 1/5).
            For faster processing of this data (POD5 & BAM), increase the number of CPUs.
            For inference (Step 4/5), only a single CPU is used no matter how many CPUs you have specified.
            For faster inference, have a GPU available (it will be detected automatically) and set dtype to 'float16'."""
            ),
        ),
    ] = 1,
    dtype: Annotated[
        str,
        typer.Option(
            "--dtype",
            "-d",
            help=textwrap.dedent(
                """\
            Data type for model input. Valid values are 'float16', 'float32', or 'float64'.
            If you do not have a GPU, use 'float32' or 'float64' for better performance.
            If you have a GPU, use 'float16' for faster inference."""
            ),
        ),
    ] = "float16",
    # target_length: Annotated[
    #     int,
    #     typer.Option(
    #         "--target_length",
    #         "-t",
    #         help="Number of signal points in cap signal to consider",
    #     ),
    # ] = 500,
    batch_size: Annotated[
        int,
        typer.Option(
            "--batch_size",
            "-bs",
            help=textwrap.dedent(
                """\
            Batch size for model inference.
            Larger batch sizes can speed up inference but require more memory."""
            ),
        ),
    ] = 128,
    custom_model_path: Annotated[
        Optional[str],
        typer.Option(
            "--custom_model_path",
            "-m",
            help="Path to a custom model (.keras) file. If not provided, the default pre-packaged model will be used.",
        ),
    ] = None,
    plot_signal: Annotated[
        bool,
        typer.Option(
            "--plot-signal/--no-plot-signal",
            help=textwrap.dedent(
                """\
                "Whether to plot extracted cap signal or not.
                Saving plots can help you plot the read's signal, and plot the signal for cap and flanking bases(&#177;5)."""
            ),
        ),
    ] = False,
    debug_code: Annotated[
        bool,
        typer.Option(
            "--debug/--no-debug",
            help="Enable debug mode for more detailed logging",
        ),
    ] = False,
    refresh_cache: Annotated[
        bool,
        typer.Option(
            "--refresh-cache/--no-refresh-cache",
            help="Refresh the cache for intermediate results",
        ),
    ] = False,
) -> None:
    """
    Predicts RNA cap types using BAM and POD5 files.

    Example command:
        capfinder predict-cap-types \\
        --bam_filepath /path/to/sorted.bam \\
        --pod5_dir /path/to/pod5_dir \\
        --output_dir /path/to/output_dir \\
        --n_cpus 10 \\
        --dtype float16 \\
        --batch_size 256 \\
        --no-plot-signal \\
        --no-debug \\
        --no-refresh-cache
    """
    from typing import cast

    from capfinder.inference import predict_cap_types
    from capfinder.train_etl import DtypeLiteral

    dt: DtypeLiteral = "float16"
    if dtype in {"float16", "float32", "float64"}:
        dt = cast(
            DtypeLiteral, dtype
        )  # This is safe because input_str must be one of the Literal values
    else:
        logger.warning(
            f"Invalid dtype literal: {dtype}. Allowed values are 'float16', 'float32', 'float64'. Using 'float16' as default."
        )

    global formatted_command_global

    predict_cap_types(
        bam_filepath=bam_filepath,
        pod5_dir=pod5_dir,
        num_cpus=n_cpus,
        output_dir=output_dir,
        dtype=dt,
        reference="GCTTTCGTTCGTCTCCGGACTTATCGCACCACCTATCCATCATCAGTACTGT",
        cap0_pos=52,
        train_or_test="test",
        plot_signal=plot_signal,
        cap_class=-99,
        target_length=500,
        batch_size=batch_size,
        custom_model_path=custom_model_path,
        debug_code=debug_code,
        refresh_cache=refresh_cache,
        formatted_command=formatted_command_global,
    )
    logger.success("Finished predicting cap types!")

remove_cap(cap_int: int) -> None

Remove a cap mapping.

Source code in src/capfinder/cli.py
@caps_app.command("remove")
def remove_cap(cap_int: int) -> None:
    """Remove a cap mapping."""
    global CAP_MAPPING

    if cap_int in CAP_MAPPING:
        del CAP_MAPPING[cap_int]
        save_custom_mapping(CAP_MAPPING)
        typer.echo(f"Removed mapping for cap integer: {cap_int}")
        typer.echo(f"Custom mappings saved to: {CUSTOM_MAPPING_PATH}")
    else:
        typer.echo(f"No mapping found for cap integer: {cap_int}")

reset_caps() -> None

Reset cap mappings to default.

Source code in src/capfinder/cli.py
@caps_app.command("reset")
def reset_caps() -> None:
    """Reset cap mappings to default."""
    global CAP_MAPPING
    CAP_MAPPING = DEFAULT_CAP_MAPPING.copy()
    save_custom_mapping(CAP_MAPPING)
    typer.echo("Cap mappings reset to default.")
    typer.echo(f"Default mappings saved to: {CUSTOM_MAPPING_PATH}")

show_config() -> None

Show the location of the configuration file.

Source code in src/capfinder/cli.py
@caps_app.command("config")
def show_config() -> None:
    """Show the location of the configuration file."""
    typer.echo(f"Custom mappings file location: {CUSTOM_MAPPING_PATH}")
    if CUSTOM_MAPPING_PATH.exists():
        typer.echo("The file exists and contains custom mappings.")
    else:
        typer.echo(
            "The file does not exist yet. It will be created when you add a custom mapping."
        )
        logger.warning(f"Config file does not exist at {CUSTOM_MAPPING_PATH}")

train_model(config_file: Annotated[str, typer.Option(--config_file, -c, help='Path to the JSON configuration file containing the parameters for the training pipeline.')] = '') -> None

Trains the model using the parameters in the JSON configuration file.

Source code in src/capfinder/cli.py
@app.command()
def train_model(
    config_file: Annotated[
        str,
        typer.Option(
            "--config_file",
            "-c",
            help="Path to the JSON configuration file containing the parameters for the training pipeline.",
        ),
    ] = "",
) -> None:
    """Trains the model using the parameters in the JSON configuration file."""
    import json

    from capfinder.training import run_training_pipeline

    # Load the configuration file
    with open(config_file) as file:
        config = json.load(file)

    etl_params = config["etl_params"]
    tune_params = config["tune_params"]
    train_params = config["train_params"]
    shared_params = config["shared_params"]
    lr_scheduler_params = config["lr_scheduler_params"]
    debug_code = config.get("debug_code", False)

    # Create a formatted command string with all parameters
    formatted_command = f"capfinder train-model --config_file {config_file}\n\n"
    formatted_command += "Configuration:\n"
    formatted_command += json.dumps(config, indent=2)

    # Run the training pipeline with the loaded parameters
    run_training_pipeline(
        etl_params=etl_params,
        tune_params=tune_params,
        train_params=train_params,
        shared_params=shared_params,
        lr_scheduler_params=lr_scheduler_params,
        debug_code=debug_code,
        formatted_command=formatted_command,
    )

collate

The main workhorse which collates information from the BAM file and the POD5 files, aligns OTE to extracts the signal for the region of interest (ROI) for training or testing purposes. It also plots the ROI signal if requested.

Author: Adnan M. Niazi Date: 2024-02-28

DatabaseHandler

Source code in src/capfinder/collate.py
class DatabaseHandler:
    def __init__(
        self,
        cap_class: int,
        num_processes: int,
        database_path: str,
        plots_csv_filepath: Union[str, None],
        output_dir: str,
    ) -> None:
        """Initializes the index database handler"""
        self.cap_class = cap_class
        self.database_path = database_path
        self.plots_csv_filepath = plots_csv_filepath
        self.num_processes = num_processes
        self.output_dir = output_dir

        # Open the plots CSV file in append mode
        if self.plots_csv_filepath:
            self.csvfile = open(self.plots_csv_filepath, "a", newline="")

    def init_func(self, worker_id: int, worker_state: Dict[str, Any]) -> None:
        """Opens the database connection and CSV files"""

        # 1. Open the database connection and cursor
        worker_state["db_connection"], worker_state["db_cursor"] = open_database(
            self.database_path
        )

        # 2. Write the header row to the plots CSV file
        if self.plots_csv_filepath:
            csv_writer = csv.writer(self.csvfile)
            csv_writer.writerow(["read_id", "plot_filepath"])
            worker_state["csv_writer"] = csv_writer
            worker_state["csvfile"] = self.csvfile  # Store csvfile in worker_state

        # Define paths to data and metadata CSV files
        data_file_path = os.path.join(self.output_dir, f"data_tmp_{worker_id}.csv")
        metadata_file_path = os.path.join(
            self.output_dir, f"metadata_tmp_{worker_id}.csv"
        )

        # 3. Open data_file_path in append mode and write the header row if the file is empty
        data_file = open(data_file_path, "a", newline="")
        data_writer = csv.writer(data_file)
        if data_file.tell() == 0:  # Check if the file is empty
            data_writer.writerow(
                ["read_id", "cap_class", "timeseries"]
            )  # Replace with your actual header

        # Save the data file path and writer to worker_state
        worker_state["data_file"] = data_file
        worker_state["data_writer"] = data_writer

        # 4. Open metadata_file_path in append mode and write the header row if the file is empty
        metadata_file = open(metadata_file_path, "a", newline="")
        metadata_writer = csv.writer(metadata_file)
        if metadata_file.tell() == 0:  # Check if the file is empty
            metadata_writer.writerow(
                [
                    "read_id",
                    "parent_read_id",
                    "pod5_file",
                    "read_type",
                    "roi_fasta",
                    "roi_start",
                    "roi_end",
                    "fasta_length",
                    "fasta",
                ]
            )  # Replace with your actual header

        # Save the metadata file path and writer to worker_state
        worker_state["metadata_file"] = metadata_file
        worker_state["metadata_writer"] = metadata_writer

    def exit_func(self, worker_id: int, worker_state: Dict[str, Any]) -> None:
        """Closes the database connection and the CSV files."""
        conn = worker_state.get("db_connection")
        if conn:
            conn.close()

        # Close the plots csv file
        csvfile = worker_state.get("csvfile")
        if self.plots_csv_filepath and csvfile:
            csvfile.close()

        # Close the data file
        worker_state["data_file"].close()

        # Close the metadata file
        worker_state["metadata_file"].close()

    def merge_data(self) -> Tuple[str, str]:
        """Merges the data and metadata CSV files."""
        data_path = self._merge_csv_files(data_or_metadata="data")
        metadata_path = self._merge_csv_files(data_or_metadata="metadata")
        return data_path, metadata_path

    def _merge_csv_files(self, data_or_metadata: str) -> str:
        """Merges the data and metadata CSV files.

        Args:
            data_or_metadata (str): Whether to merge data or metadata CSV files.

        Returns:
            str: Path to the merged CSV file.
        """
        cap_name = map_cap_int_to_name(self.cap_class)
        data_path = os.path.join(self.output_dir, f"{data_or_metadata}__{cap_name}.csv")
        # delete if the file already exists
        if os.path.exists(data_path):
            logger.info(f"Overwriting existing {data_or_metadata} CSV file.")
            os.remove(data_path)
        with open(data_path, "w", newline="") as output_csv:
            writer = csv.writer(output_csv)
            for i in range(self.num_processes):
                ind_csv_file = os.path.join(
                    self.output_dir, f"{data_or_metadata}_tmp_{i}.csv"
                )
                # Open each CSV file and read its contents
                with open(ind_csv_file) as input_csv:
                    reader = csv.reader(input_csv)

                    # If it's the first file, write the header to the output file
                    if i == 0:
                        header = next(reader)
                        writer.writerow(header)
                    else:
                        next(reader)

                    # Write the remaining rows to the output file
                    for row in reader:
                        writer.writerow(row)
                os.remove(ind_csv_file)
        logger.info(f"Successfully merged {data_or_metadata} CSV file.")
        return data_path

__init__(cap_class: int, num_processes: int, database_path: str, plots_csv_filepath: Union[str, None], output_dir: str) -> None

Initializes the index database handler

Source code in src/capfinder/collate.py
def __init__(
    self,
    cap_class: int,
    num_processes: int,
    database_path: str,
    plots_csv_filepath: Union[str, None],
    output_dir: str,
) -> None:
    """Initializes the index database handler"""
    self.cap_class = cap_class
    self.database_path = database_path
    self.plots_csv_filepath = plots_csv_filepath
    self.num_processes = num_processes
    self.output_dir = output_dir

    # Open the plots CSV file in append mode
    if self.plots_csv_filepath:
        self.csvfile = open(self.plots_csv_filepath, "a", newline="")

exit_func(worker_id: int, worker_state: Dict[str, Any]) -> None

Closes the database connection and the CSV files.

Source code in src/capfinder/collate.py
def exit_func(self, worker_id: int, worker_state: Dict[str, Any]) -> None:
    """Closes the database connection and the CSV files."""
    conn = worker_state.get("db_connection")
    if conn:
        conn.close()

    # Close the plots csv file
    csvfile = worker_state.get("csvfile")
    if self.plots_csv_filepath and csvfile:
        csvfile.close()

    # Close the data file
    worker_state["data_file"].close()

    # Close the metadata file
    worker_state["metadata_file"].close()

init_func(worker_id: int, worker_state: Dict[str, Any]) -> None

Opens the database connection and CSV files

Source code in src/capfinder/collate.py
def init_func(self, worker_id: int, worker_state: Dict[str, Any]) -> None:
    """Opens the database connection and CSV files"""

    # 1. Open the database connection and cursor
    worker_state["db_connection"], worker_state["db_cursor"] = open_database(
        self.database_path
    )

    # 2. Write the header row to the plots CSV file
    if self.plots_csv_filepath:
        csv_writer = csv.writer(self.csvfile)
        csv_writer.writerow(["read_id", "plot_filepath"])
        worker_state["csv_writer"] = csv_writer
        worker_state["csvfile"] = self.csvfile  # Store csvfile in worker_state

    # Define paths to data and metadata CSV files
    data_file_path = os.path.join(self.output_dir, f"data_tmp_{worker_id}.csv")
    metadata_file_path = os.path.join(
        self.output_dir, f"metadata_tmp_{worker_id}.csv"
    )

    # 3. Open data_file_path in append mode and write the header row if the file is empty
    data_file = open(data_file_path, "a", newline="")
    data_writer = csv.writer(data_file)
    if data_file.tell() == 0:  # Check if the file is empty
        data_writer.writerow(
            ["read_id", "cap_class", "timeseries"]
        )  # Replace with your actual header

    # Save the data file path and writer to worker_state
    worker_state["data_file"] = data_file
    worker_state["data_writer"] = data_writer

    # 4. Open metadata_file_path in append mode and write the header row if the file is empty
    metadata_file = open(metadata_file_path, "a", newline="")
    metadata_writer = csv.writer(metadata_file)
    if metadata_file.tell() == 0:  # Check if the file is empty
        metadata_writer.writerow(
            [
                "read_id",
                "parent_read_id",
                "pod5_file",
                "read_type",
                "roi_fasta",
                "roi_start",
                "roi_end",
                "fasta_length",
                "fasta",
            ]
        )  # Replace with your actual header

    # Save the metadata file path and writer to worker_state
    worker_state["metadata_file"] = metadata_file
    worker_state["metadata_writer"] = metadata_writer

merge_data() -> Tuple[str, str]

Merges the data and metadata CSV files.

Source code in src/capfinder/collate.py
def merge_data(self) -> Tuple[str, str]:
    """Merges the data and metadata CSV files."""
    data_path = self._merge_csv_files(data_or_metadata="data")
    metadata_path = self._merge_csv_files(data_or_metadata="metadata")
    return data_path, metadata_path

FASTQRecord dataclass

Simulates a FASTQ record object.

Attributes:

Name Type Description
id str

Read ID.

seq str

Read sequence.

Example

record = FASTQRecord(id="read1", seq="ATCG")

Source code in src/capfinder/collate.py
@dataclass
class FASTQRecord:
    """
    Simulates a FASTQ record object.

    Attributes:
        id (str): Read ID.
        seq (str): Read sequence.

    Example:
        >>> record = FASTQRecord(id="read1", seq="ATCG")
    """

    id: str
    seq: str

collate_bam_pod5(bam_filepath: str, pod5_dir: str, num_processes: int, reference: str, cap_class: int, cap0_pos: int, train_or_test: str, plot_signal: bool, output_dir: str) -> Tuple[str, str]

Collates information from the BAM file and the POD5 files, aligns OTE to extracts the signal for the region of interest (ROI) for training or testing purposes. It also plots the ROI signal if requested.

Parameters:

Name Type Description Default
bam_filepath str

Path to the BAM file.

required
pod5_dir str

Path to the directory containing the POD5 files.

required
num_processes int

Number of processes to use for parallel processing.

required
reference str

Reference sequence.

required
cap_class int

Class label for the RNA cap.

required
cap0_pos int

Position of the cap N1 base in the reference sequence (0-based).

required
train_or_test str

Whether to extract ROI for training or testing.

required
plot_signal bool

Whether to plot the ROI signal.

required
output_dir str

Path to the output directory.

required

Returns:

Type Description
Tuple[str, str]

Tuple[str, str]: Paths to the data and metadata CSV files.

Source code in src/capfinder/collate.py
def collate_bam_pod5(
    bam_filepath: str,
    pod5_dir: str,
    num_processes: int,
    reference: str,
    cap_class: int,
    cap0_pos: int,
    train_or_test: str,
    plot_signal: bool,
    output_dir: str,
) -> Tuple[str, str]:
    """
    Collates information from the BAM file and the POD5 files,
    aligns OTE to extracts the signal for the
    region of interest (ROI) for training or testing purposes.
    It also plots the ROI signal if requested.

    Args:
        bam_filepath (str): Path to the BAM file.
        pod5_dir (str): Path to the directory containing the POD5 files.
        num_processes (int): Number of processes to use for parallel processing.
        reference (str): Reference sequence.
        cap_class (int): Class label for the RNA cap.
        cap0_pos (int): Position of the cap N1 base in the reference sequence (0-based).
        train_or_test (str): Whether to extract ROI for training or testing.
        plot_signal (bool): Whether to plot the ROI signal.
        output_dir (str): Path to the output directory.

    Returns:
        Tuple[str, str]: Paths to the data and metadata CSV files.
    """
    os.makedirs(output_dir, exist_ok=True)

    logger.info("Computing BAM total records...")
    num_bam_records = get_total_records(bam_filepath)
    logger.info(f"Found {num_bam_records} BAM records!")

    # Make index database
    database_path = os.path.join(output_dir, "database.db")
    index(pod5_dir, output_dir)

    # If plots are requested, create the CSV file and the directories
    plots_csv_filepath = None
    if plot_signal:
        good_reads_plots_dir = os.path.join(output_dir, "plots", "good_reads", "0")
        bad_reads_plots_dir = os.path.join(output_dir, "plots", "bad_reads", "0")
        os.makedirs(good_reads_plots_dir, exist_ok=True)
        os.makedirs(bad_reads_plots_dir, exist_ok=True)
        plots_csv_filepath = os.path.join(output_dir, "plots", "plotpaths.csv")

    # Initialize the database handler
    db_handler = DatabaseHandler(
        cap_class, num_processes, database_path, plots_csv_filepath, output_dir
    )

    try:
        logger.info("Processing BAM file using multiple processes...")
        with WorkerPool(
            n_jobs=num_processes, use_worker_state=True, pass_worker_id=True
        ) as pool:
            iterator = zip(
                generate_pickled_bam_records(bam_filepath),
                repeat(reference),
                repeat(cap_class),
                repeat(cap0_pos),
                repeat(train_or_test),
                repeat(plot_signal),
                repeat(output_dir),
            )
            for _ in pool.imap_unordered(
                collate_bam_pod5_worker,
                iterator,
                worker_init=db_handler.init_func,
                worker_exit=db_handler.exit_func,
                progress_bar=True,
                iterable_len=num_bam_records,
            ):
                pass  # We don't need to do anything with the results

    except Exception as e:
        logger.error(f"An error occurred: {e}")
    finally:
        if plots_csv_filepath:
            csvfile = db_handler.csvfile
            if csvfile:
                csvfile.close()

        # Merge the data and metadata CSV files
        data_path, metadata_path = db_handler.merge_data()
        logger.info("Cap signal data extracted successfully!")
    return data_path, metadata_path

collate_bam_pod5_worker(worker_id: int, worker_state: Dict[str, Any], pickled_bam_data: bytes, reference: str, cap_class: int, cap0_pos: int, train_or_test: str, plot_signal: bool, output_dir: str) -> None

Worker function that collates information from POD5 and BAM file, finds the FASTA coordinates of region of interest (ROI) and and extracts its signal.

Parameters:

Name Type Description Default
worker_id int

int Worker ID.

required
worker_state Dict[str, Any]

dict Dictionary containing the database connection and cursor.

required
pickled_bam_data bytes

bytes Pickled dictionary containing the BAM record information.

required
reference str

str Reference sequence.

required
cap_class int

int Class label for the RNA cap

required
cap0_pos int

int Position of the cap0 base in the reference sequence.

required
train_or_test str

str Whether to extract ROI for training or testing.

required
plot_signal bool

bool Whether to plot the ROI signal.

required
output_dir str

str Path to the output directory.

required

Returns:

Type Description
None

None

Source code in src/capfinder/collate.py
def collate_bam_pod5_worker(
    worker_id: int,
    worker_state: Dict[str, Any],
    pickled_bam_data: bytes,
    reference: str,
    cap_class: int,
    cap0_pos: int,
    train_or_test: str,
    plot_signal: bool,
    output_dir: str,
) -> None:
    """Worker function that collates information from POD5 and BAM file, finds the
    FASTA coordinates of  region of interest (ROI) and and extracts its signal.

    Params:
        worker_id: int
            Worker ID.
        worker_state: dict
            Dictionary containing the database connection and cursor.
        pickled_bam_data: bytes
            Pickled dictionary containing the BAM record information.
        reference: str
            Reference sequence.
        cap_class: int
            Class label for the RNA cap
        cap0_pos: int
            Position of the cap0 base in the reference sequence.
        train_or_test: str
            Whether to extract ROI for training or testing.
        plot_signal: bool
            Whether to plot the ROI signal.
        output_dir: str
            Path to the output directory.

    Returns:
        None
    """
    # 1. Get read info from bam record
    bam_data = pickle.loads(pickled_bam_data)
    read_id = bam_data["read_id"]
    pod5_filename = bam_data["pod5_filename"]
    parent_read_id = bam_data["parent_read_id"]
    # 2. Find the pod5 filepath corresponding to the pod5_filename in the database
    pod5_filepath = fetch_filepath_using_filename(
        worker_state["db_connection"], worker_state["db_cursor"], pod5_filename
    )

    # 3. Pull the read data from the multi-pod5 file
    # If the read is a split read, pull the parent read data
    if parent_read_id == "":
        pod5_data = pull_read_from_pod5(read_id, pod5_filepath)
    else:
        pod5_data = pull_read_from_pod5(parent_read_id, pod5_filepath)

    # 4. Extract the locations of each new base in signal coordinates
    base_locs_in_signal = find_base_locs_in_signal(bam_data)

    # 5. Get alignment of OTE with the read
    # Simulate a FASTQ record object
    read_fasta = bam_data["read_fasta"]

    # Check that the read is not empty
    if read_fasta is None:
        logger.warning(f"Read {read_id} has empty FASTA. Skipping the read.")
        return None

    fastq_record = FASTQRecord(read_id, read_fasta)
    if train_or_test.lower() == "train":
        aln_res = extract_roi_coords_train(
            record=fastq_record, reference=reference, cap0_pos=cap0_pos
        )
    elif train_or_test.lower() == "test":
        aln_res = extract_roi_coords_test(
            record=fastq_record, reference=reference, cap0_pos=cap0_pos
        )
    else:
        logger.warning(
            "Invalid train_or_test argument. Must be either 'train' or 'test'."
        )
        return None

    # 6. Extract signal data for the ROI
    start_base_idx_in_fasta = aln_res["left_flanking_region_start_fastq_pos"]
    end_base_idx_in_fasta = aln_res["right_flanking_region_start_fastq_pos"]

    roi_data = extract_roi_signal(
        signal=pod5_data["signal_pa"],
        base_locs_in_signal=base_locs_in_signal,
        fasta=read_fasta,
        experiment_type=pod5_data["experiment_type"],
        start_base_idx_in_fasta=start_base_idx_in_fasta,
        end_base_idx_in_fasta=end_base_idx_in_fasta,
        num_left_clipped_bases=bam_data["num_left_clipped_bases"],
    )

    # 7. Add additional information to the ROI data
    roi_data["start_base_idx_in_fasta"] = start_base_idx_in_fasta
    roi_data["end_base_idx_in_fasta"] = end_base_idx_in_fasta
    roi_data["read_id"] = read_id

    # 8. Find if a read is good or bad
    read_type = (
        "bad_reads"
        if start_base_idx_in_fasta is None and end_base_idx_in_fasta is None
        else "good_reads"
    )

    # 9. Save the train/test and metadata information
    # We need to store train/test data only for the good reads
    precision = 8

    # Define a vectorized function for formatting (if applicable)
    def format_value(x: float) -> str:
        return f"{x:.{precision}f}"

    vectorized_formatter = np.vectorize(format_value)

    if read_type == "good_reads":
        roi_signal: np.ndarray = roi_data["roi_signal"]
        if roi_signal.size == 0:
            read_type = "bad_reads"
        else:
            timeseries_str = ",".join(vectorized_formatter(roi_data["roi_signal"]))
            worker_state["data_writer"].writerow([read_id, cap_class, timeseries_str])

    # We need to store metadata for all reads (good and bad)
    if read_fasta is not None:
        read_length = len(read_fasta)
    else:
        read_length = 0

    worker_state["metadata_writer"].writerow(
        [
            read_id,
            parent_read_id,
            pod5_filepath,
            read_type.rstrip("s"),
            roi_data["roi_fasta"],
            roi_data["start_base_idx_in_fasta"],
            roi_data["end_base_idx_in_fasta"],
            read_length,
            read_fasta,
        ]
    )

    # 10. Plot the ROI signal if requested
    # Save plot in directories of 100 plots each separated into
    # good and bad categories. Good reads mean those that have
    # the OTE in them and bad reads mean those that do not.
    if plot_signal:
        count_key = f"{read_type}_count"
        dir_key = f"{read_type}_dir"
        with lock:
            shared_dict[count_key] = shared_dict.get(count_key, 0) + 1
            if shared_dict[count_key] > 100:
                worker_state[
                    "csvfile"
                ].flush()  # write the rows in the buffer to the csv file
                shared_dict[dir_key] = shared_dict.get(dir_key, 0) + 1
                shared_dict[count_key] = 1
                os.makedirs(
                    os.path.join(
                        output_dir, "plots", read_type, str(shared_dict[dir_key])
                    ),
                    exist_ok=True,
                )
            # Get the current timestamp
            # We append the timestamp to the name of the plot file
            # so that we can handle multiple plots for the same read
            # due to multiple alignments (secondary/supp.) in SAM files
            timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
            plot_filepath = os.path.join(
                output_dir,
                "plots",
                read_type,
                str(shared_dict[dir_key]),
                f"{read_id}__{timestamp}.html",
            )
            # Write the data for this plot
            worker_state["csv_writer"].writerow([read_id, plot_filepath])
        # Suppress the output of align() function
        with contextlib.redirect_stdout(None):
            _, _, chunked_aln_str, alignment_score = align(
                query_seq=read_fasta, target_seq=reference, pretty_print_alns=True
            )
        plot_roi_signal(
            pod5_data,
            bam_data,
            roi_data,
            start_base_idx_in_fasta,
            end_base_idx_in_fasta,
            plot_filepath,
            chunked_aln_str,
            alignment_score,
        )

    return None

collate_bam_pod5_wrapper(bam_filepath: str, pod5_dir: str, num_processes: int, reference: str, cap_class: int, cap0_pos: int, train_or_test: str, plot_signal: bool, output_dir: str, debug_code: bool, formatted_command: Optional[str]) -> None

Wrapper function for collate_bam_pod5 that sets up logging and handles output.

Parameters:

Name Type Description Default
bam_filepath str

Path to the BAM file.

required
pod5_dir str

Path to the directory containing the POD5 files.

required
num_processes int

Number of processes to use for parallel processing.

required
reference str

Reference sequence.

required
cap_class int

Class label for the RNA cap.

required
cap0_pos int

Position of the cap N1 base in the reference sequence (0-based).

required
train_or_test str

Whether to extract ROI for training or testing.

required
plot_signal bool

Whether to plot the ROI signal.

required
output_dir str

Path to the output directory.

required
debug_code bool

Whether to show debug information in logs.

required
formatted_command Optional[str]

Formatted command string for logging.

required
Source code in src/capfinder/collate.py
def collate_bam_pod5_wrapper(
    bam_filepath: str,
    pod5_dir: str,
    num_processes: int,
    reference: str,
    cap_class: int,
    cap0_pos: int,
    train_or_test: str,
    plot_signal: bool,
    output_dir: str,
    debug_code: bool,
    formatted_command: Optional[str],
) -> None:
    """
    Wrapper function for collate_bam_pod5 that sets up logging and handles output.

    Args:
        bam_filepath (str): Path to the BAM file.
        pod5_dir (str): Path to the directory containing the POD5 files.
        num_processes (int): Number of processes to use for parallel processing.
        reference (str): Reference sequence.
        cap_class (int): Class label for the RNA cap.
        cap0_pos (int): Position of the cap N1 base in the reference sequence (0-based).
        train_or_test (str): Whether to extract ROI for training or testing.
        plot_signal (bool): Whether to plot the ROI signal.
        output_dir (str): Path to the output directory.
        debug_code (bool): Whether to show debug information in logs.
        formatted_command (Optional[str]): Formatted command string for logging.
    """
    log_filepath = configure_logger(
        os.path.join(output_dir, "logs"), show_location=debug_code
    )
    configure_prefect_logging(show_location=debug_code)
    version_info = version("capfinder")
    log_header(f"Using Capfinder v{version_info}")
    logger.info(formatted_command)

    data_path, metadata_path = collate_bam_pod5(
        bam_filepath,
        pod5_dir,
        num_processes,
        reference,
        cap_class,
        cap0_pos,
        train_or_test,
        plot_signal,
        output_dir,
    )
    grey = "\033[90m"
    reset = "\033[0m"
    log_output(
        f"Cap data has been saved to the following path:\n {grey}{data_path}{reset}\nCap metadata have been saved to the following path:\n {grey}{metadata_path}{reset}\nThe log file has been saved to:\n {grey}{log_filepath}{reset}"
    )
    log_header("Processing finished!")

generate_pickled_bam_records(bam_filepath: str) -> Generator[bytes, None, None]

Generate pickled BAM records from a BAM file.

Parameters:

Name Type Description Default
bam_filepath str

Path to the BAM file.

required

Yields:

Name Type Description
bytes bytes

Pickled BAM record.

Source code in src/capfinder/collate.py
def generate_pickled_bam_records(bam_filepath: str) -> Generator[bytes, None, None]:
    """
    Generate pickled BAM records from a BAM file.

    Args:
        bam_filepath (str): Path to the BAM file.

    Yields:
        bytes: Pickled BAM record.
    """
    with pysam.AlignmentFile(bam_filepath, "rb") as bam_file:
        for record in bam_file:
            yield pickle.dumps(get_signal_info(record))

constants

The module contains constants used in the capfinder package.

Author: Adnan M. Niazi Date: 2024-02-28

cyclic_learing_rate

CometLRLogger

Bases: Callback

A callback to log the learning rate to Comet.ml during training.

This callback logs the learning rate at the beginning of each epoch and at the end of each batch to a Comet.ml experiment.

Attributes:

Name Type Description
experiment Experiment

The Comet.ml experiment to log to.

Source code in src/capfinder/cyclic_learing_rate.py
class CometLRLogger(Callback):
    """
    A callback to log the learning rate to Comet.ml during training.

    This callback logs the learning rate at the beginning of each epoch
    and at the end of each batch to a Comet.ml experiment.

    Attributes:
        experiment (Experiment): The Comet.ml experiment to log to.
    """

    def __init__(self, experiment: Experiment) -> None:
        """
        Initialize the CometLRLogger.

        Args:
            experiment (Experiment): The Comet.ml experiment to log to.
        """
        super().__init__()
        self.experiment: Experiment = experiment

    def on_epoch_begin(self, epoch: int, logs: Optional[Dict[str, Any]] = None) -> None:
        """
        Log the learning rate at the beginning of each epoch.

        Args:
            epoch (int): The current epoch number.
            logs (Optional[Dict[str, Any]]): The logs dictionary.
        """
        lr: Union[float, np.ndarray] = self.model.optimizer.learning_rate
        if hasattr(lr, "numpy"):
            lr = lr.numpy()
        self.experiment.log_metric("learning_rate", lr, step=epoch)

    def on_batch_end(self, batch: int, logs: Optional[Dict[str, Any]] = None) -> None:
        """
        Log the learning rate at the end of each batch.

        Args:
            batch (int): The current batch number.
            logs (Optional[Dict[str, Any]]): The logs dictionary.
        """
        lr: Union[float, np.ndarray] = self.model.optimizer.learning_rate
        if hasattr(lr, "numpy"):
            lr = lr.numpy()
        self.experiment.log_metric(
            "learning_rate", lr, step=self.model.optimizer.iterations.numpy()
        )

__init__(experiment: Experiment) -> None

Initialize the CometLRLogger.

Parameters:

Name Type Description Default
experiment Experiment

The Comet.ml experiment to log to.

required
Source code in src/capfinder/cyclic_learing_rate.py
def __init__(self, experiment: Experiment) -> None:
    """
    Initialize the CometLRLogger.

    Args:
        experiment (Experiment): The Comet.ml experiment to log to.
    """
    super().__init__()
    self.experiment: Experiment = experiment

on_batch_end(batch: int, logs: Optional[Dict[str, Any]] = None) -> None

Log the learning rate at the end of each batch.

Parameters:

Name Type Description Default
batch int

The current batch number.

required
logs Optional[Dict[str, Any]]

The logs dictionary.

None
Source code in src/capfinder/cyclic_learing_rate.py
def on_batch_end(self, batch: int, logs: Optional[Dict[str, Any]] = None) -> None:
    """
    Log the learning rate at the end of each batch.

    Args:
        batch (int): The current batch number.
        logs (Optional[Dict[str, Any]]): The logs dictionary.
    """
    lr: Union[float, np.ndarray] = self.model.optimizer.learning_rate
    if hasattr(lr, "numpy"):
        lr = lr.numpy()
    self.experiment.log_metric(
        "learning_rate", lr, step=self.model.optimizer.iterations.numpy()
    )

on_epoch_begin(epoch: int, logs: Optional[Dict[str, Any]] = None) -> None

Log the learning rate at the beginning of each epoch.

Parameters:

Name Type Description Default
epoch int

The current epoch number.

required
logs Optional[Dict[str, Any]]

The logs dictionary.

None
Source code in src/capfinder/cyclic_learing_rate.py
def on_epoch_begin(self, epoch: int, logs: Optional[Dict[str, Any]] = None) -> None:
    """
    Log the learning rate at the beginning of each epoch.

    Args:
        epoch (int): The current epoch number.
        logs (Optional[Dict[str, Any]]): The logs dictionary.
    """
    lr: Union[float, np.ndarray] = self.model.optimizer.learning_rate
    if hasattr(lr, "numpy"):
        lr = lr.numpy()
    self.experiment.log_metric("learning_rate", lr, step=epoch)

CustomProgressCallback

Bases: Callback

A custom callback to print the learning rate at the end of each epoch.

This callback prints the current learning rate after Keras' built-in progress bar for each epoch.

Source code in src/capfinder/cyclic_learing_rate.py
class CustomProgressCallback(keras.callbacks.Callback):
    """
    A custom callback to print the learning rate at the end of each epoch.

    This callback prints the current learning rate after Keras' built-in
    progress bar for each epoch.
    """

    def __init__(self) -> None:
        """Initialize the CustomProgressCallback."""
        super().__init__()

    def on_epoch_end(self, epoch: int, logs: Optional[Dict[str, Any]] = None) -> None:
        """
        Print the learning rate at the end of each epoch.

        Args:
            epoch (int): The current epoch number.
            logs (Optional[Dict[str, Any]]): The logs dictionary.
        """
        lr: Union[float, np.ndarray] = self.model.optimizer.learning_rate
        if hasattr(lr, "numpy"):
            lr = lr.numpy()
        print(f"\nLearning rate: {lr:.6f}")

__init__() -> None

Initialize the CustomProgressCallback.

Source code in src/capfinder/cyclic_learing_rate.py
def __init__(self) -> None:
    """Initialize the CustomProgressCallback."""
    super().__init__()

on_epoch_end(epoch: int, logs: Optional[Dict[str, Any]] = None) -> None

Print the learning rate at the end of each epoch.

Parameters:

Name Type Description Default
epoch int

The current epoch number.

required
logs Optional[Dict[str, Any]]

The logs dictionary.

None
Source code in src/capfinder/cyclic_learing_rate.py
def on_epoch_end(self, epoch: int, logs: Optional[Dict[str, Any]] = None) -> None:
    """
    Print the learning rate at the end of each epoch.

    Args:
        epoch (int): The current epoch number.
        logs (Optional[Dict[str, Any]]): The logs dictionary.
    """
    lr: Union[float, np.ndarray] = self.model.optimizer.learning_rate
    if hasattr(lr, "numpy"):
        lr = lr.numpy()
    print(f"\nLearning rate: {lr:.6f}")

CyclicLR

Bases: Callback

This callback implements a cyclical learning rate policy (CLR). The method cycles the learning rate between two boundaries with some constant frequency.

Arguments

base_lr: initial learning rate which is the
    lower boundary in the cycle.
max_lr: upper boundary in the cycle. Functionally,
    it defines the cycle amplitude (max_lr - base_lr).
    The lr at any cycle is the sum of base_lr
    and some scaling of the amplitude; therefore
    max_lr may not actually be reached depending on
    scaling function.
step_size: number of training iterations per
    half cycle. Authors suggest setting step_size
    2-8 x training iterations in epoch.
mode: one of {triangular, triangular2, exp_range}.
    Default 'triangular'.
    Values correspond to policies detailed above.
    If scale_fn is not None, this argument is ignored.
gamma: constant in 'exp_range' scaling function:
    gamma**(cycle iterations)
scale_fn: Custom scaling policy defined by a single
    argument lambda function, where
    0 <= scale_fn(x) <= 1 for all x >= 0.
    mode paramater is ignored
scale_mode: {'cycle', 'iterations'}.
    Defines whether scale_fn is evaluated on
    cycle number or cycle iterations (training
    iterations since start of cycle). Default is 'cycle'.
Source code in src/capfinder/cyclic_learing_rate.py
class CyclicLR(Callback):
    """
    This callback implements a cyclical learning rate policy (CLR).
    The method cycles the learning rate between two boundaries with
    some constant frequency.

    # Arguments
        base_lr: initial learning rate which is the
            lower boundary in the cycle.
        max_lr: upper boundary in the cycle. Functionally,
            it defines the cycle amplitude (max_lr - base_lr).
            The lr at any cycle is the sum of base_lr
            and some scaling of the amplitude; therefore
            max_lr may not actually be reached depending on
            scaling function.
        step_size: number of training iterations per
            half cycle. Authors suggest setting step_size
            2-8 x training iterations in epoch.
        mode: one of {triangular, triangular2, exp_range}.
            Default 'triangular'.
            Values correspond to policies detailed above.
            If scale_fn is not None, this argument is ignored.
        gamma: constant in 'exp_range' scaling function:
            gamma**(cycle iterations)
        scale_fn: Custom scaling policy defined by a single
            argument lambda function, where
            0 <= scale_fn(x) <= 1 for all x >= 0.
            mode paramater is ignored
        scale_mode: {'cycle', 'iterations'}.
            Defines whether scale_fn is evaluated on
            cycle number or cycle iterations (training
            iterations since start of cycle). Default is 'cycle'.
    """

    def __init__(
        self,
        base_lr: float = 0.001,
        max_lr: float = 0.006,
        step_size: float = 2000.0,
        mode: str = "triangular",
        gamma: float = 1.0,
        scale_fn: Optional[Callable[[float], float]] = None,
        scale_mode: str = "cycle",
    ) -> None:
        super().__init__()

        self.base_lr: float = base_lr
        self.max_lr: float = max_lr
        self.step_size: float = step_size
        self.mode: str = mode
        self.gamma: float = gamma

        if scale_fn is None:
            if self.mode == "triangular":
                self.scale_fn = lambda x: 1.0
                self.scale_mode = "cycle"
            elif self.mode == "triangular2":
                self.scale_fn = lambda x: 1 / (2.0 ** (x - 1))
                self.scale_mode = "cycle"
            elif self.mode == "exp_range":
                self.scale_fn = lambda x: gamma**x
                self.scale_mode = "iterations"
        else:
            self.scale_fn = scale_fn
            self.scale_mode = scale_mode

        self.clr_iterations: float = 0.0
        self.trn_iterations: float = 0.0
        self.history: Dict[str, list] = {}

        self._reset()

    def _reset(
        self,
        new_base_lr: Optional[float] = None,
        new_max_lr: Optional[float] = None,
        new_step_size: Optional[float] = None,
    ) -> None:
        """Resets cycle iterations.
        Optional boundary/step size adjustment.
        """
        if new_base_lr is not None:
            self.base_lr = new_base_lr
        if new_max_lr is not None:
            self.max_lr = new_max_lr
        if new_step_size is not None:
            self.step_size = new_step_size
        self.clr_iterations = 0.0

    def clr(self) -> Union[float, NDArray[np.float64]]:
        cycle: float = np.floor(1 + self.clr_iterations / (2 * self.step_size))
        x: float = np.abs(self.clr_iterations / self.step_size - 2 * cycle + 1)
        clr_value: float = (
            self.base_lr
            + (self.max_lr - self.base_lr)
            * np.maximum(0, (1 - x))
            * self.scale_fn(cycle)
            if self.scale_mode == "cycle"
            else self.scale_fn(self.clr_iterations)
        )
        return (
            float(clr_value)
            if isinstance(clr_value, float)
            else np.array(clr_value, dtype=np.float64)
        )

    def on_train_begin(self, logs: Optional[Dict[str, Any]] = None) -> None:
        """Initialize the learning rate to the base learning rate."""
        logs = logs or {}

        if self.clr_iterations == 0:
            self.model.optimizer.learning_rate.assign(self.base_lr)
        else:
            self.model.optimizer.learning_rate.assign(self.clr())

    def on_batch_end(self, batch: int, logs: Optional[Dict[str, Any]] = None) -> None:
        """Record previous batch statistics and update the learning rate."""
        logs = logs or {}
        self.trn_iterations += 1
        self.clr_iterations += 1

        self.history.setdefault("lr", []).append(
            self.model.optimizer.learning_rate.numpy()
        )
        self.history.setdefault("iterations", []).append(self.trn_iterations)

        for k, v in logs.items():
            self.history.setdefault(k, []).append(v)

        self.model.optimizer.learning_rate.assign(self.clr())

on_batch_end(batch: int, logs: Optional[Dict[str, Any]] = None) -> None

Record previous batch statistics and update the learning rate.

Source code in src/capfinder/cyclic_learing_rate.py
def on_batch_end(self, batch: int, logs: Optional[Dict[str, Any]] = None) -> None:
    """Record previous batch statistics and update the learning rate."""
    logs = logs or {}
    self.trn_iterations += 1
    self.clr_iterations += 1

    self.history.setdefault("lr", []).append(
        self.model.optimizer.learning_rate.numpy()
    )
    self.history.setdefault("iterations", []).append(self.trn_iterations)

    for k, v in logs.items():
        self.history.setdefault(k, []).append(v)

    self.model.optimizer.learning_rate.assign(self.clr())

on_train_begin(logs: Optional[Dict[str, Any]] = None) -> None

Initialize the learning rate to the base learning rate.

Source code in src/capfinder/cyclic_learing_rate.py
def on_train_begin(self, logs: Optional[Dict[str, Any]] = None) -> None:
    """Initialize the learning rate to the base learning rate."""
    logs = logs or {}

    if self.clr_iterations == 0:
        self.model.optimizer.learning_rate.assign(self.base_lr)
    else:
        self.model.optimizer.learning_rate.assign(self.clr())

SGDRScheduler

Bases: Callback

Cosine annealing learning rate scheduler with periodic restarts.

Parameters:

Name Type Description Default
min_lr float

The lower bound of the learning rate range for the experiment.

required
max_lr float

The upper bound of the learning rate range for the experiment.

required
steps_per_epoch int

Number of mini-batches in the dataset.

required
lr_decay float

Reduce the max_lr after the completion of each cycle.

1.0
cycle_length int

Initial number of epochs in a cycle.

10
mult_factor float

Scale epochs_to_restart after each full cycle completion.

2.0
Source code in src/capfinder/cyclic_learing_rate.py
class SGDRScheduler(Callback):
    """
    Cosine annealing learning rate scheduler with periodic restarts.

    Args:
        min_lr: The lower bound of the learning rate range for the experiment.
        max_lr: The upper bound of the learning rate range for the experiment.
        steps_per_epoch: Number of mini-batches in the dataset.
        lr_decay: Reduce the max_lr after the completion of each cycle.
        cycle_length: Initial number of epochs in a cycle.
        mult_factor: Scale epochs_to_restart after each full cycle completion.
    """

    def __init__(
        self,
        min_lr: float,
        max_lr: float,
        steps_per_epoch: int,
        lr_decay: float = 1.0,
        cycle_length: int = 10,
        mult_factor: float = 2.0,
    ) -> None:
        super().__init__()
        self.min_lr: float = min_lr
        self.max_lr: float = max_lr
        self.lr_decay: float = lr_decay
        self.batch_since_restart: int = 0
        self.next_restart: int = cycle_length
        self.steps_per_epoch: int = steps_per_epoch
        self.cycle_length: float = cycle_length
        self.mult_factor: float = mult_factor
        self.history: Dict[str, list] = {}
        self.best_weights: Optional[list] = None

    def clr(self) -> float:
        """Calculate the learning rate."""
        fraction_to_restart: float = self.batch_since_restart / (
            self.steps_per_epoch * self.cycle_length
        )
        lr: float = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (
            1 + np.cos(fraction_to_restart * np.pi)
        )
        return float(lr)

    def set_lr(self, lr: float) -> None:
        """Set the learning rate for the optimizer."""
        self.model.optimizer.learning_rate.assign(lr)

    def get_lr(self) -> float:
        """Get the current learning rate."""
        return float(self.model.optimizer.learning_rate.value)

    def on_train_begin(self, logs: Optional[Dict[str, Any]] = None) -> None:
        """Initialize the learning rate to the maximum value at the start of training."""
        logs = logs or {}
        self.set_lr(self.max_lr)

    def on_batch_end(self, batch: int, logs: Optional[Dict[str, Any]] = None) -> None:
        """Record previous batch statistics and update the learning rate."""
        logs = logs or {}
        self.history.setdefault("lr", []).append(self.get_lr())
        for k, v in logs.items():
            self.history.setdefault(k, []).append(v)

        self.batch_since_restart += 1
        new_lr: float = self.clr()
        self.set_lr(new_lr)

    def on_epoch_end(self, epoch: int, logs: Optional[Dict[str, Any]] = None) -> None:
        """Check for end of current cycle, apply restarts when necessary."""
        if epoch + 1 == self.next_restart:
            self.batch_since_restart = 0
            self.cycle_length = np.ceil(self.cycle_length * self.mult_factor)
            self.next_restart += int(self.cycle_length)
            self.max_lr *= self.lr_decay
            self.best_weights = self.model.get_weights()

    def on_train_end(self, logs: Optional[Dict[str, Any]] = None) -> None:
        """Set weights to the values from the end of the most recent cycle for best performance."""
        if self.best_weights is not None:
            self.model.set_weights(self.best_weights)

clr() -> float

Calculate the learning rate.

Source code in src/capfinder/cyclic_learing_rate.py
def clr(self) -> float:
    """Calculate the learning rate."""
    fraction_to_restart: float = self.batch_since_restart / (
        self.steps_per_epoch * self.cycle_length
    )
    lr: float = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (
        1 + np.cos(fraction_to_restart * np.pi)
    )
    return float(lr)

get_lr() -> float

Get the current learning rate.

Source code in src/capfinder/cyclic_learing_rate.py
def get_lr(self) -> float:
    """Get the current learning rate."""
    return float(self.model.optimizer.learning_rate.value)

on_batch_end(batch: int, logs: Optional[Dict[str, Any]] = None) -> None

Record previous batch statistics and update the learning rate.

Source code in src/capfinder/cyclic_learing_rate.py
def on_batch_end(self, batch: int, logs: Optional[Dict[str, Any]] = None) -> None:
    """Record previous batch statistics and update the learning rate."""
    logs = logs or {}
    self.history.setdefault("lr", []).append(self.get_lr())
    for k, v in logs.items():
        self.history.setdefault(k, []).append(v)

    self.batch_since_restart += 1
    new_lr: float = self.clr()
    self.set_lr(new_lr)

on_epoch_end(epoch: int, logs: Optional[Dict[str, Any]] = None) -> None

Check for end of current cycle, apply restarts when necessary.

Source code in src/capfinder/cyclic_learing_rate.py
def on_epoch_end(self, epoch: int, logs: Optional[Dict[str, Any]] = None) -> None:
    """Check for end of current cycle, apply restarts when necessary."""
    if epoch + 1 == self.next_restart:
        self.batch_since_restart = 0
        self.cycle_length = np.ceil(self.cycle_length * self.mult_factor)
        self.next_restart += int(self.cycle_length)
        self.max_lr *= self.lr_decay
        self.best_weights = self.model.get_weights()

on_train_begin(logs: Optional[Dict[str, Any]] = None) -> None

Initialize the learning rate to the maximum value at the start of training.

Source code in src/capfinder/cyclic_learing_rate.py
def on_train_begin(self, logs: Optional[Dict[str, Any]] = None) -> None:
    """Initialize the learning rate to the maximum value at the start of training."""
    logs = logs or {}
    self.set_lr(self.max_lr)

on_train_end(logs: Optional[Dict[str, Any]] = None) -> None

Set weights to the values from the end of the most recent cycle for best performance.

Source code in src/capfinder/cyclic_learing_rate.py
def on_train_end(self, logs: Optional[Dict[str, Any]] = None) -> None:
    """Set weights to the values from the end of the most recent cycle for best performance."""
    if self.best_weights is not None:
        self.model.set_weights(self.best_weights)

set_lr(lr: float) -> None

Set the learning rate for the optimizer.

Source code in src/capfinder/cyclic_learing_rate.py
def set_lr(self, lr: float) -> None:
    """Set the learning rate for the optimizer."""
    self.model.optimizer.learning_rate.assign(lr)

data_loader

combine_datasets(features_dataset: tf.data.Dataset, labels_dataset: tf.data.Dataset, batch_size: int, num_timesteps: int) -> tf.data.Dataset

Combine feature and label datasets with padded batching.

Parameters:

features_dataset : tf.data.Dataset The dataset containing features. labels_dataset : tf.data.Dataset The dataset containing labels. batch_size : int The size of each batch. num_timesteps : int The number of time steps in each time series.

Returns:

tf.data.Dataset A combined dataset with features and labels, padded and batched.

Source code in src/capfinder/data_loader.py
def combine_datasets(
    features_dataset: tf.data.Dataset,
    labels_dataset: tf.data.Dataset,
    batch_size: int,
    num_timesteps: int,
) -> tf.data.Dataset:
    """Combine feature and label datasets with padded batching.

    Parameters:
    -----------
    features_dataset : tf.data.Dataset
        The dataset containing features.
    labels_dataset : tf.data.Dataset
        The dataset containing labels.
    batch_size : int
        The size of each batch.
    num_timesteps : int
        The number of time steps in each time series.

    Returns:
    --------
    tf.data.Dataset
        A combined dataset with features and labels, padded and batched.
    """
    dataset = tf.data.Dataset.zip((features_dataset, labels_dataset))
    dataset = dataset.padded_batch(
        batch_size, padded_shapes=([num_timesteps, 1], []), drop_remainder=True
    )
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    return dataset

load_datasets(train_x_path: str, train_y_path: str, val_x_path: str, val_y_path: str, batch_size: int, num_timesteps: int) -> Tuple[tf.data.Dataset, tf.data.Dataset]

Load and combine train and validation datasets.

Parameters:

train_x_path : str Path to the CSV file containing training features. train_y_path : str Path to the CSV file containing training labels. val_x_path : str Path to the CSV file containing validation features. val_y_path : str Path to the CSV file containing validation labels. batch_size : int The size of each batch. num_timesteps : int The number of time steps in each time series.

Returns:

Tuple[tf.data.Dataset, tf.data.Dataset] A tuple containing the combined training dataset and validation dataset.

Source code in src/capfinder/data_loader.py
def load_datasets(
    train_x_path: str,
    train_y_path: str,
    val_x_path: str,
    val_y_path: str,
    batch_size: int,
    num_timesteps: int,
) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
    """Load and combine train and validation datasets.

    Parameters:
    -----------
    train_x_path : str
        Path to the CSV file containing training features.
    train_y_path : str
        Path to the CSV file containing training labels.
    val_x_path : str
        Path to the CSV file containing validation features.
    val_y_path : str
        Path to the CSV file containing validation labels.
    batch_size : int
        The size of each batch.
    num_timesteps : int
        The number of time steps in each time series.

    Returns:
    --------
    Tuple[tf.data.Dataset, tf.data.Dataset]
        A tuple containing the combined training dataset and validation dataset.
    """
    train_features_dataset = load_feature_dataset(train_x_path, num_timesteps)
    train_labels_dataset = load_label_dataset(train_y_path)
    val_features_dataset = load_feature_dataset(val_x_path, num_timesteps)
    val_labels_dataset = load_label_dataset(val_y_path)

    train_dataset = combine_datasets(
        train_features_dataset, train_labels_dataset, batch_size, num_timesteps
    )
    val_dataset = combine_datasets(
        val_features_dataset, val_labels_dataset, batch_size, num_timesteps
    )

    return train_dataset, val_dataset

load_feature_dataset(file_path: str, num_timesteps: int) -> tf.data.Dataset

Load feature dataset from a CSV file.

Parameters:

file_path : str The path to the CSV file containing features. num_timesteps : int The number of time steps in each time series.

Returns:

tf.data.Dataset A TensorFlow dataset containing the parsed features.

Source code in src/capfinder/data_loader.py
def load_feature_dataset(file_path: str, num_timesteps: int) -> tf.data.Dataset:
    """Load feature dataset from a CSV file.

    Parameters:
    -----------
    file_path : str
        The path to the CSV file containing features.
    num_timesteps : int
        The number of time steps in each time series.

    Returns:
    --------
    tf.data.Dataset
        A TensorFlow dataset containing the parsed features.
    """
    dataset = tf.data.TextLineDataset(file_path)
    dataset = dataset.skip(1)  # Skip header row
    dataset = dataset.map(
        lambda x: parse_features(x, num_timesteps), num_parallel_calls=tf.data.AUTOTUNE
    )
    return dataset

load_label_dataset(file_path: str) -> tf.data.Dataset

Load label dataset from a CSV file.

Parameters:

file_path : str The path to the CSV file containing labels.

Returns:

tf.data.Dataset A TensorFlow dataset containing the parsed labels.

Source code in src/capfinder/data_loader.py
def load_label_dataset(file_path: str) -> tf.data.Dataset:
    """Load label dataset from a CSV file.

    Parameters:
    -----------
    file_path : str
        The path to the CSV file containing labels.

    Returns:
    --------
    tf.data.Dataset
        A TensorFlow dataset containing the parsed labels.
    """
    dataset = tf.data.TextLineDataset(file_path)
    dataset = dataset.skip(1)  # Skip header row
    dataset = dataset.map(parse_labels, num_parallel_calls=tf.data.AUTOTUNE)
    return dataset

parse_features(line: tf.Tensor, num_timesteps: int) -> tf.Tensor

Parse features from a CSV line and reshape them.

Parameters:

line : tf.Tensor A tensor representing a single line from the CSV file. num_timesteps : int The number of time steps in each time series.

Returns:

tf.Tensor A tensor of shape (num_timesteps, 1) containing the parsed features.

Source code in src/capfinder/data_loader.py
def parse_features(line: tf.Tensor, num_timesteps: int) -> tf.Tensor:
    """Parse features from a CSV line and reshape them.

    Parameters:
    -----------
    line : tf.Tensor
        A tensor representing a single line from the CSV file.
    num_timesteps : int
        The number of time steps in each time series.

    Returns:
    --------
    tf.Tensor
        A tensor of shape (num_timesteps, 1) containing the parsed features.
    """
    column_defaults = [[0.0]] * num_timesteps
    fields = tf.io.decode_csv(line, record_defaults=column_defaults)
    features = tf.reshape(fields, (num_timesteps, 1))  # Reshape to (timesteps, 1)
    return features

parse_labels(line: tf.Tensor) -> tf.Tensor

Parse labels from a CSV line.

Parameters:

line : tf.Tensor A tensor representing a single line from the CSV file.

Returns:

tf.Tensor A tensor containing the parsed label.

Source code in src/capfinder/data_loader.py
def parse_labels(line: tf.Tensor) -> tf.Tensor:
    """Parse labels from a CSV line.

    Parameters:
    -----------
    line : tf.Tensor
        A tensor representing a single line from the CSV file.

    Returns:
    --------
    tf.Tensor
        A tensor containing the parsed label.
    """
    label = tf.io.decode_csv(line, record_defaults=[[0]])
    return label[0]

download_model

create_version_info_file(output_dir: str, version: str) -> None

Create a file to store the version information. If any file with a name starting with "v" already exists in the output directory, delete it before creating a new one.

Parameters: output_dir (str): The directory where the version file will be created. version (str): The version string to be written to the file.

Returns: None

Source code in src/capfinder/download_model.py
def create_version_info_file(output_dir: str, version: str) -> None:
    """
    Create a file to store the version information. If any file with a name starting
    with "v" already exists in the output directory, delete it before creating a new one.

    Parameters:
    output_dir (str): The directory where the version file will be created.
    version (str): The version string to be written to the file.

    Returns:
    None
    """
    version_file = os.path.join(output_dir, f"v{version}")

    # Find and delete any existing version file
    existing_files = glob.glob(os.path.join(output_dir, "v*"))
    for file in existing_files:
        os.remove(file)

    # Create a new version file
    with open(version_file, "w") as f:
        f.write(version)

download_comet_model(workspace: str, model_name: str, version: str, output_dir: str = './', force_download: bool = False) -> None

Download a model from Comet ML using the official API.

Parameters: workspace (str): The Comet ML workspace name model_name (str): The name of the model version (str): The version of the model to download (use "latest" for the most recent version) output_dir (str): The local directory to save the downloaded model (default is current directory) force_download (bool): If True, download the model even if it already exists locally

Returns: str: The path to the model file (either existing or newly downloaded), or None if download failed

Source code in src/capfinder/download_model.py
def download_comet_model(
    workspace: str,
    model_name: str,
    version: str,
    output_dir: str = "./",
    force_download: bool = False,
) -> None:
    """
    Download a model from Comet ML using the official API.

    Parameters:
    workspace (str): The Comet ML workspace name
    model_name (str): The name of the model
    version (str): The version of the model to download (use "latest" for the most recent version)
    output_dir (str): The local directory to save the downloaded model (default is current directory)
    force_download (bool): If True, download the model even if it already exists locally

    Returns:
    str: The path to the model file (either existing or newly downloaded), or None if download failed
    """

    os.makedirs(output_dir, exist_ok=True)
    api = API()
    model = api.get_model(workspace, model_name)
    model.download(version, output_dir, expand=True)
    orig_model_name = model._get_assets(version)[0]["fileName"]
    rename_downloaded_model(output_dir, orig_model_name, f"{model_name}.keras")
    create_version_info_file(output_dir, version)

rename_downloaded_model(output_dir: str, orig_model_name: str, new_model_name: str) -> None

Renames the downloaded model file to a new name.

Parameters: output_dir (str): The directory where the model file is located. orig_model_name (str): The original name of the model file. new_model_name (str): The new name to rename the model file to.

Returns: None

Source code in src/capfinder/download_model.py
def rename_downloaded_model(
    output_dir: str, orig_model_name: str, new_model_name: str
) -> None:
    """
    Renames the downloaded model file to a new name.

    Parameters:
    output_dir (str): The directory where the model file is located.
    orig_model_name (str): The original name of the model file.
    new_model_name (str): The new name to rename the model file to.

    Returns:
    None
    """
    # Construct the new full path
    orig_path = os.path.join(output_dir, orig_model_name)
    new_path = os.path.join(output_dir, new_model_name)
    os.rename(orig_path, new_path)

encoder_model

CapfinderHyperModel

Bases: HyperModel

Custom HyperModel class to wrap the model building function for Capfinder.

This class defines the hyperparameter search space and builds the model based on the selected hyperparameters, including a variable number of MLP layers.

Attributes:

input_shape : Tuple[int, int] The shape of the input data. n_classes : int The number of output classes for the classification task. encoder_model : Optional[keras.Model] Stores the encoder part of the model, initialized during the build process.

Source code in src/capfinder/encoder_model.py
class CapfinderHyperModel(HyperModel):
    """
    Custom HyperModel class to wrap the model building function for Capfinder.

    This class defines the hyperparameter search space and builds the model
    based on the selected hyperparameters, including a variable number of MLP layers.

    Attributes:
    ----------
    input_shape : Tuple[int, int]
        The shape of the input data.
    n_classes : int
        The number of output classes for the classification task.
    encoder_model : Optional[keras.Model]
        Stores the encoder part of the model, initialized during the build process.
    """

    def __init__(self, input_shape: Tuple[int, int], n_classes: int):
        self.input_shape = input_shape
        self.n_classes = n_classes
        self.encoder_model: Optional[keras.Model] = None

    def build(self, hp: HyperParameters) -> Model:
        """
        Build and compile the model based on the hyperparameters.

        Parameters:
        ----------
        hp : HyperParameters
            The hyperparameters to use for building the model.

        Returns:
        -------
        Model
            The compiled Keras model.
        """
        # Hyperparameter for number of MLP layers
        num_mlp_layers = hp.Int("num_mlp_layers", min_value=1, max_value=3, step=1)

        # Create a list of MLP units for each layer
        mlp_units = [
            hp.Int(f"mlp_units_{i}", min_value=32, max_value=256, step=32)
            for i in range(num_mlp_layers)
        ]

        # Call the model builder function and obtain the full model and encoder model
        model, encoder_model = build_model(
            input_shape=self.input_shape,
            head_size=hp.Int("head_size", min_value=32, max_value=256, step=32),
            num_heads=hp.Int("num_heads", min_value=1, max_value=8, step=1),
            ff_dim=hp.Int("ff_dim", min_value=32, max_value=512, step=32),
            num_transformer_blocks=hp.Int(
                "num_transformer_blocks", min_value=1, max_value=8, step=1
            ),
            mlp_units=mlp_units,  # Now passing the list of MLP units
            n_classes=self.n_classes,
            mlp_dropout=hp.Float("mlp_dropout", min_value=0.1, max_value=0.5, step=0.1),
            dropout=hp.Float("dropout", min_value=0.1, max_value=0.5, step=0.1),
        )

        # Store the encoder model as an instance attribute for later access
        self.encoder_model = encoder_model

        # Define learning rate as a hyperparameter
        learning_rate = hp.Float(
            "learning_rate", min_value=1e-4, max_value=1e-2, sampling="log"
        )

        # Compile the model
        model.compile(
            loss="sparse_categorical_crossentropy",
            optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
            metrics=["sparse_categorical_accuracy"],
        )

        # Return only the full model to Keras Tuner
        return model

build(hp: HyperParameters) -> Model

Build and compile the model based on the hyperparameters.

Parameters:

hp : HyperParameters The hyperparameters to use for building the model.

Returns:

Model The compiled Keras model.

Source code in src/capfinder/encoder_model.py
def build(self, hp: HyperParameters) -> Model:
    """
    Build and compile the model based on the hyperparameters.

    Parameters:
    ----------
    hp : HyperParameters
        The hyperparameters to use for building the model.

    Returns:
    -------
    Model
        The compiled Keras model.
    """
    # Hyperparameter for number of MLP layers
    num_mlp_layers = hp.Int("num_mlp_layers", min_value=1, max_value=3, step=1)

    # Create a list of MLP units for each layer
    mlp_units = [
        hp.Int(f"mlp_units_{i}", min_value=32, max_value=256, step=32)
        for i in range(num_mlp_layers)
    ]

    # Call the model builder function and obtain the full model and encoder model
    model, encoder_model = build_model(
        input_shape=self.input_shape,
        head_size=hp.Int("head_size", min_value=32, max_value=256, step=32),
        num_heads=hp.Int("num_heads", min_value=1, max_value=8, step=1),
        ff_dim=hp.Int("ff_dim", min_value=32, max_value=512, step=32),
        num_transformer_blocks=hp.Int(
            "num_transformer_blocks", min_value=1, max_value=8, step=1
        ),
        mlp_units=mlp_units,  # Now passing the list of MLP units
        n_classes=self.n_classes,
        mlp_dropout=hp.Float("mlp_dropout", min_value=0.1, max_value=0.5, step=0.1),
        dropout=hp.Float("dropout", min_value=0.1, max_value=0.5, step=0.1),
    )

    # Store the encoder model as an instance attribute for later access
    self.encoder_model = encoder_model

    # Define learning rate as a hyperparameter
    learning_rate = hp.Float(
        "learning_rate", min_value=1e-4, max_value=1e-2, sampling="log"
    )

    # Compile the model
    model.compile(
        loss="sparse_categorical_crossentropy",
        optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
        metrics=["sparse_categorical_accuracy"],
    )

    # Return only the full model to Keras Tuner
    return model

build_model(input_shape: Tuple[int, int], head_size: int, num_heads: int, ff_dim: int, num_transformer_blocks: int, mlp_units: List[int], n_classes: int, dropout: float = 0.0, mlp_dropout: float = 0.0) -> Tuple[keras.Model, keras.Model]

Build a transformer-based neural network model and return the encoder output.

input_shape : Tuple[int, int] The shape of the input data. head_size : int The size of the attention heads in the transformer encoder. num_heads : int The number of attention heads in the transformer encoder. ff_dim : int The dimensionality of the feed-forward network in the transformer encoder. num_transformer_blocks : int The number of transformer encoder blocks in the model. mlp_units : List[int] A list containing the number of units for each layer in the MLP. n_classes : int The number of output classes (for classification tasks). dropout : float, optional The dropout rate applied in the transformer encoder. mlp_dropout : float, optional The dropout rate applied in the MLP.

Tuple[keras.Model, keras.Model]: A tuple containing the full model and the encoder model.

Source code in src/capfinder/encoder_model.py
def build_model(
    input_shape: Tuple[int, int],
    head_size: int,
    num_heads: int,
    ff_dim: int,
    num_transformer_blocks: int,
    mlp_units: List[int],
    n_classes: int,
    dropout: float = 0.0,
    mlp_dropout: float = 0.0,
) -> Tuple[keras.Model, keras.Model]:
    """
    Build a transformer-based neural network model and return the encoder output.

    Parameters:
    input_shape : Tuple[int, int]
        The shape of the input data.
    head_size : int
        The size of the attention heads in the transformer encoder.
    num_heads : int
        The number of attention heads in the transformer encoder.
    ff_dim : int
        The dimensionality of the feed-forward network in the transformer encoder.
    num_transformer_blocks : int
        The number of transformer encoder blocks in the model.
    mlp_units : List[int]
        A list containing the number of units for each layer in the MLP.
    n_classes : int
        The number of output classes (for classification tasks).
    dropout : float, optional
        The dropout rate applied in the transformer encoder.
    mlp_dropout : float, optional
        The dropout rate applied in the MLP.

    Returns:
    Tuple[keras.Model, keras.Model]:
        A tuple containing the full model and the encoder model.
    """
    # Input layer
    inputs = keras.Input(shape=input_shape)

    # Apply transformer encoder blocks and save the output of the encoder
    x = inputs
    for _ in range(num_transformer_blocks):
        x = transformer_encoder(x, head_size, num_heads, ff_dim, dropout)

    # Apply global average pooling
    x = layers.GlobalAveragePooling1D(data_format="channels_last")(x)

    # Save the encoder output
    encoder_output = x

    # Add multi-layer perceptron (MLP) layers
    for dim in mlp_units:
        x = layers.Dense(dim, activation="relu")(x)
        x = layers.Dropout(mlp_dropout)(x)

    # Add softmax output layer
    outputs = layers.Dense(n_classes, activation="softmax")(x)

    # Construct the full model
    model = keras.Model(inputs, outputs)

    # Create a model that produces only the encoder output
    encoder_model = keras.Model(inputs, encoder_output)

    # Return the full model and the encoder model
    return model, encoder_model

transformer_encoder(inputs: keras.layers.Layer, head_size: int, num_heads: int, ff_dim: int, dropout: Optional[float] = 0.0) -> keras.layers.Layer

Create a transformer encoder block.

The transformer encoder block consists of a multi-head attention layer followed by layer normalization and a feed-forward network.

Parameters:

inputs : keras.layers.Layer The input layer or tensor for the encoder block. head_size : int The size of the attention heads. num_heads : int The number of attention heads. ff_dim : int The dimensionality of the feed-forward network. dropout : float, optional The dropout rate applied after the attention layer and within the feed-forward network. Default is 0.0.

Returns:

keras.layers.Layer The output layer of the encoder block, which can be used as input for the next layer in a neural network.

Source code in src/capfinder/encoder_model.py
def transformer_encoder(
    inputs: keras.layers.Layer,
    head_size: int,
    num_heads: int,
    ff_dim: int,
    dropout: Optional[float] = 0.0,
) -> keras.layers.Layer:
    """
    Create a transformer encoder block.

    The transformer encoder block consists of a multi-head attention layer
    followed by layer normalization and a feed-forward network.

    Parameters:
    ----------
    inputs : keras.layers.Layer
        The input layer or tensor for the encoder block.
    head_size : int
        The size of the attention heads.
    num_heads : int
        The number of attention heads.
    ff_dim : int
        The dimensionality of the feed-forward network.
    dropout : float, optional
        The dropout rate applied after the attention layer and within the feed-forward network. Default is 0.0.

    Returns:
    -------
    keras.layers.Layer
        The output layer of the encoder block, which can be used as input for the next layer in a neural network.
    """
    # Multi-head attention layer with dropout and layer normalization
    x = layers.MultiHeadAttention(
        key_dim=head_size, num_heads=num_heads, dropout=dropout
    )(inputs, inputs)
    x = layers.Dropout(dropout)(x)
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    res = x + inputs

    # Feed-forward network with convolutional layers, dropout, and layer normalization
    x = layers.Conv1D(filters=ff_dim, kernel_size=1, activation="relu")(res)
    x = layers.Dropout(dropout)(x)
    x = layers.Conv1D(filters=inputs.shape[-1], kernel_size=1)(x)
    x = layers.LayerNormalization(epsilon=1e-6)(x)

    # Masking for zero padding
    mask = layers.Lambda(
        lambda x: keras.ops.cast(keras.ops.equal(x, 0.0), dtype="float32"),
        output_shape=(inputs.shape[1], 1),  # Specify output shape
    )(inputs)
    x = (x + res) * mask  # Apply mask to the summed output (including residual)

    return x

find_ote_test

The module contains the code to find OTE sequence in test data -- where we only know the context to the left of the NNNNNN region -- and its location with high-confidence. The modules can process one read at a time or all reads in a FASTQ file or folder of FASTQ files.

Author: Adnan M. Niazi Date: 2024-02-28

cnt_match_mismatch_gaps(aln_str: str) -> Tuple[int, int, int]

Takes an alignment string and counts the number of matches, mismatches, and gaps.

Parameters:

Name Type Description Default
aln_str str

The alignment string.

required

Returns:

Name Type Description
match_cnt int

The number of matches in the alignment string.

mismatch_cnt int

The number of mismatches in the alignment string.

gap_cnt int

The number of gaps in the alignment string.

Source code in src/capfinder/find_ote_test.py
def cnt_match_mismatch_gaps(aln_str: str) -> Tuple[int, int, int]:
    """
    Takes an alignment string and counts the number of matches, mismatches, and gaps.

    Args:
        aln_str (str): The alignment string.

    Returns:
        match_cnt (int): The number of matches in the alignment string.
        mismatch_cnt (int): The number of mismatches in the alignment string.
        gap_cnt (int): The number of gaps in the alignment string.
    """
    match_cnt = 0
    mismatch_cnt = 0
    gap_cnt = 0
    for aln_chr in aln_str:
        if aln_chr == "|":
            match_cnt += 1
        elif aln_chr == "/":
            mismatch_cnt += 1
        elif aln_chr == " ":
            gap_cnt += 1
    return match_cnt, mismatch_cnt, gap_cnt

dispatcher(input_path: str, reference: str, cap0_pos: int, num_processes: int, output_folder: str) -> None

Check if the input path is a file or folder, and call the appropriate function to process the input.

Parameters:

Name Type Description Default
input_path str

The path to the FASTQ file or folder.

required
reference str

The reference sequence to align the read to.

required
num_processes int

The number of processes to use for parallel processing.

required
output_folder str

The folder where worker output files will be stored.

required

Returns:

Type Description
None

None

Source code in src/capfinder/find_ote_test.py
def dispatcher(
    input_path: str,
    reference: str,
    cap0_pos: int,
    num_processes: int,
    output_folder: str,
) -> None:
    """
    Check if the input path is a file or folder, and call the appropriate function to process the input.

    Args:
        input_path (str): The path to the FASTQ file or folder.
        reference (str): The reference sequence to align the read to.
        num_processes (int): The number of processes to use for parallel processing.
        output_folder (str): The folder where worker output files will be stored.

    Returns:
        None
    """
    if os.path.isfile(input_path):
        process_fastq_file(
            input_path, reference, cap0_pos, num_processes, output_folder
        )
    elif os.path.isdir(input_path):
        process_fastq_folder(
            input_path, reference, cap0_pos, num_processes, output_folder
        )
    else:
        raise ValueError("Error! Invalid path type. Path must be a file or folder.")

find_ote_test(input_path: str, reference: str, cap0_pos: int, num_processes: int, output_folder: str) -> None

Main function to process a FASTQ file or folder of FASTQ files to find OTEs in the reads. The function is suitable only for testing data where only the OTE sequence is known and the N1N2 cap bases and any bases 3' of them are unknown.

Parameters:

Name Type Description Default
input_path str

The path to the FASTQ file or folder.

required
reference str

The reference sequence to align the read to.

required
num_processes int

The number of processes to use for parallel processing.

required
output_folder str

The folder where worker output files will be stored.

required

Returns: None

Source code in src/capfinder/find_ote_test.py
def find_ote_test(
    input_path: str,
    reference: str,
    cap0_pos: int,
    num_processes: int,
    output_folder: str,
) -> None:
    """
    Main function to process a FASTQ file or folder of FASTQ files to find OTEs
    in the reads. The function is suitable only for testing data where only the OTE
    sequence is known and the N1N2 cap bases and any bases 3' of them are unknown.

    Args:
        input_path (str): The path to the FASTQ file or folder.
        reference (str): The reference sequence to align the read to.
        num_processes (int): The number of processes to use for parallel processing.
        output_folder (str): The folder where worker output files will be stored.
    Returns:
        None
    """
    dispatcher(input_path, reference, cap0_pos, num_processes, output_folder)

has_good_aln_in_5prime_flanking_region(match_cnt: int, mismatch_cnt: int, gap_cnt: int) -> bool

Checks if the alignment in the flanking region before the cap is good.

Parameters:

Name Type Description Default
match_cnt int

The number of matches in the flanking region.

required
mismatch_cnt int

The number of mismatches in the flanking region.

required
gap_cnt int

The number of gaps in the flanking region.

required

Returns:

Name Type Description
bool bool

True if the alignment in the flanking region is good, False otherwise.

Source code in src/capfinder/find_ote_test.py
def has_good_aln_in_5prime_flanking_region(
    match_cnt: int, mismatch_cnt: int, gap_cnt: int
) -> bool:
    """
    Checks if the alignment in the flanking region before the cap is good.

    Args:
        match_cnt (int): The number of matches in the flanking region.
        mismatch_cnt (int): The number of mismatches in the flanking region.
        gap_cnt (int): The number of gaps in the flanking region.

    Returns:
        bool: True if the alignment in the flanking region is good, False otherwise.
    """
    if (mismatch_cnt > match_cnt) or (gap_cnt > match_cnt):
        return False
    else:
        return True

has_good_aln_in_n_region(match_cnt: int, mismatch_cnt: int, gap_cnt: int) -> bool

Checks if the alignment in the NNNNNN region is good.

Parameters:

Name Type Description Default
match_cnt int

The number of matches in the NNNNNN region.

required
mismatch_cnt int

The number of mismatches in the NNNNNN region.

required
gap_cnt int

The number of gaps in the NNNNNN region.

required

Returns:

Name Type Description
bool bool

True if the alignment in the NNNNNN region is good, False otherwise.

Source code in src/capfinder/find_ote_test.py
def has_good_aln_in_n_region(match_cnt: int, mismatch_cnt: int, gap_cnt: int) -> bool:
    """
    Checks if the alignment in the NNNNNN region is good.

    Args:
        match_cnt (int): The number of matches in the NNNNNN region.
        mismatch_cnt (int): The number of mismatches in the NNNNNN region.
        gap_cnt (int): The number of gaps in the NNNNNN region.

    Returns:
        bool: True if the alignment in the NNNNNN region is good, False otherwise.
    """
    # For a good alignment in NNNNNN region, the number of mismatches should be
    # greater than the number of gaps
    if mismatch_cnt >= gap_cnt:
        return True
    else:
        return False

make_coordinates(aln_str: str, ref_str: str) -> List[int]

Walk along the alignment string and make an incrementing index where there is a match, mismatch, and deletions. For gaps in the alignment string, it output a -1 in the index list.

Parameters:

Name Type Description Default
aln_str str

The alignment string.

required
ref_str str

The reference string.

required

Returns:

Name Type Description
coord_list list

A list of indices corresponding to the alignment string.

Source code in src/capfinder/find_ote_test.py
def make_coordinates(aln_str: str, ref_str: str) -> List[int]:
    """
    Walk along the alignment string and make an incrementing index
    where there is a match, mismatch, and deletions. For gaps in
    the alignment string, it output a -1 in the index list.

    Args:
        aln_str (str): The alignment string.
        ref_str (str): The reference string.

    Returns:
        coord_list (list): A list of indices corresponding to the alignment string.
    """
    # Make index coordinates along the alignment string
    coord_list = []
    cnt = 0
    for idx, aln_chr in enumerate(aln_str):
        if aln_chr != " ":
            coord_list.append(cnt)
            cnt += 1
        else:
            if ref_str[idx] != "-":  # handle deletions
                coord_list.append(cnt)
                cnt += 1
            else:
                coord_list.append(-1)

    # Go in reverse in the coord_list, and put -1 of all the places
    # where there is a gap in the alignment string. Break out when the
    # first non-gap character is encountered.
    for idx in range(len(coord_list) - 1, -1, -1):
        if aln_str[idx] == " ":
            coord_list[idx] = -1
        else:
            break
    return coord_list

process_fastq_file(fastq_filepath: str, reference: str, cap0_pos: int, num_processes: int, output_folder: str) -> None

Process a single FASTQ file. The function reads the FASTQ file, and processes each read in parallel.

Parameters:

Name Type Description Default
fastq_filepath str

The path to the FASTQ file.

required
reference str

The reference sequence to align the read to.

required
num_processes int

The number of processes to use for parallel processing.

required
output_folder str

The folder where worker output files will be stored.

required

Returns:

Type Description
None

None

Source code in src/capfinder/find_ote_test.py
def process_fastq_file(
    fastq_filepath: str,
    reference: str,
    cap0_pos: int,
    num_processes: int,
    output_folder: str,
) -> None:
    """
    Process a single FASTQ file. The function reads the FASTQ file, and processes each read in parallel.

    Args:
        fastq_filepath (str): The path to the FASTQ file.
        reference (str): The reference sequence to align the read to.
        num_processes (int): The number of processes to use for parallel processing.
        output_folder (str): The folder where worker output files will be stored.

    Returns:
        None
    """

    # Make output file name
    directory, filename = os.path.split(fastq_filepath)
    filename_no_extension, extension = os.path.splitext(filename)
    os.path.join(output_folder, f"{filename_no_extension}.txt")

    with file_opener(fastq_filepath) as fastq_file:
        records = list(SeqIO.parse(fastq_file, "fastq"))
        total_records = len(records)

        with WorkerPool(n_jobs=num_processes) as pool:
            results = pool.map(
                process_read,
                [(record, reference, cap0_pos) for record in records],
                iterable_len=total_records,
                progress_bar=True,
            )
            write_csv(
                results,
                output_filepath=os.path.join(
                    output_folder,
                    filename_no_extension + "_test_ote_search_results.csv",
                ),
            )

process_fastq_folder(folder_path: str, reference: str, cap0_pos: int, num_processes: int, output_folder: str) -> None

Process all FASTQ files in a folder. The function reads all FASTQ files in a folder, and feeds one FASTQ at a time which to a prcessing function that processes reads in this FASTQ file in parallel.

Parameters:

Name Type Description Default
folder_path str

The path to the folder containing FASTQ files.

required
reference str

The reference sequence to align the read to.

required
cap0_pos int

The position of the first cap base (N1) in the reference sequence (0-indexed).

required
num_processes int

The number of processes to use for parallel processing.

required
output_folder str

The folder where worker output files will be stored.

required

Returns:

Type Description
None

None

Source code in src/capfinder/find_ote_test.py
def process_fastq_folder(
    folder_path: str,
    reference: str,
    cap0_pos: int,
    num_processes: int,
    output_folder: str,
) -> None:
    """
    Process all FASTQ files in a folder. The function reads all FASTQ files in a folder,
    and feeds one FASTQ at a time which to a prcessing function that processes reads in this
    FASTQ file in parallel.

    Args:
        folder_path (str): The path to the folder containing FASTQ files.
        reference (str): The reference sequence to align the read to.
        cap0_pos (int): The position of the first cap base (N1) in the reference sequence (0-indexed).
        num_processes (int): The number of processes to use for parallel processing.
        output_folder (str): The folder where worker output files will be stored.

    Returns:
        None
    """
    # List all files in the folder
    for root, _, files in os.walk(folder_path):
        for file_name in files:
            if file_name.endswith((".fastq", ".fastq.gz")):
                file_path = os.path.join(root, file_name)
                process_fastq_file(
                    file_path, reference, cap0_pos, num_processes, output_folder
                )

process_read(record: Any, reference: str, cap0_pos: int) -> Dict[str, Any]

Process a single read from a FASTQ file. The function alnigns the read to the reference, and checks if the alignment in the NNNNNN region and the flanking regions is good. If the alignment is good, then the function returns the read ID, alignment score, and the positions of the left flanking region, cap0 base, and the right flanking region in the read's FASTQ sequence. If the alignment is bad, then the function returns the read ID, alignment score, and the reason why the alignment is bad.

Parameters:

Name Type Description Default
record SeqRecord

A single read from a FASTQ file.

required
reference str

The reference sequence to align the read to.

required
cap0_pos int

The position of the first cap base in the reference sequence (0-indexed).

required

Returns:

Name Type Description
out_ds dict

A dictionary containing the following keys: read_id (str): The identifier of the sequence read. read_type (str): The type of the read, which can be 'good' or 'bad' reason (str or None): The reason for the failed alignment, if available. alignment_score (float): The alignment score for the read. left_flanking_region_start_fastq_pos (int or None): The starting position of the left flanking region in the FASTQ file, if available. cap_n1_minus_1_read_fastq_pos (int or None): The position of the caps N1 base in the FASTQ file (0-indexed), if available. right_flanking_region_start_fastq_pos (int or None): The starting position of the right flanking region in the FASTQ file, if available.

Source code in src/capfinder/find_ote_test.py
def process_read(record: Any, reference: str, cap0_pos: int) -> Dict[str, Any]:
    """
    Process a single read from a FASTQ file. The function alnigns the read to the reference,
    and checks if the alignment in the NNNNNN region and the flanking regions is good. If the
    alignment is good, then the function returns the read ID, alignment score, and the
    positions of the left flanking region, cap0 base, and the right flanking region in the
    read's FASTQ sequence. If the alignment is bad, then the function returns the read ID,
    alignment score, and the reason why the alignment is bad.

    Args:
        record (SeqRecord): A single read from a FASTQ file.
        reference (str): The reference sequence to align the read to.
        cap0_pos (int): The position of the first cap base in the reference sequence (0-indexed).

    Returns:
        out_ds (dict): A dictionary containing the following keys:
            read_id (str): The identifier of the sequence read.
            read_type (str): The type of the read, which can be 'good' or 'bad'
            reason (str or None): The reason for the failed alignment, if available.
            alignment_score (float): The alignment score for the read.
            left_flanking_region_start_fastq_pos (int or None): The starting position of the left flanking region
            in the FASTQ file, if available.
            cap_n1_minus_1_read_fastq_pos (int or None): The position of the caps N1 base in the FASTQ file (0-indexed), if available.
            right_flanking_region_start_fastq_pos (int or None): The starting position of the right flanking region
            in the FASTQ file, if available.
    """
    # Get alignment
    sequence = str(record.seq)
    fasta_length = len(sequence)

    with contextlib.redirect_stdout(None):
        qry_str, aln_str, ref_str, aln_score = align(
            query_seq=sequence, target_seq=reference, pretty_print_alns=False
        )

    # define a data structure to return when the read OTE is not found
    out_ds_failed = {
        "read_id": record.id,
        "read_type": "bad",
        "reason": None,
        "alignment_score": aln_score,
        "left_flanking_region_start_fastq_pos": None,
        "cap_n1_minus_1_read_fastq_pos": None,
        "right_flanking_region_start_fastq_pos": None,
        "roi_fasta": None,
        "fasta_length": fasta_length,
    }

    # For low quality alignments, return None
    if aln_score < 20:
        out_ds_failed["reason"] = "low_aln_score"
        return out_ds_failed

    # Make index coordinates along the reference
    coord_list = make_coordinates(aln_str, ref_str)

    # Check if the the first base 5 prime of the cap N1 base is in
    # the coordinates list. If not then the alignment did not
    # even reach the cap, so it is a bad read then.
    try:
        cap_n1_minus_1_idx = coord_list.index(cap0_pos - 1)
    except Exception:
        out_ds_failed["reason"] = "aln_does_not_reach_the_cap_base"
        return out_ds_failed

    # 2. Define regions in which to check for good alignment
    before_nnn_region = (
        cap_n1_minus_1_idx - BEFORE_N_REGION_WINDOW_LEN,
        cap_n1_minus_1_idx + 1,
    )

    # 3. Extract alignment strings for each region
    aln_str_before_nnn_region = aln_str[before_nnn_region[0] : before_nnn_region[1]]

    # 4. Count matches, mismatches, and gaps in each region
    bn_match_cnt, bn_mismatch_cnt, bn_gap_cnt = cnt_match_mismatch_gaps(
        aln_str_before_nnn_region
    )

    # 5. Is there a good alignment region flanking 5' of the cap?
    has_good_aln_before_n_region = has_good_aln_in_5prime_flanking_region(
        bn_match_cnt, bn_mismatch_cnt, bn_gap_cnt
    )
    if not (has_good_aln_before_n_region):
        out_ds_failed["reason"] = "bad_alignment_before_the_cap"
        return out_ds_failed

    # Find the position of cap N1-1 base in read's sequence (0-based indexing)
    cap_n1_minus_1_read_fastq_pos = (
        qry_str[:cap_n1_minus_1_idx].replace("-", "").count("")
    )

    # To reach the 5' end of the left flanking region, we need to find
    # to over the alignment string on count the matches and mismatches
    # but not the gaps
    idx = cap_n1_minus_1_idx
    cnt = 0
    while True:
        if cnt == NUM_CAP_FLANKING_BASES:
            break
        else:
            if aln_str[idx] != " ":
                cnt += 1
                left_flanking_region_start_idx = idx
        idx -= 1

    left_flanking_region_start_fastq_pos = (
        qry_str[:left_flanking_region_start_idx].replace("-", "").count("") - 1
    )
    right_flanking_region_start_fastq_pos = (
        cap_n1_minus_1_read_fastq_pos + NUM_CAP_FLANKING_BASES + 1
    )

    roi_fasta = sequence[
        left_flanking_region_start_fastq_pos:right_flanking_region_start_fastq_pos
    ]

    out_ds_passed = {
        "read_id": record.id,
        "read_type": "good",
        "reason": "good_alignment_in_cap-flanking_regions",
        "alignment_score": aln_score,
        "left_flanking_region_start_fastq_pos": left_flanking_region_start_fastq_pos,
        "cap_n1_minus_1_read_fastq_pos": cap_n1_minus_1_read_fastq_pos,
        "right_flanking_region_start_fastq_pos": right_flanking_region_start_fastq_pos,
        "roi_fasta": roi_fasta,
        "fasta_length": fasta_length,
    }

    # A fix to avoid outputting blank ROI when the read is shorter than
    # computed ROI coordinates
    sl = len(sequence)
    if (left_flanking_region_start_fastq_pos) > sl or (
        right_flanking_region_start_fastq_pos
    ) > sl:
        output = out_ds_failed
    else:
        output = out_ds_passed

    return output

write_csv(resutls_list: List[dict], output_filepath: str) -> None

Take a list of dictionaries and write them to a CSV file.

Parameters:

Name Type Description Default
resutls_list list

A list of dictionaries.

required
output_filepath str

The path to the output CSV file.

required

Returns:

Type Description
None

None

Source code in src/capfinder/find_ote_test.py
def write_csv(resutls_list: List[dict], output_filepath: str) -> None:
    """
    Take a list of dictionaries and write them to a CSV file.

    Args:
        resutls_list (list): A list of dictionaries.
        output_filepath (str): The path to the output CSV file.

    Returns:
        None
    """
    # Specify the CSV column headers based on the dictionary keys
    fieldnames = resutls_list[0].keys()

    # Create and write to the CSV file
    with open(output_filepath, "w", newline="") as csv_file:
        writer = csv.DictWriter(csv_file, fieldnames=fieldnames)

        # Write the header row
        writer.writeheader()

        # Write the data rows
        writer.writerows(resutls_list)

find_ote_train

The module contains the code to find OTE sequence in training data -- where we know both the left and right context to the NNNNNN region -- and its location with high-confidence. The modules can process one read at a time or all reads in a FASTQ file or folder of FASTQ files.

Author: Adnan M. Niazi Date: 2024-02-28

cnt_match_mismatch_gaps(aln_str: str) -> Tuple[int, int, int]

Takes an alignment string and counts the number of matches, mismatches, and gaps.

Parameters:

Name Type Description Default
aln_str str

The alignment string.

required

Returns:

Name Type Description
match_cnt int

The number of matches in the alignment string.

mismatch_cnt int

The number of mismatches in the alignment string.

gap_cnt int

The number of gaps in the alignment string.

Source code in src/capfinder/find_ote_train.py
def cnt_match_mismatch_gaps(aln_str: str) -> Tuple[int, int, int]:
    """
    Takes an alignment string and counts the number of matches, mismatches, and gaps.

    Args:
        aln_str (str): The alignment string.

    Returns:
        match_cnt (int): The number of matches in the alignment string.
        mismatch_cnt (int): The number of mismatches in the alignment string.
        gap_cnt (int): The number of gaps in the alignment string.
    """
    match_cnt = 0
    mismatch_cnt = 0
    gap_cnt = 0
    for aln_chr in aln_str:
        if aln_chr == "|":
            match_cnt += 1
        elif aln_chr == "/":
            mismatch_cnt += 1
        elif aln_chr == " ":
            gap_cnt += 1
    return match_cnt, mismatch_cnt, gap_cnt

dispatcher(input_path: str, reference: str, cap0_pos: int, num_processes: int, output_folder: str) -> None

Check if the input path is a file or folder, and call the appropriate function to process the input.

Parameters:

Name Type Description Default
input_path str

The path to the FASTQ file or folder.

required
reference str

The reference sequence to align the read to.

required
num_processes int

The number of processes to use for parallel processing.

required
output_folder str

The folder where worker output files will be stored.

required

Returns:

Type Description
None

None

Source code in src/capfinder/find_ote_train.py
def dispatcher(
    input_path: str,
    reference: str,
    cap0_pos: int,
    num_processes: int,
    output_folder: str,
) -> None:
    """
    Check if the input path is a file or folder, and call the appropriate function to process the input.

    Args:
        input_path (str): The path to the FASTQ file or folder.
        reference (str): The reference sequence to align the read to.
        num_processes (int): The number of processes to use for parallel processing.
        output_folder (str): The folder where worker output files will be stored.

    Returns:
        None
    """
    if os.path.isfile(input_path):
        process_fastq_file(
            input_path, reference, cap0_pos, num_processes, output_folder
        )
    elif os.path.isdir(input_path):
        process_fastq_folder(
            input_path, reference, cap0_pos, num_processes, output_folder
        )
    else:
        raise ValueError("Error! Invalid path type. Path must be a file or folder.")

find_ote_train(input_path: str, reference: str, cap0_pos: int, num_processes: int, output_folder: str) -> None

Main function to process a FASTQ file or folder of FASTQ files ot find OTEs in the reads.

Parameters:

Name Type Description Default
input_path str

The path to the FASTQ file or folder.

required
reference str

The reference sequence to align the read to.

required
num_processes int

The number of processes to use for parallel processing.

required
output_folder str

The folder where worker output files will be stored.

required

Returns: None

Source code in src/capfinder/find_ote_train.py
def find_ote_train(
    input_path: str,
    reference: str,
    cap0_pos: int,
    num_processes: int,
    output_folder: str,
) -> None:
    """
    Main function to process a FASTQ file or folder of FASTQ files ot find OTEs
    in the reads.

    Args:
        input_path (str): The path to the FASTQ file or folder.
        reference (str): The reference sequence to align the read to.
        num_processes (int): The number of processes to use for parallel processing.
        output_folder (str): The folder where worker output files will be stored.
    Returns:
        None
    """
    dispatcher(input_path, reference, cap0_pos, num_processes, output_folder)

has_good_aln_in_n_region(match_cnt: int, mismatch_cnt: int, gap_cnt: int) -> bool

Checks if the alignment in the NNNNNN region is good.

Parameters:

Name Type Description Default
match_cnt int

The number of matches in the NNNNNN region.

required
mismatch_cnt int

The number of mismatches in the NNNNNN region.

required
gap_cnt int

The number of gaps in the NNNNNN region.

required

Returns:

Name Type Description
bool bool

True if the alignment in the NNNNNN region is good, False otherwise.

Source code in src/capfinder/find_ote_train.py
def has_good_aln_in_n_region(match_cnt: int, mismatch_cnt: int, gap_cnt: int) -> bool:
    """
    Checks if the alignment in the NNNNNN region is good.

    Args:
        match_cnt (int): The number of matches in the NNNNNN region.
        mismatch_cnt (int): The number of mismatches in the NNNNNN region.
        gap_cnt (int): The number of gaps in the NNNNNN region.

    Returns:
        bool: True if the alignment in the NNNNNN region is good, False otherwise.
    """
    # For a good alignment in NNNNNN region, the number of mismatches should be
    # greater than the number of gaps
    if mismatch_cnt >= gap_cnt:
        return True
    else:
        return False

has_good_aln_ns_flanking_region(match_cnt: int, mismatch_cnt: int, gap_cnt: int) -> bool

Checks if the alignment in the flanking region before or after the NNNNNN region is good.

Parameters:

Name Type Description Default
match_cnt int

The number of matches in the flanking region.

required
mismatch_cnt int

The number of mismatches in the flanking region.

required
gap_cnt int

The number of gaps in the flanking region.

required

Returns:

Name Type Description
bool bool

True if the alignment in the flanking region is good, False otherwise.

Source code in src/capfinder/find_ote_train.py
def has_good_aln_ns_flanking_region(
    match_cnt: int, mismatch_cnt: int, gap_cnt: int
) -> bool:
    """
    Checks if the alignment in the flanking region before or after the NNNNNN region is good.

    Args:
        match_cnt (int): The number of matches in the flanking region.
        mismatch_cnt (int): The number of mismatches in the flanking region.
        gap_cnt (int): The number of gaps in the flanking region.

    Returns:
        bool: True if the alignment in the flanking region is good, False otherwise.
    """
    if (mismatch_cnt > match_cnt) or (gap_cnt > match_cnt):
        return False
    else:
        return True

make_coordinates(aln_str: str, ref_str: str) -> List[int]

Walk along the alignment string and make an incrementing index where there is a match, mismatch, and deletions. For gaps in the alignment string, it output a -1 in the index list.

Parameters:

Name Type Description Default
aln_str str

The alignment string.

required
ref_str str

The reference string.

required

Returns:

Name Type Description
coord_list list

A list of indices corresponding to the alignment string.

Source code in src/capfinder/find_ote_train.py
def make_coordinates(aln_str: str, ref_str: str) -> List[int]:
    """
    Walk along the alignment string and make an incrementing index
    where there is a match, mismatch, and deletions. For gaps in
    the alignment string, it output a -1 in the index list.

    Args:
        aln_str (str): The alignment string.
        ref_str (str): The reference string.

    Returns:
        coord_list (list): A list of indices corresponding to the alignment string.
    """
    # Make index coordinates along the alignment string
    coord_list = []
    cnt = 0
    for idx, aln_chr in enumerate(aln_str):
        if aln_chr != " ":
            coord_list.append(cnt)
            cnt += 1
        else:
            if ref_str[idx] != "-":  # handle deletions
                coord_list.append(cnt)
                cnt += 1
            else:
                coord_list.append(-1)

    # Go in reverse in the coord_list, and put -1 of all the places
    # where there is a gap in the alignment string. Break out when the
    # first non-gap character is encountered.
    for idx in range(len(coord_list) - 1, -1, -1):
        if aln_str[idx] == " ":
            coord_list[idx] = -1
        else:
            break
    return coord_list

process_fastq_file(fastq_filepath: str, reference: str, cap0_pos: int, num_processes: int, output_folder: str) -> None

Process a single FASTQ file. The function reads the FASTQ file, and processes each read in parallel.

Parameters:

Name Type Description Default
fastq_filepath str

The path to the FASTQ file.

required
reference str

The reference sequence to align the read to.

required
num_processes int

The number of processes to use for parallel processing.

required
output_folder str

The folder where worker output files will be stored.

required

Returns:

Type Description
None

None

Source code in src/capfinder/find_ote_train.py
def process_fastq_file(
    fastq_filepath: str,
    reference: str,
    cap0_pos: int,
    num_processes: int,
    output_folder: str,
) -> None:
    """
    Process a single FASTQ file. The function reads the FASTQ file, and processes each read in parallel.

    Args:
        fastq_filepath (str): The path to the FASTQ file.
        reference (str): The reference sequence to align the read to.
        num_processes (int): The number of processes to use for parallel processing.
        output_folder (str): The folder where worker output files will be stored.

    Returns:
        None
    """

    # Make output file name
    directory, filename = os.path.split(fastq_filepath)
    filename_no_extension, extension = os.path.splitext(filename)
    os.path.join(output_folder, f"{filename_no_extension}.txt")

    with file_opener(fastq_filepath) as fastq_file:
        records = list(SeqIO.parse(fastq_file, "fastq"))
        total_records = len(records)

        with WorkerPool(n_jobs=num_processes) as pool:
            results = pool.map(
                process_read,
                [(record, reference, cap0_pos) for record in records],
                iterable_len=total_records,
                progress_bar=True,
            )
            write_csv(
                results,
                output_filepath=os.path.join(
                    output_folder,
                    filename_no_extension + "_train_ote_search_results.csv",
                ),
            )

process_fastq_folder(folder_path: str, reference: str, cap0_pos: int, num_processes: int, output_folder: str) -> None

Process all FASTQ files in a folder. The function reads all FASTQ files in a folder, and feeds one FASTQ at a time which to a prcessing function that processes reads in this FASTQ file in parallel.

Parameters:

Name Type Description Default
folder_path str

The path to the folder containing FASTQ files.

required
reference str

The reference sequence to align the read to.

required
cap0_pos int

The position of the first cap base (N1) in the reference sequence (0-indexed).

required
num_processes int

The number of processes to use for parallel processing.

required
output_folder str

The folder where worker output files will be stored.

required

Returns:

Type Description
None

None

Source code in src/capfinder/find_ote_train.py
def process_fastq_folder(
    folder_path: str,
    reference: str,
    cap0_pos: int,
    num_processes: int,
    output_folder: str,
) -> None:
    """
    Process all FASTQ files in a folder. The function reads all FASTQ files in a folder,
    and feeds one FASTQ at a time which to a prcessing function that processes reads in this
    FASTQ file in parallel.

    Args:
        folder_path (str): The path to the folder containing FASTQ files.
        reference (str): The reference sequence to align the read to.
        cap0_pos (int): The position of the first cap base (N1) in the reference sequence (0-indexed).
        num_processes (int): The number of processes to use for parallel processing.
        output_folder (str): The folder where worker output files will be stored.

    Returns:
        None
    """
    # List all files in the folder
    for root, _, files in os.walk(folder_path):
        for file_name in files:
            if file_name.endswith((".fastq", ".fastq.gz")):
                file_path = os.path.join(root, file_name)
                process_fastq_file(
                    file_path, reference, cap0_pos, num_processes, output_folder
                )

process_read(record: Any, reference: str, cap0_pos: int) -> Dict[str, Any]

Process a single read from a FASTQ file. The function alnigns the read to the reference, and checks if the alignment in the NNNNNN region and the flanking regions is good. If the alignment is good, then the function returns the read ID, alignment score, and the positions of the left flanking region, cap0 base, and the right flanking region in the read's FASTQ sequence. If the alignment is bad, then the function returns the read ID, alignment score, and the reason why the alignment is bad.

Parameters:

Name Type Description Default
record SeqRecord

A single read from a FASTQ file.

required
reference str

The reference sequence to align the read to.

required
cap0_pos int

The position of the first cap base in the reference sequence (0-indexed).

required

Returns:

Name Type Description
out_ds dict

A dictionary containing the following keys: read_id (str): The identifier of the sequence read. read_type (str): The type of the read, which can be 'good' or 'bad' reason (str or None): The reason for the failed alignment, if available. alignment_score (float): The alignment score for the read. left_flanking_region_start_fastq_pos (int or None): The starting position of the left flanking region in the FASTQ file, if available. cap0_read_fastq_pos (int or None): The position of the caps N1 base in the FASTQ file (0-indexed), if available. right_flanking_region_start_fastq_pos (int or None): The starting position of the right flanking region in the FASTQ file, if available.

Source code in src/capfinder/find_ote_train.py
def process_read(record: Any, reference: str, cap0_pos: int) -> Dict[str, Any]:
    """
    Process a single read from a FASTQ file. The function alnigns the read to the reference,
    and checks if the alignment in the NNNNNN region and the flanking regions is good. If the
    alignment is good, then the function returns the read ID, alignment score, and the
    positions of the left flanking region, cap0 base, and the right flanking region in the
    read's FASTQ sequence. If the alignment is bad, then the function returns the read ID,
    alignment score, and the reason why the alignment is bad.

    Args:
        record (SeqRecord): A single read from a FASTQ file.
        reference (str): The reference sequence to align the read to.
        cap0_pos (int): The position of the first cap base in the reference sequence (0-indexed).

    Returns:
        out_ds (dict): A dictionary containing the following keys:
            read_id (str): The identifier of the sequence read.
            read_type (str): The type of the read, which can be 'good' or 'bad'
            reason (str or None): The reason for the failed alignment, if available.
            alignment_score (float): The alignment score for the read.
            left_flanking_region_start_fastq_pos (int or None): The starting position of the left flanking region
            in the FASTQ file, if available.
            cap0_read_fastq_pos (int or None): The position of the caps N1 base in the FASTQ file (0-indexed), if available.
            right_flanking_region_start_fastq_pos (int or None): The starting position of the right flanking region
            in the FASTQ file, if available.
    """
    # Get alignment
    sequence = str(record.seq)
    with contextlib.redirect_stdout(None):
        qry_str, aln_str, ref_str, aln_score = align(
            query_seq=sequence, target_seq=reference, pretty_print_alns=False
        )

    # define a data structure to return when the read OTE is not found
    out_ds_failed = {
        "read_id": record.id,
        "read_type": "bad",
        "reason": None,
        "alignment_score": aln_score,
        "left_flanking_region_start_fastq_pos": None,
        "cap0_read_fastq_pos": None,
        "right_flanking_region_start_fastq_pos": None,
        "roi_fasta": None,
    }

    # For low quality alignments, return None
    if aln_score < 20:
        out_ds_failed["reason"] = "low_aln_score"
        return out_ds_failed

    # Make index coordinates along the reference
    coord_list = make_coordinates(aln_str, ref_str)

    # Check if the first cap base is in the coordinates list. If not then
    # the alignment did not even reach the cap, so it is a bad read then.
    try:
        cap0_idx = coord_list.index(
            cap0_pos
        )  # cap0 position in the reference with gaps
    except Exception:
        out_ds_failed["reason"] = "aln_does_not_reach_the_cap_base"
        return out_ds_failed

    # Check if the NNNNNN region in the reference has matches in it
    # 1. First find the end index of the NNNNNN region
    try:
        n_region_end_idx = coord_list.index(cap0_pos + N_REGION_LEN - 1)
    except Exception:
        out_ds_failed["reason"] = "aln_does_not_reach_nnnnnn_region"
        return out_ds_failed

    # 2. Define regions in which to check for good alignment
    nnn_region = (cap0_idx, n_region_end_idx + 1)
    before_nnn_region = (cap0_idx - BEFORE_N_REGION_WINDOW_LEN, cap0_idx)
    after_nnn_region = (
        n_region_end_idx + 1,
        n_region_end_idx + 1 + AFTER_N_REGION_WINDOW_LEN,
    )

    # 3. Extract alignment strings for each region
    aln_str_nnn_region = aln_str[nnn_region[0] : nnn_region[1]]
    aln_str_before_nnn_region = aln_str[before_nnn_region[0] : before_nnn_region[1]]
    aln_str_after_nnn_region = aln_str[after_nnn_region[0] : after_nnn_region[1]]

    # 4. Count matches, mismatches, and gaps in each region
    n_match_cnt, n_mismatch_cnt, n_gap_cnt = cnt_match_mismatch_gaps(aln_str_nnn_region)
    bn_match_cnt, bn_mismatch_cnt, bn_gap_cnt = cnt_match_mismatch_gaps(
        aln_str_before_nnn_region
    )
    an_match_cnt, an_mismatch_cnt, an_gap_cnt = cnt_match_mismatch_gaps(
        aln_str_after_nnn_region
    )

    # 5. Are there good alignments in the the NNN region and the regions flanking it?
    has_good_aln_in_nnn_region = has_good_aln_in_n_region(
        n_match_cnt, n_mismatch_cnt, n_gap_cnt
    )
    has_good_aln_before_n_region = has_good_aln_ns_flanking_region(
        bn_match_cnt, bn_mismatch_cnt, bn_gap_cnt
    )
    has_good_aln_after_n_region = has_good_aln_ns_flanking_region(
        an_match_cnt, an_mismatch_cnt, an_gap_cnt
    )

    # 6. If all three alignment are good then a read has good and reliable OTE
    if not (
        has_good_aln_before_n_region
        and has_good_aln_in_nnn_region
        and has_good_aln_after_n_region
    ):
        out_ds_failed["reason"] = "111"  # 111 means all three regions are good
        if not (has_good_aln_before_n_region):
            reason_list = list(out_ds_failed["reason"])
            reason_list[0] = "0"
            out_ds_failed["reason"] = "".join(reason_list)
        if not (has_good_aln_in_nnn_region):
            reason_list = list(out_ds_failed["reason"])
            reason_list[1] = "0"
            out_ds_failed["reason"] = "".join(reason_list)
        if not (has_good_aln_after_n_region):
            reason_list = list(out_ds_failed["reason"])
            reason_list[2] = "0"
            out_ds_failed["reason"] = "".join(reason_list)
        return out_ds_failed

    # Find the position of cap N1 base in read's sequence (0-based indexing)
    cap0_read_fastq_pos = qry_str[:cap0_idx].replace("-", "").count("") - 1

    # Find the index of first base of the left flanking region
    left_flanking_region_start_ref_idx = cap0_idx - NUM_CAP_FLANKING_BASES
    left_flanking_region_start_fastq_pos = (
        qry_str[:left_flanking_region_start_ref_idx].replace("-", "").count("") - 1
    )
    right_flanking_region_end_ref_idx = cap0_idx + 1 + NUM_CAP_FLANKING_BASES
    right_flanking_region_start_fastq_pos = (
        qry_str[:right_flanking_region_end_ref_idx].replace("-", "").count("") - 1
    )
    roi_fasta = sequence[
        left_flanking_region_start_fastq_pos:right_flanking_region_start_fastq_pos
    ]

    out_ds_passed = {
        "read_id": record.id,
        "read_type": "good",
        "reason": "111",
        "alignment_score": aln_score,
        "left_flanking_region_start_fastq_pos": left_flanking_region_start_fastq_pos,
        "cap0_read_fastq_pos": cap0_read_fastq_pos,
        "right_flanking_region_start_fastq_pos": right_flanking_region_start_fastq_pos,
        "roi_fasta": roi_fasta,
    }

    return out_ds_passed

write_csv(resutls_list: List[dict], output_filepath: str) -> None

Take a list of dictionaries and write them to a CSV file.

Parameters:

Name Type Description Default
resutls_list list

A list of dictionaries.

required
output_filepath str

The path to the output CSV file.

required

Returns:

Type Description
None

None

Source code in src/capfinder/find_ote_train.py
def write_csv(resutls_list: List[dict], output_filepath: str) -> None:
    """
    Take a list of dictionaries and write them to a CSV file.

    Args:
        resutls_list (list): A list of dictionaries.
        output_filepath (str): The path to the output CSV file.

    Returns:
        None
    """
    # Specify the CSV column headers based on the dictionary keys
    fieldnames = resutls_list[0].keys()

    # Create and write to the CSV file
    with open(output_filepath, "w", newline="") as csv_file:
        writer = csv.DictWriter(csv_file, fieldnames=fieldnames)

        # Write the header row
        writer.writeheader()

        # Write the data rows
        writer.writerows(resutls_list)

index

We cannot random access a record in a BAM file, we can only iterate through it. That is our starting point. For each record in BAM file, we need to find the corresponding record in POD5 file. For that we need a mapping between POD5 file and read_ids. This is why we need to build an index of POD5 files. This module helps us to build an index of POD5 files and stores it in a SQLite database.

Author: Adnan M. Niazi Date: 2024-02-28

fetch_filepath_using_filename(conn: sqlite3.Connection, cursor: sqlite3.Cursor, pod5_filename: str) -> Any

Retrieve the pod5_filepath based on pod5_filename from the database.

Parameters:

Name Type Description Default
conn Connection

Connection object for the database.

required
cursor Cursor

Cursor object for the database.

required
pod5_filename str

The pod5_filename to be searched for.

required

Returns:

Name Type Description
pod5_filepath Any

The corresponding pod5_filepath if found, else None.

Source code in src/capfinder/index.py
def fetch_filepath_using_filename(
    conn: sqlite3.Connection, cursor: sqlite3.Cursor, pod5_filename: str
) -> Any:
    """
    Retrieve the pod5_filepath based on pod5_filename from the database.

    Params:
        conn (sqlite3.Connection): Connection object for the database.
        cursor (sqlite3.Cursor): Cursor object for the database.
        pod5_filename (str): The pod5_filename to be searched for.

    Returns:
        pod5_filepath (Any): The corresponding pod5_filepath if found, else None.
    """
    try:
        # Execute the SQL query to retrieve the pod5_filepath based on pod5_filename
        cursor.execute(
            "SELECT pod5_filepath FROM pod5_index WHERE pod5_filename = ?",
            (pod5_filename,),
        )
        result = cursor.fetchone()

        # Return the result (pod5_filepath) or None if not found
        return result[0] if result else None

    except sqlite3.Error as e:
        logger.error(f"Error: {e}")
        return None

find_database_size(database_path: str) -> Any

Find the number of records in the database.

Parameters:

Name Type Description Default
database_path str

Path to the database.

required

Returns:

Name Type Description
size Any

Number of records in the database.

Source code in src/capfinder/index.py
def find_database_size(database_path: str) -> Any:
    """
    Find the number of records in the database.

    Params:
        database_path (str): Path to the database.

    Returns:
        size (Any): Number of records in the database.
    """
    conn = sqlite3.connect(database_path)
    cursor = conn.cursor()
    cursor.execute("SELECT COUNT(*) FROM pod5_index")
    result = cursor.fetchone()
    size = result[0] if result is not None else 0
    return size

generate_pod5_path_and_name(pod5_path: str) -> Generator[Tuple[str, str], None, None]

Traverse the directory and yield all the names+extension and fullpaths of the pod5 files.

Parameters:

Name Type Description Default
pod5_path str

Path to a POD5 file/directory of POD5 files.

required

Yields:

Type Description
str

Tuple[str, str]: Tuple containing the name+extension and full path of a POD5 file.

Source code in src/capfinder/index.py
def generate_pod5_path_and_name(
    pod5_path: str,
) -> Generator[Tuple[str, str], None, None]:
    """Traverse the directory and yield all the names+extension and
    fullpaths of the pod5 files.

    Params:
        pod5_path (str): Path to a POD5 file/directory of POD5 files.

    Yields:
        Tuple[str, str]: Tuple containing the name+extension and full path of a POD5 file.
    """

    if os.path.isdir(pod5_path):
        for root, _dirs, files in os.walk(pod5_path):
            for file in files:
                if file.endswith(".pod5"):
                    yield (file, os.path.join(root, file))
    elif os.path.isfile(pod5_path) and pod5_path.endswith(".pod5"):
        root = os.path.basename(pod5_path)
        file = os.path.dirname(pod5_path)
        yield (file, os.path.join(root, file))

index(pod5_path: str, output_dir: str) -> None

Builds an index mapping read_ids to POD5 file paths.

Parameters:

Name Type Description Default
pod5_path str

Path to a POD5 file or directory of POD5 files.

required
output_dir str

Path where database.db file is written to.

required

Returns:

Type Description
None

None

Source code in src/capfinder/index.py
def index(pod5_path: str, output_dir: str) -> None:
    """
    Builds an index mapping read_ids to POD5 file paths.

    Params:
        pod5_path (str): Path to a POD5 file or directory of POD5 files.
        output_dir (str): Path where database.db file is written to.

    Returns:
        None
    """

    database_path = os.path.join(output_dir, "database.db")
    cursor, conn = initialize_database(database_path)
    total_files = sum(1 for _ in generate_pod5_path_and_name(pod5_path))
    logger.info(f"Indexing {total_files} POD5 files")
    for data in tqdm(
        generate_pod5_path_and_name(pod5_path),
        total=total_files,
        desc="",
        unit="files",
    ):
        write_database(data, cursor, conn)
    logger.info("Indexing complete")
    conn.close()

initialize_database(database_path: str) -> Tuple[sqlite3.Cursor, sqlite3.Connection]

Intializes the database connection based on the database path.

Parameters:

Name Type Description Default
database_path str

Path to the database.

required

Returns:

Name Type Description
cursor Cursor

Cursor object for the database.

conn Connection

Connection object for the database.

Source code in src/capfinder/index.py
def initialize_database(
    database_path: str,
) -> Tuple[sqlite3.Cursor, sqlite3.Connection]:
    """
    Intializes the database connection based on the database path.

    Params:
        database_path (str): Path to the database.

    Returns:
        cursor (sqlite3.Cursor): Cursor object for the database.
        conn (sqlite3.Connection): Connection object for the database.
    """
    conn = sqlite3.connect(database_path)
    cursor = conn.cursor()
    cursor.execute(
        """CREATE TABLE IF NOT EXISTS pod5_index (pod5_filename TEXT PRIMARY KEY, pod5_filepath TEXT)"""
    )
    return cursor, conn

write_database(data: Tuple[str, str], cursor: sqlite3.Cursor, conn: sqlite3.Connection) -> None

Write the index to a database.

Parameters:

Name Type Description Default
data Tuple[str, str]

Tuples of fileroot and file

required
cursor Cursor

Cursor object for the database.

required
conn Connection

Connection object for the database.

required

Returns:

Type Description
None

None

Source code in src/capfinder/index.py
def write_database(
    data: Tuple[str, str], cursor: sqlite3.Cursor, conn: sqlite3.Connection
) -> None:
    """
    Write the index to a database.

    Params:
        data Tuple[str, str]): Tuples of fileroot and file
        cursor (sqlite3.Cursor): Cursor object for the database.
        conn (sqlite3.Connection): Connection object for the database.

    Returns:
        None
    """
    cursor.execute("INSERT or REPLACE INTO pod5_index VALUES (?, ?)", data)
    conn.commit()

inference

batched_inference(dataset: tf.data.Dataset, model: keras.Model, output_dir: str, csv_file_path: str) -> str

Perform batched inference on a dataset using a given model and save predictions to a CSV file.

Parameters:

Name Type Description Default
dataset Dataset

The input dataset to perform inference on.

required
model Model

The Keras model to use for making predictions.

required
output_dir str

The directory where the output CSV file will be saved.

required
csv_file_path str

Path to the original CSV file used to create the dataset.

required

Returns:

Name Type Description
str str

The path to the output CSV file containing the predictions.

Source code in src/capfinder/inference.py
@task(cache_key_fn=task_input_hash)
def batched_inference(
    dataset: tf.data.Dataset,
    model: keras.Model,
    output_dir: str,
    csv_file_path: str,
) -> str:
    """
    Perform batched inference on a dataset using a given model and save predictions to a CSV file.

    Args:
        dataset (tf.data.Dataset): The input dataset to perform inference on.
        model (keras.Model): The Keras model to use for making predictions.
        output_dir (str): The directory where the output CSV file will be saved.
        csv_file_path (str): Path to the original CSV file used to create the dataset.

    Returns:
        str: The path to the output CSV file containing the predictions.
    """
    os.makedirs(output_dir, exist_ok=True)
    output_csv_path = os.path.join(output_dir, "predictions.csv")

    total_reads = count_csv_rows(csv_file_path)
    logger.info(f"Total reads to perform cap predictions on: {total_reads}")

    with open(output_csv_path, "w", newline="") as csvfile:
        csvwriter = csv.writer(csvfile)
        csvwriter.writerow(["read_id", "predicted_cap"])

        pbar = tqdm(total=total_reads, unit="reads")

        processed_reads = 0
        try:
            for batch in dataset:
                x, _, read_id = batch
                preds = model.predict(x, verbose=0)
                batch_pred_classes = tf.argmax(preds, axis=1).numpy()

                rows_to_write = [
                    [rid.decode("utf-8"), map_cap_int_to_name(pred_class)]
                    for rid, pred_class in zip(read_id.numpy(), batch_pred_classes)
                ]

                csvwriter.writerows(rows_to_write)

                batch_size = len(read_id)
                processed_reads += batch_size
                pbar.update(batch_size)

                if processed_reads >= total_reads:
                    break
        except tf.errors.OutOfRangeError:
            logger.warning(
                "Dataset iterator exhausted before processing all expected reads."
            )
        finally:
            pbar.close()

    logger.info(f"Batched inference completed! Processed {processed_reads} reads.")
    if processed_reads != total_reads:
        logger.warning(
            f"Number of processed samples ({processed_reads}) "
            f"differs from expected total ({total_reads})."
        )

    return output_csv_path

collate_bam_pod5_wrapper(bam_filepath: str, pod5_dir: str, num_cpus: int, reference: str, cap_class: int, cap0_pos: int, train_or_test: str, plot_signal: bool, output_dir: str) -> tuple[str, str]

Wrapper for collating BAM and POD5 files.

Parameters:

Name Type Description Default
bam_filepath str

Path to the BAM file.

required
pod5_dir str

Directory containing POD5 files.

required
num_cpus int

Number of CPUs to use for processing.

required
reference str

Reference sequence.

required
cap_class int

CAP class identifier.

required
cap0_pos int

Position of CAP0.

required
train_or_test str

Indicates whether data is for training or testing.

required
plot_signal bool

Flag to plot the signal.

required
output_dir str

Directory where output files will be saved.

required

Returns:

Type Description
tuple[str, str]

tuple[str, str]: Paths to the data and metadata files.

Source code in src/capfinder/inference.py
@task(cache_key_fn=task_input_hash)
def collate_bam_pod5_wrapper(
    bam_filepath: str,
    pod5_dir: str,
    num_cpus: int,
    reference: str,
    cap_class: int,
    cap0_pos: int,
    train_or_test: str,
    plot_signal: bool,
    output_dir: str,
) -> tuple[str, str]:
    """
    Wrapper for collating BAM and POD5 files.

    Args:
        bam_filepath (str): Path to the BAM file.
        pod5_dir (str): Directory containing POD5 files.
        num_cpus (int): Number of CPUs to use for processing.
        reference (str): Reference sequence.
        cap_class (int): CAP class identifier.
        cap0_pos (int): Position of CAP0.
        train_or_test (str): Indicates whether data is for training or testing.
        plot_signal (bool): Flag to plot the signal.
        output_dir (str): Directory where output files will be saved.

    Returns:
        tuple[str, str]: Paths to the data and metadata files.
    """
    data_path, metadata_path = collate_bam_pod5(
        bam_filepath=bam_filepath,
        pod5_dir=pod5_dir,
        num_processes=num_cpus,
        reference=reference,
        cap_class=cap_class,
        cap0_pos=cap0_pos,
        train_or_test=train_or_test,
        plot_signal=plot_signal,
        output_dir=output_dir,
    )
    return data_path, metadata_path

count_csv_rows(file_path: str) -> int

Quickly count the number of rows in a CSV file.

Parameters:

Name Type Description Default
file_path str

Path to the CSV file.

required

Returns:

Name Type Description
int int

Number of rows in the CSV file (excluding the header).

Source code in src/capfinder/inference.py
def count_csv_rows(file_path: str) -> int:
    """
    Quickly count the number of rows in a CSV file.

    Args:
        file_path (str): Path to the CSV file.

    Returns:
        int: Number of rows in the CSV file (excluding the header).
    """
    with Path(file_path).open() as f:
        return sum(1 for _ in f) - 1  # Subtract 1 to exclude the header

custom_cache_key_fn(context: TaskRunContext, parameters: dict) -> str

Generate a custom cache key based on input parameters.

Parameters:

Name Type Description Default
context TaskRunContext

Prefect context (unused in this function).

required
parameters dict

Dictionary of parameters used for cache key generation.

required

Returns:

Name Type Description
str str

The generated cache key.

Source code in src/capfinder/inference.py
def custom_cache_key_fn(context: TaskRunContext, parameters: dict) -> str:
    """
    Generate a custom cache key based on input parameters.

    Args:
        context (TaskRunContext): Prefect context (unused in this function).
        parameters (dict): Dictionary of parameters used for cache key generation.

    Returns:
        str: The generated cache key.
    """
    dataset_hash = hashlib.md5(str(parameters["dataset"]).encode()).hexdigest()
    model_hash = hashlib.md5(str(parameters["model"]).encode()).hexdigest()
    output_dir_hash = hashlib.md5(parameters["output_dir"].encode()).hexdigest()
    combined_hash = hashlib.md5(
        f"{dataset_hash}_{model_hash}_{output_dir_hash}".encode()
    ).hexdigest()
    return combined_hash

generate_report_wrapper(metadata_file: str, predictions_file: str, output_csv: str, output_html: str) -> None

Wrapper for generating the report.

Parameters:

Name Type Description Default
metadata_file str

Path to the metadata file.

required
predictions_file str

Path to the predictions file.

required
output_csv str

Path to save the output CSV.

required
output_html str

Path to save the output HTML.

required
Source code in src/capfinder/inference.py
@task(cache_key_fn=task_input_hash)
def generate_report_wrapper(
    metadata_file: str, predictions_file: str, output_csv: str, output_html: str
) -> None:
    """
    Wrapper for generating the report.

    Args:
        metadata_file (str): Path to the metadata file.
        predictions_file (str): Path to the predictions file.
        output_csv (str): Path to save the output CSV.
        output_html (str): Path to save the output HTML.
    """
    generate_report(
        metadata_file,
        predictions_file,
        output_csv,
        output_html,
    )
    os.remove(predictions_file)

get_model(model_path: Optional[str] = None, load_optimizer: bool = False) -> keras.Model

Load and return a model from the given model path or use the default model.

Parameters:

Name Type Description Default
model_path Optional[str]

Path to the custom model file. If None, use the default model.

None
load_optimizer bool

Whether to load the optimizer with the model.

False

Returns:

Type Description
Model

keras.Model: The loaded Keras model.

Source code in src/capfinder/inference.py
def get_model(
    model_path: Optional[str] = None, load_optimizer: bool = False
) -> keras.Model:
    """
    Load and return a model from the given model path or use the default model.

    Args:
        model_path (Optional[str]): Path to the custom model file. If None, use the default model.
        load_optimizer (bool): Whether to load the optimizer with the model.

    Returns:
        keras.Model: The loaded Keras model.
    """
    if model_path is None:
        model_file = resources.files(model_module).joinpath("cnn_lstm-classifier.keras")
        with resources.as_file(model_file) as default_model_path:
            model = keras.models.load_model(default_model_path, compile=False)
    else:
        model = keras.models.load_model(model_path, compile=False)

    logger.info(
        f"Model loaded successfully from {'default path' if model_path is None else model_path}"
    )
    return model

predict_cap_types(bam_filepath: str, pod5_dir: str, num_cpus: int, output_dir: str, dtype: DtypeLiteral, reference: str = 'GCTTTCGTTCGTCTCCGGACTTATCGCACCACCTATCCATCATCAGTACTGT', cap0_pos: int = 52, train_or_test: str = 'test', plot_signal: bool = False, cap_class: int = -99, target_length: int = 500, batch_size: int = 32, custom_model_path: Optional[str] = None, debug_code: bool = False, refresh_cache: bool = False, formatted_command: Optional[str] = None) -> None

Predict CAP types by preparing the inference data and running the prediction workflow.

Parameters:

Name Type Description Default
bam_filepath str

Path to the BAM file.

required
pod5_dir str

Directory containing POD5 files.

required
num_cpus int

Number of CPUs to use for processing.

required
output_dir str

Directory where output files will be saved.

required
dtype DtypeLiteral

Data type for the features.

required
reference str

Reference sequence.

'GCTTTCGTTCGTCTCCGGACTTATCGCACCACCTATCCATCATCAGTACTGT'
cap0_pos int

Position of CAP0.

52
train_or_test str

Indicates whether data is for training or testing.

'test'
plot_signal bool

Flag to plot the signal.

False
cap_class int

CAP class identifier.

-99
target_length int

Length of the target sequence.

500
batch_size int

Size of the data batches.

32
custom_model_path Optional[str]

Path to a custom model file. If None, use the default model.

None
debug_code bool

Flag to enable debugging information in logs.

False
refresh_cache bool

Flag to refresh cached data.

False
formatted_command Optional[str]

The formatted command string to be logged.

None
Source code in src/capfinder/inference.py
def predict_cap_types(
    bam_filepath: str,
    pod5_dir: str,
    num_cpus: int,
    output_dir: str,
    dtype: DtypeLiteral,
    reference: str = "GCTTTCGTTCGTCTCCGGACTTATCGCACCACCTATCCATCATCAGTACTGT",
    cap0_pos: int = 52,
    train_or_test: str = "test",
    plot_signal: bool = False,
    cap_class: int = -99,
    target_length: int = 500,
    batch_size: int = 32,
    custom_model_path: Optional[str] = None,
    debug_code: bool = False,
    refresh_cache: bool = False,
    formatted_command: Optional[str] = None,
) -> None:
    """
    Predict CAP types by preparing the inference data and running the prediction workflow.

    Args:
        bam_filepath (str): Path to the BAM file.
        pod5_dir (str): Directory containing POD5 files.
        num_cpus (int): Number of CPUs to use for processing.
        output_dir (str): Directory where output files will be saved.
        dtype (DtypeLiteral): Data type for the features.
        reference (str): Reference sequence.
        cap0_pos (int): Position of CAP0.
        train_or_test (str): Indicates whether data is for training or testing.
        plot_signal (bool): Flag to plot the signal.
        cap_class (int): CAP class identifier.
        target_length (int): Length of the target sequence.
        batch_size (int): Size of the data batches.
        custom_model_path (Optional[str]): Path to a custom model file. If None, use the default model.
        debug_code (bool): Flag to enable debugging information in logs.
        refresh_cache (bool): Flag to refresh cached data.
        formatted_command (Optional[str]): The formatted command string to be logged.
    """
    log_filepath = configure_logger(
        os.path.join(output_dir, "logs"), show_location=debug_code
    )
    configure_prefect_logging(show_location=debug_code)
    version_info = version("capfinder")
    log_header(f"Using Capfinder v{version_info}")
    logger.info(formatted_command)
    output_csv_path, output_html_path = prepare_inference_data(
        bam_filepath,
        pod5_dir,
        num_cpus,
        output_dir,
        dtype,
        reference,
        cap0_pos,
        train_or_test,
        plot_signal,
        cap_class,
        target_length,
        batch_size,
        custom_model_path,
        debug_code,
        refresh_cache,
    )
    grey = "\033[90m"
    reset = "\033[0m"
    log_output(
        f"Cap predictions have been saved to the following path:\n {grey}{output_csv_path}{reset}\nThe log file has been saved to:\n {grey}{log_filepath}{reset}\nThe analysis report has been saved to:\n {grey}{output_html_path}{reset}"
    )
    log_header("Processing finished!")

prepare_inference_data(bam_filepath: str, pod5_dir: str, num_cpus: int, output_dir: str, dtype: DtypeLiteral, reference: str = 'GCTTTCGTTCGTCTCCGGACTTATCGCACCACCTATCCATCATCAGTACTGT', cap0_pos: int = 52, train_or_test: str = 'test', plot_signal: bool = False, cap_class: int = -99, target_length: int = 500, batch_size: int = 32, custom_model_path: Optional[str] = None, debug_code: bool = False, refresh_cache: bool = False) -> tuple[str, str]

Prepare inference data by processing BAM and POD5 files, and generate features for the model.

Parameters:

Name Type Description Default
bam_filepath str

Path to the BAM file.

required
pod5_dir str

Directory containing POD5 files.

required
num_cpus int

Number of CPUs to use for processing.

required
output_dir str

Directory where output files will be saved.

required
dtype DtypeLiteral

Data type for the features.

required
reference str

Reference sequence.

'GCTTTCGTTCGTCTCCGGACTTATCGCACCACCTATCCATCATCAGTACTGT'
cap0_pos int

Position of CAP0.

52
train_or_test str

Indicates whether data is for training or testing.

'test'
plot_signal bool

Flag to plot the signal.

False
cap_class int

CAP class identifier.

-99
target_length int

Length of the target sequence.

500
batch_size int

Size of the data batches.

32
custom_model_path Optional[str]

Path to a custom model file. If None, use the default model.

None
debug_code bool

Flag to enable debugging information in logs.

False
refresh_cache bool

Flag to refresh cached data.

False

Returns:

Type Description
tuple[str, str]

tuple[str, str]: Paths to the output CSV and HTML files.

Source code in src/capfinder/inference.py
@flow(name="prepare-inference-data")
def prepare_inference_data(
    bam_filepath: str,
    pod5_dir: str,
    num_cpus: int,
    output_dir: str,
    dtype: DtypeLiteral,
    reference: str = "GCTTTCGTTCGTCTCCGGACTTATCGCACCACCTATCCATCATCAGTACTGT",
    cap0_pos: int = 52,
    train_or_test: str = "test",
    plot_signal: bool = False,
    cap_class: int = -99,
    target_length: int = 500,
    batch_size: int = 32,
    custom_model_path: Optional[str] = None,
    debug_code: bool = False,
    refresh_cache: bool = False,
) -> tuple[str, str]:
    """
    Prepare inference data by processing BAM and POD5 files, and generate features for the model.

    Args:
        bam_filepath (str): Path to the BAM file.
        pod5_dir (str): Directory containing POD5 files.
        num_cpus (int): Number of CPUs to use for processing.
        output_dir (str): Directory where output files will be saved.
        dtype (DtypeLiteral): Data type for the features.
        reference (str): Reference sequence.
        cap0_pos (int): Position of CAP0.
        train_or_test (str): Indicates whether data is for training or testing.
        plot_signal (bool): Flag to plot the signal.
        cap_class (int): CAP class identifier.
        target_length (int): Length of the target sequence.
        batch_size (int): Size of the data batches.
        custom_model_path (Optional[str]): Path to a custom model file. If None, use the default model.
        debug_code (bool): Flag to enable debugging information in logs.
        refresh_cache (bool): Flag to refresh cached data.

    Returns:
        tuple[str, str]: Paths to the output CSV and HTML files.
    """
    configure_prefect_logging(show_location=debug_code)
    os.makedirs(output_dir, exist_ok=True)

    log_step(1, 5, "Extracting Cap Signal by collating BAM and POD5 files")
    data_path, metadata_path = collate_bam_pod5_wrapper.with_options(
        refresh_cache=refresh_cache
    )(
        bam_filepath=bam_filepath,
        pod5_dir=pod5_dir,
        num_cpus=num_cpus,
        reference=reference,
        cap_class=cap_class,
        cap0_pos=cap0_pos,
        train_or_test=train_or_test,
        plot_signal=plot_signal,
        output_dir=os.path.join(output_dir, "0_raw_cap_signal_data"),
    )

    log_step(2, 5, "Creating TensorFlow dataset")
    dataset = create_dataset(data_path, target_length, batch_size, dtype)

    log_step(
        3, 5, f"Loading the {'custom' if custom_model_path else 'pre-trained'} model"
    )
    model = get_model(custom_model_path)

    log_step(4, 5, "Performing batch inference for cap type prediction")
    predictions_csv_path = batched_inference.with_options(refresh_cache=refresh_cache)(
        dataset,
        model,
        output_dir=os.path.join(output_dir, "1_cap_predictions"),
        csv_file_path=data_path,
    )

    log_step(5, 5, "Generating report")
    output_csv_path = os.path.join(
        output_dir, "1_cap_predictions", "cap_predictions.csv"
    )
    output_html_path = os.path.join(
        output_dir, "1_cap_predictions", "cap_analysis_report.html"
    )
    generate_report_wrapper.with_options(refresh_cache=refresh_cache)(
        metadata_file=metadata_path,
        predictions_file=predictions_csv_path,
        output_csv=output_csv_path,
        output_html=output_html_path,
    )
    return output_csv_path, output_html_path

reconfigure_logging_task(output_dir: str, debug_code: bool) -> None

Reconfigure logging settings for both application and Prefect.

Parameters:

Name Type Description Default
output_dir str

Directory where logs will be saved.

required
debug_code bool

Flag to determine if code locations should be shown in logs.

required
Source code in src/capfinder/inference.py
@task(cache_key_fn=custom_cache_key_fn)
def reconfigure_logging_task(output_dir: str, debug_code: bool) -> None:
    """
    Reconfigure logging settings for both application and Prefect.

    Args:
        output_dir (str): Directory where logs will be saved.
        debug_code (bool): Flag to determine if code locations should be shown in logs.
    """
    configure_logger(output_dir, show_location=debug_code)
    configure_prefect_logging(show_location=debug_code)

inference_data_loader

create_dataset(file_path: str, target_length: int, batch_size: int, dtype: DtypeLiteral) -> tf.data.Dataset

Create a TensorFlow dataset from a CSV file.

Parameters:

Name Type Description Default
file_path str

Path to the CSV file.

required
target_length int

The desired length of the timeseries tensor.

required
batch_size int

The number of samples per batch.

required
dtype DtypeLiteral

The desired data type for the timeseries tensor as a string.

required

Returns:

Type Description
Dataset

tf.data.Dataset: A TensorFlow dataset that yields batches of parsed and formatted data.

Source code in src/capfinder/inference_data_loader.py
def create_dataset(
    file_path: str, target_length: int, batch_size: int, dtype: DtypeLiteral
) -> tf.data.Dataset:
    """
    Create a TensorFlow dataset from a CSV file.

    Args:
        file_path (str): Path to the CSV file.
        target_length (int): The desired length of the timeseries tensor.
        batch_size (int): The number of samples per batch.
        dtype (DtypeLiteral): The desired data type for the timeseries tensor as a string.

    Returns:
        tf.data.Dataset: A TensorFlow dataset that yields batches of parsed and formatted data.
    """
    tf_dtype = get_dtype(dtype)

    dataset = tf.data.Dataset.from_generator(
        lambda: csv_generator(file_path),
        output_signature=(
            tf.TensorSpec(shape=(), dtype=tf.string),
            tf.TensorSpec(shape=(), dtype=tf.string),
            tf.TensorSpec(shape=(), dtype=tf.string),
        ),
    )

    dataset = dataset.map(
        lambda x, y, z: parse_row((x, y, z), target_length, tf_dtype),
        num_parallel_calls=tf.data.AUTOTUNE,
    )

    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

    logger.info("Dataset created successfully.")
    return dataset

csv_generator(file_path: str, chunk_size: int = 10000) -> Generator[Tuple[str, str, str], None, None]

Generate rows from a CSV file in chunks.

Parameters:

Name Type Description Default
file_path str

Path to the CSV file.

required
chunk_size int

Number of rows to process in each chunk. Defaults to 10000.

10000

Yields:

Type Description
str

Tuple[str, str, str]: A tuple containing read_id, cap_class, and timeseries as strings.

Source code in src/capfinder/inference_data_loader.py
def csv_generator(
    file_path: str, chunk_size: int = 10000
) -> Generator[Tuple[str, str, str], None, None]:
    """
    Generate rows from a CSV file in chunks.

    Args:
        file_path (str): Path to the CSV file.
        chunk_size (int, optional): Number of rows to process in each chunk. Defaults to 10000.

    Yields:
        Tuple[str, str, str]: A tuple containing read_id, cap_class, and timeseries as strings.
    """
    df = pl.scan_csv(file_path)
    total_rows = df.select(pl.count()).collect().item()

    for start in range(0, total_rows, chunk_size):
        min(start + chunk_size, total_rows)
        chunk = df.slice(start, chunk_size).collect()
        for row in chunk.iter_rows():
            yield (str(row[0]), str(row[1]), str(row[2]))

get_dtype(dtype: str) -> tf.DType

Convert a string dtype to its corresponding TensorFlow data type.

Parameters:

Name Type Description Default
dtype str

A string representing the desired data type.

required

Returns:

Type Description
DType

tf.DType: The corresponding TensorFlow data type.

Raises:

Type Description
ValueError

If an invalid dtype string is provided.

Source code in src/capfinder/inference_data_loader.py
def get_dtype(dtype: str) -> tf.DType:
    """
    Convert a string dtype to its corresponding TensorFlow data type.

    Args:
        dtype (str): A string representing the desired data type.

    Returns:
        tf.DType: The corresponding TensorFlow data type.

    Raises:
        ValueError: If an invalid dtype string is provided.
    """
    valid_dtypes = {
        "float16": float16,
        "float32": float32,
        "float64": float64,
    }

    if dtype in valid_dtypes:
        return valid_dtypes[dtype]
    else:
        logger.warning('You provided an invalid dtype. Using "float32" as default.')
        return float32

parse_row(row: Tuple[str, str, str], target_length: int, dtype: tf.DType) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]

Parse a row of data and convert it to the appropriate tensor format. Padding and truncation are performed equally on both sides of the time series.

Parameters:

Name Type Description Default
row Tuple[str, str, str]

A tuple containing read_id, cap_class, and timeseries as strings.

required
target_length int

The desired length of the timeseries tensor.

required
dtype DType

The desired data type for the timeseries tensor.

required

Returns:

Type Description
Tensor

Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: A tuple containing the parsed and formatted tensors for

Tensor

timeseries, cap_class, and read_id.

Source code in src/capfinder/inference_data_loader.py
def parse_row(
    row: Tuple[str, str, str], target_length: int, dtype: tf.DType
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
    """
    Parse a row of data and convert it to the appropriate tensor format.
    Padding and truncation are performed equally on both sides of the time series.

    Args:
        row (Tuple[str, str, str]): A tuple containing read_id, cap_class, and timeseries as strings.
        target_length (int): The desired length of the timeseries tensor.
        dtype (tf.DType): The desired data type for the timeseries tensor.

    Returns:
        Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: A tuple containing the parsed and formatted tensors for
        timeseries, cap_class, and read_id.
    """
    read_id, cap_class, timeseries = row
    cap_class = tf.strings.to_number(cap_class, out_type=tf.int32)

    # Split the timeseries string and convert to float
    timeseries = tf.strings.split(timeseries, sep=",")
    timeseries = tf.strings.to_number(timeseries, out_type=tf.float32)

    # Get the current length of the timeseries
    current_length = tf.shape(timeseries)[0]

    # Function to pad the timeseries
    def pad_timeseries() -> tf.Tensor:
        pad_amount = target_length - current_length
        pad_left = pad_amount // 2
        pad_right = pad_amount - pad_left
        return tf.pad(
            timeseries,
            [[pad_left, pad_right]],
            constant_values=0.0,
        )

    # Function to truncate the timeseries
    def truncate_timeseries() -> tf.Tensor:
        truncate_amount = current_length - target_length
        truncate_left = truncate_amount // 2
        truncate_right = current_length - (truncate_amount - truncate_left)
        return timeseries[truncate_left:truncate_right]

    # Pad or truncate the timeseries to the target length
    padded = tf.cond(
        current_length >= target_length, truncate_timeseries, pad_timeseries
    )

    padded = tf.reshape(padded, (target_length, 1))

    # Cast to the desired dtype
    if dtype != tf.float32:
        padded = tf.cast(padded, dtype)

    return padded, cap_class, read_id

logger_config

PrefectHandler

Bases: Handler

A custom logging handler for Prefect that filters and formats log messages.

This handler integrates with Loguru, applies custom formatting, and prevents duplicate log messages.

Source code in src/capfinder/logger_config.py
class PrefectHandler(logging.Handler):
    """
    A custom logging handler for Prefect that filters and formats log messages.

    This handler integrates with Loguru, applies custom formatting, and prevents duplicate log messages.
    """

    def __init__(self, loguru_logger: loguru.Logger, show_location: bool) -> None:
        """
        Initialize the PrefectHandler.

        Args:
            loguru_logger (Logger): The Loguru logger instance to use for logging.
            show_location (bool): Whether to show the source location in log messages.
        """
        super().__init__()
        self.loguru_logger = loguru_logger
        self.show_location = show_location
        self.logged_messages: set[str] = set()
        self.prefix_pattern: re.Pattern = re.compile(
            r"(logging:handle:\d+ - )(\w+\.\w+)"
        )

    def emit(self, record: logging.LogRecord) -> None:
        """
        Emit a log record.

        This method formats the log record, applies custom styling, and logs it using Loguru.
        It also filters out duplicate messages and HTTP status messages.

        Args:
            record (logging.LogRecord): The log record to emit.
        """
        try:
            # Filter out HTTP status messages
            if "HTTP Request:" in record.getMessage():
                return

            level: str = record.levelname
            message: str = self.format(record)
            name: str = record.name
            function: str = record.funcName
            line: int = record.lineno

            # Color the part after "logging:handle:XXXX - " cyan
            colored_message: str = self.prefix_pattern.sub(
                r"\1<cyan>\2</cyan>", message
            )

            # Handle progress bar messages
            if "|" in colored_message and (
                "%" in colored_message or "it/s" in colored_message
            ):
                formatted_message: str = f"Progress: {colored_message}"
            else:
                if self.show_location:
                    formatted_message = f"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - {colored_message}"
                else:
                    formatted_message = colored_message

            # Create a unique identifier for this log message
            message_id: str = message

            # Only log if we haven't seen this message before
            if message_id not in self.logged_messages:
                self.logged_messages.add(message_id)
                self.loguru_logger.opt(depth=1, colors=True).log(
                    level, formatted_message
                )
        except Exception:
            self.handleError(record)

__init__(loguru_logger: loguru.Logger, show_location: bool) -> None

Initialize the PrefectHandler.

Parameters:

Name Type Description Default
loguru_logger Logger

The Loguru logger instance to use for logging.

required
show_location bool

Whether to show the source location in log messages.

required
Source code in src/capfinder/logger_config.py
def __init__(self, loguru_logger: loguru.Logger, show_location: bool) -> None:
    """
    Initialize the PrefectHandler.

    Args:
        loguru_logger (Logger): The Loguru logger instance to use for logging.
        show_location (bool): Whether to show the source location in log messages.
    """
    super().__init__()
    self.loguru_logger = loguru_logger
    self.show_location = show_location
    self.logged_messages: set[str] = set()
    self.prefix_pattern: re.Pattern = re.compile(
        r"(logging:handle:\d+ - )(\w+\.\w+)"
    )

emit(record: logging.LogRecord) -> None

Emit a log record.

This method formats the log record, applies custom styling, and logs it using Loguru. It also filters out duplicate messages and HTTP status messages.

Parameters:

Name Type Description Default
record LogRecord

The log record to emit.

required
Source code in src/capfinder/logger_config.py
def emit(self, record: logging.LogRecord) -> None:
    """
    Emit a log record.

    This method formats the log record, applies custom styling, and logs it using Loguru.
    It also filters out duplicate messages and HTTP status messages.

    Args:
        record (logging.LogRecord): The log record to emit.
    """
    try:
        # Filter out HTTP status messages
        if "HTTP Request:" in record.getMessage():
            return

        level: str = record.levelname
        message: str = self.format(record)
        name: str = record.name
        function: str = record.funcName
        line: int = record.lineno

        # Color the part after "logging:handle:XXXX - " cyan
        colored_message: str = self.prefix_pattern.sub(
            r"\1<cyan>\2</cyan>", message
        )

        # Handle progress bar messages
        if "|" in colored_message and (
            "%" in colored_message or "it/s" in colored_message
        ):
            formatted_message: str = f"Progress: {colored_message}"
        else:
            if self.show_location:
                formatted_message = f"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - {colored_message}"
            else:
                formatted_message = colored_message

        # Create a unique identifier for this log message
        message_id: str = message

        # Only log if we haven't seen this message before
        if message_id not in self.logged_messages:
            self.logged_messages.add(message_id)
            self.loguru_logger.opt(depth=1, colors=True).log(
                level, formatted_message
            )
    except Exception:
        self.handleError(record)

configure_logger(new_log_directory: str = '', show_location: bool = True) -> str

Configure the logger to log to a file in the specified directory.

Source code in src/capfinder/logger_config.py
def configure_logger(new_log_directory: str = "", show_location: bool = True) -> str:
    """Configure the logger to log to a file in the specified directory."""
    global log_directory
    log_directory = new_log_directory if new_log_directory else log_directory

    # Ensure log directory exists
    os.makedirs(log_directory, exist_ok=True)

    # Get current date and time
    now: datetime = datetime.now()
    timestamp: str = now.strftime("%Y-%m-%d_%H-%M-%S")
    app_version: str = get_version()

    # Use the timestamp in the log file name
    log_filename: str = f"capfinder_v{app_version}_{timestamp}.log"
    log_filepath: str = os.path.join(log_directory, log_filename)

    # Remove default handler
    logger.remove()

    # Configure logger to log to both file and console with the same format
    log_format: str = (
        "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | <level>{level:<8}</level> | "
    )
    if show_location:
        log_format += (
            "<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
        )
    log_format += "<level>{message}</level>"

    logger.add(log_filepath, format=log_format, colorize=True)
    logger.add(sys.stdout, format=log_format, colorize=True)

    return log_filepath

configure_prefect_logging(show_location: bool = True) -> None

Configure Prefect logging with custom handler and settings.

This function sets up a custom PrefectHandler for all Prefect loggers, configures the root logger, and adjusts logging levels.

Parameters:

Name Type Description Default
show_location bool

Whether to show source location in log messages. Defaults to True.

True
Source code in src/capfinder/logger_config.py
def configure_prefect_logging(show_location: bool = True) -> None:
    """
    Configure Prefect logging with custom handler and settings.

    This function sets up a custom PrefectHandler for all Prefect loggers,
    configures the root logger, and adjusts logging levels.

    Args:
        show_location (bool, optional): Whether to show source location in log messages. Defaults to True.
    """
    # Create a single PrefectHandler instance
    handler = PrefectHandler(logger, show_location)

    # Configure the root logger
    root_logger = logging.getLogger()
    for h in root_logger.handlers[:]:
        root_logger.removeHandler(h)
    root_logger.addHandler(handler)
    root_logger.setLevel(logging.INFO)

    # Configure all Prefect loggers
    prefect_loggers = [
        logging.getLogger(name)
        for name in logging.root.manager.loggerDict
        if name.startswith("prefect")
    ]
    for prefect_logger in prefect_loggers:
        prefect_logger.handlers = [handler]
        prefect_logger.propagate = False
        prefect_logger.setLevel(logging.INFO)

    # Disable httpx logging
    logging.getLogger("httpx").setLevel(logging.WARNING)

get_version() -> str

Get the version of the app from pyproject.toml.

Source code in src/capfinder/logger_config.py
def get_version() -> str:
    """Get the version of the app from pyproject.toml."""
    app_version: str = version("capfinder")
    return app_version

plot

The modules helps in plotting the entire read signal, signal for ROI, and the base annotations. It also prints alignments. All this information is useful in understanding if the OTE-finding algorthim is homing-in on the correct region of interest (ROI).

The plot is saved as an HTML file.

Author: Adnan M. Niazi Date: 2024-02-28

append_dummy_sequence(fasta_sequence: str, num_left_clipped_bases: int, num_right_clipped_bases: int) -> str

Append/prepend 'H' to the left/right of the FASTA sequence based on soft-clipping counts

Parameters:

Name Type Description Default
fasta_sequence str

FASTA sequence

required
num_left_clipped_bases int

Number of bases soft-clipped from the left

required
num_right_clipped_bases int

Number of bases soft-clipped from the right

required

Returns:

Name Type Description
modified_sequence str

FASTA sequence with 'H' appended/prepended to the left/right

Source code in src/capfinder/plot.py
def append_dummy_sequence(
    fasta_sequence: str, num_left_clipped_bases: int, num_right_clipped_bases: int
) -> str:
    """Append/prepend 'H' to the left/right of the FASTA sequence based on soft-clipping counts

    Args:
        fasta_sequence (str): FASTA sequence
        num_left_clipped_bases (int): Number of bases soft-clipped from the left
        num_right_clipped_bases (int): Number of bases soft-clipped from the right

    Returns:
        modified_sequence (str): FASTA sequence with 'H' appended/prepended to the left/right

    """
    modified_sequence = (
        "H" * num_left_clipped_bases + fasta_sequence + "H" * num_right_clipped_bases
    )
    return modified_sequence

process_pod5

Given read_id and pod5 filepath, this file preprocesses the signal data, and extracts the signal data for a region of interest (ROI).

Author: Adnan M. Niazi Date: 2024-02-28

clip_extreme_values(z_normalized_data: npt.NDArray[np.float64], num_std_dev: float = 4.0) -> npt.NDArray[np.float64]

Clip extreme values in the Z-score normalized data.

Clips values outside the specified number of standard deviations from the mean. This function takes Z-score normalized data as input, along with an optional parameter to set the number of standard deviations.

Parameters:

Name Type Description Default
z_normalized_data NDArray[float64]

Z-score normalized data.

required
num_std_dev float

Number of standard deviations to use as the limit. Defaults to 4.0.

4.0

Returns:

Type Description
NDArray[float64]

npt.NDArray[np.float64]: Clipped data within the specified range.

Example

z_normalized_data = np.array([-2.0, -1.0, 0.0, 1.0, 2.0]) clipped_data = clip_extreme_values(z_normalized_data, num_std_dev=3.0)

Source code in src/capfinder/process_pod5.py
def clip_extreme_values(
    z_normalized_data: npt.NDArray[np.float64], num_std_dev: float = 4.0
) -> npt.NDArray[np.float64]:
    """Clip extreme values in the Z-score normalized data.

    Clips values outside the specified number of standard deviations from
    the mean. This function takes Z-score normalized data as input, along
    with an optional parameter to set the number of standard deviations.

    Params:
        z_normalized_data (npt.NDArray[np.float64]): Z-score normalized data.
        num_std_dev (float, optional): Number of standard deviations to use
            as the limit. Defaults to 4.0.

    Returns:
        npt.NDArray[np.float64]: Clipped data within the specified range.

    Example:
        >>> z_normalized_data = np.array([-2.0, -1.0, 0.0, 1.0, 2.0])
        >>> clipped_data = clip_extreme_values(z_normalized_data, num_std_dev=3.0)
    """

    lower_limit = -num_std_dev
    upper_limit = num_std_dev
    clipped_data: npt.NDArray[np.float64] = np.clip(
        z_normalized_data, lower_limit, upper_limit
    )
    return clipped_data

extract_roi_signal(signal: np.ndarray, base_locs_in_signal: npt.NDArray[np.int32], fasta: str, experiment_type: str, start_base_idx_in_fasta: int, end_base_idx_in_fasta: int, num_left_clipped_bases: int) -> ROIData

Extracts the signal data for a region of interest (ROI).

Parameters:

Name Type Description Default
signal ndarray

Signal data.

required
base_locs_in_signal NDArray[int32]

Array of locations of each new base in the signal.

required
fasta str

Fasta sequence of the read.

required
experiment_type str

Type of experiment (rna or dna).

required
start_base_idx_in_fasta int

Index of the first base in the ROI.

required
end_base_idx_in_fasta int

Index of the last base in the ROI.

required
num_left_clipped_bases int

Number of bases clipped from the left.

required

Returns:

Name Type Description
ROIData ROIData

Dictionary containing the ROI signal and fasta sequence.

Source code in src/capfinder/process_pod5.py
def extract_roi_signal(
    signal: np.ndarray,
    base_locs_in_signal: npt.NDArray[np.int32],
    fasta: str,
    experiment_type: str,
    start_base_idx_in_fasta: int,
    end_base_idx_in_fasta: int,
    num_left_clipped_bases: int,
) -> ROIData:
    """
    Extracts the signal data for a region of interest (ROI).

    Params:
        signal (np.ndarray): Signal data.
        base_locs_in_signal (npt.NDArray[np.int32]): Array of locations of each new base in the signal.
        fasta (str): Fasta sequence of the read.
        experiment_type (str): Type of experiment (rna or dna).
        start_base_idx_in_fasta (int): Index of the first base in the ROI.
        end_base_idx_in_fasta (int): Index of the last base in the ROI.
        num_left_clipped_bases (int): Number of bases clipped from the left.

    Returns:
        ROIData: Dictionary containing the ROI signal and fasta sequence.
    """
    signal = preprocess_signal_data(signal)
    roi_data: ROIData = {
        "roi_fasta": None,
        "roi_signal": np.array([], dtype=np.float64),
        "signal_start": None,
        "signal_end": None,
        "plot_signal": signal,  # Assuming signal is defined somewhere
        "roi_signal_for_plot": None,
        "base_locs_in_signal": base_locs_in_signal,  # Assuming base_locs_in_signal is defined somewhere
        "start_base_idx_in_fasta": None,
        "end_base_idx_in_fasta": None,
        "read_id": None,
    }

    # Check for valid inputs
    if end_base_idx_in_fasta is None and start_base_idx_in_fasta is None:
        return roi_data

    if end_base_idx_in_fasta > len(fasta) or start_base_idx_in_fasta < 0:
        return roi_data

    if experiment_type not in ("rna", "dna"):
        return roi_data

    start_base_idx_in_fasta += num_left_clipped_bases
    end_base_idx_in_fasta += num_left_clipped_bases
    if experiment_type == "rna":
        rev_base_locs_in_signal = base_locs_in_signal[::-1]
        signal_end = rev_base_locs_in_signal[start_base_idx_in_fasta - 1]
        signal_start = rev_base_locs_in_signal[end_base_idx_in_fasta - 1]
        roi_data["roi_fasta"] = fasta[
            start_base_idx_in_fasta
            - num_left_clipped_bases : end_base_idx_in_fasta
            - num_left_clipped_bases
        ]
    else:
        # TODO: THE LOGIC IS NOT TESTED
        signal_start = base_locs_in_signal[
            start_base_idx_in_fasta - 1
        ]  # TODO: Confirm -1
        signal_end = base_locs_in_signal[end_base_idx_in_fasta - 1]  # TODO: Confirm -1
        roi_data["roi_fasta"] = fasta[start_base_idx_in_fasta:end_base_idx_in_fasta]

    # Signal data is 3'-> 5' for RNA 5' -> 3' for DNA
    # The ROI FASTA is always 5' -> 3' irrespective of the experiment type
    roi_data["roi_signal"] = signal[signal_start:signal_end]

    # Make roi signal for plot and pad it with NaNs outside the ROI
    plot_signal = np.copy(signal)
    plot_signal[:signal_start] = np.nan
    plot_signal[signal_end:] = np.nan
    roi_data["signal_start"] = signal_start
    roi_data["signal_end"] = signal_end
    roi_data["roi_signal_for_plot"] = plot_signal
    roi_data["base_locs_in_signal"] = base_locs_in_signal

    return roi_data

find_base_locs_in_signal(bam_data: dict) -> npt.NDArray[np.int32]

Finds the locations of each new base in the signal.

Parameters:

Name Type Description Default
bam_data dict

Dictionary containing information from the BAM file.

required

Returns:

Type Description
NDArray[int32]

npt.NDArray[np.int32]: Array of locations of each new base in the signal.

Source code in src/capfinder/process_pod5.py
def find_base_locs_in_signal(bam_data: dict) -> npt.NDArray[np.int32]:
    """
    Finds the locations of each new base in the signal.

    Params:
        bam_data (dict): Dictionary containing information from the BAM file.

    Returns:
        npt.NDArray[np.int32]: Array of locations of each new base in the signal.
    """
    start_sample = bam_data["start_sample"]
    split_point = bam_data["split_point"]

    # we map the moves from 3' to 5' to the signal
    # and start from the start sample or its sum with the split point
    # if the read is split
    start_sample = start_sample + split_point

    moves_step = bam_data["moves_step"]
    moves_table = np.array(bam_data["moves_table"])

    # Where do moves occur in the signal coordinates?
    moves_indices = np.arange(
        start_sample, start_sample + moves_step * len(moves_table), moves_step
    )

    # We only need locs where a move of 1 occurs
    base_locs_in_signal: npt.NDArray[np.int32] = moves_indices[moves_table != 0]

    return base_locs_in_signal

preprocess_signal_data(signal: np.ndarray) -> npt.NDArray[np.float64]

Preprocesses the signal data.

Parameters:

Name Type Description Default
signal ndarray

Signal data.

required

Returns:

Name Type Description
signal NDArray[float64]

Preprocessed signal data.

Source code in src/capfinder/process_pod5.py
def preprocess_signal_data(signal: np.ndarray) -> npt.NDArray[np.float64]:
    """
    Preprocesses the signal data.

    Params:
        signal (np.ndarray): Signal data.

    Returns:
        signal (npt.NDArray[np.float64]): Preprocessed signal data.
    """
    signal = z_normalize(signal)
    signal = clip_extreme_values(signal)
    return signal

pull_read_from_pod5(read_id: str, pod5_filepath: str) -> Dict[str, Any]

Returns a single read from a pod5 file.

Parameters:

Name Type Description Default
read_id str

str The read_id of the read to be extracted.

required
pod5_filepath str

str Path to the pod5 file.

required

Returns:

Name Type Description
dict Dict[str, Any]

Dictionary containing information about the extracted read. - 'sample_rate': Sample rate of the read. - 'sequencing_kit': Sequencing kit used. - 'experiment_type': Experiment type. - 'local_basecalling': Local basecalling information. - 'signal': Signal data. - 'signal_pa': Signal data for the positive strand. - 'end_reason': Reason for the end of the read. - 'sample_count': Number of samples in the read. - 'channel': Pore channel information. - 'well': Pore well information. - 'pore_type': Pore type. - 'writing_software': Software used for writing. - 'scale': Scaling factor for the signal. - 'shift': Shift factor for the signal.

Source code in src/capfinder/process_pod5.py
def pull_read_from_pod5(read_id: str, pod5_filepath: str) -> Dict[str, Any]:
    """Returns a single read from a pod5 file.

    Params:
        read_id: str
            The read_id of the read to be extracted.
        pod5_filepath: str
            Path to the pod5 file.

    Returns:
        dict: Dictionary containing information about the extracted read.
            - 'sample_rate': Sample rate of the read.
            - 'sequencing_kit': Sequencing kit used.
            - 'experiment_type': Experiment type.
            - 'local_basecalling': Local basecalling information.
            - 'signal': Signal data.
            - 'signal_pa': Signal data for the positive strand.
            - 'end_reason': Reason for the end of the read.
            - 'sample_count': Number of samples in the read.
            - 'channel': Pore channel information.
            - 'well': Pore well information.
            - 'pore_type': Pore type.
            - 'writing_software': Software used for writing.
            - 'scale': Scaling factor for the signal.
            - 'shift': Shift factor for the signal.

    """
    signal_dict = {}
    with p5.Reader(pod5_filepath) as reader:
        read = next(reader.reads(selection=[read_id]))
        # Get the signal data and sample rate
        signal_dict["sample_rate"] = read.run_info.sample_rate
        signal_dict["sequencing_kit"] = read.run_info.sequencing_kit
        signal_dict["experiment_type"] = read.run_info.context_tags["experiment_type"]
        signal_dict["local_basecalling"] = read.run_info.context_tags[
            "local_basecalling"
        ]
        signal_dict["signal"] = read.signal
        signal_dict["signal_pa"] = read.signal_pa
        signal_dict["end_reason"] = read.end_reason.reason.name
        signal_dict["sample_count"] = read.sample_count
        signal_dict["channel"] = read.pore.channel
        signal_dict["well"] = read.pore.well
        signal_dict["pore_type"] = read.pore.pore_type
        signal_dict["writing_software"] = reader.writing_software
        signal_dict["scale"] = read.tracked_scaling.scale
        signal_dict["shift"] = read.tracked_scaling.shift
    return signal_dict

z_normalize(data: np.ndarray) -> npt.NDArray[np.float64]

Normalize the input data using Z-score normalization.

Parameters:

Name Type Description Default
data ndarray

Input data to be Z-score normalized.

required

Returns:

Type Description
NDArray[float64]

npt.NDArray[np.float64]: Z-score normalized data.

Note

Z-score normalization (or Z normalization) transforms the data to have a mean of 0 and a standard deviation of 1.

Source code in src/capfinder/process_pod5.py
def z_normalize(data: np.ndarray) -> npt.NDArray[np.float64]:
    """Normalize the input data using Z-score normalization.

    Params:
        data (np.ndarray): Input data to be Z-score normalized.

    Returns:
        npt.NDArray[np.float64]: Z-score normalized data.

    Note:
        Z-score normalization (or Z normalization) transforms the data
        to have a mean of 0 and a standard deviation of 1.
    """
    mean = np.mean(data)
    std_dev = np.std(data)
    z_normalized_data: npt.NDArray[np.float64] = (data - mean) / std_dev
    return z_normalized_data

report

count_csv_rows(csv_file: str) -> int

Count the number of rows in a CSV file.

Source code in src/capfinder/report.py
def count_csv_rows(csv_file: str) -> int:
    """Count the number of rows in a CSV file."""
    with open(csv_file) as f:
        return sum(1 for _ in f) - 1  # Subtract 1 to account for header

create_database(db_path: str) -> sqlite3.Connection

Create a new SQLite database and return the connection.

Source code in src/capfinder/report.py
def create_database(db_path: str) -> sqlite3.Connection:
    """Create a new SQLite database and return the connection."""
    conn = sqlite3.connect(db_path)
    return conn

create_table(conn: sqlite3.Connection, table_name: str, columns: List[str]) -> None

Create a table in the SQLite database.

Source code in src/capfinder/report.py
def create_table(conn: sqlite3.Connection, table_name: str, columns: List[str]) -> None:
    """Create a table in the SQLite database."""
    cursor = conn.cursor()
    columns_str = ", ".join(columns)
    cursor.execute(f"CREATE TABLE IF NOT EXISTS {table_name} ({columns_str})")
    conn.commit()

csv_to_sqlite(csv_file: str, db_conn: sqlite3.Connection, table_name: str, chunk_size: int = 100000) -> None

Import CSV data into SQLite database in chunks with progress bar.

Source code in src/capfinder/report.py
def csv_to_sqlite(
    csv_file: str,
    db_conn: sqlite3.Connection,
    table_name: str,
    chunk_size: int = 100000,
) -> None:
    """Import CSV data into SQLite database in chunks with progress bar."""
    create_table(
        db_conn,
        table_name,
        ["read_id TEXT PRIMARY KEY", "pod5_file TEXT", "predicted_cap TEXT"],
    )

    cursor = db_conn.cursor()
    total_rows = count_csv_rows(csv_file)

    with open(csv_file) as f:
        csv_reader = csv.DictReader(f)
        chunk = []
        with tqdm(total=total_rows, unit="reads") as pbar:
            for rw in csv_reader:
                chunk.append(
                    (
                        rw["read_id"],
                        rw.get("pod5_file", ""),
                        rw.get("predicted_cap", ""),
                    )
                )
                if len(chunk) >= chunk_size:
                    cursor.executemany(
                        f"INSERT OR REPLACE INTO {table_name} (read_id, pod5_file, predicted_cap) VALUES (?, ?, ?)",
                        chunk,
                    )
                    db_conn.commit()
                    pbar.update(len(chunk))
                    chunk = []

            if chunk:  # Insert any remaining rows
                cursor.executemany(
                    f"INSERT OR REPLACE INTO {table_name} (read_id, pod5_file, predicted_cap) VALUES (?, ?, ?)",
                    chunk,
                )
                db_conn.commit()
                pbar.update(len(chunk))

get_cap_type_counts(conn: sqlite3.Connection) -> DefaultDict[str, int]

Get cap type counts from the joined data.

Source code in src/capfinder/report.py
def get_cap_type_counts(conn: sqlite3.Connection) -> DefaultDict[str, int]:
    """Get cap type counts from the joined data."""
    query = """
    SELECT COALESCE(predicted_cap, 'OTE_not_found') as cap, COUNT(*) as count
    FROM (
        SELECT m.read_id, COALESCE(p.predicted_cap, 'OTE_not_found') as predicted_cap
        FROM metadata m
        LEFT JOIN predictions p ON m.read_id = p.read_id
    )
    GROUP BY cap
    """
    cursor = conn.cursor()
    cursor.execute(query)

    results = defaultdict(int)
    for cap, count in cursor:
        results[cap] = count

    return results

join_tables(conn: sqlite3.Connection, output_csv: str, chunk_size: int = 100000) -> None

Join metadata and predictions tables and save to CSV in chunks with progress bar.

Source code in src/capfinder/report.py
def join_tables(
    conn: sqlite3.Connection, output_csv: str, chunk_size: int = 100000
) -> None:
    """Join metadata and predictions tables and save to CSV in chunks with progress bar."""
    query = """
    SELECT m.read_id, m.pod5_file, COALESCE(p.predicted_cap, 'OTE_not_found') as predicted_cap
    FROM metadata m
    LEFT JOIN predictions p ON m.read_id = p.read_id
    """

    cursor = conn.cursor()
    cursor.execute("SELECT COUNT(*) FROM metadata")
    total_rows = cursor.fetchone()[0]

    cursor.execute(query)

    with open(output_csv, "w", newline="") as f:
        csv_writer = csv.writer(f)
        csv_writer.writerow(["read_id", "pod5_file", "predicted_cap"])  # Write header

        with tqdm(total=total_rows, unit="reads") as pbar:
            while True:
                results = cursor.fetchmany(chunk_size)
                if not results:
                    break
                csv_writer.writerows(results)
                pbar.update(len(results))

resnet_model

ResNetTimeSeriesHyper

Bases: HyperModel

A HyperModel class for building a ResNet-style neural network for time series classification.

This class defines a tunable ResNet architecture that can be optimized using Keras Tuner. It creates a model with an initial convolutional layer, followed by a variable number of ResNet blocks, and ends with global average pooling and dense layers.

Attributes:

Name Type Description
input_shape Tuple[int, int]

The shape of the input data (timesteps, features).

n_classes int

The number of classes for classification.

Methods:

Name Description
build

Builds and returns a compiled Keras model based on the provided hyperparameters.

Source code in src/capfinder/resnet_model.py
class ResNetTimeSeriesHyper(HyperModel):
    """
    A HyperModel class for building a ResNet-style neural network for time series classification.

    This class defines a tunable ResNet architecture that can be optimized using Keras Tuner.
    It creates a model with an initial convolutional layer, followed by a variable number of
    ResNet blocks, and ends with global average pooling and dense layers.

    Attributes:
        input_shape (Tuple[int, int]): The shape of the input data (timesteps, features).
        n_classes (int): The number of classes for classification.

    Methods:
        build(hp): Builds and returns a compiled Keras model based on the provided hyperparameters.
    """

    def __init__(self, input_shape: Tuple[int, int], n_classes: int):
        self.input_shape = input_shape
        self.n_classes = n_classes
        self.encoder_model = None

    def build(self, hp: HyperParameters) -> Model:
        """
        Build and compile a ResNet model based on the provided hyperparameters.

        This method constructs a ResNet architecture with tunable hyperparameters including
        the number of filters, kernel sizes, number of ResNet blocks, dense layer units,
        dropout rate, and learning rate.

        Args:
            hp (hp.HyperParameters): A HyperParameters object used to define the search space.

        Returns:
            Model: A compiled Keras model ready for training.
        """
        inputs = keras.Input(shape=self.input_shape)

        # Initial convolution
        initial_filters = hp.Int(
            "initial_filters", min_value=32, max_value=128, step=32
        )
        x = layers.Conv1D(
            initial_filters,
            kernel_size=hp.Choice("initial_kernel", values=[3, 5, 7]),
            padding="same",
        )(inputs)
        x = layers.BatchNormalization()(x)
        x = keras.activations.relu(x)
        x = layers.MaxPooling1D(pool_size=3, strides=2, padding="same")(x)

        # ResNet blocks
        num_blocks_per_stage = hp.Int("num_blocks_per_stage", min_value=2, max_value=4)
        num_stages = hp.Int("num_stages", min_value=2, max_value=4)

        for stage in range(num_stages):
            filters = hp.Int(
                f"filters_stage_{stage}", min_value=64, max_value=256, step=64
            )
            for block in range(num_blocks_per_stage):
                kernel_size = hp.Choice(
                    f"kernel_stage_{stage}_block_{block}", values=[3, 5, 7]
                )
                strides = 2 if block == 0 and stage > 0 else 1
                x = ResNetBlockHyper(filters, kernel_size, strides)(x)

        # Global pooling and output
        x = layers.GlobalAveragePooling1D()(x)
        x = layers.Dense(
            hp.Int("dense_units", min_value=32, max_value=256, step=32),
            activation="relu",
        )(x)
        x = layers.Dropout(hp.Float("dropout", min_value=0.0, max_value=0.5, step=0.1))(
            x
        )
        outputs = layers.Dense(self.n_classes, activation="softmax")(x)

        model = Model(inputs, outputs)

        model.compile(
            optimizer=keras.optimizers.Adam(
                hp.Float(
                    "learning_rate", min_value=1e-4, max_value=1e-2, sampling="log"
                )
            ),
            loss="sparse_categorical_crossentropy",
            metrics=["sparse_categorical_accuracy"],
        )

        return model

build(hp: HyperParameters) -> Model

Build and compile a ResNet model based on the provided hyperparameters.

This method constructs a ResNet architecture with tunable hyperparameters including the number of filters, kernel sizes, number of ResNet blocks, dense layer units, dropout rate, and learning rate.

Parameters:

Name Type Description Default
hp HyperParameters

A HyperParameters object used to define the search space.

required

Returns:

Name Type Description
Model Model

A compiled Keras model ready for training.

Source code in src/capfinder/resnet_model.py
def build(self, hp: HyperParameters) -> Model:
    """
    Build and compile a ResNet model based on the provided hyperparameters.

    This method constructs a ResNet architecture with tunable hyperparameters including
    the number of filters, kernel sizes, number of ResNet blocks, dense layer units,
    dropout rate, and learning rate.

    Args:
        hp (hp.HyperParameters): A HyperParameters object used to define the search space.

    Returns:
        Model: A compiled Keras model ready for training.
    """
    inputs = keras.Input(shape=self.input_shape)

    # Initial convolution
    initial_filters = hp.Int(
        "initial_filters", min_value=32, max_value=128, step=32
    )
    x = layers.Conv1D(
        initial_filters,
        kernel_size=hp.Choice("initial_kernel", values=[3, 5, 7]),
        padding="same",
    )(inputs)
    x = layers.BatchNormalization()(x)
    x = keras.activations.relu(x)
    x = layers.MaxPooling1D(pool_size=3, strides=2, padding="same")(x)

    # ResNet blocks
    num_blocks_per_stage = hp.Int("num_blocks_per_stage", min_value=2, max_value=4)
    num_stages = hp.Int("num_stages", min_value=2, max_value=4)

    for stage in range(num_stages):
        filters = hp.Int(
            f"filters_stage_{stage}", min_value=64, max_value=256, step=64
        )
        for block in range(num_blocks_per_stage):
            kernel_size = hp.Choice(
                f"kernel_stage_{stage}_block_{block}", values=[3, 5, 7]
            )
            strides = 2 if block == 0 and stage > 0 else 1
            x = ResNetBlockHyper(filters, kernel_size, strides)(x)

    # Global pooling and output
    x = layers.GlobalAveragePooling1D()(x)
    x = layers.Dense(
        hp.Int("dense_units", min_value=32, max_value=256, step=32),
        activation="relu",
    )(x)
    x = layers.Dropout(hp.Float("dropout", min_value=0.0, max_value=0.5, step=0.1))(
        x
    )
    outputs = layers.Dense(self.n_classes, activation="softmax")(x)

    model = Model(inputs, outputs)

    model.compile(
        optimizer=keras.optimizers.Adam(
            hp.Float(
                "learning_rate", min_value=1e-4, max_value=1e-2, sampling="log"
            )
        ),
        loss="sparse_categorical_crossentropy",
        metrics=["sparse_categorical_accuracy"],
    )

    return model

train_etl

augment_example(x: tf.Tensor, y: tf.Tensor, dtype: tf.DType) -> tf.data.Dataset

Augment a single example by creating warped versions and combining them with the original.

Parameters:

Name Type Description Default
x Tensor

The input tensor to be augmented.

required
y Tensor

The corresponding label tensor.

required
dtype DType

The desired data type for the augmented tensors.

required

Returns:

Type Description
Dataset

tf.data.Dataset: A dataset containing the original and augmented examples with their labels.

Source code in src/capfinder/train_etl.py
def augment_example(x: tf.Tensor, y: tf.Tensor, dtype: tf.DType) -> tf.data.Dataset:
    """
    Augment a single example by creating warped versions and combining them with the original.

    Args:
        x (tf.Tensor): The input tensor to be augmented.
        y (tf.Tensor): The corresponding label tensor.
        dtype (tf.DType): The desired data type for the augmented tensors.

    Returns:
        tf.data.Dataset: A dataset containing the original and augmented examples with their labels.
    """
    # Apply augmentation to each example in the batch
    squished, expanded = create_warped_examples(x, 0.2, dtype=dtype)

    # Ensure all tensors have the same data type
    x = tf.cast(x, dtype)
    squished = tf.cast(squished, dtype)
    expanded = tf.cast(expanded, dtype)

    # Create a list of augmented examples
    augmented_x = [x, squished, expanded]
    augmented_y = [y, y, y]

    return tf.data.Dataset.from_tensor_slices((augmented_x, augmented_y))

calculate_sizes(total_examples: int, train_fraction: float, batch_size: int) -> Tuple[int, int]

Compute the train and validation sizes based on the total number of examples.

Parameters:

Name Type Description Default
total_examples int

Total number of examples in the dataset.

required
train_fraction float

Fraction of data to use for training.

required
batch_size int

Size of each batch.

required

Returns:

Type Description
Tuple[int, int]

Tuple[int, int]: Train size and validation size, both divisible by batch_size.

Source code in src/capfinder/train_etl.py
def calculate_sizes(
    total_examples: int, train_fraction: float, batch_size: int
) -> Tuple[int, int]:
    """
    Compute the train and validation sizes based on the total number of examples.

    Args:
        total_examples (int): Total number of examples in the dataset.
        train_fraction (float): Fraction of data to use for training.
        batch_size (int): Size of each batch.

    Returns:
        Tuple[int, int]: Train size and validation size, both divisible by batch_size.
    """
    train_size = int(total_examples * train_fraction)
    val_size = total_examples - train_size

    train_size = (train_size // batch_size) * batch_size
    val_size = (val_size // batch_size) * batch_size

    while train_size + val_size > total_examples:
        if train_size > val_size:
            train_size -= batch_size
        else:
            val_size -= batch_size

    return train_size, val_size

combine_datasets(datasets: List[tf.data.Dataset]) -> tf.data.Dataset

Combine datasets from different classes.

Parameters:

Name Type Description Default
datasets List[Dataset]

List of datasets to combine.

required

Returns:

Type Description
Dataset

tf.data.Dataset: A combined dataset.

Source code in src/capfinder/train_etl.py
def combine_datasets(datasets: List[tf.data.Dataset]) -> tf.data.Dataset:
    """
    Combine datasets from different classes.

    Args:
        datasets (List[tf.data.Dataset]): List of datasets to combine.

    Returns:
        tf.data.Dataset: A combined dataset.
    """
    combined_dataset = datasets[0]
    for dataset in datasets[1:]:
        combined_dataset = combined_dataset.concatenate(dataset)
    return combined_dataset.shuffle(buffer_size=10000)

count_examples_fast(file_path: str) -> int

Count lines in a file using fast bash utilities, falling back to Python if necessary.

Parameters:

Name Type Description Default
file_path str

Path to the file to count lines in.

required

Returns:

Name Type Description
int int

Number of lines in the file (excluding header).

Source code in src/capfinder/train_etl.py
def count_examples_fast(file_path: str) -> int:
    """
    Count lines in a file using fast bash utilities, falling back to Python if necessary.

    Args:
        file_path (str): Path to the file to count lines in.

    Returns:
        int: Number of lines in the file (excluding header).
    """
    try:
        # Try using wc -l command (fast)
        result = subprocess.run(["wc", "-l", file_path], capture_output=True, text=True)
        count = int(result.stdout.split()[0]) - 1  # Subtract 1 for header
        return count
    except (subprocess.SubprocessError, FileNotFoundError, ValueError):
        try:
            # Fallback to using sed and wc (slightly slower, but still fast)
            result = subprocess.run(
                f"sed '1d' {file_path} | wc -l",
                shell=True,
                capture_output=True,
                text=True,
            )
            return int(result.stdout.strip())
        except (subprocess.SubprocessError, ValueError):
            # If bash methods fail, fall back to Python method
            return count_examples_python(file_path)

count_examples_python(file_path: str) -> int

Count lines in a file using Python (slower but portable).

Parameters:

Name Type Description Default
file_path str

Path to the file to count lines in.

required

Returns:

Name Type Description
int int

Number of lines in the file (excluding header).

Source code in src/capfinder/train_etl.py
def count_examples_python(file_path: str) -> int:
    """
    Count lines in a file using Python (slower but portable).

    Args:
        file_path (str): Path to the file to count lines in.

    Returns:
        int: Number of lines in the file (excluding header).
    """
    with open(file_path) as f:
        return sum(1 for _ in f) - 1  # Subtract 1 for header

create_class_dataset(file_paths: List[str], target_length: int, dtype: DtypeLiteral, examples_per_class: int, train_test_fraction: float) -> Tuple[tf.data.Dataset, tf.data.Dataset]

Create a dataset for a single class from multiple files.

Parameters:

Name Type Description Default
file_paths List[str]

List of file paths for a single class.

required
target_length int

The desired length of the timeseries tensor.

required
dtype DtypeLiteral

The desired data type for the timeseries tensor as a string.

required
examples_per_class int

Number of examples to take per class.

required
train_test_fraction float

Fraction of data to use for training.

required

Returns:

Type Description
Tuple[Dataset, Dataset]

Tuple[tf.data.Dataset, tf.data.Dataset]: Train and test datasets for the given class.

Source code in src/capfinder/train_etl.py
def create_class_dataset(
    file_paths: List[str],
    target_length: int,
    dtype: DtypeLiteral,
    examples_per_class: int,
    train_test_fraction: float,
) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
    """
    Create a dataset for a single class from multiple files.

    Args:
        file_paths (List[str]): List of file paths for a single class.
        target_length (int): The desired length of the timeseries tensor.
        dtype (DtypeLiteral): The desired data type for the timeseries tensor as a string.
        examples_per_class (int): Number of examples to take per class.
        train_test_fraction (float): Fraction of data to use for training.

    Returns:
        Tuple[tf.data.Dataset, tf.data.Dataset]: Train and test datasets for the given class.
    """
    class_dataset: Optional[tf.data.Dataset] = None

    for file_path in file_paths:
        dataset = create_dataset(file_path, target_length, dtype)

        if class_dataset is None:
            class_dataset = dataset
        else:
            class_dataset = class_dataset.concatenate(dataset)

    if class_dataset is None:
        raise ValueError("No valid datasets were created.")

    # Shuffle and take examples after concatenating all files
    class_dataset = class_dataset.shuffle(buffer_size=10000).take(examples_per_class)

    # Split into train and test
    train_size = int(train_test_fraction * examples_per_class)
    train_dataset = class_dataset.take(train_size)
    test_dataset = class_dataset.skip(train_size)
    return train_dataset, test_dataset

create_dataset(file_path: str, target_length: int, dtype: DtypeLiteral) -> tf.data.Dataset

Create a TensorFlow dataset for a single class CSV file.

Parameters:

Name Type Description Default
file_path str

Path to the CSV file.

required
target_length int

The desired length of the timeseries tensor.

required
dtype DtypeLiteral

The desired data type for the timeseries tensor as a string.

required

Returns:

Type Description
Dataset

tf.data.Dataset: A dataset for the given class.

Source code in src/capfinder/train_etl.py
def create_dataset(
    file_path: str,
    target_length: int,
    dtype: DtypeLiteral,
) -> tf.data.Dataset:
    """
    Create a TensorFlow dataset for a single class CSV file.

    Args:
        file_path (str): Path to the CSV file.
        target_length (int): The desired length of the timeseries tensor.
        dtype (DtypeLiteral): The desired data type for the timeseries tensor as a string.

    Returns:
        tf.data.Dataset: A dataset for the given class.
    """
    tf_dtype = get_dtype(dtype)
    dataset = tf.data.Dataset.from_generator(
        lambda: csv_generator(file_path),
        output_signature=(
            tf.TensorSpec(shape=(), dtype=tf.string),
            tf.TensorSpec(shape=(), dtype=tf.string),
            tf.TensorSpec(shape=(), dtype=tf.string),
        ),
    )
    dataset = dataset.map(
        lambda x, y, z: parse_row((x, y, z), target_length, tf_dtype),
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    return dataset

create_train_val_test_datasets_from_train_test_csvs(dataset_dir: str, batch_size: int, target_length: int, dtype: tf.DType, train_val_fraction: float, use_augmentation: bool = False) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset, int, int]

Load ready-made train, validation, and test datasets from CSV files.

Parameters:

Name Type Description Default
dataset_dir str

Directory containing the CSV files.

required
batch_size int

Size of each batch.

required
target_length int

Target length of each time series.

required
dtype DType

Data type for the features.

required
train_val_fraction float

Fraction of training data to use for validation.

required
use_augmentation bool

Whether to augment original training examples with warped versions

False

Returns:

Type Description
Dataset

Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset, int, int]:

Dataset

Train dataset, validation dataset, test dataset, steps per epoch, and validation steps.

Source code in src/capfinder/train_etl.py
def create_train_val_test_datasets_from_train_test_csvs(
    dataset_dir: str,
    batch_size: int,
    target_length: int,
    dtype: tf.DType,
    train_val_fraction: float,
    use_augmentation: bool = False,
) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset, int, int]:
    """
    Load ready-made train, validation, and test datasets from CSV files.

    Args:
        dataset_dir (str): Directory containing the CSV files.
        batch_size (int): Size of each batch.
        target_length (int): Target length of each time series.
        dtype (tf.DType): Data type for the features.
        train_val_fraction (float): Fraction of training data to use for validation.
        use_augmentation (bool): Whether to augment original training examples with warped versions

    Returns:
        Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset, int, int]:
        Train dataset, validation dataset, test dataset, steps per epoch, and validation steps.
    """
    logger.info("Loading train, val splits...")

    train_dataset, val_dataset, steps_per_epoch, validation_steps = (
        load_train_dataset_from_csvs(
            x_file_path=os.path.join(dataset_dir, "train_x.csv"),
            y_file_path=os.path.join(dataset_dir, "train_y.csv"),
            batch_size=batch_size,
            target_length=target_length,
            dtype=dtype,
            train_val_fraction=train_val_fraction,
            use_augmentation=use_augmentation,
        )
    )
    logger.info("Loading test split ...")

    test_dataset = load_test_dataset_from_csvs(
        x_file_path=os.path.join(dataset_dir, "test_x.csv"),
        y_file_path=os.path.join(dataset_dir, "test_y.csv"),
        batch_size=batch_size,
        target_length=target_length,
        dtype=dtype,
    )
    return train_dataset, val_dataset, test_dataset, steps_per_epoch, validation_steps

create_warped_examples(signal: tf.Tensor, max_warp_factor: float = 0.3, dtype: tf.DType = tf.float32) -> Tuple[tf.Tensor, tf.Tensor]

Create warped versions (squished and expanded) of the input signal.

Parameters:

Name Type Description Default
signal Tensor

The input signal to be warped.

required
max_warp_factor float

The maximum factor by which the signal can be warped. Defaults to 0.3.

0.3
dtype DType

The desired data type for the output tensors. Defaults to tf.float32.

float32

Returns:

Type Description
Tuple[Tensor, Tensor]

Tuple[tf.Tensor, tf.Tensor]: A tuple containing the squished and expanded versions of the input signal.

Source code in src/capfinder/train_etl.py
def create_warped_examples(
    signal: tf.Tensor, max_warp_factor: float = 0.3, dtype: tf.DType = tf.float32
) -> Tuple[tf.Tensor, tf.Tensor]:
    """
    Create warped versions (squished and expanded) of the input signal.

    Args:
        signal (tf.Tensor): The input signal to be warped.
        max_warp_factor (float): The maximum factor by which the signal can be warped. Defaults to 0.3.
        dtype (tf.DType): The desired data type for the output tensors. Defaults to tf.float32.

    Returns:
        Tuple[tf.Tensor, tf.Tensor]: A tuple containing the squished and expanded versions of the input signal.
    """
    original_dtype = signal.dtype
    signal = tf.cast(signal, tf.float32)  # Convert to float32 for internal calculations

    time_steps = tf.shape(signal)[0]

    # Create squished version
    squish_factor = 1 - tf.random.uniform((), 0, max_warp_factor, seed=43)
    squished_length = tf.cast(tf.cast(time_steps, tf.float32) * squish_factor, tf.int32)
    squished = tf.image.resize(tf.expand_dims(signal, -1), (squished_length, 1))[
        :, :, 0
    ]
    pad_total = time_steps - squished_length
    pad_left = pad_total // 2
    pad_right = pad_total - pad_left
    padding = [[pad_left, pad_right], [0, 0]]
    squished = tf.pad(squished, padding)

    # Create expanded version
    expand_factor = 1 + tf.random.uniform((), 0, max_warp_factor, seed=43)
    expanded_length = tf.cast(tf.cast(time_steps, tf.float32) * expand_factor, tf.int32)
    expanded = tf.image.resize(tf.expand_dims(signal, -1), (expanded_length, 1))[
        :, :, 0
    ]
    trim_total = expanded_length - time_steps
    trim_left = trim_total // 2
    trim_right = expanded_length - (trim_total - trim_left)
    expanded = expanded[trim_left:trim_right]

    # Cast back to original dtype
    squished = tf.cast(squished, original_dtype)
    expanded = tf.cast(expanded, original_dtype)

    return squished, expanded

csv_generator(file_path: str) -> Generator[Tuple[str, str, str], None, None]

Generates rows from a CSV file one at a time.

Parameters:

Name Type Description Default
file_path str

Path to the CSV file.

required

Yields:

Type Description
str

Tuple[str, str, str]: A tuple containing read_id, cap_class, and timeseries as strings.

Source code in src/capfinder/train_etl.py
def csv_generator(file_path: str) -> Generator[Tuple[str, str, str], None, None]:
    """
    Generates rows from a CSV file one at a time.

    Args:
        file_path (str): Path to the CSV file.

    Yields:
        Tuple[str, str, str]: A tuple containing read_id, cap_class, and timeseries as strings.
    """
    with open(file_path) as csvfile:
        reader = csv.reader(csvfile)
        # Skip the header row
        next(reader)
        for row in reader:
            yield (str(row[0]), str(row[1]), str(row[2]))

get_class_from_file(file_path: str) -> int

Read the first data row from a CSV file and return the class ID.

Parameters:

Name Type Description Default
file_path str

Path to the CSV file.

required

Returns:

Name Type Description
int int

Class ID from the first data row.

Source code in src/capfinder/train_etl.py
def get_class_from_file(file_path: str) -> int:
    """
    Read the first data row from a CSV file and return the class ID.

    Args:
        file_path (str): Path to the CSV file.

    Returns:
        int: Class ID from the first data row.
    """
    with open(file_path) as f:
        csv_reader = csv.reader(f)
        next(csv_reader)  # Skip header
        first_row = next(csv_reader)
        return int(first_row[1])  # Assuming cap_class is the second column

get_local_dataset_version(dataset_dir: str) -> Optional[str]

Get the version of the local dataset.

Parameters:

Name Type Description Default
dataset_dir str

The directory containing the dataset.

required

Returns:

Type Description
Optional[str]

Optional[str]: The version of the local dataset, or None if not found.

Source code in src/capfinder/train_etl.py
def get_local_dataset_version(dataset_dir: str) -> Optional[str]:
    """
    Get the version of the local dataset.

    Args:
        dataset_dir (str): The directory containing the dataset.

    Returns:
        Optional[str]: The version of the local dataset, or None if not found.
    """
    stored_version = None
    version_file = os.path.join(dataset_dir, "artifact_version.txt")
    train_x_file = os.path.join(dataset_dir, "train_x.csv")
    train_y_file = os.path.join(dataset_dir, "train_y.csv")
    test_x_file = os.path.join(dataset_dir, "test_x.csv")
    test_y_file = os.path.join(dataset_dir, "test_y.csv")
    train_exists = os.path.exists(train_x_file) and os.path.exists(train_y_file)
    test_exists = os.path.exists(test_x_file) and os.path.exists(test_y_file)
    version_file_exists = os.path.exists(version_file)

    if train_exists and test_exists and version_file_exists:
        with open(version_file) as f:
            stored_version = f.read().strip()
    return stored_version

group_files_by_class(caps_data_dir: str) -> Dict[int, List[str]]

Group CSV files in the directory by their class ID.

Parameters:

Name Type Description Default
caps_data_dir str

Directory containing the CSV files.

required

Returns:

Type Description
Dict[int, List[str]]

Dict[int, List[str]]: Dictionary mapping class IDs to lists of file paths.

Source code in src/capfinder/train_etl.py
def group_files_by_class(caps_data_dir: str) -> Dict[int, List[str]]:
    """
    Group CSV files in the directory by their class ID.

    Args:
        caps_data_dir (str): Directory containing the CSV files.

    Returns:
        Dict[int, List[str]]: Dictionary mapping class IDs to lists of file paths.
    """
    class_files: Dict[int, List[str]] = defaultdict(list)
    for file in os.listdir(caps_data_dir):
        if file.endswith(".csv"):
            file_path = os.path.join(caps_data_dir, file)
            try:
                class_id = get_class_from_file(file_path)
                class_files[class_id].append(file_path)
            except Exception as e:
                logger.warning(
                    f"Couldn't determine class for file {file}. Error: {str(e)}"
                )
    return class_files

interleave_class_datasets(class_datasets: List[tf.data.Dataset], num_classes: int) -> tf.data.Dataset

Interleave datasets from different classes to ensure class balance.

Parameters:

Name Type Description Default
class_datasets List[Dataset]

List of datasets, one for each class.

required
num_classes int

The number of classes in the dataset.

required

Returns:

Type Description
Dataset

tf.data.Dataset: An interleaved dataset with balanced class representation.

Source code in src/capfinder/train_etl.py
def interleave_class_datasets(
    class_datasets: List[tf.data.Dataset],
    num_classes: int,
) -> tf.data.Dataset:
    """
    Interleave datasets from different classes to ensure class balance.

    Args:
        class_datasets (List[tf.data.Dataset]): List of datasets, one for each class.
        num_classes (int): The number of classes in the dataset.

    Returns:
        tf.data.Dataset: An interleaved dataset with balanced class representation.
    """
    # Ensure we have the correct number of datasets
    assert (
        len(class_datasets) == num_classes
    ), "Number of datasets should match number of classes"

    def interleave_map_fn(dataset: tf.data.Dataset) -> tf.data.Dataset:
        return dataset.map(lambda x, y, z: (x, y, z))

    # Use the interleave operation to balance the classes
    interleaved_dataset: tf.data.Dataset = tf.data.Dataset.from_tensor_slices(
        class_datasets
    ).interleave(
        interleave_map_fn,
        cycle_length=num_classes,
        block_length=1,
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    return interleaved_dataset

load_test_dataset_from_csvs(x_file_path: str, y_file_path: str, batch_size: int, target_length: int, dtype: DtypeLiteral) -> tf.data.Dataset

Load test dataset from CSV files.

Parameters:

Name Type Description Default
x_file_path str

Path to the features CSV file.

required
y_file_path str

Path to the labels CSV file.

required
batch_size int

Size of each batch.

required
target_length int

Target length of each time series.

required
dtype DtypeLiteral

Data type for the features as a string.

required

Returns:

Type Description
Dataset

tf.data.Dataset: Test dataset.

Source code in src/capfinder/train_etl.py
def load_test_dataset_from_csvs(
    x_file_path: str,
    y_file_path: str,
    batch_size: int,
    target_length: int,
    dtype: DtypeLiteral,
) -> tf.data.Dataset:
    """
    Load test dataset from CSV files.

    Args:
        x_file_path (str): Path to the features CSV file.
        y_file_path (str): Path to the labels CSV file.
        batch_size (int): Size of each batch.
        target_length (int): Target length of each time series.
        dtype (DtypeLiteral): Data type for the features as a string.

    Returns:
        tf.data.Dataset: Test dataset.
    """
    tf_dtype = get_dtype(dtype)

    def parse_fn(x: tf.Tensor, y: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
        x = tf.io.decode_csv(x, record_defaults=[[0.0]] * target_length)
        y = tf.io.decode_csv(y, record_defaults=[[0]])
        return tf.reshape(tf.stack(x), (target_length, 1)), y[0]

    dataset = tf.data.Dataset.zip(
        (
            tf.data.TextLineDataset(x_file_path).skip(1),
            tf.data.TextLineDataset(y_file_path).skip(1),
        )
    )

    return (
        dataset.map(parse_fn, num_parallel_calls=tf.data.AUTOTUNE)
        .batch(batch_size, drop_remainder=True)
        .map(lambda x, y: (tf.cast(x, tf_dtype), y))
        .prefetch(tf.data.AUTOTUNE)
    )

load_train_dataset_from_csvs(x_file_path: str, y_file_path: str, batch_size: int, target_length: int, dtype: tf.DType, train_val_fraction: float = 0.8, use_augmentation: bool = False) -> Tuple[tf.data.Dataset, tf.data.Dataset, int, int]

Load training dataset from CSV files and split into train and validation sets.

Parameters:

Name Type Description Default
x_file_path str

Path to the features CSV file.

required
y_file_path str

Path to the labels CSV file.

required
batch_size int

Size of each batch.

required
target_length int

Target length of each time series.

required
dtype DType

Data type for the features.

required
train_val_fraction float

Fraction of data to use for training. Defaults to 0.8.

0.8
use_augmentation bool

Whether to augment original training examples with warped versions

False

Returns:

Type Description
Dataset

Tuple[tf.data.Dataset, tf.data.Dataset, int, int]: Train dataset, validation dataset,

Dataset

steps per epoch, and validation steps.

Source code in src/capfinder/train_etl.py
def load_train_dataset_from_csvs(
    x_file_path: str,
    y_file_path: str,
    batch_size: int,
    target_length: int,
    dtype: tf.DType,
    train_val_fraction: float = 0.8,
    use_augmentation: bool = False,
) -> Tuple[tf.data.Dataset, tf.data.Dataset, int, int]:
    """
    Load training dataset from CSV files and split into train and validation sets.

    Args:
        x_file_path (str): Path to the features CSV file.
        y_file_path (str): Path to the labels CSV file.
        batch_size (int): Size of each batch.
        target_length (int): Target length of each time series.
        dtype (tf.DType): Data type for the features.
        train_val_fraction (float, optional): Fraction of data to use for training. Defaults to 0.8.
        use_augmentation (bool): Whether to augment original training examples with warped versions

    Returns:
        Tuple[tf.data.Dataset, tf.data.Dataset, int, int]: Train dataset, validation dataset,
        steps per epoch, and validation steps.
    """

    def parse_fn(x: tf.Tensor, y: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
        x = tf.io.decode_csv(x, record_defaults=[[0.0]] * target_length)
        y = tf.io.decode_csv(y, record_defaults=[[0]])
        return tf.reshape(tf.stack(x), (target_length, 1)), y[0]

    dataset = tf.data.Dataset.zip(
        (
            tf.data.TextLineDataset(x_file_path).skip(1),
            tf.data.TextLineDataset(y_file_path).skip(1),
        )
    )

    # Count total examples
    total_examples = count_examples_fast(x_file_path)
    # Calculate train and validation sizes
    train_size, val_size = calculate_sizes(
        total_examples, train_val_fraction, batch_size
    )

    # Split dataset into train and validation
    train_dataset = dataset.take(train_size)
    val_dataset = dataset.skip(train_size).take(val_size)

    # Process and augment the training dataset
    train_dataset = train_dataset.map(
        parse_fn, num_parallel_calls=tf.data.AUTOTUNE
    ).map(lambda x, y: (tf.cast(x, dtype), y))

    if use_augmentation:
        train_dataset = train_dataset.map(
            lambda x, y: augment_example(x, y, dtype)
        ).flat_map(
            lambda x: x
        )  # Flatten the dataset of datasets

    train_dataset = train_dataset.batch(batch_size, drop_remainder=True).prefetch(
        tf.data.AUTOTUNE
    )

    # Process the validation dataset (no augmentation)
    val_dataset = (
        val_dataset.map(parse_fn, num_parallel_calls=tf.data.AUTOTUNE)
        .batch(batch_size, drop_remainder=True)
        .map(lambda x, y: (tf.cast(x, dtype), y))
        .prefetch(tf.data.AUTOTUNE)
    )

    # Recalculate steps per epoch
    steps_per_epoch = (train_size * (3 if use_augmentation else 1)) // batch_size
    validation_steps = val_size // batch_size

    return (
        train_dataset,
        val_dataset,
        steps_per_epoch,
        validation_steps,
    )

parse_row(row: Tuple[str, str, str], target_length: int, dtype: tf.DType) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]

Parse a row of data and convert it to the appropriate tensor format. Padding and truncation are performed equally on both sides of the time series.

Parameters:

Name Type Description Default
row Tuple[str, str, str]

A tuple containing read_id, cap_class, and timeseries as strings.

required
target_length int

The desired length of the timeseries tensor.

required
dtype DType

The desired data type for the timeseries tensor.

required

Returns:

Type Description
Tensor

Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: A tuple containing the parsed and formatted tensors for

Tensor

timeseries, cap_class, and read_id.

Source code in src/capfinder/train_etl.py
def parse_row(
    row: Tuple[str, str, str], target_length: int, dtype: tf.DType
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
    """
    Parse a row of data and convert it to the appropriate tensor format.
    Padding and truncation are performed equally on both sides of the time series.

    Args:
        row (Tuple[str, str, str]): A tuple containing read_id, cap_class, and timeseries as strings.
        target_length (int): The desired length of the timeseries tensor.
        dtype (tf.DType): The desired data type for the timeseries tensor.

    Returns:
        Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: A tuple containing the parsed and formatted tensors for
        timeseries, cap_class, and read_id.
    """
    read_id, cap_class, timeseries = row
    cap_class = tf.strings.to_number(cap_class, out_type=tf.int32)

    # Split the timeseries string and convert to float
    timeseries = tf.strings.split(timeseries, sep=",")
    timeseries = tf.strings.to_number(timeseries, out_type=tf.float32)

    # Get the current length of the timeseries
    current_length = tf.shape(timeseries)[0]

    # Function to pad the timeseries
    def pad_timeseries() -> tf.Tensor:
        pad_amount = target_length - current_length
        pad_left = pad_amount // 2
        pad_right = pad_amount - pad_left
        return tf.pad(
            timeseries,
            [[pad_left, pad_right]],
            constant_values=0.0,
        )

    # Function to truncate the timeseries
    def truncate_timeseries() -> tf.Tensor:
        truncate_amount = current_length - target_length
        truncate_left = truncate_amount // 2
        truncate_right = current_length - (truncate_amount - truncate_left)
        return timeseries[truncate_left:truncate_right]

    # Pad or truncate the timeseries to the target length
    padded = tf.cond(
        current_length >= target_length, truncate_timeseries, pad_timeseries
    )

    padded = tf.reshape(padded, (target_length, 1))

    # Cast to the desired dtype
    if dtype != tf.float32:
        padded = tf.cast(padded, dtype)

    return padded, cap_class, read_id

read_dataset_version_info(dataset_dir: str) -> Optional[str]

Read the dataset version information from a file.

Parameters:

Name Type Description Default
dataset_dir str

Directory containing the dataset version file.

required

Returns:

Type Description
Optional[str]

Optional[str]: The dataset version if found, None otherwise.

Source code in src/capfinder/train_etl.py
def read_dataset_version_info(dataset_dir: str) -> Optional[str]:
    """
    Read the dataset version information from a file.

    Args:
        dataset_dir (str): Directory containing the dataset version file.

    Returns:
        Optional[str]: The dataset version if found, None otherwise.
    """
    version_file = os.path.join(dataset_dir, "artifact_version.txt")
    if os.path.exists(version_file):
        with open(version_file) as f:
            return f.read().strip()
    return None

train_etl(caps_data_dir: str, dataset_dir: str, target_length: int, dtype: DtypeLiteral, examples_per_class: int, train_test_fraction: float, train_val_fraction: float, num_classes: int, batch_size: int, comet_project_name: str, use_remote_dataset_version: str = '', use_augmentation: bool = False) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset, int, int, str]

Process the data from multiple class files, create balanced datasets, perform train-test split, and upload to Comet ML.

Parameters:

Name Type Description Default
caps_data_dir str

Directory containing the class CSV files.

required
dataset_dir str

Directory to save the processed dataset.

required
target_length int

The desired length of each time series.

required
dtype DtypeLiteral

The desired data type for the timeseries tensor as a string.

required
examples_per_class int

Number of samples to use per class.

required
train_test_fraction float

Fraction of data to use for training.

required
train_val_fraction float

Fraction of training data to use for validation.

required
num_classes int

Number of classes in the dataset.

required
batch_size int

The number of samples per batch.

required
comet_project_name str

Name of the Comet ML project.

required
use_remote_dataset_version str

Version of the remote dataset to use, if any.

''
use_augmentation bool

Whether to augment original training examples with warped versions

False

Returns: Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset, int, int, str]: The train, validation, and test datasets, steps per epoch, validation steps, and the dataset version.

Source code in src/capfinder/train_etl.py
def train_etl(
    caps_data_dir: str,
    dataset_dir: str,
    target_length: int,
    dtype: DtypeLiteral,
    examples_per_class: int,
    train_test_fraction: float,
    train_val_fraction: float,
    num_classes: int,
    batch_size: int,
    comet_project_name: str,
    use_remote_dataset_version: str = "",
    use_augmentation: bool = False,
) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset, int, int, str]:
    """
    Process the data from multiple class files, create balanced datasets,
    perform train-test split, and upload to Comet ML.

    Args:
        caps_data_dir (str): Directory containing the class CSV files.
        dataset_dir (str): Directory to save the processed dataset.
        target_length (int): The desired length of each time series.
        dtype (DtypeLiteral): The desired data type for the timeseries tensor as a string.
        examples_per_class (int): Number of samples to use per class.
        train_test_fraction (float): Fraction of data to use for training.
        train_val_fraction (float): Fraction of training data to use for validation.
        num_classes (int): Number of classes in the dataset.
        batch_size (int): The number of samples per batch.
        comet_project_name (str): Name of the Comet ML project.
        use_remote_dataset_version (str): Version of the remote dataset to use, if any.
        use_augmentation (bool): Whether to augment original training examples with warped versions
    Returns:
        Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset, int, int, str]:
        The train, validation, and test datasets, steps per epoch, validation steps, and the dataset version.
    """
    comet_obj = CometArtifactManager(
        project_name=comet_project_name, dataset_dir=dataset_dir
    )
    current_local_version = get_local_dataset_version(dataset_dir)
    reprocess_dataset = False

    # Check if the remote dataset version is different from the local version
    # If yes download the remote dataset version and load it
    # If no, then load the local dataset
    if use_remote_dataset_version != "":
        if use_remote_dataset_version != current_local_version:
            logger.info(
                f"Downloading remote dataset version: v{use_remote_dataset_version}.. "
            )
            comet_obj.download_remote_dataset(use_remote_dataset_version)
        else:
            logger.info(
                "Remote version is the same as the local version. Loading local dataset..."
            )
        train_dataset, val_dataset, test_dataset, steps_per_epoch, validation_steps = (
            create_train_val_test_datasets_from_train_test_csvs(
                dataset_dir,
                batch_size,
                target_length,
                dtype,
                train_val_fraction,
                use_augmentation,
            )
        )
        write_dataset_version_info(dataset_dir, version=use_remote_dataset_version)
        return (
            train_dataset,
            val_dataset,
            test_dataset,
            steps_per_epoch,
            validation_steps,
            use_remote_dataset_version,
        )

    if current_local_version and use_remote_dataset_version == "":
        logger.info(
            f"A dataset v{current_local_version} was found locally in:\n{dataset_dir}"
        )
        reprocess = input(
            "Do you want to overwrite this local dataset by reprocess the data, and creating a new dataset version? (y/n): "
        )
        if reprocess.lower() == "y" or reprocess.lower() == "yes":
            logger.info("You chose to reprocess the data. Reprocessing data...")
            reprocess_dataset = True
        else:
            logger.info("You chose not to reprocess the data.Loading local dataset..")
            (
                train_dataset,
                val_dataset,
                test_dataset,
                steps_per_epoch,
                validation_steps,
            ) = create_train_val_test_datasets_from_train_test_csvs(
                dataset_dir,
                batch_size,
                target_length,
                dtype,
                train_val_fraction,
                use_augmentation,
            )
            logger.info(f"Local dataset v{current_local_version} loaded successfully.")
            return (
                train_dataset,
                val_dataset,
                test_dataset,
                steps_per_epoch,
                validation_steps,
                current_local_version,
            )

    if reprocess_dataset or (
        not current_local_version and use_remote_dataset_version == ""
    ):
        class_files = group_files_by_class(caps_data_dir)
        min_class, min_rows = find_class_with_least_rows(class_files)
        if examples_per_class is None:
            examples_per_class = min_rows
        else:
            examples_per_class = min(examples_per_class, min_rows)
        logger.info(
            f"Each class in the dataset will have {examples_per_class} examples"
        )

        train_datasets = []
        test_datasets = []
        for class_id, file_paths in class_files.items():
            train_ds, test_ds = create_class_dataset(
                file_paths,
                target_length,
                dtype,
                examples_per_class,
                train_test_fraction,
            )
            train_datasets.append(train_ds)
            test_datasets.append(test_ds)
            logger.info(f"Processed class {map_cap_int_to_name(class_id)}!")

        logger.info("Combining class datasets...")
        train_dataset = combine_datasets(train_datasets)
        test_dataset = combine_datasets(test_datasets)

        logger.info("Interleaving classes for ensuring class balance in each batch...")
        train_dataset = interleave_class_datasets(
            train_datasets, num_classes=num_classes
        )
        test_dataset = interleave_class_datasets(test_datasets, num_classes=num_classes)

        # Calculate total dataset size
        logger.info("Calculating dataset size...")
        total_samples = examples_per_class * num_classes
        train_samples = int(train_test_fraction * total_samples)
        test_samples = total_samples - train_samples

        # Batch the datasets
        logger.info("Batching dataset...")
        train_dataset = train_dataset.batch(batch_size)
        test_dataset = test_dataset.batch(batch_size)

        # Prefetch for performance
        logger.info("Prefetching dataset...")
        train_dataset = train_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
        test_dataset = test_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

        logger.info("Saving train/test splits to CSV files...")
        write_dataset_to_csv(train_dataset, dataset_dir, "train")
        write_dataset_to_csv(test_dataset, dataset_dir, "test")
        logger.info(
            f"Train/test splits to CSV files in the following directory:\n{dataset_dir}"
        )

        # Log dataset information to Comet ML
        comet_obj.experiment.log_parameter("target_length", target_length)
        comet_obj.experiment.log_parameter("dtype", dtype)
        comet_obj.experiment.log_parameter("examples_per_class", examples_per_class)
        comet_obj.experiment.log_parameter("train_test_fraction", train_test_fraction)
        comet_obj.experiment.log_parameter("train_val_fraction", train_val_fraction)
        comet_obj.experiment.log_parameter("batch_size", batch_size)
        comet_obj.experiment.log_parameter("num_classes", len(class_files))
        comet_obj.experiment.log_parameter("total_samples", total_samples)
        comet_obj.experiment.log_parameter("train_samples", train_samples)
        comet_obj.experiment.log_parameter("test_samples", test_samples)

        logger.info("Making Comet ML dataset artifacts for uploading...")
        version = upload_dataset_to_comet(dataset_dir, comet_project_name)

        comet_obj.end_comet_experiment()
        logger.info(
            f"Data processed and resulting dataset {version} uploaded to Comet ML successfully."
        )

        logger.info(
            "Creating train, validation, and test datasets from dataset CSV files..."
        )
        train_dataset, val_dataset, test_dataset, steps_per_epoch, validation_steps = (
            create_train_val_test_datasets_from_train_test_csvs(
                dataset_dir,
                batch_size,
                target_length,
                dtype,
                train_val_fraction=train_val_fraction,
                use_augmentation=use_augmentation,
            )
        )

        return (
            train_dataset,
            val_dataset,
            test_dataset,
            steps_per_epoch,
            validation_steps,
            version,
        )

    raise RuntimeError("No valid dataset could be processed. Please check your inputs.")

write_dataset_to_csv(dataset: tf.data.Dataset, dataset_dir: str, train_test: str) -> None

Write a dataset to CSV files.

Parameters:

Name Type Description Default
dataset Dataset

The dataset to write.

required
dataset_dir str

The directory to write the CSV files to.

required
train_test str

Either 'train' or 'test' to indicate the dataset type.

required

Returns:

Type Description
None

None

Source code in src/capfinder/train_etl.py
def write_dataset_to_csv(
    dataset: tf.data.Dataset, dataset_dir: str, train_test: str
) -> None:
    """
    Write a dataset to CSV files.

    Args:
        dataset (tf.data.Dataset): The dataset to write.
        dataset_dir (str): The directory to write the CSV files to.
        train_test (str): Either 'train' or 'test' to indicate the dataset type.

    Returns:
        None
    """
    if not os.path.exists(dataset_dir):
        os.makedirs(dataset_dir)

    x_filename = os.path.join(dataset_dir, f"{train_test}_x.csv")
    y_filename = os.path.join(dataset_dir, f"{train_test}_y.csv")
    read_id_filename = os.path.join(dataset_dir, f"{train_test}_read_id.csv")

    with (
        open(x_filename, "w", newline="") as x_file,
        open(y_filename, "w", newline="") as y_file,
        open(read_id_filename, "w", newline="") as read_id_file,
    ):
        x_writer = csv.writer(x_file)
        y_writer = csv.writer(y_file)
        read_id_writer = csv.writer(read_id_file)

        # Write headers
        x_writer.writerow(
            [f"feature_{i}" for i in range(dataset.element_spec[0].shape[1])]
        )
        y_writer.writerow(["cap_class"])
        read_id_writer.writerow(["read_id"])

        pbar = tqdm(dataset, desc="Processing batches")

        for batch_num, (x, y, read_id) in enumerate(pbar):
            # Convert tensors to numpy arrays
            x_numpy = x.numpy()
            y_numpy = y.numpy()
            read_id_numpy = read_id.numpy()

            # Write x data (features)
            x_writer.writerows(x_numpy.reshape(x_numpy.shape[0], -1))

            # Write y data (labels)
            y_writer.writerows(y_numpy.reshape(-1, 1))

            # Write read_id data
            read_id_writer.writerows(
                [
                    [rid.decode("utf-8") if isinstance(rid, bytes) else rid]
                    for rid in read_id_numpy
                ]
            )
            pbar.set_description(f"Processed {batch_num + 1} batches")

write_dataset_version_info(dataset_dir: str, version: str) -> None

Write the dataset version information to a file.

Parameters:

Name Type Description Default
dataset_dir str

Directory to write the version file.

required
version str

Version information to write.

required
Source code in src/capfinder/train_etl.py
def write_dataset_version_info(dataset_dir: str, version: str) -> None:
    """
    Write the dataset version information to a file.

    Args:
        dataset_dir (str): Directory to write the version file.
        version (str): Version information to write.
    """
    version_file = os.path.join(dataset_dir, "artifact_version.txt")
    with open(version_file, "w") as f:
        f.write(version)

training

InterruptCallback

Bases: Callback

Callback to interrupt training based on a global flag.

Source code in src/capfinder/training.py
class InterruptCallback(keras.callbacks.Callback):
    """
    Callback to interrupt training based on a global flag.
    """

    def on_train_batch_end(
        self, batch: int, logs: Optional[Dict[str, float]] = None
    ) -> None:
        """
        Checks the global `stop_training` flag at the end of each batch.
        If True, interrupts training and logs a message.

        Args:
            batch: The current batch index (integer).
            logs: Optional dictionary of training metrics at the end of the batch (default: None).

        Returns:
            None
        """
        global stop_training
        if stop_training:
            logger.info("Training interrupted by user during batch.")
            self.model.stop_training = True

    def on_epoch_end(self, epoch: int, logs: Optional[Dict[str, float]] = None) -> None:
        """
        Checks the global `stop_training` flag at the end of each epoch.
        If True, interrupts training and logs a message.

        Args:
            epoch: The current epoch index (integer).
            logs: Optional dictionary of training metrics at the end of the epoch (default: None).

        Returns:
            None
        """
        global stop_training
        if stop_training:
            te = epoch + 1
            logger.info(f"Training interrupted by user at the end of epoch {te}")
            self.model.stop_training = True

on_epoch_end(epoch: int, logs: Optional[Dict[str, float]] = None) -> None

Checks the global stop_training flag at the end of each epoch. If True, interrupts training and logs a message.

Parameters:

Name Type Description Default
epoch int

The current epoch index (integer).

required
logs Optional[Dict[str, float]]

Optional dictionary of training metrics at the end of the epoch (default: None).

None

Returns:

Type Description
None

None

Source code in src/capfinder/training.py
def on_epoch_end(self, epoch: int, logs: Optional[Dict[str, float]] = None) -> None:
    """
    Checks the global `stop_training` flag at the end of each epoch.
    If True, interrupts training and logs a message.

    Args:
        epoch: The current epoch index (integer).
        logs: Optional dictionary of training metrics at the end of the epoch (default: None).

    Returns:
        None
    """
    global stop_training
    if stop_training:
        te = epoch + 1
        logger.info(f"Training interrupted by user at the end of epoch {te}")
        self.model.stop_training = True

on_train_batch_end(batch: int, logs: Optional[Dict[str, float]] = None) -> None

Checks the global stop_training flag at the end of each batch. If True, interrupts training and logs a message.

Parameters:

Name Type Description Default
batch int

The current batch index (integer).

required
logs Optional[Dict[str, float]]

Optional dictionary of training metrics at the end of the batch (default: None).

None

Returns:

Type Description
None

None

Source code in src/capfinder/training.py
def on_train_batch_end(
    self, batch: int, logs: Optional[Dict[str, float]] = None
) -> None:
    """
    Checks the global `stop_training` flag at the end of each batch.
    If True, interrupts training and logs a message.

    Args:
        batch: The current batch index (integer).
        logs: Optional dictionary of training metrics at the end of the batch (default: None).

    Returns:
        None
    """
    global stop_training
    if stop_training:
        logger.info("Training interrupted by user during batch.")
        self.model.stop_training = True

count_batches(dataset: tf.data.Dataset, dataset_name: str) -> int

Count the number of individual examples in a dataset.

Args: dataset (tf.data.Dataset): The dataset to count examples from. dataset_name (str): The name of the dataset.

Returns: int: The number of examples in the dataset.

Source code in src/capfinder/training.py
def count_batches(dataset: tf.data.Dataset, dataset_name: str) -> int:
    """
    Count the number of individual examples in a dataset.

    Args:
    dataset (tf.data.Dataset): The dataset to count examples from.
    dataset_name (str): The name of the dataset.

    Returns:
    int: The number of examples in the dataset.
    """
    count = sum(
        1 for _ in tqdm(dataset, desc=f"Batches in {dataset_name}", unit="batches")
    )
    return count

count_examples(dataset: tf.data.Dataset, dataset_name: str) -> int

Count the number of individual examples in a dataset.

Args: dataset (tf.data.Dataset): The dataset to count examples from. dataset_name (str): The name of the dataset.

Returns: int: The number of examples in the dataset.

Source code in src/capfinder/training.py
def count_examples(dataset: tf.data.Dataset, dataset_name: str) -> int:
    """
    Count the number of individual examples in a dataset.

    Args:
    dataset (tf.data.Dataset): The dataset to count examples from.
    dataset_name (str): The name of the dataset.

    Returns:
    int: The number of examples in the dataset.
    """
    count = sum(
        1 for _ in tqdm(dataset, desc=f"Examples in {dataset_name}", unit="examples")
    )
    return count

generate_unique_name(base_name: str, extension: str) -> str

Generate a unique filename with a datetime suffix.

Parameters:

base_name: str The base name of the file. extension: str The file extension.

Returns:

str The unique filename with the datetime suffix.

Source code in src/capfinder/training.py
def generate_unique_name(base_name: str, extension: str) -> str:
    """Generate a unique filename with a datetime suffix.

    Parameters:
    -----------
    base_name: str
        The base name of the file.
    extension: str
        The file extension.

    Returns:
    --------
    str
        The unique filename with the datetime suffix.
    """
    # Get the current date and time
    current_datetime = datetime.now().strftime("%Y%m%d_%H%M%S")
    # Append the date and time to the base name
    unique_filename = f"{base_name}_{current_datetime}{extension}"
    return unique_filename

handle_interrupt(signum: Optional[int] = None, frame: Optional[object] = None) -> None

Handles interrupt signals (e.g., Ctrl+C) by setting a global flag to stop training.

Parameters:

Name Type Description Default
signum Optional[int]

The signal number (optional).

None
frame Optional[object]

The current stack frame (optional).

None

Returns:

Type Description
None

None

Source code in src/capfinder/training.py
def handle_interrupt(
    signum: Optional[int] = None, frame: Optional[object] = None
) -> None:
    """
    Handles interrupt signals (e.g., Ctrl+C) by setting a global flag to stop training.

    Args:
        signum: The signal number (optional).
        frame: The current stack frame (optional).

    Returns:
        None
    """
    global stop_training
    stop_training = True

initialize_tuner(hyper_model: CNNLSTMModel | EncoderModel, tune_params: dict, model_save_dir: str, model_type: ModelType) -> Union[Hyperband, BayesianOptimization, RandomSearch]

Initialize a Keras Tuner object based on the specified tuning strategy.

Parameters:

hyper_model: CapfinderHyperModel An instance of the CapfinderHyperModel class. tune_params: dict A dictionary containing the hyperparameters for tuning. model_save_dir: str The directory where the model should be saved. comet_project_name: str model_type: ModelType Type of the model to be trained.

Returns:

Union[Hyperband, BayesianOptimization, RandomSearch]: An instance of the Keras Tuner class based on the specified tuning strategy.

Source code in src/capfinder/training.py
def initialize_tuner(
    hyper_model: "CNNLSTMModel | EncoderModel",
    tune_params: dict,
    model_save_dir: str,
    model_type: ModelType,
) -> Union[Hyperband, BayesianOptimization, RandomSearch]:
    """Initialize a Keras Tuner object based on the specified tuning strategy.

    Parameters:
    -----------
    hyper_model: CapfinderHyperModel
        An instance of the CapfinderHyperModel class.
    tune_params: dict
        A dictionary containing the hyperparameters for tuning.
    model_save_dir: str
        The directory where the model should be saved.
    comet_project_name: str
    model_type: ModelType
        Type of the model to be trained.

    Returns:
    --------
    Union[Hyperband, BayesianOptimization, RandomSearch]:
        An instance of the Keras Tuner class based on the specified tuning strategy.
    """

    tuning_strategy = tune_params["tuning_strategy"].lower()
    if tuning_strategy not in ["random_search", "bayesian_optimization", "hyperband"]:
        tuning_strategy = "hyperband"
        logger.warning(
            "Invalid tuning strategy. Using Hyperband. Valid options are: 'random_search', 'bayesian_optimization', and 'hyperband'"
        )

    if tuning_strategy == "hyperband":
        logger.info("Using Hyperband tuning strategy...")
        tuner = Hyperband(
            hypermodel=hyper_model.build,
            objective=Objective("val_sparse_categorical_accuracy", direction="max"),
            max_epochs=tune_params["max_epochs_hpt"],
            factor=tune_params["factor"],
            overwrite=tune_params["overwrite"],
            directory=model_save_dir,
            seed=tune_params["seed"],
            project_name=tune_params["comet_project_name"],
        )
    elif tuning_strategy == "bayesian_optimization":
        logger.info("Using Bayesian Optimization tuning strategy...")
        tuner = BayesianOptimization(
            hypermodel=hyper_model.build,
            objective=Objective("val_sparse_categorical_accuracy", direction="max"),
            max_trials=tune_params["max_trials"],
            overwrite=tune_params["overwrite"],
            directory=model_save_dir,
            seed=tune_params["seed"],
            project_name=tune_params["comet_project_name"],
        )
    elif tuning_strategy == "random_search":
        logger.info("Using Random Search tuning strategy...")
        tuner = RandomSearch(
            hypermodel=hyper_model.build,
            objective=Objective("val_sparse_categorical_accuracy", direction="max"),
            max_trials=tune_params["max_trials"],
            overwrite=tune_params["overwrite"],
            directory=model_save_dir,
            seed=tune_params["seed"],
            project_name=tune_params["comet_project_name"],
        )
    return tuner

kill_gpu_processes() -> None

Terminates processes running on the NVIDIA GPU and sets the Keras dtype policy to float16.

This function checks if the nvidia-smi command exists and, if found, attempts to terminate all Python processes utilizing the GPU. If no NVIDIA GPU is found, the function skips the termination step. It also sets the Keras global policy to mixed_float16 for faster training.

Returns:

Type Description
None

None

Source code in src/capfinder/training.py
def kill_gpu_processes() -> None:
    """
    Terminates processes running on the NVIDIA GPU and sets the Keras dtype policy to float16.

    This function checks if the `nvidia-smi` command exists and, if found, attempts
    to terminate all Python processes utilizing the GPU. If no NVIDIA GPU is found,
    the function skips the termination step. It also sets the Keras global policy to
    mixed_float16 for faster training.

    Returns:
        None
    """
    if shutil.which("nvidia-smi") is None:
        logger.info("No NVIDIA GPU found. Skipping GPU process termination.")
        return

    try:
        # Get the list of GPU processes
        result = subprocess.run(["nvidia-smi"], capture_output=True, text=True)
        lines = result.stdout.split("\n")

        # Parse the lines to find PIDs of processes using the GPU
        for line in lines:
            if "python" in line:  # Adjust this if other processes need to be terminated
                parts = line.split()
                pid = parts[4]
                print(f"Terminating process with PID: {pid}")
                subprocess.run(["kill", "-9", pid])
    except Exception as e:
        logger.warning(f"Error occurred while terminating GPU processes: {str(e)}")

save_model(model: keras.Model, base_name: str, extension: str, save_dir: str) -> str

Save the given model to a specified directory.

Parameters:

model: keras.Model The model to be saved. base_name: str The base name for the saved model file. extension: str The file extension for the saved model file. save_dir: str The directory where the model should be saved.

Returns:

str The full path where the model was saved.

Source code in src/capfinder/training.py
def save_model(
    model: keras.Model, base_name: str, extension: str, save_dir: str
) -> str:
    """
    Save the given model to a specified directory.

    Parameters:
    -----------
    model: keras.Model
        The model to be saved.
    base_name: str
        The base name for the saved model file.
    extension: str
        The file extension for the saved model file.
    save_dir: str
        The directory where the model should be saved.

    Returns:
    --------
    str
        The full path where the model was saved.
    """
    # Generate a unique filename for the model
    model_filename = generate_unique_name(base_name, extension)

    # Construct the full path where the model should be saved
    model_save_path = os.path.join(save_dir, model_filename)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)

    # Save the model to the specified path
    model.save(model_save_path)
    logger.info(f"Best model saved to:{model_save_path}")

    # Return the save path
    return model_save_path

select_lr_scheduler(lr_scheduler_params: dict, train_size: int) -> Union[keras.callbacks.ReduceLROnPlateau, CyclicLR, SGDRScheduler]

Selects and configures the learning rate scheduler based on the provided parameters.

Parameters:

Name Type Description Default
lr_scheduler_params dict

Configuration parameters for the learning rate scheduler.

required
train_size int

Number of training examples, used for step size calculations.

required

Returns:

Type Description
Union[ReduceLROnPlateau, CyclicLR, SGDRScheduler]

Union[keras.callbacks.ReduceLROnPlateau, CyclicLR, SGDRScheduler]: The selected learning rate scheduler.

Source code in src/capfinder/training.py
def select_lr_scheduler(
    lr_scheduler_params: dict, train_size: int
) -> Union[keras.callbacks.ReduceLROnPlateau, CyclicLR, SGDRScheduler]:
    """
    Selects and configures the learning rate scheduler based on the provided parameters.

    Args:
        lr_scheduler_params (dict): Configuration parameters for the learning rate scheduler.
        train_size (int): Number of training examples, used for step size calculations.

    Returns:
        Union[keras.callbacks.ReduceLROnPlateau, CyclicLR, SGDRScheduler]: The selected learning rate scheduler.
    """
    scheduler_type = lr_scheduler_params["type"]

    if scheduler_type == "reduce_lr_on_plateau":
        rlr_params = lr_scheduler_params["reduce_lr_on_plateau"]
        return keras.callbacks.ReduceLROnPlateau(
            monitor="val_loss",
            factor=rlr_params["factor"],
            patience=rlr_params["patience"],
            verbose=1,
            mode="min",
            min_lr=rlr_params["min_lr"],
        )

    elif scheduler_type == "cyclic_lr":
        clr_params = lr_scheduler_params["cyclic_lr"]
        return CyclicLR(
            base_lr=clr_params["base_lr"],
            max_lr=clr_params["max_lr"],
            step_size=train_size * clr_params["step_size_factor"],
            mode=clr_params["mode"],
        )

    elif scheduler_type == "sgdr":
        sgdr_params = lr_scheduler_params["sgdr"]
        return SGDRScheduler(
            min_lr=sgdr_params["min_lr"],
            max_lr=sgdr_params["max_lr"],
            steps_per_epoch=train_size,
            lr_decay=sgdr_params["lr_decay"],
            cycle_length=sgdr_params["cycle_length"],
            mult_factor=sgdr_params["mult_factor"],
        )

    else:
        logger.warning(
            f"Unknown scheduler type: {scheduler_type}. Using ReduceLROnPlateau as default."
        )
        return keras.callbacks.ReduceLROnPlateau(
            monitor="val_loss",
            factor=0.5,
            patience=5,
            verbose=1,
            mode="min",
            min_lr=1e-6,
        )

set_data_distributed_training() -> None

Set JAX as the backend for Keras training, with distributed training if multiple CUDA devices are available.

This function checks for available CUDA devices and sets up distributed training only if more than one is found.

Returns:

None

Source code in src/capfinder/training.py
def set_data_distributed_training() -> None:
    """
    Set JAX as the backend for Keras training, with distributed training if multiple CUDA devices are available.

    This function checks for available CUDA devices and sets up distributed training only if more than one is found.

    Returns:
    --------
    None
    """
    # Set the Keras backend to JAX
    logger.info(f"Backend for training: {keras.backend.backend()}")

    # Retrieve available devices
    all_devices = jax.devices()
    cuda_devices = [d for d in all_devices if d.platform == "gpu"]

    # Log available devices
    for device in all_devices:
        logger.info(f"Device available: {device}, Type: {device.platform}")

    if len(cuda_devices) > 1:
        keras.mixed_precision.set_global_policy("mixed_float16")
        logger.info(
            f"({len(cuda_devices)}) CUDA devices detected. Setting up data distributed training."
        )

        # Define a 1D device mesh for data parallelism using only CUDA devices
        mesh_1d = keras.distribution.DeviceMesh(
            shape=(len(cuda_devices),), axis_names=["data"], devices=cuda_devices
        )

        # Create a DataParallel distribution
        data_parallel = keras.distribution.DataParallel(device_mesh=mesh_1d)

        # Set the global distribution
        keras.distribution.set_distribution(data_parallel)

        logger.info("Distributed training setup complete.")
    elif len(cuda_devices) == 1:
        keras.mixed_precision.set_global_policy("mixed_float16")
        logger.info(
            "Single CUDA device detected. Using standard (non-distributed) training."
        )
    else:
        logger.info("No CUDA devices detected. Training will proceed on CPU.")
        keras.mixed_precision.set_global_policy("float32")

upload_download

CometArtifactManager

Manages the creation, uploading, and downloading of dataset artifacts using Comet ML.

Source code in src/capfinder/upload_download.py
class CometArtifactManager:
    """
    Manages the creation, uploading, and downloading of dataset artifacts using Comet ML.
    """

    def __init__(self, project_name: str, dataset_dir: str) -> None:
        """
        Initialize the CometArtifactManager.

        Args:
            project_name (str): The name of the Comet ML project.
            dataset_dir (str): The directory containing the dataset.
        """
        self.project_name = project_name
        self.dataset_dir = dataset_dir
        self.artifact_name = "cap_data"
        self.experiment = self.initialize_comet_ml_experiment()
        self.artifact: Optional[comet_ml.Artifact] = None
        self.tmp_dir: Optional[str] = None
        self.chunk_files: List[str] = []
        self.upload_lock = threading.Lock()
        self.upload_threads: List[threading.Thread] = []

    def initialize_comet_ml_experiment(self) -> comet_ml.Experiment:
        """
        Initialize and return a Comet ML experiment.

        Returns:
            comet_ml.Experiment: The initialized Comet ML experiment.

        Raises:
            ValueError: If the COMET_API_KEY environment variable is not set.
        """
        logger.info(f"Initializing CometML experiment for project: {self.project_name}")
        comet_api_key = os.getenv("COMET_API_KEY")
        if not comet_api_key:
            raise ValueError("COMET_API_KEY environment variable is not set.")
        return comet_ml.Experiment(
            api_key=comet_api_key,
            project_name=self.project_name,
            display_summary_level=0,
        )

    def create_artifact(self) -> comet_ml.Artifact:
        """
        Create and return a Comet ML artifact.

        Returns:
            comet_ml.Artifact: The created Comet ML artifact.
        """
        logger.info(f"Creating CometML artifact: {self.artifact_name}")
        self.artifact = comet_ml.Artifact(
            name=self.artifact_name,
            artifact_type="dataset",
            metadata={"task": "RNA caps classification"},
        )
        return self.artifact

    def upload_chunk(
        self, chunk_file: str, chunk_number: int, total_chunks: int
    ) -> None:
        """
        Upload a chunk of the dataset to the Comet ML artifact.

        Args:
            chunk_file (str): The path to the chunk file.
            chunk_number (int): The number of the current chunk.
            total_chunks (int): The total number of chunks.
        """
        with self.upload_lock:
            if self.artifact is None:
                logger.error(
                    "Artifact is not initialized. Call create_artifact() first."
                )
                return
            self.artifact.add(
                local_path_or_data=chunk_file,
                logical_path=os.path.basename(chunk_file),
                metadata={"chunk": chunk_number, "total_chunks": total_chunks},
            )
        logger.info(f"Added chunk to artifact: {os.path.basename(chunk_file)}")

    def create_targz_chunks(
        self, chunk_size: int = 200 * 1024 * 1024
    ) -> Tuple[List[str], str, int]:
        """
        Create tar.gz chunks of the dataset.

        Args:
            chunk_size (int, optional): The size of each chunk in bytes. Defaults to 20MB.

        Returns:
            Tuple[List[str], str, int]: A tuple containing the list of chunk files,
                                        the temporary directory path, and the total number of chunks.
        """
        logger.info("Creating tar.gz chunks of the dataset...")
        self.tmp_dir = tempfile.mkdtemp()
        logger.info(f"Temporary directory created at: {self.tmp_dir}")

        # Create a single tar.gz file of the entire dataset
        tar_path = os.path.join(self.tmp_dir, "dataset.tar.gz")
        with tarfile.open(tar_path, "w:gz") as tar:
            for root, _, files in os.walk(self.dataset_dir):
                for file in files:
                    file_path = os.path.join(root, file)
                    arcname = os.path.relpath(file_path, self.dataset_dir)
                    tar.add(file_path, arcname=arcname)

        # Split the tar.gz file into chunks
        chunk_number = 0
        with open(tar_path, "rb") as f:
            while True:
                chunk = f.read(chunk_size)
                if not chunk:
                    break
                chunk_file = os.path.join(
                    self.tmp_dir, f"dataset.tar.gz.{chunk_number:03d}"
                )
                with open(chunk_file, "wb") as chunk_f:
                    chunk_f.write(chunk)
                self.chunk_files.append(chunk_file)
                chunk_number += 1

        total_chunks = chunk_number
        logger.info(f"Created {total_chunks} tar.gz chunks")

        # Calculate hash of the original tar.gz file
        tar_hash = calculate_file_hash(tar_path)

        # Store tar hash
        hash_file_path = os.path.join(self.tmp_dir, "tar_hash.json")
        with open(hash_file_path, "w") as f:
            json.dump({"tar_hash": tar_hash}, f)
        logger.info(f"Tar hash stored in: {hash_file_path}")

        return self.chunk_files, self.tmp_dir, total_chunks

    def make_comet_artifacts(self) -> None:
        """
        Create and upload Comet ML artifacts.
        """
        self.create_artifact()
        self.chunk_files, self.tmp_dir, total_chunks = self.create_targz_chunks()

        # Upload chunks
        for i, chunk_file in enumerate(self.chunk_files):
            upload_thread = threading.Thread(
                target=self.upload_chunk, args=(chunk_file, i, total_chunks)
            )
            upload_thread.start()
            self.upload_threads.append(upload_thread)

        # Wait for all upload threads to complete
        for thread in self.upload_threads:
            thread.join()

        # Add tar hash to artifact
        hash_file_path = os.path.join(self.tmp_dir, "tar_hash.json")
        self.artifact.add(  # type: ignore
            local_path_or_data=hash_file_path,
            logical_path="tar_hash.json",
            metadata={"content": "Tar hash for integrity check"},
        )
        logger.info("Added tar hash to artifact")

    def log_artifacts_to_comet(self) -> Optional[str]:
        """
        Log artifacts to Comet ML.

        Returns:
            Optional[str]: The version of the logged artifact, or None if logging failed.
        """
        if self.experiment is not None and self.artifact is not None:
            logger.info("Logging artifact to CometML...")
            art = self.experiment.log_artifact(self.artifact)
            version = f"{art.version.major}.{art.version.minor}.{art.version.patch}"
            logger.info(f"Artifact logged successfully. Version: {version}")
            self.store_artifact_version_to_file(version)

            logger.info(
                "Artifact upload initiated. It will continue in the background."
            )

            # Clean up the temporary directory
            shutil.rmtree(self.tmp_dir)  # type: ignore
            logger.info(f"Temporary directory cleaned up: {self.tmp_dir}")

            return version
        return None

    def store_artifact_version_to_file(self, version: str) -> None:
        """
        Store the artifact version in a file.

        Args:
            version (str): The version of the artifact to store.
        """
        version_file = os.path.join(self.dataset_dir, "artifact_version.txt")
        with open(version_file, "w") as f:
            f.write(version)
        logger.info(f"Artifact version {version} written to {version_file}")

    def download_remote_dataset(self, version: str, max_retries: int = 3) -> None:
        """
        Download a remote dataset from Comet ML.

        Args:
            version (str): The version of the dataset to download.
            max_retries (int, optional): The maximum number of download attempts. Defaults to 3.

        Raises:
            Exception: If the download fails after the maximum number of retries.
        """
        logger.info(f"Downloading remote dataset v{version}...")

        for attempt in range(max_retries):
            try:
                art = self.experiment.get_artifact(
                    artifact_name=self.artifact_name, version_or_alias=version
                )

                tmp_dir = tempfile.mkdtemp()
                logger.info(f"Temporary directory for download created at: {tmp_dir}")
                art.download(tmp_dir)

                # Combine all chunks back into a single tar.gz file
                tar_path = os.path.join(tmp_dir, "dataset.tar.gz")
                with open(tar_path, "wb") as tar_file:
                    chunk_files = sorted(
                        [
                            f
                            for f in os.listdir(tmp_dir)
                            if f.startswith("dataset.tar.gz.")
                        ]
                    )
                    for chunk_file in chunk_files:
                        with open(os.path.join(tmp_dir, chunk_file), "rb") as chunk:
                            tar_file.write(chunk.read())

                # Verify tar.gz integrity
                with open(os.path.join(tmp_dir, "tar_hash.json")) as f:
                    original_hash = json.load(f)["tar_hash"]
                current_hash = calculate_file_hash(tar_path)
                if current_hash != original_hash:
                    raise ValueError("Tar file integrity check failed")  # noqa: TRY301

                # Extract the tar.gz file
                with tarfile.open(tar_path, "r:gz") as tar:
                    tar.extractall(path=self.dataset_dir)

                logger.info(
                    "Remote dataset downloaded, verified, and extracted successfully."
                )
                return

            except Exception as e:
                logger.error(f"Attempt {attempt + 1} failed: {str(e)}")  # noqa: G003
                if attempt < max_retries - 1:
                    wait_time = (2**attempt) + random.uniform(
                        0, 1
                    )  # Exponential backoff with jitter
                    logger.info(f"Retrying in {wait_time:.2f} seconds...")
                    time.sleep(wait_time)
                else:
                    logger.error("Max retries reached. Download failed.")
                    raise

            finally:
                # Clean up
                if "tmp_dir" in locals():
                    shutil.rmtree(tmp_dir)
                    logger.info(f"Temporary directory cleaned up: {tmp_dir}")

        raise Exception(  # noqa: TRY002
            "Failed to download and extract the dataset after maximum retries."
        )

    def end_comet_experiment(self) -> None:
        """
        End the Comet ML experiment.
        """
        logger.info("Ending CometML experiment...")
        self.experiment.end()

__init__(project_name: str, dataset_dir: str) -> None

Initialize the CometArtifactManager.

Parameters:

Name Type Description Default
project_name str

The name of the Comet ML project.

required
dataset_dir str

The directory containing the dataset.

required
Source code in src/capfinder/upload_download.py
def __init__(self, project_name: str, dataset_dir: str) -> None:
    """
    Initialize the CometArtifactManager.

    Args:
        project_name (str): The name of the Comet ML project.
        dataset_dir (str): The directory containing the dataset.
    """
    self.project_name = project_name
    self.dataset_dir = dataset_dir
    self.artifact_name = "cap_data"
    self.experiment = self.initialize_comet_ml_experiment()
    self.artifact: Optional[comet_ml.Artifact] = None
    self.tmp_dir: Optional[str] = None
    self.chunk_files: List[str] = []
    self.upload_lock = threading.Lock()
    self.upload_threads: List[threading.Thread] = []

create_artifact() -> comet_ml.Artifact

Create and return a Comet ML artifact.

Returns:

Type Description
Artifact

comet_ml.Artifact: The created Comet ML artifact.

Source code in src/capfinder/upload_download.py
def create_artifact(self) -> comet_ml.Artifact:
    """
    Create and return a Comet ML artifact.

    Returns:
        comet_ml.Artifact: The created Comet ML artifact.
    """
    logger.info(f"Creating CometML artifact: {self.artifact_name}")
    self.artifact = comet_ml.Artifact(
        name=self.artifact_name,
        artifact_type="dataset",
        metadata={"task": "RNA caps classification"},
    )
    return self.artifact

create_targz_chunks(chunk_size: int = 200 * 1024 * 1024) -> Tuple[List[str], str, int]

Create tar.gz chunks of the dataset.

Parameters:

Name Type Description Default
chunk_size int

The size of each chunk in bytes. Defaults to 20MB.

200 * 1024 * 1024

Returns:

Type Description
Tuple[List[str], str, int]

Tuple[List[str], str, int]: A tuple containing the list of chunk files, the temporary directory path, and the total number of chunks.

Source code in src/capfinder/upload_download.py
def create_targz_chunks(
    self, chunk_size: int = 200 * 1024 * 1024
) -> Tuple[List[str], str, int]:
    """
    Create tar.gz chunks of the dataset.

    Args:
        chunk_size (int, optional): The size of each chunk in bytes. Defaults to 20MB.

    Returns:
        Tuple[List[str], str, int]: A tuple containing the list of chunk files,
                                    the temporary directory path, and the total number of chunks.
    """
    logger.info("Creating tar.gz chunks of the dataset...")
    self.tmp_dir = tempfile.mkdtemp()
    logger.info(f"Temporary directory created at: {self.tmp_dir}")

    # Create a single tar.gz file of the entire dataset
    tar_path = os.path.join(self.tmp_dir, "dataset.tar.gz")
    with tarfile.open(tar_path, "w:gz") as tar:
        for root, _, files in os.walk(self.dataset_dir):
            for file in files:
                file_path = os.path.join(root, file)
                arcname = os.path.relpath(file_path, self.dataset_dir)
                tar.add(file_path, arcname=arcname)

    # Split the tar.gz file into chunks
    chunk_number = 0
    with open(tar_path, "rb") as f:
        while True:
            chunk = f.read(chunk_size)
            if not chunk:
                break
            chunk_file = os.path.join(
                self.tmp_dir, f"dataset.tar.gz.{chunk_number:03d}"
            )
            with open(chunk_file, "wb") as chunk_f:
                chunk_f.write(chunk)
            self.chunk_files.append(chunk_file)
            chunk_number += 1

    total_chunks = chunk_number
    logger.info(f"Created {total_chunks} tar.gz chunks")

    # Calculate hash of the original tar.gz file
    tar_hash = calculate_file_hash(tar_path)

    # Store tar hash
    hash_file_path = os.path.join(self.tmp_dir, "tar_hash.json")
    with open(hash_file_path, "w") as f:
        json.dump({"tar_hash": tar_hash}, f)
    logger.info(f"Tar hash stored in: {hash_file_path}")

    return self.chunk_files, self.tmp_dir, total_chunks

download_remote_dataset(version: str, max_retries: int = 3) -> None

Download a remote dataset from Comet ML.

Parameters:

Name Type Description Default
version str

The version of the dataset to download.

required
max_retries int

The maximum number of download attempts. Defaults to 3.

3

Raises:

Type Description
Exception

If the download fails after the maximum number of retries.

Source code in src/capfinder/upload_download.py
def download_remote_dataset(self, version: str, max_retries: int = 3) -> None:
    """
    Download a remote dataset from Comet ML.

    Args:
        version (str): The version of the dataset to download.
        max_retries (int, optional): The maximum number of download attempts. Defaults to 3.

    Raises:
        Exception: If the download fails after the maximum number of retries.
    """
    logger.info(f"Downloading remote dataset v{version}...")

    for attempt in range(max_retries):
        try:
            art = self.experiment.get_artifact(
                artifact_name=self.artifact_name, version_or_alias=version
            )

            tmp_dir = tempfile.mkdtemp()
            logger.info(f"Temporary directory for download created at: {tmp_dir}")
            art.download(tmp_dir)

            # Combine all chunks back into a single tar.gz file
            tar_path = os.path.join(tmp_dir, "dataset.tar.gz")
            with open(tar_path, "wb") as tar_file:
                chunk_files = sorted(
                    [
                        f
                        for f in os.listdir(tmp_dir)
                        if f.startswith("dataset.tar.gz.")
                    ]
                )
                for chunk_file in chunk_files:
                    with open(os.path.join(tmp_dir, chunk_file), "rb") as chunk:
                        tar_file.write(chunk.read())

            # Verify tar.gz integrity
            with open(os.path.join(tmp_dir, "tar_hash.json")) as f:
                original_hash = json.load(f)["tar_hash"]
            current_hash = calculate_file_hash(tar_path)
            if current_hash != original_hash:
                raise ValueError("Tar file integrity check failed")  # noqa: TRY301

            # Extract the tar.gz file
            with tarfile.open(tar_path, "r:gz") as tar:
                tar.extractall(path=self.dataset_dir)

            logger.info(
                "Remote dataset downloaded, verified, and extracted successfully."
            )
            return

        except Exception as e:
            logger.error(f"Attempt {attempt + 1} failed: {str(e)}")  # noqa: G003
            if attempt < max_retries - 1:
                wait_time = (2**attempt) + random.uniform(
                    0, 1
                )  # Exponential backoff with jitter
                logger.info(f"Retrying in {wait_time:.2f} seconds...")
                time.sleep(wait_time)
            else:
                logger.error("Max retries reached. Download failed.")
                raise

        finally:
            # Clean up
            if "tmp_dir" in locals():
                shutil.rmtree(tmp_dir)
                logger.info(f"Temporary directory cleaned up: {tmp_dir}")

    raise Exception(  # noqa: TRY002
        "Failed to download and extract the dataset after maximum retries."
    )

end_comet_experiment() -> None

End the Comet ML experiment.

Source code in src/capfinder/upload_download.py
def end_comet_experiment(self) -> None:
    """
    End the Comet ML experiment.
    """
    logger.info("Ending CometML experiment...")
    self.experiment.end()

initialize_comet_ml_experiment() -> comet_ml.Experiment

Initialize and return a Comet ML experiment.

Returns:

Type Description
Experiment

comet_ml.Experiment: The initialized Comet ML experiment.

Raises:

Type Description
ValueError

If the COMET_API_KEY environment variable is not set.

Source code in src/capfinder/upload_download.py
def initialize_comet_ml_experiment(self) -> comet_ml.Experiment:
    """
    Initialize and return a Comet ML experiment.

    Returns:
        comet_ml.Experiment: The initialized Comet ML experiment.

    Raises:
        ValueError: If the COMET_API_KEY environment variable is not set.
    """
    logger.info(f"Initializing CometML experiment for project: {self.project_name}")
    comet_api_key = os.getenv("COMET_API_KEY")
    if not comet_api_key:
        raise ValueError("COMET_API_KEY environment variable is not set.")
    return comet_ml.Experiment(
        api_key=comet_api_key,
        project_name=self.project_name,
        display_summary_level=0,
    )

log_artifacts_to_comet() -> Optional[str]

Log artifacts to Comet ML.

Returns:

Type Description
Optional[str]

Optional[str]: The version of the logged artifact, or None if logging failed.

Source code in src/capfinder/upload_download.py
def log_artifacts_to_comet(self) -> Optional[str]:
    """
    Log artifacts to Comet ML.

    Returns:
        Optional[str]: The version of the logged artifact, or None if logging failed.
    """
    if self.experiment is not None and self.artifact is not None:
        logger.info("Logging artifact to CometML...")
        art = self.experiment.log_artifact(self.artifact)
        version = f"{art.version.major}.{art.version.minor}.{art.version.patch}"
        logger.info(f"Artifact logged successfully. Version: {version}")
        self.store_artifact_version_to_file(version)

        logger.info(
            "Artifact upload initiated. It will continue in the background."
        )

        # Clean up the temporary directory
        shutil.rmtree(self.tmp_dir)  # type: ignore
        logger.info(f"Temporary directory cleaned up: {self.tmp_dir}")

        return version
    return None

make_comet_artifacts() -> None

Create and upload Comet ML artifacts.

Source code in src/capfinder/upload_download.py
def make_comet_artifacts(self) -> None:
    """
    Create and upload Comet ML artifacts.
    """
    self.create_artifact()
    self.chunk_files, self.tmp_dir, total_chunks = self.create_targz_chunks()

    # Upload chunks
    for i, chunk_file in enumerate(self.chunk_files):
        upload_thread = threading.Thread(
            target=self.upload_chunk, args=(chunk_file, i, total_chunks)
        )
        upload_thread.start()
        self.upload_threads.append(upload_thread)

    # Wait for all upload threads to complete
    for thread in self.upload_threads:
        thread.join()

    # Add tar hash to artifact
    hash_file_path = os.path.join(self.tmp_dir, "tar_hash.json")
    self.artifact.add(  # type: ignore
        local_path_or_data=hash_file_path,
        logical_path="tar_hash.json",
        metadata={"content": "Tar hash for integrity check"},
    )
    logger.info("Added tar hash to artifact")

store_artifact_version_to_file(version: str) -> None

Store the artifact version in a file.

Parameters:

Name Type Description Default
version str

The version of the artifact to store.

required
Source code in src/capfinder/upload_download.py
def store_artifact_version_to_file(self, version: str) -> None:
    """
    Store the artifact version in a file.

    Args:
        version (str): The version of the artifact to store.
    """
    version_file = os.path.join(self.dataset_dir, "artifact_version.txt")
    with open(version_file, "w") as f:
        f.write(version)
    logger.info(f"Artifact version {version} written to {version_file}")

upload_chunk(chunk_file: str, chunk_number: int, total_chunks: int) -> None

Upload a chunk of the dataset to the Comet ML artifact.

Parameters:

Name Type Description Default
chunk_file str

The path to the chunk file.

required
chunk_number int

The number of the current chunk.

required
total_chunks int

The total number of chunks.

required
Source code in src/capfinder/upload_download.py
def upload_chunk(
    self, chunk_file: str, chunk_number: int, total_chunks: int
) -> None:
    """
    Upload a chunk of the dataset to the Comet ML artifact.

    Args:
        chunk_file (str): The path to the chunk file.
        chunk_number (int): The number of the current chunk.
        total_chunks (int): The total number of chunks.
    """
    with self.upload_lock:
        if self.artifact is None:
            logger.error(
                "Artifact is not initialized. Call create_artifact() first."
            )
            return
        self.artifact.add(
            local_path_or_data=chunk_file,
            logical_path=os.path.basename(chunk_file),
            metadata={"chunk": chunk_number, "total_chunks": total_chunks},
        )
    logger.info(f"Added chunk to artifact: {os.path.basename(chunk_file)}")

calculate_file_hash(file_path: str) -> str

Calculate the SHA256 hash of a file.

Parameters:

Name Type Description Default
file_path str

The path to the file.

required

Returns:

Name Type Description
str str

The hexadecimal representation of the file's SHA256 hash.

Source code in src/capfinder/upload_download.py
def calculate_file_hash(file_path: str) -> str:
    """
    Calculate the SHA256 hash of a file.

    Args:
        file_path (str): The path to the file.

    Returns:
        str: The hexadecimal representation of the file's SHA256 hash.
    """
    sha256_hash = hashlib.sha256()
    with open(file_path, "rb") as f:
        for byte_block in iter(lambda: f.read(4096), b""):
            sha256_hash.update(byte_block)
    return sha256_hash.hexdigest()

download_dataset_from_comet(dataset_dir: str, project_name: str, version: str) -> None

Download a dataset from Comet ML.

Parameters:

Name Type Description Default
dataset_dir str

The directory to download the dataset to.

required
project_name str

The name of the Comet ML project.

required
version str

The version of the dataset to download.

required
Source code in src/capfinder/upload_download.py
def download_dataset_from_comet(
    dataset_dir: str, project_name: str, version: str
) -> None:
    """
    Download a dataset from Comet ML.

    Args:
        dataset_dir (str): The directory to download the dataset to.
        project_name (str): The name of the Comet ML project.
        version (str): The version of the dataset to download.
    """
    comet_obj = CometArtifactManager(project_name=project_name, dataset_dir=dataset_dir)
    comet_obj.download_remote_dataset(version)

upload_dataset_to_comet(dataset_dir: str, project_name: str) -> str

Upload a dataset to Comet ML.

Parameters:

Name Type Description Default
dataset_dir str

The directory containing the dataset to upload.

required
project_name str

The name of the Comet ML project.

required

Returns:

Name Type Description
str str

The version of the uploaded dataset, or None if the upload failed.

Source code in src/capfinder/upload_download.py
def upload_dataset_to_comet(dataset_dir: str, project_name: str) -> str:
    """
    Upload a dataset to Comet ML.

    Args:
        dataset_dir (str): The directory containing the dataset to upload.
        project_name (str): The name of the Comet ML project.

    Returns:
        str: The version of the uploaded dataset, or None if the upload failed.
    """
    comet_obj = CometArtifactManager(project_name=project_name, dataset_dir=dataset_dir)

    logger.info("Making Comet ML dataset artifacts for uploading...")
    comet_obj.make_comet_artifacts()

    logger.info("Logging artifacts to Comet ML...")
    version = comet_obj.log_artifacts_to_comet()

    if version:
        logger.info(
            f"Dataset version {version} logged to Comet ML successfully. Upload will continue in the background."
        )
        return version
    else:
        logger.error("Failed to log dataset to Comet ML.")
        return ""

utils

The module contains some common utility functions used in the capfinder package.

Author: Adnan M. Niazi Date: 2024-02-28

ensure_config_dir() -> None

Ensure the configuration directory exists.

Source code in src/capfinder/utils.py
def ensure_config_dir() -> None:
    """Ensure the configuration directory exists."""
    CONFIG_DIR.mkdir(parents=True, exist_ok=True)

file_opener(filename: str) -> Union[IO[str], IO[bytes]]

Open a file for reading. If the file is compressed, use gzip to open it.

Parameters:

Name Type Description Default
filename str

The path to the file to open.

required

Returns:

Type Description
Union[IO[str], IO[bytes]]

file object: A file object that can be used for reading.

Source code in src/capfinder/utils.py
def file_opener(filename: str) -> Union[IO[str], IO[bytes]]:
    """
    Open a file for reading. If the file is compressed, use gzip to open it.

    Args:
        filename (str): The path to the file to open.

    Returns:
        file object: A file object that can be used for reading.
    """
    if filename.endswith(".gz"):
        # Compressed FASTQ file (gzip)
        return gzip.open(filename, "rt")
    else:
        # Uncompressed FASTQ file
        return open(filename)

get_dtype(dtype: str) -> Type[np.floating]

Returns the numpy floating type corresponding to the provided dtype string.

If the provided dtype string is not valid, a warning is logged and np.float32 is returned as default.

Parameters: dtype (str): The dtype string to convert to a numpy floating type.

Returns: Type[np.floating]: The corresponding numpy floating type.

Source code in src/capfinder/utils.py
def get_dtype(dtype: str) -> Type[np.floating]:
    """
    Returns the numpy floating type corresponding to the provided dtype string.

    If the provided dtype string is not valid, a warning is logged and np.float32 is returned as default.

    Parameters:
    dtype (str): The dtype string to convert to a numpy floating type.

    Returns:
    Type[np.floating]: The corresponding numpy floating type.
    """
    valid_dtypes = {
        "float16": np.float16,
        "float32": np.float32,
        "float64": np.float64,
    }

    if dtype in valid_dtypes:
        dt = valid_dtypes[dtype]
    else:
        logger.warning('You provided an invalid dtype. Using "float32" as default.')
        dt = np.float32

    return cast(Type[np.floating], dt)  # Cast dt to the expected type

get_next_available_cap_number() -> int

Find the next available cap number in the sequence.

Returns: int: The next available cap number.

Source code in src/capfinder/utils.py
def get_next_available_cap_number() -> int:
    """
    Find the next available cap number in the sequence.

    Returns:
    int: The next available cap number.
    """
    global CAP_MAPPING

    existing_caps = set(CAP_MAPPING.keys())
    existing_caps.discard(-99)  # Remove the special 'unknown' cap
    if not existing_caps:
        return 0
    max_cap = max(existing_caps)
    next_cap = max_cap + 1
    return next_cap

get_terminal_width() -> int

Get the width of the terminal.

Returns:

Name Type Description
int int

The width of the terminal in columns. Defaults to 80 if not available.

Source code in src/capfinder/utils.py
def get_terminal_width() -> int:
    """
    Get the width of the terminal.

    Returns:
        int: The width of the terminal in columns. Defaults to 80 if not available.
    """
    return shutil.get_terminal_size((80, 20)).columns

initialize_cap_mapping() -> None

Initialize the cap mapping file if it doesn't exist.

Source code in src/capfinder/utils.py
def initialize_cap_mapping() -> None:
    """Initialize the cap mapping file if it doesn't exist."""
    global CAP_MAPPING
    ensure_config_dir()
    if not CUSTOM_MAPPING_PATH.exists() or CUSTOM_MAPPING_PATH.stat().st_size == 0:
        save_custom_mapping(DEFAULT_CAP_MAPPING)
    load_custom_mapping()

initialize_comet_ml_experiment(project_name: str) -> Experiment

Initialize a CometML experiment for logging.

This function creates a CometML Experiment instance using the provided project name and the COMET_API_KEY environment variable.

Parameters:

project_name: str The name of the CometML project.

Returns:

Experiment: An instance of the CometML Experiment class.

Raises:

ValueError: If the project_name is empty or None, or if the COMET_API_KEY is not set. RuntimeError: If there's an error initializing the experiment.

Source code in src/capfinder/utils.py
def initialize_comet_ml_experiment(project_name: str) -> Experiment:
    """
    Initialize a CometML experiment for logging.

    This function creates a CometML Experiment instance using the provided
    project name and the COMET_API_KEY environment variable.

    Parameters:
    -----------
    project_name: str
        The name of the CometML project.

    Returns:
    --------
    Experiment:
        An instance of the CometML Experiment class.

    Raises:
    -------
    ValueError:
        If the project_name is empty or None, or if the COMET_API_KEY is not set.
    RuntimeError:
        If there's an error initializing the experiment.
    """
    if not project_name:
        raise ValueError("Project name cannot be empty or None")

    comet_api_key = os.getenv("COMET_API_KEY")

    if not comet_api_key:
        logger.error(
            "CometML API key is not set. Please set the COMET_API_KEY environment variable."
        )
        logger.info("Example: export COMET_API_KEY='YOUR_API_KEY'")
        raise ValueError("COMET_API_KEY environment variable is not set")

    try:
        experiment = Experiment(
            api_key=comet_api_key,
            project_name=project_name,
            auto_output_logging="native",
            auto_histogram_weight_logging=True,
            auto_histogram_gradient_logging=False,
            auto_histogram_activation_logging=False,
            display_summary_level=0,
        )
        logger.info(
            f"Successfully initialized CometML experiment for project: {project_name}"
        )
        return experiment
    except Exception as e:
        logger.error(f"Failed to initialize CometML experiment: {str(e)}")
        raise RuntimeError(f"Failed to initialize CometML experiment: {str(e)}") from e

is_cap_name_unique(new_cap_name: str) -> Optional[int]

Check if the given cap name is unique among existing cap mappings.

Args: new_cap_name (str): The new cap name to check for uniqueness.

Returns: Optional[int]: The integer label of the existing cap with the same name, if any. None otherwise.

Source code in src/capfinder/utils.py
def is_cap_name_unique(new_cap_name: str) -> Optional[int]:
    """
    Check if the given cap name is unique among existing cap mappings.

    Args:
    new_cap_name (str): The new cap name to check for uniqueness.

    Returns:
    Optional[int]: The integer label of the existing cap with the same name, if any. None otherwise.
    """
    global CAP_MAPPING
    for cap_int, cap_name in CAP_MAPPING.items():
        if cap_name.lower() == new_cap_name.lower():
            return cap_int
    return None

load_custom_mapping() -> None

Load custom mapping from JSON file if it exists.

Source code in src/capfinder/utils.py
def load_custom_mapping() -> None:
    """Load custom mapping from JSON file if it exists."""
    global CAP_MAPPING
    try:
        if CUSTOM_MAPPING_PATH.exists():
            with CUSTOM_MAPPING_PATH.open("r") as f:
                loaded_mapping = json.load(f)
            # Convert string keys back to integers
            CAP_MAPPING = {int(k): v for k, v in loaded_mapping.items()}
        else:
            CAP_MAPPING = DEFAULT_CAP_MAPPING.copy()
    except json.JSONDecodeError:
        logger.error(
            "Failed to decode JSON from custom mapping file. Using default mapping."
        )
        CAP_MAPPING = DEFAULT_CAP_MAPPING.copy()
    except Exception as e:
        logger.error(
            f"Unexpected error loading custom mapping: {e}. Using default mapping."
        )
        CAP_MAPPING = DEFAULT_CAP_MAPPING.copy()

    if not CAP_MAPPING:
        logger.warning("Loaded mapping is empty. Using default mapping.")
        CAP_MAPPING = DEFAULT_CAP_MAPPING.copy()

log_header(text: str) -> None

Log a centered header surrounded by '=' characters.

Parameters:

Name Type Description Default
text str

The text to be displayed in the header.

required

Returns:

Type Description
None

None

Source code in src/capfinder/utils.py
def log_header(text: str) -> None:
    """
    Log a centered header surrounded by '=' characters.

    Args:
        text (str): The text to be displayed in the header.

    Returns:
        None
    """
    width = get_terminal_width()
    header = f"\n{'=' * width}\n{text.center(width)}\n{'=' * width}"
    logger.info(header)

log_output(description: str) -> None

Log a step in a multi-step process.

Parameters:

Name Type Description Default
description str

A description of the current step.

required

Returns:

Type Description
None

None

Source code in src/capfinder/utils.py
def log_output(description: str) -> None:
    """
    Log a step in a multi-step process.

    Args:
        description (str): A description of the current step.

    Returns:
        None
    """
    width = get_terminal_width()
    text = f"\n{'-' * width}\n{description}"
    logger.info(text)

log_step(step_num: int, total_steps: int, description: str) -> None

Log a step in a multi-step process.

Parameters:

Name Type Description Default
step_num int

The current step number.

required
total_steps int

The total number of steps.

required
description str

A description of the current step.

required

Returns:

Type Description
None

None

Source code in src/capfinder/utils.py
def log_step(step_num: int, total_steps: int, description: str) -> None:
    """
    Log a step in a multi-step process.

    Args:
        step_num (int): The current step number.
        total_steps (int): The total number of steps.
        description (str): A description of the current step.

    Returns:
        None
    """
    width = get_terminal_width()
    step = (
        f"\n{'-' * width}\nStep {step_num}/{total_steps}: {description}\n{'-' * width}"
    )
    logger.info(step)

log_subheader(text: str) -> None

Log a centered subheader surrounded by '-' characters.

Parameters:

Name Type Description Default
text str

The text to be displayed in the header.

required

Returns:

Type Description
None

None

Source code in src/capfinder/utils.py
def log_subheader(text: str) -> None:
    """
    Log a centered subheader surrounded by '-' characters.

    Args:
        text (str): The text to be displayed in the header.

    Returns:
        None
    """
    width = get_terminal_width()
    header = f"\n{'-' * width}\n{text.center(width)}\n{'-' * width}"
    logger.info(header)

log_substep(text: str) -> None

Log a substep or bullet point.

Parameters:

Name Type Description Default
text str

The text of the substep to be logged.

required

Returns:

Type Description
None

None

Source code in src/capfinder/utils.py
def log_substep(text: str) -> None:
    """
    Log a substep or bullet point.

    Args:
        text (str): The text of the substep to be logged.

    Returns:
        None
    """
    logger.info(f"  • {text}")

map_cap_int_to_name(cap_class: int) -> str

Map the integer representation of the CAP class to the CAP name.

Source code in src/capfinder/utils.py
def map_cap_int_to_name(cap_class: int) -> str:
    """Map the integer representation of the CAP class to the CAP name."""
    global CAP_MAPPING

    return CAP_MAPPING.get(cap_class, f"Unknown cap: {cap_class}")

open_database(database_path: str) -> Tuple[sqlite3.Connection, sqlite3.Cursor]

Open the database connection based on the database path.

Parameters:

Name Type Description Default
database_path str

Path to the database.

required

Returns:

Name Type Description
conn Connection

Connection object for the database.

cursor Cursor

Cursor object for the database.

Source code in src/capfinder/utils.py
def open_database(
    database_path: str,
) -> Tuple[sqlite3.Connection, sqlite3.Cursor]:
    """
    Open the database connection based on the database path.

    Params:
        database_path (str): Path to the database.

    Returns:
        conn (sqlite3.Connection): Connection object for the database.
        cursor (sqlite3.Cursor): Cursor object for the database.
    """
    conn = sqlite3.connect(database_path)
    cursor = conn.cursor()
    return conn, cursor

save_custom_mapping(mapping: Dict[int, str]) -> None

Save the given mapping to JSON file.

Source code in src/capfinder/utils.py
def save_custom_mapping(mapping: Dict[int, str]) -> None:
    """Save the given mapping to JSON file."""
    ensure_config_dir()
    try:
        with CUSTOM_MAPPING_PATH.open("w") as f:
            json.dump(mapping, f, indent=2)
    except Exception as e:
        logger.error(f"Failed to save custom mapping: {e}")
        raise

update_cap_mapping(new_mapping: Dict[int, str]) -> None

Update the CAP_MAPPING with new entries.

Source code in src/capfinder/utils.py
def update_cap_mapping(new_mapping: Dict[int, str]) -> None:
    """Update the CAP_MAPPING with new entries."""
    global CAP_MAPPING
    CAP_MAPPING.update(new_mapping)
    save_custom_mapping(CAP_MAPPING)

visualize_alns

This module helps us to visualize the alignments of reads to a reference sequence. The module reads a FASTQ file or folder of FASTQ files, processes each read in parallel, and writes the output to a file. The output file contains the read ID, average quality, sequence, alignment score, and alignment string.

This module is useful in understandig the output of Parasail alignment.

Author: Adnan M. Niazi Date: 2024-02-28

calculate_average_quality(quality_scores: Sequence[Union[int, float]]) -> float

Calculate the average quality score for a read. Args: quality_scores (Sequence[Union[int, float]]): A list of quality scores for a read. Returns: average_quality (float): The average quality score for a read.

Source code in src/capfinder/visualize_alns.py
def calculate_average_quality(quality_scores: Sequence[Union[int, float]]) -> float:
    """
    Calculate the average quality score for a read.
    Args:
        quality_scores (Sequence[Union[int, float]]): A list of quality scores for a read.
    Returns:
        average_quality (float): The average quality score for a read.
    """
    average_quality = sum(quality_scores) / len(quality_scores)
    return average_quality

process_fastq_file(fastq_filepath: str, reference: str, num_processes: int, output_folder: str) -> None

Process a single FASTQ file. The function reads the FASTQ file, processes each read in parallel. The output is a file containing the read ID, average quality, sequence, alignment score, and alignment string.

Parameters:

Name Type Description Default
fastq_filepath str

The path to the FASTQ file.

required
reference str

The reference sequence to align the read to.

required
num_processes int

The number of processes to use for parallel processing.

required
output_folder str

The folder where the output file will be stored.

required

Returns:

Type Description
None

None

Source code in src/capfinder/visualize_alns.py
def process_fastq_file(
    fastq_filepath: str, reference: str, num_processes: int, output_folder: str
) -> None:
    """
    Process a single FASTQ file. The function reads the FASTQ file, processes each read in parallel.
    The output is a file containing the read ID, average quality, sequence, alignment score, and alignment string.

    Args:
        fastq_filepath (str): The path to the FASTQ file.
        reference (str): The reference sequence to align the read to.
        num_processes (int): The number of processes to use for parallel processing.
        output_folder (str): The folder where the output file will be stored.

    Returns:
        None
    """

    # Make output file name
    # Make output_folder if it does not exist already
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    directory, filename = os.path.split(fastq_filepath)
    filename_no_extension, extension = os.path.splitext(filename)
    output_filepath = os.path.join(
        output_folder, f"{filename_no_extension}_alignments.txt"
    )

    with file_opener(fastq_filepath) as fastq_file:
        records = list(SeqIO.parse(fastq_file, "fastq"))
        total_records = len(records)

        with WorkerPool(n_jobs=num_processes) as pool:
            results = pool.map(
                process_read,
                [(item, reference) for item in records],
                iterable_len=total_records,
                progress_bar=True,
            )
            write_ouput(results, output_filepath)

process_fastq_folder(folder_path: str, reference: str, num_processes: int, output_folder: str) -> None

Process all FASTQ files in a folder. The function reads all FASTQ files in a folder, processes each read in parallel.

Parameters:

Name Type Description Default
folder_path str

The path to the folder containing FASTQ files.

required
reference str

The reference sequence to align the read to.

required
num_processes int

The number of processes to use for parallel processing.

required
output_folder str

The folder where the output file will be stored.

required

Returns:

Type Description
None

None

Source code in src/capfinder/visualize_alns.py
def process_fastq_folder(
    folder_path: str, reference: str, num_processes: int, output_folder: str
) -> None:
    """
    Process all FASTQ files in a folder. The function reads all FASTQ files in a folder, processes each read in parallel.

    args:
        folder_path (str): The path to the folder containing FASTQ files.
        reference (str): The reference sequence to align the read to.
        num_processes (int): The number of processes to use for parallel processing.
        output_folder (str): The folder where the output file will be stored.

    returns:
        None
    """
    # List all files in the folder
    for root, _, files in os.walk(folder_path):
        for file_name in files:
            if file_name.endswith((".fastq", ".fastq.gz")):
                file_path = os.path.join(root, file_name)
                process_fastq_file(file_path, reference, num_processes, output_folder)

process_fastq_path(path: str, reference: str, num_processes: int, output_folder: str) -> None

Process a FASTQ file or folder of FASTQ files based on the provided path.

Parameters:

Name Type Description Default
path str

The path to the FASTQ file or folder.

required
reference str

The reference sequence to align the read to.

required
num_processes int

The number of processes to use for parallel processing.

required
output_folder str

The folder where the output file will be stored.

required

Returns:

Type Description
None

None

Source code in src/capfinder/visualize_alns.py
def process_fastq_path(
    path: str, reference: str, num_processes: int, output_folder: str
) -> None:
    """
    Process a FASTQ file or folder of FASTQ files based on the provided path.

    args:
        path (str): The path to the FASTQ file or folder.
        reference (str): The reference sequence to align the read to.
        num_processes (int): The number of processes to use for parallel processing.
        output_folder (str): The folder where the output file will be stored.

    returns:
        None
    """
    if os.path.isfile(path):
        # Process a single FASTQ file
        process_fastq_file(path, reference, num_processes, output_folder)
    elif os.path.isdir(path):
        # Process all FASTQ files in a folder
        process_fastq_folder(path, reference, num_processes, output_folder)
    else:
        print("Invalid path. Please provide a valid FASTQ file or folder path.")

process_read(record: Any, reference: str) -> str

Process a single read from a FASTQ file. The function calculates average read quality, alignment score, and alignment string. The output is a string that can be written to a file.

Parameters:

Name Type Description Default
record Any

A single read from a FASTQ file.

required
reference str

The reference sequence to align the read to.

required

Returns: output_string (str): A string containing the read ID, average quality, sequence, alignment score, and alignment string.

Source code in src/capfinder/visualize_alns.py
def process_read(record: Any, reference: str) -> str:
    """
    Process a single read from a FASTQ file. The function calculates average read quality,
    alignment score, and alignment string. The output is a string that can be written to a file.

    Args:
        record (Any): A single read from a FASTQ file.
        reference (str): The reference sequence to align the read to.
    Returns:
        output_string (str): A string containing the read ID, average quality, sequence,
                            alignment score, and alignment string.
    """
    read_id = record.id
    quality_scores = record.letter_annotations["phred_quality"]
    average_quality = round(calculate_average_quality(quality_scores))
    sequence = str(record.seq)
    with contextlib.redirect_stdout(None):
        _, _, chunked_aln_str, alignment_score = align(
            query_seq=sequence, target_seq=reference, pretty_print_alns=True
        )

    output_string = f">{read_id} {average_quality:.0f}\n{sequence}\n\n"
    output_string += f"Alignment Score: {alignment_score}\n"
    output_string += f"{chunked_aln_str}\n"

    return output_string

visualize_alns(path: str, reference: str, num_processes: int, output_folder: str) -> None

Main function to visualize alignments.

Parameters:

Name Type Description Default
path str

The path to the FASTQ file or folder.

required
reference str

The reference sequence to align the read to.

required
num_processes int

The number of processes to use for parallel processing.

required
output_folder str

The folder where the output file will be stored.

required

Returns:

Type Description
None

None

Source code in src/capfinder/visualize_alns.py
def visualize_alns(
    path: str, reference: str, num_processes: int, output_folder: str
) -> None:
    """
    Main function to visualize alignments.

    Args:
        path (str): The path to the FASTQ file or folder.
        reference (str): The reference sequence to align the read to.
        num_processes (int): The number of processes to use for parallel processing.
        output_folder (str): The folder where the output file will be stored.

    Returns:
        None
    """
    process_fastq_path(path, reference, num_processes, output_folder)

write_ouput(output_list: List[str], output_filepath: str) -> None

Write a list of strings to a file.

Parameters:

Name Type Description Default
output_list list

A list of strings to write to a file.

required
output_filepath str

The path to the output file.

required

Returns:

Type Description
None

None

Source code in src/capfinder/visualize_alns.py
def write_ouput(output_list: List[str], output_filepath: str) -> None:
    """
    Write a list of strings to a file.

    Args:
        output_list (list): A list of strings to write to a file.
        output_filepath (str): The path to the output file.

    Returns:
        None
    """
    if os.path.exists(output_filepath):
        os.remove(output_filepath)
    with open(output_filepath, "a") as f:
        f.writelines("\n".join(output_list))