# tidymut/cleaners/basic_cleaners.py
from __future__ import annotations
import numpy as np
import pandas as pd
from functools import partial
from joblib import Parallel, delayed
from pathlib import Path
from tqdm import tqdm
from typing import cast, TYPE_CHECKING
from ..core.alphabet import ProteinAlphabet, DNAAlphabet, RNAAlphabet
from ..core.pipeline import pipeline_step, multiout_step
from ..core.sequence import ProteinSequence, DNASequence, RNASequence
from ..utils.cleaner_workers import (
valid_single_mutation,
apply_single_mutation,
infer_wt_sequence_grouped,
)
from ..utils.dataset_builders import convert_format_1, convert_format_2
from ..utils.type_converter import (
convert_data_types as _convert_data_types,
convert_data_types_batch as _convert_data_types_batch,
)
if TYPE_CHECKING:
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Tuple,
Type,
Union,
)
__all__ = [
"read_dataset",
"merge_columns",
"extract_and_rename_columns",
"filter_and_clean_data",
"convert_data_types",
"validate_mutations",
"apply_mutations_to_sequences",
"infer_wildtype_sequences",
"convert_to_mutation_dataset_format",
]
def __dir__() -> List[str]:
return __all__
[docs]
@pipeline_step
def read_dataset(
file_path: Union[str, Path], file_format: Optional[str] = None, **kwargs
) -> pd.DataFrame:
"""
Read dataset from specified file format and return as a pandas DataFrame.
Parameters
----------
file_path : Union[str, Path]
Path to the dataset file
file_format : str
Format of the dataset file ("csv", "tsv", "xlsx", etc.)
kwargs : Dict[str, Any]
Additional keyword arguments for file reading
Returns
-------
pd.DataFrame
Dataset loaded from the specified file
Example
-------
>>> # Specify file_format parameter
>>> df = read_dataset("data.csv", "csv")
>>>
>>> # Detect file_format automatically
>>> df = read_dataset("data.csv")
"""
if file_format is None:
file_format = Path(file_path).suffix.lstrip(".").lower()
readers = {
"csv": lambda path, **kw: pd.read_csv(path, **kw),
"tsv": lambda path, **kw: pd.read_csv(path, sep="\t", **kw),
"xlsx": lambda path, **kw: pd.read_excel(path, **kw),
}
tqdm.write(f"Reading dataset from {file_path}...")
try:
return readers[file_format](file_path, **kwargs)
except KeyError:
raise ValueError(f"Unsupported file format: {file_format}")
[docs]
@pipeline_step
def merge_columns(
dataset: pd.DataFrame,
columns_to_merge: List[str],
new_column_name: str,
separator: str = "_",
drop_original: bool = False,
na_rep: Optional[str] = None,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
custom_formatter: Optional[Callable[[pd.Series], str]] = None,
) -> pd.DataFrame:
"""Merge multiple columns into a single column using a separator
This function combines values from multiple columns into a new column,
with flexible formatting options.
Parameters
----------
dataset : pd.DataFrame
Input dataset
columns_to_merge : List[str]
List of column names to merge
new_column_name : str
Name for the new merged column
separator : str, default='_'
Separator to use between values
drop_original : bool, default=False
Whether to drop the original columns after merging
na_rep : Optional[str], default=None
String representation of NaN values. If None, NaN values are skipped.
prefix : Optional[str], default=None
Prefix to add to the merged value
suffix : Optional[str], default=None
Suffix to add to the merged value
custom_formatter : Optional[Callable], default=None
Custom function to format each row. Takes a pd.Series and returns a string.
If provided, ignores separator, prefix, suffix parameters.
Returns
-------
pd.DataFrame
Dataset with the new merged column
Examples
--------
Basic usage:
>>> df = pd.DataFrame({
... 'gene': ['BRCA1', 'TP53', 'EGFR'],
... 'position': [100, 200, 300],
... 'mutation': ['A', 'T', 'G']
... })
>>> result = merge_columns(df, ['gene', 'position', 'mutation'], 'mutation_id', separator='_')
>>> print(result['mutation_id'])
0 BRCA1_100_A
1 TP53_200_T
2 EGFR_300_G
With prefix and suffix:
>>> result = merge_columns(
... df, ['gene', 'position'], 'gene_pos',
... separator=':', prefix='[', suffix=']'
... )
>>> print(result['gene_pos'])
0 [BRCA1:100]
1 [TP53:200]
2 [EGFR:300]
Handling NaN values:
>>> df_with_nan = pd.DataFrame({
... 'col1': ['A', 'B', None],
... 'col2': ['X', None, 'Z'],
... 'col3': [1, 2, 3]
... })
>>> result = merge_columns(
... df_with_nan, ['col1', 'col2', 'col3'], 'merged',
... separator='-', na_rep='NA'
... )
>>> print(result['merged'])
0 A-X-1
1 B-NA-2
2 NA-Z-3
Custom formatter:
>>> def format_mutation(row):
... return f"{row['gene']}:{row['position']}{row['mutation']}"
>>> result = merge_columns(
... df, ['gene', 'position', 'mutation'], 'hgvs',
... custom_formatter=format_mutation
... )
>>> print(result['hgvs'])
0 BRCA1:100A
1 TP53:200T
2 EGFR:300G
"""
tqdm.write(f"Merging columns {columns_to_merge} into '{new_column_name}'...")
# Validate columns exist
missing_cols = [col for col in columns_to_merge if col not in dataset.columns]
if missing_cols:
raise ValueError(f"Columns not found in dataset: {missing_cols}")
# Create a copy to avoid modifying original
result = dataset.copy()
if custom_formatter is not None:
# Use custom formatter
tqdm.write("Using custom formatter...")
tqdm.pandas()
result[new_column_name] = result.progress_apply(custom_formatter, axis=1) # type: ignore
else:
# Standard merging with separator
df_to_merge = result[columns_to_merge].copy()
if na_rep is not None:
# Replace NaN with na_rep
df_to_merge = df_to_merge.fillna(na_rep).astype(str)
else:
# Convert to string and replace NaN with empty string
df_to_merge = df_to_merge.astype(str)
mask = result[columns_to_merge].isna()
df_to_merge = df_to_merge.mask(mask, "")
# Vectorized merge
merged = df_to_merge.agg(separator.join, axis=1)
# Skip rows with all NaN values
if na_rep is None:
all_na = result[columns_to_merge].isna().all(axis=1)
merged[all_na] = np.nan
# Add prefix and suffix if specified
if prefix is not None or suffix is not None:
# Add prefix and suffix to non-NaN values
non_na_mask = merged.notna()
if prefix is not None:
merged[non_na_mask] = prefix + merged[non_na_mask]
if suffix is not None:
merged[non_na_mask] = merged[non_na_mask] + suffix
result[new_column_name] = merged
# Drop original columns if requested
if drop_original:
result = result.drop(columns=columns_to_merge)
tqdm.write(f"Dropped original columns: {columns_to_merge}")
tqdm.write(f"Successfully created merged column '{new_column_name}'")
return result
[docs]
@pipeline_step
def extract_and_rename_columns(
dataset: pd.DataFrame,
column_mapping: Dict[str, str],
required_columns: Optional[Sequence[str]] = None,
) -> pd.DataFrame:
"""
Extract useful columns and rename them to standard format.
This function extracts specified columns from the input dataset and renames them
according to the provided mapping. It helps standardize column names across
different datasets.
Parameters
----------
dataset : pd.DataFrame
Input dataset containing the data to be processed
column_mapping : Dict[str, str]
Column name mapping from original names to new names
Format: {original_column_name: new_column_name}
required_columns : Optional[Sequence[str]], default=None
Required column names. If None, extracts all mapped columns
Returns
-------
pd.DataFrame
Dataset with extracted and renamed columns
Raises
------
ValueError
If required columns are missing from the input dataset
Example
-------
>>> import pandas as pd
>>> df = pd.DataFrame({
... 'uniprot_ID': ['P12345', 'Q67890'],
... 'mutation_type': ['A123B', 'C456D'],
... 'score_value': [1.5, -2.3],
... 'extra_col': ['x', 'y']
... })
>>> mapping = {
... 'uniprot_ID': 'name',
... 'mutation_type': 'mut_info',
... 'score_value': 'label'
... }
>>> result = extract_and_rename_columns(df, mapping)
>>> print(result.columns.tolist())
['name', 'mut_info', 'label']
"""
tqdm.write("Extracting and renaming columns...")
# Check if required columns exist
missing_cols = [col for col in column_mapping.keys() if col not in dataset.columns]
if missing_cols:
raise ValueError(f"Missing required columns: {missing_cols}")
# Extract and rename columns
if required_columns:
# Only extract specified columns
extract_cols = [
col
for col in column_mapping.keys()
if column_mapping[col] in required_columns
]
extracted_dataset = dataset[extract_cols].copy()
else:
# Extract all mapped columns
extracted_dataset = dataset[list(column_mapping.keys())].copy()
# Rename columns
extracted_dataset = extracted_dataset.rename(columns=column_mapping)
tqdm.write(
f"Extracted {len(extracted_dataset.columns)} columns: {list(extracted_dataset.columns)}"
)
return extracted_dataset
[docs]
@pipeline_step
def filter_and_clean_data(
dataset: pd.DataFrame,
filters: Optional[Dict[str, Union[Any, Callable[[pd.Series], pd.Series]]]] = None,
exclude_patterns: Optional[Dict[str, Union[str, List[str]]]] = None,
drop_na_columns: Optional[List[str]] = None,
) -> pd.DataFrame:
"""
Filter and clean data based on specified conditions.
This function provides flexible data filtering and cleaning capabilities,
including value-based filtering, pattern exclusion, and null value removal.
Parameters
----------
dataset : pd.DataFrame
Input dataset to be filtered and cleaned
filters : Optional[Dict[str, Union[Any, Callable[[pd.Series], pd.Series]]]], default=None
Filter conditions in format {column_name: condition_value_or_function}
If value is callable, it will be applied to the column
exclude_patterns : Optional[Dict[str, Union[str, List[str]]]], default=None
Exclusion patterns in format {column_name: regex_pattern_or_list}
Rows matching these patterns will be excluded
drop_na_columns : Optional[List[str]], default=None
List of column names where null values should be dropped
Returns
-------
pd.DataFrame
Filtered and cleaned dataset
Example
-------
>>> import pandas as pd
>>> df = pd.DataFrame({
... 'mut_type': ['A123B', 'wt', 'C456D', 'insert', 'E789F'],
... 'score': [1.5, 2.0, '-', 3.2, 4.1],
... 'quality': ['good', 'bad', 'good', 'good', None]
... })
>>> filters = {'score': lambda x: x != '-'}
>>> exclude_patterns = {'mut_type': ['wt', 'insert']}
>>> drop_na_columns = ['quality']
>>> result = filter_and_clean_data(df, filters, exclude_patterns, drop_na_columns)
>>> print(len(result)) # Should be 2 (A123B and E789F rows)
2
"""
tqdm.write("Filtering and cleaning data...")
original_len = len(dataset)
# Collect all filter conditions to avoid dataframe copy
filter_masks = []
available_columns = set(dataset.columns)
# Apply filter conditions
if filters:
for col, condition in filters.items():
if col not in available_columns:
tqdm.write(f"Warning: Column '{col}' not found for filtering")
continue
if callable(condition):
mask = condition(dataset[col])
filter_masks.append(mask)
else:
mask = dataset[col] == condition
filter_masks.append(mask)
# Exclude specific patterns
if exclude_patterns:
for col, patterns in exclude_patterns.items():
if col not in available_columns:
tqdm.write(f"Warning: Column '{col}' not found for pattern exclusion")
continue
if isinstance(patterns, str):
patterns = [patterns]
# Combine patterns into a single regex pattern
if len(patterns) == 1:
combined_pattern = patterns[0]
else:
combined_pattern = "|".join(f"({pattern})" for pattern in patterns)
mask = ~dataset[col].str.contains(combined_pattern, na=False, regex=True)
filter_masks.append(mask)
# Drop null values for specified columns
if drop_na_columns:
for col in drop_na_columns:
if col in available_columns:
mask = dataset[col].notna()
filter_masks.append(mask)
# Apply combined filter conditions
if filter_masks:
combined_mask = filter_masks[0]
for mask in filter_masks[1:]:
combined_mask &= mask
result = dataset.loc[combined_mask].copy()
else:
result = dataset.copy()
tqdm.write(
f"Filtered data: {original_len} -> {len(result)} rows "
f"({len(result)/original_len*100:.1f}% retained)"
)
return result
[docs]
@pipeline_step
def convert_data_types(
dataset: pd.DataFrame,
type_conversions: Dict[str, Union[str, Type, np.dtype]],
handle_errors: str = "coerce",
optimize_memory: bool = True,
use_batch_processing: bool = False,
chunk_size: int = 10000,
) -> pd.DataFrame:
"""
Convert data types for specified columns.
This function provides unified data type conversion with error handling options.
Supports pandas, numpy, and Python built-in types with memory optimization.
Parameters
----------
dataset : pd.DataFrame
Input dataset with columns to be converted
type_conversions : Dict[str, Union[str, Type, np.dtype]]
Type conversion mapping in format {column_name: target_type}
Supported formats:
- String types: 'float', 'int', 'str', 'category', 'bool', 'datetime'
- Numpy types: np.float32, np.float64, np.int32, np.int64, etc.
- Pandas types: 'Int64', 'Float64', 'string', 'boolean'
- Python types: float, int, str, bool
handle_errors : str, default='coerce'
Error handling strategy: 'raise', 'coerce', or 'ignore'
optimize_memory : bool, default=True
Whether to automatically optimize memory usage by choosing smaller dtypes
use_batch_processing : bool, default=False
Whether to use batch processing for large datasets
chunk_size : int, default=10000
Chunk size when using batch processing
Returns
-------
pd.DataFrame
Dataset with converted data types
Example
-------
>>> import pandas as pd
>>> import numpy as np
>>> df = pd.DataFrame({
... 'score': ['1.5', '2.3', '3.7'],
... 'count': ['10', '20', '30'],
... 'name': [123, 456, 789],
... 'flag': ['True', 'False', 'True']
... })
>>> conversions = {
... 'score': np.float32,
... 'count': 'Int64',
... 'name': 'string',
... 'flag': 'boolean'
... }
>>> result = convert_data_types(df, conversions)
"""
tqdm.write("Converting data types...")
if use_batch_processing:
return _convert_data_types_batch(
dataset, type_conversions, handle_errors, optimize_memory, chunk_size
)
else:
return _convert_data_types(
dataset, type_conversions, handle_errors, optimize_memory
)
[docs]
@multiout_step(main="success", failed="failed")
def validate_mutations(
dataset: pd.DataFrame,
mutation_column: str = "mut_info",
format_mutations: bool = True,
mutation_sep: str = ",",
is_zero_based: bool = False,
cache_results: bool = True,
num_workers: int = 4,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""
Validate and format mutation information.
This function validates mutation strings, optionally formats them to a standard
representation, and separates valid and invalid mutations into different datasets.
It supports caching for improved performance on datasets with repeated mutations.
Parameters
----------
dataset : pd.DataFrame
Input dataset containing mutation information
mutation_column : str, default='mut_info'
Name of the column containing mutation information
format_mutations : bool, default=True
Whether to format mutations to standard representation
mutation_sep : str, default=','
Separator used to split multiple mutations in a single string (e.g., 'A123B,C456D')
is_zero_based : bool, default=False
Whether origin mutation positions are zero-based
cache_results : bool, default=True
Whether to cache formatting results for performance
num_workers : int, default=4
Number of parallel workers for processing, set to -1 for all available CPUs
Returns
-------
Tuple[pd.DataFrame, pd.DataFrame]
(successful_dataset, failed_dataset) - datasets with valid and invalid mutations
Example
-------
>>> import pandas as pd
>>> df = pd.DataFrame({
... 'name': ['protein1', 'protein1', 'protein2'],
... 'mut_info': ['A123S', 'C456D,E789F', 'InvalidMut'],
... 'score': [1.5, 2.3, 3.7]
... })
>>> successful, failed = validate_mutations(df, mutation_column='mut_info', mutation_sep=',')
>>> print(len(successful)) # Should be 2 (valid mutations)
2
>>> print(successful['mut_info'].tolist()) # Formatted mutations
['A123S', 'C456D,E789F']
>>> print(len(failed)) # Should be 1 (invalid mutation)
1
>>> print(failed['failed']['error_message'].iloc[0]) # Error message for failed mutation
'ValueError: No valid mutations could be parsed...'
"""
tqdm.write("Validating and formatting mutations...")
if mutation_column not in dataset.columns:
raise ValueError(f"Mutation column '{mutation_column}' not found")
result = dataset.copy()
original_len = len(result)
# Global cache for parallel processing (shared memory)
if cache_results:
from multiprocessing import Manager
manager = Manager()
cache = manager.dict()
else:
cache = None
# Prepare arguments for parallel processing
mutation_values = result[mutation_column].tolist()
args_list = [
(
mut_info,
format_mutations,
mutation_sep,
is_zero_based,
cache if cache_results else None,
)
for mut_info in mutation_values
]
# Parallel processing
results = Parallel(n_jobs=num_workers, backend="loky")(
delayed(valid_single_mutation)(args)
for args in tqdm(args_list, desc="Processing mutations")
)
# Separate formatted mutations and error messages
formatted_mutations, error_messages = map(list, zip(*results))
# Add results to dataset
result_dataset = result.copy()
result_dataset["formatted_" + mutation_column] = formatted_mutations
result_dataset["error_message"] = error_messages
# Create success mask based on whether formatted mutation is available
success_mask = pd.notnull(result_dataset["formatted_" + mutation_column])
# Create successful dataset
successful_dataset = result_dataset[success_mask].copy()
if format_mutations:
# Replace original mutation column with formatted version
successful_dataset[mutation_column] = successful_dataset[
"formatted_" + mutation_column
]
successful_dataset = successful_dataset.drop(
columns=["formatted_" + mutation_column, "error_message"]
)
# Create failed dataset
failed_dataset = result_dataset[~success_mask].copy()
failed_dataset = failed_dataset.drop(columns=["formatted_" + mutation_column])
tqdm.write(
f"Mutation validation: {len(successful_dataset)} successful, {len(failed_dataset)} failed "
f"(out of {original_len} total, {len(successful_dataset)/original_len*100:.1f}% valid)"
)
return successful_dataset, failed_dataset
[docs]
@multiout_step(main="success", failed="failed")
def apply_mutations_to_sequences(
dataset: pd.DataFrame,
sequence_column: str = "sequence",
name_column: str = "name",
mutation_column: str = "mut_info",
position_columns: Optional[Dict[str, str]] = None,
mutation_sep: str = ",",
is_zero_based: bool = True,
sequence_type: str = "protein",
num_workers: int = 4,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""
Apply mutations to sequences to generate mutated sequences.
This function takes mutation information and applies it to wild-type sequences
to generate the corresponding mutated sequences. It supports parallel processing
and can handle position-based sequence extraction.
Parameters
----------
dataset : pd.DataFrame
Input dataset containing mutation information and sequence data
sequence_column : str, default='sequence'
Column name containing wild-type sequences
name_column : str, default='name'
Column name containing protein identifiers
mutation_column : str, default='mut_info'
Column name containing mutation information
position_columns : Optional[Dict[str, str]], default=None
Position column mapping {"start": "start_col", "end": "end_col"}
Used for extracting sequence regions
mutation_sep : str, default=','
Separator used to split multiple mutations in a single string
is_zero_based : bool, default=True
Whether origin mutation positions are zero-based
sequence_type : str, default='protein'
Type of sequence ('protein', 'dna', 'rna')
num_workers : int, default=4
Number of parallel workers for processing, set to -1 for all available CPUs
Returns
-------
Tuple[pd.DataFrame, pd.DataFrame]
(successful_dataset, failed_dataset) - datasets with and without errors
Example
-------
>>> import pandas as pd
>>> df = pd.DataFrame({
... 'name': ['prot1', 'prot1', 'prot2'],
... 'sequence': ['AKCDEF', 'AKCDEF', 'FEGHIS'],
... 'mut_info': ['A0K', 'C2D', 'E1F'],
... 'score': [1.0, 2.0, 3.0]
... })
>>> successful, failed = apply_mutations_to_sequences(df)
>>> print(successful['mut_seq'].tolist())
['KKCDEF', 'AKDDEF', 'FFGHIS']
>>> print(len(failed)) # Should be 0 if all mutations are valid
0
"""
tqdm.write("Applying mutations to sequences...")
# Validate required columns exist
required_columns = [sequence_column, name_column, mutation_column]
missing_columns = [col for col in required_columns if col not in dataset.columns]
if missing_columns:
raise ValueError(f"Missing required columns: {missing_columns}")
# Select appropriate sequence class based on sequence_type
sequence_type = sequence_type.lower()
if sequence_type == "protein":
SequenceClass = ProteinSequence
elif sequence_type == "dna":
SequenceClass = DNASequence
elif sequence_type == "rna":
SequenceClass = RNASequence
else:
raise ValueError(
f"Unsupported sequence type: {sequence_type}. Must be 'protein', 'dna', or 'rna'"
)
_apply_single_mutation = partial(
apply_single_mutation,
dataset_columns=dataset.columns,
sequence_column=sequence_column,
name_column=name_column,
mutation_column=mutation_column,
position_columns=position_columns,
mutation_sep=mutation_sep,
is_zero_based=is_zero_based,
sequence_class=SequenceClass,
)
# Parallel processing
rows = dataset.itertuples(index=False, name=None)
results = Parallel(n_jobs=num_workers, backend="loky")(
delayed(_apply_single_mutation)(row)
for row in tqdm(rows, total=len(dataset), desc="Applying mutations")
)
# Separate successful and failed results
mutated_seqs, error_messages = map(list, zip(*results))
result_dataset = dataset.copy()
result_dataset["mut_seq"] = mutated_seqs
result_dataset["error_message"] = error_messages
success_mask = pd.notnull(result_dataset["mut_seq"])
successful_dataset = result_dataset[success_mask].drop(columns=["error_message"])
failed_dataset = result_dataset[~success_mask].drop(columns=["mut_seq"])
tqdm.write(
f"Mutation application: {len(successful_dataset)} successful, {len(failed_dataset)} failed"
)
return successful_dataset, failed_dataset
[docs]
@multiout_step(main="successful", failed="failed")
def infer_wildtype_sequences(
dataset: pd.DataFrame,
name_column: str = "name",
mutation_column: str = "mut_info",
sequence_column: str = "mut_seq",
label_columns: Optional[List[str]] = None,
wt_label: float = 0.0,
mutation_sep: str = ",",
is_zero_based: bool = False,
sequence_type: Literal["protein", "dna", "rna"] = "protein",
handle_multiple_wt: Literal["error", "separate", "first"] = "error",
num_workers: int = 4,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""
Infer wild-type sequences from mutated sequences and add WT rows.
This function takes mutated sequences and their corresponding mutations to
infer the original wild-type sequences. For each protein, it adds WT row(s)
to the dataset with the inferred wild-type sequence.
Parameters
----------
dataset : pd.DataFrame
Input dataset containing mutated sequences and mutation information
name_column : str, default='name'
Column name containing protein identifiers
mutation_column : str, default='mut_info'
Column name containing mutation information
sequence_column : str, default='mut_seq'
Column name containing mutated sequences
label_columns : Optional[List[str]], default=None
List of label column names to preserve
wt_label : float, default=0.0
Wild type score for WT rows
mutation_sep : str, default=','
Separator used to split multiple mutations in a single string
is_zero_based : bool, default=False
Whether origin mutation positions are zero-based
sequence_type : str, default='protein'
Type of sequence ('protein', 'dna', 'rna')
handle_multiple_wt : Literal["error", "separate", "first"], default='error'
How to handle multiple wild-type sequences: 'separate', 'first', or 'error'
num_workers : int, default=4
Number of parallel workers for processing, set to -1 for all available CPUs
Returns
-------
Tuple[pd.DataFrame, pd.DataFrame]
(successful_dataset, problematic_dataset) - datasets with added WT rows
Example
-------
>>> import pandas as pd
>>> df = pd.DataFrame({
... 'name': ['prot1', 'prot1', 'prot2'],
... 'mut_info': ['A0S', 'C2D', 'E0F'],
... 'mut_seq': ['SQCDEF', 'AQDDEF', 'FGHIGHK'],
... 'score': [1.0, 2.0, 3.0]
... })
>>> success, failed = infer_wildtype_sequences(
... df, label_columns=['score']
... )
>>> print(len(success)) # Should have original rows + WT rows
"""
tqdm.write("Inferring wildtype sequences...")
if label_columns is None:
label_columns = [col for col in dataset.columns if col.startswith("label_")]
# Select appropriate sequence class based on sequence_type
if sequence_type.lower() == "protein":
SequenceClass = ProteinSequence
AlphabetClass = ProteinAlphabet
elif sequence_type.lower() == "dna":
SequenceClass = DNASequence
AlphabetClass = DNAAlphabet
elif sequence_type.lower() == "rna":
SequenceClass = RNASequence
AlphabetClass = RNAAlphabet
else:
raise ValueError(
f"Unsupported sequence type: {sequence_type.lower()}. Must be 'protein', 'dna', or 'rna'"
)
_process_protein_group = partial(
infer_wt_sequence_grouped,
name_column=name_column,
mutation_column=mutation_column,
sequence_column=sequence_column,
label_columns=label_columns,
wt_label=wt_label,
mutation_sep=mutation_sep,
is_zero_based=is_zero_based,
handle_multiple_wt=handle_multiple_wt,
sequence_class=SequenceClass,
alphabet_class=AlphabetClass,
)
# Group by protein and process in parallel
grouped = list(dataset.groupby(name_column, sort=False))
try:
results = Parallel(n_jobs=num_workers, backend="loky")(
delayed(_process_protein_group)(group_data)
for group_data in tqdm(grouped, desc="Processing proteins")
)
except Exception as e:
tqdm.write(
f"Warning: Parallel processing failed, falling back to sequential: {e}"
)
# Fallback to sequential processing
results = []
for group_data in tqdm(grouped, desc="Processing proteins (sequential)"):
try:
result = _process_protein_group(group_data)
results.append(result)
except Exception as group_e:
# Create error entry for this specific group
protein_name = group_data[0]
error_row = {
name_column: str(protein_name),
"error_message": f"Sequential processing error: {type(group_e).__name__}: {str(group_e)}",
}
results.append(([error_row], "failed"))
# Filter out None results and validate structure
valid_results = []
invalid_count = 0
for i, result in enumerate(results):
if result is None:
invalid_count += 1
tqdm.write(f"Warning: Result {i} is None, skipping")
continue
if not isinstance(result, tuple) or len(result) != 2:
invalid_count += 1
tqdm.write(f"Warning: Result {i} has invalid format, skipping: {result}")
continue
rows_list, category = result
if category not in ("success", "failed"):
invalid_count += 1
tqdm.write(
f"Warning: Result {i} has invalid category '{category}', skipping"
)
continue
if not isinstance(rows_list, list):
invalid_count += 1
tqdm.write(f"Warning: Result {i} has invalid rows format, skipping")
continue
valid_results.append(result)
if invalid_count > 0:
tqdm.write(f"Warning: {invalid_count} invalid results were skipped")
# Collect all rows
successful_rows = []
failed_rows = []
for rows_list, category in valid_results:
if category == "success":
successful_rows.extend(rows_list)
else:
failed_rows.extend(rows_list)
# Convert to DataFrame format
successful_df = pd.DataFrame(successful_rows) if successful_rows else pd.DataFrame()
failed_df = pd.DataFrame(failed_rows) if failed_rows else pd.DataFrame()
tqdm.write(
f"Wildtype inference: {len(successful_rows)} successful rows, {len(failed_rows)} failed rows"
)
tqdm.write(
f"Added WT rows for proteins. Success: {len(successful_df)}, Failed: {len(failed_df)}"
)
return successful_df, failed_df