Source code for tidymut.utils.dataset_builders
# tidymut/utils/dataset_builders.py
"""
Functions are used in tidymut.cleaners.basic_cleaners.convert_to_mutation_dataset_format()
>>> # format 1:
>>> pd.DataFrame({
... 'name': ['prot1', 'prot1', 'prot1', 'prot2', 'prot2'],
... 'mut_info': ['A0S,Q1D', 'C2D', 'WT', 'E0F', 'WT'],
... 'mut_seq': ['SDCDEF', 'AQDDEF', 'AQCDEF', 'FGHIGHK', 'EGHIGHK'],
... 'score': [1.5, 2.0, 0.0, 3.0, 0.0]
... })
>>>
>>> # format 2:
>>> df2 = pd.DataFrame({
... 'name': ['prot1', 'prot1', 'prot2'],
... 'sequence': ['AKCDEF', 'AKCDEF', 'FEGHIS'],
... 'mut_info': ['A0K,C2D', 'Q1P', 'E1F'],
... 'score': [1.5, 2.0, 3.0],
... 'mut_seq': ['KKDDEF', 'APCDEF', 'FFGHIS']
... })
"""
from __future__ import annotations
import pandas as pd
from tqdm import tqdm
from typing import TYPE_CHECKING
from ..core.mutation import MutationSet
if TYPE_CHECKING:
from typing import Any, Dict, List, Optional, Type, Tuple, Union
from ..core.sequence import ProteinSequence, DNASequence, RNASequence
__all__ = ["convert_format_1", "convert_format_2"]
def __dir__() -> List[str]:
return __all__
[docs]
def convert_format_1(
df: pd.DataFrame,
name_column: str,
mutation_column: str,
mutated_sequence_column: str,
score_column: str,
include_wild_type: bool,
mutation_set_prefix: str,
is_zero_based: bool,
additional_metadata: Optional[Dict[str, Any]],
sequence_class: Type[Union[ProteinSequence, DNASequence, RNASequence]],
) -> Tuple[pd.DataFrame, Dict[str, str]]:
"""Convert Format 1 (with WT rows) to mutation dataset format."""
input_df = df.copy()
# Extract reference sequences from WT rows
wt_rows = input_df[input_df[mutation_column] == "WT"]
if wt_rows.empty:
raise ValueError("No wild-type (WT) entries found in the dataset")
reference_sequences = {}
for _, row in wt_rows.iterrows():
name = row[name_column]
sequence = row[
mutated_sequence_column
] # For WT rows, this is the wild-type sequence
reference_sequences[name] = sequence_class(sequence)
# Filter out wild-type entries if requested
if not include_wild_type:
input_df = input_df[input_df[mutation_column] != "WT"].copy()
if input_df.empty:
raise ValueError("No mutation data remaining after filtering")
# Process mutations (now supporting multi-mutations)
output_rows = []
total_rows = len(input_df)
for idx, row in tqdm(enumerate(input_df.itertuples()), total=total_rows):
mut_info = getattr(row, mutation_column)
name = getattr(row, name_column)
score = getattr(row, score_column)
# Skip wild-type if it somehow made it through filtering
if mut_info == "WT":
continue
# Parse mutations (single or multiple)
try:
mutation_data_list = _parse_mutations_string(mut_info, is_zero_based)
except ValueError as e:
raise ValueError(f"Cannot parse mutation '{mut_info}' in row {idx}: {e}")
# Create one output row per individual mutation within the set
mutation_set_id = f"{mutation_set_prefix}_{idx + 1}"
mutation_set_name = f"{name}_{mut_info}"
for mutation_data in mutation_data_list:
output_row = _create_output_row_from_mutation_data(
mutation_set_id,
mutation_set_name,
mut_info,
name,
score,
mutation_data,
additional_metadata,
)
output_rows.append(output_row)
output_df = pd.DataFrame(output_rows)
return output_df, reference_sequences
[docs]
def convert_format_2(
df: pd.DataFrame,
name_column: str,
mutation_column: str,
sequence_column: str,
score_column: str,
mutation_set_prefix: str,
is_zero_based: bool,
additional_metadata: Optional[Dict[str, Any]],
sequence_class: Type[Union[ProteinSequence, DNASequence, RNASequence]],
) -> Tuple[pd.DataFrame, Dict[str, str]]:
"""Convert Format 2 (with sequence column) to mutation dataset format."""
input_df = df.copy()
# Extract reference sequences from sequence column
reference_sequences = {}
for name, group in tqdm(input_df.groupby(name_column)):
sequences = group[sequence_column].unique()
if len(sequences) > 1:
raise ValueError(
f"Multiple different sequences found for protein '{name}': {sequences}"
)
reference_sequences[name] = sequence_class(sequences[0])
# Process mutations (now supporting multi-mutations)
output_rows = []
total_rows = len(input_df)
for idx, row in tqdm(enumerate(input_df.itertuples()), total=total_rows):
mut_info = getattr(row, mutation_column)
name = getattr(row, name_column)
score = getattr(row, score_column)
# Parse mutations (single or multiple)
try:
mutation_data_list = _parse_mutations_string(mut_info, is_zero_based)
except ValueError as e:
raise ValueError(f"Cannot parse mutation '{mut_info}' in row {idx}: {e}")
# Create one output row per individual mutation within the set
mutation_set_id = f"{mutation_set_prefix}_{idx + 1}"
mutation_set_name = f"{name}_{mut_info}"
for mutation_data in mutation_data_list:
output_row = _create_output_row_from_mutation_data(
mutation_set_id,
mutation_set_name,
mut_info,
name,
score,
mutation_data,
additional_metadata,
)
output_rows.append(output_row)
output_df = pd.DataFrame(output_rows)
return output_df, reference_sequences
def _create_output_row_from_mutation_data(
mutation_set_id: str,
mutation_set_name: str,
original_mutation_string: str,
name: str,
score: float,
mutation_data: Dict[str, Any],
additional_metadata: Optional[Dict[str, Any]],
) -> Dict[str, Any]:
"""
Create a single output row from mutation data.
Parameters
----------
mutation_set_id : str
ID for the mutation set
mutation_set_name : str
Name for the mutation set
original_mutation_string : str
Original mutation string (may contain multiple mutations)
name : str
Protein/sequence name
score : float
Score associated with the mutation set
mutation_data : Dict[str, Any]
Data for a single mutation
additional_metadata : Optional[Dict[str, Any]]
Additional metadata for the mutation set
Returns
-------
Dict[str, Any]
Row data for the output DataFrame
"""
output_row = {
"mutation_set_id": mutation_set_id,
"reference_id": name,
"mutation_string": mutation_data["mutation_string"], # Individual mutation
"position": mutation_data["position"],
"mutation_type": "amino_acid",
"wild_amino_acid": mutation_data["wild_aa"],
"mutant_amino_acid": mutation_data["mutant_aa"],
"mutation_set_name": mutation_set_name,
"label": score,
"set_original_mutation_string": original_mutation_string, # Store original string
}
# Add additional metadata if provided
if additional_metadata:
for key, value in additional_metadata.items():
output_row[f"set_{key}"] = value
return output_row
def _parse_mutations_string(
mutation_string: str, is_zero_based: bool
) -> list[Dict[str, Any]]:
"""
Parse a mutation string that may contain single or multiple mutations.
This function can handle:
- Single mutations: 'A0S'
- Multiple mutations: 'A0S,Q1D' or 'A0S;Q1D'
Uses MutationSet.from_string to parse complex mutation strings and
falls back to simple parsing for basic cases.
Parameters
----------
mutation_string : str
Mutation string(s) to parse
is_zero_based : bool
Whether origin mutation positions are zero-based
Returns
-------
list[Dict[str, Any]]
List of mutation data dictionaries, each containing:
- 'wild_aa': wild-type amino acid
- 'position': position (0-based)
- 'mutant_aa': mutant amino acid
- 'mutation_string': individual mutation string
Raises
------
ValueError
If the mutation string cannot be parsed
"""
mutation_string = mutation_string.strip()
# Use MutationSet.from_string to parse complex mutation strings
mutation_set = MutationSet.from_string(mutation_string, is_zero_based=is_zero_based)
mutation_data_list = []
for mutation in mutation_set.mutations:
# Extract information from the mutation object
if (
hasattr(mutation, "wild_type")
and hasattr(mutation, "position")
and hasattr(mutation, "mutant_type")
):
mutation_data = {
"wild_aa": mutation.wild_type,
"position": mutation.position,
"mutant_aa": mutation.mutant_type,
"mutation_string": str(mutation), # Individual mutation string
}
mutation_data_list.append(mutation_data)
else:
raise ValueError(
f"Mutation object does not have expected attributes: {mutation}"
)
if not mutation_data_list:
raise ValueError("No valid mutations found in mutation set")
return mutation_data_list