Source code for indra.db.database_manager

from __future__ import absolute_import, print_function, unicode_literals
from builtins import dict, str

__all__ = ['sqltypes', 'texttypes', 'formats', 'DatabaseManager',
           'IndraDatabaseError', 'sql_expressions']

import re
import logging
from io import BytesIO
from os import path
from numbers import Number
from datetime import datetime

from sqlalchemy.sql import expression as sql_expressions
from sqlalchemy.schema import DropTable
from sqlalchemy.sql.expression import Delete, Update
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, Integer, String, UniqueConstraint, ForeignKey,\
    TIMESTAMP, create_engine, inspect, LargeBinary, Boolean, DateTime, func
from sqlalchemy.orm import relationship, sessionmaker
from sqlalchemy.dialects.postgresql import BYTEA


logger = logging.getLogger('db_manager')


# Solution to fix postgres drop tables
# See: https://stackoverflow.com/questions/38678336/sqlalchemy-how-to-implement-drop-table-cascade
@compiles(DropTable, "postgresql")
def _compile_drop_table(element, compiler, **kwargs):
    return compiler.visit_drop_table(element) + " CASCADE"


# Solution to fix deletes with constraints from multiple tables.
# See: https://groups.google.com/forum/#!topic/sqlalchemy/cIvgH2y01_o
@compiles(Delete)
def compile_delete(element, compiler, **kw):
    text = compiler.visit_delete(element, **kw)
    extra_froms = Update._extra_froms.__get__(element)
    if extra_froms:
        text = re.sub(
                    r"(FROM \S+)",
                    lambda m: "%s USING %s" % (
                        m.group(1),
                        ", ".join(
                            compiler.process(fr, asfrom=True, **kw)
                            for fr in extra_froms
                        )
                    ),
                    text
                )
    return text


try:
    from pgcopy import CopyManager
    CAN_COPY = True
except ImportError:
    print("WARNING: pgcopy unavailable. Bulk copies will be slow.")
    CopyManager = None
    CAN_COPY = False


def _isiterable(obj):
    "Bool determines if an object is an iterable (not a string)"
    return hasattr(obj, '__iter__') and not isinstance(obj, str)


class _map_class(object):
    @classmethod
    def _getattrs(self):
        return {
            k: v for k, v in self.__dict__.items() if not k.startswith('_')
            }

    @classmethod
    def items(self):
        return self._getattrs().items()

    @classmethod
    def values(self):
        return self._getattrs().values()

    @classmethod
    def keys(self):
        return self._getattrs().keys()


class sqltypes(_map_class):
    POSTGRESQL = 'postgresql'
    SQLITE = 'sqlite'


class texttypes(_map_class):
    FULLTEXT = 'fulltext'
    ABSTRACT = 'abstract'


class formats(_map_class):
    XML = 'xml'
    TEXT = 'text'
    JSON = 'json'


[docs]class IndraDatabaseError(Exception):
pass
[docs]class DatabaseManager(object): """An object used to access INDRA's database. This object can be used to access and manage indra's database. It includes both basic methods and some useful, more high-level methods. It is designed to be used with postgresql, or sqlite. This object is primarily built around sqlalchemy, which is a required package for its use. It also optionally makes use of the pgcopy package for large data transfers. If you wish to access the primary database, you can simply use the `get_primary_db` to get an instance of this object using the default settings. Parameters ---------- host : str The database to which you want to interface. sqltype : OPTIONAL[str] The type of sql library used. Use one of the sql types provided by `sqltypes`. Default is `sqltypes.POSTGRESQL` label : OPTIONAL[str] A short string to indicate the purpose of the db instance. Set as primary when initialized be `get_primary_db`. Example ------- If you wish to acces the primary database and find the the metadata for a particular pmid, 1234567: >> from indra.db import get_primary_db() >> db = get_primary_db() >> res = db.select_all(db.TextRef, db.TextRef.pmid == '1234567') You will get a list of objects whose attributes give the metadata contained in the columns of the table. For more sophisticated examples, several use cases can be found in `indra.tests.test_db`. """ def __init__(self, host, sqltype=sqltypes.POSTGRESQL, label=None): self.host = host self.session = None self.Base = declarative_base() self.sqltype = sqltype self.label = label if sqltype is sqltypes.POSTGRESQL: Bytea = BYTEA else: Bytea = LargeBinary class TextRef(self.Base): __tablename__ = 'text_ref' id = Column(Integer, primary_key=True) pmid = Column(String(20)) pmcid = Column(String(20)) doi = Column(String(100)) pii = Column(String(250)) url = Column(String(250), unique=True) # Maybe longer? manuscript_id = Column(String(100), unique=True) create_date = Column(DateTime, default=func.now()) last_updated = Column(DateTime, onupdate=func.now()) __table_args__ = ( UniqueConstraint('pmid', 'doi'), UniqueConstraint('pmcid', 'doi') ) class SourceFile(self.Base): __tablename__ = 'source_file' id = Column(Integer, primary_key=True) source = Column(String(250), nullable=False) name = Column(String(250), nullable=False) load_date = Column(DateTime, default=func.now()) __table_args__ = ( UniqueConstraint('source', 'name'), ) class Updates(self.Base): __tablename__ = 'updates' id = Column(Integer, primary_key=True) init_upload = Column(Boolean, nullable=False) source = Column(String(250), nullable=False) unresolved_conflicts_file = Column(Bytea) datetime = Column(DateTime, default=func.now()) class TextContent(self.Base): __tablename__ = 'text_content' id = Column(Integer, primary_key=True) text_ref_id = Column(Integer, ForeignKey('text_ref.id'), nullable=False) text_ref = relationship(TextRef) source = Column(String(250), nullable=False) format = Column(String(250), nullable=False) text_type = Column(String(250), nullable=False) content = Column(Bytea, nullable=False) insert_date = Column(DateTime, default=func.now()) last_updated = Column(DateTime, onupdate=func.now()) __table_args__ = ( UniqueConstraint( 'text_ref_id', 'source', 'format', 'text_type' ), ) class Readings(self.Base): __tablename__ = 'readings' id = Column(Integer, primary_key=True) text_content_id = Column(Integer, ForeignKey('text_content.id'), nullable=False) text_content = relationship(TextContent) reader = Column(String(20), nullable=False) reader_version = Column(String(20), nullable=False) format = Column(String(20), nullable=False) # xml, json, etc. bytes = Column(Bytea, nullable=False) create_date = Column(DateTime, default=func.now()) last_updated = Column(DateTime, onupdate=func.now()) __table_args__ = ( UniqueConstraint( 'text_content_id', 'reader', 'reader_version' ), ) class DBInfo(self.Base): __tablename__ = 'db_info' id = Column(Integer, primary_key=True) db_name = Column(String(20), nullable=False) create_date = Column(DateTime, default=func.now()) last_updated = Column(DateTime, onupdate=func.now()) class Statements(self.Base): __tablename__ = 'statements' id = Column(Integer, primary_key=True) uuid = Column(String(40), unique=True, nullable=False) db_ref = Column(Integer, ForeignKey('db_info.id')) db_info = relationship(DBInfo) reader_ref = Column(Integer, ForeignKey('readings.id')) readings = relationship(Readings) type = Column(String(100), nullable=False) indra_version = Column(String(100), nullable=False) json = Column(Bytea, nullable=False) create_date = Column(DateTime, default=func.now()) class Agents(self.Base): __tablename__ = 'agents' id = Column(Integer, primary_key=True) stmt_id = Column(Integer, ForeignKey('statements.id'), nullable=False) statements = relationship(Statements) db_name = Column(String(40), nullable=False) db_id = Column(String, nullable=False) role = Column(String(20), nullable=False) self.tables = {} for tbl in [TextRef, TextContent, Readings, SourceFile, Updates, DBInfo, Statements, Agents]: self.tables[tbl.__tablename__] = tbl self.__setattr__(tbl.__name__, tbl) self.engine = create_engine(host) def __del__(self, *args, **kwargs): try: self.grab_session() self.session.rollback() except: print("Failed to execute rollback of database upon deletion.")
[docs] def create_tables(self, tbl_list=None): "Create the tables for INDRA database." if tbl_list is None: logger.debug("Creating all tables...") self.Base.metadata.create_all(self.engine) logger.debug("Created all tables.") else: tbl_name_list = [] for tbl in tbl_list: if isinstance(tbl, str): tbl_name_list.append(tbl) else: tbl_name_list.append(tbl.__tablename__) # These tables must be created in this order. for tbl_name in ['text_ref', 'text_content', 'readings', 'db_info', 'statements', 'agents']: if tbl_name in tbl_name_list: tbl_name_list.remove(tbl_name) logger.debug("Creating %s..." % tbl_name) if not self.tables[tbl_name].__table__.exists(self.engine): self.tables[tbl_name].__table__.create(bind=self.engine) logger.debug("Table created.") else: logger.debug("Table already existed.") # The rest can be started any time. for tbl_name in tbl_name_list: logger.debug("Creating %s..." % tbl_name) self.tables[tbl_name].__table__.create(bind=self.engine) logger.debug("Table created.")
return
[docs] def drop_tables(self, tbl_list=None, force=False): """Drop the tables for INDRA database given in tbl_list. If tbl_list is None, all tables will be dropped. Note that if `force` is False, a warning prompt will be raised to asking for confirmation, as this action will remove all data from that table. """ if tbl_list is not None: tbl_objs = [] for tbl in tbl_list: if isinstance(tbl, str): tbl_objs.append(self.tables[tbl]) else: tbl_objs.append(tbl) if not force: # Build the message if tbl_list is None: msg = "Do you really want to clear the primary database? [y/N]: " else: msg = "You are going to clear the following tables:\n" msg += str([tbl.__tablename__ for tbl in tbl_objs]) + '\n' msg += "Do you really want to clear these tables? [y/N]: " # Check to make sure. try: resp = raw_input(msg) except NameError: resp = input(msg) if resp != 'y' and resp != 'yes': logger.info('Aborting clear.') return False if tbl_list is None: logger.info("Removing all tables...") self.Base.metadata.drop_all(self.engine) logger.debug("All tables removed.") else: for tbl in tbl_list: logger.info("Removing %s..." % tbl.__tablename__) if tbl.__table__.exists(self.engine): tbl.__table__.drop(self.engine) logger.debug("Table removed.") else: logger.debug("Table doesn't exist.")
return True def _clear(self, tbl_list=None, force=False): "Brutal clearing of all tables in tbl_list, or all tables." # This is intended for testing purposes, not general use. # Use with care. self.grab_session() logger.debug("Rolling back before clear...") self.session.rollback() logger.debug("Rolled back.") if self.drop_tables(tbl_list, force=force): self.create_tables(tbl_list) return True else: return False
[docs] def grab_session(self): "Get an active session with the database." if self.session is None or not self.session.is_active: logger.debug('Attempting to get session...') DBSession = sessionmaker(bind=self.engine) logger.debug('Got session.') self.session = DBSession() if self.session is None:
raise IndraDatabaseError("Failed to grab session.")
[docs] def get_tables(self): "Get a list of available tables."
return [tbl_name for tbl_name in self.tables.keys()]
[docs] def show_tables(self): "Print a list of all the available tables."
print(self.get_tables())
[docs] def get_active_tables(self): "Get the tables currently active in the database."
return inspect(self.engine).get_table_names()
[docs] def get_column_names(self, tbl_name): "Get a list of the column labels for a table."
return self.get_column_objects(tbl_name).keys()
[docs] def get_column_objects(self, table): 'Get a list of the column object for the given table.' if isinstance(table, type(self.Base)): table = table.__tablename__
return self.Base.metadata.tables[table].columns
[docs] def commit(self, err_msg): "Commit, and give useful info if there is an exception." try: logger.debug('Attempting to commit...') self.session.commit() logger.debug('Message committed.') except Exception as e: if self.session is not None: logger.error('Got exception in commit, rolling back...') self.session.rollback() logger.debug('Rolled back.') logger.exception(e) logger.error(err_msg)
raise
[docs] def get_values(self, entry_list, col_names=None, keyed=False): "Get the column values from the entries in entry_list" if col_names is None and len(entry_list) > 0: # Get everything. col_names = self.get_column_names(entry_list[0].__tablename__) ret = [] for entry in entry_list: if _isiterable(col_names): if not keyed: ret.append([getattr(entry, col) for col in col_names]) else: ret.append({col: getattr(entry, col) for col in col_names}) else: ret.append(getattr(entry, col_names))
return ret
[docs] def insert(self, tbl_name, ret_info='id', **input_dict): "Insert a an entry into specified table, and return id." self.grab_session() inputs = dict.fromkeys(self.get_column_names(tbl_name)) inputs.update(input_dict) new_entry = self.tables[tbl_name](**inputs) self.session.add(new_entry) self.commit("Excepted while trying to insert %s into %s" % (inputs, tbl_name))
return self.get_values([new_entry], ret_info)[0]
[docs] def insert_many(self, tbl_name, input_dict_list, ret_info='id'): "Insert many records into the table given by table_name." self.grab_session() inputs = dict.fromkeys(self.get_column_names(tbl_name)) entry_list = [] for input_dict in input_dict_list: inputs.update(input_dict) entry_list.append(self.tables[tbl_name](**inputs)) inputs = inputs.fromkeys(inputs) # Clear the values of the dict. self.session.add_all(entry_list) self.commit("Excepted while trying to insert:\n%s,\ninto %s" % (input_dict_list, tbl_name))
return self.get_values(entry_list, ret_info)
[docs] def delete_all(self, entry_list): "Remove the given records from the given table." self.grab_session() for entry in entry_list: self.session.delete(entry) self.commit("Could not remove %d records from the database." % len(entry_list))
return
[docs] def copy(self, tbl_name, data, cols=None): "Use pg_copy to copy over a large amount of data." logger.info("Received request to copy %d entries into %s." % (len(data), tbl_name)) if len(data) is 0: return # Nothing to do.... # If cols is not specified, use all the cols in the table, else check # to make sure the names are valid. if cols is None: cols = self.get_column_names(tbl_name) else: db_cols = self.get_column_names(tbl_name) assert all([col in db_cols for col in cols]),\ "Do not recognize one of the columns in %s for table %s." % \ (cols, tbl_name) # Do the copy. Use pgcopy if available. if self.sqltype == sqltypes.POSTGRESQL and CAN_COPY: # Check for automatic timestamps which won't be applied by the # database when using copy, and manually insert them. auto_timestamp_type = type(func.now()) for col in self.get_column_objects(tbl_name): if col.default is not None: if isinstance(col.default.arg, auto_timestamp_type) \ and col.name not in cols: logger.info("Applying timestamps to %s." % col.name) now = datetime.utcnow() cols += (col.name,) data = [datum + (now,) for datum in data] # Now actually do the copy conn = self.engine.raw_connection() mngr = CopyManager(conn, tbl_name, cols) data_bts = [] for entry in data: new_entry = [] for element in entry: if isinstance(element, str): new_entry.append(element.encode('utf8')) elif (isinstance(element, bytes) or element is None or isinstance(element, Number) or isinstance(element, datetime)): new_entry.append(element) else: raise IndraDatabaseError( "Don't know what to do with element of type %s." "Should be str, bytes, datetime, None, or a " "number." % type(element) ) data_bts.append(tuple(new_entry)) mngr.copy(data_bts, BytesIO) conn.commit() else: # TODO: use bulk insert mappings? logger.warning("You are not using postgresql or do not have " "pgcopy, so this will likely be very slow.")
self.insert_many(tbl_name, [dict(zip(cols, ro)) for ro in data])
[docs] def filter_query(self, tbls, *args): "Query a table and filter results." self.grab_session() if _isiterable(tbls) and not isinstance(tbls, dict): if isinstance(tbls[0], type(self.Base)): query_args = tbls elif isinstance(tbls[0], str): query_args = [self.tables[tbl] for tbl in tbls] else: raise IndraDatabaseError( 'Unrecognized table specification type: %s.' % type(tbls[0]) ) else: if isinstance(tbls, type(self.Base)): query_args = [tbls] elif isinstance(tbls, str): query_args = [self.tables[tbls]] else: raise IndraDatabaseError( 'Unrecognized table specification type: %s.' % type(tbls) )
return self.session.query(*query_args).filter(*args)
[docs] def select_one(self, tbls, *args): """Select the first value that matches requirements. Requirements are given in kwargs from table indicated by tbl_name. See *select_all*. Note that if your specification yields multiple results, this method will just return the first result without exception. """
return self.filter_query(tbls, *args).first()
[docs] def select_all(self, tbls, *args): """Select any and all entries from table given by tbl_name. The results will be filtered by your keyword arguments. For example if you want to get a text ref with pmid '10532205', you would call: .. code-block:: python db.select_all('text_ref', db.TextRef.pmid == '10532205') Note that double equals are required, not a single equal. Eqivalently you could call: .. code-block:: python db.select_all(db.TextRef, db.TextRef.pmid == '10532205') For a more complicated example, suppose you want to get all text refs that have full text from pmc oa, you could select: .. code-block:: python db.select_all( [db.TextRef, db.TextContent], db.TextContent.text_ref_id == db.TextRef.id, db.TextContent.source == 'pmc_oa', db.TextContent.text_type == 'fulltext' ) """
return self.filter_query(tbls, *args).all()
[docs] def has_entry(self, tbls, *args): "Check whether an entry/entries matching given specs live in the db." q = self.filter_query(tbls, *args)
return self.session.query(q.exists()).first()[0]