Source code for factory.alchemy

# Copyright: See the LICENSE file.

from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm.exc import NoResultFound

from . import base, errors

SESSION_PERSISTENCE_COMMIT = 'commit'
SESSION_PERSISTENCE_FLUSH = 'flush'
VALID_SESSION_PERSISTENCE_TYPES = [
    None,
    SESSION_PERSISTENCE_COMMIT,
    SESSION_PERSISTENCE_FLUSH,
]


[docs] class SQLAlchemyOptions(base.FactoryOptions): def _check_sqlalchemy_session_persistence(self, meta, value): if value not in VALID_SESSION_PERSISTENCE_TYPES: raise TypeError( "%s.sqlalchemy_session_persistence must be one of %s, got %r" % (meta, VALID_SESSION_PERSISTENCE_TYPES, value) ) @staticmethod def _check_has_sqlalchemy_session_set(meta, value): if value is not None and getattr(meta, "sqlalchemy_session", None) is not None: raise RuntimeError("Provide either a sqlalchemy_session or a sqlalchemy_session_factory, not both") def _build_default_options(self): return super()._build_default_options() + [ base.OptionDefault('sqlalchemy_get_or_create', (), inherit=True), base.OptionDefault('sqlalchemy_session', None, inherit=True), base.OptionDefault( 'sqlalchemy_session_factory', None, inherit=True, checker=self._check_has_sqlalchemy_session_set ), base.OptionDefault( 'sqlalchemy_session_persistence', None, inherit=True, checker=self._check_sqlalchemy_session_persistence, ), ]
[docs] class SQLAlchemyModelFactory(base.Factory): """Factory for SQLAlchemy models. """ _options_class = SQLAlchemyOptions _original_params = None class Meta: abstract = True @classmethod def _generate(cls, strategy, params): # Original params are used in _get_or_create if it cannot build an # object initially due to an IntegrityError being raised cls._original_params = params return super()._generate(strategy, params) @classmethod def _get_or_create(cls, model_class, session, args, kwargs): key_fields = {} for field in cls._meta.sqlalchemy_get_or_create: if field not in kwargs: raise errors.FactoryError( "sqlalchemy_get_or_create - " "Unable to find initialization value for '%s' in factory %s" % (field, cls.__name__)) key_fields[field] = kwargs.pop(field) obj = session.query(model_class).filter_by( *args, **key_fields).one_or_none() if not obj: try: obj = cls._save(model_class, session, args, {**key_fields, **kwargs}) except IntegrityError as e: session.rollback() if cls._original_params is None: raise e get_or_create_params = { lookup: value for lookup, value in cls._original_params.items() if lookup in cls._meta.sqlalchemy_get_or_create } if get_or_create_params: try: obj = session.query(model_class).filter_by( **get_or_create_params).one() except NoResultFound: # Original params are not a valid lookup and triggered a create(), # that resulted in an IntegrityError. raise e else: raise e return obj @classmethod def _create(cls, model_class, *args, **kwargs): """Create an instance of the model, and save it to the database.""" session_factory = cls._meta.sqlalchemy_session_factory if session_factory: cls._meta.sqlalchemy_session = session_factory() session = cls._meta.sqlalchemy_session if session is None: raise RuntimeError("No session provided.") if cls._meta.sqlalchemy_get_or_create: return cls._get_or_create(model_class, session, args, kwargs) return cls._save(model_class, session, args, kwargs) @classmethod def _save(cls, model_class, session, args, kwargs): session_persistence = cls._meta.sqlalchemy_session_persistence obj = model_class(*args, **kwargs) session.add(obj) if session_persistence == SESSION_PERSISTENCE_FLUSH: session.flush() elif session_persistence == SESSION_PERSISTENCE_COMMIT: session.commit() return obj