Source code for indra.explanation.model_checker

from __future__ import print_function, unicode_literals, absolute_import
from builtins import dict, str
from future.utils import python_2_unicode_compatible
import logging
import numbers
import networkx as nx
import itertools
import numpy as np
import scipy.stats
from copy import deepcopy
from collections import deque, defaultdict, namedtuple
import kappy
from pysb import WILD, export, Observable, ComponentSet
from pysb.core import as_complex_pattern, ComponentDuplicateNameError
from indra.statements import *
from indra.assemblers import pysb_assembler as pa
from indra.tools.expand_families import _agent_from_uri
from indra.explanation import paths_graph as pg
from collections import Counter
from indra.util.kappa_util import im_json_to_graph

logger = logging.getLogger('model_checker')


[docs]class PathMetric(object): """Describes results of simple path search (path existence).""" def __init__(self, source_node, target_node, polarity, length): self.source_node = source_node self.target_node = target_node self.polarity = polarity self.length = length def __repr__(self): return str(self) @python_2_unicode_compatible def __str__(self): return ('source_node: %s, target_node: %s, polarity: %s, length: %d' % (self.source_node, self.target_node, self.polarity,
self.length))
[docs]class PathResult(object): """Describes results of running the ModelChecker on a single Statement. Parameters ---------- path_found : bool result_code : string STATEMENT_TYPE_NOT_HANDLED SUBJECT_MONOMERS_NOT_FOUND OBSERVABLES_NOT_FOUND NO_PATHS_FOUND MAX_PATH_LENGTH_EXCEEDED PATHS_FOUND INPUT_RULES_NOT_FOUND MAX_PATHS_ZERO Attributes ---------- path_found : boolean result_code : string path_metrics : list of PathMetric paths : list of paths max_paths : max_path_length : """ def __init__(self, path_found, result_code, max_paths, max_path_length): self.path_found = path_found self.result_code = result_code self.max_paths = max_paths self.max_path_length = max_path_length self.path_metrics = [] self.paths = [] def add_path(self, path): self.paths.append(path) def add_metric(self, path_metric): self.path_metrics.append(path_metric) @python_2_unicode_compatible def __str__(self): summary = textwrap.dedent(""" PathResult: path_found: {path_found} result_code: {result_code} path_metrics: {path_metrics} paths: {paths} max_paths: {max_paths} max_path_length: {max_path_length}""") ws = '\n ' # String representation of path metrics if not self.path_metrics: pm_str = str(self.path_metrics) else: pm_str = ws + ws.join(['%d: %s' % (pm_ix, pm) for pm_ix, pm in enumerate(self.path_metrics)]) def format_path(path, num_spaces=11): path_ws = '\n' + (' ' * num_spaces) return path_ws.join([str(p) for p in path]) # String representation of paths if not self.paths: path_str = str(self.paths) else: path_str = ws + ws.join(['%d: %s' % (p_ix, format_path(p)) for p_ix, p in enumerate(self.paths)]) return summary.format(path_found=self.path_found, result_code=self.result_code, max_paths=self.max_paths, max_path_length=self.max_path_length, path_metrics=pm_str, paths=path_str) def __repr__(self):
return str(self)
[docs]class ModelChecker(object): """Check a PySB model against a set of INDRA statements. Parameters ---------- model : pysb.Model A PySB model to check. statements : Optional[list[indra.statements.Statement]] A list of INDRA Statements to check the model against. agent_obs: Optional[list[indra.statements.Agent]] A list of INDRA Agents in a given state to be observed. do_sampling : bool Whether to use breadth-first search or weighted sampling to generate paths. Default is False (breadth-first search). seed : int Random seed for sampling (optional, default is None). """ def __init__(self, model, statements=None, agent_obs=None, do_sampling=False, seed=None): self.model = model if statements: self.statements = statements else: self.statements = [] if agent_obs: self.agent_obs = agent_obs else: self.agent_obs = [] if seed is not None: np.random.seed(seed) # Whether to do sampling self.do_sampling = do_sampling # Influence map self._im = None # Map from statements to associated observables self.stmt_to_obs = {} # Map from agents to associated observables self.agent_to_obs = {} # Map between rules and downstream observables self.rule_obs_dict = {}
[docs] def add_statements(self, stmts): """Add to the list of statements to check against the model. Parameters ---------- stmts : list[indra.statements.Statement] The list of Statements to be added for checking. """
self.statements += stmts
[docs] def generate_im(self, model): """Return a graph representing the influence map generated by Kappa Parameters ---------- model : pysb.Model The PySB model whose influence map is to be generated Returns ------- graph : networkx.MultiDiGraph A MultiDiGraph representing the influence map """ kappa = kappy.KappaStd() model_str = export.export(model, 'kappa') kappa.add_model_string(model_str) kappa.project_parse() imap = kappa.analyses_influence_map() graph = im_json_to_graph(imap)
return graph
[docs] def get_im(self, force_update=False): """Get the influence map for the model, generating it if necessary. Parameters ---------- force_update : bool Whether to generate the influence map when the function is called. If False, returns the previously generated influence map if available. Defaults to True. Returns ------- networkx MultiDiGraph object containing the influence map. The influence map can be rendered as a pdf using the dot layout program as follows:: im_agraph = nx.nx_agraph.to_agraph(influence_map) im_agraph.draw('influence_map.pdf', prog='dot') """ if self._im and not force_update: return self._im if not self.model: raise Exception("Cannot get influence map if there is no model.") def add_obs_for_agent(agent): obj_mps = list(pa.grounded_monomer_patterns(self.model, agent)) if not obj_mps: logger.debug('No monomer patterns found in model for agent %s, ' 'skipping' % agent) return obs_list = [] for obj_mp in obj_mps: obs_name = _monomer_pattern_label(obj_mp) + '_obs' # Add the observable obj_obs = Observable(obs_name, obj_mp, _export=False) obs_list.append(obs_name) try: self.model.add_component(obj_obs) except ComponentDuplicateNameError as e: pass return obs_list # Create observables for all statements to check, and add to model # Remove any existing observables in the model self.model.observables = ComponentSet([]) for stmt in self.statements: # Generate observables for Modification statements if isinstance(stmt, Modification): mod_condition_name = modclass_to_modtype[stmt.__class__] if isinstance(stmt, RemoveModification): mod_condition_name = modtype_to_inverse[mod_condition_name] # Add modification to substrate agent modified_sub = _add_modification_to_agent(stmt.sub, mod_condition_name, stmt.residue, stmt.position) obs_list = add_obs_for_agent(modified_sub) # Associate this statement with this observable self.stmt_to_obs[stmt] = obs_list # Generate observables for Activation/Inhibition statements elif isinstance(stmt, RegulateActivity): regulated_obj, polarity = \ _add_activity_to_agent(stmt.obj, stmt.obj_activity, stmt.is_activation) obs_list = add_obs_for_agent(regulated_obj) # Associate this statement with this observable self.stmt_to_obs[stmt] = obs_list elif isinstance(stmt, RegulateAmount): obs_list = add_obs_for_agent(stmt.obj) self.stmt_to_obs[stmt] = obs_list # Add observables for each agent for ag in self.agent_obs: obs_list = add_obs_for_agent(ag) self.agent_to_obs[ag] = obs_list logger.info("Generating influence map") self._im = self.generate_im(self.model) #self._im.is_multigraph = lambda: False # Now, for every rule in the model, check if there are any observables # downstream; alternatively, for every observable in the model, get a # list of rules. # We'll need the dictionary to check if nodes are observables node_attributes = nx.get_node_attributes(self._im, 'node_type') for rule in self.model.rules: obs_list = [] # Get successors of the rule node for neighb in self._im.neighbors(rule.name): # Check if the node is an observable if node_attributes[neighb] != 'variable': continue # Get the edge and check the polarity edge_sign = _get_edge_sign(self._im, (rule.name, neighb)) obs_list.append((neighb, edge_sign)) self.rule_obs_dict[rule.name] = obs_list
return self._im
[docs] def check_model(self, max_paths=1, max_path_length=5): """Check all the statements added to the ModelChecker. Parameters ---------- max_paths : Optional[int] The maximum number of specific paths to return for each Statement to be explained. Default: 1 max_path_length : Optional[int] The maximum length of specific paths to return. Default: 5 Returns ------- list of (Statement, PathResult) Each tuple contains the Statement checked against the model and a PathResult object describing the results of model checking. """ results = [] for stmt in self.statements: result = self.check_statement(stmt, max_paths, max_path_length) results.append((stmt, result))
return results
[docs] def check_statement(self, stmt, max_paths=1, max_path_length=5): """Check a single Statement against the model. Parameters ---------- stmt : indra.statements.Statement The Statement to check. max_paths : Optional[int] The maximum number of specific paths to return for each Statement to be explained. Default: 1 max_path_length : Optional[int] The maximum length of specific paths to return. Default: 5 Returns ------- boolean True if the model satisfies the Statement. """ # Make sure the influence map is initialized self.get_im() if isinstance(stmt, Modification): return self._check_modification(stmt, max_paths, max_path_length) elif isinstance(stmt, RegulateActivity): return self._check_regulate_activity(stmt, max_paths, max_path_length) elif isinstance(stmt, RegulateAmount): return self._check_regulate_amount(stmt, max_paths, max_path_length) else: return PathResult(False, 'STATEMENT_TYPE_NOT_HANDLED',
max_paths, max_path_length) def _check_regulate_activity(self, stmt, max_paths, max_path_length): """Check a RegulateActivity statement.""" logger.info('Checking stmt: %s' % stmt) # FIXME Currently this will match rules with the corresponding monomer # pattern from the Activation/Inhibition statement, which will nearly # always have no state conditions on it. In future, this statement foo # should also match rules in which 1) the agent is in its active form, # or 2) the agent is tagged as the enzyme in a rule of the appropriate # activity (e.g., a phosphorylation rule) FIXME subj_mp = pa.get_monomer_pattern(self.model, stmt.subj) target_polarity = 1 if stmt.is_activation else -1 # This may fail, since there may be no rule in the model activating the # object, and the object may not have an "active" site of the # appropriate type obs_names = self.stmt_to_obs[stmt] for obs_name in obs_names: return self._find_im_paths(subj_mp, obs_name, target_polarity, max_paths, max_path_length) def _check_regulate_amount(self, stmt, max_paths, max_path_length): """Check a RegulateAmount statement.""" logger.info('Checking stmt: %s' % stmt) subj_mp = pa.get_monomer_pattern(self.model, stmt.subj) target_polarity = 1 if isinstance(stmt, IncreaseAmount) else -1 obs_names = self.stmt_to_obs[stmt] for obs_name in obs_names: return self._find_im_paths(subj_mp, obs_name, target_polarity, max_paths, max_path_length) def _check_modification(self, stmt, max_paths, max_path_length): """Check a Modification statement.""" # Identify the observable we're looking for in the model, which # may not exist! # The observable is the modified form of the substrate logger.info('Checking stmt: %s' % stmt) # Look for an agent with the appropriate grounding in the model if stmt.enz is not None: enz_mps = list(pa.grounded_monomer_patterns(self.model, stmt.enz)) if not enz_mps: logger.debug('No monomers found corresponding to agent %s' % stmt.enz) return PathResult(False, 'SUBJECT_MONOMERS_NOT_FOUND', max_paths, max_path_length) else: enz_mps = [None] # Get target polarity target_polarity = -1 if isinstance(stmt, RemoveModification) else 1 obs_names = self.stmt_to_obs[stmt] if not obs_names: logger.debug("No observables for stmt %s, returning False" % stmt) return PathResult(False, 'OBSERVABLES_NOT_FOUND', max_paths, max_path_length) for enz_mp, obs_name in itertools.product(enz_mps, obs_names): # FIXME Returns on the path found for the first enz_mp/obs combo result = self._find_im_paths(enz_mp, obs_name, target_polarity, max_paths, max_path_length) # If result for this observable is not False, then we return it; # otherwise, that means there was no path for this observable, so # we have to try the next one if result.path_found: return result # If we got here, then there was no path for any observable return PathResult(False, 'NO_PATHS_FOUND', max_paths, max_path_length) def _get_input_rules(self, subj_mp): if subj_mp is None: raise ValueError("Cannot take None as an argument for subj_mp.") input_rules = _match_lhs(subj_mp, self.model.rules) logger.debug('Found %s input rules matching %s' % (len(input_rules), str(subj_mp))) # Filter to include only rules where the subj_mp is actually the # subject (i.e., don't pick up upstream rules where the subject # is itself a substrate/object) # FIXME: Note that this will eliminate rules where the subject # being checked is included on the left hand side as # a bound condition rather than as an enzyme. subj_rules = pa.rules_with_annotation(self.model, subj_mp.monomer.name, 'rule_has_subject') logger.debug('%d rules with %s as subject' % (len(subj_rules), subj_mp.monomer.name)) input_rule_set = set([r.name for r in input_rules]).intersection( set([r.name for r in subj_rules])) logger.debug('Final input rule set contains %d rules' % len(input_rule_set)) return input_rule_set def _sample_paths(self, input_rule_set, obs_name, target_polarity, max_paths=1, max_path_length=5): if max_paths == 0: raise ValueError("max_paths cannot be 0 for path sampling.") # Convert path polarity representation from 0/1 to 1/-1 def convert_polarities(path_list): return [tuple((n[0], 0 if n[1] > 0 else 1) for n in path) for path in path_list] pg_polarity = 0 if target_polarity > 0 else 1 nx_graph = _im_to_signed_digraph(self.get_im()) # Add edges from dummy node to input rules source_node = 'SOURCE_NODE' for rule in input_rule_set: nx_graph.add_edge(source_node, rule, attr_dict={'sign': 0}) # ------------------------------------------------- # Create combined paths_graph f_level, b_level = pg.get_reachable_sets(nx_graph, source_node, obs_name, max_path_length, signed=True) pg_list = [] for path_length in range(1, max_path_length+1): cfpg = pg.CFPG.from_graph( nx_graph, source_node, obs_name, path_length, f_level, b_level, signed=True, target_polarity=pg_polarity) pg_list.append(cfpg) combined_pg = pg.CombinedCFPG(pg_list) # Make sure the combined paths graph is not empty if not combined_pg.graph: pr = PathResult(False, 'NO_PATHS_FOUND', max_paths, max_path_length) pr.path_metrics = None pr.paths = [] return pr # Get a dict of rule objects rule_obj_dict = {} for ann in self.model.annotations: if ann.predicate == 'rule_has_object': rule_obj_dict[ann.subject] = ann.object # Get monomer initial conditions ic_dict = {} for mon in self.model.monomers: # FIXME: A hack that depends on the _0 convention ic_name = '%s_0' % mon.name # TODO: Wrap this in try/except? ic_param = self.model.parameters[ic_name] ic_value = ic_param.value ic_dict[mon.name] = ic_value # Set weights in PG based on model initial conditions for cur_node in combined_pg.graph.nodes(): edge_weights = {} rule_obj_list = [] edge_weights_by_gene = {} for u, v in combined_pg.graph.out_edges(cur_node): v_rule = v[1][0] # Get the object of the rule (a monomer name) rule_obj = rule_obj_dict.get(v_rule) if rule_obj: # Add to list so we can count instances by gene rule_obj_list.append(rule_obj) # Get the abundance of rule object from the initial # conditions # TODO: Wrap in try/except? ic_value = ic_dict[rule_obj] else: ic_value = 1.0 edge_weights[(u, v)] = ic_value edge_weights_by_gene[rule_obj] = ic_value # Get frequency of different rule objects rule_obj_ctr = Counter(rule_obj_list) # Normalize results by weight sum and gene frequency at this level edge_weight_sum = sum(edge_weights_by_gene.values()) edge_weights_norm = {} for e, v in edge_weights.items(): v_rule = e[1][1][0] rule_obj = rule_obj_dict.get(v_rule) if rule_obj: rule_obj_count = rule_obj_ctr[rule_obj] else: rule_obj_count = 1 edge_weights_norm[e] = ((v / float(edge_weight_sum)) / float(rule_obj_count)) # Add edge weights to paths graph nx.set_edge_attributes(combined_pg.graph, 'weight', edge_weights_norm) # Sample from the combined CFPG paths = combined_pg.sample_paths(max_paths) # ------------------------------------------------- if paths: pr = PathResult(True, 'PATHS_FOUND', max_paths, max_path_length) pr.path_metrics = None # Convert path polarity representation from 0/1 to 1/-1 pr.paths = convert_polarities(paths) # Strip off the SOURCE_NODE prefix pr.paths = [p[1:] for p in pr.paths] else: assert False pr = PathResult(False, 'NO_PATHS_FOUND', max_paths, max_path_length) pr.path_metrics = None pr.paths = [] return pr def _find_im_paths(self, subj_mp, obs_name, target_polarity, max_paths=1, max_path_length=5): """Check for a source/target path in the influence map. Parameters ---------- subj_mp : pysb.MonomerPattern MonomerPattern corresponding to the subject of the Statement being checked. obs_name : string Name of the PySB model Observable corresponding to the object/target of the Statement being checked. target_polarity : 1 or -1 Whether the influence in the Statement is positive (1) or negative (-1). Returns ------- boolean or list of str Whether there is a path from a rule matching the subject MonomerPattern to the object Observable with the appropriate polarity. """ logger.info(('Running path finding with max_paths=%d,' ' max_path_length=%d') % (max_paths, max_path_length)) # Find rules in the model corresponding to the input if subj_mp is None: input_rule_set = None else: input_rule_set = self._get_input_rules(subj_mp) if not input_rule_set: return PathResult(False, 'INPUT_RULES_NOT_FOUND', max_paths, max_path_length) logger.info('Finding paths between %s and %s with polarity %s' % (subj_mp, obs_name, target_polarity)) # -- Route to the path sampling function -- if self.do_sampling: return self._sample_paths(input_rule_set, obs_name, target_polarity, max_paths, max_path_length) # -- Do Breadth-First Enumeration -- # Generate the predecessors to our observable and count the paths path_lengths = [] path_metrics = [] for source, polarity, path_length in \ _find_sources(self.get_im(), obs_name, input_rule_set, target_polarity): pm = PathMetric(source, obs_name, polarity, path_length) path_metrics.append(pm) path_lengths.append(path_length) logger.info('Finding paths between %s and %s with polarity %s' % (subj_mp, obs_name, target_polarity)) # Now, look for paths paths = [] if path_metrics and max_paths == 0: pr = PathResult(True, 'MAX_PATHS_ZERO', max_paths, max_path_length) pr.path_metrics = path_metrics return pr elif path_metrics: if min(path_lengths) <= max_path_length: pr = PathResult(True, 'PATHS_FOUND', max_paths, max_path_length) pr.path_metrics = path_metrics # Get the first path path_iter = enumerate(_find_sources_with_paths( self.get_im(), obs_name, input_rule_set, target_polarity)) for path_ix, path in path_iter: flipped = _flip(self.get_im(), path) pr.add_path(flipped) if len(pr.paths) >= max_paths: break return pr # There are no paths shorter than the max path length, so we # don't bother trying to get them else: pr = PathResult(True, 'MAX_PATH_LENGTH_EXCEEDED', max_paths, max_path_length) pr.path_metrics = path_metrics return pr else: return PathResult(False, 'NO_PATHS_FOUND', max_paths, max_path_length)
[docs] def score_paths(self, paths, agents_values, loss_of_function=False, sigma=0.15, include_final_node=False): """Return scores associated with a given set of paths. Parameters ---------- paths : list[list[tuple[str, int]]] A list of paths obtained from path finding. Each path is a list of tuples (which are edges in the path), with the first element of the tuple the name of a rule, and the second element its polarity in the path. agents_values : dict[indra.statements.Agent, float] A dictionary of INDRA Agents and their corresponding measured value in a given experimental condition. loss_of_function : Optional[boolean] If True, flip the polarity of the path. For instance, if the effect of an inhibitory drug is explained, set this to True. Default: False sigma : Optional[float] The estimated standard deviation for the normally distributed measurement error in the observation model used to score paths with respect to data. Default: 0.15 include_final_node : Optional[boolean] Determines whether the final node of the path is included in the score. Default: False """ obs_model = lambda x: scipy.stats.norm(x, sigma) # Build up dict mapping observables to values obs_dict = {} for ag, val in agents_values.items(): obs_list = self.agent_to_obs[ag] if obs_list is not None: for obs in obs_list: obs_dict[obs] = val # For every path... path_scores = [] for path in paths: logger.info('------') logger.info("Scoring path:") logger.info(path) # Look at every node in the path, excluding the final # observable... path_score = 0 last_path_node_index = -1 if include_final_node else -2 for node, sign in path[:last_path_node_index]: # ...and for each node check the sign to see if it matches the # data. So the first thing is to look at what's downstream # of the rule # affected_obs is a list of observable names alogn for affected_obs, rule_obs_sign in self.rule_obs_dict[node]: flip_polarity = -1 if loss_of_function else 1 pred_sign = sign * rule_obs_sign * flip_polarity # Check to see if this observable is in the data logger.info('%s %s: effect %s %s' % (node, sign, affected_obs, pred_sign)) measured_val = obs_dict.get(affected_obs) if measured_val: # For negative predictions use CDF (prob that given # measured value, true value lies below 0) if pred_sign <= 0: prob_correct = obs_model(measured_val).logcdf(0) # For positive predictions, use log survival function # (SF = 1 - CDF, i.e., prob that true value is # above 0) else: prob_correct = obs_model(measured_val).logsf(0) logger.info('Actual: %s, Log Probability: %s' % (measured_val, prob_correct)) path_score += prob_correct if not self.rule_obs_dict[node]: logger.info('%s %s' % (node, sign)) prob_correct = obs_model(0).logcdf(0) logger.info('Unmeasured node, Log Probability: %s' % (prob_correct)) path_score += prob_correct # Normalized path #path_score = path_score / len(path) logger.info("Path score: %s" % path_score) path_scores.append(path_score) path_tuples = list(zip(paths, path_scores)) # Sort first by path length sorted_by_length = sorted(path_tuples, key=lambda x: len(x[0])) # Sort by probability; sort in reverse order to large values # (higher probabilities) are ranked higher scored_paths = sorted(sorted_by_length, key=lambda x: x[1], reverse=True)
return scored_paths
[docs] def prune_influence_map(self): """Remove edges between rules causing problematic non-transitivity. First, all self-loops are removed. After this initial step, edges are removed between rules when they share *all* child nodes except for each other; that is, they have a mutual relationship with each other and share all of the same children. Note that edges must be removed in batch at the end to prevent edge removal from affecting the lists of rule children during the comparison process. """ logger.info('Removing self loops') im = self.get_im() # First, remove all self-loops for e in im.edges(): if e[0] == e[1]: logger.info('Removing self loop: %s', e) im.remove_edge(e[0], e[1]) # Now compare nodes pairwise and look for overlap between child nodes edges_to_remove = [] remove_im_params(self.model, im) successors = im.successors_iter succ_dict = {} logger.info('Get successorts of each node') for node in im.nodes(): succ_dict[node] = set(successors(node)) logger.info('Compare combinations of successors') combos = list(itertools.combinations(im.nodes(), 2)) for ix, (p1, p2) in enumerate(combos): # Children are identical except for mutual relationship if succ_dict[p1].difference(succ_dict[p2]) == set([p2]) and \ succ_dict[p2].difference(succ_dict[p1]) == set([p1]): for u, v in ((p1, p2), (p2, p1)): edge = (u, v) edges_to_remove.append(edge) edge_sign = _get_edge_sign(im, edge) logger.debug('Will remove edge (%s, %s) with polarity %s', u, v, edge_sign) for edge in im.edges(): if edge in edges_to_remove:
im.remove_edge(edge[0], edge[1]) def _find_sources_sample(im, target, sources, polarity, rule_obs_dict, agent_to_obs, agents_values): # Build up dict mapping observables to values obs_dict = {} for ag, val in agents_values.items(): obs_list = agent_to_obs[ag] for obs in obs_list: obs_dict[obs] = val sigma = 0.2 def obs_model(x): return scipy.stats.norm(x, sigma) def _sample_pred(im, target, rule_obs_dict, obs_model): preds = list(_get_signed_predecessors(im, target, 1)) if not preds: return None pred_scores = [] for pred, sign in preds: pred_score = 0 for affected_obs, rule_obs_sign in rule_obs_dict[pred]: pred_sign = sign * rule_obs_sign # Check to see if this observable is in the data logger.info('%s %s: effect %s %s' % (pred, sign, affected_obs, pred_sign)) measured_val = obs_dict.get(affected_obs) if measured_val: logger.info('Actual: %s' % measured_val) # The tail probability of the real value being above 1 tail_prob = obs_model(measured_val).cdf(1) pred_score += (tail_prob if pred_sign == 1 else 1-tail_prob) pred_scores.append(pred_score) # Normalize scores pred_scores = np.array(pred_scores) / np.sum(pred_scores) pred_idx = np.random.choice(range(len(preds)), p=pred_scores) pred = preds[pred_idx] return pred preds = [] for i in range(100): pred = _sample_pred(im, target, rule_obs_dict, obs_model) preds.append(pred[0]) def _find_sources_with_paths(im, target, sources, polarity): """Get the subset of source nodes with paths to the target. Given a target, a list of sources, and a path polarity, perform a breadth-first search upstream from the target to find paths to any of the upstream sources. Parameters ---------- im : networkx.MultiDiGraph Graph containing the influence map. target : string The node (rule name) in the influence map to start looking upstream for marching sources. sources : list of strings The nodes (rules) corresponding to the subject or upstream influence being checked. polarity : int Required polarity of the path between source and target. Returns ------- generator of path Yields paths as lists of nodes (rule names). If there are no paths to any of the given source nodes, the generator is empty. """ # First, create a list of visited nodes # Adapted from # http://stackoverflow.com/questions/8922060/ # how-to-trace-the-path-in-a-breadth-first-search # FIXME: the sign information for the target should be associated with # the observable itself queue = deque([[(target, 1)]]) while queue: # Get the first path in the queue path = queue.popleft() node, node_sign = path[-1] # If there's only one node in the path, it's the observable we're # starting from, so the path is positive # if len(path) == 1: # sign = 1 # Because the path runs from target back to source, we have to reverse # the path to calculate the overall polarity #else: # sign = _path_polarity(im, reversed(path)) # Don't allow trivial paths consisting only of the target observable if (sources is None or node in sources) and node_sign == polarity \ and len(path) > 1: logger.debug('Found path: %s' % str(_flip(im, path))) yield tuple(path) for predecessor, sign in _get_signed_predecessors(im, node, node_sign): # Only add predecessors to the path if it's not already in the # path--prevents loops if (predecessor, sign) in path: continue # Otherwise, the new path is a copy of the old one plus the new # predecessor new_path = list(path) new_path.append((predecessor, sign)) queue.append(new_path) return
[docs]def remove_im_params(model, im): """Remove parameter nodes from the influence map. Parameters ---------- model : pysb.core.Model PySB model. im : networkx.MultiDiGraph Influence map. Returns ------- networkx.MultiDiGraph Influence map with the parameter nodes removed. """ for param in model.parameters: # If the node doesn't exist e.g., it may have already been removed), # skip over the parameter without error try: im.remove_node(param.name) except:
pass def _find_sources(im, target, sources, polarity): """Get the subset of source nodes with paths to the target. Given a target, a list of sources, and a path polarity, perform a breadth-first search upstream from the target to determine whether any of the queried sources have paths to the target with the appropriate polarity. For efficiency, does not return the full path, but identifies the upstream sources and the length of the path. Parameters ---------- im : networkx.MultiDiGraph Graph containing the influence map. target : string The node (rule name) in the influence map to start looking upstream for marching sources. sources : list of strings The nodes (rules) corresponding to the subject or upstream influence being checked. polarity : int Required polarity of the path between source and target. Returns ------- generator of (source, polarity, path_length) Yields tuples of source node (string), polarity (int) and path length (int). If there are no paths to any of the given source nodes, the generator isignempty. """ # First, create a list of visited nodes # Adapted from # networkx.algorithms.traversal.breadth_first_search.bfs_edges visited = set([(target, 1)]) # Generate list of predecessor nodes with a sign updated according to the # sign of the target node target_tuple = (target, 1) # The queue holds tuples of "parents" (in this case downstream nodes) and # their "children" (in this case their upstream influencers) queue = deque([(target_tuple, _get_signed_predecessors(im, target, 1), 0)]) while queue: parent, children, path_length = queue[0] try: # Get the next child in the list (child, sign) = next(children) # Is this child one of the source nodes we're looking for? If so, # yield it along with path length. if (sources is None or child in sources) and sign == polarity: logger.debug("Found path to %s from %s with desired sign %s " "with length %d" % (target, child, polarity, path_length+1)) yield (child, sign, path_length+1) # Check this child against the visited list. If we haven't visited # it already (accounting for the path to the node), then add it # to the queue. if (child, sign) not in visited: visited.add((child, sign)) queue.append(((child, sign), _get_signed_predecessors(im, child, sign), path_length + 1)) # Once we've finished iterating over the children of the current node, # pop the node off and go to the next one in the queue except StopIteration: queue.popleft() # There was no path; this will produce an empty generator return def _get_signed_predecessors(im, node, polarity): """Get upstream nodes in the influence map. Return the upstream nodes along with the overall polarity of the path to that node by account for the polarity of the path to the given node and the polarity of the edge between the given node and its immediate predecessors. Parameters ---------- im : networkx.MultiDiGraph Graph containing the influence map. node : string The node (rule name) in the influence map to get predecessors (upstream nodes) for. polarity : int Polarity of the overall path to the given node. Returns ------- generator of tuples, (node, polarity) Each tuple returned contains two elements, a node (string) and the polarity of the overall path (int) to that node. """ signed_pred_list = [] predecessors = im.predecessors_iter for pred in predecessors(node): pred_edge = (pred, node) yield (pred, _get_edge_sign(im, pred_edge) * polarity) def _get_edge_sign(im, edge): """Get the polarity of the influence by examining the edge sign.""" edge_data = im[edge[0]][edge[1]] # Handle possible multiple edges between nodes signs = list(set([v['sign'] for v in edge_data.values() if v.get('sign')])) if len(signs) > 1: logger.warning("Edge %s has conflicting polarities; choosing " "positive polarity by default" % str(edge)) sign = 1 else: sign = signs[0] if sign is None: raise Exception('No sign attribute for edge.') elif abs(sign) == 1: return sign else: raise Exception('Unexpected edge sign: %s' % edge.attr['sign']) def _add_modification_to_agent(agent, mod_type, residue, position): """Add a modification condition to an Agent.""" new_mod = ModCondition(mod_type, residue, position) # Check if this modification already exists for old_mod in agent.mods: if old_mod.equals(new_mod): return agent new_agent = deepcopy(agent) new_agent.mods.append(new_mod) return new_agent def _add_activity_to_agent(agent, act_type, is_active): # Default to active, and return polarity if it's an inhibition new_act = ActivityCondition(act_type, True) # Check if this state already exists if agent.activity is not None and agent.activity.equals(new_act): return agent new_agent = deepcopy(agent) new_agent.activity = new_act polarity = 1 if is_active else -1 return (new_agent, polarity) def _match_lhs(cp, rules): """Get rules with a left-hand side matching the given ComplexPattern.""" rule_matches = [] for rule in rules: reactant_pattern = rule.rule_expression.reactant_pattern for rule_cp in reactant_pattern.complex_patterns: if _cp_embeds_into(rule_cp, cp): rule_matches.append(rule) break return rule_matches def _cp_embeds_into(cp1, cp2): """Check that any state in ComplexPattern2 is matched in ComplexPattern1. """ # Check that any state in cp2 is matched in cp1 # If the thing we're matching to is just a monomer pattern, that makes # things easier--we just need to find the corresponding monomer pattern # in cp1 if cp1 is None or cp2 is None: return False cp1 = as_complex_pattern(cp1) cp2 = as_complex_pattern(cp2) if len(cp2.monomer_patterns) == 1: mp2 = cp2.monomer_patterns[0] # Iterate over the monomer patterns in cp1 and see if there is one # that has the same name for mp1 in cp1.monomer_patterns: if _mp_embeds_into(mp1, mp2): return True return False def _mp_embeds_into(mp1, mp2): """Check that conditions in MonomerPattern2 are met in MonomerPattern1.""" sc_matches = [] if mp1.monomer.name != mp2.monomer.name: return False # Check that all conditions in mp2 are met in mp1 for site_name, site_state in mp2.site_conditions.items(): if site_name not in mp1.site_conditions or \ site_state != mp1.site_conditions[site_name]: return False return True """ # NOTE: This code is currently "deprecated" because it has been replaced by the # use of Observables for the Statement objects. def match_rhs(cp, rules): rule_matches = [] for rule in rules: product_pattern = rule.rule_expression.product_pattern for rule_cp in product_pattern.complex_patterns: if _cp_embeds_into(rule_cp, cp): rule_matches.append(rule) break return rule_matches def find_production_rules(cp, rules): # Find rules where the CP matches the left hand side lhs_rule_set = set(_match_lhs(cp, rules)) # Now find rules where the CP matches the right hand side rhs_rule_set = set(match_rhs(cp, rules)) # Production rules are rules where there is a match on the right hand # side but not on the left hand side prod_rules = list(rhs_rule_set.difference(lhs_rule_set)) return prod_rules def find_consumption_rules(cp, rules): # Find rules where the CP matches the left hand side lhs_rule_set = set(_match_lhs(cp, rules)) # Now find rules where the CP matches the right hand side rhs_rule_set = set(match_rhs(cp, rules)) # Consumption rules are rules where there is a match on the left hand # side but not on the right hand side cons_rules = list(lhs_rule_set.difference(rhs_rule_set)) return cons_rules """ def _flip(im, path): # Reverse the path and the polarities associated with each node rev = tuple(reversed(path)) return _path_with_polarities(im, rev) def _path_with_polarities(im, path): # This doesn't address the effect of the rules themselves on the # observables of interest--just the effects of the rules on each other edge_polarities = [] path_list = list(path) edges = zip(path_list[0:-1], path_list[1:]) for from_tup, to_tup in edges: from_rule = from_tup[0] to_rule = to_tup[0] edge = (from_rule, to_rule) edge_polarities.append(_get_edge_sign(im, edge)) # Compute and return the overall path polarity #path_polarity = np.prod(edge_polarities) # Calculate left product of edge polarities return polarities_lprod = [1] for ep_ix, ep in enumerate(edge_polarities): polarities_lprod.append(polarities_lprod[-1] * ep) assert len(path) == len(polarities_lprod) return tuple(zip([node for node, sign in path], polarities_lprod)) #assert path_polarity == 1 or path_polarity == -1 #return True if path_polarity == 1 else False #return path_polarity
[docs]def stmt_from_rule(rule_name, model, stmts): """Return the source INDRA Statement corresponding to a rule in a model. Parameters ---------- rule_name : str The name of a rule in the given PySB model. model : pysb.core.Model A PySB model which contains the given rule. stmts : list[indra.statements.Statement] A list of INDRA Statements from which the model was assembled. Returns ------- stmt : indra.statements.Statement The Statement from which the given rule in the model was obtained. """ stmt_uuid = None for ann in model.annotations: if ann.subject == rule_name: if ann.predicate == 'from_indra_statement': stmt_uuid = ann.object break if stmt_uuid: for stmt in stmts: if stmt.uuid == stmt_uuid:
return stmt def _monomer_pattern_label(mp): """Return a string label for a MonomerPattern.""" site_strs = [] for site, cond in mp.site_conditions.items(): if isinstance(cond, tuple) or isinstance(cond, list): assert len(cond) == 2 if cond[1] == WILD: site_str = '%s_%s' % (site, cond[0]) else: site_str = '%s_%s%s' % (site, cond[0], cond[1]) elif isinstance(cond, numbers.Real): continue else: site_str = '%s_%s' % (site, cond) site_strs.append(site_str) return '%s_%s' % (mp.monomer.name, '_'.join(site_strs)) def _im_to_signed_digraph(im): edges = [] for e in im.edges(): edge_sign = _get_edge_sign(im, e) polarity = 0 if edge_sign > 0 else 1 edges.append((e[0], e[1], dict([('sign', polarity)]))) dg = nx.DiGraph() dg.add_edges_from(edges) return dg