Source code for tidymut.core.dataset

# tidymut/core/dataset.py
from __future__ import annotations

import pickle
import json
import pandas as pd
from pathlib import Path
from tqdm import tqdm
from typing import cast, Any, Dict, TYPE_CHECKING

from .mutation import (
    MutationSet,
    AminoAcidMutationSet,
    CodonMutationSet,
    AminoAcidMutation,
    CodonMutation,
)
from .sequence import (
    DNASequence,
    ProteinSequence,
    RNASequence,
    load_sequences_from_fasta,
)

if TYPE_CHECKING:
    from typing import List, Literal, Optional, Sequence, Type, Union

    from .mutation import BaseMutation
    from .sequence import BaseSequence

__all__ = ["MutationDataset"]


def __dir__() -> List[str]:
    return __all__


SEQUENCE_TYPE_MAP = {
    "ProteinSequence": ProteinSequence,
    "DNASequence": DNASequence,
    "RNASequence": RNASequence,
    "DNA": DNASequence,
    "RNA": RNASequence,
    "Protein": ProteinSequence,
}


[docs] class MutationDataset: """ Dataset container for cleaned mutation data with multiple reference sequences. All mutation sets must be linked to a reference sequence when added to the dataset. This ensures data integrity and enables proper validation and analysis. """ def __init__(self, name: Optional[str] = None): """ Initialize a new MutationDataset. Parameters: name: Optional name for the dataset Note: All mutation sets added to this dataset must be linked to a reference sequence. Use add_reference_sequence() first, then add_mutation_set() with reference_id. """ self.name = name self.reference_sequences: Dict[str, BaseSequence] = ( {} ) # sequence_id -> sequence self.mutation_sets: List[MutationSet] = [] self.mutation_set_references: Dict[int, str] = ( {} ) # mutation_set_index -> sequence_id self.mutation_set_labels: Dict[int, Any] = {} # mutation_set_index -> label self.metadata: Dict[str, Any] = {} self._df: Optional[pd.DataFrame] = None def __len__(self) -> int: return len(self.mutation_sets) def __iter__(self): """ Iterate over mutation sets and their reference sequence IDs. Yields: Tuple[MutationSet, str]: (mutation_set, reference_id) pairs Example: for mutation_set, ref_id in dataset: print(f"Processing {len(mutation_set)} mutations for {ref_id}") ref_seq = dataset.get_reference_sequence(ref_id) # ... analysis code """ for i, mutation_set in enumerate(self.mutation_sets): reference_id = self.mutation_set_references[i] yield mutation_set, reference_id def __str__(self) -> str: stats = self.get_statistics() ref_count = stats["num_reference_sequences"] ref_info = ( f" ({ref_count} reference sequences)" if ref_count > 0 else " (no references)" ) return ( f"MutationDataset({self.name}){ref_info}: " f"{stats['total_mutation_sets']} mutation sets, " f"{stats['total_mutations']} mutations" )
[docs] def add_reference_sequence(self, sequence_id: str, sequence: BaseSequence): """Add a reference sequence with a unique identifier""" if sequence_id in self.reference_sequences: raise ValueError( f"Reference sequence with ID '{sequence_id}' already exists" ) self.reference_sequences[sequence_id] = sequence self._df = None # Reset cached DataFrame
[docs] def remove_reference_sequence(self, sequence_id: str): """Remove a reference sequence""" if sequence_id not in self.reference_sequences: raise ValueError(f"Reference sequence with ID '{sequence_id}' not found") # Check if any mutation sets reference this sequence referencing_sets = [ idx for idx, ref_id in self.mutation_set_references.items() if ref_id == sequence_id ] if referencing_sets: raise ValueError( f"Cannot remove sequence '{sequence_id}' as it is referenced by " f"{len(referencing_sets)} mutation sets. Remove the mutation sets first." ) del self.reference_sequences[sequence_id] self._df = None
[docs] def get_reference_sequence(self, sequence_id: str) -> BaseSequence: """Get a reference sequence by ID""" if sequence_id not in self.reference_sequences: raise ValueError(f"Reference sequence with ID '{sequence_id}' not found") return self.reference_sequences[sequence_id]
[docs] def list_reference_sequences(self) -> List[str]: """Get list of all reference sequence IDs""" return list(self.reference_sequences.keys())
[docs] def add_mutation_set( self, mutation_set: MutationSet, reference_id: str, label: Optional[float] = None, ): """Add a mutation set to the dataset, linking to a reference sequence""" if reference_id not in self.reference_sequences: raise ValueError(f"Reference sequence with ID '{reference_id}' not found") mutation_set_index = len(self.mutation_sets) self.mutation_sets.append(mutation_set) self.mutation_set_references[mutation_set_index] = reference_id self.mutation_set_labels[mutation_set_index] = label self._df = None # Reset cached DataFrame
[docs] def add_mutation_sets( self, mutation_sets: Sequence[MutationSet], reference_ids: Sequence[str], labels: Optional[Sequence[float]] = None, ): """Add multiple mutation sets to the dataset""" if len(reference_ids) != len(mutation_sets): raise ValueError( "Number of reference_ids must match number of mutation_sets" ) if labels is not None and len(labels) != len(mutation_sets): raise ValueError("Number of labels must match number of mutation_sets") for i, (mutation_set, ref_id) in enumerate(zip(mutation_sets, reference_ids)): label = labels[i] if labels is not None else None self.add_mutation_set(mutation_set, ref_id, label)
[docs] def set_mutation_set_reference(self, mutation_set_index: int, reference_id: str): """Set the reference sequence for a specific mutation set""" if mutation_set_index >= len(self.mutation_sets): raise ValueError(f"Mutation set index {mutation_set_index} out of range") if reference_id not in self.reference_sequences: raise ValueError(f"Reference sequence with ID '{reference_id}' not found") self.mutation_set_references[mutation_set_index] = reference_id self._df = None
[docs] def get_mutation_set_reference(self, mutation_set_index: int) -> str: """Get the reference sequence ID for a specific mutation set""" if mutation_set_index >= len(self.mutation_sets): raise ValueError(f"Mutation set index {mutation_set_index} out of range") return self.mutation_set_references[mutation_set_index]
[docs] def set_mutation_set_label(self, mutation_set_index: int, label: float): """Set the label for a specific mutation set""" if mutation_set_index >= len(self.mutation_sets): raise ValueError(f"Mutation set index {mutation_set_index} out of range") self.mutation_set_labels[mutation_set_index] = label self._df = None
[docs] def get_mutation_set_label(self, mutation_set_index: int) -> Any: """Get the label for a specific mutation set""" if mutation_set_index >= len(self.mutation_sets): raise ValueError(f"Mutation set index {mutation_set_index} out of range") return self.mutation_set_labels.get(mutation_set_index)
[docs] def remove_mutation_set(self, mutation_set_index: int): """Remove a mutation set from the dataset""" if mutation_set_index >= len(self.mutation_sets): raise ValueError(f"Mutation set index {mutation_set_index} out of range") # Remove the mutation set del self.mutation_sets[mutation_set_index] # Update the reference mapping (shift indices) new_references = {} new_labels = {} for idx, ref_id in self.mutation_set_references.items(): if idx < mutation_set_index: new_references[idx] = ref_id new_labels[idx] = self.mutation_set_labels.get(idx) elif idx > mutation_set_index: new_references[idx - 1] = ref_id new_labels[idx - 1] = self.mutation_set_labels.get(idx) # Skip the removed index self.mutation_set_references = new_references self.mutation_set_labels = new_labels self._df = None
[docs] def validate_against_references(self) -> Dict[str, Any]: """Validate mutations against their reference sequences""" validation_results = { "valid_mutation_sets": [], "invalid_mutation_sets": [], "position_mismatches": [], } for i, mutation_set in enumerate(self.mutation_sets): set_name = mutation_set.name or f"MutationSet_{i}" # All mutation sets must have a reference sequence reference_id = self.mutation_set_references[i] # This should always exist reference_sequence = self.reference_sequences[reference_id] set_valid = True for mutation in mutation_set.mutations: # Check if mutation position is within sequence bounds if mutation.position >= len(reference_sequence): validation_results["invalid_mutation_sets"].append( { "mutation_set": set_name, "reference_id": reference_id, "mutation": str(mutation), "error": f"Position {mutation.position} exceeds sequence length (0-indexed)", } ) set_valid = False continue # Check if wild type matches reference for amino acid mutations if isinstance(mutation, AminoAcidMutation) and isinstance( reference_sequence, ProteinSequence ): try: ref_residue = reference_sequence.get_residue(mutation.position) if ref_residue != mutation.wild_amino_acid: validation_results["position_mismatches"].append( { "mutation_set": set_name, "reference_id": reference_id, "mutation": str(mutation), "expected": ref_residue, "found": mutation.wild_amino_acid, "position": mutation.position, } ) except IndexError: validation_results["invalid_mutation_sets"].append( { "mutation_set": set_name, "reference_id": reference_id, "mutation": str(mutation), "error": f"Position {mutation.position} out of range", } ) set_valid = False # Check codon mutations for nucleotide sequences elif isinstance(mutation, CodonMutation) and isinstance( reference_sequence, (DNASequence, RNASequence) ): try: # Assuming position is codon position, get the codon at this position start_pos = mutation.position * 3 if start_pos + 3 <= len(reference_sequence): ref_codon = str( reference_sequence[start_pos : start_pos + 3] ).upper() if ref_codon != mutation.wild_codon: validation_results["position_mismatches"].append( { "mutation_set": set_name, "reference_id": reference_id, "mutation": str(mutation), "expected": ref_codon, "found": mutation.wild_codon, "position": mutation.position, } ) else: validation_results["invalid_mutation_sets"].append( { "mutation_set": set_name, "reference_id": reference_id, "mutation": str(mutation), "error": f"Codon position {mutation.position} exceeds sequence bounds", } ) set_valid = False except Exception as e: validation_results["invalid_mutation_sets"].append( { "mutation_set": set_name, "reference_id": reference_id, "mutation": str(mutation), "error": f"Error validating codon: {str(e)}", } ) set_valid = False if set_valid: validation_results["valid_mutation_sets"].append( { "mutation_set": set_name, "reference_id": reference_id, } ) return validation_results
[docs] def to_dataframe(self) -> pd.DataFrame: """Convert dataset to pandas DataFrame""" if self._df is None: data = [] for i, mutation_set in tqdm( enumerate(self.mutation_sets), desc="Converting dataset to DataFrame: " ): reference_id = self.mutation_set_references[ i ] # This should always exist reference_sequence = self.reference_sequences[reference_id] label = self.mutation_set_labels.get(i) base_data = { "mutation_set_id": i, "mutation_set_name": mutation_set.name, "reference_id": reference_id, "reference_sequence_name": reference_sequence.name, "reference_sequence_length": len(reference_sequence), "reference_sequence_type": type(reference_sequence).__name__, "num_mutations": len(mutation_set), "is_single_mutation": mutation_set.is_single_mutation(), "is_valid": mutation_set.validate_all(), "positions": ",".join(map(str, mutation_set.get_positions())), "mutation_subtype": mutation_set.mutation_subtype, "label": label, } # Add mutation set metadata for key, value in mutation_set.metadata.items(): base_data[f"set_{key}"] = value # Add mutation information for j, mutation in enumerate(mutation_set.mutations): mutation_data = base_data.copy() mutation_data.update( { "mutation_id": j, "mutation_type": mutation.type, "mutation_string": str(mutation), "position": mutation.position, "mutation_category": mutation.get_mutation_category(), } ) # Add amino acid mutation-specific data if isinstance(mutation, AminoAcidMutation): mutation_data.update( { "wild_amino_acid": mutation.wild_amino_acid, "mutant_amino_acid": mutation.mutant_amino_acid, "effect_type": mutation.effect_type, "is_synonymous": mutation.is_synonymous(), "is_nonsense": mutation.is_nonsense(), "is_missense": mutation.is_missense(), } ) # Add reference residue if available if isinstance(reference_sequence, ProteinSequence): try: ref_residue = reference_sequence.get_residue( mutation.position ) mutation_data["reference_residue"] = ref_residue mutation_data["wild_type_matches_reference"] = ( ref_residue == mutation.wild_amino_acid ) except IndexError: mutation_data["reference_residue"] = None mutation_data["wild_type_matches_reference"] = False # Add codon mutation-specific data elif isinstance(mutation, CodonMutation): mutation_data.update( { "wild_codon": mutation.wild_codon, "mutant_codon": mutation.mutant_codon, "seq_type": mutation.seq_type, } ) # Add reference codon if available if isinstance(reference_sequence, (DNASequence, RNASequence)): try: # Get the codon at this position (assuming position is codon position) start_pos = mutation.position * 3 if start_pos + 3 <= len(reference_sequence): ref_codon = str( reference_sequence[start_pos : start_pos + 3] ) mutation_data["reference_codon"] = ref_codon mutation_data["wild_codon_matches_reference"] = ( ref_codon.upper() == mutation.wild_codon ) else: mutation_data["reference_codon"] = None mutation_data["wild_codon_matches_reference"] = ( False ) except Exception: mutation_data["reference_codon"] = None mutation_data["wild_codon_matches_reference"] = False # Add mutation metadata for key, value in mutation.metadata.items(): mutation_data[f"mutation_{key}"] = value data.append(mutation_data) self._df = pd.DataFrame(data) return self._df
[docs] def filter_by_reference(self, reference_id: str) -> "MutationDataset": """Filter dataset to only include mutation sets from a specific reference sequence""" if reference_id not in self.reference_sequences: raise ValueError(f"Reference sequence with ID '{reference_id}' not found") filtered_sets = [] filtered_references = [] filtered_labels = [] for i, mutation_set in tqdm( enumerate(self.mutation_sets), desc="Filtering by reference: " ): if self.mutation_set_references[i] == reference_id: filtered_sets.append(mutation_set) filtered_references.append(reference_id) filtered_labels.append(self.mutation_set_labels.get(i)) filtered_dataset = MutationDataset( name=f"{self.name}_{reference_id}" if self.name else reference_id ) filtered_dataset.add_reference_sequence( reference_id, self.reference_sequences[reference_id] ) filtered_dataset.add_mutation_sets( filtered_sets, filtered_references, filtered_labels ) return filtered_dataset
[docs] def filter_by_mutation_type( self, mutation_type: Type[BaseMutation] ) -> "MutationDataset": """Filter dataset by mutation type""" filtered_sets = [] filtered_references = [] filtered_labels = [] for i, mutation_set in tqdm( enumerate(self.mutation_sets), desc="Filtering by mutation type: " ): # Filter mutations by type filtered_mutations = [ m for m in mutation_set.mutations if isinstance(m, mutation_type) ] if filtered_mutations: # Create new mutation set with filtered mutations if mutation_type == AminoAcidMutation: new_set = AminoAcidMutationSet( mutations=filtered_mutations, # type: ignore name=( f"{mutation_set.name}_filtered" if mutation_set.name else "filtered" ), metadata=mutation_set.metadata.copy(), ) elif mutation_type == CodonMutation: new_set = CodonMutationSet( mutations=filtered_mutations, # type: ignore name=( f"{mutation_set.name}_filtered" if mutation_set.name else "filtered" ), metadata=mutation_set.metadata.copy(), ) else: new_set = MutationSet( mutations=filtered_mutations, mutation_type=mutation_type, name=( f"{mutation_set.name}_filtered" if mutation_set.name else "filtered" ), metadata=mutation_set.metadata.copy(), ) filtered_sets.append(new_set) ref_id = self.mutation_set_references[i] filtered_references.append(ref_id) filtered_labels.append(self.mutation_set_labels.get(i)) filtered_dataset = MutationDataset( name=f"{self.name}_filtered" if self.name else "filtered" ) # Copy all reference sequences that are still needed needed_refs = set(filtered_references) for ref_id in needed_refs: if ref_id is not None: filtered_dataset.add_reference_sequence( ref_id, self.reference_sequences[ref_id] ) filtered_dataset.add_mutation_sets( filtered_sets, filtered_references, filtered_labels ) return filtered_dataset
[docs] def filter_by_effect_type(self, effect_type: str) -> "MutationDataset": """Filter dataset by amino acid mutation effect type (synonymous, missense, nonsense)""" filtered_sets = [] filtered_references = [] filtered_labels = [] for i, mutation_set in tqdm( enumerate(self.mutation_sets), desc="Filtering by effect type: " ): # Filter amino acid mutations by effect type filtered_mutations = [] for mutation in mutation_set.mutations: if isinstance(mutation, AminoAcidMutation): if mutation.effect_type == effect_type: filtered_mutations.append(mutation) if filtered_mutations: new_set = AminoAcidMutationSet( mutations=filtered_mutations, name=( f"{mutation_set.name}_{effect_type}" if mutation_set.name else effect_type ), metadata=mutation_set.metadata.copy(), ) filtered_sets.append(new_set) ref_id = self.mutation_set_references[i] # Always exists now filtered_references.append(ref_id) filtered_labels.append(self.mutation_set_labels.get(i)) filtered_dataset = MutationDataset( name=f"{self.name}_{effect_type}" if self.name else effect_type ) # Copy all reference sequences that are still needed needed_refs = set(filtered_references) for ref_id in needed_refs: if ref_id is not None: filtered_dataset.add_reference_sequence( ref_id, self.reference_sequences[ref_id] ) filtered_dataset.add_mutation_sets( filtered_sets, filtered_references, filtered_labels ) return filtered_dataset
[docs] def get_statistics(self) -> Dict[str, Any]: """Get basic statistics about the dataset""" total_sets = len(self.mutation_sets) total_mutations = sum(len(ms) for ms in self.mutation_sets) single_mutation_sets = sum( 1 for ms in self.mutation_sets if ms.is_single_mutation() ) multiple_mutation_sets = total_sets - single_mutation_sets mutation_types = {} mutation_categories = {} effect_types = {} reference_stats = {} # Statistics by reference sequence for ref_id, sequence in tqdm( self.reference_sequences.items(), desc="Statistics - ref seq: " ): reference_stats[ref_id] = { "sequence_name": sequence.name, "sequence_length": len(sequence), "sequence_type": type(sequence).__name__, "mutation_sets": 0, "mutations": 0, } for i, mutation_set in tqdm( enumerate(self.mutation_sets), desc="Statistics - mutation sets: " ): ref_id = self.mutation_set_references[i] # Always exists now if ref_id in reference_stats: reference_stats[ref_id]["mutation_sets"] += 1 reference_stats[ref_id]["mutations"] += len(mutation_set) for mutation in mutation_set.mutations: # Count mutation types mut_type = mutation.type mutation_types[mut_type] = mutation_types.get(mut_type, 0) + 1 # Count mutation categories category = mutation.get_mutation_category() mutation_categories[category] = mutation_categories.get(category, 0) + 1 # Count effect types for amino acid mutations if isinstance(mutation, AminoAcidMutation): effect = mutation.effect_type effect_types[effect] = effect_types.get(effect, 0) + 1 stats = { "total_mutation_sets": total_sets, "total_mutations": total_mutations, "single_mutation_sets": single_mutation_sets, "multiple_mutation_sets": multiple_mutation_sets, "mutation_types": mutation_types, "mutation_categories": mutation_categories, "effect_types": effect_types, "average_mutations_per_set": ( total_mutations / total_sets if total_sets > 0 else 0 ), "reference_sequences": reference_stats, "num_reference_sequences": len(self.reference_sequences), } return stats
[docs] def get_position_coverage( self, reference_id: Optional[str] = None ) -> Dict[str, Any]: """Get statistics about position coverage across reference sequences""" if reference_id is not None: if reference_id not in self.reference_sequences: raise ValueError( f"Reference sequence with ID '{reference_id}' not found" ) return self._get_single_sequence_coverage(reference_id) else: # Get coverage for all sequences coverage_stats = {} for ref_id in self.reference_sequences: coverage_stats[ref_id] = self._get_single_sequence_coverage(ref_id) return coverage_stats
def _get_single_sequence_coverage(self, reference_id: str) -> Dict[str, Any]: """Get position coverage for a single reference sequence""" sequence = self.reference_sequences[reference_id] all_positions = set() for i, mutation_set in enumerate(self.mutation_sets): if self.mutation_set_references[i] == reference_id: # Always exists now all_positions.update(mutation_set.get_positions()) seq_length = len(sequence) covered_positions = len(all_positions) coverage_percentage = ( (covered_positions / seq_length) * 100 if seq_length > 0 else 0 ) return { "reference_id": reference_id, "sequence_name": sequence.name, "sequence_length": seq_length, "sequence_type": type(sequence).__name__, "covered_positions": covered_positions, "uncovered_positions": seq_length - covered_positions, "coverage_percentage": coverage_percentage, "position_list": sorted(list(all_positions)), }
[docs] def convert_codon_to_amino_acid_sets( self, convert_labels: bool = False ) -> "MutationDataset": """ Convert all codon mutation sets to amino acid mutation sets Parameters: convert_labels: Whether to save the labels with the mutation sets (default: False) """ converted_sets = [] converted_references = [] converted_labels = [] for i, mutation_set in enumerate(self.mutation_sets): if isinstance(mutation_set, CodonMutationSet): aa_set = mutation_set.to_amino_acid_mutation_set() converted_sets.append(aa_set) else: converted_sets.append(mutation_set) ref_id = self.mutation_set_references[i] converted_references.append(ref_id) converted_labels.append(self.get_mutation_set_label(i)) converted_dataset = MutationDataset( name=f"{self.name}_aa_converted" if self.name else "aa_converted" ) # Copy all reference sequences for ref_id, sequence in self.reference_sequences.items(): converted_dataset.add_reference_sequence(ref_id, sequence) if not convert_labels: converted_dataset.add_mutation_sets(converted_sets, converted_references) else: converted_dataset.add_mutation_sets( converted_sets, converted_references, converted_labels ) return converted_dataset
@staticmethod def _sanitize_filename(name: str) -> str: """Sanitize a string to be used as a filename or directory name""" import re # Replace invalid characters with underscores sanitized = re.sub(r'[<>:"/\\|?*]', "_", name) # Remove leading/trailing whitespace and dots sanitized = sanitized.strip(". ") # Ensure it's not empty if not sanitized: sanitized = "unnamed" # Limit length to avoid filesystem issues if len(sanitized) > 200: sanitized = sanitized[:200] return sanitized
[docs] def save_by_reference(self, base_dir: Union[str, Path]) -> None: """ Save dataset by reference_id, creating separate folders for each reference. Parameters: base_dir: Base directory to create reference folders in For each reference_id, creates: - {base_dir}/{reference_id}/data.csv: mutation data with columns [mutation_name, mutated_sequence, label] - {base_dir}/{reference_id}/wt.fasta: wild-type reference sequence - {base_dir}/{reference_id}/metadata.json: statistics and metadata for this reference """ base_path = Path(base_dir) base_path.mkdir(parents=True, exist_ok=True) tqdm.write(f"Saving dataset by reference to: {base_path}") # Group mutation sets by reference_id ref_groups = {} for i, mutation_set in tqdm( enumerate(self.mutation_sets), desc="Grouping mutation sets by reference_id" ): ref_id = self.mutation_set_references[i] if ref_id not in ref_groups: ref_groups[ref_id] = [] ref_groups[ref_id].append((i, mutation_set)) # Process each reference for ref_id, mutation_set_list in tqdm( ref_groups.items(), desc="Processing references" ): # Create sanitized directory name sanitized_ref_id = self._sanitize_filename(ref_id) ref_dir = base_path / sanitized_ref_id ref_dir.mkdir(exist_ok=True) # Get reference sequence ref_sequence = self.reference_sequences[ref_id] # Prepare data for CSV csv_data = [] for set_index, mutation_set in mutation_set_list: # Get mutation name (string representation of mutation set) mutation_name = str(mutation_set) # Apply mutations to get mutated sequence try: mutated_sequence = ref_sequence.apply_mutation(mutation_set) mutated_seq_str = str(mutated_sequence) except Exception as e: tqdm.write( f"Warning: Could not apply mutations for {mutation_name}: {e}" ) mutated_seq_str = "ERROR_APPLYING_MUTATION" # Get label label = self.mutation_set_labels.get(set_index, "") csv_data.append( { "mutation_name": mutation_name, "mutated_sequence": mutated_seq_str, "label": label, } ) # Save data.csv csv_path = ref_dir / "data.csv" df_ref = pd.DataFrame(csv_data) df_ref.to_csv(csv_path, index=False) # Save wt.fasta fasta_path = ref_dir / "wt.fasta" with open(fasta_path, "w") as f: # Use sequence name if available, otherwise use ref_id seq_name = ref_sequence.name if ref_sequence.name else ref_id f.write(f">{seq_name}\n") f.write(f"{str(ref_sequence)}\n") # Prepare metadata ref_stats = self._get_single_sequence_coverage(ref_id) metadata = { "reference_id": ref_id, "sequence_name": ref_sequence.name, "sequence_type": type(ref_sequence).__name__, "sequence_length": len(ref_sequence), "num_mutation_sets": len(mutation_set_list), "total_mutations": sum(len(ms) for _, ms in mutation_set_list), "coverage_stats": ref_stats, "dataset_name": self.name, "sanitized_directory_name": sanitized_ref_id, } # Add label distribution label_counts = {} for set_index, _ in mutation_set_list: label = self.mutation_set_labels.get(set_index, "unlabeled") label_counts[str(label)] = label_counts.get(str(label), 0) + 1 metadata["label_distribution"] = label_counts # Save metadata.json metadata_path = ref_dir / "metadata.json" with open(metadata_path, "w") as f: json.dump(metadata, f, indent=2, default=str)
[docs] def save( self, filepath: str, save_type: Optional[Literal["tidymut", "pickle", "dataframe"]] = "tidymut", ): """ Save the dataset to files. Parameters: filepath: Base filepath (without extension) save_type: Type of save format ("tidymut", "dataframe" or "pickle") For save_type="dataframe": - Saves mutations as {filepath}.csv - Saves reference sequences as {filepath}_refs.pkl - Saves metadata as {filepath}_meta.json For save_type="pickle": - Saves entire dataset as {filepath}.pkl Example: dataset.save("my_study", "dataframe") # Creates: my_study.csv, my_study_refs.pkl, my_study_meta.json """ base_path = Path(filepath) if save_type == "dataframe": # Save mutations as CSV df = self.to_dataframe() csv_path = base_path.with_suffix(".csv") df.to_csv(csv_path, index=False) # Save reference sequences as pickle refs_path = base_path.with_suffix("").with_name( f"{base_path.name}_refs.pkl" ) with open(refs_path, "wb") as f: pickle.dump(self.reference_sequences, f) # Save dataset metadata as JSON meta_path = base_path.with_suffix("").with_name( f"{base_path.name}_meta.json" ) dataset_meta = { "name": self.name, "metadata": self.metadata, "save_type": save_type, "num_mutation_sets": len(self.mutation_sets), "num_reference_sequences": len(self.reference_sequences), } with open(meta_path, "w") as f: json.dump(dataset_meta, f, indent=2) tqdm.write(f"Dataset saved to:") tqdm.write(f" Mutations: {csv_path}") tqdm.write(f" References: {refs_path}") tqdm.write(f" Metadata: {meta_path}") elif save_type == "pickle": # Save entire dataset as pickle pkl_path = base_path.with_suffix(".pkl") with open(pkl_path, "wb") as f: pickle.dump(self, f) tqdm.write(f"Dataset saved to: {pkl_path}") elif save_type == "tidymut": # Save as TidyMut format if base_path.suffix != "": raise ValueError( f"Invalid TidyMut save format. Expected folder but got {base_path.suffix}." ) self.save_by_reference(base_path) else: raise ValueError( f"Unsupported save_type: {save_type}. Use 'tidymut', 'dataframe' or 'pickle'" )
# ====== load ======
[docs] @classmethod def load_by_reference( cls, base_dir: Union[str, Path], dataset_name: Optional[str] = None, is_zero_based: bool = True, ) -> "MutationDataset": """ Load a dataset from tidymut reference-based format. Parameters ---------- base_dir : Union[str, Path] Base directory containing reference folders dataset_name : Optional[str], default=None Optional name for the loaded dataset is_zero_based : bool, default=True Whether origin mutation positions are zero-based Returns ------- MutationDataset instance Expected directory structure: base_dir/ ├── reference_id_1/ │ ├── data.csv │ ├── wt.fasta │ └── metadata.json ├── reference_id_2/ │ ├── data.csv │ ├── wt.fasta │ └── metadata.json └── ... """ import json base_path = Path(base_dir) if not base_path.exists(): raise FileNotFoundError(f"Base directory not found: {base_path}") if not base_path.is_dir(): raise ValueError(f"Path is not a directory: {base_path}") tqdm.write(f"Loading dataset from: {base_path}") # Find all reference directories ref_dirs = [d for d in base_path.iterdir() if d.is_dir()] if not ref_dirs: raise ValueError(f"No reference directories found in {base_path}") # Create new dataset dataset = cls(name=dataset_name) # Process each reference directory for ref_dir in tqdm(ref_dirs, desc="Loading references"): # Required files data_path = ref_dir / "data.csv" fasta_path = ref_dir / "wt.fasta" metadata_path = ref_dir / "metadata.json" # Check required files exist if not data_path.exists(): tqdm.write(f"Warning: Skipping {ref_dir.name} - data.csv not found") continue if not fasta_path.exists(): tqdm.write(f"Warning: Skipping {ref_dir.name} - wt.fasta not found") continue # Load metadata to get original reference_id original_ref_id = ref_dir.name # Default to directory name if metadata_path.exists(): try: with open(metadata_path, "r") as f: metadata = json.load(f) original_ref_id = metadata.get("reference_id", ref_dir.name) except Exception as e: tqdm.write( f"Warning: Could not load metadata for {ref_dir.name}: {e}" ) # Load reference sequence from FASTA sequence_type = SEQUENCE_TYPE_MAP.get( metadata.get("sequence_type", "ProteinSequence"), ProteinSequence ) ref_sequence = list( load_sequences_from_fasta( fasta_path, sequence_type, header_func=lambda x: (x, "") ).values() )[ 0 ] # get first sequence when mutli sequences dataset.add_reference_sequence(original_ref_id, ref_sequence) # Load mutation data try: df_ref = pd.read_csv(data_path) required_cols = ["mutation_name", "mutated_sequence", "label"] missing_cols = [ col for col in required_cols if col not in df_ref.columns ] if missing_cols: tqdm.write( f"Warning: Skipping {ref_dir.name} - missing columns: {missing_cols}" ) continue # Process each mutation set for _, row in df_ref.iterrows(): mutation_name = row["mutation_name"] label = row["label"] # Parse mutation from mutation_name try: mutation_set = MutationSet.from_string( mutation_name, sep=",", is_zero_based=is_zero_based ) dataset.add_mutation_set(mutation_set, original_ref_id, label) except Exception as e: tqdm.write( f"Warning: Could not parse mutation '{mutation_name}': {e}" ) continue tqdm.write(f" Loaded {original_ref_id}: {len(df_ref)} mutation sets") except Exception as e: tqdm.write(f"Warning: Could not load data for {ref_dir.name}: {e}") continue if len(dataset) == 0: raise ValueError("No valid mutation sets were loaded") tqdm.write(f"Successfully loaded dataset with {len(dataset)} mutation sets") return dataset
[docs] @classmethod def load(cls, filepath: str, load_type: Optional[str] = None) -> "MutationDataset": """ Load a dataset from files. Parameters ---------- filepath : str Base filepath (with or without extension) load_type : Optional[str], default=None Type of load format ("tidymut", "dataframe" or "pickle"). If None, auto-detect from file extension. Returns ------- MutationDataset instance Example ------- >>> # Auto-detect from extension >>> dataset = MutationDataset.load("my_study.csv") >>> dataset = MutationDataset.load("my_study.pkl") >>> # Explicit type >>> dataset = MutationDataset.load("my_study", "dataframe") """ base_path = Path(filepath) # Auto-detect load type from extension if not specified if load_type is None: if base_path.suffix == ".csv": load_type = "dataframe" elif base_path.suffix == ".pkl": load_type = "pickle" elif base_path.suffix == "": load_type = "tidymut" else: # Try dataframe format first load_type = "dataframe" base_path = base_path.with_suffix("") # Remove any extension if load_type == "dataframe": # Remove extension to get base path if base_path.suffix == ".csv": base_path = base_path.with_suffix("") csv_path = base_path.with_suffix(".csv") refs_path = base_path.with_suffix("").with_name( f"{base_path.name}_refs.pkl" ) meta_path = base_path.with_suffix("").with_name( f"{base_path.name}_meta.json" ) # Check if files exist if not csv_path.exists(): raise FileNotFoundError(f"CSV file not found: {csv_path}") if not refs_path.exists(): raise FileNotFoundError(f"References file not found: {refs_path}") # Load mutations DataFrame df = pd.read_csv(csv_path) # Load reference sequences with open(refs_path, "rb") as f: reference_sequences = pickle.load(f) # Load metadata if available dataset_name = None dataset_metadata = {} if meta_path.exists(): with open(meta_path, "r") as f: meta = json.load(f) dataset_name = meta.get("name") dataset_metadata = meta.get("metadata", {}) # Create dataset using `from_dataframe` dataset = cls.from_dataframe(df, reference_sequences, dataset_name) dataset.metadata = dataset_metadata tqdm.write(f"Dataset loaded from:") tqdm.write(f" Mutations: {csv_path}") tqdm.write(f" References: {refs_path}") if meta_path.exists(): tqdm.write(f" Metadata: {meta_path}") return dataset elif load_type == "pickle": pkl_path = base_path.with_suffix(".pkl") if not pkl_path.exists(): raise FileNotFoundError(f"Pickle file not found: {pkl_path}") with open(pkl_path, "rb") as f: dataset = pickle.load(f) tqdm.write(f"Dataset loaded from: {pkl_path}") return dataset elif load_type == "tidymut": return cls.load_by_reference(base_path) else: raise ValueError( f"Unsupported load_type: {load_type}. Use 'dataframe' or 'pickle'" )
[docs] @classmethod def from_dataframe( cls, df: pd.DataFrame, reference_sequences: Dict[str, BaseSequence], name: Optional[str] = None, specific_mutation_type: Optional[Type[BaseMutation]] = None, ) -> "MutationDataset": """ Create a MutationDataset from a DataFrame containing mutation data. This method reconstructs a MutationDataset from a flattened DataFrame representation, typically used for loading saved mutation datasets from files. The DataFrame should contain mutation information with each row representing a single mutation within mutation sets. Parameters ---------- df : pd.DataFrame DataFrame containing mutation data with the following required columns: - 'mutation_set_id': Identifier for grouping mutations into sets - 'reference_id': Identifier for the reference sequence - 'mutation_string': String representation of the mutation - 'position': Position of the mutation in the sequence - 'mutation_type': Type of mutation ('amino_acid', 'codon_dna', 'codon_rna') Optional columns include: - 'mutation_set_name': Name of the mutation set - 'label': Label associated with the mutation set - 'wild_amino_acid': Wild-type amino acid (for amino acid mutations) - 'mutant_amino_acid': Mutant amino acid (for amino acid mutations) - 'wild_codon': Wild-type codon (for codon mutations) - 'mutant_codon': Mutant codon (for codon mutations) - 'set_*': Columns with 'set_' prefix for mutation set metadata - 'mutation_*': Columns with 'mutation_' prefix for individual mutation metadata reference_sequences : Dict[str, BaseSequence] Dictionary mapping reference sequence IDs to their corresponding BaseSequence objects. Must contain all reference sequences referenced in the DataFrame. name : Optional[str], default=None Optional name for the created MutationDataset. specific_mutation_type : Optional[BaseMutation], default=None The type of mutations to create. If None, will infer from first mutation must be provided when the mutation type is neither 'amino_acid' nor any 'codon_*' type. Returns ------- MutationDataset A new MutationDataset instance populated with the mutation sets and reference sequences from the DataFrame. Raises ------ ValueError If the DataFrame is empty, missing required columns, or references sequences not provided in reference_sequences dict. Notes ----- - Mutations are grouped by 'mutation_set_id' to reconstruct mutation sets - The method automatically determines the appropriate mutation set type (AminoAcidMutationSet, CodonMutationSet, or generic MutationSet) based on the mutation types within each set - Metadata is extracted from columns with 'set_' and 'mutation_' prefixes - Only reference sequences that are actually used in the DataFrame are added to the dataset Examples -------- >>> import pandas as pd >>> from sequences import ProteinSequence >>> >>> # Create sample DataFrame >>> df = pd.DataFrame({ ... 'mutation_set_id': ['set1', 'set1', 'set2'], ... 'reference_id': ['prot1', 'prot1', 'prot2'], ... 'mutation_string': ['A1V', 'L2P', 'G5R'], ... 'position': [1, 2, 5], ... 'mutation_type': ['amino_acid', 'amino_acid', 'amino_acid'], ... 'wild_amino_acid': ['A', 'L', 'G'], ... 'mutant_amino_acid': ['V', 'P', 'R'], ... 'mutation_set_name': ['variant1', 'variant1', 'variant2'], ... 'label': ['pathogenic', 'pathogenic', 'benign'] ... }) >>> >>> # Define reference sequences >>> ref_seqs = { ... 'prot1': ProteinSequence('ALDEFG', name='protein1'), ... 'prot2': ProteinSequence('MKGLRK', name='protein2') ... } >>> >>> # Create MutationDataset >>> dataset = MutationDataset.from_dataframe(df, ref_seqs, name="my_dataset") >>> print(len(dataset.mutation_sets)) 2 """ if df.empty: raise ValueError("DataFrame cannot be empty") # Validate required columns required_cols = [ "mutation_set_id", "reference_id", "mutation_string", "position", "mutation_type", ] missing_cols = [col for col in required_cols if col not in df.columns] if missing_cols: raise ValueError(f"DataFrame missing required columns: {missing_cols}") # Validate that all referenced sequences are provided df_ref_ids = set(df["reference_id"].dropna().unique()) provided_ref_ids = set(reference_sequences.keys()) missing_refs = df_ref_ids - provided_ref_ids if missing_refs: raise ValueError(f"Missing reference sequences for IDs: {missing_refs}") # Create new dataset dataset = cls(name=name) # Add reference sequences for ref_id, sequence in tqdm( reference_sequences.items(), desc="Adding reference sequences" ): if ref_id in df_ref_ids: # Only add sequences that are actually used dataset.add_reference_sequence(ref_id, sequence) # Rocognize metadata columns set_metadata_cols = [ col for col in df.columns if col.startswith("set_") and col != "set_metadata" ] mutation_metadata_cols = [ col for col in df.columns if col.startswith("mutation_") and col not in { "mutation_id", "mutation_type", "mutation_string", "mutation_category", } ] # Group by mutation set to rebuild mutation sets grouped = df.groupby("mutation_set_id", sort=False) for _, group in tqdm(grouped, desc="Reconstructing mutation sets"): # Get mutation set info set_info = group.iloc[0] set_name = set_info.get("mutation_set_name") reference_id = set_info["reference_id"] label = set_info.get("label") # Extract set metadata set_metadata = { col[4:]: value for col in set_metadata_cols if pd.notna(value := set_info[col]) } # Create mutations from group mutations = [] columns = list(group.columns) for values in group.values: row_dict = dict(zip(columns, values)) mutation = cls._create_mutation_from_dict( row_dict, mutation_metadata_cols, specific_mutation_type ) mutations.append(mutation) # Create appropriate mutation set type if mutations: mutation_type = type(mutations[0]) if mutation_type == AminoAcidMutation: mutation_set = AminoAcidMutationSet( mutations=mutations, name=set_name, metadata=set_metadata # type: ignore ) elif mutation_type == CodonMutation: mutation_set = CodonMutationSet( mutations=mutations, name=set_name, metadata=set_metadata # type: ignore ) else: mutation_set = MutationSet( mutations=mutations, mutation_type=mutation_type, name=set_name, metadata=set_metadata, ) # Add to dataset dataset.add_mutation_set(mutation_set, reference_id, label) return dataset
@staticmethod def _create_mutation_from_dict( row_dict: Dict[str, Any], metadata_cols: List[str], specific_mutation_type: Optional[Type[BaseMutation]] = None, ) -> BaseMutation: """ Create a mutation object from a DataFrame row dict Used by the from_dataframe() method Parameters ---------- row_dict: Dict[str, Any] A pandas row dict containing mutation data with the following required fields: - 'mutation_type': Type of mutation ('amino_acid', 'codon_*', or other) For amino acid mutations, also requires: - 'wild_amino_acid': Wild-type amino acid - 'mutant_amino_acid': Mutant amino acid For codon mutations, also requires: - 'wild_codon': Wild-type codon - 'mutant_codon': Mutant codon For other mutations: - 'mutation_string': String representation of the mutation Optional fields: - Any column starting with 'mutation_' (excluding specific reserved names) will be added as metadata to the mutation object mutation_metadata_cols : List[str] Metadata columns to be extracted and added to the mutation object specific_mutation_type : Optional[Type[BaseMutation]], default=None Only used when the mutation type is neither 'amino_acid' nor any 'codon_*' type. Returns ------- BaseMutation An instance of the appropriate mutation subclass: - AminoAcidMutation for 'amino_acid' type - CodonMutation for types starting with 'codon_' - Inferred mutation type for other types (parsed from mutation_string) """ mutation_type = row_dict["mutation_type"] position = int(row_dict["position"]) # Filter metadata mutation_metadata = { key[9:]: row_dict[key] # Remove "mutation_" prefix for key in metadata_cols if not pd.isna(row_dict[key]) } if mutation_type == "amino_acid": return AminoAcidMutation( wild_type=row_dict["wild_amino_acid"], position=position, mutant_type=row_dict["mutant_amino_acid"], metadata=mutation_metadata, ) elif mutation_type.startswith("codon_"): return CodonMutation( wild_type=row_dict["wild_codon"], position=position, mutant_type=row_dict["mutant_codon"], metadata=mutation_metadata, ) else: # FIXME: need to handle other mutation types # Try to parse from mutation string as fallback if specific_mutation_type is None: raise ValueError( f"Unsupported mutation type: {mutation_type}, " f"you must provide a specific mutation type" ) mutation_string = row_dict["mutation_string"] try: return MutationSet._create_mutation( mutation_string, mutation_type=specific_mutation_type, is_zero_based=True, ) except Exception as e: raise ValueError(f"Cannot create mutation from row: {e}")